feat(phase-26): add Campaign models, endpoints, service with kill chain timeline UI (T-217 to T-220)
This commit is contained in:
524
backend/app/routers/campaigns.py
Normal file
524
backend/app/routers/campaigns.py
Normal file
@@ -0,0 +1,524 @@
|
||||
"""Campaign endpoints — CRUD, test management, activation, and auto-generation.
|
||||
|
||||
Provides comprehensive campaign lifecycle management including
|
||||
test ordering, progress tracking, and threat actor integration.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.orm import Session
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import get_current_user, require_any_role
|
||||
from app.models.user import User
|
||||
from app.models.campaign import Campaign, CampaignTest, KILL_CHAIN_PHASES
|
||||
from app.models.test import Test
|
||||
from app.models.technique import Technique
|
||||
from app.models.threat_actor import ThreatActor
|
||||
from app.services.campaign_service import (
|
||||
validate_no_circular_dependency,
|
||||
get_campaign_progress,
|
||||
generate_campaign_from_threat_actor,
|
||||
)
|
||||
from app.services.notification_service import create_notification
|
||||
from app.services.audit_service import log_action
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/campaigns", tags=["campaigns"])
|
||||
|
||||
|
||||
# ── Pydantic schemas ─────────────────────────────────────────────────
|
||||
|
||||
class CampaignCreate(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
type: str = "custom"
|
||||
threat_actor_id: Optional[str] = None
|
||||
target_platform: Optional[str] = None
|
||||
tags: Optional[list[str]] = Field(default_factory=list)
|
||||
scheduled_at: Optional[str] = None
|
||||
|
||||
class CampaignUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
type: Optional[str] = None
|
||||
target_platform: Optional[str] = None
|
||||
tags: Optional[list[str]] = None
|
||||
scheduled_at: Optional[str] = None
|
||||
|
||||
class AddTestPayload(BaseModel):
|
||||
test_id: str
|
||||
order_index: Optional[int] = None
|
||||
depends_on: Optional[str] = None
|
||||
phase: Optional[str] = None
|
||||
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────────────────────
|
||||
|
||||
def _serialize_campaign(db: Session, campaign: Campaign) -> dict:
|
||||
"""Serialize a campaign with its tests and progress."""
|
||||
progress = get_campaign_progress(db, campaign.id)
|
||||
|
||||
campaign_tests = (
|
||||
db.query(CampaignTest)
|
||||
.filter(CampaignTest.campaign_id == campaign.id)
|
||||
.order_by(CampaignTest.order_index)
|
||||
.all()
|
||||
)
|
||||
|
||||
tests = []
|
||||
for ct in campaign_tests:
|
||||
test = ct.test
|
||||
technique = db.query(Technique).filter(Technique.id == test.technique_id).first() if test else None
|
||||
|
||||
tests.append({
|
||||
"id": str(ct.id),
|
||||
"test_id": str(ct.test_id),
|
||||
"order_index": ct.order_index,
|
||||
"depends_on": str(ct.depends_on) if ct.depends_on else None,
|
||||
"phase": ct.phase,
|
||||
"test_name": test.name if test else None,
|
||||
"test_state": test.state.value if test and test.state else None,
|
||||
"test_result": test.result.value if test and test.result else None,
|
||||
"technique_mitre_id": technique.mitre_id if technique else None,
|
||||
"technique_name": technique.name if technique else None,
|
||||
"platform": test.platform if test else None,
|
||||
})
|
||||
|
||||
actor = campaign.threat_actor
|
||||
|
||||
return {
|
||||
"id": str(campaign.id),
|
||||
"name": campaign.name,
|
||||
"description": campaign.description,
|
||||
"type": campaign.type,
|
||||
"status": campaign.status,
|
||||
"threat_actor_id": str(campaign.threat_actor_id) if campaign.threat_actor_id else None,
|
||||
"threat_actor_name": actor.name if actor else None,
|
||||
"created_by": str(campaign.created_by) if campaign.created_by else None,
|
||||
"scheduled_at": campaign.scheduled_at.isoformat() if campaign.scheduled_at else None,
|
||||
"completed_at": campaign.completed_at.isoformat() if campaign.completed_at else None,
|
||||
"target_platform": campaign.target_platform,
|
||||
"tags": campaign.tags or [],
|
||||
"created_at": campaign.created_at.isoformat() if campaign.created_at else None,
|
||||
"tests": tests,
|
||||
"progress": progress,
|
||||
}
|
||||
|
||||
|
||||
def _serialize_campaign_summary(db: Session, campaign: Campaign) -> dict:
|
||||
"""Lightweight campaign serialization for list views."""
|
||||
progress = get_campaign_progress(db, campaign.id)
|
||||
actor = campaign.threat_actor
|
||||
|
||||
return {
|
||||
"id": str(campaign.id),
|
||||
"name": campaign.name,
|
||||
"description": campaign.description,
|
||||
"type": campaign.type,
|
||||
"status": campaign.status,
|
||||
"threat_actor_id": str(campaign.threat_actor_id) if campaign.threat_actor_id else None,
|
||||
"threat_actor_name": actor.name if actor else None,
|
||||
"target_platform": campaign.target_platform,
|
||||
"tags": campaign.tags or [],
|
||||
"created_at": campaign.created_at.isoformat() if campaign.created_at else None,
|
||||
"test_count": progress["total"],
|
||||
"completion_pct": progress["completion_pct"],
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /campaigns — List campaigns with filters
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("")
|
||||
def list_campaigns(
|
||||
type: Optional[str] = Query(None),
|
||||
status: Optional[str] = Query(None),
|
||||
threat_actor_id: Optional[str] = Query(None),
|
||||
search: Optional[str] = Query(None),
|
||||
offset: int = Query(0, ge=0),
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""List campaigns with optional filters and pagination."""
|
||||
query = db.query(Campaign)
|
||||
|
||||
if type:
|
||||
query = query.filter(Campaign.type == type)
|
||||
if status:
|
||||
query = query.filter(Campaign.status == status)
|
||||
if threat_actor_id:
|
||||
query = query.filter(Campaign.threat_actor_id == threat_actor_id)
|
||||
if search:
|
||||
pattern = f"%{search}%"
|
||||
query = query.filter(Campaign.name.ilike(pattern) | Campaign.description.ilike(pattern))
|
||||
|
||||
total = query.count()
|
||||
campaigns = query.order_by(Campaign.created_at.desc()).offset(offset).limit(limit).all()
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"offset": offset,
|
||||
"limit": limit,
|
||||
"items": [_serialize_campaign_summary(db, c) for c in campaigns],
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /campaigns — Create campaign
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("", status_code=201)
|
||||
def create_campaign(
|
||||
payload: CampaignCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_any_role("red_tech", "admin")),
|
||||
):
|
||||
"""Create a new campaign."""
|
||||
campaign = Campaign(
|
||||
name=payload.name,
|
||||
description=payload.description,
|
||||
type=payload.type,
|
||||
threat_actor_id=payload.threat_actor_id,
|
||||
target_platform=payload.target_platform,
|
||||
tags=payload.tags or [],
|
||||
created_by=current_user.id,
|
||||
scheduled_at=datetime.fromisoformat(payload.scheduled_at) if payload.scheduled_at else None,
|
||||
)
|
||||
db.add(campaign)
|
||||
db.commit()
|
||||
db.refresh(campaign)
|
||||
|
||||
log_action(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
action="create_campaign",
|
||||
entity_type="campaign",
|
||||
entity_id=campaign.id,
|
||||
details={"name": campaign.name, "type": campaign.type},
|
||||
)
|
||||
|
||||
return _serialize_campaign(db, campaign)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /campaigns/{id} — Detail with tests and progress
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/{campaign_id}")
|
||||
def get_campaign(
|
||||
campaign_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Get detailed campaign info including tests and progress."""
|
||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||
if not campaign:
|
||||
raise HTTPException(status_code=404, detail="Campaign not found")
|
||||
|
||||
return _serialize_campaign(db, campaign)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PATCH /campaigns/{id} — Update campaign
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.patch("/{campaign_id}")
|
||||
def update_campaign(
|
||||
campaign_id: str,
|
||||
payload: CampaignUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_any_role("red_tech", "admin")),
|
||||
):
|
||||
"""Update a campaign. Only allowed in draft or active state."""
|
||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||
if not campaign:
|
||||
raise HTTPException(status_code=404, detail="Campaign not found")
|
||||
|
||||
if campaign.status not in ("draft", "active"):
|
||||
raise HTTPException(status_code=400, detail="Can only update draft or active campaigns")
|
||||
|
||||
# Check ownership or admin
|
||||
if str(campaign.created_by) != str(current_user.id) and current_user.role != "admin":
|
||||
raise HTTPException(status_code=403, detail="Only the creator or admin can update this campaign")
|
||||
|
||||
update_data = payload.model_dump(exclude_unset=True)
|
||||
if "scheduled_at" in update_data and update_data["scheduled_at"]:
|
||||
update_data["scheduled_at"] = datetime.fromisoformat(update_data["scheduled_at"])
|
||||
|
||||
for field, value in update_data.items():
|
||||
setattr(campaign, field, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(campaign)
|
||||
|
||||
log_action(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
action="update_campaign",
|
||||
entity_type="campaign",
|
||||
entity_id=campaign.id,
|
||||
details={"updated_fields": list(update_data.keys())},
|
||||
)
|
||||
|
||||
return _serialize_campaign(db, campaign)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /campaigns/{id}/tests — Add test to campaign
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/{campaign_id}/tests")
|
||||
def add_test_to_campaign(
|
||||
campaign_id: str,
|
||||
payload: AddTestPayload,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_any_role("red_tech", "admin")),
|
||||
):
|
||||
"""Add a test to a campaign with optional ordering and dependency."""
|
||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||
if not campaign:
|
||||
raise HTTPException(status_code=404, detail="Campaign not found")
|
||||
|
||||
if campaign.status not in ("draft", "active"):
|
||||
raise HTTPException(status_code=400, detail="Can only add tests to draft or active campaigns")
|
||||
|
||||
test = db.query(Test).filter(Test.id == payload.test_id).first()
|
||||
if not test:
|
||||
raise HTTPException(status_code=404, detail="Test not found")
|
||||
|
||||
# Calculate order_index if not provided
|
||||
if payload.order_index is not None:
|
||||
order_index = payload.order_index
|
||||
else:
|
||||
max_order = (
|
||||
db.query(CampaignTest.order_index)
|
||||
.filter(CampaignTest.campaign_id == campaign_id)
|
||||
.order_by(CampaignTest.order_index.desc())
|
||||
.first()
|
||||
)
|
||||
order_index = (max_order[0] + 1) if max_order else 0
|
||||
|
||||
depends_on = uuid.UUID(payload.depends_on) if payload.depends_on else None
|
||||
|
||||
# Validate circular dependency
|
||||
ct_id = uuid.uuid4()
|
||||
if depends_on:
|
||||
validate_no_circular_dependency(db, uuid.UUID(campaign_id), ct_id, depends_on)
|
||||
|
||||
campaign_test = CampaignTest(
|
||||
id=ct_id,
|
||||
campaign_id=campaign_id,
|
||||
test_id=payload.test_id,
|
||||
order_index=order_index,
|
||||
depends_on=depends_on,
|
||||
phase=payload.phase,
|
||||
)
|
||||
db.add(campaign_test)
|
||||
db.commit()
|
||||
db.refresh(campaign_test)
|
||||
|
||||
return {
|
||||
"id": str(campaign_test.id),
|
||||
"campaign_id": str(campaign_test.campaign_id),
|
||||
"test_id": str(campaign_test.test_id),
|
||||
"order_index": campaign_test.order_index,
|
||||
"depends_on": str(campaign_test.depends_on) if campaign_test.depends_on else None,
|
||||
"phase": campaign_test.phase,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DELETE /campaigns/{id}/tests/{campaign_test_id} — Remove test from campaign
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.delete("/{campaign_id}/tests/{campaign_test_id}")
|
||||
def remove_test_from_campaign(
|
||||
campaign_id: str,
|
||||
campaign_test_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_any_role("red_tech", "admin")),
|
||||
):
|
||||
"""Remove a test from a campaign."""
|
||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||
if not campaign:
|
||||
raise HTTPException(status_code=404, detail="Campaign not found")
|
||||
|
||||
if campaign.status not in ("draft", "active"):
|
||||
raise HTTPException(status_code=400, detail="Can only modify draft or active campaigns")
|
||||
|
||||
ct = (
|
||||
db.query(CampaignTest)
|
||||
.filter(
|
||||
CampaignTest.id == campaign_test_id,
|
||||
CampaignTest.campaign_id == campaign_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not ct:
|
||||
raise HTTPException(status_code=404, detail="Campaign test not found")
|
||||
|
||||
# Clear any references to this campaign_test
|
||||
dependents = (
|
||||
db.query(CampaignTest)
|
||||
.filter(CampaignTest.depends_on == campaign_test_id)
|
||||
.all()
|
||||
)
|
||||
for dep in dependents:
|
||||
dep.depends_on = None
|
||||
|
||||
db.delete(ct)
|
||||
db.commit()
|
||||
|
||||
return {"detail": "Test removed from campaign"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /campaigns/{id}/activate — Activate campaign
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/{campaign_id}/activate")
|
||||
def activate_campaign(
|
||||
campaign_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_any_role("red_tech", "admin")),
|
||||
):
|
||||
"""Activate a campaign, moving it from draft to active."""
|
||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||
if not campaign:
|
||||
raise HTTPException(status_code=404, detail="Campaign not found")
|
||||
|
||||
if campaign.status != "draft":
|
||||
raise HTTPException(status_code=400, detail="Only draft campaigns can be activated")
|
||||
|
||||
# Verify campaign has at least one test
|
||||
test_count = db.query(CampaignTest).filter(CampaignTest.campaign_id == campaign_id).count()
|
||||
if test_count == 0:
|
||||
raise HTTPException(status_code=400, detail="Campaign must have at least one test to activate")
|
||||
|
||||
campaign.status = "active"
|
||||
db.commit()
|
||||
db.refresh(campaign)
|
||||
|
||||
# Notify relevant users
|
||||
red_techs = db.query(User).filter(User.role == "red_tech", User.is_active == True).all() # noqa: E712
|
||||
for user in red_techs:
|
||||
create_notification(
|
||||
db,
|
||||
user_id=user.id,
|
||||
type="campaign_activated",
|
||||
title="Campaign activated",
|
||||
message=f'Campaign "{campaign.name}" has been activated.',
|
||||
entity_type="campaign",
|
||||
entity_id=campaign.id,
|
||||
)
|
||||
|
||||
log_action(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
action="activate_campaign",
|
||||
entity_type="campaign",
|
||||
entity_id=campaign.id,
|
||||
details={"name": campaign.name},
|
||||
)
|
||||
|
||||
return _serialize_campaign(db, campaign)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /campaigns/{id}/complete — Mark campaign as completed
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/{campaign_id}/complete")
|
||||
def complete_campaign(
|
||||
campaign_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_any_role("red_lead", "admin")),
|
||||
):
|
||||
"""Mark a campaign as completed."""
|
||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||
if not campaign:
|
||||
raise HTTPException(status_code=404, detail="Campaign not found")
|
||||
|
||||
if campaign.status != "active":
|
||||
raise HTTPException(status_code=400, detail="Only active campaigns can be completed")
|
||||
|
||||
campaign.status = "completed"
|
||||
campaign.completed_at = datetime.utcnow()
|
||||
db.commit()
|
||||
db.refresh(campaign)
|
||||
|
||||
log_action(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
action="complete_campaign",
|
||||
entity_type="campaign",
|
||||
entity_id=campaign.id,
|
||||
details={"name": campaign.name},
|
||||
)
|
||||
|
||||
return _serialize_campaign(db, campaign)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /campaigns/{id}/progress — Campaign progress
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/{campaign_id}/progress")
|
||||
def get_campaign_progress_endpoint(
|
||||
campaign_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Get progress statistics for a campaign."""
|
||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||
if not campaign:
|
||||
raise HTTPException(status_code=404, detail="Campaign not found")
|
||||
|
||||
progress = get_campaign_progress(db, uuid.UUID(campaign_id))
|
||||
return {
|
||||
"campaign_id": str(campaign.id),
|
||||
"campaign_name": campaign.name,
|
||||
**progress,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /campaigns/from-threat-actor/{actor_id} — Auto-generate campaign
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/from-threat-actor/{actor_id}", status_code=201)
|
||||
def generate_campaign_from_actor(
|
||||
actor_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_any_role("red_tech", "admin")),
|
||||
):
|
||||
"""Auto-generate a campaign from a threat actor's uncovered techniques.
|
||||
|
||||
Creates tests from the best available templates and orders them
|
||||
by kill chain phase.
|
||||
"""
|
||||
campaign = generate_campaign_from_threat_actor(
|
||||
db,
|
||||
uuid.UUID(actor_id),
|
||||
current_user,
|
||||
)
|
||||
|
||||
log_action(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
action="generate_campaign",
|
||||
entity_type="campaign",
|
||||
entity_id=campaign.id,
|
||||
details={"actor_id": actor_id, "campaign_name": campaign.name},
|
||||
)
|
||||
|
||||
return _serialize_campaign(db, campaign)
|
||||
Reference in New Issue
Block a user