"""Decay Engine — calculates confidence scores and expires validations.""" import logging from datetime import datetime from typing import Optional from uuid import UUID from sqlalchemy.orm import Session from app.models.detection_lifecycle import ( DetectionAsset, DetectionValidation, DetectionTechniqueMapping, TechniqueConfidenceScore, DetectionConfidence, DetectionHealthStatus, InfrastructureChangeLog, ) from app.models.decay_policy import DecayPolicy from app.models.technique import Technique logger = logging.getLogger(__name__) def _now() -> datetime: return datetime.utcnow() def get_applicable_policy(db: Session, platform: Optional[str] = None, asset_type: Optional[str] = None, tactic: Optional[str] = None) -> DecayPolicy: query = db.query(DecayPolicy).filter(DecayPolicy.is_active == True) if platform: specific = query.filter(DecayPolicy.applies_to_platform == platform).first() if specific: return specific if asset_type: specific = query.filter(DecayPolicy.applies_to_asset_type == asset_type).first() if specific: return specific if tactic: specific = query.filter(DecayPolicy.applies_to_tactic == tactic).first() if specific: return specific default_policy = query.filter(DecayPolicy.is_default == True).first() if default_policy: return default_policy # Return an in-memory default if no DB policy exists p = DecayPolicy() p.fresh_days = 90 p.aging_days = 180 p.stale_days = 365 p.recency_weight = 0.30 p.coverage_weight = 0.30 p.health_weight = 0.25 p.diversity_weight = 0.15 return p def calculate_confidence_for_technique(db: Session, technique_id: UUID) -> Optional[TechniqueConfidenceScore]: technique = db.query(Technique).filter(Technique.id == technique_id).first() if not technique: return None policy = get_applicable_policy(db, tactic=technique.tactic) mappings = db.query(DetectionTechniqueMapping).filter(DetectionTechniqueMapping.technique_id == technique_id).all() asset_ids = [m.detection_asset_id for m in mappings] if not asset_ids: return _create_or_update_score(db, technique_id, confidence_level=DetectionConfidence.unknown, confidence_score=0.0, factors={"recency": 0.0, "coverage": 0.0, "health": 0.0, "diversity": 0.0}, risk_factors=["no_detection_assets"], detection_count=0, valid_count=0, ) assets = db.query(DetectionAsset).filter(DetectionAsset.id.in_(asset_ids), DetectionAsset.is_active == True).all() now = _now() valid_validations = db.query(DetectionValidation).filter( DetectionValidation.detection_asset_id.in_(asset_ids), DetectionValidation.is_valid == True, DetectionValidation.expires_at > now, ).all() recency_factor = 0.0 last_validated = None if valid_validations: most_recent = max(v.validated_at for v in valid_validations) # Strip tzinfo if present so arithmetic stays consistent with naive UTC if most_recent.tzinfo is not None: most_recent = most_recent.replace(tzinfo=None) last_validated = most_recent days_since = (now - most_recent).days if days_since <= policy.fresh_days: recency_factor = 1.0 elif days_since <= policy.aging_days: range_days = policy.aging_days - policy.fresh_days elapsed = days_since - policy.fresh_days recency_factor = 1.0 - (elapsed / range_days) * 0.4 elif days_since <= policy.stale_days: range_days = policy.stale_days - policy.aging_days elapsed = days_since - policy.aging_days recency_factor = 0.6 - (elapsed / range_days) * 0.4 else: recency_factor = max(0.1, 0.2 - ((days_since - policy.stale_days) / 365) * 0.1) active_count = len(assets) valid_count = len(set(v.detection_asset_id for v in valid_validations)) if active_count == 0: coverage_factor = 0.0 elif valid_count >= 3: coverage_factor = 1.0 elif valid_count >= 2: coverage_factor = 0.8 elif valid_count >= 1: coverage_factor = 0.5 else: coverage_factor = 0.1 health_scores = { DetectionHealthStatus.healthy: 1.0, DetectionHealthStatus.silent: 0.4, DetectionHealthStatus.noisy: 0.6, DetectionHealthStatus.orphan: 0.3, DetectionHealthStatus.deprecated: 0.0, DetectionHealthStatus.untested: 0.2, } health_factor = sum(health_scores.get(a.health_status, 0.2) for a in assets) / max(len(assets), 1) platforms = set(a.platform for a in assets if a.platform) asset_types = set(a.asset_type for a in assets) diversity_factor = min(1.0, len(platforms) * 0.3 + len(asset_types) * 0.2) confidence_score = ( recency_factor * policy.recency_weight + coverage_factor * policy.coverage_weight + health_factor * policy.health_weight + diversity_factor * policy.diversity_weight ) * 100 if confidence_score >= 75: confidence_level = DetectionConfidence.fresh elif confidence_score >= 50: confidence_level = DetectionConfidence.aging elif confidence_score >= 25: confidence_level = DetectionConfidence.stale elif confidence_score > 0: confidence_level = DetectionConfidence.broken else: confidence_level = DetectionConfidence.unknown risk_factors = [] if len(platforms) <= 1: risk_factors.append("single_platform") if valid_count == 0: risk_factors.append("no_valid_detections") if any(a.health_status == DetectionHealthStatus.silent for a in assets): risk_factors.append("silent_rules_present") if any(a.health_status == DetectionHealthStatus.orphan for a in assets): risk_factors.append("orphan_rules_present") if recency_factor < 0.5: risk_factors.append("stale_validation") if len(assets) < 2: risk_factors.append("low_detection_diversity") next_due = None if valid_validations: earliest_expiry = min(v.expires_at for v in valid_validations) next_due = earliest_expiry return _create_or_update_score( db, technique_id, confidence_level=confidence_level, confidence_score=round(confidence_score, 1), factors={"recency": round(recency_factor, 3), "coverage": round(coverage_factor, 3), "health": round(health_factor, 3), "diversity": round(diversity_factor, 3)}, risk_factors=risk_factors, detection_count=active_count, valid_count=valid_count, last_validated=last_validated, next_due=next_due, ) def _create_or_update_score(db: Session, technique_id: UUID, **kwargs) -> TechniqueConfidenceScore: score = db.query(TechniqueConfidenceScore).filter(TechniqueConfidenceScore.technique_id == technique_id).first() if not score: score = TechniqueConfidenceScore(technique_id=technique_id) db.add(score) score.confidence_level = kwargs["confidence_level"] score.confidence_score = kwargs["confidence_score"] score.detection_count = kwargs["detection_count"] score.valid_detection_count = kwargs["valid_count"] score.recency_factor = kwargs["factors"]["recency"] score.coverage_factor = kwargs["factors"]["coverage"] score.health_factor = kwargs["factors"]["health"] score.diversity_factor = kwargs["factors"]["diversity"] score.risk_factors = kwargs["risk_factors"] score.score_breakdown = kwargs["factors"] score.last_validated_at = kwargs.get("last_validated") score.next_validation_due = kwargs.get("next_due") score.last_recalculated_at = _now() score.updated_at = _now() db.commit() db.refresh(score) return score def run_decay_engine(db: Session) -> dict: techniques = db.query(Technique).all() results = {"total_techniques": len(techniques), "fresh": 0, "aging": 0, "stale": 0, "broken": 0, "unknown": 0, "validations_expired": 0} now = _now() # Expire stale validations expired = db.query(DetectionValidation).filter( DetectionValidation.is_valid == True, DetectionValidation.expires_at <= now, ).all() from app.models.detection_lifecycle import InvalidationReason for v in expired: v.is_valid = False v.invalidated_at = now v.invalidation_reason = InvalidationReason.time_decay results["validations_expired"] = len(expired) if expired: db.commit() for technique in techniques: score = calculate_confidence_for_technique(db, technique.id) if score: level = score.confidence_level.value results[level] = results.get(level, 0) + 1 logger.info("Decay engine completed: %s", results) return results def process_infrastructure_change(db: Session, change_id: UUID) -> int: change = db.query(InfrastructureChangeLog).filter(InfrastructureChangeLog.id == change_id).first() if not change or not change.auto_invalidate: return 0 query = db.query(DetectionAsset).filter(DetectionAsset.is_active == True) if change.affected_platforms: query = query.filter(DetectionAsset.platform.in_(change.affected_platforms)) affected_assets = query.all() total_invalidated = 0 from app.services.detection_asset_service import invalidate_validations_for_asset for asset in affected_assets: if change.affected_log_sources: asset_log_source = asset.log_source_name or "" if not any(ls in asset_log_source for ls in change.affected_log_sources): continue count = invalidate_validations_for_asset(db, asset.id, change.reported_by, "infrastructure_change") total_invalidated += count change.invalidated_count = total_invalidated db.commit() logger.info("Infrastructure change %s: invalidated %d validations", change_id, total_invalidated) return total_invalidated