Some checks failed
Aegis CI / lint-and-test (push) Has been cancelled
Foundational changes required before any new feature work can begin. - 0.1 Redis infrastructure: add redis:7-alpine to docker-compose dev and prod, REDIS_URL config, singleton client in app/infrastructure/redis_client.py - 0.2 Token blacklist on Redis SEC-001: replace in-memory dict with Redis SETEX keyed by jti, auto-expiring TTL derived from token exp - 0.3 Database indexes SR-006: Alembic migration b019 with 5 composite indexes for scoring, MTTD/MTTR, remediation, and notification queries - 0.4 Domain exceptions TD-003: app/domain/exceptions.py with typed errors, error_handler middleware mapping them to HTTP, services decoupled from FastAPI - 0.5 Fix silenced exceptions TD-007: replace 4 bare except-pass blocks in test_workflow_service with logger.warning with exc_info - 0.6 CI pipeline TD-009: GitHub Actions workflow with Postgres and Redis service containers, ruff lint, pytest; ruff.toml for baseline config
465 lines
16 KiB
Python
465 lines
16 KiB
Python
"""Tests for campaigns, snapshots, and re-testing — T-237.
|
|
|
|
Uses the in-memory SQLite test database from conftest.py.
|
|
"""
|
|
|
|
import uuid
|
|
from datetime import datetime, timedelta
|
|
|
|
import pytest
|
|
|
|
from app.models.technique import Technique
|
|
from app.models.test import Test
|
|
from app.models.test_template import TestTemplate
|
|
from app.models.campaign import Campaign, CampaignTest
|
|
from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState
|
|
from app.models.enums import TestState, TestResult, TechniqueStatus
|
|
from app.services.campaign_service import (
|
|
validate_no_circular_dependency,
|
|
get_campaign_progress,
|
|
)
|
|
from app.services.campaign_scheduler_service import (
|
|
calculate_next_run,
|
|
check_and_run_recurring_campaigns,
|
|
)
|
|
from app.services.snapshot_service import (
|
|
create_snapshot,
|
|
compare_snapshots,
|
|
cleanup_old_snapshots,
|
|
)
|
|
from app.services.test_workflow_service import (
|
|
handle_remediation_completed,
|
|
get_retest_chain,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Shared fixtures
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.fixture
|
|
def techniques(db):
|
|
"""Create a set of techniques for testing."""
|
|
techs = []
|
|
for mid, name, status in [
|
|
("T1059", "Command Line", TechniqueStatus.validated),
|
|
("T1078", "Valid Accounts", TechniqueStatus.partial),
|
|
("T1053", "Scheduled Tasks", TechniqueStatus.not_covered),
|
|
]:
|
|
tech = Technique(
|
|
mitre_id=mid,
|
|
name=name,
|
|
tactic="execution",
|
|
platforms=["windows"],
|
|
status_global=status,
|
|
)
|
|
db.add(tech)
|
|
techs.append(tech)
|
|
db.commit()
|
|
for t in techs:
|
|
db.refresh(t)
|
|
return techs
|
|
|
|
|
|
@pytest.fixture
|
|
def campaign_with_tests(db, techniques, admin_user):
|
|
"""Create a campaign with ordered tests."""
|
|
campaign = Campaign(
|
|
name="Test Campaign",
|
|
type="custom",
|
|
status="draft",
|
|
created_by=admin_user.id,
|
|
)
|
|
db.add(campaign)
|
|
db.flush()
|
|
|
|
tests = []
|
|
for i, tech in enumerate(techniques):
|
|
test = Test(
|
|
technique_id=tech.id,
|
|
name=f"Test for {tech.mitre_id}",
|
|
state=TestState.draft,
|
|
created_by=admin_user.id,
|
|
)
|
|
db.add(test)
|
|
db.flush()
|
|
tests.append(test)
|
|
|
|
ct = CampaignTest(
|
|
campaign_id=campaign.id,
|
|
test_id=test.id,
|
|
order_index=i,
|
|
phase="execution",
|
|
)
|
|
db.add(ct)
|
|
|
|
db.commit()
|
|
db.refresh(campaign)
|
|
return {"campaign": campaign, "tests": tests}
|
|
|
|
|
|
# ═══════════════════════════════════════════════════════════════════════
|
|
# Campaign Tests
|
|
# ═══════════════════════════════════════════════════════════════════════
|
|
|
|
|
|
class TestCampaigns:
|
|
|
|
def test_create_campaign_with_tests(self, db, campaign_with_tests):
|
|
"""CRUD básico de campaña con tests ordenados."""
|
|
campaign = campaign_with_tests["campaign"]
|
|
assert campaign.name == "Test Campaign"
|
|
assert campaign.status == "draft"
|
|
|
|
cts = (
|
|
db.query(CampaignTest)
|
|
.filter(CampaignTest.campaign_id == campaign.id)
|
|
.order_by(CampaignTest.order_index)
|
|
.all()
|
|
)
|
|
assert len(cts) == 3
|
|
assert cts[0].order_index == 0
|
|
assert cts[1].order_index == 1
|
|
assert cts[2].order_index == 2
|
|
|
|
def test_campaign_progress_calculation(self, db, campaign_with_tests):
|
|
"""Progreso se calcula según estado de tests."""
|
|
campaign = campaign_with_tests["campaign"]
|
|
tests = campaign_with_tests["tests"]
|
|
|
|
# Initially all draft → 0% complete
|
|
progress = get_campaign_progress(db, campaign.id)
|
|
assert progress["total"] == 3
|
|
assert progress["completion_pct"] == 0.0
|
|
|
|
# Validate one test
|
|
tests[0].state = TestState.validated
|
|
db.commit()
|
|
|
|
progress = get_campaign_progress(db, campaign.id)
|
|
assert progress["completion_pct"] == pytest.approx(33.3, abs=0.1)
|
|
|
|
def test_circular_dependency_prevention(self, db, campaign_with_tests):
|
|
"""Intentar crear dependencia circular en campaign_tests falla."""
|
|
from app.domain.exceptions import InvalidOperationError
|
|
|
|
campaign = campaign_with_tests["campaign"]
|
|
cts = (
|
|
db.query(CampaignTest)
|
|
.filter(CampaignTest.campaign_id == campaign.id)
|
|
.order_by(CampaignTest.order_index)
|
|
.all()
|
|
)
|
|
|
|
# Create A -> B dependency
|
|
cts[1].depends_on = cts[0].id
|
|
db.commit()
|
|
|
|
# Try to create B -> A (circular)
|
|
with pytest.raises(InvalidOperationError) as exc_info:
|
|
validate_no_circular_dependency(
|
|
db, campaign.id, cts[0].id, cts[1].id
|
|
)
|
|
assert exc_info.value.code == "INVALID_OPERATION"
|
|
|
|
def test_campaign_scheduling_next_run(self):
|
|
"""next_run_at se calcula correctamente para weekly/monthly/quarterly."""
|
|
base = datetime(2026, 1, 1)
|
|
|
|
weekly = calculate_next_run(base, "weekly")
|
|
assert weekly == datetime(2026, 1, 8)
|
|
|
|
monthly = calculate_next_run(base, "monthly")
|
|
assert monthly == datetime(2026, 1, 31)
|
|
|
|
quarterly = calculate_next_run(base, "quarterly")
|
|
assert quarterly == datetime(2026, 4, 1)
|
|
|
|
def test_campaign_cloning(self, db, campaign_with_tests, admin_user):
|
|
"""Clonación de campaña recurrente crea tests nuevos con datos correctos."""
|
|
campaign = campaign_with_tests["campaign"]
|
|
original_tests = campaign_with_tests["tests"]
|
|
|
|
# Set up as recurring
|
|
campaign.is_recurring = True
|
|
campaign.recurrence_pattern = "monthly"
|
|
campaign.next_run_at = datetime.utcnow() - timedelta(hours=1) # Due now
|
|
db.commit()
|
|
|
|
# Run the scheduler
|
|
spawned = check_and_run_recurring_campaigns(db)
|
|
assert spawned == 1
|
|
|
|
# Find child campaign
|
|
child = (
|
|
db.query(Campaign)
|
|
.filter(Campaign.parent_campaign_id == campaign.id)
|
|
.first()
|
|
)
|
|
assert child is not None
|
|
assert "Run" in child.name
|
|
assert child.status == "active"
|
|
|
|
# Check child tests are fresh copies (new IDs, draft state)
|
|
child_cts = (
|
|
db.query(CampaignTest)
|
|
.filter(CampaignTest.campaign_id == child.id)
|
|
.all()
|
|
)
|
|
assert len(child_cts) == len(original_tests)
|
|
|
|
child_test_ids = {ct.test_id for ct in child_cts}
|
|
original_test_ids = {t.id for t in original_tests}
|
|
assert child_test_ids.isdisjoint(original_test_ids) # All new IDs
|
|
|
|
for ct in child_cts:
|
|
test = db.query(Test).filter(Test.id == ct.test_id).first()
|
|
assert test.state == TestState.draft
|
|
|
|
# Check parent was updated
|
|
db.refresh(campaign)
|
|
assert campaign.last_run_at is not None
|
|
assert campaign.next_run_at > datetime.utcnow()
|
|
|
|
|
|
# ═══════════════════════════════════════════════════════════════════════
|
|
# Snapshot Tests
|
|
# ═══════════════════════════════════════════════════════════════════════
|
|
|
|
|
|
class TestSnapshots:
|
|
|
|
def test_create_snapshot(self, db, techniques, admin_user):
|
|
"""Snapshot captura estado actual correctamente."""
|
|
snapshot = create_snapshot(db, name="Test Snapshot", user_id=admin_user.id)
|
|
|
|
assert snapshot is not None
|
|
assert snapshot.name == "Test Snapshot"
|
|
assert snapshot.total_techniques == len(techniques)
|
|
assert snapshot.created_by == admin_user.id
|
|
assert snapshot.organization_score >= 0
|
|
|
|
# Verify per-technique states
|
|
states = (
|
|
db.query(SnapshotTechniqueState)
|
|
.filter(SnapshotTechniqueState.snapshot_id == snapshot.id)
|
|
.all()
|
|
)
|
|
assert len(states) == len(techniques)
|
|
|
|
mitre_ids = {s.mitre_id for s in states}
|
|
assert "T1059" in mitre_ids
|
|
assert "T1078" in mitre_ids
|
|
assert "T1053" in mitre_ids
|
|
|
|
def test_compare_snapshots_improvements(self, db, techniques, admin_user):
|
|
"""Comparación detecta técnicas que mejoraron."""
|
|
# Create snapshot A
|
|
snap_a = create_snapshot(db, name="Before")
|
|
|
|
# Improve a technique
|
|
tech = db.query(Technique).filter(Technique.mitre_id == "T1053").first()
|
|
tech.status_global = TechniqueStatus.validated
|
|
db.commit()
|
|
|
|
# Create snapshot B
|
|
snap_b = create_snapshot(db, name="After")
|
|
|
|
result = compare_snapshots(db, snap_a.id, snap_b.id)
|
|
|
|
assert result["score_delta"] is not None
|
|
assert result["summary"]["improved_count"] >= 0
|
|
assert isinstance(result["improved"], list)
|
|
assert isinstance(result["worsened"], list)
|
|
assert result["unchanged_count"] >= 0
|
|
|
|
def test_compare_snapshots_regressions(self, db, techniques, admin_user):
|
|
"""Comparación detecta técnicas que empeoraron."""
|
|
# Create snapshot A
|
|
snap_a = create_snapshot(db, name="Before Regression")
|
|
|
|
# Worsen a technique
|
|
tech = db.query(Technique).filter(Technique.mitre_id == "T1059").first()
|
|
tech.status_global = TechniqueStatus.not_covered
|
|
db.commit()
|
|
|
|
snap_b = create_snapshot(db, name="After Regression")
|
|
|
|
result = compare_snapshots(db, snap_a.id, snap_b.id)
|
|
assert result["summary"]["worsened_count"] >= 0
|
|
|
|
def test_snapshot_cleanup(self, db, techniques, admin_user):
|
|
"""Cleanup mantiene solo los últimos N snapshots."""
|
|
# Create 5 snapshots
|
|
for i in range(5):
|
|
create_snapshot(db, name=f"Snapshot {i}")
|
|
|
|
total_before = db.query(CoverageSnapshot).count()
|
|
assert total_before == 5
|
|
|
|
# Cleanup keeping only 3
|
|
deleted = cleanup_old_snapshots(db, keep_last=3)
|
|
assert deleted == 2
|
|
|
|
total_after = db.query(CoverageSnapshot).count()
|
|
assert total_after == 3
|
|
|
|
def test_snapshot_normalized_storage(self, db, techniques, admin_user):
|
|
"""Verificar que el almacenamiento normalizado funciona correctamente."""
|
|
snapshot = create_snapshot(db, name="Normalized Check")
|
|
|
|
# Each technique should have exactly one SnapshotTechniqueState row
|
|
for tech in techniques:
|
|
states = (
|
|
db.query(SnapshotTechniqueState)
|
|
.filter(
|
|
SnapshotTechniqueState.snapshot_id == snapshot.id,
|
|
SnapshotTechniqueState.technique_id == tech.id,
|
|
)
|
|
.all()
|
|
)
|
|
assert len(states) == 1
|
|
state = states[0]
|
|
assert state.mitre_id == tech.mitre_id
|
|
assert state.status is not None
|
|
|
|
|
|
# ═══════════════════════════════════════════════════════════════════════
|
|
# Re-testing Tests
|
|
# ═══════════════════════════════════════════════════════════════════════
|
|
|
|
|
|
class TestRetesting:
|
|
|
|
def test_retest_created_on_remediation(self, db, techniques, admin_user):
|
|
"""Completar remediación crea retest automáticamente."""
|
|
test = Test(
|
|
technique_id=techniques[0].id,
|
|
name="Original Test",
|
|
state=TestState.validated,
|
|
remediation_status="completed",
|
|
created_by=admin_user.id,
|
|
)
|
|
db.add(test)
|
|
db.commit()
|
|
db.refresh(test)
|
|
|
|
retest = handle_remediation_completed(db, test, admin_user)
|
|
assert retest is not None
|
|
assert retest.retest_of == test.id
|
|
assert retest.retest_count == 1
|
|
assert retest.state == TestState.draft
|
|
assert retest.technique_id == test.technique_id
|
|
|
|
def test_retest_points_to_original(self, db, techniques, admin_user):
|
|
"""Retest de un retest apunta al test original, no al intermedio."""
|
|
original = Test(
|
|
technique_id=techniques[0].id,
|
|
name="Original",
|
|
state=TestState.validated,
|
|
remediation_status="completed",
|
|
created_by=admin_user.id,
|
|
retest_count=0,
|
|
)
|
|
db.add(original)
|
|
db.commit()
|
|
db.refresh(original)
|
|
|
|
# First retest
|
|
retest1 = handle_remediation_completed(db, original, admin_user)
|
|
assert retest1 is not None
|
|
assert retest1.retest_of == original.id
|
|
|
|
# Simulate completing remediation on retest1
|
|
retest1.state = TestState.validated
|
|
retest1.remediation_status = "completed"
|
|
db.commit()
|
|
db.refresh(retest1)
|
|
|
|
# Second retest — should point to ORIGINAL, not retest1
|
|
retest2 = handle_remediation_completed(db, retest1, admin_user)
|
|
assert retest2 is not None
|
|
assert retest2.retest_of == original.id # Points to original!
|
|
assert retest2.retest_count == 2
|
|
|
|
def test_retest_max_limit(self, db, techniques, admin_user):
|
|
"""Al alcanzar MAX_RETEST_COUNT no se crea retest."""
|
|
from app.config import settings
|
|
|
|
test = Test(
|
|
technique_id=techniques[0].id,
|
|
name="Max Retests Test",
|
|
state=TestState.validated,
|
|
remediation_status="completed",
|
|
created_by=admin_user.id,
|
|
retest_count=settings.MAX_RETEST_COUNT, # Already at max
|
|
)
|
|
db.add(test)
|
|
db.commit()
|
|
db.refresh(test)
|
|
|
|
result = handle_remediation_completed(db, test, admin_user)
|
|
assert result is None # No retest created
|
|
|
|
def test_retest_chain_query(self, db, techniques, admin_user):
|
|
"""Endpoint /tests/{id}/retest-chain retorna cadena completa."""
|
|
original = Test(
|
|
technique_id=techniques[0].id,
|
|
name="Chain Original",
|
|
state=TestState.validated,
|
|
remediation_status="completed",
|
|
created_by=admin_user.id,
|
|
)
|
|
db.add(original)
|
|
db.commit()
|
|
db.refresh(original)
|
|
|
|
retest1 = handle_remediation_completed(db, original, admin_user)
|
|
assert retest1 is not None
|
|
|
|
# Complete retest1 and trigger another
|
|
retest1.state = TestState.validated
|
|
retest1.remediation_status = "completed"
|
|
db.commit()
|
|
db.refresh(retest1)
|
|
|
|
retest2 = handle_remediation_completed(db, retest1, admin_user)
|
|
assert retest2 is not None
|
|
|
|
# Get chain
|
|
chain = get_retest_chain(db, original.id)
|
|
assert len(chain) == 3 # original + retest1 + retest2
|
|
assert chain[0].id == original.id
|
|
assert chain[1].retest_count == 1
|
|
assert chain[2].retest_count == 2
|
|
|
|
def test_retest_has_correct_data(self, db, techniques, admin_user):
|
|
"""Retest tiene mismos datos base que el original."""
|
|
original = Test(
|
|
technique_id=techniques[0].id,
|
|
name="Data Check Original",
|
|
description="Test description",
|
|
platform="windows",
|
|
procedure_text="Run cmd /c whoami",
|
|
tool_used="cmd.exe",
|
|
state=TestState.validated,
|
|
remediation_status="completed",
|
|
created_by=admin_user.id,
|
|
)
|
|
db.add(original)
|
|
db.commit()
|
|
db.refresh(original)
|
|
|
|
retest = handle_remediation_completed(db, original, admin_user)
|
|
assert retest is not None
|
|
|
|
# Verify base data is copied
|
|
assert retest.technique_id == original.technique_id
|
|
assert retest.description == original.description
|
|
assert retest.platform == original.platform
|
|
assert retest.procedure_text == original.procedure_text
|
|
assert retest.tool_used == original.tool_used
|
|
assert retest.created_by == original.created_by
|
|
assert retest.state == TestState.draft
|