Files
Aegis/backend/app/routers/compliance.py
Kitos c2e9c687f4 fix: D3FEND expandable cards, System page cleanup, and multi-source improvements
- Make D3FEND defense cards clickable with expandable details and external link
- Fix D3FEND URLs to use PascalCase technique names matching the ontology
- Remove duplicate Import Atomic Red Team from System page (use Data Sources)
- Add bulk Activate All / Deactivate All buttons with confirmation modal
- Fix template admin list to show both active and inactive templates
- Add PATCH /test-templates/bulk-activate backend endpoint
- Auto-seed data sources on container startup via entrypoint.sh
- Fix SigmaHQ, CALDERA, GTFOBins import issues
- Register D3FEND sync handler in data sources router
- Add CIS Controls v8 compliance framework import
- Expand Test Catalog source filters (CALDERA, LOLBAS, GTFOBins)
- Campaign Generate from Threat Actor now opens actor selector modal
- Add coverage snapshot creation button to Comparison page
- Update README with accurate data source and feature documentation
2026-02-10 13:22:23 +01:00

394 lines
12 KiB
Python

"""Compliance endpoints — framework status, reports, and gap analysis.
Provides compliance posture assessment by mapping MITRE ATT&CK technique
coverage to compliance framework controls.
"""
import csv
import io
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session, joinedload
from app.database import get_db
from app.dependencies.auth import get_current_user, require_role
from app.models.user import User
from app.models.compliance import (
ComplianceFramework,
ComplianceControl,
ComplianceControlMapping,
)
from app.models.technique import Technique
from app.models.test_template import TestTemplate
from app.models.threat_actor import ThreatActorTechnique
from app.services.scoring_service import calculate_technique_score
from app.services.compliance_import_service import (
import_nist_800_53_mappings,
import_cis_controls_v8_mappings,
)
router = APIRouter(prefix="/compliance", tags=["compliance"])
# ── Helpers ───────────────────────────────────────────────────────────
def _classify_control(technique_scores: list[float]) -> str:
"""Classify a control status based on its technique scores."""
if not technique_scores:
return "not_evaluated"
all_above_70 = all(s >= 70 for s in technique_scores)
any_above_30 = any(s >= 30 for s in technique_scores)
all_below_30 = all(s < 30 for s in technique_scores)
all_zero = all(s == 0 for s in technique_scores)
if all_zero:
return "not_evaluated"
if all_above_70:
return "covered"
if all_below_30:
return "not_covered"
if any_above_30:
return "partially_covered"
return "not_covered"
def _get_control_status(control: ComplianceControl, db: Session) -> dict:
"""Compute the status and score for a single control."""
mappings = (
db.query(ComplianceControlMapping)
.filter(ComplianceControlMapping.compliance_control_id == control.id)
.all()
)
if not mappings:
return {
"control_id": control.control_id,
"title": control.title,
"category": control.category,
"status": "not_evaluated",
"score": 0,
"techniques_count": 0,
"techniques_covered": 0,
"techniques": [],
}
technique_ids = [m.technique_id for m in mappings]
techniques = (
db.query(Technique)
.filter(Technique.id.in_(technique_ids))
.all()
)
tech_details = []
scores = []
covered_count = 0
for tech in techniques:
result = calculate_technique_score(tech, db)
score = result["total_score"]
scores.append(score)
if score >= 50:
covered_count += 1
tech_details.append({
"mitre_id": tech.mitre_id,
"name": tech.name,
"score": score,
"status": tech.status_global.value if tech.status_global else "not_evaluated",
})
# Sort techniques by score ascending (worst first for priority)
tech_details.sort(key=lambda t: t["score"])
avg_score = round(sum(scores) / len(scores), 1) if scores else 0
status = _classify_control(scores)
return {
"control_id": control.control_id,
"title": control.title,
"category": control.category,
"status": status,
"score": avg_score,
"techniques_count": len(techniques),
"techniques_covered": covered_count,
"techniques": tech_details,
}
# ── GET /compliance/frameworks ────────────────────────────────────────
@router.get("/frameworks")
def list_frameworks(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""List all available compliance frameworks."""
frameworks = (
db.query(ComplianceFramework)
.filter(ComplianceFramework.is_active == True)
.all()
)
result = []
for fw in frameworks:
control_count = (
db.query(ComplianceControl)
.filter(ComplianceControl.framework_id == fw.id)
.count()
)
result.append({
"id": str(fw.id),
"name": fw.name,
"version": fw.version,
"description": fw.description,
"url": fw.url,
"is_active": fw.is_active,
"controls_count": control_count,
})
return result
# ── GET /compliance/frameworks/{id}/status ────────────────────────────
@router.get("/frameworks/{framework_id}/status")
def framework_status(
framework_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Get compliance status for each control in a framework."""
framework = (
db.query(ComplianceFramework)
.filter(ComplianceFramework.id == framework_id)
.first()
)
if not framework:
raise HTTPException(status_code=404, detail="Framework not found")
controls = (
db.query(ComplianceControl)
.filter(ComplianceControl.framework_id == framework.id)
.order_by(ComplianceControl.control_id)
.all()
)
control_statuses = []
summary = {
"total_controls": len(controls),
"covered": 0,
"partially_covered": 0,
"not_covered": 0,
"not_evaluated": 0,
}
for control in controls:
status_data = _get_control_status(control, db)
control_statuses.append(status_data)
status = status_data["status"]
if status in summary:
summary[status] += 1
# Compliance percentage: (covered + partially_covered*0.5) / total * 100
total = summary["total_controls"]
if total > 0:
compliance_pct = round(
(summary["covered"] + summary["partially_covered"] * 0.5) / total * 100,
1,
)
else:
compliance_pct = 0
summary["compliance_percentage"] = compliance_pct
return {
"framework": {"id": str(framework.id), "name": framework.name},
"summary": summary,
"controls": control_statuses,
}
# ── GET /compliance/frameworks/{id}/report ────────────────────────────
@router.get("/frameworks/{framework_id}/report")
def framework_report(
framework_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Get the full compliance report (same as status but marked as report)."""
return framework_status(framework_id, db=db, current_user=current_user)
# ── GET /compliance/frameworks/{id}/report/csv ────────────────────────
@router.get("/frameworks/{framework_id}/report/csv")
def framework_report_csv(
framework_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Export compliance report as CSV."""
framework = (
db.query(ComplianceFramework)
.filter(ComplianceFramework.id == framework_id)
.first()
)
if not framework:
raise HTTPException(status_code=404, detail="Framework not found")
controls = (
db.query(ComplianceControl)
.filter(ComplianceControl.framework_id == framework.id)
.order_by(ComplianceControl.control_id)
.all()
)
output = io.StringIO()
writer = csv.writer(output)
writer.writerow([
"control_id",
"title",
"category",
"status",
"score",
"techniques_total",
"techniques_covered",
"technique_ids",
])
for control in controls:
status_data = _get_control_status(control, db)
technique_ids = ",".join(t["mitre_id"] for t in status_data["techniques"])
writer.writerow([
status_data["control_id"],
status_data["title"],
status_data["category"] or "",
status_data["status"],
status_data["score"],
status_data["techniques_count"],
status_data["techniques_covered"],
technique_ids,
])
output.seek(0)
filename = f"compliance_{framework.name.replace(' ', '_')}.csv"
return StreamingResponse(
io.BytesIO(output.getvalue().encode("utf-8")),
media_type="text/csv",
headers={
"Content-Disposition": f"attachment; filename={filename}",
},
)
# ── GET /compliance/frameworks/{id}/gaps ──────────────────────────────
@router.get("/frameworks/{framework_id}/gaps")
def framework_gaps(
framework_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Get controls with techniques that are not adequately covered."""
framework = (
db.query(ComplianceFramework)
.filter(ComplianceFramework.id == framework_id)
.first()
)
if not framework:
raise HTTPException(status_code=404, detail="Framework not found")
controls = (
db.query(ComplianceControl)
.filter(ComplianceControl.framework_id == framework.id)
.order_by(ComplianceControl.control_id)
.all()
)
gaps = []
for control in controls:
status_data = _get_control_status(control, db)
if status_data["status"] in ("not_covered", "partially_covered"):
# Find uncovered techniques
uncovered_techniques = []
for tech_info in status_data["techniques"]:
if tech_info["score"] < 70:
# Count available templates
template_count = (
db.query(TestTemplate)
.filter(TestTemplate.mitre_technique_id == tech_info["mitre_id"])
.count()
)
# Count threat actors using this technique
technique = (
db.query(Technique)
.filter(Technique.mitre_id == tech_info["mitre_id"])
.first()
)
actor_count = 0
if technique:
actor_count = (
db.query(ThreatActorTechnique)
.filter(ThreatActorTechnique.technique_id == technique.id)
.count()
)
uncovered_techniques.append({
**tech_info,
"templates_available": template_count,
"threat_actors_using": actor_count,
})
if uncovered_techniques:
gaps.append({
"control_id": status_data["control_id"],
"title": status_data["title"],
"category": status_data["category"],
"status": status_data["status"],
"score": status_data["score"],
"uncovered_techniques": uncovered_techniques,
})
return {
"framework": {"id": str(framework.id), "name": framework.name},
"total_gaps": len(gaps),
"gaps": gaps,
}
# ── POST /compliance/import/nist-800-53 ──────────────────────────────
@router.post("/import/nist-800-53")
def import_nist(
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
"""Import NIST 800-53 Rev 5 mappings (admin only)."""
result = import_nist_800_53_mappings(db)
return result
@router.post("/import/cis-controls-v8")
def import_cis(
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
"""Import CIS Controls v8 mappings (admin only)."""
result = import_cis_controls_v8_mappings(db)
return result