"""Tests for campaigns, snapshots, and re-testing — T-237. Uses the in-memory SQLite test database from conftest.py. """ import uuid from datetime import datetime, timedelta import pytest from app.models.technique import Technique from app.models.test import Test from app.models.test_template import TestTemplate from app.models.campaign import Campaign, CampaignTest from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState from app.models.enums import TestState, TestResult, TechniqueStatus from app.services.campaign_service import ( validate_no_circular_dependency, get_campaign_progress, ) from app.services.campaign_scheduler_service import ( calculate_next_run, check_and_run_recurring_campaigns, ) from app.services.snapshot_service import ( create_snapshot, compare_snapshots, cleanup_old_snapshots, ) from app.services.test_workflow_service import ( handle_remediation_completed, get_retest_chain, ) # --------------------------------------------------------------------------- # Shared fixtures # --------------------------------------------------------------------------- @pytest.fixture def techniques(db): """Create a set of techniques for testing.""" techs = [] for mid, name, status in [ ("T1059", "Command Line", TechniqueStatus.validated), ("T1078", "Valid Accounts", TechniqueStatus.partial), ("T1053", "Scheduled Tasks", TechniqueStatus.not_covered), ]: tech = Technique( mitre_id=mid, name=name, tactic="execution", platforms=["windows"], status_global=status, ) db.add(tech) techs.append(tech) db.commit() for t in techs: db.refresh(t) return techs @pytest.fixture def campaign_with_tests(db, techniques, admin_user): """Create a campaign with ordered tests.""" campaign = Campaign( name="Test Campaign", type="custom", status="draft", created_by=admin_user.id, ) db.add(campaign) db.flush() tests = [] for i, tech in enumerate(techniques): test = Test( technique_id=tech.id, name=f"Test for {tech.mitre_id}", state=TestState.draft, created_by=admin_user.id, ) db.add(test) db.flush() tests.append(test) ct = CampaignTest( campaign_id=campaign.id, test_id=test.id, order_index=i, phase="execution", ) db.add(ct) db.commit() db.refresh(campaign) return {"campaign": campaign, "tests": tests} # ═══════════════════════════════════════════════════════════════════════ # Campaign Tests # ═══════════════════════════════════════════════════════════════════════ class TestCampaigns: def test_create_campaign_with_tests(self, db, campaign_with_tests): """CRUD básico de campaña con tests ordenados.""" campaign = campaign_with_tests["campaign"] assert campaign.name == "Test Campaign" assert campaign.status == "draft" cts = ( db.query(CampaignTest) .filter(CampaignTest.campaign_id == campaign.id) .order_by(CampaignTest.order_index) .all() ) assert len(cts) == 3 assert cts[0].order_index == 0 assert cts[1].order_index == 1 assert cts[2].order_index == 2 def test_campaign_progress_calculation(self, db, campaign_with_tests): """Progreso se calcula según estado de tests.""" campaign = campaign_with_tests["campaign"] tests = campaign_with_tests["tests"] # Initially all draft → 0% complete progress = get_campaign_progress(db, campaign.id) assert progress["total"] == 3 assert progress["completion_pct"] == 0.0 # Validate one test tests[0].state = TestState.validated db.commit() progress = get_campaign_progress(db, campaign.id) assert progress["completion_pct"] == pytest.approx(33.3, abs=0.1) def test_circular_dependency_prevention(self, db, campaign_with_tests): """Intentar crear dependencia circular en campaign_tests falla.""" from app.domain.exceptions import InvalidOperationError campaign = campaign_with_tests["campaign"] cts = ( db.query(CampaignTest) .filter(CampaignTest.campaign_id == campaign.id) .order_by(CampaignTest.order_index) .all() ) # Create A -> B dependency cts[1].depends_on = cts[0].id db.commit() # Try to create B -> A (circular) with pytest.raises(InvalidOperationError) as exc_info: validate_no_circular_dependency( db, campaign.id, cts[0].id, cts[1].id ) assert exc_info.value.code == "INVALID_OPERATION" def test_campaign_scheduling_next_run(self): """next_run_at se calcula correctamente para weekly/monthly/quarterly.""" base = datetime(2026, 1, 1) weekly = calculate_next_run(base, "weekly") assert weekly == datetime(2026, 1, 8) monthly = calculate_next_run(base, "monthly") assert monthly == datetime(2026, 1, 31) quarterly = calculate_next_run(base, "quarterly") assert quarterly == datetime(2026, 4, 1) def test_campaign_cloning(self, db, campaign_with_tests, admin_user): """Clonación de campaña recurrente crea tests nuevos con datos correctos.""" campaign = campaign_with_tests["campaign"] original_tests = campaign_with_tests["tests"] # Set up as recurring campaign.is_recurring = True campaign.recurrence_pattern = "monthly" campaign.next_run_at = datetime.utcnow() - timedelta(hours=1) # Due now db.commit() # Run the scheduler spawned = check_and_run_recurring_campaigns(db) assert spawned == 1 # Find child campaign child = ( db.query(Campaign) .filter(Campaign.parent_campaign_id == campaign.id) .first() ) assert child is not None assert "Run" in child.name assert child.status == "active" # Check child tests are fresh copies (new IDs, draft state) child_cts = ( db.query(CampaignTest) .filter(CampaignTest.campaign_id == child.id) .all() ) assert len(child_cts) == len(original_tests) child_test_ids = {ct.test_id for ct in child_cts} original_test_ids = {t.id for t in original_tests} assert child_test_ids.isdisjoint(original_test_ids) # All new IDs for ct in child_cts: test = db.query(Test).filter(Test.id == ct.test_id).first() assert test.state == TestState.draft # Check parent was updated db.refresh(campaign) assert campaign.last_run_at is not None assert campaign.next_run_at > datetime.utcnow() # ═══════════════════════════════════════════════════════════════════════ # Snapshot Tests # ═══════════════════════════════════════════════════════════════════════ class TestSnapshots: def test_create_snapshot(self, db, techniques, admin_user): """Snapshot captura estado actual correctamente.""" snapshot = create_snapshot(db, name="Test Snapshot", user_id=admin_user.id) assert snapshot is not None assert snapshot.name == "Test Snapshot" assert snapshot.total_techniques == len(techniques) assert snapshot.created_by == admin_user.id assert snapshot.organization_score >= 0 # Verify per-technique states states = ( db.query(SnapshotTechniqueState) .filter(SnapshotTechniqueState.snapshot_id == snapshot.id) .all() ) assert len(states) == len(techniques) mitre_ids = {s.mitre_id for s in states} assert "T1059" in mitre_ids assert "T1078" in mitre_ids assert "T1053" in mitre_ids def test_compare_snapshots_improvements(self, db, techniques, admin_user): """Comparación detecta técnicas que mejoraron.""" # Create snapshot A snap_a = create_snapshot(db, name="Before") # Improve a technique tech = db.query(Technique).filter(Technique.mitre_id == "T1053").first() tech.status_global = TechniqueStatus.validated db.commit() # Create snapshot B snap_b = create_snapshot(db, name="After") result = compare_snapshots(db, snap_a.id, snap_b.id) assert result["score_delta"] is not None assert result["summary"]["improved_count"] >= 0 assert isinstance(result["improved"], list) assert isinstance(result["worsened"], list) assert result["unchanged_count"] >= 0 def test_compare_snapshots_regressions(self, db, techniques, admin_user): """Comparación detecta técnicas que empeoraron.""" # Create snapshot A snap_a = create_snapshot(db, name="Before Regression") # Worsen a technique tech = db.query(Technique).filter(Technique.mitre_id == "T1059").first() tech.status_global = TechniqueStatus.not_covered db.commit() snap_b = create_snapshot(db, name="After Regression") result = compare_snapshots(db, snap_a.id, snap_b.id) assert result["summary"]["worsened_count"] >= 0 def test_snapshot_cleanup(self, db, techniques, admin_user): """Cleanup mantiene solo los últimos N snapshots.""" # Create 5 snapshots for i in range(5): create_snapshot(db, name=f"Snapshot {i}") total_before = db.query(CoverageSnapshot).count() assert total_before == 5 # Cleanup keeping only 3 deleted = cleanup_old_snapshots(db, keep_last=3) assert deleted == 2 total_after = db.query(CoverageSnapshot).count() assert total_after == 3 def test_snapshot_normalized_storage(self, db, techniques, admin_user): """Verificar que el almacenamiento normalizado funciona correctamente.""" snapshot = create_snapshot(db, name="Normalized Check") # Each technique should have exactly one SnapshotTechniqueState row for tech in techniques: states = ( db.query(SnapshotTechniqueState) .filter( SnapshotTechniqueState.snapshot_id == snapshot.id, SnapshotTechniqueState.technique_id == tech.id, ) .all() ) assert len(states) == 1 state = states[0] assert state.mitre_id == tech.mitre_id assert state.status is not None # ═══════════════════════════════════════════════════════════════════════ # Re-testing Tests # ═══════════════════════════════════════════════════════════════════════ class TestRetesting: def test_retest_created_on_remediation(self, db, techniques, admin_user): """Completar remediación crea retest automáticamente.""" test = Test( technique_id=techniques[0].id, name="Original Test", state=TestState.validated, remediation_status="completed", created_by=admin_user.id, ) db.add(test) db.commit() db.refresh(test) retest = handle_remediation_completed(db, test, admin_user) assert retest is not None assert retest.retest_of == test.id assert retest.retest_count == 1 assert retest.state == TestState.draft assert retest.technique_id == test.technique_id def test_retest_points_to_original(self, db, techniques, admin_user): """Retest de un retest apunta al test original, no al intermedio.""" original = Test( technique_id=techniques[0].id, name="Original", state=TestState.validated, remediation_status="completed", created_by=admin_user.id, retest_count=0, ) db.add(original) db.commit() db.refresh(original) # First retest retest1 = handle_remediation_completed(db, original, admin_user) assert retest1 is not None assert retest1.retest_of == original.id # Simulate completing remediation on retest1 retest1.state = TestState.validated retest1.remediation_status = "completed" db.commit() db.refresh(retest1) # Second retest — should point to ORIGINAL, not retest1 retest2 = handle_remediation_completed(db, retest1, admin_user) assert retest2 is not None assert retest2.retest_of == original.id # Points to original! assert retest2.retest_count == 2 def test_retest_max_limit(self, db, techniques, admin_user): """Al alcanzar MAX_RETEST_COUNT no se crea retest.""" from app.config import settings test = Test( technique_id=techniques[0].id, name="Max Retests Test", state=TestState.validated, remediation_status="completed", created_by=admin_user.id, retest_count=settings.MAX_RETEST_COUNT, # Already at max ) db.add(test) db.commit() db.refresh(test) result = handle_remediation_completed(db, test, admin_user) assert result is None # No retest created def test_retest_chain_query(self, db, techniques, admin_user): """Endpoint /tests/{id}/retest-chain retorna cadena completa.""" original = Test( technique_id=techniques[0].id, name="Chain Original", state=TestState.validated, remediation_status="completed", created_by=admin_user.id, ) db.add(original) db.commit() db.refresh(original) retest1 = handle_remediation_completed(db, original, admin_user) assert retest1 is not None # Complete retest1 and trigger another retest1.state = TestState.validated retest1.remediation_status = "completed" db.commit() db.refresh(retest1) retest2 = handle_remediation_completed(db, retest1, admin_user) assert retest2 is not None # Get chain chain = get_retest_chain(db, original.id) assert len(chain) == 3 # original + retest1 + retest2 assert chain[0].id == original.id assert chain[1].retest_count == 1 assert chain[2].retest_count == 2 def test_retest_has_correct_data(self, db, techniques, admin_user): """Retest tiene mismos datos base que el original.""" original = Test( technique_id=techniques[0].id, name="Data Check Original", description="Test description", platform="windows", procedure_text="Run cmd /c whoami", tool_used="cmd.exe", state=TestState.validated, remediation_status="completed", created_by=admin_user.id, ) db.add(original) db.commit() db.refresh(original) retest = handle_remediation_completed(db, original, admin_user) assert retest is not None # Verify base data is copied assert retest.technique_id == original.technique_id assert retest.description == original.description assert retest.platform == original.platform assert retest.procedure_text == original.procedure_text assert retest.tool_used == original.tool_used assert retest.created_by == original.created_by assert retest.state == TestState.draft