diff --git a/backend/alembic/versions/b034_detection_lifecycle.py b/backend/alembic/versions/b034_detection_lifecycle.py new file mode 100644 index 0000000..b52f4a5 --- /dev/null +++ b/backend/alembic/versions/b034_detection_lifecycle.py @@ -0,0 +1,174 @@ +"""Phase 8: Detection Lifecycle Management tables. + +Revision ID: b034dlm +Revises: b033syscfg +""" +from typing import Sequence, Union +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql +from alembic import op + +revision: str = "b034dlm" +down_revision: Union[str, None] = "b033syscfg" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def _table_exists(name: str) -> bool: + bind = op.get_bind() + insp = sa.inspect(bind) + return name in insp.get_table_names() + + +def upgrade() -> None: + if not _table_exists("detection_assets"): + op.create_table( + "detection_assets", + sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True), + sa.Column("name", sa.String(500), nullable=False), + sa.Column("description", sa.Text), + sa.Column("asset_type", sa.String(50), nullable=False), + sa.Column("platform", sa.String(100)), + sa.Column("rule_content", sa.Text), + sa.Column("rule_language", sa.String(50)), + sa.Column("rule_repository_url", sa.Text), + sa.Column("rule_file_path", sa.String(500)), + sa.Column("rule_version", sa.String(50)), + sa.Column("rule_hash", sa.String(64)), + sa.Column("last_rule_change_at", sa.DateTime), + sa.Column("log_source_name", sa.String(200)), + sa.Column("log_source_version", sa.String(50)), + sa.Column("log_source_config", postgresql.JSONB, server_default="{}"), + sa.Column("infrastructure_hash", sa.String(64)), + sa.Column("infrastructure_details", postgresql.JSONB, server_default="{}"), + sa.Column("health_status", sa.String(20), server_default="untested", nullable=False), + sa.Column("last_alert_at", sa.DateTime), + sa.Column("alert_count_30d", sa.Integer, server_default="0"), + sa.Column("false_positive_rate", sa.Float), + sa.Column("expected_alert_frequency", sa.String(50)), + sa.Column("owner_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("users.id", ondelete="SET NULL")), + sa.Column("backup_owner_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("users.id", ondelete="SET NULL")), + sa.Column("team", sa.String(100)), + sa.Column("is_active", sa.Boolean, server_default="true", nullable=False), + sa.Column("tags", postgresql.JSONB, server_default="[]"), + sa.Column("asset_metadata", postgresql.JSONB, server_default="{}"), + sa.Column("created_by", postgresql.UUID(as_uuid=True), sa.ForeignKey("users.id", ondelete="SET NULL")), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()")), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()")), + ) + op.create_index("ix_detection_assets_platform", "detection_assets", ["platform"]) + op.create_index("ix_detection_assets_health_status", "detection_assets", ["health_status"]) + op.create_index("ix_detection_assets_owner_id", "detection_assets", ["owner_id"]) + + if not _table_exists("detection_technique_mappings"): + op.create_table( + "detection_technique_mappings", + sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True), + sa.Column("detection_asset_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("detection_assets.id", ondelete="CASCADE"), nullable=False), + sa.Column("technique_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("techniques.id", ondelete="CASCADE"), nullable=False), + sa.Column("coverage_type", sa.String(50), server_default="detect"), + sa.Column("confidence_level", sa.String(20), server_default="medium"), + sa.Column("notes", sa.Text), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()")), + ) + op.create_index("ix_detection_technique_mappings_technique_id", "detection_technique_mappings", ["technique_id"]) + op.create_index("ix_detection_technique_mappings_asset_id", "detection_technique_mappings", ["detection_asset_id"]) + + if not _table_exists("detection_validations"): + op.create_table( + "detection_validations", + sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True), + sa.Column("detection_asset_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("detection_assets.id", ondelete="CASCADE"), nullable=False), + sa.Column("technique_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("techniques.id", ondelete="SET NULL")), + sa.Column("test_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("tests.id", ondelete="SET NULL")), + sa.Column("validated_at", sa.DateTime), + sa.Column("expires_at", sa.DateTime, nullable=False), + sa.Column("is_valid", sa.Boolean, server_default="true", nullable=False), + sa.Column("validation_result", sa.String(50)), + sa.Column("validation_method", sa.String(100)), + sa.Column("rule_hash_at_validation", sa.String(64)), + sa.Column("log_source_version_at_validation", sa.String(50)), + sa.Column("infrastructure_hash_at_validation", sa.String(64)), + sa.Column("environment_snapshot", postgresql.JSONB, server_default="{}"), + sa.Column("invalidated_at", sa.DateTime), + sa.Column("invalidation_reason", sa.String(50)), + sa.Column("invalidation_details", sa.Text), + sa.Column("invalidated_by", postgresql.UUID(as_uuid=True), sa.ForeignKey("users.id", ondelete="SET NULL")), + sa.Column("validated_by", postgresql.UUID(as_uuid=True), sa.ForeignKey("users.id", ondelete="SET NULL"), nullable=False), + sa.Column("integrity_hash", sa.String(64)), + sa.Column("notes", sa.Text), + sa.Column("evidence_ids", postgresql.JSONB, server_default="[]"), + ) + op.create_index("ix_detection_validations_asset_id_valid", "detection_validations", ["detection_asset_id", "is_valid"]) + op.create_index("ix_detection_validations_expires_at", "detection_validations", ["expires_at"]) + + if not _table_exists("technique_confidence_scores"): + op.create_table( + "technique_confidence_scores", + sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True), + sa.Column("technique_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("techniques.id", ondelete="CASCADE"), nullable=False, unique=True), + sa.Column("confidence_level", sa.String(20), server_default="unknown"), + sa.Column("confidence_score", sa.Float, server_default="0.0"), + sa.Column("detection_count", sa.Integer, server_default="0"), + sa.Column("valid_detection_count", sa.Integer, server_default="0"), + sa.Column("last_validated_at", sa.DateTime), + sa.Column("next_validation_due", sa.DateTime), + sa.Column("last_recalculated_at", sa.DateTime), + sa.Column("recency_factor", sa.Float, server_default="0.0"), + sa.Column("coverage_factor", sa.Float, server_default="0.0"), + sa.Column("health_factor", sa.Float, server_default="0.0"), + sa.Column("diversity_factor", sa.Float, server_default="0.0"), + sa.Column("score_breakdown", postgresql.JSONB, server_default="{}"), + sa.Column("risk_factors", postgresql.JSONB, server_default="[]"), + sa.Column("updated_at", sa.DateTime), + ) + op.create_index("ix_technique_confidence_scores_technique_id", "technique_confidence_scores", ["technique_id"]) + op.create_index("ix_technique_confidence_scores_confidence_level", "technique_confidence_scores", ["confidence_level"]) + + if not _table_exists("infrastructure_change_logs"): + op.create_table( + "infrastructure_change_logs", + sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True), + sa.Column("change_type", sa.String(100), nullable=False), + sa.Column("description", sa.Text, nullable=False), + sa.Column("affected_platforms", postgresql.JSONB, server_default="[]"), + sa.Column("affected_log_sources", postgresql.JSONB, server_default="[]"), + sa.Column("change_date", sa.DateTime), + sa.Column("reported_by", postgresql.UUID(as_uuid=True), sa.ForeignKey("users.id", ondelete="SET NULL")), + sa.Column("auto_invalidate", sa.Boolean, server_default="true"), + sa.Column("invalidated_count", sa.Integer, server_default="0"), + sa.Column("change_metadata", postgresql.JSONB, server_default="{}"), + sa.Column("created_at", sa.DateTime), + ) + op.create_index("ix_infrastructure_change_logs_change_date", "infrastructure_change_logs", ["change_date"]) + + if not _table_exists("decay_policies"): + op.create_table( + "decay_policies", + sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True), + sa.Column("name", sa.String(200), nullable=False), + sa.Column("description", sa.Text), + sa.Column("applies_to_platform", sa.String(100)), + sa.Column("applies_to_asset_type", sa.String(50)), + sa.Column("applies_to_tactic", sa.String(100)), + sa.Column("fresh_days", sa.Integer, server_default="90"), + sa.Column("aging_days", sa.Integer, server_default="180"), + sa.Column("stale_days", sa.Integer, server_default="365"), + sa.Column("default_validity_days", sa.Integer, server_default="180"), + sa.Column("silent_threshold_days", sa.Integer, server_default="30"), + sa.Column("noisy_threshold_daily", sa.Integer, server_default="100"), + sa.Column("recency_weight", sa.Float, server_default="0.3"), + sa.Column("coverage_weight", sa.Float, server_default="0.3"), + sa.Column("health_weight", sa.Float, server_default="0.25"), + sa.Column("diversity_weight", sa.Float, server_default="0.15"), + sa.Column("is_default", sa.Boolean, server_default="false"), + sa.Column("is_active", sa.Boolean, server_default="true"), + sa.Column("created_at", sa.DateTime), + sa.Column("updated_at", sa.DateTime), + ) + + +def downgrade() -> None: + for table in ["decay_policies", "infrastructure_change_logs", "technique_confidence_scores", "detection_validations", "detection_technique_mappings", "detection_assets"]: + if _table_exists(table): + op.drop_table(table) diff --git a/backend/app/jobs/mitre_sync_job.py b/backend/app/jobs/mitre_sync_job.py index 94df5d5..b5049d2 100644 --- a/backend/app/jobs/mitre_sync_job.py +++ b/backend/app/jobs/mitre_sync_job.py @@ -195,6 +195,20 @@ def _run_stale_detection() -> None: db.close() +def _run_decay_engine() -> None: + """Execute the decay engine inside its own DB session.""" + logger.info("Scheduled decay engine job starting...") + db = SessionLocal() + try: + from app.services.decay_engine_service import run_decay_engine + results = run_decay_engine(db) + logger.info("Decay engine job finished — %s", results) + except Exception: + logger.exception("Decay engine job failed") + finally: + db.close() + + # --------------------------------------------------------------------------- # Scheduler bootstrap # --------------------------------------------------------------------------- @@ -292,6 +306,15 @@ def start_scheduler() -> None: name="Data sources auto-sync (every 6h)", replace_existing=True, ) + scheduler.add_job( + _run_decay_engine, + trigger="cron", + hour=2, + minute=0, + id="decay_engine", + name="Detection decay engine (daily 02:00)", + replace_existing=True, + ) scheduler.start() logger.info( "Background scheduler started — mitre_sync (24h), intel_scan (7d), " diff --git a/backend/app/main.py b/backend/app/main.py index 73abe12..ff6fb00 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -38,6 +38,7 @@ from app.routers import analytics as analytics_router from app.routers import advanced_metrics as advanced_metrics_router from app.routers import osint as osint_router from app.routers import webhooks as webhooks_router +from app.routers import detection_lifecycle as detection_lifecycle_router from app.domain.errors import DomainError from app.middleware.error_handler import domain_exception_handler from app.middleware.request_context import RequestContextMiddleware @@ -58,6 +59,16 @@ async def lifespan(app: FastAPI): """Startup / shutdown logic.""" ensure_bucket_exists() start_scheduler() + # Seed decay policies + from app.database import SessionLocal + from app.seed_decay_policies import seed_decay_policies + db = SessionLocal() + try: + seed_decay_policies(db) + except Exception: + pass + finally: + db.close() yield # Graceful shutdown of the background scheduler scheduler.shutdown(wait=False) @@ -125,6 +136,7 @@ app.include_router(analytics_router.router, prefix="/api/v1") app.include_router(advanced_metrics_router.router, prefix="/api/v1") app.include_router(osint_router.router, prefix="/api/v1") app.include_router(webhooks_router.router, prefix="/api/v1") +app.include_router(detection_lifecycle_router.router, prefix="/api/v1") @app.get("/health", include_in_schema=False) diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index af6a30d..1ae717e 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -23,6 +23,12 @@ from app.models.scoring_config import ScoringConfig from app.models.enums import TechniqueStatus, TestState, TestResult, TeamSide from app.models.webhook_config import WebhookConfig from app.models.system_config import SystemConfig +from app.models.detection_lifecycle import ( + DetectionAsset, DetectionTechniqueMapping, DetectionValidation, + TechniqueConfidenceScore, InfrastructureChangeLog, + DetectionConfidence, DetectionHealthStatus, InvalidationReason, +) +from app.models.decay_policy import DecayPolicy __all__ = [ "User", "Technique", "Test", "TestTemplate", "Evidence", @@ -37,4 +43,6 @@ __all__ = [ "Worklog", "OsintItem", "ScoringConfig", "TechniqueStatus", "TestState", "TestResult", "TeamSide", "WebhookConfig", "SystemConfig", + "DetectionAsset", "DetectionTechniqueMapping", "DetectionValidation", + "TechniqueConfidenceScore", "InfrastructureChangeLog", "DecayPolicy", ] diff --git a/backend/app/models/decay_policy.py b/backend/app/models/decay_policy.py new file mode 100644 index 0000000..c7985cf --- /dev/null +++ b/backend/app/models/decay_policy.py @@ -0,0 +1,32 @@ +"""Decay Policy model — configurable detection validity rules.""" + +import uuid +from datetime import datetime +from sqlalchemy import Column, String, Integer, Float, Boolean, DateTime, Text +from sqlalchemy.dialects.postgresql import UUID +from app.database import Base + + +class DecayPolicy(Base): + __tablename__ = "decay_policies" + + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + name = Column(String(200), nullable=False) + description = Column(Text) + applies_to_platform = Column(String(100)) + applies_to_asset_type = Column(String(50)) + applies_to_tactic = Column(String(100)) + fresh_days = Column(Integer, default=90, server_default='90') + aging_days = Column(Integer, default=180, server_default='180') + stale_days = Column(Integer, default=365, server_default='365') + default_validity_days = Column(Integer, default=180, server_default='180') + silent_threshold_days = Column(Integer, default=30, server_default='30') + noisy_threshold_daily = Column(Integer, default=100, server_default='100') + recency_weight = Column(Float, default=0.3, server_default='0.3') + coverage_weight = Column(Float, default=0.3, server_default='0.3') + health_weight = Column(Float, default=0.25, server_default='0.25') + diversity_weight = Column(Float, default=0.15, server_default='0.15') + is_default = Column(Boolean, default=False, server_default='false') + is_active = Column(Boolean, default=True, server_default='true') + created_at = Column(DateTime, default=datetime.utcnow) + updated_at = Column(DateTime, default=datetime.utcnow) diff --git a/backend/app/models/detection_lifecycle.py b/backend/app/models/detection_lifecycle.py new file mode 100644 index 0000000..b632f03 --- /dev/null +++ b/backend/app/models/detection_lifecycle.py @@ -0,0 +1,168 @@ +"""Detection Lifecycle Management models.""" + +import uuid +import enum +from datetime import datetime +from sqlalchemy import ( + Column, String, Integer, Float, Boolean, DateTime, + ForeignKey, Text, Enum as SQLEnum +) +from sqlalchemy.dialects.postgresql import UUID, JSONB +from sqlalchemy.orm import relationship +from app.database import Base + + +class DetectionConfidence(str, enum.Enum): + fresh = "fresh" + aging = "aging" + stale = "stale" + broken = "broken" + unknown = "unknown" + + +class DetectionHealthStatus(str, enum.Enum): + healthy = "healthy" + silent = "silent" + noisy = "noisy" + orphan = "orphan" + deprecated = "deprecated" + untested = "untested" + + +class InvalidationReason(str, enum.Enum): + time_decay = "time_decay" + mitre_update = "mitre_update" + log_source_change = "log_source_change" + siem_update = "siem_update" + edr_update = "edr_update" + infrastructure_change = "infrastructure_change" + parser_change = "parser_change" + manual = "manual" + rule_modified = "rule_modified" + + +class DetectionAsset(Base): + __tablename__ = "detection_assets" + + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + name = Column(String(500), nullable=False) + description = Column(Text) + asset_type = Column(String(50), nullable=False) + platform = Column(String(100)) + rule_content = Column(Text) + rule_language = Column(String(50)) + rule_repository_url = Column(Text) + rule_file_path = Column(String(500)) + rule_version = Column(String(50)) + rule_hash = Column(String(64)) + last_rule_change_at = Column(DateTime) + log_source_name = Column(String(200)) + log_source_version = Column(String(50)) + log_source_config = Column(JSONB, server_default='{}') + infrastructure_hash = Column(String(64)) + infrastructure_details = Column(JSONB, server_default='{}') + health_status = Column( + SQLEnum(DetectionHealthStatus, name="detectionhealthstatus"), + default=DetectionHealthStatus.untested, + nullable=False, + server_default="untested", + ) + last_alert_at = Column(DateTime) + alert_count_30d = Column(Integer, default=0, server_default='0') + false_positive_rate = Column(Float) + expected_alert_frequency = Column(String(50)) + owner_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True) + backup_owner_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True) + team = Column(String(100)) + is_active = Column(Boolean, default=True, nullable=False, server_default='true') + tags = Column(JSONB, server_default='[]') + asset_metadata = Column(JSONB, server_default='{}') + created_by = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True) + created_at = Column(DateTime(timezone=True), server_default='now()') + updated_at = Column(DateTime(timezone=True), server_default='now()') + + technique_mappings = relationship("DetectionTechniqueMapping", back_populates="detection_asset", cascade="all, delete-orphan") + validations = relationship("DetectionValidation", back_populates="detection_asset", cascade="all, delete-orphan") + + +class DetectionTechniqueMapping(Base): + __tablename__ = "detection_technique_mappings" + + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + detection_asset_id = Column(UUID(as_uuid=True), ForeignKey("detection_assets.id", ondelete="CASCADE"), nullable=False) + technique_id = Column(UUID(as_uuid=True), ForeignKey("techniques.id", ondelete="CASCADE"), nullable=False) + coverage_type = Column(String(50), default="detect", server_default="detect") + confidence_level = Column(String(20), default="medium", server_default="medium") + notes = Column(Text) + created_at = Column(DateTime(timezone=True), server_default='now()') + + detection_asset = relationship("DetectionAsset", back_populates="technique_mappings") + + +class DetectionValidation(Base): + __tablename__ = "detection_validations" + + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + detection_asset_id = Column(UUID(as_uuid=True), ForeignKey("detection_assets.id", ondelete="CASCADE"), nullable=False) + technique_id = Column(UUID(as_uuid=True), ForeignKey("techniques.id", ondelete="SET NULL"), nullable=True) + test_id = Column(UUID(as_uuid=True), ForeignKey("tests.id", ondelete="SET NULL"), nullable=True) + validated_at = Column(DateTime, default=datetime.utcnow) + expires_at = Column(DateTime, nullable=False) + is_valid = Column(Boolean, default=True, nullable=False, server_default='true') + validation_result = Column(String(50)) + validation_method = Column(String(100)) + rule_hash_at_validation = Column(String(64)) + log_source_version_at_validation = Column(String(50)) + infrastructure_hash_at_validation = Column(String(64)) + environment_snapshot = Column(JSONB, server_default='{}') + invalidated_at = Column(DateTime) + invalidation_reason = Column(SQLEnum(InvalidationReason, name="invalidationreason"), nullable=True) + invalidation_details = Column(Text) + invalidated_by = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True) + validated_by = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=False) + integrity_hash = Column(String(64)) + notes = Column(Text) + evidence_ids = Column(JSONB, server_default='[]') + + detection_asset = relationship("DetectionAsset", back_populates="validations") + + +class TechniqueConfidenceScore(Base): + __tablename__ = "technique_confidence_scores" + + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + technique_id = Column(UUID(as_uuid=True), ForeignKey("techniques.id", ondelete="CASCADE"), nullable=False, unique=True) + confidence_level = Column( + SQLEnum(DetectionConfidence, name="detectionconfidence"), + default=DetectionConfidence.unknown, + server_default="unknown", + ) + confidence_score = Column(Float, default=0.0, server_default='0.0') + detection_count = Column(Integer, default=0, server_default='0') + valid_detection_count = Column(Integer, default=0, server_default='0') + last_validated_at = Column(DateTime) + next_validation_due = Column(DateTime) + last_recalculated_at = Column(DateTime, default=datetime.utcnow) + recency_factor = Column(Float, default=0.0, server_default='0.0') + coverage_factor = Column(Float, default=0.0, server_default='0.0') + health_factor = Column(Float, default=0.0, server_default='0.0') + diversity_factor = Column(Float, default=0.0, server_default='0.0') + score_breakdown = Column(JSONB, server_default='{}') + risk_factors = Column(JSONB, server_default='[]') + updated_at = Column(DateTime, default=datetime.utcnow) + + +class InfrastructureChangeLog(Base): + __tablename__ = "infrastructure_change_logs" + + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + change_type = Column(String(100), nullable=False) + description = Column(Text, nullable=False) + affected_platforms = Column(JSONB, server_default='[]') + affected_log_sources = Column(JSONB, server_default='[]') + change_date = Column(DateTime, default=datetime.utcnow) + reported_by = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True) + auto_invalidate = Column(Boolean, default=True, server_default='true') + invalidated_count = Column(Integer, default=0, server_default='0') + change_metadata = Column(JSONB, server_default='{}') + created_at = Column(DateTime, default=datetime.utcnow) diff --git a/backend/app/routers/detection_lifecycle.py b/backend/app/routers/detection_lifecycle.py new file mode 100644 index 0000000..4f63ee5 --- /dev/null +++ b/backend/app/routers/detection_lifecycle.py @@ -0,0 +1,302 @@ +"""Detection Lifecycle Management router.""" + +import hashlib +from datetime import datetime, timezone, timedelta +from typing import Optional +from uuid import UUID + +from fastapi import APIRouter, Depends, Query +from sqlalchemy import func +from sqlalchemy.orm import Session + +from app.database import get_db +from app.dependencies.auth import get_current_user, require_any_role +from app.domain.exceptions import EntityNotFoundError +from app.models.detection_lifecycle import ( + DetectionAsset, DetectionTechniqueMapping, DetectionValidation, + TechniqueConfidenceScore, InfrastructureChangeLog, +) +from app.schemas.detection_lifecycle_schema import ( + DetectionAssetCreate, DetectionAssetUpdate, DetectionAssetOut, + DetectionValidationCreate, DetectionValidationOut, + TechniqueConfidenceOut, + InfrastructureChangeCreate, InfrastructureChangeOut, +) +from app.services import detection_asset_service, decay_engine_service, audit_service + +router = APIRouter(prefix="/detection-lifecycle", tags=["detection-lifecycle"]) + + +def _now() -> datetime: + return datetime.now(timezone.utc) + + +# ── Detection Assets ───────────────────────────────────────────────────────── + +@router.post("/assets", response_model=DetectionAssetOut, status_code=201) +def create_asset(body: DetectionAssetCreate, db: Session = Depends(get_db), user=Depends(get_current_user)): + asset = detection_asset_service.create_detection_asset(db, body.model_dump(), user.id) + return asset + + +@router.get("/assets", response_model=list[DetectionAssetOut]) +def list_assets( + platform: Optional[str] = None, + asset_type: Optional[str] = None, + health_status: Optional[str] = None, + technique_id: Optional[UUID] = None, + is_active: Optional[bool] = True, + db: Session = Depends(get_db), + user=Depends(get_current_user), +): + return detection_asset_service.list_assets(db, platform=platform, asset_type=asset_type, health_status=health_status, technique_id=technique_id, is_active=is_active) + + +@router.get("/assets/{asset_id}", response_model=DetectionAssetOut) +def get_asset(asset_id: UUID, db: Session = Depends(get_db), user=Depends(get_current_user)): + return detection_asset_service.get_asset_with_details(db, asset_id) + + +@router.patch("/assets/{asset_id}", response_model=DetectionAssetOut) +def update_asset(asset_id: UUID, body: DetectionAssetUpdate, db: Session = Depends(get_db), user=Depends(get_current_user)): + return detection_asset_service.update_detection_asset(db, asset_id, body.model_dump(exclude_unset=True), user.id) + + +@router.delete("/assets/{asset_id}", status_code=204) +def delete_asset(asset_id: UUID, db: Session = Depends(get_db), user=Depends(require_any_role("red_lead", "blue_lead"))): + asset = db.query(DetectionAsset).filter(DetectionAsset.id == asset_id).first() + if not asset: + raise EntityNotFoundError("DetectionAsset", str(asset_id)) + asset.is_active = False + db.commit() + + +# ── Technique Mappings ─────────────────────────────────────────────────────── + +@router.post("/assets/{asset_id}/techniques/{technique_id}") +def map_technique( + asset_id: UUID, technique_id: UUID, + coverage_type: str = Query("detect"), + confidence_level: str = Query("medium"), + db: Session = Depends(get_db), + user=Depends(get_current_user), +): + mapping = DetectionTechniqueMapping( + detection_asset_id=asset_id, technique_id=technique_id, + coverage_type=coverage_type, confidence_level=confidence_level, + ) + db.add(mapping) + db.commit() + return {"message": "Technique mapped", "mapping_id": str(mapping.id)} + + +@router.get("/techniques/{technique_id}/detections") +def get_technique_detections(technique_id: UUID, db: Session = Depends(get_db), user=Depends(get_current_user)): + return detection_asset_service.get_technique_detection_summary(db, technique_id) + + +# ── Validations ────────────────────────────────────────────────────────────── + +@router.post("/validations", response_model=DetectionValidationOut, status_code=201) +def create_validation(body: DetectionValidationCreate, db: Session = Depends(get_db), user=Depends(get_current_user)): + asset = db.query(DetectionAsset).filter(DetectionAsset.id == body.detection_asset_id).first() + if not asset: + raise EntityNotFoundError("DetectionAsset", str(body.detection_asset_id)) + + now = _now() + validation = DetectionValidation( + detection_asset_id=body.detection_asset_id, + technique_id=body.technique_id, + test_id=body.test_id, + validation_result=body.validation_result, + validation_method=body.validation_method, + notes=body.notes, + evidence_ids=[str(e) for e in (body.evidence_ids or [])], + validated_by=user.id, + validated_at=now, + expires_at=now + timedelta(days=body.validity_days), + rule_hash_at_validation=asset.rule_hash, + log_source_version_at_validation=asset.log_source_version, + infrastructure_hash_at_validation=asset.infrastructure_hash, + ) + data = f"{validation.detection_asset_id}:{validation.validated_by}:{validation.validation_result}:{validation.validated_at}" + validation.integrity_hash = hashlib.sha256(data.encode()).hexdigest() + + db.add(validation) + db.commit() + db.refresh(validation) + + if body.technique_id: + decay_engine_service.calculate_confidence_for_technique(db, body.technique_id) + + audit_service.log_action(db, user.id, "DETECTION_VALIDATED", "detection_validation", str(validation.id), + details={"asset_id": str(body.detection_asset_id), "result": body.validation_result, "validity_days": body.validity_days}) + + return validation + + +@router.get("/validations", response_model=list[DetectionValidationOut]) +def list_validations( + asset_id: Optional[UUID] = None, + technique_id: Optional[UUID] = None, + is_valid: Optional[bool] = None, + db: Session = Depends(get_db), + user=Depends(get_current_user), +): + query = db.query(DetectionValidation) + if asset_id: + query = query.filter(DetectionValidation.detection_asset_id == asset_id) + if technique_id: + query = query.filter(DetectionValidation.technique_id == technique_id) + if is_valid is not None: + query = query.filter(DetectionValidation.is_valid == is_valid) + return query.order_by(DetectionValidation.validated_at.desc()).all() + + +@router.post("/validations/{validation_id}/invalidate") +def invalidate_validation( + validation_id: UUID, + reason: str = Query(...), + details: Optional[str] = None, + db: Session = Depends(get_db), + user=Depends(require_any_role("admin", "blue_lead")), +): + validation = db.query(DetectionValidation).filter(DetectionValidation.id == validation_id).first() + if not validation: + raise EntityNotFoundError("DetectionValidation", str(validation_id)) + + from app.models.detection_lifecycle import InvalidationReason + try: + reason_enum = InvalidationReason(reason) + except ValueError: + reason_enum = InvalidationReason.manual + + validation.is_valid = False + validation.invalidated_at = _now() + validation.invalidation_reason = reason_enum + validation.invalidation_details = details + validation.invalidated_by = user.id + db.commit() + return {"message": "Validation invalidated"} + + +# ── Confidence Scores ──────────────────────────────────────────────────────── + +@router.get("/confidence", response_model=list[TechniqueConfidenceOut]) +def list_confidence_scores( + confidence_level: Optional[str] = None, + min_score: Optional[float] = None, + max_score: Optional[float] = None, + db: Session = Depends(get_db), + user=Depends(get_current_user), +): + query = db.query(TechniqueConfidenceScore) + if confidence_level: + query = query.filter(TechniqueConfidenceScore.confidence_level == confidence_level) + if min_score is not None: + query = query.filter(TechniqueConfidenceScore.confidence_score >= min_score) + if max_score is not None: + query = query.filter(TechniqueConfidenceScore.confidence_score <= max_score) + return query.order_by(TechniqueConfidenceScore.confidence_score.asc()).all() + + +@router.get("/confidence/{technique_id}", response_model=TechniqueConfidenceOut) +def get_technique_confidence( + technique_id: UUID, + recalculate: bool = Query(False), + db: Session = Depends(get_db), + user=Depends(get_current_user), +): + if recalculate: + return decay_engine_service.calculate_confidence_for_technique(db, technique_id) + score = db.query(TechniqueConfidenceScore).filter(TechniqueConfidenceScore.technique_id == technique_id).first() + if not score: + return decay_engine_service.calculate_confidence_for_technique(db, technique_id) + return score + + +# ── Infrastructure Changes ─────────────────────────────────────────────────── + +@router.post("/infrastructure-changes", response_model=InfrastructureChangeOut, status_code=201) +def report_infrastructure_change( + body: InfrastructureChangeCreate, + db: Session = Depends(get_db), + user=Depends(require_any_role("admin", "blue_lead")), +): + change = InfrastructureChangeLog( + change_type=body.change_type, + description=body.description, + affected_platforms=body.affected_platforms, + affected_log_sources=body.affected_log_sources, + change_date=body.change_date or _now(), + auto_invalidate=body.auto_invalidate, + reported_by=user.id, + ) + db.add(change) + db.commit() + db.refresh(change) + + if change.auto_invalidate: + decay_engine_service.process_infrastructure_change(db, change.id) + db.refresh(change) + + audit_service.log_action(db, user.id, "INFRASTRUCTURE_CHANGE_REPORTED", "infrastructure_change", str(change.id), + details={"type": body.change_type, "invalidated_count": change.invalidated_count}) + + return change + + +@router.get("/infrastructure-changes", response_model=list[InfrastructureChangeOut]) +def list_infrastructure_changes( + days: int = Query(90, ge=1, le=730), + db: Session = Depends(get_db), + user=Depends(get_current_user), +): + cutoff = _now() - timedelta(days=days) + return db.query(InfrastructureChangeLog).filter(InfrastructureChangeLog.change_date >= cutoff).order_by(InfrastructureChangeLog.change_date.desc()).all() + + +# ── Decay Engine Control ───────────────────────────────────────────────────── + +@router.post("/decay-engine/run") +def trigger_decay_engine(db: Session = Depends(get_db), user=Depends(require_any_role("admin"))): + results = decay_engine_service.run_decay_engine(db) + return {"message": "Decay engine completed", "results": results} + + +# ── Dashboard ──────────────────────────────────────────────────────────────── + +@router.get("/dashboard") +def lifecycle_dashboard(db: Session = Depends(get_db), user=Depends(get_current_user)): + now = _now() + + health_dist = dict( + db.query(DetectionAsset.health_status, func.count(DetectionAsset.id)) + .filter(DetectionAsset.is_active == True) + .group_by(DetectionAsset.health_status) + .all() + ) + confidence_dist = dict( + db.query(TechniqueConfidenceScore.confidence_level, func.count(TechniqueConfidenceScore.id)) + .group_by(TechniqueConfidenceScore.confidence_level) + .all() + ) + expiring_soon = db.query(func.count(DetectionValidation.id)).filter( + DetectionValidation.is_valid == True, + DetectionValidation.expires_at <= (now + timedelta(days=7)), + ).scalar() or 0 + + total_assets = db.query(func.count(DetectionAsset.id)).filter(DetectionAsset.is_active == True).scalar() or 0 + total_valid = db.query(func.count(DetectionValidation.id)).filter(DetectionValidation.is_valid == True).scalar() or 0 + recent_changes = db.query(func.count(InfrastructureChangeLog.id)).filter( + InfrastructureChangeLog.change_date >= (now - timedelta(days=30)) + ).scalar() or 0 + + return { + "total_detection_assets": total_assets, + "total_valid_validations": total_valid, + "health_distribution": {k.value if hasattr(k, "value") else str(k): v for k, v in health_dist.items()}, + "confidence_distribution": {k.value if hasattr(k, "value") else str(k): v for k, v in confidence_dist.items()}, + "validations_expiring_7d": expiring_soon, + "infrastructure_changes_30d": recent_changes, + } diff --git a/backend/app/schemas/detection_lifecycle_schema.py b/backend/app/schemas/detection_lifecycle_schema.py new file mode 100644 index 0000000..767f88d --- /dev/null +++ b/backend/app/schemas/detection_lifecycle_schema.py @@ -0,0 +1,140 @@ +"""Pydantic schemas for Detection Lifecycle endpoints.""" + +from pydantic import BaseModel, Field, ConfigDict +from typing import Optional +from uuid import UUID +from datetime import datetime +from app.models.detection_lifecycle import ( + DetectionConfidence, DetectionHealthStatus, InvalidationReason +) + + +class DetectionAssetCreate(BaseModel): + name: str = Field(..., min_length=3, max_length=500) + description: Optional[str] = None + asset_type: str = Field(..., pattern=r'^(siem_rule|edr_rule|sigma_rule|yara_rule|spl_query|kql_query|custom_script)$') + platform: Optional[str] = None + rule_content: Optional[str] = None + rule_language: Optional[str] = None + rule_repository_url: Optional[str] = None + rule_file_path: Optional[str] = None + rule_version: Optional[str] = None + log_source_name: Optional[str] = None + log_source_version: Optional[str] = None + log_source_config: Optional[dict] = Field(default_factory=dict) + infrastructure_details: Optional[dict] = Field(default_factory=dict) + expected_alert_frequency: Optional[str] = None + tags: Optional[list[str]] = Field(default_factory=list) + technique_ids: Optional[list[UUID]] = Field(default_factory=list) + + +class DetectionAssetUpdate(BaseModel): + name: Optional[str] = None + description: Optional[str] = None + rule_content: Optional[str] = None + rule_version: Optional[str] = None + log_source_version: Optional[str] = None + infrastructure_details: Optional[dict] = None + expected_alert_frequency: Optional[str] = None + health_status: Optional[DetectionHealthStatus] = None + last_alert_at: Optional[datetime] = None + alert_count_30d: Optional[int] = None + false_positive_rate: Optional[float] = None + owner_id: Optional[UUID] = None + backup_owner_id: Optional[UUID] = None + team: Optional[str] = None + tags: Optional[list[str]] = None + is_active: Optional[bool] = None + + +class DetectionAssetOut(BaseModel): + model_config = ConfigDict(from_attributes=True) + + id: UUID + name: str + description: Optional[str] = None + asset_type: str + platform: Optional[str] = None + rule_language: Optional[str] = None + rule_version: Optional[str] = None + rule_hash: Optional[str] = None + health_status: DetectionHealthStatus + last_alert_at: Optional[datetime] = None + alert_count_30d: int + false_positive_rate: Optional[float] = None + expected_alert_frequency: Optional[str] = None + owner_id: Optional[UUID] = None + team: Optional[str] = None + is_active: bool + tags: list = Field(default_factory=list) + created_at: Optional[datetime] = None + updated_at: Optional[datetime] = None + + +class DetectionValidationCreate(BaseModel): + detection_asset_id: UUID + technique_id: Optional[UUID] = None + test_id: Optional[UUID] = None + validation_result: str = Field(..., pattern=r'^(detected|not_detected|partial|error)$') + validation_method: str + notes: Optional[str] = None + evidence_ids: Optional[list[UUID]] = Field(default_factory=list) + validity_days: int = Field(default=180, ge=30, le=730) + + +class DetectionValidationOut(BaseModel): + model_config = ConfigDict(from_attributes=True) + + id: UUID + detection_asset_id: UUID + technique_id: Optional[UUID] = None + validated_at: Optional[datetime] = None + expires_at: datetime + is_valid: bool + validation_result: Optional[str] = None + validation_method: Optional[str] = None + invalidated_at: Optional[datetime] = None + invalidation_reason: Optional[InvalidationReason] = None + validated_by: Optional[UUID] = None + notes: Optional[str] = None + + +class TechniqueConfidenceOut(BaseModel): + model_config = ConfigDict(from_attributes=True) + + technique_id: UUID + confidence_level: DetectionConfidence + confidence_score: float + detection_count: int + valid_detection_count: int + last_validated_at: Optional[datetime] = None + next_validation_due: Optional[datetime] = None + recency_factor: float + coverage_factor: float + health_factor: float + diversity_factor: float + risk_factors: list = Field(default_factory=list) + + +class InfrastructureChangeCreate(BaseModel): + change_type: str + description: str = Field(..., min_length=10) + affected_platforms: list[str] = Field(default_factory=list) + affected_log_sources: list[str] = Field(default_factory=list) + change_date: Optional[datetime] = None + auto_invalidate: bool = True + + +class InfrastructureChangeOut(BaseModel): + model_config = ConfigDict(from_attributes=True) + + id: UUID + change_type: str + description: str + affected_platforms: list = Field(default_factory=list) + affected_log_sources: list = Field(default_factory=list) + change_date: Optional[datetime] = None + auto_invalidate: bool + invalidated_count: int + reported_by: Optional[UUID] = None + created_at: Optional[datetime] = None diff --git a/backend/app/seed_decay_policies.py b/backend/app/seed_decay_policies.py new file mode 100644 index 0000000..6df7761 --- /dev/null +++ b/backend/app/seed_decay_policies.py @@ -0,0 +1,39 @@ +"""Seed default decay policies.""" +from datetime import datetime +from sqlalchemy.orm import Session +from app.models.decay_policy import DecayPolicy + + +def seed_decay_policies(db: Session) -> None: + existing = db.query(DecayPolicy).filter(DecayPolicy.is_default == True).first() + if existing: + return + + now = datetime.utcnow() + default_policy = DecayPolicy( + name="Default Decay Policy", + description="Standard: Fresh 90d, Aging 91-180d, Stale 181-365d.", + fresh_days=90, aging_days=180, stale_days=365, + default_validity_days=180, silent_threshold_days=30, + noisy_threshold_daily=100, + recency_weight=0.30, coverage_weight=0.30, + health_weight=0.25, diversity_weight=0.15, + is_default=True, is_active=True, + created_at=now, updated_at=now, + ) + db.add(default_policy) + + critical_policy = DecayPolicy( + name="Critical Techniques Policy", + description="Stricter: Fresh 60d, Aging 90d, Stale 180d.", + applies_to_tactic="initial-access", + fresh_days=60, aging_days=90, stale_days=180, + default_validity_days=90, silent_threshold_days=14, + noisy_threshold_daily=50, + recency_weight=0.35, coverage_weight=0.30, + health_weight=0.25, diversity_weight=0.10, + is_default=False, is_active=True, + created_at=now, updated_at=now, + ) + db.add(critical_policy) + db.commit() diff --git a/backend/app/services/decay_engine_service.py b/backend/app/services/decay_engine_service.py new file mode 100644 index 0000000..b8b52c7 --- /dev/null +++ b/backend/app/services/decay_engine_service.py @@ -0,0 +1,260 @@ +"""Decay Engine — calculates confidence scores and expires validations.""" + +import logging +from datetime import datetime, timezone +from typing import Optional +from uuid import UUID + +from sqlalchemy.orm import Session + +from app.models.detection_lifecycle import ( + DetectionAsset, DetectionValidation, + DetectionTechniqueMapping, TechniqueConfidenceScore, + DetectionConfidence, DetectionHealthStatus, + InfrastructureChangeLog, +) +from app.models.decay_policy import DecayPolicy +from app.models.technique import Technique + +logger = logging.getLogger(__name__) + + +def _now() -> datetime: + return datetime.now(timezone.utc) + + +def get_applicable_policy(db: Session, platform: Optional[str] = None, asset_type: Optional[str] = None, tactic: Optional[str] = None) -> DecayPolicy: + query = db.query(DecayPolicy).filter(DecayPolicy.is_active == True) + if platform: + specific = query.filter(DecayPolicy.applies_to_platform == platform).first() + if specific: + return specific + if asset_type: + specific = query.filter(DecayPolicy.applies_to_asset_type == asset_type).first() + if specific: + return specific + if tactic: + specific = query.filter(DecayPolicy.applies_to_tactic == tactic).first() + if specific: + return specific + default_policy = query.filter(DecayPolicy.is_default == True).first() + if default_policy: + return default_policy + # Return an in-memory default if no DB policy exists + p = DecayPolicy() + p.fresh_days = 90 + p.aging_days = 180 + p.stale_days = 365 + p.recency_weight = 0.30 + p.coverage_weight = 0.30 + p.health_weight = 0.25 + p.diversity_weight = 0.15 + return p + + +def calculate_confidence_for_technique(db: Session, technique_id: UUID) -> Optional[TechniqueConfidenceScore]: + technique = db.query(Technique).filter(Technique.id == technique_id).first() + if not technique: + return None + + policy = get_applicable_policy(db, tactic=technique.tactic) + mappings = db.query(DetectionTechniqueMapping).filter(DetectionTechniqueMapping.technique_id == technique_id).all() + asset_ids = [m.detection_asset_id for m in mappings] + + if not asset_ids: + return _create_or_update_score(db, technique_id, + confidence_level=DetectionConfidence.unknown, confidence_score=0.0, + factors={"recency": 0.0, "coverage": 0.0, "health": 0.0, "diversity": 0.0}, + risk_factors=["no_detection_assets"], detection_count=0, valid_count=0, + ) + + assets = db.query(DetectionAsset).filter(DetectionAsset.id.in_(asset_ids), DetectionAsset.is_active == True).all() + now = _now() + + valid_validations = db.query(DetectionValidation).filter( + DetectionValidation.detection_asset_id.in_(asset_ids), + DetectionValidation.is_valid == True, + DetectionValidation.expires_at > now, + ).all() + + recency_factor = 0.0 + last_validated = None + if valid_validations: + most_recent = max(v.validated_at for v in valid_validations) + # Make timezone-aware if needed + if most_recent.tzinfo is None: + from datetime import timezone as _tz + most_recent = most_recent.replace(tzinfo=_tz.utc) + last_validated = most_recent + days_since = (now - most_recent).days + if days_since <= policy.fresh_days: + recency_factor = 1.0 + elif days_since <= policy.aging_days: + range_days = policy.aging_days - policy.fresh_days + elapsed = days_since - policy.fresh_days + recency_factor = 1.0 - (elapsed / range_days) * 0.4 + elif days_since <= policy.stale_days: + range_days = policy.stale_days - policy.aging_days + elapsed = days_since - policy.aging_days + recency_factor = 0.6 - (elapsed / range_days) * 0.4 + else: + recency_factor = max(0.1, 0.2 - ((days_since - policy.stale_days) / 365) * 0.1) + + active_count = len(assets) + valid_count = len(set(v.detection_asset_id for v in valid_validations)) + + if active_count == 0: + coverage_factor = 0.0 + elif valid_count >= 3: + coverage_factor = 1.0 + elif valid_count >= 2: + coverage_factor = 0.8 + elif valid_count >= 1: + coverage_factor = 0.5 + else: + coverage_factor = 0.1 + + health_scores = { + DetectionHealthStatus.healthy: 1.0, + DetectionHealthStatus.silent: 0.4, + DetectionHealthStatus.noisy: 0.6, + DetectionHealthStatus.orphan: 0.3, + DetectionHealthStatus.deprecated: 0.0, + DetectionHealthStatus.untested: 0.2, + } + health_factor = sum(health_scores.get(a.health_status, 0.2) for a in assets) / max(len(assets), 1) + + platforms = set(a.platform for a in assets if a.platform) + asset_types = set(a.asset_type for a in assets) + diversity_factor = min(1.0, len(platforms) * 0.3 + len(asset_types) * 0.2) + + confidence_score = ( + recency_factor * policy.recency_weight + + coverage_factor * policy.coverage_weight + + health_factor * policy.health_weight + + diversity_factor * policy.diversity_weight + ) * 100 + + if confidence_score >= 75: + confidence_level = DetectionConfidence.fresh + elif confidence_score >= 50: + confidence_level = DetectionConfidence.aging + elif confidence_score >= 25: + confidence_level = DetectionConfidence.stale + elif confidence_score > 0: + confidence_level = DetectionConfidence.broken + else: + confidence_level = DetectionConfidence.unknown + + risk_factors = [] + if len(platforms) <= 1: + risk_factors.append("single_platform") + if valid_count == 0: + risk_factors.append("no_valid_detections") + if any(a.health_status == DetectionHealthStatus.silent for a in assets): + risk_factors.append("silent_rules_present") + if any(a.health_status == DetectionHealthStatus.orphan for a in assets): + risk_factors.append("orphan_rules_present") + if recency_factor < 0.5: + risk_factors.append("stale_validation") + if len(assets) < 2: + risk_factors.append("low_detection_diversity") + + next_due = None + if valid_validations: + earliest_expiry = min(v.expires_at for v in valid_validations) + next_due = earliest_expiry + + return _create_or_update_score( + db, technique_id, + confidence_level=confidence_level, + confidence_score=round(confidence_score, 1), + factors={"recency": round(recency_factor, 3), "coverage": round(coverage_factor, 3), "health": round(health_factor, 3), "diversity": round(diversity_factor, 3)}, + risk_factors=risk_factors, + detection_count=active_count, + valid_count=valid_count, + last_validated=last_validated, + next_due=next_due, + ) + + +def _create_or_update_score(db: Session, technique_id: UUID, **kwargs) -> TechniqueConfidenceScore: + score = db.query(TechniqueConfidenceScore).filter(TechniqueConfidenceScore.technique_id == technique_id).first() + if not score: + score = TechniqueConfidenceScore(technique_id=technique_id) + db.add(score) + + score.confidence_level = kwargs["confidence_level"] + score.confidence_score = kwargs["confidence_score"] + score.detection_count = kwargs["detection_count"] + score.valid_detection_count = kwargs["valid_count"] + score.recency_factor = kwargs["factors"]["recency"] + score.coverage_factor = kwargs["factors"]["coverage"] + score.health_factor = kwargs["factors"]["health"] + score.diversity_factor = kwargs["factors"]["diversity"] + score.risk_factors = kwargs["risk_factors"] + score.score_breakdown = kwargs["factors"] + score.last_validated_at = kwargs.get("last_validated") + score.next_validation_due = kwargs.get("next_due") + score.last_recalculated_at = _now() + score.updated_at = _now() + + db.commit() + db.refresh(score) + return score + + +def run_decay_engine(db: Session) -> dict: + techniques = db.query(Technique).all() + results = {"total_techniques": len(techniques), "fresh": 0, "aging": 0, "stale": 0, "broken": 0, "unknown": 0, "validations_expired": 0} + now = _now() + + # Expire stale validations + expired = db.query(DetectionValidation).filter( + DetectionValidation.is_valid == True, + DetectionValidation.expires_at <= now, + ).all() + from app.models.detection_lifecycle import InvalidationReason + for v in expired: + v.is_valid = False + v.invalidated_at = now + v.invalidation_reason = InvalidationReason.time_decay + results["validations_expired"] = len(expired) + if expired: + db.commit() + + for technique in techniques: + score = calculate_confidence_for_technique(db, technique.id) + if score: + level = score.confidence_level.value + results[level] = results.get(level, 0) + 1 + + logger.info("Decay engine completed: %s", results) + return results + + +def process_infrastructure_change(db: Session, change_id: UUID) -> int: + change = db.query(InfrastructureChangeLog).filter(InfrastructureChangeLog.id == change_id).first() + if not change or not change.auto_invalidate: + return 0 + + query = db.query(DetectionAsset).filter(DetectionAsset.is_active == True) + if change.affected_platforms: + query = query.filter(DetectionAsset.platform.in_(change.affected_platforms)) + + affected_assets = query.all() + total_invalidated = 0 + + from app.services.detection_asset_service import invalidate_validations_for_asset + for asset in affected_assets: + if change.affected_log_sources: + asset_log_source = asset.log_source_name or "" + if not any(ls in asset_log_source for ls in change.affected_log_sources): + continue + count = invalidate_validations_for_asset(db, asset.id, change.reported_by, "infrastructure_change") + total_invalidated += count + + change.invalidated_count = total_invalidated + db.commit() + logger.info("Infrastructure change %s: invalidated %d validations", change_id, total_invalidated) + return total_invalidated diff --git a/backend/app/services/detection_asset_service.py b/backend/app/services/detection_asset_service.py new file mode 100644 index 0000000..6d65c48 --- /dev/null +++ b/backend/app/services/detection_asset_service.py @@ -0,0 +1,211 @@ +"""Detection Asset CRUD service with auto-hash and change detection.""" + +import hashlib +import logging +from datetime import datetime, timezone +from typing import Optional +from uuid import UUID + +from sqlalchemy.orm import Session, joinedload + +from app.models.detection_lifecycle import ( + DetectionAsset, DetectionTechniqueMapping, + DetectionValidation, DetectionHealthStatus, InvalidationReason +) +from app.models.technique import Technique +from app.domain.exceptions import EntityNotFoundError +from app.services import audit_service + +logger = logging.getLogger(__name__) + + +def _compute_rule_hash(content: str) -> str: + normalized = content.strip().replace('\r\n', '\n') + return hashlib.sha256(normalized.encode()).hexdigest() + + +def _now() -> datetime: + return datetime.now(timezone.utc) + + +def create_detection_asset(db: Session, data: dict, user_id: UUID) -> DetectionAsset: + technique_ids = data.pop("technique_ids", []) or [] + # Remove None values so defaults apply + data = {k: v for k, v in data.items() if v is not None or k in ("log_source_config", "infrastructure_details", "tags")} + + asset = DetectionAsset(**data, created_by=user_id) + + if asset.rule_content: + asset.rule_hash = _compute_rule_hash(asset.rule_content) + asset.last_rule_change_at = _now() + + if asset.infrastructure_details: + infra_str = str(sorted(asset.infrastructure_details.items())) + asset.infrastructure_hash = hashlib.sha256(infra_str.encode()).hexdigest() + + db.add(asset) + db.flush() + + for tech_id in technique_ids: + technique = db.query(Technique).filter(Technique.id == tech_id).first() + if technique: + mapping = DetectionTechniqueMapping( + detection_asset_id=asset.id, + technique_id=tech_id, + ) + db.add(mapping) + + db.commit() + db.refresh(asset) + + audit_service.log_action( + db, user_id, "DETECTION_ASSET_CREATED", "detection_asset", str(asset.id), + details={"name": asset.name, "type": asset.asset_type, "platform": asset.platform, "technique_count": len(technique_ids)}, + ) + return asset + + +def update_detection_asset(db: Session, asset_id: UUID, data: dict, user_id: UUID) -> DetectionAsset: + asset = db.query(DetectionAsset).filter(DetectionAsset.id == asset_id).first() + if not asset: + raise EntityNotFoundError("DetectionAsset", str(asset_id)) + + changes = {} + rule_changed = False + + for key, value in data.items(): + if value is not None and hasattr(asset, key): + old_value = getattr(asset, key) + if old_value != value: + changes[key] = {"old": str(old_value), "new": str(value)} + setattr(asset, key, value) + + if "rule_content" in data and data["rule_content"]: + new_hash = _compute_rule_hash(data["rule_content"]) + if new_hash != asset.rule_hash: + rule_changed = True + asset.rule_hash = new_hash + asset.last_rule_change_at = _now() + + if "infrastructure_details" in data and data["infrastructure_details"]: + infra_str = str(sorted(data["infrastructure_details"].items())) + new_hash = hashlib.sha256(infra_str.encode()).hexdigest() + if new_hash != asset.infrastructure_hash: + asset.infrastructure_hash = new_hash + changes["infrastructure_hash_changed"] = True + + asset.updated_at = _now() + db.commit() + db.refresh(asset) + + if changes: + audit_service.log_action( + db, user_id, "DETECTION_ASSET_UPDATED", "detection_asset", str(asset.id), + details={"changes": changes, "rule_changed": rule_changed}, + ) + + if rule_changed: + invalidate_validations_for_asset(db, asset.id, user_id, "rule_modified") + + return asset + + +def invalidate_validations_for_asset(db: Session, asset_id: UUID, user_id: UUID, reason: str) -> int: + try: + reason_enum = InvalidationReason(reason) + except ValueError: + reason_enum = InvalidationReason.manual + + validations = db.query(DetectionValidation).filter( + DetectionValidation.detection_asset_id == asset_id, + DetectionValidation.is_valid == True, + ).all() + + count = 0 + for v in validations: + v.is_valid = False + v.invalidated_at = _now() + v.invalidation_reason = reason_enum + v.invalidated_by = user_id + count += 1 + + if count > 0: + db.commit() + logger.info("Invalidated %d validations for asset %s due to %s", count, asset_id, reason) + + return count + + +def get_asset_with_details(db: Session, asset_id: UUID) -> DetectionAsset: + asset = ( + db.query(DetectionAsset) + .options(joinedload(DetectionAsset.technique_mappings), joinedload(DetectionAsset.validations)) + .filter(DetectionAsset.id == asset_id) + .first() + ) + if not asset: + raise EntityNotFoundError("DetectionAsset", str(asset_id)) + return asset + + +def list_assets( + db: Session, + platform: Optional[str] = None, + asset_type: Optional[str] = None, + health_status: Optional[str] = None, + technique_id: Optional[UUID] = None, + is_active: Optional[bool] = True, +) -> list: + query = db.query(DetectionAsset) + if platform: + query = query.filter(DetectionAsset.platform == platform) + if asset_type: + query = query.filter(DetectionAsset.asset_type == asset_type) + if health_status: + query = query.filter(DetectionAsset.health_status == health_status) + if is_active is not None: + query = query.filter(DetectionAsset.is_active == is_active) + if technique_id: + query = query.join(DetectionTechniqueMapping).filter( + DetectionTechniqueMapping.technique_id == technique_id + ) + return query.order_by(DetectionAsset.name).all() + + +def get_technique_detection_summary(db: Session, technique_id: UUID) -> dict: + mappings = ( + db.query(DetectionTechniqueMapping) + .options(joinedload(DetectionTechniqueMapping.detection_asset)) + .filter(DetectionTechniqueMapping.technique_id == technique_id) + .all() + ) + + assets = [m.detection_asset for m in mappings if m.detection_asset] + active_assets = [a for a in assets if a.is_active] + now = _now() + + valid_count = 0 + for asset in active_assets: + has_valid = db.query(DetectionValidation).filter( + DetectionValidation.detection_asset_id == asset.id, + DetectionValidation.is_valid == True, + DetectionValidation.expires_at > now, + ).first() + if has_valid: + valid_count += 1 + + health_distribution = {} + for asset in active_assets: + status = asset.health_status.value if asset.health_status else "unknown" + health_distribution[status] = health_distribution.get(status, 0) + 1 + + platforms = list(set(a.platform for a in active_assets if a.platform)) + + return { + "technique_id": str(technique_id), + "total_assets": len(active_assets), + "validated_assets": valid_count, + "health_distribution": health_distribution, + "platforms": platforms, + "coverage_types": list(set(m.coverage_type for m in mappings if m.coverage_type)), + }