Files
Aegis/backend/app/services/campaign_crud_service.py
kitos c62dafbc1f
Some checks failed
Aegis CI / lint-and-test (push) Has been cancelled
feat(campaigns): campaign start date — scheduled activation, Jira start_date
DB: migration b047 adds start_date (DateTime nullable) + index to campaigns.

Backend:
- Campaign model: start_date field
- CampaignCreate/Update schemas: accept start_date (ISO string)
- CRUD service: persist + serialize start_date in both serializers
- Activation endpoint: blocks manual activation if start_date is in the future
  (campaign will auto-activate via scheduler)
- Scheduler: new hourly job _run_scheduled_campaign_activation — finds draft
  campaigns with start_date <= now, activates them, creates Jira tickets,
  notifies red_tech team
- Jira: campaign + test tickets now include JIRA_START_DATE_FIELD (configurable,
  default customfield_10015). Campaign uses start_date if set, else created_at.
  Tests inherit campaign start_date.
- config.py: JIRA_START_DATE_FIELD setting

Frontend:
- Campaign type: start_date field on Campaign + CampaignSummary
- CampaignCreatePayload: start_date optional field
- Create form: date picker with min=today, warning message explaining behavior
- Campaign detail header: start_date badge showing days remaining or started date

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-06-03 16:57:06 +02:00

553 lines
19 KiB
Python

