feat(phase-30): add coverage snapshots, temporal comparison and auto re-testing (T-230 to T-232)

This commit is contained in:
2026-02-10 08:34:29 +01:00
parent 2ac8e7f4a5
commit 4d124b42dd
20 changed files with 1517 additions and 4 deletions

View File

@@ -11,6 +11,9 @@ class Settings(BaseSettings):
MINIO_SECRET_KEY: str = "minioadmin"
MINIO_BUCKET: str = "evidence"
# Re-testing
MAX_RETEST_COUNT: int = 3 # maximum automatic retests per original test
# Scoring weights (must sum to 100)
SCORING_WEIGHT_TESTS: int = 40
SCORING_WEIGHT_DETECTION_RULES: int = 20

View File

@@ -18,6 +18,7 @@ from app.database import SessionLocal
from app.services.mitre_sync_service import sync_mitre
from app.services.intel_service import scan_intel
from app.services.notification_service import cleanup_old_notifications
from app.services.snapshot_service import create_snapshot, cleanup_old_snapshots
logger = logging.getLogger(__name__)
@@ -59,6 +60,26 @@ def _run_notification_cleanup() -> None:
db.close()
def _run_weekly_snapshot() -> None:
"""Create a weekly coverage snapshot and clean up old ones."""
logger.info("Scheduled weekly snapshot job starting...")
db = SessionLocal()
try:
snapshot = create_snapshot(db, name="Auto-weekly")
logger.info(
"Weekly snapshot created — score %.1f, %d techniques",
snapshot.organization_score,
snapshot.total_techniques,
)
deleted = cleanup_old_snapshots(db, keep_last=52)
if deleted:
logger.info("Cleaned up %d old snapshots", deleted)
except Exception:
logger.exception("Weekly snapshot job failed")
finally:
db.close()
def _run_intel_scan() -> None:
"""Execute an intel scan inside its own DB session."""
logger.info("Scheduled intel scan job starting...")
@@ -111,5 +132,18 @@ def start_scheduler() -> None:
name="Notification cleanup (daily)",
replace_existing=True,
)
scheduler.add_job(
_run_weekly_snapshot,
trigger="cron",
day_of_week="sun",
hour=0,
minute=0,
id="weekly_snapshot",
name="Weekly coverage snapshot (Sundays 00:00)",
replace_existing=True,
)
scheduler.start()
logger.info("Background scheduler started — mitre_sync (24h), intel_scan (7d), notification_cleanup (24h)")
logger.info(
"Background scheduler started — mitre_sync (24h), intel_scan (7d), "
"notification_cleanup (24h), weekly_snapshot (Sundays 00:00)"
)

View File

@@ -27,6 +27,7 @@ from app.routers import heatmap as heatmap_router
from app.routers import scores as scores_router
from app.routers import operational_metrics as operational_metrics_router
from app.routers import compliance as compliance_router
from app.routers import snapshots as snapshots_router
from app.storage import ensure_bucket_exists
from app.jobs.mitre_sync_job import start_scheduler, scheduler
@@ -78,6 +79,7 @@ app.include_router(heatmap_router.router, prefix="/api/v1")
app.include_router(scores_router.router, prefix="/api/v1")
app.include_router(operational_metrics_router.router, prefix="/api/v1")
app.include_router(compliance_router.router, prefix="/api/v1")
app.include_router(snapshots_router.router, prefix="/api/v1")
@app.get("/health")

View File

