refactor(campaigns): extract CRUD/business logic to campaign_crud_service, use domain exceptions

This commit is contained in:
2026-02-19 19:04:32 +01:00
parent 50b70704ae
commit f4c74230ec
2 changed files with 529 additions and 325 deletions

View File

@@ -7,26 +7,29 @@ test ordering, progress tracking, and threat actor integration.
import logging import logging
import uuid import uuid
from typing import Optional from typing import Optional
from datetime import datetime
from fastapi import APIRouter, Depends, HTTPException, Query from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from app.database import get_db from app.database import get_db
from app.dependencies.auth import get_current_user, require_any_role from app.dependencies.auth import get_current_user, require_any_role
from app.models.user import User from app.models.user import User
from app.models.campaign import Campaign, CampaignTest, KILL_CHAIN_PHASES from app.services.campaign_service import generate_campaign_from_threat_actor
from app.models.test import Test from app.services.campaign_crud_service import (
from app.models.technique import Technique add_test_to_campaign as crud_add_test,
from app.models.threat_actor import ThreatActor activate_campaign as crud_activate,
from app.services.campaign_service import ( complete_campaign as crud_complete,
validate_no_circular_dependency, create_campaign as crud_create,
get_campaign_progress, get_campaign_detail as crud_get_detail,
generate_campaign_from_threat_actor, get_campaign_history as crud_get_history,
TACTIC_TO_PHASE, get_campaign_progress_data as crud_get_progress,
list_campaigns as crud_list,
remove_test_from_campaign as crud_remove_test,
schedule_campaign as crud_schedule,
serialize_campaign,
update_campaign as crud_update,
) )
from app.services.campaign_scheduler_service import calculate_next_run
from app.services.notification_service import create_notification from app.services.notification_service import create_notification
from app.services.audit_service import log_action from app.services.audit_service import log_action
@@ -67,89 +70,6 @@ class SchedulePayload(BaseModel):
next_run_at: Optional[str] = None next_run_at: Optional[str] = None
# ── Helpers ──────────────────────────────────────────────────────────
def _serialize_campaign(db: Session, campaign: Campaign) -> dict:
"""Serialize a campaign with its tests and progress."""
progress = get_campaign_progress(db, campaign.id)
campaign_tests = (
db.query(CampaignTest)
.filter(CampaignTest.campaign_id == campaign.id)
.order_by(CampaignTest.order_index)
.all()
)
tests = []
for ct in campaign_tests:
test = ct.test
technique = db.query(Technique).filter(Technique.id == test.technique_id).first() if test else None
tests.append({
"id": str(ct.id),
"test_id": str(ct.test_id),
"order_index": ct.order_index,
"depends_on": str(ct.depends_on) if ct.depends_on else None,
"phase": ct.phase,
"test_name": test.name if test else None,
"test_state": test.state.value if test and test.state else None,
"test_result": test.result.value if test and test.result else None,
"technique_mitre_id": technique.mitre_id if technique else None,
"technique_name": technique.name if technique else None,
"platform": test.platform if test else None,
})
actor = campaign.threat_actor
return {
"id": str(campaign.id),
"name": campaign.name,
"description": campaign.description,
"type": campaign.type,
"status": campaign.status,
"threat_actor_id": str(campaign.threat_actor_id) if campaign.threat_actor_id else None,
"threat_actor_name": actor.name if actor else None,
"created_by": str(campaign.created_by) if campaign.created_by else None,
"scheduled_at": campaign.scheduled_at.isoformat() if campaign.scheduled_at else None,
"completed_at": campaign.completed_at.isoformat() if campaign.completed_at else None,
"target_platform": campaign.target_platform,
"tags": campaign.tags or [],
"created_at": campaign.created_at.isoformat() if campaign.created_at else None,
"is_recurring": campaign.is_recurring or False,
"recurrence_pattern": campaign.recurrence_pattern,
"next_run_at": campaign.next_run_at.isoformat() if campaign.next_run_at else None,
"last_run_at": campaign.last_run_at.isoformat() if campaign.last_run_at else None,
"parent_campaign_id": str(campaign.parent_campaign_id) if campaign.parent_campaign_id else None,
"tests": tests,
"progress": progress,
}
def _serialize_campaign_summary(db: Session, campaign: Campaign) -> dict:
"""Lightweight campaign serialization for list views."""
progress = get_campaign_progress(db, campaign.id)
actor = campaign.threat_actor
return {
"id": str(campaign.id),
"name": campaign.name,
"description": campaign.description,
"type": campaign.type,
"status": campaign.status,
"threat_actor_id": str(campaign.threat_actor_id) if campaign.threat_actor_id else None,
"threat_actor_name": actor.name if actor else None,
"target_platform": campaign.target_platform,
"tags": campaign.tags or [],
"created_at": campaign.created_at.isoformat() if campaign.created_at else None,
"is_recurring": campaign.is_recurring or False,
"recurrence_pattern": campaign.recurrence_pattern,
"next_run_at": campaign.next_run_at.isoformat() if campaign.next_run_at else None,
"last_run_at": campaign.last_run_at.isoformat() if campaign.last_run_at else None,
"test_count": progress["total"],
"completion_pct": progress["completion_pct"],
}
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# GET /campaigns — List campaigns with filters # GET /campaigns — List campaigns with filters
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -166,28 +86,15 @@ def list_campaigns(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ):
"""List campaigns with optional filters and pagination.""" """List campaigns with optional filters and pagination."""
query = db.query(Campaign) return crud_list(
db,
if type: type=type,
query = query.filter(Campaign.type == type) status=status,
if status: threat_actor_id=threat_actor_id,
query = query.filter(Campaign.status == status) search=search,
if threat_actor_id: offset=offset,
query = query.filter(Campaign.threat_actor_id == threat_actor_id) limit=limit,
if search: )
from app.utils import escape_like
pattern = f"%{escape_like(search)}%"
query = query.filter(Campaign.name.ilike(pattern) | Campaign.description.ilike(pattern))
total = query.count()
campaigns = query.order_by(Campaign.created_at.desc()).offset(offset).limit(limit).all()
return {
"total": total,
"offset": offset,
"limit": limit,
"items": [_serialize_campaign_summary(db, c) for c in campaigns],
}
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -201,30 +108,29 @@ def create_campaign(
current_user: User = Depends(require_any_role("red_lead", "blue_lead")), current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
): ):
"""Create a new campaign.""" """Create a new campaign."""
campaign = Campaign( result = crud_create(
db,
creator_id=current_user.id,
name=payload.name, name=payload.name,
description=payload.description, description=payload.description,
type=payload.type, type=payload.type,
threat_actor_id=payload.threat_actor_id, threat_actor_id=payload.threat_actor_id,
target_platform=payload.target_platform, target_platform=payload.target_platform,
tags=payload.tags or [], tags=payload.tags,
created_by=current_user.id, scheduled_at=payload.scheduled_at,
scheduled_at=datetime.fromisoformat(payload.scheduled_at) if payload.scheduled_at else None,
) )
db.add(campaign)
db.commit()
db.refresh(campaign)
log_action( log_action(
db, db,
user_id=current_user.id, user_id=current_user.id,
action="create_campaign", action="create_campaign",
entity_type="campaign", entity_type="campaign",
entity_id=campaign.id, entity_id=result["id"],
details={"name": campaign.name, "type": campaign.type}, details={"name": payload.name, "type": payload.type},
) )
db.commit()
return _serialize_campaign(db, campaign) return result
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -238,11 +144,7 @@ def get_campaign(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ):
"""Get detailed campaign info including tests and progress.""" """Get detailed campaign info including tests and progress."""
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() return crud_get_detail(db, campaign_id)
if not campaign:
raise HTTPException(status_code=404, detail="Campaign not found")
return _serialize_campaign(db, campaign)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -257,37 +159,26 @@ def update_campaign(
current_user: User = Depends(require_any_role("red_lead", "blue_lead")), current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
): ):
"""Update a campaign. Only allowed in draft or active state.""" """Update a campaign. Only allowed in draft or active state."""
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
if not campaign:
raise HTTPException(status_code=404, detail="Campaign not found")
if campaign.status not in ("draft", "active"):
raise HTTPException(status_code=400, detail="Can only update draft or active campaigns")
# Check ownership or admin
if str(campaign.created_by) != str(current_user.id) and current_user.role != "admin":
raise HTTPException(status_code=403, detail="Only the creator or admin can update this campaign")
update_data = payload.model_dump(exclude_unset=True) update_data = payload.model_dump(exclude_unset=True)
if "scheduled_at" in update_data and update_data["scheduled_at"]: result = crud_update(
update_data["scheduled_at"] = datetime.fromisoformat(update_data["scheduled_at"]) db,
campaign_id,
for field, value in update_data.items(): updater_id=current_user.id,
setattr(campaign, field, value) updater_role=current_user.role,
**update_data,
db.commit() )
db.refresh(campaign)
log_action( log_action(
db, db,
user_id=current_user.id, user_id=current_user.id,
action="update_campaign", action="update_campaign",
entity_type="campaign", entity_type="campaign",
entity_id=campaign.id, entity_id=campaign_id,
details={"updated_fields": list(update_data.keys())}, details={"updated_fields": list(update_data.keys())},
) )
db.commit()
return _serialize_campaign(db, campaign) return result
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -302,63 +193,16 @@ def add_test_to_campaign(
current_user: User = Depends(require_any_role("red_lead", "blue_lead")), current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
): ):
"""Add a test to a campaign with optional ordering and dependency.""" """Add a test to a campaign with optional ordering and dependency."""
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() result = crud_add_test(
if not campaign: db,
raise HTTPException(status_code=404, detail="Campaign not found") campaign_id,
if campaign.status not in ("draft", "active"):
raise HTTPException(status_code=400, detail="Can only add tests to draft or active campaigns")
test = db.query(Test).filter(Test.id == payload.test_id).first()
if not test:
raise HTTPException(status_code=404, detail="Test not found")
# Calculate order_index if not provided
if payload.order_index is not None:
order_index = payload.order_index
else:
max_order = (
db.query(CampaignTest.order_index)
.filter(CampaignTest.campaign_id == campaign_id)
.order_by(CampaignTest.order_index.desc())
.first()
)
order_index = (max_order[0] + 1) if max_order else 0
depends_on = uuid.UUID(payload.depends_on) if payload.depends_on else None
# Validate circular dependency
ct_id = uuid.uuid4()
if depends_on:
validate_no_circular_dependency(db, uuid.UUID(campaign_id), ct_id, depends_on)
# Auto-detect kill chain phase from the test's technique tactic if not provided
phase = payload.phase
if not phase and test.technique_id:
technique = db.query(Technique).filter(Technique.id == test.technique_id).first()
if technique and technique.tactic:
phase = TACTIC_TO_PHASE.get(technique.tactic, None)
campaign_test = CampaignTest(
id=ct_id,
campaign_id=campaign_id,
test_id=payload.test_id, test_id=payload.test_id,
order_index=order_index, order_index=payload.order_index,
depends_on=depends_on, depends_on=payload.depends_on,
phase=phase, phase=payload.phase,
) )
db.add(campaign_test)
db.commit() db.commit()
db.refresh(campaign_test) return result
return {
"id": str(campaign_test.id),
"campaign_id": str(campaign_test.campaign_id),
"test_id": str(campaign_test.test_id),
"order_index": campaign_test.order_index,
"depends_on": str(campaign_test.depends_on) if campaign_test.depends_on else None,
"phase": campaign_test.phase,
}
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -373,36 +217,8 @@ def remove_test_from_campaign(
current_user: User = Depends(require_any_role("red_lead", "blue_lead")), current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
): ):
"""Remove a test from a campaign.""" """Remove a test from a campaign."""
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() crud_remove_test(db, campaign_id, campaign_test_id)
if not campaign:
raise HTTPException(status_code=404, detail="Campaign not found")
if campaign.status not in ("draft", "active"):
raise HTTPException(status_code=400, detail="Can only modify draft or active campaigns")
ct = (
db.query(CampaignTest)
.filter(
CampaignTest.id == campaign_test_id,
CampaignTest.campaign_id == campaign_id,
)
.first()
)
if not ct:
raise HTTPException(status_code=404, detail="Campaign test not found")
# Clear any references to this campaign_test
dependents = (
db.query(CampaignTest)
.filter(CampaignTest.depends_on == campaign_test_id)
.all()
)
for dep in dependents:
dep.depends_on = None
db.delete(ct)
db.commit() db.commit()
return {"detail": "Test removed from campaign"} return {"detail": "Test removed from campaign"}
@@ -417,23 +233,10 @@ def activate_campaign(
current_user: User = Depends(require_any_role("red_lead", "blue_lead")), current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
): ):
"""Activate a campaign, moving it from draft to active.""" """Activate a campaign, moving it from draft to active."""
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() campaign = crud_activate(db, campaign_id)
if not campaign:
raise HTTPException(status_code=404, detail="Campaign not found")
if campaign.status != "draft":
raise HTTPException(status_code=400, detail="Only draft campaigns can be activated")
# Verify campaign has at least one test
test_count = db.query(CampaignTest).filter(CampaignTest.campaign_id == campaign_id).count()
if test_count == 0:
raise HTTPException(status_code=400, detail="Campaign must have at least one test to activate")
campaign.status = "active"
db.commit() db.commit()
db.refresh(campaign) db.refresh(campaign)
# Notify relevant users
red_techs = db.query(User).filter(User.role == "red_tech", User.is_active == True).all() # noqa: E712 red_techs = db.query(User).filter(User.role == "red_tech", User.is_active == True).all() # noqa: E712
for user in red_techs: for user in red_techs:
create_notification( create_notification(
@@ -455,7 +258,7 @@ def activate_campaign(
details={"name": campaign.name}, details={"name": campaign.name},
) )
return _serialize_campaign(db, campaign) return serialize_campaign(db, campaign)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -469,15 +272,7 @@ def complete_campaign(
current_user: User = Depends(require_any_role("red_lead", "admin")), current_user: User = Depends(require_any_role("red_lead", "admin")),
): ):
"""Mark a campaign as completed.""" """Mark a campaign as completed."""
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() campaign = crud_complete(db, campaign_id)
if not campaign:
raise HTTPException(status_code=404, detail="Campaign not found")
if campaign.status != "active":
raise HTTPException(status_code=400, detail="Only active campaigns can be completed")
campaign.status = "completed"
campaign.completed_at = datetime.utcnow()
db.commit() db.commit()
db.refresh(campaign) db.refresh(campaign)
@@ -490,7 +285,7 @@ def complete_campaign(
details={"name": campaign.name}, details={"name": campaign.name},
) )
return _serialize_campaign(db, campaign) return serialize_campaign(db, campaign)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -504,16 +299,7 @@ def get_campaign_progress_endpoint(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ):
"""Get progress statistics for a campaign.""" """Get progress statistics for a campaign."""
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() return crud_get_progress(db, campaign_id)
if not campaign:
raise HTTPException(status_code=404, detail="Campaign not found")
progress = get_campaign_progress(db, uuid.UUID(campaign_id))
return {
"campaign_id": str(campaign.id),
"campaign_name": campaign.name,
**progress,
}
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -546,7 +332,7 @@ def generate_campaign_from_actor(
details={"actor_id": actor_id, "campaign_name": campaign.name}, details={"actor_id": actor_id, "campaign_name": campaign.name},
) )
return _serialize_campaign(db, campaign) return serialize_campaign(db, campaign)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -564,31 +350,15 @@ def schedule_campaign(
Only the campaign creator or admin can change scheduling. Only the campaign creator or admin can change scheduling.
""" """
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() campaign = crud_schedule(
if not campaign: db,
raise HTTPException(status_code=404, detail="Campaign not found") campaign_id,
owner_id=current_user.id,
# Check ownership or admin owner_role=current_user.role,
if str(campaign.created_by) != str(current_user.id) and current_user.role != "admin": is_recurring=payload.is_recurring,
raise HTTPException(status_code=403, detail="Only the creator or admin can configure scheduling") recurrence_pattern=payload.recurrence_pattern,
next_run_at=payload.next_run_at,
campaign.is_recurring = payload.is_recurring )
if payload.is_recurring:
if payload.recurrence_pattern not in ("weekly", "monthly", "quarterly"):
raise HTTPException(
status_code=400,
detail="recurrence_pattern must be 'weekly', 'monthly', or 'quarterly'",
)
campaign.recurrence_pattern = payload.recurrence_pattern
if payload.next_run_at:
campaign.next_run_at = datetime.fromisoformat(payload.next_run_at.replace("Z", "+00:00").replace("+00:00", ""))
elif not campaign.next_run_at:
campaign.next_run_at = calculate_next_run(datetime.utcnow(), payload.recurrence_pattern)
else:
campaign.recurrence_pattern = None
campaign.next_run_at = None
db.commit() db.commit()
db.refresh(campaign) db.refresh(campaign)
@@ -605,7 +375,7 @@ def schedule_campaign(
}, },
) )
return _serialize_campaign(db, campaign) return serialize_campaign(db, campaign)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -619,30 +389,4 @@ def get_campaign_history(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ):
"""List all child campaigns (execution history) of a recurring campaign.""" """List all child campaigns (execution history) of a recurring campaign."""
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() return crud_get_history(db, campaign_id)
if not campaign:
raise HTTPException(status_code=404, detail="Campaign not found")
children = (
db.query(Campaign)
.filter(Campaign.parent_campaign_id == campaign_id)
.order_by(Campaign.created_at.desc())
.all()
)
return {
"campaign_id": str(campaign.id),
"campaign_name": campaign.name,
"items": [
{
"id": str(child.id),
"name": child.name,
"status": child.status,
"test_count": db.query(CampaignTest).filter(CampaignTest.campaign_id == child.id).count(),
"completion_pct": get_campaign_progress(db, child.id)["completion_pct"],
"created_at": child.created_at.isoformat() if child.created_at else None,
"completed_at": child.completed_at.isoformat() if child.completed_at else None,
}
for child in children
],
}

