235 lines
7.4 KiB
Python
235 lines
7.4 KiB
Python
"""Coverage report data service.
|
|
|
|
Extracts query and aggregation logic from the reports router so
|
|
that the router remains a thin HTTP adapter. Fixes the N+1
|
|
technique/test-count pattern by using a single grouped query.
|
|
|
|
This module is framework-agnostic: no FastAPI imports.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from datetime import datetime
|
|
|
|
from sqlalchemy import func
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.models.technique import Technique
|
|
from app.models.test import Test
|
|
from app.utils import escape_like
|
|
|
|
|
|
def _technique_test_counts(
|
|
db: Session,
|
|
technique_ids: list,
|
|
) -> dict:
|
|
"""Return ``{technique_id: {state_str: count}}`` in a single query."""
|
|
if not technique_ids:
|
|
return {}
|
|
rows = (
|
|
db.query(Test.technique_id, Test.state, func.count(Test.id))
|
|
.filter(Test.technique_id.in_(technique_ids))
|
|
.group_by(Test.technique_id, Test.state)
|
|
.all()
|
|
)
|
|
result: dict = {}
|
|
for tid, state, count in rows:
|
|
result.setdefault(tid, {})[str(state)] = count
|
|
return result
|
|
|
|
|
|
def build_coverage_summary(
|
|
db: Session,
|
|
*,
|
|
tactic: str | None = None,
|
|
platform: str | None = None,
|
|
) -> dict:
|
|
"""Build the full coverage summary report as a dict."""
|
|
query = db.query(Technique)
|
|
if tactic:
|
|
query = query.filter(Technique.tactic.ilike(f"%{escape_like(tactic)}%"))
|
|
|
|
techniques = query.order_by(Technique.mitre_id).all()
|
|
counts_map = _technique_test_counts(db, [t.id for t in techniques])
|
|
|
|
rows = []
|
|
for t in techniques:
|
|
if platform and platform.lower() not in [p.lower() for p in (t.platforms or [])]:
|
|
continue
|
|
|
|
counts = counts_map.get(t.id, {})
|
|
rows.append({
|
|
"mitre_id": t.mitre_id,
|
|
"name": t.name,
|
|
"tactic": t.tactic,
|
|
"platforms": t.platforms,
|
|
"status_global": t.status_global,
|
|
"total_tests": sum(counts.values()),
|
|
"tests_by_state": counts,
|
|
})
|
|
|
|
total = len(rows)
|
|
validated = sum(1 for r in rows if r["status_global"] == "validated")
|
|
partial = sum(1 for r in rows if r["status_global"] == "partial")
|
|
not_covered = sum(1 for r in rows if r["status_global"] == "not_covered")
|
|
in_progress = sum(1 for r in rows if r["status_global"] == "in_progress")
|
|
not_evaluated = sum(1 for r in rows if r["status_global"] == "not_evaluated")
|
|
|
|
return {
|
|
"generated_at": datetime.utcnow().isoformat(),
|
|
"summary": {
|
|
"total_techniques": total,
|
|
"validated": validated,
|
|
"partial": partial,
|
|
"not_covered": not_covered,
|
|
"in_progress": in_progress,
|
|
"not_evaluated": not_evaluated,
|
|
"coverage_percentage": round((validated / total * 100) if total > 0 else 0, 1),
|
|
},
|
|
"techniques": rows,
|
|
}
|
|
|
|
|
|
def build_coverage_csv_rows(
|
|
db: Session,
|
|
*,
|
|
tactic: str | None = None,
|
|
platform: str | None = None,
|
|
) -> list[list]:
|
|
"""Build rows for a CSV coverage export (header + data)."""
|
|
query = db.query(Technique)
|
|
if tactic:
|
|
query = query.filter(Technique.tactic.ilike(f"%{escape_like(tactic)}%"))
|
|
|
|
techniques = query.order_by(Technique.mitre_id).all()
|
|
counts_map = _technique_test_counts(db, [t.id for t in techniques])
|
|
|
|
header = [
|
|
"MITRE ID", "Name", "Tactic", "Platforms", "Status",
|
|
"Total Tests", "Validated", "In Progress", "Not Covered",
|
|
]
|
|
rows = [header]
|
|
|
|
in_progress_states = {"draft", "red_executing", "blue_evaluating", "in_review"}
|
|
|
|
for t in techniques:
|
|
if platform and platform.lower() not in [p.lower() for p in (t.platforms or [])]:
|
|
continue
|
|
|
|
counts = counts_map.get(t.id, {})
|
|
rows.append([
|
|
t.mitre_id,
|
|
t.name,
|
|
t.tactic,
|
|
", ".join(t.platforms or []),
|
|
t.status_global,
|
|
sum(counts.values()),
|
|
counts.get("validated", 0),
|
|
sum(counts.get(s, 0) for s in in_progress_states),
|
|
counts.get("rejected", 0),
|
|
])
|
|
|
|
return rows
|
|
|
|
|
|
def build_test_results_report(
|
|
db: Session,
|
|
*,
|
|
state: str | None = None,
|
|
date_from: str | None = None,
|
|
date_to: str | None = None,
|
|
) -> dict:
|
|
"""Build a test results report with optional filters."""
|
|
query = db.query(Test)
|
|
|
|
if state:
|
|
query = query.filter(Test.state == state)
|
|
if date_from:
|
|
try:
|
|
query = query.filter(Test.created_at >= datetime.fromisoformat(date_from))
|
|
except ValueError:
|
|
pass
|
|
if date_to:
|
|
try:
|
|
query = query.filter(Test.created_at <= datetime.fromisoformat(date_to))
|
|
except ValueError:
|
|
pass
|
|
|
|
tests = query.order_by(Test.created_at.desc()).all()
|
|
|
|
by_state: dict[str, int] = {}
|
|
by_result: dict[str, int] = {}
|
|
for t in tests:
|
|
s = t.state.value if hasattr(t.state, "value") else str(t.state)
|
|
by_state[s] = by_state.get(s, 0) + 1
|
|
if t.detection_result:
|
|
r = t.detection_result.value if hasattr(t.detection_result, "value") else str(t.detection_result)
|
|
by_result[r] = by_result.get(r, 0) + 1
|
|
|
|
return {
|
|
"generated_at": datetime.utcnow().isoformat(),
|
|
"filters": {"state": state, "date_from": date_from, "date_to": date_to},
|
|
"summary": {
|
|
"total_tests": len(tests),
|
|
"by_state": by_state,
|
|
"by_detection_result": by_result,
|
|
},
|
|
"tests": [
|
|
{
|
|
"id": str(t.id),
|
|
"name": t.name,
|
|
"technique_id": str(t.technique_id),
|
|
"state": t.state.value if hasattr(t.state, "value") else str(t.state),
|
|
"platform": t.platform,
|
|
"attack_success": t.attack_success,
|
|
"detection_result": (
|
|
t.detection_result.value if t.detection_result and hasattr(t.detection_result, "value")
|
|
else str(t.detection_result) if t.detection_result else None
|
|
),
|
|
"red_validation_status": t.red_validation_status,
|
|
"blue_validation_status": t.blue_validation_status,
|
|
"created_at": t.created_at.isoformat() if t.created_at else None,
|
|
}
|
|
for t in tests
|
|
],
|
|
}
|
|
|
|
|
|
def build_remediation_status_report(
|
|
db: Session,
|
|
*,
|
|
status: str | None = None,
|
|
) -> dict:
|
|
"""Build a remediation status report."""
|
|
query = db.query(Test).filter(Test.remediation_steps.isnot(None))
|
|
|
|
if status:
|
|
query = query.filter(Test.remediation_status == status)
|
|
|
|
tests = query.order_by(Test.created_at.desc()).all()
|
|
|
|
by_status: dict[str, int] = {}
|
|
for t in tests:
|
|
s = t.remediation_status or "unset"
|
|
by_status[s] = by_status.get(s, 0) + 1
|
|
|
|
return {
|
|
"generated_at": datetime.utcnow().isoformat(),
|
|
"summary": {
|
|
"total_with_remediation": len(tests),
|
|
"by_status": by_status,
|
|
},
|
|
"tests": [
|
|
{
|
|
"id": str(t.id),
|
|
"name": t.name,
|
|
"technique_id": str(t.technique_id),
|
|
"state": t.state.value if hasattr(t.state, "value") else str(t.state),
|
|
"remediation_status": t.remediation_status,
|
|
"remediation_steps": t.remediation_steps,
|
|
"remediation_assignee": str(t.remediation_assignee) if t.remediation_assignee else None,
|
|
}
|
|
for t in tests
|
|
],
|
|
}
|