Files
Aegis/backend/app/services/snapshot_service.py
T
kitos ec26183e2e refactor(pep8): enforce full PEP8 compliance across backend Python codebase
- ruff.toml: select E/W/F/I/N rules, line-length=120, drop legacy ignores
- Auto-fix: sort 82 import blocks (isort), remove 29 unused imports,
  strip 6 trailing-whitespace blank lines in docstrings
- main.py: move setup_logging and settings imports to top (E402)
- errors.py: noqa N818 on DDD exception names (96 call sites, safe)
- intel_service.py: noqa N817 for universal ET alias
- atomic/elastic/sigma import services: move _MAX_UNCOMPRESSED_SIZE and
  _MAX_ENTRIES to module level (N806)
- compliance_import_service.py: move SAMPLE_CONTROLS / CIS_CONTROLS to
  module level; wrap long description strings (N806 + E501)
- snapshot_service.py: move STATUS_ORDER dict to module level (N806)
- sigma_import_service.py: remove dead dedup_key expression (F841)
- threat_actor_import_service.py: remove dead stix_to_actor expression (F841)
- data_source.py, seed_demo.py, campaign_scheduler_service.py,
  lolbas_import_service.py: wrap lines exceeding 120 chars (E501)
- d3fend_import_service.py: per-file E501 ignore (data file with long strings)

All 439 unit tests pass. ruff check app/ → All checks passed!

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-06-09 16:40:14 +02:00

430 lines
14 KiB
Python