"""Campaign CRUD service — list, create, update, and business logic.
Framework-agnostic; uses domain exceptions from app.domain.errors.
The router is responsible for HTTP concerns, auth, audit logging, and commit.
"""
import uuid
from datetime import datetime
from typing import Optional
from sqlalchemy.orm import Session
from app.domain.errors import (
BusinessRuleViolation,
EntityNotFoundError,
PermissionViolation,
)
from app.models.campaign import Campaign, CampaignTest
from app.models.test import Test
from app.models.technique import Technique
from app.utils import escape_like
from app.services.campaign_service import (
get_campaign_progress,
validate_no_circular_dependency,
TACTIC_TO_PHASE,
)
from app.services.campaign_scheduler_service import calculate_next_run
from app.services.status_service import recalculate_technique_status
# ── Serialization helpers ────────────────────────────────────────────────
def serialize_campaign(db: Session, campaign: Campaign) -> dict:
"""Serialize a campaign with its tests and progress."""
progress = get_campaign_progress(db, campaign.id)
campaign_tests = (
db.query(CampaignTest)
.filter(CampaignTest.campaign_id == campaign.id)
.order_by(CampaignTest.order_index)
.all()
)
tests = []
for ct in campaign_tests:
test = ct.test
technique = db.query(Technique).filter(Technique.id == test.technique_id).first() if test else None
tests.append({
"id": str(ct.id),
"test_id": str(ct.test_id),
"order_index": ct.order_index,
"depends_on": str(ct.depends_on) if ct.depends_on else None,
"phase": ct.phase,
"test_name": test.name if test else None,
"test_state": test.state.value if test and test.state else None,
"test_result": test.result.value if test and test.result else None,
"technique_mitre_id": technique.mitre_id if technique else None,
"technique_name": technique.name if technique else None,
"platform": test.platform if test else None,
})
actor = campaign.threat_actor
return {
"id": str(campaign.id),
"name": campaign.name,
"description": campaign.description,
"type": campaign.type,
"status": campaign.status,
"threat_actor_id": str(campaign.threat_actor_id) if campaign.threat_actor_id else None,
"threat_actor_name": actor.name if actor else None,
"created_by": str(campaign.created_by) if campaign.created_by else None,
"start_date": campaign.start_date.isoformat() if campaign.start_date else None,
"scheduled_at": campaign.scheduled_at.isoformat() if campaign.scheduled_at else None,
"completed_at": campaign.completed_at.isoformat() if campaign.completed_at else None,
"target_platform": campaign.target_platform,
"tags": campaign.tags or [],
"created_at": campaign.created_at.isoformat() if campaign.created_at else None,
"is_recurring": campaign.is_recurring or False,
"recurrence_pattern": campaign.recurrence_pattern,
"next_run_at": campaign.next_run_at.isoformat() if campaign.next_run_at else None,
"last_run_at": campaign.last_run_at.isoformat() if campaign.last_run_at else None,
"parent_campaign_id": str(campaign.parent_campaign_id) if campaign.parent_campaign_id else None,
"tests": tests,
"progress": progress,
}
def serialize_campaign_summary(db: Session, campaign: Campaign) -> dict:
"""Lightweight campaign serialization for list views."""
progress = get_campaign_progress(db, campaign.id)
actor = campaign.threat_actor
return {
"id": str(campaign.id),
"name": campaign.name,
"description": campaign.description,
"type": campaign.type,
"status": campaign.status,
"threat_actor_id": str(campaign.threat_actor_id) if campaign.threat_actor_id else None,
"threat_actor_name": actor.name if actor else None,
"start_date": campaign.start_date.isoformat() if campaign.start_date else None,
"target_platform": campaign.target_platform,
"tags": campaign.tags or [],
"created_at": campaign.created_at.isoformat() if campaign.created_at else None,
"is_recurring": campaign.is_recurring or False,
"recurrence_pattern": campaign.recurrence_pattern,
"next_run_at": campaign.next_run_at.isoformat() if campaign.next_run_at else None,
"last_run_at": campaign.last_run_at.isoformat() if campaign.last_run_at else None,
"test_count": progress["total"],
"completion_pct": progress["completion_pct"],
}
# ── CRUD operations ───────────────────────────────────────────────────────
def list_campaigns(
db: Session,
*,
type: Optional[str] = None,
status: Optional[str] = None,
threat_actor_id: Optional[str] = None,
search: Optional[str] = None,
offset: int = 0,
limit: int = 50,
) -> dict:
"""Return a paginated list of campaigns with optional filters."""
query = db.query(Campaign)
if type:
query = query.filter(Campaign.type == type)
if status:
query = query.filter(Campaign.status == status)
if threat_actor_id:
query = query.filter(Campaign.threat_actor_id == threat_actor_id)
if search:
pattern = f"%{escape_like(search)}%"
query = query.filter(Campaign.name.ilike(pattern) | Campaign.description.ilike(pattern))
total = query.count()
campaigns = query.order_by(Campaign.created_at.desc()).offset(offset).limit(limit).all()
return {
"total": total,
"offset": offset,
"limit": limit,
"items": [serialize_campaign_summary(db, c) for c in campaigns],
}
def create_campaign(
db: Session,
*,
creator_id: uuid.UUID,
name: str,
description: Optional[str] = None,
type: str = "custom",
threat_actor_id: Optional[str] = None,
target_platform: Optional[str] = None,
tags: Optional[list[str]] = None,
scheduled_at: Optional[str] = None,
start_date: Optional[str] = None,
) -> dict:
"""Create a new campaign. Does not commit; caller commits."""
campaign = Campaign(
name=name,
description=description,
type=type,
threat_actor_id=uuid.UUID(threat_actor_id) if threat_actor_id else None,
target_platform=target_platform,
tags=tags or [],
created_by=creator_id,
scheduled_at=datetime.fromisoformat(scheduled_at) if scheduled_at else None,
start_date=datetime.fromisoformat(start_date) if start_date else None,
)
db.add(campaign)
db.flush()
return serialize_campaign(db, campaign)
def get_campaign_detail(db: Session, campaign_id: str) -> dict:
"""Get detailed campaign info including tests and progress.
Raises EntityNotFoundError if campaign not found.
"""
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
if not campaign:
raise EntityNotFoundError("Campaign", campaign_id)
return serialize_campaign(db, campaign)
def update_campaign(
db: Session,
campaign_id: str,
*,
updater_id: uuid.UUID,
updater_role: str,
**fields,
) -> dict:
"""Update a campaign. Only allowed in draft or active state.
Raises EntityNotFoundError, BusinessRuleViolation, or PermissionViolation.
Does not commit; caller commits.
"""
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
if not campaign:
raise EntityNotFoundError("Campaign", campaign_id)
if campaign.status not in ("draft", "active"):
raise BusinessRuleViolation("Can only update draft or active campaigns")
if str(campaign.created_by) != str(updater_id) and updater_role != "admin":
raise PermissionViolation("Only the creator or admin can update this campaign")
if "scheduled_at" in fields and fields["scheduled_at"]:
fields["scheduled_at"] = datetime.fromisoformat(fields["scheduled_at"])
if "start_date" in fields and fields["start_date"]:
fields["start_date"] = datetime.fromisoformat(fields["start_date"])
for field, value in fields.items():
setattr(campaign, field, value)
db.flush()
return serialize_campaign(db, campaign)
def add_test_to_campaign(
db: Session,
campaign_id: str,
*,
test_id: str,
order_index: Optional[int] = None,
depends_on: Optional[str] = None,
phase: Optional[str] = None,
) -> dict:
"""Add a test to a campaign with optional ordering and dependency.
Raises EntityNotFoundError for missing campaign or test.
Raises BusinessRuleViolation for invalid state or circular dependency.
Does not commit; caller commits.
"""
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
if not campaign:
raise EntityNotFoundError("Campaign", campaign_id)
if campaign.status not in ("draft", "active"):
raise BusinessRuleViolation("Can only add tests to draft or active campaigns")
test = db.query(Test).filter(Test.id == test_id).first()
if not test:
raise EntityNotFoundError("Test", test_id)
if order_index is not None:
final_order_index = order_index
else:
max_order = (
db.query(CampaignTest.order_index)
.filter(CampaignTest.campaign_id == campaign_id)
.order_by(CampaignTest.order_index.desc())
.first()
)
final_order_index = (max_order[0] + 1) if max_order else 0
depends_on_uuid = uuid.UUID(depends_on) if depends_on else None
ct_id = uuid.uuid4()
if depends_on_uuid:
validate_no_circular_dependency(db, uuid.UUID(campaign_id), ct_id, depends_on_uuid)
if not phase and test.technique_id:
technique = db.query(Technique).filter(Technique.id == test.technique_id).first()
if technique and technique.tactic:
phase = TACTIC_TO_PHASE.get(technique.tactic, None)
campaign_test = CampaignTest(
id=ct_id,
campaign_id=campaign_id,
test_id=test_id,
order_index=final_order_index,
depends_on=depends_on_uuid,
phase=phase,
)
db.add(campaign_test)
db.flush()
return {
"id": str(campaign_test.id),
"campaign_id": str(campaign_test.campaign_id),
"test_id": str(campaign_test.test_id),
"order_index": campaign_test.order_index,
"depends_on": str(campaign_test.depends_on) if campaign_test.depends_on else None,
"phase": campaign_test.phase,
}
def remove_test_from_campaign(db: Session, campaign_id: str, campaign_test_id: str) -> None:
"""Remove a test from a campaign.
Raises EntityNotFoundError for missing campaign or campaign test.
Raises BusinessRuleViolation for invalid state.
Does not commit; caller commits.
"""
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
if not campaign:
raise EntityNotFoundError("Campaign", campaign_id)
if campaign.status not in ("draft", "active"):
raise BusinessRuleViolation("Can only modify draft or active campaigns")
ct = (
db.query(CampaignTest)
.filter(
CampaignTest.id == campaign_test_id,
CampaignTest.campaign_id == campaign_id,
)
.first()
)
if not ct:
raise EntityNotFoundError("CampaignTest", campaign_test_id)
dep_id = uuid.UUID(campaign_test_id)
dependents = db.query(CampaignTest).filter(CampaignTest.depends_on == dep_id).all()
for dep in dependents:
dep.depends_on = None
# Keep a reference to the underlying test before deleting the join record
test_id = ct.test_id
technique_id = None
test_obj = db.query(Test).filter(Test.id == test_id).first()
if test_obj:
technique_id = test_obj.technique_id
db.delete(ct)
db.flush()
# Also delete the actual test record (it was created for this campaign)
if test_obj:
db.delete(test_obj)
db.flush()
# Recalculate technique status_global so coverage metrics stay consistent
if technique_id:
technique = db.query(Technique).filter(Technique.id == technique_id).first()
if technique:
recalculate_technique_status(db, technique)
db.flush()
def activate_campaign(db: Session, campaign_id: str) -> Campaign:
"""Activate a campaign, moving it from draft to active.
Raises EntityNotFoundError, BusinessRuleViolation.
Does not commit; caller commits.
"""
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
if not campaign:
raise EntityNotFoundError("Campaign", campaign_id)
if campaign.status != "draft":
raise BusinessRuleViolation("Only draft campaigns can be activated")
test_count = db.query(CampaignTest).filter(CampaignTest.campaign_id == campaign_id).count()
if test_count == 0:
raise BusinessRuleViolation("Campaign must have at least one test to activate")
campaign.status = "active"
db.flush()
return campaign
def complete_campaign(db: Session, campaign_id: str) -> Campaign:
"""Mark a campaign as completed.
Raises EntityNotFoundError, BusinessRuleViolation.
Does not commit; caller commits.
"""
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
if not campaign:
raise EntityNotFoundError("Campaign", campaign_id)
if campaign.status != "active":
raise BusinessRuleViolation("Only active campaigns can be completed")
campaign.status = "completed"
campaign.completed_at = datetime.utcnow()
db.flush()
return campaign
def get_campaign_progress_data(db: Session, campaign_id: str) -> dict:
"""Get progress statistics for a campaign.
Raises EntityNotFoundError if campaign not found.
"""
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
if not campaign:
raise EntityNotFoundError("Campaign", campaign_id)
progress = get_campaign_progress(db, uuid.UUID(campaign_id))
return {
"campaign_id": str(campaign.id),
"campaign_name": campaign.name,
**progress,
}
def schedule_campaign(
db: Session,
campaign_id: str,
*,
owner_id: uuid.UUID,
owner_role: str,
is_recurring: bool,
recurrence_pattern: Optional[str] = None,
next_run_at: Optional[str] = None,
) -> Campaign:
"""Configure or update the recurrence schedule for a campaign.
Raises EntityNotFoundError, PermissionViolation, BusinessRuleViolation.
Does not commit; caller commits.
"""
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
if not campaign:
raise EntityNotFoundError("Campaign", campaign_id)
if str(campaign.created_by) != str(owner_id) and owner_role != "admin":
raise PermissionViolation("Only the creator or admin can configure scheduling")
campaign.is_recurring = is_recurring
if is_recurring:
if recurrence_pattern not in ("weekly", "monthly", "quarterly"):
raise BusinessRuleViolation(
"recurrence_pattern must be 'weekly', 'monthly', or 'quarterly'"
)
campaign.recurrence_pattern = recurrence_pattern
if next_run_at:
campaign.next_run_at = datetime.fromisoformat(
next_run_at.replace("Z", "+00:00").replace("+00:00", "")
)
elif not campaign.next_run_at:
campaign.next_run_at = calculate_next_run(datetime.utcnow(), recurrence_pattern)
else:
campaign.recurrence_pattern = None
campaign.next_run_at = None
db.flush()
return campaign
def delete_campaign(
db: Session,
campaign_id: str,
*,
deleter_id: uuid.UUID,
deleter_role: str,
delete_tests: bool = False,
) -> None:
"""Delete a campaign.
Only draft campaigns can be deleted unless the caller is admin.
If delete_tests=True, the associated Test objects are also deleted.
Does not commit; caller commits.
"""
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
if not campaign:
raise EntityNotFoundError("Campaign", campaign_id)
if campaign.status != "draft" and deleter_role != "admin":
raise BusinessRuleViolation("Only draft campaigns can be deleted")
if str(campaign.created_by) != str(deleter_id) and deleter_role != "admin":
raise PermissionViolation("Only the creator or admin can delete this campaign")
# Collect test IDs before removing associations
campaign_tests = (
db.query(CampaignTest).filter(CampaignTest.campaign_id == campaign_id).all()
)
test_ids = [ct.test_id for ct in campaign_tests]
# Remove CampaignTest join rows (clear depends_on refs first to avoid FK cycles)
for ct in campaign_tests:
ct.depends_on = None
db.flush()
for ct in campaign_tests:
db.delete(ct)
db.flush()
# Optionally delete the associated tests
if delete_tests:
affected_technique_ids: set = set()
for test_id in test_ids:
test = db.query(Test).filter(Test.id == test_id).first()
if test:
if test.technique_id:
affected_technique_ids.add(test.technique_id)
db.delete(test)
db.flush()
# Recalculate status_global for every affected technique so the
# coverage metrics stay consistent after test deletion.
for tech_id in affected_technique_ids:
technique = db.query(Technique).filter(Technique.id == tech_id).first()
if technique:
recalculate_technique_status(db, technique)
db.flush()
# Null-out parent_campaign_id on child campaigns to avoid FK violation
db.query(Campaign).filter(Campaign.parent_campaign_id == campaign.id).update(
{"parent_campaign_id": None}
)
db.flush()
db.delete(campaign)
db.flush()
def get_campaign_history(db: Session, campaign_id: str) -> dict:
"""List all child campaigns (execution history) of a recurring campaign.
Raises EntityNotFoundError if campaign not found.
"""
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
if not campaign:
raise EntityNotFoundError("Campaign", campaign_id)
campaign_uuid = uuid.UUID(campaign_id)
children = (
db.query(Campaign)
.filter(Campaign.parent_campaign_id == campaign_uuid)
.order_by(Campaign.created_at.desc())
.all()
)
return {
"campaign_id": str(campaign.id),
"campaign_name": campaign.name,
"items": [
{
"id": str(child.id),
"name": child.name,
"status": child.status,
"test_count": db.query(CampaignTest).filter(CampaignTest.campaign_id == child.id).count(),
"completion_pct": get_campaign_progress(db, child.id)["completion_pct"],
"created_at": child.created_at.isoformat() if child.created_at else None,
"completed_at": child.completed_at.isoformat() if child.completed_at else None,
}
for child in children
],
}