"""Campaign endpoints — CRUD, test management, activation, and auto-generation. Provides comprehensive campaign lifecycle management including test ordering, progress tracking, and threat actor integration. """ import logging import uuid from typing import Optional from datetime import datetime from fastapi import APIRouter, Depends, HTTPException, Query from sqlalchemy.orm import Session from pydantic import BaseModel, Field from app.database import get_db from app.dependencies.auth import get_current_user, require_any_role from app.models.user import User from app.models.campaign import Campaign, CampaignTest, KILL_CHAIN_PHASES from app.models.test import Test from app.models.technique import Technique from app.models.threat_actor import ThreatActor from app.services.campaign_service import ( validate_no_circular_dependency, get_campaign_progress, generate_campaign_from_threat_actor, ) from app.services.notification_service import create_notification from app.services.audit_service import log_action logger = logging.getLogger(__name__) router = APIRouter(prefix="/campaigns", tags=["campaigns"]) # ── Pydantic schemas ───────────────────────────────────────────────── class CampaignCreate(BaseModel): name: str description: Optional[str] = None type: str = "custom" threat_actor_id: Optional[str] = None target_platform: Optional[str] = None tags: Optional[list[str]] = Field(default_factory=list) scheduled_at: Optional[str] = None class CampaignUpdate(BaseModel): name: Optional[str] = None description: Optional[str] = None type: Optional[str] = None target_platform: Optional[str] = None tags: Optional[list[str]] = None scheduled_at: Optional[str] = None class AddTestPayload(BaseModel): test_id: str order_index: Optional[int] = None depends_on: Optional[str] = None phase: Optional[str] = None # ── Helpers ────────────────────────────────────────────────────────── def _serialize_campaign(db: Session, campaign: Campaign) -> dict: """Serialize a campaign with its tests and progress.""" progress = get_campaign_progress(db, campaign.id) campaign_tests = ( db.query(CampaignTest) .filter(CampaignTest.campaign_id == campaign.id) .order_by(CampaignTest.order_index) .all() ) tests = [] for ct in campaign_tests: test = ct.test technique = db.query(Technique).filter(Technique.id == test.technique_id).first() if test else None tests.append({ "id": str(ct.id), "test_id": str(ct.test_id), "order_index": ct.order_index, "depends_on": str(ct.depends_on) if ct.depends_on else None, "phase": ct.phase, "test_name": test.name if test else None, "test_state": test.state.value if test and test.state else None, "test_result": test.result.value if test and test.result else None, "technique_mitre_id": technique.mitre_id if technique else None, "technique_name": technique.name if technique else None, "platform": test.platform if test else None, }) actor = campaign.threat_actor return { "id": str(campaign.id), "name": campaign.name, "description": campaign.description, "type": campaign.type, "status": campaign.status, "threat_actor_id": str(campaign.threat_actor_id) if campaign.threat_actor_id else None, "threat_actor_name": actor.name if actor else None, "created_by": str(campaign.created_by) if campaign.created_by else None, "scheduled_at": campaign.scheduled_at.isoformat() if campaign.scheduled_at else None, "completed_at": campaign.completed_at.isoformat() if campaign.completed_at else None, "target_platform": campaign.target_platform, "tags": campaign.tags or [], "created_at": campaign.created_at.isoformat() if campaign.created_at else None, "tests": tests, "progress": progress, } def _serialize_campaign_summary(db: Session, campaign: Campaign) -> dict: """Lightweight campaign serialization for list views.""" progress = get_campaign_progress(db, campaign.id) actor = campaign.threat_actor return { "id": str(campaign.id), "name": campaign.name, "description": campaign.description, "type": campaign.type, "status": campaign.status, "threat_actor_id": str(campaign.threat_actor_id) if campaign.threat_actor_id else None, "threat_actor_name": actor.name if actor else None, "target_platform": campaign.target_platform, "tags": campaign.tags or [], "created_at": campaign.created_at.isoformat() if campaign.created_at else None, "test_count": progress["total"], "completion_pct": progress["completion_pct"], } # --------------------------------------------------------------------------- # GET /campaigns — List campaigns with filters # --------------------------------------------------------------------------- @router.get("") def list_campaigns( type: Optional[str] = Query(None), status: Optional[str] = Query(None), threat_actor_id: Optional[str] = Query(None), search: Optional[str] = Query(None), offset: int = Query(0, ge=0), limit: int = Query(50, ge=1, le=200), db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ): """List campaigns with optional filters and pagination.""" query = db.query(Campaign) if type: query = query.filter(Campaign.type == type) if status: query = query.filter(Campaign.status == status) if threat_actor_id: query = query.filter(Campaign.threat_actor_id == threat_actor_id) if search: pattern = f"%{search}%" query = query.filter(Campaign.name.ilike(pattern) | Campaign.description.ilike(pattern)) total = query.count() campaigns = query.order_by(Campaign.created_at.desc()).offset(offset).limit(limit).all() return { "total": total, "offset": offset, "limit": limit, "items": [_serialize_campaign_summary(db, c) for c in campaigns], } # --------------------------------------------------------------------------- # POST /campaigns — Create campaign # --------------------------------------------------------------------------- @router.post("", status_code=201) def create_campaign( payload: CampaignCreate, db: Session = Depends(get_db), current_user: User = Depends(require_any_role("red_tech", "admin")), ): """Create a new campaign.""" campaign = Campaign( name=payload.name, description=payload.description, type=payload.type, threat_actor_id=payload.threat_actor_id, target_platform=payload.target_platform, tags=payload.tags or [], created_by=current_user.id, scheduled_at=datetime.fromisoformat(payload.scheduled_at) if payload.scheduled_at else None, ) db.add(campaign) db.commit() db.refresh(campaign) log_action( db, user_id=current_user.id, action="create_campaign", entity_type="campaign", entity_id=campaign.id, details={"name": campaign.name, "type": campaign.type}, ) return _serialize_campaign(db, campaign) # --------------------------------------------------------------------------- # GET /campaigns/{id} — Detail with tests and progress # --------------------------------------------------------------------------- @router.get("/{campaign_id}") def get_campaign( campaign_id: str, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ): """Get detailed campaign info including tests and progress.""" campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() if not campaign: raise HTTPException(status_code=404, detail="Campaign not found") return _serialize_campaign(db, campaign) # --------------------------------------------------------------------------- # PATCH /campaigns/{id} — Update campaign # --------------------------------------------------------------------------- @router.patch("/{campaign_id}") def update_campaign( campaign_id: str, payload: CampaignUpdate, db: Session = Depends(get_db), current_user: User = Depends(require_any_role("red_tech", "admin")), ): """Update a campaign. Only allowed in draft or active state.""" campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() if not campaign: raise HTTPException(status_code=404, detail="Campaign not found") if campaign.status not in ("draft", "active"): raise HTTPException(status_code=400, detail="Can only update draft or active campaigns") # Check ownership or admin if str(campaign.created_by) != str(current_user.id) and current_user.role != "admin": raise HTTPException(status_code=403, detail="Only the creator or admin can update this campaign") update_data = payload.model_dump(exclude_unset=True) if "scheduled_at" in update_data and update_data["scheduled_at"]: update_data["scheduled_at"] = datetime.fromisoformat(update_data["scheduled_at"]) for field, value in update_data.items(): setattr(campaign, field, value) db.commit() db.refresh(campaign) log_action( db, user_id=current_user.id, action="update_campaign", entity_type="campaign", entity_id=campaign.id, details={"updated_fields": list(update_data.keys())}, ) return _serialize_campaign(db, campaign) # --------------------------------------------------------------------------- # POST /campaigns/{id}/tests — Add test to campaign # --------------------------------------------------------------------------- @router.post("/{campaign_id}/tests") def add_test_to_campaign( campaign_id: str, payload: AddTestPayload, db: Session = Depends(get_db), current_user: User = Depends(require_any_role("red_tech", "admin")), ): """Add a test to a campaign with optional ordering and dependency.""" campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() if not campaign: raise HTTPException(status_code=404, detail="Campaign not found") if campaign.status not in ("draft", "active"): raise HTTPException(status_code=400, detail="Can only add tests to draft or active campaigns") test = db.query(Test).filter(Test.id == payload.test_id).first() if not test: raise HTTPException(status_code=404, detail="Test not found") # Calculate order_index if not provided if payload.order_index is not None: order_index = payload.order_index else: max_order = ( db.query(CampaignTest.order_index) .filter(CampaignTest.campaign_id == campaign_id) .order_by(CampaignTest.order_index.desc()) .first() ) order_index = (max_order[0] + 1) if max_order else 0 depends_on = uuid.UUID(payload.depends_on) if payload.depends_on else None # Validate circular dependency ct_id = uuid.uuid4() if depends_on: validate_no_circular_dependency(db, uuid.UUID(campaign_id), ct_id, depends_on) campaign_test = CampaignTest( id=ct_id, campaign_id=campaign_id, test_id=payload.test_id, order_index=order_index, depends_on=depends_on, phase=payload.phase, ) db.add(campaign_test) db.commit() db.refresh(campaign_test) return { "id": str(campaign_test.id), "campaign_id": str(campaign_test.campaign_id), "test_id": str(campaign_test.test_id), "order_index": campaign_test.order_index, "depends_on": str(campaign_test.depends_on) if campaign_test.depends_on else None, "phase": campaign_test.phase, } # --------------------------------------------------------------------------- # DELETE /campaigns/{id}/tests/{campaign_test_id} — Remove test from campaign # --------------------------------------------------------------------------- @router.delete("/{campaign_id}/tests/{campaign_test_id}") def remove_test_from_campaign( campaign_id: str, campaign_test_id: str, db: Session = Depends(get_db), current_user: User = Depends(require_any_role("red_tech", "admin")), ): """Remove a test from a campaign.""" campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() if not campaign: raise HTTPException(status_code=404, detail="Campaign not found") if campaign.status not in ("draft", "active"): raise HTTPException(status_code=400, detail="Can only modify draft or active campaigns") ct = ( db.query(CampaignTest) .filter( CampaignTest.id == campaign_test_id, CampaignTest.campaign_id == campaign_id, ) .first() ) if not ct: raise HTTPException(status_code=404, detail="Campaign test not found") # Clear any references to this campaign_test dependents = ( db.query(CampaignTest) .filter(CampaignTest.depends_on == campaign_test_id) .all() ) for dep in dependents: dep.depends_on = None db.delete(ct) db.commit() return {"detail": "Test removed from campaign"} # --------------------------------------------------------------------------- # POST /campaigns/{id}/activate — Activate campaign # --------------------------------------------------------------------------- @router.post("/{campaign_id}/activate") def activate_campaign( campaign_id: str, db: Session = Depends(get_db), current_user: User = Depends(require_any_role("red_tech", "admin")), ): """Activate a campaign, moving it from draft to active.""" campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() if not campaign: raise HTTPException(status_code=404, detail="Campaign not found") if campaign.status != "draft": raise HTTPException(status_code=400, detail="Only draft campaigns can be activated") # Verify campaign has at least one test test_count = db.query(CampaignTest).filter(CampaignTest.campaign_id == campaign_id).count() if test_count == 0: raise HTTPException(status_code=400, detail="Campaign must have at least one test to activate") campaign.status = "active" db.commit() db.refresh(campaign) # Notify relevant users red_techs = db.query(User).filter(User.role == "red_tech", User.is_active == True).all() # noqa: E712 for user in red_techs: create_notification( db, user_id=user.id, type="campaign_activated", title="Campaign activated", message=f'Campaign "{campaign.name}" has been activated.', entity_type="campaign", entity_id=campaign.id, ) log_action( db, user_id=current_user.id, action="activate_campaign", entity_type="campaign", entity_id=campaign.id, details={"name": campaign.name}, ) return _serialize_campaign(db, campaign) # --------------------------------------------------------------------------- # POST /campaigns/{id}/complete — Mark campaign as completed # --------------------------------------------------------------------------- @router.post("/{campaign_id}/complete") def complete_campaign( campaign_id: str, db: Session = Depends(get_db), current_user: User = Depends(require_any_role("red_lead", "admin")), ): """Mark a campaign as completed.""" campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() if not campaign: raise HTTPException(status_code=404, detail="Campaign not found") if campaign.status != "active": raise HTTPException(status_code=400, detail="Only active campaigns can be completed") campaign.status = "completed" campaign.completed_at = datetime.utcnow() db.commit() db.refresh(campaign) log_action( db, user_id=current_user.id, action="complete_campaign", entity_type="campaign", entity_id=campaign.id, details={"name": campaign.name}, ) return _serialize_campaign(db, campaign) # --------------------------------------------------------------------------- # GET /campaigns/{id}/progress — Campaign progress # --------------------------------------------------------------------------- @router.get("/{campaign_id}/progress") def get_campaign_progress_endpoint( campaign_id: str, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ): """Get progress statistics for a campaign.""" campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() if not campaign: raise HTTPException(status_code=404, detail="Campaign not found") progress = get_campaign_progress(db, uuid.UUID(campaign_id)) return { "campaign_id": str(campaign.id), "campaign_name": campaign.name, **progress, } # --------------------------------------------------------------------------- # POST /campaigns/from-threat-actor/{actor_id} — Auto-generate campaign # --------------------------------------------------------------------------- @router.post("/from-threat-actor/{actor_id}", status_code=201) def generate_campaign_from_actor( actor_id: str, db: Session = Depends(get_db), current_user: User = Depends(require_any_role("red_tech", "admin")), ): """Auto-generate a campaign from a threat actor's uncovered techniques. Creates tests from the best available templates and orders them by kill chain phase. """ campaign = generate_campaign_from_threat_actor( db, uuid.UUID(actor_id), current_user, ) log_action( db, user_id=current_user.id, action="generate_campaign", entity_type="campaign", entity_id=campaign.id, details={"actor_id": actor_id, "campaign_name": campaign.name}, ) return _serialize_campaign(db, campaign)