fix(risk): correct TechniqueConfidenceScore fields, TechniqueStatus values, Test.result usage
Some checks failed
Aegis CI / lint-and-test (push) Has been cancelled

This commit is contained in:
kitos
2026-05-20 15:58:03 +02:00
parent 362a17aa1b
commit 084ea4c0b2

View File

@@ -5,7 +5,6 @@ from datetime import datetime, timedelta
from typing import List, Optional from typing import List, Optional
from uuid import UUID from uuid import UUID
from sqlalchemy import func
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.domain.errors import EntityNotFoundError from app.domain.errors import EntityNotFoundError
@@ -14,25 +13,22 @@ from app.models.technique import Technique
from app.models.threat_actor import ThreatActorTechnique from app.models.threat_actor import ThreatActorTechnique
from app.models.osint_item import OsintItem from app.models.osint_item import OsintItem
from app.models.test import Test from app.models.test import Test
from app.models.test_detection_result import TestDetectionResult
from app.models.detection_lifecycle import ( from app.models.detection_lifecycle import (
TechniqueConfidenceScore, TechniqueConfidenceScore,
DetectionTechniqueMapping, DetectionTechniqueMapping,
DetectionConfidence,
) )
from app.models.enums import TechniqueStatus from app.models.enums import TechniqueStatus, TestResult
# ── Scoring constants ────────────────────────────────────────────────────────── # ── Scoring weights & thresholds ───────────────────────────────────────────────
WEIGHT_DETECTION_GAP = 0.35 WEIGHT_DETECTION_GAP = 0.35
WEIGHT_THREAT_ACTORS = 0.30 WEIGHT_THREAT_ACTORS = 0.30
WEIGHT_OSINT = 0.20 WEIGHT_OSINT = 0.20
WEIGHT_TEST_FAILURES = 0.15 WEIGHT_TEST_FAILURES = 0.15
# Normalisation caps MAX_THREAT_ACTORS = 5
MAX_THREAT_ACTORS = 5 # beyond this → factor saturates at 1.0 MAX_OSINT_SIGNALS = 10
MAX_OSINT_SIGNALS = 10 # OSINT items in last 30 days
OSINT_LOOKBACK_DAYS = 30 OSINT_LOOKBACK_DAYS = 30
LEVEL_CRITICAL = 75.0 LEVEL_CRITICAL = 75.0
@@ -53,52 +49,48 @@ def _clamp(v: float, lo: float = 0.0, hi: float = 1.0) -> float:
return max(lo, min(hi, v)) return max(lo, min(hi, v))
# ── Single-technique computation ─────────────────────────────────────────────── # ── Per-technique computation ──────────────────────────────────────────────────
def _compute_for_technique(db: Session, tech: Technique) -> TechniqueRiskProfile: def _compute_for_technique(db: Session, tech: Technique) -> TechniqueRiskProfile:
"""Calculate the risk profile for one technique and return the (unsaved) model."""
breakdown: dict = {} breakdown: dict = {}
recs: list = [] recs: list = []
# ── Factor 1: Detection gap (0=covered, 1=no coverage) ─────────────────── # ── Factor 1: Detection gap ───────────────────────────────────────────────
# Check if technique is covered (has at least one DetectionTechniqueMapping) # Count how many assets map this technique
mapping_count = db.query(DetectionTechniqueMapping).filter( mapping_count = db.query(DetectionTechniqueMapping).filter(
DetectionTechniqueMapping.technique_id == tech.id, DetectionTechniqueMapping.technique_id == tech.id,
).count() ).count()
# Get DLC confidence score if available # DLC confidence score (unique per technique — no order_by needed)
dlc_conf = db.query(TechniqueConfidenceScore).filter( dlc_conf = db.query(TechniqueConfidenceScore).filter(
TechniqueConfidenceScore.technique_id == tech.id, TechniqueConfidenceScore.technique_id == tech.id,
).order_by(TechniqueConfidenceScore.computed_at.desc()).first() ).first()
confidence_level: float = float(dlc_conf.confidence_score or 0.0) if dlc_conf else 0.0
confidence_level: float = 0.0 # Map technique status to coverage factor
if dlc_conf: status = tech.status
confidence_level = float(dlc_conf.score or 0.0) if status == TechniqueStatus.validated:
# Also factor in technique status
if tech.status == TechniqueStatus.covered:
status_coverage = 1.0 status_coverage = 1.0
elif tech.status == TechniqueStatus.partial: elif status == TechniqueStatus.partial:
status_coverage = 0.5 status_coverage = 0.5
else: # uncovered / unknown elif status == TechniqueStatus.in_progress:
status_coverage = 0.25
else:
status_coverage = 0.0 status_coverage = 0.0
if mapping_count > 0: if mapping_count > 0:
# Has at least one asset mapped — use confidence as detection quality
raw_coverage = max(status_coverage, _clamp(confidence_level)) raw_coverage = max(status_coverage, _clamp(confidence_level))
else: else:
raw_coverage = 0.0 raw_coverage = status_coverage # rely on status alone if no asset mapping
detection_gap = 1.0 - raw_coverage detection_gap = 1.0 - raw_coverage
detection_gap_factor = detection_gap # already 01
breakdown["detection_gap"] = { breakdown["detection_gap"] = {
"mapping_count": mapping_count, "mapping_count": mapping_count,
"status": str(status) if status else None,
"status_coverage": status_coverage, "status_coverage": status_coverage,
"confidence_level": confidence_level, "confidence_level": round(confidence_level, 3),
"detection_gap": round(detection_gap, 3), "detection_gap": round(detection_gap, 3),
"contribution": round(detection_gap_factor * WEIGHT_DETECTION_GAP * 100, 2), "contribution": round(detection_gap * WEIGHT_DETECTION_GAP * 100, 2),
} }
if detection_gap >= 0.8: if detection_gap >= 0.8:
recs.append("Implement detection coverage — technique is largely undetected.") recs.append("Implement detection coverage — technique is largely undetected.")
@@ -109,12 +101,11 @@ def _compute_for_technique(db: Session, tech: Technique) -> TechniqueRiskProfile
actor_count = db.query(ThreatActorTechnique).filter( actor_count = db.query(ThreatActorTechnique).filter(
ThreatActorTechnique.technique_id == tech.id, ThreatActorTechnique.technique_id == tech.id,
).count() ).count()
ta_factor = _clamp(actor_count / MAX_THREAT_ACTORS) ta_factor = _clamp(actor_count / MAX_THREAT_ACTORS)
breakdown["threat_actor"] = { breakdown["threat_actor"] = {
"actor_count": actor_count, "actor_count": actor_count,
"max_cap": MAX_THREAT_ACTORS, "max_cap": MAX_THREAT_ACTORS,
"normalised": round(ta_factor, 3), "normalised": round(ta_factor, 3),
"contribution": round(ta_factor * WEIGHT_THREAT_ACTORS * 100, 2), "contribution": round(ta_factor * WEIGHT_THREAT_ACTORS * 100, 2),
} }
if actor_count >= 3: if actor_count >= 3:
@@ -132,85 +123,64 @@ def _compute_for_technique(db: Session, tech: Technique) -> TechniqueRiskProfile
OsintItem.technique_id == tech.id, OsintItem.technique_id == tech.id,
OsintItem.discovered_at >= cutoff, OsintItem.discovered_at >= cutoff,
).count() ).count()
osint_factor = _clamp(osint_count / MAX_OSINT_SIGNALS) osint_factor = _clamp(osint_count / MAX_OSINT_SIGNALS)
breakdown["osint"] = { breakdown["osint"] = {
"signal_count_30d": osint_count, "signal_count_30d": osint_count,
"max_cap": MAX_OSINT_SIGNALS, "max_cap": MAX_OSINT_SIGNALS,
"normalised": round(osint_factor, 3), "normalised": round(osint_factor, 3),
"contribution": round(osint_factor * WEIGHT_OSINT * 100, 2), "contribution": round(osint_factor * WEIGHT_OSINT * 100, 2),
} }
if osint_count >= 5: if osint_count >= 5:
recs.append( recs.append(
f"High OSINT activity — {osint_count} signals in the last 30 days. Review urgently." f"High OSINT activity — {osint_count} signals in last 30 days. Review urgently."
) )
elif osint_count >= 1: elif osint_count >= 1:
recs.append( recs.append(
f"{osint_count} OSINT signal(s) detected in last 30 days. Review for IoCs." f"{osint_count} OSINT signal(s) in last 30 days. Review for IoCs."
) )
# ── Factor 4: Test failure rate ─────────────────────────────────────────── # ── Factor 4: Test failure rate ───────────────────────────────────────────
# Count TestDetectionResult rows for this technique's tests # Use Test.result (TestResult enum) to determine not-detected count
from app.models.enums import TestResult
tech_tests = db.query(Test).filter(Test.technique_id == tech.id).all() tech_tests = db.query(Test).filter(Test.technique_id == tech.id).all()
test_ids = [t.id for t in tech_tests] test_total = len([t for t in tech_tests if t.result is not None])
test_not_detected = sum(
test_total = 0 1 for t in tech_tests
test_not_detected = 0 if t.result == TestResult.not_detected
if test_ids: )
from app.models.test_detection_result import TestDetectionResult as TDR
results = db.query(TDR).filter(TDR.test_id.in_(test_ids)).all()
test_total = len(results)
test_not_detected = sum(
1 for r in results
if hasattr(r, 'result') and str(getattr(r, 'result', '')) == 'not_detected'
)
# Also count tests where overall result is not_detected
if test_total == 0:
for t in tech_tests:
if hasattr(t, 'result') and t.result is not None:
test_total += 1
if str(t.result) in ('not_detected', 'TestResult.not_detected'):
test_not_detected += 1
test_failure_rate = (test_not_detected / test_total) if test_total > 0 else 0.0 test_failure_rate = (test_not_detected / test_total) if test_total > 0 else 0.0
# If no tests exist at all → treat as unknown risk (moderate) # No tests → moderate unknown risk factor
test_factor = test_failure_rate if test_total > 0 else 0.3 test_factor = test_failure_rate if test_total > 0 else 0.3
breakdown["test_failures"] = { breakdown["test_failures"] = {
"total_tests": test_total, "total_tests": test_total,
"not_detected": test_not_detected, "not_detected": test_not_detected,
"failure_rate": round(test_failure_rate, 3), "failure_rate": round(test_failure_rate, 3),
"factor_used": round(test_factor, 3), "factor_used": round(test_factor, 3),
"contribution": round(test_factor * WEIGHT_TEST_FAILURES * 100, 2), "contribution": round(test_factor * WEIGHT_TEST_FAILURES * 100, 2),
} }
if test_total == 0: if test_total == 0:
recs.append("No purple-team tests found — add tests to validate detection.") recs.append("No purple-team tests found — add tests to validate detection.")
elif test_failure_rate >= 0.5: elif test_failure_rate >= 0.5:
recs.append( recs.append(
f"High test failure rate ({test_failure_rate:.0%}) — blue team is missing this technique." f"High test failure rate ({test_failure_rate:.0%}) — blue team misses this technique."
) )
# ── Weighted risk score ─────────────────────────────────────────────────── # ── Aggregate scores ──────────────────────────────────────────────────────
raw_score = ( raw_score = (
detection_gap_factor * WEIGHT_DETECTION_GAP detection_gap * WEIGHT_DETECTION_GAP
+ ta_factor * WEIGHT_THREAT_ACTORS + ta_factor * WEIGHT_THREAT_ACTORS
+ osint_factor * WEIGHT_OSINT + osint_factor * WEIGHT_OSINT
+ test_factor * WEIGHT_TEST_FAILURES + test_factor * WEIGHT_TEST_FAILURES
) )
risk_score = _clamp(raw_score) * 100.0 risk_score = _clamp(raw_score) * 100.0
# Likelihood = detection + actor contribution (exposure)
likelihood = _clamp( likelihood = _clamp(
detection_gap_factor * 0.5 + ta_factor * 0.35 + osint_factor * 0.15 detection_gap * 0.5 + ta_factor * 0.35 + osint_factor * 0.15
) * 100.0 ) * 100.0
# Impact = test failures + osint severity signal
impact = _clamp( impact = _clamp(
test_factor * 0.6 + osint_factor * 0.25 + detection_gap_factor * 0.15 test_factor * 0.6 + osint_factor * 0.25 + detection_gap * 0.15
) * 100.0 ) * 100.0
level = _risk_level(risk_score) level = _risk_level(risk_score)
breakdown["total"] = { breakdown["total"] = {
"risk_score": round(risk_score, 2), "risk_score": round(risk_score, 2),
"likelihood": round(likelihood, 2), "likelihood": round(likelihood, 2),
@@ -238,7 +208,7 @@ def _compute_for_technique(db: Session, tech: Technique) -> TechniqueRiskProfile
) )
# ── Upsert helpers ───────────────────────────────────────────────────────────── # ── Upsert ─────────────────────────────────────────────────────────────────────
def _upsert_profile(db: Session, profile: TechniqueRiskProfile) -> TechniqueRiskProfile: def _upsert_profile(db: Session, profile: TechniqueRiskProfile) -> TechniqueRiskProfile:
existing = db.query(TechniqueRiskProfile).filter( existing = db.query(TechniqueRiskProfile).filter(
@@ -265,7 +235,6 @@ def _upsert_profile(db: Session, profile: TechniqueRiskProfile) -> TechniqueRisk
# ── Public API ──────────────────────────────────────────────────────────────── # ── Public API ────────────────────────────────────────────────────────────────
def compute_technique_risk(db: Session, technique_id: UUID) -> TechniqueRiskProfile: def compute_technique_risk(db: Session, technique_id: UUID) -> TechniqueRiskProfile:
"""Compute (or recompute) risk profile for a single technique."""
tech = db.query(Technique).filter(Technique.id == technique_id).first() tech = db.query(Technique).filter(Technique.id == technique_id).first()
if not tech: if not tech:
raise EntityNotFoundError("Technique", str(technique_id)) raise EntityNotFoundError("Technique", str(technique_id))
@@ -274,12 +243,10 @@ def compute_technique_risk(db: Session, technique_id: UUID) -> TechniqueRiskProf
def compute_all_risk_scores(db: Session) -> dict: def compute_all_risk_scores(db: Session) -> dict:
"""Compute risk profiles for all techniques. Returns summary counts."""
t0 = time.monotonic() t0 = time.monotonic()
techniques = db.query(Technique).all() techniques = db.query(Technique).all()
computed = 0 computed = 0
errors = 0 errors = 0
for tech in techniques: for tech in techniques:
try: try:
profile = _compute_for_technique(db, tech) profile = _compute_for_technique(db, tech)
@@ -287,13 +254,11 @@ def compute_all_risk_scores(db: Session) -> dict:
computed += 1 computed += 1
except Exception: except Exception:
errors += 1 errors += 1
duration = time.monotonic() - t0
return { return {
"computed": computed, "computed": computed,
"skipped": 0, "skipped": 0,
"errors": errors, "errors": errors,
"duration_seconds": round(duration, 2), "duration_seconds": round(time.monotonic() - t0, 2),
} }
@@ -308,12 +273,12 @@ def get_risk_profile(db: Session, technique_id: UUID) -> TechniqueRiskProfile:
def list_risk_profiles( def list_risk_profiles(
db: Session, db: Session,
risk_level: Optional[str] = None, risk_level: Optional[str] = None,
min_score: Optional[float] = None, min_score: Optional[float] = None,
max_score: Optional[float] = None, max_score: Optional[float] = None,
stale_only: bool = False, stale_only: bool = False,
limit: int = 100, limit: int = 100,
offset: int = 0, offset: int = 0,
) -> List[TechniqueRiskProfile]: ) -> List[TechniqueRiskProfile]:
q = db.query(TechniqueRiskProfile) q = db.query(TechniqueRiskProfile)
if risk_level: if risk_level:
@@ -326,58 +291,51 @@ def list_risk_profiles(
q = q.filter(TechniqueRiskProfile.is_stale == True) q = q.filter(TechniqueRiskProfile.is_stale == True)
return ( return (
q.order_by(TechniqueRiskProfile.risk_score.desc()) q.order_by(TechniqueRiskProfile.risk_score.desc())
.offset(offset) .offset(offset).limit(limit).all()
.limit(limit)
.all()
) )
def get_risk_matrix(db: Session) -> list: def get_risk_matrix(db: Session) -> list:
"""Return all profiled techniques with name+tid for the matrix view."""
rows = ( rows = (
db.query(TechniqueRiskProfile, Technique) db.query(TechniqueRiskProfile, Technique)
.join(Technique, TechniqueRiskProfile.technique_id == Technique.id) .join(Technique, TechniqueRiskProfile.technique_id == Technique.id)
.order_by(TechniqueRiskProfile.risk_score.desc()) .order_by(TechniqueRiskProfile.risk_score.desc())
.all() .all()
) )
result = [] return [
for profile, tech in rows: {
result.append({ "technique_id": str(p.technique_id),
"technique_id": str(profile.technique_id), "technique_name": t.name,
"technique_name": tech.name, "technique_tid": t.technique_id,
"technique_tid": tech.technique_id, # MITRE T-ID string "risk_score": p.risk_score,
"risk_score": profile.risk_score, "likelihood": p.likelihood,
"likelihood": profile.likelihood, "impact": p.impact,
"impact": profile.impact, "risk_level": p.risk_level,
"risk_level": profile.risk_level, "detection_gap": p.detection_gap,
"detection_gap": profile.detection_gap, "computed_at": p.computed_at.isoformat() if p.computed_at else None,
"computed_at": profile.computed_at.isoformat() if profile.computed_at else None, }
}) for p, t in rows
return result ]
def get_risk_summary(db: Session) -> dict: def get_risk_summary(db: Session) -> dict:
"""Aggregate statistics across all risk profiles."""
all_profiles = db.query(TechniqueRiskProfile).all() all_profiles = db.query(TechniqueRiskProfile).all()
total_tech = db.query(Technique).count() total_tech = db.query(Technique).count()
scored = len(all_profiles) scored = len(all_profiles)
stale = sum(1 for p in all_profiles if p.is_stale) stale = sum(1 for p in all_profiles if p.is_stale)
by_level: dict = {lvl: 0 for lvl in ("critical", "high", "medium", "low", "info")} by_level: dict = {l: 0 for l in ("critical", "high", "medium", "low", "info")}
score_sum = 0.0 score_sum = 0.0
for p in all_profiles: for p in all_profiles:
by_level[p.risk_level] = by_level.get(p.risk_level, 0) + 1 by_level[p.risk_level] = by_level.get(p.risk_level, 0) + 1
score_sum += p.risk_score score_sum += p.risk_score
avg_score = (score_sum / scored) if scored > 0 else 0.0 avg_score = (score_sum / scored) if scored > 0 else 0.0
# Top 5 by risk score (with technique name)
top_rows = ( top_rows = (
db.query(TechniqueRiskProfile, Technique) db.query(TechniqueRiskProfile, Technique)
.join(Technique, TechniqueRiskProfile.technique_id == Technique.id) .join(Technique, TechniqueRiskProfile.technique_id == Technique.id)
.order_by(TechniqueRiskProfile.risk_score.desc()) .order_by(TechniqueRiskProfile.risk_score.desc())
.limit(5) .limit(5).all()
.all()
) )
top_risks = [ top_risks = [
{ {
@@ -393,7 +351,6 @@ def get_risk_summary(db: Session) -> dict:
} }
for p, t in top_rows for p, t in top_rows
] ]
return { return {
"total_techniques": total_tech, "total_techniques": total_tech,
"scored_techniques": scored, "scored_techniques": scored,
@@ -405,24 +362,22 @@ def get_risk_summary(db: Session) -> dict:
def get_recommendations(db: Session, limit: int = 20) -> list: def get_recommendations(db: Session, limit: int = 20) -> list:
"""Prioritised list of techniques with actionable recommendations."""
rows = ( rows = (
db.query(TechniqueRiskProfile, Technique) db.query(TechniqueRiskProfile, Technique)
.join(Technique, TechniqueRiskProfile.technique_id == Technique.id) .join(Technique, TechniqueRiskProfile.technique_id == Technique.id)
.filter(TechniqueRiskProfile.risk_score > 0) .filter(TechniqueRiskProfile.risk_score > 0)
.order_by(TechniqueRiskProfile.risk_score.desc()) .order_by(TechniqueRiskProfile.risk_score.desc())
.limit(limit) .limit(limit).all()
.all()
) )
result = [] return [
for priority, (profile, tech) in enumerate(rows, start=1): {
result.append({ "technique_id": str(p.technique_id),
"technique_id": str(profile.technique_id), "technique_name": t.name,
"technique_name": tech.name, "technique_tid": t.technique_id,
"technique_tid": tech.technique_id, "risk_level": p.risk_level,
"risk_level": profile.risk_level, "risk_score": p.risk_score,
"risk_score": profile.risk_score, "recommendations": p.recommendations or [],
"recommendations": profile.recommendations or [], "priority": i,
"priority": priority, }
}) for i, (p, t) in enumerate(rows, start=1)
return result ]