refactor(reports): extract query and aggregation logic to coverage_report_service, fix N+1 test-count pattern

This commit is contained in:
2026-02-19 15:56:42 +01:00
parent 42a9f4dcd4
commit 8d5c5fa80e
2 changed files with 250 additions and 199 deletions

View File

@@ -1,5 +1,7 @@
"""Reports endpoints — export coverage summaries and test results.
Thin HTTP adapter: delegates all data logic to coverage_report_service.
Endpoints
---------
GET /reports/coverage-summary — full coverage JSON report
@@ -15,24 +17,21 @@ from typing import Optional
from fastapi import APIRouter, Depends, Query
from fastapi.responses import StreamingResponse
from sqlalchemy import func
from sqlalchemy.orm import Session
from app.database import get_db
from app.dependencies.auth import get_current_user
from app.models.enums import TestState
from app.models.technique import Technique
from app.models.test import Test
from app.models.user import User
from app.services.coverage_report_service import (
build_coverage_csv_rows,
build_coverage_summary,
build_remediation_status_report,
build_test_results_report,
)
router = APIRouter(prefix="/reports", tags=["reports"])
# ---------------------------------------------------------------------------
# GET /reports/coverage-summary
# ---------------------------------------------------------------------------
@router.get("/coverage-summary")
def coverage_summary(
tactic: Optional[str] = Query(None, description="Filter by tactic"),
@@ -41,63 +40,7 @@ def coverage_summary(
current_user: User = Depends(get_current_user),
):
"""Full coverage report as JSON — technique-by-technique with test counts."""
query = db.query(Technique)
if tactic:
from app.utils import escape_like
query = query.filter(Technique.tactic.ilike(f"%{escape_like(tactic)}%"))
techniques = query.order_by(Technique.mitre_id).all()
rows = []
for t in techniques:
# Count tests per state for this technique
test_counts = (
db.query(Test.state, func.count(Test.id))
.filter(Test.technique_id == t.id)
.group_by(Test.state)
.all()
)
counts = {str(state): count for state, count in test_counts}
# Filter by platform if requested (check if technique platforms contain it)
if platform and platform.lower() not in [p.lower() for p in (t.platforms or [])]:
continue
rows.append({
"mitre_id": t.mitre_id,
"name": t.name,
"tactic": t.tactic,
"platforms": t.platforms,
"status_global": t.status_global,
"total_tests": sum(counts.values()),
"tests_by_state": counts,
})
total = len(rows)
validated = sum(1 for r in rows if r["status_global"] == "validated")
partial = sum(1 for r in rows if r["status_global"] == "partial")
not_covered = sum(1 for r in rows if r["status_global"] == "not_covered")
in_progress = sum(1 for r in rows if r["status_global"] == "in_progress")
not_evaluated = sum(1 for r in rows if r["status_global"] == "not_evaluated")
return {
"generated_at": datetime.utcnow().isoformat(),
"summary": {
"total_techniques": total,
"validated": validated,
"partial": partial,
"not_covered": not_covered,
"in_progress": in_progress,
"not_evaluated": not_evaluated,
"coverage_percentage": round((validated / total * 100) if total > 0 else 0, 1),
},
"techniques": rows,
}
# ---------------------------------------------------------------------------
# GET /reports/coverage-csv
# ---------------------------------------------------------------------------
return build_coverage_summary(db, tactic=tactic, platform=platform)
@router.get("/coverage-csv")
@@ -108,57 +51,22 @@ def coverage_csv(
current_user: User = Depends(get_current_user),
):
"""Export coverage as a downloadable CSV."""
query = db.query(Technique)
if tactic:
from app.utils import escape_like
query = query.filter(Technique.tactic.ilike(f"%{escape_like(tactic)}%"))
techniques = query.order_by(Technique.mitre_id).all()
rows = build_coverage_csv_rows(db, tactic=tactic, platform=platform)
output = io.StringIO()
writer = csv.writer(output)
writer.writerow([
"MITRE ID", "Name", "Tactic", "Platforms", "Status",
"Total Tests", "Validated", "In Progress", "Not Covered",
])
for t in techniques:
if platform and platform.lower() not in [p.lower() for p in (t.platforms or [])]:
continue
test_counts = (
db.query(Test.state, func.count(Test.id))
.filter(Test.technique_id == t.id)
.group_by(Test.state)
.all()
)
counts = {str(state): count for state, count in test_counts}
writer.writerow([
t.mitre_id,
t.name,
t.tactic,
", ".join(t.platforms or []),
t.status_global,
sum(counts.values()),
counts.get("validated", 0),
sum(counts.get(s, 0) for s in ["draft", "red_executing", "blue_evaluating", "in_review"]),
counts.get("rejected", 0),
])
for row in rows:
writer.writerow(row)
output.seek(0)
filename = f"aegis_coverage_{datetime.utcnow().strftime('%Y%m%d')}.csv"
return StreamingResponse(
iter([output.getvalue()]),
media_type="text/csv",
headers={"Content-Disposition": f"attachment; filename=aegis_coverage_{datetime.utcnow().strftime('%Y%m%d')}.csv"},
headers={"Content-Disposition": f"attachment; filename={filename}"},
)
# ---------------------------------------------------------------------------
# GET /reports/test-results
# ---------------------------------------------------------------------------
@router.get("/test-results")
def test_results(
state: Optional[str] = Query(None),
@@ -168,68 +76,7 @@ def test_results(
current_user: User = Depends(get_current_user),
):
"""Report of test results with optional filters."""
query = db.query(Test)
if state:
query = query.filter(Test.state == state)
if date_from:
try:
dt = datetime.fromisoformat(date_from)
query = query.filter(Test.created_at >= dt)
except ValueError:
pass
if date_to:
try:
dt = datetime.fromisoformat(date_to)
query = query.filter(Test.created_at <= dt)
except ValueError:
pass
tests = query.order_by(Test.created_at.desc()).all()
# Summary
total = len(tests)
by_state = {}
by_result = {}
for t in tests:
s = t.state.value if hasattr(t.state, "value") else str(t.state)
by_state[s] = by_state.get(s, 0) + 1
if t.detection_result:
r = t.detection_result.value if hasattr(t.detection_result, "value") else str(t.detection_result)
by_result[r] = by_result.get(r, 0) + 1
return {
"generated_at": datetime.utcnow().isoformat(),
"filters": {"state": state, "date_from": date_from, "date_to": date_to},
"summary": {
"total_tests": total,
"by_state": by_state,
"by_detection_result": by_result,
},
"tests": [
{
"id": str(t.id),
"name": t.name,
"technique_id": str(t.technique_id),
"state": t.state.value if hasattr(t.state, "value") else str(t.state),
"platform": t.platform,
"attack_success": t.attack_success,
"detection_result": (
t.detection_result.value if t.detection_result and hasattr(t.detection_result, "value")
else str(t.detection_result) if t.detection_result else None
),
"red_validation_status": t.red_validation_status,
"blue_validation_status": t.blue_validation_status,
"created_at": t.created_at.isoformat() if t.created_at else None,
}
for t in tests
],
}
# ---------------------------------------------------------------------------
# GET /reports/remediation-status
# ---------------------------------------------------------------------------
return build_test_results_report(db, state=state, date_from=date_from, date_to=date_to)
@router.get("/remediation-status")
@@ -239,34 +86,4 @@ def remediation_status(
current_user: User = Depends(get_current_user),
):
"""Report of remediation status across all tests."""
query = db.query(Test).filter(Test.remediation_steps.isnot(None))
if status:
query = query.filter(Test.remediation_status == status)
tests = query.order_by(Test.created_at.desc()).all()
by_status = {}
for t in tests:
s = t.remediation_status or "unset"
by_status[s] = by_status.get(s, 0) + 1
return {
"generated_at": datetime.utcnow().isoformat(),
"summary": {
"total_with_remediation": len(tests),
"by_status": by_status,
},
"tests": [
{
"id": str(t.id),
"name": t.name,
"technique_id": str(t.technique_id),
"state": t.state.value if hasattr(t.state, "value") else str(t.state),
"remediation_status": t.remediation_status,
"remediation_steps": t.remediation_steps,
"remediation_assignee": str(t.remediation_assignee) if t.remediation_assignee else None,
}
for t in tests
],
}
return build_remediation_status_report(db, status=status)

View File

@@ -0,0 +1,234 @@
"""Coverage report data service.
Extracts query and aggregation logic from the reports router so
that the router remains a thin HTTP adapter. Fixes the N+1
technique/test-count pattern by using a single grouped query.
This module is framework-agnostic: no FastAPI imports.
"""
from __future__ import annotations
from datetime import datetime
from sqlalchemy import func
from sqlalchemy.orm import Session
from app.models.technique import Technique
from app.models.test import Test
from app.utils import escape_like
def _technique_test_counts(
db: Session,
technique_ids: list,
) -> dict:
"""Return ``{technique_id: {state_str: count}}`` in a single query."""
if not technique_ids:
return {}
rows = (
db.query(Test.technique_id, Test.state, func.count(Test.id))
.filter(Test.technique_id.in_(technique_ids))
.group_by(Test.technique_id, Test.state)
.all()
)
result: dict = {}
for tid, state, count in rows:
result.setdefault(tid, {})[str(state)] = count
return result
def build_coverage_summary(
db: Session,
*,
tactic: str | None = None,
platform: str | None = None,
) -> dict:
"""Build the full coverage summary report as a dict."""
query = db.query(Technique)
if tactic:
query = query.filter(Technique.tactic.ilike(f"%{escape_like(tactic)}%"))
techniques = query.order_by(Technique.mitre_id).all()
counts_map = _technique_test_counts(db, [t.id for t in techniques])
rows = []
for t in techniques:
if platform and platform.lower() not in [p.lower() for p in (t.platforms or [])]:
continue
counts = counts_map.get(t.id, {})
rows.append({
"mitre_id": t.mitre_id,
"name": t.name,
"tactic": t.tactic,
"platforms": t.platforms,
"status_global": t.status_global,
"total_tests": sum(counts.values()),
"tests_by_state": counts,
})
total = len(rows)
validated = sum(1 for r in rows if r["status_global"] == "validated")
partial = sum(1 for r in rows if r["status_global"] == "partial")
not_covered = sum(1 for r in rows if r["status_global"] == "not_covered")
in_progress = sum(1 for r in rows if r["status_global"] == "in_progress")
not_evaluated = sum(1 for r in rows if r["status_global"] == "not_evaluated")
return {
"generated_at": datetime.utcnow().isoformat(),
"summary": {
"total_techniques": total,
"validated": validated,
"partial": partial,
"not_covered": not_covered,
"in_progress": in_progress,
"not_evaluated": not_evaluated,
"coverage_percentage": round((validated / total * 100) if total > 0 else 0, 1),
},
"techniques": rows,
}
def build_coverage_csv_rows(
db: Session,
*,
tactic: str | None = None,
platform: str | None = None,
) -> list[list]:
"""Build rows for a CSV coverage export (header + data)."""
query = db.query(Technique)
if tactic:
query = query.filter(Technique.tactic.ilike(f"%{escape_like(tactic)}%"))
techniques = query.order_by(Technique.mitre_id).all()
counts_map = _technique_test_counts(db, [t.id for t in techniques])
header = [
"MITRE ID", "Name", "Tactic", "Platforms", "Status",
"Total Tests", "Validated", "In Progress", "Not Covered",
]
rows = [header]
in_progress_states = {"draft", "red_executing", "blue_evaluating", "in_review"}
for t in techniques:
if platform and platform.lower() not in [p.lower() for p in (t.platforms or [])]:
continue
counts = counts_map.get(t.id, {})
rows.append([
t.mitre_id,
t.name,
t.tactic,
", ".join(t.platforms or []),
t.status_global,
sum(counts.values()),
counts.get("validated", 0),
sum(counts.get(s, 0) for s in in_progress_states),
counts.get("rejected", 0),
])
return rows
def build_test_results_report(
db: Session,
*,
state: str | None = None,
date_from: str | None = None,
date_to: str | None = None,
) -> dict:
"""Build a test results report with optional filters."""
query = db.query(Test)
if state:
query = query.filter(Test.state == state)
if date_from:
try:
query = query.filter(Test.created_at >= datetime.fromisoformat(date_from))
except ValueError:
pass
if date_to:
try:
query = query.filter(Test.created_at <= datetime.fromisoformat(date_to))
except ValueError:
pass
tests = query.order_by(Test.created_at.desc()).all()
by_state: dict[str, int] = {}
by_result: dict[str, int] = {}
for t in tests:
s = t.state.value if hasattr(t.state, "value") else str(t.state)
by_state[s] = by_state.get(s, 0) + 1
if t.detection_result:
r = t.detection_result.value if hasattr(t.detection_result, "value") else str(t.detection_result)
by_result[r] = by_result.get(r, 0) + 1
return {
"generated_at": datetime.utcnow().isoformat(),
"filters": {"state": state, "date_from": date_from, "date_to": date_to},
"summary": {
"total_tests": len(tests),
"by_state": by_state,
"by_detection_result": by_result,
},
"tests": [
{
"id": str(t.id),
"name": t.name,
"technique_id": str(t.technique_id),
"state": t.state.value if hasattr(t.state, "value") else str(t.state),
"platform": t.platform,
"attack_success": t.attack_success,
"detection_result": (
t.detection_result.value if t.detection_result and hasattr(t.detection_result, "value")
else str(t.detection_result) if t.detection_result else None
),
"red_validation_status": t.red_validation_status,
"blue_validation_status": t.blue_validation_status,
"created_at": t.created_at.isoformat() if t.created_at else None,
}
for t in tests
],
}
def build_remediation_status_report(
db: Session,
*,
status: str | None = None,
) -> dict:
"""Build a remediation status report."""
query = db.query(Test).filter(Test.remediation_steps.isnot(None))
if status:
query = query.filter(Test.remediation_status == status)
tests = query.order_by(Test.created_at.desc()).all()
by_status: dict[str, int] = {}
for t in tests:
s = t.remediation_status or "unset"
by_status[s] = by_status.get(s, 0) + 1
return {
"generated_at": datetime.utcnow().isoformat(),
"summary": {
"total_with_remediation": len(tests),
"by_status": by_status,
},
"tests": [
{
"id": str(t.id),
"name": t.name,
"technique_id": str(t.technique_id),
"state": t.state.value if hasattr(t.state, "value") else str(t.state),
"remediation_status": t.remediation_status,
"remediation_steps": t.remediation_steps,
"remediation_assignee": str(t.remediation_assignee) if t.remediation_assignee else None,
}
for t in tests
],
}