"""Campaign service — business logic for campaign management. Handles circular dependency validation, campaign generation from threat actors, and progress calculation. """ import logging import uuid from datetime import datetime from sqlalchemy.orm import Session from app.domain.exceptions import EntityNotFoundError, InvalidOperationError from app.models.campaign import Campaign, CampaignTest, KILL_CHAIN_PHASES from app.models.test import Test from app.models.test_template import TestTemplate from app.models.technique import Technique from app.models.threat_actor import ThreatActor, ThreatActorTechnique from app.models.enums import TechniqueStatus, TestState from app.services.notification_service import create_notification from app.models.user import User logger = logging.getLogger(__name__) # Mapping from ATT&CK tactics to kill chain phases TACTIC_TO_PHASE: dict[str, str] = { "reconnaissance": "reconnaissance", "resource-development": "resource_development", "initial-access": "initial_access", "execution": "execution", "persistence": "persistence", "privilege-escalation": "privilege_escalation", "defense-evasion": "defense_evasion", "credential-access": "credential_access", "discovery": "discovery", "lateral-movement": "lateral_movement", "collection": "collection", "command-and-control": "command_and_control", "exfiltration": "exfiltration", "impact": "impact", } def validate_no_circular_dependency( db: Session, campaign_id: uuid.UUID, test_id: uuid.UUID, depends_on_id: uuid.UUID | None, ) -> None: """Walk the depends_on chain and verify no cycle is formed. Raises :class:`InvalidOperationError` if a circular dependency is detected. """ if depends_on_id is None: return visited: set[uuid.UUID] = set() current = depends_on_id while current is not None: if current in visited or current == test_id: raise InvalidOperationError( "Circular dependency detected in campaign test chain" ) visited.add(current) parent = db.query(CampaignTest).filter_by(id=current).first() current = parent.depends_on if parent else None def get_campaign_progress(db: Session, campaign_id: uuid.UUID) -> dict: """Calculate progress statistics for a campaign. Returns counts of tests by state, plus total and completion percentage. """ campaign_tests = ( db.query(CampaignTest) .filter(CampaignTest.campaign_id == campaign_id) .all() ) if not campaign_tests: return { "total": 0, "by_state": {}, "completion_pct": 0.0, } by_state: dict[str, int] = {} for ct in campaign_tests: test = ct.test state = test.state.value if test and test.state else "unknown" by_state[state] = by_state.get(state, 0) + 1 total = len(campaign_tests) completed = by_state.get("validated", 0) completion_pct = round(completed / total * 100, 1) if total > 0 else 0.0 return { "total": total, "by_state": by_state, "completion_pct": completion_pct, } def generate_campaign_from_threat_actor( db: Session, actor_id: uuid.UUID, user: User, ) -> Campaign: """Auto-generate a campaign from a threat actor's uncovered techniques. Steps: 1. Get techniques of the actor that are NOT validated 2. For each, find the best template (highest severity) 3. Create a test from each template 4. Create a campaign with tests ordered by kill chain phase 5. Return the campaign """ actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first() if not actor: raise EntityNotFoundError("ThreatActor", str(actor_id)) # Get unvalidated techniques for this actor gap_techniques = ( db.query(Technique, ThreatActorTechnique) .join(ThreatActorTechnique, ThreatActorTechnique.technique_id == Technique.id) .filter(ThreatActorTechnique.threat_actor_id == actor_id) .filter(Technique.status_global != TechniqueStatus.validated) .order_by(Technique.tactic, Technique.mitre_id) .all() ) if not gap_techniques: raise InvalidOperationError( f"No uncovered techniques found for {actor.name}" ) # Create the campaign campaign = Campaign( name=f"APT Emulation: {actor.name}", description=f"Auto-generated campaign to test coverage against {actor.name} " f"({actor.mitre_id or 'unknown'}). " f"Covers {len(gap_techniques)} uncovered technique(s).", type="apt_emulation", threat_actor_id=actor_id, status="draft", created_by=user.id, tags=[actor.name, "auto-generated"], ) db.add(campaign) db.flush() # Get campaign.id order_index = 0 for tech, _at in gap_techniques: # Find best template for this technique template = ( db.query(TestTemplate) .filter( TestTemplate.mitre_technique_id == tech.mitre_id, TestTemplate.is_active == True, # noqa: E712 ) .order_by( # Prioritize by severity: critical > high > medium > low TestTemplate.severity.desc(), TestTemplate.name, ) .first() ) if not template: continue # Skip techniques without templates # Create a test from the template test = Test( technique_id=tech.id, name=f"[Campaign] {template.name}", description=template.description, platform=template.platform, procedure_text=template.attack_procedure, tool_used=template.tool_suggested, created_by=user.id, state=TestState.draft, ) db.add(test) db.flush() # Get test.id # Determine kill chain phase from the technique's tactic phase = TACTIC_TO_PHASE.get(tech.tactic, None) if tech.tactic else None # Add to campaign campaign_test = CampaignTest( campaign_id=campaign.id, test_id=test.id, order_index=order_index, phase=phase, ) db.add(campaign_test) order_index += 1 db.commit() db.refresh(campaign) logger.info( "Generated campaign '%s' with %d tests for actor %s", campaign.name, order_index, actor.name, ) return campaign