feat(scoring): composite recency decay and severity weights persisted in DB [FASE-5.1]

This commit is contained in:
2026-05-18 15:07:12 +02:00
parent 2ee59d4e18
commit 05b221a22d
13 changed files with 588 additions and 154 deletions

View File

@@ -1,14 +1,8 @@
"""Scoring configuration persistence service.
Reads and writes scoring weights from the ``scoring_config`` table.
Falls back to environment-variable defaults (from ``Settings``) when
no row has been persisted yet.
This module is framework-agnostic: no FastAPI imports.
"""
"""Scoring configuration persistence service."""
from __future__ import annotations
import uuid
from typing import Any
from sqlalchemy.orm import Session
@@ -18,29 +12,39 @@ from app.domain.value_objects.scoring_weights import ScoringWeights
from app.models.scoring_config import ScoringConfig
def get_scoring_weights(db: Session) -> ScoringWeights:
"""Return the active scoring weights.
def _row_recency(row: ScoringConfig) -> float:
return float(getattr(row, "weight_recency", None) or getattr(row, "weight_freshness", 10.0))
Reads the single ``scoring_config`` row. If the table is empty
(first run or migration just applied), falls back to the values
from the environment / ``Settings``.
"""
def _row_severity(row: ScoringConfig) -> float:
return float(
getattr(row, "weight_severity", None)
or getattr(row, "weight_platform_diversity", 10.0)
)
def get_scoring_weights(db: Session) -> ScoringWeights:
"""Return the active scoring weights from the database or env defaults."""
row = db.query(ScoringConfig).first()
if row is not None:
return ScoringWeights(
tests=row.weight_tests,
detection_rules=row.weight_detection_rules,
d3fend=row.weight_d3fend,
freshness=row.weight_freshness,
platform_diversity=row.weight_platform_diversity,
recency=_row_recency(row),
severity=_row_severity(row),
)
return ScoringWeights(
tests=float(settings.SCORING_WEIGHT_TESTS),
detection_rules=float(settings.SCORING_WEIGHT_DETECTION_RULES),
d3fend=float(settings.SCORING_WEIGHT_D3FEND),
freshness=float(settings.SCORING_WEIGHT_FRESHNESS),
platform_diversity=float(settings.SCORING_WEIGHT_PLATFORM_DIVERSITY),
recency=float(
getattr(settings, "SCORING_WEIGHT_RECENCY", settings.SCORING_WEIGHT_FRESHNESS)
),
severity=float(
getattr(settings, "SCORING_WEIGHT_SEVERITY", settings.SCORING_WEIGHT_PLATFORM_DIVERSITY)
),
)
@@ -50,25 +54,26 @@ def update_scoring_weights(
tests: float | None = None,
detection_rules: float | None = None,
d3fend: float | None = None,
recency: float | None = None,
severity: float | None = None,
freshness: float | None = None,
platform_diversity: float | None = None,
updated_by: uuid.UUID | None = None,
) -> dict[str, Any]:
"""Upsert scoring weights into the database.
"""Upsert scoring weights. Does not commit."""
if freshness is not None and recency is None:
recency = freshness
if platform_diversity is not None and severity is None:
severity = platform_diversity
Only provided fields are overwritten; ``None`` values keep the
current (or default) value. Validates via ``ScoringWeights``
before persisting.
Returns a dict with ``weights`` and ``total``.
"""
current = get_scoring_weights(db)
new = ScoringWeights(
tests=tests if tests is not None else current.tests,
detection_rules=detection_rules if detection_rules is not None else current.detection_rules,
d3fend=d3fend if d3fend is not None else current.d3fend,
freshness=freshness if freshness is not None else current.freshness,
platform_diversity=platform_diversity if platform_diversity is not None else current.platform_diversity,
recency=recency if recency is not None else current.recency,
severity=severity if severity is not None else current.severity,
)
row = db.query(ScoringConfig).first()
@@ -79,10 +84,17 @@ def update_scoring_weights(
row.weight_tests = new.tests
row.weight_detection_rules = new.detection_rules
row.weight_d3fend = new.d3fend
row.weight_freshness = new.freshness
row.weight_platform_diversity = new.platform_diversity
if hasattr(row, "weight_recency"):
row.weight_recency = new.recency
elif hasattr(row, "weight_freshness"):
row.weight_freshness = new.recency
if hasattr(row, "weight_severity"):
row.weight_severity = new.severity
elif hasattr(row, "weight_platform_diversity"):
row.weight_platform_diversity = new.severity
if updated_by is not None and hasattr(row, "updated_by"):
row.updated_by = updated_by
# Does not commit; caller (router) uses UnitOfWork.
return _weights_dict(new)
@@ -96,10 +108,15 @@ def _weights_dict(w: ScoringWeights) -> dict[str, Any]:
"tests": w.tests,
"detection_rules": w.detection_rules,
"d3fend": w.d3fend,
"freshness": w.freshness,
"platform_diversity": w.platform_diversity,
"recency": w.recency,
"severity": w.severity,
# Legacy keys for older clients
"freshness": w.recency,
"platform_diversity": w.severity,
}
return {
"weights": weights,
"total": sum(weights.values()),
"total": sum(
[w.tests, w.detection_rules, w.d3fend, w.recency, w.severity]
),
}

