feat(scoring): composite recency decay and severity weights persisted in DB [FASE-5.1]

This commit is contained in:
2026-05-18 15:07:12 +02:00
parent 2ee59d4e18
commit 05b221a22d
13 changed files with 588 additions and 154 deletions

View File

@@ -0,0 +1,117 @@
"""Phase 5: scoring recency/severity columns and snapshot breakdown fields.
Revision ID: b030phase5
Revises: b029phase3
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
revision: str = "b030phase5"
down_revision: Union[str, None] = "b029phase3"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def _column_names(table: str) -> set[str]:
bind = op.get_bind()
insp = sa.inspect(bind)
return {c["name"] for c in insp.get_columns(table)}
def upgrade() -> None:
snap_cols = _column_names("coverage_snapshots")
if "by_tactic" not in snap_cols:
op.add_column(
"coverage_snapshots",
sa.Column("by_tactic", postgresql.JSONB(), nullable=False, server_default="{}"),
)
if "by_status" not in snap_cols:
op.add_column(
"coverage_snapshots",
sa.Column("by_status", postgresql.JSONB(), nullable=False, server_default="{}"),
)
if "stale_count" not in snap_cols:
op.add_column(
"coverage_snapshots",
sa.Column("stale_count", sa.Integer(), nullable=False, server_default="0"),
)
if "never_tested_count" not in snap_cols:
op.add_column(
"coverage_snapshots",
sa.Column("never_tested_count", sa.Integer(), nullable=False, server_default="0"),
)
if "coverage_percentage" not in snap_cols:
op.add_column(
"coverage_snapshots",
sa.Column("coverage_percentage", sa.Float(), nullable=False, server_default="0"),
)
cfg_cols = _column_names("scoring_config")
if "weight_recency" not in cfg_cols and "weight_freshness" in cfg_cols:
op.alter_column(
"scoring_config",
"weight_freshness",
new_column_name="weight_recency",
)
cfg_cols.remove("weight_freshness")
cfg_cols.add("weight_recency")
elif "weight_recency" not in cfg_cols:
op.add_column(
"scoring_config",
sa.Column("weight_recency", sa.Float(), nullable=False, server_default="10.0"),
)
if "weight_severity" not in cfg_cols and "weight_platform_diversity" in cfg_cols:
op.alter_column(
"scoring_config",
"weight_platform_diversity",
new_column_name="weight_severity",
)
elif "weight_severity" not in cfg_cols:
op.add_column(
"scoring_config",
sa.Column("weight_severity", sa.Float(), nullable=False, server_default="10.0"),
)
if "updated_by" not in cfg_cols:
op.add_column(
"scoring_config",
sa.Column(
"updated_by",
postgresql.UUID(as_uuid=True),
sa.ForeignKey("users.id", ondelete="SET NULL"),
nullable=True,
),
)
def downgrade() -> None:
cfg_cols = _column_names("scoring_config")
if "updated_by" in cfg_cols:
op.drop_column("scoring_config", "updated_by")
if "weight_severity" in cfg_cols:
op.alter_column(
"scoring_config",
"weight_severity",
new_column_name="weight_platform_diversity",
)
if "weight_recency" in cfg_cols:
op.alter_column(
"scoring_config",
"weight_recency",
new_column_name="weight_freshness",
)
for col in (
"coverage_percentage",
"never_tested_count",
"stale_count",
"by_status",
"by_tactic",
):
if col in _column_names("coverage_snapshots"):
op.drop_column("coverage_snapshots", col)

View File

@@ -72,9 +72,12 @@ class Settings(BaseSettings):
# ── Scoring weights (must sum to 100) ──────────────────────────── # ── Scoring weights (must sum to 100) ────────────────────────────
SCORING_WEIGHT_TESTS: int = 40 SCORING_WEIGHT_TESTS: int = 40
SCORING_WEIGHT_DETECTION_RULES: int = 20 SCORING_WEIGHT_DETECTION_RULES: int = 25
SCORING_WEIGHT_D3FEND: int = 15 SCORING_WEIGHT_D3FEND: int = 15
SCORING_WEIGHT_FRESHNESS: int = 15 SCORING_WEIGHT_RECENCY: int = 10
SCORING_WEIGHT_SEVERITY: int = 10
# Legacy env names (mapped in scoring_config_service)
SCORING_WEIGHT_FRESHNESS: int = 10
SCORING_WEIGHT_PLATFORM_DIVERSITY: int = 10 SCORING_WEIGHT_PLATFORM_DIVERSITY: int = 10
class Config: class Config:

View File

@@ -15,16 +15,16 @@ class ScoringWeights:
tests: float tests: float
detection_rules: float detection_rules: float
d3fend: float d3fend: float
freshness: float recency: float
platform_diversity: float severity: float
def __post_init__(self) -> None: def __post_init__(self) -> None:
fields = [ fields = [
self.tests, self.tests,
self.detection_rules, self.detection_rules,
self.d3fend, self.d3fend,
self.freshness, self.recency,
self.platform_diversity, self.severity,
] ]
for f in fields: for f in fields:
if f < 0: if f < 0:
@@ -43,6 +43,15 @@ class ScoringWeights:
tests=40.0, tests=40.0,
detection_rules=25.0, detection_rules=25.0,
d3fend=15.0, d3fend=15.0,
freshness=10.0, recency=10.0,
platform_diversity=10.0, severity=10.0,
) )
# Backward-compatible aliases for older API payloads
@property
def freshness(self) -> float:
return self.recency
@property
def platform_diversity(self) -> float:
return self.severity

View File

@@ -10,7 +10,7 @@ from sqlalchemy import (
Column, String, Float, Integer, DateTime, Column, String, Float, Integer, DateTime,
ForeignKey, Index, func, ForeignKey, Index, func,
) )
from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.dialects.postgresql import JSONB, UUID
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
from app.database import Base from app.database import Base
@@ -30,6 +30,11 @@ class CoverageSnapshot(Base):
not_covered_count = Column(Integer, nullable=False) not_covered_count = Column(Integer, nullable=False)
in_progress_count = Column(Integer, nullable=False) in_progress_count = Column(Integer, nullable=False)
not_evaluated_count = Column(Integer, nullable=False) not_evaluated_count = Column(Integer, nullable=False)
coverage_percentage = Column(Float, nullable=False, default=0.0)
by_tactic = Column(JSONB, nullable=False, default=dict)
by_status = Column(JSONB, nullable=False, default=dict)
stale_count = Column(Integer, nullable=False, default=0)
never_tested_count = Column(Integer, nullable=False, default=0)
created_by = Column( created_by = Column(
UUID(as_uuid=True), UUID(as_uuid=True),
ForeignKey("users.id", ondelete="SET NULL"), ForeignKey("users.id", ondelete="SET NULL"),

View File

@@ -1,12 +1,8 @@
"""ScoringConfig — single-row table for persisted scoring weights. """ScoringConfig — single-row table for persisted scoring weights."""
Replaces the mutable-settings approach where PATCH /scores/config
mutated the in-process ``Settings`` object (lost on restart).
"""
import uuid import uuid
from sqlalchemy import Column, Float, DateTime, func from sqlalchemy import Column, Float, DateTime, ForeignKey, func
from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.dialects.postgresql import UUID
from app.database import Base from app.database import Base
@@ -17,8 +13,13 @@ class ScoringConfig(Base):
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
weight_tests = Column(Float, nullable=False, default=40.0) weight_tests = Column(Float, nullable=False, default=40.0)
weight_detection_rules = Column(Float, nullable=False, default=20.0) weight_detection_rules = Column(Float, nullable=False, default=25.0)
weight_d3fend = Column(Float, nullable=False, default=15.0) weight_d3fend = Column(Float, nullable=False, default=15.0)
weight_freshness = Column(Float, nullable=False, default=15.0) weight_recency = Column(Float, nullable=False, default=10.0)
weight_platform_diversity = Column(Float, nullable=False, default=10.0) weight_severity = Column(Float, nullable=False, default=10.0)
updated_by = Column(
UUID(as_uuid=True),
ForeignKey("users.id", ondelete="SET NULL"),
nullable=True,
)
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now()) updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())

View File

@@ -113,6 +113,8 @@ class ScoringConfigUpdate(BaseModel):
tests: Optional[float] = None tests: Optional[float] = None
detection_rules: Optional[float] = None detection_rules: Optional[float] = None
d3fend: Optional[float] = None d3fend: Optional[float] = None
recency: Optional[float] = None
severity: Optional[float] = None
freshness: Optional[float] = None freshness: Optional[float] = None
platform_diversity: Optional[float] = None platform_diversity: Optional[float] = None
@@ -134,8 +136,11 @@ def update_scoring_config(
tests=payload.tests, tests=payload.tests,
detection_rules=payload.detection_rules, detection_rules=payload.detection_rules,
d3fend=payload.d3fend, d3fend=payload.d3fend,
recency=payload.recency,
severity=payload.severity,
freshness=payload.freshness, freshness=payload.freshness,
platform_diversity=payload.platform_diversity, platform_diversity=payload.platform_diversity,
updated_by=current_user.id,
) )
uow.commit() uow.commit()

View File

@@ -21,6 +21,7 @@ from app.services.snapshot_service import (
create_snapshot, create_snapshot,
compare_snapshots, compare_snapshots,
cleanup_old_snapshots, cleanup_old_snapshots,
get_coverage_evolution,
serialize_snapshot_summary, serialize_snapshot_summary,
list_snapshots as list_snapshots_svc, list_snapshots as list_snapshots_svc,
get_snapshot_or_raise, get_snapshot_or_raise,
@@ -82,6 +83,21 @@ def create_snapshot_endpoint(
return serialize_snapshot_summary(snapshot) return serialize_snapshot_summary(snapshot)
# ---------------------------------------------------------------------------
# GET /snapshots/evolution — Coverage trend over time
# ---------------------------------------------------------------------------
@router.get("/evolution")
def coverage_evolution(
months: int = Query(12, ge=1, le=36),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Return coverage snapshots for trend charts (last *months* months)."""
return get_coverage_evolution(db, months=months)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# GET /snapshots/compare — Compare two snapshots # GET /snapshots/compare — Compare two snapshots
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------

