"""Snapshot service — create, compare, and manage coverage snapshots. Provides point-in-time coverage captures with normalised per-technique storage and temporal comparison between any two snapshots. Uses ``bulk_technique_scores`` so that snapshot creation runs in a fixed number of SQL queries regardless of technique count. """ import logging import uuid from datetime import datetime from sqlalchemy import func from sqlalchemy.orm import Session from app.models.technique import Technique from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState from app.models.enums import TechniqueStatus from app.services.scoring_service import ( bulk_technique_scores, calculate_organization_score, ) logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Create snapshot # --------------------------------------------------------------------------- def create_snapshot( db: Session, name: str | None = None, user_id: uuid.UUID | None = None, ) -> CoverageSnapshot: """Capture the current coverage state into a new snapshot. 1. Bulk-fetch all technique scores in 5 aggregated queries. 2. Walk the already-loaded techniques to count statuses. 3. Compute the org score from the same bulk data. 4. Persist a ``CoverageSnapshot`` with normalised ``SnapshotTechniqueState`` rows. """ scores_map = bulk_technique_scores(db) techniques = db.query(Technique).all() validated_count = 0 partial_count = 0 not_covered_count = 0 in_progress_count = 0 not_evaluated_count = 0 technique_rows: list[dict] = [] for tech in techniques: status_value = ( tech.status_global.value if isinstance(tech.status_global, TechniqueStatus) else (tech.status_global or "not_evaluated") ) if status_value == "validated": validated_count += 1 elif status_value == "partial": partial_count += 1 elif status_value == "not_covered": not_covered_count += 1 elif status_value == "in_progress": in_progress_count += 1 else: not_evaluated_count += 1 entry = scores_map.get(tech.id, {}) technique_rows.append({ "technique_id": tech.id, "mitre_id": tech.mitre_id, "status": status_value, "score": entry.get("total_score", 0), }) org_data = calculate_organization_score(db) org_score = org_data.get("overall_score", 0) snapshot = CoverageSnapshot( name=name, organization_score=org_score, total_techniques=len(techniques), validated_count=validated_count, partial_count=partial_count, not_covered_count=not_covered_count, in_progress_count=in_progress_count, not_evaluated_count=not_evaluated_count, created_by=user_id, ) db.add(snapshot) db.flush() for row in technique_rows: state = SnapshotTechniqueState( snapshot_id=snapshot.id, technique_id=row["technique_id"], mitre_id=row["mitre_id"], status=row["status"], score=row["score"], ) db.add(state) db.commit() db.refresh(snapshot) logger.info( "Snapshot '%s' created — %d techniques, org score %.1f", snapshot.name or snapshot.id, len(techniques), org_score, ) return snapshot # --------------------------------------------------------------------------- # Compare snapshots # --------------------------------------------------------------------------- def compare_snapshots( db: Session, snapshot_a_id: uuid.UUID, snapshot_b_id: uuid.UUID, ) -> dict: """Compare two snapshots and return deltas. Returns improved/worsened technique lists plus aggregate statistics. """ snap_a = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_a_id).first() snap_b = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_b_id).first() if not snap_a or not snap_b: return {"error": "One or both snapshots not found"} # Build lookup dicts: mitre_id -> {status, score} states_a = { s.mitre_id: {"status": s.status, "score": s.score or 0} for s in db.query(SnapshotTechniqueState) .filter(SnapshotTechniqueState.snapshot_id == snapshot_a_id) .all() } states_b = { s.mitre_id: {"status": s.status, "score": s.score or 0} for s in db.query(SnapshotTechniqueState) .filter(SnapshotTechniqueState.snapshot_id == snapshot_b_id) .all() } # Status priority for comparison STATUS_ORDER = { "not_evaluated": 0, "not_covered": 1, "in_progress": 2, "partial": 3, "validated": 4, } improved = [] worsened = [] unchanged_count = 0 all_mitre_ids = set(states_a.keys()) | set(states_b.keys()) for mitre_id in sorted(all_mitre_ids): a = states_a.get(mitre_id, {"status": "not_evaluated", "score": 0}) b = states_b.get(mitre_id, {"status": "not_evaluated", "score": 0}) a_order = STATUS_ORDER.get(a["status"], 0) b_order = STATUS_ORDER.get(b["status"], 0) if b_order > a_order or (b_order == a_order and b["score"] > a["score"]): improved.append({ "mitre_id": mitre_id, "old_status": a["status"], "new_status": b["status"], "old_score": a["score"], "new_score": b["score"], }) elif b_order < a_order or (b_order == a_order and b["score"] < a["score"]): worsened.append({ "mitre_id": mitre_id, "old_status": a["status"], "new_status": b["status"], "old_score": a["score"], "new_score": b["score"], }) else: unchanged_count += 1 def _snap_summary(snap: CoverageSnapshot) -> dict: return { "id": str(snap.id), "name": snap.name, "organization_score": snap.organization_score, "total_techniques": snap.total_techniques, "validated_count": snap.validated_count, "partial_count": snap.partial_count, "not_covered_count": snap.not_covered_count, "in_progress_count": snap.in_progress_count, "not_evaluated_count": snap.not_evaluated_count, "created_at": snap.created_at.isoformat() if snap.created_at else None, } return { "snapshot_a": _snap_summary(snap_a), "snapshot_b": _snap_summary(snap_b), "score_delta": round(snap_b.organization_score - snap_a.organization_score, 1), "improved": improved, "worsened": worsened, "unchanged_count": unchanged_count, "summary": { "improved_count": len(improved), "worsened_count": len(worsened), "new_count": len(states_b.keys() - states_a.keys()), }, } # --------------------------------------------------------------------------- # Cleanup # --------------------------------------------------------------------------- def cleanup_old_snapshots(db: Session, keep_last: int = 52) -> int: """Delete oldest snapshots, keeping the most recent *keep_last*. Returns the number of snapshots deleted. """ total = db.query(func.count(CoverageSnapshot.id)).scalar() or 0 if total <= keep_last: return 0 to_delete = total - keep_last old_snapshots = ( db.query(CoverageSnapshot) .order_by(CoverageSnapshot.created_at.asc()) .limit(to_delete) .all() ) deleted = 0 for snap in old_snapshots: db.delete(snap) deleted += 1 db.commit() logger.info("Snapshot cleanup — deleted %d old snapshots (kept %d)", deleted, keep_last) return deleted