View File

@@ -9,7 +9,7 @@ fixed number of aggregated queries so that organisation-wide calculations
never produce N+1 traffic.
"""
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from typing import Optional
from sqlalchemy import case, func
@@ -25,6 +25,61 @@ from app.models.threat_actor import ThreatActor, ThreatActorTechnique
from app.models.enums import TestState, TestResult
from app.services.scoring_config_service import get_scoring_weights
_SEVERITY_FACTORS: dict[str, float] = {
"critical": 1.0,
"high": 0.85,
"medium": 0.65,
"low": 0.5,
}
def _recency_factor(last_tested: datetime | None) -> float:
"""Decay factor: 1.0 when recent, decreasing over time."""
if not last_tested:
return 0.0
now = datetime.now(timezone.utc)
tested = last_tested
if tested.tzinfo is None:
tested = tested.replace(tzinfo=timezone.utc)
days_ago = (now - tested).days
if days_ago <= 90:
return 1.0
if days_ago <= 180:
return 0.8
if days_ago <= 365:
return 0.5
return 0.2
def _severity_factor(severity_label: str | None) -> float:
"""Map template severity to a 01 multiplier."""
if not severity_label:
return 0.7
return _SEVERITY_FACTORS.get(severity_label.lower(), 0.7)
def _max_severity_by_mitre(db: Session) -> dict[str, str]:
"""Highest severity label per MITRE id from active test templates."""
from app.models.test_template import TestTemplate
order = {"critical": 4, "high": 3, "medium": 2, "low": 1}
rows = (
db.query(TestTemplate.mitre_technique_id, TestTemplate.severity)
.filter(
TestTemplate.is_active == True, # noqa: E712
TestTemplate.severity.isnot(None),
)
.all()
)
best: dict[str, str] = {}
for mitre_id, severity in rows:
if not mitre_id or not severity:
continue
current = best.get(mitre_id)
if current is None or order.get(severity.lower(), 0) > order.get(current.lower(), 0):
best[mitre_id] = severity
return best
# ── Bulk scoring helpers (5 queries for ALL techniques) ───────────────
@@ -45,8 +100,15 @@ def bulk_technique_scores(db: Session) -> dict:
w_tests = w.tests
w_detection = w.detection_rules
w_d3fend = w.d3fend
w_freshness = w.freshness
w_diversity = w.platform_diversity
w_recency = w.recency
w_severity = w.severity
severity_by_mitre = _max_severity_by_mitre(db)
last_validated = func.coalesce(
Test.blue_validated_at,
Test.red_validated_at,
Test.created_at,
)
# Q1: test stats grouped by technique_id
test_rows = (
@@ -56,8 +118,7 @@ def bulk_technique_scores(db: Session) -> dict:
func.count(
case((Test.detection_result == TestResult.detected, Test.id))
).label("detected_count"),
func.max(Test.red_validated_at).label("latest_validated_at"),
func.count(func.distinct(Test.platform)).label("platform_count"),
func.max(last_validated).label("latest_validated_at"),
)
.filter(Test.state == TestState.validated)
.group_by(Test.technique_id)
@@ -70,7 +131,6 @@ def bulk_technique_scores(db: Session) -> dict:
"validated": row.validated_count,
"detected": row.detected_count,
"latest_validated_at": row.latest_validated_at,
"platform_count": row.platform_count,
}
# Q2: active detection rules per mitre_id
@@ -114,7 +174,6 @@ def bulk_technique_scores(db: Session) -> dict:
# Q5: all techniques
techniques = db.query(Technique).all()
now = datetime.utcnow()
results: dict = {}
for tech in techniques:
@@ -122,7 +181,6 @@ def bulk_technique_scores(db: Session) -> dict:
validated = ts.get("validated", 0)
detected = ts.get("detected", 0)
latest_at = ts.get("latest_validated_at")
plat_count = ts.get("platform_count", 0)
breakdown = {}
@@ -177,47 +235,41 @@ def bulk_technique_scores(db: Session) -> dict:
),
}
# 4. Freshness
# 4. Recency decay
recency_mult = _recency_factor(latest_at)
recency_score = round(recency_mult * w_recency, 1)
if latest_at:
days_ago = (now - latest_at).days
if days_ago < 90:
freshness_pct = 1.0
elif days_ago < 180:
freshness_pct = 0.5
tested = latest_at
if tested.tzinfo is None:
days_ago = (datetime.utcnow() - tested).days
else:
freshness_pct = 0.0
freshness_score = round(freshness_pct * w_freshness, 1)
freshness_detail = f"Last test {days_ago} days ago"
days_ago = (datetime.now(timezone.utc) - tested.astimezone(timezone.utc)).days
recency_detail = f"Last validated {days_ago} days ago (factor {recency_mult})"
else:
freshness_score = 0
freshness_detail = "No validated tests"
breakdown["freshness"] = {
"score": freshness_score,
"max": w_freshness,
"detail": freshness_detail,
recency_detail = "No validated tests"
breakdown["recency"] = {
"score": recency_score,
"max": w_recency,
"detail": recency_detail,
}
# 5. Platform diversity
available = tech.platforms or []
total_platforms = len(available) if available else 3
if total_platforms > 0 and plat_count > 0:
diversity_score = round(
min(plat_count / total_platforms, 1.0) * w_diversity, 1,
)
else:
diversity_score = 0
breakdown["platform_diversity"] = {
"score": diversity_score,
"max": w_diversity,
# 5. Severity / criticality (template-driven)
sev_label = severity_by_mitre.get(tech.mitre_id)
sev_mult = _severity_factor(sev_label)
severity_score = round(sev_mult * w_severity, 1)
breakdown["severity"] = {
"score": severity_score,
"max": w_severity,
"detail": (
f"{plat_count}/{total_platforms} platforms covered"
if plat_count > 0 else "No platforms tested"
f"Template severity: {sev_label} (factor {sev_mult})"
if sev_label
else "No severity template (default factor)"
),
}
total = min(
test_score + detection_score + d3fend_score
+ freshness_score + diversity_score,
+ recency_score + severity_score,
100,
)
results[tech.id] = {
@@ -265,8 +317,9 @@ def calculate_technique_score(technique: Technique, db: Session) -> dict:
w_tests = w.tests
w_detection = w.detection_rules
w_d3fend = w.d3fend
w_freshness = w.freshness
w_diversity = w.platform_diversity
w_recency = w.recency
w_severity = w.severity
severity_by_mitre = _max_severity_by_mitre(db)
breakdown = {}
@@ -360,65 +413,50 @@ def calculate_technique_score(technique: Technique, db: Session) -> dict:
else "No D3FEND mappings",
}
# ── 4. Freshness ──────────────────────────────────────────────
most_recent_test = (
db.query(func.max(Test.red_validated_at))
.filter(
Test.technique_id == technique.id,
Test.state == TestState.validated,
)
.scalar()
)
# ── 4. Recency ────────────────────────────────────────────────
most_recent_test = None
for t in validated_tests:
candidate = t.blue_validated_at or t.red_validated_at or t.created_at
if candidate and (most_recent_test is None or candidate > most_recent_test):
most_recent_test = candidate
now = datetime.utcnow()
recency_mult = _recency_factor(most_recent_test)
recency_score = round(recency_mult * w_recency, 1)
if most_recent_test:
days_ago = (now - most_recent_test).days
if days_ago < 90:
freshness_pct = 1.0
elif days_ago < 180:
freshness_pct = 0.5
else:
freshness_pct = 0.0
freshness_score = round(freshness_pct * w_freshness, 1)
freshness_detail = f"Last test {days_ago} days ago"
days_ago = (
datetime.now(timezone.utc) - (
most_recent_test.replace(tzinfo=timezone.utc)
if most_recent_test.tzinfo is None
else most_recent_test.astimezone(timezone.utc)
)
).days
recency_detail = f"Last validated {days_ago} days ago (factor {recency_mult})"
else:
freshness_pct = 0
freshness_score = 0
freshness_detail = "No validated tests"
recency_detail = "No validated tests"
breakdown["freshness"] = {
"score": freshness_score,
"max": w_freshness,
"detail": freshness_detail,
breakdown["recency"] = {
"score": recency_score,
"max": w_recency,
"detail": recency_detail,
}
# ── 5. Platform diversity ─────────────────────────────────────
available_platforms = technique.platforms or []
total_platforms = len(available_platforms) if available_platforms else 3
tested_platforms = set()
for t in validated_tests:
if t.platform:
tested_platforms.add(t.platform.lower())
if total_platforms > 0 and tested_platforms:
diversity_ratio = min(len(tested_platforms) / total_platforms, 1.0)
diversity_score = round(diversity_ratio * w_diversity, 1)
else:
diversity_ratio = 0
diversity_score = 0
breakdown["platform_diversity"] = {
"score": diversity_score,
"max": w_diversity,
"detail": f"{len(tested_platforms)}/{total_platforms} platforms covered"
if tested_platforms
else "No platforms tested",
# ── 5. Severity ───────────────────────────────────────────────
sev_label = severity_by_mitre.get(technique.mitre_id)
sev_mult = _severity_factor(sev_label)
severity_score = round(sev_mult * w_severity, 1)
breakdown["severity"] = {
"score": severity_score,
"max": w_severity,
"detail": (
f"Template severity: {sev_label} (factor {sev_mult})"
if sev_label
else "No severity template (default factor)"
),
}
# ── Total ─────────────────────────────────────────────────────
total = min(
test_score + detection_score + d3fend_score + freshness_score + diversity_score,
test_score + detection_score + d3fend_score + recency_score + severity_score,
100,
)

View File

@@ -9,7 +9,8 @@ number of SQL queries regardless of technique count.
import logging
import uuid
from datetime import datetime
from collections import defaultdict
from datetime import datetime, timedelta, timezone
from sqlalchemy import func
from sqlalchemy.orm import Session
@@ -43,6 +44,11 @@ def serialize_snapshot_summary(snap: CoverageSnapshot) -> dict:
"not_covered_count": snap.not_covered_count,
"in_progress_count": snap.in_progress_count,
"not_evaluated_count": snap.not_evaluated_count,
"coverage_percentage": getattr(snap, "coverage_percentage", 0.0),
"by_tactic": getattr(snap, "by_tactic", None) or {},
"by_status": getattr(snap, "by_status", None) or {},
"stale_count": getattr(snap, "stale_count", 0),
"never_tested_count": getattr(snap, "never_tested_count", 0),
"created_by": str(snap.created_by) if snap.created_by else None,
"created_at": snap.created_at.isoformat() if snap.created_at else None,
}
@@ -148,6 +154,13 @@ def create_snapshot(
not_covered_count = 0
in_progress_count = 0
not_evaluated_count = 0
stale_count = 0
never_tested_count = 0
by_tactic: dict[str, dict] = defaultdict(
lambda: {"total": 0, "validated": 0, "partial": 0, "score_sum": 0.0}
)
by_status: dict[str, int] = defaultdict(int)
technique_rows: list[dict] = []
@@ -170,15 +183,43 @@ def create_snapshot(
not_evaluated_count += 1
entry = scores_map.get(tech.id, {})
score = entry.get("total_score", 0)
technique_rows.append({
"technique_id": tech.id,
"mitre_id": tech.mitre_id,
"status": status_value,
"score": entry.get("total_score", 0),
"score": score,
})
by_status[status_value] += 1
tactic_key = tech.tactic or "unknown"
bucket = by_tactic[tactic_key]
bucket["total"] += 1
bucket["score_sum"] += score
if status_value == "validated":
bucket["validated"] += 1
elif status_value == "partial":
bucket["partial"] += 1
if status_value == "not_evaluated":
never_tested_count += 1
if tech.review_required:
stale_count += 1
org_data = calculate_organization_score(db)
org_score = org_data.get("overall_score", 0)
total_techniques = len(techniques) or 1
coverage_pct = round((validated_count / total_techniques) * 100, 1)
by_tactic_out = {
tactic: {
"total": data["total"],
"validated": data["validated"],
"partial": data["partial"],
"average_score": round(data["score_sum"] / data["total"], 1) if data["total"] else 0,
}
for tactic, data in by_tactic.items()
}
snapshot = CoverageSnapshot(
name=name,
@@ -189,6 +230,11 @@ def create_snapshot(
not_covered_count=not_covered_count,
in_progress_count=in_progress_count,
not_evaluated_count=not_evaluated_count,
coverage_percentage=coverage_pct,
by_tactic=by_tactic_out,
by_status=dict(by_status),
stale_count=stale_count,
never_tested_count=never_tested_count,
created_by=user_id,
)
db.add(snapshot)
@@ -320,6 +366,37 @@ def compare_snapshots(
}
# ---------------------------------------------------------------------------
# Coverage evolution (trends)
# ---------------------------------------------------------------------------
def get_coverage_evolution(db: Session, *, months: int = 12) -> list[dict]:
"""Return snapshot trend points for the last *months* months."""
cutoff = datetime.now(timezone.utc) - timedelta(days=months * 30)
snapshots = (
db.query(CoverageSnapshot)
.filter(CoverageSnapshot.created_at >= cutoff)
.order_by(CoverageSnapshot.created_at.asc())
.all()
)
return [
{
"date": snap.created_at.isoformat() if snap.created_at else None,
"name": snap.name,
"org_score": snap.organization_score,
"coverage_pct": getattr(snap, "coverage_percentage", 0.0),
"by_tactic": getattr(snap, "by_tactic", None) or {},
"by_status": getattr(snap, "by_status", None) or {},
"stale_count": getattr(snap, "stale_count", 0),
"never_tested_count": getattr(snap, "never_tested_count", 0),
"validated_count": snap.validated_count,
"total_techniques": snap.total_techniques,
}
for snap in snapshots
]
# ---------------------------------------------------------------------------
# Cleanup
# ---------------------------------------------------------------------------