perf(scoring): eliminate N+1 in organization score calculation

- Add bulk_technique_scores() that pre-fetches all scoring data in 5 aggregated GROUP BY queries instead of N*5 per-technique queries

- Rewrite calculate_organization_score to use bulk data (N*5+5 queries -> 10 fixed queries)

- Rewrite calculate_tactic_score and calculate_actor_coverage_score to use bulk data

- Preserve calculate_technique_score single-technique API for router-level calls
This commit is contained in:
2026-02-18 12:18:48 +01:00
parent 898bb7e4e7
commit f0f59facdb

View File

@@ -2,12 +2,16 @@
Uses configurable weights from Settings to compute coverage scores with Uses configurable weights from Settings to compute coverage scores with
detailed breakdowns. detailed breakdowns.
Bulk helpers (``bulk_technique_scores``) pre-fetch all scoring data in a
fixed number of aggregated queries so that organisation-wide calculations
never produce N+1 traffic.
""" """
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Optional from typing import Optional
from sqlalchemy import func from sqlalchemy import case, func
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.config import settings from app.config import settings
@@ -20,7 +24,219 @@ from app.models.threat_actor import ThreatActor, ThreatActorTechnique
from app.models.enums import TestState, TestResult from app.models.enums import TestState, TestResult
# ── Technique-level scoring ────────────────────────────────────────── # ── Bulk scoring helpers (5 queries for ALL techniques) ───────────────
def _build_empty_stats():
return {
"validated": 0,
"detected": 0,
"platforms": set(),
"latest_validated_at": None,
}
def bulk_technique_scores(db: Session) -> dict:
"""Pre-fetch all scoring data and compute per-technique scores in memory.
Executes exactly 5 queries regardless of technique count:
Q1 — Test aggregates per technique (validated / detected / platforms / freshness)
Q2 — Detection rules per mitre_id
Q3 — Triggered rules per mitre_id
Q4 — D3FEND mapping counts per technique
Q5 — All techniques
Returns ``{technique_id: {"total_score": float, "breakdown": dict}}``.
"""
w_tests = settings.SCORING_WEIGHT_TESTS
w_detection = settings.SCORING_WEIGHT_DETECTION_RULES
w_d3fend = settings.SCORING_WEIGHT_D3FEND
w_freshness = settings.SCORING_WEIGHT_FRESHNESS
w_diversity = settings.SCORING_WEIGHT_PLATFORM_DIVERSITY
# Q1: test stats grouped by technique_id
test_rows = (
db.query(
Test.technique_id,
func.count(Test.id).label("validated_count"),
func.count(
case((Test.detection_result == TestResult.detected, Test.id))
).label("detected_count"),
func.max(Test.red_validated_at).label("latest_validated_at"),
func.count(func.distinct(Test.platform)).label("platform_count"),
)
.filter(Test.state == TestState.validated)
.group_by(Test.technique_id)
.all()
)
test_stats: dict = {}
for row in test_rows:
test_stats[row.technique_id] = {
"validated": row.validated_count,
"detected": row.detected_count,
"latest_validated_at": row.latest_validated_at,
"platform_count": row.platform_count,
}
# Q2: active detection rules per mitre_id
rule_rows = (
db.query(
DetectionRule.mitre_technique_id,
func.count(DetectionRule.id).label("total"),
)
.filter(DetectionRule.is_active == True) # noqa: E712
.group_by(DetectionRule.mitre_technique_id)
.all()
)
rules_by_mitre: dict[str, int] = {r.mitre_technique_id: r.total for r in rule_rows}
# Q3: triggered rules per mitre_id
triggered_rows = (
db.query(
DetectionRule.mitre_technique_id,
func.count(TestDetectionResult.id).label("triggered"),
)
.join(DetectionRule, DetectionRule.id == TestDetectionResult.detection_rule_id)
.filter(TestDetectionResult.triggered == True) # noqa: E712
.group_by(DetectionRule.mitre_technique_id)
.all()
)
triggered_by_mitre: dict[str, int] = {
r.mitre_technique_id: r.triggered for r in triggered_rows
}
# Q4: D3FEND mapping counts per technique
d3fend_rows = (
db.query(
DefensiveTechniqueMapping.attack_technique_id,
func.count(DefensiveTechniqueMapping.id).label("total"),
)
.group_by(DefensiveTechniqueMapping.attack_technique_id)
.all()
)
d3fend_by_tech: dict = {r.attack_technique_id: r.total for r in d3fend_rows}
# Q5: all techniques
techniques = db.query(Technique).all()
now = datetime.utcnow()
results: dict = {}
for tech in techniques:
ts = test_stats.get(tech.id, {})
validated = ts.get("validated", 0)
detected = ts.get("detected", 0)
latest_at = ts.get("latest_validated_at")
plat_count = ts.get("platform_count", 0)
breakdown = {}
# 1. Tests validated with detection
if validated > 0:
test_ratio = detected / validated
test_score = round(test_ratio * w_tests, 1)
else:
test_ratio = 0
test_score = 0
breakdown["tests_validated"] = {
"score": test_score,
"max": w_tests,
"detail": (
f"{detected}/{validated} tests detected"
if validated else "No validated tests"
),
}
# 2. Detection rules
total_rules = rules_by_mitre.get(tech.mitre_id, 0)
triggered_rules = triggered_by_mitre.get(tech.mitre_id, 0)
if total_rules > 0:
detection_ratio = min(triggered_rules / total_rules, 1.0)
detection_score = round(detection_ratio * w_detection, 1)
else:
detection_ratio = 0
detection_score = 0
breakdown["detection_rules"] = {
"score": detection_score,
"max": w_detection,
"detail": (
f"{triggered_rules}/{total_rules} rules triggered"
if total_rules > 0 else "No detection rules available"
),
}
# 3. D3FEND coverage
total_cm = d3fend_by_tech.get(tech.id, 0)
if total_cm > 0 and detected > 0:
verified_cm = min(detected, total_cm)
d3fend_score = round((verified_cm / total_cm) * w_d3fend, 1)
else:
verified_cm = 0
d3fend_score = 0
breakdown["d3fend_coverage"] = {
"score": d3fend_score,
"max": w_d3fend,
"detail": (
f"{verified_cm}/{total_cm} countermeasures"
if total_cm > 0 else "No D3FEND mappings"
),
}
# 4. Freshness
if latest_at:
days_ago = (now - latest_at).days
if days_ago < 90:
freshness_pct = 1.0
elif days_ago < 180:
freshness_pct = 0.5
else:
freshness_pct = 0.0
freshness_score = round(freshness_pct * w_freshness, 1)
freshness_detail = f"Last test {days_ago} days ago"
else:
freshness_score = 0
freshness_detail = "No validated tests"
breakdown["freshness"] = {
"score": freshness_score,
"max": w_freshness,
"detail": freshness_detail,
}
# 5. Platform diversity
available = tech.platforms or []
total_platforms = len(available) if available else 3
if total_platforms > 0 and plat_count > 0:
diversity_score = round(
min(plat_count / total_platforms, 1.0) * w_diversity, 1,
)
else:
diversity_score = 0
breakdown["platform_diversity"] = {
"score": diversity_score,
"max": w_diversity,
"detail": (
f"{plat_count}/{total_platforms} platforms covered"
if plat_count > 0 else "No platforms tested"
),
}
total = min(
test_score + detection_score + d3fend_score
+ freshness_score + diversity_score,
100,
)
results[tech.id] = {
"total_score": round(total, 1),
"breakdown": breakdown,
"mitre_id": tech.mitre_id,
"tactic": tech.tactic,
}
return results
# ── Technique-level scoring (single technique — preserved API) ────────
def calculate_technique_score(technique: Technique, db: Session) -> dict: def calculate_technique_score(technique: Technique, db: Session) -> dict:
@@ -73,7 +289,7 @@ def calculate_technique_score(technique: Technique, db: Session) -> dict:
db.query(func.count(DetectionRule.id)) db.query(func.count(DetectionRule.id))
.filter( .filter(
DetectionRule.mitre_technique_id == technique.mitre_id, DetectionRule.mitre_technique_id == technique.mitre_id,
DetectionRule.is_active == True, DetectionRule.is_active == True, # noqa: E712
) )
.scalar() .scalar()
) or 0 ) or 0
@@ -88,7 +304,7 @@ def calculate_technique_score(technique: Technique, db: Session) -> dict:
) )
.filter( .filter(
DetectionRule.mitre_technique_id == technique.mitre_id, DetectionRule.mitre_technique_id == technique.mitre_id,
TestDetectionResult.triggered == True, TestDetectionResult.triggered == True, # noqa: E712
) )
.scalar() .scalar()
) or 0 ) or 0
@@ -114,11 +330,8 @@ def calculate_technique_score(technique: Technique, db: Session) -> dict:
.scalar() .scalar()
) or 0 ) or 0
# Consider a countermeasure "verified" if we have validated tests
# with detection for the technique (simplified heuristic)
verified_countermeasures = 0 verified_countermeasures = 0
if total_countermeasures > 0 and len(detected_tests) > 0: if total_countermeasures > 0 and len(detected_tests) > 0:
# Rough heuristic: each detected test validates ~1 countermeasure
verified_countermeasures = min(len(detected_tests), total_countermeasures) verified_countermeasures = min(len(detected_tests), total_countermeasures)
d3fend_ratio = verified_countermeasures / total_countermeasures d3fend_ratio = verified_countermeasures / total_countermeasures
d3fend_score = round(d3fend_ratio * w_d3fend, 1) d3fend_score = round(d3fend_ratio * w_d3fend, 1)
@@ -135,7 +348,6 @@ def calculate_technique_score(technique: Technique, db: Session) -> dict:
} }
# ── 4. Freshness ────────────────────────────────────────────── # ── 4. Freshness ──────────────────────────────────────────────
# Most recent validated test date
most_recent_test = ( most_recent_test = (
db.query(func.max(Test.red_validated_at)) db.query(func.max(Test.red_validated_at))
.filter( .filter(
@@ -169,7 +381,7 @@ def calculate_technique_score(technique: Technique, db: Session) -> dict:
# ── 5. Platform diversity ───────────────────────────────────── # ── 5. Platform diversity ─────────────────────────────────────
available_platforms = technique.platforms or [] available_platforms = technique.platforms or []
total_platforms = len(available_platforms) if available_platforms else 3 # default 3 total_platforms = len(available_platforms) if available_platforms else 3
tested_platforms = set() tested_platforms = set()
for t in validated_tests: for t in validated_tests:
@@ -208,30 +420,19 @@ def calculate_technique_score(technique: Technique, db: Session) -> dict:
def calculate_tactic_score(tactic: str, db: Session) -> dict: def calculate_tactic_score(tactic: str, db: Session) -> dict:
"""Calculate average score for all techniques in a tactic.""" """Calculate average score for all techniques in a tactic."""
techniques = ( scores_map = bulk_technique_scores(db)
db.query(Technique)
.filter(Technique.tactic.ilike(f"%{tactic}%"))
.all()
)
if not techniques: matching = [
return { v["total_score"]
"tactic": tactic, for v in scores_map.values()
"average_score": 0, if v.get("tactic") and tactic.lower() in v["tactic"].lower()
"techniques_count": 0, ]
"techniques_scored": 0,
}
scores = []
for tech in techniques:
result = calculate_technique_score(tech, db)
scores.append(result["total_score"])
return { return {
"tactic": tactic, "tactic": tactic,
"average_score": round(sum(scores) / len(scores), 1) if scores else 0, "average_score": round(sum(matching) / len(matching), 1) if matching else 0,
"techniques_count": len(techniques), "techniques_count": len(matching),
"techniques_scored": len([s for s in scores if s > 0]), "techniques_scored": len([s for s in matching if s > 0]),
} }
@@ -244,14 +445,13 @@ def calculate_actor_coverage_score(actor_id: str, db: Session) -> dict:
if not actor: if not actor:
return {"total_score": 0, "techniques_count": 0, "techniques_covered": 0} return {"total_score": 0, "techniques_count": 0, "techniques_covered": 0}
# Get all techniques used by this actor
actor_techniques = ( actor_techniques = (
db.query(ThreatActorTechnique) db.query(ThreatActorTechnique)
.filter(ThreatActorTechnique.threat_actor_id == actor.id) .filter(ThreatActorTechnique.threat_actor_id == actor.id)
.all() .all()
) )
technique_ids = [at.technique_id for at in actor_techniques] technique_ids = {at.technique_id for at in actor_techniques}
if not technique_ids: if not technique_ids:
return { return {
"actor_id": str(actor.id), "actor_id": str(actor.id),
@@ -262,23 +462,21 @@ def calculate_actor_coverage_score(actor_id: str, db: Session) -> dict:
"techniques_detail": [], "techniques_detail": [],
} }
techniques = ( scores_map = bulk_technique_scores(db)
db.query(Technique)
.filter(Technique.id.in_(technique_ids))
.all()
)
scores = [] scores = []
details = [] details = []
for tech in techniques: for tid in technique_ids:
result = calculate_technique_score(tech, db) entry = scores_map.get(tid)
score = result["total_score"] if not entry:
continue
score = entry["total_score"]
scores.append(score) scores.append(score)
details.append({ details.append({
"mitre_id": tech.mitre_id, "mitre_id": entry["mitre_id"],
"name": tech.name, "name": entry.get("name", ""),
"score": score, "score": score,
"breakdown": result["breakdown"], "breakdown": entry["breakdown"],
}) })
avg_score = round(sum(scores) / len(scores), 1) if scores else 0 avg_score = round(sum(scores) / len(scores), 1) if scores else 0
@@ -287,7 +485,7 @@ def calculate_actor_coverage_score(actor_id: str, db: Session) -> dict:
"actor_id": str(actor.id), "actor_id": str(actor.id),
"actor_name": actor.name, "actor_name": actor.name,
"total_score": avg_score, "total_score": avg_score,
"techniques_count": len(techniques), "techniques_count": len(technique_ids),
"techniques_covered": len([s for s in scores if s > 50]), "techniques_covered": len([s for s in scores if s > 50]),
"techniques_detail": details, "techniques_detail": details,
} }
@@ -297,10 +495,13 @@ def calculate_actor_coverage_score(actor_id: str, db: Session) -> dict:
def calculate_organization_score(db: Session) -> dict: def calculate_organization_score(db: Session) -> dict:
"""Calculate the overall organization security score.""" """Calculate the overall organization security score.
# All techniques
all_techniques = db.query(Technique).all() Uses ``bulk_technique_scores`` to compute all technique scores in
total_count = len(all_techniques) 5 aggregated queries instead of N*5.
"""
scores_map = bulk_technique_scores(db)
total_count = len(scores_map)
if total_count == 0: if total_count == 0:
return { return {
@@ -313,27 +514,16 @@ def calculate_organization_score(db: Session) -> dict:
"techniques_total": 0, "techniques_total": 0,
} }
# Calculate scores for all techniques (with caching for performance) all_scores = [v["total_score"] for v in scores_map.values()]
all_scores = []
evaluated_count = 0
for tech in all_techniques:
result = calculate_technique_score(tech, db)
score = result["total_score"]
all_scores.append(score)
if score > 0:
evaluated_count += 1
# Total coverage: average of all evaluated techniques
evaluated_scores = [s for s in all_scores if s > 0] evaluated_scores = [s for s in all_scores if s > 0]
evaluated_count = len(evaluated_scores)
total_coverage = ( total_coverage = (
round(sum(evaluated_scores) / len(evaluated_scores), 1) round(sum(evaluated_scores) / len(evaluated_scores), 1)
if evaluated_scores if evaluated_scores else 0
else 0
) )
# Critical coverage: techniques with high-severity templates # Critical coverage: techniques with high/critical severity templates
# (simplified: techniques that have tests are "critical")
from app.models.test_template import TestTemplate from app.models.test_template import TestTemplate
critical_mitre_ids = set( critical_mitre_ids = set(
@@ -344,38 +534,35 @@ def calculate_organization_score(db: Session) -> dict:
.all() .all()
) )
critical_techniques = [ critical_scores = [
t for t in all_techniques if t.mitre_id in critical_mitre_ids v["total_score"]
for v in scores_map.values()
if v.get("mitre_id") in critical_mitre_ids
] ]
if critical_techniques: critical_coverage = (
critical_scores = [] round(sum(critical_scores) / len(critical_scores), 1)
for tech in critical_techniques: if critical_scores else 0
result = calculate_technique_score(tech, db) )
critical_scores.append(result["total_score"])
critical_coverage = round(sum(critical_scores) / len(critical_scores), 1)
else:
critical_coverage = 0
# Detection maturity: based on detection rule coverage # Detection maturity (2 scalar queries — already efficient)
total_rules = ( total_rules = (
db.query(func.count(DetectionRule.id)) db.query(func.count(DetectionRule.id))
.filter(DetectionRule.is_active == True) .filter(DetectionRule.is_active == True) # noqa: E712
.scalar() .scalar()
) or 0 ) or 0
triggered_total = ( triggered_total = (
db.query(func.count(TestDetectionResult.id)) db.query(func.count(TestDetectionResult.id))
.filter(TestDetectionResult.triggered == True) .filter(TestDetectionResult.triggered == True) # noqa: E712
.scalar() .scalar()
) or 0 ) or 0
detection_maturity = ( detection_maturity = (
round((triggered_total / total_rules) * 100, 1) round((triggered_total / total_rules) * 100, 1)
if total_rules > 0 if total_rules > 0 else 0
else 0
) )
detection_maturity = min(detection_maturity, 100) detection_maturity = min(detection_maturity, 100)
# Response readiness: based on remediation completion # Response readiness (2 scalar queries — already efficient)
remediation_total = ( remediation_total = (
db.query(func.count(Test.id)) db.query(func.count(Test.id))
.filter(Test.remediation_status.isnot(None)) .filter(Test.remediation_status.isnot(None))
@@ -389,11 +576,9 @@ def calculate_organization_score(db: Session) -> dict:
response_readiness = ( response_readiness = (
round((remediation_completed / remediation_total) * 100, 1) round((remediation_completed / remediation_total) * 100, 1)
if remediation_total > 0 if remediation_total > 0 else 0
else 0
) )
# Overall score: weighted average of sub-scores
overall = round( overall = round(
total_coverage * 0.4 total_coverage * 0.4
+ critical_coverage * 0.25 + critical_coverage * 0.25