206 lines
6.7 KiB
Python
206 lines
6.7 KiB
Python
"""Snapshot endpoints — coverage snapshots CRUD and comparison.
|
|
|
|
Provides periodic and manual snapshots of the organisation's coverage
|
|
state, plus temporal comparison between any two snapshots.
|
|
"""
|
|
|
|
import logging
|
|
import uuid
|
|
from typing import Optional
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
|
from pydantic import BaseModel
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.database import get_db
|
|
from app.dependencies.auth import get_current_user, require_any_role, require_role
|
|
from app.models.user import User
|
|
from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState
|
|
from app.services.snapshot_service import (
|
|
create_snapshot,
|
|
compare_snapshots,
|
|
cleanup_old_snapshots,
|
|
)
|
|
from app.services.audit_service import log_action
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter(prefix="/snapshots", tags=["snapshots"])
|
|
|
|
|
|
# ── Pydantic schemas ─────────────────────────────────────────────────
|
|
|
|
class SnapshotCreate(BaseModel):
|
|
name: Optional[str] = None
|
|
|
|
|
|
# ── Helpers ──────────────────────────────────────────────────────────
|
|
|
|
def _serialize_snapshot_summary(snap: CoverageSnapshot) -> dict:
|
|
"""Lightweight serialization for list views."""
|
|
return {
|
|
"id": str(snap.id),
|
|
"name": snap.name,
|
|
"organization_score": snap.organization_score,
|
|
"total_techniques": snap.total_techniques,
|
|
"validated_count": snap.validated_count,
|
|
"partial_count": snap.partial_count,
|
|
"not_covered_count": snap.not_covered_count,
|
|
"in_progress_count": snap.in_progress_count,
|
|
"not_evaluated_count": snap.not_evaluated_count,
|
|
"created_by": str(snap.created_by) if snap.created_by else None,
|
|
"created_at": snap.created_at.isoformat() if snap.created_at else None,
|
|
}
|
|
|
|
|
|
def _serialize_snapshot_detail(db: Session, snap: CoverageSnapshot) -> dict:
|
|
"""Full serialization including technique states."""
|
|
base = _serialize_snapshot_summary(snap)
|
|
|
|
technique_states = (
|
|
db.query(SnapshotTechniqueState)
|
|
.filter(SnapshotTechniqueState.snapshot_id == snap.id)
|
|
.order_by(SnapshotTechniqueState.mitre_id)
|
|
.all()
|
|
)
|
|
|
|
base["technique_states"] = [
|
|
{
|
|
"mitre_id": s.mitre_id,
|
|
"technique_id": str(s.technique_id),
|
|
"status": s.status,
|
|
"score": s.score,
|
|
}
|
|
for s in technique_states
|
|
]
|
|
return base
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# GET /snapshots — List snapshots (paginated)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@router.get("")
|
|
def list_snapshots(
|
|
offset: int = Query(0, ge=0),
|
|
limit: int = Query(50, ge=1, le=200),
|
|
db: Session = Depends(get_db),
|
|
current_user: User = Depends(get_current_user),
|
|
):
|
|
"""List coverage snapshots ordered by creation date (newest first)."""
|
|
query = db.query(CoverageSnapshot)
|
|
total = query.count()
|
|
|
|
snapshots = (
|
|
query
|
|
.order_by(CoverageSnapshot.created_at.desc())
|
|
.offset(offset)
|
|
.limit(limit)
|
|
.all()
|
|
)
|
|
|
|
return {
|
|
"total": total,
|
|
"offset": offset,
|
|
"limit": limit,
|
|
"items": [_serialize_snapshot_summary(s) for s in snapshots],
|
|
}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# POST /snapshots — Create snapshot manually
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@router.post("", status_code=201)
|
|
def create_snapshot_endpoint(
|
|
payload: SnapshotCreate,
|
|
db: Session = Depends(get_db),
|
|
current_user: User = Depends(require_any_role("red_lead", "blue_lead", "admin")),
|
|
):
|
|
"""Create a manual coverage snapshot with an optional name."""
|
|
snapshot = create_snapshot(db, name=payload.name, user_id=current_user.id)
|
|
|
|
log_action(
|
|
db,
|
|
user_id=current_user.id,
|
|
action="create_snapshot",
|
|
entity_type="snapshot",
|
|
entity_id=snapshot.id,
|
|
details={"name": snapshot.name, "score": snapshot.organization_score},
|
|
)
|
|
|
|
return _serialize_snapshot_summary(snapshot)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# GET /snapshots/compare — Compare two snapshots
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@router.get("/compare")
|
|
def compare_snapshots_endpoint(
|
|
a: str = Query(..., description="Snapshot A ID"),
|
|
b: str = Query(..., description="Snapshot B ID"),
|
|
db: Session = Depends(get_db),
|
|
current_user: User = Depends(get_current_user),
|
|
):
|
|
"""Compare two snapshots showing improved, worsened, and unchanged techniques."""
|
|
try:
|
|
a_id = uuid.UUID(a)
|
|
b_id = uuid.UUID(b)
|
|
except ValueError:
|
|
raise HTTPException(status_code=400, detail="Invalid snapshot ID format")
|
|
|
|
result = compare_snapshots(db, a_id, b_id)
|
|
if "error" in result:
|
|
raise HTTPException(status_code=404, detail=result["error"])
|
|
|
|
return result
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# GET /snapshots/{id} — Snapshot detail
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@router.get("/{snapshot_id}")
|
|
def get_snapshot(
|
|
snapshot_id: str,
|
|
db: Session = Depends(get_db),
|
|
current_user: User = Depends(get_current_user),
|
|
):
|
|
"""Get detailed snapshot information including per-technique states."""
|
|
snapshot = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_id).first()
|
|
if not snapshot:
|
|
raise HTTPException(status_code=404, detail="Snapshot not found")
|
|
|
|
return _serialize_snapshot_detail(db, snapshot)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# DELETE /snapshots/{id} — Delete snapshot (admin only)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@router.delete("/{snapshot_id}")
|
|
def delete_snapshot(
|
|
snapshot_id: str,
|
|
db: Session = Depends(get_db),
|
|
current_user: User = Depends(require_role("admin")),
|
|
):
|
|
"""Delete a snapshot (admin only)."""
|
|
snapshot = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_id).first()
|
|
if not snapshot:
|
|
raise HTTPException(status_code=404, detail="Snapshot not found")
|
|
|
|
log_action(
|
|
db,
|
|
user_id=current_user.id,
|
|
action="delete_snapshot",
|
|
entity_type="snapshot",
|
|
entity_id=snapshot.id,
|
|
details={"name": snapshot.name},
|
|
)
|
|
|
|
db.delete(snapshot)
|
|
db.commit()
|
|
|
|
return {"detail": "Snapshot deleted"}
|