"""Phase 12: Risk Intelligence service — compute and query per-technique risk scores.""" import time from datetime import datetime, timedelta from typing import List, Optional from uuid import UUID from sqlalchemy.orm import Session from app.domain.errors import EntityNotFoundError from app.models.risk_intelligence import TechniqueRiskProfile from app.models.technique import Technique from app.models.threat_actor import ThreatActorTechnique from app.models.osint_item import OsintItem from app.models.test import Test from app.models.detection_lifecycle import ( TechniqueConfidenceScore, DetectionTechniqueMapping, ) from app.models.enums import TechniqueStatus, TestResult # ── Scoring weights & thresholds ─────────────────────────────────────────────── WEIGHT_DETECTION_GAP = 0.35 WEIGHT_THREAT_ACTORS = 0.30 WEIGHT_OSINT = 0.20 WEIGHT_TEST_FAILURES = 0.15 MAX_THREAT_ACTORS = 5 MAX_OSINT_SIGNALS = 10 OSINT_LOOKBACK_DAYS = 30 LEVEL_CRITICAL = 75.0 LEVEL_HIGH = 50.0 LEVEL_MEDIUM = 25.0 LEVEL_LOW = 10.0 def _risk_level(score: float) -> str: if score >= LEVEL_CRITICAL: return "critical" if score >= LEVEL_HIGH: return "high" if score >= LEVEL_MEDIUM: return "medium" if score >= LEVEL_LOW: return "low" return "info" def _clamp(v: float, lo: float = 0.0, hi: float = 1.0) -> float: return max(lo, min(hi, v)) # ── Per-technique computation ────────────────────────────────────────────────── def _compute_for_technique(db: Session, tech: Technique) -> TechniqueRiskProfile: breakdown: dict = {} recs: list = [] # ── Factor 1: Detection gap ─────────────────────────────────────────────── # Count how many assets map this technique mapping_count = db.query(DetectionTechniqueMapping).filter( DetectionTechniqueMapping.technique_id == tech.id, ).count() # DLC confidence score (unique per technique — no order_by needed) dlc_conf = db.query(TechniqueConfidenceScore).filter( TechniqueConfidenceScore.technique_id == tech.id, ).first() confidence_level: float = float(dlc_conf.confidence_score or 0.0) if dlc_conf else 0.0 # Map technique status_global to coverage factor status = tech.status_global if status == TechniqueStatus.validated: status_coverage = 1.0 elif status == TechniqueStatus.partial: status_coverage = 0.5 elif status == TechniqueStatus.in_progress: status_coverage = 0.25 else: status_coverage = 0.0 if mapping_count > 0: raw_coverage = max(status_coverage, _clamp(confidence_level)) else: raw_coverage = status_coverage # rely on status alone if no asset mapping detection_gap = 1.0 - raw_coverage breakdown["detection_gap"] = { "mapping_count": mapping_count, "status": str(status.value) if status else None, "status_coverage": status_coverage, "confidence_level": round(confidence_level, 3), "detection_gap": round(detection_gap, 3), "contribution": round(detection_gap * WEIGHT_DETECTION_GAP * 100, 2), } if detection_gap >= 0.8: recs.append("Implement detection coverage — technique is largely undetected.") elif detection_gap >= 0.5: recs.append("Improve detection quality — coverage is partial.") # ── Factor 2: Threat actor relevance ───────────────────────────────────── actor_count = db.query(ThreatActorTechnique).filter( ThreatActorTechnique.technique_id == tech.id, ).count() ta_factor = _clamp(actor_count / MAX_THREAT_ACTORS) breakdown["threat_actor"] = { "actor_count": actor_count, "max_cap": MAX_THREAT_ACTORS, "normalised": round(ta_factor, 3), "contribution": round(ta_factor * WEIGHT_THREAT_ACTORS * 100, 2), } if actor_count >= 3: recs.append( f"High threat-actor relevance — {actor_count} tracked actors use this technique." ) elif actor_count >= 1: recs.append( f"{actor_count} threat actor(s) use this technique — monitor closely." ) # ── Factor 3: OSINT signals (last 30 days) ──────────────────────────────── cutoff = datetime.utcnow() - timedelta(days=OSINT_LOOKBACK_DAYS) osint_count = db.query(OsintItem).filter( OsintItem.technique_id == tech.id, OsintItem.discovered_at >= cutoff, ).count() osint_factor = _clamp(osint_count / MAX_OSINT_SIGNALS) breakdown["osint"] = { "signal_count_30d": osint_count, "max_cap": MAX_OSINT_SIGNALS, "normalised": round(osint_factor, 3), "contribution": round(osint_factor * WEIGHT_OSINT * 100, 2), } if osint_count >= 5: recs.append( f"High OSINT activity — {osint_count} signals in last 30 days. Review urgently." ) elif osint_count >= 1: recs.append( f"{osint_count} OSINT signal(s) in last 30 days. Review for IoCs." ) # ── Factor 4: Test failure rate ─────────────────────────────────────────── # Use Test.result (TestResult enum) to determine not-detected count tech_tests = db.query(Test).filter(Test.technique_id == tech.id).all() test_total = len([t for t in tech_tests if t.result is not None]) test_not_detected = sum( 1 for t in tech_tests if t.result == TestResult.not_detected ) test_failure_rate = (test_not_detected / test_total) if test_total > 0 else 0.0 # No tests → moderate unknown risk factor test_factor = test_failure_rate if test_total > 0 else 0.3 breakdown["test_failures"] = { "total_tests": test_total, "not_detected": test_not_detected, "failure_rate": round(test_failure_rate, 3), "factor_used": round(test_factor, 3), "contribution": round(test_factor * WEIGHT_TEST_FAILURES * 100, 2), } if test_total == 0: recs.append("No purple-team tests found — add tests to validate detection.") elif test_failure_rate >= 0.5: recs.append( f"High test failure rate ({test_failure_rate:.0%}) — blue team misses this technique." ) # ── Aggregate scores ────────────────────────────────────────────────────── raw_score = ( detection_gap * WEIGHT_DETECTION_GAP + ta_factor * WEIGHT_THREAT_ACTORS + osint_factor * WEIGHT_OSINT + test_factor * WEIGHT_TEST_FAILURES ) risk_score = _clamp(raw_score) * 100.0 likelihood = _clamp( detection_gap * 0.5 + ta_factor * 0.35 + osint_factor * 0.15 ) * 100.0 impact = _clamp( test_factor * 0.6 + osint_factor * 0.25 + detection_gap * 0.15 ) * 100.0 level = _risk_level(risk_score) breakdown["total"] = { "risk_score": round(risk_score, 2), "likelihood": round(likelihood, 2), "impact": round(impact, 2), "risk_level": level, } return TechniqueRiskProfile( technique_id = tech.id, risk_score = round(risk_score, 4), likelihood = round(likelihood, 4), impact = round(impact, 4), risk_level = level, detection_gap = round(detection_gap, 4), threat_actor_count = actor_count, osint_signal_count = osint_count, test_fail_count = test_not_detected, test_total_count = test_total, test_failure_rate = round(test_failure_rate, 4), confidence_level = round(confidence_level, 4), scoring_breakdown = breakdown, recommendations = recs or ["Risk profile looks healthy — continue monitoring."], computed_at = datetime.utcnow(), is_stale = False, ) # ── Upsert ───────────────────────────────────────────────────────────────────── def _upsert_profile(db: Session, profile: TechniqueRiskProfile) -> TechniqueRiskProfile: existing = db.query(TechniqueRiskProfile).filter( TechniqueRiskProfile.technique_id == profile.technique_id, ).first() if existing: for attr in ( "risk_score", "likelihood", "impact", "risk_level", "detection_gap", "threat_actor_count", "osint_signal_count", "test_fail_count", "test_total_count", "test_failure_rate", "confidence_level", "scoring_breakdown", "recommendations", "computed_at", "is_stale", ): setattr(existing, attr, getattr(profile, attr)) db.commit() db.refresh(existing) return existing db.add(profile) db.commit() db.refresh(profile) return profile # ── Public API ──────────────────────────────────────────────────────────────── def compute_technique_risk(db: Session, technique_id: UUID) -> TechniqueRiskProfile: tech = db.query(Technique).filter(Technique.id == technique_id).first() if not tech: raise EntityNotFoundError("Technique", str(technique_id)) profile = _compute_for_technique(db, tech) return _upsert_profile(db, profile) def compute_all_risk_scores(db: Session) -> dict: t0 = time.monotonic() techniques = db.query(Technique).all() computed = 0 errors = 0 for tech in techniques: try: profile = _compute_for_technique(db, tech) _upsert_profile(db, profile) computed += 1 except Exception: errors += 1 return { "computed": computed, "skipped": 0, "errors": errors, "duration_seconds": round(time.monotonic() - t0, 2), } def get_risk_profile(db: Session, technique_id: UUID) -> TechniqueRiskProfile: profile = db.query(TechniqueRiskProfile).filter( TechniqueRiskProfile.technique_id == technique_id, ).first() if not profile: raise EntityNotFoundError("TechniqueRiskProfile", str(technique_id)) return profile def list_risk_profiles( db: Session, risk_level: Optional[str] = None, min_score: Optional[float] = None, max_score: Optional[float] = None, stale_only: bool = False, limit: int = 100, offset: int = 0, ) -> List[TechniqueRiskProfile]: q = db.query(TechniqueRiskProfile) if risk_level: q = q.filter(TechniqueRiskProfile.risk_level == risk_level) if min_score is not None: q = q.filter(TechniqueRiskProfile.risk_score >= min_score) if max_score is not None: q = q.filter(TechniqueRiskProfile.risk_score <= max_score) if stale_only: q = q.filter(TechniqueRiskProfile.is_stale == True) return ( q.order_by(TechniqueRiskProfile.risk_score.desc()) .offset(offset).limit(limit).all() ) def get_risk_matrix(db: Session) -> list: rows = ( db.query(TechniqueRiskProfile, Technique) .join(Technique, TechniqueRiskProfile.technique_id == Technique.id) .order_by(TechniqueRiskProfile.risk_score.desc()) .all() ) return [ { "technique_id": str(p.technique_id), "technique_name": t.name, "technique_tid": t.mitre_id, "risk_score": p.risk_score, "likelihood": p.likelihood, "impact": p.impact, "risk_level": p.risk_level, "detection_gap": p.detection_gap, "computed_at": p.computed_at.isoformat() if p.computed_at else None, } for p, t in rows ] def get_risk_summary(db: Session) -> dict: all_profiles = db.query(TechniqueRiskProfile).all() total_tech = db.query(Technique).count() scored = len(all_profiles) stale = sum(1 for p in all_profiles if p.is_stale) by_level: dict = {l: 0 for l in ("critical", "high", "medium", "low", "info")} score_sum = 0.0 for p in all_profiles: by_level[p.risk_level] = by_level.get(p.risk_level, 0) + 1 score_sum += p.risk_score avg_score = (score_sum / scored) if scored > 0 else 0.0 top_rows = ( db.query(TechniqueRiskProfile, Technique) .join(Technique, TechniqueRiskProfile.technique_id == Technique.id) .order_by(TechniqueRiskProfile.risk_score.desc()) .limit(5).all() ) top_risks = [ { "technique_id": str(p.technique_id), "technique_name": t.name, "technique_tid": t.mitre_id, "risk_score": p.risk_score, "risk_level": p.risk_level, "likelihood": p.likelihood, "impact": p.impact, "detection_gap": p.detection_gap, "computed_at": p.computed_at.isoformat() if p.computed_at else None, } for p, t in top_rows ] return { "total_techniques": total_tech, "scored_techniques": scored, "stale_count": stale, "by_level": by_level, "avg_risk_score": round(avg_score, 2), "top_risks": top_risks, } def get_recommendations(db: Session, limit: int = 20) -> list: rows = ( db.query(TechniqueRiskProfile, Technique) .join(Technique, TechniqueRiskProfile.technique_id == Technique.id) .filter(TechniqueRiskProfile.risk_score > 0) .order_by(TechniqueRiskProfile.risk_score.desc()) .limit(limit).all() ) return [ { "technique_id": str(p.technique_id), "technique_name": t.name, "technique_tid": t.mitre_id, "risk_level": p.risk_level, "risk_score": p.risk_score, "recommendations": p.recommendations or [], "priority": i, } for i, (p, t) in enumerate(rows, start=1) ]