feat(phase-30): add coverage snapshots, temporal comparison and auto re-testing (T-230 to T-232)

This commit is contained in:
2026-02-10 08:34:29 +01:00
parent 2ac8e7f4a5
commit 4d124b42dd
20 changed files with 1517 additions and 4 deletions

View File

@@ -0,0 +1,77 @@
"""add_coverage_snapshots
Revision ID: b015snapshots
Revises: b014compliance
Create Date: 2026-02-10 00:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = "b015snapshots"
down_revision: Union[str, None] = "b014compliance"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ── coverage_snapshots ────────────────────────────────────────
op.create_table(
"coverage_snapshots",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
sa.Column("name", sa.String, nullable=True),
sa.Column("organization_score", sa.Float, nullable=False),
sa.Column("total_techniques", sa.Integer, nullable=False),
sa.Column("validated_count", sa.Integer, nullable=False),
sa.Column("partial_count", sa.Integer, nullable=False),
sa.Column("not_covered_count", sa.Integer, nullable=False),
sa.Column("in_progress_count", sa.Integer, nullable=False),
sa.Column("not_evaluated_count", sa.Integer, nullable=False),
sa.Column(
"created_by",
postgresql.UUID(as_uuid=True),
sa.ForeignKey("users.id", ondelete="SET NULL"),
nullable=True,
),
sa.Column("created_at", sa.DateTime, server_default=sa.func.now()),
)
# ── snapshot_technique_states ─────────────────────────────────
op.create_table(
"snapshot_technique_states",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
sa.Column(
"snapshot_id",
postgresql.UUID(as_uuid=True),
sa.ForeignKey("coverage_snapshots.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column(
"technique_id",
postgresql.UUID(as_uuid=True),
sa.ForeignKey("techniques.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("mitre_id", sa.String, nullable=False),
sa.Column("status", sa.String, nullable=False),
sa.Column("score", sa.Float, nullable=True),
)
op.create_index(
"ix_snapshot_technique_states_snapshot",
"snapshot_technique_states",
["snapshot_id"],
)
op.create_index(
"ix_snapshot_technique_states_technique",
"snapshot_technique_states",
["technique_id"],
)
def downgrade() -> None:
op.drop_table("snapshot_technique_states")
op.drop_table("coverage_snapshots")

View File

@@ -0,0 +1,41 @@
"""add_retest_fields
Revision ID: b016retests
Revises: b015snapshots
Create Date: 2026-02-10 01:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = "b016retests"
down_revision: Union[str, None] = "b015snapshots"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.add_column(
"tests",
sa.Column(
"retest_of",
postgresql.UUID(as_uuid=True),
sa.ForeignKey("tests.id"),
nullable=True,
),
)
op.add_column(
"tests",
sa.Column("retest_count", sa.Integer, server_default="0", nullable=False),
)
op.create_index("ix_tests_retest_of", "tests", ["retest_of"])
def downgrade() -> None:
op.drop_index("ix_tests_retest_of", table_name="tests")
op.drop_column("tests", "retest_count")
op.drop_column("tests", "retest_of")

View File

@@ -11,6 +11,9 @@ class Settings(BaseSettings):
MINIO_SECRET_KEY: str = "minioadmin"
MINIO_BUCKET: str = "evidence"
# Re-testing
MAX_RETEST_COUNT: int = 3 # maximum automatic retests per original test
# Scoring weights (must sum to 100)
SCORING_WEIGHT_TESTS: int = 40
SCORING_WEIGHT_DETECTION_RULES: int = 20

View File

@@ -18,6 +18,7 @@ from app.database import SessionLocal
from app.services.mitre_sync_service import sync_mitre
from app.services.intel_service import scan_intel
from app.services.notification_service import cleanup_old_notifications
from app.services.snapshot_service import create_snapshot, cleanup_old_snapshots
logger = logging.getLogger(__name__)
@@ -59,6 +60,26 @@ def _run_notification_cleanup() -> None:
db.close()
def _run_weekly_snapshot() -> None:
"""Create a weekly coverage snapshot and clean up old ones."""
logger.info("Scheduled weekly snapshot job starting...")
db = SessionLocal()
try:
snapshot = create_snapshot(db, name="Auto-weekly")
logger.info(
"Weekly snapshot created — score %.1f, %d techniques",
snapshot.organization_score,
snapshot.total_techniques,
)
deleted = cleanup_old_snapshots(db, keep_last=52)
if deleted:
logger.info("Cleaned up %d old snapshots", deleted)
except Exception:
logger.exception("Weekly snapshot job failed")
finally:
db.close()
def _run_intel_scan() -> None:
"""Execute an intel scan inside its own DB session."""
logger.info("Scheduled intel scan job starting...")
@@ -111,5 +132,18 @@ def start_scheduler() -> None:
name="Notification cleanup (daily)",
replace_existing=True,
)
scheduler.add_job(
_run_weekly_snapshot,
trigger="cron",
day_of_week="sun",
hour=0,
minute=0,
id="weekly_snapshot",
name="Weekly coverage snapshot (Sundays 00:00)",
replace_existing=True,
)
scheduler.start()
logger.info("Background scheduler started — mitre_sync (24h), intel_scan (7d), notification_cleanup (24h)")
logger.info(
"Background scheduler started — mitre_sync (24h), intel_scan (7d), "
"notification_cleanup (24h), weekly_snapshot (Sundays 00:00)"
)

View File

@@ -27,6 +27,7 @@ from app.routers import heatmap as heatmap_router
from app.routers import scores as scores_router
from app.routers import operational_metrics as operational_metrics_router
from app.routers import compliance as compliance_router
from app.routers import snapshots as snapshots_router
from app.storage import ensure_bucket_exists
from app.jobs.mitre_sync_job import start_scheduler, scheduler
@@ -78,6 +79,7 @@ app.include_router(heatmap_router.router, prefix="/api/v1")
app.include_router(scores_router.router, prefix="/api/v1")
app.include_router(operational_metrics_router.router, prefix="/api/v1")
app.include_router(compliance_router.router, prefix="/api/v1")
app.include_router(snapshots_router.router, prefix="/api/v1")
@app.get("/health")

View File

