"""Scoring service — granular 0-100 scoring for techniques, tactics, actors, and org. Reads configurable weights from the ``scoring_config`` table (falling back to env-var defaults) to compute coverage scores with detailed breakdowns. Bulk helpers (``bulk_technique_scores``) pre-fetch all scoring data in a fixed number of aggregated queries so that organisation-wide calculations never produce N+1 traffic. """ from datetime import datetime, timedelta, timezone from sqlalchemy import case, func from sqlalchemy.orm import Session from app.domain.errors import EntityNotFoundError from app.models.defensive_technique import DefensiveTechniqueMapping from app.models.detection_rule import DetectionRule from app.models.enums import TestResult, TestState from app.models.technique import Technique from app.models.test import Test from app.models.test_detection_result import TestDetectionResult from app.models.threat_actor import ThreatActor, ThreatActorTechnique from app.services.scoring_config_service import get_scoring_weights _SEVERITY_FACTORS: dict[str, float] = { "critical": 1.0, "high": 0.85, "medium": 0.65, "low": 0.5, } def _recency_factor(last_tested: datetime | None) -> float: """Decay factor: 1.0 when recent, decreasing over time.""" if not last_tested: return 0.0 now = datetime.now(timezone.utc) tested = last_tested if tested.tzinfo is None: tested = tested.replace(tzinfo=timezone.utc) days_ago = (now - tested).days if days_ago <= 90: return 1.0 if days_ago <= 180: return 0.8 if days_ago <= 365: return 0.5 return 0.2 def _severity_factor(severity_label: str | None) -> float: """Map template severity to a 0–1 multiplier.""" if not severity_label: return 0.7 return _SEVERITY_FACTORS.get(severity_label.lower(), 0.7) def _max_severity_by_mitre(db: Session) -> dict[str, str]: """Highest severity label per MITRE id from active test templates.""" from app.models.test_template import TestTemplate order = {"critical": 4, "high": 3, "medium": 2, "low": 1} rows = ( db.query(TestTemplate.mitre_technique_id, TestTemplate.severity) .filter( TestTemplate.is_active == True, # noqa: E712 TestTemplate.severity.isnot(None), ) .all() ) best: dict[str, str] = {} for mitre_id, severity in rows: if not mitre_id or not severity: continue current = best.get(mitre_id) if current is None or order.get(severity.lower(), 0) > order.get(current.lower(), 0): best[mitre_id] = severity return best # ── Bulk scoring helpers (5 queries for ALL techniques) ─────────────── def bulk_technique_scores(db: Session) -> dict: """Pre-fetch all scoring data and compute per-technique scores in memory. Executes exactly 5 queries regardless of technique count: Q1 — Test aggregates per technique (validated / detected / platforms / freshness) Q2 — Detection rules per mitre_id Q3 — Triggered rules per mitre_id Q4 — D3FEND mapping counts per technique Q5 — All techniques Returns ``{technique_id: {"total_score": float, "breakdown": dict}}``. """ w = get_scoring_weights(db) w_tests = w.tests w_detection = w.detection_rules w_d3fend = w.d3fend w_recency = w.recency w_severity = w.severity severity_by_mitre = _max_severity_by_mitre(db) last_validated = func.coalesce( Test.blue_validated_at, Test.red_validated_at, Test.created_at, ) # Q1: test stats grouped by technique_id test_rows = ( db.query( Test.technique_id, func.count(Test.id).label("validated_count"), func.count( case((Test.detection_result == TestResult.detected, Test.id)) ).label("detected_count"), func.max(last_validated).label("latest_validated_at"), ) .filter(Test.state == TestState.validated) .group_by(Test.technique_id) .all() ) test_stats: dict = {} for row in test_rows: test_stats[row.technique_id] = { "validated": row.validated_count, "detected": row.detected_count, "latest_validated_at": row.latest_validated_at, } # Q2: active detection rules per mitre_id rule_rows = ( db.query( DetectionRule.mitre_technique_id, func.count(DetectionRule.id).label("total"), ) .filter(DetectionRule.is_active == True) # noqa: E712 .group_by(DetectionRule.mitre_technique_id) .all() ) rules_by_mitre: dict[str, int] = {r.mitre_technique_id: r.total for r in rule_rows} # Q3: triggered rules per mitre_id triggered_rows = ( db.query( DetectionRule.mitre_technique_id, func.count(TestDetectionResult.id).label("triggered"), ) .join(DetectionRule, DetectionRule.id == TestDetectionResult.detection_rule_id) .filter(TestDetectionResult.triggered == True) # noqa: E712 .group_by(DetectionRule.mitre_technique_id) .all() ) triggered_by_mitre: dict[str, int] = { r.mitre_technique_id: r.triggered for r in triggered_rows } # Q4: D3FEND mapping counts per technique d3fend_rows = ( db.query( DefensiveTechniqueMapping.attack_technique_id, func.count(DefensiveTechniqueMapping.id).label("total"), ) .group_by(DefensiveTechniqueMapping.attack_technique_id) .all() ) d3fend_by_tech: dict = {r.attack_technique_id: r.total for r in d3fend_rows} # Q5: all techniques techniques = db.query(Technique).all() results: dict = {} for tech in techniques: ts = test_stats.get(tech.id, {}) validated = ts.get("validated", 0) detected = ts.get("detected", 0) latest_at = ts.get("latest_validated_at") breakdown = {} # 1. Tests validated with detection if validated > 0: test_ratio = detected / validated test_score = round(test_ratio * w_tests, 1) else: test_ratio = 0 test_score = 0 breakdown["tests_validated"] = { "score": test_score, "max": w_tests, "detail": ( f"{detected}/{validated} tests detected" if validated else "No validated tests" ), } # 2. Detection rules total_rules = rules_by_mitre.get(tech.mitre_id, 0) triggered_rules = triggered_by_mitre.get(tech.mitre_id, 0) if total_rules > 0: detection_ratio = min(triggered_rules / total_rules, 1.0) detection_score = round(detection_ratio * w_detection, 1) else: detection_ratio = 0 detection_score = 0 breakdown["detection_rules"] = { "score": detection_score, "max": w_detection, "detail": ( f"{triggered_rules}/{total_rules} rules triggered" if total_rules > 0 else "No detection rules available" ), } # 3. D3FEND coverage total_cm = d3fend_by_tech.get(tech.id, 0) if total_cm > 0 and detected > 0: verified_cm = min(detected, total_cm) d3fend_score = round((verified_cm / total_cm) * w_d3fend, 1) else: verified_cm = 0 d3fend_score = 0 breakdown["d3fend_coverage"] = { "score": d3fend_score, "max": w_d3fend, "detail": ( f"{verified_cm}/{total_cm} countermeasures" if total_cm > 0 else "No D3FEND mappings" ), } # 4. Recency decay recency_mult = _recency_factor(latest_at) recency_score = round(recency_mult * w_recency, 1) if latest_at: tested = latest_at if tested.tzinfo is None: days_ago = (datetime.utcnow() - tested).days else: days_ago = (datetime.now(timezone.utc) - tested.astimezone(timezone.utc)).days recency_detail = f"Last validated {days_ago} days ago (factor {recency_mult})" else: recency_detail = "No validated tests" breakdown["recency"] = { "score": recency_score, "max": w_recency, "detail": recency_detail, } # 5. Severity / criticality (template-driven) sev_label = severity_by_mitre.get(tech.mitre_id) sev_mult = _severity_factor(sev_label) severity_score = round(sev_mult * w_severity, 1) breakdown["severity"] = { "score": severity_score, "max": w_severity, "detail": ( f"Template severity: {sev_label} (factor {sev_mult})" if sev_label else "No severity template (default factor)" ), } total = min( test_score + detection_score + d3fend_score + recency_score + severity_score, 100, ) results[tech.id] = { "total_score": round(total, 1), "breakdown": breakdown, "mitre_id": tech.mitre_id, "tactic": tech.tactic, } return results # ── Technique-level scoring (single technique — preserved API) ──────── def score_technique_by_mitre_id(db: Session, mitre_id: str) -> dict: """Get detailed score with breakdown for a technique by MITRE ID.""" technique = db.query(Technique).filter(Technique.mitre_id == mitre_id).first() if not technique: raise EntityNotFoundError("Technique", mitre_id) result = calculate_technique_score(technique, db) return { "mitre_id": technique.mitre_id, "name": technique.name, "tactic": technique.tactic, "status_global": technique.status_global.value if technique.status_global else None, **result, } def score_actor_by_id(db: Session, actor_id: str) -> dict: """Get coverage score for a threat actor by ID.""" actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first() if not actor: raise EntityNotFoundError("ThreatActor", actor_id) return calculate_actor_coverage_score(actor_id, db) def calculate_technique_score(technique: Technique, db: Session) -> dict: """Calculate a 0-100 score for a technique with detailed breakdown. Weights are read from the ``scoring_config`` table (or env defaults). """ w = get_scoring_weights(db) w_tests = w.tests w_detection = w.detection_rules w_d3fend = w.d3fend w_recency = w.recency w_severity = w.severity severity_by_mitre = _max_severity_by_mitre(db) breakdown = {} # ── 1. Tests validated with detection ────────────────────────── all_tests = ( db.query(Test) .filter(Test.technique_id == technique.id) .all() ) validated_tests = [t for t in all_tests if t.state == TestState.validated] detected_tests = [ t for t in validated_tests if t.detection_result == TestResult.detected ] if validated_tests: test_ratio = len(detected_tests) / len(validated_tests) test_score = round(test_ratio * w_tests, 1) else: test_ratio = 0 test_score = 0 breakdown["tests_validated"] = { "score": test_score, "max": w_tests, "detail": f"{len(detected_tests)}/{len(validated_tests)} tests detected" if validated_tests else "No validated tests", } # ── 2. Detection rules coverage ─────────────────────────────── total_rules = ( db.query(func.count(DetectionRule.id)) .filter( DetectionRule.mitre_technique_id == technique.mitre_id, DetectionRule.is_active == True, # noqa: E712 ) .scalar() ) or 0 triggered_rules = 0 if total_rules > 0: triggered_rules = ( db.query(func.count(TestDetectionResult.id)) .join( DetectionRule, DetectionRule.id == TestDetectionResult.detection_rule_id, ) .filter( DetectionRule.mitre_technique_id == technique.mitre_id, TestDetectionResult.triggered == True, # noqa: E712 ) .scalar() ) or 0 detection_ratio = min(triggered_rules / total_rules, 1.0) detection_score = round(detection_ratio * w_detection, 1) else: detection_ratio = 0 detection_score = 0 breakdown["detection_rules"] = { "score": detection_score, "max": w_detection, "detail": f"{triggered_rules}/{total_rules} rules triggered" if total_rules > 0 else "No detection rules available", } # ── 3. D3FEND coverage ──────────────────────────────────────── total_countermeasures = ( db.query(func.count(DefensiveTechniqueMapping.id)) .filter(DefensiveTechniqueMapping.attack_technique_id == technique.id) .scalar() ) or 0 verified_countermeasures = 0 if total_countermeasures > 0 and len(detected_tests) > 0: verified_countermeasures = min(len(detected_tests), total_countermeasures) d3fend_ratio = verified_countermeasures / total_countermeasures d3fend_score = round(d3fend_ratio * w_d3fend, 1) else: d3fend_ratio = 0 d3fend_score = 0 breakdown["d3fend_coverage"] = { "score": d3fend_score, "max": w_d3fend, "detail": f"{verified_countermeasures}/{total_countermeasures} countermeasures" if total_countermeasures > 0 else "No D3FEND mappings", } # ── 4. Recency ──────────────────────────────────────────────── most_recent_test = None for t in validated_tests: candidate = t.blue_validated_at or t.red_validated_at or t.created_at if candidate and (most_recent_test is None or candidate > most_recent_test): most_recent_test = candidate recency_mult = _recency_factor(most_recent_test) recency_score = round(recency_mult * w_recency, 1) if most_recent_test: days_ago = ( datetime.now(timezone.utc) - ( most_recent_test.replace(tzinfo=timezone.utc) if most_recent_test.tzinfo is None else most_recent_test.astimezone(timezone.utc) ) ).days recency_detail = f"Last validated {days_ago} days ago (factor {recency_mult})" else: recency_detail = "No validated tests" breakdown["recency"] = { "score": recency_score, "max": w_recency, "detail": recency_detail, } # ── 5. Severity ─────────────────────────────────────────────── sev_label = severity_by_mitre.get(technique.mitre_id) sev_mult = _severity_factor(sev_label) severity_score = round(sev_mult * w_severity, 1) breakdown["severity"] = { "score": severity_score, "max": w_severity, "detail": ( f"Template severity: {sev_label} (factor {sev_mult})" if sev_label else "No severity template (default factor)" ), } # ── Total ───────────────────────────────────────────────────── total = min( test_score + detection_score + d3fend_score + recency_score + severity_score, 100, ) return { "total_score": round(total, 1), "breakdown": breakdown, } # ── Tactic-level scoring ───────────────────────────────────────────── def calculate_tactic_score(tactic: str, db: Session) -> dict: """Calculate average score for all techniques in a tactic.""" scores_map = bulk_technique_scores(db) matching = [ v["total_score"] for v in scores_map.values() if v.get("tactic") and tactic.lower() in v["tactic"].lower() ] return { "tactic": tactic, "average_score": round(sum(matching) / len(matching), 1) if matching else 0, "techniques_count": len(matching), "techniques_scored": len([s for s in matching if s > 0]), } # ── Threat actor scoring ───────────────────────────────────────────── def calculate_actor_coverage_score(actor_id: str, db: Session) -> dict: """Calculate coverage score for a specific threat actor's techniques.""" actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first() if not actor: return {"total_score": 0, "techniques_count": 0, "techniques_covered": 0} actor_techniques = ( db.query(ThreatActorTechnique) .filter(ThreatActorTechnique.threat_actor_id == actor.id) .all() ) technique_ids = {at.technique_id for at in actor_techniques} if not technique_ids: return { "actor_id": str(actor.id), "actor_name": actor.name, "total_score": 0, "techniques_count": 0, "techniques_covered": 0, "techniques_detail": [], } scores_map = bulk_technique_scores(db) scores = [] details = [] for tid in technique_ids: entry = scores_map.get(tid) if not entry: continue score = entry["total_score"] scores.append(score) details.append({ "mitre_id": entry["mitre_id"], "name": entry.get("name", ""), "score": score, "breakdown": entry["breakdown"], }) avg_score = round(sum(scores) / len(scores), 1) if scores else 0 return { "actor_id": str(actor.id), "actor_name": actor.name, "total_score": avg_score, "techniques_count": len(technique_ids), "techniques_covered": len([s for s in scores if s > 50]), "techniques_detail": details, } # ── Organization-level scoring ──────────────────────────────────────── def calculate_organization_score(db: Session) -> dict: """Calculate the overall organization security score. Uses ``bulk_technique_scores`` to compute all technique scores in 5 aggregated queries instead of N*5. """ scores_map = bulk_technique_scores(db) total_count = len(scores_map) if total_count == 0: return { "overall_score": 0, "total_coverage": 0, "critical_coverage": 0, "detection_maturity": 0, "response_readiness": 0, "techniques_evaluated": 0, "techniques_total": 0, } all_scores = [v["total_score"] for v in scores_map.values()] evaluated_scores = [s for s in all_scores if s > 0] evaluated_count = len(evaluated_scores) total_coverage = ( round(sum(evaluated_scores) / len(evaluated_scores), 1) if evaluated_scores else 0 ) # Critical coverage: techniques with high/critical severity templates from app.models.test_template import TestTemplate critical_mitre_ids = set( row[0] for row in db.query(TestTemplate.mitre_technique_id) .filter(TestTemplate.severity.in_(["high", "critical"])) .distinct() .all() ) critical_scores = [ v["total_score"] for v in scores_map.values() if v.get("mitre_id") in critical_mitre_ids ] critical_coverage = ( round(sum(critical_scores) / len(critical_scores), 1) if critical_scores else 0 ) # Detection maturity (2 scalar queries — already efficient) total_rules = ( db.query(func.count(DetectionRule.id)) .filter(DetectionRule.is_active == True) # noqa: E712 .scalar() ) or 0 triggered_total = ( db.query(func.count(TestDetectionResult.id)) .filter(TestDetectionResult.triggered == True) # noqa: E712 .scalar() ) or 0 detection_maturity = ( round((triggered_total / total_rules) * 100, 1) if total_rules > 0 else 0 ) detection_maturity = min(detection_maturity, 100) # Response readiness (2 scalar queries — already efficient) remediation_total = ( db.query(func.count(Test.id)) .filter(Test.remediation_status.isnot(None)) .scalar() ) or 0 remediation_completed = ( db.query(func.count(Test.id)) .filter(Test.remediation_status == "completed") .scalar() ) or 0 response_readiness = ( round((remediation_completed / remediation_total) * 100, 1) if remediation_total > 0 else 0 ) overall = round( total_coverage * 0.4 + critical_coverage * 0.25 + detection_maturity * 0.2 + response_readiness * 0.15, 1, ) return { "overall_score": overall, "total_coverage": total_coverage, "critical_coverage": critical_coverage, "detection_maturity": detection_maturity, "response_readiness": response_readiness, "techniques_evaluated": evaluated_count, "techniques_total": total_count, } # ── Score history ──────────────────────────────────────────────────── def get_score_history(db: Session, period: str = "90d") -> list: """Get historical score snapshots. Since we don't have a dedicated history table, we approximate by computing scores based on test dates within time windows. Returns a list of weekly data points. """ now = datetime.utcnow() if period == "30d": start = now - timedelta(days=30) elif period == "1y": start = now - timedelta(days=365) else: # 90d default start = now - timedelta(days=90) # Group validated tests by week weeks = [] current = start while current < now: week_end = min(current + timedelta(days=7), now) # Count validated tests up to this week validated_up_to = ( db.query(func.count(Test.id)) .filter( Test.state == TestState.validated, Test.red_validated_at <= week_end, ) .scalar() ) or 0 total_techniques = ( db.query(func.count(Technique.id)).scalar() ) or 1 # Simple approximation: coverage percentage as score proxy score_approx = round((validated_up_to / total_techniques) * 100, 1) weeks.append({ "date": current.strftime("%Y-%m-%d"), "score": min(score_approx, 100), "validated_tests": validated_up_to, }) current = week_end return weeks