"""Scoring service — granular 0-100 scoring for techniques, tactics, actors, and org. Reads configurable weights from the ``scoring_config`` table (falling back to env-var defaults) to compute coverage scores with detailed breakdowns. Bulk helpers (``bulk_technique_scores``) pre-fetch all scoring data in a fixed number of aggregated queries so that organisation-wide calculations never produce N+1 traffic. """ from datetime import datetime, timedelta, timezone from typing import Optional from sqlalchemy import case, func from sqlalchemy.orm import Session from app.domain.errors import EntityNotFoundError from app.models.technique import Technique from app.models.test import Test from app.models.detection_rule import DetectionRule from app.models.test_detection_result import TestDetectionResult from app.models.defensive_technique import DefensiveTechniqueMapping from app.models.threat_actor import ThreatActor, ThreatActorTechnique from app.models.enums import TestState, TestResult from app.services.scoring_config_service import get_scoring_weights _SEVERITY_FACTORS: dict[str, float] = { "critical": 1.0, "high": 0.85, "medium": 0.65, "low": 0.5, } def _recency_factor(last_tested: datetime | None) -> float: """Decay factor: 1.0 when recent, decreasing over time.""" if not last_tested: return 0.0 now = datetime.now(timezone.utc) tested = last_tested if tested.tzinfo is None: tested = tested.replace(tzinfo=timezone.utc) days_ago = (now - tested).days if days_ago <= 90: return 1.0 if days_ago <= 180: return 0.8 if days_ago <= 365: return 0.5 return 0.2 def _severity_factor(severity_label: str | None) -> float: """Map template severity to a 0–1 multiplier.""" if not severity_label: return 0.7 return _SEVERITY_FACTORS.get(severity_label.lower(), 0.7) def _max_severity_by_mitre(db: Session) -> dict[str, str]: """Highest severity label per MITRE id from active test templates.""" from app.models.test_template import TestTemplate order = {"critical": 4, "high": 3, "medium": 2, "low": 1} rows = ( db.query(TestTemplate.mitre_technique_id, TestTemplate.severity) .filter( TestTemplate.is_active == True, # noqa: E712 TestTemplate.severity.isnot(None), ) .all() ) best: dict[str, str] = {} for mitre_id, severity in rows: if not mitre_id or not severity: continue current = best.get(mitre_id) if current is None or order.get(severity.lower(), 0) > order.get(current.lower(), 0): best[mitre_id] = severity return best # ── Bulk scoring helpers (5 queries for ALL techniques) ─────────────── def bulk_technique_scores(db: Session) -> dict: """Pre-fetch all scoring data and compute per-technique scores in memory. Executes exactly 5 queries regardless of technique count: Q1 — Test aggregates per technique (validated / detected / platforms / freshness) Q2 — Detection rules per mitre_id Q3 — Triggered rules per mitre_id Q4 — D3FEND mapping counts per technique Q5 — All techniques Returns ``{technique_id: {"total_score": float, "breakdown": dict}}``. """ w = get_scoring_weights(db) w_tests = w.tests w_detection = w.detection_rules w_d3fend = w.d3fend w_recency = w.recency w_severity = w.severity severity_by_mitre = _max_severity_by_mitre(db) last_validated = func.coalesce( Test.blue_validated_at, Test.red_validated_at, Test.created_at, ) # Q1: test stats grouped by technique_id test_rows = ( db.query( Test.technique_id, func.count(Test.id).label("validated_count"), func.count( case((Test.detection_result == TestResult.detected, Test.id)) ).label("detected_count"), func.max(last_validated).label("latest_validated_at"), ) .filter(Test.state == TestState.validated) .group_by(Test.technique_id) .all() ) test_stats: dict = {} for row in test_rows: test_stats[row.technique_id] = { "validated": row.validated_count, "detected": row.detected_count, "latest_validated_at": row.latest_validated_at, } # Q2: active detection rules per mitre_id rule_rows = ( db.query( DetectionRule.mitre_technique_id, func.count(DetectionRule.id).label("total"), ) .filter(DetectionRule.is_active == True) # noqa: E712 .group_by(DetectionRule.mitre_technique_id) .all() ) rules_by_mitre: dict[str, int] = {r.mitre_technique_id: r.total for r in rule_rows} # Q3: triggered rules per mitre_id triggered_rows = ( db.query( DetectionRule.mitre_technique_id, func.count(TestDetectionResult.id).label("triggered"), ) .join(DetectionRule, DetectionRule.id == TestDetectionResult.detection_rule_id) .filter(TestDetectionResult.triggered == True) # noqa: E712 .group_by(DetectionRule.mitre_technique_id) .all() ) triggered_by_mitre: dict[str, int] = { r.mitre_technique_id: r.triggered for r in triggered_rows } # Q4: D3FEND mapping counts per technique d3fend_rows = ( db.query( DefensiveTechniqueMapping.attack_technique_id, func.count(DefensiveTechniqueMapping.id).label("total"), ) .group_by(DefensiveTechniqueMapping.attack_technique_id) .all() ) d3fend_by_tech: dict = {r.attack_technique_id: r.total for r in d3fend_rows} # Q5: all techniques techniques = db.query(Technique).all() results: dict = {} for tech in techniques: ts = test_stats.get(tech.id, {}) validated = ts.get("validated", 0) detected = ts.get("detected", 0) latest_at = ts.get("latest_validated_at") breakdown = {} # 1. Tests validated with detection if validated > 0: test_ratio = detected / validated test_score = round(test_ratio * w_tests, 1) else: test_ratio = 0 test_score = 0 breakdown["tests_validated"] = { "score": test_score, "max": w_tests, "detail": ( f"{detected}/{validated} tests detected" if validated else "No validated tests" ), } # 2. Detection rules total_rules = rules_by_mitre.get(tech.mitre_id, 0) triggered_rules = triggered_by_mitre.get(tech.mitre_id, 0) if total_rules > 0: detection_ratio = min(triggered_rules / total_rules, 1.0) detection_score = round(detection_ratio * w_detection, 1) else: detection_ratio = 0 detection_score = 0 breakdown["detection_rules"] = { "score": detection_score, "max": w_detection, "detail": ( f"{triggered_rules}/{total_rules} rules triggered" if total_rules > 0 else "No detection rules available" ), } # 3. D3FEND coverage total_cm = d3fend_by_tech.get(tech.id, 0) if total_cm > 0 and detected > 0: verified_cm = min(detected, total_cm) d3fend_score = round((verified_cm / total_cm) * w_d3fend, 1) else: verified_cm = 0 d3fend_score = 0 breakdown["d3fend_coverage"] = { "score": d3fend_score, "max": w_d3fend, "detail": ( f"{verified_cm}/{total_cm} countermeasures" if total_cm > 0 else "No D3FEND mappings" ), } # 4. Recency decay recency_mult = _recency_factor(latest_at) recency_score = round(recency_mult * w_recency, 1) if latest_at: tested = latest_at if tested.tzinfo is None: days_ago = (datetime.utcnow() - tested).days else: days_ago = (datetime.now(timezone.utc) - tested.astimezone(timezone.utc)).days recency_detail = f"Last validated {days_ago} days ago (factor {recency_mult})" else: recency_detail = "No validated tests" breakdown["recency"] = { "score": recency_score, "max": w_recency, "detail": recency_detail, } # 5. Severity / criticality (template-driven) sev_label = severity_by_mitre.get(tech.mitre_id) sev_mult = _severity_factor(sev_label) severity_score = round(sev_mult * w_severity, 1) breakdown["severity"] = { "score": severity_score, "max": w_severity, "detail": ( f"Template severity: {sev_label} (factor {sev_mult})" if sev_label else "No severity template (default factor)" ), } total = min( test_score + detection_score + d3fend_score + recency_score + severity_score, 100, ) results[tech.id] = { "total_score": round(total, 1), "breakdown": breakdown, "mitre_id": tech.mitre_id, "tactic": tech.tactic, } return results # ── Technique-level scoring (single technique — preserved API) ──────── def score_technique_by_mitre_id(db: Session, mitre_id: str) -> dict: """Get detailed score with breakdown for a technique by MITRE ID.""" technique = db.query(Technique).filter(Technique.mitre_id == mitre_id).first() if not technique: raise EntityNotFoundError("Technique", mitre_id) result = calculate_technique_score(technique, db) return { "mitre_id": technique.mitre_id, "name": technique.name, "tactic": technique.tactic, "status_global": technique.status_global.value if technique.status_global else None, **result, } def score_actor_by_id(db: Session, actor_id: str) -> dict: """Get coverage score for a threat actor by ID.""" actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first() if not actor: raise EntityNotFoundError("ThreatActor", actor_id) return calculate_actor_coverage_score(actor_id, db) def calculate_technique_score(technique: Technique, db: Session) -> dict: """Calculate a 0-100 score for a technique with detailed breakdown. Weights are read from the ``scoring_config`` table (or env defaults). """ w = get_scoring_weights(db) w_tests = w.tests w_detection = w.detection_rules w_d3fend = w.d3fend w_recency = w.recency w_severity = w.severity severity_by_mitre = _max_severity_by_mitre(db) breakdown = {} # ── 1. Tests validated with detection ────────────────────────── all_tests = ( db.query(Test) .filter(Test.technique_id == technique.id) .all() ) validated_tests = [t for t in all_tests if t.state == TestState.validated] detected_tests = [ t for t in validated_tests if t.detection_result == TestResult.detected ] if validated_tests: test_ratio = len(detected_tests) / len(validated_tests) test_score = round(test_ratio * w_tests, 1) else: test_ratio = 0 test_score = 0 breakdown["tests_validated"] = { "score": test_score, "max": w_tests, "detail": f"{len(detected_tests)}/{len(validated_tests)} tests detected" if validated_tests else "No validated tests", } # ── 2. Detection rules coverage ─────────────────────────────── total_rules = ( db.query(func.count(DetectionRule.id)) .filter( DetectionRule.mitre_technique_id == technique.mitre_id, DetectionRule.is_active == True, # noqa: E712 ) .scalar() ) or 0 triggered_rules = 0 if total_rules > 0: triggered_rules = ( db.query(func.count(TestDetectionResult.id)) .join( DetectionRule, DetectionRule.id == TestDetectionResult.detection_rule_id, ) .filter( DetectionRule.mitre_technique_id == technique.mitre_id, TestDetectionResult.triggered == True, # noqa: E712 ) .scalar() ) or 0 detection_ratio = min(triggered_rules / total_rules, 1.0) detection_score = round(detection_ratio * w_detection, 1) else: detection_ratio = 0 detection_score = 0 breakdown["detection_rules"] = { "score": detection_score, "max": w_detection, "detail": f"{triggered_rules}/{total_rules} rules triggered" if total_rules > 0 else "No detection rules available", } # ── 3. D3FEND coverage ──────────────────────────────────────── total_countermeasures = ( db.query(func.count(DefensiveTechniqueMapping.id)) .filter(DefensiveTechniqueMapping.attack_technique_id == technique.id) .scalar() ) or 0 verified_countermeasures = 0 if total_countermeasures > 0 and len(detected_tests) > 0: verified_countermeasures = min(len(detected_tests), total_countermeasures) d3fend_ratio = verified_countermeasures / total_countermeasures d3fend_score = round(d3fend_ratio * w_d3fend, 1) else: d3fend_ratio = 0 d3fend_score = 0 breakdown["d3fend_coverage"] = { "score": d3fend_score, "max": w_d3fend, "detail": f"{verified_countermeasures}/{total_countermeasures} countermeasures" if total_countermeasures > 0 else "No D3FEND mappings", } # ── 4. Recency ──────────────────────────────────────────────── most_recent_test = None for t in validated_tests: candidate = t.blue_validated_at or t.red_validated_at or t.created_at if candidate and (most_recent_test is None or candidate > most_recent_test): most_recent_test = candidate recency_mult = _recency_factor(most_recent_test) recency_score = round(recency_mult * w_recency, 1) if most_recent_test: days_ago = ( datetime.now(timezone.utc) - ( most_recent_test.replace(tzinfo=timezone.utc) if most_recent_test.tzinfo is None else most_recent_test.astimezone(timezone.utc) ) ).days recency_detail = f"Last validated {days_ago} days ago (factor {recency_mult})" else: recency_detail = "No validated tests" breakdown["recency"] = { "score": recency_score, "max": w_recency, "detail": recency_detail, } # ── 5. Severity ─────────────────────────────────────────────── sev_label = severity_by_mitre.get(technique.mitre_id) sev_mult = _severity_factor(sev_label) severity_score = round(sev_mult * w_severity, 1) breakdown["severity"] = { "score": severity_score, "max": w_severity, "detail": ( f"Template severity: {sev_label} (factor {sev_mult})" if sev_label else "No severity template (default factor)" ), } # ── Total ───────────────────────────────────────────────────── total = min( test_score + detection_score + d3fend_score + recency_score + severity_score, 100, ) return { "total_score": round(total, 1), "breakdown": breakdown, } # ── Tactic-level scoring ───────────────────────────────────────────── def calculate_tactic_score(tactic: str, db: Session) -> dict: """Calculate average score for all techniques in a tactic.""" scores_map = bulk_technique_scores(db) matching = [ v["total_score"] for v in scores_map.values() if v.get("tactic") and tactic.lower() in v["tactic"].lower() ] return { "tactic": tactic, "average_score": round(sum(matching) / len(matching), 1) if matching else 0, "techniques_count": len(matching), "techniques_scored": len([s for s in matching if s > 0]), } # ── Threat actor scoring ───────────────────────────────────────────── def calculate_actor_coverage_score(actor_id: str, db: Session) -> dict: """Calculate coverage score for a specific threat actor's techniques.""" actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first() if not actor: return {"total_score": 0, "techniques_count": 0, "techniques_covered": 0} actor_techniques = ( db.query(ThreatActorTechnique) .filter(ThreatActorTechnique.threat_actor_id == actor.id) .all() ) technique_ids = {at.technique_id for at in actor_techniques} if not technique_ids: return { "actor_id": str(actor.id), "actor_name": actor.name, "total_score": 0, "techniques_count": 0, "techniques_covered": 0, "techniques_detail": [], } scores_map = bulk_technique_scores(db) scores = [] details = [] for tid in technique_ids: entry = scores_map.get(tid) if not entry: continue score = entry["total_score"] scores.append(score) details.append({ "mitre_id": entry["mitre_id"], "name": entry.get("name", ""), "score": score, "breakdown": entry["breakdown"], }) avg_score = round(sum(scores) / len(scores), 1) if scores else 0 return { "actor_id": str(actor.id), "actor_name": actor.name, "total_score": avg_score, "techniques_count": len(technique_ids), "techniques_covered": len([s for s in scores if s > 50]), "techniques_detail": details, } # ── Organization-level scoring ──────────────────────────────────────── def calculate_organization_score(db: Session) -> dict: """Calculate the overall organization security score. Uses ``bulk_technique_scores`` to compute all technique scores in 5 aggregated queries instead of N*5. """ scores_map = bulk_technique_scores(db) total_count = len(scores_map) if total_count == 0: return { "overall_score": 0, "total_coverage": 0, "critical_coverage": 0, "detection_maturity": 0, "response_readiness": 0, "techniques_evaluated": 0, "techniques_total": 0, } all_scores = [v["total_score"] for v in scores_map.values()] evaluated_scores = [s for s in all_scores if s > 0] evaluated_count = len(evaluated_scores) total_coverage = ( round(sum(evaluated_scores) / len(evaluated_scores), 1) if evaluated_scores else 0 ) # Critical coverage: techniques with high/critical severity templates from app.models.test_template import TestTemplate critical_mitre_ids = set( row[0] for row in db.query(TestTemplate.mitre_technique_id) .filter(TestTemplate.severity.in_(["high", "critical"])) .distinct() .all() ) critical_scores = [ v["total_score"] for v in scores_map.values() if v.get("mitre_id") in critical_mitre_ids ] critical_coverage = ( round(sum(critical_scores) / len(critical_scores), 1) if critical_scores else 0 ) # Detection maturity (2 scalar queries — already efficient) total_rules = ( db.query(func.count(DetectionRule.id)) .filter(DetectionRule.is_active == True) # noqa: E712 .scalar() ) or 0 triggered_total = ( db.query(func.count(TestDetectionResult.id)) .filter(TestDetectionResult.triggered == True) # noqa: E712 .scalar() ) or 0 detection_maturity = ( round((triggered_total / total_rules) * 100, 1) if total_rules > 0 else 0 ) detection_maturity = min(detection_maturity, 100) # Response readiness (2 scalar queries — already efficient) remediation_total = ( db.query(func.count(Test.id)) .filter(Test.remediation_status.isnot(None)) .scalar() ) or 0 remediation_completed = ( db.query(func.count(Test.id)) .filter(Test.remediation_status == "completed") .scalar() ) or 0 response_readiness = ( round((remediation_completed / remediation_total) * 100, 1) if remediation_total > 0 else 0 ) overall = round( total_coverage * 0.4 + critical_coverage * 0.25 + detection_maturity * 0.2 + response_readiness * 0.15, 1, ) return { "overall_score": overall, "total_coverage": total_coverage, "critical_coverage": critical_coverage, "detection_maturity": detection_maturity, "response_readiness": response_readiness, "techniques_evaluated": evaluated_count, "techniques_total": total_count, } # ── Score history ──────────────────────────────────────────────────── def get_score_history(db: Session, period: str = "90d") -> list: """Get historical score snapshots. Since we don't have a dedicated history table, we approximate by computing scores based on test dates within time windows. Returns a list of weekly data points. """ from app.models.audit import AuditLog now = datetime.utcnow() if period == "30d": start = now - timedelta(days=30) elif period == "1y": start = now - timedelta(days=365) else: # 90d default start = now - timedelta(days=90) # Group validated tests by week weeks = [] current = start while current < now: week_end = min(current + timedelta(days=7), now) # Count validated tests up to this week validated_up_to = ( db.query(func.count(Test.id)) .filter( Test.state == TestState.validated, Test.red_validated_at <= week_end, ) .scalar() ) or 0 total_techniques = ( db.query(func.count(Technique.id)).scalar() ) or 1 # Simple approximation: coverage percentage as score proxy score_approx = round((validated_up_to / total_techniques) * 100, 1) weeks.append({ "date": current.strftime("%Y-%m-%d"), "score": min(score_approx, 100), "validated_tests": validated_up_to, }) current = week_end return weeks