@@ -15,6 +15,7 @@ from app.models.test_template_detection_rule import TestTemplateDetectionRule
from app.models.test_detection_result import TestDetectionResult
from app.models.campaign import Campaign, CampaignTest
from app.models.compliance import ComplianceFramework, ComplianceControl, ComplianceControlMapping
from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState
from app.models.enums import TechniqueStatus, TestState, TestResult, TeamSide
__all__ = [
@@ -25,5 +26,6 @@ __all__ = [
"TestTemplateDetectionRule", "TestDetectionResult",
"Campaign", "CampaignTest",
"ComplianceFramework", "ComplianceControl", "ComplianceControlMapping",
"CoverageSnapshot", "SnapshotTechniqueState",
"TechniqueStatus", "TestState", "TestResult", "TeamSide",
]

View File

@@ -0,0 +1,78 @@
"""Coverage snapshot models — periodic snapshots of coverage state.
CoverageSnapshot stores aggregate metrics at a point in time.
SnapshotTechniqueState stores per-technique state (normalized, one row
per technique per snapshot) to avoid bloated JSONB fields.
"""
import uuid
from datetime import datetime
from sqlalchemy import (
Column, String, Float, Integer, DateTime,
ForeignKey, Index,
)
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship
from app.database import Base
class CoverageSnapshot(Base):
"""A point-in-time snapshot of the organisation's overall coverage."""
__tablename__ = "coverage_snapshots"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
name = Column(String, nullable=True) # e.g. "Pre-remediación Q1"
organization_score = Column(Float, nullable=False)
total_techniques = Column(Integer, nullable=False)
validated_count = Column(Integer, nullable=False)
partial_count = Column(Integer, nullable=False)
not_covered_count = Column(Integer, nullable=False)
in_progress_count = Column(Integer, nullable=False)
not_evaluated_count = Column(Integer, nullable=False)
created_by = Column(
UUID(as_uuid=True),
ForeignKey("users.id", ondelete="SET NULL"),
nullable=True,
)
created_at = Column(DateTime, default=datetime.utcnow)
# Relationships
creator = relationship("User", foreign_keys=[created_by])
technique_states = relationship(
"SnapshotTechniqueState",
back_populates="snapshot",
cascade="all, delete-orphan",
)
class SnapshotTechniqueState(Base):
"""Per-technique state within a snapshot (normalised storage)."""
__tablename__ = "snapshot_technique_states"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
snapshot_id = Column(
UUID(as_uuid=True),
ForeignKey("coverage_snapshots.id", ondelete="CASCADE"),
nullable=False,
)
technique_id = Column(
UUID(as_uuid=True),
ForeignKey("techniques.id", ondelete="CASCADE"),
nullable=False,
)
mitre_id = Column(String, nullable=False) # denormalised for fast queries
status = Column(String, nullable=False)
score = Column(Float, nullable=True)
# Relationships
snapshot = relationship("CoverageSnapshot", back_populates="technique_states")
technique = relationship("Technique")
__table_args__ = (
Index("ix_snapshot_technique_states_snapshot", "snapshot_id"),
Index("ix_snapshot_technique_states_technique", "technique_id"),
)

View File

@@ -1,7 +1,7 @@
import uuid
from datetime import datetime
from sqlalchemy import Column, String, Text, Boolean, DateTime, ForeignKey, Enum
from sqlalchemy import Column, String, Text, Boolean, Integer, DateTime, ForeignKey, Enum
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship
@@ -54,6 +54,10 @@ class Test(Base):
remediation_status = Column(String, nullable=True) # pending / in_progress / completed / not_applicable
remediation_assignee = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
# ── Re-test fields ────────────────────────────────────────────
retest_of = Column(UUID(as_uuid=True), ForeignKey("tests.id"), nullable=True)
retest_count = Column(Integer, default=0)
# ── Relationships ───────────────────────────────────────────────
technique = relationship("Technique", back_populates="tests")
evidences = relationship("Evidence", back_populates="test")
@@ -61,3 +65,5 @@ class Test(Base):
red_validator = relationship("User", foreign_keys=[red_validated_by])
blue_validator = relationship("User", foreign_keys=[blue_validated_by])
remediation_user = relationship("User", foreign_keys=[remediation_assignee])
original_test = relationship("Test", remote_side="Test.id", foreign_keys=[retest_of])
retests = relationship("Test", foreign_keys=[retest_of], back_populates="original_test")

View File

@@ -0,0 +1,205 @@
"""Snapshot endpoints — coverage snapshots CRUD and comparison.
Provides periodic and manual snapshots of the organisation's coverage
state, plus temporal comparison between any two snapshots.
"""
import logging
import uuid
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel
from sqlalchemy.orm import Session
from app.database import get_db
from app.dependencies.auth import get_current_user, require_any_role, require_role
from app.models.user import User
from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState
from app.services.snapshot_service import (
create_snapshot,
compare_snapshots,
cleanup_old_snapshots,
)
from app.services.audit_service import log_action
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/snapshots", tags=["snapshots"])
# ── Pydantic schemas ─────────────────────────────────────────────────
class SnapshotCreate(BaseModel):
name: Optional[str] = None
# ── Helpers ──────────────────────────────────────────────────────────
def _serialize_snapshot_summary(snap: CoverageSnapshot) -> dict:
"""Lightweight serialization for list views."""
return {
"id": str(snap.id),
"name": snap.name,
"organization_score": snap.organization_score,
"total_techniques": snap.total_techniques,
"validated_count": snap.validated_count,
"partial_count": snap.partial_count,
"not_covered_count": snap.not_covered_count,
"in_progress_count": snap.in_progress_count,
"not_evaluated_count": snap.not_evaluated_count,
"created_by": str(snap.created_by) if snap.created_by else None,
"created_at": snap.created_at.isoformat() if snap.created_at else None,
}
def _serialize_snapshot_detail(db: Session, snap: CoverageSnapshot) -> dict:
"""Full serialization including technique states."""
base = _serialize_snapshot_summary(snap)
technique_states = (
db.query(SnapshotTechniqueState)
.filter(SnapshotTechniqueState.snapshot_id == snap.id)
.order_by(SnapshotTechniqueState.mitre_id)
.all()
)
base["technique_states"] = [
{
"mitre_id": s.mitre_id,
"technique_id": str(s.technique_id),
"status": s.status,
"score": s.score,
}
for s in technique_states
]
return base
# ---------------------------------------------------------------------------
# GET /snapshots — List snapshots (paginated)
# ---------------------------------------------------------------------------
@router.get("")
def list_snapshots(
offset: int = Query(0, ge=0),
limit: int = Query(50, ge=1, le=200),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""List coverage snapshots ordered by creation date (newest first)."""
query = db.query(CoverageSnapshot)
total = query.count()
snapshots = (
query
.order_by(CoverageSnapshot.created_at.desc())
.offset(offset)
.limit(limit)
.all()
)
return {
"total": total,
"offset": offset,
"limit": limit,
"items": [_serialize_snapshot_summary(s) for s in snapshots],
}
# ---------------------------------------------------------------------------
# POST /snapshots — Create snapshot manually
# ---------------------------------------------------------------------------
@router.post("", status_code=201)
def create_snapshot_endpoint(
payload: SnapshotCreate,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead", "admin")),
):
"""Create a manual coverage snapshot with an optional name."""
snapshot = create_snapshot(db, name=payload.name, user_id=current_user.id)
log_action(
db,
user_id=current_user.id,
action="create_snapshot",
entity_type="snapshot",
entity_id=snapshot.id,
details={"name": snapshot.name, "score": snapshot.organization_score},
)
return _serialize_snapshot_summary(snapshot)
# ---------------------------------------------------------------------------
# GET /snapshots/compare — Compare two snapshots
# ---------------------------------------------------------------------------
@router.get("/compare")
def compare_snapshots_endpoint(
a: str = Query(..., description="Snapshot A ID"),
b: str = Query(..., description="Snapshot B ID"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Compare two snapshots showing improved, worsened, and unchanged techniques."""
try:
a_id = uuid.UUID(a)
b_id = uuid.UUID(b)
except ValueError:
raise HTTPException(status_code=400, detail="Invalid snapshot ID format")
result = compare_snapshots(db, a_id, b_id)
if "error" in result:
raise HTTPException(status_code=404, detail=result["error"])
return result
# ---------------------------------------------------------------------------
# GET /snapshots/{id} — Snapshot detail
# ---------------------------------------------------------------------------
@router.get("/{snapshot_id}")
def get_snapshot(
snapshot_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Get detailed snapshot information including per-technique states."""
snapshot = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_id).first()
if not snapshot:
raise HTTPException(status_code=404, detail="Snapshot not found")
return _serialize_snapshot_detail(db, snapshot)
# ---------------------------------------------------------------------------
# DELETE /snapshots/{id} — Delete snapshot (admin only)
# ---------------------------------------------------------------------------
@router.delete("/{snapshot_id}")
def delete_snapshot(
snapshot_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
"""Delete a snapshot (admin only)."""
snapshot = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_id).first()
if not snapshot:
raise HTTPException(status_code=404, detail="Snapshot not found")
log_action(
db,
user_id=current_user.id,
action="delete_snapshot",
entity_type="snapshot",
entity_id=snapshot.id,
details={"name": snapshot.name},
)
db.delete(snapshot)
db.commit()
return {"detail": "Snapshot deleted"}

View File

@@ -52,6 +52,8 @@ from app.services.test_workflow_service import (
validate_as_red_lead as wf_validate_red,
validate_as_blue_lead as wf_validate_blue,
reopen_test as wf_reopen,
handle_remediation_completed as wf_handle_remediation,
get_retest_chain as wf_get_retest_chain,
)
router = APIRouter(prefix="/tests", tags=["tests"])
@@ -546,9 +548,15 @@ def update_remediation(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Update remediation fields on a test (any authenticated user)."""
"""Update remediation fields on a test (any authenticated user).
When ``remediation_status`` transitions to ``'completed'``, an automatic
re-test is created (subject to ``MAX_RETEST_COUNT``).
"""
test = _get_test_or_404(db, test_id)
old_remediation_status = test.remediation_status
update_data = payload.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(test, field, value)
@@ -565,6 +573,13 @@ def update_remediation(
details={"updated_fields": list(update_data.keys())},
)
# Auto-create retest when remediation is marked completed
new_status = update_data.get("remediation_status")
if new_status == "completed" and old_remediation_status != "completed":
retest = wf_handle_remediation(db, test, current_user)
if retest:
db.refresh(test)
return test
@@ -603,3 +618,35 @@ def get_test_timeline(
}
for log in logs
]
# ---------------------------------------------------------------------------
# GET /tests/{id}/retest-chain — full retest chain
# ---------------------------------------------------------------------------
@router.get("/{test_id}/retest-chain")
def get_retest_chain(
test_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Return the full chain of retests (original + all retests) for a test."""
chain = wf_get_retest_chain(db, test_id)
if not chain:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Test not found")
return [
{
"id": str(t.id),
"name": t.name,
"state": t.state.value if t.state else None,
"retest_of": str(t.retest_of) if t.retest_of else None,
"retest_count": t.retest_count,
"result": t.result.value if t.result else None,
"detection_result": t.detection_result.value if t.detection_result else None,
"remediation_status": t.remediation_status,
"created_at": t.created_at.isoformat() if t.created_at else None,
}
for t in chain
]

View File

@@ -142,6 +142,10 @@ class TestOut(BaseModel):
remediation_status: str | None = None
remediation_assignee: uuid.UUID | None = None
# Re-test fields
retest_of: uuid.UUID | None = None
retest_count: int = 0
# Technique info (populated when joined)
technique_mitre_id: str | None = None
technique_name: str | None = None

View File

@@ -0,0 +1,253 @@
"""Snapshot service — create, compare, and manage coverage snapshots.
Provides point-in-time coverage captures with normalised per-technique
storage and temporal comparison between any two snapshots.
"""
import logging
import uuid
from datetime import datetime
from sqlalchemy import func
from sqlalchemy.orm import Session
from app.models.technique import Technique
from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState
from app.models.enums import TechniqueStatus
from app.services.scoring_service import calculate_technique_score, calculate_organization_score
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Create snapshot
# ---------------------------------------------------------------------------
def create_snapshot(
db: Session,
name: str | None = None,
user_id: uuid.UUID | None = None,
) -> CoverageSnapshot:
"""Capture the current coverage state into a new snapshot.
1. Fetch every technique with its status and score.
2. Compute aggregate counts.
3. Persist a ``CoverageSnapshot`` with normalised
``SnapshotTechniqueState`` rows.
"""
techniques = db.query(Technique).all()
# Aggregate counters
validated_count = 0
partial_count = 0
not_covered_count = 0
in_progress_count = 0
not_evaluated_count = 0
technique_rows: list[dict] = []
for tech in techniques:
status_value = (
tech.status_global.value
if isinstance(tech.status_global, TechniqueStatus)
else (tech.status_global or "not_evaluated")
)
# Count by status
if status_value == "validated":
validated_count += 1
elif status_value == "partial":
partial_count += 1
elif status_value == "not_covered":
not_covered_count += 1
elif status_value == "in_progress":
in_progress_count += 1
else:
not_evaluated_count += 1
# Compute technique score
score_data = calculate_technique_score(tech, db)
technique_rows.append({
"technique_id": tech.id,
"mitre_id": tech.mitre_id,
"status": status_value,
"score": score_data["total_score"],
})
# Organization score
org_data = calculate_organization_score(db)
org_score = org_data.get("overall_score", 0)
# Create the snapshot
snapshot = CoverageSnapshot(
name=name,
organization_score=org_score,
total_techniques=len(techniques),
validated_count=validated_count,
partial_count=partial_count,
not_covered_count=not_covered_count,
in_progress_count=in_progress_count,
not_evaluated_count=not_evaluated_count,
created_by=user_id,
)
db.add(snapshot)
db.flush() # get snapshot.id
# Create normalised technique state rows
for row in technique_rows:
state = SnapshotTechniqueState(
snapshot_id=snapshot.id,
technique_id=row["technique_id"],
mitre_id=row["mitre_id"],
status=row["status"],
score=row["score"],
)
db.add(state)
db.commit()
db.refresh(snapshot)
logger.info(
"Snapshot '%s' created — %d techniques, org score %.1f",
snapshot.name or snapshot.id,
len(techniques),
org_score,
)
return snapshot
# ---------------------------------------------------------------------------
# Compare snapshots
# ---------------------------------------------------------------------------
def compare_snapshots(
db: Session,
snapshot_a_id: uuid.UUID,
snapshot_b_id: uuid.UUID,
) -> dict:
"""Compare two snapshots and return deltas.
Returns improved/worsened technique lists plus aggregate statistics.
"""
snap_a = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_a_id).first()
snap_b = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_b_id).first()
if not snap_a or not snap_b:
return {"error": "One or both snapshots not found"}
# Build lookup dicts: mitre_id -> {status, score}
states_a = {
s.mitre_id: {"status": s.status, "score": s.score or 0}
for s in db.query(SnapshotTechniqueState)
.filter(SnapshotTechniqueState.snapshot_id == snapshot_a_id)
.all()
}
states_b = {
s.mitre_id: {"status": s.status, "score": s.score or 0}
for s in db.query(SnapshotTechniqueState)
.filter(SnapshotTechniqueState.snapshot_id == snapshot_b_id)
.all()
}
# Status priority for comparison
STATUS_ORDER = {
"not_evaluated": 0,
"not_covered": 1,
"in_progress": 2,
"partial": 3,
"validated": 4,
}
improved = []
worsened = []
unchanged_count = 0
all_mitre_ids = set(states_a.keys()) | set(states_b.keys())
for mitre_id in sorted(all_mitre_ids):
a = states_a.get(mitre_id, {"status": "not_evaluated", "score": 0})
b = states_b.get(mitre_id, {"status": "not_evaluated", "score": 0})
a_order = STATUS_ORDER.get(a["status"], 0)
b_order = STATUS_ORDER.get(b["status"], 0)
if b_order > a_order or (b_order == a_order and b["score"] > a["score"]):
improved.append({
"mitre_id": mitre_id,
"old_status": a["status"],
"new_status": b["status"],
"old_score": a["score"],
"new_score": b["score"],
})
elif b_order < a_order or (b_order == a_order and b["score"] < a["score"]):
worsened.append({
"mitre_id": mitre_id,
"old_status": a["status"],
"new_status": b["status"],
"old_score": a["score"],
"new_score": b["score"],
})
else:
unchanged_count += 1
def _snap_summary(snap: CoverageSnapshot) -> dict:
return {
"id": str(snap.id),
"name": snap.name,
"organization_score": snap.organization_score,
"total_techniques": snap.total_techniques,
"validated_count": snap.validated_count,
"partial_count": snap.partial_count,
"not_covered_count": snap.not_covered_count,
"in_progress_count": snap.in_progress_count,
"not_evaluated_count": snap.not_evaluated_count,
"created_at": snap.created_at.isoformat() if snap.created_at else None,
}
return {
"snapshot_a": _snap_summary(snap_a),
"snapshot_b": _snap_summary(snap_b),
"score_delta": round(snap_b.organization_score - snap_a.organization_score, 1),
"improved": improved,
"worsened": worsened,
"unchanged_count": unchanged_count,
"summary": {
"improved_count": len(improved),
"worsened_count": len(worsened),
"new_count": len(states_b.keys() - states_a.keys()),
},
}
# ---------------------------------------------------------------------------
# Cleanup
# ---------------------------------------------------------------------------
def cleanup_old_snapshots(db: Session, keep_last: int = 52) -> int:
"""Delete oldest snapshots, keeping the most recent *keep_last*.
Returns the number of snapshots deleted.
"""
total = db.query(func.count(CoverageSnapshot.id)).scalar() or 0
if total <= keep_last:
return 0
to_delete = total - keep_last
old_snapshots = (
db.query(CoverageSnapshot)
.order_by(CoverageSnapshot.created_at.asc())
.limit(to_delete)
.all()
)
deleted = 0
for snap in old_snapshots:
db.delete(snap)
deleted += 1
db.commit()
logger.info("Snapshot cleanup — deleted %d old snapshots (kept %d)", deleted, keep_last)
return deleted

View File

@@ -16,11 +16,12 @@ from datetime import datetime
from fastapi import HTTPException, status
from sqlalchemy.orm import Session
from app.config import settings
from app.models.enums import TestState
from app.models.test import Test
from app.models.user import User
from app.services.audit_service import log_action
from app.services.notification_service import notify_test_state_change
from app.services.notification_service import notify_test_state_change, create_notification
# ---------------------------------------------------------------------------
# Valid transition map
@@ -298,6 +299,131 @@ def check_dual_validation(db: Session, test: Test) -> Test:
return test
def handle_remediation_completed(db: Session, test: Test, user: User) -> Test | None:
"""Create a re-test when remediation is completed.
When a test's remediation_status changes to 'completed', this function
creates a new test (retest) with the same base data to verify that the
fix was effective.
Prevents infinite loops by enforcing ``MAX_RETEST_COUNT``.
Returns the new retest or *None* if the limit was reached.
"""
# Always reference the original test, not an intermediate retest
original_test_id = test.retest_of or test.id
if test.retest_count >= settings.MAX_RETEST_COUNT:
# Max retests reached — notify and bail out
if test.created_by:
create_notification(
db,
user_id=test.created_by,
type="max_retests_reached",
title="Maximum retests reached",
message=(
f'Test "{test.name}" has reached the maximum of '
f'{settings.MAX_RETEST_COUNT} retests. Manual review required.'
),
entity_type="test",
entity_id=test.id,
)
log_action(
db,
user_id=user.id,
action="max_retests_reached",
entity_type="test",
entity_id=test.id,
details={
"retest_count": test.retest_count,
"max_allowed": settings.MAX_RETEST_COUNT,
"original_test_id": str(original_test_id),
},
)
return None
retest = Test(
technique_id=test.technique_id,
name=f"[Retest #{test.retest_count + 1}] {test.name.replace(f'[Retest #{test.retest_count}] ', '')}",
description=test.description,
platform=test.platform,
procedure_text=test.procedure_text,
tool_used=test.tool_used,
state=TestState.draft,
created_by=test.created_by,
retest_of=original_test_id,
retest_count=test.retest_count + 1,
)
db.add(retest)
db.flush()
log_action(
db,
user_id=user.id,
action="create_retest",
entity_type="test",
entity_id=retest.id,
details={
"original_test_id": str(original_test_id),
"retest_number": retest.retest_count,
"source_test_id": str(test.id),
},
)
# Notify the test creator and any red_tech users
if test.created_by:
create_notification(
db,
user_id=test.created_by,
type="retest_created",
title="Re-test created",
message=(
f'A re-test has been automatically created for "{test.name}" '
f'after remediation was completed.'
),
entity_type="test",
entity_id=retest.id,
)
db.commit()
db.refresh(retest)
return retest
def get_retest_chain(db: Session, test_id) -> list[Test]:
"""Return the full chain of retests for a given test.
Includes the original test and all subsequent retests, ordered
by retest_count.
"""
import uuid as _uuid
tid = _uuid.UUID(str(test_id)) if not isinstance(test_id, _uuid.UUID) else test_id
# Find the original test first
test = db.query(Test).filter(Test.id == tid).first()
if not test:
return []
original_id = test.retest_of or test.id
# Get original
original = db.query(Test).filter(Test.id == original_id).first()
if not original:
return [test]
# Get all retests of the original
retests = (
db.query(Test)
.filter(Test.retest_of == original_id)
.order_by(Test.retest_count)
.all()
)
return [original] + retests
def reopen_test(db: Session, test: Test, user: User) -> Test:
"""Move a ``rejected`` test back to ``draft``, clearing validation fields.