refactor(campaigns): extract CRUD/business logic to campaign_crud_service, use domain exceptions
This commit is contained in:
@@ -7,26 +7,29 @@ test ordering, progress tracking, and threat actor integration.
|
|||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
from fastapi import APIRouter, Depends, Query
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
from app.dependencies.auth import get_current_user, require_any_role
|
from app.dependencies.auth import get_current_user, require_any_role
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.models.campaign import Campaign, CampaignTest, KILL_CHAIN_PHASES
|
from app.services.campaign_service import generate_campaign_from_threat_actor
|
||||||
from app.models.test import Test
|
from app.services.campaign_crud_service import (
|
||||||
from app.models.technique import Technique
|
add_test_to_campaign as crud_add_test,
|
||||||
from app.models.threat_actor import ThreatActor
|
activate_campaign as crud_activate,
|
||||||
from app.services.campaign_service import (
|
complete_campaign as crud_complete,
|
||||||
validate_no_circular_dependency,
|
create_campaign as crud_create,
|
||||||
get_campaign_progress,
|
get_campaign_detail as crud_get_detail,
|
||||||
generate_campaign_from_threat_actor,
|
get_campaign_history as crud_get_history,
|
||||||
TACTIC_TO_PHASE,
|
get_campaign_progress_data as crud_get_progress,
|
||||||
|
list_campaigns as crud_list,
|
||||||
|
remove_test_from_campaign as crud_remove_test,
|
||||||
|
schedule_campaign as crud_schedule,
|
||||||
|
serialize_campaign,
|
||||||
|
update_campaign as crud_update,
|
||||||
)
|
)
|
||||||
from app.services.campaign_scheduler_service import calculate_next_run
|
|
||||||
from app.services.notification_service import create_notification
|
from app.services.notification_service import create_notification
|
||||||
from app.services.audit_service import log_action
|
from app.services.audit_service import log_action
|
||||||
|
|
||||||
@@ -67,89 +70,6 @@ class SchedulePayload(BaseModel):
|
|||||||
next_run_at: Optional[str] = None
|
next_run_at: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
# ── Helpers ──────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def _serialize_campaign(db: Session, campaign: Campaign) -> dict:
|
|
||||||
"""Serialize a campaign with its tests and progress."""
|
|
||||||
progress = get_campaign_progress(db, campaign.id)
|
|
||||||
|
|
||||||
campaign_tests = (
|
|
||||||
db.query(CampaignTest)
|
|
||||||
.filter(CampaignTest.campaign_id == campaign.id)
|
|
||||||
.order_by(CampaignTest.order_index)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
|
|
||||||
tests = []
|
|
||||||
for ct in campaign_tests:
|
|
||||||
test = ct.test
|
|
||||||
technique = db.query(Technique).filter(Technique.id == test.technique_id).first() if test else None
|
|
||||||
|
|
||||||
tests.append({
|
|
||||||
"id": str(ct.id),
|
|
||||||
"test_id": str(ct.test_id),
|
|
||||||
"order_index": ct.order_index,
|
|
||||||
"depends_on": str(ct.depends_on) if ct.depends_on else None,
|
|
||||||
"phase": ct.phase,
|
|
||||||
"test_name": test.name if test else None,
|
|
||||||
"test_state": test.state.value if test and test.state else None,
|
|
||||||
"test_result": test.result.value if test and test.result else None,
|
|
||||||
"technique_mitre_id": technique.mitre_id if technique else None,
|
|
||||||
"technique_name": technique.name if technique else None,
|
|
||||||
"platform": test.platform if test else None,
|
|
||||||
})
|
|
||||||
|
|
||||||
actor = campaign.threat_actor
|
|
||||||
|
|
||||||
return {
|
|
||||||
"id": str(campaign.id),
|
|
||||||
"name": campaign.name,
|
|
||||||
"description": campaign.description,
|
|
||||||
"type": campaign.type,
|
|
||||||
"status": campaign.status,
|
|
||||||
"threat_actor_id": str(campaign.threat_actor_id) if campaign.threat_actor_id else None,
|
|
||||||
"threat_actor_name": actor.name if actor else None,
|
|
||||||
"created_by": str(campaign.created_by) if campaign.created_by else None,
|
|
||||||
"scheduled_at": campaign.scheduled_at.isoformat() if campaign.scheduled_at else None,
|
|
||||||
"completed_at": campaign.completed_at.isoformat() if campaign.completed_at else None,
|
|
||||||
"target_platform": campaign.target_platform,
|
|
||||||
"tags": campaign.tags or [],
|
|
||||||
"created_at": campaign.created_at.isoformat() if campaign.created_at else None,
|
|
||||||
"is_recurring": campaign.is_recurring or False,
|
|
||||||
"recurrence_pattern": campaign.recurrence_pattern,
|
|
||||||
"next_run_at": campaign.next_run_at.isoformat() if campaign.next_run_at else None,
|
|
||||||
"last_run_at": campaign.last_run_at.isoformat() if campaign.last_run_at else None,
|
|
||||||
"parent_campaign_id": str(campaign.parent_campaign_id) if campaign.parent_campaign_id else None,
|
|
||||||
"tests": tests,
|
|
||||||
"progress": progress,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _serialize_campaign_summary(db: Session, campaign: Campaign) -> dict:
|
|
||||||
"""Lightweight campaign serialization for list views."""
|
|
||||||
progress = get_campaign_progress(db, campaign.id)
|
|
||||||
actor = campaign.threat_actor
|
|
||||||
|
|
||||||
return {
|
|
||||||
"id": str(campaign.id),
|
|
||||||
"name": campaign.name,
|
|
||||||
"description": campaign.description,
|
|
||||||
"type": campaign.type,
|
|
||||||
"status": campaign.status,
|
|
||||||
"threat_actor_id": str(campaign.threat_actor_id) if campaign.threat_actor_id else None,
|
|
||||||
"threat_actor_name": actor.name if actor else None,
|
|
||||||
"target_platform": campaign.target_platform,
|
|
||||||
"tags": campaign.tags or [],
|
|
||||||
"created_at": campaign.created_at.isoformat() if campaign.created_at else None,
|
|
||||||
"is_recurring": campaign.is_recurring or False,
|
|
||||||
"recurrence_pattern": campaign.recurrence_pattern,
|
|
||||||
"next_run_at": campaign.next_run_at.isoformat() if campaign.next_run_at else None,
|
|
||||||
"last_run_at": campaign.last_run_at.isoformat() if campaign.last_run_at else None,
|
|
||||||
"test_count": progress["total"],
|
|
||||||
"completion_pct": progress["completion_pct"],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# GET /campaigns — List campaigns with filters
|
# GET /campaigns — List campaigns with filters
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -166,28 +86,15 @@ def list_campaigns(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""List campaigns with optional filters and pagination."""
|
"""List campaigns with optional filters and pagination."""
|
||||||
query = db.query(Campaign)
|
return crud_list(
|
||||||
|
db,
|
||||||
if type:
|
type=type,
|
||||||
query = query.filter(Campaign.type == type)
|
status=status,
|
||||||
if status:
|
threat_actor_id=threat_actor_id,
|
||||||
query = query.filter(Campaign.status == status)
|
search=search,
|
||||||
if threat_actor_id:
|
offset=offset,
|
||||||
query = query.filter(Campaign.threat_actor_id == threat_actor_id)
|
limit=limit,
|
||||||
if search:
|
)
|
||||||
from app.utils import escape_like
|
|
||||||
pattern = f"%{escape_like(search)}%"
|
|
||||||
query = query.filter(Campaign.name.ilike(pattern) | Campaign.description.ilike(pattern))
|
|
||||||
|
|
||||||
total = query.count()
|
|
||||||
campaigns = query.order_by(Campaign.created_at.desc()).offset(offset).limit(limit).all()
|
|
||||||
|
|
||||||
return {
|
|
||||||
"total": total,
|
|
||||||
"offset": offset,
|
|
||||||
"limit": limit,
|
|
||||||
"items": [_serialize_campaign_summary(db, c) for c in campaigns],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -201,30 +108,29 @@ def create_campaign(
|
|||||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||||
):
|
):
|
||||||
"""Create a new campaign."""
|
"""Create a new campaign."""
|
||||||
campaign = Campaign(
|
result = crud_create(
|
||||||
|
db,
|
||||||
|
creator_id=current_user.id,
|
||||||
name=payload.name,
|
name=payload.name,
|
||||||
description=payload.description,
|
description=payload.description,
|
||||||
type=payload.type,
|
type=payload.type,
|
||||||
threat_actor_id=payload.threat_actor_id,
|
threat_actor_id=payload.threat_actor_id,
|
||||||
target_platform=payload.target_platform,
|
target_platform=payload.target_platform,
|
||||||
tags=payload.tags or [],
|
tags=payload.tags,
|
||||||
created_by=current_user.id,
|
scheduled_at=payload.scheduled_at,
|
||||||
scheduled_at=datetime.fromisoformat(payload.scheduled_at) if payload.scheduled_at else None,
|
|
||||||
)
|
)
|
||||||
db.add(campaign)
|
|
||||||
db.commit()
|
|
||||||
db.refresh(campaign)
|
|
||||||
|
|
||||||
log_action(
|
log_action(
|
||||||
db,
|
db,
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
action="create_campaign",
|
action="create_campaign",
|
||||||
entity_type="campaign",
|
entity_type="campaign",
|
||||||
entity_id=campaign.id,
|
entity_id=result["id"],
|
||||||
details={"name": campaign.name, "type": campaign.type},
|
details={"name": payload.name, "type": payload.type},
|
||||||
)
|
)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
return _serialize_campaign(db, campaign)
|
return result
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -238,11 +144,7 @@ def get_campaign(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Get detailed campaign info including tests and progress."""
|
"""Get detailed campaign info including tests and progress."""
|
||||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
return crud_get_detail(db, campaign_id)
|
||||||
if not campaign:
|
|
||||||
raise HTTPException(status_code=404, detail="Campaign not found")
|
|
||||||
|
|
||||||
return _serialize_campaign(db, campaign)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -257,37 +159,26 @@ def update_campaign(
|
|||||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||||
):
|
):
|
||||||
"""Update a campaign. Only allowed in draft or active state."""
|
"""Update a campaign. Only allowed in draft or active state."""
|
||||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
|
||||||
if not campaign:
|
|
||||||
raise HTTPException(status_code=404, detail="Campaign not found")
|
|
||||||
|
|
||||||
if campaign.status not in ("draft", "active"):
|
|
||||||
raise HTTPException(status_code=400, detail="Can only update draft or active campaigns")
|
|
||||||
|
|
||||||
# Check ownership or admin
|
|
||||||
if str(campaign.created_by) != str(current_user.id) and current_user.role != "admin":
|
|
||||||
raise HTTPException(status_code=403, detail="Only the creator or admin can update this campaign")
|
|
||||||
|
|
||||||
update_data = payload.model_dump(exclude_unset=True)
|
update_data = payload.model_dump(exclude_unset=True)
|
||||||
if "scheduled_at" in update_data and update_data["scheduled_at"]:
|
result = crud_update(
|
||||||
update_data["scheduled_at"] = datetime.fromisoformat(update_data["scheduled_at"])
|
db,
|
||||||
|
campaign_id,
|
||||||
for field, value in update_data.items():
|
updater_id=current_user.id,
|
||||||
setattr(campaign, field, value)
|
updater_role=current_user.role,
|
||||||
|
**update_data,
|
||||||
db.commit()
|
)
|
||||||
db.refresh(campaign)
|
|
||||||
|
|
||||||
log_action(
|
log_action(
|
||||||
db,
|
db,
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
action="update_campaign",
|
action="update_campaign",
|
||||||
entity_type="campaign",
|
entity_type="campaign",
|
||||||
entity_id=campaign.id,
|
entity_id=campaign_id,
|
||||||
details={"updated_fields": list(update_data.keys())},
|
details={"updated_fields": list(update_data.keys())},
|
||||||
)
|
)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
return _serialize_campaign(db, campaign)
|
return result
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -302,63 +193,16 @@ def add_test_to_campaign(
|
|||||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||||
):
|
):
|
||||||
"""Add a test to a campaign with optional ordering and dependency."""
|
"""Add a test to a campaign with optional ordering and dependency."""
|
||||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
result = crud_add_test(
|
||||||
if not campaign:
|
db,
|
||||||
raise HTTPException(status_code=404, detail="Campaign not found")
|
campaign_id,
|
||||||
|
|
||||||
if campaign.status not in ("draft", "active"):
|
|
||||||
raise HTTPException(status_code=400, detail="Can only add tests to draft or active campaigns")
|
|
||||||
|
|
||||||
test = db.query(Test).filter(Test.id == payload.test_id).first()
|
|
||||||
if not test:
|
|
||||||
raise HTTPException(status_code=404, detail="Test not found")
|
|
||||||
|
|
||||||
# Calculate order_index if not provided
|
|
||||||
if payload.order_index is not None:
|
|
||||||
order_index = payload.order_index
|
|
||||||
else:
|
|
||||||
max_order = (
|
|
||||||
db.query(CampaignTest.order_index)
|
|
||||||
.filter(CampaignTest.campaign_id == campaign_id)
|
|
||||||
.order_by(CampaignTest.order_index.desc())
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
order_index = (max_order[0] + 1) if max_order else 0
|
|
||||||
|
|
||||||
depends_on = uuid.UUID(payload.depends_on) if payload.depends_on else None
|
|
||||||
|
|
||||||
# Validate circular dependency
|
|
||||||
ct_id = uuid.uuid4()
|
|
||||||
if depends_on:
|
|
||||||
validate_no_circular_dependency(db, uuid.UUID(campaign_id), ct_id, depends_on)
|
|
||||||
|
|
||||||
# Auto-detect kill chain phase from the test's technique tactic if not provided
|
|
||||||
phase = payload.phase
|
|
||||||
if not phase and test.technique_id:
|
|
||||||
technique = db.query(Technique).filter(Technique.id == test.technique_id).first()
|
|
||||||
if technique and technique.tactic:
|
|
||||||
phase = TACTIC_TO_PHASE.get(technique.tactic, None)
|
|
||||||
|
|
||||||
campaign_test = CampaignTest(
|
|
||||||
id=ct_id,
|
|
||||||
campaign_id=campaign_id,
|
|
||||||
test_id=payload.test_id,
|
test_id=payload.test_id,
|
||||||
order_index=order_index,
|
order_index=payload.order_index,
|
||||||
depends_on=depends_on,
|
depends_on=payload.depends_on,
|
||||||
phase=phase,
|
phase=payload.phase,
|
||||||
)
|
)
|
||||||
db.add(campaign_test)
|
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(campaign_test)
|
return result
|
||||||
|
|
||||||
return {
|
|
||||||
"id": str(campaign_test.id),
|
|
||||||
"campaign_id": str(campaign_test.campaign_id),
|
|
||||||
"test_id": str(campaign_test.test_id),
|
|
||||||
"order_index": campaign_test.order_index,
|
|
||||||
"depends_on": str(campaign_test.depends_on) if campaign_test.depends_on else None,
|
|
||||||
"phase": campaign_test.phase,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -373,36 +217,8 @@ def remove_test_from_campaign(
|
|||||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||||
):
|
):
|
||||||
"""Remove a test from a campaign."""
|
"""Remove a test from a campaign."""
|
||||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
crud_remove_test(db, campaign_id, campaign_test_id)
|
||||||
if not campaign:
|
|
||||||
raise HTTPException(status_code=404, detail="Campaign not found")
|
|
||||||
|
|
||||||
if campaign.status not in ("draft", "active"):
|
|
||||||
raise HTTPException(status_code=400, detail="Can only modify draft or active campaigns")
|
|
||||||
|
|
||||||
ct = (
|
|
||||||
db.query(CampaignTest)
|
|
||||||
.filter(
|
|
||||||
CampaignTest.id == campaign_test_id,
|
|
||||||
CampaignTest.campaign_id == campaign_id,
|
|
||||||
)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
if not ct:
|
|
||||||
raise HTTPException(status_code=404, detail="Campaign test not found")
|
|
||||||
|
|
||||||
# Clear any references to this campaign_test
|
|
||||||
dependents = (
|
|
||||||
db.query(CampaignTest)
|
|
||||||
.filter(CampaignTest.depends_on == campaign_test_id)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
for dep in dependents:
|
|
||||||
dep.depends_on = None
|
|
||||||
|
|
||||||
db.delete(ct)
|
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
return {"detail": "Test removed from campaign"}
|
return {"detail": "Test removed from campaign"}
|
||||||
|
|
||||||
|
|
||||||
@@ -417,23 +233,10 @@ def activate_campaign(
|
|||||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||||
):
|
):
|
||||||
"""Activate a campaign, moving it from draft to active."""
|
"""Activate a campaign, moving it from draft to active."""
|
||||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
campaign = crud_activate(db, campaign_id)
|
||||||
if not campaign:
|
|
||||||
raise HTTPException(status_code=404, detail="Campaign not found")
|
|
||||||
|
|
||||||
if campaign.status != "draft":
|
|
||||||
raise HTTPException(status_code=400, detail="Only draft campaigns can be activated")
|
|
||||||
|
|
||||||
# Verify campaign has at least one test
|
|
||||||
test_count = db.query(CampaignTest).filter(CampaignTest.campaign_id == campaign_id).count()
|
|
||||||
if test_count == 0:
|
|
||||||
raise HTTPException(status_code=400, detail="Campaign must have at least one test to activate")
|
|
||||||
|
|
||||||
campaign.status = "active"
|
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(campaign)
|
db.refresh(campaign)
|
||||||
|
|
||||||
# Notify relevant users
|
|
||||||
red_techs = db.query(User).filter(User.role == "red_tech", User.is_active == True).all() # noqa: E712
|
red_techs = db.query(User).filter(User.role == "red_tech", User.is_active == True).all() # noqa: E712
|
||||||
for user in red_techs:
|
for user in red_techs:
|
||||||
create_notification(
|
create_notification(
|
||||||
@@ -455,7 +258,7 @@ def activate_campaign(
|
|||||||
details={"name": campaign.name},
|
details={"name": campaign.name},
|
||||||
)
|
)
|
||||||
|
|
||||||
return _serialize_campaign(db, campaign)
|
return serialize_campaign(db, campaign)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -469,15 +272,7 @@ def complete_campaign(
|
|||||||
current_user: User = Depends(require_any_role("red_lead", "admin")),
|
current_user: User = Depends(require_any_role("red_lead", "admin")),
|
||||||
):
|
):
|
||||||
"""Mark a campaign as completed."""
|
"""Mark a campaign as completed."""
|
||||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
campaign = crud_complete(db, campaign_id)
|
||||||
if not campaign:
|
|
||||||
raise HTTPException(status_code=404, detail="Campaign not found")
|
|
||||||
|
|
||||||
if campaign.status != "active":
|
|
||||||
raise HTTPException(status_code=400, detail="Only active campaigns can be completed")
|
|
||||||
|
|
||||||
campaign.status = "completed"
|
|
||||||
campaign.completed_at = datetime.utcnow()
|
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(campaign)
|
db.refresh(campaign)
|
||||||
|
|
||||||
@@ -490,7 +285,7 @@ def complete_campaign(
|
|||||||
details={"name": campaign.name},
|
details={"name": campaign.name},
|
||||||
)
|
)
|
||||||
|
|
||||||
return _serialize_campaign(db, campaign)
|
return serialize_campaign(db, campaign)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -504,16 +299,7 @@ def get_campaign_progress_endpoint(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Get progress statistics for a campaign."""
|
"""Get progress statistics for a campaign."""
|
||||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
return crud_get_progress(db, campaign_id)
|
||||||
if not campaign:
|
|
||||||
raise HTTPException(status_code=404, detail="Campaign not found")
|
|
||||||
|
|
||||||
progress = get_campaign_progress(db, uuid.UUID(campaign_id))
|
|
||||||
return {
|
|
||||||
"campaign_id": str(campaign.id),
|
|
||||||
"campaign_name": campaign.name,
|
|
||||||
**progress,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -546,7 +332,7 @@ def generate_campaign_from_actor(
|
|||||||
details={"actor_id": actor_id, "campaign_name": campaign.name},
|
details={"actor_id": actor_id, "campaign_name": campaign.name},
|
||||||
)
|
)
|
||||||
|
|
||||||
return _serialize_campaign(db, campaign)
|
return serialize_campaign(db, campaign)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -564,31 +350,15 @@ def schedule_campaign(
|
|||||||
|
|
||||||
Only the campaign creator or admin can change scheduling.
|
Only the campaign creator or admin can change scheduling.
|
||||||
"""
|
"""
|
||||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
campaign = crud_schedule(
|
||||||
if not campaign:
|
db,
|
||||||
raise HTTPException(status_code=404, detail="Campaign not found")
|
campaign_id,
|
||||||
|
owner_id=current_user.id,
|
||||||
# Check ownership or admin
|
owner_role=current_user.role,
|
||||||
if str(campaign.created_by) != str(current_user.id) and current_user.role != "admin":
|
is_recurring=payload.is_recurring,
|
||||||
raise HTTPException(status_code=403, detail="Only the creator or admin can configure scheduling")
|
recurrence_pattern=payload.recurrence_pattern,
|
||||||
|
next_run_at=payload.next_run_at,
|
||||||
campaign.is_recurring = payload.is_recurring
|
)
|
||||||
|
|
||||||
if payload.is_recurring:
|
|
||||||
if payload.recurrence_pattern not in ("weekly", "monthly", "quarterly"):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail="recurrence_pattern must be 'weekly', 'monthly', or 'quarterly'",
|
|
||||||
)
|
|
||||||
campaign.recurrence_pattern = payload.recurrence_pattern
|
|
||||||
if payload.next_run_at:
|
|
||||||
campaign.next_run_at = datetime.fromisoformat(payload.next_run_at.replace("Z", "+00:00").replace("+00:00", ""))
|
|
||||||
elif not campaign.next_run_at:
|
|
||||||
campaign.next_run_at = calculate_next_run(datetime.utcnow(), payload.recurrence_pattern)
|
|
||||||
else:
|
|
||||||
campaign.recurrence_pattern = None
|
|
||||||
campaign.next_run_at = None
|
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(campaign)
|
db.refresh(campaign)
|
||||||
|
|
||||||
@@ -605,7 +375,7 @@ def schedule_campaign(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
return _serialize_campaign(db, campaign)
|
return serialize_campaign(db, campaign)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -619,30 +389,4 @@ def get_campaign_history(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""List all child campaigns (execution history) of a recurring campaign."""
|
"""List all child campaigns (execution history) of a recurring campaign."""
|
||||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
return crud_get_history(db, campaign_id)
|
||||||
if not campaign:
|
|
||||||
raise HTTPException(status_code=404, detail="Campaign not found")
|
|
||||||
|
|
||||||
children = (
|
|
||||||
db.query(Campaign)
|
|
||||||
.filter(Campaign.parent_campaign_id == campaign_id)
|
|
||||||
.order_by(Campaign.created_at.desc())
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"campaign_id": str(campaign.id),
|
|
||||||
"campaign_name": campaign.name,
|
|
||||||
"items": [
|
|
||||||
{
|
|
||||||
"id": str(child.id),
|
|
||||||
"name": child.name,
|
|
||||||
"status": child.status,
|
|
||||||
"test_count": db.query(CampaignTest).filter(CampaignTest.campaign_id == child.id).count(),
|
|
||||||
"completion_pct": get_campaign_progress(db, child.id)["completion_pct"],
|
|
||||||
"created_at": child.created_at.isoformat() if child.created_at else None,
|
|
||||||
"completed_at": child.completed_at.isoformat() if child.completed_at else None,
|
|
||||||
}
|
|
||||||
for child in children
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|||||||
460
backend/app/services/campaign_crud_service.py
Normal file
460
backend/app/services/campaign_crud_service.py
Normal file
@@ -0,0 +1,460 @@
|
|||||||
|
"""Campaign CRUD service — list, create, update, and business logic.
|
||||||
|
|
||||||
|
Framework-agnostic; uses domain exceptions from app.domain.errors.
|
||||||
|
The router is responsible for HTTP concerns, auth, audit logging, and commit.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.domain.errors import (
|
||||||
|
BusinessRuleViolation,
|
||||||
|
EntityNotFoundError,
|
||||||
|
PermissionViolation,
|
||||||
|
)
|
||||||
|
from app.models.campaign import Campaign, CampaignTest
|
||||||
|
from app.models.test import Test
|
||||||
|
from app.models.technique import Technique
|
||||||
|
from app.utils import escape_like
|
||||||
|
from app.services.campaign_service import (
|
||||||
|
get_campaign_progress,
|
||||||
|
validate_no_circular_dependency,
|
||||||
|
TACTIC_TO_PHASE,
|
||||||
|
)
|
||||||
|
from app.services.campaign_scheduler_service import calculate_next_run
|
||||||
|
|
||||||
|
|
||||||
|
# ── Serialization helpers ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_campaign(db: Session, campaign: Campaign) -> dict:
|
||||||
|
"""Serialize a campaign with its tests and progress."""
|
||||||
|
progress = get_campaign_progress(db, campaign.id)
|
||||||
|
|
||||||
|
campaign_tests = (
|
||||||
|
db.query(CampaignTest)
|
||||||
|
.filter(CampaignTest.campaign_id == campaign.id)
|
||||||
|
.order_by(CampaignTest.order_index)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
tests = []
|
||||||
|
for ct in campaign_tests:
|
||||||
|
test = ct.test
|
||||||
|
technique = db.query(Technique).filter(Technique.id == test.technique_id).first() if test else None
|
||||||
|
|
||||||
|
tests.append({
|
||||||
|
"id": str(ct.id),
|
||||||
|
"test_id": str(ct.test_id),
|
||||||
|
"order_index": ct.order_index,
|
||||||
|
"depends_on": str(ct.depends_on) if ct.depends_on else None,
|
||||||
|
"phase": ct.phase,
|
||||||
|
"test_name": test.name if test else None,
|
||||||
|
"test_state": test.state.value if test and test.state else None,
|
||||||
|
"test_result": test.result.value if test and test.result else None,
|
||||||
|
"technique_mitre_id": technique.mitre_id if technique else None,
|
||||||
|
"technique_name": technique.name if technique else None,
|
||||||
|
"platform": test.platform if test else None,
|
||||||
|
})
|
||||||
|
|
||||||
|
actor = campaign.threat_actor
|
||||||
|
|
||||||
|
return {
|
||||||
|
"id": str(campaign.id),
|
||||||
|
"name": campaign.name,
|
||||||
|
"description": campaign.description,
|
||||||
|
"type": campaign.type,
|
||||||
|
"status": campaign.status,
|
||||||
|
"threat_actor_id": str(campaign.threat_actor_id) if campaign.threat_actor_id else None,
|
||||||
|
"threat_actor_name": actor.name if actor else None,
|
||||||
|
"created_by": str(campaign.created_by) if campaign.created_by else None,
|
||||||
|
"scheduled_at": campaign.scheduled_at.isoformat() if campaign.scheduled_at else None,
|
||||||
|
"completed_at": campaign.completed_at.isoformat() if campaign.completed_at else None,
|
||||||
|
"target_platform": campaign.target_platform,
|
||||||
|
"tags": campaign.tags or [],
|
||||||
|
"created_at": campaign.created_at.isoformat() if campaign.created_at else None,
|
||||||
|
"is_recurring": campaign.is_recurring or False,
|
||||||
|
"recurrence_pattern": campaign.recurrence_pattern,
|
||||||
|
"next_run_at": campaign.next_run_at.isoformat() if campaign.next_run_at else None,
|
||||||
|
"last_run_at": campaign.last_run_at.isoformat() if campaign.last_run_at else None,
|
||||||
|
"parent_campaign_id": str(campaign.parent_campaign_id) if campaign.parent_campaign_id else None,
|
||||||
|
"tests": tests,
|
||||||
|
"progress": progress,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_campaign_summary(db: Session, campaign: Campaign) -> dict:
|
||||||
|
"""Lightweight campaign serialization for list views."""
|
||||||
|
progress = get_campaign_progress(db, campaign.id)
|
||||||
|
actor = campaign.threat_actor
|
||||||
|
|
||||||
|
return {
|
||||||
|
"id": str(campaign.id),
|
||||||
|
"name": campaign.name,
|
||||||
|
"description": campaign.description,
|
||||||
|
"type": campaign.type,
|
||||||
|
"status": campaign.status,
|
||||||
|
"threat_actor_id": str(campaign.threat_actor_id) if campaign.threat_actor_id else None,
|
||||||
|
"threat_actor_name": actor.name if actor else None,
|
||||||
|
"target_platform": campaign.target_platform,
|
||||||
|
"tags": campaign.tags or [],
|
||||||
|
"created_at": campaign.created_at.isoformat() if campaign.created_at else None,
|
||||||
|
"is_recurring": campaign.is_recurring or False,
|
||||||
|
"recurrence_pattern": campaign.recurrence_pattern,
|
||||||
|
"next_run_at": campaign.next_run_at.isoformat() if campaign.next_run_at else None,
|
||||||
|
"last_run_at": campaign.last_run_at.isoformat() if campaign.last_run_at else None,
|
||||||
|
"test_count": progress["total"],
|
||||||
|
"completion_pct": progress["completion_pct"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ── CRUD operations ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def list_campaigns(
|
||||||
|
db: Session,
|
||||||
|
*,
|
||||||
|
type: Optional[str] = None,
|
||||||
|
status: Optional[str] = None,
|
||||||
|
threat_actor_id: Optional[str] = None,
|
||||||
|
search: Optional[str] = None,
|
||||||
|
offset: int = 0,
|
||||||
|
limit: int = 50,
|
||||||
|
) -> dict:
|
||||||
|
"""Return a paginated list of campaigns with optional filters."""
|
||||||
|
query = db.query(Campaign)
|
||||||
|
|
||||||
|
if type:
|
||||||
|
query = query.filter(Campaign.type == type)
|
||||||
|
if status:
|
||||||
|
query = query.filter(Campaign.status == status)
|
||||||
|
if threat_actor_id:
|
||||||
|
query = query.filter(Campaign.threat_actor_id == threat_actor_id)
|
||||||
|
if search:
|
||||||
|
pattern = f"%{escape_like(search)}%"
|
||||||
|
query = query.filter(Campaign.name.ilike(pattern) | Campaign.description.ilike(pattern))
|
||||||
|
|
||||||
|
total = query.count()
|
||||||
|
campaigns = query.order_by(Campaign.created_at.desc()).offset(offset).limit(limit).all()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total": total,
|
||||||
|
"offset": offset,
|
||||||
|
"limit": limit,
|
||||||
|
"items": [serialize_campaign_summary(db, c) for c in campaigns],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def create_campaign(
|
||||||
|
db: Session,
|
||||||
|
*,
|
||||||
|
creator_id: uuid.UUID,
|
||||||
|
name: str,
|
||||||
|
description: Optional[str] = None,
|
||||||
|
type: str = "custom",
|
||||||
|
threat_actor_id: Optional[str] = None,
|
||||||
|
target_platform: Optional[str] = None,
|
||||||
|
tags: Optional[list[str]] = None,
|
||||||
|
scheduled_at: Optional[str] = None,
|
||||||
|
) -> dict:
|
||||||
|
"""Create a new campaign. Does not commit; caller commits."""
|
||||||
|
campaign = Campaign(
|
||||||
|
name=name,
|
||||||
|
description=description,
|
||||||
|
type=type,
|
||||||
|
threat_actor_id=uuid.UUID(threat_actor_id) if threat_actor_id else None,
|
||||||
|
target_platform=target_platform,
|
||||||
|
tags=tags or [],
|
||||||
|
created_by=creator_id,
|
||||||
|
scheduled_at=datetime.fromisoformat(scheduled_at) if scheduled_at else None,
|
||||||
|
)
|
||||||
|
db.add(campaign)
|
||||||
|
db.flush()
|
||||||
|
return serialize_campaign(db, campaign)
|
||||||
|
|
||||||
|
|
||||||
|
def get_campaign_detail(db: Session, campaign_id: str) -> dict:
|
||||||
|
"""Get detailed campaign info including tests and progress.
|
||||||
|
|
||||||
|
Raises EntityNotFoundError if campaign not found.
|
||||||
|
"""
|
||||||
|
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||||
|
if not campaign:
|
||||||
|
raise EntityNotFoundError("Campaign", campaign_id)
|
||||||
|
return serialize_campaign(db, campaign)
|
||||||
|
|
||||||
|
|
||||||
|
def update_campaign(
|
||||||
|
db: Session,
|
||||||
|
campaign_id: str,
|
||||||
|
*,
|
||||||
|
updater_id: uuid.UUID,
|
||||||
|
updater_role: str,
|
||||||
|
**fields,
|
||||||
|
) -> dict:
|
||||||
|
"""Update a campaign. Only allowed in draft or active state.
|
||||||
|
|
||||||
|
Raises EntityNotFoundError, BusinessRuleViolation, or PermissionViolation.
|
||||||
|
Does not commit; caller commits.
|
||||||
|
"""
|
||||||
|
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||||
|
if not campaign:
|
||||||
|
raise EntityNotFoundError("Campaign", campaign_id)
|
||||||
|
|
||||||
|
if campaign.status not in ("draft", "active"):
|
||||||
|
raise BusinessRuleViolation("Can only update draft or active campaigns")
|
||||||
|
|
||||||
|
if str(campaign.created_by) != str(updater_id) and updater_role != "admin":
|
||||||
|
raise PermissionViolation("Only the creator or admin can update this campaign")
|
||||||
|
|
||||||
|
if "scheduled_at" in fields and fields["scheduled_at"]:
|
||||||
|
fields["scheduled_at"] = datetime.fromisoformat(fields["scheduled_at"])
|
||||||
|
|
||||||
|
for field, value in fields.items():
|
||||||
|
setattr(campaign, field, value)
|
||||||
|
|
||||||
|
db.flush()
|
||||||
|
return serialize_campaign(db, campaign)
|
||||||
|
|
||||||
|
|
||||||
|
def add_test_to_campaign(
|
||||||
|
db: Session,
|
||||||
|
campaign_id: str,
|
||||||
|
*,
|
||||||
|
test_id: str,
|
||||||
|
order_index: Optional[int] = None,
|
||||||
|
depends_on: Optional[str] = None,
|
||||||
|
phase: Optional[str] = None,
|
||||||
|
) -> dict:
|
||||||
|
"""Add a test to a campaign with optional ordering and dependency.
|
||||||
|
|
||||||
|
Raises EntityNotFoundError for missing campaign or test.
|
||||||
|
Raises BusinessRuleViolation for invalid state or circular dependency.
|
||||||
|
Does not commit; caller commits.
|
||||||
|
"""
|
||||||
|
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||||
|
if not campaign:
|
||||||
|
raise EntityNotFoundError("Campaign", campaign_id)
|
||||||
|
|
||||||
|
if campaign.status not in ("draft", "active"):
|
||||||
|
raise BusinessRuleViolation("Can only add tests to draft or active campaigns")
|
||||||
|
|
||||||
|
test = db.query(Test).filter(Test.id == test_id).first()
|
||||||
|
if not test:
|
||||||
|
raise EntityNotFoundError("Test", test_id)
|
||||||
|
|
||||||
|
if order_index is not None:
|
||||||
|
final_order_index = order_index
|
||||||
|
else:
|
||||||
|
max_order = (
|
||||||
|
db.query(CampaignTest.order_index)
|
||||||
|
.filter(CampaignTest.campaign_id == campaign_id)
|
||||||
|
.order_by(CampaignTest.order_index.desc())
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
final_order_index = (max_order[0] + 1) if max_order else 0
|
||||||
|
|
||||||
|
depends_on_uuid = uuid.UUID(depends_on) if depends_on else None
|
||||||
|
|
||||||
|
ct_id = uuid.uuid4()
|
||||||
|
if depends_on_uuid:
|
||||||
|
validate_no_circular_dependency(db, uuid.UUID(campaign_id), ct_id, depends_on_uuid)
|
||||||
|
|
||||||
|
if not phase and test.technique_id:
|
||||||
|
technique = db.query(Technique).filter(Technique.id == test.technique_id).first()
|
||||||
|
if technique and technique.tactic:
|
||||||
|
phase = TACTIC_TO_PHASE.get(technique.tactic, None)
|
||||||
|
|
||||||
|
campaign_test = CampaignTest(
|
||||||
|
id=ct_id,
|
||||||
|
campaign_id=campaign_id,
|
||||||
|
test_id=test_id,
|
||||||
|
order_index=final_order_index,
|
||||||
|
depends_on=depends_on_uuid,
|
||||||
|
phase=phase,
|
||||||
|
)
|
||||||
|
db.add(campaign_test)
|
||||||
|
db.flush()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"id": str(campaign_test.id),
|
||||||
|
"campaign_id": str(campaign_test.campaign_id),
|
||||||
|
"test_id": str(campaign_test.test_id),
|
||||||
|
"order_index": campaign_test.order_index,
|
||||||
|
"depends_on": str(campaign_test.depends_on) if campaign_test.depends_on else None,
|
||||||
|
"phase": campaign_test.phase,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def remove_test_from_campaign(db: Session, campaign_id: str, campaign_test_id: str) -> None:
|
||||||
|
"""Remove a test from a campaign.
|
||||||
|
|
||||||
|
Raises EntityNotFoundError for missing campaign or campaign test.
|
||||||
|
Raises BusinessRuleViolation for invalid state.
|
||||||
|
Does not commit; caller commits.
|
||||||
|
"""
|
||||||
|
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||||
|
if not campaign:
|
||||||
|
raise EntityNotFoundError("Campaign", campaign_id)
|
||||||
|
|
||||||
|
if campaign.status not in ("draft", "active"):
|
||||||
|
raise BusinessRuleViolation("Can only modify draft or active campaigns")
|
||||||
|
|
||||||
|
ct = (
|
||||||
|
db.query(CampaignTest)
|
||||||
|
.filter(
|
||||||
|
CampaignTest.id == campaign_test_id,
|
||||||
|
CampaignTest.campaign_id == campaign_id,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if not ct:
|
||||||
|
raise EntityNotFoundError("CampaignTest", campaign_test_id)
|
||||||
|
|
||||||
|
dep_id = uuid.UUID(campaign_test_id)
|
||||||
|
dependents = db.query(CampaignTest).filter(CampaignTest.depends_on == dep_id).all()
|
||||||
|
for dep in dependents:
|
||||||
|
dep.depends_on = None
|
||||||
|
|
||||||
|
db.delete(ct)
|
||||||
|
db.flush()
|
||||||
|
|
||||||
|
|
||||||
|
def activate_campaign(db: Session, campaign_id: str) -> Campaign:
|
||||||
|
"""Activate a campaign, moving it from draft to active.
|
||||||
|
|
||||||
|
Raises EntityNotFoundError, BusinessRuleViolation.
|
||||||
|
Does not commit; caller commits.
|
||||||
|
"""
|
||||||
|
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||||
|
if not campaign:
|
||||||
|
raise EntityNotFoundError("Campaign", campaign_id)
|
||||||
|
|
||||||
|
if campaign.status != "draft":
|
||||||
|
raise BusinessRuleViolation("Only draft campaigns can be activated")
|
||||||
|
|
||||||
|
test_count = db.query(CampaignTest).filter(CampaignTest.campaign_id == campaign_id).count()
|
||||||
|
if test_count == 0:
|
||||||
|
raise BusinessRuleViolation("Campaign must have at least one test to activate")
|
||||||
|
|
||||||
|
campaign.status = "active"
|
||||||
|
db.flush()
|
||||||
|
return campaign
|
||||||
|
|
||||||
|
|
||||||
|
def complete_campaign(db: Session, campaign_id: str) -> Campaign:
|
||||||
|
"""Mark a campaign as completed.
|
||||||
|
|
||||||
|
Raises EntityNotFoundError, BusinessRuleViolation.
|
||||||
|
Does not commit; caller commits.
|
||||||
|
"""
|
||||||
|
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||||
|
if not campaign:
|
||||||
|
raise EntityNotFoundError("Campaign", campaign_id)
|
||||||
|
|
||||||
|
if campaign.status != "active":
|
||||||
|
raise BusinessRuleViolation("Only active campaigns can be completed")
|
||||||
|
|
||||||
|
campaign.status = "completed"
|
||||||
|
campaign.completed_at = datetime.utcnow()
|
||||||
|
db.flush()
|
||||||
|
return campaign
|
||||||
|
|
||||||
|
|
||||||
|
def get_campaign_progress_data(db: Session, campaign_id: str) -> dict:
|
||||||
|
"""Get progress statistics for a campaign.
|
||||||
|
|
||||||
|
Raises EntityNotFoundError if campaign not found.
|
||||||
|
"""
|
||||||
|
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||||
|
if not campaign:
|
||||||
|
raise EntityNotFoundError("Campaign", campaign_id)
|
||||||
|
|
||||||
|
progress = get_campaign_progress(db, uuid.UUID(campaign_id))
|
||||||
|
return {
|
||||||
|
"campaign_id": str(campaign.id),
|
||||||
|
"campaign_name": campaign.name,
|
||||||
|
**progress,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def schedule_campaign(
|
||||||
|
db: Session,
|
||||||
|
campaign_id: str,
|
||||||
|
*,
|
||||||
|
owner_id: uuid.UUID,
|
||||||
|
owner_role: str,
|
||||||
|
is_recurring: bool,
|
||||||
|
recurrence_pattern: Optional[str] = None,
|
||||||
|
next_run_at: Optional[str] = None,
|
||||||
|
) -> Campaign:
|
||||||
|
"""Configure or update the recurrence schedule for a campaign.
|
||||||
|
|
||||||
|
Raises EntityNotFoundError, PermissionViolation, BusinessRuleViolation.
|
||||||
|
Does not commit; caller commits.
|
||||||
|
"""
|
||||||
|
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||||
|
if not campaign:
|
||||||
|
raise EntityNotFoundError("Campaign", campaign_id)
|
||||||
|
|
||||||
|
if str(campaign.created_by) != str(owner_id) and owner_role != "admin":
|
||||||
|
raise PermissionViolation("Only the creator or admin can configure scheduling")
|
||||||
|
|
||||||
|
campaign.is_recurring = is_recurring
|
||||||
|
|
||||||
|
if is_recurring:
|
||||||
|
if recurrence_pattern not in ("weekly", "monthly", "quarterly"):
|
||||||
|
raise BusinessRuleViolation(
|
||||||
|
"recurrence_pattern must be 'weekly', 'monthly', or 'quarterly'"
|
||||||
|
)
|
||||||
|
campaign.recurrence_pattern = recurrence_pattern
|
||||||
|
if next_run_at:
|
||||||
|
campaign.next_run_at = datetime.fromisoformat(
|
||||||
|
next_run_at.replace("Z", "+00:00").replace("+00:00", "")
|
||||||
|
)
|
||||||
|
elif not campaign.next_run_at:
|
||||||
|
campaign.next_run_at = calculate_next_run(datetime.utcnow(), recurrence_pattern)
|
||||||
|
else:
|
||||||
|
campaign.recurrence_pattern = None
|
||||||
|
campaign.next_run_at = None
|
||||||
|
|
||||||
|
db.flush()
|
||||||
|
return campaign
|
||||||
|
|
||||||
|
|
||||||
|
def get_campaign_history(db: Session, campaign_id: str) -> dict:
|
||||||
|
"""List all child campaigns (execution history) of a recurring campaign.
|
||||||
|
|
||||||
|
Raises EntityNotFoundError if campaign not found.
|
||||||
|
"""
|
||||||
|
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||||
|
if not campaign:
|
||||||
|
raise EntityNotFoundError("Campaign", campaign_id)
|
||||||
|
|
||||||
|
campaign_uuid = uuid.UUID(campaign_id)
|
||||||
|
children = (
|
||||||
|
db.query(Campaign)
|
||||||
|
.filter(Campaign.parent_campaign_id == campaign_uuid)
|
||||||
|
.order_by(Campaign.created_at.desc())
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"campaign_id": str(campaign.id),
|
||||||
|
"campaign_name": campaign.name,
|
||||||
|
"items": [
|
||||||
|
{
|
||||||
|
"id": str(child.id),
|
||||||
|
"name": child.name,
|
||||||
|
"status": child.status,
|
||||||
|
"test_count": db.query(CampaignTest).filter(CampaignTest.campaign_id == child.id).count(),
|
||||||
|
"completion_pct": get_campaign_progress(db, child.id)["completion_pct"],
|
||||||
|
"created_at": child.created_at.isoformat() if child.created_at else None,
|
||||||
|
"completed_at": child.completed_at.isoformat() if child.completed_at else None,
|
||||||
|
}
|
||||||
|
for child in children
|
||||||
|
],
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user