"""Scoring service — granular 0-100 scoring for techniques, tactics, actors, and org. Reads configurable weights from the ``scoring_config`` table (falling back to env-var defaults) to compute coverage scores with detailed breakdowns. Bulk helpers (``bulk_technique_scores``) pre-fetch all scoring data in a fixed number of aggregated queries so that organisation-wide calculations never produce N+1 traffic. """ from datetime import datetime, timedelta from typing import Optional from sqlalchemy import case, func from sqlalchemy.orm import Session from app.models.technique import Technique from app.models.test import Test from app.models.detection_rule import DetectionRule from app.models.test_detection_result import TestDetectionResult from app.models.defensive_technique import DefensiveTechniqueMapping from app.models.threat_actor import ThreatActor, ThreatActorTechnique from app.models.enums import TestState, TestResult from app.services.scoring_config_service import get_scoring_weights # ── Bulk scoring helpers (5 queries for ALL techniques) ─────────────── def bulk_technique_scores(db: Session) -> dict: """Pre-fetch all scoring data and compute per-technique scores in memory. Executes exactly 5 queries regardless of technique count: Q1 — Test aggregates per technique (validated / detected / platforms / freshness) Q2 — Detection rules per mitre_id Q3 — Triggered rules per mitre_id Q4 — D3FEND mapping counts per technique Q5 — All techniques Returns ``{technique_id: {"total_score": float, "breakdown": dict}}``. """ w = get_scoring_weights(db) w_tests = w.tests w_detection = w.detection_rules w_d3fend = w.d3fend w_freshness = w.freshness w_diversity = w.platform_diversity # Q1: test stats grouped by technique_id test_rows = ( db.query( Test.technique_id, func.count(Test.id).label("validated_count"), func.count( case((Test.detection_result == TestResult.detected, Test.id)) ).label("detected_count"), func.max(Test.red_validated_at).label("latest_validated_at"), func.count(func.distinct(Test.platform)).label("platform_count"), ) .filter(Test.state == TestState.validated) .group_by(Test.technique_id) .all() ) test_stats: dict = {} for row in test_rows: test_stats[row.technique_id] = { "validated": row.validated_count, "detected": row.detected_count, "latest_validated_at": row.latest_validated_at, "platform_count": row.platform_count, } # Q2: active detection rules per mitre_id rule_rows = ( db.query( DetectionRule.mitre_technique_id, func.count(DetectionRule.id).label("total"), ) .filter(DetectionRule.is_active == True) # noqa: E712 .group_by(DetectionRule.mitre_technique_id) .all() ) rules_by_mitre: dict[str, int] = {r.mitre_technique_id: r.total for r in rule_rows} # Q3: triggered rules per mitre_id triggered_rows = ( db.query( DetectionRule.mitre_technique_id, func.count(TestDetectionResult.id).label("triggered"), ) .join(DetectionRule, DetectionRule.id == TestDetectionResult.detection_rule_id) .filter(TestDetectionResult.triggered == True) # noqa: E712 .group_by(DetectionRule.mitre_technique_id) .all() ) triggered_by_mitre: dict[str, int] = { r.mitre_technique_id: r.triggered for r in triggered_rows } # Q4: D3FEND mapping counts per technique d3fend_rows = ( db.query( DefensiveTechniqueMapping.attack_technique_id, func.count(DefensiveTechniqueMapping.id).label("total"), ) .group_by(DefensiveTechniqueMapping.attack_technique_id) .all() ) d3fend_by_tech: dict = {r.attack_technique_id: r.total for r in d3fend_rows} # Q5: all techniques techniques = db.query(Technique).all() now = datetime.utcnow() results: dict = {} for tech in techniques: ts = test_stats.get(tech.id, {}) validated = ts.get("validated", 0) detected = ts.get("detected", 0) latest_at = ts.get("latest_validated_at") plat_count = ts.get("platform_count", 0) breakdown = {} # 1. Tests validated with detection if validated > 0: test_ratio = detected / validated test_score = round(test_ratio * w_tests, 1) else: test_ratio = 0 test_score = 0 breakdown["tests_validated"] = { "score": test_score, "max": w_tests, "detail": ( f"{detected}/{validated} tests detected" if validated else "No validated tests" ), } # 2. Detection rules total_rules = rules_by_mitre.get(tech.mitre_id, 0) triggered_rules = triggered_by_mitre.get(tech.mitre_id, 0) if total_rules > 0: detection_ratio = min(triggered_rules / total_rules, 1.0) detection_score = round(detection_ratio * w_detection, 1) else: detection_ratio = 0 detection_score = 0 breakdown["detection_rules"] = { "score": detection_score, "max": w_detection, "detail": ( f"{triggered_rules}/{total_rules} rules triggered" if total_rules > 0 else "No detection rules available" ), } # 3. D3FEND coverage total_cm = d3fend_by_tech.get(tech.id, 0) if total_cm > 0 and detected > 0: verified_cm = min(detected, total_cm) d3fend_score = round((verified_cm / total_cm) * w_d3fend, 1) else: verified_cm = 0 d3fend_score = 0 breakdown["d3fend_coverage"] = { "score": d3fend_score, "max": w_d3fend, "detail": ( f"{verified_cm}/{total_cm} countermeasures" if total_cm > 0 else "No D3FEND mappings" ), } # 4. Freshness if latest_at: days_ago = (now - latest_at).days if days_ago < 90: freshness_pct = 1.0 elif days_ago < 180: freshness_pct = 0.5 else: freshness_pct = 0.0 freshness_score = round(freshness_pct * w_freshness, 1) freshness_detail = f"Last test {days_ago} days ago" else: freshness_score = 0 freshness_detail = "No validated tests" breakdown["freshness"] = { "score": freshness_score, "max": w_freshness, "detail": freshness_detail, } # 5. Platform diversity available = tech.platforms or [] total_platforms = len(available) if available else 3 if total_platforms > 0 and plat_count > 0: diversity_score = round( min(plat_count / total_platforms, 1.0) * w_diversity, 1, ) else: diversity_score = 0 breakdown["platform_diversity"] = { "score": diversity_score, "max": w_diversity, "detail": ( f"{plat_count}/{total_platforms} platforms covered" if plat_count > 0 else "No platforms tested" ), } total = min( test_score + detection_score + d3fend_score + freshness_score + diversity_score, 100, ) results[tech.id] = { "total_score": round(total, 1), "breakdown": breakdown, "mitre_id": tech.mitre_id, "tactic": tech.tactic, } return results # ── Technique-level scoring (single technique — preserved API) ──────── def calculate_technique_score(technique: Technique, db: Session) -> dict: """Calculate a 0-100 score for a technique with detailed breakdown. Weights are read from the ``scoring_config`` table (or env defaults). """ w = get_scoring_weights(db) w_tests = w.tests w_detection = w.detection_rules w_d3fend = w.d3fend w_freshness = w.freshness w_diversity = w.platform_diversity breakdown = {} # ── 1. Tests validated with detection ────────────────────────── all_tests = ( db.query(Test) .filter(Test.technique_id == technique.id) .all() ) validated_tests = [t for t in all_tests if t.state == TestState.validated] detected_tests = [ t for t in validated_tests if t.detection_result == TestResult.detected ] if validated_tests: test_ratio = len(detected_tests) / len(validated_tests) test_score = round(test_ratio * w_tests, 1) else: test_ratio = 0 test_score = 0 breakdown["tests_validated"] = { "score": test_score, "max": w_tests, "detail": f"{len(detected_tests)}/{len(validated_tests)} tests detected" if validated_tests else "No validated tests", } # ── 2. Detection rules coverage ─────────────────────────────── total_rules = ( db.query(func.count(DetectionRule.id)) .filter( DetectionRule.mitre_technique_id == technique.mitre_id, DetectionRule.is_active == True, # noqa: E712 ) .scalar() ) or 0 triggered_rules = 0 if total_rules > 0: triggered_rules = ( db.query(func.count(TestDetectionResult.id)) .join( DetectionRule, DetectionRule.id == TestDetectionResult.detection_rule_id, ) .filter( DetectionRule.mitre_technique_id == technique.mitre_id, TestDetectionResult.triggered == True, # noqa: E712 ) .scalar() ) or 0 detection_ratio = min(triggered_rules / total_rules, 1.0) detection_score = round(detection_ratio * w_detection, 1) else: detection_ratio = 0 detection_score = 0 breakdown["detection_rules"] = { "score": detection_score, "max": w_detection, "detail": f"{triggered_rules}/{total_rules} rules triggered" if total_rules > 0 else "No detection rules available", } # ── 3. D3FEND coverage ──────────────────────────────────────── total_countermeasures = ( db.query(func.count(DefensiveTechniqueMapping.id)) .filter(DefensiveTechniqueMapping.attack_technique_id == technique.id) .scalar() ) or 0 verified_countermeasures = 0 if total_countermeasures > 0 and len(detected_tests) > 0: verified_countermeasures = min(len(detected_tests), total_countermeasures) d3fend_ratio = verified_countermeasures / total_countermeasures d3fend_score = round(d3fend_ratio * w_d3fend, 1) else: d3fend_ratio = 0 d3fend_score = 0 breakdown["d3fend_coverage"] = { "score": d3fend_score, "max": w_d3fend, "detail": f"{verified_countermeasures}/{total_countermeasures} countermeasures" if total_countermeasures > 0 else "No D3FEND mappings", } # ── 4. Freshness ────────────────────────────────────────────── most_recent_test = ( db.query(func.max(Test.red_validated_at)) .filter( Test.technique_id == technique.id, Test.state == TestState.validated, ) .scalar() ) now = datetime.utcnow() if most_recent_test: days_ago = (now - most_recent_test).days if days_ago < 90: freshness_pct = 1.0 elif days_ago < 180: freshness_pct = 0.5 else: freshness_pct = 0.0 freshness_score = round(freshness_pct * w_freshness, 1) freshness_detail = f"Last test {days_ago} days ago" else: freshness_pct = 0 freshness_score = 0 freshness_detail = "No validated tests" breakdown["freshness"] = { "score": freshness_score, "max": w_freshness, "detail": freshness_detail, } # ── 5. Platform diversity ───────────────────────────────────── available_platforms = technique.platforms or [] total_platforms = len(available_platforms) if available_platforms else 3 tested_platforms = set() for t in validated_tests: if t.platform: tested_platforms.add(t.platform.lower()) if total_platforms > 0 and tested_platforms: diversity_ratio = min(len(tested_platforms) / total_platforms, 1.0) diversity_score = round(diversity_ratio * w_diversity, 1) else: diversity_ratio = 0 diversity_score = 0 breakdown["platform_diversity"] = { "score": diversity_score, "max": w_diversity, "detail": f"{len(tested_platforms)}/{total_platforms} platforms covered" if tested_platforms else "No platforms tested", } # ── Total ───────────────────────────────────────────────────── total = min( test_score + detection_score + d3fend_score + freshness_score + diversity_score, 100, ) return { "total_score": round(total, 1), "breakdown": breakdown, } # ── Tactic-level scoring ───────────────────────────────────────────── def calculate_tactic_score(tactic: str, db: Session) -> dict: """Calculate average score for all techniques in a tactic.""" scores_map = bulk_technique_scores(db) matching = [ v["total_score"] for v in scores_map.values() if v.get("tactic") and tactic.lower() in v["tactic"].lower() ] return { "tactic": tactic, "average_score": round(sum(matching) / len(matching), 1) if matching else 0, "techniques_count": len(matching), "techniques_scored": len([s for s in matching if s > 0]), } # ── Threat actor scoring ───────────────────────────────────────────── def calculate_actor_coverage_score(actor_id: str, db: Session) -> dict: """Calculate coverage score for a specific threat actor's techniques.""" actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first() if not actor: return {"total_score": 0, "techniques_count": 0, "techniques_covered": 0} actor_techniques = ( db.query(ThreatActorTechnique) .filter(ThreatActorTechnique.threat_actor_id == actor.id) .all() ) technique_ids = {at.technique_id for at in actor_techniques} if not technique_ids: return { "actor_id": str(actor.id), "actor_name": actor.name, "total_score": 0, "techniques_count": 0, "techniques_covered": 0, "techniques_detail": [], } scores_map = bulk_technique_scores(db) scores = [] details = [] for tid in technique_ids: entry = scores_map.get(tid) if not entry: continue score = entry["total_score"] scores.append(score) details.append({ "mitre_id": entry["mitre_id"], "name": entry.get("name", ""), "score": score, "breakdown": entry["breakdown"], }) avg_score = round(sum(scores) / len(scores), 1) if scores else 0 return { "actor_id": str(actor.id), "actor_name": actor.name, "total_score": avg_score, "techniques_count": len(technique_ids), "techniques_covered": len([s for s in scores if s > 50]), "techniques_detail": details, } # ── Organization-level scoring ──────────────────────────────────────── def calculate_organization_score(db: Session) -> dict: """Calculate the overall organization security score. Uses ``bulk_technique_scores`` to compute all technique scores in 5 aggregated queries instead of N*5. """ scores_map = bulk_technique_scores(db) total_count = len(scores_map) if total_count == 0: return { "overall_score": 0, "total_coverage": 0, "critical_coverage": 0, "detection_maturity": 0, "response_readiness": 0, "techniques_evaluated": 0, "techniques_total": 0, } all_scores = [v["total_score"] for v in scores_map.values()] evaluated_scores = [s for s in all_scores if s > 0] evaluated_count = len(evaluated_scores) total_coverage = ( round(sum(evaluated_scores) / len(evaluated_scores), 1) if evaluated_scores else 0 ) # Critical coverage: techniques with high/critical severity templates from app.models.test_template import TestTemplate critical_mitre_ids = set( row[0] for row in db.query(TestTemplate.mitre_technique_id) .filter(TestTemplate.severity.in_(["high", "critical"])) .distinct() .all() ) critical_scores = [ v["total_score"] for v in scores_map.values() if v.get("mitre_id") in critical_mitre_ids ] critical_coverage = ( round(sum(critical_scores) / len(critical_scores), 1) if critical_scores else 0 ) # Detection maturity (2 scalar queries — already efficient) total_rules = ( db.query(func.count(DetectionRule.id)) .filter(DetectionRule.is_active == True) # noqa: E712 .scalar() ) or 0 triggered_total = ( db.query(func.count(TestDetectionResult.id)) .filter(TestDetectionResult.triggered == True) # noqa: E712 .scalar() ) or 0 detection_maturity = ( round((triggered_total / total_rules) * 100, 1) if total_rules > 0 else 0 ) detection_maturity = min(detection_maturity, 100) # Response readiness (2 scalar queries — already efficient) remediation_total = ( db.query(func.count(Test.id)) .filter(Test.remediation_status.isnot(None)) .scalar() ) or 0 remediation_completed = ( db.query(func.count(Test.id)) .filter(Test.remediation_status == "completed") .scalar() ) or 0 response_readiness = ( round((remediation_completed / remediation_total) * 100, 1) if remediation_total > 0 else 0 ) overall = round( total_coverage * 0.4 + critical_coverage * 0.25 + detection_maturity * 0.2 + response_readiness * 0.15, 1, ) return { "overall_score": overall, "total_coverage": total_coverage, "critical_coverage": critical_coverage, "detection_maturity": detection_maturity, "response_readiness": response_readiness, "techniques_evaluated": evaluated_count, "techniques_total": total_count, } # ── Score history ──────────────────────────────────────────────────── def get_score_history(db: Session, period: str = "90d") -> list: """Get historical score snapshots. Since we don't have a dedicated history table, we approximate by computing scores based on test dates within time windows. Returns a list of weekly data points. """ from app.models.audit import AuditLog now = datetime.utcnow() if period == "30d": start = now - timedelta(days=30) elif period == "1y": start = now - timedelta(days=365) else: # 90d default start = now - timedelta(days=90) # Group validated tests by week weeks = [] current = start while current < now: week_end = min(current + timedelta(days=7), now) # Count validated tests up to this week validated_up_to = ( db.query(func.count(Test.id)) .filter( Test.state == TestState.validated, Test.red_validated_at <= week_end, ) .scalar() ) or 0 total_techniques = ( db.query(func.count(Technique.id)).scalar() ) or 1 # Simple approximation: coverage percentage as score proxy score_approx = round((validated_up_to / total_techniques) * 100, 1) weeks.append({ "date": current.strftime("%Y-%m-%d"), "score": min(score_approx, 100), "validated_tests": validated_up_to, }) current = week_end return weeks