feat(phase-31): add campaign scheduling and recurring automation (T-233 to T-234)
This commit is contained in:
@@ -19,6 +19,7 @@ from app.services.mitre_sync_service import sync_mitre
|
||||
from app.services.intel_service import scan_intel
|
||||
from app.services.notification_service import cleanup_old_notifications
|
||||
from app.services.snapshot_service import create_snapshot, cleanup_old_snapshots
|
||||
from app.services.campaign_scheduler_service import check_and_run_recurring_campaigns
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -80,6 +81,19 @@ def _run_weekly_snapshot() -> None:
|
||||
db.close()
|
||||
|
||||
|
||||
def _run_recurring_campaigns() -> None:
|
||||
"""Check and run any due recurring campaigns."""
|
||||
logger.info("Scheduled recurring campaigns check starting...")
|
||||
db = SessionLocal()
|
||||
try:
|
||||
spawned = check_and_run_recurring_campaigns(db)
|
||||
logger.info("Recurring campaigns check finished — spawned %d campaigns", spawned)
|
||||
except Exception:
|
||||
logger.exception("Recurring campaigns check failed")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def _run_intel_scan() -> None:
|
||||
"""Execute an intel scan inside its own DB session."""
|
||||
logger.info("Scheduled intel scan job starting...")
|
||||
@@ -142,8 +156,17 @@ def start_scheduler() -> None:
|
||||
name="Weekly coverage snapshot (Sundays 00:00)",
|
||||
replace_existing=True,
|
||||
)
|
||||
scheduler.add_job(
|
||||
_run_recurring_campaigns,
|
||||
trigger="interval",
|
||||
hours=24,
|
||||
id="recurring_campaigns",
|
||||
name="Recurring campaigns check (daily)",
|
||||
replace_existing=True,
|
||||
)
|
||||
scheduler.start()
|
||||
logger.info(
|
||||
"Background scheduler started — mitre_sync (24h), intel_scan (7d), "
|
||||
"notification_cleanup (24h), weekly_snapshot (Sundays 00:00)"
|
||||
"notification_cleanup (24h), weekly_snapshot (Sundays 00:00), "
|
||||
"recurring_campaigns (daily)"
|
||||
)
|
||||
|
||||
@@ -8,7 +8,7 @@ import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import (
|
||||
Column, String, Text, Integer, DateTime,
|
||||
Column, String, Text, Integer, Boolean, DateTime,
|
||||
ForeignKey, Index,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
@@ -56,6 +56,17 @@ class Campaign(Base):
|
||||
tags = Column(JSONB, nullable=True, default=[])
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
# Recurring scheduling fields
|
||||
is_recurring = Column(Boolean, default=False)
|
||||
recurrence_pattern = Column(String, nullable=True) # weekly, monthly, quarterly
|
||||
next_run_at = Column(DateTime, nullable=True)
|
||||
last_run_at = Column(DateTime, nullable=True)
|
||||
parent_campaign_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("campaigns.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
# Relationships
|
||||
threat_actor = relationship("ThreatActor")
|
||||
creator = relationship("User", foreign_keys=[created_by])
|
||||
@@ -65,12 +76,23 @@ class Campaign(Base):
|
||||
cascade="all, delete-orphan",
|
||||
order_by="CampaignTest.order_index",
|
||||
)
|
||||
parent_campaign = relationship(
|
||||
"Campaign",
|
||||
remote_side="Campaign.id",
|
||||
foreign_keys=[parent_campaign_id],
|
||||
)
|
||||
child_campaigns = relationship(
|
||||
"Campaign",
|
||||
foreign_keys=[parent_campaign_id],
|
||||
back_populates="parent_campaign",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index('ix_campaigns_status', 'status'),
|
||||
Index('ix_campaigns_type', 'type'),
|
||||
Index('ix_campaigns_threat_actor', 'threat_actor_id'),
|
||||
Index('ix_campaigns_created_by', 'created_by'),
|
||||
Index('ix_campaigns_next_run', 'next_run_at'),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ from app.services.campaign_service import (
|
||||
get_campaign_progress,
|
||||
generate_campaign_from_threat_actor,
|
||||
)
|
||||
from app.services.campaign_scheduler_service import calculate_next_run
|
||||
from app.services.notification_service import create_notification
|
||||
from app.services.audit_service import log_action
|
||||
|
||||
@@ -59,6 +60,12 @@ class AddTestPayload(BaseModel):
|
||||
phase: Optional[str] = None
|
||||
|
||||
|
||||
class SchedulePayload(BaseModel):
|
||||
is_recurring: bool
|
||||
recurrence_pattern: Optional[str] = None # weekly, monthly, quarterly
|
||||
next_run_at: Optional[str] = None
|
||||
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────────────────────
|
||||
|
||||
def _serialize_campaign(db: Session, campaign: Campaign) -> dict:
|
||||
@@ -107,6 +114,11 @@ def _serialize_campaign(db: Session, campaign: Campaign) -> dict:
|
||||
"target_platform": campaign.target_platform,
|
||||
"tags": campaign.tags or [],
|
||||
"created_at": campaign.created_at.isoformat() if campaign.created_at else None,
|
||||
"is_recurring": campaign.is_recurring or False,
|
||||
"recurrence_pattern": campaign.recurrence_pattern,
|
||||
"next_run_at": campaign.next_run_at.isoformat() if campaign.next_run_at else None,
|
||||
"last_run_at": campaign.last_run_at.isoformat() if campaign.last_run_at else None,
|
||||
"parent_campaign_id": str(campaign.parent_campaign_id) if campaign.parent_campaign_id else None,
|
||||
"tests": tests,
|
||||
"progress": progress,
|
||||
}
|
||||
@@ -128,6 +140,10 @@ def _serialize_campaign_summary(db: Session, campaign: Campaign) -> dict:
|
||||
"target_platform": campaign.target_platform,
|
||||
"tags": campaign.tags or [],
|
||||
"created_at": campaign.created_at.isoformat() if campaign.created_at else None,
|
||||
"is_recurring": campaign.is_recurring or False,
|
||||
"recurrence_pattern": campaign.recurrence_pattern,
|
||||
"next_run_at": campaign.next_run_at.isoformat() if campaign.next_run_at else None,
|
||||
"last_run_at": campaign.last_run_at.isoformat() if campaign.last_run_at else None,
|
||||
"test_count": progress["total"],
|
||||
"completion_pct": progress["completion_pct"],
|
||||
}
|
||||
@@ -522,3 +538,102 @@ def generate_campaign_from_actor(
|
||||
)
|
||||
|
||||
return _serialize_campaign(db, campaign)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PATCH /campaigns/{id}/schedule — Configure recurrence
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.patch("/{campaign_id}/schedule")
|
||||
def schedule_campaign(
|
||||
campaign_id: str,
|
||||
payload: SchedulePayload,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_any_role("admin")),
|
||||
):
|
||||
"""Configure or update the recurrence schedule for a campaign.
|
||||
|
||||
Only the campaign creator or admin can change scheduling.
|
||||
"""
|
||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||
if not campaign:
|
||||
raise HTTPException(status_code=404, detail="Campaign not found")
|
||||
|
||||
# Check ownership or admin
|
||||
if str(campaign.created_by) != str(current_user.id) and current_user.role != "admin":
|
||||
raise HTTPException(status_code=403, detail="Only the creator or admin can configure scheduling")
|
||||
|
||||
campaign.is_recurring = payload.is_recurring
|
||||
|
||||
if payload.is_recurring:
|
||||
if payload.recurrence_pattern not in ("weekly", "monthly", "quarterly"):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="recurrence_pattern must be 'weekly', 'monthly', or 'quarterly'",
|
||||
)
|
||||
campaign.recurrence_pattern = payload.recurrence_pattern
|
||||
if payload.next_run_at:
|
||||
campaign.next_run_at = datetime.fromisoformat(payload.next_run_at.replace("Z", "+00:00").replace("+00:00", ""))
|
||||
elif not campaign.next_run_at:
|
||||
campaign.next_run_at = calculate_next_run(datetime.utcnow(), payload.recurrence_pattern)
|
||||
else:
|
||||
campaign.recurrence_pattern = None
|
||||
campaign.next_run_at = None
|
||||
|
||||
db.commit()
|
||||
db.refresh(campaign)
|
||||
|
||||
log_action(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
action="schedule_campaign",
|
||||
entity_type="campaign",
|
||||
entity_id=campaign.id,
|
||||
details={
|
||||
"is_recurring": campaign.is_recurring,
|
||||
"recurrence_pattern": campaign.recurrence_pattern,
|
||||
"next_run_at": campaign.next_run_at.isoformat() if campaign.next_run_at else None,
|
||||
},
|
||||
)
|
||||
|
||||
return _serialize_campaign(db, campaign)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /campaigns/{id}/history — Execution history (child campaigns)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/{campaign_id}/history")
|
||||
def get_campaign_history(
|
||||
campaign_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""List all child campaigns (execution history) of a recurring campaign."""
|
||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||
if not campaign:
|
||||
raise HTTPException(status_code=404, detail="Campaign not found")
|
||||
|
||||
children = (
|
||||
db.query(Campaign)
|
||||
.filter(Campaign.parent_campaign_id == campaign_id)
|
||||
.order_by(Campaign.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
return {
|
||||
"campaign_id": str(campaign.id),
|
||||
"campaign_name": campaign.name,
|
||||
"items": [
|
||||
{
|
||||
"id": str(child.id),
|
||||
"name": child.name,
|
||||
"status": child.status,
|
||||
"test_count": db.query(CampaignTest).filter(CampaignTest.campaign_id == child.id).count(),
|
||||
"completion_pct": get_campaign_progress(db, child.id)["completion_pct"],
|
||||
"created_at": child.created_at.isoformat() if child.created_at else None,
|
||||
"completed_at": child.completed_at.isoformat() if child.completed_at else None,
|
||||
}
|
||||
for child in children
|
||||
],
|
||||
}
|
||||
|
||||
193
backend/app/services/campaign_scheduler_service.py
Normal file
193
backend/app/services/campaign_scheduler_service.py
Normal file
@@ -0,0 +1,193 @@
|
||||
"""Campaign scheduler service — recurring campaign execution.
|
||||
|
||||
Handles checking which recurring campaigns are due, cloning them with
|
||||
fresh tests, and computing the next run date.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.campaign import Campaign, CampaignTest
|
||||
from app.models.test import Test
|
||||
from app.models.enums import TestState
|
||||
from app.services.notification_service import create_notification
|
||||
from app.services.audit_service import log_action
|
||||
from app.models.user import User
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Next-run calculation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def calculate_next_run(current_date: datetime, pattern: str) -> datetime:
|
||||
"""Compute the next run date from *current_date* and a recurrence pattern.
|
||||
|
||||
Supported patterns:
|
||||
- ``weekly`` : +7 days
|
||||
- ``monthly`` : +30 days
|
||||
- ``quarterly``: +90 days
|
||||
"""
|
||||
offsets = {
|
||||
"weekly": timedelta(days=7),
|
||||
"monthly": timedelta(days=30),
|
||||
"quarterly": timedelta(days=90),
|
||||
}
|
||||
return current_date + offsets.get(pattern, timedelta(days=30))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Clone a campaign
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _clone_campaign(db: Session, original: Campaign) -> Campaign:
|
||||
"""Create a new child campaign from a recurring template.
|
||||
|
||||
1. Clone the campaign with a date-stamped name.
|
||||
2. For each ``CampaignTest`` in the original, create a new ``Test``
|
||||
with the same base data (in ``draft`` state) and link it.
|
||||
3. Activate the new campaign.
|
||||
"""
|
||||
now = datetime.utcnow()
|
||||
run_label = now.strftime("%Y-%m-%d")
|
||||
|
||||
child = Campaign(
|
||||
name=f"{original.name} (Run {run_label})",
|
||||
description=original.description,
|
||||
type=original.type,
|
||||
threat_actor_id=original.threat_actor_id,
|
||||
status="active",
|
||||
created_by=original.created_by,
|
||||
target_platform=original.target_platform,
|
||||
tags=original.tags or [],
|
||||
parent_campaign_id=original.id,
|
||||
)
|
||||
db.add(child)
|
||||
db.flush() # get child.id
|
||||
|
||||
# Clone each campaign_test with a fresh Test
|
||||
original_cts = (
|
||||
db.query(CampaignTest)
|
||||
.filter(CampaignTest.campaign_id == original.id)
|
||||
.order_by(CampaignTest.order_index)
|
||||
.all()
|
||||
)
|
||||
|
||||
for ct in original_cts:
|
||||
src_test = ct.test
|
||||
if not src_test:
|
||||
continue
|
||||
|
||||
new_test = Test(
|
||||
technique_id=src_test.technique_id,
|
||||
name=src_test.name,
|
||||
description=src_test.description,
|
||||
platform=src_test.platform,
|
||||
procedure_text=src_test.procedure_text,
|
||||
tool_used=src_test.tool_used,
|
||||
created_by=original.created_by,
|
||||
state=TestState.draft,
|
||||
)
|
||||
db.add(new_test)
|
||||
db.flush() # get new_test.id
|
||||
|
||||
new_ct = CampaignTest(
|
||||
campaign_id=child.id,
|
||||
test_id=new_test.id,
|
||||
order_index=ct.order_index,
|
||||
phase=ct.phase,
|
||||
# depends_on is not copied — would need ID remapping
|
||||
)
|
||||
db.add(new_ct)
|
||||
|
||||
db.flush()
|
||||
return child
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Check and run recurring campaigns (daily job)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def check_and_run_recurring_campaigns(db: Session) -> int:
|
||||
"""Check all recurring campaigns and clone any that are due.
|
||||
|
||||
Returns the number of campaigns spawned.
|
||||
"""
|
||||
now = datetime.utcnow()
|
||||
|
||||
due_campaigns = (
|
||||
db.query(Campaign)
|
||||
.filter(
|
||||
Campaign.is_recurring == True, # noqa: E712
|
||||
Campaign.next_run_at <= now,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
spawned = 0
|
||||
|
||||
for campaign in due_campaigns:
|
||||
try:
|
||||
child = _clone_campaign(db, campaign)
|
||||
|
||||
# Update the original's scheduling fields
|
||||
campaign.last_run_at = now
|
||||
campaign.next_run_at = calculate_next_run(now, campaign.recurrence_pattern or "monthly")
|
||||
|
||||
db.commit()
|
||||
db.refresh(child)
|
||||
|
||||
# Audit
|
||||
log_action(
|
||||
db,
|
||||
user_id=campaign.created_by,
|
||||
action="recurring_campaign_run",
|
||||
entity_type="campaign",
|
||||
entity_id=child.id,
|
||||
details={
|
||||
"parent_campaign_id": str(campaign.id),
|
||||
"child_campaign_name": child.name,
|
||||
"pattern": campaign.recurrence_pattern,
|
||||
},
|
||||
)
|
||||
|
||||
# Notify
|
||||
if campaign.created_by:
|
||||
create_notification(
|
||||
db,
|
||||
user_id=campaign.created_by,
|
||||
type="recurring_campaign_run",
|
||||
title="Recurring campaign executed",
|
||||
message=f'Campaign "{child.name}" was automatically created from recurring template "{campaign.name}".',
|
||||
entity_type="campaign",
|
||||
entity_id=child.id,
|
||||
)
|
||||
|
||||
# Notify red_tech users
|
||||
red_techs = db.query(User).filter(User.role == "red_tech", User.is_active == True).all() # noqa: E712
|
||||
for user in red_techs:
|
||||
create_notification(
|
||||
db,
|
||||
user_id=user.id,
|
||||
type="campaign_activated",
|
||||
title="New recurring campaign active",
|
||||
message=f'Campaign "{child.name}" is now active and ready for execution.',
|
||||
entity_type="campaign",
|
||||
entity_id=child.id,
|
||||
)
|
||||
|
||||
spawned += 1
|
||||
logger.info("Spawned child campaign '%s' from parent '%s'", child.name, campaign.name)
|
||||
|
||||
except Exception:
|
||||
db.rollback()
|
||||
logger.exception("Failed to run recurring campaign '%s'", campaign.name)
|
||||
|
||||
return spawned
|
||||
Reference in New Issue
Block a user