feat: move all remaining inline logic from routers to services (Tier 2)

This commit is contained in:
2026-02-20 14:34:24 +01:00
parent 9e22fde746
commit 339d669498
17 changed files with 632 additions and 414 deletions

View File

@@ -8,18 +8,24 @@ import logging
import uuid
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi import APIRouter, Depends, Query
from pydantic import BaseModel
from sqlalchemy.orm import Session
from app.database import get_db
from app.dependencies.auth import get_current_user, require_any_role, require_role
from app.domain.errors import BusinessRuleViolation
from app.domain.unit_of_work import UnitOfWork
from app.models.user import User
from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState
from app.services.snapshot_service import (
create_snapshot,
compare_snapshots,
cleanup_old_snapshots,
serialize_snapshot_summary,
list_snapshots as list_snapshots_svc,
get_snapshot_or_raise,
get_snapshot_detail,
delete_snapshot,
)
from app.services.audit_service import log_action
@@ -34,48 +40,6 @@ class SnapshotCreate(BaseModel):
name: Optional[str] = None
# ── Helpers ──────────────────────────────────────────────────────────
def _serialize_snapshot_summary(snap: CoverageSnapshot) -> dict:
"""Lightweight serialization for list views."""
return {
"id": str(snap.id),
"name": snap.name,
"organization_score": snap.organization_score,
"total_techniques": snap.total_techniques,
"validated_count": snap.validated_count,
"partial_count": snap.partial_count,
"not_covered_count": snap.not_covered_count,
"in_progress_count": snap.in_progress_count,
"not_evaluated_count": snap.not_evaluated_count,
"created_by": str(snap.created_by) if snap.created_by else None,
"created_at": snap.created_at.isoformat() if snap.created_at else None,
}
def _serialize_snapshot_detail(db: Session, snap: CoverageSnapshot) -> dict:
"""Full serialization including technique states."""
base = _serialize_snapshot_summary(snap)
technique_states = (
db.query(SnapshotTechniqueState)
.filter(SnapshotTechniqueState.snapshot_id == snap.id)
.order_by(SnapshotTechniqueState.mitre_id)
.all()
)
base["technique_states"] = [
{
"mitre_id": s.mitre_id,
"technique_id": str(s.technique_id),
"status": s.status,
"score": s.score,
}
for s in technique_states
]
return base
# ---------------------------------------------------------------------------
# GET /snapshots — List snapshots (paginated)
# ---------------------------------------------------------------------------
@@ -88,23 +52,7 @@ def list_snapshots(
current_user: User = Depends(get_current_user),
):
"""List coverage snapshots ordered by creation date (newest first)."""
query = db.query(CoverageSnapshot)
total = query.count()
snapshots = (
query
.order_by(CoverageSnapshot.created_at.desc())
.offset(offset)
.limit(limit)
.all()
)
return {
"total": total,
"offset": offset,
"limit": limit,
"items": [_serialize_snapshot_summary(s) for s in snapshots],
}
return list_snapshots_svc(db, offset=offset, limit=limit)
# ---------------------------------------------------------------------------
@@ -129,7 +77,7 @@ def create_snapshot_endpoint(
details={"name": snapshot.name, "score": snapshot.organization_score},
)
return _serialize_snapshot_summary(snapshot)
return serialize_snapshot_summary(snapshot)
# ---------------------------------------------------------------------------
@@ -148,13 +96,9 @@ def compare_snapshots_endpoint(
a_id = uuid.UUID(a)
b_id = uuid.UUID(b)
except ValueError:
raise HTTPException(status_code=400, detail="Invalid snapshot ID format")
raise BusinessRuleViolation("Invalid snapshot ID format")
result = compare_snapshots(db, a_id, b_id)
if "error" in result:
raise HTTPException(status_code=404, detail=result["error"])
return result
return compare_snapshots(db, a_id, b_id)
# ---------------------------------------------------------------------------
@@ -168,11 +112,7 @@ def get_snapshot(
current_user: User = Depends(get_current_user),
):
"""Get detailed snapshot information including per-technique states."""
snapshot = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_id).first()
if not snapshot:
raise HTTPException(status_code=404, detail="Snapshot not found")
return _serialize_snapshot_detail(db, snapshot)
return get_snapshot_detail(db, snapshot_id)
# ---------------------------------------------------------------------------
@@ -180,15 +120,13 @@ def get_snapshot(
# ---------------------------------------------------------------------------
@router.delete("/{snapshot_id}")
def delete_snapshot(
def delete_snapshot_endpoint(
snapshot_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
"""Delete a snapshot (admin only)."""
snapshot = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_id).first()
if not snapshot:
raise HTTPException(status_code=404, detail="Snapshot not found")
snapshot = get_snapshot_or_raise(db, snapshot_id)
log_action(
db,
@@ -199,7 +137,8 @@ def delete_snapshot(
details={"name": snapshot.name},
)
db.delete(snapshot)
db.commit()
with UnitOfWork(db) as uow:
delete_snapshot(db, snapshot_id)
uow.commit()
return {"detail": "Snapshot deleted"}