feat(phase-32): add automated tests V3 for data sources, scoring, campaigns and snapshots (T-235 to T-237)

This commit is contained in:
2026-02-10 09:07:43 +01:00
parent 02034d60f0
commit 35983de67e
11 changed files with 1676 additions and 12 deletions

View File

@@ -0,0 +1,464 @@
"""Tests for campaigns, snapshots, and re-testing — T-237.
Uses the in-memory SQLite test database from conftest.py.
"""
import uuid
from datetime import datetime, timedelta
import pytest
from app.models.technique import Technique
from app.models.test import Test
from app.models.test_template import TestTemplate
from app.models.campaign import Campaign, CampaignTest
from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState
from app.models.enums import TestState, TestResult, TechniqueStatus
from app.services.campaign_service import (
validate_no_circular_dependency,
get_campaign_progress,
)
from app.services.campaign_scheduler_service import (
calculate_next_run,
check_and_run_recurring_campaigns,
)
from app.services.snapshot_service import (
create_snapshot,
compare_snapshots,
cleanup_old_snapshots,
)
from app.services.test_workflow_service import (
handle_remediation_completed,
get_retest_chain,
)
# ---------------------------------------------------------------------------
# Shared fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def techniques(db):
"""Create a set of techniques for testing."""
techs = []
for mid, name, status in [
("T1059", "Command Line", TechniqueStatus.validated),
("T1078", "Valid Accounts", TechniqueStatus.partial),
("T1053", "Scheduled Tasks", TechniqueStatus.not_covered),
]:
tech = Technique(
mitre_id=mid,
name=name,
tactic="execution",
platforms=["windows"],
status_global=status,
)
db.add(tech)
techs.append(tech)
db.commit()
for t in techs:
db.refresh(t)
return techs
@pytest.fixture
def campaign_with_tests(db, techniques, admin_user):
"""Create a campaign with ordered tests."""
campaign = Campaign(
name="Test Campaign",
type="custom",
status="draft",
created_by=admin_user.id,
)
db.add(campaign)
db.flush()
tests = []
for i, tech in enumerate(techniques):
test = Test(
technique_id=tech.id,
name=f"Test for {tech.mitre_id}",
state=TestState.draft,
created_by=admin_user.id,
)
db.add(test)
db.flush()
tests.append(test)
ct = CampaignTest(
campaign_id=campaign.id,
test_id=test.id,
order_index=i,
phase="execution",
)
db.add(ct)
db.commit()
db.refresh(campaign)
return {"campaign": campaign, "tests": tests}
# ═══════════════════════════════════════════════════════════════════════
# Campaign Tests
# ═══════════════════════════════════════════════════════════════════════
class TestCampaigns:
def test_create_campaign_with_tests(self, db, campaign_with_tests):
"""CRUD básico de campaña con tests ordenados."""
campaign = campaign_with_tests["campaign"]
assert campaign.name == "Test Campaign"
assert campaign.status == "draft"
cts = (
db.query(CampaignTest)
.filter(CampaignTest.campaign_id == campaign.id)
.order_by(CampaignTest.order_index)
.all()
)
assert len(cts) == 3
assert cts[0].order_index == 0
assert cts[1].order_index == 1
assert cts[2].order_index == 2
def test_campaign_progress_calculation(self, db, campaign_with_tests):
"""Progreso se calcula según estado de tests."""
campaign = campaign_with_tests["campaign"]
tests = campaign_with_tests["tests"]
# Initially all draft → 0% complete
progress = get_campaign_progress(db, campaign.id)
assert progress["total"] == 3
assert progress["completion_pct"] == 0.0
# Validate one test
tests[0].state = TestState.validated
db.commit()
progress = get_campaign_progress(db, campaign.id)
assert progress["completion_pct"] == pytest.approx(33.3, abs=0.1)
def test_circular_dependency_prevention(self, db, campaign_with_tests):
"""Intentar crear dependencia circular en campaign_tests falla."""
from fastapi import HTTPException
campaign = campaign_with_tests["campaign"]
cts = (
db.query(CampaignTest)
.filter(CampaignTest.campaign_id == campaign.id)
.order_by(CampaignTest.order_index)
.all()
)
# Create A -> B dependency
cts[1].depends_on = cts[0].id
db.commit()
# Try to create B -> A (circular)
with pytest.raises(HTTPException) as exc_info:
validate_no_circular_dependency(
db, campaign.id, cts[0].id, cts[1].id
)
assert exc_info.value.status_code == 400
def test_campaign_scheduling_next_run(self):
"""next_run_at se calcula correctamente para weekly/monthly/quarterly."""
base = datetime(2026, 1, 1)
weekly = calculate_next_run(base, "weekly")
assert weekly == datetime(2026, 1, 8)
monthly = calculate_next_run(base, "monthly")
assert monthly == datetime(2026, 1, 31)
quarterly = calculate_next_run(base, "quarterly")
assert quarterly == datetime(2026, 4, 1)
def test_campaign_cloning(self, db, campaign_with_tests, admin_user):
"""Clonación de campaña recurrente crea tests nuevos con datos correctos."""
campaign = campaign_with_tests["campaign"]
original_tests = campaign_with_tests["tests"]
# Set up as recurring
campaign.is_recurring = True
campaign.recurrence_pattern = "monthly"
campaign.next_run_at = datetime.utcnow() - timedelta(hours=1) # Due now
db.commit()
# Run the scheduler
spawned = check_and_run_recurring_campaigns(db)
assert spawned == 1
# Find child campaign
child = (
db.query(Campaign)
.filter(Campaign.parent_campaign_id == campaign.id)
.first()
)
assert child is not None
assert "Run" in child.name
assert child.status == "active"
# Check child tests are fresh copies (new IDs, draft state)
child_cts = (
db.query(CampaignTest)
.filter(CampaignTest.campaign_id == child.id)
.all()
)
assert len(child_cts) == len(original_tests)
child_test_ids = {ct.test_id for ct in child_cts}
original_test_ids = {t.id for t in original_tests}
assert child_test_ids.isdisjoint(original_test_ids) # All new IDs
for ct in child_cts:
test = db.query(Test).filter(Test.id == ct.test_id).first()
assert test.state == TestState.draft
# Check parent was updated
db.refresh(campaign)
assert campaign.last_run_at is not None
assert campaign.next_run_at > datetime.utcnow()
# ═══════════════════════════════════════════════════════════════════════
# Snapshot Tests
# ═══════════════════════════════════════════════════════════════════════
class TestSnapshots:
def test_create_snapshot(self, db, techniques, admin_user):
"""Snapshot captura estado actual correctamente."""
snapshot = create_snapshot(db, name="Test Snapshot", user_id=admin_user.id)
assert snapshot is not None
assert snapshot.name == "Test Snapshot"
assert snapshot.total_techniques == len(techniques)
assert snapshot.created_by == admin_user.id
assert snapshot.organization_score >= 0
# Verify per-technique states
states = (
db.query(SnapshotTechniqueState)
.filter(SnapshotTechniqueState.snapshot_id == snapshot.id)
.all()
)
assert len(states) == len(techniques)
mitre_ids = {s.mitre_id for s in states}
assert "T1059" in mitre_ids
assert "T1078" in mitre_ids
assert "T1053" in mitre_ids
def test_compare_snapshots_improvements(self, db, techniques, admin_user):
"""Comparación detecta técnicas que mejoraron."""
# Create snapshot A
snap_a = create_snapshot(db, name="Before")
# Improve a technique
tech = db.query(Technique).filter(Technique.mitre_id == "T1053").first()
tech.status_global = TechniqueStatus.validated
db.commit()
# Create snapshot B
snap_b = create_snapshot(db, name="After")
result = compare_snapshots(db, snap_a.id, snap_b.id)
assert result["score_delta"] is not None
assert result["summary"]["improved_count"] >= 0
assert isinstance(result["improved"], list)
assert isinstance(result["worsened"], list)
assert result["unchanged_count"] >= 0
def test_compare_snapshots_regressions(self, db, techniques, admin_user):
"""Comparación detecta técnicas que empeoraron."""
# Create snapshot A
snap_a = create_snapshot(db, name="Before Regression")
# Worsen a technique
tech = db.query(Technique).filter(Technique.mitre_id == "T1059").first()
tech.status_global = TechniqueStatus.not_covered
db.commit()
snap_b = create_snapshot(db, name="After Regression")
result = compare_snapshots(db, snap_a.id, snap_b.id)
assert result["summary"]["worsened_count"] >= 0
def test_snapshot_cleanup(self, db, techniques, admin_user):
"""Cleanup mantiene solo los últimos N snapshots."""
# Create 5 snapshots
for i in range(5):
create_snapshot(db, name=f"Snapshot {i}")
total_before = db.query(CoverageSnapshot).count()
assert total_before == 5
# Cleanup keeping only 3
deleted = cleanup_old_snapshots(db, keep_last=3)
assert deleted == 2
total_after = db.query(CoverageSnapshot).count()
assert total_after == 3
def test_snapshot_normalized_storage(self, db, techniques, admin_user):
"""Verificar que el almacenamiento normalizado funciona correctamente."""
snapshot = create_snapshot(db, name="Normalized Check")
# Each technique should have exactly one SnapshotTechniqueState row
for tech in techniques:
states = (
db.query(SnapshotTechniqueState)
.filter(
SnapshotTechniqueState.snapshot_id == snapshot.id,
SnapshotTechniqueState.technique_id == tech.id,
)
.all()
)
assert len(states) == 1
state = states[0]
assert state.mitre_id == tech.mitre_id
assert state.status is not None
# ═══════════════════════════════════════════════════════════════════════
# Re-testing Tests
# ═══════════════════════════════════════════════════════════════════════
class TestRetesting:
def test_retest_created_on_remediation(self, db, techniques, admin_user):
"""Completar remediación crea retest automáticamente."""
test = Test(
technique_id=techniques[0].id,
name="Original Test",
state=TestState.validated,
remediation_status="completed",
created_by=admin_user.id,
)
db.add(test)
db.commit()
db.refresh(test)
retest = handle_remediation_completed(db, test, admin_user)
assert retest is not None
assert retest.retest_of == test.id
assert retest.retest_count == 1
assert retest.state == TestState.draft
assert retest.technique_id == test.technique_id
def test_retest_points_to_original(self, db, techniques, admin_user):
"""Retest de un retest apunta al test original, no al intermedio."""
original = Test(
technique_id=techniques[0].id,
name="Original",
state=TestState.validated,
remediation_status="completed",
created_by=admin_user.id,
retest_count=0,
)
db.add(original)
db.commit()
db.refresh(original)
# First retest
retest1 = handle_remediation_completed(db, original, admin_user)
assert retest1 is not None
assert retest1.retest_of == original.id
# Simulate completing remediation on retest1
retest1.state = TestState.validated
retest1.remediation_status = "completed"
db.commit()
db.refresh(retest1)
# Second retest — should point to ORIGINAL, not retest1
retest2 = handle_remediation_completed(db, retest1, admin_user)
assert retest2 is not None
assert retest2.retest_of == original.id # Points to original!
assert retest2.retest_count == 2
def test_retest_max_limit(self, db, techniques, admin_user):
"""Al alcanzar MAX_RETEST_COUNT no se crea retest."""
from app.config import settings
test = Test(
technique_id=techniques[0].id,
name="Max Retests Test",
state=TestState.validated,
remediation_status="completed",
created_by=admin_user.id,
retest_count=settings.MAX_RETEST_COUNT, # Already at max
)
db.add(test)
db.commit()
db.refresh(test)
result = handle_remediation_completed(db, test, admin_user)
assert result is None # No retest created
def test_retest_chain_query(self, db, techniques, admin_user):
"""Endpoint /tests/{id}/retest-chain retorna cadena completa."""
original = Test(
technique_id=techniques[0].id,
name="Chain Original",
state=TestState.validated,
remediation_status="completed",
created_by=admin_user.id,
)
db.add(original)
db.commit()
db.refresh(original)
retest1 = handle_remediation_completed(db, original, admin_user)
assert retest1 is not None
# Complete retest1 and trigger another
retest1.state = TestState.validated
retest1.remediation_status = "completed"
db.commit()
db.refresh(retest1)
retest2 = handle_remediation_completed(db, retest1, admin_user)
assert retest2 is not None
# Get chain
chain = get_retest_chain(db, original.id)
assert len(chain) == 3 # original + retest1 + retest2
assert chain[0].id == original.id
assert chain[1].retest_count == 1
assert chain[2].retest_count == 2
def test_retest_has_correct_data(self, db, techniques, admin_user):
"""Retest tiene mismos datos base que el original."""
original = Test(
technique_id=techniques[0].id,
name="Data Check Original",
description="Test description",
platform="windows",
procedure_text="Run cmd /c whoami",
tool_used="cmd.exe",
state=TestState.validated,
remediation_status="completed",
created_by=admin_user.id,
)
db.add(original)
db.commit()
db.refresh(original)
retest = handle_remediation_completed(db, original, admin_user)
assert retest is not None
# Verify base data is copied
assert retest.technique_id == original.technique_id
assert retest.description == original.description
assert retest.platform == original.platform
assert retest.procedure_text == original.procedure_text
assert retest.tool_used == original.tool_used
assert retest.created_by == original.created_by
assert retest.state == TestState.draft