feat(phase-30): add coverage snapshots, temporal comparison and auto re-testing (T-230 to T-232)

This commit is contained in:
2026-02-10 08:34:29 +01:00
parent 2ac8e7f4a5
commit 4d124b42dd
20 changed files with 1517 additions and 4 deletions

View File

@@ -0,0 +1,205 @@
"""Snapshot endpoints — coverage snapshots CRUD and comparison.
Provides periodic and manual snapshots of the organisation's coverage
state, plus temporal comparison between any two snapshots.
"""
import logging
import uuid
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel
from sqlalchemy.orm import Session
from app.database import get_db
from app.dependencies.auth import get_current_user, require_any_role, require_role
from app.models.user import User
from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState
from app.services.snapshot_service import (
create_snapshot,
compare_snapshots,
cleanup_old_snapshots,
)
from app.services.audit_service import log_action
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/snapshots", tags=["snapshots"])
# ── Pydantic schemas ─────────────────────────────────────────────────
class SnapshotCreate(BaseModel):
name: Optional[str] = None
# ── Helpers ──────────────────────────────────────────────────────────
def _serialize_snapshot_summary(snap: CoverageSnapshot) -> dict:
"""Lightweight serialization for list views."""
return {
"id": str(snap.id),
"name": snap.name,
"organization_score": snap.organization_score,
"total_techniques": snap.total_techniques,
"validated_count": snap.validated_count,
"partial_count": snap.partial_count,
"not_covered_count": snap.not_covered_count,
"in_progress_count": snap.in_progress_count,
"not_evaluated_count": snap.not_evaluated_count,
"created_by": str(snap.created_by) if snap.created_by else None,
"created_at": snap.created_at.isoformat() if snap.created_at else None,
}
def _serialize_snapshot_detail(db: Session, snap: CoverageSnapshot) -> dict:
"""Full serialization including technique states."""
base = _serialize_snapshot_summary(snap)
technique_states = (
db.query(SnapshotTechniqueState)
.filter(SnapshotTechniqueState.snapshot_id == snap.id)
.order_by(SnapshotTechniqueState.mitre_id)
.all()
)
base["technique_states"] = [
{
"mitre_id": s.mitre_id,
"technique_id": str(s.technique_id),
"status": s.status,
"score": s.score,
}
for s in technique_states
]
return base
# ---------------------------------------------------------------------------
# GET /snapshots — List snapshots (paginated)
# ---------------------------------------------------------------------------
@router.get("")
def list_snapshots(
offset: int = Query(0, ge=0),
limit: int = Query(50, ge=1, le=200),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""List coverage snapshots ordered by creation date (newest first)."""
query = db.query(CoverageSnapshot)
total = query.count()
snapshots = (
query
.order_by(CoverageSnapshot.created_at.desc())
.offset(offset)
.limit(limit)
.all()
)
return {
"total": total,
"offset": offset,
"limit": limit,
"items": [_serialize_snapshot_summary(s) for s in snapshots],
}
# ---------------------------------------------------------------------------
# POST /snapshots — Create snapshot manually
# ---------------------------------------------------------------------------
@router.post("", status_code=201)
def create_snapshot_endpoint(
payload: SnapshotCreate,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead", "admin")),
):
"""Create a manual coverage snapshot with an optional name."""
snapshot = create_snapshot(db, name=payload.name, user_id=current_user.id)
log_action(
db,
user_id=current_user.id,
action="create_snapshot",
entity_type="snapshot",
entity_id=snapshot.id,
details={"name": snapshot.name, "score": snapshot.organization_score},
)
return _serialize_snapshot_summary(snapshot)
# ---------------------------------------------------------------------------
# GET /snapshots/compare — Compare two snapshots
# ---------------------------------------------------------------------------
@router.get("/compare")
def compare_snapshots_endpoint(
a: str = Query(..., description="Snapshot A ID"),
b: str = Query(..., description="Snapshot B ID"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Compare two snapshots showing improved, worsened, and unchanged techniques."""
try:
a_id = uuid.UUID(a)
b_id = uuid.UUID(b)
except ValueError:
raise HTTPException(status_code=400, detail="Invalid snapshot ID format")
result = compare_snapshots(db, a_id, b_id)
if "error" in result:
raise HTTPException(status_code=404, detail=result["error"])
return result
# ---------------------------------------------------------------------------
# GET /snapshots/{id} — Snapshot detail
# ---------------------------------------------------------------------------
@router.get("/{snapshot_id}")
def get_snapshot(
snapshot_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Get detailed snapshot information including per-technique states."""
snapshot = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_id).first()
if not snapshot:
raise HTTPException(status_code=404, detail="Snapshot not found")
return _serialize_snapshot_detail(db, snapshot)
# ---------------------------------------------------------------------------
# DELETE /snapshots/{id} — Delete snapshot (admin only)
# ---------------------------------------------------------------------------
@router.delete("/{snapshot_id}")
def delete_snapshot(
snapshot_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
"""Delete a snapshot (admin only)."""
snapshot = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_id).first()
if not snapshot:
raise HTTPException(status_code=404, detail="Snapshot not found")
log_action(
db,
user_id=current_user.id,
action="delete_snapshot",
entity_type="snapshot",
entity_id=snapshot.id,
details={"name": snapshot.name},
)
db.delete(snapshot)
db.commit()
return {"detail": "Snapshot deleted"}

View File

@@ -52,6 +52,8 @@ from app.services.test_workflow_service import (
validate_as_red_lead as wf_validate_red,
validate_as_blue_lead as wf_validate_blue,
reopen_test as wf_reopen,
handle_remediation_completed as wf_handle_remediation,
get_retest_chain as wf_get_retest_chain,
)
router = APIRouter(prefix="/tests", tags=["tests"])
@@ -546,9 +548,15 @@ def update_remediation(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Update remediation fields on a test (any authenticated user)."""
"""Update remediation fields on a test (any authenticated user).
When ``remediation_status`` transitions to ``'completed'``, an automatic
re-test is created (subject to ``MAX_RETEST_COUNT``).
"""
test = _get_test_or_404(db, test_id)
old_remediation_status = test.remediation_status
update_data = payload.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(test, field, value)
@@ -565,6 +573,13 @@ def update_remediation(
details={"updated_fields": list(update_data.keys())},
)
# Auto-create retest when remediation is marked completed
new_status = update_data.get("remediation_status")
if new_status == "completed" and old_remediation_status != "completed":
retest = wf_handle_remediation(db, test, current_user)
if retest:
db.refresh(test)
return test
@@ -603,3 +618,35 @@ def get_test_timeline(
}
for log in logs
]
# ---------------------------------------------------------------------------
# GET /tests/{id}/retest-chain — full retest chain
# ---------------------------------------------------------------------------
@router.get("/{test_id}/retest-chain")
def get_retest_chain(
test_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Return the full chain of retests (original + all retests) for a test."""
chain = wf_get_retest_chain(db, test_id)
if not chain:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Test not found")
return [
{
"id": str(t.id),
"name": t.name,
"state": t.state.value if t.state else None,
"retest_of": str(t.retest_of) if t.retest_of else None,
"retest_count": t.retest_count,
"result": t.result.value if t.result else None,
"detection_result": t.detection_result.value if t.detection_result else None,
"remediation_status": t.remediation_status,
"created_at": t.created_at.isoformat() if t.created_at else None,
}
for t in chain
]