refactor(compliance): extract business logic to compliance_service, use domain exceptions instead of HTTPException

This commit is contained in:
2026-02-19 17:06:32 +01:00
parent 25fddad17c
commit d305db8794
2 changed files with 344 additions and 289 deletions

View File

@@ -1,29 +1,24 @@
"""Compliance endpoints — framework status, reports, and gap analysis. """Compliance endpoints — framework status, reports, and gap analysis.
Thin HTTP adapter: delegates all data logic to compliance_service.
Provides compliance posture assessment by mapping MITRE ATT&CK technique Provides compliance posture assessment by mapping MITRE ATT&CK technique
coverage to compliance framework controls. coverage to compliance framework controls.
""" """
import csv from fastapi import APIRouter, Depends
import io
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session, joinedload from sqlalchemy.orm import Session
from app.database import get_db from app.database import get_db
from app.dependencies.auth import get_current_user, require_role from app.dependencies.auth import get_current_user, require_role
from app.models.user import User from app.models.user import User
from app.models.compliance import ( from app.services.compliance_service import (
ComplianceFramework, list_frameworks,
ComplianceControl, get_framework_status,
ComplianceControlMapping, build_framework_report_csv,
get_framework_gaps,
) )
from app.models.technique import Technique
from app.models.test_template import TestTemplate
from app.models.threat_actor import ThreatActorTechnique
from app.services.scoring_service import calculate_technique_score
from app.services.compliance_import_service import ( from app.services.compliance_import_service import (
import_nist_800_53_mappings, import_nist_800_53_mappings,
import_cis_controls_v8_mappings, import_cis_controls_v8_mappings,
@@ -32,126 +27,16 @@ from app.services.compliance_import_service import (
router = APIRouter(prefix="/compliance", tags=["compliance"]) router = APIRouter(prefix="/compliance", tags=["compliance"])
# ── Helpers ───────────────────────────────────────────────────────────
def _classify_control(technique_scores: list[float]) -> str:
"""Classify a control status based on its technique scores."""
if not technique_scores:
return "not_evaluated"
all_above_70 = all(s >= 70 for s in technique_scores)
any_above_30 = any(s >= 30 for s in technique_scores)
all_below_30 = all(s < 30 for s in technique_scores)
all_zero = all(s == 0 for s in technique_scores)
if all_zero:
return "not_evaluated"
if all_above_70:
return "covered"
if all_below_30:
return "not_covered"
if any_above_30:
return "partially_covered"
return "not_covered"
def _get_control_status(control: ComplianceControl, db: Session) -> dict:
"""Compute the status and score for a single control."""
mappings = (
db.query(ComplianceControlMapping)
.filter(ComplianceControlMapping.compliance_control_id == control.id)
.all()
)
if not mappings:
return {
"control_id": control.control_id,
"title": control.title,
"category": control.category,
"status": "not_evaluated",
"score": 0,
"techniques_count": 0,
"techniques_covered": 0,
"techniques": [],
}
technique_ids = [m.technique_id for m in mappings]
techniques = (
db.query(Technique)
.filter(Technique.id.in_(technique_ids))
.all()
)
tech_details = []
scores = []
covered_count = 0
for tech in techniques:
result = calculate_technique_score(tech, db)
score = result["total_score"]
scores.append(score)
if score >= 50:
covered_count += 1
tech_details.append({
"mitre_id": tech.mitre_id,
"name": tech.name,
"score": score,
"status": tech.status_global.value if tech.status_global else "not_evaluated",
})
# Sort techniques by score ascending (worst first for priority)
tech_details.sort(key=lambda t: t["score"])
avg_score = round(sum(scores) / len(scores), 1) if scores else 0
status = _classify_control(scores)
return {
"control_id": control.control_id,
"title": control.title,
"category": control.category,
"status": status,
"score": avg_score,
"techniques_count": len(techniques),
"techniques_covered": covered_count,
"techniques": tech_details,
}
# ── GET /compliance/frameworks ──────────────────────────────────────── # ── GET /compliance/frameworks ────────────────────────────────────────
@router.get("/frameworks") @router.get("/frameworks")
def list_frameworks( def list_frameworks_endpoint(
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ):
"""List all available compliance frameworks.""" """List all available compliance frameworks."""
frameworks = ( return list_frameworks(db)
db.query(ComplianceFramework)
.filter(ComplianceFramework.is_active == True)
.all()
)
result = []
for fw in frameworks:
control_count = (
db.query(ComplianceControl)
.filter(ComplianceControl.framework_id == fw.id)
.count()
)
result.append({
"id": str(fw.id),
"name": fw.name,
"version": fw.version,
"description": fw.description,
"url": fw.url,
"is_active": fw.is_active,
"controls_count": control_count,
})
return result
# ── GET /compliance/frameworks/{id}/status ──────────────────────────── # ── GET /compliance/frameworks/{id}/status ────────────────────────────
@@ -164,55 +49,7 @@ def framework_status(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ):
"""Get compliance status for each control in a framework.""" """Get compliance status for each control in a framework."""
framework = ( return get_framework_status(db, framework_id)
db.query(ComplianceFramework)
.filter(ComplianceFramework.id == framework_id)
.first()
)
if not framework:
raise HTTPException(status_code=404, detail="Framework not found")
controls = (
db.query(ComplianceControl)
.filter(ComplianceControl.framework_id == framework.id)
.order_by(ComplianceControl.control_id)
.all()
)
control_statuses = []
summary = {
"total_controls": len(controls),
"covered": 0,
"partially_covered": 0,
"not_covered": 0,
"not_evaluated": 0,
}
for control in controls:
status_data = _get_control_status(control, db)
control_statuses.append(status_data)
status = status_data["status"]
if status in summary:
summary[status] += 1
# Compliance percentage: (covered + partially_covered*0.5) / total * 100
total = summary["total_controls"]
if total > 0:
compliance_pct = round(
(summary["covered"] + summary["partially_covered"] * 0.5) / total * 100,
1,
)
else:
compliance_pct = 0
summary["compliance_percentage"] = compliance_pct
return {
"framework": {"id": str(framework.id), "name": framework.name},
"summary": summary,
"controls": control_statuses,
}
# ── GET /compliance/frameworks/{id}/report ──────────────────────────── # ── GET /compliance/frameworks/{id}/report ────────────────────────────
@@ -225,7 +62,7 @@ def framework_report(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ):
"""Get the full compliance report (same as status but marked as report).""" """Get the full compliance report (same as status but marked as report)."""
return framework_status(framework_id, db=db, current_user=current_user) return get_framework_status(db, framework_id)
# ── GET /compliance/frameworks/{id}/report/csv ──────────────────────── # ── GET /compliance/frameworks/{id}/report/csv ────────────────────────
@@ -238,53 +75,9 @@ def framework_report_csv(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ):
"""Export compliance report as CSV.""" """Export compliance report as CSV."""
framework = ( csv_bytes, filename = build_framework_report_csv(db, framework_id)
db.query(ComplianceFramework)
.filter(ComplianceFramework.id == framework_id)
.first()
)
if not framework:
raise HTTPException(status_code=404, detail="Framework not found")
controls = (
db.query(ComplianceControl)
.filter(ComplianceControl.framework_id == framework.id)
.order_by(ComplianceControl.control_id)
.all()
)
output = io.StringIO()
writer = csv.writer(output)
writer.writerow([
"control_id",
"title",
"category",
"status",
"score",
"techniques_total",
"techniques_covered",
"technique_ids",
])
for control in controls:
status_data = _get_control_status(control, db)
technique_ids = ",".join(t["mitre_id"] for t in status_data["techniques"])
writer.writerow([
status_data["control_id"],
status_data["title"],
status_data["category"] or "",
status_data["status"],
status_data["score"],
status_data["techniques_count"],
status_data["techniques_covered"],
technique_ids,
])
output.seek(0)
filename = f"compliance_{framework.name.replace(' ', '_')}.csv"
return StreamingResponse( return StreamingResponse(
io.BytesIO(output.getvalue().encode("utf-8")), iter([csv_bytes]),
media_type="text/csv", media_type="text/csv",
headers={ headers={
"Content-Disposition": f"attachment; filename={filename}", "Content-Disposition": f"attachment; filename={filename}",
@@ -302,75 +95,10 @@ def framework_gaps(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ):
"""Get controls with techniques that are not adequately covered.""" """Get controls with techniques that are not adequately covered."""
framework = ( return get_framework_gaps(db, framework_id)
db.query(ComplianceFramework)
.filter(ComplianceFramework.id == framework_id)
.first()
)
if not framework:
raise HTTPException(status_code=404, detail="Framework not found")
controls = (
db.query(ComplianceControl)
.filter(ComplianceControl.framework_id == framework.id)
.order_by(ComplianceControl.control_id)
.all()
)
gaps = []
for control in controls:
status_data = _get_control_status(control, db)
if status_data["status"] in ("not_covered", "partially_covered"):
# Find uncovered techniques
uncovered_techniques = []
for tech_info in status_data["techniques"]:
if tech_info["score"] < 70:
# Count available templates
template_count = (
db.query(TestTemplate)
.filter(TestTemplate.mitre_technique_id == tech_info["mitre_id"])
.count()
)
# Count threat actors using this technique
technique = (
db.query(Technique)
.filter(Technique.mitre_id == tech_info["mitre_id"])
.first()
)
actor_count = 0
if technique:
actor_count = (
db.query(ThreatActorTechnique)
.filter(ThreatActorTechnique.technique_id == technique.id)
.count()
)
uncovered_techniques.append({
**tech_info,
"templates_available": template_count,
"threat_actors_using": actor_count,
})
if uncovered_techniques:
gaps.append({
"control_id": status_data["control_id"],
"title": status_data["title"],
"category": status_data["category"],
"status": status_data["status"],
"score": status_data["score"],
"uncovered_techniques": uncovered_techniques,
})
return {
"framework": {"id": str(framework.id), "name": framework.name},
"total_gaps": len(gaps),
"gaps": gaps,
}
# ── POST /compliance/import/nist-800-53 ────────────────────────────── # ── POST /compliance/import/... ────────────────────────────────────────
@router.post("/import/nist-800-53") @router.post("/import/nist-800-53")

View File

@@ -0,0 +1,327 @@
"""Compliance data service.
Extracts query and aggregation logic from the compliance router so
that the router remains a thin HTTP adapter.
This module is framework-agnostic: no FastAPI imports.
"""
from __future__ import annotations
import csv
import io
from typing import Any
from sqlalchemy.orm import Session
from app.domain.errors import EntityNotFoundError
from app.models.compliance import (
ComplianceFramework,
ComplianceControl,
ComplianceControlMapping,
)
from app.models.technique import Technique
from app.models.test_template import TestTemplate
from app.models.threat_actor import ThreatActorTechnique
from app.services.scoring_service import calculate_technique_score
# ── Helpers ───────────────────────────────────────────────────────────
def _classify_control(technique_scores: list[float]) -> str:
"""Classify a control status based on its technique scores."""
if not technique_scores:
return "not_evaluated"
all_above_70 = all(s >= 70 for s in technique_scores)
any_above_30 = any(s >= 30 for s in technique_scores)
all_below_30 = all(s < 30 for s in technique_scores)
all_zero = all(s == 0 for s in technique_scores)
if all_zero:
return "not_evaluated"
if all_above_70:
return "covered"
if all_below_30:
return "not_covered"
if any_above_30:
return "partially_covered"
return "not_covered"
def _get_control_status(control: ComplianceControl, db: Session) -> dict[str, Any]:
"""Compute the status and score for a single control."""
mappings = (
db.query(ComplianceControlMapping)
.filter(ComplianceControlMapping.compliance_control_id == control.id)
.all()
)
if not mappings:
return {
"control_id": control.control_id,
"title": control.title,
"category": control.category,
"status": "not_evaluated",
"score": 0,
"techniques_count": 0,
"techniques_covered": 0,
"techniques": [],
}
technique_ids = [m.technique_id for m in mappings]
techniques = (
db.query(Technique)
.filter(Technique.id.in_(technique_ids))
.all()
)
tech_details = []
scores = []
covered_count = 0
for tech in techniques:
result = calculate_technique_score(tech, db)
score = result["total_score"]
scores.append(score)
if score >= 50:
covered_count += 1
tech_details.append({
"mitre_id": tech.mitre_id,
"name": tech.name,
"score": score,
"status": tech.status_global.value if tech.status_global else "not_evaluated",
})
# Sort techniques by score ascending (worst first for priority)
tech_details.sort(key=lambda t: t["score"])
avg_score = round(sum(scores) / len(scores), 1) if scores else 0
status = _classify_control(scores)
return {
"control_id": control.control_id,
"title": control.title,
"category": control.category,
"status": status,
"score": avg_score,
"techniques_count": len(techniques),
"techniques_covered": covered_count,
"techniques": tech_details,
}
# ── Public service functions ───────────────────────────────────────────
def list_frameworks(db: Session) -> list[dict[str, Any]]:
"""List all available compliance frameworks with control counts."""
frameworks = (
db.query(ComplianceFramework)
.filter(ComplianceFramework.is_active == True)
.all()
)
result = []
for fw in frameworks:
control_count = (
db.query(ComplianceControl)
.filter(ComplianceControl.framework_id == fw.id)
.count()
)
result.append({
"id": str(fw.id),
"name": fw.name,
"version": fw.version,
"description": fw.description,
"url": fw.url,
"is_active": fw.is_active,
"controls_count": control_count,
})
return result
def get_framework(db: Session, framework_id: str) -> ComplianceFramework | None:
"""Get a framework by ID, or None if not found."""
return (
db.query(ComplianceFramework)
.filter(ComplianceFramework.id == framework_id)
.first()
)
def get_framework_status(db: Session, framework_id: str) -> dict[str, Any]:
"""Get compliance status for each control in a framework.
Raises EntityNotFoundError if the framework does not exist.
"""
framework = get_framework(db, framework_id)
if not framework:
raise EntityNotFoundError("Framework", framework_id)
controls = (
db.query(ComplianceControl)
.filter(ComplianceControl.framework_id == framework.id)
.order_by(ComplianceControl.control_id)
.all()
)
control_statuses = []
summary = {
"total_controls": len(controls),
"covered": 0,
"partially_covered": 0,
"not_covered": 0,
"not_evaluated": 0,
}
for control in controls:
status_data = _get_control_status(control, db)
control_statuses.append(status_data)
status = status_data["status"]
if status in summary:
summary[status] += 1
# Compliance percentage: (covered + partially_covered*0.5) / total * 100
total = summary["total_controls"]
if total > 0:
compliance_pct = round(
(summary["covered"] + summary["partially_covered"] * 0.5) / total * 100,
1,
)
else:
compliance_pct = 0
summary["compliance_percentage"] = compliance_pct
return {
"framework": {"id": str(framework.id), "name": framework.name},
"summary": summary,
"controls": control_statuses,
}
def build_framework_report_csv(
db: Session,
framework_id: str,
) -> tuple[bytes, str]:
"""Build the compliance report CSV content and filename.
Returns (csv_bytes, filename).
Raises EntityNotFoundError if the framework does not exist.
"""
framework = get_framework(db, framework_id)
if not framework:
raise EntityNotFoundError("Framework", framework_id)
controls = (
db.query(ComplianceControl)
.filter(ComplianceControl.framework_id == framework.id)
.order_by(ComplianceControl.control_id)
.all()
)
output = io.StringIO()
writer = csv.writer(output)
writer.writerow([
"control_id",
"title",
"category",
"status",
"score",
"techniques_total",
"techniques_covered",
"technique_ids",
])
for control in controls:
status_data = _get_control_status(control, db)
technique_ids = ",".join(t["mitre_id"] for t in status_data["techniques"])
writer.writerow([
status_data["control_id"],
status_data["title"],
status_data["category"] or "",
status_data["status"],
status_data["score"],
status_data["techniques_count"],
status_data["techniques_covered"],
technique_ids,
])
output.seek(0)
filename = f"compliance_{framework.name.replace(' ', '_')}.csv"
return output.getvalue().encode("utf-8"), filename
def get_framework_gaps(db: Session, framework_id: str) -> dict[str, Any]:
"""Get controls with techniques that are not adequately covered.
Raises EntityNotFoundError if the framework does not exist.
"""
framework = get_framework(db, framework_id)
if not framework:
raise EntityNotFoundError("Framework", framework_id)
controls = (
db.query(ComplianceControl)
.filter(ComplianceControl.framework_id == framework.id)
.order_by(ComplianceControl.control_id)
.all()
)
gaps = []
for control in controls:
status_data = _get_control_status(control, db)
if status_data["status"] in ("not_covered", "partially_covered"):
# Find uncovered techniques
uncovered_techniques = []
for tech_info in status_data["techniques"]:
if tech_info["score"] < 70:
# Count available templates
template_count = (
db.query(TestTemplate)
.filter(TestTemplate.mitre_technique_id == tech_info["mitre_id"])
.count()
)
# Count threat actors using this technique
technique = (
db.query(Technique)
.filter(Technique.mitre_id == tech_info["mitre_id"])
.first()
)
actor_count = 0
if technique:
actor_count = (
db.query(ThreatActorTechnique)
.filter(ThreatActorTechnique.technique_id == technique.id)
.count()
)
uncovered_techniques.append({
**tech_info,
"templates_available": template_count,
"threat_actors_using": actor_count,
})
if uncovered_techniques:
gaps.append({
"control_id": status_data["control_id"],
"title": status_data["title"],
"category": status_data["category"],
"status": status_data["status"],
"score": status_data["score"],
"uncovered_techniques": uncovered_techniques,
})
return {
"framework": {"id": str(framework.id), "name": framework.name},
"total_gaps": len(gaps),
"gaps": gaps,
}