"""Data source management service — framework-agnostic query and sync logic. Provides list, update, sync, and stats. Sync operations commit internally since they are long-running and self-contained. """ # Enable future language features for compatibility from __future__ import annotations # Import logging import logging # Import datetime from datetime from datetime import datetime # Import Session from sqlalchemy.orm from sqlalchemy.orm import Session # Import BusinessRuleViolation, EntityNotFoundError from app.domain.errors from app.domain.errors import BusinessRuleViolation, EntityNotFoundError # Import get_import_handler from app.domain.ports.import_service from app.domain.ports.import_service import get_import_handler # Import DataSource from app.models.data_source from app.models.data_source import DataSource # Assign logger = logging.getLogger(__name__) logger = logging.getLogger(__name__) # Define function list_sources def list_sources(db: Session) -> list[dict]: """Return all registered data sources as a list of dicts.""" # Assign sources = db.query(DataSource).order_by(DataSource.name).all() sources = db.query(DataSource).order_by(DataSource.name).all() # Return [ return [ { # Literal argument value "id": str(s.id), # Literal argument value "name": s.name, # Literal argument value "display_name": s.display_name, # Literal argument value "type": s.type, # Literal argument value "url": s.url, # Literal argument value "description": s.description, # Literal argument value "is_enabled": s.is_enabled, # Literal argument value "last_sync_at": s.last_sync_at.isoformat() if s.last_sync_at else None, # Literal argument value "last_sync_status": s.last_sync_status, # Literal argument value "last_sync_stats": s.last_sync_stats, # Literal argument value "sync_frequency": s.sync_frequency, # Literal argument value "config": s.config, # Literal argument value "created_at": s.created_at.isoformat() if s.created_at else None, } for s in sources ] # Define function update_source def update_source(db: Session, source_id: str, **fields: object) -> None: """Update a data source's fields (is_enabled, sync_frequency, config). Raises EntityNotFoundError if source does not exist. Does not commit; the router handles that. """ # Assign ds = db.query(DataSource).filter(DataSource.id == source_id).first() ds = db.query(DataSource).filter(DataSource.id == source_id).first() # Check: not ds if not ds: # Raise EntityNotFoundError raise EntityNotFoundError("Data source", source_id) # Check: "is_enabled" in fields if "is_enabled" in fields: # Assign ds.is_enabled = fields["is_enabled"] ds.is_enabled = fields["is_enabled"] # Check: "sync_frequency" in fields if "sync_frequency" in fields: # Assign ds.sync_frequency = fields["sync_frequency"] ds.sync_frequency = fields["sync_frequency"] # Check: "config" in fields if "config" in fields: # Assign ds.config = fields["config"] ds.config = fields["config"] # Define function sync_source def sync_source(db: Session, source_id: str) -> dict: """Trigger sync for a specific data source. Raises EntityNotFoundError if source does not exist. Raises BusinessRuleViolation if no sync handler is available. Commits internally (long-running, self-contained operation). Returns dict with message, source, stats. """ # Assign ds = db.query(DataSource).filter(DataSource.id == source_id).first() ds = db.query(DataSource).filter(DataSource.id == source_id).first() # Check: not ds if not ds: # Raise EntityNotFoundError raise EntityNotFoundError("Data source", source_id) # Assign handler = get_import_handler(ds.name) handler = get_import_handler(ds.name) # Check: handler is None if handler is None: # Raise BusinessRuleViolation raise BusinessRuleViolation(f"No sync handler available for '{ds.name}'") # Assign ds.last_sync_status = "in_progress" ds.last_sync_status = "in_progress" # Commit all pending changes to the database db.commit() # Attempt the following; catch errors below try: # Assign summary = handler(db) summary = handler(db) # Handle Exception except Exception as exc: # Log error: "Sync failed for %s: %s", ds.name, exc, exc_info=T logger.error("Sync failed for %s: %s", ds.name, exc, exc_info=True) # Assign ds.last_sync_status = "error" ds.last_sync_status = "error" # Assign ds.last_sync_at = datetime.utcnow() ds.last_sync_at = datetime.utcnow() # Assign ds.last_sync_stats = {"error": str(exc)} ds.last_sync_stats = {"error": str(exc)} # Commit all pending changes to the database db.commit() # Raise BusinessRuleViolation raise BusinessRuleViolation( f"Sync failed for '{ds.display_name}'. Check server logs for details." ) # Assign ds.last_sync_at = datetime.utcnow() ds.last_sync_at = datetime.utcnow() # Assign ds.last_sync_status = "success" ds.last_sync_status = "success" # Assign ds.last_sync_stats = summary ds.last_sync_stats = summary # Commit all pending changes to the database db.commit() # Return { return { # Literal argument value "message": f"Sync complete for {ds.display_name}", # Literal argument value "source": ds.name, # Literal argument value "stats": summary, } # Define function sync_all_sources def sync_all_sources(db: Session) -> list[dict]: """Trigger sync for all enabled data sources (sequentially). Commits internally (long-running, self-contained operation). Returns list of result dicts with source, status, stats/detail. """ # Assign enabled_sources = ( enabled_sources = ( db.query(DataSource) # Chain .filter() call .filter(DataSource.is_enabled == True) # Chain .order_by() call .order_by(DataSource.name) # Chain .all() call .all() ) # Assign results = [] results = [] # Iterate over enabled_sources for ds in enabled_sources: # Assign handler = get_import_handler(ds.name) handler = get_import_handler(ds.name) # Check: handler is None if handler is None: # Call results.append() results.append({ # Literal argument value "source": ds.name, # Literal argument value "status": "skipped", # Literal argument value "detail": "No sync handler available", }) # Skip to the next loop iteration continue # Assign ds.last_sync_status = "in_progress" ds.last_sync_status = "in_progress" # Commit all pending changes to the database db.commit() # Attempt the following; catch errors below try: # Assign summary = handler(db) summary = handler(db) # Assign ds.last_sync_at = datetime.utcnow() ds.last_sync_at = datetime.utcnow() # Assign ds.last_sync_status = "success" ds.last_sync_status = "success" # Assign ds.last_sync_stats = summary ds.last_sync_stats = summary # Commit all pending changes to the database db.commit() # Call results.append() results.append({ # Literal argument value "source": ds.name, # Literal argument value "status": "success", # Literal argument value "stats": summary, }) # Handle Exception except Exception as exc: # Log error: "Sync failed for %s: %s", ds.name, exc, exc_info=T logger.error("Sync failed for %s: %s", ds.name, exc, exc_info=True) # Assign ds.last_sync_status = "error" ds.last_sync_status = "error" # Assign ds.last_sync_at = datetime.utcnow() ds.last_sync_at = datetime.utcnow() # Assign ds.last_sync_stats = {"error": str(exc)} ds.last_sync_stats = {"error": str(exc)} # Commit all pending changes to the database db.commit() # Call results.append() results.append({ # Literal argument value "source": ds.name, # Literal argument value "status": "error", # Literal argument value "detail": "Sync failed. Check server logs for details.", }) # Return results return results # Define function get_source_stats def get_source_stats(db: Session, source_id: str) -> dict: """Return detailed statistics for a data source. Raises EntityNotFoundError if source does not exist. """ # Assign ds = db.query(DataSource).filter(DataSource.id == source_id).first() ds = db.query(DataSource).filter(DataSource.id == source_id).first() # Check: not ds if not ds: # Raise EntityNotFoundError raise EntityNotFoundError("Data source", source_id) # Import DetectionRule from app.models.detection_rule from app.models.detection_rule import DetectionRule # Import TestTemplate from app.models.test_template from app.models.test_template import TestTemplate # Assign template_count = 0 template_count = 0 # Assign rule_count = 0 rule_count = 0 # Check: ds.type == "attack_procedure" if ds.type == "attack_procedure": # Assign template_count = ( template_count = ( db.query(TestTemplate) # Chain .filter() call .filter(TestTemplate.source == ds.name) # Chain .count() call .count() ) # Alternative: ds.type == "detection_rule" elif ds.type == "detection_rule": # Assign rule_count = ( rule_count = ( db.query(DetectionRule) # Chain .filter() call .filter(DetectionRule.source == ds.name) # Chain .count() call .count() ) # Return { return { # Literal argument value "id": str(ds.id), # Literal argument value "name": ds.name, # Literal argument value "display_name": ds.display_name, # Literal argument value "type": ds.type, # Literal argument value "is_enabled": ds.is_enabled, # Literal argument value "last_sync_at": ds.last_sync_at.isoformat() if ds.last_sync_at else None, # Literal argument value "last_sync_status": ds.last_sync_status, # Literal argument value "last_sync_stats": ds.last_sync_stats, # Literal argument value "total_templates": template_count, # Literal argument value "total_rules": rule_count, }