"""Snapshot service — create, compare, and manage coverage snapshots. Provides point-in-time coverage captures with normalised per-technique storage and temporal comparison between any two snapshots. Uses ``bulk_technique_scores`` so that snapshot creation runs in a fixed number of SQL queries regardless of technique count. """ # Import logging import logging # Import uuid import uuid # Import defaultdict from collections from collections import defaultdict # Import datetime, timedelta, timezone from datetime from datetime import datetime, timedelta, timezone # Import func from sqlalchemy from sqlalchemy import func # Import Session from sqlalchemy.orm from sqlalchemy.orm import Session # Import EntityNotFoundError from app.domain.errors from app.domain.errors import EntityNotFoundError # Import CoverageSnapshot, SnapshotTechniqueState from app.models.coverage_snapshot from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState # Import TechniqueStatus from app.models.enums from app.models.enums import TechniqueStatus # Import Technique from app.models.technique from app.models.technique import Technique # Import from app.services.scoring_service from app.services.scoring_service import ( bulk_technique_scores, calculate_organization_score, ) # Assign logger = logging.getLogger(__name__) logger = logging.getLogger(__name__) # Coverage status ordering for snapshot delta comparisons (higher = better coverage) _STATUS_ORDER: dict[str, int] = { # Literal argument value "not_evaluated": 0, # Literal argument value "not_covered": 1, # Literal argument value "in_progress": 2, # Literal argument value "partial": 3, # Literal argument value "validated": 4, } # --------------------------------------------------------------------------- # Serialization and queries # --------------------------------------------------------------------------- def serialize_snapshot_summary(snap: CoverageSnapshot) -> dict: """Return a lightweight serialization of a snapshot for list views. Args: snap (CoverageSnapshot): The snapshot ORM object to serialize. Returns: dict: Flat dictionary with summary fields (counts, scores, tactic breakdown) suitable for paginated list responses. """ # Return { return { # Literal argument value "id": str(snap.id), # Literal argument value "name": snap.name, # Literal argument value "organization_score": snap.organization_score, # Literal argument value "total_techniques": snap.total_techniques, # Literal argument value "validated_count": snap.validated_count, # Literal argument value "partial_count": snap.partial_count, # Literal argument value "not_covered_count": snap.not_covered_count, # Literal argument value "in_progress_count": snap.in_progress_count, # Literal argument value "not_evaluated_count": snap.not_evaluated_count, # Literal argument value "coverage_percentage": getattr(snap, "coverage_percentage", 0.0), # Literal argument value "by_tactic": getattr(snap, "by_tactic", None) or {}, # Literal argument value "by_status": getattr(snap, "by_status", None) or {}, # Literal argument value "stale_count": getattr(snap, "stale_count", 0), # Literal argument value "never_tested_count": getattr(snap, "never_tested_count", 0), # Literal argument value "created_by": str(snap.created_by) if snap.created_by else None, # Literal argument value "created_at": snap.created_at.isoformat() if snap.created_at else None, } # Define function serialize_snapshot_detail def serialize_snapshot_detail(db: Session, snap: CoverageSnapshot) -> dict: """Return full serialization of a snapshot including per-technique states. Args: db (Session): Active SQLAlchemy database session. snap (CoverageSnapshot): The snapshot ORM object to serialize. Returns: dict: Summary fields merged with a ``technique_states`` list, each entry containing ``mitre_id``, ``technique_id``, ``status``, and ``score``. """ # Assign base = serialize_snapshot_summary(snap) base = serialize_snapshot_summary(snap) # Assign technique_states = ( technique_states = ( db.query(SnapshotTechniqueState) # Chain .filter() call .filter(SnapshotTechniqueState.snapshot_id == snap.id) # Chain .order_by() call .order_by(SnapshotTechniqueState.mitre_id) # Chain .all() call .all() ) # Assign base["technique_states"] = [ base["technique_states"] = [ { # Literal argument value "mitre_id": s.mitre_id, # Literal argument value "technique_id": str(s.technique_id), # Literal argument value "status": s.status, # Literal argument value "score": s.score, } for s in technique_states ] # Return base return base # Define function list_snapshots def list_snapshots( # Entry: db db: Session, *, # Entry: offset offset: int = 0, # Entry: limit limit: int = 50, ) -> dict: """List coverage snapshots ordered by creation date (newest first). Args: db (Session): Active SQLAlchemy database session. offset (int): Number of records to skip for pagination. limit (int): Maximum number of records to return. Returns: dict: Contains ``total``, ``offset``, ``limit``, and ``items`` (list of serialized snapshot summaries). """ # Assign query = db.query(CoverageSnapshot) query = db.query(CoverageSnapshot) # Assign total = query.count() total = query.count() # Assign snapshots = ( snapshots = ( query # Chain .order_by() call .order_by(CoverageSnapshot.created_at.desc()) # Chain .offset() call .offset(offset) # Chain .limit() call .limit(limit) # Chain .all() call .all() ) # Return { return { # Literal argument value "total": total, # Literal argument value "offset": offset, # Literal argument value "limit": limit, # Literal argument value "items": [serialize_snapshot_summary(s) for s in snapshots], } # Define function get_snapshot_or_raise def get_snapshot_or_raise(db: Session, snapshot_id: str) -> CoverageSnapshot: """Fetch snapshot by ID or raise EntityNotFoundError. Args: db (Session): Active SQLAlchemy database session. snapshot_id (str): UUID string of the snapshot to retrieve. Returns: CoverageSnapshot: The matching snapshot ORM object. """ # Attempt the following; catch errors below try: # Assign sid = uuid.UUID(snapshot_id) sid = uuid.UUID(snapshot_id) # Handle (ValueError, TypeError) except (ValueError, TypeError): # Raise EntityNotFoundError raise EntityNotFoundError("Snapshot", snapshot_id) # Assign snapshot = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == sid).first() snapshot = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == sid).first() # Check: snapshot is None if snapshot is None: # Raise EntityNotFoundError raise EntityNotFoundError("Snapshot", snapshot_id) # Return snapshot return snapshot # Define function get_snapshot_detail def get_snapshot_detail(db: Session, snapshot_id: str) -> dict: """Return detailed snapshot data including per-technique states. Args: db (Session): Active SQLAlchemy database session. snapshot_id (str): UUID string of the snapshot to retrieve. Returns: dict: Full snapshot serialization from :func:`serialize_snapshot_detail`. """ # Assign snapshot = get_snapshot_or_raise(db, snapshot_id) snapshot = get_snapshot_or_raise(db, snapshot_id) # Return serialize_snapshot_detail(db, snapshot) return serialize_snapshot_detail(db, snapshot) # Define function delete_snapshot def delete_snapshot(db: Session, snapshot_id: str) -> None: """Delete a snapshot. Does not commit — caller must commit. Args: db (Session): Active SQLAlchemy database session. snapshot_id (str): UUID string of the snapshot to delete. """ # Assign snapshot = get_snapshot_or_raise(db, snapshot_id) snapshot = get_snapshot_or_raise(db, snapshot_id) # Mark record for deletion on next commit db.delete(snapshot) # --------------------------------------------------------------------------- # Create snapshot # --------------------------------------------------------------------------- def create_snapshot( # Entry: db db: Session, # Entry: name name: str | None = None, # Entry: user_id user_id: uuid.UUID | None = None, ) -> CoverageSnapshot: """Capture the current coverage state into a new snapshot. 1. Bulk-fetch all technique scores in 5 aggregated queries. 2. Walk the already-loaded techniques to count statuses. 3. Compute the org score from the same bulk data. 4. Persist a ``CoverageSnapshot`` with normalised ``SnapshotTechniqueState`` rows. Args: db (Session): Active SQLAlchemy database session. name (str | None): Optional human-readable label for the snapshot. user_id (uuid.UUID | None): UUID of the user creating the snapshot, stored for auditing. Returns: CoverageSnapshot: The newly created and committed snapshot ORM object. """ # Assign scores_map = bulk_technique_scores(db) scores_map = bulk_technique_scores(db) # Assign techniques = db.query(Technique).all() techniques = db.query(Technique).all() # Assign validated_count = 0 validated_count = 0 # Assign partial_count = 0 partial_count = 0 # Assign not_covered_count = 0 not_covered_count = 0 # Assign in_progress_count = 0 in_progress_count = 0 # Assign not_evaluated_count = 0 not_evaluated_count = 0 # Assign stale_count = 0 stale_count = 0 # Assign never_tested_count = 0 never_tested_count = 0 # Assign by_tactic = defaultdict( by_tactic: dict[str, dict] = defaultdict( # Entry: lambda lambda: {"total": 0, "validated": 0, "partial": 0, "score_sum": 0.0} ) # Assign by_status = defaultdict(int) by_status: dict[str, int] = defaultdict(int) # Assign technique_rows = [] technique_rows: list[dict] = [] # Iterate over techniques for tech in techniques: # Assign status_value = ( status_value = ( tech.status_global.value if isinstance(tech.status_global, TechniqueStatus) else (tech.status_global or "not_evaluated") ) # Check: status_value == "validated" if status_value == "validated": # Assign validated_count = 1 validated_count += 1 # Alternative: status_value == "partial" elif status_value == "partial": # Assign partial_count = 1 partial_count += 1 # Alternative: status_value == "not_covered" elif status_value == "not_covered": # Assign not_covered_count = 1 not_covered_count += 1 # Alternative: status_value == "in_progress" elif status_value == "in_progress": # Assign in_progress_count = 1 in_progress_count += 1 # Fallback: handle remaining cases else: # Assign not_evaluated_count = 1 not_evaluated_count += 1 # Assign entry = scores_map.get(tech.id, {}) entry = scores_map.get(tech.id, {}) # Assign score = entry.get("total_score", 0) score = entry.get("total_score", 0) # Call technique_rows.append() technique_rows.append({ # Literal argument value "technique_id": tech.id, # Literal argument value "mitre_id": tech.mitre_id, # Literal argument value "status": status_value, # Literal argument value "score": score, }) # Assign by_status[status_value] = 1 by_status[status_value] += 1 # Assign tactic_key = tech.tactic or "unknown" tactic_key = tech.tactic or "unknown" # Assign bucket = by_tactic[tactic_key] bucket = by_tactic[tactic_key] # Assign bucket["total"] = 1 bucket["total"] += 1 # Assign bucket["score_sum"] = score bucket["score_sum"] += score # Check: status_value == "validated" if status_value == "validated": # Assign bucket["validated"] = 1 bucket["validated"] += 1 # Alternative: status_value == "partial" elif status_value == "partial": # Assign bucket["partial"] = 1 bucket["partial"] += 1 # Check: status_value == "not_evaluated" if status_value == "not_evaluated": # Assign never_tested_count = 1 never_tested_count += 1 # Check: tech.review_required if tech.review_required: # Assign stale_count = 1 stale_count += 1 # Assign org_data = calculate_organization_score(db) org_data = calculate_organization_score(db) # Assign org_score = org_data.get("overall_score", 0) org_score = org_data.get("overall_score", 0) # Assign total_techniques = len(techniques) or 1 total_techniques = len(techniques) or 1 # Assign coverage_pct = round((validated_count / total_techniques) * 100, 1) coverage_pct = round((validated_count / total_techniques) * 100, 1) # Assign by_tactic_out = { by_tactic_out = { # Entry: tactic tactic: { # Literal argument value "total": data["total"], # Literal argument value "validated": data["validated"], # Literal argument value "partial": data["partial"], # Literal argument value "average_score": round(data["score_sum"] / data["total"], 1) if data["total"] else 0, } for tactic, data in by_tactic.items() } # Assign snapshot = CoverageSnapshot( snapshot = CoverageSnapshot( # Keyword argument: name name=name, # Keyword argument: organization_score organization_score=org_score, # Keyword argument: total_techniques total_techniques=len(techniques), # Keyword argument: validated_count validated_count=validated_count, # Keyword argument: partial_count partial_count=partial_count, # Keyword argument: not_covered_count not_covered_count=not_covered_count, # Keyword argument: in_progress_count in_progress_count=in_progress_count, # Keyword argument: not_evaluated_count not_evaluated_count=not_evaluated_count, # Keyword argument: coverage_percentage coverage_percentage=coverage_pct, # Keyword argument: by_tactic by_tactic=by_tactic_out, # Keyword argument: by_status by_status=dict(by_status), # Keyword argument: stale_count stale_count=stale_count, # Keyword argument: never_tested_count never_tested_count=never_tested_count, # Keyword argument: created_by created_by=user_id, ) # Stage new record(s) for database insertion db.add(snapshot) # Flush changes to DB without committing the transaction db.flush() # Iterate over technique_rows for row in technique_rows: # Assign state = SnapshotTechniqueState( state = SnapshotTechniqueState( # Keyword argument: snapshot_id snapshot_id=snapshot.id, # Keyword argument: technique_id technique_id=row["technique_id"], # Keyword argument: mitre_id mitre_id=row["mitre_id"], # Keyword argument: status status=row["status"], # Keyword argument: score score=row["score"], ) # Stage new record(s) for database insertion db.add(state) # Commit all pending changes to the database db.commit() # Reload ORM object attributes from the database db.refresh(snapshot) # Log info: logger.info( # Literal argument value "Snapshot '%s' created — %d techniques, org score %.1f", snapshot.name or snapshot.id, len(techniques), org_score, ) # Return snapshot return snapshot # --------------------------------------------------------------------------- # Compare snapshots # --------------------------------------------------------------------------- def compare_snapshots( # Entry: db db: Session, # Entry: snapshot_a_id snapshot_a_id: uuid.UUID, # Entry: snapshot_b_id snapshot_b_id: uuid.UUID, ) -> dict: """Compare two snapshots and return deltas. Returns improved/worsened technique lists plus aggregate statistics. Args: db (Session): Active SQLAlchemy database session. snapshot_a_id (uuid.UUID): UUID of the baseline (older) snapshot. snapshot_b_id (uuid.UUID): UUID of the comparison (newer) snapshot. Returns: dict: Contains ``snapshot_a``, ``snapshot_b``, ``score_delta``, ``improved``, ``worsened``, ``unchanged_count``, and ``summary`` keys. """ # Assign snap_a = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_a... snap_a = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_a_id).first() # Assign snap_b = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_b... snap_b = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_b_id).first() # Check: not snap_a or not snap_b if not snap_a or not snap_b: # Raise EntityNotFoundError raise EntityNotFoundError("Snapshot", f"{snapshot_a_id} or {snapshot_b_id}") # Build lookup dicts: mitre_id -> {status, score} states_a = { s.mitre_id: {"status": s.status, "score": s.score or 0} for s in db.query(SnapshotTechniqueState) # Chain .filter() call .filter(SnapshotTechniqueState.snapshot_id == snapshot_a_id) # Chain .all() call .all() } # Assign states_b = { states_b = { s.mitre_id: {"status": s.status, "score": s.score or 0} for s in db.query(SnapshotTechniqueState) # Chain .filter() call .filter(SnapshotTechniqueState.snapshot_id == snapshot_b_id) # Chain .all() call .all() } # Assign improved = [] improved = [] # Assign worsened = [] worsened = [] # Assign unchanged_count = 0 unchanged_count = 0 # Assign all_mitre_ids = set(states_a.keys()) | set(states_b.keys()) all_mitre_ids = set(states_a.keys()) | set(states_b.keys()) # Iterate over sorted(all_mitre_ids) for mitre_id in sorted(all_mitre_ids): # Assign a = states_a.get(mitre_id, {"status": "not_evaluated", "score": 0}) a = states_a.get(mitre_id, {"status": "not_evaluated", "score": 0}) # Assign b = states_b.get(mitre_id, {"status": "not_evaluated", "score": 0}) b = states_b.get(mitre_id, {"status": "not_evaluated", "score": 0}) # Assign a_order = _STATUS_ORDER.get(a["status"], 0) a_order = _STATUS_ORDER.get(a["status"], 0) # Assign b_order = _STATUS_ORDER.get(b["status"], 0) b_order = _STATUS_ORDER.get(b["status"], 0) # Check: b_order > a_order or (b_order == a_order and b["score"] > a["score"]) if b_order > a_order or (b_order == a_order and b["score"] > a["score"]): # Call improved.append() improved.append({ # Literal argument value "mitre_id": mitre_id, # Literal argument value "old_status": a["status"], # Literal argument value "new_status": b["status"], # Literal argument value "old_score": a["score"], # Literal argument value "new_score": b["score"], }) # Alternative: b_order < a_order or (b_order == a_order and b["score"] < a["score"]) elif b_order < a_order or (b_order == a_order and b["score"] < a["score"]): # Call worsened.append() worsened.append({ # Literal argument value "mitre_id": mitre_id, # Literal argument value "old_status": a["status"], # Literal argument value "new_status": b["status"], # Literal argument value "old_score": a["score"], # Literal argument value "new_score": b["score"], }) # Fallback: handle remaining cases else: # Assign unchanged_count = 1 unchanged_count += 1 # Define function _snap_summary def _snap_summary(snap: CoverageSnapshot) -> dict: # Return { return { # Literal argument value "id": str(snap.id), # Literal argument value "name": snap.name, # Literal argument value "organization_score": snap.organization_score, # Literal argument value "total_techniques": snap.total_techniques, # Literal argument value "validated_count": snap.validated_count, # Literal argument value "partial_count": snap.partial_count, # Literal argument value "not_covered_count": snap.not_covered_count, # Literal argument value "in_progress_count": snap.in_progress_count, # Literal argument value "not_evaluated_count": snap.not_evaluated_count, # Literal argument value "created_at": snap.created_at.isoformat() if snap.created_at else None, } # Return { return { # Literal argument value "snapshot_a": _snap_summary(snap_a), # Literal argument value "snapshot_b": _snap_summary(snap_b), # Literal argument value "score_delta": round(snap_b.organization_score - snap_a.organization_score, 1), # Literal argument value "improved": improved, # Literal argument value "worsened": worsened, # Literal argument value "unchanged_count": unchanged_count, # Literal argument value "summary": { # Literal argument value "improved_count": len(improved), # Literal argument value "worsened_count": len(worsened), # Literal argument value "new_count": len(states_b.keys() - states_a.keys()), }, } # --------------------------------------------------------------------------- # Coverage evolution (trends) # --------------------------------------------------------------------------- def get_coverage_evolution(db: Session, *, months: int = 12) -> list[dict]: """Return snapshot trend points for the last *months* months. Args: db (Session): Active SQLAlchemy database session. months (int): Number of months to look back; defaults to 12. Returns: list[dict]: Snapshot trend entries ordered by creation date ascending, each containing ``date``, ``name``, ``org_score``, ``coverage_pct``, ``by_tactic``, ``by_status``, ``stale_count``, ``never_tested_count``, ``validated_count``, and ``total_techniques``. """ # Assign cutoff = datetime.now(timezone.utc) - timedelta(days=months * 30) cutoff = datetime.now(timezone.utc) - timedelta(days=months * 30) # Assign snapshots = ( snapshots = ( db.query(CoverageSnapshot) # Chain .filter() call .filter(CoverageSnapshot.created_at >= cutoff) # Chain .order_by() call .order_by(CoverageSnapshot.created_at.asc()) # Chain .all() call .all() ) # Return [ return [ { # Literal argument value "date": snap.created_at.isoformat() if snap.created_at else None, # Literal argument value "name": snap.name, # Literal argument value "org_score": snap.organization_score, # Literal argument value "coverage_pct": getattr(snap, "coverage_percentage", 0.0), # Literal argument value "by_tactic": getattr(snap, "by_tactic", None) or {}, # Literal argument value "by_status": getattr(snap, "by_status", None) or {}, # Literal argument value "stale_count": getattr(snap, "stale_count", 0), # Literal argument value "never_tested_count": getattr(snap, "never_tested_count", 0), # Literal argument value "validated_count": snap.validated_count, # Literal argument value "total_techniques": snap.total_techniques, } for snap in snapshots ] # --------------------------------------------------------------------------- # Cleanup # --------------------------------------------------------------------------- def cleanup_old_snapshots(db: Session, keep_last: int = 52) -> int: """Delete oldest snapshots, keeping the most recent *keep_last*. Args: db (Session): Active SQLAlchemy database session. keep_last (int): Number of most-recent snapshots to retain; defaults to 52 (one year of weekly snapshots). Returns: int: Number of snapshots deleted. """ # Assign total = db.query(func.count(CoverageSnapshot.id)).scalar() or 0 total = db.query(func.count(CoverageSnapshot.id)).scalar() or 0 # Check: total <= keep_last if total <= keep_last: # Return 0 return 0 # Assign to_delete = total - keep_last to_delete = total - keep_last # Assign old_snapshots = ( old_snapshots = ( db.query(CoverageSnapshot) # Chain .order_by() call .order_by(CoverageSnapshot.created_at.asc()) # Chain .limit() call .limit(to_delete) # Chain .all() call .all() ) # Assign deleted = 0 deleted = 0 # Iterate over old_snapshots for snap in old_snapshots: # Mark record for deletion on next commit db.delete(snap) # Assign deleted = 1 deleted += 1 # Commit all pending changes to the database db.commit() # Log info: "Snapshot cleanup — deleted %d old snapshots (kept logger.info("Snapshot cleanup — deleted %d old snapshots (kept %d)", deleted, keep_last) # Return deleted return deleted