feat: add ImportService protocol and registry for OCP-compliant import extensibility (LP-7)

This commit is contained in:
2026-02-20 13:31:18 +01:00
parent c0c6cda11d
commit d77075272e
3 changed files with 140 additions and 28 deletions

View File

@@ -0,0 +1,88 @@
"""Port defining the common interface for data import services.
All import services (Atomic Red Team, Sigma, CALDERA, etc.) follow the
same contract: they receive a database session and return a summary dict
with import statistics.
New import sources can be added by:
1. Implementing the ``ImportService`` protocol in a new module
2. Registering the handler in the ``IMPORT_REGISTRY``
This satisfies the Open/Closed Principle — the system is open for new
import sources without modifying existing code.
"""
from __future__ import annotations
from typing import Any, Protocol, runtime_checkable
from sqlalchemy.orm import Session
@runtime_checkable
class ImportService(Protocol):
"""Contract for any data-import operation.
Each implementation is a callable ``(Session) -> dict`` that
downloads, parses, and upserts records from an external source.
"""
def __call__(self, db: Session) -> dict[str, Any]: ...
class ImportServiceEntry:
"""Lazy-loading wrapper that resolves a module-level function on first call."""
__slots__ = ("_module_path", "_func_name", "_resolved")
def __init__(self, module_path: str, func_name: str) -> None:
self._module_path = module_path
self._func_name = func_name
self._resolved: ImportService | None = None
def __call__(self, db: Session) -> dict[str, Any]:
if self._resolved is None:
import importlib
mod = importlib.import_module(self._module_path)
self._resolved = getattr(mod, self._func_name)
return self._resolved(db)
@property
def source_info(self) -> str:
return f"{self._module_path}.{self._func_name}"
IMPORT_REGISTRY: dict[str, ImportServiceEntry] = {
"atomic_red_team": ImportServiceEntry(
"app.services.atomic_import_service", "import_atomic_red_team",
),
"sigma": ImportServiceEntry(
"app.services.sigma_import_service", "sync",
),
"lolbas": ImportServiceEntry(
"app.services.lolbas_import_service", "sync",
),
"gtfobins": ImportServiceEntry(
"app.services.lolbas_import_service", "sync_gtfobins",
),
"caldera": ImportServiceEntry(
"app.services.caldera_import_service", "sync",
),
"elastic_rules": ImportServiceEntry(
"app.services.elastic_import_service", "sync",
),
"mitre_cti": ImportServiceEntry(
"app.services.threat_actor_import_service", "sync",
),
"d3fend": ImportServiceEntry(
"app.services.d3fend_import_service", "sync",
),
}
def get_import_handler(source_name: str) -> ImportServiceEntry | None:
"""Look up the import handler for *source_name*.
Returns ``None`` when no handler is registered.
"""
return IMPORT_REGISTRY.get(source_name)

View File

@@ -6,43 +6,18 @@ since they are long-running and self-contained.
from __future__ import annotations from __future__ import annotations
import importlib
import logging import logging
from datetime import datetime from datetime import datetime
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.domain.errors import BusinessRuleViolation, EntityNotFoundError from app.domain.errors import BusinessRuleViolation, EntityNotFoundError
from app.domain.ports.import_service import get_import_handler
from app.models.data_source import DataSource from app.models.data_source import DataSource
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _get_sync_handler(source_name: str):
"""Lazily import and return the sync function for *source_name*.
We import lazily to avoid circular imports and to only load the
modules that are actually needed.
"""
handlers = {
"atomic_red_team": ("app.services.atomic_import_service", "import_atomic_red_team"),
"sigma": ("app.services.sigma_import_service", "sync"),
"lolbas": ("app.services.lolbas_import_service", "sync"),
"gtfobins": ("app.services.lolbas_import_service", "sync_gtfobins"),
"caldera": ("app.services.caldera_import_service", "sync"),
"elastic_rules": ("app.services.elastic_import_service", "sync"),
"mitre_cti": ("app.services.threat_actor_import_service", "sync"),
"d3fend": ("app.services.d3fend_import_service", "sync"),
}
if source_name not in handlers:
return None
module_path, func_name = handlers[source_name]
mod = importlib.import_module(module_path)
return getattr(mod, func_name)
def list_sources(db: Session) -> list[dict]: def list_sources(db: Session) -> list[dict]:
"""Return all registered data sources as a list of dicts.""" """Return all registered data sources as a list of dicts."""
sources = db.query(DataSource).order_by(DataSource.name).all() sources = db.query(DataSource).order_by(DataSource.name).all()
@@ -96,7 +71,7 @@ def sync_source(db: Session, source_id: str) -> dict:
if not ds: if not ds:
raise EntityNotFoundError("Data source", source_id) raise EntityNotFoundError("Data source", source_id)
handler = _get_sync_handler(ds.name) handler = get_import_handler(ds.name)
if handler is None: if handler is None:
raise BusinessRuleViolation(f"No sync handler available for '{ds.name}'") raise BusinessRuleViolation(f"No sync handler available for '{ds.name}'")
@@ -142,7 +117,7 @@ def sync_all_sources(db: Session) -> list[dict]:
results = [] results = []
for ds in enabled_sources: for ds in enabled_sources:
handler = _get_sync_handler(ds.name) handler = get_import_handler(ds.name)
if handler is None: if handler is None:
results.append({ results.append({
"source": ds.name, "source": ds.name,

View File

@@ -0,0 +1,49 @@
"""Tests for the ImportService protocol and IMPORT_REGISTRY."""
from app.domain.ports.import_service import (
IMPORT_REGISTRY,
ImportService,
ImportServiceEntry,
get_import_handler,
)
def test_registry_has_all_known_sources():
expected = {
"atomic_red_team",
"sigma",
"lolbas",
"gtfobins",
"caldera",
"elastic_rules",
"mitre_cti",
"d3fend",
}
assert set(IMPORT_REGISTRY.keys()) == expected
def test_all_entries_are_import_service_entries():
for name, entry in IMPORT_REGISTRY.items():
assert isinstance(entry, ImportServiceEntry), f"{name} is not ImportServiceEntry"
def test_get_import_handler_returns_entry_for_known_source():
handler = get_import_handler("sigma")
assert handler is not None
assert isinstance(handler, ImportServiceEntry)
def test_get_import_handler_returns_none_for_unknown():
assert get_import_handler("nonexistent_source") is None
def test_import_service_entry_source_info():
entry = ImportServiceEntry("app.services.sigma_import_service", "sync")
assert entry.source_info == "app.services.sigma_import_service.sync"
def test_callable_satisfies_protocol():
def mock_handler(db):
return {"created": 0}
assert isinstance(mock_handler, ImportService)