"""Campaign CRUD service — list, create, update, and business logic. Framework-agnostic; uses domain exceptions from app.domain.errors. The router is responsible for HTTP concerns, auth, audit logging, and commit. """ # Import uuid import uuid # Import datetime from datetime from datetime import datetime # Import Optional from typing from typing import Optional # Import Session from sqlalchemy.orm from sqlalchemy.orm import Session # Import from app.domain.errors from app.domain.errors import ( BusinessRuleViolation, EntityNotFoundError, PermissionViolation, ) # Import Campaign, CampaignTest from app.models.campaign from app.models.campaign import Campaign, CampaignTest # Import Technique from app.models.technique from app.models.technique import Technique # Import Test from app.models.test from app.models.test import Test # Import calculate_next_run from app.services.campaign_scheduler_service from app.services.campaign_scheduler_service import calculate_next_run # Import from app.services.campaign_service from app.services.campaign_service import ( TACTIC_TO_PHASE, get_campaign_progress, validate_no_circular_dependency, ) # Import escape_like from app.utils from app.utils import escape_like # ── Serialization helpers ──────────────────────────────────────────────── def serialize_campaign(db: Session, campaign: Campaign) -> dict: """Serialize a campaign with its tests and progress.""" # Assign progress = get_campaign_progress(db, campaign.id) progress = get_campaign_progress(db, campaign.id) # Assign campaign_tests = ( campaign_tests = ( db.query(CampaignTest) # Chain .filter() call .filter(CampaignTest.campaign_id == campaign.id) # Chain .order_by() call .order_by(CampaignTest.order_index) # Chain .all() call .all() ) # Assign tests = [] tests = [] # Iterate over campaign_tests for ct in campaign_tests: # Assign test = ct.test test = ct.test # Assign technique = db.query(Technique).filter(Technique.id == test.technique_id).first... technique = db.query(Technique).filter(Technique.id == test.technique_id).first() if test else None # Call tests.append() tests.append({ # Literal argument value "id": str(ct.id), # Literal argument value "test_id": str(ct.test_id), # Literal argument value "order_index": ct.order_index, # Literal argument value "depends_on": str(ct.depends_on) if ct.depends_on else None, # Literal argument value "phase": ct.phase, # Literal argument value "test_name": test.name if test else None, # Literal argument value "test_state": test.state.value if test and test.state else None, # Literal argument value "test_result": test.result.value if test and test.result else None, # Literal argument value "technique_mitre_id": technique.mitre_id if technique else None, # Literal argument value "technique_name": technique.name if technique else None, # Literal argument value "platform": test.platform if test else None, }) # Assign actor = campaign.threat_actor actor = campaign.threat_actor # Return { return { # Literal argument value "id": str(campaign.id), # Literal argument value "name": campaign.name, # Literal argument value "description": campaign.description, # Literal argument value "type": campaign.type, # Literal argument value "status": campaign.status, # Literal argument value "threat_actor_id": str(campaign.threat_actor_id) if campaign.threat_actor_id else None, # Literal argument value "threat_actor_name": actor.name if actor else None, # Literal argument value "created_by": str(campaign.created_by) if campaign.created_by else None, # Literal argument value "scheduled_at": campaign.scheduled_at.isoformat() if campaign.scheduled_at else None, # Literal argument value "completed_at": campaign.completed_at.isoformat() if campaign.completed_at else None, # Literal argument value "target_platform": campaign.target_platform, # Literal argument value "tags": campaign.tags or [], # Literal argument value "created_at": campaign.created_at.isoformat() if campaign.created_at else None, # Literal argument value "is_recurring": campaign.is_recurring or False, # Literal argument value "recurrence_pattern": campaign.recurrence_pattern, # Literal argument value "next_run_at": campaign.next_run_at.isoformat() if campaign.next_run_at else None, # Literal argument value "last_run_at": campaign.last_run_at.isoformat() if campaign.last_run_at else None, # Literal argument value "parent_campaign_id": str(campaign.parent_campaign_id) if campaign.parent_campaign_id else None, # Literal argument value "tests": tests, # Literal argument value "progress": progress, } # Define function serialize_campaign_summary def serialize_campaign_summary(db: Session, campaign: Campaign) -> dict: """Lightweight campaign serialization for list views.""" # Assign progress = get_campaign_progress(db, campaign.id) progress = get_campaign_progress(db, campaign.id) # Assign actor = campaign.threat_actor actor = campaign.threat_actor # Return { return { # Literal argument value "id": str(campaign.id), # Literal argument value "name": campaign.name, # Literal argument value "description": campaign.description, # Literal argument value "type": campaign.type, # Literal argument value "status": campaign.status, # Literal argument value "threat_actor_id": str(campaign.threat_actor_id) if campaign.threat_actor_id else None, # Literal argument value "threat_actor_name": actor.name if actor else None, # Literal argument value "target_platform": campaign.target_platform, # Literal argument value "tags": campaign.tags or [], # Literal argument value "created_at": campaign.created_at.isoformat() if campaign.created_at else None, # Literal argument value "is_recurring": campaign.is_recurring or False, # Literal argument value "recurrence_pattern": campaign.recurrence_pattern, # Literal argument value "next_run_at": campaign.next_run_at.isoformat() if campaign.next_run_at else None, # Literal argument value "last_run_at": campaign.last_run_at.isoformat() if campaign.last_run_at else None, # Literal argument value "test_count": progress["total"], # Literal argument value "completion_pct": progress["completion_pct"], } # ── CRUD operations ─────────────────────────────────────────────────────── def list_campaigns( # Entry: db db: Session, *, # Entry: type type: Optional[str] = None, # Entry: status status: Optional[str] = None, # Entry: threat_actor_id threat_actor_id: Optional[str] = None, # Entry: search search: Optional[str] = None, # Entry: offset offset: int = 0, # Entry: limit limit: int = 50, ) -> dict: """Return a paginated list of campaigns with optional filters.""" # Assign query = db.query(Campaign) query = db.query(Campaign) # Check: type if type: # Assign query = query.filter(Campaign.type == type) query = query.filter(Campaign.type == type) # Check: status if status: # Assign query = query.filter(Campaign.status == status) query = query.filter(Campaign.status == status) # Check: threat_actor_id if threat_actor_id: # Assign query = query.filter(Campaign.threat_actor_id == threat_actor_id) query = query.filter(Campaign.threat_actor_id == threat_actor_id) # Check: search if search: # Assign pattern = f"%{escape_like(search)}%" pattern = f"%{escape_like(search)}%" # Assign query = query.filter(Campaign.name.ilike(pattern) | Campaign.description.il... query = query.filter(Campaign.name.ilike(pattern) | Campaign.description.ilike(pattern)) # Assign total = query.count() total = query.count() # Assign campaigns = query.order_by(Campaign.created_at.desc()).offset(offset).limit(lim... campaigns = query.order_by(Campaign.created_at.desc()).offset(offset).limit(limit).all() # Return { return { # Literal argument value "total": total, # Literal argument value "offset": offset, # Literal argument value "limit": limit, # Literal argument value "items": [serialize_campaign_summary(db, c) for c in campaigns], } # Define function create_campaign def create_campaign( # Entry: db db: Session, *, # Entry: creator_id creator_id: uuid.UUID, # Entry: name name: str, # Entry: description description: Optional[str] = None, # Entry: type type: str = "custom", # Entry: threat_actor_id threat_actor_id: Optional[str] = None, # Entry: target_platform target_platform: Optional[str] = None, # Entry: tags tags: Optional[list[str]] = None, # Entry: scheduled_at scheduled_at: Optional[str] = None, ) -> dict: """Create a new campaign. Does not commit; caller commits.""" # Assign campaign = Campaign( campaign = Campaign( # Keyword argument: name name=name, # Keyword argument: description description=description, # Keyword argument: type type=type, # Keyword argument: threat_actor_id threat_actor_id=uuid.UUID(threat_actor_id) if threat_actor_id else None, # Keyword argument: target_platform target_platform=target_platform, # Keyword argument: tags tags=tags or [], # Keyword argument: created_by created_by=creator_id, # Keyword argument: scheduled_at scheduled_at=datetime.fromisoformat(scheduled_at) if scheduled_at else None, ) # Stage new record(s) for database insertion db.add(campaign) # Flush changes to DB without committing the transaction db.flush() # Return serialize_campaign(db, campaign) return serialize_campaign(db, campaign) # Define function get_campaign_detail def get_campaign_detail(db: Session, campaign_id: str) -> dict: """Get detailed campaign info including tests and progress. Raises EntityNotFoundError if campaign not found. """ # Assign campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() # Check: not campaign if not campaign: # Raise EntityNotFoundError raise EntityNotFoundError("Campaign", campaign_id) # Return serialize_campaign(db, campaign) return serialize_campaign(db, campaign) # Define function update_campaign def update_campaign( # Entry: db db: Session, # Entry: campaign_id campaign_id: str, *, # Entry: updater_id updater_id: uuid.UUID, # Entry: updater_role updater_role: str, **fields: object, ) -> dict: """Update a campaign. Only allowed in draft or active state. Raises EntityNotFoundError, BusinessRuleViolation, or PermissionViolation. Does not commit; caller commits. """ # Assign campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() # Check: not campaign if not campaign: # Raise EntityNotFoundError raise EntityNotFoundError("Campaign", campaign_id) # Check: campaign.status not in ("draft", "active") if campaign.status not in ("draft", "active"): # Raise BusinessRuleViolation raise BusinessRuleViolation("Can only update draft or active campaigns") # Check: str(campaign.created_by) != str(updater_id) and updater_role != "ad... if str(campaign.created_by) != str(updater_id) and updater_role != "admin": # Raise PermissionViolation raise PermissionViolation("Only the creator or admin can update this campaign") # Check: "scheduled_at" in fields and fields["scheduled_at"] if "scheduled_at" in fields and fields["scheduled_at"]: # Assign fields["scheduled_at"] = datetime.fromisoformat(fields["scheduled_at"]) fields["scheduled_at"] = datetime.fromisoformat(fields["scheduled_at"]) # Iterate over fields.items() for field, value in fields.items(): # Call setattr() setattr(campaign, field, value) # Flush changes to DB without committing the transaction db.flush() # Return serialize_campaign(db, campaign) return serialize_campaign(db, campaign) # Define function add_test_to_campaign def add_test_to_campaign( # Entry: db db: Session, # Entry: campaign_id campaign_id: str, *, # Entry: test_id test_id: str, # Entry: order_index order_index: Optional[int] = None, # Entry: depends_on depends_on: Optional[str] = None, # Entry: phase phase: Optional[str] = None, ) -> dict: """Add a test to a campaign with optional ordering and dependency. Raises EntityNotFoundError for missing campaign or test. Raises BusinessRuleViolation for invalid state or circular dependency. Does not commit; caller commits. """ # Assign campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() # Check: not campaign if not campaign: # Raise EntityNotFoundError raise EntityNotFoundError("Campaign", campaign_id) # Check: campaign.status not in ("draft", "active") if campaign.status not in ("draft", "active"): # Raise BusinessRuleViolation raise BusinessRuleViolation("Can only add tests to draft or active campaigns") # Assign test = db.query(Test).filter(Test.id == test_id).first() test = db.query(Test).filter(Test.id == test_id).first() # Check: not test if not test: # Raise EntityNotFoundError raise EntityNotFoundError("Test", test_id) # Check: order_index is not None if order_index is not None: # Assign final_order_index = order_index final_order_index = order_index # Fallback: handle remaining cases else: # Assign max_order = ( max_order = ( db.query(CampaignTest.order_index) # Chain .filter() call .filter(CampaignTest.campaign_id == campaign_id) # Chain .order_by() call .order_by(CampaignTest.order_index.desc()) # Chain .first() call .first() ) # Assign final_order_index = (max_order[0] + 1) if max_order else 0 final_order_index = (max_order[0] + 1) if max_order else 0 # Assign depends_on_uuid = uuid.UUID(depends_on) if depends_on else None depends_on_uuid = uuid.UUID(depends_on) if depends_on else None # Assign ct_id = uuid.uuid4() ct_id = uuid.uuid4() # Check: depends_on_uuid if depends_on_uuid: # Call validate_no_circular_dependency() validate_no_circular_dependency(db, uuid.UUID(campaign_id), ct_id, depends_on_uuid) # Check: not phase and test.technique_id if not phase and test.technique_id: # Assign technique = db.query(Technique).filter(Technique.id == test.technique_id).first() technique = db.query(Technique).filter(Technique.id == test.technique_id).first() # Check: technique and technique.tactic if technique and technique.tactic: # Assign phase = TACTIC_TO_PHASE.get(technique.tactic, None) phase = TACTIC_TO_PHASE.get(technique.tactic, None) # Assign campaign_test = CampaignTest( campaign_test = CampaignTest( # Keyword argument: id id=ct_id, # Keyword argument: campaign_id campaign_id=campaign_id, # Keyword argument: test_id test_id=test_id, # Keyword argument: order_index order_index=final_order_index, # Keyword argument: depends_on depends_on=depends_on_uuid, # Keyword argument: phase phase=phase, ) # Stage new record(s) for database insertion db.add(campaign_test) # Flush changes to DB without committing the transaction db.flush() # Return { return { # Literal argument value "id": str(campaign_test.id), # Literal argument value "campaign_id": str(campaign_test.campaign_id), # Literal argument value "test_id": str(campaign_test.test_id), # Literal argument value "order_index": campaign_test.order_index, # Literal argument value "depends_on": str(campaign_test.depends_on) if campaign_test.depends_on else None, # Literal argument value "phase": campaign_test.phase, } # Define function remove_test_from_campaign def remove_test_from_campaign(db: Session, campaign_id: str, campaign_test_id: str) -> None: """Remove a test from a campaign. Raises EntityNotFoundError for missing campaign or campaign test. Raises BusinessRuleViolation for invalid state. Does not commit; caller commits. """ # Assign campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() # Check: not campaign if not campaign: # Raise EntityNotFoundError raise EntityNotFoundError("Campaign", campaign_id) # Check: campaign.status not in ("draft", "active") if campaign.status not in ("draft", "active"): # Raise BusinessRuleViolation raise BusinessRuleViolation("Can only modify draft or active campaigns") # Assign ct = ( ct = ( db.query(CampaignTest) # Chain .filter() call .filter( CampaignTest.id == campaign_test_id, CampaignTest.campaign_id == campaign_id, ) # Chain .first() call .first() ) # Check: not ct if not ct: # Raise EntityNotFoundError raise EntityNotFoundError("CampaignTest", campaign_test_id) # Assign dep_id = uuid.UUID(campaign_test_id) dep_id = uuid.UUID(campaign_test_id) # Assign dependents = db.query(CampaignTest).filter(CampaignTest.depends_on == dep_id).all() dependents = db.query(CampaignTest).filter(CampaignTest.depends_on == dep_id).all() # Iterate over dependents for dep in dependents: # Assign dep.depends_on = None dep.depends_on = None # Mark record for deletion on next commit db.delete(ct) # Flush changes to DB without committing the transaction db.flush() # Define function activate_campaign def activate_campaign(db: Session, campaign_id: str) -> Campaign: """Activate a campaign, moving it from draft to active. Raises EntityNotFoundError, BusinessRuleViolation. Does not commit; caller commits. """ # Assign campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() # Check: not campaign if not campaign: # Raise EntityNotFoundError raise EntityNotFoundError("Campaign", campaign_id) # Check: campaign.status != "draft" if campaign.status != "draft": # Raise BusinessRuleViolation raise BusinessRuleViolation("Only draft campaigns can be activated") # Assign test_count = db.query(CampaignTest).filter(CampaignTest.campaign_id == campaign_... test_count = db.query(CampaignTest).filter(CampaignTest.campaign_id == campaign_id).count() # Check: test_count == 0 if test_count == 0: # Raise BusinessRuleViolation raise BusinessRuleViolation("Campaign must have at least one test to activate") # Assign campaign.status = "active" campaign.status = "active" # Flush changes to DB without committing the transaction db.flush() # Return campaign return campaign # Define function complete_campaign def complete_campaign(db: Session, campaign_id: str) -> Campaign: """Mark a campaign as completed. Raises EntityNotFoundError, BusinessRuleViolation. Does not commit; caller commits. """ # Assign campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() # Check: not campaign if not campaign: # Raise EntityNotFoundError raise EntityNotFoundError("Campaign", campaign_id) # Check: campaign.status != "active" if campaign.status != "active": # Raise BusinessRuleViolation raise BusinessRuleViolation("Only active campaigns can be completed") # Assign campaign.status = "completed" campaign.status = "completed" # Assign campaign.completed_at = datetime.utcnow() campaign.completed_at = datetime.utcnow() # Flush changes to DB without committing the transaction db.flush() # Return campaign return campaign # Define function get_campaign_progress_data def get_campaign_progress_data(db: Session, campaign_id: str) -> dict: """Get progress statistics for a campaign. Raises EntityNotFoundError if campaign not found. """ # Assign campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() # Check: not campaign if not campaign: # Raise EntityNotFoundError raise EntityNotFoundError("Campaign", campaign_id) # Assign progress = get_campaign_progress(db, uuid.UUID(campaign_id)) progress = get_campaign_progress(db, uuid.UUID(campaign_id)) # Return { return { # Literal argument value "campaign_id": str(campaign.id), # Literal argument value "campaign_name": campaign.name, **progress, } # Define function schedule_campaign def schedule_campaign( # Entry: db db: Session, # Entry: campaign_id campaign_id: str, *, # Entry: owner_id owner_id: uuid.UUID, # Entry: owner_role owner_role: str, # Entry: is_recurring is_recurring: bool, # Entry: recurrence_pattern recurrence_pattern: Optional[str] = None, # Entry: next_run_at next_run_at: Optional[str] = None, ) -> Campaign: """Configure or update the recurrence schedule for a campaign. Raises EntityNotFoundError, PermissionViolation, BusinessRuleViolation. Does not commit; caller commits. """ # Assign campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() # Check: not campaign if not campaign: # Raise EntityNotFoundError raise EntityNotFoundError("Campaign", campaign_id) # Check: str(campaign.created_by) != str(owner_id) and owner_role != "admin" if str(campaign.created_by) != str(owner_id) and owner_role != "admin": # Raise PermissionViolation raise PermissionViolation("Only the creator or admin can configure scheduling") # Assign campaign.is_recurring = is_recurring campaign.is_recurring = is_recurring # Check: is_recurring if is_recurring: # Check: recurrence_pattern not in ("weekly", "monthly", "quarterly") if recurrence_pattern not in ("weekly", "monthly", "quarterly"): # Raise BusinessRuleViolation raise BusinessRuleViolation( # Literal argument value "recurrence_pattern must be 'weekly', 'monthly', or 'quarterly'" ) # Assign campaign.recurrence_pattern = recurrence_pattern campaign.recurrence_pattern = recurrence_pattern # Check: next_run_at if next_run_at: # Assign campaign.next_run_at = datetime.fromisoformat( campaign.next_run_at = datetime.fromisoformat( next_run_at.replace("Z", "+00:00").replace("+00:00", "") ) # Alternative: not campaign.next_run_at elif not campaign.next_run_at: # Assign campaign.next_run_at = calculate_next_run(datetime.utcnow(), recurrence_pattern) campaign.next_run_at = calculate_next_run(datetime.utcnow(), recurrence_pattern) # Fallback: handle remaining cases else: # Assign campaign.recurrence_pattern = None campaign.recurrence_pattern = None # Assign campaign.next_run_at = None campaign.next_run_at = None # Flush changes to DB without committing the transaction db.flush() # Return campaign return campaign # Define function get_campaign_history def get_campaign_history(db: Session, campaign_id: str) -> dict: """List all child campaigns (execution history) of a recurring campaign. Raises EntityNotFoundError if campaign not found. """ # Assign campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() # Check: not campaign if not campaign: # Raise EntityNotFoundError raise EntityNotFoundError("Campaign", campaign_id) # Assign campaign_uuid = uuid.UUID(campaign_id) campaign_uuid = uuid.UUID(campaign_id) # Assign children = ( children = ( db.query(Campaign) # Chain .filter() call .filter(Campaign.parent_campaign_id == campaign_uuid) # Chain .order_by() call .order_by(Campaign.created_at.desc()) # Chain .all() call .all() ) # Return { return { # Literal argument value "campaign_id": str(campaign.id), # Literal argument value "campaign_name": campaign.name, # Literal argument value "items": [ { # Literal argument value "id": str(child.id), # Literal argument value "name": child.name, # Literal argument value "status": child.status, # Literal argument value "test_count": db.query(CampaignTest).filter(CampaignTest.campaign_id == child.id).count(), # Literal argument value "completion_pct": get_campaign_progress(db, child.id)["completion_pct"], # Literal argument value "created_at": child.created_at.isoformat() if child.created_at else None, # Literal argument value "completed_at": child.completed_at.isoformat() if child.completed_at else None, } for child in children ], }