"""Phase 10: Attack Path CRUD service.""" import logging from datetime import datetime from typing import Optional from uuid import UUID from sqlalchemy.orm import Session, joinedload from app.models.attack_path import ( AttackPath, AttackPathStep, AttackPathExecution, AttackPathStepResult, TimelineEntry, ExecutionStatus, StepResultStatus, TimelineActorSide, TimelineEntryType, ) from app.domain.exceptions import EntityNotFoundError from app.services import audit_service logger = logging.getLogger(__name__) def _now() -> datetime: return datetime.utcnow() # ── Attack Path CRUD ────────────────────────────────────────────────────────── def create_attack_path(db: Session, data: dict, user_id: UUID) -> AttackPath: path = AttackPath( name=data["name"], description=data.get("description"), objective=data.get("objective"), is_template=data.get("is_template", False), threat_actor_id=data.get("threat_actor_id"), tags=data.get("tags") or [], created_by=user_id, ) db.add(path) db.commit() db.refresh(path) audit_service.log_action( db, user_id, "ATTACK_PATH_CREATED", "attack_path", str(path.id), details={"name": path.name, "is_template": path.is_template}, ) return path def get_attack_path(db: Session, path_id: UUID) -> AttackPath: path = ( db.query(AttackPath) .options(joinedload(AttackPath.steps)) .filter(AttackPath.id == path_id) .first() ) if not path: raise EntityNotFoundError("AttackPath", str(path_id)) return path def list_attack_paths( db: Session, is_template: Optional[bool] = None, technique_id: Optional[UUID] = None, is_active: Optional[bool] = True, ) -> list[AttackPath]: q = db.query(AttackPath) if is_active is not None: q = q.filter(AttackPath.is_active == is_active) if is_template is not None: q = q.filter(AttackPath.is_template == is_template) if technique_id: q = q.join(AttackPathStep).filter(AttackPathStep.technique_id == technique_id) return q.order_by(AttackPath.created_at.desc()).all() def update_attack_path(db: Session, path_id: UUID, data: dict, user_id: UUID) -> AttackPath: path = db.query(AttackPath).filter(AttackPath.id == path_id).first() if not path: raise EntityNotFoundError("AttackPath", str(path_id)) for k, v in data.items(): if v is not None and hasattr(path, k): setattr(path, k, v) path.updated_at = _now() db.commit() db.refresh(path) return path def delete_attack_path(db: Session, path_id: UUID, user_id: UUID) -> None: path = db.query(AttackPath).filter(AttackPath.id == path_id).first() if not path: raise EntityNotFoundError("AttackPath", str(path_id)) path.is_active = False path.updated_at = _now() db.commit() audit_service.log_action(db, user_id, "ATTACK_PATH_ARCHIVED", "attack_path", str(path_id)) # ── Steps CRUD ──────────────────────────────────────────────────────────────── def add_step(db: Session, path_id: UUID, data: dict, user_id: UUID) -> AttackPathStep: path = db.query(AttackPath).filter(AttackPath.id == path_id).first() if not path: raise EntityNotFoundError("AttackPath", str(path_id)) # Auto-assign order_index if not provided if data.get("order_index") is None: max_idx = db.query(AttackPathStep).filter( AttackPathStep.attack_path_id == path_id ).count() data["order_index"] = max_idx step = AttackPathStep( attack_path_id=path_id, order_index=data.get("order_index", 0), kill_chain_phase=data.get("kill_chain_phase"), technique_id=data.get("technique_id"), test_id=data.get("test_id"), name=data.get("name"), description=data.get("description"), expected_detection=data.get("expected_detection", True), notes=data.get("notes"), ) db.add(step) path.updated_at = _now() db.commit() db.refresh(step) return step def update_step(db: Session, step_id: UUID, data: dict, user_id: UUID) -> AttackPathStep: step = db.query(AttackPathStep).filter(AttackPathStep.id == step_id).first() if not step: raise EntityNotFoundError("AttackPathStep", str(step_id)) for k, v in data.items(): if v is not None and hasattr(step, k): setattr(step, k, v) db.commit() db.refresh(step) return step def delete_step(db: Session, step_id: UUID, user_id: UUID) -> None: step = db.query(AttackPathStep).filter(AttackPathStep.id == step_id).first() if not step: raise EntityNotFoundError("AttackPathStep", str(step_id)) db.delete(step) db.commit() def reorder_steps(db: Session, path_id: UUID, step_ids: list[UUID], user_id: UUID) -> list[AttackPathStep]: """Reorder steps by providing ordered list of step IDs.""" path = db.query(AttackPath).filter(AttackPath.id == path_id).first() if not path: raise EntityNotFoundError("AttackPath", str(path_id)) for idx, step_id in enumerate(step_ids): db.query(AttackPathStep).filter( AttackPathStep.id == step_id, AttackPathStep.attack_path_id == path_id, ).update({"order_index": idx}) path.updated_at = _now() db.commit() return ( db.query(AttackPathStep) .filter(AttackPathStep.attack_path_id == path_id) .order_by(AttackPathStep.order_index) .all() ) # ── Executions ──────────────────────────────────────────────────────────────── def create_execution( db: Session, path_id: UUID, data: dict, user_id: UUID ) -> AttackPathExecution: path = db.query(AttackPath).filter(AttackPath.id == path_id).first() if not path: raise EntityNotFoundError("AttackPath", str(path_id)) execution = AttackPathExecution( attack_path_id=path_id, status=ExecutionStatus.planned, environment=data.get("environment"), red_team_lead=data.get("red_team_lead"), blue_team_lead=data.get("blue_team_lead"), notes=data.get("notes"), started_by=user_id, ) db.add(execution) db.flush() # Pre-create pending step results for every step in the path steps = ( db.query(AttackPathStep) .filter(AttackPathStep.attack_path_id == path_id) .order_by(AttackPathStep.order_index) .all() ) for step in steps: result = AttackPathStepResult( execution_id=execution.id, step_id=step.id, step_order=step.order_index, status=StepResultStatus.pending, ) db.add(result) db.commit() db.refresh(execution) # Auto-add system timeline entry _add_system_entry( db, execution.id, entry_type=TimelineEntryType.phase_transition, content=f"Execution created for '{path.name}' with {len(steps)} steps.", ) audit_service.log_action( db, user_id, "ATTACK_PATH_EXECUTION_STARTED", "attack_path_execution", str(execution.id), details={"path_id": str(path_id), "path_name": path.name, "steps": len(steps)}, ) return execution def get_execution(db: Session, execution_id: UUID) -> AttackPathExecution: ex = ( db.query(AttackPathExecution) .options( joinedload(AttackPathExecution.step_results), joinedload(AttackPathExecution.timeline), ) .filter(AttackPathExecution.id == execution_id) .first() ) if not ex: raise EntityNotFoundError("AttackPathExecution", str(execution_id)) return ex def list_executions(db: Session, path_id: UUID) -> list[AttackPathExecution]: path = db.query(AttackPath).filter(AttackPath.id == path_id).first() if not path: raise EntityNotFoundError("AttackPath", str(path_id)) return ( db.query(AttackPathExecution) .filter(AttackPathExecution.attack_path_id == path_id) .order_by(AttackPathExecution.created_at.desc()) .all() ) def start_execution(db: Session, execution_id: UUID, user_id: UUID) -> AttackPathExecution: ex = db.query(AttackPathExecution).filter(AttackPathExecution.id == execution_id).first() if not ex: raise EntityNotFoundError("AttackPathExecution", str(execution_id)) if ex.status not in (ExecutionStatus.planned,): from fastapi import HTTPException raise HTTPException(400, "Execution is not in 'planned' state") ex.status = ExecutionStatus.in_progress ex.started_at = _now() db.commit() db.refresh(ex) _add_system_entry(db, execution_id, TimelineEntryType.phase_transition, "Execution started.", actor_id=user_id, actor_side=TimelineActorSide.system) return ex # ── Step Execution ──────────────────────────────────────────────────────────── def execute_step( db: Session, execution_id: UUID, step_id: UUID, data: dict, user_id: UUID, ) -> AttackPathStepResult: """Record the result of executing one step.""" ex = db.query(AttackPathExecution).filter(AttackPathExecution.id == execution_id).first() if not ex: raise EntityNotFoundError("AttackPathExecution", str(execution_id)) if ex.status not in (ExecutionStatus.in_progress, ExecutionStatus.planned): from fastapi import HTTPException raise HTTPException(400, "Execution must be in_progress to record step results") # Auto-start if still planned if ex.status == ExecutionStatus.planned: ex.status = ExecutionStatus.in_progress ex.started_at = _now() result = ( db.query(AttackPathStepResult) .filter( AttackPathStepResult.execution_id == execution_id, AttackPathStepResult.step_id == step_id, ) .first() ) if not result: # Create on-the-fly if step was added after execution started step = db.query(AttackPathStep).filter(AttackPathStep.id == step_id).first() if not step: raise EntityNotFoundError("AttackPathStep", str(step_id)) result = AttackPathStepResult( execution_id=execution_id, step_id=step_id, step_order=step.order_index, ) db.add(result) now = _now() new_status = StepResultStatus(data["status"]) result.status = new_status result.executed_by = user_id result.executed_at = data.get("executed_at") or now result.notes = data.get("notes") result.evidence_ids = [str(e) for e in (data.get("evidence_ids") or [])] result.detection_asset_id = data.get("detection_asset_id") if new_status == StepResultStatus.detected: result.detected_at = data.get("detected_at") or now if result.executed_at: delta = (result.detected_at - result.executed_at).total_seconds() result.time_to_detect_seconds = max(0.0, delta) db.commit() db.refresh(result) # Add timeline entry step_obj = db.query(AttackPathStep).filter(AttackPathStep.id == step_id).first() step_name = step_obj.name or (step_obj.kill_chain_phase or "Unknown step") actor_side = TimelineActorSide.red if new_status != StepResultStatus.detected else TimelineActorSide.blue entry_type = ( TimelineEntryType.detection if new_status == StepResultStatus.detected else TimelineEntryType.action ) content = ( f"Step '{step_name}' marked as {new_status.value}." + (f" Detected in {result.time_to_detect_seconds:.0f}s." if result.time_to_detect_seconds else "") ) _add_system_entry( db, execution_id, entry_type, content, actor_id=user_id, actor_side=actor_side, step_id=step_id, ) return result # ── Completion & Metrics ────────────────────────────────────────────────────── def complete_execution(db: Session, execution_id: UUID, user_id: UUID) -> AttackPathExecution: """Mark execution complete and compute all kill-chain metrics.""" ex = db.query(AttackPathExecution).filter(AttackPathExecution.id == execution_id).first() if not ex: raise EntityNotFoundError("AttackPathExecution", str(execution_id)) results = ( db.query(AttackPathStepResult) .filter(AttackPathStepResult.execution_id == execution_id) .order_by(AttackPathStepResult.step_order) .all() ) total = len(results) detected = sum(1 for r in results if r.status == StepResultStatus.detected) not_detected = sum(1 for r in results if r.status == StepResultStatus.not_detected) skipped = sum(1 for r in results if r.status == StepResultStatus.skipped) detection_rate = (detected / total) if total > 0 else 0.0 ttds = [r.time_to_detect_seconds for r in results if r.time_to_detect_seconds is not None] mttd = (sum(ttds) / len(ttds)) if ttds else None # Furthest undetected step (highest order_index with not_detected status) undetected = [r for r in results if r.status == StepResultStatus.not_detected] furthest = max((r.step_order for r in undetected), default=None) ex.status = ExecutionStatus.completed ex.completed_at = _now() ex.total_steps = total ex.detected_steps = detected ex.not_detected_steps = not_detected ex.skipped_steps = skipped ex.detection_rate = round(detection_rate, 4) ex.mttd_seconds = round(mttd, 1) if mttd is not None else None ex.furthest_undetected_step = furthest db.commit() db.refresh(ex) _add_system_entry( db, execution_id, TimelineEntryType.phase_transition, f"Execution completed. Detection rate: {detection_rate:.0%}. " f"Detected {detected}/{total} steps. " + (f"MTTD: {mttd:.0f}s." if mttd else ""), actor_id=user_id, actor_side=TimelineActorSide.system, ) audit_service.log_action( db, user_id, "ATTACK_PATH_EXECUTION_COMPLETED", "attack_path_execution", str(execution_id), details={"detection_rate": detection_rate, "mttd_seconds": mttd, "detected": detected, "total": total}, ) return ex def abort_execution(db: Session, execution_id: UUID, user_id: UUID) -> AttackPathExecution: ex = db.query(AttackPathExecution).filter(AttackPathExecution.id == execution_id).first() if not ex: raise EntityNotFoundError("AttackPathExecution", str(execution_id)) ex.status = ExecutionStatus.aborted ex.completed_at = _now() db.commit() db.refresh(ex) _add_system_entry(db, execution_id, TimelineEntryType.flag, "Execution aborted.", actor_id=user_id, actor_side=TimelineActorSide.system) return ex # ── Timeline ────────────────────────────────────────────────────────────────── def add_timeline_entry( db: Session, execution_id: UUID, data: dict, user_id: UUID ) -> TimelineEntry: ex = db.query(AttackPathExecution).filter(AttackPathExecution.id == execution_id).first() if not ex: raise EntityNotFoundError("AttackPathExecution", str(execution_id)) entry = TimelineEntry( execution_id=execution_id, step_id=data.get("step_id"), timestamp=data.get("timestamp") or _now(), actor_side=TimelineActorSide(data["actor_side"]), actor_id=user_id, entry_type=TimelineEntryType(data["entry_type"]), content=data["content"], extra=data.get("extra"), ) db.add(entry) db.commit() db.refresh(entry) return entry def get_timeline(db: Session, execution_id: UUID) -> list[TimelineEntry]: ex = db.query(AttackPathExecution).filter(AttackPathExecution.id == execution_id).first() if not ex: raise EntityNotFoundError("AttackPathExecution", str(execution_id)) return ( db.query(TimelineEntry) .filter(TimelineEntry.execution_id == execution_id) .order_by(TimelineEntry.timestamp.asc()) .all() ) # ── Kill-Chain Metrics ──────────────────────────────────────────────────────── def get_kill_chain_metrics(db: Session, execution_id: UUID) -> dict: ex = db.query(AttackPathExecution).filter(AttackPathExecution.id == execution_id).first() if not ex: raise EntityNotFoundError("AttackPathExecution", str(execution_id)) results = ( db.query(AttackPathStepResult) .filter(AttackPathStepResult.execution_id == execution_id) .order_by(AttackPathStepResult.step_order) .all() ) step_breakdown = [] phase_detected: dict[str, list] = {} for r in results: step = db.query(AttackPathStep).filter(AttackPathStep.id == r.step_id).first() phase = step.kill_chain_phase if step else None entry = { "step_id": str(r.step_id), "step_order": r.step_order, "step_name": step.name if step else None, "kill_chain_phase": phase, "status": r.status.value if hasattr(r.status, "value") else r.status, "executed_at": r.executed_at.isoformat() if r.executed_at else None, "detected_at": r.detected_at.isoformat() if r.detected_at else None, "time_to_detect_seconds": r.time_to_detect_seconds, "detection_asset_id": str(r.detection_asset_id) if r.detection_asset_id else None, } step_breakdown.append(entry) if phase: phase_detected.setdefault(phase, []).append( r.status == StepResultStatus.detected ) phase_summary = { phase: { "total": len(v), "detected": sum(v), "detection_rate": round(sum(v) / len(v), 3) if v else 0.0, } for phase, v in phase_detected.items() } # Furthest undetected phase furthest_undetected_phase = None if ex.furthest_undetected_step is not None: for r in reversed(results): if r.step_order == ex.furthest_undetected_step: step = db.query(AttackPathStep).filter(AttackPathStep.id == r.step_id).first() if step: furthest_undetected_phase = step.kill_chain_phase break return { "execution_id": str(execution_id), "total_steps": ex.total_steps or len(results), "detected_steps": ex.detected_steps or 0, "not_detected_steps": ex.not_detected_steps or 0, "skipped_steps": ex.skipped_steps or 0, "detection_rate": ex.detection_rate or 0.0, "mttd_seconds": ex.mttd_seconds, "furthest_undetected_step": ex.furthest_undetected_step, "furthest_undetected_phase": furthest_undetected_phase, "step_breakdown": step_breakdown, "phase_summary": phase_summary, } # ── Helper ──────────────────────────────────────────────────────────────────── def _add_system_entry( db: Session, execution_id: UUID, entry_type: TimelineEntryType, content: str, actor_id: Optional[UUID] = None, actor_side: TimelineActorSide = TimelineActorSide.system, step_id: Optional[UUID] = None, ) -> None: entry = TimelineEntry( execution_id=execution_id, step_id=step_id, timestamp=_now(), actor_side=actor_side, actor_id=actor_id, entry_type=entry_type, content=content, ) db.add(entry) db.commit()