@@ -15,6 +15,7 @@ from app.models.test_template_detection_rule import TestTemplateDetectionRule
from app.models.test_detection_result import TestDetectionResult
from app.models.campaign import Campaign, CampaignTest
from app.models.compliance import ComplianceFramework, ComplianceControl, ComplianceControlMapping
from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState
from app.models.enums import TechniqueStatus, TestState, TestResult, TeamSide
__all__ = [
@@ -25,5 +26,6 @@ __all__ = [
"TestTemplateDetectionRule", "TestDetectionResult",
"Campaign", "CampaignTest",
"ComplianceFramework", "ComplianceControl", "ComplianceControlMapping",
"CoverageSnapshot", "SnapshotTechniqueState",
"TechniqueStatus", "TestState", "TestResult", "TeamSide",
]

View File

@@ -0,0 +1,78 @@
"""Coverage snapshot models — periodic snapshots of coverage state.
CoverageSnapshot stores aggregate metrics at a point in time.
SnapshotTechniqueState stores per-technique state (normalized, one row
per technique per snapshot) to avoid bloated JSONB fields.
"""
import uuid
from datetime import datetime
from sqlalchemy import (
Column, String, Float, Integer, DateTime,
ForeignKey, Index,
)
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship
from app.database import Base
class CoverageSnapshot(Base):
"""A point-in-time snapshot of the organisation's overall coverage."""
__tablename__ = "coverage_snapshots"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
name = Column(String, nullable=True) # e.g. "Pre-remediación Q1"
organization_score = Column(Float, nullable=False)
total_techniques = Column(Integer, nullable=False)
validated_count = Column(Integer, nullable=False)
partial_count = Column(Integer, nullable=False)
not_covered_count = Column(Integer, nullable=False)
in_progress_count = Column(Integer, nullable=False)
not_evaluated_count = Column(Integer, nullable=False)
created_by = Column(
UUID(as_uuid=True),
ForeignKey("users.id", ondelete="SET NULL"),
nullable=True,
)
created_at = Column(DateTime, default=datetime.utcnow)
# Relationships
creator = relationship("User", foreign_keys=[created_by])
technique_states = relationship(
"SnapshotTechniqueState",
back_populates="snapshot",
cascade="all, delete-orphan",
)
class SnapshotTechniqueState(Base):
"""Per-technique state within a snapshot (normalised storage)."""
__tablename__ = "snapshot_technique_states"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
snapshot_id = Column(
UUID(as_uuid=True),
ForeignKey("coverage_snapshots.id", ondelete="CASCADE"),
nullable=False,
)
technique_id = Column(
UUID(as_uuid=True),
ForeignKey("techniques.id", ondelete="CASCADE"),
nullable=False,
)
mitre_id = Column(String, nullable=False) # denormalised for fast queries
status = Column(String, nullable=False)
score = Column(Float, nullable=True)
# Relationships
snapshot = relationship("CoverageSnapshot", back_populates="technique_states")
technique = relationship("Technique")
__table_args__ = (
Index("ix_snapshot_technique_states_snapshot", "snapshot_id"),
Index("ix_snapshot_technique_states_technique", "technique_id"),
)

View File

@@ -1,7 +1,7 @@
import uuid
from datetime import datetime
from sqlalchemy import Column, String, Text, Boolean, DateTime, ForeignKey, Enum
from sqlalchemy import Column, String, Text, Boolean, Integer, DateTime, ForeignKey, Enum
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship
@@ -54,6 +54,10 @@ class Test(Base):
remediation_status = Column(String, nullable=True) # pending / in_progress / completed / not_applicable
remediation_assignee = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
# ── Re-test fields ────────────────────────────────────────────
retest_of = Column(UUID(as_uuid=True), ForeignKey("tests.id"), nullable=True)
retest_count = Column(Integer, default=0)
# ── Relationships ───────────────────────────────────────────────
technique = relationship("Technique", back_populates="tests")
evidences = relationship("Evidence", back_populates="test")
@@ -61,3 +65,5 @@ class Test(Base):
red_validator = relationship("User", foreign_keys=[red_validated_by])
blue_validator = relationship("User", foreign_keys=[blue_validated_by])
remediation_user = relationship("User", foreign_keys=[remediation_assignee])
original_test = relationship("Test", remote_side="Test.id", foreign_keys=[retest_of])
retests = relationship("Test", foreign_keys=[retest_of], back_populates="original_test")

View File

@@ -0,0 +1,205 @@
"""Snapshot endpoints — coverage snapshots CRUD and comparison.
Provides periodic and manual snapshots of the organisation's coverage
state, plus temporal comparison between any two snapshots.
"""
import logging
import uuid
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel
from sqlalchemy.orm import Session
from app.database import get_db
from app.dependencies.auth import get_current_user, require_any_role, require_role
from app.models.user import User
from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState
from app.services.snapshot_service import (
create_snapshot,
compare_snapshots,
cleanup_old_snapshots,
)
from app.services.audit_service import log_action
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/snapshots", tags=["snapshots"])
# ── Pydantic schemas ─────────────────────────────────────────────────
class SnapshotCreate(BaseModel):
name: Optional[str] = None
# ── Helpers ──────────────────────────────────────────────────────────
def _serialize_snapshot_summary(snap: CoverageSnapshot) -> dict:
"""Lightweight serialization for list views."""
return {
"id": str(snap.id),
"name": snap.name,
"organization_score": snap.organization_score,
"total_techniques": snap.total_techniques,
"validated_count": snap.validated_count,
"partial_count": snap.partial_count,
"not_covered_count": snap.not_covered_count,
"in_progress_count": snap.in_progress_count,
"not_evaluated_count": snap.not_evaluated_count,
"created_by": str(snap.created_by) if snap.created_by else None,
"created_at": snap.created_at.isoformat() if snap.created_at else None,
}
def _serialize_snapshot_detail(db: Session, snap: CoverageSnapshot) -> dict:
"""Full serialization including technique states."""
base = _serialize_snapshot_summary(snap)
technique_states = (
db.query(SnapshotTechniqueState)
.filter(SnapshotTechniqueState.snapshot_id == snap.id)
.order_by(SnapshotTechniqueState.mitre_id)
.all()
)
base["technique_states"] = [
{
"mitre_id": s.mitre_id,
"technique_id": str(s.technique_id),
"status": s.status,
"score": s.score,
}
for s in technique_states
]
return base
# ---------------------------------------------------------------------------
# GET /snapshots — List snapshots (paginated)
# ---------------------------------------------------------------------------
@router.get("")
def list_snapshots(
offset: int = Query(0, ge=0),
limit: int = Query(50, ge=1, le=200),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""List coverage snapshots ordered by creation date (newest first)."""
query = db.query(CoverageSnapshot)
total = query.count()
snapshots = (
query
.order_by(CoverageSnapshot.created_at.desc())
.offset(offset)
.limit(limit)
.all()
)
return {
"total": total,
"offset": offset,
"limit": limit,
"items": [_serialize_snapshot_summary(s) for s in snapshots],
}
# ---------------------------------------------------------------------------
# POST /snapshots — Create snapshot manually
# ---------------------------------------------------------------------------
@router.post("", status_code=201)
def create_snapshot_endpoint(
payload: SnapshotCreate,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead", "admin")),
):
"""Create a manual coverage snapshot with an optional name."""
snapshot = create_snapshot(db, name=payload.name, user_id=current_user.id)
log_action(
db,
user_id=current_user.id,
action="create_snapshot",
entity_type="snapshot",
entity_id=snapshot.id,
details={"name": snapshot.name, "score": snapshot.organization_score},
)
return _serialize_snapshot_summary(snapshot)
# ---------------------------------------------------------------------------
# GET /snapshots/compare — Compare two snapshots
# ---------------------------------------------------------------------------
@router.get("/compare")
def compare_snapshots_endpoint(
a: str = Query(..., description="Snapshot A ID"),
b: str = Query(..., description="Snapshot B ID"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Compare two snapshots showing improved, worsened, and unchanged techniques."""
try:
a_id = uuid.UUID(a)
b_id = uuid.UUID(b)
except ValueError:
raise HTTPException(status_code=400, detail="Invalid snapshot ID format")
result = compare_snapshots(db, a_id, b_id)
if "error" in result:
raise HTTPException(status_code=404, detail=result["error"])
return result
# ---------------------------------------------------------------------------
# GET /snapshots/{id} — Snapshot detail
# ---------------------------------------------------------------------------
@router.get("/{snapshot_id}")
def get_snapshot(
snapshot_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Get detailed snapshot information including per-technique states."""
snapshot = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_id).first()
if not snapshot:
raise HTTPException(status_code=404, detail="Snapshot not found")
return _serialize_snapshot_detail(db, snapshot)
# ---------------------------------------------------------------------------
# DELETE /snapshots/{id} — Delete snapshot (admin only)
# ---------------------------------------------------------------------------
@router.delete("/{snapshot_id}")
def delete_snapshot(
snapshot_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
"""Delete a snapshot (admin only)."""
snapshot = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_id).first()
if not snapshot:
raise HTTPException(status_code=404, detail="Snapshot not found")
log_action(
db,
user_id=current_user.id,
action="delete_snapshot",
entity_type="snapshot",
entity_id=snapshot.id,
details={"name": snapshot.name},
)
db.delete(snapshot)
db.commit()
return {"detail": "Snapshot deleted"}

View File

@@ -52,6 +52,8 @@ from app.services.test_workflow_service import (
validate_as_red_lead as wf_validate_red,
validate_as_blue_lead as wf_validate_blue,
reopen_test as wf_reopen,
handle_remediation_completed as wf_handle_remediation,
get_retest_chain as wf_get_retest_chain,
)
router = APIRouter(prefix="/tests", tags=["tests"])
@@ -546,9 +548,15 @@ def update_remediation(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Update remediation fields on a test (any authenticated user)."""
"""Update remediation fields on a test (any authenticated user).
When ``remediation_status`` transitions to ``'completed'``, an automatic
re-test is created (subject to ``MAX_RETEST_COUNT``).
"""
test = _get_test_or_404(db, test_id)
old_remediation_status = test.remediation_status
update_data = payload.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(test, field, value)
@@ -565,6 +573,13 @@ def update_remediation(
details={"updated_fields": list(update_data.keys())},
)
# Auto-create retest when remediation is marked completed
new_status = update_data.get("remediation_status")
if new_status == "completed" and old_remediation_status != "completed":
retest = wf_handle_remediation(db, test, current_user)
if retest:
db.refresh(test)
return test
@@ -603,3 +618,35 @@ def get_test_timeline(
}
for log in logs
]
# ---------------------------------------------------------------------------
# GET /tests/{id}/retest-chain — full retest chain
# ---------------------------------------------------------------------------
@router.get("/{test_id}/retest-chain")
def get_retest_chain(
test_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Return the full chain of retests (original + all retests) for a test."""
chain = wf_get_retest_chain(db, test_id)
if not chain:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Test not found")
return [
{
"id": str(t.id),
"name": t.name,
"state": t.state.value if t.state else None,
"retest_of": str(t.retest_of) if t.retest_of else None,
"retest_count": t.retest_count,
"result": t.result.value if t.result else None,
"detection_result": t.detection_result.value if t.detection_result else None,
"remediation_status": t.remediation_status,
"created_at": t.created_at.isoformat() if t.created_at else None,
}
for t in chain
]

View File

@@ -142,6 +142,10 @@ class TestOut(BaseModel):
remediation_status: str | None = None
remediation_assignee: uuid.UUID | None = None
# Re-test fields
retest_of: uuid.UUID | None = None
retest_count: int = 0
# Technique info (populated when joined)
technique_mitre_id: str | None = None
technique_name: str | None = None

View File

@@ -0,0 +1,253 @@
"""Snapshot service — create, compare, and manage coverage snapshots.
Provides point-in-time coverage captures with normalised per-technique
storage and temporal comparison between any two snapshots.
"""
import logging
import uuid
from datetime import datetime
from sqlalchemy import func
from sqlalchemy.orm import Session
from app.models.technique import Technique
from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState
from app.models.enums import TechniqueStatus
from app.services.scoring_service import calculate_technique_score, calculate_organization_score
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Create snapshot
# ---------------------------------------------------------------------------
def create_snapshot(
db: Session,
name: str | None = None,
user_id: uuid.UUID | None = None,
) -> CoverageSnapshot:
"""Capture the current coverage state into a new snapshot.
1. Fetch every technique with its status and score.
2. Compute aggregate counts.
3. Persist a ``CoverageSnapshot`` with normalised
``SnapshotTechniqueState`` rows.
"""
techniques = db.query(Technique).all()
# Aggregate counters
validated_count = 0
partial_count = 0
not_covered_count = 0
in_progress_count = 0
not_evaluated_count = 0
technique_rows: list[dict] = []
for tech in techniques:
status_value = (
tech.status_global.value
if isinstance(tech.status_global, TechniqueStatus)
else (tech.status_global or "not_evaluated")
)
# Count by status
if status_value == "validated":
validated_count += 1
elif status_value == "partial":
partial_count += 1
elif status_value == "not_covered":
not_covered_count += 1
elif status_value == "in_progress":
in_progress_count += 1
else:
not_evaluated_count += 1
# Compute technique score
score_data = calculate_technique_score(tech, db)
technique_rows.append({
"technique_id": tech.id,
"mitre_id": tech.mitre_id,
"status": status_value,
"score": score_data["total_score"],
})
# Organization score
org_data = calculate_organization_score(db)
org_score = org_data.get("overall_score", 0)
# Create the snapshot
snapshot = CoverageSnapshot(
name=name,
organization_score=org_score,
total_techniques=len(techniques),
validated_count=validated_count,
partial_count=partial_count,
not_covered_count=not_covered_count,
in_progress_count=in_progress_count,
not_evaluated_count=not_evaluated_count,
created_by=user_id,
)
db.add(snapshot)
db.flush() # get snapshot.id
# Create normalised technique state rows
for row in technique_rows:
state = SnapshotTechniqueState(
snapshot_id=snapshot.id,
technique_id=row["technique_id"],
mitre_id=row["mitre_id"],
status=row["status"],
score=row["score"],
)
db.add(state)
db.commit()
db.refresh(snapshot)
logger.info(
"Snapshot '%s' created — %d techniques, org score %.1f",
snapshot.name or snapshot.id,
len(techniques),
org_score,
)
return snapshot
# ---------------------------------------------------------------------------
# Compare snapshots
# ---------------------------------------------------------------------------
def compare_snapshots(
db: Session,
snapshot_a_id: uuid.UUID,
snapshot_b_id: uuid.UUID,
) -> dict:
"""Compare two snapshots and return deltas.
Returns improved/worsened technique lists plus aggregate statistics.
"""
snap_a = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_a_id).first()
snap_b = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_b_id).first()
if not snap_a or not snap_b:
return {"error": "One or both snapshots not found"}
# Build lookup dicts: mitre_id -> {status, score}
states_a = {
s.mitre_id: {"status": s.status, "score": s.score or 0}
for s in db.query(SnapshotTechniqueState)
.filter(SnapshotTechniqueState.snapshot_id == snapshot_a_id)
.all()
}
states_b = {
s.mitre_id: {"status": s.status, "score": s.score or 0}
for s in db.query(SnapshotTechniqueState)
.filter(SnapshotTechniqueState.snapshot_id == snapshot_b_id)
.all()
}
# Status priority for comparison
STATUS_ORDER = {
"not_evaluated": 0,
"not_covered": 1,
"in_progress": 2,
"partial": 3,
"validated": 4,
}
improved = []
worsened = []
unchanged_count = 0
all_mitre_ids = set(states_a.keys()) | set(states_b.keys())
for mitre_id in sorted(all_mitre_ids):
a = states_a.get(mitre_id, {"status": "not_evaluated", "score": 0})
b = states_b.get(mitre_id, {"status": "not_evaluated", "score": 0})
a_order = STATUS_ORDER.get(a["status"], 0)
b_order = STATUS_ORDER.get(b["status"], 0)
if b_order > a_order or (b_order == a_order and b["score"] > a["score"]):
improved.append({
"mitre_id": mitre_id,
"old_status": a["status"],
"new_status": b["status"],
"old_score": a["score"],
"new_score": b["score"],
})
elif b_order < a_order or (b_order == a_order and b["score"] < a["score"]):
worsened.append({
"mitre_id": mitre_id,
"old_status": a["status"],
"new_status": b["status"],
"old_score": a["score"],
"new_score": b["score"],
})
else:
unchanged_count += 1
def _snap_summary(snap: CoverageSnapshot) -> dict:
return {
"id": str(snap.id),
"name": snap.name,
"organization_score": snap.organization_score,
"total_techniques": snap.total_techniques,
"validated_count": snap.validated_count,
"partial_count": snap.partial_count,
"not_covered_count": snap.not_covered_count,
"in_progress_count": snap.in_progress_count,
"not_evaluated_count": snap.not_evaluated_count,
"created_at": snap.created_at.isoformat() if snap.created_at else None,
}
return {
"snapshot_a": _snap_summary(snap_a),
"snapshot_b": _snap_summary(snap_b),
"score_delta": round(snap_b.organization_score - snap_a.organization_score, 1),
"improved": improved,
"worsened": worsened,
"unchanged_count": unchanged_count,
"summary": {
"improved_count": len(improved),
"worsened_count": len(worsened),
"new_count": len(states_b.keys() - states_a.keys()),
},
}
# ---------------------------------------------------------------------------
# Cleanup
# ---------------------------------------------------------------------------
def cleanup_old_snapshots(db: Session, keep_last: int = 52) -> int:
"""Delete oldest snapshots, keeping the most recent *keep_last*.
Returns the number of snapshots deleted.
"""
total = db.query(func.count(CoverageSnapshot.id)).scalar() or 0
if total <= keep_last:
return 0
to_delete = total - keep_last
old_snapshots = (
db.query(CoverageSnapshot)
.order_by(CoverageSnapshot.created_at.asc())
.limit(to_delete)
.all()
)
deleted = 0
for snap in old_snapshots:
db.delete(snap)
deleted += 1
db.commit()
logger.info("Snapshot cleanup — deleted %d old snapshots (kept %d)", deleted, keep_last)
return deleted

View File

@@ -16,11 +16,12 @@ from datetime import datetime
from fastapi import HTTPException, status
from sqlalchemy.orm import Session
from app.config import settings
from app.models.enums import TestState
from app.models.test import Test
from app.models.user import User
from app.services.audit_service import log_action
from app.services.notification_service import notify_test_state_change
from app.services.notification_service import notify_test_state_change, create_notification
# ---------------------------------------------------------------------------
# Valid transition map
@@ -298,6 +299,131 @@ def check_dual_validation(db: Session, test: Test) -> Test:
return test
def handle_remediation_completed(db: Session, test: Test, user: User) -> Test | None:
"""Create a re-test when remediation is completed.
When a test's remediation_status changes to 'completed', this function
creates a new test (retest) with the same base data to verify that the
fix was effective.
Prevents infinite loops by enforcing ``MAX_RETEST_COUNT``.
Returns the new retest or *None* if the limit was reached.
"""
# Always reference the original test, not an intermediate retest
original_test_id = test.retest_of or test.id
if test.retest_count >= settings.MAX_RETEST_COUNT:
# Max retests reached — notify and bail out
if test.created_by:
create_notification(
db,
user_id=test.created_by,
type="max_retests_reached",
title="Maximum retests reached",
message=(
f'Test "{test.name}" has reached the maximum of '
f'{settings.MAX_RETEST_COUNT} retests. Manual review required.'
),
entity_type="test",
entity_id=test.id,
)
log_action(
db,
user_id=user.id,
action="max_retests_reached",
entity_type="test",
entity_id=test.id,
details={
"retest_count": test.retest_count,
"max_allowed": settings.MAX_RETEST_COUNT,
"original_test_id": str(original_test_id),
},
)
return None
retest = Test(
technique_id=test.technique_id,
name=f"[Retest #{test.retest_count + 1}] {test.name.replace(f'[Retest #{test.retest_count}] ', '')}",
description=test.description,
platform=test.platform,
procedure_text=test.procedure_text,
tool_used=test.tool_used,
state=TestState.draft,
created_by=test.created_by,
retest_of=original_test_id,
retest_count=test.retest_count + 1,
)
db.add(retest)
db.flush()
log_action(
db,
user_id=user.id,
action="create_retest",
entity_type="test",
entity_id=retest.id,
details={
"original_test_id": str(original_test_id),
"retest_number": retest.retest_count,
"source_test_id": str(test.id),
},
)
# Notify the test creator and any red_tech users
if test.created_by:
create_notification(
db,
user_id=test.created_by,
type="retest_created",
title="Re-test created",
message=(
f'A re-test has been automatically created for "{test.name}" '
f'after remediation was completed.'
),
entity_type="test",
entity_id=retest.id,
)
db.commit()
db.refresh(retest)
return retest
def get_retest_chain(db: Session, test_id) -> list[Test]:
"""Return the full chain of retests for a given test.
Includes the original test and all subsequent retests, ordered
by retest_count.
"""
import uuid as _uuid
tid = _uuid.UUID(str(test_id)) if not isinstance(test_id, _uuid.UUID) else test_id
# Find the original test first
test = db.query(Test).filter(Test.id == tid).first()
if not test:
return []
original_id = test.retest_of or test.id
# Get original
original = db.query(Test).filter(Test.id == original_id).first()
if not original:
return [test]
# Get all retests of the original
retests = (
db.query(Test)
.filter(Test.retest_of == original_id)
.order_by(Test.retest_count)
.all()
)
return [original] + retests
def reopen_test(db: Session, test: Test, user: User) -> Test:
"""Move a ``rejected`` test back to ``draft``, clearing validation fields.

View File

@@ -19,6 +19,7 @@ import ThreatActorsPage from "./pages/ThreatActorsPage";
import ThreatActorDetailPage from "./pages/ThreatActorDetailPage";
import CampaignsPage from "./pages/CampaignsPage";
import CampaignDetailPage from "./pages/CampaignDetailPage";
import ComparisonPage from "./pages/ComparisonPage";
import Layout from "./components/Layout";
import ProtectedRoute from "./components/ProtectedRoute";
@@ -58,6 +59,7 @@ export default function App() {
<Route path="/threat-actors/:actorId" element={<ThreatActorDetailPage />} />
<Route path="/campaigns" element={<CampaignsPage />} />
<Route path="/campaigns/:campaignId" element={<CampaignDetailPage />} />
<Route path="/comparison" element={<ComparisonPage />} />
<Route path="/compliance" element={<CompliancePage />} />
<Route
path="/system"

View File

@@ -0,0 +1,93 @@
import client from "./client";
// ── Types ───────────────────────────────────────────────────────────
export interface SnapshotSummary {
id: string;
name: string | null;
organization_score: number;
total_techniques: number;
validated_count: number;
partial_count: number;
not_covered_count: number;
in_progress_count: number;
not_evaluated_count: number;
created_by: string | null;
created_at: string | null;
}
export interface TechniqueState {
mitre_id: string;
technique_id: string;
status: string;
score: number | null;
}
export interface SnapshotDetail extends SnapshotSummary {
technique_states: TechniqueState[];
}
export interface SnapshotComparisonDelta {
mitre_id: string;
old_status: string;
new_status: string;
old_score: number;
new_score: number;
}
export interface SnapshotComparison {
snapshot_a: SnapshotSummary;
snapshot_b: SnapshotSummary;
score_delta: number;
improved: SnapshotComparisonDelta[];
worsened: SnapshotComparisonDelta[];
unchanged_count: number;
summary: {
improved_count: number;
worsened_count: number;
new_count: number;
};
}
// ── API Functions ───────────────────────────────────────────────────
/** List snapshots (paginated, newest first). */
export async function listSnapshots(params?: {
offset?: number;
limit?: number;
}): Promise<{ total: number; items: SnapshotSummary[] }> {
const { data } = await client.get("/snapshots", { params });
return data;
}
/** Create a manual snapshot. */
export async function createSnapshot(name?: string): Promise<SnapshotSummary> {
const { data } = await client.post<SnapshotSummary>("/snapshots", {
name: name || null,
});
return data;
}
/** Get snapshot detail with per-technique states. */
export async function getSnapshot(snapshotId: string): Promise<SnapshotDetail> {
const { data } = await client.get<SnapshotDetail>(
`/snapshots/${snapshotId}`,
);
return data;
}
/** Compare two snapshots. */
export async function compareSnapshots(
aId: string,
bId: string,
): Promise<SnapshotComparison> {
const { data } = await client.get<SnapshotComparison>("/snapshots/compare", {
params: { a: aId, b: bId },
});
return data;
}
/** Delete a snapshot (admin only). */
export async function deleteSnapshot(snapshotId: string): Promise<void> {
await client.delete(`/snapshots/${snapshotId}`);
}

View File

@@ -204,6 +204,26 @@ export async function getTestTimeline(
return data;
}
// ── Retest Chain ────────────────────────────────────────────────────
export interface RetestChainEntry {
id: string;
name: string;
state: string | null;
retest_of: string | null;
retest_count: number;
result: string | null;
detection_result: string | null;
remediation_status: string | null;
created_at: string | null;
}
/** Get the full retest chain for a test. */
export async function getRetestChain(testId: string): Promise<RetestChainEntry[]> {
const { data } = await client.get<RetestChainEntry[]>(`/tests/${testId}/retest-chain`);
return data;
}
// ── Legacy (kept for backwards compat) ─────────────────────────────
/** Validate a test (legacy endpoint). */

View File

@@ -18,6 +18,7 @@ import {
Grid3X3,
Gauge,
ShieldCheck,
GitCompareArrows,
} from "lucide-react";
import { useAuth } from "../context/AuthContext";
@@ -46,6 +47,7 @@ const mainLinks: NavItem[] = [
{ to: "/reports", label: "Reports", icon: BarChart3 },
{ to: "/threat-actors", label: "Threat Actors", icon: Crosshair },
{ to: "/campaigns", label: "Campaigns", icon: Zap },
{ to: "/comparison", label: "Comparison", icon: GitCompareArrows },
{ to: "/compliance", label: "Compliance", icon: ShieldCheck },
];

View File

@@ -0,0 +1,458 @@
import { useState, useMemo } from "react";
import { useNavigate } from "react-router-dom";
import { useQuery } from "@tanstack/react-query";
import {
Loader2,
AlertCircle,
ArrowUp,
ArrowDown,
Minus,
GitCompareArrows,
Camera,
TrendingUp,
TrendingDown,
} from "lucide-react";
import {
listSnapshots,
compareSnapshots,
type SnapshotSummary,
type SnapshotComparison,
} from "../api/snapshots";
type Tab = "improved" | "worsened" | "unchanged";
const statusColors: Record<string, string> = {
validated: "text-green-400",
partial: "text-yellow-400",
not_covered: "text-red-400",
in_progress: "text-blue-400",
not_evaluated: "text-gray-500",
};
const statusDots: Record<string, string> = {
validated: "bg-green-400",
partial: "bg-yellow-400",
not_covered: "bg-red-400",
in_progress: "bg-blue-400",
not_evaluated: "bg-gray-500",
};
function StatusBadge({ status }: { status: string }) {
return (
<span className="inline-flex items-center gap-1.5 text-xs">
<span className={`h-2 w-2 rounded-full ${statusDots[status] || statusDots.not_evaluated}`} />
<span className={statusColors[status] || statusColors.not_evaluated}>
{status.replace(/_/g, " ")}
</span>
</span>
);
}
function DeltaArrow({ delta }: { delta: number }) {
if (delta > 0) return <ArrowUp className="h-3.5 w-3.5 text-green-400" />;
if (delta < 0) return <ArrowDown className="h-3.5 w-3.5 text-red-400" />;
return <Minus className="h-3.5 w-3.5 text-gray-500" />;
}
function MetricCard({
label,
valueA,
valueB,
suffix,
}: {
label: string;
valueA: number;
valueB: number;
suffix?: string;
}) {
const delta = valueB - valueA;
const deltaColor =
delta > 0 ? "text-green-400" : delta < 0 ? "text-red-400" : "text-gray-500";
return (
<div className="flex flex-col gap-1">
<span className="text-xs text-gray-500">{label}</span>
<div className="flex items-baseline gap-2">
<span className="text-lg font-bold text-white">
{valueB}
{suffix}
</span>
{delta !== 0 && (
<span className={`flex items-center gap-0.5 text-xs font-medium ${deltaColor}`}>
<DeltaArrow delta={delta} />
{delta > 0 ? "+" : ""}
{delta}
</span>
)}
</div>
</div>
);
}
export default function ComparisonPage() {
const navigate = useNavigate();
const [snapA, setSnapA] = useState<string>("");
const [snapB, setSnapB] = useState<string>("");
const [activeTab, setActiveTab] = useState<Tab>("improved");
// Fetch all snapshots for the dropdowns
const { data: snapshotsData, isLoading: isLoadingSnapshots } = useQuery({
queryKey: ["snapshots", "all"],
queryFn: () => listSnapshots({ limit: 200 }),
});
const snapshots = snapshotsData?.items || [];
// Comparison query
const {
data: comparison,
isLoading: isComparing,
error: compareError,
} = useQuery({
queryKey: ["snapshot-compare", snapA, snapB],
queryFn: () => compareSnapshots(snapA, snapB),
enabled: !!snapA && !!snapB && snapA !== snapB,
});
const formatDate = (dateStr: string | null) => {
if (!dateStr) return "—";
return new Date(dateStr).toLocaleDateString("en-US", {
year: "numeric",
month: "short",
day: "numeric",
hour: "2-digit",
minute: "2-digit",
});
};
// For the "unchanged" tab, we don't get individual rows from the API,
// just a count, so we show the count.
const tabData = useMemo(() => {
if (!comparison) return { improved: [], worsened: [], unchanged_count: 0 };
return {
improved: comparison.improved,
worsened: comparison.worsened,
unchanged_count: comparison.unchanged_count,
};
}, [comparison]);
return (
<div className="space-y-6">
{/* Header */}
<div className="flex items-center justify-between">
<div className="flex items-center gap-3">
<div className="rounded-lg bg-cyan-500/10 p-2.5">
<GitCompareArrows className="h-6 w-6 text-cyan-400" />
</div>
<div>
<h1 className="text-2xl font-bold text-white">Temporal Comparison</h1>
<p className="text-sm text-gray-400">Compare coverage snapshots over time</p>
</div>
</div>
</div>
{/* Snapshot selectors */}
<div className="rounded-xl border border-gray-800 bg-gray-900 p-6">
<div className="grid grid-cols-1 gap-6 md:grid-cols-2">
{/* Snapshot A */}
<div>
<label className="mb-2 block text-sm font-medium text-gray-400">
Snapshot A (Baseline)
</label>
<select
value={snapA}
onChange={(e) => setSnapA(e.target.value)}
className="w-full rounded-lg border border-gray-700 bg-gray-800 px-4 py-2.5 text-sm text-white focus:border-cyan-500 focus:outline-none focus:ring-1 focus:ring-cyan-500"
>
<option value="">Select snapshot...</option>
{snapshots.map((s) => (
<option key={s.id} value={s.id}>
{s.name || `Snapshot ${formatDate(s.created_at)}`} Score:{" "}
{s.organization_score}
</option>
))}
</select>
</div>
{/* Snapshot B */}
<div>
<label className="mb-2 block text-sm font-medium text-gray-400">
Snapshot B (Current)
</label>
<select
value={snapB}
onChange={(e) => setSnapB(e.target.value)}
className="w-full rounded-lg border border-gray-700 bg-gray-800 px-4 py-2.5 text-sm text-white focus:border-cyan-500 focus:outline-none focus:ring-1 focus:ring-cyan-500"
>
<option value="">Select snapshot...</option>
{snapshots.map((s) => (
<option key={s.id} value={s.id}>
{s.name || `Snapshot ${formatDate(s.created_at)}`} Score:{" "}
{s.organization_score}
</option>
))}
</select>
</div>
</div>
{isLoadingSnapshots && (
<div className="mt-4 flex items-center gap-2 text-sm text-gray-400">
<Loader2 className="h-4 w-4 animate-spin" />
Loading snapshots...
</div>
)}
</div>
{/* Loading / Error */}
{isComparing && (
<div className="flex h-40 items-center justify-center">
<Loader2 className="h-8 w-8 animate-spin text-cyan-400" />
</div>
)}
{compareError && (
<div className="flex items-center gap-2 rounded-lg border border-red-500/30 bg-red-900/30 px-4 py-3 text-sm text-red-400">
<AlertCircle className="h-4 w-4" />
Failed to compare snapshots
</div>
)}
{/* Comparison results */}
{comparison && (
<>
{/* Side-by-side score cards */}
<div className="grid grid-cols-1 gap-4 md:grid-cols-2">
{/* Snapshot A card */}
<div className="rounded-xl border border-gray-800 bg-gray-900 p-6">
<div className="mb-4 flex items-center gap-2">
<Camera className="h-4 w-4 text-gray-500" />
<span className="text-sm font-medium text-gray-400">
{comparison.snapshot_a.name || "Snapshot A"}
</span>
<span className="ml-auto text-xs text-gray-600">
{formatDate(comparison.snapshot_a.created_at)}
</span>
</div>
<div className="text-3xl font-bold text-white">
{comparison.snapshot_a.organization_score}
</div>
<div className="mt-3 grid grid-cols-2 gap-3">
<MetricCard
label="Validated"
valueA={comparison.snapshot_a.validated_count}
valueB={comparison.snapshot_a.validated_count}
/>
<MetricCard
label="Partial"
valueA={comparison.snapshot_a.partial_count}
valueB={comparison.snapshot_a.partial_count}
/>
<MetricCard
label="Not Covered"
valueA={comparison.snapshot_a.not_covered_count}
valueB={comparison.snapshot_a.not_covered_count}
/>
<MetricCard
label="In Progress"
valueA={comparison.snapshot_a.in_progress_count}
valueB={comparison.snapshot_a.in_progress_count}
/>
</div>
</div>
{/* Snapshot B card */}
<div className="rounded-xl border border-gray-800 bg-gray-900 p-6">
<div className="mb-4 flex items-center gap-2">
<Camera className="h-4 w-4 text-gray-500" />
<span className="text-sm font-medium text-gray-400">
{comparison.snapshot_b.name || "Snapshot B"}
</span>
<span className="ml-auto text-xs text-gray-600">
{formatDate(comparison.snapshot_b.created_at)}
</span>
</div>
<div className="flex items-baseline gap-3">
<span className="text-3xl font-bold text-white">
{comparison.snapshot_b.organization_score}
</span>
{comparison.score_delta !== 0 && (
<span
className={`flex items-center gap-1 text-sm font-semibold ${
comparison.score_delta > 0 ? "text-green-400" : "text-red-400"
}`}
>
{comparison.score_delta > 0 ? (
<TrendingUp className="h-4 w-4" />
) : (
<TrendingDown className="h-4 w-4" />
)}
{comparison.score_delta > 0 ? "+" : ""}
{comparison.score_delta}
</span>
)}
</div>
<div className="mt-3 grid grid-cols-2 gap-3">
<MetricCard
label="Validated"
valueA={comparison.snapshot_a.validated_count}
valueB={comparison.snapshot_b.validated_count}
/>
<MetricCard
label="Partial"
valueA={comparison.snapshot_a.partial_count}
valueB={comparison.snapshot_b.partial_count}
/>
<MetricCard
label="Not Covered"
valueA={comparison.snapshot_a.not_covered_count}
valueB={comparison.snapshot_b.not_covered_count}
/>
<MetricCard
label="In Progress"
valueA={comparison.snapshot_a.in_progress_count}
valueB={comparison.snapshot_b.in_progress_count}
/>
</div>
</div>
</div>
{/* Tabs */}
<div className="rounded-xl border border-gray-800 bg-gray-900">
<div className="flex border-b border-gray-800">
<button
onClick={() => setActiveTab("improved")}
className={`flex items-center gap-2 px-6 py-3 text-sm font-medium transition-colors ${
activeTab === "improved"
? "border-b-2 border-green-400 text-green-400"
: "text-gray-400 hover:text-white"
}`}
>
<ArrowUp className="h-4 w-4" />
Improved ({comparison.summary.improved_count})
</button>
<button
onClick={() => setActiveTab("worsened")}
className={`flex items-center gap-2 px-6 py-3 text-sm font-medium transition-colors ${
activeTab === "worsened"
? "border-b-2 border-red-400 text-red-400"
: "text-gray-400 hover:text-white"
}`}
>
<ArrowDown className="h-4 w-4" />
Worsened ({comparison.summary.worsened_count})
</button>
<button
onClick={() => setActiveTab("unchanged")}
className={`flex items-center gap-2 px-6 py-3 text-sm font-medium transition-colors ${
activeTab === "unchanged"
? "border-b-2 border-gray-400 text-gray-400"
: "text-gray-500 hover:text-white"
}`}
>
<Minus className="h-4 w-4" />
Unchanged ({comparison.unchanged_count})
</button>
</div>
<div className="p-6">
{activeTab === "unchanged" ? (
<div className="flex flex-col items-center justify-center py-8 text-gray-400">
<Minus className="mb-2 h-8 w-8 text-gray-600" />
<p className="text-lg font-medium">
{comparison.unchanged_count} techniques unchanged
</p>
<p className="mt-1 text-sm text-gray-500">
These techniques had the same status and score in both snapshots.
</p>
</div>
) : (
<>
{(activeTab === "improved"
? tabData.improved
: tabData.worsened
).length === 0 ? (
<div className="flex flex-col items-center justify-center py-8 text-gray-400">
<p className="text-sm">No techniques {activeTab} between snapshots.</p>
</div>
) : (
<div className="overflow-x-auto">
<table className="w-full text-left text-sm">
<thead>
<tr className="border-b border-gray-800">
<th className="pb-3 pr-4 font-medium text-gray-400">MITRE ID</th>
<th className="pb-3 px-4 font-medium text-gray-400">Before</th>
<th className="pb-3 px-4 font-medium text-gray-400">After</th>
<th className="pb-3 px-4 font-medium text-gray-400">Score Before</th>
<th className="pb-3 px-4 font-medium text-gray-400">Score After</th>
<th className="pb-3 pl-4 font-medium text-gray-400">Delta</th>
</tr>
</thead>
<tbody>
{(activeTab === "improved"
? tabData.improved
: tabData.worsened
).map((item) => (
<tr
key={item.mitre_id}
className="border-b border-gray-800/50 hover:bg-gray-800/30 transition-colors cursor-pointer"
onClick={() =>
navigate(`/techniques/${item.mitre_id}`)
}
>
<td className="py-3 pr-4">
<span className="font-mono text-xs text-cyan-400">
{item.mitre_id}
</span>
</td>
<td className="py-3 px-4">
<StatusBadge status={item.old_status} />
</td>
<td className="py-3 px-4">
<StatusBadge status={item.new_status} />
</td>
<td className="py-3 px-4">
<span className="text-xs text-gray-400">{item.old_score}</span>
</td>
<td className="py-3 px-4">
<span className="text-xs text-white">{item.new_score}</span>
</td>
<td className="py-3 pl-4">
<span
className={`flex items-center gap-1 text-xs font-medium ${
item.new_score > item.old_score
? "text-green-400"
: item.new_score < item.old_score
? "text-red-400"
: "text-gray-500"
}`}
>
<DeltaArrow delta={item.new_score - item.old_score} />
{item.new_score > item.old_score ? "+" : ""}
{Math.round((item.new_score - item.old_score) * 10) / 10}
</span>
</td>
</tr>
))}
</tbody>
</table>
</div>
)}
</>
)}
</div>
</div>
</>
)}
{/* No selection prompt */}
{!comparison && !isComparing && !compareError && (
<div className="flex flex-col items-center justify-center py-16 text-gray-400">
<GitCompareArrows className="mb-3 h-12 w-12 text-gray-600" />
<p className="text-lg font-medium">Select two snapshots to compare</p>
<p className="mt-1 text-sm text-gray-500">
Choose a baseline and current snapshot from the dropdowns above.
</p>
</div>
)}
</div>
);
}

View File

@@ -14,6 +14,7 @@ import {
validateAsBlueLead,
reopenTest,
getTestTimeline,
getRetestChain,
} from "../api/tests";
import { uploadEvidence, getEvidence } from "../api/evidence";
import { useAuth } from "../context/AuthContext";
@@ -79,6 +80,12 @@ export default function TestDetailPage() {
enabled: !!testId,
});
const { data: retestChain = [] } = useQuery({
queryKey: ["retest-chain", testId],
queryFn: () => getRetestChain(testId!),
enabled: !!testId && !!test && (test.retest_of !== null || test.retest_count > 0),
});
// Hydrate drafts from test data
useEffect(() => {
if (test) {
@@ -442,6 +449,55 @@ export default function TestDetailPage() {
)}
</dl>
</div>
{/* Retest Chain */}
{(test.retest_of || test.retest_count > 0 || retestChain.length > 1) && (
<div className="rounded-xl border border-gray-800 bg-gray-900 p-6">
<h2 className="mb-4 text-lg font-semibold text-white">Retest Chain</h2>
{test.retest_of && (
<div className="mb-3">
<span className="text-xs font-medium uppercase text-gray-500">
Retest {test.retest_count} / 3
</span>
<div className="mt-1 h-2 w-full rounded-full bg-gray-800 overflow-hidden">
<div
className="h-full rounded-full bg-cyan-500 transition-all"
style={{ width: `${(test.retest_count / 3) * 100}%` }}
/>
</div>
</div>
)}
<div className="space-y-2">
{retestChain.map((entry) => (
<button
key={entry.id}
onClick={() => entry.id !== testId && navigate(`/tests/${entry.id}`)}
className={`flex w-full items-center justify-between rounded-lg border px-3 py-2 text-left text-sm transition-colors ${
entry.id === testId
? "border-cyan-500/30 bg-cyan-900/30 text-cyan-400"
: "border-gray-700 bg-gray-800/50 text-gray-300 hover:border-cyan-500/30 hover:text-cyan-400"
}`}
>
<div className="flex items-center gap-2 truncate">
<span className="truncate text-xs">
{entry.retest_of ? `#${entry.retest_count}` : "Original"}
</span>
<span className="truncate">{entry.name}</span>
</div>
<span className={`shrink-0 rounded-full px-2 py-0.5 text-xs ${
entry.state === "validated"
? "bg-green-900/50 text-green-400"
: entry.state === "rejected"
? "bg-red-900/50 text-red-400"
: "bg-gray-800/50 text-gray-500"
}`}>
{entry.state || "draft"}
</span>
</button>
))}
</div>
</div>
)}
</div>
</div>

View File

@@ -91,6 +91,10 @@ export interface Test {
remediation_status: string | null;
remediation_assignee: string | null;
// Re-test fields
retest_of: string | null;
retest_count: number;
// Technique info (populated in list endpoints)
technique_mitre_id: string | null;
technique_name: string | null;