214 lines
6.5 KiB
Python
214 lines
6.5 KiB
Python
"""Campaign service — business logic for campaign management.
|
|
|
|
Handles circular dependency validation, campaign generation from
|
|
threat actors, and progress calculation.
|
|
"""
|
|
|
|
import logging
|
|
import uuid
|
|
from datetime import datetime
|
|
|
|
from fastapi import HTTPException
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.models.campaign import Campaign, CampaignTest, KILL_CHAIN_PHASES
|
|
from app.models.test import Test
|
|
from app.models.test_template import TestTemplate
|
|
from app.models.technique import Technique
|
|
from app.models.threat_actor import ThreatActor, ThreatActorTechnique
|
|
from app.models.enums import TechniqueStatus, TestState
|
|
from app.services.notification_service import create_notification
|
|
from app.models.user import User
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Mapping from ATT&CK tactics to kill chain phases
|
|
TACTIC_TO_PHASE: dict[str, str] = {
|
|
"reconnaissance": "reconnaissance",
|
|
"resource-development": "resource_development",
|
|
"initial-access": "initial_access",
|
|
"execution": "execution",
|
|
"persistence": "persistence",
|
|
"privilege-escalation": "privilege_escalation",
|
|
"defense-evasion": "defense_evasion",
|
|
"credential-access": "credential_access",
|
|
"discovery": "discovery",
|
|
"lateral-movement": "lateral_movement",
|
|
"collection": "collection",
|
|
"command-and-control": "command_and_control",
|
|
"exfiltration": "exfiltration",
|
|
"impact": "impact",
|
|
}
|
|
|
|
|
|
def validate_no_circular_dependency(
|
|
db: Session,
|
|
campaign_id: uuid.UUID,
|
|
test_id: uuid.UUID,
|
|
depends_on_id: uuid.UUID | None,
|
|
) -> None:
|
|
"""Walk the depends_on chain and verify no cycle is formed.
|
|
|
|
Raises HTTPException(400) if a circular dependency is detected.
|
|
"""
|
|
if depends_on_id is None:
|
|
return
|
|
|
|
visited: set[uuid.UUID] = set()
|
|
current = depends_on_id
|
|
|
|
while current is not None:
|
|
if current in visited or current == test_id:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Circular dependency detected in campaign test chain",
|
|
)
|
|
visited.add(current)
|
|
parent = db.query(CampaignTest).filter_by(id=current).first()
|
|
current = parent.depends_on if parent else None
|
|
|
|
|
|
def get_campaign_progress(db: Session, campaign_id: uuid.UUID) -> dict:
|
|
"""Calculate progress statistics for a campaign.
|
|
|
|
Returns counts of tests by state, plus total and completion percentage.
|
|
"""
|
|
campaign_tests = (
|
|
db.query(CampaignTest)
|
|
.filter(CampaignTest.campaign_id == campaign_id)
|
|
.all()
|
|
)
|
|
|
|
if not campaign_tests:
|
|
return {
|
|
"total": 0,
|
|
"by_state": {},
|
|
"completion_pct": 0.0,
|
|
}
|
|
|
|
by_state: dict[str, int] = {}
|
|
for ct in campaign_tests:
|
|
test = ct.test
|
|
state = test.state.value if test and test.state else "unknown"
|
|
by_state[state] = by_state.get(state, 0) + 1
|
|
|
|
total = len(campaign_tests)
|
|
completed = by_state.get("validated", 0)
|
|
completion_pct = round(completed / total * 100, 1) if total > 0 else 0.0
|
|
|
|
return {
|
|
"total": total,
|
|
"by_state": by_state,
|
|
"completion_pct": completion_pct,
|
|
}
|
|
|
|
|
|
def generate_campaign_from_threat_actor(
|
|
db: Session,
|
|
actor_id: uuid.UUID,
|
|
user: User,
|
|
) -> Campaign:
|
|
"""Auto-generate a campaign from a threat actor's uncovered techniques.
|
|
|
|
Steps:
|
|
1. Get techniques of the actor that are NOT validated
|
|
2. For each, find the best template (highest severity)
|
|
3. Create a test from each template
|
|
4. Create a campaign with tests ordered by kill chain phase
|
|
5. Return the campaign
|
|
"""
|
|
actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first()
|
|
if not actor:
|
|
raise HTTPException(status_code=404, detail="Threat actor not found")
|
|
|
|
# Get unvalidated techniques for this actor
|
|
gap_techniques = (
|
|
db.query(Technique, ThreatActorTechnique)
|
|
.join(ThreatActorTechnique, ThreatActorTechnique.technique_id == Technique.id)
|
|
.filter(ThreatActorTechnique.threat_actor_id == actor_id)
|
|
.filter(Technique.status_global != TechniqueStatus.validated)
|
|
.order_by(Technique.tactic, Technique.mitre_id)
|
|
.all()
|
|
)
|
|
|
|
if not gap_techniques:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"No uncovered techniques found for {actor.name}",
|
|
)
|
|
|
|
# Create the campaign
|
|
campaign = Campaign(
|
|
name=f"APT Emulation: {actor.name}",
|
|
description=f"Auto-generated campaign to test coverage against {actor.name} "
|
|
f"({actor.mitre_id or 'unknown'}). "
|
|
f"Covers {len(gap_techniques)} uncovered technique(s).",
|
|
type="apt_emulation",
|
|
threat_actor_id=actor_id,
|
|
status="draft",
|
|
created_by=user.id,
|
|
tags=[actor.name, "auto-generated"],
|
|
)
|
|
db.add(campaign)
|
|
db.flush() # Get campaign.id
|
|
|
|
order_index = 0
|
|
|
|
for tech, _at in gap_techniques:
|
|
# Find best template for this technique
|
|
template = (
|
|
db.query(TestTemplate)
|
|
.filter(
|
|
TestTemplate.mitre_technique_id == tech.mitre_id,
|
|
TestTemplate.is_active == True, # noqa: E712
|
|
)
|
|
.order_by(
|
|
# Prioritize by severity: critical > high > medium > low
|
|
TestTemplate.severity.desc(),
|
|
TestTemplate.name,
|
|
)
|
|
.first()
|
|
)
|
|
|
|
if not template:
|
|
continue # Skip techniques without templates
|
|
|
|
# Create a test from the template
|
|
test = Test(
|
|
technique_id=tech.id,
|
|
name=f"[Campaign] {template.name}",
|
|
description=template.description,
|
|
platform=template.platform,
|
|
procedure_text=template.attack_procedure,
|
|
tool_used=template.tool_suggested,
|
|
created_by=user.id,
|
|
state=TestState.draft,
|
|
)
|
|
db.add(test)
|
|
db.flush() # Get test.id
|
|
|
|
# Determine kill chain phase from the technique's tactic
|
|
phase = TACTIC_TO_PHASE.get(tech.tactic, None) if tech.tactic else None
|
|
|
|
# Add to campaign
|
|
campaign_test = CampaignTest(
|
|
campaign_id=campaign.id,
|
|
test_id=test.id,
|
|
order_index=order_index,
|
|
phase=phase,
|
|
)
|
|
db.add(campaign_test)
|
|
order_index += 1
|
|
|
|
db.commit()
|
|
db.refresh(campaign)
|
|
|
|
logger.info(
|
|
"Generated campaign '%s' with %d tests for actor %s",
|
|
campaign.name,
|
|
order_index,
|
|
actor.name,
|
|
)
|
|
|
|
return campaign
|