From d305db87944ee5d9c364142002130def1ac90936 Mon Sep 17 00:00:00 2001 From: Kitos Date: Thu, 19 Feb 2026 17:06:32 +0100 Subject: [PATCH] refactor(compliance): extract business logic to compliance_service, use domain exceptions instead of HTTPException --- backend/app/routers/compliance.py | 306 ++----------------- backend/app/services/compliance_service.py | 327 +++++++++++++++++++++ 2 files changed, 344 insertions(+), 289 deletions(-) create mode 100644 backend/app/services/compliance_service.py diff --git a/backend/app/routers/compliance.py b/backend/app/routers/compliance.py index 090f4f0..cd642ae 100644 --- a/backend/app/routers/compliance.py +++ b/backend/app/routers/compliance.py @@ -1,29 +1,24 @@ """Compliance endpoints — framework status, reports, and gap analysis. +Thin HTTP adapter: delegates all data logic to compliance_service. + Provides compliance posture assessment by mapping MITRE ATT&CK technique coverage to compliance framework controls. """ -import csv -import io -from typing import Optional - -from fastapi import APIRouter, Depends, HTTPException, Query +from fastapi import APIRouter, Depends from fastapi.responses import StreamingResponse -from sqlalchemy.orm import Session, joinedload +from sqlalchemy.orm import Session from app.database import get_db from app.dependencies.auth import get_current_user, require_role from app.models.user import User -from app.models.compliance import ( - ComplianceFramework, - ComplianceControl, - ComplianceControlMapping, +from app.services.compliance_service import ( + list_frameworks, + get_framework_status, + build_framework_report_csv, + get_framework_gaps, ) -from app.models.technique import Technique -from app.models.test_template import TestTemplate -from app.models.threat_actor import ThreatActorTechnique -from app.services.scoring_service import calculate_technique_score from app.services.compliance_import_service import ( import_nist_800_53_mappings, import_cis_controls_v8_mappings, @@ -32,126 +27,16 @@ from app.services.compliance_import_service import ( router = APIRouter(prefix="/compliance", tags=["compliance"]) -# ── Helpers ─────────────────────────────────────────────────────────── - - -def _classify_control(technique_scores: list[float]) -> str: - """Classify a control status based on its technique scores.""" - if not technique_scores: - return "not_evaluated" - - all_above_70 = all(s >= 70 for s in technique_scores) - any_above_30 = any(s >= 30 for s in technique_scores) - all_below_30 = all(s < 30 for s in technique_scores) - all_zero = all(s == 0 for s in technique_scores) - - if all_zero: - return "not_evaluated" - if all_above_70: - return "covered" - if all_below_30: - return "not_covered" - if any_above_30: - return "partially_covered" - return "not_covered" - - -def _get_control_status(control: ComplianceControl, db: Session) -> dict: - """Compute the status and score for a single control.""" - mappings = ( - db.query(ComplianceControlMapping) - .filter(ComplianceControlMapping.compliance_control_id == control.id) - .all() - ) - - if not mappings: - return { - "control_id": control.control_id, - "title": control.title, - "category": control.category, - "status": "not_evaluated", - "score": 0, - "techniques_count": 0, - "techniques_covered": 0, - "techniques": [], - } - - technique_ids = [m.technique_id for m in mappings] - techniques = ( - db.query(Technique) - .filter(Technique.id.in_(technique_ids)) - .all() - ) - - tech_details = [] - scores = [] - covered_count = 0 - - for tech in techniques: - result = calculate_technique_score(tech, db) - score = result["total_score"] - scores.append(score) - if score >= 50: - covered_count += 1 - - tech_details.append({ - "mitre_id": tech.mitre_id, - "name": tech.name, - "score": score, - "status": tech.status_global.value if tech.status_global else "not_evaluated", - }) - - # Sort techniques by score ascending (worst first for priority) - tech_details.sort(key=lambda t: t["score"]) - - avg_score = round(sum(scores) / len(scores), 1) if scores else 0 - status = _classify_control(scores) - - return { - "control_id": control.control_id, - "title": control.title, - "category": control.category, - "status": status, - "score": avg_score, - "techniques_count": len(techniques), - "techniques_covered": covered_count, - "techniques": tech_details, - } - - # ── GET /compliance/frameworks ──────────────────────────────────────── @router.get("/frameworks") -def list_frameworks( +def list_frameworks_endpoint( db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ): """List all available compliance frameworks.""" - frameworks = ( - db.query(ComplianceFramework) - .filter(ComplianceFramework.is_active == True) - .all() - ) - - result = [] - for fw in frameworks: - control_count = ( - db.query(ComplianceControl) - .filter(ComplianceControl.framework_id == fw.id) - .count() - ) - result.append({ - "id": str(fw.id), - "name": fw.name, - "version": fw.version, - "description": fw.description, - "url": fw.url, - "is_active": fw.is_active, - "controls_count": control_count, - }) - - return result + return list_frameworks(db) # ── GET /compliance/frameworks/{id}/status ──────────────────────────── @@ -164,55 +49,7 @@ def framework_status( current_user: User = Depends(get_current_user), ): """Get compliance status for each control in a framework.""" - framework = ( - db.query(ComplianceFramework) - .filter(ComplianceFramework.id == framework_id) - .first() - ) - if not framework: - raise HTTPException(status_code=404, detail="Framework not found") - - controls = ( - db.query(ComplianceControl) - .filter(ComplianceControl.framework_id == framework.id) - .order_by(ComplianceControl.control_id) - .all() - ) - - control_statuses = [] - summary = { - "total_controls": len(controls), - "covered": 0, - "partially_covered": 0, - "not_covered": 0, - "not_evaluated": 0, - } - - for control in controls: - status_data = _get_control_status(control, db) - control_statuses.append(status_data) - - status = status_data["status"] - if status in summary: - summary[status] += 1 - - # Compliance percentage: (covered + partially_covered*0.5) / total * 100 - total = summary["total_controls"] - if total > 0: - compliance_pct = round( - (summary["covered"] + summary["partially_covered"] * 0.5) / total * 100, - 1, - ) - else: - compliance_pct = 0 - - summary["compliance_percentage"] = compliance_pct - - return { - "framework": {"id": str(framework.id), "name": framework.name}, - "summary": summary, - "controls": control_statuses, - } + return get_framework_status(db, framework_id) # ── GET /compliance/frameworks/{id}/report ──────────────────────────── @@ -225,7 +62,7 @@ def framework_report( current_user: User = Depends(get_current_user), ): """Get the full compliance report (same as status but marked as report).""" - return framework_status(framework_id, db=db, current_user=current_user) + return get_framework_status(db, framework_id) # ── GET /compliance/frameworks/{id}/report/csv ──────────────────────── @@ -238,53 +75,9 @@ def framework_report_csv( current_user: User = Depends(get_current_user), ): """Export compliance report as CSV.""" - framework = ( - db.query(ComplianceFramework) - .filter(ComplianceFramework.id == framework_id) - .first() - ) - if not framework: - raise HTTPException(status_code=404, detail="Framework not found") - - controls = ( - db.query(ComplianceControl) - .filter(ComplianceControl.framework_id == framework.id) - .order_by(ComplianceControl.control_id) - .all() - ) - - output = io.StringIO() - writer = csv.writer(output) - writer.writerow([ - "control_id", - "title", - "category", - "status", - "score", - "techniques_total", - "techniques_covered", - "technique_ids", - ]) - - for control in controls: - status_data = _get_control_status(control, db) - technique_ids = ",".join(t["mitre_id"] for t in status_data["techniques"]) - writer.writerow([ - status_data["control_id"], - status_data["title"], - status_data["category"] or "", - status_data["status"], - status_data["score"], - status_data["techniques_count"], - status_data["techniques_covered"], - technique_ids, - ]) - - output.seek(0) - filename = f"compliance_{framework.name.replace(' ', '_')}.csv" - + csv_bytes, filename = build_framework_report_csv(db, framework_id) return StreamingResponse( - io.BytesIO(output.getvalue().encode("utf-8")), + iter([csv_bytes]), media_type="text/csv", headers={ "Content-Disposition": f"attachment; filename={filename}", @@ -302,75 +95,10 @@ def framework_gaps( current_user: User = Depends(get_current_user), ): """Get controls with techniques that are not adequately covered.""" - framework = ( - db.query(ComplianceFramework) - .filter(ComplianceFramework.id == framework_id) - .first() - ) - if not framework: - raise HTTPException(status_code=404, detail="Framework not found") - - controls = ( - db.query(ComplianceControl) - .filter(ComplianceControl.framework_id == framework.id) - .order_by(ComplianceControl.control_id) - .all() - ) - - gaps = [] - for control in controls: - status_data = _get_control_status(control, db) - - if status_data["status"] in ("not_covered", "partially_covered"): - # Find uncovered techniques - uncovered_techniques = [] - for tech_info in status_data["techniques"]: - if tech_info["score"] < 70: - # Count available templates - template_count = ( - db.query(TestTemplate) - .filter(TestTemplate.mitre_technique_id == tech_info["mitre_id"]) - .count() - ) - - # Count threat actors using this technique - technique = ( - db.query(Technique) - .filter(Technique.mitre_id == tech_info["mitre_id"]) - .first() - ) - actor_count = 0 - if technique: - actor_count = ( - db.query(ThreatActorTechnique) - .filter(ThreatActorTechnique.technique_id == technique.id) - .count() - ) - - uncovered_techniques.append({ - **tech_info, - "templates_available": template_count, - "threat_actors_using": actor_count, - }) - - if uncovered_techniques: - gaps.append({ - "control_id": status_data["control_id"], - "title": status_data["title"], - "category": status_data["category"], - "status": status_data["status"], - "score": status_data["score"], - "uncovered_techniques": uncovered_techniques, - }) - - return { - "framework": {"id": str(framework.id), "name": framework.name}, - "total_gaps": len(gaps), - "gaps": gaps, - } + return get_framework_gaps(db, framework_id) -# ── POST /compliance/import/nist-800-53 ────────────────────────────── +# ── POST /compliance/import/... ──────────────────────────────────────── @router.post("/import/nist-800-53") diff --git a/backend/app/services/compliance_service.py b/backend/app/services/compliance_service.py new file mode 100644 index 0000000..696231d --- /dev/null +++ b/backend/app/services/compliance_service.py @@ -0,0 +1,327 @@ +"""Compliance data service. + +Extracts query and aggregation logic from the compliance router so +that the router remains a thin HTTP adapter. + +This module is framework-agnostic: no FastAPI imports. +""" + +from __future__ import annotations + +import csv +import io +from typing import Any + +from sqlalchemy.orm import Session + +from app.domain.errors import EntityNotFoundError +from app.models.compliance import ( + ComplianceFramework, + ComplianceControl, + ComplianceControlMapping, +) +from app.models.technique import Technique +from app.models.test_template import TestTemplate +from app.models.threat_actor import ThreatActorTechnique +from app.services.scoring_service import calculate_technique_score + + +# ── Helpers ─────────────────────────────────────────────────────────── + + +def _classify_control(technique_scores: list[float]) -> str: + """Classify a control status based on its technique scores.""" + if not technique_scores: + return "not_evaluated" + + all_above_70 = all(s >= 70 for s in technique_scores) + any_above_30 = any(s >= 30 for s in technique_scores) + all_below_30 = all(s < 30 for s in technique_scores) + all_zero = all(s == 0 for s in technique_scores) + + if all_zero: + return "not_evaluated" + if all_above_70: + return "covered" + if all_below_30: + return "not_covered" + if any_above_30: + return "partially_covered" + return "not_covered" + + +def _get_control_status(control: ComplianceControl, db: Session) -> dict[str, Any]: + """Compute the status and score for a single control.""" + mappings = ( + db.query(ComplianceControlMapping) + .filter(ComplianceControlMapping.compliance_control_id == control.id) + .all() + ) + + if not mappings: + return { + "control_id": control.control_id, + "title": control.title, + "category": control.category, + "status": "not_evaluated", + "score": 0, + "techniques_count": 0, + "techniques_covered": 0, + "techniques": [], + } + + technique_ids = [m.technique_id for m in mappings] + techniques = ( + db.query(Technique) + .filter(Technique.id.in_(technique_ids)) + .all() + ) + + tech_details = [] + scores = [] + covered_count = 0 + + for tech in techniques: + result = calculate_technique_score(tech, db) + score = result["total_score"] + scores.append(score) + if score >= 50: + covered_count += 1 + + tech_details.append({ + "mitre_id": tech.mitre_id, + "name": tech.name, + "score": score, + "status": tech.status_global.value if tech.status_global else "not_evaluated", + }) + + # Sort techniques by score ascending (worst first for priority) + tech_details.sort(key=lambda t: t["score"]) + + avg_score = round(sum(scores) / len(scores), 1) if scores else 0 + status = _classify_control(scores) + + return { + "control_id": control.control_id, + "title": control.title, + "category": control.category, + "status": status, + "score": avg_score, + "techniques_count": len(techniques), + "techniques_covered": covered_count, + "techniques": tech_details, + } + + +# ── Public service functions ─────────────────────────────────────────── + + +def list_frameworks(db: Session) -> list[dict[str, Any]]: + """List all available compliance frameworks with control counts.""" + frameworks = ( + db.query(ComplianceFramework) + .filter(ComplianceFramework.is_active == True) + .all() + ) + + result = [] + for fw in frameworks: + control_count = ( + db.query(ComplianceControl) + .filter(ComplianceControl.framework_id == fw.id) + .count() + ) + result.append({ + "id": str(fw.id), + "name": fw.name, + "version": fw.version, + "description": fw.description, + "url": fw.url, + "is_active": fw.is_active, + "controls_count": control_count, + }) + + return result + + +def get_framework(db: Session, framework_id: str) -> ComplianceFramework | None: + """Get a framework by ID, or None if not found.""" + return ( + db.query(ComplianceFramework) + .filter(ComplianceFramework.id == framework_id) + .first() + ) + + +def get_framework_status(db: Session, framework_id: str) -> dict[str, Any]: + """Get compliance status for each control in a framework. + + Raises EntityNotFoundError if the framework does not exist. + """ + framework = get_framework(db, framework_id) + if not framework: + raise EntityNotFoundError("Framework", framework_id) + + controls = ( + db.query(ComplianceControl) + .filter(ComplianceControl.framework_id == framework.id) + .order_by(ComplianceControl.control_id) + .all() + ) + + control_statuses = [] + summary = { + "total_controls": len(controls), + "covered": 0, + "partially_covered": 0, + "not_covered": 0, + "not_evaluated": 0, + } + + for control in controls: + status_data = _get_control_status(control, db) + control_statuses.append(status_data) + + status = status_data["status"] + if status in summary: + summary[status] += 1 + + # Compliance percentage: (covered + partially_covered*0.5) / total * 100 + total = summary["total_controls"] + if total > 0: + compliance_pct = round( + (summary["covered"] + summary["partially_covered"] * 0.5) / total * 100, + 1, + ) + else: + compliance_pct = 0 + + summary["compliance_percentage"] = compliance_pct + + return { + "framework": {"id": str(framework.id), "name": framework.name}, + "summary": summary, + "controls": control_statuses, + } + + +def build_framework_report_csv( + db: Session, + framework_id: str, +) -> tuple[bytes, str]: + """Build the compliance report CSV content and filename. + + Returns (csv_bytes, filename). + + Raises EntityNotFoundError if the framework does not exist. + """ + framework = get_framework(db, framework_id) + if not framework: + raise EntityNotFoundError("Framework", framework_id) + + controls = ( + db.query(ComplianceControl) + .filter(ComplianceControl.framework_id == framework.id) + .order_by(ComplianceControl.control_id) + .all() + ) + + output = io.StringIO() + writer = csv.writer(output) + writer.writerow([ + "control_id", + "title", + "category", + "status", + "score", + "techniques_total", + "techniques_covered", + "technique_ids", + ]) + + for control in controls: + status_data = _get_control_status(control, db) + technique_ids = ",".join(t["mitre_id"] for t in status_data["techniques"]) + writer.writerow([ + status_data["control_id"], + status_data["title"], + status_data["category"] or "", + status_data["status"], + status_data["score"], + status_data["techniques_count"], + status_data["techniques_covered"], + technique_ids, + ]) + + output.seek(0) + filename = f"compliance_{framework.name.replace(' ', '_')}.csv" + return output.getvalue().encode("utf-8"), filename + + +def get_framework_gaps(db: Session, framework_id: str) -> dict[str, Any]: + """Get controls with techniques that are not adequately covered. + + Raises EntityNotFoundError if the framework does not exist. + """ + framework = get_framework(db, framework_id) + if not framework: + raise EntityNotFoundError("Framework", framework_id) + + controls = ( + db.query(ComplianceControl) + .filter(ComplianceControl.framework_id == framework.id) + .order_by(ComplianceControl.control_id) + .all() + ) + + gaps = [] + for control in controls: + status_data = _get_control_status(control, db) + + if status_data["status"] in ("not_covered", "partially_covered"): + # Find uncovered techniques + uncovered_techniques = [] + for tech_info in status_data["techniques"]: + if tech_info["score"] < 70: + # Count available templates + template_count = ( + db.query(TestTemplate) + .filter(TestTemplate.mitre_technique_id == tech_info["mitre_id"]) + .count() + ) + + # Count threat actors using this technique + technique = ( + db.query(Technique) + .filter(Technique.mitre_id == tech_info["mitre_id"]) + .first() + ) + actor_count = 0 + if technique: + actor_count = ( + db.query(ThreatActorTechnique) + .filter(ThreatActorTechnique.technique_id == technique.id) + .count() + ) + + uncovered_techniques.append({ + **tech_info, + "templates_available": template_count, + "threat_actors_using": actor_count, + }) + + if uncovered_techniques: + gaps.append({ + "control_id": status_data["control_id"], + "title": status_data["title"], + "category": status_data["category"], + "status": status_data["status"], + "score": status_data["score"], + "uncovered_techniques": uncovered_techniques, + }) + + return { + "framework": {"id": str(framework.id), "name": framework.name}, + "total_gaps": len(gaps), + "gaps": gaps, + }