"""Phase 12: Risk Intelligence service — compute and query per-technique risk scores.""" import time from datetime import datetime, timedelta from typing import List, Optional from uuid import UUID from sqlalchemy import func from sqlalchemy.orm import Session from app.domain.errors import EntityNotFoundError from app.models.risk_intelligence import TechniqueRiskProfile from app.models.technique import Technique from app.models.threat_actor import ThreatActorTechnique from app.models.osint_item import OsintItem from app.models.test import Test from app.models.test_detection_result import TestDetectionResult from app.models.detection_lifecycle import ( TechniqueConfidenceScore, DetectionTechniqueMapping, DetectionConfidence, ) from app.models.enums import TechniqueStatus # ── Scoring constants ────────────────────────────────────────────────────────── WEIGHT_DETECTION_GAP = 0.35 WEIGHT_THREAT_ACTORS = 0.30 WEIGHT_OSINT = 0.20 WEIGHT_TEST_FAILURES = 0.15 # Normalisation caps MAX_THREAT_ACTORS = 5 # beyond this → factor saturates at 1.0 MAX_OSINT_SIGNALS = 10 # OSINT items in last 30 days OSINT_LOOKBACK_DAYS = 30 LEVEL_CRITICAL = 75.0 LEVEL_HIGH = 50.0 LEVEL_MEDIUM = 25.0 LEVEL_LOW = 10.0 def _risk_level(score: float) -> str: if score >= LEVEL_CRITICAL: return "critical" if score >= LEVEL_HIGH: return "high" if score >= LEVEL_MEDIUM: return "medium" if score >= LEVEL_LOW: return "low" return "info" def _clamp(v: float, lo: float = 0.0, hi: float = 1.0) -> float: return max(lo, min(hi, v)) # ── Single-technique computation ─────────────────────────────────────────────── def _compute_for_technique(db: Session, tech: Technique) -> TechniqueRiskProfile: """Calculate the risk profile for one technique and return the (unsaved) model.""" breakdown: dict = {} recs: list = [] # ── Factor 1: Detection gap (0=covered, 1=no coverage) ─────────────────── # Check if technique is covered (has at least one DetectionTechniqueMapping) mapping_count = db.query(DetectionTechniqueMapping).filter( DetectionTechniqueMapping.technique_id == tech.id, ).count() # Get DLC confidence score if available dlc_conf = db.query(TechniqueConfidenceScore).filter( TechniqueConfidenceScore.technique_id == tech.id, ).order_by(TechniqueConfidenceScore.computed_at.desc()).first() confidence_level: float = 0.0 if dlc_conf: confidence_level = float(dlc_conf.score or 0.0) # Also factor in technique status if tech.status == TechniqueStatus.covered: status_coverage = 1.0 elif tech.status == TechniqueStatus.partial: status_coverage = 0.5 else: # uncovered / unknown status_coverage = 0.0 if mapping_count > 0: # Has at least one asset mapped — use confidence as detection quality raw_coverage = max(status_coverage, _clamp(confidence_level)) else: raw_coverage = 0.0 detection_gap = 1.0 - raw_coverage detection_gap_factor = detection_gap # already 0–1 breakdown["detection_gap"] = { "mapping_count": mapping_count, "status_coverage": status_coverage, "confidence_level": confidence_level, "detection_gap": round(detection_gap, 3), "contribution": round(detection_gap_factor * WEIGHT_DETECTION_GAP * 100, 2), } if detection_gap >= 0.8: recs.append("Implement detection coverage — technique is largely undetected.") elif detection_gap >= 0.5: recs.append("Improve detection quality — coverage is partial.") # ── Factor 2: Threat actor relevance ───────────────────────────────────── actor_count = db.query(ThreatActorTechnique).filter( ThreatActorTechnique.technique_id == tech.id, ).count() ta_factor = _clamp(actor_count / MAX_THREAT_ACTORS) breakdown["threat_actor"] = { "actor_count": actor_count, "max_cap": MAX_THREAT_ACTORS, "normalised": round(ta_factor, 3), "contribution": round(ta_factor * WEIGHT_THREAT_ACTORS * 100, 2), } if actor_count >= 3: recs.append( f"High threat-actor relevance — {actor_count} tracked actors use this technique." ) elif actor_count >= 1: recs.append( f"{actor_count} threat actor(s) use this technique — monitor closely." ) # ── Factor 3: OSINT signals (last 30 days) ──────────────────────────────── cutoff = datetime.utcnow() - timedelta(days=OSINT_LOOKBACK_DAYS) osint_count = db.query(OsintItem).filter( OsintItem.technique_id == tech.id, OsintItem.discovered_at >= cutoff, ).count() osint_factor = _clamp(osint_count / MAX_OSINT_SIGNALS) breakdown["osint"] = { "signal_count_30d": osint_count, "max_cap": MAX_OSINT_SIGNALS, "normalised": round(osint_factor, 3), "contribution": round(osint_factor * WEIGHT_OSINT * 100, 2), } if osint_count >= 5: recs.append( f"High OSINT activity — {osint_count} signals in the last 30 days. Review urgently." ) elif osint_count >= 1: recs.append( f"{osint_count} OSINT signal(s) detected in last 30 days. Review for IoCs." ) # ── Factor 4: Test failure rate ─────────────────────────────────────────── # Count TestDetectionResult rows for this technique's tests from app.models.enums import TestResult tech_tests = db.query(Test).filter(Test.technique_id == tech.id).all() test_ids = [t.id for t in tech_tests] test_total = 0 test_not_detected = 0 if test_ids: from app.models.test_detection_result import TestDetectionResult as TDR results = db.query(TDR).filter(TDR.test_id.in_(test_ids)).all() test_total = len(results) test_not_detected = sum( 1 for r in results if hasattr(r, 'result') and str(getattr(r, 'result', '')) == 'not_detected' ) # Also count tests where overall result is not_detected if test_total == 0: for t in tech_tests: if hasattr(t, 'result') and t.result is not None: test_total += 1 if str(t.result) in ('not_detected', 'TestResult.not_detected'): test_not_detected += 1 test_failure_rate = (test_not_detected / test_total) if test_total > 0 else 0.0 # If no tests exist at all → treat as unknown risk (moderate) test_factor = test_failure_rate if test_total > 0 else 0.3 breakdown["test_failures"] = { "total_tests": test_total, "not_detected": test_not_detected, "failure_rate": round(test_failure_rate, 3), "factor_used": round(test_factor, 3), "contribution": round(test_factor * WEIGHT_TEST_FAILURES * 100, 2), } if test_total == 0: recs.append("No purple-team tests found — add tests to validate detection.") elif test_failure_rate >= 0.5: recs.append( f"High test failure rate ({test_failure_rate:.0%}) — blue team is missing this technique." ) # ── Weighted risk score ─────────────────────────────────────────────────── raw_score = ( detection_gap_factor * WEIGHT_DETECTION_GAP + ta_factor * WEIGHT_THREAT_ACTORS + osint_factor * WEIGHT_OSINT + test_factor * WEIGHT_TEST_FAILURES ) risk_score = _clamp(raw_score) * 100.0 # Likelihood = detection + actor contribution (exposure) likelihood = _clamp( detection_gap_factor * 0.5 + ta_factor * 0.35 + osint_factor * 0.15 ) * 100.0 # Impact = test failures + osint severity signal impact = _clamp( test_factor * 0.6 + osint_factor * 0.25 + detection_gap_factor * 0.15 ) * 100.0 level = _risk_level(risk_score) breakdown["total"] = { "risk_score": round(risk_score, 2), "likelihood": round(likelihood, 2), "impact": round(impact, 2), "risk_level": level, } return TechniqueRiskProfile( technique_id = tech.id, risk_score = round(risk_score, 4), likelihood = round(likelihood, 4), impact = round(impact, 4), risk_level = level, detection_gap = round(detection_gap, 4), threat_actor_count = actor_count, osint_signal_count = osint_count, test_fail_count = test_not_detected, test_total_count = test_total, test_failure_rate = round(test_failure_rate, 4), confidence_level = round(confidence_level, 4), scoring_breakdown = breakdown, recommendations = recs or ["Risk profile looks healthy — continue monitoring."], computed_at = datetime.utcnow(), is_stale = False, ) # ── Upsert helpers ───────────────────────────────────────────────────────────── def _upsert_profile(db: Session, profile: TechniqueRiskProfile) -> TechniqueRiskProfile: existing = db.query(TechniqueRiskProfile).filter( TechniqueRiskProfile.technique_id == profile.technique_id, ).first() if existing: for attr in ( "risk_score", "likelihood", "impact", "risk_level", "detection_gap", "threat_actor_count", "osint_signal_count", "test_fail_count", "test_total_count", "test_failure_rate", "confidence_level", "scoring_breakdown", "recommendations", "computed_at", "is_stale", ): setattr(existing, attr, getattr(profile, attr)) db.commit() db.refresh(existing) return existing db.add(profile) db.commit() db.refresh(profile) return profile # ── Public API ──────────────────────────────────────────────────────────────── def compute_technique_risk(db: Session, technique_id: UUID) -> TechniqueRiskProfile: """Compute (or recompute) risk profile for a single technique.""" tech = db.query(Technique).filter(Technique.id == technique_id).first() if not tech: raise EntityNotFoundError("Technique", str(technique_id)) profile = _compute_for_technique(db, tech) return _upsert_profile(db, profile) def compute_all_risk_scores(db: Session) -> dict: """Compute risk profiles for all techniques. Returns summary counts.""" t0 = time.monotonic() techniques = db.query(Technique).all() computed = 0 errors = 0 for tech in techniques: try: profile = _compute_for_technique(db, tech) _upsert_profile(db, profile) computed += 1 except Exception: errors += 1 duration = time.monotonic() - t0 return { "computed": computed, "skipped": 0, "errors": errors, "duration_seconds": round(duration, 2), } def get_risk_profile(db: Session, technique_id: UUID) -> TechniqueRiskProfile: profile = db.query(TechniqueRiskProfile).filter( TechniqueRiskProfile.technique_id == technique_id, ).first() if not profile: raise EntityNotFoundError("TechniqueRiskProfile", str(technique_id)) return profile def list_risk_profiles( db: Session, risk_level: Optional[str] = None, min_score: Optional[float] = None, max_score: Optional[float] = None, stale_only: bool = False, limit: int = 100, offset: int = 0, ) -> List[TechniqueRiskProfile]: q = db.query(TechniqueRiskProfile) if risk_level: q = q.filter(TechniqueRiskProfile.risk_level == risk_level) if min_score is not None: q = q.filter(TechniqueRiskProfile.risk_score >= min_score) if max_score is not None: q = q.filter(TechniqueRiskProfile.risk_score <= max_score) if stale_only: q = q.filter(TechniqueRiskProfile.is_stale == True) return ( q.order_by(TechniqueRiskProfile.risk_score.desc()) .offset(offset) .limit(limit) .all() ) def get_risk_matrix(db: Session) -> list: """Return all profiled techniques with name+tid for the matrix view.""" rows = ( db.query(TechniqueRiskProfile, Technique) .join(Technique, TechniqueRiskProfile.technique_id == Technique.id) .order_by(TechniqueRiskProfile.risk_score.desc()) .all() ) result = [] for profile, tech in rows: result.append({ "technique_id": str(profile.technique_id), "technique_name": tech.name, "technique_tid": tech.technique_id, # MITRE T-ID string "risk_score": profile.risk_score, "likelihood": profile.likelihood, "impact": profile.impact, "risk_level": profile.risk_level, "detection_gap": profile.detection_gap, "computed_at": profile.computed_at.isoformat() if profile.computed_at else None, }) return result def get_risk_summary(db: Session) -> dict: """Aggregate statistics across all risk profiles.""" all_profiles = db.query(TechniqueRiskProfile).all() total_tech = db.query(Technique).count() scored = len(all_profiles) stale = sum(1 for p in all_profiles if p.is_stale) by_level: dict = {lvl: 0 for lvl in ("critical", "high", "medium", "low", "info")} score_sum = 0.0 for p in all_profiles: by_level[p.risk_level] = by_level.get(p.risk_level, 0) + 1 score_sum += p.risk_score avg_score = (score_sum / scored) if scored > 0 else 0.0 # Top 5 by risk score (with technique name) top_rows = ( db.query(TechniqueRiskProfile, Technique) .join(Technique, TechniqueRiskProfile.technique_id == Technique.id) .order_by(TechniqueRiskProfile.risk_score.desc()) .limit(5) .all() ) top_risks = [ { "technique_id": str(p.technique_id), "technique_name": t.name, "technique_tid": t.technique_id, "risk_score": p.risk_score, "risk_level": p.risk_level, "likelihood": p.likelihood, "impact": p.impact, "detection_gap": p.detection_gap, "computed_at": p.computed_at.isoformat() if p.computed_at else None, } for p, t in top_rows ] return { "total_techniques": total_tech, "scored_techniques": scored, "stale_count": stale, "by_level": by_level, "avg_risk_score": round(avg_score, 2), "top_risks": top_risks, } def get_recommendations(db: Session, limit: int = 20) -> list: """Prioritised list of techniques with actionable recommendations.""" rows = ( db.query(TechniqueRiskProfile, Technique) .join(Technique, TechniqueRiskProfile.technique_id == Technique.id) .filter(TechniqueRiskProfile.risk_score > 0) .order_by(TechniqueRiskProfile.risk_score.desc()) .limit(limit) .all() ) result = [] for priority, (profile, tech) in enumerate(rows, start=1): result.append({ "technique_id": str(profile.technique_id), "technique_name": tech.name, "technique_tid": tech.technique_id, "risk_level": profile.risk_level, "risk_score": profile.risk_score, "recommendations": profile.recommendations or [], "priority": priority, }) return result