"""Compliance endpoints — framework status, reports, and gap analysis. Provides compliance posture assessment by mapping MITRE ATT&CK technique coverage to compliance framework controls. """ import csv import io from typing import Optional from fastapi import APIRouter, Depends, HTTPException, Query from fastapi.responses import StreamingResponse from sqlalchemy.orm import Session, joinedload from app.database import get_db from app.dependencies.auth import get_current_user, require_role from app.models.user import User from app.models.compliance import ( ComplianceFramework, ComplianceControl, ComplianceControlMapping, ) from app.models.technique import Technique from app.models.test_template import TestTemplate from app.models.threat_actor import ThreatActorTechnique from app.services.scoring_service import calculate_technique_score from app.services.compliance_import_service import import_nist_800_53_mappings router = APIRouter(prefix="/compliance", tags=["compliance"]) # ── Helpers ─────────────────────────────────────────────────────────── def _classify_control(technique_scores: list[float]) -> str: """Classify a control status based on its technique scores.""" if not technique_scores: return "not_evaluated" all_above_70 = all(s >= 70 for s in technique_scores) any_above_30 = any(s >= 30 for s in technique_scores) all_below_30 = all(s < 30 for s in technique_scores) all_zero = all(s == 0 for s in technique_scores) if all_zero: return "not_evaluated" if all_above_70: return "covered" if all_below_30: return "not_covered" if any_above_30: return "partially_covered" return "not_covered" def _get_control_status(control: ComplianceControl, db: Session) -> dict: """Compute the status and score for a single control.""" mappings = ( db.query(ComplianceControlMapping) .filter(ComplianceControlMapping.compliance_control_id == control.id) .all() ) if not mappings: return { "control_id": control.control_id, "title": control.title, "category": control.category, "status": "not_evaluated", "score": 0, "techniques_count": 0, "techniques_covered": 0, "techniques": [], } technique_ids = [m.technique_id for m in mappings] techniques = ( db.query(Technique) .filter(Technique.id.in_(technique_ids)) .all() ) tech_details = [] scores = [] covered_count = 0 for tech in techniques: result = calculate_technique_score(tech, db) score = result["total_score"] scores.append(score) if score >= 50: covered_count += 1 tech_details.append({ "mitre_id": tech.mitre_id, "name": tech.name, "score": score, "status": tech.status_global.value if tech.status_global else "not_evaluated", }) # Sort techniques by score ascending (worst first for priority) tech_details.sort(key=lambda t: t["score"]) avg_score = round(sum(scores) / len(scores), 1) if scores else 0 status = _classify_control(scores) return { "control_id": control.control_id, "title": control.title, "category": control.category, "status": status, "score": avg_score, "techniques_count": len(techniques), "techniques_covered": covered_count, "techniques": tech_details, } # ── GET /compliance/frameworks ──────────────────────────────────────── @router.get("/frameworks") def list_frameworks( db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ): """List all available compliance frameworks.""" frameworks = ( db.query(ComplianceFramework) .filter(ComplianceFramework.is_active == True) .all() ) result = [] for fw in frameworks: control_count = ( db.query(ComplianceControl) .filter(ComplianceControl.framework_id == fw.id) .count() ) result.append({ "id": str(fw.id), "name": fw.name, "version": fw.version, "description": fw.description, "url": fw.url, "is_active": fw.is_active, "controls_count": control_count, }) return result # ── GET /compliance/frameworks/{id}/status ──────────────────────────── @router.get("/frameworks/{framework_id}/status") def framework_status( framework_id: str, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ): """Get compliance status for each control in a framework.""" framework = ( db.query(ComplianceFramework) .filter(ComplianceFramework.id == framework_id) .first() ) if not framework: raise HTTPException(status_code=404, detail="Framework not found") controls = ( db.query(ComplianceControl) .filter(ComplianceControl.framework_id == framework.id) .order_by(ComplianceControl.control_id) .all() ) control_statuses = [] summary = { "total_controls": len(controls), "covered": 0, "partially_covered": 0, "not_covered": 0, "not_evaluated": 0, } for control in controls: status_data = _get_control_status(control, db) control_statuses.append(status_data) status = status_data["status"] if status in summary: summary[status] += 1 # Compliance percentage: (covered + partially_covered*0.5) / total * 100 total = summary["total_controls"] if total > 0: compliance_pct = round( (summary["covered"] + summary["partially_covered"] * 0.5) / total * 100, 1, ) else: compliance_pct = 0 summary["compliance_percentage"] = compliance_pct return { "framework": {"id": str(framework.id), "name": framework.name}, "summary": summary, "controls": control_statuses, } # ── GET /compliance/frameworks/{id}/report ──────────────────────────── @router.get("/frameworks/{framework_id}/report") def framework_report( framework_id: str, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ): """Get the full compliance report (same as status but marked as report).""" return framework_status(framework_id, db=db, current_user=current_user) # ── GET /compliance/frameworks/{id}/report/csv ──────────────────────── @router.get("/frameworks/{framework_id}/report/csv") def framework_report_csv( framework_id: str, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ): """Export compliance report as CSV.""" framework = ( db.query(ComplianceFramework) .filter(ComplianceFramework.id == framework_id) .first() ) if not framework: raise HTTPException(status_code=404, detail="Framework not found") controls = ( db.query(ComplianceControl) .filter(ComplianceControl.framework_id == framework.id) .order_by(ComplianceControl.control_id) .all() ) output = io.StringIO() writer = csv.writer(output) writer.writerow([ "control_id", "title", "category", "status", "score", "techniques_total", "techniques_covered", "technique_ids", ]) for control in controls: status_data = _get_control_status(control, db) technique_ids = ",".join(t["mitre_id"] for t in status_data["techniques"]) writer.writerow([ status_data["control_id"], status_data["title"], status_data["category"] or "", status_data["status"], status_data["score"], status_data["techniques_count"], status_data["techniques_covered"], technique_ids, ]) output.seek(0) filename = f"compliance_{framework.name.replace(' ', '_')}.csv" return StreamingResponse( io.BytesIO(output.getvalue().encode("utf-8")), media_type="text/csv", headers={ "Content-Disposition": f"attachment; filename={filename}", }, ) # ── GET /compliance/frameworks/{id}/gaps ────────────────────────────── @router.get("/frameworks/{framework_id}/gaps") def framework_gaps( framework_id: str, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ): """Get controls with techniques that are not adequately covered.""" framework = ( db.query(ComplianceFramework) .filter(ComplianceFramework.id == framework_id) .first() ) if not framework: raise HTTPException(status_code=404, detail="Framework not found") controls = ( db.query(ComplianceControl) .filter(ComplianceControl.framework_id == framework.id) .order_by(ComplianceControl.control_id) .all() ) gaps = [] for control in controls: status_data = _get_control_status(control, db) if status_data["status"] in ("not_covered", "partially_covered"): # Find uncovered techniques uncovered_techniques = [] for tech_info in status_data["techniques"]: if tech_info["score"] < 70: # Count available templates template_count = ( db.query(TestTemplate) .filter(TestTemplate.mitre_technique_id == tech_info["mitre_id"]) .count() ) # Count threat actors using this technique technique = ( db.query(Technique) .filter(Technique.mitre_id == tech_info["mitre_id"]) .first() ) actor_count = 0 if technique: actor_count = ( db.query(ThreatActorTechnique) .filter(ThreatActorTechnique.technique_id == technique.id) .count() ) uncovered_techniques.append({ **tech_info, "templates_available": template_count, "threat_actors_using": actor_count, }) if uncovered_techniques: gaps.append({ "control_id": status_data["control_id"], "title": status_data["title"], "category": status_data["category"], "status": status_data["status"], "score": status_data["score"], "uncovered_techniques": uncovered_techniques, }) return { "framework": {"id": str(framework.id), "name": framework.name}, "total_gaps": len(gaps), "gaps": gaps, } # ── POST /compliance/import/nist-800-53 ────────────────────────────── @router.post("/import/nist-800-53") def import_nist( db: Session = Depends(get_db), current_user: User = Depends(require_role("admin")), ): """Import NIST 800-53 Rev 5 mappings (admin only).""" result = import_nist_800_53_mappings(db) return result