feat(phase-30): add coverage snapshots, temporal comparison and auto re-testing (T-230 to T-232)
This commit is contained in:
253
backend/app/services/snapshot_service.py
Normal file
253
backend/app/services/snapshot_service.py
Normal file
@@ -0,0 +1,253 @@
|
||||
"""Snapshot service — create, compare, and manage coverage snapshots.
|
||||
|
||||
Provides point-in-time coverage captures with normalised per-technique
|
||||
storage and temporal comparison between any two snapshots.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.technique import Technique
|
||||
from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState
|
||||
from app.models.enums import TechniqueStatus
|
||||
from app.services.scoring_service import calculate_technique_score, calculate_organization_score
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Create snapshot
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def create_snapshot(
|
||||
db: Session,
|
||||
name: str | None = None,
|
||||
user_id: uuid.UUID | None = None,
|
||||
) -> CoverageSnapshot:
|
||||
"""Capture the current coverage state into a new snapshot.
|
||||
|
||||
1. Fetch every technique with its status and score.
|
||||
2. Compute aggregate counts.
|
||||
3. Persist a ``CoverageSnapshot`` with normalised
|
||||
``SnapshotTechniqueState`` rows.
|
||||
"""
|
||||
techniques = db.query(Technique).all()
|
||||
|
||||
# Aggregate counters
|
||||
validated_count = 0
|
||||
partial_count = 0
|
||||
not_covered_count = 0
|
||||
in_progress_count = 0
|
||||
not_evaluated_count = 0
|
||||
|
||||
technique_rows: list[dict] = []
|
||||
|
||||
for tech in techniques:
|
||||
status_value = (
|
||||
tech.status_global.value
|
||||
if isinstance(tech.status_global, TechniqueStatus)
|
||||
else (tech.status_global or "not_evaluated")
|
||||
)
|
||||
|
||||
# Count by status
|
||||
if status_value == "validated":
|
||||
validated_count += 1
|
||||
elif status_value == "partial":
|
||||
partial_count += 1
|
||||
elif status_value == "not_covered":
|
||||
not_covered_count += 1
|
||||
elif status_value == "in_progress":
|
||||
in_progress_count += 1
|
||||
else:
|
||||
not_evaluated_count += 1
|
||||
|
||||
# Compute technique score
|
||||
score_data = calculate_technique_score(tech, db)
|
||||
technique_rows.append({
|
||||
"technique_id": tech.id,
|
||||
"mitre_id": tech.mitre_id,
|
||||
"status": status_value,
|
||||
"score": score_data["total_score"],
|
||||
})
|
||||
|
||||
# Organization score
|
||||
org_data = calculate_organization_score(db)
|
||||
org_score = org_data.get("overall_score", 0)
|
||||
|
||||
# Create the snapshot
|
||||
snapshot = CoverageSnapshot(
|
||||
name=name,
|
||||
organization_score=org_score,
|
||||
total_techniques=len(techniques),
|
||||
validated_count=validated_count,
|
||||
partial_count=partial_count,
|
||||
not_covered_count=not_covered_count,
|
||||
in_progress_count=in_progress_count,
|
||||
not_evaluated_count=not_evaluated_count,
|
||||
created_by=user_id,
|
||||
)
|
||||
db.add(snapshot)
|
||||
db.flush() # get snapshot.id
|
||||
|
||||
# Create normalised technique state rows
|
||||
for row in technique_rows:
|
||||
state = SnapshotTechniqueState(
|
||||
snapshot_id=snapshot.id,
|
||||
technique_id=row["technique_id"],
|
||||
mitre_id=row["mitre_id"],
|
||||
status=row["status"],
|
||||
score=row["score"],
|
||||
)
|
||||
db.add(state)
|
||||
|
||||
db.commit()
|
||||
db.refresh(snapshot)
|
||||
|
||||
logger.info(
|
||||
"Snapshot '%s' created — %d techniques, org score %.1f",
|
||||
snapshot.name or snapshot.id,
|
||||
len(techniques),
|
||||
org_score,
|
||||
)
|
||||
return snapshot
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Compare snapshots
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def compare_snapshots(
|
||||
db: Session,
|
||||
snapshot_a_id: uuid.UUID,
|
||||
snapshot_b_id: uuid.UUID,
|
||||
) -> dict:
|
||||
"""Compare two snapshots and return deltas.
|
||||
|
||||
Returns improved/worsened technique lists plus aggregate statistics.
|
||||
"""
|
||||
snap_a = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_a_id).first()
|
||||
snap_b = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_b_id).first()
|
||||
|
||||
if not snap_a or not snap_b:
|
||||
return {"error": "One or both snapshots not found"}
|
||||
|
||||
# Build lookup dicts: mitre_id -> {status, score}
|
||||
states_a = {
|
||||
s.mitre_id: {"status": s.status, "score": s.score or 0}
|
||||
for s in db.query(SnapshotTechniqueState)
|
||||
.filter(SnapshotTechniqueState.snapshot_id == snapshot_a_id)
|
||||
.all()
|
||||
}
|
||||
states_b = {
|
||||
s.mitre_id: {"status": s.status, "score": s.score or 0}
|
||||
for s in db.query(SnapshotTechniqueState)
|
||||
.filter(SnapshotTechniqueState.snapshot_id == snapshot_b_id)
|
||||
.all()
|
||||
}
|
||||
|
||||
# Status priority for comparison
|
||||
STATUS_ORDER = {
|
||||
"not_evaluated": 0,
|
||||
"not_covered": 1,
|
||||
"in_progress": 2,
|
||||
"partial": 3,
|
||||
"validated": 4,
|
||||
}
|
||||
|
||||
improved = []
|
||||
worsened = []
|
||||
unchanged_count = 0
|
||||
|
||||
all_mitre_ids = set(states_a.keys()) | set(states_b.keys())
|
||||
|
||||
for mitre_id in sorted(all_mitre_ids):
|
||||
a = states_a.get(mitre_id, {"status": "not_evaluated", "score": 0})
|
||||
b = states_b.get(mitre_id, {"status": "not_evaluated", "score": 0})
|
||||
|
||||
a_order = STATUS_ORDER.get(a["status"], 0)
|
||||
b_order = STATUS_ORDER.get(b["status"], 0)
|
||||
|
||||
if b_order > a_order or (b_order == a_order and b["score"] > a["score"]):
|
||||
improved.append({
|
||||
"mitre_id": mitre_id,
|
||||
"old_status": a["status"],
|
||||
"new_status": b["status"],
|
||||
"old_score": a["score"],
|
||||
"new_score": b["score"],
|
||||
})
|
||||
elif b_order < a_order or (b_order == a_order and b["score"] < a["score"]):
|
||||
worsened.append({
|
||||
"mitre_id": mitre_id,
|
||||
"old_status": a["status"],
|
||||
"new_status": b["status"],
|
||||
"old_score": a["score"],
|
||||
"new_score": b["score"],
|
||||
})
|
||||
else:
|
||||
unchanged_count += 1
|
||||
|
||||
def _snap_summary(snap: CoverageSnapshot) -> dict:
|
||||
return {
|
||||
"id": str(snap.id),
|
||||
"name": snap.name,
|
||||
"organization_score": snap.organization_score,
|
||||
"total_techniques": snap.total_techniques,
|
||||
"validated_count": snap.validated_count,
|
||||
"partial_count": snap.partial_count,
|
||||
"not_covered_count": snap.not_covered_count,
|
||||
"in_progress_count": snap.in_progress_count,
|
||||
"not_evaluated_count": snap.not_evaluated_count,
|
||||
"created_at": snap.created_at.isoformat() if snap.created_at else None,
|
||||
}
|
||||
|
||||
return {
|
||||
"snapshot_a": _snap_summary(snap_a),
|
||||
"snapshot_b": _snap_summary(snap_b),
|
||||
"score_delta": round(snap_b.organization_score - snap_a.organization_score, 1),
|
||||
"improved": improved,
|
||||
"worsened": worsened,
|
||||
"unchanged_count": unchanged_count,
|
||||
"summary": {
|
||||
"improved_count": len(improved),
|
||||
"worsened_count": len(worsened),
|
||||
"new_count": len(states_b.keys() - states_a.keys()),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Cleanup
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def cleanup_old_snapshots(db: Session, keep_last: int = 52) -> int:
|
||||
"""Delete oldest snapshots, keeping the most recent *keep_last*.
|
||||
|
||||
Returns the number of snapshots deleted.
|
||||
"""
|
||||
total = db.query(func.count(CoverageSnapshot.id)).scalar() or 0
|
||||
if total <= keep_last:
|
||||
return 0
|
||||
|
||||
to_delete = total - keep_last
|
||||
old_snapshots = (
|
||||
db.query(CoverageSnapshot)
|
||||
.order_by(CoverageSnapshot.created_at.asc())
|
||||
.limit(to_delete)
|
||||
.all()
|
||||
)
|
||||
|
||||
deleted = 0
|
||||
for snap in old_snapshots:
|
||||
db.delete(snap)
|
||||
deleted += 1
|
||||
|
||||
db.commit()
|
||||
logger.info("Snapshot cleanup — deleted %d old snapshots (kept %d)", deleted, keep_last)
|
||||
return deleted
|
||||
@@ -16,11 +16,12 @@ from datetime import datetime
|
||||
from fastapi import HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.config import settings
|
||||
from app.models.enums import TestState
|
||||
from app.models.test import Test
|
||||
from app.models.user import User
|
||||
from app.services.audit_service import log_action
|
||||
from app.services.notification_service import notify_test_state_change
|
||||
from app.services.notification_service import notify_test_state_change, create_notification
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Valid transition map
|
||||
@@ -298,6 +299,131 @@ def check_dual_validation(db: Session, test: Test) -> Test:
|
||||
return test
|
||||
|
||||
|
||||
def handle_remediation_completed(db: Session, test: Test, user: User) -> Test | None:
|
||||
"""Create a re-test when remediation is completed.
|
||||
|
||||
When a test's remediation_status changes to 'completed', this function
|
||||
creates a new test (retest) with the same base data to verify that the
|
||||
fix was effective.
|
||||
|
||||
Prevents infinite loops by enforcing ``MAX_RETEST_COUNT``.
|
||||
|
||||
Returns the new retest or *None* if the limit was reached.
|
||||
"""
|
||||
# Always reference the original test, not an intermediate retest
|
||||
original_test_id = test.retest_of or test.id
|
||||
|
||||
if test.retest_count >= settings.MAX_RETEST_COUNT:
|
||||
# Max retests reached — notify and bail out
|
||||
if test.created_by:
|
||||
create_notification(
|
||||
db,
|
||||
user_id=test.created_by,
|
||||
type="max_retests_reached",
|
||||
title="Maximum retests reached",
|
||||
message=(
|
||||
f'Test "{test.name}" has reached the maximum of '
|
||||
f'{settings.MAX_RETEST_COUNT} retests. Manual review required.'
|
||||
),
|
||||
entity_type="test",
|
||||
entity_id=test.id,
|
||||
)
|
||||
|
||||
log_action(
|
||||
db,
|
||||
user_id=user.id,
|
||||
action="max_retests_reached",
|
||||
entity_type="test",
|
||||
entity_id=test.id,
|
||||
details={
|
||||
"retest_count": test.retest_count,
|
||||
"max_allowed": settings.MAX_RETEST_COUNT,
|
||||
"original_test_id": str(original_test_id),
|
||||
},
|
||||
)
|
||||
return None
|
||||
|
||||
retest = Test(
|
||||
technique_id=test.technique_id,
|
||||
name=f"[Retest #{test.retest_count + 1}] {test.name.replace(f'[Retest #{test.retest_count}] ', '')}",
|
||||
description=test.description,
|
||||
platform=test.platform,
|
||||
procedure_text=test.procedure_text,
|
||||
tool_used=test.tool_used,
|
||||
state=TestState.draft,
|
||||
created_by=test.created_by,
|
||||
retest_of=original_test_id,
|
||||
retest_count=test.retest_count + 1,
|
||||
)
|
||||
db.add(retest)
|
||||
db.flush()
|
||||
|
||||
log_action(
|
||||
db,
|
||||
user_id=user.id,
|
||||
action="create_retest",
|
||||
entity_type="test",
|
||||
entity_id=retest.id,
|
||||
details={
|
||||
"original_test_id": str(original_test_id),
|
||||
"retest_number": retest.retest_count,
|
||||
"source_test_id": str(test.id),
|
||||
},
|
||||
)
|
||||
|
||||
# Notify the test creator and any red_tech users
|
||||
if test.created_by:
|
||||
create_notification(
|
||||
db,
|
||||
user_id=test.created_by,
|
||||
type="retest_created",
|
||||
title="Re-test created",
|
||||
message=(
|
||||
f'A re-test has been automatically created for "{test.name}" '
|
||||
f'after remediation was completed.'
|
||||
),
|
||||
entity_type="test",
|
||||
entity_id=retest.id,
|
||||
)
|
||||
|
||||
db.commit()
|
||||
db.refresh(retest)
|
||||
return retest
|
||||
|
||||
|
||||
def get_retest_chain(db: Session, test_id) -> list[Test]:
|
||||
"""Return the full chain of retests for a given test.
|
||||
|
||||
Includes the original test and all subsequent retests, ordered
|
||||
by retest_count.
|
||||
"""
|
||||
import uuid as _uuid
|
||||
|
||||
tid = _uuid.UUID(str(test_id)) if not isinstance(test_id, _uuid.UUID) else test_id
|
||||
|
||||
# Find the original test first
|
||||
test = db.query(Test).filter(Test.id == tid).first()
|
||||
if not test:
|
||||
return []
|
||||
|
||||
original_id = test.retest_of or test.id
|
||||
|
||||
# Get original
|
||||
original = db.query(Test).filter(Test.id == original_id).first()
|
||||
if not original:
|
||||
return [test]
|
||||
|
||||
# Get all retests of the original
|
||||
retests = (
|
||||
db.query(Test)
|
||||
.filter(Test.retest_of == original_id)
|
||||
.order_by(Test.retest_count)
|
||||
.all()
|
||||
)
|
||||
|
||||
return [original] + retests
|
||||
|
||||
|
||||
def reopen_test(db: Session, test: Test, user: User) -> Test:
|
||||
"""Move a ``rejected`` test back to ``draft``, clearing validation fields.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user