refactor(reports): extract query and aggregation logic to coverage_report_service, fix N+1 test-count pattern
This commit is contained in:
234
backend/app/services/coverage_report_service.py
Normal file
234
backend/app/services/coverage_report_service.py
Normal file
@@ -0,0 +1,234 @@
|
||||
"""Coverage report data service.
|
||||
|
||||
Extracts query and aggregation logic from the reports router so
|
||||
that the router remains a thin HTTP adapter. Fixes the N+1
|
||||
technique/test-count pattern by using a single grouped query.
|
||||
|
||||
This module is framework-agnostic: no FastAPI imports.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.technique import Technique
|
||||
from app.models.test import Test
|
||||
from app.utils import escape_like
|
||||
|
||||
|
||||
def _technique_test_counts(
|
||||
db: Session,
|
||||
technique_ids: list,
|
||||
) -> dict:
|
||||
"""Return ``{technique_id: {state_str: count}}`` in a single query."""
|
||||
if not technique_ids:
|
||||
return {}
|
||||
rows = (
|
||||
db.query(Test.technique_id, Test.state, func.count(Test.id))
|
||||
.filter(Test.technique_id.in_(technique_ids))
|
||||
.group_by(Test.technique_id, Test.state)
|
||||
.all()
|
||||
)
|
||||
result: dict = {}
|
||||
for tid, state, count in rows:
|
||||
result.setdefault(tid, {})[str(state)] = count
|
||||
return result
|
||||
|
||||
|
||||
def build_coverage_summary(
|
||||
db: Session,
|
||||
*,
|
||||
tactic: str | None = None,
|
||||
platform: str | None = None,
|
||||
) -> dict:
|
||||
"""Build the full coverage summary report as a dict."""
|
||||
query = db.query(Technique)
|
||||
if tactic:
|
||||
query = query.filter(Technique.tactic.ilike(f"%{escape_like(tactic)}%"))
|
||||
|
||||
techniques = query.order_by(Technique.mitre_id).all()
|
||||
counts_map = _technique_test_counts(db, [t.id for t in techniques])
|
||||
|
||||
rows = []
|
||||
for t in techniques:
|
||||
if platform and platform.lower() not in [p.lower() for p in (t.platforms or [])]:
|
||||
continue
|
||||
|
||||
counts = counts_map.get(t.id, {})
|
||||
rows.append({
|
||||
"mitre_id": t.mitre_id,
|
||||
"name": t.name,
|
||||
"tactic": t.tactic,
|
||||
"platforms": t.platforms,
|
||||
"status_global": t.status_global,
|
||||
"total_tests": sum(counts.values()),
|
||||
"tests_by_state": counts,
|
||||
})
|
||||
|
||||
total = len(rows)
|
||||
validated = sum(1 for r in rows if r["status_global"] == "validated")
|
||||
partial = sum(1 for r in rows if r["status_global"] == "partial")
|
||||
not_covered = sum(1 for r in rows if r["status_global"] == "not_covered")
|
||||
in_progress = sum(1 for r in rows if r["status_global"] == "in_progress")
|
||||
not_evaluated = sum(1 for r in rows if r["status_global"] == "not_evaluated")
|
||||
|
||||
return {
|
||||
"generated_at": datetime.utcnow().isoformat(),
|
||||
"summary": {
|
||||
"total_techniques": total,
|
||||
"validated": validated,
|
||||
"partial": partial,
|
||||
"not_covered": not_covered,
|
||||
"in_progress": in_progress,
|
||||
"not_evaluated": not_evaluated,
|
||||
"coverage_percentage": round((validated / total * 100) if total > 0 else 0, 1),
|
||||
},
|
||||
"techniques": rows,
|
||||
}
|
||||
|
||||
|
||||
def build_coverage_csv_rows(
|
||||
db: Session,
|
||||
*,
|
||||
tactic: str | None = None,
|
||||
platform: str | None = None,
|
||||
) -> list[list]:
|
||||
"""Build rows for a CSV coverage export (header + data)."""
|
||||
query = db.query(Technique)
|
||||
if tactic:
|
||||
query = query.filter(Technique.tactic.ilike(f"%{escape_like(tactic)}%"))
|
||||
|
||||
techniques = query.order_by(Technique.mitre_id).all()
|
||||
counts_map = _technique_test_counts(db, [t.id for t in techniques])
|
||||
|
||||
header = [
|
||||
"MITRE ID", "Name", "Tactic", "Platforms", "Status",
|
||||
"Total Tests", "Validated", "In Progress", "Not Covered",
|
||||
]
|
||||
rows = [header]
|
||||
|
||||
in_progress_states = {"draft", "red_executing", "blue_evaluating", "in_review"}
|
||||
|
||||
for t in techniques:
|
||||
if platform and platform.lower() not in [p.lower() for p in (t.platforms or [])]:
|
||||
continue
|
||||
|
||||
counts = counts_map.get(t.id, {})
|
||||
rows.append([
|
||||
t.mitre_id,
|
||||
t.name,
|
||||
t.tactic,
|
||||
", ".join(t.platforms or []),
|
||||
t.status_global,
|
||||
sum(counts.values()),
|
||||
counts.get("validated", 0),
|
||||
sum(counts.get(s, 0) for s in in_progress_states),
|
||||
counts.get("rejected", 0),
|
||||
])
|
||||
|
||||
return rows
|
||||
|
||||
|
||||
def build_test_results_report(
|
||||
db: Session,
|
||||
*,
|
||||
state: str | None = None,
|
||||
date_from: str | None = None,
|
||||
date_to: str | None = None,
|
||||
) -> dict:
|
||||
"""Build a test results report with optional filters."""
|
||||
query = db.query(Test)
|
||||
|
||||
if state:
|
||||
query = query.filter(Test.state == state)
|
||||
if date_from:
|
||||
try:
|
||||
query = query.filter(Test.created_at >= datetime.fromisoformat(date_from))
|
||||
except ValueError:
|
||||
pass
|
||||
if date_to:
|
||||
try:
|
||||
query = query.filter(Test.created_at <= datetime.fromisoformat(date_to))
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
tests = query.order_by(Test.created_at.desc()).all()
|
||||
|
||||
by_state: dict[str, int] = {}
|
||||
by_result: dict[str, int] = {}
|
||||
for t in tests:
|
||||
s = t.state.value if hasattr(t.state, "value") else str(t.state)
|
||||
by_state[s] = by_state.get(s, 0) + 1
|
||||
if t.detection_result:
|
||||
r = t.detection_result.value if hasattr(t.detection_result, "value") else str(t.detection_result)
|
||||
by_result[r] = by_result.get(r, 0) + 1
|
||||
|
||||
return {
|
||||
"generated_at": datetime.utcnow().isoformat(),
|
||||
"filters": {"state": state, "date_from": date_from, "date_to": date_to},
|
||||
"summary": {
|
||||
"total_tests": len(tests),
|
||||
"by_state": by_state,
|
||||
"by_detection_result": by_result,
|
||||
},
|
||||
"tests": [
|
||||
{
|
||||
"id": str(t.id),
|
||||
"name": t.name,
|
||||
"technique_id": str(t.technique_id),
|
||||
"state": t.state.value if hasattr(t.state, "value") else str(t.state),
|
||||
"platform": t.platform,
|
||||
"attack_success": t.attack_success,
|
||||
"detection_result": (
|
||||
t.detection_result.value if t.detection_result and hasattr(t.detection_result, "value")
|
||||
else str(t.detection_result) if t.detection_result else None
|
||||
),
|
||||
"red_validation_status": t.red_validation_status,
|
||||
"blue_validation_status": t.blue_validation_status,
|
||||
"created_at": t.created_at.isoformat() if t.created_at else None,
|
||||
}
|
||||
for t in tests
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def build_remediation_status_report(
|
||||
db: Session,
|
||||
*,
|
||||
status: str | None = None,
|
||||
) -> dict:
|
||||
"""Build a remediation status report."""
|
||||
query = db.query(Test).filter(Test.remediation_steps.isnot(None))
|
||||
|
||||
if status:
|
||||
query = query.filter(Test.remediation_status == status)
|
||||
|
||||
tests = query.order_by(Test.created_at.desc()).all()
|
||||
|
||||
by_status: dict[str, int] = {}
|
||||
for t in tests:
|
||||
s = t.remediation_status or "unset"
|
||||
by_status[s] = by_status.get(s, 0) + 1
|
||||
|
||||
return {
|
||||
"generated_at": datetime.utcnow().isoformat(),
|
||||
"summary": {
|
||||
"total_with_remediation": len(tests),
|
||||
"by_status": by_status,
|
||||
},
|
||||
"tests": [
|
||||
{
|
||||
"id": str(t.id),
|
||||
"name": t.name,
|
||||
"technique_id": str(t.technique_id),
|
||||
"state": t.state.value if hasattr(t.state, "value") else str(t.state),
|
||||
"remediation_status": t.remediation_status,
|
||||
"remediation_steps": t.remediation_steps,
|
||||
"remediation_assignee": str(t.remediation_assignee) if t.remediation_assignee else None,
|
||||
}
|
||||
for t in tests
|
||||
],
|
||||
}
|
||||
Reference in New Issue
Block a user