"""Data source management service — framework-agnostic query and sync logic. Provides list, update, sync, and stats. Sync operations commit internally since they are long-running and self-contained. """ from __future__ import annotations import importlib import logging from datetime import datetime from sqlalchemy.orm import Session from app.domain.errors import BusinessRuleViolation, EntityNotFoundError from app.models.data_source import DataSource logger = logging.getLogger(__name__) def _get_sync_handler(source_name: str): """Lazily import and return the sync function for *source_name*. We import lazily to avoid circular imports and to only load the modules that are actually needed. """ handlers = { "atomic_red_team": ("app.services.atomic_import_service", "import_atomic_red_team"), "sigma": ("app.services.sigma_import_service", "sync"), "lolbas": ("app.services.lolbas_import_service", "sync"), "gtfobins": ("app.services.lolbas_import_service", "sync_gtfobins"), "caldera": ("app.services.caldera_import_service", "sync"), "elastic_rules": ("app.services.elastic_import_service", "sync"), "mitre_cti": ("app.services.threat_actor_import_service", "sync"), "d3fend": ("app.services.d3fend_import_service", "sync"), } if source_name not in handlers: return None module_path, func_name = handlers[source_name] mod = importlib.import_module(module_path) return getattr(mod, func_name) def list_sources(db: Session) -> list[dict]: """Return all registered data sources as a list of dicts.""" sources = db.query(DataSource).order_by(DataSource.name).all() return [ { "id": str(s.id), "name": s.name, "display_name": s.display_name, "type": s.type, "url": s.url, "description": s.description, "is_enabled": s.is_enabled, "last_sync_at": s.last_sync_at.isoformat() if s.last_sync_at else None, "last_sync_status": s.last_sync_status, "last_sync_stats": s.last_sync_stats, "sync_frequency": s.sync_frequency, "config": s.config, "created_at": s.created_at.isoformat() if s.created_at else None, } for s in sources ] def update_source(db: Session, source_id: str, **fields: object) -> None: """Update a data source's fields (is_enabled, sync_frequency, config). Raises EntityNotFoundError if source does not exist. Does not commit; the router handles that. """ ds = db.query(DataSource).filter(DataSource.id == source_id).first() if not ds: raise EntityNotFoundError("Data source", source_id) if "is_enabled" in fields: ds.is_enabled = fields["is_enabled"] if "sync_frequency" in fields: ds.sync_frequency = fields["sync_frequency"] if "config" in fields: ds.config = fields["config"] def sync_source(db: Session, source_id: str) -> dict: """Trigger sync for a specific data source. Raises EntityNotFoundError if source does not exist. Raises BusinessRuleViolation if no sync handler is available. Commits internally (long-running, self-contained operation). Returns dict with message, source, stats. """ ds = db.query(DataSource).filter(DataSource.id == source_id).first() if not ds: raise EntityNotFoundError("Data source", source_id) handler = _get_sync_handler(ds.name) if handler is None: raise BusinessRuleViolation(f"No sync handler available for '{ds.name}'") ds.last_sync_status = "in_progress" db.commit() try: summary = handler(db) except Exception as exc: logger.error("Sync failed for %s: %s", ds.name, exc, exc_info=True) ds.last_sync_status = "error" ds.last_sync_at = datetime.utcnow() ds.last_sync_stats = {"error": str(exc)} db.commit() raise BusinessRuleViolation( f"Sync failed for '{ds.display_name}'. Check server logs for details." ) ds.last_sync_at = datetime.utcnow() ds.last_sync_status = "success" ds.last_sync_stats = summary db.commit() return { "message": f"Sync complete for {ds.display_name}", "source": ds.name, "stats": summary, } def sync_all_sources(db: Session) -> list[dict]: """Trigger sync for all enabled data sources (sequentially). Commits internally (long-running, self-contained operation). Returns list of result dicts with source, status, stats/detail. """ enabled_sources = ( db.query(DataSource) .filter(DataSource.is_enabled == True) .order_by(DataSource.name) .all() ) results = [] for ds in enabled_sources: handler = _get_sync_handler(ds.name) if handler is None: results.append({ "source": ds.name, "status": "skipped", "detail": "No sync handler available", }) continue ds.last_sync_status = "in_progress" db.commit() try: summary = handler(db) ds.last_sync_at = datetime.utcnow() ds.last_sync_status = "success" ds.last_sync_stats = summary db.commit() results.append({ "source": ds.name, "status": "success", "stats": summary, }) except Exception as exc: logger.error("Sync failed for %s: %s", ds.name, exc, exc_info=True) ds.last_sync_status = "error" ds.last_sync_at = datetime.utcnow() ds.last_sync_stats = {"error": str(exc)} db.commit() results.append({ "source": ds.name, "status": "error", "detail": "Sync failed. Check server logs for details.", }) return results def get_source_stats(db: Session, source_id: str) -> dict: """Return detailed statistics for a data source. Raises EntityNotFoundError if source does not exist. """ ds = db.query(DataSource).filter(DataSource.id == source_id).first() if not ds: raise EntityNotFoundError("Data source", source_id) from app.models.test_template import TestTemplate from app.models.detection_rule import DetectionRule template_count = 0 rule_count = 0 if ds.type == "attack_procedure": template_count = ( db.query(TestTemplate) .filter(TestTemplate.source == ds.name) .count() ) elif ds.type == "detection_rule": rule_count = ( db.query(DetectionRule) .filter(DetectionRule.source == ds.name) .count() ) return { "id": str(ds.id), "name": ds.name, "display_name": ds.display_name, "type": ds.type, "is_enabled": ds.is_enabled, "last_sync_at": ds.last_sync_at.isoformat() if ds.last_sync_at else None, "last_sync_status": ds.last_sync_status, "last_sync_stats": ds.last_sync_stats, "total_templates": template_count, "total_rules": rule_count, }