"""Compliance data service. Extracts query and aggregation logic from the compliance router so that the router remains a thin HTTP adapter. This module is framework-agnostic: no FastAPI imports. """ from __future__ import annotations import csv import io from typing import Any from sqlalchemy.orm import Session from app.domain.errors import EntityNotFoundError from app.models.compliance import ( ComplianceFramework, ComplianceControl, ComplianceControlMapping, ) from app.models.technique import Technique from app.models.test_template import TestTemplate from app.models.threat_actor import ThreatActorTechnique from app.services.scoring_service import calculate_technique_score # ── Helpers ─────────────────────────────────────────────────────────── def _classify_control(technique_scores: list[float]) -> str: """Classify a control status based on its technique scores.""" if not technique_scores: return "not_evaluated" all_above_70 = all(s >= 70 for s in technique_scores) any_above_30 = any(s >= 30 for s in technique_scores) all_below_30 = all(s < 30 for s in technique_scores) all_zero = all(s == 0 for s in technique_scores) if all_zero: return "not_evaluated" if all_above_70: return "covered" if all_below_30: return "not_covered" if any_above_30: return "partially_covered" return "not_covered" def _get_control_status(control: ComplianceControl, db: Session) -> dict[str, Any]: """Compute the status and score for a single control.""" mappings = ( db.query(ComplianceControlMapping) .filter(ComplianceControlMapping.compliance_control_id == control.id) .all() ) if not mappings: return { "control_id": control.control_id, "title": control.title, "category": control.category, "status": "not_evaluated", "score": 0, "techniques_count": 0, "techniques_covered": 0, "techniques": [], } technique_ids = [m.technique_id for m in mappings] techniques = ( db.query(Technique) .filter(Technique.id.in_(technique_ids)) .all() ) tech_details = [] scores = [] covered_count = 0 for tech in techniques: result = calculate_technique_score(tech, db) score = result["total_score"] scores.append(score) if score >= 50: covered_count += 1 tech_details.append({ "mitre_id": tech.mitre_id, "name": tech.name, "score": score, "status": tech.status_global.value if tech.status_global else "not_evaluated", }) # Sort techniques by score ascending (worst first for priority) tech_details.sort(key=lambda t: t["score"]) avg_score = round(sum(scores) / len(scores), 1) if scores else 0 status = _classify_control(scores) return { "control_id": control.control_id, "title": control.title, "category": control.category, "status": status, "score": avg_score, "techniques_count": len(techniques), "techniques_covered": covered_count, "techniques": tech_details, } # ── Public service functions ─────────────────────────────────────────── def list_frameworks(db: Session) -> list[dict[str, Any]]: """List all available compliance frameworks with control counts.""" frameworks = ( db.query(ComplianceFramework) .filter(ComplianceFramework.is_active == True) .all() ) result = [] for fw in frameworks: control_count = ( db.query(ComplianceControl) .filter(ComplianceControl.framework_id == fw.id) .count() ) result.append({ "id": str(fw.id), "name": fw.name, "version": fw.version, "description": fw.description, "url": fw.url, "is_active": fw.is_active, "controls_count": control_count, }) return result def get_framework(db: Session, framework_id: str) -> ComplianceFramework | None: """Get a framework by ID, or None if not found.""" return ( db.query(ComplianceFramework) .filter(ComplianceFramework.id == framework_id) .first() ) def get_framework_status(db: Session, framework_id: str) -> dict[str, Any]: """Get compliance status for each control in a framework. Raises EntityNotFoundError if the framework does not exist. """ framework = get_framework(db, framework_id) if not framework: raise EntityNotFoundError("Framework", framework_id) controls = ( db.query(ComplianceControl) .filter(ComplianceControl.framework_id == framework.id) .order_by(ComplianceControl.control_id) .all() ) control_statuses = [] summary = { "total_controls": len(controls), "covered": 0, "partially_covered": 0, "not_covered": 0, "not_evaluated": 0, } for control in controls: status_data = _get_control_status(control, db) control_statuses.append(status_data) status = status_data["status"] if status in summary: summary[status] += 1 # Compliance percentage: (covered + partially_covered*0.5) / total * 100 total = summary["total_controls"] if total > 0: compliance_pct = round( (summary["covered"] + summary["partially_covered"] * 0.5) / total * 100, 1, ) else: compliance_pct = 0 summary["compliance_percentage"] = compliance_pct return { "framework": {"id": str(framework.id), "name": framework.name}, "summary": summary, "controls": control_statuses, } def build_framework_report_csv( db: Session, framework_id: str, ) -> tuple[bytes, str]: """Build the compliance report CSV content and filename. Returns (csv_bytes, filename). Raises EntityNotFoundError if the framework does not exist. """ framework = get_framework(db, framework_id) if not framework: raise EntityNotFoundError("Framework", framework_id) controls = ( db.query(ComplianceControl) .filter(ComplianceControl.framework_id == framework.id) .order_by(ComplianceControl.control_id) .all() ) output = io.StringIO() writer = csv.writer(output) writer.writerow([ "control_id", "title", "category", "status", "score", "techniques_total", "techniques_covered", "technique_ids", ]) for control in controls: status_data = _get_control_status(control, db) technique_ids = ",".join(t["mitre_id"] for t in status_data["techniques"]) writer.writerow([ status_data["control_id"], status_data["title"], status_data["category"] or "", status_data["status"], status_data["score"], status_data["techniques_count"], status_data["techniques_covered"], technique_ids, ]) output.seek(0) filename = f"compliance_{framework.name.replace(' ', '_')}.csv" return output.getvalue().encode("utf-8"), filename def get_framework_gaps(db: Session, framework_id: str) -> dict[str, Any]: """Get controls with techniques that are not adequately covered. Raises EntityNotFoundError if the framework does not exist. """ framework = get_framework(db, framework_id) if not framework: raise EntityNotFoundError("Framework", framework_id) controls = ( db.query(ComplianceControl) .filter(ComplianceControl.framework_id == framework.id) .order_by(ComplianceControl.control_id) .all() ) gaps = [] for control in controls: status_data = _get_control_status(control, db) if status_data["status"] in ("not_covered", "partially_covered"): # Find uncovered techniques uncovered_techniques = [] for tech_info in status_data["techniques"]: if tech_info["score"] < 70: # Count available templates template_count = ( db.query(TestTemplate) .filter(TestTemplate.mitre_technique_id == tech_info["mitre_id"]) .count() ) # Count threat actors using this technique technique = ( db.query(Technique) .filter(Technique.mitre_id == tech_info["mitre_id"]) .first() ) actor_count = 0 if technique: actor_count = ( db.query(ThreatActorTechnique) .filter(ThreatActorTechnique.technique_id == technique.id) .count() ) uncovered_techniques.append({ **tech_info, "templates_available": template_count, "threat_actors_using": actor_count, }) if uncovered_techniques: gaps.append({ "control_id": status_data["control_id"], "title": status_data["title"], "category": status_data["category"], "status": status_data["status"], "score": status_data["score"], "uncovered_techniques": uncovered_techniques, }) return { "framework": {"id": str(framework.id), "name": framework.name}, "total_gaps": len(gaps), "gaps": gaps, }