Files
Aegis/backend/app/services/scoring_service.py

666 lines
22 KiB
Python

"""Scoring service — granular 0-100 scoring for techniques, tactics, actors, and org.
Reads configurable weights from the ``scoring_config`` table (falling
back to env-var defaults) to compute coverage scores with detailed
breakdowns.
Bulk helpers (``bulk_technique_scores``) pre-fetch all scoring data in a
fixed number of aggregated queries so that organisation-wide calculations
never produce N+1 traffic.
"""
from datetime import datetime, timedelta
from typing import Optional
from sqlalchemy import case, func
from sqlalchemy.orm import Session
from app.domain.errors import EntityNotFoundError
from app.models.technique import Technique
from app.models.test import Test
from app.models.detection_rule import DetectionRule
from app.models.test_detection_result import TestDetectionResult
from app.models.defensive_technique import DefensiveTechniqueMapping
from app.models.threat_actor import ThreatActor, ThreatActorTechnique
from app.models.enums import TestState, TestResult
from app.services.scoring_config_service import get_scoring_weights
# ── Bulk scoring helpers (5 queries for ALL techniques) ───────────────
def bulk_technique_scores(db: Session) -> dict:
"""Pre-fetch all scoring data and compute per-technique scores in memory.
Executes exactly 5 queries regardless of technique count:
Q1 — Test aggregates per technique (validated / detected / platforms / freshness)
Q2 — Detection rules per mitre_id
Q3 — Triggered rules per mitre_id
Q4 — D3FEND mapping counts per technique
Q5 — All techniques
Returns ``{technique_id: {"total_score": float, "breakdown": dict}}``.
"""
w = get_scoring_weights(db)
w_tests = w.tests
w_detection = w.detection_rules
w_d3fend = w.d3fend
w_freshness = w.freshness
w_diversity = w.platform_diversity
# Q1: test stats grouped by technique_id
test_rows = (
db.query(
Test.technique_id,
func.count(Test.id).label("validated_count"),
func.count(
case((Test.detection_result == TestResult.detected, Test.id))
).label("detected_count"),
func.max(Test.red_validated_at).label("latest_validated_at"),
func.count(func.distinct(Test.platform)).label("platform_count"),
)
.filter(Test.state == TestState.validated)
.group_by(Test.technique_id)
.all()
)
test_stats: dict = {}
for row in test_rows:
test_stats[row.technique_id] = {
"validated": row.validated_count,
"detected": row.detected_count,
"latest_validated_at": row.latest_validated_at,
"platform_count": row.platform_count,
}
# Q2: active detection rules per mitre_id
rule_rows = (
db.query(
DetectionRule.mitre_technique_id,
func.count(DetectionRule.id).label("total"),
)
.filter(DetectionRule.is_active == True) # noqa: E712
.group_by(DetectionRule.mitre_technique_id)
.all()
)
rules_by_mitre: dict[str, int] = {r.mitre_technique_id: r.total for r in rule_rows}
# Q3: triggered rules per mitre_id
triggered_rows = (
db.query(
DetectionRule.mitre_technique_id,
func.count(TestDetectionResult.id).label("triggered"),
)
.join(DetectionRule, DetectionRule.id == TestDetectionResult.detection_rule_id)
.filter(TestDetectionResult.triggered == True) # noqa: E712
.group_by(DetectionRule.mitre_technique_id)
.all()
)
triggered_by_mitre: dict[str, int] = {
r.mitre_technique_id: r.triggered for r in triggered_rows
}
# Q4: D3FEND mapping counts per technique
d3fend_rows = (
db.query(
DefensiveTechniqueMapping.attack_technique_id,
func.count(DefensiveTechniqueMapping.id).label("total"),
)
.group_by(DefensiveTechniqueMapping.attack_technique_id)
.all()
)
d3fend_by_tech: dict = {r.attack_technique_id: r.total for r in d3fend_rows}
# Q5: all techniques
techniques = db.query(Technique).all()
now = datetime.utcnow()
results: dict = {}
for tech in techniques:
ts = test_stats.get(tech.id, {})
validated = ts.get("validated", 0)
detected = ts.get("detected", 0)
latest_at = ts.get("latest_validated_at")
plat_count = ts.get("platform_count", 0)
breakdown = {}
# 1. Tests validated with detection
if validated > 0:
test_ratio = detected / validated
test_score = round(test_ratio * w_tests, 1)
else:
test_ratio = 0
test_score = 0
breakdown["tests_validated"] = {
"score": test_score,
"max": w_tests,
"detail": (
f"{detected}/{validated} tests detected"
if validated else "No validated tests"
),
}
# 2. Detection rules
total_rules = rules_by_mitre.get(tech.mitre_id, 0)
triggered_rules = triggered_by_mitre.get(tech.mitre_id, 0)
if total_rules > 0:
detection_ratio = min(triggered_rules / total_rules, 1.0)
detection_score = round(detection_ratio * w_detection, 1)
else:
detection_ratio = 0
detection_score = 0
breakdown["detection_rules"] = {
"score": detection_score,
"max": w_detection,
"detail": (
f"{triggered_rules}/{total_rules} rules triggered"
if total_rules > 0 else "No detection rules available"
),
}
# 3. D3FEND coverage
total_cm = d3fend_by_tech.get(tech.id, 0)
if total_cm > 0 and detected > 0:
verified_cm = min(detected, total_cm)
d3fend_score = round((verified_cm / total_cm) * w_d3fend, 1)
else:
verified_cm = 0
d3fend_score = 0
breakdown["d3fend_coverage"] = {
"score": d3fend_score,
"max": w_d3fend,
"detail": (
f"{verified_cm}/{total_cm} countermeasures"
if total_cm > 0 else "No D3FEND mappings"
),
}
# 4. Freshness
if latest_at:
days_ago = (now - latest_at).days
if days_ago < 90:
freshness_pct = 1.0
elif days_ago < 180:
freshness_pct = 0.5
else:
freshness_pct = 0.0
freshness_score = round(freshness_pct * w_freshness, 1)
freshness_detail = f"Last test {days_ago} days ago"
else:
freshness_score = 0
freshness_detail = "No validated tests"
breakdown["freshness"] = {
"score": freshness_score,
"max": w_freshness,
"detail": freshness_detail,
}
# 5. Platform diversity
available = tech.platforms or []
total_platforms = len(available) if available else 3
if total_platforms > 0 and plat_count > 0:
diversity_score = round(
min(plat_count / total_platforms, 1.0) * w_diversity, 1,
)
else:
diversity_score = 0
breakdown["platform_diversity"] = {
"score": diversity_score,
"max": w_diversity,
"detail": (
f"{plat_count}/{total_platforms} platforms covered"
if plat_count > 0 else "No platforms tested"
),
}
total = min(
test_score + detection_score + d3fend_score
+ freshness_score + diversity_score,
100,
)
results[tech.id] = {
"total_score": round(total, 1),
"breakdown": breakdown,
"mitre_id": tech.mitre_id,
"tactic": tech.tactic,
}
return results
# ── Technique-level scoring (single technique — preserved API) ────────
def score_technique_by_mitre_id(db: Session, mitre_id: str) -> dict:
"""Get detailed score with breakdown for a technique by MITRE ID."""
technique = db.query(Technique).filter(Technique.mitre_id == mitre_id).first()
if not technique:
raise EntityNotFoundError("Technique", mitre_id)
result = calculate_technique_score(technique, db)
return {
"mitre_id": technique.mitre_id,
"name": technique.name,
"tactic": technique.tactic,
"status_global": technique.status_global.value if technique.status_global else None,
**result,
}
def score_actor_by_id(db: Session, actor_id: str) -> dict:
"""Get coverage score for a threat actor by ID."""
actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first()
if not actor:
raise EntityNotFoundError("ThreatActor", actor_id)
return calculate_actor_coverage_score(actor_id, db)
def calculate_technique_score(technique: Technique, db: Session) -> dict:
"""Calculate a 0-100 score for a technique with detailed breakdown.
Weights are read from the ``scoring_config`` table (or env defaults).
"""
w = get_scoring_weights(db)
w_tests = w.tests
w_detection = w.detection_rules
w_d3fend = w.d3fend
w_freshness = w.freshness
w_diversity = w.platform_diversity
breakdown = {}
# ── 1. Tests validated with detection ──────────────────────────
all_tests = (
db.query(Test)
.filter(Test.technique_id == technique.id)
.all()
)
validated_tests = [t for t in all_tests if t.state == TestState.validated]
detected_tests = [
t for t in validated_tests
if t.detection_result == TestResult.detected
]
if validated_tests:
test_ratio = len(detected_tests) / len(validated_tests)
test_score = round(test_ratio * w_tests, 1)
else:
test_ratio = 0
test_score = 0
breakdown["tests_validated"] = {
"score": test_score,
"max": w_tests,
"detail": f"{len(detected_tests)}/{len(validated_tests)} tests detected"
if validated_tests
else "No validated tests",
}
# ── 2. Detection rules coverage ───────────────────────────────
total_rules = (
db.query(func.count(DetectionRule.id))
.filter(
DetectionRule.mitre_technique_id == technique.mitre_id,
DetectionRule.is_active == True, # noqa: E712
)
.scalar()
) or 0
triggered_rules = 0
if total_rules > 0:
triggered_rules = (
db.query(func.count(TestDetectionResult.id))
.join(
DetectionRule,
DetectionRule.id == TestDetectionResult.detection_rule_id,
)
.filter(
DetectionRule.mitre_technique_id == technique.mitre_id,
TestDetectionResult.triggered == True, # noqa: E712
)
.scalar()
) or 0
detection_ratio = min(triggered_rules / total_rules, 1.0)
detection_score = round(detection_ratio * w_detection, 1)
else:
detection_ratio = 0
detection_score = 0
breakdown["detection_rules"] = {
"score": detection_score,
"max": w_detection,
"detail": f"{triggered_rules}/{total_rules} rules triggered"
if total_rules > 0
else "No detection rules available",
}
# ── 3. D3FEND coverage ────────────────────────────────────────
total_countermeasures = (
db.query(func.count(DefensiveTechniqueMapping.id))
.filter(DefensiveTechniqueMapping.attack_technique_id == technique.id)
.scalar()
) or 0
verified_countermeasures = 0
if total_countermeasures > 0 and len(detected_tests) > 0:
verified_countermeasures = min(len(detected_tests), total_countermeasures)
d3fend_ratio = verified_countermeasures / total_countermeasures
d3fend_score = round(d3fend_ratio * w_d3fend, 1)
else:
d3fend_ratio = 0
d3fend_score = 0
breakdown["d3fend_coverage"] = {
"score": d3fend_score,
"max": w_d3fend,
"detail": f"{verified_countermeasures}/{total_countermeasures} countermeasures"
if total_countermeasures > 0
else "No D3FEND mappings",
}
# ── 4. Freshness ──────────────────────────────────────────────
most_recent_test = (
db.query(func.max(Test.red_validated_at))
.filter(
Test.technique_id == technique.id,
Test.state == TestState.validated,
)
.scalar()
)
now = datetime.utcnow()
if most_recent_test:
days_ago = (now - most_recent_test).days
if days_ago < 90:
freshness_pct = 1.0
elif days_ago < 180:
freshness_pct = 0.5
else:
freshness_pct = 0.0
freshness_score = round(freshness_pct * w_freshness, 1)
freshness_detail = f"Last test {days_ago} days ago"
else:
freshness_pct = 0
freshness_score = 0
freshness_detail = "No validated tests"
breakdown["freshness"] = {
"score": freshness_score,
"max": w_freshness,
"detail": freshness_detail,
}
# ── 5. Platform diversity ─────────────────────────────────────
available_platforms = technique.platforms or []
total_platforms = len(available_platforms) if available_platforms else 3
tested_platforms = set()
for t in validated_tests:
if t.platform:
tested_platforms.add(t.platform.lower())
if total_platforms > 0 and tested_platforms:
diversity_ratio = min(len(tested_platforms) / total_platforms, 1.0)
diversity_score = round(diversity_ratio * w_diversity, 1)
else:
diversity_ratio = 0
diversity_score = 0
breakdown["platform_diversity"] = {
"score": diversity_score,
"max": w_diversity,
"detail": f"{len(tested_platforms)}/{total_platforms} platforms covered"
if tested_platforms
else "No platforms tested",
}
# ── Total ─────────────────────────────────────────────────────
total = min(
test_score + detection_score + d3fend_score + freshness_score + diversity_score,
100,
)
return {
"total_score": round(total, 1),
"breakdown": breakdown,
}
# ── Tactic-level scoring ─────────────────────────────────────────────
def calculate_tactic_score(tactic: str, db: Session) -> dict:
"""Calculate average score for all techniques in a tactic."""
scores_map = bulk_technique_scores(db)
matching = [
v["total_score"]
for v in scores_map.values()
if v.get("tactic") and tactic.lower() in v["tactic"].lower()
]
return {
"tactic": tactic,
"average_score": round(sum(matching) / len(matching), 1) if matching else 0,
"techniques_count": len(matching),
"techniques_scored": len([s for s in matching if s > 0]),
}
# ── Threat actor scoring ─────────────────────────────────────────────
def calculate_actor_coverage_score(actor_id: str, db: Session) -> dict:
"""Calculate coverage score for a specific threat actor's techniques."""
actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first()
if not actor:
return {"total_score": 0, "techniques_count": 0, "techniques_covered": 0}
actor_techniques = (
db.query(ThreatActorTechnique)
.filter(ThreatActorTechnique.threat_actor_id == actor.id)
.all()
)
technique_ids = {at.technique_id for at in actor_techniques}
if not technique_ids:
return {
"actor_id": str(actor.id),
"actor_name": actor.name,
"total_score": 0,
"techniques_count": 0,
"techniques_covered": 0,
"techniques_detail": [],
}
scores_map = bulk_technique_scores(db)
scores = []
details = []
for tid in technique_ids:
entry = scores_map.get(tid)
if not entry:
continue
score = entry["total_score"]
scores.append(score)
details.append({
"mitre_id": entry["mitre_id"],
"name": entry.get("name", ""),
"score": score,
"breakdown": entry["breakdown"],
})
avg_score = round(sum(scores) / len(scores), 1) if scores else 0
return {
"actor_id": str(actor.id),
"actor_name": actor.name,
"total_score": avg_score,
"techniques_count": len(technique_ids),
"techniques_covered": len([s for s in scores if s > 50]),
"techniques_detail": details,
}
# ── Organization-level scoring ────────────────────────────────────────
def calculate_organization_score(db: Session) -> dict:
"""Calculate the overall organization security score.
Uses ``bulk_technique_scores`` to compute all technique scores in
5 aggregated queries instead of N*5.
"""
scores_map = bulk_technique_scores(db)
total_count = len(scores_map)
if total_count == 0:
return {
"overall_score": 0,
"total_coverage": 0,
"critical_coverage": 0,
"detection_maturity": 0,
"response_readiness": 0,
"techniques_evaluated": 0,
"techniques_total": 0,
}
all_scores = [v["total_score"] for v in scores_map.values()]
evaluated_scores = [s for s in all_scores if s > 0]
evaluated_count = len(evaluated_scores)
total_coverage = (
round(sum(evaluated_scores) / len(evaluated_scores), 1)
if evaluated_scores else 0
)
# Critical coverage: techniques with high/critical severity templates
from app.models.test_template import TestTemplate
critical_mitre_ids = set(
row[0]
for row in db.query(TestTemplate.mitre_technique_id)
.filter(TestTemplate.severity.in_(["high", "critical"]))
.distinct()
.all()
)
critical_scores = [
v["total_score"]
for v in scores_map.values()
if v.get("mitre_id") in critical_mitre_ids
]
critical_coverage = (
round(sum(critical_scores) / len(critical_scores), 1)
if critical_scores else 0
)
# Detection maturity (2 scalar queries — already efficient)
total_rules = (
db.query(func.count(DetectionRule.id))
.filter(DetectionRule.is_active == True) # noqa: E712
.scalar()
) or 0
triggered_total = (
db.query(func.count(TestDetectionResult.id))
.filter(TestDetectionResult.triggered == True) # noqa: E712
.scalar()
) or 0
detection_maturity = (
round((triggered_total / total_rules) * 100, 1)
if total_rules > 0 else 0
)
detection_maturity = min(detection_maturity, 100)
# Response readiness (2 scalar queries — already efficient)
remediation_total = (
db.query(func.count(Test.id))
.filter(Test.remediation_status.isnot(None))
.scalar()
) or 0
remediation_completed = (
db.query(func.count(Test.id))
.filter(Test.remediation_status == "completed")
.scalar()
) or 0
response_readiness = (
round((remediation_completed / remediation_total) * 100, 1)
if remediation_total > 0 else 0
)
overall = round(
total_coverage * 0.4
+ critical_coverage * 0.25
+ detection_maturity * 0.2
+ response_readiness * 0.15,
1,
)
return {
"overall_score": overall,
"total_coverage": total_coverage,
"critical_coverage": critical_coverage,
"detection_maturity": detection_maturity,
"response_readiness": response_readiness,
"techniques_evaluated": evaluated_count,
"techniques_total": total_count,
}
# ── Score history ────────────────────────────────────────────────────
def get_score_history(db: Session, period: str = "90d") -> list:
"""Get historical score snapshots.
Since we don't have a dedicated history table, we approximate by
computing scores based on test dates within time windows.
Returns a list of weekly data points.
"""
from app.models.audit import AuditLog
now = datetime.utcnow()
if period == "30d":
start = now - timedelta(days=30)
elif period == "1y":
start = now - timedelta(days=365)
else: # 90d default
start = now - timedelta(days=90)
# Group validated tests by week
weeks = []
current = start
while current < now:
week_end = min(current + timedelta(days=7), now)
# Count validated tests up to this week
validated_up_to = (
db.query(func.count(Test.id))
.filter(
Test.state == TestState.validated,
Test.red_validated_at <= week_end,
)
.scalar()
) or 0
total_techniques = (
db.query(func.count(Technique.id)).scalar()
) or 1
# Simple approximation: coverage percentage as score proxy
score_approx = round((validated_up_to / total_techniques) * 100, 1)
weeks.append({
"date": current.strftime("%Y-%m-%d"),
"score": min(score_approx, 100),
"validated_tests": validated_up_to,
})
current = week_end
return weeks