feat(phase-30): add coverage snapshots, temporal comparison and auto re-testing (T-230 to T-232)
This commit is contained in:
77
backend/alembic/versions/b015_add_coverage_snapshots.py
Normal file
77
backend/alembic/versions/b015_add_coverage_snapshots.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""add_coverage_snapshots
|
||||
|
||||
Revision ID: b015snapshots
|
||||
Revises: b014compliance
|
||||
Create Date: 2026-02-10 00:00:00.000000
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "b015snapshots"
|
||||
down_revision: Union[str, None] = "b014compliance"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ── coverage_snapshots ────────────────────────────────────────
|
||||
op.create_table(
|
||||
"coverage_snapshots",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column("name", sa.String, nullable=True),
|
||||
sa.Column("organization_score", sa.Float, nullable=False),
|
||||
sa.Column("total_techniques", sa.Integer, nullable=False),
|
||||
sa.Column("validated_count", sa.Integer, nullable=False),
|
||||
sa.Column("partial_count", sa.Integer, nullable=False),
|
||||
sa.Column("not_covered_count", sa.Integer, nullable=False),
|
||||
sa.Column("in_progress_count", sa.Integer, nullable=False),
|
||||
sa.Column("not_evaluated_count", sa.Integer, nullable=False),
|
||||
sa.Column(
|
||||
"created_by",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
sa.ForeignKey("users.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("created_at", sa.DateTime, server_default=sa.func.now()),
|
||||
)
|
||||
|
||||
# ── snapshot_technique_states ─────────────────────────────────
|
||||
op.create_table(
|
||||
"snapshot_technique_states",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column(
|
||||
"snapshot_id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
sa.ForeignKey("coverage_snapshots.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"technique_id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
sa.ForeignKey("techniques.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("mitre_id", sa.String, nullable=False),
|
||||
sa.Column("status", sa.String, nullable=False),
|
||||
sa.Column("score", sa.Float, nullable=True),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_snapshot_technique_states_snapshot",
|
||||
"snapshot_technique_states",
|
||||
["snapshot_id"],
|
||||
)
|
||||
op.create_index(
|
||||
"ix_snapshot_technique_states_technique",
|
||||
"snapshot_technique_states",
|
||||
["technique_id"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("snapshot_technique_states")
|
||||
op.drop_table("coverage_snapshots")
|
||||
41
backend/alembic/versions/b016_add_retest_fields.py
Normal file
41
backend/alembic/versions/b016_add_retest_fields.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""add_retest_fields
|
||||
|
||||
Revision ID: b016retests
|
||||
Revises: b015snapshots
|
||||
Create Date: 2026-02-10 01:00:00.000000
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "b016retests"
|
||||
down_revision: Union[str, None] = "b015snapshots"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"tests",
|
||||
sa.Column(
|
||||
"retest_of",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
sa.ForeignKey("tests.id"),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"tests",
|
||||
sa.Column("retest_count", sa.Integer, server_default="0", nullable=False),
|
||||
)
|
||||
op.create_index("ix_tests_retest_of", "tests", ["retest_of"])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_tests_retest_of", table_name="tests")
|
||||
op.drop_column("tests", "retest_count")
|
||||
op.drop_column("tests", "retest_of")
|
||||
@@ -11,6 +11,9 @@ class Settings(BaseSettings):
|
||||
MINIO_SECRET_KEY: str = "minioadmin"
|
||||
MINIO_BUCKET: str = "evidence"
|
||||
|
||||
# Re-testing
|
||||
MAX_RETEST_COUNT: int = 3 # maximum automatic retests per original test
|
||||
|
||||
# Scoring weights (must sum to 100)
|
||||
SCORING_WEIGHT_TESTS: int = 40
|
||||
SCORING_WEIGHT_DETECTION_RULES: int = 20
|
||||
|
||||
@@ -18,6 +18,7 @@ from app.database import SessionLocal
|
||||
from app.services.mitre_sync_service import sync_mitre
|
||||
from app.services.intel_service import scan_intel
|
||||
from app.services.notification_service import cleanup_old_notifications
|
||||
from app.services.snapshot_service import create_snapshot, cleanup_old_snapshots
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -59,6 +60,26 @@ def _run_notification_cleanup() -> None:
|
||||
db.close()
|
||||
|
||||
|
||||
def _run_weekly_snapshot() -> None:
|
||||
"""Create a weekly coverage snapshot and clean up old ones."""
|
||||
logger.info("Scheduled weekly snapshot job starting...")
|
||||
db = SessionLocal()
|
||||
try:
|
||||
snapshot = create_snapshot(db, name="Auto-weekly")
|
||||
logger.info(
|
||||
"Weekly snapshot created — score %.1f, %d techniques",
|
||||
snapshot.organization_score,
|
||||
snapshot.total_techniques,
|
||||
)
|
||||
deleted = cleanup_old_snapshots(db, keep_last=52)
|
||||
if deleted:
|
||||
logger.info("Cleaned up %d old snapshots", deleted)
|
||||
except Exception:
|
||||
logger.exception("Weekly snapshot job failed")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def _run_intel_scan() -> None:
|
||||
"""Execute an intel scan inside its own DB session."""
|
||||
logger.info("Scheduled intel scan job starting...")
|
||||
@@ -111,5 +132,18 @@ def start_scheduler() -> None:
|
||||
name="Notification cleanup (daily)",
|
||||
replace_existing=True,
|
||||
)
|
||||
scheduler.add_job(
|
||||
_run_weekly_snapshot,
|
||||
trigger="cron",
|
||||
day_of_week="sun",
|
||||
hour=0,
|
||||
minute=0,
|
||||
id="weekly_snapshot",
|
||||
name="Weekly coverage snapshot (Sundays 00:00)",
|
||||
replace_existing=True,
|
||||
)
|
||||
scheduler.start()
|
||||
logger.info("Background scheduler started — mitre_sync (24h), intel_scan (7d), notification_cleanup (24h)")
|
||||
logger.info(
|
||||
"Background scheduler started — mitre_sync (24h), intel_scan (7d), "
|
||||
"notification_cleanup (24h), weekly_snapshot (Sundays 00:00)"
|
||||
)
|
||||
|
||||
@@ -27,6 +27,7 @@ from app.routers import heatmap as heatmap_router
|
||||
from app.routers import scores as scores_router
|
||||
from app.routers import operational_metrics as operational_metrics_router
|
||||
from app.routers import compliance as compliance_router
|
||||
from app.routers import snapshots as snapshots_router
|
||||
from app.storage import ensure_bucket_exists
|
||||
from app.jobs.mitre_sync_job import start_scheduler, scheduler
|
||||
|
||||
@@ -78,6 +79,7 @@ app.include_router(heatmap_router.router, prefix="/api/v1")
|
||||
app.include_router(scores_router.router, prefix="/api/v1")
|
||||
app.include_router(operational_metrics_router.router, prefix="/api/v1")
|
||||
app.include_router(compliance_router.router, prefix="/api/v1")
|
||||
app.include_router(snapshots_router.router, prefix="/api/v1")
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
|
||||
@@ -15,6 +15,7 @@ from app.models.test_template_detection_rule import TestTemplateDetectionRule
|
||||
from app.models.test_detection_result import TestDetectionResult
|
||||
from app.models.campaign import Campaign, CampaignTest
|
||||
from app.models.compliance import ComplianceFramework, ComplianceControl, ComplianceControlMapping
|
||||
from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState
|
||||
from app.models.enums import TechniqueStatus, TestState, TestResult, TeamSide
|
||||
|
||||
__all__ = [
|
||||
@@ -25,5 +26,6 @@ __all__ = [
|
||||
"TestTemplateDetectionRule", "TestDetectionResult",
|
||||
"Campaign", "CampaignTest",
|
||||
"ComplianceFramework", "ComplianceControl", "ComplianceControlMapping",
|
||||
"CoverageSnapshot", "SnapshotTechniqueState",
|
||||
"TechniqueStatus", "TestState", "TestResult", "TeamSide",
|
||||
]
|
||||
|
||||
78
backend/app/models/coverage_snapshot.py
Normal file
78
backend/app/models/coverage_snapshot.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""Coverage snapshot models — periodic snapshots of coverage state.
|
||||
|
||||
CoverageSnapshot stores aggregate metrics at a point in time.
|
||||
SnapshotTechniqueState stores per-technique state (normalized, one row
|
||||
per technique per snapshot) to avoid bloated JSONB fields.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import (
|
||||
Column, String, Float, Integer, DateTime,
|
||||
ForeignKey, Index,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class CoverageSnapshot(Base):
|
||||
"""A point-in-time snapshot of the organisation's overall coverage."""
|
||||
|
||||
__tablename__ = "coverage_snapshots"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
name = Column(String, nullable=True) # e.g. "Pre-remediación Q1"
|
||||
organization_score = Column(Float, nullable=False)
|
||||
total_techniques = Column(Integer, nullable=False)
|
||||
validated_count = Column(Integer, nullable=False)
|
||||
partial_count = Column(Integer, nullable=False)
|
||||
not_covered_count = Column(Integer, nullable=False)
|
||||
in_progress_count = Column(Integer, nullable=False)
|
||||
not_evaluated_count = Column(Integer, nullable=False)
|
||||
created_by = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("users.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
# Relationships
|
||||
creator = relationship("User", foreign_keys=[created_by])
|
||||
technique_states = relationship(
|
||||
"SnapshotTechniqueState",
|
||||
back_populates="snapshot",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
|
||||
class SnapshotTechniqueState(Base):
|
||||
"""Per-technique state within a snapshot (normalised storage)."""
|
||||
|
||||
__tablename__ = "snapshot_technique_states"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
snapshot_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("coverage_snapshots.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
technique_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("techniques.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
mitre_id = Column(String, nullable=False) # denormalised for fast queries
|
||||
status = Column(String, nullable=False)
|
||||
score = Column(Float, nullable=True)
|
||||
|
||||
# Relationships
|
||||
snapshot = relationship("CoverageSnapshot", back_populates="technique_states")
|
||||
technique = relationship("Technique")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_snapshot_technique_states_snapshot", "snapshot_id"),
|
||||
Index("ix_snapshot_technique_states_technique", "technique_id"),
|
||||
)
|
||||
@@ -1,7 +1,7 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Column, String, Text, Boolean, DateTime, ForeignKey, Enum
|
||||
from sqlalchemy import Column, String, Text, Boolean, Integer, DateTime, ForeignKey, Enum
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
@@ -54,6 +54,10 @@ class Test(Base):
|
||||
remediation_status = Column(String, nullable=True) # pending / in_progress / completed / not_applicable
|
||||
remediation_assignee = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
|
||||
|
||||
# ── Re-test fields ────────────────────────────────────────────
|
||||
retest_of = Column(UUID(as_uuid=True), ForeignKey("tests.id"), nullable=True)
|
||||
retest_count = Column(Integer, default=0)
|
||||
|
||||
# ── Relationships ───────────────────────────────────────────────
|
||||
technique = relationship("Technique", back_populates="tests")
|
||||
evidences = relationship("Evidence", back_populates="test")
|
||||
@@ -61,3 +65,5 @@ class Test(Base):
|
||||
red_validator = relationship("User", foreign_keys=[red_validated_by])
|
||||
blue_validator = relationship("User", foreign_keys=[blue_validated_by])
|
||||
remediation_user = relationship("User", foreign_keys=[remediation_assignee])
|
||||
original_test = relationship("Test", remote_side="Test.id", foreign_keys=[retest_of])
|
||||
retests = relationship("Test", foreign_keys=[retest_of], back_populates="original_test")
|
||||
|
||||
205
backend/app/routers/snapshots.py
Normal file
205
backend/app/routers/snapshots.py
Normal file
@@ -0,0 +1,205 @@
|
||||
"""Snapshot endpoints — coverage snapshots CRUD and comparison.
|
||||
|
||||
Provides periodic and manual snapshots of the organisation's coverage
|
||||
state, plus temporal comparison between any two snapshots.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import get_current_user, require_any_role, require_role
|
||||
from app.models.user import User
|
||||
from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState
|
||||
from app.services.snapshot_service import (
|
||||
create_snapshot,
|
||||
compare_snapshots,
|
||||
cleanup_old_snapshots,
|
||||
)
|
||||
from app.services.audit_service import log_action
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/snapshots", tags=["snapshots"])
|
||||
|
||||
|
||||
# ── Pydantic schemas ─────────────────────────────────────────────────
|
||||
|
||||
class SnapshotCreate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────────────────────
|
||||
|
||||
def _serialize_snapshot_summary(snap: CoverageSnapshot) -> dict:
|
||||
"""Lightweight serialization for list views."""
|
||||
return {
|
||||
"id": str(snap.id),
|
||||
"name": snap.name,
|
||||
"organization_score": snap.organization_score,
|
||||
"total_techniques": snap.total_techniques,
|
||||
"validated_count": snap.validated_count,
|
||||
"partial_count": snap.partial_count,
|
||||
"not_covered_count": snap.not_covered_count,
|
||||
"in_progress_count": snap.in_progress_count,
|
||||
"not_evaluated_count": snap.not_evaluated_count,
|
||||
"created_by": str(snap.created_by) if snap.created_by else None,
|
||||
"created_at": snap.created_at.isoformat() if snap.created_at else None,
|
||||
}
|
||||
|
||||
|
||||
def _serialize_snapshot_detail(db: Session, snap: CoverageSnapshot) -> dict:
|
||||
"""Full serialization including technique states."""
|
||||
base = _serialize_snapshot_summary(snap)
|
||||
|
||||
technique_states = (
|
||||
db.query(SnapshotTechniqueState)
|
||||
.filter(SnapshotTechniqueState.snapshot_id == snap.id)
|
||||
.order_by(SnapshotTechniqueState.mitre_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
base["technique_states"] = [
|
||||
{
|
||||
"mitre_id": s.mitre_id,
|
||||
"technique_id": str(s.technique_id),
|
||||
"status": s.status,
|
||||
"score": s.score,
|
||||
}
|
||||
for s in technique_states
|
||||
]
|
||||
return base
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /snapshots — List snapshots (paginated)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("")
|
||||
def list_snapshots(
|
||||
offset: int = Query(0, ge=0),
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""List coverage snapshots ordered by creation date (newest first)."""
|
||||
query = db.query(CoverageSnapshot)
|
||||
total = query.count()
|
||||
|
||||
snapshots = (
|
||||
query
|
||||
.order_by(CoverageSnapshot.created_at.desc())
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"offset": offset,
|
||||
"limit": limit,
|
||||
"items": [_serialize_snapshot_summary(s) for s in snapshots],
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /snapshots — Create snapshot manually
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("", status_code=201)
|
||||
def create_snapshot_endpoint(
|
||||
payload: SnapshotCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead", "admin")),
|
||||
):
|
||||
"""Create a manual coverage snapshot with an optional name."""
|
||||
snapshot = create_snapshot(db, name=payload.name, user_id=current_user.id)
|
||||
|
||||
log_action(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
action="create_snapshot",
|
||||
entity_type="snapshot",
|
||||
entity_id=snapshot.id,
|
||||
details={"name": snapshot.name, "score": snapshot.organization_score},
|
||||
)
|
||||
|
||||
return _serialize_snapshot_summary(snapshot)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /snapshots/compare — Compare two snapshots
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/compare")
|
||||
def compare_snapshots_endpoint(
|
||||
a: str = Query(..., description="Snapshot A ID"),
|
||||
b: str = Query(..., description="Snapshot B ID"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Compare two snapshots showing improved, worsened, and unchanged techniques."""
|
||||
try:
|
||||
a_id = uuid.UUID(a)
|
||||
b_id = uuid.UUID(b)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail="Invalid snapshot ID format")
|
||||
|
||||
result = compare_snapshots(db, a_id, b_id)
|
||||
if "error" in result:
|
||||
raise HTTPException(status_code=404, detail=result["error"])
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /snapshots/{id} — Snapshot detail
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/{snapshot_id}")
|
||||
def get_snapshot(
|
||||
snapshot_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Get detailed snapshot information including per-technique states."""
|
||||
snapshot = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_id).first()
|
||||
if not snapshot:
|
||||
raise HTTPException(status_code=404, detail="Snapshot not found")
|
||||
|
||||
return _serialize_snapshot_detail(db, snapshot)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DELETE /snapshots/{id} — Delete snapshot (admin only)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.delete("/{snapshot_id}")
|
||||
def delete_snapshot(
|
||||
snapshot_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Delete a snapshot (admin only)."""
|
||||
snapshot = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_id).first()
|
||||
if not snapshot:
|
||||
raise HTTPException(status_code=404, detail="Snapshot not found")
|
||||
|
||||
log_action(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
action="delete_snapshot",
|
||||
entity_type="snapshot",
|
||||
entity_id=snapshot.id,
|
||||
details={"name": snapshot.name},
|
||||
)
|
||||
|
||||
db.delete(snapshot)
|
||||
db.commit()
|
||||
|
||||
return {"detail": "Snapshot deleted"}
|
||||
@@ -52,6 +52,8 @@ from app.services.test_workflow_service import (
|
||||
validate_as_red_lead as wf_validate_red,
|
||||
validate_as_blue_lead as wf_validate_blue,
|
||||
reopen_test as wf_reopen,
|
||||
handle_remediation_completed as wf_handle_remediation,
|
||||
get_retest_chain as wf_get_retest_chain,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/tests", tags=["tests"])
|
||||
@@ -546,9 +548,15 @@ def update_remediation(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Update remediation fields on a test (any authenticated user)."""
|
||||
"""Update remediation fields on a test (any authenticated user).
|
||||
|
||||
When ``remediation_status`` transitions to ``'completed'``, an automatic
|
||||
re-test is created (subject to ``MAX_RETEST_COUNT``).
|
||||
"""
|
||||
test = _get_test_or_404(db, test_id)
|
||||
|
||||
old_remediation_status = test.remediation_status
|
||||
|
||||
update_data = payload.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(test, field, value)
|
||||
@@ -565,6 +573,13 @@ def update_remediation(
|
||||
details={"updated_fields": list(update_data.keys())},
|
||||
)
|
||||
|
||||
# Auto-create retest when remediation is marked completed
|
||||
new_status = update_data.get("remediation_status")
|
||||
if new_status == "completed" and old_remediation_status != "completed":
|
||||
retest = wf_handle_remediation(db, test, current_user)
|
||||
if retest:
|
||||
db.refresh(test)
|
||||
|
||||
return test
|
||||
|
||||
|
||||
@@ -603,3 +618,35 @@ def get_test_timeline(
|
||||
}
|
||||
for log in logs
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /tests/{id}/retest-chain — full retest chain
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get("/{test_id}/retest-chain")
|
||||
def get_retest_chain(
|
||||
test_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Return the full chain of retests (original + all retests) for a test."""
|
||||
chain = wf_get_retest_chain(db, test_id)
|
||||
if not chain:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Test not found")
|
||||
|
||||
return [
|
||||
{
|
||||
"id": str(t.id),
|
||||
"name": t.name,
|
||||
"state": t.state.value if t.state else None,
|
||||
"retest_of": str(t.retest_of) if t.retest_of else None,
|
||||
"retest_count": t.retest_count,
|
||||
"result": t.result.value if t.result else None,
|
||||
"detection_result": t.detection_result.value if t.detection_result else None,
|
||||
"remediation_status": t.remediation_status,
|
||||
"created_at": t.created_at.isoformat() if t.created_at else None,
|
||||
}
|
||||
for t in chain
|
||||
]
|
||||
|
||||
@@ -142,6 +142,10 @@ class TestOut(BaseModel):
|
||||
remediation_status: str | None = None
|
||||
remediation_assignee: uuid.UUID | None = None
|
||||
|
||||
# Re-test fields
|
||||
retest_of: uuid.UUID | None = None
|
||||
retest_count: int = 0
|
||||
|
||||
# Technique info (populated when joined)
|
||||
technique_mitre_id: str | None = None
|
||||
technique_name: str | None = None
|
||||
|
||||
253
backend/app/services/snapshot_service.py
Normal file
253
backend/app/services/snapshot_service.py
Normal file
@@ -0,0 +1,253 @@
|
||||
"""Snapshot service — create, compare, and manage coverage snapshots.
|
||||
|
||||
Provides point-in-time coverage captures with normalised per-technique
|
||||
storage and temporal comparison between any two snapshots.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.technique import Technique
|
||||
from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState
|
||||
from app.models.enums import TechniqueStatus
|
||||
from app.services.scoring_service import calculate_technique_score, calculate_organization_score
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Create snapshot
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def create_snapshot(
|
||||
db: Session,
|
||||
name: str | None = None,
|
||||
user_id: uuid.UUID | None = None,
|
||||
) -> CoverageSnapshot:
|
||||
"""Capture the current coverage state into a new snapshot.
|
||||
|
||||
1. Fetch every technique with its status and score.
|
||||
2. Compute aggregate counts.
|
||||
3. Persist a ``CoverageSnapshot`` with normalised
|
||||
``SnapshotTechniqueState`` rows.
|
||||
"""
|
||||
techniques = db.query(Technique).all()
|
||||
|
||||
# Aggregate counters
|
||||
validated_count = 0
|
||||
partial_count = 0
|
||||
not_covered_count = 0
|
||||
in_progress_count = 0
|
||||
not_evaluated_count = 0
|
||||
|
||||
technique_rows: list[dict] = []
|
||||
|
||||
for tech in techniques:
|
||||
status_value = (
|
||||
tech.status_global.value
|
||||
if isinstance(tech.status_global, TechniqueStatus)
|
||||
else (tech.status_global or "not_evaluated")
|
||||
)
|
||||
|
||||
# Count by status
|
||||
if status_value == "validated":
|
||||
validated_count += 1
|
||||
elif status_value == "partial":
|
||||
partial_count += 1
|
||||
elif status_value == "not_covered":
|
||||
not_covered_count += 1
|
||||
elif status_value == "in_progress":
|
||||
in_progress_count += 1
|
||||
else:
|
||||
not_evaluated_count += 1
|
||||
|
||||
# Compute technique score
|
||||
score_data = calculate_technique_score(tech, db)
|
||||
technique_rows.append({
|
||||
"technique_id": tech.id,
|
||||
"mitre_id": tech.mitre_id,
|
||||
"status": status_value,
|
||||
"score": score_data["total_score"],
|
||||
})
|
||||
|
||||
# Organization score
|
||||
org_data = calculate_organization_score(db)
|
||||
org_score = org_data.get("overall_score", 0)
|
||||
|
||||
# Create the snapshot
|
||||
snapshot = CoverageSnapshot(
|
||||
name=name,
|
||||
organization_score=org_score,
|
||||
total_techniques=len(techniques),
|
||||
validated_count=validated_count,
|
||||
partial_count=partial_count,
|
||||
not_covered_count=not_covered_count,
|
||||
in_progress_count=in_progress_count,
|
||||
not_evaluated_count=not_evaluated_count,
|
||||
created_by=user_id,
|
||||
)
|
||||
db.add(snapshot)
|
||||
db.flush() # get snapshot.id
|
||||
|
||||
# Create normalised technique state rows
|
||||
for row in technique_rows:
|
||||
state = SnapshotTechniqueState(
|
||||
snapshot_id=snapshot.id,
|
||||
technique_id=row["technique_id"],
|
||||
mitre_id=row["mitre_id"],
|
||||
status=row["status"],
|
||||
score=row["score"],
|
||||
)
|
||||
db.add(state)
|
||||
|
||||
db.commit()
|
||||
db.refresh(snapshot)
|
||||
|
||||
logger.info(
|
||||
"Snapshot '%s' created — %d techniques, org score %.1f",
|
||||
snapshot.name or snapshot.id,
|
||||
len(techniques),
|
||||
org_score,
|
||||
)
|
||||
return snapshot
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Compare snapshots
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def compare_snapshots(
|
||||
db: Session,
|
||||
snapshot_a_id: uuid.UUID,
|
||||
snapshot_b_id: uuid.UUID,
|
||||
) -> dict:
|
||||
"""Compare two snapshots and return deltas.
|
||||
|
||||
Returns improved/worsened technique lists plus aggregate statistics.
|
||||
"""
|
||||
snap_a = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_a_id).first()
|
||||
snap_b = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_b_id).first()
|
||||
|
||||
if not snap_a or not snap_b:
|
||||
return {"error": "One or both snapshots not found"}
|
||||
|
||||
# Build lookup dicts: mitre_id -> {status, score}
|
||||
states_a = {
|
||||
s.mitre_id: {"status": s.status, "score": s.score or 0}
|
||||
for s in db.query(SnapshotTechniqueState)
|
||||
.filter(SnapshotTechniqueState.snapshot_id == snapshot_a_id)
|
||||
.all()
|
||||
}
|
||||
states_b = {
|
||||
s.mitre_id: {"status": s.status, "score": s.score or 0}
|
||||
for s in db.query(SnapshotTechniqueState)
|
||||
.filter(SnapshotTechniqueState.snapshot_id == snapshot_b_id)
|
||||
.all()
|
||||
}
|
||||
|
||||
# Status priority for comparison
|
||||
STATUS_ORDER = {
|
||||
"not_evaluated": 0,
|
||||
"not_covered": 1,
|
||||
"in_progress": 2,
|
||||
"partial": 3,
|
||||
"validated": 4,
|
||||
}
|
||||
|
||||
improved = []
|
||||
worsened = []
|
||||
unchanged_count = 0
|
||||
|
||||
all_mitre_ids = set(states_a.keys()) | set(states_b.keys())
|
||||
|
||||
for mitre_id in sorted(all_mitre_ids):
|
||||
a = states_a.get(mitre_id, {"status": "not_evaluated", "score": 0})
|
||||
b = states_b.get(mitre_id, {"status": "not_evaluated", "score": 0})
|
||||
|
||||
a_order = STATUS_ORDER.get(a["status"], 0)
|
||||
b_order = STATUS_ORDER.get(b["status"], 0)
|
||||
|
||||
if b_order > a_order or (b_order == a_order and b["score"] > a["score"]):
|
||||
improved.append({
|
||||
"mitre_id": mitre_id,
|
||||
"old_status": a["status"],
|
||||
"new_status": b["status"],
|
||||
"old_score": a["score"],
|
||||
"new_score": b["score"],
|
||||
})
|
||||
elif b_order < a_order or (b_order == a_order and b["score"] < a["score"]):
|
||||
worsened.append({
|
||||
"mitre_id": mitre_id,
|
||||
"old_status": a["status"],
|
||||
"new_status": b["status"],
|
||||
"old_score": a["score"],
|
||||
"new_score": b["score"],
|
||||
})
|
||||
else:
|
||||
unchanged_count += 1
|
||||
|
||||
def _snap_summary(snap: CoverageSnapshot) -> dict:
|
||||
return {
|
||||
"id": str(snap.id),
|
||||
"name": snap.name,
|
||||
"organization_score": snap.organization_score,
|
||||
"total_techniques": snap.total_techniques,
|
||||
"validated_count": snap.validated_count,
|
||||
"partial_count": snap.partial_count,
|
||||
"not_covered_count": snap.not_covered_count,
|
||||
"in_progress_count": snap.in_progress_count,
|
||||
"not_evaluated_count": snap.not_evaluated_count,
|
||||
"created_at": snap.created_at.isoformat() if snap.created_at else None,
|
||||
}
|
||||
|
||||
return {
|
||||
"snapshot_a": _snap_summary(snap_a),
|
||||
"snapshot_b": _snap_summary(snap_b),
|
||||
"score_delta": round(snap_b.organization_score - snap_a.organization_score, 1),
|
||||
"improved": improved,
|
||||
"worsened": worsened,
|
||||
"unchanged_count": unchanged_count,
|
||||
"summary": {
|
||||
"improved_count": len(improved),
|
||||
"worsened_count": len(worsened),
|
||||
"new_count": len(states_b.keys() - states_a.keys()),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Cleanup
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def cleanup_old_snapshots(db: Session, keep_last: int = 52) -> int:
|
||||
"""Delete oldest snapshots, keeping the most recent *keep_last*.
|
||||
|
||||
Returns the number of snapshots deleted.
|
||||
"""
|
||||
total = db.query(func.count(CoverageSnapshot.id)).scalar() or 0
|
||||
if total <= keep_last:
|
||||
return 0
|
||||
|
||||
to_delete = total - keep_last
|
||||
old_snapshots = (
|
||||
db.query(CoverageSnapshot)
|
||||
.order_by(CoverageSnapshot.created_at.asc())
|
||||
.limit(to_delete)
|
||||
.all()
|
||||
)
|
||||
|
||||
deleted = 0
|
||||
for snap in old_snapshots:
|
||||
db.delete(snap)
|
||||
deleted += 1
|
||||
|
||||
db.commit()
|
||||
logger.info("Snapshot cleanup — deleted %d old snapshots (kept %d)", deleted, keep_last)
|
||||
return deleted
|
||||
@@ -16,11 +16,12 @@ from datetime import datetime
|
||||
from fastapi import HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.config import settings
|
||||
from app.models.enums import TestState
|
||||
from app.models.test import Test
|
||||
from app.models.user import User
|
||||
from app.services.audit_service import log_action
|
||||
from app.services.notification_service import notify_test_state_change
|
||||
from app.services.notification_service import notify_test_state_change, create_notification
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Valid transition map
|
||||
@@ -298,6 +299,131 @@ def check_dual_validation(db: Session, test: Test) -> Test:
|
||||
return test
|
||||
|
||||
|
||||
def handle_remediation_completed(db: Session, test: Test, user: User) -> Test | None:
|
||||
"""Create a re-test when remediation is completed.
|
||||
|
||||
When a test's remediation_status changes to 'completed', this function
|
||||
creates a new test (retest) with the same base data to verify that the
|
||||
fix was effective.
|
||||
|
||||
Prevents infinite loops by enforcing ``MAX_RETEST_COUNT``.
|
||||
|
||||
Returns the new retest or *None* if the limit was reached.
|
||||
"""
|
||||
# Always reference the original test, not an intermediate retest
|
||||
original_test_id = test.retest_of or test.id
|
||||
|
||||
if test.retest_count >= settings.MAX_RETEST_COUNT:
|
||||
# Max retests reached — notify and bail out
|
||||
if test.created_by:
|
||||
create_notification(
|
||||
db,
|
||||
user_id=test.created_by,
|
||||
type="max_retests_reached",
|
||||
title="Maximum retests reached",
|
||||
message=(
|
||||
f'Test "{test.name}" has reached the maximum of '
|
||||
f'{settings.MAX_RETEST_COUNT} retests. Manual review required.'
|
||||
),
|
||||
entity_type="test",
|
||||
entity_id=test.id,
|
||||
)
|
||||
|
||||
log_action(
|
||||
db,
|
||||
user_id=user.id,
|
||||
action="max_retests_reached",
|
||||
entity_type="test",
|
||||
entity_id=test.id,
|
||||
details={
|
||||
"retest_count": test.retest_count,
|
||||
"max_allowed": settings.MAX_RETEST_COUNT,
|
||||
"original_test_id": str(original_test_id),
|
||||
},
|
||||
)
|
||||
return None
|
||||
|
||||
retest = Test(
|
||||
technique_id=test.technique_id,
|
||||
name=f"[Retest #{test.retest_count + 1}] {test.name.replace(f'[Retest #{test.retest_count}] ', '')}",
|
||||
description=test.description,
|
||||
platform=test.platform,
|
||||
procedure_text=test.procedure_text,
|
||||
tool_used=test.tool_used,
|
||||
state=TestState.draft,
|
||||
created_by=test.created_by,
|
||||
retest_of=original_test_id,
|
||||
retest_count=test.retest_count + 1,
|
||||
)
|
||||
db.add(retest)
|
||||
db.flush()
|
||||
|
||||
log_action(
|
||||
db,
|
||||
user_id=user.id,
|
||||
action="create_retest",
|
||||
entity_type="test",
|
||||
entity_id=retest.id,
|
||||
details={
|
||||
"original_test_id": str(original_test_id),
|
||||
"retest_number": retest.retest_count,
|
||||
"source_test_id": str(test.id),
|
||||
},
|
||||
)
|
||||
|
||||
# Notify the test creator and any red_tech users
|
||||
if test.created_by:
|
||||
create_notification(
|
||||
db,
|
||||
user_id=test.created_by,
|
||||
type="retest_created",
|
||||
title="Re-test created",
|
||||
message=(
|
||||
f'A re-test has been automatically created for "{test.name}" '
|
||||
f'after remediation was completed.'
|
||||
),
|
||||
entity_type="test",
|
||||
entity_id=retest.id,
|
||||
)
|
||||
|
||||
db.commit()
|
||||
db.refresh(retest)
|
||||
return retest
|
||||
|
||||
|
||||
def get_retest_chain(db: Session, test_id) -> list[Test]:
|
||||
"""Return the full chain of retests for a given test.
|
||||
|
||||
Includes the original test and all subsequent retests, ordered
|
||||
by retest_count.
|
||||
"""
|
||||
import uuid as _uuid
|
||||
|
||||
tid = _uuid.UUID(str(test_id)) if not isinstance(test_id, _uuid.UUID) else test_id
|
||||
|
||||
# Find the original test first
|
||||
test = db.query(Test).filter(Test.id == tid).first()
|
||||
if not test:
|
||||
return []
|
||||
|
||||
original_id = test.retest_of or test.id
|
||||
|
||||
# Get original
|
||||
original = db.query(Test).filter(Test.id == original_id).first()
|
||||
if not original:
|
||||
return [test]
|
||||
|
||||
# Get all retests of the original
|
||||
retests = (
|
||||
db.query(Test)
|
||||
.filter(Test.retest_of == original_id)
|
||||
.order_by(Test.retest_count)
|
||||
.all()
|
||||
)
|
||||
|
||||
return [original] + retests
|
||||
|
||||
|
||||
def reopen_test(db: Session, test: Test, user: User) -> Test:
|
||||
"""Move a ``rejected`` test back to ``draft``, clearing validation fields.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user