feat(phase-30): add coverage snapshots, temporal comparison and auto re-testing (T-230 to T-232)

This commit is contained in:
2026-02-10 08:34:29 +01:00
parent 2ac8e7f4a5
commit 4d124b42dd
20 changed files with 1517 additions and 4 deletions

View File

@@ -0,0 +1,253 @@
"""Snapshot service — create, compare, and manage coverage snapshots.
Provides point-in-time coverage captures with normalised per-technique
storage and temporal comparison between any two snapshots.
"""
import logging
import uuid
from datetime import datetime
from sqlalchemy import func
from sqlalchemy.orm import Session
from app.models.technique import Technique
from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState
from app.models.enums import TechniqueStatus
from app.services.scoring_service import calculate_technique_score, calculate_organization_score
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Create snapshot
# ---------------------------------------------------------------------------
def create_snapshot(
db: Session,
name: str | None = None,
user_id: uuid.UUID | None = None,
) -> CoverageSnapshot:
"""Capture the current coverage state into a new snapshot.
1. Fetch every technique with its status and score.
2. Compute aggregate counts.
3. Persist a ``CoverageSnapshot`` with normalised
``SnapshotTechniqueState`` rows.
"""
techniques = db.query(Technique).all()
# Aggregate counters
validated_count = 0
partial_count = 0
not_covered_count = 0
in_progress_count = 0
not_evaluated_count = 0
technique_rows: list[dict] = []
for tech in techniques:
status_value = (
tech.status_global.value
if isinstance(tech.status_global, TechniqueStatus)
else (tech.status_global or "not_evaluated")
)
# Count by status
if status_value == "validated":
validated_count += 1
elif status_value == "partial":
partial_count += 1
elif status_value == "not_covered":
not_covered_count += 1
elif status_value == "in_progress":
in_progress_count += 1
else:
not_evaluated_count += 1
# Compute technique score
score_data = calculate_technique_score(tech, db)
technique_rows.append({
"technique_id": tech.id,
"mitre_id": tech.mitre_id,
"status": status_value,
"score": score_data["total_score"],
})
# Organization score
org_data = calculate_organization_score(db)
org_score = org_data.get("overall_score", 0)
# Create the snapshot
snapshot = CoverageSnapshot(
name=name,
organization_score=org_score,
total_techniques=len(techniques),
validated_count=validated_count,
partial_count=partial_count,
not_covered_count=not_covered_count,
in_progress_count=in_progress_count,
not_evaluated_count=not_evaluated_count,
created_by=user_id,
)
db.add(snapshot)
db.flush() # get snapshot.id
# Create normalised technique state rows
for row in technique_rows:
state = SnapshotTechniqueState(
snapshot_id=snapshot.id,
technique_id=row["technique_id"],
mitre_id=row["mitre_id"],
status=row["status"],
score=row["score"],
)
db.add(state)
db.commit()
db.refresh(snapshot)
logger.info(
"Snapshot '%s' created — %d techniques, org score %.1f",
snapshot.name or snapshot.id,
len(techniques),
org_score,
)
return snapshot
# ---------------------------------------------------------------------------
# Compare snapshots
# ---------------------------------------------------------------------------
def compare_snapshots(
db: Session,
snapshot_a_id: uuid.UUID,
snapshot_b_id: uuid.UUID,
) -> dict:
"""Compare two snapshots and return deltas.
Returns improved/worsened technique lists plus aggregate statistics.
"""
snap_a = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_a_id).first()
snap_b = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_b_id).first()
if not snap_a or not snap_b:
return {"error": "One or both snapshots not found"}
# Build lookup dicts: mitre_id -> {status, score}
states_a = {
s.mitre_id: {"status": s.status, "score": s.score or 0}
for s in db.query(SnapshotTechniqueState)
.filter(SnapshotTechniqueState.snapshot_id == snapshot_a_id)
.all()
}
states_b = {
s.mitre_id: {"status": s.status, "score": s.score or 0}
for s in db.query(SnapshotTechniqueState)
.filter(SnapshotTechniqueState.snapshot_id == snapshot_b_id)
.all()
}
# Status priority for comparison
STATUS_ORDER = {
"not_evaluated": 0,
"not_covered": 1,
"in_progress": 2,
"partial": 3,
"validated": 4,
}
improved = []
worsened = []
unchanged_count = 0
all_mitre_ids = set(states_a.keys()) | set(states_b.keys())
for mitre_id in sorted(all_mitre_ids):
a = states_a.get(mitre_id, {"status": "not_evaluated", "score": 0})
b = states_b.get(mitre_id, {"status": "not_evaluated", "score": 0})
a_order = STATUS_ORDER.get(a["status"], 0)
b_order = STATUS_ORDER.get(b["status"], 0)
if b_order > a_order or (b_order == a_order and b["score"] > a["score"]):
improved.append({
"mitre_id": mitre_id,
"old_status": a["status"],
"new_status": b["status"],
"old_score": a["score"],
"new_score": b["score"],
})
elif b_order < a_order or (b_order == a_order and b["score"] < a["score"]):
worsened.append({
"mitre_id": mitre_id,
"old_status": a["status"],
"new_status": b["status"],
"old_score": a["score"],
"new_score": b["score"],
})
else:
unchanged_count += 1
def _snap_summary(snap: CoverageSnapshot) -> dict:
return {
"id": str(snap.id),
"name": snap.name,
"organization_score": snap.organization_score,
"total_techniques": snap.total_techniques,
"validated_count": snap.validated_count,
"partial_count": snap.partial_count,
"not_covered_count": snap.not_covered_count,
"in_progress_count": snap.in_progress_count,
"not_evaluated_count": snap.not_evaluated_count,
"created_at": snap.created_at.isoformat() if snap.created_at else None,
}
return {
"snapshot_a": _snap_summary(snap_a),
"snapshot_b": _snap_summary(snap_b),
"score_delta": round(snap_b.organization_score - snap_a.organization_score, 1),
"improved": improved,
"worsened": worsened,
"unchanged_count": unchanged_count,
"summary": {
"improved_count": len(improved),
"worsened_count": len(worsened),
"new_count": len(states_b.keys() - states_a.keys()),
},
}
# ---------------------------------------------------------------------------
# Cleanup
# ---------------------------------------------------------------------------
def cleanup_old_snapshots(db: Session, keep_last: int = 52) -> int:
"""Delete oldest snapshots, keeping the most recent *keep_last*.
Returns the number of snapshots deleted.
"""
total = db.query(func.count(CoverageSnapshot.id)).scalar() or 0
if total <= keep_last:
return 0
to_delete = total - keep_last
old_snapshots = (
db.query(CoverageSnapshot)
.order_by(CoverageSnapshot.created_at.asc())
.limit(to_delete)
.all()
)
deleted = 0
for snap in old_snapshots:
db.delete(snap)
deleted += 1
db.commit()
logger.info("Snapshot cleanup — deleted %d old snapshots (kept %d)", deleted, keep_last)
return deleted

