Files
Aegis/backend/app/services/attack_path_service.py
kitos 080ce56de7
Some checks failed
Aegis CI / lint-and-test (push) Has been cancelled
feat(attack-paths): Phase 10 — Attack Paths & Advanced Purple Team [FASE-10]
Models (5 tables):
  - AttackPath: named reusable attack scenario with template flag
  - AttackPathStep: ordered kill-chain step (technique + test link)
  - AttackPathExecution: a run with Red/Blue leads, timing, stored metrics
  - AttackPathStepResult: per-step detected/not_detected/skipped result
  - TimelineEntry: timestamped Red/Blue/system actions for MTTD/MTTR

Migration b036atk: raw SQL to avoid SQLAlchemy DDL hook issues

Service (attack_path_service.py):
  - Full CRUD for paths + steps (add, update, delete, reorder)
  - Execution lifecycle: create → start → execute steps → complete/abort
  - Pre-creates pending step results on execution creation
  - Auto-adds system timeline entries on key state transitions
  - complete_execution() computes: detection_rate, mttd_seconds,
    furthest_undetected_step, detected/not_detected/skipped counts
  - get_kill_chain_metrics(): per-step breakdown + phase summary

Router /api/v1/attack-paths (20 endpoints):
  POST/GET/PATCH/DELETE attack paths
  GET/POST/PATCH/DELETE steps + reorder
  POST/GET executions per path
  GET/POST/start/complete/abort executions
  POST/GET step results
  POST/GET timeline entries
  GET kill-chain metrics

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-20 13:11:01 +02:00

554 lines
20 KiB
Python

