Files
Aegis/backend/app/services/risk_intelligence_service.py
kitos 084ea4c0b2
Some checks failed
Aegis CI / lint-and-test (push) Has been cancelled
fix(risk): correct TechniqueConfidenceScore fields, TechniqueStatus values, Test.result usage
2026-05-20 15:58:03 +02:00

384 lines
14 KiB
Python

"""Phase 12: Risk Intelligence service — compute and query per-technique risk scores."""
import time
from datetime import datetime, timedelta
from typing import List, Optional
from uuid import UUID
from sqlalchemy.orm import Session
from app.domain.errors import EntityNotFoundError
from app.models.risk_intelligence import TechniqueRiskProfile
from app.models.technique import Technique
from app.models.threat_actor import ThreatActorTechnique
from app.models.osint_item import OsintItem
from app.models.test import Test
from app.models.detection_lifecycle import (
TechniqueConfidenceScore,
DetectionTechniqueMapping,
)
from app.models.enums import TechniqueStatus, TestResult
# ── Scoring weights & thresholds ───────────────────────────────────────────────
WEIGHT_DETECTION_GAP = 0.35
WEIGHT_THREAT_ACTORS = 0.30
WEIGHT_OSINT = 0.20
WEIGHT_TEST_FAILURES = 0.15
MAX_THREAT_ACTORS = 5
MAX_OSINT_SIGNALS = 10
OSINT_LOOKBACK_DAYS = 30
LEVEL_CRITICAL = 75.0
LEVEL_HIGH = 50.0
LEVEL_MEDIUM = 25.0
LEVEL_LOW = 10.0
def _risk_level(score: float) -> str:
if score >= LEVEL_CRITICAL: return "critical"
if score >= LEVEL_HIGH: return "high"
if score >= LEVEL_MEDIUM: return "medium"
if score >= LEVEL_LOW: return "low"
return "info"
def _clamp(v: float, lo: float = 0.0, hi: float = 1.0) -> float:
return max(lo, min(hi, v))
# ── Per-technique computation ──────────────────────────────────────────────────
def _compute_for_technique(db: Session, tech: Technique) -> TechniqueRiskProfile:
breakdown: dict = {}
recs: list = []
# ── Factor 1: Detection gap ───────────────────────────────────────────────
# Count how many assets map this technique
mapping_count = db.query(DetectionTechniqueMapping).filter(
DetectionTechniqueMapping.technique_id == tech.id,
).count()
# DLC confidence score (unique per technique — no order_by needed)
dlc_conf = db.query(TechniqueConfidenceScore).filter(
TechniqueConfidenceScore.technique_id == tech.id,
).first()
confidence_level: float = float(dlc_conf.confidence_score or 0.0) if dlc_conf else 0.0
# Map technique status to coverage factor
status = tech.status
if status == TechniqueStatus.validated:
status_coverage = 1.0
elif status == TechniqueStatus.partial:
status_coverage = 0.5
elif status == TechniqueStatus.in_progress:
status_coverage = 0.25
else:
status_coverage = 0.0
if mapping_count > 0:
raw_coverage = max(status_coverage, _clamp(confidence_level))
else:
raw_coverage = status_coverage # rely on status alone if no asset mapping
detection_gap = 1.0 - raw_coverage
breakdown["detection_gap"] = {
"mapping_count": mapping_count,
"status": str(status) if status else None,
"status_coverage": status_coverage,
"confidence_level": round(confidence_level, 3),
"detection_gap": round(detection_gap, 3),
"contribution": round(detection_gap * WEIGHT_DETECTION_GAP * 100, 2),
}
if detection_gap >= 0.8:
recs.append("Implement detection coverage — technique is largely undetected.")
elif detection_gap >= 0.5:
recs.append("Improve detection quality — coverage is partial.")
# ── Factor 2: Threat actor relevance ─────────────────────────────────────
actor_count = db.query(ThreatActorTechnique).filter(
ThreatActorTechnique.technique_id == tech.id,
).count()
ta_factor = _clamp(actor_count / MAX_THREAT_ACTORS)
breakdown["threat_actor"] = {
"actor_count": actor_count,
"max_cap": MAX_THREAT_ACTORS,
"normalised": round(ta_factor, 3),
"contribution": round(ta_factor * WEIGHT_THREAT_ACTORS * 100, 2),
}
if actor_count >= 3:
recs.append(
f"High threat-actor relevance — {actor_count} tracked actors use this technique."
)
elif actor_count >= 1:
recs.append(
f"{actor_count} threat actor(s) use this technique — monitor closely."
)
# ── Factor 3: OSINT signals (last 30 days) ────────────────────────────────
cutoff = datetime.utcnow() - timedelta(days=OSINT_LOOKBACK_DAYS)
osint_count = db.query(OsintItem).filter(
OsintItem.technique_id == tech.id,
OsintItem.discovered_at >= cutoff,
).count()
osint_factor = _clamp(osint_count / MAX_OSINT_SIGNALS)
breakdown["osint"] = {
"signal_count_30d": osint_count,
"max_cap": MAX_OSINT_SIGNALS,
"normalised": round(osint_factor, 3),
"contribution": round(osint_factor * WEIGHT_OSINT * 100, 2),
}
if osint_count >= 5:
recs.append(
f"High OSINT activity — {osint_count} signals in last 30 days. Review urgently."
)
elif osint_count >= 1:
recs.append(
f"{osint_count} OSINT signal(s) in last 30 days. Review for IoCs."
)
# ── Factor 4: Test failure rate ───────────────────────────────────────────
# Use Test.result (TestResult enum) to determine not-detected count
tech_tests = db.query(Test).filter(Test.technique_id == tech.id).all()
test_total = len([t for t in tech_tests if t.result is not None])
test_not_detected = sum(
1 for t in tech_tests
if t.result == TestResult.not_detected
)
test_failure_rate = (test_not_detected / test_total) if test_total > 0 else 0.0
# No tests → moderate unknown risk factor
test_factor = test_failure_rate if test_total > 0 else 0.3
breakdown["test_failures"] = {
"total_tests": test_total,
"not_detected": test_not_detected,
"failure_rate": round(test_failure_rate, 3),
"factor_used": round(test_factor, 3),
"contribution": round(test_factor * WEIGHT_TEST_FAILURES * 100, 2),
}
if test_total == 0:
recs.append("No purple-team tests found — add tests to validate detection.")
elif test_failure_rate >= 0.5:
recs.append(
f"High test failure rate ({test_failure_rate:.0%}) — blue team misses this technique."
)
# ── Aggregate scores ──────────────────────────────────────────────────────
raw_score = (
detection_gap * WEIGHT_DETECTION_GAP
+ ta_factor * WEIGHT_THREAT_ACTORS
+ osint_factor * WEIGHT_OSINT
+ test_factor * WEIGHT_TEST_FAILURES
)
risk_score = _clamp(raw_score) * 100.0
likelihood = _clamp(
detection_gap * 0.5 + ta_factor * 0.35 + osint_factor * 0.15
) * 100.0
impact = _clamp(
test_factor * 0.6 + osint_factor * 0.25 + detection_gap * 0.15
) * 100.0
level = _risk_level(risk_score)
breakdown["total"] = {
"risk_score": round(risk_score, 2),
"likelihood": round(likelihood, 2),
"impact": round(impact, 2),
"risk_level": level,
}
return TechniqueRiskProfile(
technique_id = tech.id,
risk_score = round(risk_score, 4),
likelihood = round(likelihood, 4),
impact = round(impact, 4),
risk_level = level,
detection_gap = round(detection_gap, 4),
threat_actor_count = actor_count,
osint_signal_count = osint_count,
test_fail_count = test_not_detected,
test_total_count = test_total,
test_failure_rate = round(test_failure_rate, 4),
confidence_level = round(confidence_level, 4),
scoring_breakdown = breakdown,
recommendations = recs or ["Risk profile looks healthy — continue monitoring."],
computed_at = datetime.utcnow(),
is_stale = False,
)
# ── Upsert ─────────────────────────────────────────────────────────────────────
def _upsert_profile(db: Session, profile: TechniqueRiskProfile) -> TechniqueRiskProfile:
existing = db.query(TechniqueRiskProfile).filter(
TechniqueRiskProfile.technique_id == profile.technique_id,
).first()
if existing:
for attr in (
"risk_score", "likelihood", "impact", "risk_level",
"detection_gap", "threat_actor_count", "osint_signal_count",
"test_fail_count", "test_total_count", "test_failure_rate",
"confidence_level", "scoring_breakdown", "recommendations",
"computed_at", "is_stale",
):
setattr(existing, attr, getattr(profile, attr))
db.commit()
db.refresh(existing)
return existing
db.add(profile)
db.commit()
db.refresh(profile)
return profile
# ── Public API ────────────────────────────────────────────────────────────────
def compute_technique_risk(db: Session, technique_id: UUID) -> TechniqueRiskProfile:
tech = db.query(Technique).filter(Technique.id == technique_id).first()
if not tech:
raise EntityNotFoundError("Technique", str(technique_id))
profile = _compute_for_technique(db, tech)
return _upsert_profile(db, profile)
def compute_all_risk_scores(db: Session) -> dict:
t0 = time.monotonic()
techniques = db.query(Technique).all()
computed = 0
errors = 0
for tech in techniques:
try:
profile = _compute_for_technique(db, tech)
_upsert_profile(db, profile)
computed += 1
except Exception:
errors += 1
return {
"computed": computed,
"skipped": 0,
"errors": errors,
"duration_seconds": round(time.monotonic() - t0, 2),
}
def get_risk_profile(db: Session, technique_id: UUID) -> TechniqueRiskProfile:
profile = db.query(TechniqueRiskProfile).filter(
TechniqueRiskProfile.technique_id == technique_id,
).first()
if not profile:
raise EntityNotFoundError("TechniqueRiskProfile", str(technique_id))
return profile
def list_risk_profiles(
db: Session,
risk_level: Optional[str] = None,
min_score: Optional[float] = None,
max_score: Optional[float] = None,
stale_only: bool = False,
limit: int = 100,
offset: int = 0,
) -> List[TechniqueRiskProfile]:
q = db.query(TechniqueRiskProfile)
if risk_level:
q = q.filter(TechniqueRiskProfile.risk_level == risk_level)
if min_score is not None:
q = q.filter(TechniqueRiskProfile.risk_score >= min_score)
if max_score is not None:
q = q.filter(TechniqueRiskProfile.risk_score <= max_score)
if stale_only:
q = q.filter(TechniqueRiskProfile.is_stale == True)
return (
q.order_by(TechniqueRiskProfile.risk_score.desc())
.offset(offset).limit(limit).all()
)
def get_risk_matrix(db: Session) -> list:
rows = (
db.query(TechniqueRiskProfile, Technique)
.join(Technique, TechniqueRiskProfile.technique_id == Technique.id)
.order_by(TechniqueRiskProfile.risk_score.desc())
.all()
)
return [
{
"technique_id": str(p.technique_id),
"technique_name": t.name,
"technique_tid": t.technique_id,
"risk_score": p.risk_score,
"likelihood": p.likelihood,
"impact": p.impact,
"risk_level": p.risk_level,
"detection_gap": p.detection_gap,
"computed_at": p.computed_at.isoformat() if p.computed_at else None,
}
for p, t in rows
]
def get_risk_summary(db: Session) -> dict:
all_profiles = db.query(TechniqueRiskProfile).all()
total_tech = db.query(Technique).count()
scored = len(all_profiles)
stale = sum(1 for p in all_profiles if p.is_stale)
by_level: dict = {l: 0 for l in ("critical", "high", "medium", "low", "info")}
score_sum = 0.0
for p in all_profiles:
by_level[p.risk_level] = by_level.get(p.risk_level, 0) + 1
score_sum += p.risk_score
avg_score = (score_sum / scored) if scored > 0 else 0.0
top_rows = (
db.query(TechniqueRiskProfile, Technique)
.join(Technique, TechniqueRiskProfile.technique_id == Technique.id)
.order_by(TechniqueRiskProfile.risk_score.desc())
.limit(5).all()
)
top_risks = [
{
"technique_id": str(p.technique_id),
"technique_name": t.name,
"technique_tid": t.technique_id,
"risk_score": p.risk_score,
"risk_level": p.risk_level,
"likelihood": p.likelihood,
"impact": p.impact,
"detection_gap": p.detection_gap,
"computed_at": p.computed_at.isoformat() if p.computed_at else None,
}
for p, t in top_rows
]
return {
"total_techniques": total_tech,
"scored_techniques": scored,
"stale_count": stale,
"by_level": by_level,
"avg_risk_score": round(avg_score, 2),
"top_risks": top_risks,
}
def get_recommendations(db: Session, limit: int = 20) -> list:
rows = (
db.query(TechniqueRiskProfile, Technique)
.join(Technique, TechniqueRiskProfile.technique_id == Technique.id)
.filter(TechniqueRiskProfile.risk_score > 0)
.order_by(TechniqueRiskProfile.risk_score.desc())
.limit(limit).all()
)
return [
{
"technique_id": str(p.technique_id),
"technique_name": t.name,
"technique_tid": t.technique_id,
"risk_level": p.risk_level,
"risk_score": p.risk_score,
"recommendations": p.recommendations or [],
"priority": i,
}
for i, (p, t) in enumerate(rows, start=1)
]