View File

@@ -16,11 +16,12 @@ from datetime import datetime
from fastapi import HTTPException, status
from sqlalchemy.orm import Session
from app.config import settings
from app.models.enums import TestState
from app.models.test import Test
from app.models.user import User
from app.services.audit_service import log_action
from app.services.notification_service import notify_test_state_change
from app.services.notification_service import notify_test_state_change, create_notification
# ---------------------------------------------------------------------------
# Valid transition map
@@ -298,6 +299,131 @@ def check_dual_validation(db: Session, test: Test) -> Test:
return test
def handle_remediation_completed(db: Session, test: Test, user: User) -> Test | None:
"""Create a re-test when remediation is completed.
When a test's remediation_status changes to 'completed', this function
creates a new test (retest) with the same base data to verify that the
fix was effective.
Prevents infinite loops by enforcing ``MAX_RETEST_COUNT``.
Returns the new retest or *None* if the limit was reached.
"""
# Always reference the original test, not an intermediate retest
original_test_id = test.retest_of or test.id
if test.retest_count >= settings.MAX_RETEST_COUNT:
# Max retests reached — notify and bail out
if test.created_by:
create_notification(
db,
user_id=test.created_by,
type="max_retests_reached",
title="Maximum retests reached",
message=(
f'Test "{test.name}" has reached the maximum of '
f'{settings.MAX_RETEST_COUNT} retests. Manual review required.'
),
entity_type="test",
entity_id=test.id,
)
log_action(
db,
user_id=user.id,
action="max_retests_reached",
entity_type="test",
entity_id=test.id,
details={
"retest_count": test.retest_count,
"max_allowed": settings.MAX_RETEST_COUNT,
"original_test_id": str(original_test_id),
},
)
return None
retest = Test(
technique_id=test.technique_id,
name=f"[Retest #{test.retest_count + 1}] {test.name.replace(f'[Retest #{test.retest_count}] ', '')}",
description=test.description,
platform=test.platform,
procedure_text=test.procedure_text,
tool_used=test.tool_used,
state=TestState.draft,
created_by=test.created_by,
retest_of=original_test_id,
retest_count=test.retest_count + 1,
)
db.add(retest)
db.flush()
log_action(
db,
user_id=user.id,
action="create_retest",
entity_type="test",
entity_id=retest.id,
details={
"original_test_id": str(original_test_id),
"retest_number": retest.retest_count,
"source_test_id": str(test.id),
},
)
# Notify the test creator and any red_tech users
if test.created_by:
create_notification(
db,
user_id=test.created_by,
type="retest_created",
title="Re-test created",
message=(
f'A re-test has been automatically created for "{test.name}" '
f'after remediation was completed.'
),
entity_type="test",
entity_id=retest.id,
)
db.commit()
db.refresh(retest)
return retest
def get_retest_chain(db: Session, test_id) -> list[Test]:
"""Return the full chain of retests for a given test.
Includes the original test and all subsequent retests, ordered
by retest_count.
"""
import uuid as _uuid
tid = _uuid.UUID(str(test_id)) if not isinstance(test_id, _uuid.UUID) else test_id
# Find the original test first
test = db.query(Test).filter(Test.id == tid).first()
if not test:
return []
original_id = test.retest_of or test.id
# Get original
original = db.query(Test).filter(Test.id == original_id).first()
if not original:
return [test]
# Get all retests of the original
retests = (
db.query(Test)
.filter(Test.retest_of == original_id)
.order_by(Test.retest_count)
.all()
)
return [original] + retests
def reopen_test(db: Session, test: Test, user: User) -> Test:
"""Move a ``rejected`` test back to ``draft``, clearing validation fields.