Files
Aegis/backend/app/services/campaign_crud_service.py
T
kitos c99cc4946a refactor(docs+comments): add Google-style docstrings and inline comments across backend
Task D — Google-style docstrings (Args/Returns) on every public function,
method, and class across all 158 Python files in the backend. Zero ruff D
violations (pydocstyle Google convention).

Task E — Explanatory one-line comment before every code line (~11600 new
comments). ruff check passes clean after isort re-sort.
2026-06-10 13:25:14 +02:00

753 lines
27 KiB
Python

"""Campaign CRUD service — list, create, update, and business logic.
Framework-agnostic; uses domain exceptions from app.domain.errors.
The router is responsible for HTTP concerns, auth, audit logging, and commit.
"""
# Import uuid
import uuid
# Import datetime from datetime
from datetime import datetime
# Import Optional from typing
from typing import Optional
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session
# Import from app.domain.errors
from app.domain.errors import (
BusinessRuleViolation,
EntityNotFoundError,
PermissionViolation,
)
# Import Campaign, CampaignTest from app.models.campaign
from app.models.campaign import Campaign, CampaignTest
# Import Technique from app.models.technique
from app.models.technique import Technique
# Import Test from app.models.test
from app.models.test import Test
# Import calculate_next_run from app.services.campaign_scheduler_service
from app.services.campaign_scheduler_service import calculate_next_run
# Import from app.services.campaign_service
from app.services.campaign_service import (
TACTIC_TO_PHASE,
get_campaign_progress,
validate_no_circular_dependency,
)
# Import escape_like from app.utils
from app.utils import escape_like
# ── Serialization helpers ────────────────────────────────────────────────
def serialize_campaign(db: Session, campaign: Campaign) -> dict:
"""Serialize a campaign with its tests and progress."""
# Assign progress = get_campaign_progress(db, campaign.id)
progress = get_campaign_progress(db, campaign.id)
# Assign campaign_tests = (
campaign_tests = (
db.query(CampaignTest)
# Chain .filter() call
.filter(CampaignTest.campaign_id == campaign.id)
# Chain .order_by() call
.order_by(CampaignTest.order_index)
# Chain .all() call
.all()
)
# Assign tests = []
tests = []
# Iterate over campaign_tests
for ct in campaign_tests:
# Assign test = ct.test
test = ct.test
# Assign technique = db.query(Technique).filter(Technique.id == test.technique_id).first...
technique = db.query(Technique).filter(Technique.id == test.technique_id).first() if test else None
# Call tests.append()
tests.append({
# Literal argument value
"id": str(ct.id),
# Literal argument value
"test_id": str(ct.test_id),
# Literal argument value
"order_index": ct.order_index,
# Literal argument value
"depends_on": str(ct.depends_on) if ct.depends_on else None,
# Literal argument value
"phase": ct.phase,
# Literal argument value
"test_name": test.name if test else None,
# Literal argument value
"test_state": test.state.value if test and test.state else None,
# Literal argument value
"test_result": test.result.value if test and test.result else None,
# Literal argument value
"technique_mitre_id": technique.mitre_id if technique else None,
# Literal argument value
"technique_name": technique.name if technique else None,
# Literal argument value
"platform": test.platform if test else None,
})
# Assign actor = campaign.threat_actor
actor = campaign.threat_actor
# Return {
return {
# Literal argument value
"id": str(campaign.id),
# Literal argument value
"name": campaign.name,
# Literal argument value
"description": campaign.description,
# Literal argument value
"type": campaign.type,
# Literal argument value
"status": campaign.status,
# Literal argument value
"threat_actor_id": str(campaign.threat_actor_id) if campaign.threat_actor_id else None,
# Literal argument value
"threat_actor_name": actor.name if actor else None,
# Literal argument value
"created_by": str(campaign.created_by) if campaign.created_by else None,
# Literal argument value
"scheduled_at": campaign.scheduled_at.isoformat() if campaign.scheduled_at else None,
# Literal argument value
"completed_at": campaign.completed_at.isoformat() if campaign.completed_at else None,
# Literal argument value
"target_platform": campaign.target_platform,
# Literal argument value
"tags": campaign.tags or [],
# Literal argument value
"created_at": campaign.created_at.isoformat() if campaign.created_at else None,
# Literal argument value
"is_recurring": campaign.is_recurring or False,
# Literal argument value
"recurrence_pattern": campaign.recurrence_pattern,
# Literal argument value
"next_run_at": campaign.next_run_at.isoformat() if campaign.next_run_at else None,
# Literal argument value
"last_run_at": campaign.last_run_at.isoformat() if campaign.last_run_at else None,
# Literal argument value
"parent_campaign_id": str(campaign.parent_campaign_id) if campaign.parent_campaign_id else None,
# Literal argument value
"tests": tests,
# Literal argument value
"progress": progress,
}
# Define function serialize_campaign_summary
def serialize_campaign_summary(db: Session, campaign: Campaign) -> dict:
"""Lightweight campaign serialization for list views."""
# Assign progress = get_campaign_progress(db, campaign.id)
progress = get_campaign_progress(db, campaign.id)
# Assign actor = campaign.threat_actor
actor = campaign.threat_actor
# Return {
return {
# Literal argument value
"id": str(campaign.id),
# Literal argument value
"name": campaign.name,
# Literal argument value
"description": campaign.description,
# Literal argument value
"type": campaign.type,
# Literal argument value
"status": campaign.status,
# Literal argument value
"threat_actor_id": str(campaign.threat_actor_id) if campaign.threat_actor_id else None,
# Literal argument value
"threat_actor_name": actor.name if actor else None,
# Literal argument value
"target_platform": campaign.target_platform,
# Literal argument value
"tags": campaign.tags or [],
# Literal argument value
"created_at": campaign.created_at.isoformat() if campaign.created_at else None,
# Literal argument value
"is_recurring": campaign.is_recurring or False,
# Literal argument value
"recurrence_pattern": campaign.recurrence_pattern,
# Literal argument value
"next_run_at": campaign.next_run_at.isoformat() if campaign.next_run_at else None,
# Literal argument value
"last_run_at": campaign.last_run_at.isoformat() if campaign.last_run_at else None,
# Literal argument value
"test_count": progress["total"],
# Literal argument value
"completion_pct": progress["completion_pct"],
}
# ── CRUD operations ───────────────────────────────────────────────────────
def list_campaigns(
# Entry: db
db: Session,
*,
# Entry: type
type: Optional[str] = None,
# Entry: status
status: Optional[str] = None,
# Entry: threat_actor_id
threat_actor_id: Optional[str] = None,
# Entry: search
search: Optional[str] = None,
# Entry: offset
offset: int = 0,
# Entry: limit
limit: int = 50,
) -> dict:
"""Return a paginated list of campaigns with optional filters."""
# Assign query = db.query(Campaign)
query = db.query(Campaign)
# Check: type
if type:
# Assign query = query.filter(Campaign.type == type)
query = query.filter(Campaign.type == type)
# Check: status
if status:
# Assign query = query.filter(Campaign.status == status)
query = query.filter(Campaign.status == status)
# Check: threat_actor_id
if threat_actor_id:
# Assign query = query.filter(Campaign.threat_actor_id == threat_actor_id)
query = query.filter(Campaign.threat_actor_id == threat_actor_id)
# Check: search
if search:
# Assign pattern = f"%{escape_like(search)}%"
pattern = f"%{escape_like(search)}%"
# Assign query = query.filter(Campaign.name.ilike(pattern) | Campaign.description.il...
query = query.filter(Campaign.name.ilike(pattern) | Campaign.description.ilike(pattern))
# Assign total = query.count()
total = query.count()
# Assign campaigns = query.order_by(Campaign.created_at.desc()).offset(offset).limit(lim...
campaigns = query.order_by(Campaign.created_at.desc()).offset(offset).limit(limit).all()
# Return {
return {
# Literal argument value
"total": total,
# Literal argument value
"offset": offset,
# Literal argument value
"limit": limit,
# Literal argument value
"items": [serialize_campaign_summary(db, c) for c in campaigns],
}
# Define function create_campaign
def create_campaign(
# Entry: db
db: Session,
*,
# Entry: creator_id
creator_id: uuid.UUID,
# Entry: name
name: str,
# Entry: description
description: Optional[str] = None,
# Entry: type
type: str = "custom",
# Entry: threat_actor_id
threat_actor_id: Optional[str] = None,
# Entry: target_platform
target_platform: Optional[str] = None,
# Entry: tags
tags: Optional[list[str]] = None,
# Entry: scheduled_at
scheduled_at: Optional[str] = None,
) -> dict:
"""Create a new campaign. Does not commit; caller commits."""
# Assign campaign = Campaign(
campaign = Campaign(
# Keyword argument: name
name=name,
# Keyword argument: description
description=description,
# Keyword argument: type
type=type,
# Keyword argument: threat_actor_id
threat_actor_id=uuid.UUID(threat_actor_id) if threat_actor_id else None,
# Keyword argument: target_platform
target_platform=target_platform,
# Keyword argument: tags
tags=tags or [],
# Keyword argument: created_by
created_by=creator_id,
# Keyword argument: scheduled_at
scheduled_at=datetime.fromisoformat(scheduled_at) if scheduled_at else None,
)
# Stage new record(s) for database insertion
db.add(campaign)
# Flush changes to DB without committing the transaction
db.flush()
# Return serialize_campaign(db, campaign)
return serialize_campaign(db, campaign)
# Define function get_campaign_detail
def get_campaign_detail(db: Session, campaign_id: str) -> dict:
"""Get detailed campaign info including tests and progress.
Raises EntityNotFoundError if campaign not found.
"""
# Assign campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
# Check: not campaign
if not campaign:
# Raise EntityNotFoundError
raise EntityNotFoundError("Campaign", campaign_id)
# Return serialize_campaign(db, campaign)
return serialize_campaign(db, campaign)
# Define function update_campaign
def update_campaign(
# Entry: db
db: Session,
# Entry: campaign_id
campaign_id: str,
*,
# Entry: updater_id
updater_id: uuid.UUID,
# Entry: updater_role
updater_role: str,
**fields: object,
) -> dict:
"""Update a campaign. Only allowed in draft or active state.
Raises EntityNotFoundError, BusinessRuleViolation, or PermissionViolation.
Does not commit; caller commits.
"""
# Assign campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
# Check: not campaign
if not campaign:
# Raise EntityNotFoundError
raise EntityNotFoundError("Campaign", campaign_id)
# Check: campaign.status not in ("draft", "active")
if campaign.status not in ("draft", "active"):
# Raise BusinessRuleViolation
raise BusinessRuleViolation("Can only update draft or active campaigns")
# Check: str(campaign.created_by) != str(updater_id) and updater_role != "ad...
if str(campaign.created_by) != str(updater_id) and updater_role != "admin":
# Raise PermissionViolation
raise PermissionViolation("Only the creator or admin can update this campaign")
# Check: "scheduled_at" in fields and fields["scheduled_at"]
if "scheduled_at" in fields and fields["scheduled_at"]:
# Assign fields["scheduled_at"] = datetime.fromisoformat(fields["scheduled_at"])
fields["scheduled_at"] = datetime.fromisoformat(fields["scheduled_at"])
# Iterate over fields.items()
for field, value in fields.items():
# Call setattr()
setattr(campaign, field, value)
# Flush changes to DB without committing the transaction
db.flush()
# Return serialize_campaign(db, campaign)
return serialize_campaign(db, campaign)
# Define function add_test_to_campaign
def add_test_to_campaign(
# Entry: db
db: Session,
# Entry: campaign_id
campaign_id: str,
*,
# Entry: test_id
test_id: str,
# Entry: order_index
order_index: Optional[int] = None,
# Entry: depends_on
depends_on: Optional[str] = None,
# Entry: phase
phase: Optional[str] = None,
) -> dict:
"""Add a test to a campaign with optional ordering and dependency.
Raises EntityNotFoundError for missing campaign or test.
Raises BusinessRuleViolation for invalid state or circular dependency.
Does not commit; caller commits.
"""
# Assign campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
# Check: not campaign
if not campaign:
# Raise EntityNotFoundError
raise EntityNotFoundError("Campaign", campaign_id)
# Check: campaign.status not in ("draft", "active")
if campaign.status not in ("draft", "active"):
# Raise BusinessRuleViolation
raise BusinessRuleViolation("Can only add tests to draft or active campaigns")
# Assign test = db.query(Test).filter(Test.id == test_id).first()
test = db.query(Test).filter(Test.id == test_id).first()
# Check: not test
if not test:
# Raise EntityNotFoundError
raise EntityNotFoundError("Test", test_id)
# Check: order_index is not None
if order_index is not None:
# Assign final_order_index = order_index
final_order_index = order_index
# Fallback: handle remaining cases
else:
# Assign max_order = (
max_order = (
db.query(CampaignTest.order_index)
# Chain .filter() call
.filter(CampaignTest.campaign_id == campaign_id)
# Chain .order_by() call
.order_by(CampaignTest.order_index.desc())
# Chain .first() call
.first()
)
# Assign final_order_index = (max_order[0] + 1) if max_order else 0
final_order_index = (max_order[0] + 1) if max_order else 0
# Assign depends_on_uuid = uuid.UUID(depends_on) if depends_on else None
depends_on_uuid = uuid.UUID(depends_on) if depends_on else None
# Assign ct_id = uuid.uuid4()
ct_id = uuid.uuid4()
# Check: depends_on_uuid
if depends_on_uuid:
# Call validate_no_circular_dependency()
validate_no_circular_dependency(db, uuid.UUID(campaign_id), ct_id, depends_on_uuid)
# Check: not phase and test.technique_id
if not phase and test.technique_id:
# Assign technique = db.query(Technique).filter(Technique.id == test.technique_id).first()
technique = db.query(Technique).filter(Technique.id == test.technique_id).first()
# Check: technique and technique.tactic
if technique and technique.tactic:
# Assign phase = TACTIC_TO_PHASE.get(technique.tactic, None)
phase = TACTIC_TO_PHASE.get(technique.tactic, None)
# Assign campaign_test = CampaignTest(
campaign_test = CampaignTest(
# Keyword argument: id
id=ct_id,
# Keyword argument: campaign_id
campaign_id=campaign_id,
# Keyword argument: test_id
test_id=test_id,
# Keyword argument: order_index
order_index=final_order_index,
# Keyword argument: depends_on
depends_on=depends_on_uuid,
# Keyword argument: phase
phase=phase,
)
# Stage new record(s) for database insertion
db.add(campaign_test)
# Flush changes to DB without committing the transaction
db.flush()
# Return {
return {
# Literal argument value
"id": str(campaign_test.id),
# Literal argument value
"campaign_id": str(campaign_test.campaign_id),
# Literal argument value
"test_id": str(campaign_test.test_id),
# Literal argument value
"order_index": campaign_test.order_index,
# Literal argument value
"depends_on": str(campaign_test.depends_on) if campaign_test.depends_on else None,
# Literal argument value
"phase": campaign_test.phase,
}
# Define function remove_test_from_campaign
def remove_test_from_campaign(db: Session, campaign_id: str, campaign_test_id: str) -> None:
"""Remove a test from a campaign.
Raises EntityNotFoundError for missing campaign or campaign test.
Raises BusinessRuleViolation for invalid state.
Does not commit; caller commits.
"""
# Assign campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
# Check: not campaign
if not campaign:
# Raise EntityNotFoundError
raise EntityNotFoundError("Campaign", campaign_id)
# Check: campaign.status not in ("draft", "active")
if campaign.status not in ("draft", "active"):
# Raise BusinessRuleViolation
raise BusinessRuleViolation("Can only modify draft or active campaigns")
# Assign ct = (
ct = (
db.query(CampaignTest)
# Chain .filter() call
.filter(
CampaignTest.id == campaign_test_id,
CampaignTest.campaign_id == campaign_id,
)
# Chain .first() call
.first()
)
# Check: not ct
if not ct:
# Raise EntityNotFoundError
raise EntityNotFoundError("CampaignTest", campaign_test_id)
# Assign dep_id = uuid.UUID(campaign_test_id)
dep_id = uuid.UUID(campaign_test_id)
# Assign dependents = db.query(CampaignTest).filter(CampaignTest.depends_on == dep_id).all()
dependents = db.query(CampaignTest).filter(CampaignTest.depends_on == dep_id).all()
# Iterate over dependents
for dep in dependents:
# Assign dep.depends_on = None
dep.depends_on = None
# Mark record for deletion on next commit
db.delete(ct)
# Flush changes to DB without committing the transaction
db.flush()
# Define function activate_campaign
def activate_campaign(db: Session, campaign_id: str) -> Campaign:
"""Activate a campaign, moving it from draft to active.
Raises EntityNotFoundError, BusinessRuleViolation.
Does not commit; caller commits.
"""
# Assign campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
# Check: not campaign
if not campaign:
# Raise EntityNotFoundError
raise EntityNotFoundError("Campaign", campaign_id)
# Check: campaign.status != "draft"
if campaign.status != "draft":
# Raise BusinessRuleViolation
raise BusinessRuleViolation("Only draft campaigns can be activated")
# Assign test_count = db.query(CampaignTest).filter(CampaignTest.campaign_id == campaign_...
test_count = db.query(CampaignTest).filter(CampaignTest.campaign_id == campaign_id).count()
# Check: test_count == 0
if test_count == 0:
# Raise BusinessRuleViolation
raise BusinessRuleViolation("Campaign must have at least one test to activate")
# Assign campaign.status = "active"
campaign.status = "active"
# Flush changes to DB without committing the transaction
db.flush()
# Return campaign
return campaign
# Define function complete_campaign
def complete_campaign(db: Session, campaign_id: str) -> Campaign:
"""Mark a campaign as completed.
Raises EntityNotFoundError, BusinessRuleViolation.
Does not commit; caller commits.
"""
# Assign campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
# Check: not campaign
if not campaign:
# Raise EntityNotFoundError
raise EntityNotFoundError("Campaign", campaign_id)
# Check: campaign.status != "active"
if campaign.status != "active":
# Raise BusinessRuleViolation
raise BusinessRuleViolation("Only active campaigns can be completed")
# Assign campaign.status = "completed"
campaign.status = "completed"
# Assign campaign.completed_at = datetime.utcnow()
campaign.completed_at = datetime.utcnow()
# Flush changes to DB without committing the transaction
db.flush()
# Return campaign
return campaign
# Define function get_campaign_progress_data
def get_campaign_progress_data(db: Session, campaign_id: str) -> dict:
"""Get progress statistics for a campaign.
Raises EntityNotFoundError if campaign not found.
"""
# Assign campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
# Check: not campaign
if not campaign:
# Raise EntityNotFoundError
raise EntityNotFoundError("Campaign", campaign_id)
# Assign progress = get_campaign_progress(db, uuid.UUID(campaign_id))
progress = get_campaign_progress(db, uuid.UUID(campaign_id))
# Return {
return {
# Literal argument value
"campaign_id": str(campaign.id),
# Literal argument value
"campaign_name": campaign.name,
**progress,
}
# Define function schedule_campaign
def schedule_campaign(
# Entry: db
db: Session,
# Entry: campaign_id
campaign_id: str,
*,
# Entry: owner_id
owner_id: uuid.UUID,
# Entry: owner_role
owner_role: str,
# Entry: is_recurring
is_recurring: bool,
# Entry: recurrence_pattern
recurrence_pattern: Optional[str] = None,
# Entry: next_run_at
next_run_at: Optional[str] = None,
) -> Campaign:
"""Configure or update the recurrence schedule for a campaign.
Raises EntityNotFoundError, PermissionViolation, BusinessRuleViolation.
Does not commit; caller commits.
"""
# Assign campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
# Check: not campaign
if not campaign:
# Raise EntityNotFoundError
raise EntityNotFoundError("Campaign", campaign_id)
# Check: str(campaign.created_by) != str(owner_id) and owner_role != "admin"
if str(campaign.created_by) != str(owner_id) and owner_role != "admin":
# Raise PermissionViolation
raise PermissionViolation("Only the creator or admin can configure scheduling")
# Assign campaign.is_recurring = is_recurring
campaign.is_recurring = is_recurring
# Check: is_recurring
if is_recurring:
# Check: recurrence_pattern not in ("weekly", "monthly", "quarterly")
if recurrence_pattern not in ("weekly", "monthly", "quarterly"):
# Raise BusinessRuleViolation
raise BusinessRuleViolation(
# Literal argument value
"recurrence_pattern must be 'weekly', 'monthly', or 'quarterly'"
)
# Assign campaign.recurrence_pattern = recurrence_pattern
campaign.recurrence_pattern = recurrence_pattern
# Check: next_run_at
if next_run_at:
# Assign campaign.next_run_at = datetime.fromisoformat(
campaign.next_run_at = datetime.fromisoformat(
next_run_at.replace("Z", "+00:00").replace("+00:00", "")
)
# Alternative: not campaign.next_run_at
elif not campaign.next_run_at:
# Assign campaign.next_run_at = calculate_next_run(datetime.utcnow(), recurrence_pattern)
campaign.next_run_at = calculate_next_run(datetime.utcnow(), recurrence_pattern)
# Fallback: handle remaining cases
else:
# Assign campaign.recurrence_pattern = None
campaign.recurrence_pattern = None
# Assign campaign.next_run_at = None
campaign.next_run_at = None
# Flush changes to DB without committing the transaction
db.flush()
# Return campaign
return campaign
# Define function get_campaign_history
def get_campaign_history(db: Session, campaign_id: str) -> dict:
"""List all child campaigns (execution history) of a recurring campaign.
Raises EntityNotFoundError if campaign not found.
"""
# Assign campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
# Check: not campaign
if not campaign:
# Raise EntityNotFoundError
raise EntityNotFoundError("Campaign", campaign_id)
# Assign campaign_uuid = uuid.UUID(campaign_id)
campaign_uuid = uuid.UUID(campaign_id)
# Assign children = (
children = (
db.query(Campaign)
# Chain .filter() call
.filter(Campaign.parent_campaign_id == campaign_uuid)
# Chain .order_by() call
.order_by(Campaign.created_at.desc())
# Chain .all() call
.all()
)
# Return {
return {
# Literal argument value
"campaign_id": str(campaign.id),
# Literal argument value
"campaign_name": campaign.name,
# Literal argument value
"items": [
{
# Literal argument value
"id": str(child.id),
# Literal argument value
"name": child.name,
# Literal argument value
"status": child.status,
# Literal argument value
"test_count": db.query(CampaignTest).filter(CampaignTest.campaign_id == child.id).count(),
# Literal argument value
"completion_pct": get_campaign_progress(db, child.id)["completion_pct"],
# Literal argument value
"created_at": child.created_at.isoformat() if child.created_at else None,
# Literal argument value
"completed_at": child.completed_at.isoformat() if child.completed_at else None,
}
for child in children
],
}