diff --git a/backend/app/routers/reports.py b/backend/app/routers/reports.py index 9a552e5..d065116 100644 --- a/backend/app/routers/reports.py +++ b/backend/app/routers/reports.py @@ -1,5 +1,7 @@ """Reports endpoints — export coverage summaries and test results. +Thin HTTP adapter: delegates all data logic to coverage_report_service. + Endpoints --------- GET /reports/coverage-summary — full coverage JSON report @@ -15,24 +17,21 @@ from typing import Optional from fastapi import APIRouter, Depends, Query from fastapi.responses import StreamingResponse -from sqlalchemy import func from sqlalchemy.orm import Session from app.database import get_db from app.dependencies.auth import get_current_user -from app.models.enums import TestState -from app.models.technique import Technique -from app.models.test import Test from app.models.user import User +from app.services.coverage_report_service import ( + build_coverage_csv_rows, + build_coverage_summary, + build_remediation_status_report, + build_test_results_report, +) router = APIRouter(prefix="/reports", tags=["reports"]) -# --------------------------------------------------------------------------- -# GET /reports/coverage-summary -# --------------------------------------------------------------------------- - - @router.get("/coverage-summary") def coverage_summary( tactic: Optional[str] = Query(None, description="Filter by tactic"), @@ -41,63 +40,7 @@ def coverage_summary( current_user: User = Depends(get_current_user), ): """Full coverage report as JSON — technique-by-technique with test counts.""" - query = db.query(Technique) - if tactic: - from app.utils import escape_like - query = query.filter(Technique.tactic.ilike(f"%{escape_like(tactic)}%")) - - techniques = query.order_by(Technique.mitre_id).all() - - rows = [] - for t in techniques: - # Count tests per state for this technique - test_counts = ( - db.query(Test.state, func.count(Test.id)) - .filter(Test.technique_id == t.id) - .group_by(Test.state) - .all() - ) - counts = {str(state): count for state, count in test_counts} - - # Filter by platform if requested (check if technique platforms contain it) - if platform and platform.lower() not in [p.lower() for p in (t.platforms or [])]: - continue - - rows.append({ - "mitre_id": t.mitre_id, - "name": t.name, - "tactic": t.tactic, - "platforms": t.platforms, - "status_global": t.status_global, - "total_tests": sum(counts.values()), - "tests_by_state": counts, - }) - - total = len(rows) - validated = sum(1 for r in rows if r["status_global"] == "validated") - partial = sum(1 for r in rows if r["status_global"] == "partial") - not_covered = sum(1 for r in rows if r["status_global"] == "not_covered") - in_progress = sum(1 for r in rows if r["status_global"] == "in_progress") - not_evaluated = sum(1 for r in rows if r["status_global"] == "not_evaluated") - - return { - "generated_at": datetime.utcnow().isoformat(), - "summary": { - "total_techniques": total, - "validated": validated, - "partial": partial, - "not_covered": not_covered, - "in_progress": in_progress, - "not_evaluated": not_evaluated, - "coverage_percentage": round((validated / total * 100) if total > 0 else 0, 1), - }, - "techniques": rows, - } - - -# --------------------------------------------------------------------------- -# GET /reports/coverage-csv -# --------------------------------------------------------------------------- + return build_coverage_summary(db, tactic=tactic, platform=platform) @router.get("/coverage-csv") @@ -108,57 +51,22 @@ def coverage_csv( current_user: User = Depends(get_current_user), ): """Export coverage as a downloadable CSV.""" - query = db.query(Technique) - if tactic: - from app.utils import escape_like - query = query.filter(Technique.tactic.ilike(f"%{escape_like(tactic)}%")) - - techniques = query.order_by(Technique.mitre_id).all() + rows = build_coverage_csv_rows(db, tactic=tactic, platform=platform) output = io.StringIO() writer = csv.writer(output) - writer.writerow([ - "MITRE ID", "Name", "Tactic", "Platforms", "Status", - "Total Tests", "Validated", "In Progress", "Not Covered", - ]) - - for t in techniques: - if platform and platform.lower() not in [p.lower() for p in (t.platforms or [])]: - continue - - test_counts = ( - db.query(Test.state, func.count(Test.id)) - .filter(Test.technique_id == t.id) - .group_by(Test.state) - .all() - ) - counts = {str(state): count for state, count in test_counts} - - writer.writerow([ - t.mitre_id, - t.name, - t.tactic, - ", ".join(t.platforms or []), - t.status_global, - sum(counts.values()), - counts.get("validated", 0), - sum(counts.get(s, 0) for s in ["draft", "red_executing", "blue_evaluating", "in_review"]), - counts.get("rejected", 0), - ]) + for row in rows: + writer.writerow(row) output.seek(0) + filename = f"aegis_coverage_{datetime.utcnow().strftime('%Y%m%d')}.csv" return StreamingResponse( iter([output.getvalue()]), media_type="text/csv", - headers={"Content-Disposition": f"attachment; filename=aegis_coverage_{datetime.utcnow().strftime('%Y%m%d')}.csv"}, + headers={"Content-Disposition": f"attachment; filename={filename}"}, ) -# --------------------------------------------------------------------------- -# GET /reports/test-results -# --------------------------------------------------------------------------- - - @router.get("/test-results") def test_results( state: Optional[str] = Query(None), @@ -168,68 +76,7 @@ def test_results( current_user: User = Depends(get_current_user), ): """Report of test results with optional filters.""" - query = db.query(Test) - - if state: - query = query.filter(Test.state == state) - if date_from: - try: - dt = datetime.fromisoformat(date_from) - query = query.filter(Test.created_at >= dt) - except ValueError: - pass - if date_to: - try: - dt = datetime.fromisoformat(date_to) - query = query.filter(Test.created_at <= dt) - except ValueError: - pass - - tests = query.order_by(Test.created_at.desc()).all() - - # Summary - total = len(tests) - by_state = {} - by_result = {} - for t in tests: - s = t.state.value if hasattr(t.state, "value") else str(t.state) - by_state[s] = by_state.get(s, 0) + 1 - if t.detection_result: - r = t.detection_result.value if hasattr(t.detection_result, "value") else str(t.detection_result) - by_result[r] = by_result.get(r, 0) + 1 - - return { - "generated_at": datetime.utcnow().isoformat(), - "filters": {"state": state, "date_from": date_from, "date_to": date_to}, - "summary": { - "total_tests": total, - "by_state": by_state, - "by_detection_result": by_result, - }, - "tests": [ - { - "id": str(t.id), - "name": t.name, - "technique_id": str(t.technique_id), - "state": t.state.value if hasattr(t.state, "value") else str(t.state), - "platform": t.platform, - "attack_success": t.attack_success, - "detection_result": ( - t.detection_result.value if t.detection_result and hasattr(t.detection_result, "value") - else str(t.detection_result) if t.detection_result else None - ), - "red_validation_status": t.red_validation_status, - "blue_validation_status": t.blue_validation_status, - "created_at": t.created_at.isoformat() if t.created_at else None, - } - for t in tests - ], - } - - -# --------------------------------------------------------------------------- -# GET /reports/remediation-status -# --------------------------------------------------------------------------- + return build_test_results_report(db, state=state, date_from=date_from, date_to=date_to) @router.get("/remediation-status") @@ -239,34 +86,4 @@ def remediation_status( current_user: User = Depends(get_current_user), ): """Report of remediation status across all tests.""" - query = db.query(Test).filter(Test.remediation_steps.isnot(None)) - - if status: - query = query.filter(Test.remediation_status == status) - - tests = query.order_by(Test.created_at.desc()).all() - - by_status = {} - for t in tests: - s = t.remediation_status or "unset" - by_status[s] = by_status.get(s, 0) + 1 - - return { - "generated_at": datetime.utcnow().isoformat(), - "summary": { - "total_with_remediation": len(tests), - "by_status": by_status, - }, - "tests": [ - { - "id": str(t.id), - "name": t.name, - "technique_id": str(t.technique_id), - "state": t.state.value if hasattr(t.state, "value") else str(t.state), - "remediation_status": t.remediation_status, - "remediation_steps": t.remediation_steps, - "remediation_assignee": str(t.remediation_assignee) if t.remediation_assignee else None, - } - for t in tests - ], - } + return build_remediation_status_report(db, status=status) diff --git a/backend/app/services/coverage_report_service.py b/backend/app/services/coverage_report_service.py new file mode 100644 index 0000000..c33f1df --- /dev/null +++ b/backend/app/services/coverage_report_service.py @@ -0,0 +1,234 @@ +"""Coverage report data service. + +Extracts query and aggregation logic from the reports router so +that the router remains a thin HTTP adapter. Fixes the N+1 +technique/test-count pattern by using a single grouped query. + +This module is framework-agnostic: no FastAPI imports. +""" + +from __future__ import annotations + +from datetime import datetime + +from sqlalchemy import func +from sqlalchemy.orm import Session + +from app.models.technique import Technique +from app.models.test import Test +from app.utils import escape_like + + +def _technique_test_counts( + db: Session, + technique_ids: list, +) -> dict: + """Return ``{technique_id: {state_str: count}}`` in a single query.""" + if not technique_ids: + return {} + rows = ( + db.query(Test.technique_id, Test.state, func.count(Test.id)) + .filter(Test.technique_id.in_(technique_ids)) + .group_by(Test.technique_id, Test.state) + .all() + ) + result: dict = {} + for tid, state, count in rows: + result.setdefault(tid, {})[str(state)] = count + return result + + +def build_coverage_summary( + db: Session, + *, + tactic: str | None = None, + platform: str | None = None, +) -> dict: + """Build the full coverage summary report as a dict.""" + query = db.query(Technique) + if tactic: + query = query.filter(Technique.tactic.ilike(f"%{escape_like(tactic)}%")) + + techniques = query.order_by(Technique.mitre_id).all() + counts_map = _technique_test_counts(db, [t.id for t in techniques]) + + rows = [] + for t in techniques: + if platform and platform.lower() not in [p.lower() for p in (t.platforms or [])]: + continue + + counts = counts_map.get(t.id, {}) + rows.append({ + "mitre_id": t.mitre_id, + "name": t.name, + "tactic": t.tactic, + "platforms": t.platforms, + "status_global": t.status_global, + "total_tests": sum(counts.values()), + "tests_by_state": counts, + }) + + total = len(rows) + validated = sum(1 for r in rows if r["status_global"] == "validated") + partial = sum(1 for r in rows if r["status_global"] == "partial") + not_covered = sum(1 for r in rows if r["status_global"] == "not_covered") + in_progress = sum(1 for r in rows if r["status_global"] == "in_progress") + not_evaluated = sum(1 for r in rows if r["status_global"] == "not_evaluated") + + return { + "generated_at": datetime.utcnow().isoformat(), + "summary": { + "total_techniques": total, + "validated": validated, + "partial": partial, + "not_covered": not_covered, + "in_progress": in_progress, + "not_evaluated": not_evaluated, + "coverage_percentage": round((validated / total * 100) if total > 0 else 0, 1), + }, + "techniques": rows, + } + + +def build_coverage_csv_rows( + db: Session, + *, + tactic: str | None = None, + platform: str | None = None, +) -> list[list]: + """Build rows for a CSV coverage export (header + data).""" + query = db.query(Technique) + if tactic: + query = query.filter(Technique.tactic.ilike(f"%{escape_like(tactic)}%")) + + techniques = query.order_by(Technique.mitre_id).all() + counts_map = _technique_test_counts(db, [t.id for t in techniques]) + + header = [ + "MITRE ID", "Name", "Tactic", "Platforms", "Status", + "Total Tests", "Validated", "In Progress", "Not Covered", + ] + rows = [header] + + in_progress_states = {"draft", "red_executing", "blue_evaluating", "in_review"} + + for t in techniques: + if platform and platform.lower() not in [p.lower() for p in (t.platforms or [])]: + continue + + counts = counts_map.get(t.id, {}) + rows.append([ + t.mitre_id, + t.name, + t.tactic, + ", ".join(t.platforms or []), + t.status_global, + sum(counts.values()), + counts.get("validated", 0), + sum(counts.get(s, 0) for s in in_progress_states), + counts.get("rejected", 0), + ]) + + return rows + + +def build_test_results_report( + db: Session, + *, + state: str | None = None, + date_from: str | None = None, + date_to: str | None = None, +) -> dict: + """Build a test results report with optional filters.""" + query = db.query(Test) + + if state: + query = query.filter(Test.state == state) + if date_from: + try: + query = query.filter(Test.created_at >= datetime.fromisoformat(date_from)) + except ValueError: + pass + if date_to: + try: + query = query.filter(Test.created_at <= datetime.fromisoformat(date_to)) + except ValueError: + pass + + tests = query.order_by(Test.created_at.desc()).all() + + by_state: dict[str, int] = {} + by_result: dict[str, int] = {} + for t in tests: + s = t.state.value if hasattr(t.state, "value") else str(t.state) + by_state[s] = by_state.get(s, 0) + 1 + if t.detection_result: + r = t.detection_result.value if hasattr(t.detection_result, "value") else str(t.detection_result) + by_result[r] = by_result.get(r, 0) + 1 + + return { + "generated_at": datetime.utcnow().isoformat(), + "filters": {"state": state, "date_from": date_from, "date_to": date_to}, + "summary": { + "total_tests": len(tests), + "by_state": by_state, + "by_detection_result": by_result, + }, + "tests": [ + { + "id": str(t.id), + "name": t.name, + "technique_id": str(t.technique_id), + "state": t.state.value if hasattr(t.state, "value") else str(t.state), + "platform": t.platform, + "attack_success": t.attack_success, + "detection_result": ( + t.detection_result.value if t.detection_result and hasattr(t.detection_result, "value") + else str(t.detection_result) if t.detection_result else None + ), + "red_validation_status": t.red_validation_status, + "blue_validation_status": t.blue_validation_status, + "created_at": t.created_at.isoformat() if t.created_at else None, + } + for t in tests + ], + } + + +def build_remediation_status_report( + db: Session, + *, + status: str | None = None, +) -> dict: + """Build a remediation status report.""" + query = db.query(Test).filter(Test.remediation_steps.isnot(None)) + + if status: + query = query.filter(Test.remediation_status == status) + + tests = query.order_by(Test.created_at.desc()).all() + + by_status: dict[str, int] = {} + for t in tests: + s = t.remediation_status or "unset" + by_status[s] = by_status.get(s, 0) + 1 + + return { + "generated_at": datetime.utcnow().isoformat(), + "summary": { + "total_with_remediation": len(tests), + "by_status": by_status, + }, + "tests": [ + { + "id": str(t.id), + "name": t.name, + "technique_id": str(t.technique_id), + "state": t.state.value if hasattr(t.state, "value") else str(t.state), + "remediation_status": t.remediation_status, + "remediation_steps": t.remediation_steps, + "remediation_assignee": str(t.remediation_assignee) if t.remediation_assignee else None, + } + for t in tests + ], + }