feat: add ImportService protocol and registry for OCP-compliant import extensibility (LP-7)

This commit is contained in:
2026-02-20 13:31:18 +01:00
parent c0c6cda11d
commit d77075272e
3 changed files with 140 additions and 28 deletions

View File

@@ -6,43 +6,18 @@ since they are long-running and self-contained.
from __future__ import annotations
import importlib
import logging
from datetime import datetime
from sqlalchemy.orm import Session
from app.domain.errors import BusinessRuleViolation, EntityNotFoundError
from app.domain.ports.import_service import get_import_handler
from app.models.data_source import DataSource
logger = logging.getLogger(__name__)
def _get_sync_handler(source_name: str):
"""Lazily import and return the sync function for *source_name*.
We import lazily to avoid circular imports and to only load the
modules that are actually needed.
"""
handlers = {
"atomic_red_team": ("app.services.atomic_import_service", "import_atomic_red_team"),
"sigma": ("app.services.sigma_import_service", "sync"),
"lolbas": ("app.services.lolbas_import_service", "sync"),
"gtfobins": ("app.services.lolbas_import_service", "sync_gtfobins"),
"caldera": ("app.services.caldera_import_service", "sync"),
"elastic_rules": ("app.services.elastic_import_service", "sync"),
"mitre_cti": ("app.services.threat_actor_import_service", "sync"),
"d3fend": ("app.services.d3fend_import_service", "sync"),
}
if source_name not in handlers:
return None
module_path, func_name = handlers[source_name]
mod = importlib.import_module(module_path)
return getattr(mod, func_name)
def list_sources(db: Session) -> list[dict]:
"""Return all registered data sources as a list of dicts."""
sources = db.query(DataSource).order_by(DataSource.name).all()
@@ -96,7 +71,7 @@ def sync_source(db: Session, source_id: str) -> dict:
if not ds:
raise EntityNotFoundError("Data source", source_id)
handler = _get_sync_handler(ds.name)
handler = get_import_handler(ds.name)
if handler is None:
raise BusinessRuleViolation(f"No sync handler available for '{ds.name}'")
@@ -142,7 +117,7 @@ def sync_all_sources(db: Session) -> list[dict]:
results = []
for ds in enabled_sources:
handler = _get_sync_handler(ds.name)
handler = get_import_handler(ds.name)
if handler is None:
results.append({
"source": ds.name,