refactor(compliance): extract business logic to compliance_service, use domain exceptions instead of HTTPException

This commit is contained in:
2026-02-19 17:06:32 +01:00
parent 25fddad17c
commit d305db8794
2 changed files with 344 additions and 289 deletions

View File

@@ -0,0 +1,327 @@
"""Compliance data service.
Extracts query and aggregation logic from the compliance router so
that the router remains a thin HTTP adapter.
This module is framework-agnostic: no FastAPI imports.
"""
from __future__ import annotations
import csv
import io
from typing import Any
from sqlalchemy.orm import Session
from app.domain.errors import EntityNotFoundError
from app.models.compliance import (
ComplianceFramework,
ComplianceControl,
ComplianceControlMapping,
)
from app.models.technique import Technique
from app.models.test_template import TestTemplate
from app.models.threat_actor import ThreatActorTechnique
from app.services.scoring_service import calculate_technique_score
# ── Helpers ───────────────────────────────────────────────────────────
def _classify_control(technique_scores: list[float]) -> str:
"""Classify a control status based on its technique scores."""
if not technique_scores:
return "not_evaluated"
all_above_70 = all(s >= 70 for s in technique_scores)
any_above_30 = any(s >= 30 for s in technique_scores)
all_below_30 = all(s < 30 for s in technique_scores)
all_zero = all(s == 0 for s in technique_scores)
if all_zero:
return "not_evaluated"
if all_above_70:
return "covered"
if all_below_30:
return "not_covered"
if any_above_30:
return "partially_covered"
return "not_covered"
def _get_control_status(control: ComplianceControl, db: Session) -> dict[str, Any]:
"""Compute the status and score for a single control."""
mappings = (
db.query(ComplianceControlMapping)
.filter(ComplianceControlMapping.compliance_control_id == control.id)
.all()
)
if not mappings:
return {
"control_id": control.control_id,
"title": control.title,
"category": control.category,
"status": "not_evaluated",
"score": 0,
"techniques_count": 0,
"techniques_covered": 0,
"techniques": [],
}
technique_ids = [m.technique_id for m in mappings]
techniques = (
db.query(Technique)
.filter(Technique.id.in_(technique_ids))
.all()
)
tech_details = []
scores = []
covered_count = 0
for tech in techniques:
result = calculate_technique_score(tech, db)
score = result["total_score"]
scores.append(score)
if score >= 50:
covered_count += 1
tech_details.append({
"mitre_id": tech.mitre_id,
"name": tech.name,
"score": score,
"status": tech.status_global.value if tech.status_global else "not_evaluated",
})
# Sort techniques by score ascending (worst first for priority)
tech_details.sort(key=lambda t: t["score"])
avg_score = round(sum(scores) / len(scores), 1) if scores else 0
status = _classify_control(scores)
return {
"control_id": control.control_id,
"title": control.title,
"category": control.category,
"status": status,
"score": avg_score,
"techniques_count": len(techniques),
"techniques_covered": covered_count,
"techniques": tech_details,
}
# ── Public service functions ───────────────────────────────────────────
def list_frameworks(db: Session) -> list[dict[str, Any]]:
"""List all available compliance frameworks with control counts."""
frameworks = (
db.query(ComplianceFramework)
.filter(ComplianceFramework.is_active == True)
.all()
)
result = []
for fw in frameworks:
control_count = (
db.query(ComplianceControl)
.filter(ComplianceControl.framework_id == fw.id)
.count()
)
result.append({
"id": str(fw.id),
"name": fw.name,
"version": fw.version,
"description": fw.description,
"url": fw.url,
"is_active": fw.is_active,
"controls_count": control_count,
})
return result
def get_framework(db: Session, framework_id: str) -> ComplianceFramework | None:
"""Get a framework by ID, or None if not found."""
return (
db.query(ComplianceFramework)
.filter(ComplianceFramework.id == framework_id)
.first()
)
def get_framework_status(db: Session, framework_id: str) -> dict[str, Any]:
"""Get compliance status for each control in a framework.
Raises EntityNotFoundError if the framework does not exist.
"""
framework = get_framework(db, framework_id)
if not framework:
raise EntityNotFoundError("Framework", framework_id)
controls = (
db.query(ComplianceControl)
.filter(ComplianceControl.framework_id == framework.id)
.order_by(ComplianceControl.control_id)
.all()
)
control_statuses = []
summary = {
"total_controls": len(controls),
"covered": 0,
"partially_covered": 0,
"not_covered": 0,
"not_evaluated": 0,
}
for control in controls:
status_data = _get_control_status(control, db)
control_statuses.append(status_data)
status = status_data["status"]
if status in summary:
summary[status] += 1
# Compliance percentage: (covered + partially_covered*0.5) / total * 100
total = summary["total_controls"]
if total > 0:
compliance_pct = round(
(summary["covered"] + summary["partially_covered"] * 0.5) / total * 100,
1,
)
else:
compliance_pct = 0
summary["compliance_percentage"] = compliance_pct
return {
"framework": {"id": str(framework.id), "name": framework.name},
"summary": summary,
"controls": control_statuses,
}
def build_framework_report_csv(
db: Session,
framework_id: str,
) -> tuple[bytes, str]:
"""Build the compliance report CSV content and filename.
Returns (csv_bytes, filename).
Raises EntityNotFoundError if the framework does not exist.
"""
framework = get_framework(db, framework_id)
if not framework:
raise EntityNotFoundError("Framework", framework_id)
controls = (
db.query(ComplianceControl)
.filter(ComplianceControl.framework_id == framework.id)
.order_by(ComplianceControl.control_id)
.all()
)
output = io.StringIO()
writer = csv.writer(output)
writer.writerow([
"control_id",
"title",
"category",
"status",
"score",
"techniques_total",
"techniques_covered",
"technique_ids",
])
for control in controls:
status_data = _get_control_status(control, db)
technique_ids = ",".join(t["mitre_id"] for t in status_data["techniques"])
writer.writerow([
status_data["control_id"],
status_data["title"],
status_data["category"] or "",
status_data["status"],
status_data["score"],
status_data["techniques_count"],
status_data["techniques_covered"],
technique_ids,
])
output.seek(0)
filename = f"compliance_{framework.name.replace(' ', '_')}.csv"
return output.getvalue().encode("utf-8"), filename
def get_framework_gaps(db: Session, framework_id: str) -> dict[str, Any]:
"""Get controls with techniques that are not adequately covered.
Raises EntityNotFoundError if the framework does not exist.
"""
framework = get_framework(db, framework_id)
if not framework:
raise EntityNotFoundError("Framework", framework_id)
controls = (
db.query(ComplianceControl)
.filter(ComplianceControl.framework_id == framework.id)
.order_by(ComplianceControl.control_id)
.all()
)
gaps = []
for control in controls:
status_data = _get_control_status(control, db)
if status_data["status"] in ("not_covered", "partially_covered"):
# Find uncovered techniques
uncovered_techniques = []
for tech_info in status_data["techniques"]:
if tech_info["score"] < 70:
# Count available templates
template_count = (
db.query(TestTemplate)
.filter(TestTemplate.mitre_technique_id == tech_info["mitre_id"])
.count()
)
# Count threat actors using this technique
technique = (
db.query(Technique)
.filter(Technique.mitre_id == tech_info["mitre_id"])
.first()
)
actor_count = 0
if technique:
actor_count = (
db.query(ThreatActorTechnique)
.filter(ThreatActorTechnique.technique_id == technique.id)
.count()
)
uncovered_techniques.append({
**tech_info,
"templates_available": template_count,
"threat_actors_using": actor_count,
})
if uncovered_techniques:
gaps.append({
"control_id": status_data["control_id"],
"title": status_data["title"],
"category": status_data["category"],
"status": status_data["status"],
"score": status_data["score"],
"uncovered_techniques": uncovered_techniques,
})
return {
"framework": {"id": str(framework.id), "name": framework.name},
"total_gaps": len(gaps),
"gaps": gaps,
}