- Replace per-technique calculate_technique_score loop with bulk_technique_scores() from scoring_service - Snapshot creation now runs ~10 fixed queries instead of N*5+N*5 (was ~2000+ for 200 techniques)
257 lines
7.9 KiB
Python
257 lines
7.9 KiB
Python
"""Snapshot service — create, compare, and manage coverage snapshots.
|
|
|
|
Provides point-in-time coverage captures with normalised per-technique
|
|
storage and temporal comparison between any two snapshots.
|
|
|
|
Uses ``bulk_technique_scores`` so that snapshot creation runs in a fixed
|
|
number of SQL queries regardless of technique count.
|
|
"""
|
|
|
|
import logging
|
|
import uuid
|
|
from datetime import datetime
|
|
|
|
from sqlalchemy import func
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.models.technique import Technique
|
|
from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState
|
|
from app.models.enums import TechniqueStatus
|
|
from app.services.scoring_service import (
|
|
bulk_technique_scores,
|
|
calculate_organization_score,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Create snapshot
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def create_snapshot(
|
|
db: Session,
|
|
name: str | None = None,
|
|
user_id: uuid.UUID | None = None,
|
|
) -> CoverageSnapshot:
|
|
"""Capture the current coverage state into a new snapshot.
|
|
|
|
1. Bulk-fetch all technique scores in 5 aggregated queries.
|
|
2. Walk the already-loaded techniques to count statuses.
|
|
3. Compute the org score from the same bulk data.
|
|
4. Persist a ``CoverageSnapshot`` with normalised
|
|
``SnapshotTechniqueState`` rows.
|
|
"""
|
|
scores_map = bulk_technique_scores(db)
|
|
|
|
techniques = db.query(Technique).all()
|
|
|
|
validated_count = 0
|
|
partial_count = 0
|
|
not_covered_count = 0
|
|
in_progress_count = 0
|
|
not_evaluated_count = 0
|
|
|
|
technique_rows: list[dict] = []
|
|
|
|
for tech in techniques:
|
|
status_value = (
|
|
tech.status_global.value
|
|
if isinstance(tech.status_global, TechniqueStatus)
|
|
else (tech.status_global or "not_evaluated")
|
|
)
|
|
|
|
if status_value == "validated":
|
|
validated_count += 1
|
|
elif status_value == "partial":
|
|
partial_count += 1
|
|
elif status_value == "not_covered":
|
|
not_covered_count += 1
|
|
elif status_value == "in_progress":
|
|
in_progress_count += 1
|
|
else:
|
|
not_evaluated_count += 1
|
|
|
|
entry = scores_map.get(tech.id, {})
|
|
technique_rows.append({
|
|
"technique_id": tech.id,
|
|
"mitre_id": tech.mitre_id,
|
|
"status": status_value,
|
|
"score": entry.get("total_score", 0),
|
|
})
|
|
|
|
org_data = calculate_organization_score(db)
|
|
org_score = org_data.get("overall_score", 0)
|
|
|
|
snapshot = CoverageSnapshot(
|
|
name=name,
|
|
organization_score=org_score,
|
|
total_techniques=len(techniques),
|
|
validated_count=validated_count,
|
|
partial_count=partial_count,
|
|
not_covered_count=not_covered_count,
|
|
in_progress_count=in_progress_count,
|
|
not_evaluated_count=not_evaluated_count,
|
|
created_by=user_id,
|
|
)
|
|
db.add(snapshot)
|
|
db.flush()
|
|
|
|
for row in technique_rows:
|
|
state = SnapshotTechniqueState(
|
|
snapshot_id=snapshot.id,
|
|
technique_id=row["technique_id"],
|
|
mitre_id=row["mitre_id"],
|
|
status=row["status"],
|
|
score=row["score"],
|
|
)
|
|
db.add(state)
|
|
|
|
db.commit()
|
|
db.refresh(snapshot)
|
|
|
|
logger.info(
|
|
"Snapshot '%s' created — %d techniques, org score %.1f",
|
|
snapshot.name or snapshot.id,
|
|
len(techniques),
|
|
org_score,
|
|
)
|
|
return snapshot
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Compare snapshots
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def compare_snapshots(
|
|
db: Session,
|
|
snapshot_a_id: uuid.UUID,
|
|
snapshot_b_id: uuid.UUID,
|
|
) -> dict:
|
|
"""Compare two snapshots and return deltas.
|
|
|
|
Returns improved/worsened technique lists plus aggregate statistics.
|
|
"""
|
|
snap_a = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_a_id).first()
|
|
snap_b = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_b_id).first()
|
|
|
|
if not snap_a or not snap_b:
|
|
return {"error": "One or both snapshots not found"}
|
|
|
|
# Build lookup dicts: mitre_id -> {status, score}
|
|
states_a = {
|
|
s.mitre_id: {"status": s.status, "score": s.score or 0}
|
|
for s in db.query(SnapshotTechniqueState)
|
|
.filter(SnapshotTechniqueState.snapshot_id == snapshot_a_id)
|
|
.all()
|
|
}
|
|
states_b = {
|
|
s.mitre_id: {"status": s.status, "score": s.score or 0}
|
|
for s in db.query(SnapshotTechniqueState)
|
|
.filter(SnapshotTechniqueState.snapshot_id == snapshot_b_id)
|
|
.all()
|
|
}
|
|
|
|
# Status priority for comparison
|
|
STATUS_ORDER = {
|
|
"not_evaluated": 0,
|
|
"not_covered": 1,
|
|
"in_progress": 2,
|
|
"partial": 3,
|
|
"validated": 4,
|
|
}
|
|
|
|
improved = []
|
|
worsened = []
|
|
unchanged_count = 0
|
|
|
|
all_mitre_ids = set(states_a.keys()) | set(states_b.keys())
|
|
|
|
for mitre_id in sorted(all_mitre_ids):
|
|
a = states_a.get(mitre_id, {"status": "not_evaluated", "score": 0})
|
|
b = states_b.get(mitre_id, {"status": "not_evaluated", "score": 0})
|
|
|
|
a_order = STATUS_ORDER.get(a["status"], 0)
|
|
b_order = STATUS_ORDER.get(b["status"], 0)
|
|
|
|
if b_order > a_order or (b_order == a_order and b["score"] > a["score"]):
|
|
improved.append({
|
|
"mitre_id": mitre_id,
|
|
"old_status": a["status"],
|
|
"new_status": b["status"],
|
|
"old_score": a["score"],
|
|
"new_score": b["score"],
|
|
})
|
|
elif b_order < a_order or (b_order == a_order and b["score"] < a["score"]):
|
|
worsened.append({
|
|
"mitre_id": mitre_id,
|
|
"old_status": a["status"],
|
|
"new_status": b["status"],
|
|
"old_score": a["score"],
|
|
"new_score": b["score"],
|
|
})
|
|
else:
|
|
unchanged_count += 1
|
|
|
|
def _snap_summary(snap: CoverageSnapshot) -> dict:
|
|
return {
|
|
"id": str(snap.id),
|
|
"name": snap.name,
|
|
"organization_score": snap.organization_score,
|
|
"total_techniques": snap.total_techniques,
|
|
"validated_count": snap.validated_count,
|
|
"partial_count": snap.partial_count,
|
|
"not_covered_count": snap.not_covered_count,
|
|
"in_progress_count": snap.in_progress_count,
|
|
"not_evaluated_count": snap.not_evaluated_count,
|
|
"created_at": snap.created_at.isoformat() if snap.created_at else None,
|
|
}
|
|
|
|
return {
|
|
"snapshot_a": _snap_summary(snap_a),
|
|
"snapshot_b": _snap_summary(snap_b),
|
|
"score_delta": round(snap_b.organization_score - snap_a.organization_score, 1),
|
|
"improved": improved,
|
|
"worsened": worsened,
|
|
"unchanged_count": unchanged_count,
|
|
"summary": {
|
|
"improved_count": len(improved),
|
|
"worsened_count": len(worsened),
|
|
"new_count": len(states_b.keys() - states_a.keys()),
|
|
},
|
|
}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Cleanup
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def cleanup_old_snapshots(db: Session, keep_last: int = 52) -> int:
|
|
"""Delete oldest snapshots, keeping the most recent *keep_last*.
|
|
|
|
Returns the number of snapshots deleted.
|
|
"""
|
|
total = db.query(func.count(CoverageSnapshot.id)).scalar() or 0
|
|
if total <= keep_last:
|
|
return 0
|
|
|
|
to_delete = total - keep_last
|
|
old_snapshots = (
|
|
db.query(CoverageSnapshot)
|
|
.order_by(CoverageSnapshot.created_at.asc())
|
|
.limit(to_delete)
|
|
.all()
|
|
)
|
|
|
|
deleted = 0
|
|
for snap in old_snapshots:
|
|
db.delete(snap)
|
|
deleted += 1
|
|
|
|
db.commit()
|
|
logger.info("Snapshot cleanup — deleted %d old snapshots (kept %d)", deleted, keep_last)
|
|
return deleted
|