430 lines
14 KiB
Python
430 lines
14 KiB
Python
"""Snapshot service — create, compare, and manage coverage snapshots.
|
|
|
|
Provides point-in-time coverage captures with normalised per-technique
|
|
storage and temporal comparison between any two snapshots.
|
|
|
|
Uses ``bulk_technique_scores`` so that snapshot creation runs in a fixed
|
|
number of SQL queries regardless of technique count.
|
|
"""
|
|
|
|
import logging
|
|
import uuid
|
|
from collections import defaultdict
|
|
from datetime import datetime, timedelta, timezone
|
|
|
|
from sqlalchemy import func
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.domain.errors import EntityNotFoundError
|
|
from app.models.technique import Technique
|
|
from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState
|
|
from app.models.enums import TechniqueStatus
|
|
from app.services.scoring_service import (
|
|
bulk_technique_scores,
|
|
calculate_organization_score,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Serialization and queries
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def serialize_snapshot_summary(snap: CoverageSnapshot) -> dict:
|
|
"""Lightweight serialization for list views."""
|
|
return {
|
|
"id": str(snap.id),
|
|
"name": snap.name,
|
|
"organization_score": snap.organization_score,
|
|
"total_techniques": snap.total_techniques,
|
|
"validated_count": snap.validated_count,
|
|
"partial_count": snap.partial_count,
|
|
"not_covered_count": snap.not_covered_count,
|
|
"in_progress_count": snap.in_progress_count,
|
|
"not_evaluated_count": snap.not_evaluated_count,
|
|
"coverage_percentage": getattr(snap, "coverage_percentage", 0.0),
|
|
"by_tactic": getattr(snap, "by_tactic", None) or {},
|
|
"by_status": getattr(snap, "by_status", None) or {},
|
|
"stale_count": getattr(snap, "stale_count", 0),
|
|
"never_tested_count": getattr(snap, "never_tested_count", 0),
|
|
"created_by": str(snap.created_by) if snap.created_by else None,
|
|
"created_at": snap.created_at.isoformat() if snap.created_at else None,
|
|
}
|
|
|
|
|
|
def serialize_snapshot_detail(db: Session, snap: CoverageSnapshot) -> dict:
|
|
"""Full serialization including technique states."""
|
|
base = serialize_snapshot_summary(snap)
|
|
|
|
technique_states = (
|
|
db.query(SnapshotTechniqueState)
|
|
.filter(SnapshotTechniqueState.snapshot_id == snap.id)
|
|
.order_by(SnapshotTechniqueState.mitre_id)
|
|
.all()
|
|
)
|
|
|
|
base["technique_states"] = [
|
|
{
|
|
"mitre_id": s.mitre_id,
|
|
"technique_id": str(s.technique_id),
|
|
"status": s.status,
|
|
"score": s.score,
|
|
}
|
|
for s in technique_states
|
|
]
|
|
return base
|
|
|
|
|
|
def list_snapshots(
|
|
db: Session,
|
|
*,
|
|
offset: int = 0,
|
|
limit: int = 50,
|
|
) -> dict:
|
|
"""List coverage snapshots ordered by creation date (newest first)."""
|
|
query = db.query(CoverageSnapshot)
|
|
total = query.count()
|
|
|
|
snapshots = (
|
|
query
|
|
.order_by(CoverageSnapshot.created_at.desc())
|
|
.offset(offset)
|
|
.limit(limit)
|
|
.all()
|
|
)
|
|
|
|
return {
|
|
"total": total,
|
|
"offset": offset,
|
|
"limit": limit,
|
|
"items": [serialize_snapshot_summary(s) for s in snapshots],
|
|
}
|
|
|
|
|
|
def get_snapshot_or_raise(db: Session, snapshot_id: str) -> CoverageSnapshot:
|
|
"""Fetch snapshot by ID or raise EntityNotFoundError."""
|
|
try:
|
|
sid = uuid.UUID(snapshot_id)
|
|
except (ValueError, TypeError):
|
|
raise EntityNotFoundError("Snapshot", snapshot_id)
|
|
snapshot = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == sid).first()
|
|
if snapshot is None:
|
|
raise EntityNotFoundError("Snapshot", snapshot_id)
|
|
return snapshot
|
|
|
|
|
|
def get_snapshot_detail(db: Session, snapshot_id: str) -> dict:
|
|
"""Get detailed snapshot including per-technique states."""
|
|
snapshot = get_snapshot_or_raise(db, snapshot_id)
|
|
return serialize_snapshot_detail(db, snapshot)
|
|
|
|
|
|
def delete_snapshot(db: Session, snapshot_id: str) -> None:
|
|
"""Delete a snapshot. Does not commit — caller must commit."""
|
|
snapshot = get_snapshot_or_raise(db, snapshot_id)
|
|
db.delete(snapshot)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Create snapshot
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def create_snapshot(
|
|
db: Session,
|
|
name: str | None = None,
|
|
user_id: uuid.UUID | None = None,
|
|
) -> CoverageSnapshot:
|
|
"""Capture the current coverage state into a new snapshot.
|
|
|
|
1. Bulk-fetch all technique scores in 5 aggregated queries.
|
|
2. Walk the already-loaded techniques to count statuses.
|
|
3. Compute the org score from the same bulk data.
|
|
4. Persist a ``CoverageSnapshot`` with normalised
|
|
``SnapshotTechniqueState`` rows.
|
|
"""
|
|
scores_map = bulk_technique_scores(db)
|
|
|
|
techniques = db.query(Technique).all()
|
|
|
|
validated_count = 0
|
|
partial_count = 0
|
|
not_covered_count = 0
|
|
in_progress_count = 0
|
|
not_evaluated_count = 0
|
|
stale_count = 0
|
|
never_tested_count = 0
|
|
|
|
by_tactic: dict[str, dict] = defaultdict(
|
|
lambda: {"total": 0, "validated": 0, "partial": 0, "score_sum": 0.0}
|
|
)
|
|
by_status: dict[str, int] = defaultdict(int)
|
|
|
|
technique_rows: list[dict] = []
|
|
|
|
for tech in techniques:
|
|
status_value = (
|
|
tech.status_global.value
|
|
if isinstance(tech.status_global, TechniqueStatus)
|
|
else (tech.status_global or "not_evaluated")
|
|
)
|
|
|
|
if status_value == "validated":
|
|
validated_count += 1
|
|
elif status_value == "partial":
|
|
partial_count += 1
|
|
elif status_value == "not_covered":
|
|
not_covered_count += 1
|
|
elif status_value == "in_progress":
|
|
in_progress_count += 1
|
|
else:
|
|
not_evaluated_count += 1
|
|
|
|
entry = scores_map.get(tech.id, {})
|
|
score = entry.get("total_score", 0)
|
|
technique_rows.append({
|
|
"technique_id": tech.id,
|
|
"mitre_id": tech.mitre_id,
|
|
"status": status_value,
|
|
"score": score,
|
|
})
|
|
|
|
by_status[status_value] += 1
|
|
tactic_key = tech.tactic or "unknown"
|
|
bucket = by_tactic[tactic_key]
|
|
bucket["total"] += 1
|
|
bucket["score_sum"] += score
|
|
if status_value == "validated":
|
|
bucket["validated"] += 1
|
|
elif status_value == "partial":
|
|
bucket["partial"] += 1
|
|
|
|
if status_value == "not_evaluated":
|
|
never_tested_count += 1
|
|
if tech.review_required:
|
|
stale_count += 1
|
|
|
|
org_data = calculate_organization_score(db)
|
|
org_score = org_data.get("overall_score", 0)
|
|
total_techniques = len(techniques) or 1
|
|
coverage_pct = round((validated_count / total_techniques) * 100, 1)
|
|
|
|
by_tactic_out = {
|
|
tactic: {
|
|
"total": data["total"],
|
|
"validated": data["validated"],
|
|
"partial": data["partial"],
|
|
"average_score": round(data["score_sum"] / data["total"], 1) if data["total"] else 0,
|
|
}
|
|
for tactic, data in by_tactic.items()
|
|
}
|
|
|
|
snapshot = CoverageSnapshot(
|
|
name=name,
|
|
organization_score=org_score,
|
|
total_techniques=len(techniques),
|
|
validated_count=validated_count,
|
|
partial_count=partial_count,
|
|
not_covered_count=not_covered_count,
|
|
in_progress_count=in_progress_count,
|
|
not_evaluated_count=not_evaluated_count,
|
|
coverage_percentage=coverage_pct,
|
|
by_tactic=by_tactic_out,
|
|
by_status=dict(by_status),
|
|
stale_count=stale_count,
|
|
never_tested_count=never_tested_count,
|
|
created_by=user_id,
|
|
)
|
|
db.add(snapshot)
|
|
db.flush()
|
|
|
|
for row in technique_rows:
|
|
state = SnapshotTechniqueState(
|
|
snapshot_id=snapshot.id,
|
|
technique_id=row["technique_id"],
|
|
mitre_id=row["mitre_id"],
|
|
status=row["status"],
|
|
score=row["score"],
|
|
)
|
|
db.add(state)
|
|
|
|
db.commit()
|
|
db.refresh(snapshot)
|
|
|
|
logger.info(
|
|
"Snapshot '%s' created — %d techniques, org score %.1f",
|
|
snapshot.name or snapshot.id,
|
|
len(techniques),
|
|
org_score,
|
|
)
|
|
return snapshot
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Compare snapshots
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def compare_snapshots(
|
|
db: Session,
|
|
snapshot_a_id: uuid.UUID,
|
|
snapshot_b_id: uuid.UUID,
|
|
) -> dict:
|
|
"""Compare two snapshots and return deltas.
|
|
|
|
Returns improved/worsened technique lists plus aggregate statistics.
|
|
"""
|
|
snap_a = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_a_id).first()
|
|
snap_b = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_b_id).first()
|
|
|
|
if not snap_a or not snap_b:
|
|
raise EntityNotFoundError("Snapshot", f"{snapshot_a_id} or {snapshot_b_id}")
|
|
|
|
# Build lookup dicts: mitre_id -> {status, score}
|
|
states_a = {
|
|
s.mitre_id: {"status": s.status, "score": s.score or 0}
|
|
for s in db.query(SnapshotTechniqueState)
|
|
.filter(SnapshotTechniqueState.snapshot_id == snapshot_a_id)
|
|
.all()
|
|
}
|
|
states_b = {
|
|
s.mitre_id: {"status": s.status, "score": s.score or 0}
|
|
for s in db.query(SnapshotTechniqueState)
|
|
.filter(SnapshotTechniqueState.snapshot_id == snapshot_b_id)
|
|
.all()
|
|
}
|
|
|
|
# Status priority for comparison
|
|
STATUS_ORDER = {
|
|
"not_evaluated": 0,
|
|
"not_covered": 1,
|
|
"in_progress": 2,
|
|
"partial": 3,
|
|
"validated": 4,
|
|
}
|
|
|
|
improved = []
|
|
worsened = []
|
|
unchanged_count = 0
|
|
|
|
all_mitre_ids = set(states_a.keys()) | set(states_b.keys())
|
|
|
|
for mitre_id in sorted(all_mitre_ids):
|
|
a = states_a.get(mitre_id, {"status": "not_evaluated", "score": 0})
|
|
b = states_b.get(mitre_id, {"status": "not_evaluated", "score": 0})
|
|
|
|
a_order = STATUS_ORDER.get(a["status"], 0)
|
|
b_order = STATUS_ORDER.get(b["status"], 0)
|
|
|
|
if b_order > a_order or (b_order == a_order and b["score"] > a["score"]):
|
|
improved.append({
|
|
"mitre_id": mitre_id,
|
|
"old_status": a["status"],
|
|
"new_status": b["status"],
|
|
"old_score": a["score"],
|
|
"new_score": b["score"],
|
|
})
|
|
elif b_order < a_order or (b_order == a_order and b["score"] < a["score"]):
|
|
worsened.append({
|
|
"mitre_id": mitre_id,
|
|
"old_status": a["status"],
|
|
"new_status": b["status"],
|
|
"old_score": a["score"],
|
|
"new_score": b["score"],
|
|
})
|
|
else:
|
|
unchanged_count += 1
|
|
|
|
def _snap_summary(snap: CoverageSnapshot) -> dict:
|
|
return {
|
|
"id": str(snap.id),
|
|
"name": snap.name,
|
|
"organization_score": snap.organization_score,
|
|
"total_techniques": snap.total_techniques,
|
|
"validated_count": snap.validated_count,
|
|
"partial_count": snap.partial_count,
|
|
"not_covered_count": snap.not_covered_count,
|
|
"in_progress_count": snap.in_progress_count,
|
|
"not_evaluated_count": snap.not_evaluated_count,
|
|
"created_at": snap.created_at.isoformat() if snap.created_at else None,
|
|
}
|
|
|
|
return {
|
|
"snapshot_a": _snap_summary(snap_a),
|
|
"snapshot_b": _snap_summary(snap_b),
|
|
"score_delta": round(snap_b.organization_score - snap_a.organization_score, 1),
|
|
"improved": improved,
|
|
"worsened": worsened,
|
|
"unchanged_count": unchanged_count,
|
|
"summary": {
|
|
"improved_count": len(improved),
|
|
"worsened_count": len(worsened),
|
|
"new_count": len(states_b.keys() - states_a.keys()),
|
|
},
|
|
}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Coverage evolution (trends)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def get_coverage_evolution(db: Session, *, months: int = 12) -> list[dict]:
|
|
"""Return snapshot trend points for the last *months* months."""
|
|
cutoff = datetime.now(timezone.utc) - timedelta(days=months * 30)
|
|
snapshots = (
|
|
db.query(CoverageSnapshot)
|
|
.filter(CoverageSnapshot.created_at >= cutoff)
|
|
.order_by(CoverageSnapshot.created_at.asc())
|
|
.all()
|
|
)
|
|
return [
|
|
{
|
|
"date": snap.created_at.isoformat() if snap.created_at else None,
|
|
"name": snap.name,
|
|
"org_score": snap.organization_score,
|
|
"coverage_pct": getattr(snap, "coverage_percentage", 0.0),
|
|
"by_tactic": getattr(snap, "by_tactic", None) or {},
|
|
"by_status": getattr(snap, "by_status", None) or {},
|
|
"stale_count": getattr(snap, "stale_count", 0),
|
|
"never_tested_count": getattr(snap, "never_tested_count", 0),
|
|
"validated_count": snap.validated_count,
|
|
"total_techniques": snap.total_techniques,
|
|
}
|
|
for snap in snapshots
|
|
]
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Cleanup
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def cleanup_old_snapshots(db: Session, keep_last: int = 52) -> int:
|
|
"""Delete oldest snapshots, keeping the most recent *keep_last*.
|
|
|
|
Returns the number of snapshots deleted.
|
|
"""
|
|
total = db.query(func.count(CoverageSnapshot.id)).scalar() or 0
|
|
if total <= keep_last:
|
|
return 0
|
|
|
|
to_delete = total - keep_last
|
|
old_snapshots = (
|
|
db.query(CoverageSnapshot)
|
|
.order_by(CoverageSnapshot.created_at.asc())
|
|
.limit(to_delete)
|
|
.all()
|
|
)
|
|
|
|
deleted = 0
|
|
for snap in old_snapshots:
|
|
db.delete(snap)
|
|
deleted += 1
|
|
|
|
db.commit()
|
|
logger.info("Snapshot cleanup — deleted %d old snapshots (kept %d)", deleted, keep_last)
|
|
return deleted
|