View File

@@ -0,0 +1,460 @@
"""Campaign CRUD service — list, create, update, and business logic.
Framework-agnostic; uses domain exceptions from app.domain.errors.
The router is responsible for HTTP concerns, auth, audit logging, and commit.
"""
import uuid
from datetime import datetime
from typing import Optional
from sqlalchemy.orm import Session
from app.domain.errors import (
BusinessRuleViolation,
EntityNotFoundError,
PermissionViolation,
)
from app.models.campaign import Campaign, CampaignTest
from app.models.test import Test
from app.models.technique import Technique
from app.utils import escape_like
from app.services.campaign_service import (
get_campaign_progress,
validate_no_circular_dependency,
TACTIC_TO_PHASE,
)
from app.services.campaign_scheduler_service import calculate_next_run
# ── Serialization helpers ────────────────────────────────────────────────
def serialize_campaign(db: Session, campaign: Campaign) -> dict:
"""Serialize a campaign with its tests and progress."""
progress = get_campaign_progress(db, campaign.id)
campaign_tests = (
db.query(CampaignTest)
.filter(CampaignTest.campaign_id == campaign.id)
.order_by(CampaignTest.order_index)
.all()
)
tests = []
for ct in campaign_tests:
test = ct.test
technique = db.query(Technique).filter(Technique.id == test.technique_id).first() if test else None
tests.append({
"id": str(ct.id),
"test_id": str(ct.test_id),
"order_index": ct.order_index,
"depends_on": str(ct.depends_on) if ct.depends_on else None,
"phase": ct.phase,
"test_name": test.name if test else None,
"test_state": test.state.value if test and test.state else None,
"test_result": test.result.value if test and test.result else None,
"technique_mitre_id": technique.mitre_id if technique else None,
"technique_name": technique.name if technique else None,
"platform": test.platform if test else None,
})
actor = campaign.threat_actor
return {
"id": str(campaign.id),
"name": campaign.name,
"description": campaign.description,
"type": campaign.type,
"status": campaign.status,
"threat_actor_id": str(campaign.threat_actor_id) if campaign.threat_actor_id else None,
"threat_actor_name": actor.name if actor else None,
"created_by": str(campaign.created_by) if campaign.created_by else None,
"scheduled_at": campaign.scheduled_at.isoformat() if campaign.scheduled_at else None,
"completed_at": campaign.completed_at.isoformat() if campaign.completed_at else None,
"target_platform": campaign.target_platform,
"tags": campaign.tags or [],
"created_at": campaign.created_at.isoformat() if campaign.created_at else None,
"is_recurring": campaign.is_recurring or False,
"recurrence_pattern": campaign.recurrence_pattern,
"next_run_at": campaign.next_run_at.isoformat() if campaign.next_run_at else None,
"last_run_at": campaign.last_run_at.isoformat() if campaign.last_run_at else None,
"parent_campaign_id": str(campaign.parent_campaign_id) if campaign.parent_campaign_id else None,
"tests": tests,
"progress": progress,
}
def serialize_campaign_summary(db: Session, campaign: Campaign) -> dict:
"""Lightweight campaign serialization for list views."""
progress = get_campaign_progress(db, campaign.id)
actor = campaign.threat_actor
return {
"id": str(campaign.id),
"name": campaign.name,
"description": campaign.description,
"type": campaign.type,
"status": campaign.status,
"threat_actor_id": str(campaign.threat_actor_id) if campaign.threat_actor_id else None,
"threat_actor_name": actor.name if actor else None,
"target_platform": campaign.target_platform,
"tags": campaign.tags or [],
"created_at": campaign.created_at.isoformat() if campaign.created_at else None,
"is_recurring": campaign.is_recurring or False,
"recurrence_pattern": campaign.recurrence_pattern,
"next_run_at": campaign.next_run_at.isoformat() if campaign.next_run_at else None,
"last_run_at": campaign.last_run_at.isoformat() if campaign.last_run_at else None,
"test_count": progress["total"],
"completion_pct": progress["completion_pct"],
}
# ── CRUD operations ───────────────────────────────────────────────────────
def list_campaigns(
db: Session,
*,
type: Optional[str] = None,
status: Optional[str] = None,
threat_actor_id: Optional[str] = None,
search: Optional[str] = None,
offset: int = 0,
limit: int = 50,
) -> dict:
"""Return a paginated list of campaigns with optional filters."""
query = db.query(Campaign)
if type:
query = query.filter(Campaign.type == type)
if status:
query = query.filter(Campaign.status == status)
if threat_actor_id:
query = query.filter(Campaign.threat_actor_id == threat_actor_id)
if search:
pattern = f"%{escape_like(search)}%"
query = query.filter(Campaign.name.ilike(pattern) | Campaign.description.ilike(pattern))
total = query.count()
campaigns = query.order_by(Campaign.created_at.desc()).offset(offset).limit(limit).all()
return {
"total": total,
"offset": offset,
"limit": limit,
"items": [serialize_campaign_summary(db, c) for c in campaigns],
}
def create_campaign(
db: Session,
*,
creator_id: uuid.UUID,
name: str,
description: Optional[str] = None,
type: str = "custom",
threat_actor_id: Optional[str] = None,
target_platform: Optional[str] = None,
tags: Optional[list[str]] = None,
scheduled_at: Optional[str] = None,
) -> dict:
"""Create a new campaign. Does not commit; caller commits."""
campaign = Campaign(
name=name,
description=description,
type=type,
threat_actor_id=uuid.UUID(threat_actor_id) if threat_actor_id else None,
target_platform=target_platform,
tags=tags or [],
created_by=creator_id,
scheduled_at=datetime.fromisoformat(scheduled_at) if scheduled_at else None,
)
db.add(campaign)
db.flush()
return serialize_campaign(db, campaign)
def get_campaign_detail(db: Session, campaign_id: str) -> dict:
"""Get detailed campaign info including tests and progress.
Raises EntityNotFoundError if campaign not found.
"""
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
if not campaign:
raise EntityNotFoundError("Campaign", campaign_id)
return serialize_campaign(db, campaign)
def update_campaign(
db: Session,
campaign_id: str,
*,
updater_id: uuid.UUID,
updater_role: str,
**fields,
) -> dict:
"""Update a campaign. Only allowed in draft or active state.
Raises EntityNotFoundError, BusinessRuleViolation, or PermissionViolation.
Does not commit; caller commits.
"""
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
if not campaign:
raise EntityNotFoundError("Campaign", campaign_id)
if campaign.status not in ("draft", "active"):
raise BusinessRuleViolation("Can only update draft or active campaigns")
if str(campaign.created_by) != str(updater_id) and updater_role != "admin":
raise PermissionViolation("Only the creator or admin can update this campaign")
if "scheduled_at" in fields and fields["scheduled_at"]:
fields["scheduled_at"] = datetime.fromisoformat(fields["scheduled_at"])
for field, value in fields.items():
setattr(campaign, field, value)
db.flush()
return serialize_campaign(db, campaign)
def add_test_to_campaign(
db: Session,
campaign_id: str,
*,
test_id: str,
order_index: Optional[int] = None,
depends_on: Optional[str] = None,
phase: Optional[str] = None,
) -> dict:
"""Add a test to a campaign with optional ordering and dependency.
Raises EntityNotFoundError for missing campaign or test.
Raises BusinessRuleViolation for invalid state or circular dependency.
Does not commit; caller commits.
"""
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
if not campaign:
raise EntityNotFoundError("Campaign", campaign_id)
if campaign.status not in ("draft", "active"):
raise BusinessRuleViolation("Can only add tests to draft or active campaigns")
test = db.query(Test).filter(Test.id == test_id).first()
if not test:
raise EntityNotFoundError("Test", test_id)
if order_index is not None:
final_order_index = order_index
else:
max_order = (
db.query(CampaignTest.order_index)
.filter(CampaignTest.campaign_id == campaign_id)
.order_by(CampaignTest.order_index.desc())
.first()
)
final_order_index = (max_order[0] + 1) if max_order else 0
depends_on_uuid = uuid.UUID(depends_on) if depends_on else None
ct_id = uuid.uuid4()
if depends_on_uuid:
validate_no_circular_dependency(db, uuid.UUID(campaign_id), ct_id, depends_on_uuid)
if not phase and test.technique_id:
technique = db.query(Technique).filter(Technique.id == test.technique_id).first()
if technique and technique.tactic:
phase = TACTIC_TO_PHASE.get(technique.tactic, None)
campaign_test = CampaignTest(
id=ct_id,
campaign_id=campaign_id,
test_id=test_id,
order_index=final_order_index,
depends_on=depends_on_uuid,
phase=phase,
)
db.add(campaign_test)
db.flush()
return {
"id": str(campaign_test.id),
"campaign_id": str(campaign_test.campaign_id),
"test_id": str(campaign_test.test_id),
"order_index": campaign_test.order_index,
"depends_on": str(campaign_test.depends_on) if campaign_test.depends_on else None,
"phase": campaign_test.phase,
}
def remove_test_from_campaign(db: Session, campaign_id: str, campaign_test_id: str) -> None:
"""Remove a test from a campaign.
Raises EntityNotFoundError for missing campaign or campaign test.
Raises BusinessRuleViolation for invalid state.
Does not commit; caller commits.
"""
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
if not campaign:
raise EntityNotFoundError("Campaign", campaign_id)
if campaign.status not in ("draft", "active"):
raise BusinessRuleViolation("Can only modify draft or active campaigns")
ct = (
db.query(CampaignTest)
.filter(
CampaignTest.id == campaign_test_id,
CampaignTest.campaign_id == campaign_id,
)
.first()
)
if not ct:
raise EntityNotFoundError("CampaignTest", campaign_test_id)
dep_id = uuid.UUID(campaign_test_id)
dependents = db.query(CampaignTest).filter(CampaignTest.depends_on == dep_id).all()
for dep in dependents:
dep.depends_on = None
db.delete(ct)
db.flush()
def activate_campaign(db: Session, campaign_id: str) -> Campaign:
"""Activate a campaign, moving it from draft to active.
Raises EntityNotFoundError, BusinessRuleViolation.
Does not commit; caller commits.
"""
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
if not campaign:
raise EntityNotFoundError("Campaign", campaign_id)
if campaign.status != "draft":
raise BusinessRuleViolation("Only draft campaigns can be activated")
test_count = db.query(CampaignTest).filter(CampaignTest.campaign_id == campaign_id).count()
if test_count == 0:
raise BusinessRuleViolation("Campaign must have at least one test to activate")
campaign.status = "active"
db.flush()
return campaign
def complete_campaign(db: Session, campaign_id: str) -> Campaign:
"""Mark a campaign as completed.
Raises EntityNotFoundError, BusinessRuleViolation.
Does not commit; caller commits.
"""
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
if not campaign:
raise EntityNotFoundError("Campaign", campaign_id)
if campaign.status != "active":
raise BusinessRuleViolation("Only active campaigns can be completed")
campaign.status = "completed"
campaign.completed_at = datetime.utcnow()
db.flush()
return campaign
def get_campaign_progress_data(db: Session, campaign_id: str) -> dict:
"""Get progress statistics for a campaign.
Raises EntityNotFoundError if campaign not found.
"""
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
if not campaign:
raise EntityNotFoundError("Campaign", campaign_id)
progress = get_campaign_progress(db, uuid.UUID(campaign_id))
return {
"campaign_id": str(campaign.id),
"campaign_name": campaign.name,
**progress,
}
def schedule_campaign(
db: Session,
campaign_id: str,
*,
owner_id: uuid.UUID,
owner_role: str,
is_recurring: bool,
recurrence_pattern: Optional[str] = None,
next_run_at: Optional[str] = None,
) -> Campaign:
"""Configure or update the recurrence schedule for a campaign.
Raises EntityNotFoundError, PermissionViolation, BusinessRuleViolation.
Does not commit; caller commits.
"""
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
if not campaign:
raise EntityNotFoundError("Campaign", campaign_id)
if str(campaign.created_by) != str(owner_id) and owner_role != "admin":
raise PermissionViolation("Only the creator or admin can configure scheduling")
campaign.is_recurring = is_recurring
if is_recurring:
if recurrence_pattern not in ("weekly", "monthly", "quarterly"):
raise BusinessRuleViolation(
"recurrence_pattern must be 'weekly', 'monthly', or 'quarterly'"
)
campaign.recurrence_pattern = recurrence_pattern
if next_run_at:
campaign.next_run_at = datetime.fromisoformat(
next_run_at.replace("Z", "+00:00").replace("+00:00", "")
)
elif not campaign.next_run_at:
campaign.next_run_at = calculate_next_run(datetime.utcnow(), recurrence_pattern)
else:
campaign.recurrence_pattern = None
campaign.next_run_at = None
db.flush()
return campaign
def get_campaign_history(db: Session, campaign_id: str) -> dict:
"""List all child campaigns (execution history) of a recurring campaign.
Raises EntityNotFoundError if campaign not found.
"""
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
if not campaign:
raise EntityNotFoundError("Campaign", campaign_id)
campaign_uuid = uuid.UUID(campaign_id)
children = (
db.query(Campaign)
.filter(Campaign.parent_campaign_id == campaign_uuid)
.order_by(Campaign.created_at.desc())
.all()
)
return {
"campaign_id": str(campaign.id),
"campaign_name": campaign.name,
"items": [
{
"id": str(child.id),
"name": child.name,
"status": child.status,
"test_count": db.query(CampaignTest).filter(CampaignTest.campaign_id == child.id).count(),
"completion_pct": get_campaign_progress(db, child.id)["completion_pct"],
"created_at": child.created_at.isoformat() if child.created_at else None,
"completed_at": child.completed_at.isoformat() if child.completed_at else None,
}
for child in children
],
}