"""Compliance data service. Extracts query and aggregation logic from the compliance router so that the router remains a thin HTTP adapter. This module is framework-agnostic: no FastAPI imports. """ # Enable future language features for compatibility from __future__ import annotations # Import csv import csv # Import io import io # Import Any from typing from typing import Any # Import Session from sqlalchemy.orm from sqlalchemy.orm import Session # Import EntityNotFoundError from app.domain.errors from app.domain.errors import EntityNotFoundError # Import from app.models.compliance from app.models.compliance import ( ComplianceControl, ComplianceControlMapping, ComplianceFramework, ) # Import Technique from app.models.technique from app.models.technique import Technique # Import TestTemplate from app.models.test_template from app.models.test_template import TestTemplate # Import ThreatActorTechnique from app.models.threat_actor from app.models.threat_actor import ThreatActorTechnique # Import calculate_technique_score from app.services.scoring_service from app.services.scoring_service import calculate_technique_score # ── Helpers ─────────────────────────────────────────────────────────── def _classify_control(technique_scores: list[float]) -> str: """Classify a control status based on its technique scores.""" # Check: not technique_scores if not technique_scores: # Return "not_evaluated" return "not_evaluated" # Assign all_above_70 = all(s >= 70 for s in technique_scores) all_above_70 = all(s >= 70 for s in technique_scores) # Assign any_above_30 = any(s >= 30 for s in technique_scores) any_above_30 = any(s >= 30 for s in technique_scores) # Assign all_below_30 = all(s < 30 for s in technique_scores) all_below_30 = all(s < 30 for s in technique_scores) # Assign all_zero = all(s == 0 for s in technique_scores) all_zero = all(s == 0 for s in technique_scores) # Check: all_zero if all_zero: # Return "not_evaluated" return "not_evaluated" # Check: all_above_70 if all_above_70: # Return "covered" return "covered" # Check: all_below_30 if all_below_30: # Return "not_covered" return "not_covered" # Check: any_above_30 if any_above_30: # Return "partially_covered" return "partially_covered" # Return "not_covered" return "not_covered" # Define function _get_control_status def _get_control_status(control: ComplianceControl, db: Session) -> dict[str, Any]: """Compute the status and score for a single control.""" # Assign mappings = ( mappings = ( db.query(ComplianceControlMapping) # Chain .filter() call .filter(ComplianceControlMapping.compliance_control_id == control.id) # Chain .all() call .all() ) # Check: not mappings if not mappings: # Return { return { # Literal argument value "control_id": control.control_id, # Literal argument value "title": control.title, "description": control.description, "category": control.category, # Literal argument value "status": "not_evaluated", # Literal argument value "score": 0, # Literal argument value "techniques_count": 0, # Literal argument value "techniques_covered": 0, # Literal argument value "techniques": [], } # Assign technique_ids = [m.technique_id for m in mappings] technique_ids = [m.technique_id for m in mappings] # Assign techniques = ( techniques = ( db.query(Technique) # Chain .filter() call .filter(Technique.id.in_(technique_ids)) # Chain .all() call .all() ) # Assign tech_details = [] tech_details = [] # Assign scores = [] scores = [] # Assign covered_count = 0 covered_count = 0 # Iterate over techniques for tech in techniques: # Assign result = calculate_technique_score(tech, db) result = calculate_technique_score(tech, db) # Assign score = result["total_score"] score = result["total_score"] # Call scores.append() scores.append(score) # Check: score >= 50 if score >= 50: # Assign covered_count = 1 covered_count += 1 # Call tech_details.append() tech_details.append({ # Literal argument value "mitre_id": tech.mitre_id, # Literal argument value "name": tech.name, # Literal argument value "score": score, # Literal argument value "status": tech.status_global.value if tech.status_global else "not_evaluated", }) # Sort techniques by score ascending (worst first for priority) tech_details.sort(key=lambda t: t["score"]) # Assign avg_score = round(sum(scores) / len(scores), 1) if scores else 0 avg_score = round(sum(scores) / len(scores), 1) if scores else 0 # Assign status = _classify_control(scores) status = _classify_control(scores) # Return { return { # Literal argument value "control_id": control.control_id, # Literal argument value "title": control.title, "description": control.description, "category": control.category, # Literal argument value "status": status, # Literal argument value "score": avg_score, # Literal argument value "techniques_count": len(techniques), # Literal argument value "techniques_covered": covered_count, # Literal argument value "techniques": tech_details, } # ── Public service functions ─────────────────────────────────────────── def list_frameworks(db: Session) -> list[dict[str, Any]]: """List all available compliance frameworks with control counts.""" # Assign frameworks = ( frameworks = ( db.query(ComplianceFramework) # Chain .filter() call .filter(ComplianceFramework.is_active == True) # Chain .all() call .all() ) # Assign result = [] result = [] # Iterate over frameworks for fw in frameworks: # Assign control_count = ( control_count = ( db.query(ComplianceControl) # Chain .filter() call .filter(ComplianceControl.framework_id == fw.id) # Chain .count() call .count() ) # Call result.append() result.append({ # Literal argument value "id": str(fw.id), # Literal argument value "name": fw.name, # Literal argument value "version": fw.version, # Literal argument value "description": fw.description, # Literal argument value "url": fw.url, # Literal argument value "is_active": fw.is_active, # Literal argument value "controls_count": control_count, }) # Return result return result # Define function get_framework def get_framework(db: Session, framework_id: str) -> ComplianceFramework | None: """Get a framework by ID, or None if not found.""" # Return ( return ( db.query(ComplianceFramework) # Chain .filter() call .filter(ComplianceFramework.id == framework_id) # Chain .first() call .first() ) # Define function get_framework_status def get_framework_status(db: Session, framework_id: str) -> dict[str, Any]: """Get compliance status for each control in a framework. Raises EntityNotFoundError if the framework does not exist. """ # Assign framework = get_framework(db, framework_id) framework = get_framework(db, framework_id) # Check: not framework if not framework: # Raise EntityNotFoundError raise EntityNotFoundError("Framework", framework_id) # Assign controls = ( controls = ( db.query(ComplianceControl) # Chain .filter() call .filter(ComplianceControl.framework_id == framework.id) # Chain .order_by() call .order_by(ComplianceControl.control_id) # Chain .all() call .all() ) # Assign control_statuses = [] control_statuses = [] # Assign summary = { summary = { # Literal argument value "total_controls": len(controls), # Literal argument value "covered": 0, # Literal argument value "partially_covered": 0, # Literal argument value "not_covered": 0, # Literal argument value "not_evaluated": 0, } # Iterate over controls for control in controls: # Assign status_data = _get_control_status(control, db) status_data = _get_control_status(control, db) # Call control_statuses.append() control_statuses.append(status_data) # Assign status = status_data["status"] status = status_data["status"] # Check: status in summary if status in summary: # Assign summary[status] = 1 summary[status] += 1 # Compliance percentage: (covered + partially_covered*0.5) / total * 100 total = summary["total_controls"] # Check: total > 0 if total > 0: # Assign compliance_pct = round( compliance_pct = round( (summary["covered"] + summary["partially_covered"] * 0.5) / total * 100, # Literal argument value 1, ) # Fallback: handle remaining cases else: # Assign compliance_pct = 0 compliance_pct = 0 # Assign summary["compliance_percentage"] = compliance_pct summary["compliance_percentage"] = compliance_pct # Return { return { # Literal argument value "framework": {"id": str(framework.id), "name": framework.name}, # Literal argument value "summary": summary, # Literal argument value "controls": control_statuses, } # Define function build_framework_report_csv def build_framework_report_csv( # Entry: db db: Session, # Entry: framework_id framework_id: str, ) -> tuple[bytes, str]: """Build the compliance report CSV content and filename. Returns (csv_bytes, filename). Raises EntityNotFoundError if the framework does not exist. """ # Assign framework = get_framework(db, framework_id) framework = get_framework(db, framework_id) # Check: not framework if not framework: # Raise EntityNotFoundError raise EntityNotFoundError("Framework", framework_id) # Assign controls = ( controls = ( db.query(ComplianceControl) # Chain .filter() call .filter(ComplianceControl.framework_id == framework.id) # Chain .order_by() call .order_by(ComplianceControl.control_id) # Chain .all() call .all() ) # Assign output = io.StringIO() output = io.StringIO() # Assign writer = csv.writer(output) writer = csv.writer(output) # Call writer.writerow() writer.writerow([ # Literal argument value "control_id", # Literal argument value "title", # Literal argument value "category", # Literal argument value "status", # Literal argument value "score", # Literal argument value "techniques_total", # Literal argument value "techniques_covered", # Literal argument value "technique_ids", ]) # Iterate over controls for control in controls: # Assign status_data = _get_control_status(control, db) status_data = _get_control_status(control, db) # Assign technique_ids = ",".join(t["mitre_id"] for t in status_data["techniques"]) technique_ids = ",".join(t["mitre_id"] for t in status_data["techniques"]) # Call writer.writerow() writer.writerow([ status_data["control_id"], status_data["title"], status_data["category"] or "", status_data["status"], status_data["score"], status_data["techniques_count"], status_data["techniques_covered"], technique_ids, ]) # Call output.seek() output.seek(0) # Assign filename = f"compliance_{framework.name.replace(' ', '_')}.csv" filename = f"compliance_{framework.name.replace(' ', '_')}.csv" # Return output.getvalue().encode("utf-8"), filename return output.getvalue().encode("utf-8"), filename # Define function get_framework_gaps def get_framework_gaps(db: Session, framework_id: str) -> dict[str, Any]: """Get controls with techniques that are not adequately covered. Raises EntityNotFoundError if the framework does not exist. """ # Assign framework = get_framework(db, framework_id) framework = get_framework(db, framework_id) # Check: not framework if not framework: # Raise EntityNotFoundError raise EntityNotFoundError("Framework", framework_id) # Assign controls = ( controls = ( db.query(ComplianceControl) # Chain .filter() call .filter(ComplianceControl.framework_id == framework.id) # Chain .order_by() call .order_by(ComplianceControl.control_id) # Chain .all() call .all() ) # Assign gaps = [] gaps = [] # Iterate over controls for control in controls: # Assign status_data = _get_control_status(control, db) status_data = _get_control_status(control, db) # Check: status_data["status"] in ("not_covered", "partially_covered") if status_data["status"] in ("not_covered", "partially_covered"): # Find uncovered techniques uncovered_techniques = [] # Iterate over status_data["techniques"] for tech_info in status_data["techniques"]: # Check: tech_info["score"] < 70 if tech_info["score"] < 70: # Count available templates template_count = ( db.query(TestTemplate) # Chain .filter() call .filter(TestTemplate.mitre_technique_id == tech_info["mitre_id"]) # Chain .count() call .count() ) # Count threat actors using this technique technique = ( db.query(Technique) # Chain .filter() call .filter(Technique.mitre_id == tech_info["mitre_id"]) # Chain .first() call .first() ) # Assign actor_count = 0 actor_count = 0 # Check: technique if technique: # Assign actor_count = ( actor_count = ( db.query(ThreatActorTechnique) # Chain .filter() call .filter(ThreatActorTechnique.technique_id == technique.id) # Chain .count() call .count() ) # Call uncovered_techniques.append() uncovered_techniques.append({ **tech_info, # Literal argument value "templates_available": template_count, # Literal argument value "threat_actors_using": actor_count, }) # Check: uncovered_techniques if uncovered_techniques: # Call gaps.append() gaps.append({ # Literal argument value "control_id": status_data["control_id"], # Literal argument value "title": status_data["title"], # Literal argument value "category": status_data["category"], # Literal argument value "status": status_data["status"], # Literal argument value "score": status_data["score"], # Literal argument value "uncovered_techniques": uncovered_techniques, }) # Return { return { # Literal argument value "framework": {"id": str(framework.id), "name": framework.name}, # Literal argument value "total_gaps": len(gaps), # Literal argument value "gaps": gaps, }