"""Phase 10: Attack Path CRUD service."""
import logging
from datetime import datetime
from typing import Optional
from uuid import UUID
from sqlalchemy.orm import Session, joinedload
from app.models.attack_path import (
AttackPath, AttackPathStep, AttackPathExecution,
AttackPathStepResult, TimelineEntry,
ExecutionStatus, StepResultStatus, TimelineActorSide, TimelineEntryType,
)
from app.domain.exceptions import EntityNotFoundError
from app.services import audit_service
logger = logging.getLogger(__name__)
def _now() -> datetime:
return datetime.utcnow()
# ── Attack Path CRUD ──────────────────────────────────────────────────────────
def create_attack_path(db: Session, data: dict, user_id: UUID) -> AttackPath:
path = AttackPath(
name=data["name"],
description=data.get("description"),
objective=data.get("objective"),
is_template=data.get("is_template", False),
threat_actor_id=data.get("threat_actor_id"),
tags=data.get("tags") or [],
created_by=user_id,
)
db.add(path)
db.commit()
db.refresh(path)
audit_service.log_action(
db, user_id, "ATTACK_PATH_CREATED", "attack_path", str(path.id),
details={"name": path.name, "is_template": path.is_template},
)
return path
def get_attack_path(db: Session, path_id: UUID) -> AttackPath:
path = (
db.query(AttackPath)
.options(joinedload(AttackPath.steps))
.filter(AttackPath.id == path_id)
.first()
)
if not path:
raise EntityNotFoundError("AttackPath", str(path_id))
return path
def list_attack_paths(
db: Session,
is_template: Optional[bool] = None,
technique_id: Optional[UUID] = None,
is_active: Optional[bool] = True,
) -> list[AttackPath]:
q = db.query(AttackPath)
if is_active is not None:
q = q.filter(AttackPath.is_active == is_active)
if is_template is not None:
q = q.filter(AttackPath.is_template == is_template)
if technique_id:
q = q.join(AttackPathStep).filter(AttackPathStep.technique_id == technique_id)
return q.order_by(AttackPath.created_at.desc()).all()
def update_attack_path(db: Session, path_id: UUID, data: dict, user_id: UUID) -> AttackPath:
path = db.query(AttackPath).filter(AttackPath.id == path_id).first()
if not path:
raise EntityNotFoundError("AttackPath", str(path_id))
for k, v in data.items():
if v is not None and hasattr(path, k):
setattr(path, k, v)
path.updated_at = _now()
db.commit()
db.refresh(path)
return path
def delete_attack_path(db: Session, path_id: UUID, user_id: UUID) -> None:
path = db.query(AttackPath).filter(AttackPath.id == path_id).first()
if not path:
raise EntityNotFoundError("AttackPath", str(path_id))
path.is_active = False
path.updated_at = _now()
db.commit()
audit_service.log_action(db, user_id, "ATTACK_PATH_ARCHIVED", "attack_path", str(path_id))
# ── Steps CRUD ────────────────────────────────────────────────────────────────
def add_step(db: Session, path_id: UUID, data: dict, user_id: UUID) -> AttackPathStep:
path = db.query(AttackPath).filter(AttackPath.id == path_id).first()
if not path:
raise EntityNotFoundError("AttackPath", str(path_id))
# Auto-assign order_index if not provided
if data.get("order_index") is None:
max_idx = db.query(AttackPathStep).filter(
AttackPathStep.attack_path_id == path_id
).count()
data["order_index"] = max_idx
step = AttackPathStep(
attack_path_id=path_id,
order_index=data.get("order_index", 0),
kill_chain_phase=data.get("kill_chain_phase"),
technique_id=data.get("technique_id"),
test_id=data.get("test_id"),
name=data.get("name"),
description=data.get("description"),
expected_detection=data.get("expected_detection", True),
notes=data.get("notes"),
)
db.add(step)
path.updated_at = _now()
db.commit()
db.refresh(step)
return step
def update_step(db: Session, step_id: UUID, data: dict, user_id: UUID) -> AttackPathStep:
step = db.query(AttackPathStep).filter(AttackPathStep.id == step_id).first()
if not step:
raise EntityNotFoundError("AttackPathStep", str(step_id))
for k, v in data.items():
if v is not None and hasattr(step, k):
setattr(step, k, v)
db.commit()
db.refresh(step)
return step
def delete_step(db: Session, step_id: UUID, user_id: UUID) -> None:
step = db.query(AttackPathStep).filter(AttackPathStep.id == step_id).first()
if not step:
raise EntityNotFoundError("AttackPathStep", str(step_id))
db.delete(step)
db.commit()
def reorder_steps(db: Session, path_id: UUID, step_ids: list[UUID], user_id: UUID) -> list[AttackPathStep]:
"""Reorder steps by providing ordered list of step IDs."""
path = db.query(AttackPath).filter(AttackPath.id == path_id).first()
if not path:
raise EntityNotFoundError("AttackPath", str(path_id))
for idx, step_id in enumerate(step_ids):
db.query(AttackPathStep).filter(
AttackPathStep.id == step_id,
AttackPathStep.attack_path_id == path_id,
).update({"order_index": idx})
path.updated_at = _now()
db.commit()
return (
db.query(AttackPathStep)
.filter(AttackPathStep.attack_path_id == path_id)
.order_by(AttackPathStep.order_index)
.all()
)
# ── Executions ────────────────────────────────────────────────────────────────
def create_execution(
db: Session, path_id: UUID, data: dict, user_id: UUID
) -> AttackPathExecution:
path = db.query(AttackPath).filter(AttackPath.id == path_id).first()
if not path:
raise EntityNotFoundError("AttackPath", str(path_id))
execution = AttackPathExecution(
attack_path_id=path_id,
status=ExecutionStatus.planned,
environment=data.get("environment"),
red_team_lead=data.get("red_team_lead"),
blue_team_lead=data.get("blue_team_lead"),
notes=data.get("notes"),
started_by=user_id,
)
db.add(execution)
db.flush()
# Pre-create pending step results for every step in the path
steps = (
db.query(AttackPathStep)
.filter(AttackPathStep.attack_path_id == path_id)
.order_by(AttackPathStep.order_index)
.all()
)
for step in steps:
result = AttackPathStepResult(
execution_id=execution.id,
step_id=step.id,
step_order=step.order_index,
status=StepResultStatus.pending,
)
db.add(result)
db.commit()
db.refresh(execution)
# Auto-add system timeline entry
_add_system_entry(
db, execution.id,
entry_type=TimelineEntryType.phase_transition,
content=f"Execution created for '{path.name}' with {len(steps)} steps.",
)
audit_service.log_action(
db, user_id, "ATTACK_PATH_EXECUTION_STARTED", "attack_path_execution",
str(execution.id),
details={"path_id": str(path_id), "path_name": path.name, "steps": len(steps)},
)
return execution
def get_execution(db: Session, execution_id: UUID) -> AttackPathExecution:
ex = (
db.query(AttackPathExecution)
.options(
joinedload(AttackPathExecution.step_results),
joinedload(AttackPathExecution.timeline),
)
.filter(AttackPathExecution.id == execution_id)
.first()
)
if not ex:
raise EntityNotFoundError("AttackPathExecution", str(execution_id))
return ex
def list_executions(db: Session, path_id: UUID) -> list[AttackPathExecution]:
path = db.query(AttackPath).filter(AttackPath.id == path_id).first()
if not path:
raise EntityNotFoundError("AttackPath", str(path_id))
return (
db.query(AttackPathExecution)
.filter(AttackPathExecution.attack_path_id == path_id)
.order_by(AttackPathExecution.created_at.desc())
.all()
)
def start_execution(db: Session, execution_id: UUID, user_id: UUID) -> AttackPathExecution:
ex = db.query(AttackPathExecution).filter(AttackPathExecution.id == execution_id).first()
if not ex:
raise EntityNotFoundError("AttackPathExecution", str(execution_id))
if ex.status not in (ExecutionStatus.planned,):
from fastapi import HTTPException
raise HTTPException(400, "Execution is not in 'planned' state")
ex.status = ExecutionStatus.in_progress
ex.started_at = _now()
db.commit()
db.refresh(ex)
_add_system_entry(db, execution_id, TimelineEntryType.phase_transition,
"Execution started.", actor_id=user_id, actor_side=TimelineActorSide.system)
return ex
# ── Step Execution ────────────────────────────────────────────────────────────
def execute_step(
db: Session,
execution_id: UUID,
step_id: UUID,
data: dict,
user_id: UUID,
) -> AttackPathStepResult:
"""Record the result of executing one step."""
ex = db.query(AttackPathExecution).filter(AttackPathExecution.id == execution_id).first()
if not ex:
raise EntityNotFoundError("AttackPathExecution", str(execution_id))
if ex.status not in (ExecutionStatus.in_progress, ExecutionStatus.planned):
from fastapi import HTTPException
raise HTTPException(400, "Execution must be in_progress to record step results")
# Auto-start if still planned
if ex.status == ExecutionStatus.planned:
ex.status = ExecutionStatus.in_progress
ex.started_at = _now()
result = (
db.query(AttackPathStepResult)
.filter(
AttackPathStepResult.execution_id == execution_id,
AttackPathStepResult.step_id == step_id,
)
.first()
)
if not result:
# Create on-the-fly if step was added after execution started
step = db.query(AttackPathStep).filter(AttackPathStep.id == step_id).first()
if not step:
raise EntityNotFoundError("AttackPathStep", str(step_id))
result = AttackPathStepResult(
execution_id=execution_id,
step_id=step_id,
step_order=step.order_index,
)
db.add(result)
now = _now()
new_status = StepResultStatus(data["status"])
result.status = new_status
result.executed_by = user_id
result.executed_at = data.get("executed_at") or now
result.notes = data.get("notes")
result.evidence_ids = [str(e) for e in (data.get("evidence_ids") or [])]
result.detection_asset_id = data.get("detection_asset_id")
if new_status == StepResultStatus.detected:
result.detected_at = data.get("detected_at") or now
if result.executed_at:
delta = (result.detected_at - result.executed_at).total_seconds()
result.time_to_detect_seconds = max(0.0, delta)
db.commit()
db.refresh(result)
# Add timeline entry
step_obj = db.query(AttackPathStep).filter(AttackPathStep.id == step_id).first()
step_name = step_obj.name or (step_obj.kill_chain_phase or "Unknown step")
actor_side = TimelineActorSide.red if new_status != StepResultStatus.detected else TimelineActorSide.blue
entry_type = (
TimelineEntryType.detection if new_status == StepResultStatus.detected
else TimelineEntryType.action
)
content = (
f"Step '{step_name}' marked as {new_status.value}."
+ (f" Detected in {result.time_to_detect_seconds:.0f}s." if result.time_to_detect_seconds else "")
)
_add_system_entry(
db, execution_id, entry_type, content,
actor_id=user_id, actor_side=actor_side, step_id=step_id,
)
return result
# ── Completion & Metrics ──────────────────────────────────────────────────────
def complete_execution(db: Session, execution_id: UUID, user_id: UUID) -> AttackPathExecution:
"""Mark execution complete and compute all kill-chain metrics."""
ex = db.query(AttackPathExecution).filter(AttackPathExecution.id == execution_id).first()
if not ex:
raise EntityNotFoundError("AttackPathExecution", str(execution_id))
results = (
db.query(AttackPathStepResult)
.filter(AttackPathStepResult.execution_id == execution_id)
.order_by(AttackPathStepResult.step_order)
.all()
)
total = len(results)
detected = sum(1 for r in results if r.status == StepResultStatus.detected)
not_detected = sum(1 for r in results if r.status == StepResultStatus.not_detected)
skipped = sum(1 for r in results if r.status == StepResultStatus.skipped)
detection_rate = (detected / total) if total > 0 else 0.0
ttds = [r.time_to_detect_seconds for r in results
if r.time_to_detect_seconds is not None]
mttd = (sum(ttds) / len(ttds)) if ttds else None
# Furthest undetected step (highest order_index with not_detected status)
undetected = [r for r in results if r.status == StepResultStatus.not_detected]
furthest = max((r.step_order for r in undetected), default=None)
ex.status = ExecutionStatus.completed
ex.completed_at = _now()
ex.total_steps = total
ex.detected_steps = detected
ex.not_detected_steps = not_detected
ex.skipped_steps = skipped
ex.detection_rate = round(detection_rate, 4)
ex.mttd_seconds = round(mttd, 1) if mttd is not None else None
ex.furthest_undetected_step = furthest
db.commit()
db.refresh(ex)
_add_system_entry(
db, execution_id, TimelineEntryType.phase_transition,
f"Execution completed. Detection rate: {detection_rate:.0%}. "
f"Detected {detected}/{total} steps. "
+ (f"MTTD: {mttd:.0f}s." if mttd else ""),
actor_id=user_id, actor_side=TimelineActorSide.system,
)
audit_service.log_action(
db, user_id, "ATTACK_PATH_EXECUTION_COMPLETED", "attack_path_execution",
str(execution_id),
details={"detection_rate": detection_rate, "mttd_seconds": mttd,
"detected": detected, "total": total},
)
return ex
def abort_execution(db: Session, execution_id: UUID, user_id: UUID) -> AttackPathExecution:
ex = db.query(AttackPathExecution).filter(AttackPathExecution.id == execution_id).first()
if not ex:
raise EntityNotFoundError("AttackPathExecution", str(execution_id))
ex.status = ExecutionStatus.aborted
ex.completed_at = _now()
db.commit()
db.refresh(ex)
_add_system_entry(db, execution_id, TimelineEntryType.flag, "Execution aborted.",
actor_id=user_id, actor_side=TimelineActorSide.system)
return ex
# ── Timeline ──────────────────────────────────────────────────────────────────
def add_timeline_entry(
db: Session, execution_id: UUID, data: dict, user_id: UUID
) -> TimelineEntry:
ex = db.query(AttackPathExecution).filter(AttackPathExecution.id == execution_id).first()
if not ex:
raise EntityNotFoundError("AttackPathExecution", str(execution_id))
entry = TimelineEntry(
execution_id=execution_id,
step_id=data.get("step_id"),
timestamp=data.get("timestamp") or _now(),
actor_side=TimelineActorSide(data["actor_side"]),
actor_id=user_id,
entry_type=TimelineEntryType(data["entry_type"]),
content=data["content"],
extra=data.get("extra"),
)
db.add(entry)
db.commit()
db.refresh(entry)
return entry
def get_timeline(db: Session, execution_id: UUID) -> list[TimelineEntry]:
ex = db.query(AttackPathExecution).filter(AttackPathExecution.id == execution_id).first()
if not ex:
raise EntityNotFoundError("AttackPathExecution", str(execution_id))
return (
db.query(TimelineEntry)
.filter(TimelineEntry.execution_id == execution_id)
.order_by(TimelineEntry.timestamp.asc())
.all()
)
# ── Kill-Chain Metrics ────────────────────────────────────────────────────────
def get_kill_chain_metrics(db: Session, execution_id: UUID) -> dict:
ex = db.query(AttackPathExecution).filter(AttackPathExecution.id == execution_id).first()
if not ex:
raise EntityNotFoundError("AttackPathExecution", str(execution_id))
results = (
db.query(AttackPathStepResult)
.filter(AttackPathStepResult.execution_id == execution_id)
.order_by(AttackPathStepResult.step_order)
.all()
)
step_breakdown = []
phase_detected: dict[str, list] = {}
for r in results:
step = db.query(AttackPathStep).filter(AttackPathStep.id == r.step_id).first()
phase = step.kill_chain_phase if step else None
entry = {
"step_id": str(r.step_id),
"step_order": r.step_order,
"step_name": step.name if step else None,
"kill_chain_phase": phase,
"status": r.status.value if hasattr(r.status, "value") else r.status,
"executed_at": r.executed_at.isoformat() if r.executed_at else None,
"detected_at": r.detected_at.isoformat() if r.detected_at else None,
"time_to_detect_seconds": r.time_to_detect_seconds,
"detection_asset_id": str(r.detection_asset_id) if r.detection_asset_id else None,
}
step_breakdown.append(entry)
if phase:
phase_detected.setdefault(phase, []).append(
r.status == StepResultStatus.detected
)
phase_summary = {
phase: {
"total": len(v),
"detected": sum(v),
"detection_rate": round(sum(v) / len(v), 3) if v else 0.0,
}
for phase, v in phase_detected.items()
}
# Furthest undetected phase
furthest_undetected_phase = None
if ex.furthest_undetected_step is not None:
for r in reversed(results):
if r.step_order == ex.furthest_undetected_step:
step = db.query(AttackPathStep).filter(AttackPathStep.id == r.step_id).first()
if step:
furthest_undetected_phase = step.kill_chain_phase
break
return {
"execution_id": str(execution_id),
"total_steps": ex.total_steps or len(results),
"detected_steps": ex.detected_steps or 0,
"not_detected_steps": ex.not_detected_steps or 0,
"skipped_steps": ex.skipped_steps or 0,
"detection_rate": ex.detection_rate or 0.0,
"mttd_seconds": ex.mttd_seconds,
"furthest_undetected_step": ex.furthest_undetected_step,
"furthest_undetected_phase": furthest_undetected_phase,
"step_breakdown": step_breakdown,
"phase_summary": phase_summary,
}
# ── Helper ────────────────────────────────────────────────────────────────────
def _add_system_entry(
db: Session,
execution_id: UUID,
entry_type: TimelineEntryType,
content: str,
actor_id: Optional[UUID] = None,
actor_side: TimelineActorSide = TimelineActorSide.system,
step_id: Optional[UUID] = None,
) -> None:
entry = TimelineEntry(
execution_id=execution_id,
step_id=step_id,
timestamp=_now(),
actor_side=actor_side,
actor_id=actor_id,
entry_type=entry_type,
content=content,
)
db.add(entry)
db.commit()