From f4c74230ec3cd2d58d1749da2afd8610511a7571 Mon Sep 17 00:00:00 2001 From: Kitos Date: Thu, 19 Feb 2026 19:04:32 +0100 Subject: [PATCH] refactor(campaigns): extract CRUD/business logic to campaign_crud_service, use domain exceptions --- backend/app/routers/campaigns.py | 394 +++------------ backend/app/services/campaign_crud_service.py | 460 ++++++++++++++++++ 2 files changed, 529 insertions(+), 325 deletions(-) create mode 100644 backend/app/services/campaign_crud_service.py diff --git a/backend/app/routers/campaigns.py b/backend/app/routers/campaigns.py index 1d62a51..335c906 100644 --- a/backend/app/routers/campaigns.py +++ b/backend/app/routers/campaigns.py @@ -7,26 +7,29 @@ test ordering, progress tracking, and threat actor integration. import logging import uuid from typing import Optional -from datetime import datetime -from fastapi import APIRouter, Depends, HTTPException, Query +from fastapi import APIRouter, Depends, Query from sqlalchemy.orm import Session from pydantic import BaseModel, Field from app.database import get_db from app.dependencies.auth import get_current_user, require_any_role from app.models.user import User -from app.models.campaign import Campaign, CampaignTest, KILL_CHAIN_PHASES -from app.models.test import Test -from app.models.technique import Technique -from app.models.threat_actor import ThreatActor -from app.services.campaign_service import ( - validate_no_circular_dependency, - get_campaign_progress, - generate_campaign_from_threat_actor, - TACTIC_TO_PHASE, +from app.services.campaign_service import generate_campaign_from_threat_actor +from app.services.campaign_crud_service import ( + add_test_to_campaign as crud_add_test, + activate_campaign as crud_activate, + complete_campaign as crud_complete, + create_campaign as crud_create, + get_campaign_detail as crud_get_detail, + get_campaign_history as crud_get_history, + get_campaign_progress_data as crud_get_progress, + list_campaigns as crud_list, + remove_test_from_campaign as crud_remove_test, + schedule_campaign as crud_schedule, + serialize_campaign, + update_campaign as crud_update, ) -from app.services.campaign_scheduler_service import calculate_next_run from app.services.notification_service import create_notification from app.services.audit_service import log_action @@ -67,89 +70,6 @@ class SchedulePayload(BaseModel): next_run_at: Optional[str] = None -# ── Helpers ────────────────────────────────────────────────────────── - -def _serialize_campaign(db: Session, campaign: Campaign) -> dict: - """Serialize a campaign with its tests and progress.""" - progress = get_campaign_progress(db, campaign.id) - - campaign_tests = ( - db.query(CampaignTest) - .filter(CampaignTest.campaign_id == campaign.id) - .order_by(CampaignTest.order_index) - .all() - ) - - tests = [] - for ct in campaign_tests: - test = ct.test - technique = db.query(Technique).filter(Technique.id == test.technique_id).first() if test else None - - tests.append({ - "id": str(ct.id), - "test_id": str(ct.test_id), - "order_index": ct.order_index, - "depends_on": str(ct.depends_on) if ct.depends_on else None, - "phase": ct.phase, - "test_name": test.name if test else None, - "test_state": test.state.value if test and test.state else None, - "test_result": test.result.value if test and test.result else None, - "technique_mitre_id": technique.mitre_id if technique else None, - "technique_name": technique.name if technique else None, - "platform": test.platform if test else None, - }) - - actor = campaign.threat_actor - - return { - "id": str(campaign.id), - "name": campaign.name, - "description": campaign.description, - "type": campaign.type, - "status": campaign.status, - "threat_actor_id": str(campaign.threat_actor_id) if campaign.threat_actor_id else None, - "threat_actor_name": actor.name if actor else None, - "created_by": str(campaign.created_by) if campaign.created_by else None, - "scheduled_at": campaign.scheduled_at.isoformat() if campaign.scheduled_at else None, - "completed_at": campaign.completed_at.isoformat() if campaign.completed_at else None, - "target_platform": campaign.target_platform, - "tags": campaign.tags or [], - "created_at": campaign.created_at.isoformat() if campaign.created_at else None, - "is_recurring": campaign.is_recurring or False, - "recurrence_pattern": campaign.recurrence_pattern, - "next_run_at": campaign.next_run_at.isoformat() if campaign.next_run_at else None, - "last_run_at": campaign.last_run_at.isoformat() if campaign.last_run_at else None, - "parent_campaign_id": str(campaign.parent_campaign_id) if campaign.parent_campaign_id else None, - "tests": tests, - "progress": progress, - } - - -def _serialize_campaign_summary(db: Session, campaign: Campaign) -> dict: - """Lightweight campaign serialization for list views.""" - progress = get_campaign_progress(db, campaign.id) - actor = campaign.threat_actor - - return { - "id": str(campaign.id), - "name": campaign.name, - "description": campaign.description, - "type": campaign.type, - "status": campaign.status, - "threat_actor_id": str(campaign.threat_actor_id) if campaign.threat_actor_id else None, - "threat_actor_name": actor.name if actor else None, - "target_platform": campaign.target_platform, - "tags": campaign.tags or [], - "created_at": campaign.created_at.isoformat() if campaign.created_at else None, - "is_recurring": campaign.is_recurring or False, - "recurrence_pattern": campaign.recurrence_pattern, - "next_run_at": campaign.next_run_at.isoformat() if campaign.next_run_at else None, - "last_run_at": campaign.last_run_at.isoformat() if campaign.last_run_at else None, - "test_count": progress["total"], - "completion_pct": progress["completion_pct"], - } - - # --------------------------------------------------------------------------- # GET /campaigns — List campaigns with filters # --------------------------------------------------------------------------- @@ -166,28 +86,15 @@ def list_campaigns( current_user: User = Depends(get_current_user), ): """List campaigns with optional filters and pagination.""" - query = db.query(Campaign) - - if type: - query = query.filter(Campaign.type == type) - if status: - query = query.filter(Campaign.status == status) - if threat_actor_id: - query = query.filter(Campaign.threat_actor_id == threat_actor_id) - if search: - from app.utils import escape_like - pattern = f"%{escape_like(search)}%" - query = query.filter(Campaign.name.ilike(pattern) | Campaign.description.ilike(pattern)) - - total = query.count() - campaigns = query.order_by(Campaign.created_at.desc()).offset(offset).limit(limit).all() - - return { - "total": total, - "offset": offset, - "limit": limit, - "items": [_serialize_campaign_summary(db, c) for c in campaigns], - } + return crud_list( + db, + type=type, + status=status, + threat_actor_id=threat_actor_id, + search=search, + offset=offset, + limit=limit, + ) # --------------------------------------------------------------------------- @@ -201,30 +108,29 @@ def create_campaign( current_user: User = Depends(require_any_role("red_lead", "blue_lead")), ): """Create a new campaign.""" - campaign = Campaign( + result = crud_create( + db, + creator_id=current_user.id, name=payload.name, description=payload.description, type=payload.type, threat_actor_id=payload.threat_actor_id, target_platform=payload.target_platform, - tags=payload.tags or [], - created_by=current_user.id, - scheduled_at=datetime.fromisoformat(payload.scheduled_at) if payload.scheduled_at else None, + tags=payload.tags, + scheduled_at=payload.scheduled_at, ) - db.add(campaign) - db.commit() - db.refresh(campaign) log_action( db, user_id=current_user.id, action="create_campaign", entity_type="campaign", - entity_id=campaign.id, - details={"name": campaign.name, "type": campaign.type}, + entity_id=result["id"], + details={"name": payload.name, "type": payload.type}, ) + db.commit() - return _serialize_campaign(db, campaign) + return result # --------------------------------------------------------------------------- @@ -238,11 +144,7 @@ def get_campaign( current_user: User = Depends(get_current_user), ): """Get detailed campaign info including tests and progress.""" - campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() - if not campaign: - raise HTTPException(status_code=404, detail="Campaign not found") - - return _serialize_campaign(db, campaign) + return crud_get_detail(db, campaign_id) # --------------------------------------------------------------------------- @@ -257,37 +159,26 @@ def update_campaign( current_user: User = Depends(require_any_role("red_lead", "blue_lead")), ): """Update a campaign. Only allowed in draft or active state.""" - campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() - if not campaign: - raise HTTPException(status_code=404, detail="Campaign not found") - - if campaign.status not in ("draft", "active"): - raise HTTPException(status_code=400, detail="Can only update draft or active campaigns") - - # Check ownership or admin - if str(campaign.created_by) != str(current_user.id) and current_user.role != "admin": - raise HTTPException(status_code=403, detail="Only the creator or admin can update this campaign") - update_data = payload.model_dump(exclude_unset=True) - if "scheduled_at" in update_data and update_data["scheduled_at"]: - update_data["scheduled_at"] = datetime.fromisoformat(update_data["scheduled_at"]) - - for field, value in update_data.items(): - setattr(campaign, field, value) - - db.commit() - db.refresh(campaign) + result = crud_update( + db, + campaign_id, + updater_id=current_user.id, + updater_role=current_user.role, + **update_data, + ) log_action( db, user_id=current_user.id, action="update_campaign", entity_type="campaign", - entity_id=campaign.id, + entity_id=campaign_id, details={"updated_fields": list(update_data.keys())}, ) + db.commit() - return _serialize_campaign(db, campaign) + return result # --------------------------------------------------------------------------- @@ -302,63 +193,16 @@ def add_test_to_campaign( current_user: User = Depends(require_any_role("red_lead", "blue_lead")), ): """Add a test to a campaign with optional ordering and dependency.""" - campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() - if not campaign: - raise HTTPException(status_code=404, detail="Campaign not found") - - if campaign.status not in ("draft", "active"): - raise HTTPException(status_code=400, detail="Can only add tests to draft or active campaigns") - - test = db.query(Test).filter(Test.id == payload.test_id).first() - if not test: - raise HTTPException(status_code=404, detail="Test not found") - - # Calculate order_index if not provided - if payload.order_index is not None: - order_index = payload.order_index - else: - max_order = ( - db.query(CampaignTest.order_index) - .filter(CampaignTest.campaign_id == campaign_id) - .order_by(CampaignTest.order_index.desc()) - .first() - ) - order_index = (max_order[0] + 1) if max_order else 0 - - depends_on = uuid.UUID(payload.depends_on) if payload.depends_on else None - - # Validate circular dependency - ct_id = uuid.uuid4() - if depends_on: - validate_no_circular_dependency(db, uuid.UUID(campaign_id), ct_id, depends_on) - - # Auto-detect kill chain phase from the test's technique tactic if not provided - phase = payload.phase - if not phase and test.technique_id: - technique = db.query(Technique).filter(Technique.id == test.technique_id).first() - if technique and technique.tactic: - phase = TACTIC_TO_PHASE.get(technique.tactic, None) - - campaign_test = CampaignTest( - id=ct_id, - campaign_id=campaign_id, + result = crud_add_test( + db, + campaign_id, test_id=payload.test_id, - order_index=order_index, - depends_on=depends_on, - phase=phase, + order_index=payload.order_index, + depends_on=payload.depends_on, + phase=payload.phase, ) - db.add(campaign_test) db.commit() - db.refresh(campaign_test) - - return { - "id": str(campaign_test.id), - "campaign_id": str(campaign_test.campaign_id), - "test_id": str(campaign_test.test_id), - "order_index": campaign_test.order_index, - "depends_on": str(campaign_test.depends_on) if campaign_test.depends_on else None, - "phase": campaign_test.phase, - } + return result # --------------------------------------------------------------------------- @@ -373,36 +217,8 @@ def remove_test_from_campaign( current_user: User = Depends(require_any_role("red_lead", "blue_lead")), ): """Remove a test from a campaign.""" - campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() - if not campaign: - raise HTTPException(status_code=404, detail="Campaign not found") - - if campaign.status not in ("draft", "active"): - raise HTTPException(status_code=400, detail="Can only modify draft or active campaigns") - - ct = ( - db.query(CampaignTest) - .filter( - CampaignTest.id == campaign_test_id, - CampaignTest.campaign_id == campaign_id, - ) - .first() - ) - if not ct: - raise HTTPException(status_code=404, detail="Campaign test not found") - - # Clear any references to this campaign_test - dependents = ( - db.query(CampaignTest) - .filter(CampaignTest.depends_on == campaign_test_id) - .all() - ) - for dep in dependents: - dep.depends_on = None - - db.delete(ct) + crud_remove_test(db, campaign_id, campaign_test_id) db.commit() - return {"detail": "Test removed from campaign"} @@ -417,23 +233,10 @@ def activate_campaign( current_user: User = Depends(require_any_role("red_lead", "blue_lead")), ): """Activate a campaign, moving it from draft to active.""" - campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() - if not campaign: - raise HTTPException(status_code=404, detail="Campaign not found") - - if campaign.status != "draft": - raise HTTPException(status_code=400, detail="Only draft campaigns can be activated") - - # Verify campaign has at least one test - test_count = db.query(CampaignTest).filter(CampaignTest.campaign_id == campaign_id).count() - if test_count == 0: - raise HTTPException(status_code=400, detail="Campaign must have at least one test to activate") - - campaign.status = "active" + campaign = crud_activate(db, campaign_id) db.commit() db.refresh(campaign) - # Notify relevant users red_techs = db.query(User).filter(User.role == "red_tech", User.is_active == True).all() # noqa: E712 for user in red_techs: create_notification( @@ -455,7 +258,7 @@ def activate_campaign( details={"name": campaign.name}, ) - return _serialize_campaign(db, campaign) + return serialize_campaign(db, campaign) # --------------------------------------------------------------------------- @@ -469,15 +272,7 @@ def complete_campaign( current_user: User = Depends(require_any_role("red_lead", "admin")), ): """Mark a campaign as completed.""" - campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() - if not campaign: - raise HTTPException(status_code=404, detail="Campaign not found") - - if campaign.status != "active": - raise HTTPException(status_code=400, detail="Only active campaigns can be completed") - - campaign.status = "completed" - campaign.completed_at = datetime.utcnow() + campaign = crud_complete(db, campaign_id) db.commit() db.refresh(campaign) @@ -490,7 +285,7 @@ def complete_campaign( details={"name": campaign.name}, ) - return _serialize_campaign(db, campaign) + return serialize_campaign(db, campaign) # --------------------------------------------------------------------------- @@ -504,16 +299,7 @@ def get_campaign_progress_endpoint( current_user: User = Depends(get_current_user), ): """Get progress statistics for a campaign.""" - campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() - if not campaign: - raise HTTPException(status_code=404, detail="Campaign not found") - - progress = get_campaign_progress(db, uuid.UUID(campaign_id)) - return { - "campaign_id": str(campaign.id), - "campaign_name": campaign.name, - **progress, - } + return crud_get_progress(db, campaign_id) # --------------------------------------------------------------------------- @@ -546,7 +332,7 @@ def generate_campaign_from_actor( details={"actor_id": actor_id, "campaign_name": campaign.name}, ) - return _serialize_campaign(db, campaign) + return serialize_campaign(db, campaign) # --------------------------------------------------------------------------- @@ -564,31 +350,15 @@ def schedule_campaign( Only the campaign creator or admin can change scheduling. """ - campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() - if not campaign: - raise HTTPException(status_code=404, detail="Campaign not found") - - # Check ownership or admin - if str(campaign.created_by) != str(current_user.id) and current_user.role != "admin": - raise HTTPException(status_code=403, detail="Only the creator or admin can configure scheduling") - - campaign.is_recurring = payload.is_recurring - - if payload.is_recurring: - if payload.recurrence_pattern not in ("weekly", "monthly", "quarterly"): - raise HTTPException( - status_code=400, - detail="recurrence_pattern must be 'weekly', 'monthly', or 'quarterly'", - ) - campaign.recurrence_pattern = payload.recurrence_pattern - if payload.next_run_at: - campaign.next_run_at = datetime.fromisoformat(payload.next_run_at.replace("Z", "+00:00").replace("+00:00", "")) - elif not campaign.next_run_at: - campaign.next_run_at = calculate_next_run(datetime.utcnow(), payload.recurrence_pattern) - else: - campaign.recurrence_pattern = None - campaign.next_run_at = None - + campaign = crud_schedule( + db, + campaign_id, + owner_id=current_user.id, + owner_role=current_user.role, + is_recurring=payload.is_recurring, + recurrence_pattern=payload.recurrence_pattern, + next_run_at=payload.next_run_at, + ) db.commit() db.refresh(campaign) @@ -605,7 +375,7 @@ def schedule_campaign( }, ) - return _serialize_campaign(db, campaign) + return serialize_campaign(db, campaign) # --------------------------------------------------------------------------- @@ -619,30 +389,4 @@ def get_campaign_history( current_user: User = Depends(get_current_user), ): """List all child campaigns (execution history) of a recurring campaign.""" - campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() - if not campaign: - raise HTTPException(status_code=404, detail="Campaign not found") - - children = ( - db.query(Campaign) - .filter(Campaign.parent_campaign_id == campaign_id) - .order_by(Campaign.created_at.desc()) - .all() - ) - - return { - "campaign_id": str(campaign.id), - "campaign_name": campaign.name, - "items": [ - { - "id": str(child.id), - "name": child.name, - "status": child.status, - "test_count": db.query(CampaignTest).filter(CampaignTest.campaign_id == child.id).count(), - "completion_pct": get_campaign_progress(db, child.id)["completion_pct"], - "created_at": child.created_at.isoformat() if child.created_at else None, - "completed_at": child.completed_at.isoformat() if child.completed_at else None, - } - for child in children - ], - } + return crud_get_history(db, campaign_id) diff --git a/backend/app/services/campaign_crud_service.py b/backend/app/services/campaign_crud_service.py new file mode 100644 index 0000000..34a503c --- /dev/null +++ b/backend/app/services/campaign_crud_service.py @@ -0,0 +1,460 @@ +"""Campaign CRUD service — list, create, update, and business logic. + +Framework-agnostic; uses domain exceptions from app.domain.errors. +The router is responsible for HTTP concerns, auth, audit logging, and commit. +""" + +import uuid +from datetime import datetime +from typing import Optional + +from sqlalchemy.orm import Session + +from app.domain.errors import ( + BusinessRuleViolation, + EntityNotFoundError, + PermissionViolation, +) +from app.models.campaign import Campaign, CampaignTest +from app.models.test import Test +from app.models.technique import Technique +from app.utils import escape_like +from app.services.campaign_service import ( + get_campaign_progress, + validate_no_circular_dependency, + TACTIC_TO_PHASE, +) +from app.services.campaign_scheduler_service import calculate_next_run + + +# ── Serialization helpers ──────────────────────────────────────────────── + + +def serialize_campaign(db: Session, campaign: Campaign) -> dict: + """Serialize a campaign with its tests and progress.""" + progress = get_campaign_progress(db, campaign.id) + + campaign_tests = ( + db.query(CampaignTest) + .filter(CampaignTest.campaign_id == campaign.id) + .order_by(CampaignTest.order_index) + .all() + ) + + tests = [] + for ct in campaign_tests: + test = ct.test + technique = db.query(Technique).filter(Technique.id == test.technique_id).first() if test else None + + tests.append({ + "id": str(ct.id), + "test_id": str(ct.test_id), + "order_index": ct.order_index, + "depends_on": str(ct.depends_on) if ct.depends_on else None, + "phase": ct.phase, + "test_name": test.name if test else None, + "test_state": test.state.value if test and test.state else None, + "test_result": test.result.value if test and test.result else None, + "technique_mitre_id": technique.mitre_id if technique else None, + "technique_name": technique.name if technique else None, + "platform": test.platform if test else None, + }) + + actor = campaign.threat_actor + + return { + "id": str(campaign.id), + "name": campaign.name, + "description": campaign.description, + "type": campaign.type, + "status": campaign.status, + "threat_actor_id": str(campaign.threat_actor_id) if campaign.threat_actor_id else None, + "threat_actor_name": actor.name if actor else None, + "created_by": str(campaign.created_by) if campaign.created_by else None, + "scheduled_at": campaign.scheduled_at.isoformat() if campaign.scheduled_at else None, + "completed_at": campaign.completed_at.isoformat() if campaign.completed_at else None, + "target_platform": campaign.target_platform, + "tags": campaign.tags or [], + "created_at": campaign.created_at.isoformat() if campaign.created_at else None, + "is_recurring": campaign.is_recurring or False, + "recurrence_pattern": campaign.recurrence_pattern, + "next_run_at": campaign.next_run_at.isoformat() if campaign.next_run_at else None, + "last_run_at": campaign.last_run_at.isoformat() if campaign.last_run_at else None, + "parent_campaign_id": str(campaign.parent_campaign_id) if campaign.parent_campaign_id else None, + "tests": tests, + "progress": progress, + } + + +def serialize_campaign_summary(db: Session, campaign: Campaign) -> dict: + """Lightweight campaign serialization for list views.""" + progress = get_campaign_progress(db, campaign.id) + actor = campaign.threat_actor + + return { + "id": str(campaign.id), + "name": campaign.name, + "description": campaign.description, + "type": campaign.type, + "status": campaign.status, + "threat_actor_id": str(campaign.threat_actor_id) if campaign.threat_actor_id else None, + "threat_actor_name": actor.name if actor else None, + "target_platform": campaign.target_platform, + "tags": campaign.tags or [], + "created_at": campaign.created_at.isoformat() if campaign.created_at else None, + "is_recurring": campaign.is_recurring or False, + "recurrence_pattern": campaign.recurrence_pattern, + "next_run_at": campaign.next_run_at.isoformat() if campaign.next_run_at else None, + "last_run_at": campaign.last_run_at.isoformat() if campaign.last_run_at else None, + "test_count": progress["total"], + "completion_pct": progress["completion_pct"], + } + + +# ── CRUD operations ─────────────────────────────────────────────────────── + + +def list_campaigns( + db: Session, + *, + type: Optional[str] = None, + status: Optional[str] = None, + threat_actor_id: Optional[str] = None, + search: Optional[str] = None, + offset: int = 0, + limit: int = 50, +) -> dict: + """Return a paginated list of campaigns with optional filters.""" + query = db.query(Campaign) + + if type: + query = query.filter(Campaign.type == type) + if status: + query = query.filter(Campaign.status == status) + if threat_actor_id: + query = query.filter(Campaign.threat_actor_id == threat_actor_id) + if search: + pattern = f"%{escape_like(search)}%" + query = query.filter(Campaign.name.ilike(pattern) | Campaign.description.ilike(pattern)) + + total = query.count() + campaigns = query.order_by(Campaign.created_at.desc()).offset(offset).limit(limit).all() + + return { + "total": total, + "offset": offset, + "limit": limit, + "items": [serialize_campaign_summary(db, c) for c in campaigns], + } + + +def create_campaign( + db: Session, + *, + creator_id: uuid.UUID, + name: str, + description: Optional[str] = None, + type: str = "custom", + threat_actor_id: Optional[str] = None, + target_platform: Optional[str] = None, + tags: Optional[list[str]] = None, + scheduled_at: Optional[str] = None, +) -> dict: + """Create a new campaign. Does not commit; caller commits.""" + campaign = Campaign( + name=name, + description=description, + type=type, + threat_actor_id=uuid.UUID(threat_actor_id) if threat_actor_id else None, + target_platform=target_platform, + tags=tags or [], + created_by=creator_id, + scheduled_at=datetime.fromisoformat(scheduled_at) if scheduled_at else None, + ) + db.add(campaign) + db.flush() + return serialize_campaign(db, campaign) + + +def get_campaign_detail(db: Session, campaign_id: str) -> dict: + """Get detailed campaign info including tests and progress. + + Raises EntityNotFoundError if campaign not found. + """ + campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() + if not campaign: + raise EntityNotFoundError("Campaign", campaign_id) + return serialize_campaign(db, campaign) + + +def update_campaign( + db: Session, + campaign_id: str, + *, + updater_id: uuid.UUID, + updater_role: str, + **fields, +) -> dict: + """Update a campaign. Only allowed in draft or active state. + + Raises EntityNotFoundError, BusinessRuleViolation, or PermissionViolation. + Does not commit; caller commits. + """ + campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() + if not campaign: + raise EntityNotFoundError("Campaign", campaign_id) + + if campaign.status not in ("draft", "active"): + raise BusinessRuleViolation("Can only update draft or active campaigns") + + if str(campaign.created_by) != str(updater_id) and updater_role != "admin": + raise PermissionViolation("Only the creator or admin can update this campaign") + + if "scheduled_at" in fields and fields["scheduled_at"]: + fields["scheduled_at"] = datetime.fromisoformat(fields["scheduled_at"]) + + for field, value in fields.items(): + setattr(campaign, field, value) + + db.flush() + return serialize_campaign(db, campaign) + + +def add_test_to_campaign( + db: Session, + campaign_id: str, + *, + test_id: str, + order_index: Optional[int] = None, + depends_on: Optional[str] = None, + phase: Optional[str] = None, +) -> dict: + """Add a test to a campaign with optional ordering and dependency. + + Raises EntityNotFoundError for missing campaign or test. + Raises BusinessRuleViolation for invalid state or circular dependency. + Does not commit; caller commits. + """ + campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() + if not campaign: + raise EntityNotFoundError("Campaign", campaign_id) + + if campaign.status not in ("draft", "active"): + raise BusinessRuleViolation("Can only add tests to draft or active campaigns") + + test = db.query(Test).filter(Test.id == test_id).first() + if not test: + raise EntityNotFoundError("Test", test_id) + + if order_index is not None: + final_order_index = order_index + else: + max_order = ( + db.query(CampaignTest.order_index) + .filter(CampaignTest.campaign_id == campaign_id) + .order_by(CampaignTest.order_index.desc()) + .first() + ) + final_order_index = (max_order[0] + 1) if max_order else 0 + + depends_on_uuid = uuid.UUID(depends_on) if depends_on else None + + ct_id = uuid.uuid4() + if depends_on_uuid: + validate_no_circular_dependency(db, uuid.UUID(campaign_id), ct_id, depends_on_uuid) + + if not phase and test.technique_id: + technique = db.query(Technique).filter(Technique.id == test.technique_id).first() + if technique and technique.tactic: + phase = TACTIC_TO_PHASE.get(technique.tactic, None) + + campaign_test = CampaignTest( + id=ct_id, + campaign_id=campaign_id, + test_id=test_id, + order_index=final_order_index, + depends_on=depends_on_uuid, + phase=phase, + ) + db.add(campaign_test) + db.flush() + + return { + "id": str(campaign_test.id), + "campaign_id": str(campaign_test.campaign_id), + "test_id": str(campaign_test.test_id), + "order_index": campaign_test.order_index, + "depends_on": str(campaign_test.depends_on) if campaign_test.depends_on else None, + "phase": campaign_test.phase, + } + + +def remove_test_from_campaign(db: Session, campaign_id: str, campaign_test_id: str) -> None: + """Remove a test from a campaign. + + Raises EntityNotFoundError for missing campaign or campaign test. + Raises BusinessRuleViolation for invalid state. + Does not commit; caller commits. + """ + campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() + if not campaign: + raise EntityNotFoundError("Campaign", campaign_id) + + if campaign.status not in ("draft", "active"): + raise BusinessRuleViolation("Can only modify draft or active campaigns") + + ct = ( + db.query(CampaignTest) + .filter( + CampaignTest.id == campaign_test_id, + CampaignTest.campaign_id == campaign_id, + ) + .first() + ) + if not ct: + raise EntityNotFoundError("CampaignTest", campaign_test_id) + + dep_id = uuid.UUID(campaign_test_id) + dependents = db.query(CampaignTest).filter(CampaignTest.depends_on == dep_id).all() + for dep in dependents: + dep.depends_on = None + + db.delete(ct) + db.flush() + + +def activate_campaign(db: Session, campaign_id: str) -> Campaign: + """Activate a campaign, moving it from draft to active. + + Raises EntityNotFoundError, BusinessRuleViolation. + Does not commit; caller commits. + """ + campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() + if not campaign: + raise EntityNotFoundError("Campaign", campaign_id) + + if campaign.status != "draft": + raise BusinessRuleViolation("Only draft campaigns can be activated") + + test_count = db.query(CampaignTest).filter(CampaignTest.campaign_id == campaign_id).count() + if test_count == 0: + raise BusinessRuleViolation("Campaign must have at least one test to activate") + + campaign.status = "active" + db.flush() + return campaign + + +def complete_campaign(db: Session, campaign_id: str) -> Campaign: + """Mark a campaign as completed. + + Raises EntityNotFoundError, BusinessRuleViolation. + Does not commit; caller commits. + """ + campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() + if not campaign: + raise EntityNotFoundError("Campaign", campaign_id) + + if campaign.status != "active": + raise BusinessRuleViolation("Only active campaigns can be completed") + + campaign.status = "completed" + campaign.completed_at = datetime.utcnow() + db.flush() + return campaign + + +def get_campaign_progress_data(db: Session, campaign_id: str) -> dict: + """Get progress statistics for a campaign. + + Raises EntityNotFoundError if campaign not found. + """ + campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() + if not campaign: + raise EntityNotFoundError("Campaign", campaign_id) + + progress = get_campaign_progress(db, uuid.UUID(campaign_id)) + return { + "campaign_id": str(campaign.id), + "campaign_name": campaign.name, + **progress, + } + + +def schedule_campaign( + db: Session, + campaign_id: str, + *, + owner_id: uuid.UUID, + owner_role: str, + is_recurring: bool, + recurrence_pattern: Optional[str] = None, + next_run_at: Optional[str] = None, +) -> Campaign: + """Configure or update the recurrence schedule for a campaign. + + Raises EntityNotFoundError, PermissionViolation, BusinessRuleViolation. + Does not commit; caller commits. + """ + campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() + if not campaign: + raise EntityNotFoundError("Campaign", campaign_id) + + if str(campaign.created_by) != str(owner_id) and owner_role != "admin": + raise PermissionViolation("Only the creator or admin can configure scheduling") + + campaign.is_recurring = is_recurring + + if is_recurring: + if recurrence_pattern not in ("weekly", "monthly", "quarterly"): + raise BusinessRuleViolation( + "recurrence_pattern must be 'weekly', 'monthly', or 'quarterly'" + ) + campaign.recurrence_pattern = recurrence_pattern + if next_run_at: + campaign.next_run_at = datetime.fromisoformat( + next_run_at.replace("Z", "+00:00").replace("+00:00", "") + ) + elif not campaign.next_run_at: + campaign.next_run_at = calculate_next_run(datetime.utcnow(), recurrence_pattern) + else: + campaign.recurrence_pattern = None + campaign.next_run_at = None + + db.flush() + return campaign + + +def get_campaign_history(db: Session, campaign_id: str) -> dict: + """List all child campaigns (execution history) of a recurring campaign. + + Raises EntityNotFoundError if campaign not found. + """ + campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() + if not campaign: + raise EntityNotFoundError("Campaign", campaign_id) + + campaign_uuid = uuid.UUID(campaign_id) + children = ( + db.query(Campaign) + .filter(Campaign.parent_campaign_id == campaign_uuid) + .order_by(Campaign.created_at.desc()) + .all() + ) + + return { + "campaign_id": str(campaign.id), + "campaign_name": campaign.name, + "items": [ + { + "id": str(child.id), + "name": child.name, + "status": child.status, + "test_count": db.query(CampaignTest).filter(CampaignTest.campaign_id == child.id).count(), + "completion_pct": get_campaign_progress(db, child.id)["completion_pct"], + "created_at": child.created_at.isoformat() if child.created_at else None, + "completed_at": child.completed_at.isoformat() if child.completed_at else None, + } + for child in children + ], + }