From 4d124b42dd7397c50db1914b5b19a3553f095eef Mon Sep 17 00:00:00 2001 From: Kitos Date: Tue, 10 Feb 2026 08:34:29 +0100 Subject: [PATCH] feat(phase-30): add coverage snapshots, temporal comparison and auto re-testing (T-230 to T-232) --- .../versions/b015_add_coverage_snapshots.py | 77 +++ .../versions/b016_add_retest_fields.py | 41 ++ backend/app/config.py | 3 + backend/app/jobs/mitre_sync_job.py | 36 +- backend/app/main.py | 2 + backend/app/models/__init__.py | 2 + backend/app/models/coverage_snapshot.py | 78 +++ backend/app/models/test.py | 8 +- backend/app/routers/snapshots.py | 205 ++++++++ backend/app/routers/tests.py | 49 +- backend/app/schemas/test.py | 4 + backend/app/services/snapshot_service.py | 253 ++++++++++ backend/app/services/test_workflow_service.py | 128 ++++- frontend/src/App.tsx | 2 + frontend/src/api/snapshots.ts | 93 ++++ frontend/src/api/tests.ts | 20 + frontend/src/components/Sidebar.tsx | 2 + frontend/src/pages/ComparisonPage.tsx | 458 ++++++++++++++++++ frontend/src/pages/TestDetailPage.tsx | 56 +++ frontend/src/types/models.ts | 4 + 20 files changed, 1517 insertions(+), 4 deletions(-) create mode 100644 backend/alembic/versions/b015_add_coverage_snapshots.py create mode 100644 backend/alembic/versions/b016_add_retest_fields.py create mode 100644 backend/app/models/coverage_snapshot.py create mode 100644 backend/app/routers/snapshots.py create mode 100644 backend/app/services/snapshot_service.py create mode 100644 frontend/src/api/snapshots.ts create mode 100644 frontend/src/pages/ComparisonPage.tsx diff --git a/backend/alembic/versions/b015_add_coverage_snapshots.py b/backend/alembic/versions/b015_add_coverage_snapshots.py new file mode 100644 index 0000000..c9f897a --- /dev/null +++ b/backend/alembic/versions/b015_add_coverage_snapshots.py @@ -0,0 +1,77 @@ +"""add_coverage_snapshots + +Revision ID: b015snapshots +Revises: b014compliance +Create Date: 2026-02-10 00:00:00.000000 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = "b015snapshots" +down_revision: Union[str, None] = "b014compliance" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ── coverage_snapshots ──────────────────────────────────────── + op.create_table( + "coverage_snapshots", + sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True), + sa.Column("name", sa.String, nullable=True), + sa.Column("organization_score", sa.Float, nullable=False), + sa.Column("total_techniques", sa.Integer, nullable=False), + sa.Column("validated_count", sa.Integer, nullable=False), + sa.Column("partial_count", sa.Integer, nullable=False), + sa.Column("not_covered_count", sa.Integer, nullable=False), + sa.Column("in_progress_count", sa.Integer, nullable=False), + sa.Column("not_evaluated_count", sa.Integer, nullable=False), + sa.Column( + "created_by", + postgresql.UUID(as_uuid=True), + sa.ForeignKey("users.id", ondelete="SET NULL"), + nullable=True, + ), + sa.Column("created_at", sa.DateTime, server_default=sa.func.now()), + ) + + # ── snapshot_technique_states ───────────────────────────────── + op.create_table( + "snapshot_technique_states", + sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True), + sa.Column( + "snapshot_id", + postgresql.UUID(as_uuid=True), + sa.ForeignKey("coverage_snapshots.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column( + "technique_id", + postgresql.UUID(as_uuid=True), + sa.ForeignKey("techniques.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column("mitre_id", sa.String, nullable=False), + sa.Column("status", sa.String, nullable=False), + sa.Column("score", sa.Float, nullable=True), + ) + op.create_index( + "ix_snapshot_technique_states_snapshot", + "snapshot_technique_states", + ["snapshot_id"], + ) + op.create_index( + "ix_snapshot_technique_states_technique", + "snapshot_technique_states", + ["technique_id"], + ) + + +def downgrade() -> None: + op.drop_table("snapshot_technique_states") + op.drop_table("coverage_snapshots") diff --git a/backend/alembic/versions/b016_add_retest_fields.py b/backend/alembic/versions/b016_add_retest_fields.py new file mode 100644 index 0000000..729b1cb --- /dev/null +++ b/backend/alembic/versions/b016_add_retest_fields.py @@ -0,0 +1,41 @@ +"""add_retest_fields + +Revision ID: b016retests +Revises: b015snapshots +Create Date: 2026-02-10 01:00:00.000000 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = "b016retests" +down_revision: Union[str, None] = "b015snapshots" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.add_column( + "tests", + sa.Column( + "retest_of", + postgresql.UUID(as_uuid=True), + sa.ForeignKey("tests.id"), + nullable=True, + ), + ) + op.add_column( + "tests", + sa.Column("retest_count", sa.Integer, server_default="0", nullable=False), + ) + op.create_index("ix_tests_retest_of", "tests", ["retest_of"]) + + +def downgrade() -> None: + op.drop_index("ix_tests_retest_of", table_name="tests") + op.drop_column("tests", "retest_count") + op.drop_column("tests", "retest_of") diff --git a/backend/app/config.py b/backend/app/config.py index 2f676b2..82920f1 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -11,6 +11,9 @@ class Settings(BaseSettings): MINIO_SECRET_KEY: str = "minioadmin" MINIO_BUCKET: str = "evidence" + # Re-testing + MAX_RETEST_COUNT: int = 3 # maximum automatic retests per original test + # Scoring weights (must sum to 100) SCORING_WEIGHT_TESTS: int = 40 SCORING_WEIGHT_DETECTION_RULES: int = 20 diff --git a/backend/app/jobs/mitre_sync_job.py b/backend/app/jobs/mitre_sync_job.py index b3eec45..87a153e 100644 --- a/backend/app/jobs/mitre_sync_job.py +++ b/backend/app/jobs/mitre_sync_job.py @@ -18,6 +18,7 @@ from app.database import SessionLocal from app.services.mitre_sync_service import sync_mitre from app.services.intel_service import scan_intel from app.services.notification_service import cleanup_old_notifications +from app.services.snapshot_service import create_snapshot, cleanup_old_snapshots logger = logging.getLogger(__name__) @@ -59,6 +60,26 @@ def _run_notification_cleanup() -> None: db.close() +def _run_weekly_snapshot() -> None: + """Create a weekly coverage snapshot and clean up old ones.""" + logger.info("Scheduled weekly snapshot job starting...") + db = SessionLocal() + try: + snapshot = create_snapshot(db, name="Auto-weekly") + logger.info( + "Weekly snapshot created — score %.1f, %d techniques", + snapshot.organization_score, + snapshot.total_techniques, + ) + deleted = cleanup_old_snapshots(db, keep_last=52) + if deleted: + logger.info("Cleaned up %d old snapshots", deleted) + except Exception: + logger.exception("Weekly snapshot job failed") + finally: + db.close() + + def _run_intel_scan() -> None: """Execute an intel scan inside its own DB session.""" logger.info("Scheduled intel scan job starting...") @@ -111,5 +132,18 @@ def start_scheduler() -> None: name="Notification cleanup (daily)", replace_existing=True, ) + scheduler.add_job( + _run_weekly_snapshot, + trigger="cron", + day_of_week="sun", + hour=0, + minute=0, + id="weekly_snapshot", + name="Weekly coverage snapshot (Sundays 00:00)", + replace_existing=True, + ) scheduler.start() - logger.info("Background scheduler started — mitre_sync (24h), intel_scan (7d), notification_cleanup (24h)") + logger.info( + "Background scheduler started — mitre_sync (24h), intel_scan (7d), " + "notification_cleanup (24h), weekly_snapshot (Sundays 00:00)" + ) diff --git a/backend/app/main.py b/backend/app/main.py index 4d9ae29..68dbf81 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -27,6 +27,7 @@ from app.routers import heatmap as heatmap_router from app.routers import scores as scores_router from app.routers import operational_metrics as operational_metrics_router from app.routers import compliance as compliance_router +from app.routers import snapshots as snapshots_router from app.storage import ensure_bucket_exists from app.jobs.mitre_sync_job import start_scheduler, scheduler @@ -78,6 +79,7 @@ app.include_router(heatmap_router.router, prefix="/api/v1") app.include_router(scores_router.router, prefix="/api/v1") app.include_router(operational_metrics_router.router, prefix="/api/v1") app.include_router(compliance_router.router, prefix="/api/v1") +app.include_router(snapshots_router.router, prefix="/api/v1") @app.get("/health") diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index c61e495..2fb5911 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -15,6 +15,7 @@ from app.models.test_template_detection_rule import TestTemplateDetectionRule from app.models.test_detection_result import TestDetectionResult from app.models.campaign import Campaign, CampaignTest from app.models.compliance import ComplianceFramework, ComplianceControl, ComplianceControlMapping +from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState from app.models.enums import TechniqueStatus, TestState, TestResult, TeamSide __all__ = [ @@ -25,5 +26,6 @@ __all__ = [ "TestTemplateDetectionRule", "TestDetectionResult", "Campaign", "CampaignTest", "ComplianceFramework", "ComplianceControl", "ComplianceControlMapping", + "CoverageSnapshot", "SnapshotTechniqueState", "TechniqueStatus", "TestState", "TestResult", "TeamSide", ] diff --git a/backend/app/models/coverage_snapshot.py b/backend/app/models/coverage_snapshot.py new file mode 100644 index 0000000..7744d30 --- /dev/null +++ b/backend/app/models/coverage_snapshot.py @@ -0,0 +1,78 @@ +"""Coverage snapshot models — periodic snapshots of coverage state. + +CoverageSnapshot stores aggregate metrics at a point in time. +SnapshotTechniqueState stores per-technique state (normalized, one row +per technique per snapshot) to avoid bloated JSONB fields. +""" + +import uuid +from datetime import datetime + +from sqlalchemy import ( + Column, String, Float, Integer, DateTime, + ForeignKey, Index, +) +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.orm import relationship + +from app.database import Base + + +class CoverageSnapshot(Base): + """A point-in-time snapshot of the organisation's overall coverage.""" + + __tablename__ = "coverage_snapshots" + + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + name = Column(String, nullable=True) # e.g. "Pre-remediación Q1" + organization_score = Column(Float, nullable=False) + total_techniques = Column(Integer, nullable=False) + validated_count = Column(Integer, nullable=False) + partial_count = Column(Integer, nullable=False) + not_covered_count = Column(Integer, nullable=False) + in_progress_count = Column(Integer, nullable=False) + not_evaluated_count = Column(Integer, nullable=False) + created_by = Column( + UUID(as_uuid=True), + ForeignKey("users.id", ondelete="SET NULL"), + nullable=True, + ) + created_at = Column(DateTime, default=datetime.utcnow) + + # Relationships + creator = relationship("User", foreign_keys=[created_by]) + technique_states = relationship( + "SnapshotTechniqueState", + back_populates="snapshot", + cascade="all, delete-orphan", + ) + + +class SnapshotTechniqueState(Base): + """Per-technique state within a snapshot (normalised storage).""" + + __tablename__ = "snapshot_technique_states" + + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + snapshot_id = Column( + UUID(as_uuid=True), + ForeignKey("coverage_snapshots.id", ondelete="CASCADE"), + nullable=False, + ) + technique_id = Column( + UUID(as_uuid=True), + ForeignKey("techniques.id", ondelete="CASCADE"), + nullable=False, + ) + mitre_id = Column(String, nullable=False) # denormalised for fast queries + status = Column(String, nullable=False) + score = Column(Float, nullable=True) + + # Relationships + snapshot = relationship("CoverageSnapshot", back_populates="technique_states") + technique = relationship("Technique") + + __table_args__ = ( + Index("ix_snapshot_technique_states_snapshot", "snapshot_id"), + Index("ix_snapshot_technique_states_technique", "technique_id"), + ) diff --git a/backend/app/models/test.py b/backend/app/models/test.py index 6b8d989..b23c37a 100644 --- a/backend/app/models/test.py +++ b/backend/app/models/test.py @@ -1,7 +1,7 @@ import uuid from datetime import datetime -from sqlalchemy import Column, String, Text, Boolean, DateTime, ForeignKey, Enum +from sqlalchemy import Column, String, Text, Boolean, Integer, DateTime, ForeignKey, Enum from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.orm import relationship @@ -54,6 +54,10 @@ class Test(Base): remediation_status = Column(String, nullable=True) # pending / in_progress / completed / not_applicable remediation_assignee = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True) + # ── Re-test fields ──────────────────────────────────────────── + retest_of = Column(UUID(as_uuid=True), ForeignKey("tests.id"), nullable=True) + retest_count = Column(Integer, default=0) + # ── Relationships ─────────────────────────────────────────────── technique = relationship("Technique", back_populates="tests") evidences = relationship("Evidence", back_populates="test") @@ -61,3 +65,5 @@ class Test(Base): red_validator = relationship("User", foreign_keys=[red_validated_by]) blue_validator = relationship("User", foreign_keys=[blue_validated_by]) remediation_user = relationship("User", foreign_keys=[remediation_assignee]) + original_test = relationship("Test", remote_side="Test.id", foreign_keys=[retest_of]) + retests = relationship("Test", foreign_keys=[retest_of], back_populates="original_test") diff --git a/backend/app/routers/snapshots.py b/backend/app/routers/snapshots.py new file mode 100644 index 0000000..cacea02 --- /dev/null +++ b/backend/app/routers/snapshots.py @@ -0,0 +1,205 @@ +"""Snapshot endpoints — coverage snapshots CRUD and comparison. + +Provides periodic and manual snapshots of the organisation's coverage +state, plus temporal comparison between any two snapshots. +""" + +import logging +import uuid +from typing import Optional + +from fastapi import APIRouter, Depends, HTTPException, Query +from pydantic import BaseModel +from sqlalchemy.orm import Session + +from app.database import get_db +from app.dependencies.auth import get_current_user, require_any_role, require_role +from app.models.user import User +from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState +from app.services.snapshot_service import ( + create_snapshot, + compare_snapshots, + cleanup_old_snapshots, +) +from app.services.audit_service import log_action + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/snapshots", tags=["snapshots"]) + + +# ── Pydantic schemas ───────────────────────────────────────────────── + +class SnapshotCreate(BaseModel): + name: Optional[str] = None + + +# ── Helpers ────────────────────────────────────────────────────────── + +def _serialize_snapshot_summary(snap: CoverageSnapshot) -> dict: + """Lightweight serialization for list views.""" + return { + "id": str(snap.id), + "name": snap.name, + "organization_score": snap.organization_score, + "total_techniques": snap.total_techniques, + "validated_count": snap.validated_count, + "partial_count": snap.partial_count, + "not_covered_count": snap.not_covered_count, + "in_progress_count": snap.in_progress_count, + "not_evaluated_count": snap.not_evaluated_count, + "created_by": str(snap.created_by) if snap.created_by else None, + "created_at": snap.created_at.isoformat() if snap.created_at else None, + } + + +def _serialize_snapshot_detail(db: Session, snap: CoverageSnapshot) -> dict: + """Full serialization including technique states.""" + base = _serialize_snapshot_summary(snap) + + technique_states = ( + db.query(SnapshotTechniqueState) + .filter(SnapshotTechniqueState.snapshot_id == snap.id) + .order_by(SnapshotTechniqueState.mitre_id) + .all() + ) + + base["technique_states"] = [ + { + "mitre_id": s.mitre_id, + "technique_id": str(s.technique_id), + "status": s.status, + "score": s.score, + } + for s in technique_states + ] + return base + + +# --------------------------------------------------------------------------- +# GET /snapshots — List snapshots (paginated) +# --------------------------------------------------------------------------- + +@router.get("") +def list_snapshots( + offset: int = Query(0, ge=0), + limit: int = Query(50, ge=1, le=200), + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user), +): + """List coverage snapshots ordered by creation date (newest first).""" + query = db.query(CoverageSnapshot) + total = query.count() + + snapshots = ( + query + .order_by(CoverageSnapshot.created_at.desc()) + .offset(offset) + .limit(limit) + .all() + ) + + return { + "total": total, + "offset": offset, + "limit": limit, + "items": [_serialize_snapshot_summary(s) for s in snapshots], + } + + +# --------------------------------------------------------------------------- +# POST /snapshots — Create snapshot manually +# --------------------------------------------------------------------------- + +@router.post("", status_code=201) +def create_snapshot_endpoint( + payload: SnapshotCreate, + db: Session = Depends(get_db), + current_user: User = Depends(require_any_role("red_lead", "blue_lead", "admin")), +): + """Create a manual coverage snapshot with an optional name.""" + snapshot = create_snapshot(db, name=payload.name, user_id=current_user.id) + + log_action( + db, + user_id=current_user.id, + action="create_snapshot", + entity_type="snapshot", + entity_id=snapshot.id, + details={"name": snapshot.name, "score": snapshot.organization_score}, + ) + + return _serialize_snapshot_summary(snapshot) + + +# --------------------------------------------------------------------------- +# GET /snapshots/compare — Compare two snapshots +# --------------------------------------------------------------------------- + +@router.get("/compare") +def compare_snapshots_endpoint( + a: str = Query(..., description="Snapshot A ID"), + b: str = Query(..., description="Snapshot B ID"), + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user), +): + """Compare two snapshots showing improved, worsened, and unchanged techniques.""" + try: + a_id = uuid.UUID(a) + b_id = uuid.UUID(b) + except ValueError: + raise HTTPException(status_code=400, detail="Invalid snapshot ID format") + + result = compare_snapshots(db, a_id, b_id) + if "error" in result: + raise HTTPException(status_code=404, detail=result["error"]) + + return result + + +# --------------------------------------------------------------------------- +# GET /snapshots/{id} — Snapshot detail +# --------------------------------------------------------------------------- + +@router.get("/{snapshot_id}") +def get_snapshot( + snapshot_id: str, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user), +): + """Get detailed snapshot information including per-technique states.""" + snapshot = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_id).first() + if not snapshot: + raise HTTPException(status_code=404, detail="Snapshot not found") + + return _serialize_snapshot_detail(db, snapshot) + + +# --------------------------------------------------------------------------- +# DELETE /snapshots/{id} — Delete snapshot (admin only) +# --------------------------------------------------------------------------- + +@router.delete("/{snapshot_id}") +def delete_snapshot( + snapshot_id: str, + db: Session = Depends(get_db), + current_user: User = Depends(require_role("admin")), +): + """Delete a snapshot (admin only).""" + snapshot = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_id).first() + if not snapshot: + raise HTTPException(status_code=404, detail="Snapshot not found") + + log_action( + db, + user_id=current_user.id, + action="delete_snapshot", + entity_type="snapshot", + entity_id=snapshot.id, + details={"name": snapshot.name}, + ) + + db.delete(snapshot) + db.commit() + + return {"detail": "Snapshot deleted"} diff --git a/backend/app/routers/tests.py b/backend/app/routers/tests.py index 2106415..dbb72fb 100644 --- a/backend/app/routers/tests.py +++ b/backend/app/routers/tests.py @@ -52,6 +52,8 @@ from app.services.test_workflow_service import ( validate_as_red_lead as wf_validate_red, validate_as_blue_lead as wf_validate_blue, reopen_test as wf_reopen, + handle_remediation_completed as wf_handle_remediation, + get_retest_chain as wf_get_retest_chain, ) router = APIRouter(prefix="/tests", tags=["tests"]) @@ -546,9 +548,15 @@ def update_remediation( db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ): - """Update remediation fields on a test (any authenticated user).""" + """Update remediation fields on a test (any authenticated user). + + When ``remediation_status`` transitions to ``'completed'``, an automatic + re-test is created (subject to ``MAX_RETEST_COUNT``). + """ test = _get_test_or_404(db, test_id) + old_remediation_status = test.remediation_status + update_data = payload.model_dump(exclude_unset=True) for field, value in update_data.items(): setattr(test, field, value) @@ -565,6 +573,13 @@ def update_remediation( details={"updated_fields": list(update_data.keys())}, ) + # Auto-create retest when remediation is marked completed + new_status = update_data.get("remediation_status") + if new_status == "completed" and old_remediation_status != "completed": + retest = wf_handle_remediation(db, test, current_user) + if retest: + db.refresh(test) + return test @@ -603,3 +618,35 @@ def get_test_timeline( } for log in logs ] + + +# --------------------------------------------------------------------------- +# GET /tests/{id}/retest-chain — full retest chain +# --------------------------------------------------------------------------- + + +@router.get("/{test_id}/retest-chain") +def get_retest_chain( + test_id: uuid.UUID, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user), +): + """Return the full chain of retests (original + all retests) for a test.""" + chain = wf_get_retest_chain(db, test_id) + if not chain: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Test not found") + + return [ + { + "id": str(t.id), + "name": t.name, + "state": t.state.value if t.state else None, + "retest_of": str(t.retest_of) if t.retest_of else None, + "retest_count": t.retest_count, + "result": t.result.value if t.result else None, + "detection_result": t.detection_result.value if t.detection_result else None, + "remediation_status": t.remediation_status, + "created_at": t.created_at.isoformat() if t.created_at else None, + } + for t in chain + ] diff --git a/backend/app/schemas/test.py b/backend/app/schemas/test.py index 75c927b..40c6c04 100644 --- a/backend/app/schemas/test.py +++ b/backend/app/schemas/test.py @@ -142,6 +142,10 @@ class TestOut(BaseModel): remediation_status: str | None = None remediation_assignee: uuid.UUID | None = None + # Re-test fields + retest_of: uuid.UUID | None = None + retest_count: int = 0 + # Technique info (populated when joined) technique_mitre_id: str | None = None technique_name: str | None = None diff --git a/backend/app/services/snapshot_service.py b/backend/app/services/snapshot_service.py new file mode 100644 index 0000000..7cf5385 --- /dev/null +++ b/backend/app/services/snapshot_service.py @@ -0,0 +1,253 @@ +"""Snapshot service — create, compare, and manage coverage snapshots. + +Provides point-in-time coverage captures with normalised per-technique +storage and temporal comparison between any two snapshots. +""" + +import logging +import uuid +from datetime import datetime + +from sqlalchemy import func +from sqlalchemy.orm import Session + +from app.models.technique import Technique +from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState +from app.models.enums import TechniqueStatus +from app.services.scoring_service import calculate_technique_score, calculate_organization_score + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Create snapshot +# --------------------------------------------------------------------------- + + +def create_snapshot( + db: Session, + name: str | None = None, + user_id: uuid.UUID | None = None, +) -> CoverageSnapshot: + """Capture the current coverage state into a new snapshot. + + 1. Fetch every technique with its status and score. + 2. Compute aggregate counts. + 3. Persist a ``CoverageSnapshot`` with normalised + ``SnapshotTechniqueState`` rows. + """ + techniques = db.query(Technique).all() + + # Aggregate counters + validated_count = 0 + partial_count = 0 + not_covered_count = 0 + in_progress_count = 0 + not_evaluated_count = 0 + + technique_rows: list[dict] = [] + + for tech in techniques: + status_value = ( + tech.status_global.value + if isinstance(tech.status_global, TechniqueStatus) + else (tech.status_global or "not_evaluated") + ) + + # Count by status + if status_value == "validated": + validated_count += 1 + elif status_value == "partial": + partial_count += 1 + elif status_value == "not_covered": + not_covered_count += 1 + elif status_value == "in_progress": + in_progress_count += 1 + else: + not_evaluated_count += 1 + + # Compute technique score + score_data = calculate_technique_score(tech, db) + technique_rows.append({ + "technique_id": tech.id, + "mitre_id": tech.mitre_id, + "status": status_value, + "score": score_data["total_score"], + }) + + # Organization score + org_data = calculate_organization_score(db) + org_score = org_data.get("overall_score", 0) + + # Create the snapshot + snapshot = CoverageSnapshot( + name=name, + organization_score=org_score, + total_techniques=len(techniques), + validated_count=validated_count, + partial_count=partial_count, + not_covered_count=not_covered_count, + in_progress_count=in_progress_count, + not_evaluated_count=not_evaluated_count, + created_by=user_id, + ) + db.add(snapshot) + db.flush() # get snapshot.id + + # Create normalised technique state rows + for row in technique_rows: + state = SnapshotTechniqueState( + snapshot_id=snapshot.id, + technique_id=row["technique_id"], + mitre_id=row["mitre_id"], + status=row["status"], + score=row["score"], + ) + db.add(state) + + db.commit() + db.refresh(snapshot) + + logger.info( + "Snapshot '%s' created — %d techniques, org score %.1f", + snapshot.name or snapshot.id, + len(techniques), + org_score, + ) + return snapshot + + +# --------------------------------------------------------------------------- +# Compare snapshots +# --------------------------------------------------------------------------- + + +def compare_snapshots( + db: Session, + snapshot_a_id: uuid.UUID, + snapshot_b_id: uuid.UUID, +) -> dict: + """Compare two snapshots and return deltas. + + Returns improved/worsened technique lists plus aggregate statistics. + """ + snap_a = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_a_id).first() + snap_b = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_b_id).first() + + if not snap_a or not snap_b: + return {"error": "One or both snapshots not found"} + + # Build lookup dicts: mitre_id -> {status, score} + states_a = { + s.mitre_id: {"status": s.status, "score": s.score or 0} + for s in db.query(SnapshotTechniqueState) + .filter(SnapshotTechniqueState.snapshot_id == snapshot_a_id) + .all() + } + states_b = { + s.mitre_id: {"status": s.status, "score": s.score or 0} + for s in db.query(SnapshotTechniqueState) + .filter(SnapshotTechniqueState.snapshot_id == snapshot_b_id) + .all() + } + + # Status priority for comparison + STATUS_ORDER = { + "not_evaluated": 0, + "not_covered": 1, + "in_progress": 2, + "partial": 3, + "validated": 4, + } + + improved = [] + worsened = [] + unchanged_count = 0 + + all_mitre_ids = set(states_a.keys()) | set(states_b.keys()) + + for mitre_id in sorted(all_mitre_ids): + a = states_a.get(mitre_id, {"status": "not_evaluated", "score": 0}) + b = states_b.get(mitre_id, {"status": "not_evaluated", "score": 0}) + + a_order = STATUS_ORDER.get(a["status"], 0) + b_order = STATUS_ORDER.get(b["status"], 0) + + if b_order > a_order or (b_order == a_order and b["score"] > a["score"]): + improved.append({ + "mitre_id": mitre_id, + "old_status": a["status"], + "new_status": b["status"], + "old_score": a["score"], + "new_score": b["score"], + }) + elif b_order < a_order or (b_order == a_order and b["score"] < a["score"]): + worsened.append({ + "mitre_id": mitre_id, + "old_status": a["status"], + "new_status": b["status"], + "old_score": a["score"], + "new_score": b["score"], + }) + else: + unchanged_count += 1 + + def _snap_summary(snap: CoverageSnapshot) -> dict: + return { + "id": str(snap.id), + "name": snap.name, + "organization_score": snap.organization_score, + "total_techniques": snap.total_techniques, + "validated_count": snap.validated_count, + "partial_count": snap.partial_count, + "not_covered_count": snap.not_covered_count, + "in_progress_count": snap.in_progress_count, + "not_evaluated_count": snap.not_evaluated_count, + "created_at": snap.created_at.isoformat() if snap.created_at else None, + } + + return { + "snapshot_a": _snap_summary(snap_a), + "snapshot_b": _snap_summary(snap_b), + "score_delta": round(snap_b.organization_score - snap_a.organization_score, 1), + "improved": improved, + "worsened": worsened, + "unchanged_count": unchanged_count, + "summary": { + "improved_count": len(improved), + "worsened_count": len(worsened), + "new_count": len(states_b.keys() - states_a.keys()), + }, + } + + +# --------------------------------------------------------------------------- +# Cleanup +# --------------------------------------------------------------------------- + + +def cleanup_old_snapshots(db: Session, keep_last: int = 52) -> int: + """Delete oldest snapshots, keeping the most recent *keep_last*. + + Returns the number of snapshots deleted. + """ + total = db.query(func.count(CoverageSnapshot.id)).scalar() or 0 + if total <= keep_last: + return 0 + + to_delete = total - keep_last + old_snapshots = ( + db.query(CoverageSnapshot) + .order_by(CoverageSnapshot.created_at.asc()) + .limit(to_delete) + .all() + ) + + deleted = 0 + for snap in old_snapshots: + db.delete(snap) + deleted += 1 + + db.commit() + logger.info("Snapshot cleanup — deleted %d old snapshots (kept %d)", deleted, keep_last) + return deleted diff --git a/backend/app/services/test_workflow_service.py b/backend/app/services/test_workflow_service.py index 1f4cfe4..a6205aa 100644 --- a/backend/app/services/test_workflow_service.py +++ b/backend/app/services/test_workflow_service.py @@ -16,11 +16,12 @@ from datetime import datetime from fastapi import HTTPException, status from sqlalchemy.orm import Session +from app.config import settings from app.models.enums import TestState from app.models.test import Test from app.models.user import User from app.services.audit_service import log_action -from app.services.notification_service import notify_test_state_change +from app.services.notification_service import notify_test_state_change, create_notification # --------------------------------------------------------------------------- # Valid transition map @@ -298,6 +299,131 @@ def check_dual_validation(db: Session, test: Test) -> Test: return test +def handle_remediation_completed(db: Session, test: Test, user: User) -> Test | None: + """Create a re-test when remediation is completed. + + When a test's remediation_status changes to 'completed', this function + creates a new test (retest) with the same base data to verify that the + fix was effective. + + Prevents infinite loops by enforcing ``MAX_RETEST_COUNT``. + + Returns the new retest or *None* if the limit was reached. + """ + # Always reference the original test, not an intermediate retest + original_test_id = test.retest_of or test.id + + if test.retest_count >= settings.MAX_RETEST_COUNT: + # Max retests reached — notify and bail out + if test.created_by: + create_notification( + db, + user_id=test.created_by, + type="max_retests_reached", + title="Maximum retests reached", + message=( + f'Test "{test.name}" has reached the maximum of ' + f'{settings.MAX_RETEST_COUNT} retests. Manual review required.' + ), + entity_type="test", + entity_id=test.id, + ) + + log_action( + db, + user_id=user.id, + action="max_retests_reached", + entity_type="test", + entity_id=test.id, + details={ + "retest_count": test.retest_count, + "max_allowed": settings.MAX_RETEST_COUNT, + "original_test_id": str(original_test_id), + }, + ) + return None + + retest = Test( + technique_id=test.technique_id, + name=f"[Retest #{test.retest_count + 1}] {test.name.replace(f'[Retest #{test.retest_count}] ', '')}", + description=test.description, + platform=test.platform, + procedure_text=test.procedure_text, + tool_used=test.tool_used, + state=TestState.draft, + created_by=test.created_by, + retest_of=original_test_id, + retest_count=test.retest_count + 1, + ) + db.add(retest) + db.flush() + + log_action( + db, + user_id=user.id, + action="create_retest", + entity_type="test", + entity_id=retest.id, + details={ + "original_test_id": str(original_test_id), + "retest_number": retest.retest_count, + "source_test_id": str(test.id), + }, + ) + + # Notify the test creator and any red_tech users + if test.created_by: + create_notification( + db, + user_id=test.created_by, + type="retest_created", + title="Re-test created", + message=( + f'A re-test has been automatically created for "{test.name}" ' + f'after remediation was completed.' + ), + entity_type="test", + entity_id=retest.id, + ) + + db.commit() + db.refresh(retest) + return retest + + +def get_retest_chain(db: Session, test_id) -> list[Test]: + """Return the full chain of retests for a given test. + + Includes the original test and all subsequent retests, ordered + by retest_count. + """ + import uuid as _uuid + + tid = _uuid.UUID(str(test_id)) if not isinstance(test_id, _uuid.UUID) else test_id + + # Find the original test first + test = db.query(Test).filter(Test.id == tid).first() + if not test: + return [] + + original_id = test.retest_of or test.id + + # Get original + original = db.query(Test).filter(Test.id == original_id).first() + if not original: + return [test] + + # Get all retests of the original + retests = ( + db.query(Test) + .filter(Test.retest_of == original_id) + .order_by(Test.retest_count) + .all() + ) + + return [original] + retests + + def reopen_test(db: Session, test: Test, user: User) -> Test: """Move a ``rejected`` test back to ``draft``, clearing validation fields. diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index 7debaaf..e41c525 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -19,6 +19,7 @@ import ThreatActorsPage from "./pages/ThreatActorsPage"; import ThreatActorDetailPage from "./pages/ThreatActorDetailPage"; import CampaignsPage from "./pages/CampaignsPage"; import CampaignDetailPage from "./pages/CampaignDetailPage"; +import ComparisonPage from "./pages/ComparisonPage"; import Layout from "./components/Layout"; import ProtectedRoute from "./components/ProtectedRoute"; @@ -58,6 +59,7 @@ export default function App() { } /> } /> } /> + } /> } /> { + const { data } = await client.get("/snapshots", { params }); + return data; +} + +/** Create a manual snapshot. */ +export async function createSnapshot(name?: string): Promise { + const { data } = await client.post("/snapshots", { + name: name || null, + }); + return data; +} + +/** Get snapshot detail with per-technique states. */ +export async function getSnapshot(snapshotId: string): Promise { + const { data } = await client.get( + `/snapshots/${snapshotId}`, + ); + return data; +} + +/** Compare two snapshots. */ +export async function compareSnapshots( + aId: string, + bId: string, +): Promise { + const { data } = await client.get("/snapshots/compare", { + params: { a: aId, b: bId }, + }); + return data; +} + +/** Delete a snapshot (admin only). */ +export async function deleteSnapshot(snapshotId: string): Promise { + await client.delete(`/snapshots/${snapshotId}`); +} diff --git a/frontend/src/api/tests.ts b/frontend/src/api/tests.ts index e4b73ca..70d55cf 100644 --- a/frontend/src/api/tests.ts +++ b/frontend/src/api/tests.ts @@ -204,6 +204,26 @@ export async function getTestTimeline( return data; } +// ── Retest Chain ──────────────────────────────────────────────────── + +export interface RetestChainEntry { + id: string; + name: string; + state: string | null; + retest_of: string | null; + retest_count: number; + result: string | null; + detection_result: string | null; + remediation_status: string | null; + created_at: string | null; +} + +/** Get the full retest chain for a test. */ +export async function getRetestChain(testId: string): Promise { + const { data } = await client.get(`/tests/${testId}/retest-chain`); + return data; +} + // ── Legacy (kept for backwards compat) ───────────────────────────── /** Validate a test (legacy endpoint). */ diff --git a/frontend/src/components/Sidebar.tsx b/frontend/src/components/Sidebar.tsx index 84b8707..d5aee29 100644 --- a/frontend/src/components/Sidebar.tsx +++ b/frontend/src/components/Sidebar.tsx @@ -18,6 +18,7 @@ import { Grid3X3, Gauge, ShieldCheck, + GitCompareArrows, } from "lucide-react"; import { useAuth } from "../context/AuthContext"; @@ -46,6 +47,7 @@ const mainLinks: NavItem[] = [ { to: "/reports", label: "Reports", icon: BarChart3 }, { to: "/threat-actors", label: "Threat Actors", icon: Crosshair }, { to: "/campaigns", label: "Campaigns", icon: Zap }, + { to: "/comparison", label: "Comparison", icon: GitCompareArrows }, { to: "/compliance", label: "Compliance", icon: ShieldCheck }, ]; diff --git a/frontend/src/pages/ComparisonPage.tsx b/frontend/src/pages/ComparisonPage.tsx new file mode 100644 index 0000000..a8a4f8e --- /dev/null +++ b/frontend/src/pages/ComparisonPage.tsx @@ -0,0 +1,458 @@ +import { useState, useMemo } from "react"; +import { useNavigate } from "react-router-dom"; +import { useQuery } from "@tanstack/react-query"; +import { + Loader2, + AlertCircle, + ArrowUp, + ArrowDown, + Minus, + GitCompareArrows, + Camera, + TrendingUp, + TrendingDown, +} from "lucide-react"; +import { + listSnapshots, + compareSnapshots, + type SnapshotSummary, + type SnapshotComparison, +} from "../api/snapshots"; + +type Tab = "improved" | "worsened" | "unchanged"; + +const statusColors: Record = { + validated: "text-green-400", + partial: "text-yellow-400", + not_covered: "text-red-400", + in_progress: "text-blue-400", + not_evaluated: "text-gray-500", +}; + +const statusDots: Record = { + validated: "bg-green-400", + partial: "bg-yellow-400", + not_covered: "bg-red-400", + in_progress: "bg-blue-400", + not_evaluated: "bg-gray-500", +}; + +function StatusBadge({ status }: { status: string }) { + return ( + + + + {status.replace(/_/g, " ")} + + + ); +} + +function DeltaArrow({ delta }: { delta: number }) { + if (delta > 0) return ; + if (delta < 0) return ; + return ; +} + +function MetricCard({ + label, + valueA, + valueB, + suffix, +}: { + label: string; + valueA: number; + valueB: number; + suffix?: string; +}) { + const delta = valueB - valueA; + const deltaColor = + delta > 0 ? "text-green-400" : delta < 0 ? "text-red-400" : "text-gray-500"; + + return ( +
+ {label} +
+ + {valueB} + {suffix} + + {delta !== 0 && ( + + + {delta > 0 ? "+" : ""} + {delta} + + )} +
+
+ ); +} + +export default function ComparisonPage() { + const navigate = useNavigate(); + const [snapA, setSnapA] = useState(""); + const [snapB, setSnapB] = useState(""); + const [activeTab, setActiveTab] = useState("improved"); + + // Fetch all snapshots for the dropdowns + const { data: snapshotsData, isLoading: isLoadingSnapshots } = useQuery({ + queryKey: ["snapshots", "all"], + queryFn: () => listSnapshots({ limit: 200 }), + }); + + const snapshots = snapshotsData?.items || []; + + // Comparison query + const { + data: comparison, + isLoading: isComparing, + error: compareError, + } = useQuery({ + queryKey: ["snapshot-compare", snapA, snapB], + queryFn: () => compareSnapshots(snapA, snapB), + enabled: !!snapA && !!snapB && snapA !== snapB, + }); + + const formatDate = (dateStr: string | null) => { + if (!dateStr) return "—"; + return new Date(dateStr).toLocaleDateString("en-US", { + year: "numeric", + month: "short", + day: "numeric", + hour: "2-digit", + minute: "2-digit", + }); + }; + + // For the "unchanged" tab, we don't get individual rows from the API, + // just a count, so we show the count. + const tabData = useMemo(() => { + if (!comparison) return { improved: [], worsened: [], unchanged_count: 0 }; + return { + improved: comparison.improved, + worsened: comparison.worsened, + unchanged_count: comparison.unchanged_count, + }; + }, [comparison]); + + return ( +
+ {/* Header */} +
+
+
+ +
+
+

Temporal Comparison

+

Compare coverage snapshots over time

+
+
+
+ + {/* Snapshot selectors */} +
+
+ {/* Snapshot A */} +
+ + +
+ + {/* Snapshot B */} +
+ + +
+
+ + {isLoadingSnapshots && ( +
+ + Loading snapshots... +
+ )} +
+ + {/* Loading / Error */} + {isComparing && ( +
+ +
+ )} + {compareError && ( +
+ + Failed to compare snapshots +
+ )} + + {/* Comparison results */} + {comparison && ( + <> + {/* Side-by-side score cards */} +
+ {/* Snapshot A card */} +
+
+ + + {comparison.snapshot_a.name || "Snapshot A"} + + + {formatDate(comparison.snapshot_a.created_at)} + +
+
+ {comparison.snapshot_a.organization_score} +
+
+ + + + +
+
+ + {/* Snapshot B card */} +
+
+ + + {comparison.snapshot_b.name || "Snapshot B"} + + + {formatDate(comparison.snapshot_b.created_at)} + +
+
+ + {comparison.snapshot_b.organization_score} + + {comparison.score_delta !== 0 && ( + 0 ? "text-green-400" : "text-red-400" + }`} + > + {comparison.score_delta > 0 ? ( + + ) : ( + + )} + {comparison.score_delta > 0 ? "+" : ""} + {comparison.score_delta} + + )} +
+
+ + + + +
+
+
+ + {/* Tabs */} +
+
+ + + +
+ +
+ {activeTab === "unchanged" ? ( +
+ +

+ {comparison.unchanged_count} techniques unchanged +

+

+ These techniques had the same status and score in both snapshots. +

+
+ ) : ( + <> + {(activeTab === "improved" + ? tabData.improved + : tabData.worsened + ).length === 0 ? ( +
+

No techniques {activeTab} between snapshots.

+
+ ) : ( +
+ + + + + + + + + + + + + {(activeTab === "improved" + ? tabData.improved + : tabData.worsened + ).map((item) => ( + + navigate(`/techniques/${item.mitre_id}`) + } + > + + + + + + + + ))} + +
MITRE IDBeforeAfterScore BeforeScore AfterDelta
+ + {item.mitre_id} + + + + + + + {item.old_score} + + {item.new_score} + + item.old_score + ? "text-green-400" + : item.new_score < item.old_score + ? "text-red-400" + : "text-gray-500" + }`} + > + + {item.new_score > item.old_score ? "+" : ""} + {Math.round((item.new_score - item.old_score) * 10) / 10} + +
+
+ )} + + )} +
+
+ + )} + + {/* No selection prompt */} + {!comparison && !isComparing && !compareError && ( +
+ +

Select two snapshots to compare

+

+ Choose a baseline and current snapshot from the dropdowns above. +

+
+ )} +
+ ); +} diff --git a/frontend/src/pages/TestDetailPage.tsx b/frontend/src/pages/TestDetailPage.tsx index 075ff92..fb1768c 100644 --- a/frontend/src/pages/TestDetailPage.tsx +++ b/frontend/src/pages/TestDetailPage.tsx @@ -14,6 +14,7 @@ import { validateAsBlueLead, reopenTest, getTestTimeline, + getRetestChain, } from "../api/tests"; import { uploadEvidence, getEvidence } from "../api/evidence"; import { useAuth } from "../context/AuthContext"; @@ -79,6 +80,12 @@ export default function TestDetailPage() { enabled: !!testId, }); + const { data: retestChain = [] } = useQuery({ + queryKey: ["retest-chain", testId], + queryFn: () => getRetestChain(testId!), + enabled: !!testId && !!test && (test.retest_of !== null || test.retest_count > 0), + }); + // Hydrate drafts from test data useEffect(() => { if (test) { @@ -442,6 +449,55 @@ export default function TestDetailPage() { )} + + {/* Retest Chain */} + {(test.retest_of || test.retest_count > 0 || retestChain.length > 1) && ( +
+

Retest Chain

+ {test.retest_of && ( +
+ + Retest {test.retest_count} / 3 + +
+
+
+
+ )} +
+ {retestChain.map((entry) => ( + + ))} +
+
+ )}
diff --git a/frontend/src/types/models.ts b/frontend/src/types/models.ts index 7ecacbb..a3c3ea7 100644 --- a/frontend/src/types/models.ts +++ b/frontend/src/types/models.ts @@ -91,6 +91,10 @@ export interface Test { remediation_status: string | null; remediation_assignee: string | null; + // Re-test fields + retest_of: string | null; + retest_count: number; + // Technique info (populated in list endpoints) technique_mitre_id: string | null; technique_name: string | null;