"""Coverage report data service. Extracts query and aggregation logic from the reports router so that the router remains a thin HTTP adapter. Fixes the N+1 technique/test-count pattern by using a single grouped query. This module is framework-agnostic: no FastAPI imports. """ from __future__ import annotations from datetime import datetime from sqlalchemy import func from sqlalchemy.orm import Session from app.models.technique import Technique from app.models.test import Test from app.utils import escape_like def _technique_test_counts( db: Session, technique_ids: list, ) -> dict: """Return ``{technique_id: {state_str: count}}`` in a single query.""" if not technique_ids: return {} rows = ( db.query(Test.technique_id, Test.state, func.count(Test.id)) .filter(Test.technique_id.in_(technique_ids)) .group_by(Test.technique_id, Test.state) .all() ) result: dict = {} for tid, state, count in rows: result.setdefault(tid, {})[str(state)] = count return result def build_coverage_summary( db: Session, *, tactic: str | None = None, platform: str | None = None, ) -> dict: """Build the full coverage summary report as a dict.""" query = db.query(Technique) if tactic: query = query.filter(Technique.tactic.ilike(f"%{escape_like(tactic)}%")) techniques = query.order_by(Technique.mitre_id).all() counts_map = _technique_test_counts(db, [t.id for t in techniques]) rows = [] for t in techniques: if platform and platform.lower() not in [p.lower() for p in (t.platforms or [])]: continue counts = counts_map.get(t.id, {}) rows.append({ "mitre_id": t.mitre_id, "name": t.name, "tactic": t.tactic, "platforms": t.platforms, "status_global": t.status_global, "total_tests": sum(counts.values()), "tests_by_state": counts, }) total = len(rows) validated = sum(1 for r in rows if r["status_global"] == "validated") partial = sum(1 for r in rows if r["status_global"] == "partial") not_covered = sum(1 for r in rows if r["status_global"] == "not_covered") in_progress = sum(1 for r in rows if r["status_global"] == "in_progress") not_evaluated = sum(1 for r in rows if r["status_global"] == "not_evaluated") return { "generated_at": datetime.utcnow().isoformat(), "summary": { "total_techniques": total, "validated": validated, "partial": partial, "not_covered": not_covered, "in_progress": in_progress, "not_evaluated": not_evaluated, "coverage_percentage": round((validated / total * 100) if total > 0 else 0, 1), }, "techniques": rows, } def build_coverage_csv_rows( db: Session, *, tactic: str | None = None, platform: str | None = None, ) -> list[list]: """Build rows for a CSV coverage export (header + data).""" query = db.query(Technique) if tactic: query = query.filter(Technique.tactic.ilike(f"%{escape_like(tactic)}%")) techniques = query.order_by(Technique.mitre_id).all() counts_map = _technique_test_counts(db, [t.id for t in techniques]) header = [ "MITRE ID", "Name", "Tactic", "Platforms", "Status", "Total Tests", "Validated", "In Progress", "Not Covered", ] rows = [header] in_progress_states = {"draft", "red_executing", "blue_evaluating", "in_review"} for t in techniques: if platform and platform.lower() not in [p.lower() for p in (t.platforms or [])]: continue counts = counts_map.get(t.id, {}) rows.append([ t.mitre_id, t.name, t.tactic, ", ".join(t.platforms or []), t.status_global, sum(counts.values()), counts.get("validated", 0), sum(counts.get(s, 0) for s in in_progress_states), counts.get("rejected", 0), ]) return rows def build_test_results_report( db: Session, *, state: str | None = None, date_from: str | None = None, date_to: str | None = None, ) -> dict: """Build a test results report with optional filters.""" query = db.query(Test) if state: query = query.filter(Test.state == state) if date_from: try: query = query.filter(Test.created_at >= datetime.fromisoformat(date_from)) except ValueError: pass if date_to: try: query = query.filter(Test.created_at <= datetime.fromisoformat(date_to)) except ValueError: pass tests = query.order_by(Test.created_at.desc()).all() by_state: dict[str, int] = {} by_result: dict[str, int] = {} for t in tests: s = t.state.value if hasattr(t.state, "value") else str(t.state) by_state[s] = by_state.get(s, 0) + 1 if t.detection_result: r = t.detection_result.value if hasattr(t.detection_result, "value") else str(t.detection_result) by_result[r] = by_result.get(r, 0) + 1 return { "generated_at": datetime.utcnow().isoformat(), "filters": {"state": state, "date_from": date_from, "date_to": date_to}, "summary": { "total_tests": len(tests), "by_state": by_state, "by_detection_result": by_result, }, "tests": [ { "id": str(t.id), "name": t.name, "technique_id": str(t.technique_id), "state": t.state.value if hasattr(t.state, "value") else str(t.state), "platform": t.platform, "attack_success": t.attack_success, "detection_result": ( t.detection_result.value if t.detection_result and hasattr(t.detection_result, "value") else str(t.detection_result) if t.detection_result else None ), "red_validation_status": t.red_validation_status, "blue_validation_status": t.blue_validation_status, "created_at": t.created_at.isoformat() if t.created_at else None, } for t in tests ], } def build_remediation_status_report( db: Session, *, status: str | None = None, ) -> dict: """Build a remediation status report.""" query = db.query(Test).filter(Test.remediation_steps.isnot(None)) if status: query = query.filter(Test.remediation_status == status) tests = query.order_by(Test.created_at.desc()).all() by_status: dict[str, int] = {} for t in tests: s = t.remediation_status or "unset" by_status[s] = by_status.get(s, 0) + 1 return { "generated_at": datetime.utcnow().isoformat(), "summary": { "total_with_remediation": len(tests), "by_status": by_status, }, "tests": [ { "id": str(t.id), "name": t.name, "technique_id": str(t.technique_id), "state": t.state.value if hasattr(t.state, "value") else str(t.state), "remediation_status": t.remediation_status, "remediation_steps": t.remediation_steps, "remediation_assignee": str(t.remediation_assignee) if t.remediation_assignee else None, } for t in tests ], }