feat(phase-26): add Campaign models, endpoints, service with kill chain timeline UI (T-217 to T-220)
This commit is contained in:
74
backend/alembic/versions/b013_add_campaigns_tables.py
Normal file
74
backend/alembic/versions/b013_add_campaigns_tables.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""add_campaigns_tables
|
||||
|
||||
Revision ID: b013campaigns
|
||||
Revises: b012detectionassoc
|
||||
Create Date: 2026-02-09 18:00:00.000000
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'b013campaigns'
|
||||
down_revision: Union[str, Sequence[str], None] = 'b012detectionassoc'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Create campaigns and campaign_tests tables."""
|
||||
|
||||
# campaigns
|
||||
op.create_table(
|
||||
'campaigns',
|
||||
sa.Column('id', UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column('name', sa.String(), nullable=False),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('type', sa.String(), nullable=False, server_default='custom'),
|
||||
sa.Column('threat_actor_id', UUID(as_uuid=True),
|
||||
sa.ForeignKey('threat_actors.id', ondelete='SET NULL'), nullable=True),
|
||||
sa.Column('status', sa.String(), nullable=False, server_default='draft'),
|
||||
sa.Column('created_by', UUID(as_uuid=True),
|
||||
sa.ForeignKey('users.id', ondelete='SET NULL'), nullable=True),
|
||||
sa.Column('scheduled_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('completed_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('target_platform', sa.String(), nullable=True),
|
||||
sa.Column('tags', JSONB(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.func.now()),
|
||||
)
|
||||
op.create_index('ix_campaigns_status', 'campaigns', ['status'])
|
||||
op.create_index('ix_campaigns_type', 'campaigns', ['type'])
|
||||
op.create_index('ix_campaigns_threat_actor', 'campaigns', ['threat_actor_id'])
|
||||
op.create_index('ix_campaigns_created_by', 'campaigns', ['created_by'])
|
||||
|
||||
# campaign_tests
|
||||
op.create_table(
|
||||
'campaign_tests',
|
||||
sa.Column('id', UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column('campaign_id', UUID(as_uuid=True),
|
||||
sa.ForeignKey('campaigns.id', ondelete='CASCADE'), nullable=False),
|
||||
sa.Column('test_id', UUID(as_uuid=True),
|
||||
sa.ForeignKey('tests.id', ondelete='CASCADE'), nullable=False),
|
||||
sa.Column('order_index', sa.Integer(), nullable=False, server_default='0'),
|
||||
sa.Column('depends_on', UUID(as_uuid=True),
|
||||
sa.ForeignKey('campaign_tests.id', ondelete='SET NULL'), nullable=True),
|
||||
sa.Column('phase', sa.String(), nullable=True),
|
||||
)
|
||||
op.create_index('ix_campaign_tests_campaign', 'campaign_tests', ['campaign_id'])
|
||||
op.create_index('ix_campaign_tests_test', 'campaign_tests', ['test_id'])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Drop campaign_tests and campaigns tables."""
|
||||
op.drop_index('ix_campaign_tests_test', table_name='campaign_tests')
|
||||
op.drop_index('ix_campaign_tests_campaign', table_name='campaign_tests')
|
||||
op.drop_table('campaign_tests')
|
||||
op.drop_index('ix_campaigns_created_by', table_name='campaigns')
|
||||
op.drop_index('ix_campaigns_threat_actor', table_name='campaigns')
|
||||
op.drop_index('ix_campaigns_type', table_name='campaigns')
|
||||
op.drop_index('ix_campaigns_status', table_name='campaigns')
|
||||
op.drop_table('campaigns')
|
||||
@@ -22,6 +22,7 @@ from app.routers import data_sources as data_sources_router
|
||||
from app.routers import threat_actors as threat_actors_router
|
||||
from app.routers import d3fend as d3fend_router
|
||||
from app.routers import detection_rules as detection_rules_router
|
||||
from app.routers import campaigns as campaigns_router
|
||||
from app.storage import ensure_bucket_exists
|
||||
from app.jobs.mitre_sync_job import start_scheduler, scheduler
|
||||
|
||||
@@ -68,6 +69,7 @@ app.include_router(data_sources_router.router, prefix="/api/v1")
|
||||
app.include_router(threat_actors_router.router, prefix="/api/v1")
|
||||
app.include_router(d3fend_router.router, prefix="/api/v1")
|
||||
app.include_router(detection_rules_router.router, prefix="/api/v1")
|
||||
app.include_router(campaigns_router.router, prefix="/api/v1")
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
|
||||
@@ -13,6 +13,7 @@ from app.models.threat_actor import ThreatActor, ThreatActorTechnique
|
||||
from app.models.defensive_technique import DefensiveTechnique, DefensiveTechniqueMapping
|
||||
from app.models.test_template_detection_rule import TestTemplateDetectionRule
|
||||
from app.models.test_detection_result import TestDetectionResult
|
||||
from app.models.campaign import Campaign, CampaignTest
|
||||
from app.models.enums import TechniqueStatus, TestState, TestResult, TeamSide
|
||||
|
||||
__all__ = [
|
||||
@@ -21,5 +22,6 @@ __all__ = [
|
||||
"DetectionRule", "ThreatActor", "ThreatActorTechnique",
|
||||
"DefensiveTechnique", "DefensiveTechniqueMapping",
|
||||
"TestTemplateDetectionRule", "TestDetectionResult",
|
||||
"Campaign", "CampaignTest",
|
||||
"TechniqueStatus", "TestState", "TestResult", "TeamSide",
|
||||
]
|
||||
|
||||
132
backend/app/models/campaign.py
Normal file
132
backend/app/models/campaign.py
Normal file
@@ -0,0 +1,132 @@
|
||||
"""Campaign and CampaignTest models.
|
||||
|
||||
Campaigns group multiple tests into a kill chain sequence,
|
||||
enabling simulation of complete attack chains and APT emulations.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import (
|
||||
Column, String, Text, Integer, DateTime,
|
||||
ForeignKey, Index,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class Campaign(Base):
|
||||
"""
|
||||
A campaign groups multiple tests into a sequenced attack chain.
|
||||
|
||||
Types:
|
||||
- custom: manually created campaign
|
||||
- apt_emulation: generated from a threat actor profile
|
||||
- kill_chain: structured around kill chain phases
|
||||
- compliance: targeting specific compliance requirements
|
||||
|
||||
Status:
|
||||
- draft: being configured
|
||||
- active: tests are being executed
|
||||
- completed: all tests done
|
||||
- archived: historical record
|
||||
"""
|
||||
__tablename__ = "campaigns"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
name = Column(String, nullable=False)
|
||||
description = Column(Text, nullable=True)
|
||||
type = Column(String, nullable=False, default="custom") # custom, apt_emulation, kill_chain, compliance
|
||||
threat_actor_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("threat_actors.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
status = Column(String, nullable=False, default="draft") # draft, active, completed, archived
|
||||
created_by = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("users.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
scheduled_at = Column(DateTime, nullable=True)
|
||||
completed_at = Column(DateTime, nullable=True)
|
||||
target_platform = Column(String, nullable=True)
|
||||
tags = Column(JSONB, nullable=True, default=[])
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
# Relationships
|
||||
threat_actor = relationship("ThreatActor")
|
||||
creator = relationship("User", foreign_keys=[created_by])
|
||||
campaign_tests = relationship(
|
||||
"CampaignTest",
|
||||
back_populates="campaign",
|
||||
cascade="all, delete-orphan",
|
||||
order_by="CampaignTest.order_index",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index('ix_campaigns_status', 'status'),
|
||||
Index('ix_campaigns_type', 'type'),
|
||||
Index('ix_campaigns_threat_actor', 'threat_actor_id'),
|
||||
Index('ix_campaigns_created_by', 'created_by'),
|
||||
)
|
||||
|
||||
|
||||
# Kill chain phases in order (for sorting and validation)
|
||||
KILL_CHAIN_PHASES = [
|
||||
"reconnaissance",
|
||||
"resource_development",
|
||||
"initial_access",
|
||||
"execution",
|
||||
"persistence",
|
||||
"privilege_escalation",
|
||||
"defense_evasion",
|
||||
"credential_access",
|
||||
"discovery",
|
||||
"lateral_movement",
|
||||
"collection",
|
||||
"command_and_control",
|
||||
"exfiltration",
|
||||
"impact",
|
||||
]
|
||||
|
||||
|
||||
class CampaignTest(Base):
|
||||
"""
|
||||
A test within a campaign, with ordering and dependency information.
|
||||
|
||||
``depends_on`` creates a self-referential chain (A -> B -> C).
|
||||
Circular dependencies are validated at the service layer.
|
||||
"""
|
||||
__tablename__ = "campaign_tests"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
campaign_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("campaigns.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
test_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("tests.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
order_index = Column(Integer, nullable=False, default=0)
|
||||
depends_on = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("campaign_tests.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
phase = Column(String, nullable=True) # kill chain phase
|
||||
|
||||
# Relationships
|
||||
campaign = relationship("Campaign", back_populates="campaign_tests")
|
||||
test = relationship("Test")
|
||||
dependency = relationship("CampaignTest", remote_side="CampaignTest.id")
|
||||
|
||||
__table_args__ = (
|
||||
Index('ix_campaign_tests_campaign', 'campaign_id'),
|
||||
Index('ix_campaign_tests_test', 'test_id'),
|
||||
)
|
||||
524
backend/app/routers/campaigns.py
Normal file
524
backend/app/routers/campaigns.py
Normal file
@@ -0,0 +1,524 @@
|
||||
"""Campaign endpoints — CRUD, test management, activation, and auto-generation.
|
||||
|
||||
Provides comprehensive campaign lifecycle management including
|
||||
test ordering, progress tracking, and threat actor integration.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.orm import Session
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import get_current_user, require_any_role
|
||||
from app.models.user import User
|
||||
from app.models.campaign import Campaign, CampaignTest, KILL_CHAIN_PHASES
|
||||
from app.models.test import Test
|
||||
from app.models.technique import Technique
|
||||
from app.models.threat_actor import ThreatActor
|
||||
from app.services.campaign_service import (
|
||||
validate_no_circular_dependency,
|
||||
get_campaign_progress,
|
||||
generate_campaign_from_threat_actor,
|
||||
)
|
||||
from app.services.notification_service import create_notification
|
||||
from app.services.audit_service import log_action
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/campaigns", tags=["campaigns"])
|
||||
|
||||
|
||||
# ── Pydantic schemas ─────────────────────────────────────────────────
|
||||
|
||||
class CampaignCreate(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
type: str = "custom"
|
||||
threat_actor_id: Optional[str] = None
|
||||
target_platform: Optional[str] = None
|
||||
tags: Optional[list[str]] = Field(default_factory=list)
|
||||
scheduled_at: Optional[str] = None
|
||||
|
||||
class CampaignUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
type: Optional[str] = None
|
||||
target_platform: Optional[str] = None
|
||||
tags: Optional[list[str]] = None
|
||||
scheduled_at: Optional[str] = None
|
||||
|
||||
class AddTestPayload(BaseModel):
|
||||
test_id: str
|
||||
order_index: Optional[int] = None
|
||||
depends_on: Optional[str] = None
|
||||
phase: Optional[str] = None
|
||||
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────────────────────
|
||||
|
||||
def _serialize_campaign(db: Session, campaign: Campaign) -> dict:
|
||||
"""Serialize a campaign with its tests and progress."""
|
||||
progress = get_campaign_progress(db, campaign.id)
|
||||
|
||||
campaign_tests = (
|
||||
db.query(CampaignTest)
|
||||
.filter(CampaignTest.campaign_id == campaign.id)
|
||||
.order_by(CampaignTest.order_index)
|
||||
.all()
|
||||
)
|
||||
|
||||
tests = []
|
||||
for ct in campaign_tests:
|
||||
test = ct.test
|
||||
technique = db.query(Technique).filter(Technique.id == test.technique_id).first() if test else None
|
||||
|
||||
tests.append({
|
||||
"id": str(ct.id),
|
||||
"test_id": str(ct.test_id),
|
||||
"order_index": ct.order_index,
|
||||
"depends_on": str(ct.depends_on) if ct.depends_on else None,
|
||||
"phase": ct.phase,
|
||||
"test_name": test.name if test else None,
|
||||
"test_state": test.state.value if test and test.state else None,
|
||||
"test_result": test.result.value if test and test.result else None,
|
||||
"technique_mitre_id": technique.mitre_id if technique else None,
|
||||
"technique_name": technique.name if technique else None,
|
||||
"platform": test.platform if test else None,
|
||||
})
|
||||
|
||||
actor = campaign.threat_actor
|
||||
|
||||
return {
|
||||
"id": str(campaign.id),
|
||||
"name": campaign.name,
|
||||
"description": campaign.description,
|
||||
"type": campaign.type,
|
||||
"status": campaign.status,
|
||||
"threat_actor_id": str(campaign.threat_actor_id) if campaign.threat_actor_id else None,
|
||||
"threat_actor_name": actor.name if actor else None,
|
||||
"created_by": str(campaign.created_by) if campaign.created_by else None,
|
||||
"scheduled_at": campaign.scheduled_at.isoformat() if campaign.scheduled_at else None,
|
||||
"completed_at": campaign.completed_at.isoformat() if campaign.completed_at else None,
|
||||
"target_platform": campaign.target_platform,
|
||||
"tags": campaign.tags or [],
|
||||
"created_at": campaign.created_at.isoformat() if campaign.created_at else None,
|
||||
"tests": tests,
|
||||
"progress": progress,
|
||||
}
|
||||
|
||||
|
||||
def _serialize_campaign_summary(db: Session, campaign: Campaign) -> dict:
|
||||
"""Lightweight campaign serialization for list views."""
|
||||
progress = get_campaign_progress(db, campaign.id)
|
||||
actor = campaign.threat_actor
|
||||
|
||||
return {
|
||||
"id": str(campaign.id),
|
||||
"name": campaign.name,
|
||||
"description": campaign.description,
|
||||
"type": campaign.type,
|
||||
"status": campaign.status,
|
||||
"threat_actor_id": str(campaign.threat_actor_id) if campaign.threat_actor_id else None,
|
||||
"threat_actor_name": actor.name if actor else None,
|
||||
"target_platform": campaign.target_platform,
|
||||
"tags": campaign.tags or [],
|
||||
"created_at": campaign.created_at.isoformat() if campaign.created_at else None,
|
||||
"test_count": progress["total"],
|
||||
"completion_pct": progress["completion_pct"],
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /campaigns — List campaigns with filters
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("")
|
||||
def list_campaigns(
|
||||
type: Optional[str] = Query(None),
|
||||
status: Optional[str] = Query(None),
|
||||
threat_actor_id: Optional[str] = Query(None),
|
||||
search: Optional[str] = Query(None),
|
||||
offset: int = Query(0, ge=0),
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""List campaigns with optional filters and pagination."""
|
||||
query = db.query(Campaign)
|
||||
|
||||
if type:
|
||||
query = query.filter(Campaign.type == type)
|
||||
if status:
|
||||
query = query.filter(Campaign.status == status)
|
||||
if threat_actor_id:
|
||||
query = query.filter(Campaign.threat_actor_id == threat_actor_id)
|
||||
if search:
|
||||
pattern = f"%{search}%"
|
||||
query = query.filter(Campaign.name.ilike(pattern) | Campaign.description.ilike(pattern))
|
||||
|
||||
total = query.count()
|
||||
campaigns = query.order_by(Campaign.created_at.desc()).offset(offset).limit(limit).all()
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"offset": offset,
|
||||
"limit": limit,
|
||||
"items": [_serialize_campaign_summary(db, c) for c in campaigns],
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /campaigns — Create campaign
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("", status_code=201)
|
||||
def create_campaign(
|
||||
payload: CampaignCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_any_role("red_tech", "admin")),
|
||||
):
|
||||
"""Create a new campaign."""
|
||||
campaign = Campaign(
|
||||
name=payload.name,
|
||||
description=payload.description,
|
||||
type=payload.type,
|
||||
threat_actor_id=payload.threat_actor_id,
|
||||
target_platform=payload.target_platform,
|
||||
tags=payload.tags or [],
|
||||
created_by=current_user.id,
|
||||
scheduled_at=datetime.fromisoformat(payload.scheduled_at) if payload.scheduled_at else None,
|
||||
)
|
||||
db.add(campaign)
|
||||
db.commit()
|
||||
db.refresh(campaign)
|
||||
|
||||
log_action(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
action="create_campaign",
|
||||
entity_type="campaign",
|
||||
entity_id=campaign.id,
|
||||
details={"name": campaign.name, "type": campaign.type},
|
||||
)
|
||||
|
||||
return _serialize_campaign(db, campaign)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /campaigns/{id} — Detail with tests and progress
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/{campaign_id}")
|
||||
def get_campaign(
|
||||
campaign_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Get detailed campaign info including tests and progress."""
|
||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||
if not campaign:
|
||||
raise HTTPException(status_code=404, detail="Campaign not found")
|
||||
|
||||
return _serialize_campaign(db, campaign)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PATCH /campaigns/{id} — Update campaign
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.patch("/{campaign_id}")
|
||||
def update_campaign(
|
||||
campaign_id: str,
|
||||
payload: CampaignUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_any_role("red_tech", "admin")),
|
||||
):
|
||||
"""Update a campaign. Only allowed in draft or active state."""
|
||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||
if not campaign:
|
||||
raise HTTPException(status_code=404, detail="Campaign not found")
|
||||
|
||||
if campaign.status not in ("draft", "active"):
|
||||
raise HTTPException(status_code=400, detail="Can only update draft or active campaigns")
|
||||
|
||||
# Check ownership or admin
|
||||
if str(campaign.created_by) != str(current_user.id) and current_user.role != "admin":
|
||||
raise HTTPException(status_code=403, detail="Only the creator or admin can update this campaign")
|
||||
|
||||
update_data = payload.model_dump(exclude_unset=True)
|
||||
if "scheduled_at" in update_data and update_data["scheduled_at"]:
|
||||
update_data["scheduled_at"] = datetime.fromisoformat(update_data["scheduled_at"])
|
||||
|
||||
for field, value in update_data.items():
|
||||
setattr(campaign, field, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(campaign)
|
||||
|
||||
log_action(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
action="update_campaign",
|
||||
entity_type="campaign",
|
||||
entity_id=campaign.id,
|
||||
details={"updated_fields": list(update_data.keys())},
|
||||
)
|
||||
|
||||
return _serialize_campaign(db, campaign)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /campaigns/{id}/tests — Add test to campaign
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/{campaign_id}/tests")
|
||||
def add_test_to_campaign(
|
||||
campaign_id: str,
|
||||
payload: AddTestPayload,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_any_role("red_tech", "admin")),
|
||||
):
|
||||
"""Add a test to a campaign with optional ordering and dependency."""
|
||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||
if not campaign:
|
||||
raise HTTPException(status_code=404, detail="Campaign not found")
|
||||
|
||||
if campaign.status not in ("draft", "active"):
|
||||
raise HTTPException(status_code=400, detail="Can only add tests to draft or active campaigns")
|
||||
|
||||
test = db.query(Test).filter(Test.id == payload.test_id).first()
|
||||
if not test:
|
||||
raise HTTPException(status_code=404, detail="Test not found")
|
||||
|
||||
# Calculate order_index if not provided
|
||||
if payload.order_index is not None:
|
||||
order_index = payload.order_index
|
||||
else:
|
||||
max_order = (
|
||||
db.query(CampaignTest.order_index)
|
||||
.filter(CampaignTest.campaign_id == campaign_id)
|
||||
.order_by(CampaignTest.order_index.desc())
|
||||
.first()
|
||||
)
|
||||
order_index = (max_order[0] + 1) if max_order else 0
|
||||
|
||||
depends_on = uuid.UUID(payload.depends_on) if payload.depends_on else None
|
||||
|
||||
# Validate circular dependency
|
||||
ct_id = uuid.uuid4()
|
||||
if depends_on:
|
||||
validate_no_circular_dependency(db, uuid.UUID(campaign_id), ct_id, depends_on)
|
||||
|
||||
campaign_test = CampaignTest(
|
||||
id=ct_id,
|
||||
campaign_id=campaign_id,
|
||||
test_id=payload.test_id,
|
||||
order_index=order_index,
|
||||
depends_on=depends_on,
|
||||
phase=payload.phase,
|
||||
)
|
||||
db.add(campaign_test)
|
||||
db.commit()
|
||||
db.refresh(campaign_test)
|
||||
|
||||
return {
|
||||
"id": str(campaign_test.id),
|
||||
"campaign_id": str(campaign_test.campaign_id),
|
||||
"test_id": str(campaign_test.test_id),
|
||||
"order_index": campaign_test.order_index,
|
||||
"depends_on": str(campaign_test.depends_on) if campaign_test.depends_on else None,
|
||||
"phase": campaign_test.phase,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DELETE /campaigns/{id}/tests/{campaign_test_id} — Remove test from campaign
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.delete("/{campaign_id}/tests/{campaign_test_id}")
|
||||
def remove_test_from_campaign(
|
||||
campaign_id: str,
|
||||
campaign_test_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_any_role("red_tech", "admin")),
|
||||
):
|
||||
"""Remove a test from a campaign."""
|
||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||
if not campaign:
|
||||
raise HTTPException(status_code=404, detail="Campaign not found")
|
||||
|
||||
if campaign.status not in ("draft", "active"):
|
||||
raise HTTPException(status_code=400, detail="Can only modify draft or active campaigns")
|
||||
|
||||
ct = (
|
||||
db.query(CampaignTest)
|
||||
.filter(
|
||||
CampaignTest.id == campaign_test_id,
|
||||
CampaignTest.campaign_id == campaign_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not ct:
|
||||
raise HTTPException(status_code=404, detail="Campaign test not found")
|
||||
|
||||
# Clear any references to this campaign_test
|
||||
dependents = (
|
||||
db.query(CampaignTest)
|
||||
.filter(CampaignTest.depends_on == campaign_test_id)
|
||||
.all()
|
||||
)
|
||||
for dep in dependents:
|
||||
dep.depends_on = None
|
||||
|
||||
db.delete(ct)
|
||||
db.commit()
|
||||
|
||||
return {"detail": "Test removed from campaign"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /campaigns/{id}/activate — Activate campaign
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/{campaign_id}/activate")
|
||||
def activate_campaign(
|
||||
campaign_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_any_role("red_tech", "admin")),
|
||||
):
|
||||
"""Activate a campaign, moving it from draft to active."""
|
||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||
if not campaign:
|
||||
raise HTTPException(status_code=404, detail="Campaign not found")
|
||||
|
||||
if campaign.status != "draft":
|
||||
raise HTTPException(status_code=400, detail="Only draft campaigns can be activated")
|
||||
|
||||
# Verify campaign has at least one test
|
||||
test_count = db.query(CampaignTest).filter(CampaignTest.campaign_id == campaign_id).count()
|
||||
if test_count == 0:
|
||||
raise HTTPException(status_code=400, detail="Campaign must have at least one test to activate")
|
||||
|
||||
campaign.status = "active"
|
||||
db.commit()
|
||||
db.refresh(campaign)
|
||||
|
||||
# Notify relevant users
|
||||
red_techs = db.query(User).filter(User.role == "red_tech", User.is_active == True).all() # noqa: E712
|
||||
for user in red_techs:
|
||||
create_notification(
|
||||
db,
|
||||
user_id=user.id,
|
||||
type="campaign_activated",
|
||||
title="Campaign activated",
|
||||
message=f'Campaign "{campaign.name}" has been activated.',
|
||||
entity_type="campaign",
|
||||
entity_id=campaign.id,
|
||||
)
|
||||
|
||||
log_action(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
action="activate_campaign",
|
||||
entity_type="campaign",
|
||||
entity_id=campaign.id,
|
||||
details={"name": campaign.name},
|
||||
)
|
||||
|
||||
return _serialize_campaign(db, campaign)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /campaigns/{id}/complete — Mark campaign as completed
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/{campaign_id}/complete")
|
||||
def complete_campaign(
|
||||
campaign_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_any_role("red_lead", "admin")),
|
||||
):
|
||||
"""Mark a campaign as completed."""
|
||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||
if not campaign:
|
||||
raise HTTPException(status_code=404, detail="Campaign not found")
|
||||
|
||||
if campaign.status != "active":
|
||||
raise HTTPException(status_code=400, detail="Only active campaigns can be completed")
|
||||
|
||||
campaign.status = "completed"
|
||||
campaign.completed_at = datetime.utcnow()
|
||||
db.commit()
|
||||
db.refresh(campaign)
|
||||
|
||||
log_action(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
action="complete_campaign",
|
||||
entity_type="campaign",
|
||||
entity_id=campaign.id,
|
||||
details={"name": campaign.name},
|
||||
)
|
||||
|
||||
return _serialize_campaign(db, campaign)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /campaigns/{id}/progress — Campaign progress
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/{campaign_id}/progress")
|
||||
def get_campaign_progress_endpoint(
|
||||
campaign_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Get progress statistics for a campaign."""
|
||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||
if not campaign:
|
||||
raise HTTPException(status_code=404, detail="Campaign not found")
|
||||
|
||||
progress = get_campaign_progress(db, uuid.UUID(campaign_id))
|
||||
return {
|
||||
"campaign_id": str(campaign.id),
|
||||
"campaign_name": campaign.name,
|
||||
**progress,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /campaigns/from-threat-actor/{actor_id} — Auto-generate campaign
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/from-threat-actor/{actor_id}", status_code=201)
|
||||
def generate_campaign_from_actor(
|
||||
actor_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_any_role("red_tech", "admin")),
|
||||
):
|
||||
"""Auto-generate a campaign from a threat actor's uncovered techniques.
|
||||
|
||||
Creates tests from the best available templates and orders them
|
||||
by kill chain phase.
|
||||
"""
|
||||
campaign = generate_campaign_from_threat_actor(
|
||||
db,
|
||||
uuid.UUID(actor_id),
|
||||
current_user,
|
||||
)
|
||||
|
||||
log_action(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
action="generate_campaign",
|
||||
entity_type="campaign",
|
||||
entity_id=campaign.id,
|
||||
details={"actor_id": actor_id, "campaign_name": campaign.name},
|
||||
)
|
||||
|
||||
return _serialize_campaign(db, campaign)
|
||||
213
backend/app/services/campaign_service.py
Normal file
213
backend/app/services/campaign_service.py
Normal file
@@ -0,0 +1,213 @@
|
||||
"""Campaign service — business logic for campaign management.
|
||||
|
||||
Handles circular dependency validation, campaign generation from
|
||||
threat actors, and progress calculation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.campaign import Campaign, CampaignTest, KILL_CHAIN_PHASES
|
||||
from app.models.test import Test
|
||||
from app.models.test_template import TestTemplate
|
||||
from app.models.technique import Technique
|
||||
from app.models.threat_actor import ThreatActor, ThreatActorTechnique
|
||||
from app.models.enums import TechniqueStatus, TestState
|
||||
from app.services.notification_service import create_notification
|
||||
from app.models.user import User
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Mapping from ATT&CK tactics to kill chain phases
|
||||
TACTIC_TO_PHASE: dict[str, str] = {
|
||||
"reconnaissance": "reconnaissance",
|
||||
"resource-development": "resource_development",
|
||||
"initial-access": "initial_access",
|
||||
"execution": "execution",
|
||||
"persistence": "persistence",
|
||||
"privilege-escalation": "privilege_escalation",
|
||||
"defense-evasion": "defense_evasion",
|
||||
"credential-access": "credential_access",
|
||||
"discovery": "discovery",
|
||||
"lateral-movement": "lateral_movement",
|
||||
"collection": "collection",
|
||||
"command-and-control": "command_and_control",
|
||||
"exfiltration": "exfiltration",
|
||||
"impact": "impact",
|
||||
}
|
||||
|
||||
|
||||
def validate_no_circular_dependency(
|
||||
db: Session,
|
||||
campaign_id: uuid.UUID,
|
||||
test_id: uuid.UUID,
|
||||
depends_on_id: uuid.UUID | None,
|
||||
) -> None:
|
||||
"""Walk the depends_on chain and verify no cycle is formed.
|
||||
|
||||
Raises HTTPException(400) if a circular dependency is detected.
|
||||
"""
|
||||
if depends_on_id is None:
|
||||
return
|
||||
|
||||
visited: set[uuid.UUID] = set()
|
||||
current = depends_on_id
|
||||
|
||||
while current is not None:
|
||||
if current in visited or current == test_id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Circular dependency detected in campaign test chain",
|
||||
)
|
||||
visited.add(current)
|
||||
parent = db.query(CampaignTest).filter_by(id=current).first()
|
||||
current = parent.depends_on if parent else None
|
||||
|
||||
|
||||
def get_campaign_progress(db: Session, campaign_id: uuid.UUID) -> dict:
|
||||
"""Calculate progress statistics for a campaign.
|
||||
|
||||
Returns counts of tests by state, plus total and completion percentage.
|
||||
"""
|
||||
campaign_tests = (
|
||||
db.query(CampaignTest)
|
||||
.filter(CampaignTest.campaign_id == campaign_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
if not campaign_tests:
|
||||
return {
|
||||
"total": 0,
|
||||
"by_state": {},
|
||||
"completion_pct": 0.0,
|
||||
}
|
||||
|
||||
by_state: dict[str, int] = {}
|
||||
for ct in campaign_tests:
|
||||
test = ct.test
|
||||
state = test.state.value if test and test.state else "unknown"
|
||||
by_state[state] = by_state.get(state, 0) + 1
|
||||
|
||||
total = len(campaign_tests)
|
||||
completed = by_state.get("validated", 0)
|
||||
completion_pct = round(completed / total * 100, 1) if total > 0 else 0.0
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"by_state": by_state,
|
||||
"completion_pct": completion_pct,
|
||||
}
|
||||
|
||||
|
||||
def generate_campaign_from_threat_actor(
|
||||
db: Session,
|
||||
actor_id: uuid.UUID,
|
||||
user: User,
|
||||
) -> Campaign:
|
||||
"""Auto-generate a campaign from a threat actor's uncovered techniques.
|
||||
|
||||
Steps:
|
||||
1. Get techniques of the actor that are NOT validated
|
||||
2. For each, find the best template (highest severity)
|
||||
3. Create a test from each template
|
||||
4. Create a campaign with tests ordered by kill chain phase
|
||||
5. Return the campaign
|
||||
"""
|
||||
actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first()
|
||||
if not actor:
|
||||
raise HTTPException(status_code=404, detail="Threat actor not found")
|
||||
|
||||
# Get unvalidated techniques for this actor
|
||||
gap_techniques = (
|
||||
db.query(Technique, ThreatActorTechnique)
|
||||
.join(ThreatActorTechnique, ThreatActorTechnique.technique_id == Technique.id)
|
||||
.filter(ThreatActorTechnique.threat_actor_id == actor_id)
|
||||
.filter(Technique.status_global != TechniqueStatus.validated)
|
||||
.order_by(Technique.tactic, Technique.mitre_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
if not gap_techniques:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"No uncovered techniques found for {actor.name}",
|
||||
)
|
||||
|
||||
# Create the campaign
|
||||
campaign = Campaign(
|
||||
name=f"APT Emulation: {actor.name}",
|
||||
description=f"Auto-generated campaign to test coverage against {actor.name} "
|
||||
f"({actor.mitre_id or 'unknown'}). "
|
||||
f"Covers {len(gap_techniques)} uncovered technique(s).",
|
||||
type="apt_emulation",
|
||||
threat_actor_id=actor_id,
|
||||
status="draft",
|
||||
created_by=user.id,
|
||||
tags=[actor.name, "auto-generated"],
|
||||
)
|
||||
db.add(campaign)
|
||||
db.flush() # Get campaign.id
|
||||
|
||||
order_index = 0
|
||||
|
||||
for tech, _at in gap_techniques:
|
||||
# Find best template for this technique
|
||||
template = (
|
||||
db.query(TestTemplate)
|
||||
.filter(
|
||||
TestTemplate.mitre_technique_id == tech.mitre_id,
|
||||
TestTemplate.is_active == True, # noqa: E712
|
||||
)
|
||||
.order_by(
|
||||
# Prioritize by severity: critical > high > medium > low
|
||||
TestTemplate.severity.desc(),
|
||||
TestTemplate.name,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not template:
|
||||
continue # Skip techniques without templates
|
||||
|
||||
# Create a test from the template
|
||||
test = Test(
|
||||
technique_id=tech.id,
|
||||
name=f"[Campaign] {template.name}",
|
||||
description=template.description,
|
||||
platform=template.platform,
|
||||
procedure_text=template.attack_procedure,
|
||||
tool_used=template.tool_suggested,
|
||||
created_by=user.id,
|
||||
state=TestState.draft,
|
||||
)
|
||||
db.add(test)
|
||||
db.flush() # Get test.id
|
||||
|
||||
# Determine kill chain phase from the technique's tactic
|
||||
phase = TACTIC_TO_PHASE.get(tech.tactic, None) if tech.tactic else None
|
||||
|
||||
# Add to campaign
|
||||
campaign_test = CampaignTest(
|
||||
campaign_id=campaign.id,
|
||||
test_id=test.id,
|
||||
order_index=order_index,
|
||||
phase=phase,
|
||||
)
|
||||
db.add(campaign_test)
|
||||
order_index += 1
|
||||
|
||||
db.commit()
|
||||
db.refresh(campaign)
|
||||
|
||||
logger.info(
|
||||
"Generated campaign '%s' with %d tests for actor %s",
|
||||
campaign.name,
|
||||
order_index,
|
||||
actor.name,
|
||||
)
|
||||
|
||||
return campaign
|
||||
Reference in New Issue
Block a user