refactor(reports): extract query and aggregation logic to coverage_report_service, fix N+1 test-count pattern

This commit is contained in:
2026-02-19 15:56:42 +01:00
parent 42a9f4dcd4
commit 8d5c5fa80e
2 changed files with 250 additions and 199 deletions

View File

@@ -1,5 +1,7 @@
"""Reports endpoints — export coverage summaries and test results. """Reports endpoints — export coverage summaries and test results.
Thin HTTP adapter: delegates all data logic to coverage_report_service.
Endpoints Endpoints
--------- ---------
GET /reports/coverage-summary — full coverage JSON report GET /reports/coverage-summary — full coverage JSON report
@@ -15,24 +17,21 @@ from typing import Optional
from fastapi import APIRouter, Depends, Query from fastapi import APIRouter, Depends, Query
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from sqlalchemy import func
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.database import get_db from app.database import get_db
from app.dependencies.auth import get_current_user from app.dependencies.auth import get_current_user
from app.models.enums import TestState
from app.models.technique import Technique
from app.models.test import Test
from app.models.user import User from app.models.user import User
from app.services.coverage_report_service import (
build_coverage_csv_rows,
build_coverage_summary,
build_remediation_status_report,
build_test_results_report,
)
router = APIRouter(prefix="/reports", tags=["reports"]) router = APIRouter(prefix="/reports", tags=["reports"])
# ---------------------------------------------------------------------------
# GET /reports/coverage-summary
# ---------------------------------------------------------------------------
@router.get("/coverage-summary") @router.get("/coverage-summary")
def coverage_summary( def coverage_summary(
tactic: Optional[str] = Query(None, description="Filter by tactic"), tactic: Optional[str] = Query(None, description="Filter by tactic"),
@@ -41,63 +40,7 @@ def coverage_summary(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ):
"""Full coverage report as JSON — technique-by-technique with test counts.""" """Full coverage report as JSON — technique-by-technique with test counts."""
query = db.query(Technique) return build_coverage_summary(db, tactic=tactic, platform=platform)
if tactic:
from app.utils import escape_like
query = query.filter(Technique.tactic.ilike(f"%{escape_like(tactic)}%"))
techniques = query.order_by(Technique.mitre_id).all()
rows = []
for t in techniques:
# Count tests per state for this technique
test_counts = (
db.query(Test.state, func.count(Test.id))
.filter(Test.technique_id == t.id)
.group_by(Test.state)
.all()
)
counts = {str(state): count for state, count in test_counts}
# Filter by platform if requested (check if technique platforms contain it)
if platform and platform.lower() not in [p.lower() for p in (t.platforms or [])]:
continue
rows.append({
"mitre_id": t.mitre_id,
"name": t.name,
"tactic": t.tactic,
"platforms": t.platforms,
"status_global": t.status_global,
"total_tests": sum(counts.values()),
"tests_by_state": counts,
})
total = len(rows)
validated = sum(1 for r in rows if r["status_global"] == "validated")
partial = sum(1 for r in rows if r["status_global"] == "partial")
not_covered = sum(1 for r in rows if r["status_global"] == "not_covered")
in_progress = sum(1 for r in rows if r["status_global"] == "in_progress")
not_evaluated = sum(1 for r in rows if r["status_global"] == "not_evaluated")
return {
"generated_at": datetime.utcnow().isoformat(),
"summary": {
"total_techniques": total,
"validated": validated,
"partial": partial,
"not_covered": not_covered,
"in_progress": in_progress,
"not_evaluated": not_evaluated,
"coverage_percentage": round((validated / total * 100) if total > 0 else 0, 1),
},
"techniques": rows,
}
# ---------------------------------------------------------------------------
# GET /reports/coverage-csv
# ---------------------------------------------------------------------------
@router.get("/coverage-csv") @router.get("/coverage-csv")
@@ -108,57 +51,22 @@ def coverage_csv(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ):
"""Export coverage as a downloadable CSV.""" """Export coverage as a downloadable CSV."""
query = db.query(Technique) rows = build_coverage_csv_rows(db, tactic=tactic, platform=platform)
if tactic:
from app.utils import escape_like
query = query.filter(Technique.tactic.ilike(f"%{escape_like(tactic)}%"))
techniques = query.order_by(Technique.mitre_id).all()
output = io.StringIO() output = io.StringIO()
writer = csv.writer(output) writer = csv.writer(output)
writer.writerow([ for row in rows:
"MITRE ID", "Name", "Tactic", "Platforms", "Status", writer.writerow(row)
"Total Tests", "Validated", "In Progress", "Not Covered",
])
for t in techniques:
if platform and platform.lower() not in [p.lower() for p in (t.platforms or [])]:
continue
test_counts = (
db.query(Test.state, func.count(Test.id))
.filter(Test.technique_id == t.id)
.group_by(Test.state)
.all()
)
counts = {str(state): count for state, count in test_counts}
writer.writerow([
t.mitre_id,
t.name,
t.tactic,
", ".join(t.platforms or []),
t.status_global,
sum(counts.values()),
counts.get("validated", 0),
sum(counts.get(s, 0) for s in ["draft", "red_executing", "blue_evaluating", "in_review"]),
counts.get("rejected", 0),
])
output.seek(0) output.seek(0)
filename = f"aegis_coverage_{datetime.utcnow().strftime('%Y%m%d')}.csv"
return StreamingResponse( return StreamingResponse(
iter([output.getvalue()]), iter([output.getvalue()]),
media_type="text/csv", media_type="text/csv",
headers={"Content-Disposition": f"attachment; filename=aegis_coverage_{datetime.utcnow().strftime('%Y%m%d')}.csv"}, headers={"Content-Disposition": f"attachment; filename={filename}"},
) )
# ---------------------------------------------------------------------------
# GET /reports/test-results
# ---------------------------------------------------------------------------
@router.get("/test-results") @router.get("/test-results")
def test_results( def test_results(
state: Optional[str] = Query(None), state: Optional[str] = Query(None),
@@ -168,68 +76,7 @@ def test_results(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ):
"""Report of test results with optional filters.""" """Report of test results with optional filters."""
query = db.query(Test) return build_test_results_report(db, state=state, date_from=date_from, date_to=date_to)
if state:
query = query.filter(Test.state == state)
if date_from:
try:
dt = datetime.fromisoformat(date_from)
query = query.filter(Test.created_at >= dt)
except ValueError:
pass
if date_to:
try:
dt = datetime.fromisoformat(date_to)
query = query.filter(Test.created_at <= dt)
except ValueError:
pass
tests = query.order_by(Test.created_at.desc()).all()
# Summary
total = len(tests)
by_state = {}
by_result = {}
for t in tests:
s = t.state.value if hasattr(t.state, "value") else str(t.state)
by_state[s] = by_state.get(s, 0) + 1
if t.detection_result:
r = t.detection_result.value if hasattr(t.detection_result, "value") else str(t.detection_result)
by_result[r] = by_result.get(r, 0) + 1
return {
"generated_at": datetime.utcnow().isoformat(),
"filters": {"state": state, "date_from": date_from, "date_to": date_to},
"summary": {
"total_tests": total,
"by_state": by_state,
"by_detection_result": by_result,
},
"tests": [
{
"id": str(t.id),
"name": t.name,
"technique_id": str(t.technique_id),
"state": t.state.value if hasattr(t.state, "value") else str(t.state),
"platform": t.platform,
"attack_success": t.attack_success,
"detection_result": (
t.detection_result.value if t.detection_result and hasattr(t.detection_result, "value")
else str(t.detection_result) if t.detection_result else None
),
"red_validation_status": t.red_validation_status,
"blue_validation_status": t.blue_validation_status,
"created_at": t.created_at.isoformat() if t.created_at else None,
}
for t in tests
],
}
# ---------------------------------------------------------------------------
# GET /reports/remediation-status
# ---------------------------------------------------------------------------
@router.get("/remediation-status") @router.get("/remediation-status")
@@ -239,34 +86,4 @@ def remediation_status(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ):
"""Report of remediation status across all tests.""" """Report of remediation status across all tests."""
query = db.query(Test).filter(Test.remediation_steps.isnot(None)) return build_remediation_status_report(db, status=status)
if status:
query = query.filter(Test.remediation_status == status)
tests = query.order_by(Test.created_at.desc()).all()
by_status = {}
for t in tests:
s = t.remediation_status or "unset"
by_status[s] = by_status.get(s, 0) + 1
return {
"generated_at": datetime.utcnow().isoformat(),
"summary": {
"total_with_remediation": len(tests),
"by_status": by_status,
},
"tests": [
{
"id": str(t.id),
"name": t.name,
"technique_id": str(t.technique_id),
"state": t.state.value if hasattr(t.state, "value") else str(t.state),
"remediation_status": t.remediation_status,
"remediation_steps": t.remediation_steps,
"remediation_assignee": str(t.remediation_assignee) if t.remediation_assignee else None,
}
for t in tests
],
}

View File

@@ -0,0 +1,234 @@
"""Coverage report data service.
Extracts query and aggregation logic from the reports router so
that the router remains a thin HTTP adapter. Fixes the N+1
technique/test-count pattern by using a single grouped query.
This module is framework-agnostic: no FastAPI imports.
"""
from __future__ import annotations
from datetime import datetime
from sqlalchemy import func
from sqlalchemy.orm import Session
from app.models.technique import Technique
from app.models.test import Test
from app.utils import escape_like
def _technique_test_counts(
db: Session,
technique_ids: list,
) -> dict:
"""Return ``{technique_id: {state_str: count}}`` in a single query."""
if not technique_ids:
return {}
rows = (
db.query(Test.technique_id, Test.state, func.count(Test.id))
.filter(Test.technique_id.in_(technique_ids))
.group_by(Test.technique_id, Test.state)
.all()
)
result: dict = {}
for tid, state, count in rows:
result.setdefault(tid, {})[str(state)] = count
return result
def build_coverage_summary(
db: Session,
*,
tactic: str | None = None,
platform: str | None = None,
) -> dict:
"""Build the full coverage summary report as a dict."""
query = db.query(Technique)
if tactic:
query = query.filter(Technique.tactic.ilike(f"%{escape_like(tactic)}%"))
techniques = query.order_by(Technique.mitre_id).all()
counts_map = _technique_test_counts(db, [t.id for t in techniques])
rows = []
for t in techniques:
if platform and platform.lower() not in [p.lower() for p in (t.platforms or [])]:
continue
counts = counts_map.get(t.id, {})
rows.append({
"mitre_id": t.mitre_id,
"name": t.name,
"tactic": t.tactic,
"platforms": t.platforms,
"status_global": t.status_global,
"total_tests": sum(counts.values()),
"tests_by_state": counts,
})
total = len(rows)
validated = sum(1 for r in rows if r["status_global"] == "validated")
partial = sum(1 for r in rows if r["status_global"] == "partial")
not_covered = sum(1 for r in rows if r["status_global"] == "not_covered")
in_progress = sum(1 for r in rows if r["status_global"] == "in_progress")
not_evaluated = sum(1 for r in rows if r["status_global"] == "not_evaluated")
return {
"generated_at": datetime.utcnow().isoformat(),
"summary": {
"total_techniques": total,
"validated": validated,
"partial": partial,
"not_covered": not_covered,
"in_progress": in_progress,
"not_evaluated": not_evaluated,
"coverage_percentage": round((validated / total * 100) if total > 0 else 0, 1),
},
"techniques": rows,
}
def build_coverage_csv_rows(
db: Session,
*,
tactic: str | None = None,
platform: str | None = None,
) -> list[list]:
"""Build rows for a CSV coverage export (header + data)."""
query = db.query(Technique)
if tactic:
query = query.filter(Technique.tactic.ilike(f"%{escape_like(tactic)}%"))
techniques = query.order_by(Technique.mitre_id).all()
counts_map = _technique_test_counts(db, [t.id for t in techniques])
header = [
"MITRE ID", "Name", "Tactic", "Platforms", "Status",
"Total Tests", "Validated", "In Progress", "Not Covered",
]
rows = [header]
in_progress_states = {"draft", "red_executing", "blue_evaluating", "in_review"}
for t in techniques:
if platform and platform.lower() not in [p.lower() for p in (t.platforms or [])]:
continue
counts = counts_map.get(t.id, {})
rows.append([
t.mitre_id,
t.name,
t.tactic,
", ".join(t.platforms or []),
t.status_global,
sum(counts.values()),
counts.get("validated", 0),
sum(counts.get(s, 0) for s in in_progress_states),
counts.get("rejected", 0),
])
return rows
def build_test_results_report(
db: Session,
*,
state: str | None = None,
date_from: str | None = None,
date_to: str | None = None,
) -> dict:
"""Build a test results report with optional filters."""
query = db.query(Test)
if state:
query = query.filter(Test.state == state)
if date_from:
try:
query = query.filter(Test.created_at >= datetime.fromisoformat(date_from))
except ValueError:
pass
if date_to:
try:
query = query.filter(Test.created_at <= datetime.fromisoformat(date_to))
except ValueError:
pass
tests = query.order_by(Test.created_at.desc()).all()
by_state: dict[str, int] = {}
by_result: dict[str, int] = {}
for t in tests:
s = t.state.value if hasattr(t.state, "value") else str(t.state)
by_state[s] = by_state.get(s, 0) + 1
if t.detection_result:
r = t.detection_result.value if hasattr(t.detection_result, "value") else str(t.detection_result)
by_result[r] = by_result.get(r, 0) + 1
return {
"generated_at": datetime.utcnow().isoformat(),
"filters": {"state": state, "date_from": date_from, "date_to": date_to},
"summary": {
"total_tests": len(tests),
"by_state": by_state,
"by_detection_result": by_result,
},
"tests": [
{
"id": str(t.id),
"name": t.name,
"technique_id": str(t.technique_id),
"state": t.state.value if hasattr(t.state, "value") else str(t.state),
"platform": t.platform,
"attack_success": t.attack_success,
"detection_result": (
t.detection_result.value if t.detection_result and hasattr(t.detection_result, "value")
else str(t.detection_result) if t.detection_result else None
),
"red_validation_status": t.red_validation_status,
"blue_validation_status": t.blue_validation_status,
"created_at": t.created_at.isoformat() if t.created_at else None,
}
for t in tests
],
}
def build_remediation_status_report(
db: Session,
*,
status: str | None = None,
) -> dict:
"""Build a remediation status report."""
query = db.query(Test).filter(Test.remediation_steps.isnot(None))
if status:
query = query.filter(Test.remediation_status == status)
tests = query.order_by(Test.created_at.desc()).all()
by_status: dict[str, int] = {}
for t in tests:
s = t.remediation_status or "unset"
by_status[s] = by_status.get(s, 0) + 1
return {
"generated_at": datetime.utcnow().isoformat(),
"summary": {
"total_with_remediation": len(tests),
"by_status": by_status,
},
"tests": [
{
"id": str(t.id),
"name": t.name,
"technique_id": str(t.technique_id),
"state": t.state.value if hasattr(t.state, "value") else str(t.state),
"remediation_status": t.remediation_status,
"remediation_steps": t.remediation_steps,
"remediation_assignee": str(t.remediation_assignee) if t.remediation_assignee else None,
}
for t in tests
],
}