"""Snapshot service — create, compare, and manage coverage snapshots.
Provides point-in-time coverage captures with normalised per-technique
storage and temporal comparison between any two snapshots.
Uses ``bulk_technique_scores`` so that snapshot creation runs in a fixed
number of SQL queries regardless of technique count.
"""
import logging
import uuid
from collections import defaultdict
from datetime import datetime, timedelta, timezone
from sqlalchemy import func
from sqlalchemy.orm import Session
from app.domain.errors import EntityNotFoundError
from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState
from app.models.enums import TechniqueStatus
from app.models.technique import Technique
from app.services.scoring_service import (
bulk_technique_scores,
calculate_organization_score,
)
logger = logging.getLogger(__name__)
# Coverage status ordering for snapshot delta comparisons (higher = better coverage)
_STATUS_ORDER: dict[str, int] = {
"not_evaluated": 0,
"not_covered": 1,
"in_progress": 2,
"partial": 3,
"validated": 4,
}
# ---------------------------------------------------------------------------
# Serialization and queries
# ---------------------------------------------------------------------------
def serialize_snapshot_summary(snap: CoverageSnapshot) -> dict:
"""Lightweight serialization for list views."""
return {
"id": str(snap.id),
"name": snap.name,
"organization_score": snap.organization_score,
"total_techniques": snap.total_techniques,
"validated_count": snap.validated_count,
"partial_count": snap.partial_count,
"not_covered_count": snap.not_covered_count,
"in_progress_count": snap.in_progress_count,
"not_evaluated_count": snap.not_evaluated_count,
"coverage_percentage": getattr(snap, "coverage_percentage", 0.0),
"by_tactic": getattr(snap, "by_tactic", None) or {},
"by_status": getattr(snap, "by_status", None) or {},
"stale_count": getattr(snap, "stale_count", 0),
"never_tested_count": getattr(snap, "never_tested_count", 0),
"created_by": str(snap.created_by) if snap.created_by else None,
"created_at": snap.created_at.isoformat() if snap.created_at else None,
}
def serialize_snapshot_detail(db: Session, snap: CoverageSnapshot) -> dict:
"""Full serialization including technique states."""
base = serialize_snapshot_summary(snap)
technique_states = (
db.query(SnapshotTechniqueState)
.filter(SnapshotTechniqueState.snapshot_id == snap.id)
.order_by(SnapshotTechniqueState.mitre_id)
.all()
)
base["technique_states"] = [
{
"mitre_id": s.mitre_id,
"technique_id": str(s.technique_id),
"status": s.status,
"score": s.score,
}
for s in technique_states
]
return base
def list_snapshots(
db: Session,
*,
offset: int = 0,
limit: int = 50,
) -> dict:
"""List coverage snapshots ordered by creation date (newest first)."""
query = db.query(CoverageSnapshot)
total = query.count()
snapshots = (
query
.order_by(CoverageSnapshot.created_at.desc())
.offset(offset)
.limit(limit)
.all()
)
return {
"total": total,
"offset": offset,
"limit": limit,
"items": [serialize_snapshot_summary(s) for s in snapshots],
}
def get_snapshot_or_raise(db: Session, snapshot_id: str) -> CoverageSnapshot:
"""Fetch snapshot by ID or raise EntityNotFoundError."""
try:
sid = uuid.UUID(snapshot_id)
except (ValueError, TypeError):
raise EntityNotFoundError("Snapshot", snapshot_id)
snapshot = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == sid).first()
if snapshot is None:
raise EntityNotFoundError("Snapshot", snapshot_id)
return snapshot
def get_snapshot_detail(db: Session, snapshot_id: str) -> dict:
"""Get detailed snapshot including per-technique states."""
snapshot = get_snapshot_or_raise(db, snapshot_id)
return serialize_snapshot_detail(db, snapshot)
def delete_snapshot(db: Session, snapshot_id: str) -> None:
"""Delete a snapshot. Does not commit — caller must commit."""
snapshot = get_snapshot_or_raise(db, snapshot_id)
db.delete(snapshot)
# ---------------------------------------------------------------------------
# Create snapshot
# ---------------------------------------------------------------------------
def create_snapshot(
db: Session,
name: str | None = None,
user_id: uuid.UUID | None = None,
) -> CoverageSnapshot:
"""Capture the current coverage state into a new snapshot.
1. Bulk-fetch all technique scores in 5 aggregated queries.
2. Walk the already-loaded techniques to count statuses.
3. Compute the org score from the same bulk data.
4. Persist a ``CoverageSnapshot`` with normalised
``SnapshotTechniqueState`` rows.
"""
scores_map = bulk_technique_scores(db)
techniques = db.query(Technique).all()
validated_count = 0
partial_count = 0
not_covered_count = 0
in_progress_count = 0
not_evaluated_count = 0
stale_count = 0
never_tested_count = 0
by_tactic: dict[str, dict] = defaultdict(
lambda: {"total": 0, "validated": 0, "partial": 0, "score_sum": 0.0}
)
by_status: dict[str, int] = defaultdict(int)
technique_rows: list[dict] = []
for tech in techniques:
status_value = (
tech.status_global.value
if isinstance(tech.status_global, TechniqueStatus)
else (tech.status_global or "not_evaluated")
)
if status_value == "validated":
validated_count += 1
elif status_value == "partial":
partial_count += 1
elif status_value == "not_covered":
not_covered_count += 1
elif status_value == "in_progress":
in_progress_count += 1
else:
not_evaluated_count += 1
entry = scores_map.get(tech.id, {})
score = entry.get("total_score", 0)
technique_rows.append({
"technique_id": tech.id,
"mitre_id": tech.mitre_id,
"status": status_value,
"score": score,
})
by_status[status_value] += 1
tactic_key = tech.tactic or "unknown"
bucket = by_tactic[tactic_key]
bucket["total"] += 1
bucket["score_sum"] += score
if status_value == "validated":
bucket["validated"] += 1
elif status_value == "partial":
bucket["partial"] += 1
if status_value == "not_evaluated":
never_tested_count += 1
if tech.review_required:
stale_count += 1
org_data = calculate_organization_score(db)
org_score = org_data.get("overall_score", 0)
total_techniques = len(techniques) or 1
coverage_pct = round((validated_count / total_techniques) * 100, 1)
by_tactic_out = {
tactic: {
"total": data["total"],
"validated": data["validated"],
"partial": data["partial"],
"average_score": round(data["score_sum"] / data["total"], 1) if data["total"] else 0,
}
for tactic, data in by_tactic.items()
}
snapshot = CoverageSnapshot(
name=name,
organization_score=org_score,
total_techniques=len(techniques),
validated_count=validated_count,
partial_count=partial_count,
not_covered_count=not_covered_count,
in_progress_count=in_progress_count,
not_evaluated_count=not_evaluated_count,
coverage_percentage=coverage_pct,
by_tactic=by_tactic_out,
by_status=dict(by_status),
stale_count=stale_count,
never_tested_count=never_tested_count,
created_by=user_id,
)
db.add(snapshot)
db.flush()
for row in technique_rows:
state = SnapshotTechniqueState(
snapshot_id=snapshot.id,
technique_id=row["technique_id"],
mitre_id=row["mitre_id"],
status=row["status"],
score=row["score"],
)
db.add(state)
db.commit()
db.refresh(snapshot)
logger.info(
"Snapshot '%s' created — %d techniques, org score %.1f",
snapshot.name or snapshot.id,
len(techniques),
org_score,
)
return snapshot
# ---------------------------------------------------------------------------
# Compare snapshots
# ---------------------------------------------------------------------------
def compare_snapshots(
db: Session,
snapshot_a_id: uuid.UUID,
snapshot_b_id: uuid.UUID,
) -> dict:
"""Compare two snapshots and return deltas.
Returns improved/worsened technique lists plus aggregate statistics.
"""
snap_a = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_a_id).first()
snap_b = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_b_id).first()
if not snap_a or not snap_b:
raise EntityNotFoundError("Snapshot", f"{snapshot_a_id} or {snapshot_b_id}")
# Build lookup dicts: mitre_id -> {status, score}
states_a = {
s.mitre_id: {"status": s.status, "score": s.score or 0}
for s in db.query(SnapshotTechniqueState)
.filter(SnapshotTechniqueState.snapshot_id == snapshot_a_id)
.all()
}
states_b = {
s.mitre_id: {"status": s.status, "score": s.score or 0}
for s in db.query(SnapshotTechniqueState)
.filter(SnapshotTechniqueState.snapshot_id == snapshot_b_id)
.all()
}
improved = []
worsened = []
unchanged_count = 0
all_mitre_ids = set(states_a.keys()) | set(states_b.keys())
for mitre_id in sorted(all_mitre_ids):
a = states_a.get(mitre_id, {"status": "not_evaluated", "score": 0})
b = states_b.get(mitre_id, {"status": "not_evaluated", "score": 0})
a_order = _STATUS_ORDER.get(a["status"], 0)
b_order = _STATUS_ORDER.get(b["status"], 0)
if b_order > a_order or (b_order == a_order and b["score"] > a["score"]):
improved.append({
"mitre_id": mitre_id,
"old_status": a["status"],
"new_status": b["status"],
"old_score": a["score"],
"new_score": b["score"],
})
elif b_order < a_order or (b_order == a_order and b["score"] < a["score"]):
worsened.append({
"mitre_id": mitre_id,
"old_status": a["status"],
"new_status": b["status"],
"old_score": a["score"],
"new_score": b["score"],
})
else:
unchanged_count += 1
def _snap_summary(snap: CoverageSnapshot) -> dict:
return {
"id": str(snap.id),
"name": snap.name,
"organization_score": snap.organization_score,
"total_techniques": snap.total_techniques,
"validated_count": snap.validated_count,
"partial_count": snap.partial_count,
"not_covered_count": snap.not_covered_count,
"in_progress_count": snap.in_progress_count,
"not_evaluated_count": snap.not_evaluated_count,
"created_at": snap.created_at.isoformat() if snap.created_at else None,
}
return {
"snapshot_a": _snap_summary(snap_a),
"snapshot_b": _snap_summary(snap_b),
"score_delta": round(snap_b.organization_score - snap_a.organization_score, 1),
"improved": improved,
"worsened": worsened,
"unchanged_count": unchanged_count,
"summary": {
"improved_count": len(improved),
"worsened_count": len(worsened),
"new_count": len(states_b.keys() - states_a.keys()),
},
}
# ---------------------------------------------------------------------------
# Coverage evolution (trends)
# ---------------------------------------------------------------------------
def get_coverage_evolution(db: Session, *, months: int = 12) -> list[dict]:
"""Return snapshot trend points for the last *months* months."""
cutoff = datetime.now(timezone.utc) - timedelta(days=months * 30)
snapshots = (
db.query(CoverageSnapshot)
.filter(CoverageSnapshot.created_at >= cutoff)
.order_by(CoverageSnapshot.created_at.asc())
.all()
)
return [
{
"date": snap.created_at.isoformat() if snap.created_at else None,
"name": snap.name,
"org_score": snap.organization_score,
"coverage_pct": getattr(snap, "coverage_percentage", 0.0),
"by_tactic": getattr(snap, "by_tactic", None) or {},
"by_status": getattr(snap, "by_status", None) or {},
"stale_count": getattr(snap, "stale_count", 0),
"never_tested_count": getattr(snap, "never_tested_count", 0),
"validated_count": snap.validated_count,
"total_techniques": snap.total_techniques,
}
for snap in snapshots
]
# ---------------------------------------------------------------------------
# Cleanup
# ---------------------------------------------------------------------------
def cleanup_old_snapshots(db: Session, keep_last: int = 52) -> int:
"""Delete oldest snapshots, keeping the most recent *keep_last*.
Returns the number of snapshots deleted.
"""
total = db.query(func.count(CoverageSnapshot.id)).scalar() or 0
if total <= keep_last:
return 0
to_delete = total - keep_last
old_snapshots = (
db.query(CoverageSnapshot)
.order_by(CoverageSnapshot.created_at.asc())
.limit(to_delete)
.all()
)
deleted = 0
for snap in old_snapshots:
db.delete(snap)
deleted += 1
db.commit()
logger.info("Snapshot cleanup — deleted %d old snapshots (kept %d)", deleted, keep_last)
return deleted