From f0f59facdbdbe19ad30646d8dc8605fbba1999c1 Mon Sep 17 00:00:00 2001 From: Kitos Date: Wed, 18 Feb 2026 12:18:48 +0100 Subject: [PATCH] perf(scoring): eliminate N+1 in organization score calculation - Add bulk_technique_scores() that pre-fetches all scoring data in 5 aggregated GROUP BY queries instead of N*5 per-technique queries - Rewrite calculate_organization_score to use bulk data (N*5+5 queries -> 10 fixed queries) - Rewrite calculate_tactic_score and calculate_actor_coverage_score to use bulk data - Preserve calculate_technique_score single-technique API for router-level calls --- backend/app/services/scoring_service.py | 349 ++++++++++++++++++------ 1 file changed, 267 insertions(+), 82 deletions(-) diff --git a/backend/app/services/scoring_service.py b/backend/app/services/scoring_service.py index 072eb00..5209816 100644 --- a/backend/app/services/scoring_service.py +++ b/backend/app/services/scoring_service.py @@ -2,12 +2,16 @@ Uses configurable weights from Settings to compute coverage scores with detailed breakdowns. + +Bulk helpers (``bulk_technique_scores``) pre-fetch all scoring data in a +fixed number of aggregated queries so that organisation-wide calculations +never produce N+1 traffic. """ from datetime import datetime, timedelta from typing import Optional -from sqlalchemy import func +from sqlalchemy import case, func from sqlalchemy.orm import Session from app.config import settings @@ -20,7 +24,219 @@ from app.models.threat_actor import ThreatActor, ThreatActorTechnique from app.models.enums import TestState, TestResult -# ── Technique-level scoring ────────────────────────────────────────── +# ── Bulk scoring helpers (5 queries for ALL techniques) ─────────────── + + +def _build_empty_stats(): + return { + "validated": 0, + "detected": 0, + "platforms": set(), + "latest_validated_at": None, + } + + +def bulk_technique_scores(db: Session) -> dict: + """Pre-fetch all scoring data and compute per-technique scores in memory. + + Executes exactly 5 queries regardless of technique count: + Q1 — Test aggregates per technique (validated / detected / platforms / freshness) + Q2 — Detection rules per mitre_id + Q3 — Triggered rules per mitre_id + Q4 — D3FEND mapping counts per technique + Q5 — All techniques + + Returns ``{technique_id: {"total_score": float, "breakdown": dict}}``. + """ + w_tests = settings.SCORING_WEIGHT_TESTS + w_detection = settings.SCORING_WEIGHT_DETECTION_RULES + w_d3fend = settings.SCORING_WEIGHT_D3FEND + w_freshness = settings.SCORING_WEIGHT_FRESHNESS + w_diversity = settings.SCORING_WEIGHT_PLATFORM_DIVERSITY + + # Q1: test stats grouped by technique_id + test_rows = ( + db.query( + Test.technique_id, + func.count(Test.id).label("validated_count"), + func.count( + case((Test.detection_result == TestResult.detected, Test.id)) + ).label("detected_count"), + func.max(Test.red_validated_at).label("latest_validated_at"), + func.count(func.distinct(Test.platform)).label("platform_count"), + ) + .filter(Test.state == TestState.validated) + .group_by(Test.technique_id) + .all() + ) + + test_stats: dict = {} + for row in test_rows: + test_stats[row.technique_id] = { + "validated": row.validated_count, + "detected": row.detected_count, + "latest_validated_at": row.latest_validated_at, + "platform_count": row.platform_count, + } + + # Q2: active detection rules per mitre_id + rule_rows = ( + db.query( + DetectionRule.mitre_technique_id, + func.count(DetectionRule.id).label("total"), + ) + .filter(DetectionRule.is_active == True) # noqa: E712 + .group_by(DetectionRule.mitre_technique_id) + .all() + ) + rules_by_mitre: dict[str, int] = {r.mitre_technique_id: r.total for r in rule_rows} + + # Q3: triggered rules per mitre_id + triggered_rows = ( + db.query( + DetectionRule.mitre_technique_id, + func.count(TestDetectionResult.id).label("triggered"), + ) + .join(DetectionRule, DetectionRule.id == TestDetectionResult.detection_rule_id) + .filter(TestDetectionResult.triggered == True) # noqa: E712 + .group_by(DetectionRule.mitre_technique_id) + .all() + ) + triggered_by_mitre: dict[str, int] = { + r.mitre_technique_id: r.triggered for r in triggered_rows + } + + # Q4: D3FEND mapping counts per technique + d3fend_rows = ( + db.query( + DefensiveTechniqueMapping.attack_technique_id, + func.count(DefensiveTechniqueMapping.id).label("total"), + ) + .group_by(DefensiveTechniqueMapping.attack_technique_id) + .all() + ) + d3fend_by_tech: dict = {r.attack_technique_id: r.total for r in d3fend_rows} + + # Q5: all techniques + techniques = db.query(Technique).all() + + now = datetime.utcnow() + results: dict = {} + + for tech in techniques: + ts = test_stats.get(tech.id, {}) + validated = ts.get("validated", 0) + detected = ts.get("detected", 0) + latest_at = ts.get("latest_validated_at") + plat_count = ts.get("platform_count", 0) + + breakdown = {} + + # 1. Tests validated with detection + if validated > 0: + test_ratio = detected / validated + test_score = round(test_ratio * w_tests, 1) + else: + test_ratio = 0 + test_score = 0 + breakdown["tests_validated"] = { + "score": test_score, + "max": w_tests, + "detail": ( + f"{detected}/{validated} tests detected" + if validated else "No validated tests" + ), + } + + # 2. Detection rules + total_rules = rules_by_mitre.get(tech.mitre_id, 0) + triggered_rules = triggered_by_mitre.get(tech.mitre_id, 0) + if total_rules > 0: + detection_ratio = min(triggered_rules / total_rules, 1.0) + detection_score = round(detection_ratio * w_detection, 1) + else: + detection_ratio = 0 + detection_score = 0 + breakdown["detection_rules"] = { + "score": detection_score, + "max": w_detection, + "detail": ( + f"{triggered_rules}/{total_rules} rules triggered" + if total_rules > 0 else "No detection rules available" + ), + } + + # 3. D3FEND coverage + total_cm = d3fend_by_tech.get(tech.id, 0) + if total_cm > 0 and detected > 0: + verified_cm = min(detected, total_cm) + d3fend_score = round((verified_cm / total_cm) * w_d3fend, 1) + else: + verified_cm = 0 + d3fend_score = 0 + breakdown["d3fend_coverage"] = { + "score": d3fend_score, + "max": w_d3fend, + "detail": ( + f"{verified_cm}/{total_cm} countermeasures" + if total_cm > 0 else "No D3FEND mappings" + ), + } + + # 4. Freshness + if latest_at: + days_ago = (now - latest_at).days + if days_ago < 90: + freshness_pct = 1.0 + elif days_ago < 180: + freshness_pct = 0.5 + else: + freshness_pct = 0.0 + freshness_score = round(freshness_pct * w_freshness, 1) + freshness_detail = f"Last test {days_ago} days ago" + else: + freshness_score = 0 + freshness_detail = "No validated tests" + breakdown["freshness"] = { + "score": freshness_score, + "max": w_freshness, + "detail": freshness_detail, + } + + # 5. Platform diversity + available = tech.platforms or [] + total_platforms = len(available) if available else 3 + if total_platforms > 0 and plat_count > 0: + diversity_score = round( + min(plat_count / total_platforms, 1.0) * w_diversity, 1, + ) + else: + diversity_score = 0 + breakdown["platform_diversity"] = { + "score": diversity_score, + "max": w_diversity, + "detail": ( + f"{plat_count}/{total_platforms} platforms covered" + if plat_count > 0 else "No platforms tested" + ), + } + + total = min( + test_score + detection_score + d3fend_score + + freshness_score + diversity_score, + 100, + ) + results[tech.id] = { + "total_score": round(total, 1), + "breakdown": breakdown, + "mitre_id": tech.mitre_id, + "tactic": tech.tactic, + } + + return results + + +# ── Technique-level scoring (single technique — preserved API) ──────── def calculate_technique_score(technique: Technique, db: Session) -> dict: @@ -73,7 +289,7 @@ def calculate_technique_score(technique: Technique, db: Session) -> dict: db.query(func.count(DetectionRule.id)) .filter( DetectionRule.mitre_technique_id == technique.mitre_id, - DetectionRule.is_active == True, + DetectionRule.is_active == True, # noqa: E712 ) .scalar() ) or 0 @@ -88,7 +304,7 @@ def calculate_technique_score(technique: Technique, db: Session) -> dict: ) .filter( DetectionRule.mitre_technique_id == technique.mitre_id, - TestDetectionResult.triggered == True, + TestDetectionResult.triggered == True, # noqa: E712 ) .scalar() ) or 0 @@ -114,11 +330,8 @@ def calculate_technique_score(technique: Technique, db: Session) -> dict: .scalar() ) or 0 - # Consider a countermeasure "verified" if we have validated tests - # with detection for the technique (simplified heuristic) verified_countermeasures = 0 if total_countermeasures > 0 and len(detected_tests) > 0: - # Rough heuristic: each detected test validates ~1 countermeasure verified_countermeasures = min(len(detected_tests), total_countermeasures) d3fend_ratio = verified_countermeasures / total_countermeasures d3fend_score = round(d3fend_ratio * w_d3fend, 1) @@ -135,7 +348,6 @@ def calculate_technique_score(technique: Technique, db: Session) -> dict: } # ── 4. Freshness ────────────────────────────────────────────── - # Most recent validated test date most_recent_test = ( db.query(func.max(Test.red_validated_at)) .filter( @@ -169,7 +381,7 @@ def calculate_technique_score(technique: Technique, db: Session) -> dict: # ── 5. Platform diversity ───────────────────────────────────── available_platforms = technique.platforms or [] - total_platforms = len(available_platforms) if available_platforms else 3 # default 3 + total_platforms = len(available_platforms) if available_platforms else 3 tested_platforms = set() for t in validated_tests: @@ -208,30 +420,19 @@ def calculate_technique_score(technique: Technique, db: Session) -> dict: def calculate_tactic_score(tactic: str, db: Session) -> dict: """Calculate average score for all techniques in a tactic.""" - techniques = ( - db.query(Technique) - .filter(Technique.tactic.ilike(f"%{tactic}%")) - .all() - ) + scores_map = bulk_technique_scores(db) - if not techniques: - return { - "tactic": tactic, - "average_score": 0, - "techniques_count": 0, - "techniques_scored": 0, - } - - scores = [] - for tech in techniques: - result = calculate_technique_score(tech, db) - scores.append(result["total_score"]) + matching = [ + v["total_score"] + for v in scores_map.values() + if v.get("tactic") and tactic.lower() in v["tactic"].lower() + ] return { "tactic": tactic, - "average_score": round(sum(scores) / len(scores), 1) if scores else 0, - "techniques_count": len(techniques), - "techniques_scored": len([s for s in scores if s > 0]), + "average_score": round(sum(matching) / len(matching), 1) if matching else 0, + "techniques_count": len(matching), + "techniques_scored": len([s for s in matching if s > 0]), } @@ -244,14 +445,13 @@ def calculate_actor_coverage_score(actor_id: str, db: Session) -> dict: if not actor: return {"total_score": 0, "techniques_count": 0, "techniques_covered": 0} - # Get all techniques used by this actor actor_techniques = ( db.query(ThreatActorTechnique) .filter(ThreatActorTechnique.threat_actor_id == actor.id) .all() ) - technique_ids = [at.technique_id for at in actor_techniques] + technique_ids = {at.technique_id for at in actor_techniques} if not technique_ids: return { "actor_id": str(actor.id), @@ -262,23 +462,21 @@ def calculate_actor_coverage_score(actor_id: str, db: Session) -> dict: "techniques_detail": [], } - techniques = ( - db.query(Technique) - .filter(Technique.id.in_(technique_ids)) - .all() - ) + scores_map = bulk_technique_scores(db) scores = [] details = [] - for tech in techniques: - result = calculate_technique_score(tech, db) - score = result["total_score"] + for tid in technique_ids: + entry = scores_map.get(tid) + if not entry: + continue + score = entry["total_score"] scores.append(score) details.append({ - "mitre_id": tech.mitre_id, - "name": tech.name, + "mitre_id": entry["mitre_id"], + "name": entry.get("name", ""), "score": score, - "breakdown": result["breakdown"], + "breakdown": entry["breakdown"], }) avg_score = round(sum(scores) / len(scores), 1) if scores else 0 @@ -287,7 +485,7 @@ def calculate_actor_coverage_score(actor_id: str, db: Session) -> dict: "actor_id": str(actor.id), "actor_name": actor.name, "total_score": avg_score, - "techniques_count": len(techniques), + "techniques_count": len(technique_ids), "techniques_covered": len([s for s in scores if s > 50]), "techniques_detail": details, } @@ -297,10 +495,13 @@ def calculate_actor_coverage_score(actor_id: str, db: Session) -> dict: def calculate_organization_score(db: Session) -> dict: - """Calculate the overall organization security score.""" - # All techniques - all_techniques = db.query(Technique).all() - total_count = len(all_techniques) + """Calculate the overall organization security score. + + Uses ``bulk_technique_scores`` to compute all technique scores in + 5 aggregated queries instead of N*5. + """ + scores_map = bulk_technique_scores(db) + total_count = len(scores_map) if total_count == 0: return { @@ -313,27 +514,16 @@ def calculate_organization_score(db: Session) -> dict: "techniques_total": 0, } - # Calculate scores for all techniques (with caching for performance) - all_scores = [] - evaluated_count = 0 - - for tech in all_techniques: - result = calculate_technique_score(tech, db) - score = result["total_score"] - all_scores.append(score) - if score > 0: - evaluated_count += 1 - - # Total coverage: average of all evaluated techniques + all_scores = [v["total_score"] for v in scores_map.values()] evaluated_scores = [s for s in all_scores if s > 0] + evaluated_count = len(evaluated_scores) + total_coverage = ( round(sum(evaluated_scores) / len(evaluated_scores), 1) - if evaluated_scores - else 0 + if evaluated_scores else 0 ) - # Critical coverage: techniques with high-severity templates - # (simplified: techniques that have tests are "critical") + # Critical coverage: techniques with high/critical severity templates from app.models.test_template import TestTemplate critical_mitre_ids = set( @@ -344,38 +534,35 @@ def calculate_organization_score(db: Session) -> dict: .all() ) - critical_techniques = [ - t for t in all_techniques if t.mitre_id in critical_mitre_ids + critical_scores = [ + v["total_score"] + for v in scores_map.values() + if v.get("mitre_id") in critical_mitre_ids ] - if critical_techniques: - critical_scores = [] - for tech in critical_techniques: - result = calculate_technique_score(tech, db) - critical_scores.append(result["total_score"]) - critical_coverage = round(sum(critical_scores) / len(critical_scores), 1) - else: - critical_coverage = 0 + critical_coverage = ( + round(sum(critical_scores) / len(critical_scores), 1) + if critical_scores else 0 + ) - # Detection maturity: based on detection rule coverage + # Detection maturity (2 scalar queries — already efficient) total_rules = ( db.query(func.count(DetectionRule.id)) - .filter(DetectionRule.is_active == True) + .filter(DetectionRule.is_active == True) # noqa: E712 .scalar() ) or 0 triggered_total = ( db.query(func.count(TestDetectionResult.id)) - .filter(TestDetectionResult.triggered == True) + .filter(TestDetectionResult.triggered == True) # noqa: E712 .scalar() ) or 0 detection_maturity = ( round((triggered_total / total_rules) * 100, 1) - if total_rules > 0 - else 0 + if total_rules > 0 else 0 ) detection_maturity = min(detection_maturity, 100) - # Response readiness: based on remediation completion + # Response readiness (2 scalar queries — already efficient) remediation_total = ( db.query(func.count(Test.id)) .filter(Test.remediation_status.isnot(None)) @@ -389,11 +576,9 @@ def calculate_organization_score(db: Session) -> dict: response_readiness = ( round((remediation_completed / remediation_total) * 100, 1) - if remediation_total > 0 - else 0 + if remediation_total > 0 else 0 ) - # Overall score: weighted average of sub-scores overall = round( total_coverage * 0.4 + critical_coverage * 0.25