"""Detection Asset CRUD service with auto-hash and change detection.""" import hashlib import logging from datetime import datetime from typing import Optional from uuid import UUID from sqlalchemy.orm import Session, joinedload from app.models.detection_lifecycle import ( DetectionAsset, DetectionTechniqueMapping, DetectionValidation, DetectionHealthStatus, InvalidationReason ) from app.models.technique import Technique from app.domain.exceptions import EntityNotFoundError from app.services import audit_service logger = logging.getLogger(__name__) def _compute_rule_hash(content: str) -> str: normalized = content.strip().replace('\r\n', '\n') return hashlib.sha256(normalized.encode()).hexdigest() def _now() -> datetime: return datetime.utcnow() def create_detection_asset(db: Session, data: dict, user_id: UUID) -> DetectionAsset: technique_ids = data.pop("technique_ids", []) or [] # Remove None values so defaults apply data = {k: v for k, v in data.items() if v is not None or k in ("log_source_config", "infrastructure_details", "tags")} asset = DetectionAsset(**data, created_by=user_id) if asset.rule_content: asset.rule_hash = _compute_rule_hash(asset.rule_content) asset.last_rule_change_at = _now() if asset.infrastructure_details: infra_str = str(sorted(asset.infrastructure_details.items())) asset.infrastructure_hash = hashlib.sha256(infra_str.encode()).hexdigest() db.add(asset) db.flush() for tech_id in technique_ids: technique = db.query(Technique).filter(Technique.id == tech_id).first() if technique: mapping = DetectionTechniqueMapping( detection_asset_id=asset.id, technique_id=tech_id, ) db.add(mapping) db.commit() db.refresh(asset) audit_service.log_action( db, user_id, "DETECTION_ASSET_CREATED", "detection_asset", str(asset.id), details={"name": asset.name, "type": asset.asset_type, "platform": asset.platform, "technique_count": len(technique_ids)}, ) return asset def update_detection_asset(db: Session, asset_id: UUID, data: dict, user_id: UUID) -> DetectionAsset: asset = db.query(DetectionAsset).filter(DetectionAsset.id == asset_id).first() if not asset: raise EntityNotFoundError("DetectionAsset", str(asset_id)) changes = {} rule_changed = False for key, value in data.items(): if value is not None and hasattr(asset, key): old_value = getattr(asset, key) if old_value != value: changes[key] = {"old": str(old_value), "new": str(value)} setattr(asset, key, value) if "rule_content" in data and data["rule_content"]: new_hash = _compute_rule_hash(data["rule_content"]) if new_hash != asset.rule_hash: rule_changed = True asset.rule_hash = new_hash asset.last_rule_change_at = _now() if "infrastructure_details" in data and data["infrastructure_details"]: infra_str = str(sorted(data["infrastructure_details"].items())) new_hash = hashlib.sha256(infra_str.encode()).hexdigest() if new_hash != asset.infrastructure_hash: asset.infrastructure_hash = new_hash changes["infrastructure_hash_changed"] = True asset.updated_at = _now() db.commit() db.refresh(asset) if changes: audit_service.log_action( db, user_id, "DETECTION_ASSET_UPDATED", "detection_asset", str(asset.id), details={"changes": changes, "rule_changed": rule_changed}, ) if rule_changed: invalidate_validations_for_asset(db, asset.id, user_id, "rule_modified") return asset def invalidate_validations_for_asset(db: Session, asset_id: UUID, user_id: UUID, reason: str) -> int: try: reason_enum = InvalidationReason(reason) except ValueError: reason_enum = InvalidationReason.manual validations = db.query(DetectionValidation).filter( DetectionValidation.detection_asset_id == asset_id, DetectionValidation.is_valid == True, ).all() count = 0 for v in validations: v.is_valid = False v.invalidated_at = _now() v.invalidation_reason = reason_enum v.invalidated_by = user_id count += 1 if count > 0: db.commit() logger.info("Invalidated %d validations for asset %s due to %s", count, asset_id, reason) return count def get_asset_with_details(db: Session, asset_id: UUID) -> DetectionAsset: asset = ( db.query(DetectionAsset) .options(joinedload(DetectionAsset.technique_mappings), joinedload(DetectionAsset.validations)) .filter(DetectionAsset.id == asset_id) .first() ) if not asset: raise EntityNotFoundError("DetectionAsset", str(asset_id)) return asset def list_assets( db: Session, platform: Optional[str] = None, asset_type: Optional[str] = None, health_status: Optional[str] = None, technique_id: Optional[UUID] = None, is_active: Optional[bool] = True, ) -> list: query = db.query(DetectionAsset) if platform: query = query.filter(DetectionAsset.platform == platform) if asset_type: query = query.filter(DetectionAsset.asset_type == asset_type) if health_status: query = query.filter(DetectionAsset.health_status == health_status) if is_active is not None: query = query.filter(DetectionAsset.is_active == is_active) if technique_id: query = query.join(DetectionTechniqueMapping).filter( DetectionTechniqueMapping.technique_id == technique_id ) return query.order_by(DetectionAsset.name).all() def get_technique_detection_summary(db: Session, technique_id: UUID) -> dict: mappings = ( db.query(DetectionTechniqueMapping) .options(joinedload(DetectionTechniqueMapping.detection_asset)) .filter(DetectionTechniqueMapping.technique_id == technique_id) .all() ) assets = [m.detection_asset for m in mappings if m.detection_asset] active_assets = [a for a in assets if a.is_active] now = _now() valid_count = 0 for asset in active_assets: has_valid = db.query(DetectionValidation).filter( DetectionValidation.detection_asset_id == asset.id, DetectionValidation.is_valid == True, DetectionValidation.expires_at > now, ).first() if has_valid: valid_count += 1 health_distribution = {} for asset in active_assets: status = asset.health_status.value if asset.health_status else "unknown" health_distribution[status] = health_distribution.get(status, 0) + 1 platforms = list(set(a.platform for a in active_assets if a.platform)) return { "technique_id": str(technique_id), "total_assets": len(active_assets), "validated_assets": valid_count, "health_distribution": health_distribution, "platforms": platforms, "coverage_types": list(set(m.coverage_type for m in mappings if m.coverage_type)), }