"""Snapshot endpoints — coverage snapshots CRUD and comparison. Provides periodic and manual snapshots of the organisation's coverage state, plus temporal comparison between any two snapshots. """ import logging import uuid from typing import Optional from fastapi import APIRouter, Depends, HTTPException, Query from pydantic import BaseModel from sqlalchemy.orm import Session from app.database import get_db from app.dependencies.auth import get_current_user, require_any_role, require_role from app.models.user import User from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState from app.services.snapshot_service import ( create_snapshot, compare_snapshots, cleanup_old_snapshots, ) from app.services.audit_service import log_action logger = logging.getLogger(__name__) router = APIRouter(prefix="/snapshots", tags=["snapshots"]) # ── Pydantic schemas ───────────────────────────────────────────────── class SnapshotCreate(BaseModel): name: Optional[str] = None # ── Helpers ────────────────────────────────────────────────────────── def _serialize_snapshot_summary(snap: CoverageSnapshot) -> dict: """Lightweight serialization for list views.""" return { "id": str(snap.id), "name": snap.name, "organization_score": snap.organization_score, "total_techniques": snap.total_techniques, "validated_count": snap.validated_count, "partial_count": snap.partial_count, "not_covered_count": snap.not_covered_count, "in_progress_count": snap.in_progress_count, "not_evaluated_count": snap.not_evaluated_count, "created_by": str(snap.created_by) if snap.created_by else None, "created_at": snap.created_at.isoformat() if snap.created_at else None, } def _serialize_snapshot_detail(db: Session, snap: CoverageSnapshot) -> dict: """Full serialization including technique states.""" base = _serialize_snapshot_summary(snap) technique_states = ( db.query(SnapshotTechniqueState) .filter(SnapshotTechniqueState.snapshot_id == snap.id) .order_by(SnapshotTechniqueState.mitre_id) .all() ) base["technique_states"] = [ { "mitre_id": s.mitre_id, "technique_id": str(s.technique_id), "status": s.status, "score": s.score, } for s in technique_states ] return base # --------------------------------------------------------------------------- # GET /snapshots — List snapshots (paginated) # --------------------------------------------------------------------------- @router.get("") def list_snapshots( offset: int = Query(0, ge=0), limit: int = Query(50, ge=1, le=200), db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ): """List coverage snapshots ordered by creation date (newest first).""" query = db.query(CoverageSnapshot) total = query.count() snapshots = ( query .order_by(CoverageSnapshot.created_at.desc()) .offset(offset) .limit(limit) .all() ) return { "total": total, "offset": offset, "limit": limit, "items": [_serialize_snapshot_summary(s) for s in snapshots], } # --------------------------------------------------------------------------- # POST /snapshots — Create snapshot manually # --------------------------------------------------------------------------- @router.post("", status_code=201) def create_snapshot_endpoint( payload: SnapshotCreate, db: Session = Depends(get_db), current_user: User = Depends(require_any_role("red_lead", "blue_lead", "admin")), ): """Create a manual coverage snapshot with an optional name.""" snapshot = create_snapshot(db, name=payload.name, user_id=current_user.id) log_action( db, user_id=current_user.id, action="create_snapshot", entity_type="snapshot", entity_id=snapshot.id, details={"name": snapshot.name, "score": snapshot.organization_score}, ) return _serialize_snapshot_summary(snapshot) # --------------------------------------------------------------------------- # GET /snapshots/compare — Compare two snapshots # --------------------------------------------------------------------------- @router.get("/compare") def compare_snapshots_endpoint( a: str = Query(..., description="Snapshot A ID"), b: str = Query(..., description="Snapshot B ID"), db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ): """Compare two snapshots showing improved, worsened, and unchanged techniques.""" try: a_id = uuid.UUID(a) b_id = uuid.UUID(b) except ValueError: raise HTTPException(status_code=400, detail="Invalid snapshot ID format") result = compare_snapshots(db, a_id, b_id) if "error" in result: raise HTTPException(status_code=404, detail=result["error"]) return result # --------------------------------------------------------------------------- # GET /snapshots/{id} — Snapshot detail # --------------------------------------------------------------------------- @router.get("/{snapshot_id}") def get_snapshot( snapshot_id: str, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ): """Get detailed snapshot information including per-technique states.""" snapshot = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_id).first() if not snapshot: raise HTTPException(status_code=404, detail="Snapshot not found") return _serialize_snapshot_detail(db, snapshot) # --------------------------------------------------------------------------- # DELETE /snapshots/{id} — Delete snapshot (admin only) # --------------------------------------------------------------------------- @router.delete("/{snapshot_id}") def delete_snapshot( snapshot_id: str, db: Session = Depends(get_db), current_user: User = Depends(require_role("admin")), ): """Delete a snapshot (admin only).""" snapshot = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_id).first() if not snapshot: raise HTTPException(status_code=404, detail="Snapshot not found") log_action( db, user_id=current_user.id, action="delete_snapshot", entity_type="snapshot", entity_id=snapshot.id, details={"name": snapshot.name}, ) db.delete(snapshot) db.commit() return {"detail": "Snapshot deleted"}