perf(scoring): eliminate N+1 in organization score calculation
- Add bulk_technique_scores() that pre-fetches all scoring data in 5 aggregated GROUP BY queries instead of N*5 per-technique queries - Rewrite calculate_organization_score to use bulk data (N*5+5 queries -> 10 fixed queries) - Rewrite calculate_tactic_score and calculate_actor_coverage_score to use bulk data - Preserve calculate_technique_score single-technique API for router-level calls
This commit is contained in:
@@ -2,12 +2,16 @@
|
||||
|
||||
Uses configurable weights from Settings to compute coverage scores with
|
||||
detailed breakdowns.
|
||||
|
||||
Bulk helpers (``bulk_technique_scores``) pre-fetch all scoring data in a
|
||||
fixed number of aggregated queries so that organisation-wide calculations
|
||||
never produce N+1 traffic.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import case, func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.config import settings
|
||||
@@ -20,7 +24,219 @@ from app.models.threat_actor import ThreatActor, ThreatActorTechnique
|
||||
from app.models.enums import TestState, TestResult
|
||||
|
||||
|
||||
# ── Technique-level scoring ──────────────────────────────────────────
|
||||
# ── Bulk scoring helpers (5 queries for ALL techniques) ───────────────
|
||||
|
||||
|
||||
def _build_empty_stats():
|
||||
return {
|
||||
"validated": 0,
|
||||
"detected": 0,
|
||||
"platforms": set(),
|
||||
"latest_validated_at": None,
|
||||
}
|
||||
|
||||
|
||||
def bulk_technique_scores(db: Session) -> dict:
|
||||
"""Pre-fetch all scoring data and compute per-technique scores in memory.
|
||||
|
||||
Executes exactly 5 queries regardless of technique count:
|
||||
Q1 — Test aggregates per technique (validated / detected / platforms / freshness)
|
||||
Q2 — Detection rules per mitre_id
|
||||
Q3 — Triggered rules per mitre_id
|
||||
Q4 — D3FEND mapping counts per technique
|
||||
Q5 — All techniques
|
||||
|
||||
Returns ``{technique_id: {"total_score": float, "breakdown": dict}}``.
|
||||
"""
|
||||
w_tests = settings.SCORING_WEIGHT_TESTS
|
||||
w_detection = settings.SCORING_WEIGHT_DETECTION_RULES
|
||||
w_d3fend = settings.SCORING_WEIGHT_D3FEND
|
||||
w_freshness = settings.SCORING_WEIGHT_FRESHNESS
|
||||
w_diversity = settings.SCORING_WEIGHT_PLATFORM_DIVERSITY
|
||||
|
||||
# Q1: test stats grouped by technique_id
|
||||
test_rows = (
|
||||
db.query(
|
||||
Test.technique_id,
|
||||
func.count(Test.id).label("validated_count"),
|
||||
func.count(
|
||||
case((Test.detection_result == TestResult.detected, Test.id))
|
||||
).label("detected_count"),
|
||||
func.max(Test.red_validated_at).label("latest_validated_at"),
|
||||
func.count(func.distinct(Test.platform)).label("platform_count"),
|
||||
)
|
||||
.filter(Test.state == TestState.validated)
|
||||
.group_by(Test.technique_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
test_stats: dict = {}
|
||||
for row in test_rows:
|
||||
test_stats[row.technique_id] = {
|
||||
"validated": row.validated_count,
|
||||
"detected": row.detected_count,
|
||||
"latest_validated_at": row.latest_validated_at,
|
||||
"platform_count": row.platform_count,
|
||||
}
|
||||
|
||||
# Q2: active detection rules per mitre_id
|
||||
rule_rows = (
|
||||
db.query(
|
||||
DetectionRule.mitre_technique_id,
|
||||
func.count(DetectionRule.id).label("total"),
|
||||
)
|
||||
.filter(DetectionRule.is_active == True) # noqa: E712
|
||||
.group_by(DetectionRule.mitre_technique_id)
|
||||
.all()
|
||||
)
|
||||
rules_by_mitre: dict[str, int] = {r.mitre_technique_id: r.total for r in rule_rows}
|
||||
|
||||
# Q3: triggered rules per mitre_id
|
||||
triggered_rows = (
|
||||
db.query(
|
||||
DetectionRule.mitre_technique_id,
|
||||
func.count(TestDetectionResult.id).label("triggered"),
|
||||
)
|
||||
.join(DetectionRule, DetectionRule.id == TestDetectionResult.detection_rule_id)
|
||||
.filter(TestDetectionResult.triggered == True) # noqa: E712
|
||||
.group_by(DetectionRule.mitre_technique_id)
|
||||
.all()
|
||||
)
|
||||
triggered_by_mitre: dict[str, int] = {
|
||||
r.mitre_technique_id: r.triggered for r in triggered_rows
|
||||
}
|
||||
|
||||
# Q4: D3FEND mapping counts per technique
|
||||
d3fend_rows = (
|
||||
db.query(
|
||||
DefensiveTechniqueMapping.attack_technique_id,
|
||||
func.count(DefensiveTechniqueMapping.id).label("total"),
|
||||
)
|
||||
.group_by(DefensiveTechniqueMapping.attack_technique_id)
|
||||
.all()
|
||||
)
|
||||
d3fend_by_tech: dict = {r.attack_technique_id: r.total for r in d3fend_rows}
|
||||
|
||||
# Q5: all techniques
|
||||
techniques = db.query(Technique).all()
|
||||
|
||||
now = datetime.utcnow()
|
||||
results: dict = {}
|
||||
|
||||
for tech in techniques:
|
||||
ts = test_stats.get(tech.id, {})
|
||||
validated = ts.get("validated", 0)
|
||||
detected = ts.get("detected", 0)
|
||||
latest_at = ts.get("latest_validated_at")
|
||||
plat_count = ts.get("platform_count", 0)
|
||||
|
||||
breakdown = {}
|
||||
|
||||
# 1. Tests validated with detection
|
||||
if validated > 0:
|
||||
test_ratio = detected / validated
|
||||
test_score = round(test_ratio * w_tests, 1)
|
||||
else:
|
||||
test_ratio = 0
|
||||
test_score = 0
|
||||
breakdown["tests_validated"] = {
|
||||
"score": test_score,
|
||||
"max": w_tests,
|
||||
"detail": (
|
||||
f"{detected}/{validated} tests detected"
|
||||
if validated else "No validated tests"
|
||||
),
|
||||
}
|
||||
|
||||
# 2. Detection rules
|
||||
total_rules = rules_by_mitre.get(tech.mitre_id, 0)
|
||||
triggered_rules = triggered_by_mitre.get(tech.mitre_id, 0)
|
||||
if total_rules > 0:
|
||||
detection_ratio = min(triggered_rules / total_rules, 1.0)
|
||||
detection_score = round(detection_ratio * w_detection, 1)
|
||||
else:
|
||||
detection_ratio = 0
|
||||
detection_score = 0
|
||||
breakdown["detection_rules"] = {
|
||||
"score": detection_score,
|
||||
"max": w_detection,
|
||||
"detail": (
|
||||
f"{triggered_rules}/{total_rules} rules triggered"
|
||||
if total_rules > 0 else "No detection rules available"
|
||||
),
|
||||
}
|
||||
|
||||
# 3. D3FEND coverage
|
||||
total_cm = d3fend_by_tech.get(tech.id, 0)
|
||||
if total_cm > 0 and detected > 0:
|
||||
verified_cm = min(detected, total_cm)
|
||||
d3fend_score = round((verified_cm / total_cm) * w_d3fend, 1)
|
||||
else:
|
||||
verified_cm = 0
|
||||
d3fend_score = 0
|
||||
breakdown["d3fend_coverage"] = {
|
||||
"score": d3fend_score,
|
||||
"max": w_d3fend,
|
||||
"detail": (
|
||||
f"{verified_cm}/{total_cm} countermeasures"
|
||||
if total_cm > 0 else "No D3FEND mappings"
|
||||
),
|
||||
}
|
||||
|
||||
# 4. Freshness
|
||||
if latest_at:
|
||||
days_ago = (now - latest_at).days
|
||||
if days_ago < 90:
|
||||
freshness_pct = 1.0
|
||||
elif days_ago < 180:
|
||||
freshness_pct = 0.5
|
||||
else:
|
||||
freshness_pct = 0.0
|
||||
freshness_score = round(freshness_pct * w_freshness, 1)
|
||||
freshness_detail = f"Last test {days_ago} days ago"
|
||||
else:
|
||||
freshness_score = 0
|
||||
freshness_detail = "No validated tests"
|
||||
breakdown["freshness"] = {
|
||||
"score": freshness_score,
|
||||
"max": w_freshness,
|
||||
"detail": freshness_detail,
|
||||
}
|
||||
|
||||
# 5. Platform diversity
|
||||
available = tech.platforms or []
|
||||
total_platforms = len(available) if available else 3
|
||||
if total_platforms > 0 and plat_count > 0:
|
||||
diversity_score = round(
|
||||
min(plat_count / total_platforms, 1.0) * w_diversity, 1,
|
||||
)
|
||||
else:
|
||||
diversity_score = 0
|
||||
breakdown["platform_diversity"] = {
|
||||
"score": diversity_score,
|
||||
"max": w_diversity,
|
||||
"detail": (
|
||||
f"{plat_count}/{total_platforms} platforms covered"
|
||||
if plat_count > 0 else "No platforms tested"
|
||||
),
|
||||
}
|
||||
|
||||
total = min(
|
||||
test_score + detection_score + d3fend_score
|
||||
+ freshness_score + diversity_score,
|
||||
100,
|
||||
)
|
||||
results[tech.id] = {
|
||||
"total_score": round(total, 1),
|
||||
"breakdown": breakdown,
|
||||
"mitre_id": tech.mitre_id,
|
||||
"tactic": tech.tactic,
|
||||
}
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# ── Technique-level scoring (single technique — preserved API) ────────
|
||||
|
||||
|
||||
def calculate_technique_score(technique: Technique, db: Session) -> dict:
|
||||
@@ -73,7 +289,7 @@ def calculate_technique_score(technique: Technique, db: Session) -> dict:
|
||||
db.query(func.count(DetectionRule.id))
|
||||
.filter(
|
||||
DetectionRule.mitre_technique_id == technique.mitre_id,
|
||||
DetectionRule.is_active == True,
|
||||
DetectionRule.is_active == True, # noqa: E712
|
||||
)
|
||||
.scalar()
|
||||
) or 0
|
||||
@@ -88,7 +304,7 @@ def calculate_technique_score(technique: Technique, db: Session) -> dict:
|
||||
)
|
||||
.filter(
|
||||
DetectionRule.mitre_technique_id == technique.mitre_id,
|
||||
TestDetectionResult.triggered == True,
|
||||
TestDetectionResult.triggered == True, # noqa: E712
|
||||
)
|
||||
.scalar()
|
||||
) or 0
|
||||
@@ -114,11 +330,8 @@ def calculate_technique_score(technique: Technique, db: Session) -> dict:
|
||||
.scalar()
|
||||
) or 0
|
||||
|
||||
# Consider a countermeasure "verified" if we have validated tests
|
||||
# with detection for the technique (simplified heuristic)
|
||||
verified_countermeasures = 0
|
||||
if total_countermeasures > 0 and len(detected_tests) > 0:
|
||||
# Rough heuristic: each detected test validates ~1 countermeasure
|
||||
verified_countermeasures = min(len(detected_tests), total_countermeasures)
|
||||
d3fend_ratio = verified_countermeasures / total_countermeasures
|
||||
d3fend_score = round(d3fend_ratio * w_d3fend, 1)
|
||||
@@ -135,7 +348,6 @@ def calculate_technique_score(technique: Technique, db: Session) -> dict:
|
||||
}
|
||||
|
||||
# ── 4. Freshness ──────────────────────────────────────────────
|
||||
# Most recent validated test date
|
||||
most_recent_test = (
|
||||
db.query(func.max(Test.red_validated_at))
|
||||
.filter(
|
||||
@@ -169,7 +381,7 @@ def calculate_technique_score(technique: Technique, db: Session) -> dict:
|
||||
|
||||
# ── 5. Platform diversity ─────────────────────────────────────
|
||||
available_platforms = technique.platforms or []
|
||||
total_platforms = len(available_platforms) if available_platforms else 3 # default 3
|
||||
total_platforms = len(available_platforms) if available_platforms else 3
|
||||
|
||||
tested_platforms = set()
|
||||
for t in validated_tests:
|
||||
@@ -208,30 +420,19 @@ def calculate_technique_score(technique: Technique, db: Session) -> dict:
|
||||
|
||||
def calculate_tactic_score(tactic: str, db: Session) -> dict:
|
||||
"""Calculate average score for all techniques in a tactic."""
|
||||
techniques = (
|
||||
db.query(Technique)
|
||||
.filter(Technique.tactic.ilike(f"%{tactic}%"))
|
||||
.all()
|
||||
)
|
||||
scores_map = bulk_technique_scores(db)
|
||||
|
||||
if not techniques:
|
||||
return {
|
||||
"tactic": tactic,
|
||||
"average_score": 0,
|
||||
"techniques_count": 0,
|
||||
"techniques_scored": 0,
|
||||
}
|
||||
|
||||
scores = []
|
||||
for tech in techniques:
|
||||
result = calculate_technique_score(tech, db)
|
||||
scores.append(result["total_score"])
|
||||
matching = [
|
||||
v["total_score"]
|
||||
for v in scores_map.values()
|
||||
if v.get("tactic") and tactic.lower() in v["tactic"].lower()
|
||||
]
|
||||
|
||||
return {
|
||||
"tactic": tactic,
|
||||
"average_score": round(sum(scores) / len(scores), 1) if scores else 0,
|
||||
"techniques_count": len(techniques),
|
||||
"techniques_scored": len([s for s in scores if s > 0]),
|
||||
"average_score": round(sum(matching) / len(matching), 1) if matching else 0,
|
||||
"techniques_count": len(matching),
|
||||
"techniques_scored": len([s for s in matching if s > 0]),
|
||||
}
|
||||
|
||||
|
||||
@@ -244,14 +445,13 @@ def calculate_actor_coverage_score(actor_id: str, db: Session) -> dict:
|
||||
if not actor:
|
||||
return {"total_score": 0, "techniques_count": 0, "techniques_covered": 0}
|
||||
|
||||
# Get all techniques used by this actor
|
||||
actor_techniques = (
|
||||
db.query(ThreatActorTechnique)
|
||||
.filter(ThreatActorTechnique.threat_actor_id == actor.id)
|
||||
.all()
|
||||
)
|
||||
|
||||
technique_ids = [at.technique_id for at in actor_techniques]
|
||||
technique_ids = {at.technique_id for at in actor_techniques}
|
||||
if not technique_ids:
|
||||
return {
|
||||
"actor_id": str(actor.id),
|
||||
@@ -262,23 +462,21 @@ def calculate_actor_coverage_score(actor_id: str, db: Session) -> dict:
|
||||
"techniques_detail": [],
|
||||
}
|
||||
|
||||
techniques = (
|
||||
db.query(Technique)
|
||||
.filter(Technique.id.in_(technique_ids))
|
||||
.all()
|
||||
)
|
||||
scores_map = bulk_technique_scores(db)
|
||||
|
||||
scores = []
|
||||
details = []
|
||||
for tech in techniques:
|
||||
result = calculate_technique_score(tech, db)
|
||||
score = result["total_score"]
|
||||
for tid in technique_ids:
|
||||
entry = scores_map.get(tid)
|
||||
if not entry:
|
||||
continue
|
||||
score = entry["total_score"]
|
||||
scores.append(score)
|
||||
details.append({
|
||||
"mitre_id": tech.mitre_id,
|
||||
"name": tech.name,
|
||||
"mitre_id": entry["mitre_id"],
|
||||
"name": entry.get("name", ""),
|
||||
"score": score,
|
||||
"breakdown": result["breakdown"],
|
||||
"breakdown": entry["breakdown"],
|
||||
})
|
||||
|
||||
avg_score = round(sum(scores) / len(scores), 1) if scores else 0
|
||||
@@ -287,7 +485,7 @@ def calculate_actor_coverage_score(actor_id: str, db: Session) -> dict:
|
||||
"actor_id": str(actor.id),
|
||||
"actor_name": actor.name,
|
||||
"total_score": avg_score,
|
||||
"techniques_count": len(techniques),
|
||||
"techniques_count": len(technique_ids),
|
||||
"techniques_covered": len([s for s in scores if s > 50]),
|
||||
"techniques_detail": details,
|
||||
}
|
||||
@@ -297,10 +495,13 @@ def calculate_actor_coverage_score(actor_id: str, db: Session) -> dict:
|
||||
|
||||
|
||||
def calculate_organization_score(db: Session) -> dict:
|
||||
"""Calculate the overall organization security score."""
|
||||
# All techniques
|
||||
all_techniques = db.query(Technique).all()
|
||||
total_count = len(all_techniques)
|
||||
"""Calculate the overall organization security score.
|
||||
|
||||
Uses ``bulk_technique_scores`` to compute all technique scores in
|
||||
5 aggregated queries instead of N*5.
|
||||
"""
|
||||
scores_map = bulk_technique_scores(db)
|
||||
total_count = len(scores_map)
|
||||
|
||||
if total_count == 0:
|
||||
return {
|
||||
@@ -313,27 +514,16 @@ def calculate_organization_score(db: Session) -> dict:
|
||||
"techniques_total": 0,
|
||||
}
|
||||
|
||||
# Calculate scores for all techniques (with caching for performance)
|
||||
all_scores = []
|
||||
evaluated_count = 0
|
||||
|
||||
for tech in all_techniques:
|
||||
result = calculate_technique_score(tech, db)
|
||||
score = result["total_score"]
|
||||
all_scores.append(score)
|
||||
if score > 0:
|
||||
evaluated_count += 1
|
||||
|
||||
# Total coverage: average of all evaluated techniques
|
||||
all_scores = [v["total_score"] for v in scores_map.values()]
|
||||
evaluated_scores = [s for s in all_scores if s > 0]
|
||||
evaluated_count = len(evaluated_scores)
|
||||
|
||||
total_coverage = (
|
||||
round(sum(evaluated_scores) / len(evaluated_scores), 1)
|
||||
if evaluated_scores
|
||||
else 0
|
||||
if evaluated_scores else 0
|
||||
)
|
||||
|
||||
# Critical coverage: techniques with high-severity templates
|
||||
# (simplified: techniques that have tests are "critical")
|
||||
# Critical coverage: techniques with high/critical severity templates
|
||||
from app.models.test_template import TestTemplate
|
||||
|
||||
critical_mitre_ids = set(
|
||||
@@ -344,38 +534,35 @@ def calculate_organization_score(db: Session) -> dict:
|
||||
.all()
|
||||
)
|
||||
|
||||
critical_techniques = [
|
||||
t for t in all_techniques if t.mitre_id in critical_mitre_ids
|
||||
critical_scores = [
|
||||
v["total_score"]
|
||||
for v in scores_map.values()
|
||||
if v.get("mitre_id") in critical_mitre_ids
|
||||
]
|
||||
if critical_techniques:
|
||||
critical_scores = []
|
||||
for tech in critical_techniques:
|
||||
result = calculate_technique_score(tech, db)
|
||||
critical_scores.append(result["total_score"])
|
||||
critical_coverage = round(sum(critical_scores) / len(critical_scores), 1)
|
||||
else:
|
||||
critical_coverage = 0
|
||||
critical_coverage = (
|
||||
round(sum(critical_scores) / len(critical_scores), 1)
|
||||
if critical_scores else 0
|
||||
)
|
||||
|
||||
# Detection maturity: based on detection rule coverage
|
||||
# Detection maturity (2 scalar queries — already efficient)
|
||||
total_rules = (
|
||||
db.query(func.count(DetectionRule.id))
|
||||
.filter(DetectionRule.is_active == True)
|
||||
.filter(DetectionRule.is_active == True) # noqa: E712
|
||||
.scalar()
|
||||
) or 0
|
||||
triggered_total = (
|
||||
db.query(func.count(TestDetectionResult.id))
|
||||
.filter(TestDetectionResult.triggered == True)
|
||||
.filter(TestDetectionResult.triggered == True) # noqa: E712
|
||||
.scalar()
|
||||
) or 0
|
||||
|
||||
detection_maturity = (
|
||||
round((triggered_total / total_rules) * 100, 1)
|
||||
if total_rules > 0
|
||||
else 0
|
||||
if total_rules > 0 else 0
|
||||
)
|
||||
detection_maturity = min(detection_maturity, 100)
|
||||
|
||||
# Response readiness: based on remediation completion
|
||||
# Response readiness (2 scalar queries — already efficient)
|
||||
remediation_total = (
|
||||
db.query(func.count(Test.id))
|
||||
.filter(Test.remediation_status.isnot(None))
|
||||
@@ -389,11 +576,9 @@ def calculate_organization_score(db: Session) -> dict:
|
||||
|
||||
response_readiness = (
|
||||
round((remediation_completed / remediation_total) * 100, 1)
|
||||
if remediation_total > 0
|
||||
else 0
|
||||
if remediation_total > 0 else 0
|
||||
)
|
||||
|
||||
# Overall score: weighted average of sub-scores
|
||||
overall = round(
|
||||
total_coverage * 0.4
|
||||
+ critical_coverage * 0.25
|
||||
|
||||
Reference in New Issue
Block a user