From d77075272e7bb0d003dcf91f4693f2a60f9bece4 Mon Sep 17 00:00:00 2001 From: Kitos Date: Fri, 20 Feb 2026 13:31:18 +0100 Subject: [PATCH] feat: add ImportService protocol and registry for OCP-compliant import extensibility (LP-7) --- backend/app/domain/ports/import_service.py | 88 +++++++++++++++++++ backend/app/services/data_source_service.py | 31 +------ backend/tests/test_import_service_protocol.py | 49 +++++++++++ 3 files changed, 140 insertions(+), 28 deletions(-) create mode 100644 backend/app/domain/ports/import_service.py create mode 100644 backend/tests/test_import_service_protocol.py diff --git a/backend/app/domain/ports/import_service.py b/backend/app/domain/ports/import_service.py new file mode 100644 index 0000000..be7b56d --- /dev/null +++ b/backend/app/domain/ports/import_service.py @@ -0,0 +1,88 @@ +"""Port defining the common interface for data import services. + +All import services (Atomic Red Team, Sigma, CALDERA, etc.) follow the +same contract: they receive a database session and return a summary dict +with import statistics. + +New import sources can be added by: +1. Implementing the ``ImportService`` protocol in a new module +2. Registering the handler in the ``IMPORT_REGISTRY`` + +This satisfies the Open/Closed Principle — the system is open for new +import sources without modifying existing code. +""" + +from __future__ import annotations + +from typing import Any, Protocol, runtime_checkable + +from sqlalchemy.orm import Session + + +@runtime_checkable +class ImportService(Protocol): + """Contract for any data-import operation. + + Each implementation is a callable ``(Session) -> dict`` that + downloads, parses, and upserts records from an external source. + """ + + def __call__(self, db: Session) -> dict[str, Any]: ... + + +class ImportServiceEntry: + """Lazy-loading wrapper that resolves a module-level function on first call.""" + + __slots__ = ("_module_path", "_func_name", "_resolved") + + def __init__(self, module_path: str, func_name: str) -> None: + self._module_path = module_path + self._func_name = func_name + self._resolved: ImportService | None = None + + def __call__(self, db: Session) -> dict[str, Any]: + if self._resolved is None: + import importlib + mod = importlib.import_module(self._module_path) + self._resolved = getattr(mod, self._func_name) + return self._resolved(db) + + @property + def source_info(self) -> str: + return f"{self._module_path}.{self._func_name}" + + +IMPORT_REGISTRY: dict[str, ImportServiceEntry] = { + "atomic_red_team": ImportServiceEntry( + "app.services.atomic_import_service", "import_atomic_red_team", + ), + "sigma": ImportServiceEntry( + "app.services.sigma_import_service", "sync", + ), + "lolbas": ImportServiceEntry( + "app.services.lolbas_import_service", "sync", + ), + "gtfobins": ImportServiceEntry( + "app.services.lolbas_import_service", "sync_gtfobins", + ), + "caldera": ImportServiceEntry( + "app.services.caldera_import_service", "sync", + ), + "elastic_rules": ImportServiceEntry( + "app.services.elastic_import_service", "sync", + ), + "mitre_cti": ImportServiceEntry( + "app.services.threat_actor_import_service", "sync", + ), + "d3fend": ImportServiceEntry( + "app.services.d3fend_import_service", "sync", + ), +} + + +def get_import_handler(source_name: str) -> ImportServiceEntry | None: + """Look up the import handler for *source_name*. + + Returns ``None`` when no handler is registered. + """ + return IMPORT_REGISTRY.get(source_name) diff --git a/backend/app/services/data_source_service.py b/backend/app/services/data_source_service.py index 6de64ef..25ad0ca 100644 --- a/backend/app/services/data_source_service.py +++ b/backend/app/services/data_source_service.py @@ -6,43 +6,18 @@ since they are long-running and self-contained. from __future__ import annotations -import importlib import logging from datetime import datetime from sqlalchemy.orm import Session from app.domain.errors import BusinessRuleViolation, EntityNotFoundError +from app.domain.ports.import_service import get_import_handler from app.models.data_source import DataSource logger = logging.getLogger(__name__) -def _get_sync_handler(source_name: str): - """Lazily import and return the sync function for *source_name*. - - We import lazily to avoid circular imports and to only load the - modules that are actually needed. - """ - handlers = { - "atomic_red_team": ("app.services.atomic_import_service", "import_atomic_red_team"), - "sigma": ("app.services.sigma_import_service", "sync"), - "lolbas": ("app.services.lolbas_import_service", "sync"), - "gtfobins": ("app.services.lolbas_import_service", "sync_gtfobins"), - "caldera": ("app.services.caldera_import_service", "sync"), - "elastic_rules": ("app.services.elastic_import_service", "sync"), - "mitre_cti": ("app.services.threat_actor_import_service", "sync"), - "d3fend": ("app.services.d3fend_import_service", "sync"), - } - - if source_name not in handlers: - return None - - module_path, func_name = handlers[source_name] - mod = importlib.import_module(module_path) - return getattr(mod, func_name) - - def list_sources(db: Session) -> list[dict]: """Return all registered data sources as a list of dicts.""" sources = db.query(DataSource).order_by(DataSource.name).all() @@ -96,7 +71,7 @@ def sync_source(db: Session, source_id: str) -> dict: if not ds: raise EntityNotFoundError("Data source", source_id) - handler = _get_sync_handler(ds.name) + handler = get_import_handler(ds.name) if handler is None: raise BusinessRuleViolation(f"No sync handler available for '{ds.name}'") @@ -142,7 +117,7 @@ def sync_all_sources(db: Session) -> list[dict]: results = [] for ds in enabled_sources: - handler = _get_sync_handler(ds.name) + handler = get_import_handler(ds.name) if handler is None: results.append({ "source": ds.name, diff --git a/backend/tests/test_import_service_protocol.py b/backend/tests/test_import_service_protocol.py new file mode 100644 index 0000000..8ff2043 --- /dev/null +++ b/backend/tests/test_import_service_protocol.py @@ -0,0 +1,49 @@ +"""Tests for the ImportService protocol and IMPORT_REGISTRY.""" + +from app.domain.ports.import_service import ( + IMPORT_REGISTRY, + ImportService, + ImportServiceEntry, + get_import_handler, +) + + +def test_registry_has_all_known_sources(): + expected = { + "atomic_red_team", + "sigma", + "lolbas", + "gtfobins", + "caldera", + "elastic_rules", + "mitre_cti", + "d3fend", + } + assert set(IMPORT_REGISTRY.keys()) == expected + + +def test_all_entries_are_import_service_entries(): + for name, entry in IMPORT_REGISTRY.items(): + assert isinstance(entry, ImportServiceEntry), f"{name} is not ImportServiceEntry" + + +def test_get_import_handler_returns_entry_for_known_source(): + handler = get_import_handler("sigma") + assert handler is not None + assert isinstance(handler, ImportServiceEntry) + + +def test_get_import_handler_returns_none_for_unknown(): + assert get_import_handler("nonexistent_source") is None + + +def test_import_service_entry_source_info(): + entry = ImportServiceEntry("app.services.sigma_import_service", "sync") + assert entry.source_info == "app.services.sigma_import_service.sync" + + +def test_callable_satisfies_protocol(): + def mock_handler(db): + return {"created": 0} + + assert isinstance(mock_handler, ImportService)