"""Campaign service — business logic for campaign management. Handles circular dependency validation, campaign generation from threat actors, and progress calculation. """ # Import logging import logging import uuid from datetime import datetime from typing import Optional # Import uuid import uuid # Import Session from sqlalchemy.orm from sqlalchemy.orm import Session # Import EntityNotFoundError, InvalidOperationError from app.domain.exceptions from app.domain.exceptions import EntityNotFoundError, InvalidOperationError # Import Campaign, CampaignTest from app.models.campaign from app.models.campaign import Campaign, CampaignTest # Import TechniqueStatus, TestState from app.models.enums from app.models.enums import TechniqueStatus, TestState # Import Technique from app.models.technique from app.models.technique import Technique # Import Test from app.models.test from app.models.test import Test # Import TestTemplate from app.models.test_template from app.models.test_template import TestTemplate # Import ThreatActor, ThreatActorTechnique from app.models.threat_actor from app.models.threat_actor import ThreatActor, ThreatActorTechnique # Import User from app.models.user from app.models.user import User # Assign logger = logging.getLogger(__name__) logger = logging.getLogger(__name__) # Mapping from ATT&CK tactics to kill chain phases TACTIC_TO_PHASE: dict[str, str] = { # Literal argument value "reconnaissance": "reconnaissance", # Literal argument value "resource-development": "resource_development", # Literal argument value "initial-access": "initial_access", # Literal argument value "execution": "execution", # Literal argument value "persistence": "persistence", # Literal argument value "privilege-escalation": "privilege_escalation", # Literal argument value "defense-evasion": "defense_evasion", # Literal argument value "credential-access": "credential_access", # Literal argument value "discovery": "discovery", # Literal argument value "lateral-movement": "lateral_movement", # Literal argument value "collection": "collection", # Literal argument value "command-and-control": "command_and_control", # Literal argument value "exfiltration": "exfiltration", # Literal argument value "impact": "impact", } # Define function validate_no_circular_dependency def validate_no_circular_dependency( # Entry: db db: Session, # Entry: campaign_id campaign_id: uuid.UUID, # Entry: test_id test_id: uuid.UUID, # Entry: depends_on_id depends_on_id: uuid.UUID | None, ) -> None: """Walk the depends_on chain and verify no cycle is formed. Raises :class:`InvalidOperationError` if a circular dependency is detected. """ # Check: depends_on_id is None if depends_on_id is None: # Return control to caller return # Assign visited = set() visited: set[uuid.UUID] = set() # Assign current = depends_on_id current = depends_on_id # Loop while current is not None while current is not None: # Check: current in visited or current == test_id if current in visited or current == test_id: # Raise InvalidOperationError raise InvalidOperationError( # Literal argument value "Circular dependency detected in campaign test chain" ) # Call visited.add() visited.add(current) # Assign parent = db.query(CampaignTest).filter_by(id=current).first() parent = db.query(CampaignTest).filter_by(id=current).first() # Assign current = parent.depends_on if parent else None current = parent.depends_on if parent else None # Define function get_campaign_progress def get_campaign_progress(db: Session, campaign_id: uuid.UUID) -> dict: """Calculate progress statistics for a campaign. Returns counts of tests by state, plus total and completion percentage. """ # Assign campaign_tests = ( campaign_tests = ( db.query(CampaignTest) # Chain .filter() call .filter(CampaignTest.campaign_id == campaign_id) # Chain .all() call .all() ) # Check: not campaign_tests if not campaign_tests: # Return { return { # Literal argument value "total": 0, # Literal argument value "by_state": {}, # Literal argument value "completion_pct": 0.0, } # Assign by_state = {} by_state: dict[str, int] = {} # Iterate over campaign_tests for ct in campaign_tests: # Assign test = ct.test test = ct.test # Assign state = test.state.value if test and test.state else "unknown" state = test.state.value if test and test.state else "unknown" # Assign by_state[state] = by_state.get(state, 0) + 1 by_state[state] = by_state.get(state, 0) + 1 # Assign total = len(campaign_tests) total = len(campaign_tests) # Assign completed = by_state.get("validated", 0) completed = by_state.get("validated", 0) # Assign completion_pct = round(completed / total * 100, 1) if total > 0 else 0.0 completion_pct = round(completed / total * 100, 1) if total > 0 else 0.0 # Return { return { # Literal argument value "total": total, # Literal argument value "by_state": by_state, # Literal argument value "completion_pct": completion_pct, } # Define function generate_campaign_from_threat_actor def generate_campaign_from_threat_actor( # Entry: db db: Session, # Entry: actor_id actor_id: uuid.UUID, # Entry: user user: User, *, start_date: Optional[datetime] = None, ) -> Campaign: """Auto-generate a campaign from a threat actor's uncovered techniques. Steps: 1. Get techniques of the actor that are NOT validated 2. For each, find the best template (highest severity) 3. Create a test from each template 4. Create a campaign with tests ordered by kill chain phase 5. Return the campaign """ # Assign actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first() actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first() # Check: not actor if not actor: # Raise EntityNotFoundError raise EntityNotFoundError("ThreatActor", str(actor_id)) # Get unvalidated techniques for this actor gap_techniques = ( db.query(Technique, ThreatActorTechnique) # Chain .join() call .join(ThreatActorTechnique, ThreatActorTechnique.technique_id == Technique.id) # Chain .filter() call .filter(ThreatActorTechnique.threat_actor_id == actor_id) # Chain .filter() call .filter(Technique.status_global != TechniqueStatus.validated) # Chain .order_by() call .order_by(Technique.tactic, Technique.mitre_id) # Chain .all() call .all() ) # Check: not gap_techniques if not gap_techniques: # Raise InvalidOperationError raise InvalidOperationError( f"No uncovered techniques found for {actor.name}" ) # Create the campaign campaign = Campaign( # Keyword argument: name name=f"APT Emulation: {actor.name}", # Keyword argument: description description=f"Auto-generated campaign to test coverage against {actor.name} " f"({actor.mitre_id or 'unknown'}). " f"Covers {len(gap_techniques)} uncovered technique(s).", # Keyword argument: type type="apt_emulation", # Keyword argument: threat_actor_id threat_actor_id=actor_id, # Keyword argument: status status="draft", # Keyword argument: created_by created_by=user.id, # Keyword argument: tags tags=[actor.name, "auto-generated"], start_date=start_date, ) # Stage new record(s) for database insertion db.add(campaign) # Flush changes to DB without committing the transaction db.flush() # Get campaign.id # Assign order_index = 0 order_index = 0 # Iterate over gap_techniques for tech, _at in gap_techniques: # Find best template for this technique template = ( db.query(TestTemplate) # Chain .filter() call .filter( TestTemplate.mitre_technique_id == tech.mitre_id, TestTemplate.is_active == True, # noqa: E712 ) # Chain .order_by() call .order_by( # Prioritize by severity: critical > high > medium > low TestTemplate.severity.desc(), TestTemplate.name, ) # Chain .first() call .first() ) # Check: not template if not template: # continue # Skip techniques without templates continue # Skip techniques without templates # Create a test from the template test = Test( # Keyword argument: technique_id technique_id=tech.id, # Keyword argument: name name=f"[Campaign] {template.name}", # Keyword argument: description description=template.description, # Keyword argument: platform platform=template.platform, # Keyword argument: procedure_text procedure_text=template.attack_procedure, # Keyword argument: tool_used tool_used=template.tool_suggested, # Keyword argument: created_by created_by=user.id, # Keyword argument: state state=TestState.draft, created_at=datetime.utcnow(), ) # Stage new record(s) for database insertion db.add(test) # Flush changes to DB without committing the transaction db.flush() # Get test.id # Determine kill chain phase from the technique's tactic phase = TACTIC_TO_PHASE.get(tech.tactic, None) if tech.tactic else None # Add to campaign campaign_test = CampaignTest( # Keyword argument: campaign_id campaign_id=campaign.id, # Keyword argument: test_id test_id=test.id, # Keyword argument: order_index order_index=order_index, # Keyword argument: phase phase=phase, ) # Stage new record(s) for database insertion db.add(campaign_test) # Assign order_index = 1 order_index += 1 # Commit all pending changes to the database db.commit() # Reload ORM object attributes from the database db.refresh(campaign) # Log info: logger.info( # Literal argument value "Generated campaign '%s' with %d tests for actor %s", campaign.name, order_index, actor.name, ) # Return campaign return campaign