View File

@@ -1,14 +1,8 @@
"""Scoring configuration persistence service. """Scoring configuration persistence service."""
Reads and writes scoring weights from the ``scoring_config`` table.
Falls back to environment-variable defaults (from ``Settings``) when
no row has been persisted yet.
This module is framework-agnostic: no FastAPI imports.
"""
from __future__ import annotations from __future__ import annotations
import uuid
from typing import Any from typing import Any
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -18,29 +12,39 @@ from app.domain.value_objects.scoring_weights import ScoringWeights
from app.models.scoring_config import ScoringConfig from app.models.scoring_config import ScoringConfig
def get_scoring_weights(db: Session) -> ScoringWeights: def _row_recency(row: ScoringConfig) -> float:
"""Return the active scoring weights. return float(getattr(row, "weight_recency", None) or getattr(row, "weight_freshness", 10.0))
Reads the single ``scoring_config`` row. If the table is empty
(first run or migration just applied), falls back to the values def _row_severity(row: ScoringConfig) -> float:
from the environment / ``Settings``. return float(
""" getattr(row, "weight_severity", None)
or getattr(row, "weight_platform_diversity", 10.0)
)
def get_scoring_weights(db: Session) -> ScoringWeights:
"""Return the active scoring weights from the database or env defaults."""
row = db.query(ScoringConfig).first() row = db.query(ScoringConfig).first()
if row is not None: if row is not None:
return ScoringWeights( return ScoringWeights(
tests=row.weight_tests, tests=row.weight_tests,
detection_rules=row.weight_detection_rules, detection_rules=row.weight_detection_rules,
d3fend=row.weight_d3fend, d3fend=row.weight_d3fend,
freshness=row.weight_freshness, recency=_row_recency(row),
platform_diversity=row.weight_platform_diversity, severity=_row_severity(row),
) )
return ScoringWeights( return ScoringWeights(
tests=float(settings.SCORING_WEIGHT_TESTS), tests=float(settings.SCORING_WEIGHT_TESTS),
detection_rules=float(settings.SCORING_WEIGHT_DETECTION_RULES), detection_rules=float(settings.SCORING_WEIGHT_DETECTION_RULES),
d3fend=float(settings.SCORING_WEIGHT_D3FEND), d3fend=float(settings.SCORING_WEIGHT_D3FEND),
freshness=float(settings.SCORING_WEIGHT_FRESHNESS), recency=float(
platform_diversity=float(settings.SCORING_WEIGHT_PLATFORM_DIVERSITY), getattr(settings, "SCORING_WEIGHT_RECENCY", settings.SCORING_WEIGHT_FRESHNESS)
),
severity=float(
getattr(settings, "SCORING_WEIGHT_SEVERITY", settings.SCORING_WEIGHT_PLATFORM_DIVERSITY)
),
) )
@@ -50,25 +54,26 @@ def update_scoring_weights(
tests: float | None = None, tests: float | None = None,
detection_rules: float | None = None, detection_rules: float | None = None,
d3fend: float | None = None, d3fend: float | None = None,
recency: float | None = None,
severity: float | None = None,
freshness: float | None = None, freshness: float | None = None,
platform_diversity: float | None = None, platform_diversity: float | None = None,
updated_by: uuid.UUID | None = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Upsert scoring weights into the database. """Upsert scoring weights. Does not commit."""
if freshness is not None and recency is None:
recency = freshness
if platform_diversity is not None and severity is None:
severity = platform_diversity
Only provided fields are overwritten; ``None`` values keep the
current (or default) value. Validates via ``ScoringWeights``
before persisting.
Returns a dict with ``weights`` and ``total``.
"""
current = get_scoring_weights(db) current = get_scoring_weights(db)
new = ScoringWeights( new = ScoringWeights(
tests=tests if tests is not None else current.tests, tests=tests if tests is not None else current.tests,
detection_rules=detection_rules if detection_rules is not None else current.detection_rules, detection_rules=detection_rules if detection_rules is not None else current.detection_rules,
d3fend=d3fend if d3fend is not None else current.d3fend, d3fend=d3fend if d3fend is not None else current.d3fend,
freshness=freshness if freshness is not None else current.freshness, recency=recency if recency is not None else current.recency,
platform_diversity=platform_diversity if platform_diversity is not None else current.platform_diversity, severity=severity if severity is not None else current.severity,
) )
row = db.query(ScoringConfig).first() row = db.query(ScoringConfig).first()
@@ -79,10 +84,17 @@ def update_scoring_weights(
row.weight_tests = new.tests row.weight_tests = new.tests
row.weight_detection_rules = new.detection_rules row.weight_detection_rules = new.detection_rules
row.weight_d3fend = new.d3fend row.weight_d3fend = new.d3fend
row.weight_freshness = new.freshness if hasattr(row, "weight_recency"):
row.weight_platform_diversity = new.platform_diversity row.weight_recency = new.recency
elif hasattr(row, "weight_freshness"):
row.weight_freshness = new.recency
if hasattr(row, "weight_severity"):
row.weight_severity = new.severity
elif hasattr(row, "weight_platform_diversity"):
row.weight_platform_diversity = new.severity
if updated_by is not None and hasattr(row, "updated_by"):
row.updated_by = updated_by
# Does not commit; caller (router) uses UnitOfWork.
return _weights_dict(new) return _weights_dict(new)
@@ -96,10 +108,15 @@ def _weights_dict(w: ScoringWeights) -> dict[str, Any]:
"tests": w.tests, "tests": w.tests,
"detection_rules": w.detection_rules, "detection_rules": w.detection_rules,
"d3fend": w.d3fend, "d3fend": w.d3fend,
"freshness": w.freshness, "recency": w.recency,
"platform_diversity": w.platform_diversity, "severity": w.severity,
# Legacy keys for older clients
"freshness": w.recency,
"platform_diversity": w.severity,
} }
return { return {
"weights": weights, "weights": weights,
"total": sum(weights.values()), "total": sum(
[w.tests, w.detection_rules, w.d3fend, w.recency, w.severity]
),
} }

View File

@@ -9,7 +9,7 @@ fixed number of aggregated queries so that organisation-wide calculations
never produce N+1 traffic. never produce N+1 traffic.
""" """
from datetime import datetime, timedelta from datetime import datetime, timedelta, timezone
from typing import Optional from typing import Optional
from sqlalchemy import case, func from sqlalchemy import case, func
@@ -25,6 +25,61 @@ from app.models.threat_actor import ThreatActor, ThreatActorTechnique
from app.models.enums import TestState, TestResult from app.models.enums import TestState, TestResult
from app.services.scoring_config_service import get_scoring_weights from app.services.scoring_config_service import get_scoring_weights
_SEVERITY_FACTORS: dict[str, float] = {
"critical": 1.0,
"high": 0.85,
"medium": 0.65,
"low": 0.5,
}
def _recency_factor(last_tested: datetime | None) -> float:
"""Decay factor: 1.0 when recent, decreasing over time."""
if not last_tested:
return 0.0
now = datetime.now(timezone.utc)
tested = last_tested
if tested.tzinfo is None:
tested = tested.replace(tzinfo=timezone.utc)
days_ago = (now - tested).days
if days_ago <= 90:
return 1.0
if days_ago <= 180:
return 0.8
if days_ago <= 365:
return 0.5
return 0.2
def _severity_factor(severity_label: str | None) -> float:
"""Map template severity to a 01 multiplier."""
if not severity_label:
return 0.7
return _SEVERITY_FACTORS.get(severity_label.lower(), 0.7)
def _max_severity_by_mitre(db: Session) -> dict[str, str]:
"""Highest severity label per MITRE id from active test templates."""
from app.models.test_template import TestTemplate
order = {"critical": 4, "high": 3, "medium": 2, "low": 1}
rows = (
db.query(TestTemplate.mitre_technique_id, TestTemplate.severity)
.filter(
TestTemplate.is_active == True, # noqa: E712
TestTemplate.severity.isnot(None),
)
.all()
)
best: dict[str, str] = {}
for mitre_id, severity in rows:
if not mitre_id or not severity:
continue
current = best.get(mitre_id)
if current is None or order.get(severity.lower(), 0) > order.get(current.lower(), 0):
best[mitre_id] = severity
return best
# ── Bulk scoring helpers (5 queries for ALL techniques) ─────────────── # ── Bulk scoring helpers (5 queries for ALL techniques) ───────────────
@@ -45,8 +100,15 @@ def bulk_technique_scores(db: Session) -> dict:
w_tests = w.tests w_tests = w.tests
w_detection = w.detection_rules w_detection = w.detection_rules
w_d3fend = w.d3fend w_d3fend = w.d3fend
w_freshness = w.freshness w_recency = w.recency
w_diversity = w.platform_diversity w_severity = w.severity
severity_by_mitre = _max_severity_by_mitre(db)
last_validated = func.coalesce(
Test.blue_validated_at,
Test.red_validated_at,
Test.created_at,
)
# Q1: test stats grouped by technique_id # Q1: test stats grouped by technique_id
test_rows = ( test_rows = (
@@ -56,8 +118,7 @@ def bulk_technique_scores(db: Session) -> dict:
func.count( func.count(
case((Test.detection_result == TestResult.detected, Test.id)) case((Test.detection_result == TestResult.detected, Test.id))
).label("detected_count"), ).label("detected_count"),
func.max(Test.red_validated_at).label("latest_validated_at"), func.max(last_validated).label("latest_validated_at"),
func.count(func.distinct(Test.platform)).label("platform_count"),
) )
.filter(Test.state == TestState.validated) .filter(Test.state == TestState.validated)
.group_by(Test.technique_id) .group_by(Test.technique_id)
@@ -70,7 +131,6 @@ def bulk_technique_scores(db: Session) -> dict:
"validated": row.validated_count, "validated": row.validated_count,
"detected": row.detected_count, "detected": row.detected_count,
"latest_validated_at": row.latest_validated_at, "latest_validated_at": row.latest_validated_at,
"platform_count": row.platform_count,
} }
# Q2: active detection rules per mitre_id # Q2: active detection rules per mitre_id
@@ -114,7 +174,6 @@ def bulk_technique_scores(db: Session) -> dict:
# Q5: all techniques # Q5: all techniques
techniques = db.query(Technique).all() techniques = db.query(Technique).all()
now = datetime.utcnow()
results: dict = {} results: dict = {}
for tech in techniques: for tech in techniques:
@@ -122,7 +181,6 @@ def bulk_technique_scores(db: Session) -> dict:
validated = ts.get("validated", 0) validated = ts.get("validated", 0)
detected = ts.get("detected", 0) detected = ts.get("detected", 0)
latest_at = ts.get("latest_validated_at") latest_at = ts.get("latest_validated_at")
plat_count = ts.get("platform_count", 0)
breakdown = {} breakdown = {}
@@ -177,47 +235,41 @@ def bulk_technique_scores(db: Session) -> dict:
), ),
} }
# 4. Freshness # 4. Recency decay
recency_mult = _recency_factor(latest_at)
recency_score = round(recency_mult * w_recency, 1)
if latest_at: if latest_at:
days_ago = (now - latest_at).days tested = latest_at
if days_ago < 90: if tested.tzinfo is None:
freshness_pct = 1.0 days_ago = (datetime.utcnow() - tested).days
elif days_ago < 180:
freshness_pct = 0.5
else: else:
freshness_pct = 0.0 days_ago = (datetime.now(timezone.utc) - tested.astimezone(timezone.utc)).days
freshness_score = round(freshness_pct * w_freshness, 1) recency_detail = f"Last validated {days_ago} days ago (factor {recency_mult})"
freshness_detail = f"Last test {days_ago} days ago"
else: else:
freshness_score = 0 recency_detail = "No validated tests"
freshness_detail = "No validated tests" breakdown["recency"] = {
breakdown["freshness"] = { "score": recency_score,
"score": freshness_score, "max": w_recency,
"max": w_freshness, "detail": recency_detail,
"detail": freshness_detail,
} }
# 5. Platform diversity # 5. Severity / criticality (template-driven)
available = tech.platforms or [] sev_label = severity_by_mitre.get(tech.mitre_id)
total_platforms = len(available) if available else 3 sev_mult = _severity_factor(sev_label)
if total_platforms > 0 and plat_count > 0: severity_score = round(sev_mult * w_severity, 1)
diversity_score = round( breakdown["severity"] = {
min(plat_count / total_platforms, 1.0) * w_diversity, 1, "score": severity_score,
) "max": w_severity,
else:
diversity_score = 0
breakdown["platform_diversity"] = {
"score": diversity_score,
"max": w_diversity,
"detail": ( "detail": (
f"{plat_count}/{total_platforms} platforms covered" f"Template severity: {sev_label} (factor {sev_mult})"
if plat_count > 0 else "No platforms tested" if sev_label
else "No severity template (default factor)"
), ),
} }
total = min( total = min(
test_score + detection_score + d3fend_score test_score + detection_score + d3fend_score
+ freshness_score + diversity_score, + recency_score + severity_score,
100, 100,
) )
results[tech.id] = { results[tech.id] = {
@@ -265,8 +317,9 @@ def calculate_technique_score(technique: Technique, db: Session) -> dict:
w_tests = w.tests w_tests = w.tests
w_detection = w.detection_rules w_detection = w.detection_rules
w_d3fend = w.d3fend w_d3fend = w.d3fend
w_freshness = w.freshness w_recency = w.recency
w_diversity = w.platform_diversity w_severity = w.severity
severity_by_mitre = _max_severity_by_mitre(db)
breakdown = {} breakdown = {}
@@ -360,65 +413,50 @@ def calculate_technique_score(technique: Technique, db: Session) -> dict:
else "No D3FEND mappings", else "No D3FEND mappings",
} }
# ── 4. Freshness ────────────────────────────────────────────── # ── 4. Recency ────────────────────────────────────────────────
most_recent_test = ( most_recent_test = None
db.query(func.max(Test.red_validated_at)) for t in validated_tests:
.filter( candidate = t.blue_validated_at or t.red_validated_at or t.created_at
Test.technique_id == technique.id, if candidate and (most_recent_test is None or candidate > most_recent_test):
Test.state == TestState.validated, most_recent_test = candidate
)
.scalar()
)
now = datetime.utcnow() recency_mult = _recency_factor(most_recent_test)
recency_score = round(recency_mult * w_recency, 1)
if most_recent_test: if most_recent_test:
days_ago = (now - most_recent_test).days days_ago = (
if days_ago < 90: datetime.now(timezone.utc) - (
freshness_pct = 1.0 most_recent_test.replace(tzinfo=timezone.utc)
elif days_ago < 180: if most_recent_test.tzinfo is None
freshness_pct = 0.5 else most_recent_test.astimezone(timezone.utc)
else: )
freshness_pct = 0.0 ).days
freshness_score = round(freshness_pct * w_freshness, 1) recency_detail = f"Last validated {days_ago} days ago (factor {recency_mult})"
freshness_detail = f"Last test {days_ago} days ago"
else: else:
freshness_pct = 0 recency_detail = "No validated tests"
freshness_score = 0
freshness_detail = "No validated tests"
breakdown["freshness"] = { breakdown["recency"] = {
"score": freshness_score, "score": recency_score,
"max": w_freshness, "max": w_recency,
"detail": freshness_detail, "detail": recency_detail,
} }
# ── 5. Platform diversity ───────────────────────────────────── # ── 5. Severity ───────────────────────────────────────────────
available_platforms = technique.platforms or [] sev_label = severity_by_mitre.get(technique.mitre_id)
total_platforms = len(available_platforms) if available_platforms else 3 sev_mult = _severity_factor(sev_label)
severity_score = round(sev_mult * w_severity, 1)
tested_platforms = set() breakdown["severity"] = {
for t in validated_tests: "score": severity_score,
if t.platform: "max": w_severity,
tested_platforms.add(t.platform.lower()) "detail": (
f"Template severity: {sev_label} (factor {sev_mult})"
if total_platforms > 0 and tested_platforms: if sev_label
diversity_ratio = min(len(tested_platforms) / total_platforms, 1.0) else "No severity template (default factor)"
diversity_score = round(diversity_ratio * w_diversity, 1) ),
else:
diversity_ratio = 0
diversity_score = 0
breakdown["platform_diversity"] = {
"score": diversity_score,
"max": w_diversity,
"detail": f"{len(tested_platforms)}/{total_platforms} platforms covered"
if tested_platforms
else "No platforms tested",
} }
# ── Total ───────────────────────────────────────────────────── # ── Total ─────────────────────────────────────────────────────
total = min( total = min(
test_score + detection_score + d3fend_score + freshness_score + diversity_score, test_score + detection_score + d3fend_score + recency_score + severity_score,
100, 100,
) )

View File

@@ -9,7 +9,8 @@ number of SQL queries regardless of technique count.
import logging import logging
import uuid import uuid
from datetime import datetime from collections import defaultdict
from datetime import datetime, timedelta, timezone
from sqlalchemy import func from sqlalchemy import func
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -43,6 +44,11 @@ def serialize_snapshot_summary(snap: CoverageSnapshot) -> dict:
"not_covered_count": snap.not_covered_count, "not_covered_count": snap.not_covered_count,
"in_progress_count": snap.in_progress_count, "in_progress_count": snap.in_progress_count,
"not_evaluated_count": snap.not_evaluated_count, "not_evaluated_count": snap.not_evaluated_count,
"coverage_percentage": getattr(snap, "coverage_percentage", 0.0),
"by_tactic": getattr(snap, "by_tactic", None) or {},
"by_status": getattr(snap, "by_status", None) or {},
"stale_count": getattr(snap, "stale_count", 0),
"never_tested_count": getattr(snap, "never_tested_count", 0),
"created_by": str(snap.created_by) if snap.created_by else None, "created_by": str(snap.created_by) if snap.created_by else None,
"created_at": snap.created_at.isoformat() if snap.created_at else None, "created_at": snap.created_at.isoformat() if snap.created_at else None,
} }
@@ -148,6 +154,13 @@ def create_snapshot(
not_covered_count = 0 not_covered_count = 0
in_progress_count = 0 in_progress_count = 0
not_evaluated_count = 0 not_evaluated_count = 0
stale_count = 0
never_tested_count = 0
by_tactic: dict[str, dict] = defaultdict(
lambda: {"total": 0, "validated": 0, "partial": 0, "score_sum": 0.0}
)
by_status: dict[str, int] = defaultdict(int)
technique_rows: list[dict] = [] technique_rows: list[dict] = []
@@ -170,15 +183,43 @@ def create_snapshot(
not_evaluated_count += 1 not_evaluated_count += 1
entry = scores_map.get(tech.id, {}) entry = scores_map.get(tech.id, {})
score = entry.get("total_score", 0)
technique_rows.append({ technique_rows.append({
"technique_id": tech.id, "technique_id": tech.id,
"mitre_id": tech.mitre_id, "mitre_id": tech.mitre_id,
"status": status_value, "status": status_value,
"score": entry.get("total_score", 0), "score": score,
}) })
by_status[status_value] += 1
tactic_key = tech.tactic or "unknown"
bucket = by_tactic[tactic_key]
bucket["total"] += 1
bucket["score_sum"] += score
if status_value == "validated":
bucket["validated"] += 1
elif status_value == "partial":
bucket["partial"] += 1
if status_value == "not_evaluated":
never_tested_count += 1
if tech.review_required:
stale_count += 1
org_data = calculate_organization_score(db) org_data = calculate_organization_score(db)
org_score = org_data.get("overall_score", 0) org_score = org_data.get("overall_score", 0)
total_techniques = len(techniques) or 1
coverage_pct = round((validated_count / total_techniques) * 100, 1)
by_tactic_out = {
tactic: {
"total": data["total"],
"validated": data["validated"],
"partial": data["partial"],
"average_score": round(data["score_sum"] / data["total"], 1) if data["total"] else 0,
}
for tactic, data in by_tactic.items()
}
snapshot = CoverageSnapshot( snapshot = CoverageSnapshot(
name=name, name=name,
@@ -189,6 +230,11 @@ def create_snapshot(
not_covered_count=not_covered_count, not_covered_count=not_covered_count,
in_progress_count=in_progress_count, in_progress_count=in_progress_count,
not_evaluated_count=not_evaluated_count, not_evaluated_count=not_evaluated_count,
coverage_percentage=coverage_pct,
by_tactic=by_tactic_out,
by_status=dict(by_status),
stale_count=stale_count,
never_tested_count=never_tested_count,
created_by=user_id, created_by=user_id,
) )
db.add(snapshot) db.add(snapshot)
@@ -320,6 +366,37 @@ def compare_snapshots(
} }
# ---------------------------------------------------------------------------
# Coverage evolution (trends)
# ---------------------------------------------------------------------------
def get_coverage_evolution(db: Session, *, months: int = 12) -> list[dict]:
"""Return snapshot trend points for the last *months* months."""
cutoff = datetime.now(timezone.utc) - timedelta(days=months * 30)
snapshots = (
db.query(CoverageSnapshot)
.filter(CoverageSnapshot.created_at >= cutoff)
.order_by(CoverageSnapshot.created_at.asc())
.all()
)
return [
{
"date": snap.created_at.isoformat() if snap.created_at else None,
"name": snap.name,
"org_score": snap.organization_score,
"coverage_pct": getattr(snap, "coverage_percentage", 0.0),
"by_tactic": getattr(snap, "by_tactic", None) or {},
"by_status": getattr(snap, "by_status", None) or {},
"stale_count": getattr(snap, "stale_count", 0),
"never_tested_count": getattr(snap, "never_tested_count", 0),
"validated_count": snap.validated_count,
"total_techniques": snap.total_techniques,
}
for snap in snapshots
]
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Cleanup # Cleanup
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------

View File

@@ -0,0 +1,102 @@
"""Tests for Phase 5 snapshot evolution and breakdown fields."""
from datetime import datetime, timedelta, timezone
from app.models.coverage_snapshot import CoverageSnapshot
from app.models.enums import TechniqueStatus
from app.models.technique import Technique
from app.services.snapshot_service import create_snapshot, get_coverage_evolution
def test_create_snapshot_includes_tactic_breakdown(db, admin_user):
tech = Technique(
mitre_id="T1059",
name="Command and Scripting Interpreter",
tactic="execution",
status_global=TechniqueStatus.validated,
)
db.add(tech)
db.commit()
snap = create_snapshot(db, name="Phase5 test", user_id=admin_user.id)
assert snap.by_tactic
assert "execution" in snap.by_tactic
assert snap.by_status
assert snap.coverage_percentage >= 0
assert snap.never_tested_count >= 0
def test_coverage_evolution_filters_by_months(db, admin_user):
now = datetime.now(timezone.utc)
old = CoverageSnapshot(
name="old",
organization_score=50.0,
total_techniques=10,
validated_count=5,
partial_count=1,
not_covered_count=1,
in_progress_count=1,
not_evaluated_count=2,
coverage_percentage=50.0,
by_tactic={"execution": {"total": 10, "validated": 5}},
by_status={"validated": 5},
stale_count=0,
never_tested_count=2,
created_by=admin_user.id,
)
old.created_at = now - timedelta(days=400)
recent = CoverageSnapshot(
name="recent",
organization_score=70.0,
total_techniques=10,
validated_count=7,
partial_count=1,
not_covered_count=0,
in_progress_count=1,
not_evaluated_count=1,
coverage_percentage=70.0,
by_tactic={"execution": {"total": 10, "validated": 7}},
by_status={"validated": 7},
stale_count=1,
never_tested_count=1,
created_by=admin_user.id,
)
db.add_all([old, recent])
db.commit()
evolution = get_coverage_evolution(db, months=6)
assert len(evolution) == 1
assert evolution[0]["org_score"] == 70.0
assert evolution[0]["coverage_pct"] == 70.0
assert evolution[0]["stale_count"] == 1
def test_evolution_endpoint(client, db, admin_user, admin_token):
snap = CoverageSnapshot(
name="api-evolution",
organization_score=60.0,
total_techniques=5,
validated_count=3,
partial_count=0,
not_covered_count=0,
in_progress_count=0,
not_evaluated_count=2,
coverage_percentage=60.0,
by_tactic={},
by_status={"validated": 3},
stale_count=0,
never_tested_count=2,
)
db.add(snap)
db.commit()
response = client.get(
"/api/v1/snapshots/evolution?months=12",
headers={"Authorization": f"Bearer {admin_token}"},
)
assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
assert len(data) >= 1
assert "org_score" in data[0]

View File

@@ -18,7 +18,10 @@ from app.models.defensive_technique import DefensiveTechnique, DefensiveTechniqu
from app.models.compliance import ComplianceFramework, ComplianceControl, ComplianceControlMapping from app.models.compliance import ComplianceFramework, ComplianceControl, ComplianceControlMapping
from app.models.audit import AuditLog from app.models.audit import AuditLog
from app.models.enums import TestState, TestResult, TechniqueStatus from app.models.enums import TestState, TestResult, TechniqueStatus
from app.models.scoring_config import ScoringConfig
from app.services.scoring_config_service import get_scoring_weights, update_scoring_weights
from app.services.scoring_service import ( from app.services.scoring_service import (
_recency_factor,
calculate_technique_score, calculate_technique_score,
calculate_tactic_score, calculate_tactic_score,
calculate_organization_score, calculate_organization_score,
@@ -175,9 +178,11 @@ class TestScoring:
assert result["breakdown"]["tests_validated"]["score"] > 0 assert result["breakdown"]["tests_validated"]["score"] > 0
def test_technique_score_no_tests(self, db, sample_technique_no_tests): def test_technique_score_no_tests(self, db, sample_technique_no_tests):
"""Técnica sin tests → score 0.""" """Técnica sin tests → solo puede puntuar severidad por defecto."""
result = calculate_technique_score(sample_technique_no_tests, db) result = calculate_technique_score(sample_technique_no_tests, db)
assert result["total_score"] == 0 assert result["breakdown"]["tests_validated"]["score"] == 0
assert result["breakdown"]["recency"]["score"] == 0
assert result["total_score"] == result["breakdown"]["severity"]["score"]
def test_technique_score_partial_detection(self, db, sample_technique, validated_tests): def test_technique_score_partial_detection(self, db, sample_technique, validated_tests):
"""Técnica con detección parcial → score intermedio.""" """Técnica con detección parcial → score intermedio."""
@@ -187,8 +192,8 @@ class TestScoring:
breakdown = result["breakdown"] breakdown = result["breakdown"]
assert "2/3" in breakdown["tests_validated"]["detail"] assert "2/3" in breakdown["tests_validated"]["detail"]
def test_technique_score_freshness_penalty(self, db, sample_technique, admin_user): def test_technique_score_recency_penalty(self, db, sample_technique, admin_user):
"""Tests > 180 días → penalización en freshness.""" """Tests > 180 días → factor de recencia reducido."""
old_date = datetime.utcnow() - timedelta(days=200) old_date = datetime.utcnow() - timedelta(days=200)
test = Test( test = Test(
technique_id=sample_technique.id, technique_id=sample_technique.id,
@@ -198,14 +203,53 @@ class TestScoring:
created_by=admin_user.id, created_by=admin_user.id,
platform="windows", platform="windows",
red_validated_at=old_date, red_validated_at=old_date,
blue_validated_at=old_date,
) )
db.add(test) db.add(test)
db.commit() db.commit()
result = calculate_technique_score(sample_technique, db) result = calculate_technique_score(sample_technique, db)
# Freshness should be 0 for tests > 180 days old assert result["breakdown"]["recency"]["score"] == 5.0 # 0.5 * 10 (181365 días)
assert result["breakdown"]["freshness"]["score"] == 0 assert "200" in result["breakdown"]["recency"]["detail"]
assert "200" in result["breakdown"]["freshness"]["detail"]
def test_recency_recent_scores_higher_than_old(self, db, sample_technique, admin_user):
"""Mismo resultado de detección: test reciente puntúa más que uno de hace 1 año."""
now = datetime.utcnow()
for days_ago in (1, 400):
db.add(
Test(
technique_id=sample_technique.id,
name=f"Recency {days_ago}",
state=TestState.validated,
detection_result=TestResult.detected,
created_by=admin_user.id,
platform="windows",
red_validated_at=now - timedelta(days=days_ago),
blue_validated_at=now - timedelta(days=days_ago),
)
)
db.commit()
result = calculate_technique_score(sample_technique, db)
assert _recency_factor(now - timedelta(days=1)) == 1.0
assert _recency_factor(now - timedelta(days=400)) == 0.2
assert result["breakdown"]["recency"]["score"] == 10.0
def test_scoring_weights_persist_in_database(self, db, sample_technique, validated_tests):
"""Cambiar pesos en BD se refleja en el breakdown."""
update_scoring_weights(db, tests=50, detection_rules=20, d3fend=10, recency=10, severity=10)
db.commit()
score = calculate_technique_score(sample_technique, db)
assert score["breakdown"]["tests_validated"]["max"] == 50
row = db.query(ScoringConfig).first()
assert row is not None
assert row.weight_tests == 50
update_scoring_weights(db, tests=40, detection_rules=25, d3fend=15, recency=10, severity=10)
db.commit()
assert get_scoring_weights(db).tests == 40
def test_scoring_weights_configurable(self, db, sample_technique, validated_tests): def test_scoring_weights_configurable(self, db, sample_technique, validated_tests):
"""Scoring weights are reflected in the breakdown max values.""" """Scoring weights are reflected in the breakdown max values."""

View File

@@ -84,13 +84,13 @@ class TestScoringWeights:
assert w.tests == 40.0 assert w.tests == 40.0
assert w.detection_rules == 25.0 assert w.detection_rules == 25.0
assert w.d3fend == 15.0 assert w.d3fend == 15.0
assert w.freshness == 10.0 assert w.recency == 10.0
assert w.platform_diversity == 10.0 assert w.severity == 10.0
def test_valid_custom(self): def test_valid_custom(self):
w = ScoringWeights( w = ScoringWeights(
tests=50, detection_rules=20, d3fend=10, tests=50, detection_rules=20, d3fend=10,
freshness=10, platform_diversity=10, recency=10, severity=10,
) )
assert w.tests == 50 assert w.tests == 50
@@ -98,14 +98,14 @@ class TestScoringWeights:
with pytest.raises(ValueError, match="sum to 100"): with pytest.raises(ValueError, match="sum to 100"):
ScoringWeights( ScoringWeights(
tests=50, detection_rules=20, d3fend=10, tests=50, detection_rules=20, d3fend=10,
freshness=10, platform_diversity=5, recency=10, severity=5,
) )
def test_invalid_negative_weight(self): def test_invalid_negative_weight(self):
with pytest.raises(ValueError, match="non-negative"): with pytest.raises(ValueError, match="non-negative"):
ScoringWeights( ScoringWeights(
tests=-10, detection_rules=40, d3fend=30, tests=-10, detection_rules=40, d3fend=30,
freshness=20, platform_diversity=20, recency=20, severity=20,
) )
def test_immutable(self): def test_immutable(self):