feat(phase-29): add compliance framework mapping, reports and UI (T-227 to T-229)

This commit is contained in:
2026-02-09 18:41:24 +01:00
parent 12f33307fd
commit 2ac8e7f4a5
12 changed files with 1516 additions and 0 deletions

View File

@@ -0,0 +1,380 @@
"""Compliance endpoints — framework status, reports, and gap analysis.
Provides compliance posture assessment by mapping MITRE ATT&CK technique
coverage to compliance framework controls.
"""
import csv
import io
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session, joinedload
from app.database import get_db
from app.dependencies.auth import get_current_user, require_role
from app.models.user import User
from app.models.compliance import (
ComplianceFramework,
ComplianceControl,
ComplianceControlMapping,
)
from app.models.technique import Technique
from app.models.test_template import TestTemplate
from app.models.threat_actor import ThreatActorTechnique
from app.services.scoring_service import calculate_technique_score
from app.services.compliance_import_service import import_nist_800_53_mappings
router = APIRouter(prefix="/compliance", tags=["compliance"])
# ── Helpers ───────────────────────────────────────────────────────────
def _classify_control(technique_scores: list[float]) -> str:
"""Classify a control status based on its technique scores."""
if not technique_scores:
return "not_evaluated"
all_above_70 = all(s >= 70 for s in technique_scores)
any_above_30 = any(s >= 30 for s in technique_scores)
all_below_30 = all(s < 30 for s in technique_scores)
all_zero = all(s == 0 for s in technique_scores)
if all_zero:
return "not_evaluated"
if all_above_70:
return "covered"
if all_below_30:
return "not_covered"
if any_above_30:
return "partially_covered"
return "not_covered"
def _get_control_status(control: ComplianceControl, db: Session) -> dict:
"""Compute the status and score for a single control."""
mappings = (
db.query(ComplianceControlMapping)
.filter(ComplianceControlMapping.compliance_control_id == control.id)
.all()
)
if not mappings:
return {
"control_id": control.control_id,
"title": control.title,
"category": control.category,
"status": "not_evaluated",
"score": 0,
"techniques_count": 0,
"techniques_covered": 0,
"techniques": [],
}
technique_ids = [m.technique_id for m in mappings]
techniques = (
db.query(Technique)
.filter(Technique.id.in_(technique_ids))
.all()
)
tech_details = []
scores = []
covered_count = 0
for tech in techniques:
result = calculate_technique_score(tech, db)
score = result["total_score"]
scores.append(score)
if score >= 50:
covered_count += 1
tech_details.append({
"mitre_id": tech.mitre_id,
"name": tech.name,
"score": score,
"status": tech.status_global.value if tech.status_global else "not_evaluated",
})
# Sort techniques by score ascending (worst first for priority)
tech_details.sort(key=lambda t: t["score"])
avg_score = round(sum(scores) / len(scores), 1) if scores else 0
status = _classify_control(scores)
return {
"control_id": control.control_id,
"title": control.title,
"category": control.category,
"status": status,
"score": avg_score,
"techniques_count": len(techniques),
"techniques_covered": covered_count,
"techniques": tech_details,
}
# ── GET /compliance/frameworks ────────────────────────────────────────
@router.get("/frameworks")
def list_frameworks(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""List all available compliance frameworks."""
frameworks = (
db.query(ComplianceFramework)
.filter(ComplianceFramework.is_active == True)
.all()
)
result = []
for fw in frameworks:
control_count = (
db.query(ComplianceControl)
.filter(ComplianceControl.framework_id == fw.id)
.count()
)
result.append({
"id": str(fw.id),
"name": fw.name,
"version": fw.version,
"description": fw.description,
"url": fw.url,
"is_active": fw.is_active,
"controls_count": control_count,
})
return result
# ── GET /compliance/frameworks/{id}/status ────────────────────────────
@router.get("/frameworks/{framework_id}/status")
def framework_status(
framework_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Get compliance status for each control in a framework."""
framework = (
db.query(ComplianceFramework)
.filter(ComplianceFramework.id == framework_id)
.first()
)
if not framework:
raise HTTPException(status_code=404, detail="Framework not found")
controls = (
db.query(ComplianceControl)
.filter(ComplianceControl.framework_id == framework.id)
.order_by(ComplianceControl.control_id)
.all()
)
control_statuses = []
summary = {
"total_controls": len(controls),
"covered": 0,
"partially_covered": 0,
"not_covered": 0,
"not_evaluated": 0,
}
for control in controls:
status_data = _get_control_status(control, db)
control_statuses.append(status_data)
status = status_data["status"]
if status in summary:
summary[status] += 1
# Compliance percentage: (covered + partially_covered*0.5) / total * 100
total = summary["total_controls"]
if total > 0:
compliance_pct = round(
(summary["covered"] + summary["partially_covered"] * 0.5) / total * 100,
1,
)
else:
compliance_pct = 0
summary["compliance_percentage"] = compliance_pct
return {
"framework": {"id": str(framework.id), "name": framework.name},
"summary": summary,
"controls": control_statuses,
}
# ── GET /compliance/frameworks/{id}/report ────────────────────────────
@router.get("/frameworks/{framework_id}/report")
def framework_report(
framework_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Get the full compliance report (same as status but marked as report)."""
return framework_status(framework_id, db=db, current_user=current_user)
# ── GET /compliance/frameworks/{id}/report/csv ────────────────────────
@router.get("/frameworks/{framework_id}/report/csv")
def framework_report_csv(
framework_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Export compliance report as CSV."""
framework = (
db.query(ComplianceFramework)
.filter(ComplianceFramework.id == framework_id)
.first()
)
if not framework:
raise HTTPException(status_code=404, detail="Framework not found")
controls = (
db.query(ComplianceControl)
.filter(ComplianceControl.framework_id == framework.id)
.order_by(ComplianceControl.control_id)
.all()
)
output = io.StringIO()
writer = csv.writer(output)
writer.writerow([
"control_id",
"title",
"category",
"status",
"score",
"techniques_total",
"techniques_covered",
"technique_ids",
])
for control in controls:
status_data = _get_control_status(control, db)
technique_ids = ",".join(t["mitre_id"] for t in status_data["techniques"])
writer.writerow([
status_data["control_id"],
status_data["title"],
status_data["category"] or "",
status_data["status"],
status_data["score"],
status_data["techniques_count"],
status_data["techniques_covered"],
technique_ids,
])
output.seek(0)
filename = f"compliance_{framework.name.replace(' ', '_')}.csv"
return StreamingResponse(
io.BytesIO(output.getvalue().encode("utf-8")),
media_type="text/csv",
headers={
"Content-Disposition": f"attachment; filename={filename}",
},
)
# ── GET /compliance/frameworks/{id}/gaps ──────────────────────────────
@router.get("/frameworks/{framework_id}/gaps")
def framework_gaps(
framework_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Get controls with techniques that are not adequately covered."""
framework = (
db.query(ComplianceFramework)
.filter(ComplianceFramework.id == framework_id)
.first()
)
if not framework:
raise HTTPException(status_code=404, detail="Framework not found")
controls = (
db.query(ComplianceControl)
.filter(ComplianceControl.framework_id == framework.id)
.order_by(ComplianceControl.control_id)
.all()
)
gaps = []
for control in controls:
status_data = _get_control_status(control, db)
if status_data["status"] in ("not_covered", "partially_covered"):
# Find uncovered techniques
uncovered_techniques = []
for tech_info in status_data["techniques"]:
if tech_info["score"] < 70:
# Count available templates
template_count = (
db.query(TestTemplate)
.filter(TestTemplate.mitre_technique_id == tech_info["mitre_id"])
.count()
)
# Count threat actors using this technique
technique = (
db.query(Technique)
.filter(Technique.mitre_id == tech_info["mitre_id"])
.first()
)
actor_count = 0
if technique:
actor_count = (
db.query(ThreatActorTechnique)
.filter(ThreatActorTechnique.technique_id == technique.id)
.count()
)
uncovered_techniques.append({
**tech_info,
"templates_available": template_count,
"threat_actors_using": actor_count,
})
if uncovered_techniques:
gaps.append({
"control_id": status_data["control_id"],
"title": status_data["title"],
"category": status_data["category"],
"status": status_data["status"],
"score": status_data["score"],
"uncovered_techniques": uncovered_techniques,
})
return {
"framework": {"id": str(framework.id), "name": framework.name},
"total_gaps": len(gaps),
"gaps": gaps,
}
# ── POST /compliance/import/nist-800-53 ──────────────────────────────
@router.post("/import/nist-800-53")
def import_nist(
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
"""Import NIST 800-53 Rev 5 mappings (admin only)."""
result = import_nist_800_53_mappings(db)
return result