diff --git a/backend/app/services/risk_intelligence_service.py b/backend/app/services/risk_intelligence_service.py index c1bed66..53eb1d0 100644 --- a/backend/app/services/risk_intelligence_service.py +++ b/backend/app/services/risk_intelligence_service.py @@ -5,7 +5,6 @@ from datetime import datetime, timedelta from typing import List, Optional from uuid import UUID -from sqlalchemy import func from sqlalchemy.orm import Session from app.domain.errors import EntityNotFoundError @@ -14,25 +13,22 @@ from app.models.technique import Technique from app.models.threat_actor import ThreatActorTechnique from app.models.osint_item import OsintItem from app.models.test import Test -from app.models.test_detection_result import TestDetectionResult from app.models.detection_lifecycle import ( TechniqueConfidenceScore, DetectionTechniqueMapping, - DetectionConfidence, ) -from app.models.enums import TechniqueStatus +from app.models.enums import TechniqueStatus, TestResult -# ── Scoring constants ────────────────────────────────────────────────────────── +# ── Scoring weights & thresholds ─────────────────────────────────────────────── WEIGHT_DETECTION_GAP = 0.35 WEIGHT_THREAT_ACTORS = 0.30 WEIGHT_OSINT = 0.20 WEIGHT_TEST_FAILURES = 0.15 -# Normalisation caps -MAX_THREAT_ACTORS = 5 # beyond this → factor saturates at 1.0 -MAX_OSINT_SIGNALS = 10 # OSINT items in last 30 days +MAX_THREAT_ACTORS = 5 +MAX_OSINT_SIGNALS = 10 OSINT_LOOKBACK_DAYS = 30 LEVEL_CRITICAL = 75.0 @@ -53,52 +49,48 @@ def _clamp(v: float, lo: float = 0.0, hi: float = 1.0) -> float: return max(lo, min(hi, v)) -# ── Single-technique computation ─────────────────────────────────────────────── +# ── Per-technique computation ────────────────────────────────────────────────── def _compute_for_technique(db: Session, tech: Technique) -> TechniqueRiskProfile: - """Calculate the risk profile for one technique and return the (unsaved) model.""" - breakdown: dict = {} recs: list = [] - # ── Factor 1: Detection gap (0=covered, 1=no coverage) ─────────────────── - # Check if technique is covered (has at least one DetectionTechniqueMapping) + # ── Factor 1: Detection gap ─────────────────────────────────────────────── + # Count how many assets map this technique mapping_count = db.query(DetectionTechniqueMapping).filter( DetectionTechniqueMapping.technique_id == tech.id, ).count() - # Get DLC confidence score if available + # DLC confidence score (unique per technique — no order_by needed) dlc_conf = db.query(TechniqueConfidenceScore).filter( TechniqueConfidenceScore.technique_id == tech.id, - ).order_by(TechniqueConfidenceScore.computed_at.desc()).first() + ).first() + confidence_level: float = float(dlc_conf.confidence_score or 0.0) if dlc_conf else 0.0 - confidence_level: float = 0.0 - if dlc_conf: - confidence_level = float(dlc_conf.score or 0.0) - - # Also factor in technique status - if tech.status == TechniqueStatus.covered: + # Map technique status to coverage factor + status = tech.status + if status == TechniqueStatus.validated: status_coverage = 1.0 - elif tech.status == TechniqueStatus.partial: + elif status == TechniqueStatus.partial: status_coverage = 0.5 - else: # uncovered / unknown + elif status == TechniqueStatus.in_progress: + status_coverage = 0.25 + else: status_coverage = 0.0 if mapping_count > 0: - # Has at least one asset mapped — use confidence as detection quality raw_coverage = max(status_coverage, _clamp(confidence_level)) else: - raw_coverage = 0.0 + raw_coverage = status_coverage # rely on status alone if no asset mapping detection_gap = 1.0 - raw_coverage - detection_gap_factor = detection_gap # already 0–1 - breakdown["detection_gap"] = { - "mapping_count": mapping_count, + "mapping_count": mapping_count, + "status": str(status) if status else None, "status_coverage": status_coverage, - "confidence_level": confidence_level, - "detection_gap": round(detection_gap, 3), - "contribution": round(detection_gap_factor * WEIGHT_DETECTION_GAP * 100, 2), + "confidence_level": round(confidence_level, 3), + "detection_gap": round(detection_gap, 3), + "contribution": round(detection_gap * WEIGHT_DETECTION_GAP * 100, 2), } if detection_gap >= 0.8: recs.append("Implement detection coverage — technique is largely undetected.") @@ -109,12 +101,11 @@ def _compute_for_technique(db: Session, tech: Technique) -> TechniqueRiskProfile actor_count = db.query(ThreatActorTechnique).filter( ThreatActorTechnique.technique_id == tech.id, ).count() - ta_factor = _clamp(actor_count / MAX_THREAT_ACTORS) breakdown["threat_actor"] = { - "actor_count": actor_count, - "max_cap": MAX_THREAT_ACTORS, - "normalised": round(ta_factor, 3), + "actor_count": actor_count, + "max_cap": MAX_THREAT_ACTORS, + "normalised": round(ta_factor, 3), "contribution": round(ta_factor * WEIGHT_THREAT_ACTORS * 100, 2), } if actor_count >= 3: @@ -132,85 +123,64 @@ def _compute_for_technique(db: Session, tech: Technique) -> TechniqueRiskProfile OsintItem.technique_id == tech.id, OsintItem.discovered_at >= cutoff, ).count() - osint_factor = _clamp(osint_count / MAX_OSINT_SIGNALS) breakdown["osint"] = { "signal_count_30d": osint_count, - "max_cap": MAX_OSINT_SIGNALS, - "normalised": round(osint_factor, 3), - "contribution": round(osint_factor * WEIGHT_OSINT * 100, 2), + "max_cap": MAX_OSINT_SIGNALS, + "normalised": round(osint_factor, 3), + "contribution": round(osint_factor * WEIGHT_OSINT * 100, 2), } if osint_count >= 5: recs.append( - f"High OSINT activity — {osint_count} signals in the last 30 days. Review urgently." + f"High OSINT activity — {osint_count} signals in last 30 days. Review urgently." ) elif osint_count >= 1: recs.append( - f"{osint_count} OSINT signal(s) detected in last 30 days. Review for IoCs." + f"{osint_count} OSINT signal(s) in last 30 days. Review for IoCs." ) # ── Factor 4: Test failure rate ─────────────────────────────────────────── - # Count TestDetectionResult rows for this technique's tests - from app.models.enums import TestResult + # Use Test.result (TestResult enum) to determine not-detected count tech_tests = db.query(Test).filter(Test.technique_id == tech.id).all() - test_ids = [t.id for t in tech_tests] - - test_total = 0 - test_not_detected = 0 - if test_ids: - from app.models.test_detection_result import TestDetectionResult as TDR - results = db.query(TDR).filter(TDR.test_id.in_(test_ids)).all() - test_total = len(results) - test_not_detected = sum( - 1 for r in results - if hasattr(r, 'result') and str(getattr(r, 'result', '')) == 'not_detected' - ) - # Also count tests where overall result is not_detected - if test_total == 0: - for t in tech_tests: - if hasattr(t, 'result') and t.result is not None: - test_total += 1 - if str(t.result) in ('not_detected', 'TestResult.not_detected'): - test_not_detected += 1 - + test_total = len([t for t in tech_tests if t.result is not None]) + test_not_detected = sum( + 1 for t in tech_tests + if t.result == TestResult.not_detected + ) test_failure_rate = (test_not_detected / test_total) if test_total > 0 else 0.0 - # If no tests exist at all → treat as unknown risk (moderate) + # No tests → moderate unknown risk factor test_factor = test_failure_rate if test_total > 0 else 0.3 breakdown["test_failures"] = { - "total_tests": test_total, - "not_detected": test_not_detected, - "failure_rate": round(test_failure_rate, 3), - "factor_used": round(test_factor, 3), - "contribution": round(test_factor * WEIGHT_TEST_FAILURES * 100, 2), + "total_tests": test_total, + "not_detected": test_not_detected, + "failure_rate": round(test_failure_rate, 3), + "factor_used": round(test_factor, 3), + "contribution": round(test_factor * WEIGHT_TEST_FAILURES * 100, 2), } if test_total == 0: recs.append("No purple-team tests found — add tests to validate detection.") elif test_failure_rate >= 0.5: recs.append( - f"High test failure rate ({test_failure_rate:.0%}) — blue team is missing this technique." + f"High test failure rate ({test_failure_rate:.0%}) — blue team misses this technique." ) - # ── Weighted risk score ─────────────────────────────────────────────────── + # ── Aggregate scores ────────────────────────────────────────────────────── raw_score = ( - detection_gap_factor * WEIGHT_DETECTION_GAP - + ta_factor * WEIGHT_THREAT_ACTORS - + osint_factor * WEIGHT_OSINT - + test_factor * WEIGHT_TEST_FAILURES + detection_gap * WEIGHT_DETECTION_GAP + + ta_factor * WEIGHT_THREAT_ACTORS + + osint_factor * WEIGHT_OSINT + + test_factor * WEIGHT_TEST_FAILURES ) risk_score = _clamp(raw_score) * 100.0 - - # Likelihood = detection + actor contribution (exposure) likelihood = _clamp( - detection_gap_factor * 0.5 + ta_factor * 0.35 + osint_factor * 0.15 + detection_gap * 0.5 + ta_factor * 0.35 + osint_factor * 0.15 ) * 100.0 - - # Impact = test failures + osint severity signal impact = _clamp( - test_factor * 0.6 + osint_factor * 0.25 + detection_gap_factor * 0.15 + test_factor * 0.6 + osint_factor * 0.25 + detection_gap * 0.15 ) * 100.0 - level = _risk_level(risk_score) + breakdown["total"] = { "risk_score": round(risk_score, 2), "likelihood": round(likelihood, 2), @@ -238,7 +208,7 @@ def _compute_for_technique(db: Session, tech: Technique) -> TechniqueRiskProfile ) -# ── Upsert helpers ───────────────────────────────────────────────────────────── +# ── Upsert ───────────────────────────────────────────────────────────────────── def _upsert_profile(db: Session, profile: TechniqueRiskProfile) -> TechniqueRiskProfile: existing = db.query(TechniqueRiskProfile).filter( @@ -265,7 +235,6 @@ def _upsert_profile(db: Session, profile: TechniqueRiskProfile) -> TechniqueRisk # ── Public API ──────────────────────────────────────────────────────────────── def compute_technique_risk(db: Session, technique_id: UUID) -> TechniqueRiskProfile: - """Compute (or recompute) risk profile for a single technique.""" tech = db.query(Technique).filter(Technique.id == technique_id).first() if not tech: raise EntityNotFoundError("Technique", str(technique_id)) @@ -274,12 +243,10 @@ def compute_technique_risk(db: Session, technique_id: UUID) -> TechniqueRiskProf def compute_all_risk_scores(db: Session) -> dict: - """Compute risk profiles for all techniques. Returns summary counts.""" t0 = time.monotonic() techniques = db.query(Technique).all() computed = 0 errors = 0 - for tech in techniques: try: profile = _compute_for_technique(db, tech) @@ -287,13 +254,11 @@ def compute_all_risk_scores(db: Session) -> dict: computed += 1 except Exception: errors += 1 - - duration = time.monotonic() - t0 return { - "computed": computed, - "skipped": 0, - "errors": errors, - "duration_seconds": round(duration, 2), + "computed": computed, + "skipped": 0, + "errors": errors, + "duration_seconds": round(time.monotonic() - t0, 2), } @@ -308,12 +273,12 @@ def get_risk_profile(db: Session, technique_id: UUID) -> TechniqueRiskProfile: def list_risk_profiles( db: Session, - risk_level: Optional[str] = None, - min_score: Optional[float] = None, - max_score: Optional[float] = None, - stale_only: bool = False, - limit: int = 100, - offset: int = 0, + risk_level: Optional[str] = None, + min_score: Optional[float] = None, + max_score: Optional[float] = None, + stale_only: bool = False, + limit: int = 100, + offset: int = 0, ) -> List[TechniqueRiskProfile]: q = db.query(TechniqueRiskProfile) if risk_level: @@ -326,58 +291,51 @@ def list_risk_profiles( q = q.filter(TechniqueRiskProfile.is_stale == True) return ( q.order_by(TechniqueRiskProfile.risk_score.desc()) - .offset(offset) - .limit(limit) - .all() + .offset(offset).limit(limit).all() ) def get_risk_matrix(db: Session) -> list: - """Return all profiled techniques with name+tid for the matrix view.""" rows = ( db.query(TechniqueRiskProfile, Technique) .join(Technique, TechniqueRiskProfile.technique_id == Technique.id) .order_by(TechniqueRiskProfile.risk_score.desc()) .all() ) - result = [] - for profile, tech in rows: - result.append({ - "technique_id": str(profile.technique_id), - "technique_name": tech.name, - "technique_tid": tech.technique_id, # MITRE T-ID string - "risk_score": profile.risk_score, - "likelihood": profile.likelihood, - "impact": profile.impact, - "risk_level": profile.risk_level, - "detection_gap": profile.detection_gap, - "computed_at": profile.computed_at.isoformat() if profile.computed_at else None, - }) - return result + return [ + { + "technique_id": str(p.technique_id), + "technique_name": t.name, + "technique_tid": t.technique_id, + "risk_score": p.risk_score, + "likelihood": p.likelihood, + "impact": p.impact, + "risk_level": p.risk_level, + "detection_gap": p.detection_gap, + "computed_at": p.computed_at.isoformat() if p.computed_at else None, + } + for p, t in rows + ] def get_risk_summary(db: Session) -> dict: - """Aggregate statistics across all risk profiles.""" all_profiles = db.query(TechniqueRiskProfile).all() total_tech = db.query(Technique).count() scored = len(all_profiles) stale = sum(1 for p in all_profiles if p.is_stale) - by_level: dict = {lvl: 0 for lvl in ("critical", "high", "medium", "low", "info")} + by_level: dict = {l: 0 for l in ("critical", "high", "medium", "low", "info")} score_sum = 0.0 for p in all_profiles: by_level[p.risk_level] = by_level.get(p.risk_level, 0) + 1 score_sum += p.risk_score - avg_score = (score_sum / scored) if scored > 0 else 0.0 - # Top 5 by risk score (with technique name) top_rows = ( db.query(TechniqueRiskProfile, Technique) .join(Technique, TechniqueRiskProfile.technique_id == Technique.id) .order_by(TechniqueRiskProfile.risk_score.desc()) - .limit(5) - .all() + .limit(5).all() ) top_risks = [ { @@ -393,7 +351,6 @@ def get_risk_summary(db: Session) -> dict: } for p, t in top_rows ] - return { "total_techniques": total_tech, "scored_techniques": scored, @@ -405,24 +362,22 @@ def get_risk_summary(db: Session) -> dict: def get_recommendations(db: Session, limit: int = 20) -> list: - """Prioritised list of techniques with actionable recommendations.""" rows = ( db.query(TechniqueRiskProfile, Technique) .join(Technique, TechniqueRiskProfile.technique_id == Technique.id) .filter(TechniqueRiskProfile.risk_score > 0) .order_by(TechniqueRiskProfile.risk_score.desc()) - .limit(limit) - .all() + .limit(limit).all() ) - result = [] - for priority, (profile, tech) in enumerate(rows, start=1): - result.append({ - "technique_id": str(profile.technique_id), - "technique_name": tech.name, - "technique_tid": tech.technique_id, - "risk_level": profile.risk_level, - "risk_score": profile.risk_score, - "recommendations": profile.recommendations or [], - "priority": priority, - }) - return result + return [ + { + "technique_id": str(p.technique_id), + "technique_name": t.name, + "technique_tid": t.technique_id, + "risk_level": p.risk_level, + "risk_score": p.risk_score, + "recommendations": p.recommendations or [], + "priority": i, + } + for i, (p, t) in enumerate(rows, start=1) + ]