Files
Aegis/backend/app/services/scoring_service.py

642 lines
21 KiB
Python

"""Scoring service — granular 0-100 scoring for techniques, tactics, actors, and org.
Reads configurable weights from the ``scoring_config`` table (falling
back to env-var defaults) to compute coverage scores with detailed
breakdowns.
Bulk helpers (``bulk_technique_scores``) pre-fetch all scoring data in a
fixed number of aggregated queries so that organisation-wide calculations
never produce N+1 traffic.
"""
from datetime import datetime, timedelta
from typing import Optional
from sqlalchemy import case, func
from sqlalchemy.orm import Session
from app.models.technique import Technique
from app.models.test import Test
from app.models.detection_rule import DetectionRule
from app.models.test_detection_result import TestDetectionResult
from app.models.defensive_technique import DefensiveTechniqueMapping
from app.models.threat_actor import ThreatActor, ThreatActorTechnique
from app.models.enums import TestState, TestResult
from app.services.scoring_config_service import get_scoring_weights
# ── Bulk scoring helpers (5 queries for ALL techniques) ───────────────
def bulk_technique_scores(db: Session) -> dict:
"""Pre-fetch all scoring data and compute per-technique scores in memory.
Executes exactly 5 queries regardless of technique count:
Q1 — Test aggregates per technique (validated / detected / platforms / freshness)
Q2 — Detection rules per mitre_id
Q3 — Triggered rules per mitre_id
Q4 — D3FEND mapping counts per technique
Q5 — All techniques
Returns ``{technique_id: {"total_score": float, "breakdown": dict}}``.
"""
w = get_scoring_weights(db)
w_tests = w.tests
w_detection = w.detection_rules
w_d3fend = w.d3fend
w_freshness = w.freshness
w_diversity = w.platform_diversity
# Q1: test stats grouped by technique_id
test_rows = (
db.query(
Test.technique_id,
func.count(Test.id).label("validated_count"),
func.count(
case((Test.detection_result == TestResult.detected, Test.id))
).label("detected_count"),
func.max(Test.red_validated_at).label("latest_validated_at"),
func.count(func.distinct(Test.platform)).label("platform_count"),
)
.filter(Test.state == TestState.validated)
.group_by(Test.technique_id)
.all()
)
test_stats: dict = {}
for row in test_rows:
test_stats[row.technique_id] = {
"validated": row.validated_count,
"detected": row.detected_count,
"latest_validated_at": row.latest_validated_at,
"platform_count": row.platform_count,
}
# Q2: active detection rules per mitre_id
rule_rows = (
db.query(
DetectionRule.mitre_technique_id,
func.count(DetectionRule.id).label("total"),
)
.filter(DetectionRule.is_active == True) # noqa: E712
.group_by(DetectionRule.mitre_technique_id)
.all()
)
rules_by_mitre: dict[str, int] = {r.mitre_technique_id: r.total for r in rule_rows}
# Q3: triggered rules per mitre_id
triggered_rows = (
db.query(
DetectionRule.mitre_technique_id,
func.count(TestDetectionResult.id).label("triggered"),
)
.join(DetectionRule, DetectionRule.id == TestDetectionResult.detection_rule_id)
.filter(TestDetectionResult.triggered == True) # noqa: E712
.group_by(DetectionRule.mitre_technique_id)
.all()
)
triggered_by_mitre: dict[str, int] = {
r.mitre_technique_id: r.triggered for r in triggered_rows
}
# Q4: D3FEND mapping counts per technique
d3fend_rows = (
db.query(
DefensiveTechniqueMapping.attack_technique_id,
func.count(DefensiveTechniqueMapping.id).label("total"),
)
.group_by(DefensiveTechniqueMapping.attack_technique_id)
.all()
)
d3fend_by_tech: dict = {r.attack_technique_id: r.total for r in d3fend_rows}
# Q5: all techniques
techniques = db.query(Technique).all()
now = datetime.utcnow()
results: dict = {}
for tech in techniques:
ts = test_stats.get(tech.id, {})
validated = ts.get("validated", 0)
detected = ts.get("detected", 0)
latest_at = ts.get("latest_validated_at")
plat_count = ts.get("platform_count", 0)
breakdown = {}
# 1. Tests validated with detection
if validated > 0:
test_ratio = detected / validated
test_score = round(test_ratio * w_tests, 1)
else:
test_ratio = 0
test_score = 0
breakdown["tests_validated"] = {
"score": test_score,
"max": w_tests,
"detail": (
f"{detected}/{validated} tests detected"
if validated else "No validated tests"
),
}
# 2. Detection rules
total_rules = rules_by_mitre.get(tech.mitre_id, 0)
triggered_rules = triggered_by_mitre.get(tech.mitre_id, 0)
if total_rules > 0:
detection_ratio = min(triggered_rules / total_rules, 1.0)
detection_score = round(detection_ratio * w_detection, 1)
else:
detection_ratio = 0
detection_score = 0
breakdown["detection_rules"] = {
"score": detection_score,
"max": w_detection,
"detail": (
f"{triggered_rules}/{total_rules} rules triggered"
if total_rules > 0 else "No detection rules available"
),
}
# 3. D3FEND coverage
total_cm = d3fend_by_tech.get(tech.id, 0)
if total_cm > 0 and detected > 0:
verified_cm = min(detected, total_cm)
d3fend_score = round((verified_cm / total_cm) * w_d3fend, 1)
else:
verified_cm = 0
d3fend_score = 0
breakdown["d3fend_coverage"] = {
"score": d3fend_score,
"max": w_d3fend,
"detail": (
f"{verified_cm}/{total_cm} countermeasures"
if total_cm > 0 else "No D3FEND mappings"
),
}
# 4. Freshness
if latest_at:
days_ago = (now - latest_at).days
if days_ago < 90:
freshness_pct = 1.0
elif days_ago < 180:
freshness_pct = 0.5
else:
freshness_pct = 0.0
freshness_score = round(freshness_pct * w_freshness, 1)
freshness_detail = f"Last test {days_ago} days ago"
else:
freshness_score = 0
freshness_detail = "No validated tests"
breakdown["freshness"] = {
"score": freshness_score,
"max": w_freshness,
"detail": freshness_detail,
}
# 5. Platform diversity
available = tech.platforms or []
total_platforms = len(available) if available else 3
if total_platforms > 0 and plat_count > 0:
diversity_score = round(
min(plat_count / total_platforms, 1.0) * w_diversity, 1,
)
else:
diversity_score = 0
breakdown["platform_diversity"] = {
"score": diversity_score,
"max": w_diversity,
"detail": (
f"{plat_count}/{total_platforms} platforms covered"
if plat_count > 0 else "No platforms tested"
),
}
total = min(
test_score + detection_score + d3fend_score
+ freshness_score + diversity_score,
100,
)
results[tech.id] = {
"total_score": round(total, 1),
"breakdown": breakdown,
"mitre_id": tech.mitre_id,
"tactic": tech.tactic,
}
return results
# ── Technique-level scoring (single technique — preserved API) ────────
def calculate_technique_score(technique: Technique, db: Session) -> dict:
"""Calculate a 0-100 score for a technique with detailed breakdown.
Weights are read from the ``scoring_config`` table (or env defaults).
"""
w = get_scoring_weights(db)
w_tests = w.tests
w_detection = w.detection_rules
w_d3fend = w.d3fend
w_freshness = w.freshness
w_diversity = w.platform_diversity
breakdown = {}
# ── 1. Tests validated with detection ──────────────────────────
all_tests = (
db.query(Test)
.filter(Test.technique_id == technique.id)
.all()
)
validated_tests = [t for t in all_tests if t.state == TestState.validated]
detected_tests = [
t for t in validated_tests
if t.detection_result == TestResult.detected
]
if validated_tests:
test_ratio = len(detected_tests) / len(validated_tests)
test_score = round(test_ratio * w_tests, 1)
else:
test_ratio = 0
test_score = 0
breakdown["tests_validated"] = {
"score": test_score,
"max": w_tests,
"detail": f"{len(detected_tests)}/{len(validated_tests)} tests detected"
if validated_tests
else "No validated tests",
}
# ── 2. Detection rules coverage ───────────────────────────────
total_rules = (
db.query(func.count(DetectionRule.id))
.filter(
DetectionRule.mitre_technique_id == technique.mitre_id,
DetectionRule.is_active == True, # noqa: E712
)
.scalar()
) or 0
triggered_rules = 0
if total_rules > 0:
triggered_rules = (
db.query(func.count(TestDetectionResult.id))
.join(
DetectionRule,
DetectionRule.id == TestDetectionResult.detection_rule_id,
)
.filter(
DetectionRule.mitre_technique_id == technique.mitre_id,
TestDetectionResult.triggered == True, # noqa: E712
)
.scalar()
) or 0
detection_ratio = min(triggered_rules / total_rules, 1.0)
detection_score = round(detection_ratio * w_detection, 1)
else:
detection_ratio = 0
detection_score = 0
breakdown["detection_rules"] = {
"score": detection_score,
"max": w_detection,
"detail": f"{triggered_rules}/{total_rules} rules triggered"
if total_rules > 0
else "No detection rules available",
}
# ── 3. D3FEND coverage ────────────────────────────────────────
total_countermeasures = (
db.query(func.count(DefensiveTechniqueMapping.id))
.filter(DefensiveTechniqueMapping.attack_technique_id == technique.id)
.scalar()
) or 0
verified_countermeasures = 0
if total_countermeasures > 0 and len(detected_tests) > 0:
verified_countermeasures = min(len(detected_tests), total_countermeasures)
d3fend_ratio = verified_countermeasures / total_countermeasures
d3fend_score = round(d3fend_ratio * w_d3fend, 1)
else:
d3fend_ratio = 0
d3fend_score = 0
breakdown["d3fend_coverage"] = {
"score": d3fend_score,
"max": w_d3fend,
"detail": f"{verified_countermeasures}/{total_countermeasures} countermeasures"
if total_countermeasures > 0
else "No D3FEND mappings",
}
# ── 4. Freshness ──────────────────────────────────────────────
most_recent_test = (
db.query(func.max(Test.red_validated_at))
.filter(
Test.technique_id == technique.id,
Test.state == TestState.validated,
)
.scalar()
)
now = datetime.utcnow()
if most_recent_test:
days_ago = (now - most_recent_test).days
if days_ago < 90:
freshness_pct = 1.0
elif days_ago < 180:
freshness_pct = 0.5
else:
freshness_pct = 0.0
freshness_score = round(freshness_pct * w_freshness, 1)
freshness_detail = f"Last test {days_ago} days ago"
else:
freshness_pct = 0
freshness_score = 0
freshness_detail = "No validated tests"
breakdown["freshness"] = {
"score": freshness_score,
"max": w_freshness,
"detail": freshness_detail,
}
# ── 5. Platform diversity ─────────────────────────────────────
available_platforms = technique.platforms or []
total_platforms = len(available_platforms) if available_platforms else 3
tested_platforms = set()
for t in validated_tests:
if t.platform:
tested_platforms.add(t.platform.lower())
if total_platforms > 0 and tested_platforms:
diversity_ratio = min(len(tested_platforms) / total_platforms, 1.0)
diversity_score = round(diversity_ratio * w_diversity, 1)
else:
diversity_ratio = 0
diversity_score = 0
breakdown["platform_diversity"] = {
"score": diversity_score,
"max": w_diversity,
"detail": f"{len(tested_platforms)}/{total_platforms} platforms covered"
if tested_platforms
else "No platforms tested",
}
# ── Total ─────────────────────────────────────────────────────
total = min(
test_score + detection_score + d3fend_score + freshness_score + diversity_score,
100,
)
return {
"total_score": round(total, 1),
"breakdown": breakdown,
}
# ── Tactic-level scoring ─────────────────────────────────────────────
def calculate_tactic_score(tactic: str, db: Session) -> dict:
"""Calculate average score for all techniques in a tactic."""
scores_map = bulk_technique_scores(db)
matching = [
v["total_score"]
for v in scores_map.values()
if v.get("tactic") and tactic.lower() in v["tactic"].lower()
]
return {
"tactic": tactic,
"average_score": round(sum(matching) / len(matching), 1) if matching else 0,
"techniques_count": len(matching),
"techniques_scored": len([s for s in matching if s > 0]),
}
# ── Threat actor scoring ─────────────────────────────────────────────
def calculate_actor_coverage_score(actor_id: str, db: Session) -> dict:
"""Calculate coverage score for a specific threat actor's techniques."""
actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first()
if not actor:
return {"total_score": 0, "techniques_count": 0, "techniques_covered": 0}
actor_techniques = (
db.query(ThreatActorTechnique)
.filter(ThreatActorTechnique.threat_actor_id == actor.id)
.all()
)
technique_ids = {at.technique_id for at in actor_techniques}
if not technique_ids:
return {
"actor_id": str(actor.id),
"actor_name": actor.name,
"total_score": 0,
"techniques_count": 0,
"techniques_covered": 0,
"techniques_detail": [],
}
scores_map = bulk_technique_scores(db)
scores = []
details = []
for tid in technique_ids:
entry = scores_map.get(tid)
if not entry:
continue
score = entry["total_score"]
scores.append(score)
details.append({
"mitre_id": entry["mitre_id"],
"name": entry.get("name", ""),
"score": score,
"breakdown": entry["breakdown"],
})
avg_score = round(sum(scores) / len(scores), 1) if scores else 0
return {
"actor_id": str(actor.id),
"actor_name": actor.name,
"total_score": avg_score,
"techniques_count": len(technique_ids),
"techniques_covered": len([s for s in scores if s > 50]),
"techniques_detail": details,
}
# ── Organization-level scoring ────────────────────────────────────────
def calculate_organization_score(db: Session) -> dict:
"""Calculate the overall organization security score.
Uses ``bulk_technique_scores`` to compute all technique scores in
5 aggregated queries instead of N*5.
"""
scores_map = bulk_technique_scores(db)
total_count = len(scores_map)
if total_count == 0:
return {
"overall_score": 0,
"total_coverage": 0,
"critical_coverage": 0,
"detection_maturity": 0,
"response_readiness": 0,
"techniques_evaluated": 0,
"techniques_total": 0,
}
all_scores = [v["total_score"] for v in scores_map.values()]
evaluated_scores = [s for s in all_scores if s > 0]
evaluated_count = len(evaluated_scores)
total_coverage = (
round(sum(evaluated_scores) / len(evaluated_scores), 1)
if evaluated_scores else 0
)
# Critical coverage: techniques with high/critical severity templates
from app.models.test_template import TestTemplate
critical_mitre_ids = set(
row[0]
for row in db.query(TestTemplate.mitre_technique_id)
.filter(TestTemplate.severity.in_(["high", "critical"]))
.distinct()
.all()
)
critical_scores = [
v["total_score"]
for v in scores_map.values()
if v.get("mitre_id") in critical_mitre_ids
]
critical_coverage = (
round(sum(critical_scores) / len(critical_scores), 1)
if critical_scores else 0
)
# Detection maturity (2 scalar queries — already efficient)
total_rules = (
db.query(func.count(DetectionRule.id))
.filter(DetectionRule.is_active == True) # noqa: E712
.scalar()
) or 0
triggered_total = (
db.query(func.count(TestDetectionResult.id))
.filter(TestDetectionResult.triggered == True) # noqa: E712
.scalar()
) or 0
detection_maturity = (
round((triggered_total / total_rules) * 100, 1)
if total_rules > 0 else 0
)
detection_maturity = min(detection_maturity, 100)
# Response readiness (2 scalar queries — already efficient)
remediation_total = (
db.query(func.count(Test.id))
.filter(Test.remediation_status.isnot(None))
.scalar()
) or 0
remediation_completed = (
db.query(func.count(Test.id))
.filter(Test.remediation_status == "completed")
.scalar()
) or 0
response_readiness = (
round((remediation_completed / remediation_total) * 100, 1)
if remediation_total > 0 else 0
)
overall = round(
total_coverage * 0.4
+ critical_coverage * 0.25
+ detection_maturity * 0.2
+ response_readiness * 0.15,
1,
)
return {
"overall_score": overall,
"total_coverage": total_coverage,
"critical_coverage": critical_coverage,
"detection_maturity": detection_maturity,
"response_readiness": response_readiness,
"techniques_evaluated": evaluated_count,
"techniques_total": total_count,
}
# ── Score history ────────────────────────────────────────────────────
def get_score_history(db: Session, period: str = "90d") -> list:
"""Get historical score snapshots.
Since we don't have a dedicated history table, we approximate by
computing scores based on test dates within time windows.
Returns a list of weekly data points.
"""
from app.models.audit import AuditLog
now = datetime.utcnow()
if period == "30d":
start = now - timedelta(days=30)
elif period == "1y":
start = now - timedelta(days=365)
else: # 90d default
start = now - timedelta(days=90)
# Group validated tests by week
weeks = []
current = start
while current < now:
week_end = min(current + timedelta(days=7), now)
# Count validated tests up to this week
validated_up_to = (
db.query(func.count(Test.id))
.filter(
Test.state == TestState.validated,
Test.red_validated_at <= week_end,
)
.scalar()
) or 0
total_techniques = (
db.query(func.count(Technique.id)).scalar()
) or 1
# Simple approximation: coverage percentage as score proxy
score_approx = round((validated_up_to / total_techniques) * 100, 1)
weeks.append({
"date": current.strftime("%Y-%m-%d"),
"score": min(score_approx, 100),
"validated_tests": validated_up_to,
})
current = week_end
return weeks