"""Coverage report data service. Extracts query and aggregation logic from the reports router so that the router remains a thin HTTP adapter. Fixes the N+1 technique/test-count pattern by using a single grouped query. This module is framework-agnostic: no FastAPI imports. """ # Enable future language features for compatibility from __future__ import annotations # Import datetime from datetime from datetime import datetime # Import func from sqlalchemy from sqlalchemy import func # Import Session from sqlalchemy.orm from sqlalchemy.orm import Session # Import Technique from app.models.technique from app.models.technique import Technique # Import Test from app.models.test from app.models.test import Test # Import escape_like from app.utils from app.utils import escape_like # Define function _technique_test_counts def _technique_test_counts( # Entry: db db: Session, # Entry: technique_ids technique_ids: list, ) -> dict: """Return ``{technique_id: {state_str: count}}`` in a single query.""" # Check: not technique_ids if not technique_ids: # Return {} return {} # Assign rows = ( rows = ( db.query(Test.technique_id, Test.state, func.count(Test.id)) # Chain .filter() call .filter(Test.technique_id.in_(technique_ids)) # Chain .group_by() call .group_by(Test.technique_id, Test.state) # Chain .all() call .all() ) # Assign result = {} result: dict = {} # Iterate over rows for tid, state, count in rows: # Call result.setdefault() result.setdefault(tid, {})[str(state)] = count # Return result return result # Define function build_coverage_summary def build_coverage_summary( # Entry: db db: Session, *, # Entry: tactic tactic: str | None = None, # Entry: platform platform: str | None = None, ) -> dict: """Build the full coverage summary report as a dict.""" # Assign query = db.query(Technique) query = db.query(Technique) # Check: tactic if tactic: # Assign query = query.filter(Technique.tactic.ilike(f"%{escape_like(tactic)}%")) query = query.filter(Technique.tactic.ilike(f"%{escape_like(tactic)}%")) # Assign techniques = query.order_by(Technique.mitre_id).all() techniques = query.order_by(Technique.mitre_id).all() # Assign counts_map = _technique_test_counts(db, [t.id for t in techniques]) counts_map = _technique_test_counts(db, [t.id for t in techniques]) # Assign rows = [] rows = [] # Iterate over techniques for t in techniques: # Check: platform and platform.lower() not in [p.lower() for p in (t.platfor... if platform and platform.lower() not in [p.lower() for p in (t.platforms or [])]: # Skip to the next loop iteration continue # Assign counts = counts_map.get(t.id, {}) counts = counts_map.get(t.id, {}) # Call rows.append() rows.append({ # Literal argument value "mitre_id": t.mitre_id, # Literal argument value "name": t.name, # Literal argument value "tactic": t.tactic, # Literal argument value "platforms": t.platforms, # Literal argument value "status_global": t.status_global, # Literal argument value "total_tests": sum(counts.values()), # Literal argument value "tests_by_state": counts, }) # Assign total = len(rows) total = len(rows) # Assign validated = sum(1 for r in rows if r["status_global"] == "validated") validated = sum(1 for r in rows if r["status_global"] == "validated") # Assign partial = sum(1 for r in rows if r["status_global"] == "partial") partial = sum(1 for r in rows if r["status_global"] == "partial") # Assign not_covered = sum(1 for r in rows if r["status_global"] == "not_covered") not_covered = sum(1 for r in rows if r["status_global"] == "not_covered") # Assign in_progress = sum(1 for r in rows if r["status_global"] == "in_progress") in_progress = sum(1 for r in rows if r["status_global"] == "in_progress") # Assign not_evaluated = sum(1 for r in rows if r["status_global"] == "not_evaluated") not_evaluated = sum(1 for r in rows if r["status_global"] == "not_evaluated") # Return { return { # Literal argument value "generated_at": datetime.utcnow().isoformat(), # Literal argument value "summary": { # Literal argument value "total_techniques": total, # Literal argument value "validated": validated, # Literal argument value "partial": partial, # Literal argument value "not_covered": not_covered, # Literal argument value "in_progress": in_progress, # Literal argument value "not_evaluated": not_evaluated, # Literal argument value "coverage_percentage": round((validated / total * 100) if total > 0 else 0, 1), }, # Literal argument value "techniques": rows, } # Define function build_coverage_csv_rows def build_coverage_csv_rows( # Entry: db db: Session, *, # Entry: tactic tactic: str | None = None, # Entry: platform platform: str | None = None, ) -> list[list]: """Build rows for a CSV coverage export (header + data).""" # Assign query = db.query(Technique) query = db.query(Technique) # Check: tactic if tactic: # Assign query = query.filter(Technique.tactic.ilike(f"%{escape_like(tactic)}%")) query = query.filter(Technique.tactic.ilike(f"%{escape_like(tactic)}%")) # Assign techniques = query.order_by(Technique.mitre_id).all() techniques = query.order_by(Technique.mitre_id).all() # Assign counts_map = _technique_test_counts(db, [t.id for t in techniques]) counts_map = _technique_test_counts(db, [t.id for t in techniques]) # Assign header = [ header = [ # Literal argument value "MITRE ID", "Name", "Tactic", "Platforms", "Status", # Literal argument value "Total Tests", "Validated", "In Progress", "Not Covered", ] # Assign rows = [header] rows = [header] # Assign in_progress_states = {"draft", "red_executing", "blue_evaluating", "in_review"} in_progress_states = {"draft", "red_executing", "blue_evaluating", "in_review"} # Iterate over techniques for t in techniques: # Check: platform and platform.lower() not in [p.lower() for p in (t.platfor... if platform and platform.lower() not in [p.lower() for p in (t.platforms or [])]: # Skip to the next loop iteration continue # Assign counts = counts_map.get(t.id, {}) counts = counts_map.get(t.id, {}) # Call rows.append() rows.append([ t.mitre_id, t.name, t.tactic, # Literal argument value ", ".join(t.platforms or []), t.status_global, sum(counts.values()), counts.get("validated", 0), sum(counts.get(s, 0) for s in in_progress_states), counts.get("rejected", 0), ]) # Return rows return rows # Define function build_test_results_report def build_test_results_report( # Entry: db db: Session, *, # Entry: state state: str | None = None, # Entry: date_from date_from: str | None = None, # Entry: date_to date_to: str | None = None, ) -> dict: """Build a test results report with optional filters.""" # Assign query = db.query(Test) query = db.query(Test) # Check: state if state: # Assign query = query.filter(Test.state == state) query = query.filter(Test.state == state) # Check: date_from if date_from: # Attempt the following; catch errors below try: # Assign query = query.filter(Test.created_at >= datetime.fromisoformat(date_from)) query = query.filter(Test.created_at >= datetime.fromisoformat(date_from)) # Handle ValueError except ValueError: # Intentional no-op placeholder pass # Check: date_to if date_to: # Attempt the following; catch errors below try: # Assign query = query.filter(Test.created_at <= datetime.fromisoformat(date_to)) query = query.filter(Test.created_at <= datetime.fromisoformat(date_to)) # Handle ValueError except ValueError: # Intentional no-op placeholder pass # Assign tests = query.order_by(Test.created_at.desc()).all() tests = query.order_by(Test.created_at.desc()).all() # Assign by_state = {} by_state: dict[str, int] = {} # Assign by_result = {} by_result: dict[str, int] = {} # Iterate over tests for t in tests: # Assign s = t.state.value if hasattr(t.state, "value") else str(t.state) s = t.state.value if hasattr(t.state, "value") else str(t.state) # Assign by_state[s] = by_state.get(s, 0) + 1 by_state[s] = by_state.get(s, 0) + 1 # Check: t.detection_result if t.detection_result: # Assign r = t.detection_result.value if hasattr(t.detection_result, "value") el... r = t.detection_result.value if hasattr(t.detection_result, "value") else str(t.detection_result) # Assign by_result[r] = by_result.get(r, 0) + 1 by_result[r] = by_result.get(r, 0) + 1 # Return { return { # Literal argument value "generated_at": datetime.utcnow().isoformat(), # Literal argument value "filters": {"state": state, "date_from": date_from, "date_to": date_to}, # Literal argument value "summary": { # Literal argument value "total_tests": len(tests), # Literal argument value "by_state": by_state, # Literal argument value "by_detection_result": by_result, }, # Literal argument value "tests": [ { # Literal argument value "id": str(t.id), # Literal argument value "name": t.name, # Literal argument value "technique_id": str(t.technique_id), # Literal argument value "state": t.state.value if hasattr(t.state, "value") else str(t.state), # Literal argument value "platform": t.platform, # Literal argument value "attack_success": t.attack_success, # Literal argument value "detection_result": ( t.detection_result.value if t.detection_result and hasattr(t.detection_result, "value") else str(t.detection_result) if t.detection_result else None ), # Literal argument value "red_validation_status": t.red_validation_status, # Literal argument value "blue_validation_status": t.blue_validation_status, # Literal argument value "created_at": t.created_at.isoformat() if t.created_at else None, } for t in tests ], } # Define function build_remediation_status_report def build_remediation_status_report( # Entry: db db: Session, *, # Entry: status status: str | None = None, ) -> dict: """Build a remediation status report.""" # Assign query = db.query(Test).filter(Test.remediation_steps.isnot(None)) query = db.query(Test).filter(Test.remediation_steps.isnot(None)) # Check: status if status: # Assign query = query.filter(Test.remediation_status == status) query = query.filter(Test.remediation_status == status) # Assign tests = query.order_by(Test.created_at.desc()).all() tests = query.order_by(Test.created_at.desc()).all() # Assign by_status = {} by_status: dict[str, int] = {} # Iterate over tests for t in tests: # Assign s = t.remediation_status or "unset" s = t.remediation_status or "unset" # Assign by_status[s] = by_status.get(s, 0) + 1 by_status[s] = by_status.get(s, 0) + 1 # Return { return { # Literal argument value "generated_at": datetime.utcnow().isoformat(), # Literal argument value "summary": { # Literal argument value "total_with_remediation": len(tests), # Literal argument value "by_status": by_status, }, # Literal argument value "tests": [ { # Literal argument value "id": str(t.id), # Literal argument value "name": t.name, # Literal argument value "technique_id": str(t.technique_id), # Literal argument value "state": t.state.value if hasattr(t.state, "value") else str(t.state), # Literal argument value "remediation_status": t.remediation_status, # Literal argument value "remediation_steps": t.remediation_steps, # Literal argument value "remediation_assignee": str(t.remediation_assignee) if t.remediation_assignee else None, } for t in tests ], }