"""Phase 13: Operational Alert service — rule evaluation engine + CRUD.""" from __future__ import annotations import logging import time from datetime import datetime, timedelta from typing import List, Optional from uuid import UUID from sqlalchemy.orm import Session from app.domain.errors import EntityNotFoundError, BusinessRuleViolation from app.models.operational_alert import ( AlertInstance, AlertRule, AlertRuleType, AlertSeverity, AlertStatus, ) from app.models.technique import Technique from app.models.risk_intelligence import TechniqueRiskProfile from app.models.ownership_queue import RevalidationQueueItem, QueueStatus from app.models.ownership_queue import TechniqueOwnership from app.models.executive_dashboard import PostureSnapshot from app.models.enums import TechniqueStatus from app.models.user import User log = logging.getLogger(__name__) # ── Notification & webhook dispatch helpers ─────────────────────────────────── def _dispatch_inapp_notifications(db: Session, rule: AlertRule, instance: AlertInstance) -> None: """Create in-app Notification rows for all admins and leads.""" from app.services.notification_service import create_notification admin_roles = {"admin", "red_lead", "blue_lead"} users = db.query(User).filter( User.role.in_(admin_roles), User.is_active == True, # noqa: E712 ).all() for user in users: create_notification( db, user_id = user.id, type = "alert_fired", title = instance.title, message = instance.message, entity_type = "alert_instance", entity_id = instance.id, ) def _dispatch_webhooks(rule: AlertRule, instance: AlertInstance) -> None: """Fire webhook(s) for a triggered alert (all exceptions caught).""" from app.services.webhook_service import dispatch_webhook, dispatch_webhook_targeted payload = { "alert_id": str(instance.id), "rule_id": str(rule.id) if rule.id else None, "rule_name": instance.rule_name, "rule_type": instance.rule_type, "severity": instance.severity, "title": instance.title, "message": instance.message, "details": instance.details, } # 1. Targeted webhook configured on the rule if rule.notify_webhook and rule.webhook_id: dispatch_webhook_targeted(str(rule.webhook_id), "alert.fired", payload) # 2. Broadcast to all global "alert.fired" subscribers dispatch_webhook("alert.fired", payload) # ── Pre-configured system rules (seeded at startup) ─────────────────────────── SYSTEM_RULES = [ { "name": "Critical Risk Techniques", "description": "Fires when 3 or more techniques reach critical risk level (score ≥ 75).", "rule_type": AlertRuleType.high_risk.value, "severity": AlertSeverity.critical.value, "is_system": True, "config": {"min_risk_score": 75.0, "min_count": 3}, "cooldown_hours": 24, }, { "name": "High-Risk Technique Spike", "description": "Fires when 10 or more techniques reach high risk (score ≥ 50).", "rule_type": AlertRuleType.high_risk.value, "severity": AlertSeverity.high.value, "is_system": True, "config": {"min_risk_score": 50.0, "min_count": 10}, "cooldown_hours": 24, }, { "name": "Stale Technique Detection", "description": "Fires when 5+ validated techniques have not been reviewed in 30+ days.", "rule_type": AlertRuleType.stale_technique.value, "severity": AlertSeverity.medium.value, "is_system": True, "config": {"days_stale": 30, "min_count": 5}, "cooldown_hours": 48, }, { "name": "Coverage Regression", "description": "Fires when coverage drops by 5 or more percentage points between daily snapshots.", "rule_type": AlertRuleType.coverage_regression.value, "severity": AlertSeverity.high.value, "is_system": True, "config": {"min_drop_pct": 5.0}, "cooldown_hours": 12, }, { "name": "Low Coverage Warning", "description": "Fires when overall coverage falls below 30%.", "rule_type": AlertRuleType.low_coverage.value, "severity": AlertSeverity.medium.value, "is_system": True, "config": {"max_coverage_pct": 30.0}, "cooldown_hours": 72, }, { "name": "Revalidation Queue Backlog", "description": "Fires when 15+ techniques are waiting in the revalidation queue.", "rule_type": AlertRuleType.expiry_wave.value, "severity": AlertSeverity.medium.value, "is_system": True, "config": {"min_pending_count": 15}, "cooldown_hours": 24, }, { "name": "New MITRE Techniques Detected", "description": "Fires when new ATT&CK techniques are added in the last 7 days.", "rule_type": AlertRuleType.new_technique.value, "severity": AlertSeverity.info.value, "is_system": True, "config": {"lookback_days": 7, "min_count": 1}, "cooldown_hours": 168, # once a week }, { "name": "Orphan Technique Spike", "description": "Fires when 20+ techniques have no assigned owner.", "rule_type": AlertRuleType.orphan_spike.value, "severity": AlertSeverity.low.value, "is_system": True, "config": {"min_orphan_count": 20}, "cooldown_hours": 48, }, ] def seed_system_rules(db: Session) -> int: """Ensure all system rules exist (idempotent). Returns count created.""" created = 0 for rule_def in SYSTEM_RULES: exists = db.query(AlertRule).filter( AlertRule.name == rule_def["name"], AlertRule.is_system == True, ).first() if not exists: rule = AlertRule(**rule_def) db.add(rule) created += 1 if created: db.commit() return created # ── Rule evaluators (one per AlertRuleType) ─────────────────────────────────── def _eval_high_risk(db: Session, rule: AlertRule) -> Optional[dict]: min_score = float(rule.config.get("min_risk_score", 75.0)) min_count = int(rule.config.get("min_count", 1)) profiles = db.query(TechniqueRiskProfile).filter( TechniqueRiskProfile.risk_score >= min_score, ).all() count = len(profiles) if count < min_count: return None top = sorted(profiles, key=lambda p: p.risk_score, reverse=True)[:5] return { "title": f"{count} technique(s) with risk score ≥ {min_score:.0f}", "message": ( f"{count} technique(s) have reached risk score ≥ {min_score:.0f}. " f"Top: {', '.join(str(p.technique_id)[:8] + '…' for p in top[:3])}." ), "details": { "count": count, "threshold": min_score, "top_ids": [str(p.technique_id) for p in top], "top_scores": [p.risk_score for p in top], }, } def _eval_stale_technique(db: Session, rule: AlertRule) -> Optional[dict]: days_stale = int(rule.config.get("days_stale", 30)) min_count = int(rule.config.get("min_count", 1)) cutoff = datetime.utcnow() - timedelta(days=days_stale) stale = db.query(Technique).filter( Technique.status_global == TechniqueStatus.validated, Technique.last_review_date < cutoff, ).all() count = len(stale) if count < min_count: return None return { "title": f"{count} validated technique(s) stale for {days_stale}+ days", "message": ( f"{count} technique(s) have been validated but not reviewed in over " f"{days_stale} days. Re-validate to maintain confidence." ), "details": { "count": count, "days_stale": days_stale, "example_ids": [str(t.id) for t in stale[:10]], }, } def _eval_coverage_regression(db: Session, rule: AlertRule) -> Optional[dict]: min_drop = float(rule.config.get("min_drop_pct", 5.0)) snaps = ( db.query(PostureSnapshot) .order_by(PostureSnapshot.snapshot_date.desc()) .limit(2) .all() ) if len(snaps) < 2: return None latest, previous = snaps[0], snaps[1] drop = previous.coverage_pct - latest.coverage_pct if drop < min_drop: return None return { "title": f"Coverage dropped {drop:.1f}% ({previous.coverage_pct:.1f}% → {latest.coverage_pct:.1f}%)", "message": ( f"Overall coverage fell by {drop:.1f} percentage points " f"between {previous.snapshot_date} and {latest.snapshot_date}. " f"Investigate recent technique status changes." ), "details": { "previous_pct": previous.coverage_pct, "current_pct": latest.coverage_pct, "drop_pct": round(drop, 2), "previous_date": str(previous.snapshot_date), "current_date": str(latest.snapshot_date), }, } def _eval_low_coverage(db: Session, rule: AlertRule) -> Optional[dict]: max_pct = float(rule.config.get("max_coverage_pct", 30.0)) techniques = db.query(Technique).all() total = len(techniques) if total == 0: return None validated = sum(1 for t in techniques if t.status_global == TechniqueStatus.validated) partial = sum(1 for t in techniques if t.status_global == TechniqueStatus.partial) coverage = (validated + partial * 0.5) / total * 100.0 if coverage > max_pct: return None return { "title": f"Coverage is critically low: {coverage:.1f}%", "message": ( f"Current detection coverage is {coverage:.1f}%, below the minimum " f"threshold of {max_pct:.0f}%. Prioritise coverage improvements." ), "details": { "coverage_pct": round(coverage, 2), "threshold": max_pct, "validated": validated, "partial": partial, "total": total, }, } def _eval_expiry_wave(db: Session, rule: AlertRule) -> Optional[dict]: min_pending = int(rule.config.get("min_pending_count", 15)) pending_count = db.query(RevalidationQueueItem).filter( RevalidationQueueItem.status.in_([ QueueStatus.pending, QueueStatus.in_progress, ]), ).count() if pending_count < min_pending: return None return { "title": f"Revalidation queue backlog: {pending_count} items pending", "message": ( f"{pending_count} technique(s) are waiting in the revalidation queue " f"(threshold: {min_pending}). Assign analysts to clear the backlog." ), "details": { "pending_count": pending_count, "threshold": min_pending, }, } def _eval_new_technique(db: Session, rule: AlertRule) -> Optional[dict]: lookback_days = int(rule.config.get("lookback_days", 7)) min_count = int(rule.config.get("min_count", 1)) cutoff = datetime.utcnow() - timedelta(days=lookback_days) new_techs = db.query(Technique).filter( Technique.mitre_last_modified >= cutoff, ).all() count = len(new_techs) if count < min_count: return None return { "title": f"{count} new/updated MITRE technique(s) in last {lookback_days} days", "message": ( f"{count} ATT&CK technique(s) have been added or updated in the last " f"{lookback_days} days. Review and assign coverage." ), "details": { "count": count, "lookback_days": lookback_days, "technique_ids": [str(t.id) for t in new_techs[:20]], "mitre_ids": [t.mitre_id for t in new_techs[:20]], }, } def _eval_orphan_spike(db: Session, rule: AlertRule) -> Optional[dict]: min_orphans = int(rule.config.get("min_orphan_count", 20)) total = db.query(Technique).count() owned = db.query(TechniqueOwnership).filter( TechniqueOwnership.owner_id.isnot(None), ).count() orphans = max(total - owned, 0) if orphans < min_orphans: return None return { "title": f"{orphans} unowned techniques detected", "message": ( f"{orphans} out of {total} technique(s) have no assigned owner. " f"Assign ownership to ensure accountability." ), "details": { "orphan_count": orphans, "total": total, "threshold": min_orphans, }, } _EVALUATORS = { AlertRuleType.high_risk.value: _eval_high_risk, AlertRuleType.stale_technique.value: _eval_stale_technique, AlertRuleType.coverage_regression.value: _eval_coverage_regression, AlertRuleType.low_coverage.value: _eval_low_coverage, AlertRuleType.expiry_wave.value: _eval_expiry_wave, AlertRuleType.new_technique.value: _eval_new_technique, AlertRuleType.orphan_spike.value: _eval_orphan_spike, } # ── Core evaluation engine ──────────────────────────────────────────────────── def _in_cooldown(rule: AlertRule) -> bool: if rule.last_fired_at is None: return False if rule.cooldown_hours <= 0: return False return datetime.utcnow() < rule.last_fired_at + timedelta(hours=rule.cooldown_hours) def evaluate_all_rules(db: Session) -> dict: """Evaluate every enabled rule; create AlertInstances for those that fire. After persisting each alert, dispatches: - In-app notifications to all admins/leads (if rule.notify_in_app) - Webhooks to the rule's targeted webhook + global "alert.fired" subscribers (if rule.notify_webhook) """ t0 = time.monotonic() rules = db.query(AlertRule).filter(AlertRule.is_enabled == True).all() # (rule, instance) pairs so we can dispatch after commit fired_pairs: List[tuple] = [] for rule in rules: if _in_cooldown(rule): continue evaluator = _EVALUATORS.get(rule.rule_type) if not evaluator: continue try: result = evaluator(db, rule) except Exception: log.exception("Error evaluating rule %s (%s)", rule.id, rule.name) continue if result is None: continue # condition not met instance = AlertInstance( rule_id = rule.id, rule_name = rule.name, rule_type = rule.rule_type, severity = rule.severity, title = result["title"], message = result["message"], details = result.get("details"), status = AlertStatus.open.value, ) db.add(instance) rule.last_fired_at = datetime.utcnow() fired_pairs.append((rule, instance)) # ── Persist alerts ──────────────────────────────────────────────────────── db.commit() for _rule, inst in fired_pairs: db.refresh(inst) # populate id + created_at from DB # ── In-app notifications (need instance.id, so must be after refresh) ──── for rule, inst in fired_pairs: if rule.notify_in_app: try: _dispatch_inapp_notifications(db, rule, inst) except Exception: log.exception("In-app notification failed for alert %s", inst.id) if fired_pairs: try: db.commit() # commit notifications except Exception: log.exception("Failed to commit in-app notifications") db.rollback() # ── Webhooks (fire-and-forget, own sessions) ────────────────────────────── for rule, inst in fired_pairs: try: _dispatch_webhooks(rule, inst) except Exception: log.exception("Webhook dispatch failed for alert %s", inst.id) fired = [inst for _, inst in fired_pairs] return { "rules_evaluated": len(rules), "alerts_fired": len(fired), "alerts": fired, "duration_seconds": round(time.monotonic() - t0, 3), } # ── AlertRule CRUD ──────────────────────────────────────────────────────────── def list_rules( db: Session, rule_type: Optional[str] = None, include_disabled: bool = False, ) -> List[AlertRule]: q = db.query(AlertRule) if rule_type: q = q.filter(AlertRule.rule_type == rule_type) if not include_disabled: q = q.filter(AlertRule.is_enabled == True) return q.order_by(AlertRule.created_at.asc()).all() def get_rule(db: Session, rule_id: UUID) -> AlertRule: rule = db.query(AlertRule).filter(AlertRule.id == rule_id).first() if not rule: raise EntityNotFoundError("AlertRule", str(rule_id)) return rule def create_rule(db: Session, created_by: UUID, **kwargs) -> AlertRule: kwargs["is_system"] = False kwargs["created_by"] = created_by rule = AlertRule(**kwargs) db.add(rule) db.commit() db.refresh(rule) return rule def update_rule(db: Session, rule_id: UUID, **kwargs) -> AlertRule: rule = get_rule(db, rule_id) for k, v in kwargs.items(): if v is not None: setattr(rule, k, v) db.commit() db.refresh(rule) return rule def delete_rule(db: Session, rule_id: UUID) -> None: rule = get_rule(db, rule_id) if rule.is_system: raise BusinessRuleViolation("System rules cannot be deleted. Disable them instead.") db.delete(rule) db.commit() # ── AlertInstance CRUD ──────────────────────────────────────────────────────── def list_instances( db: Session, status: Optional[str] = None, severity: Optional[str] = None, rule_type: Optional[str] = None, limit: int = 50, offset: int = 0, ) -> List[AlertInstance]: q = db.query(AlertInstance) if status: q = q.filter(AlertInstance.status == status) if severity: q = q.filter(AlertInstance.severity == severity) if rule_type: q = q.filter(AlertInstance.rule_type == rule_type) return q.order_by(AlertInstance.created_at.desc()).offset(offset).limit(limit).all() def get_instance(db: Session, instance_id: UUID) -> AlertInstance: inst = db.query(AlertInstance).filter(AlertInstance.id == instance_id).first() if not inst: raise EntityNotFoundError("AlertInstance", str(instance_id)) return inst def _transition( db: Session, instance_id: UUID, new_status: str, user_id: Optional[UUID] = None, ) -> AlertInstance: inst = get_instance(db, instance_id) inst.status = new_status if new_status == AlertStatus.acknowledged.value: inst.acknowledged_by = user_id inst.acknowledged_at = datetime.utcnow() elif new_status == AlertStatus.resolved.value: inst.resolved_at = datetime.utcnow() db.commit() db.refresh(inst) return inst def acknowledge(db: Session, instance_id: UUID, user_id: UUID) -> AlertInstance: inst = get_instance(db, instance_id) if inst.status != AlertStatus.open.value: raise BusinessRuleViolation(f"Cannot acknowledge alert in status '{inst.status}'.") return _transition(db, instance_id, AlertStatus.acknowledged.value, user_id) def resolve(db: Session, instance_id: UUID, user_id: UUID) -> AlertInstance: inst = get_instance(db, instance_id) if inst.status == AlertStatus.resolved.value: raise BusinessRuleViolation("Alert is already resolved.") return _transition(db, instance_id, AlertStatus.resolved.value, user_id) def dismiss(db: Session, instance_id: UUID, user_id: UUID) -> AlertInstance: inst = get_instance(db, instance_id) if inst.status in (AlertStatus.resolved.value, AlertStatus.dismissed.value): raise BusinessRuleViolation(f"Cannot dismiss alert in status '{inst.status}'.") return _transition(db, instance_id, AlertStatus.dismissed.value, user_id) def get_summary(db: Session) -> dict: instances = db.query(AlertInstance).all() by_status = {s.value: 0 for s in AlertStatus} by_severity = {s.value: 0 for s in AlertSeverity} by_type = {} for i in instances: by_status[i.status] = by_status.get(i.status, 0) + 1 by_severity[i.severity] = by_severity.get(i.severity, 0) + 1 by_type[i.rule_type] = by_type.get(i.rule_type, 0) + 1 recent = ( db.query(AlertInstance) .filter(AlertInstance.status == AlertStatus.open.value) .order_by(AlertInstance.created_at.desc()) .limit(5) .all() ) return { "total_open": by_status.get(AlertStatus.open.value, 0), "total_acknowledged": by_status.get(AlertStatus.acknowledged.value, 0), "total_resolved": by_status.get(AlertStatus.resolved.value, 0), "by_severity": by_severity, "by_rule_type": by_type, "recent_alerts": recent, }