Files
Aegis/backend/app/services/coverage_report_service.py

235 lines
7.4 KiB
Python

"""Coverage report data service.
Extracts query and aggregation logic from the reports router so
that the router remains a thin HTTP adapter. Fixes the N+1
technique/test-count pattern by using a single grouped query.
This module is framework-agnostic: no FastAPI imports.
"""
from __future__ import annotations
from datetime import datetime
from sqlalchemy import func
from sqlalchemy.orm import Session
from app.models.technique import Technique
from app.models.test import Test
from app.utils import escape_like
def _technique_test_counts(
db: Session,
technique_ids: list,
) -> dict:
"""Return ``{technique_id: {state_str: count}}`` in a single query."""
if not technique_ids:
return {}
rows = (
db.query(Test.technique_id, Test.state, func.count(Test.id))
.filter(Test.technique_id.in_(technique_ids))
.group_by(Test.technique_id, Test.state)
.all()
)
result: dict = {}
for tid, state, count in rows:
result.setdefault(tid, {})[str(state)] = count
return result
def build_coverage_summary(
db: Session,
*,
tactic: str | None = None,
platform: str | None = None,
) -> dict:
"""Build the full coverage summary report as a dict."""
query = db.query(Technique)
if tactic:
query = query.filter(Technique.tactic.ilike(f"%{escape_like(tactic)}%"))
techniques = query.order_by(Technique.mitre_id).all()
counts_map = _technique_test_counts(db, [t.id for t in techniques])
rows = []
for t in techniques:
if platform and platform.lower() not in [p.lower() for p in (t.platforms or [])]:
continue
counts = counts_map.get(t.id, {})
rows.append({
"mitre_id": t.mitre_id,
"name": t.name,
"tactic": t.tactic,
"platforms": t.platforms,
"status_global": t.status_global,
"total_tests": sum(counts.values()),
"tests_by_state": counts,
})
total = len(rows)
validated = sum(1 for r in rows if r["status_global"] == "validated")
partial = sum(1 for r in rows if r["status_global"] == "partial")
not_covered = sum(1 for r in rows if r["status_global"] == "not_covered")
in_progress = sum(1 for r in rows if r["status_global"] == "in_progress")
not_evaluated = sum(1 for r in rows if r["status_global"] == "not_evaluated")
return {
"generated_at": datetime.utcnow().isoformat(),
"summary": {
"total_techniques": total,
"validated": validated,
"partial": partial,
"not_covered": not_covered,
"in_progress": in_progress,
"not_evaluated": not_evaluated,
"coverage_percentage": round((validated / total * 100) if total > 0 else 0, 1),
},
"techniques": rows,
}
def build_coverage_csv_rows(
db: Session,
*,
tactic: str | None = None,
platform: str | None = None,
) -> list[list]:
"""Build rows for a CSV coverage export (header + data)."""
query = db.query(Technique)
if tactic:
query = query.filter(Technique.tactic.ilike(f"%{escape_like(tactic)}%"))
techniques = query.order_by(Technique.mitre_id).all()
counts_map = _technique_test_counts(db, [t.id for t in techniques])
header = [
"MITRE ID", "Name", "Tactic", "Platforms", "Status",
"Total Tests", "Validated", "In Progress", "Not Covered",
]
rows = [header]
in_progress_states = {"draft", "red_executing", "blue_evaluating", "in_review"}
for t in techniques:
if platform and platform.lower() not in [p.lower() for p in (t.platforms or [])]:
continue
counts = counts_map.get(t.id, {})
rows.append([
t.mitre_id,
t.name,
t.tactic,
", ".join(t.platforms or []),
t.status_global,
sum(counts.values()),
counts.get("validated", 0),
sum(counts.get(s, 0) for s in in_progress_states),
counts.get("rejected", 0),
])
return rows
def build_test_results_report(
db: Session,
*,
state: str | None = None,
date_from: str | None = None,
date_to: str | None = None,
) -> dict:
"""Build a test results report with optional filters."""
query = db.query(Test)
if state:
query = query.filter(Test.state == state)
if date_from:
try:
query = query.filter(Test.created_at >= datetime.fromisoformat(date_from))
except ValueError:
pass
if date_to:
try:
query = query.filter(Test.created_at <= datetime.fromisoformat(date_to))
except ValueError:
pass
tests = query.order_by(Test.created_at.desc()).all()
by_state: dict[str, int] = {}
by_result: dict[str, int] = {}
for t in tests:
s = t.state.value if hasattr(t.state, "value") else str(t.state)
by_state[s] = by_state.get(s, 0) + 1
if t.detection_result:
r = t.detection_result.value if hasattr(t.detection_result, "value") else str(t.detection_result)
by_result[r] = by_result.get(r, 0) + 1
return {
"generated_at": datetime.utcnow().isoformat(),
"filters": {"state": state, "date_from": date_from, "date_to": date_to},
"summary": {
"total_tests": len(tests),
"by_state": by_state,
"by_detection_result": by_result,
},
"tests": [
{
"id": str(t.id),
"name": t.name,
"technique_id": str(t.technique_id),
"state": t.state.value if hasattr(t.state, "value") else str(t.state),
"platform": t.platform,
"attack_success": t.attack_success,
"detection_result": (
t.detection_result.value if t.detection_result and hasattr(t.detection_result, "value")
else str(t.detection_result) if t.detection_result else None
),
"red_validation_status": t.red_validation_status,
"blue_validation_status": t.blue_validation_status,
"created_at": t.created_at.isoformat() if t.created_at else None,
}
for t in tests
],
}
def build_remediation_status_report(
db: Session,
*,
status: str | None = None,
) -> dict:
"""Build a remediation status report."""
query = db.query(Test).filter(Test.remediation_steps.isnot(None))
if status:
query = query.filter(Test.remediation_status == status)
tests = query.order_by(Test.created_at.desc()).all()
by_status: dict[str, int] = {}
for t in tests:
s = t.remediation_status or "unset"
by_status[s] = by_status.get(s, 0) + 1
return {
"generated_at": datetime.utcnow().isoformat(),
"summary": {
"total_with_remediation": len(tests),
"by_status": by_status,
},
"tests": [
{
"id": str(t.id),
"name": t.name,
"technique_id": str(t.technique_id),
"state": t.state.value if hasattr(t.state, "value") else str(t.state),
"remediation_status": t.remediation_status,
"remediation_steps": t.remediation_steps,
"remediation_assignee": str(t.remediation_assignee) if t.remediation_assignee else None,
}
for t in tests
],
}