feat: add ImportService protocol and registry for OCP-compliant import extensibility (LP-7)
This commit is contained in:
88
backend/app/domain/ports/import_service.py
Normal file
88
backend/app/domain/ports/import_service.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""Port defining the common interface for data import services.
|
||||
|
||||
All import services (Atomic Red Team, Sigma, CALDERA, etc.) follow the
|
||||
same contract: they receive a database session and return a summary dict
|
||||
with import statistics.
|
||||
|
||||
New import sources can be added by:
|
||||
1. Implementing the ``ImportService`` protocol in a new module
|
||||
2. Registering the handler in the ``IMPORT_REGISTRY``
|
||||
|
||||
This satisfies the Open/Closed Principle — the system is open for new
|
||||
import sources without modifying existing code.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ImportService(Protocol):
|
||||
"""Contract for any data-import operation.
|
||||
|
||||
Each implementation is a callable ``(Session) -> dict`` that
|
||||
downloads, parses, and upserts records from an external source.
|
||||
"""
|
||||
|
||||
def __call__(self, db: Session) -> dict[str, Any]: ...
|
||||
|
||||
|
||||
class ImportServiceEntry:
|
||||
"""Lazy-loading wrapper that resolves a module-level function on first call."""
|
||||
|
||||
__slots__ = ("_module_path", "_func_name", "_resolved")
|
||||
|
||||
def __init__(self, module_path: str, func_name: str) -> None:
|
||||
self._module_path = module_path
|
||||
self._func_name = func_name
|
||||
self._resolved: ImportService | None = None
|
||||
|
||||
def __call__(self, db: Session) -> dict[str, Any]:
|
||||
if self._resolved is None:
|
||||
import importlib
|
||||
mod = importlib.import_module(self._module_path)
|
||||
self._resolved = getattr(mod, self._func_name)
|
||||
return self._resolved(db)
|
||||
|
||||
@property
|
||||
def source_info(self) -> str:
|
||||
return f"{self._module_path}.{self._func_name}"
|
||||
|
||||
|
||||
IMPORT_REGISTRY: dict[str, ImportServiceEntry] = {
|
||||
"atomic_red_team": ImportServiceEntry(
|
||||
"app.services.atomic_import_service", "import_atomic_red_team",
|
||||
),
|
||||
"sigma": ImportServiceEntry(
|
||||
"app.services.sigma_import_service", "sync",
|
||||
),
|
||||
"lolbas": ImportServiceEntry(
|
||||
"app.services.lolbas_import_service", "sync",
|
||||
),
|
||||
"gtfobins": ImportServiceEntry(
|
||||
"app.services.lolbas_import_service", "sync_gtfobins",
|
||||
),
|
||||
"caldera": ImportServiceEntry(
|
||||
"app.services.caldera_import_service", "sync",
|
||||
),
|
||||
"elastic_rules": ImportServiceEntry(
|
||||
"app.services.elastic_import_service", "sync",
|
||||
),
|
||||
"mitre_cti": ImportServiceEntry(
|
||||
"app.services.threat_actor_import_service", "sync",
|
||||
),
|
||||
"d3fend": ImportServiceEntry(
|
||||
"app.services.d3fend_import_service", "sync",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def get_import_handler(source_name: str) -> ImportServiceEntry | None:
|
||||
"""Look up the import handler for *source_name*.
|
||||
|
||||
Returns ``None`` when no handler is registered.
|
||||
"""
|
||||
return IMPORT_REGISTRY.get(source_name)
|
||||
@@ -6,43 +6,18 @@ since they are long-running and self-contained.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.domain.errors import BusinessRuleViolation, EntityNotFoundError
|
||||
from app.domain.ports.import_service import get_import_handler
|
||||
from app.models.data_source import DataSource
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_sync_handler(source_name: str):
|
||||
"""Lazily import and return the sync function for *source_name*.
|
||||
|
||||
We import lazily to avoid circular imports and to only load the
|
||||
modules that are actually needed.
|
||||
"""
|
||||
handlers = {
|
||||
"atomic_red_team": ("app.services.atomic_import_service", "import_atomic_red_team"),
|
||||
"sigma": ("app.services.sigma_import_service", "sync"),
|
||||
"lolbas": ("app.services.lolbas_import_service", "sync"),
|
||||
"gtfobins": ("app.services.lolbas_import_service", "sync_gtfobins"),
|
||||
"caldera": ("app.services.caldera_import_service", "sync"),
|
||||
"elastic_rules": ("app.services.elastic_import_service", "sync"),
|
||||
"mitre_cti": ("app.services.threat_actor_import_service", "sync"),
|
||||
"d3fend": ("app.services.d3fend_import_service", "sync"),
|
||||
}
|
||||
|
||||
if source_name not in handlers:
|
||||
return None
|
||||
|
||||
module_path, func_name = handlers[source_name]
|
||||
mod = importlib.import_module(module_path)
|
||||
return getattr(mod, func_name)
|
||||
|
||||
|
||||
def list_sources(db: Session) -> list[dict]:
|
||||
"""Return all registered data sources as a list of dicts."""
|
||||
sources = db.query(DataSource).order_by(DataSource.name).all()
|
||||
@@ -96,7 +71,7 @@ def sync_source(db: Session, source_id: str) -> dict:
|
||||
if not ds:
|
||||
raise EntityNotFoundError("Data source", source_id)
|
||||
|
||||
handler = _get_sync_handler(ds.name)
|
||||
handler = get_import_handler(ds.name)
|
||||
if handler is None:
|
||||
raise BusinessRuleViolation(f"No sync handler available for '{ds.name}'")
|
||||
|
||||
@@ -142,7 +117,7 @@ def sync_all_sources(db: Session) -> list[dict]:
|
||||
|
||||
results = []
|
||||
for ds in enabled_sources:
|
||||
handler = _get_sync_handler(ds.name)
|
||||
handler = get_import_handler(ds.name)
|
||||
if handler is None:
|
||||
results.append({
|
||||
"source": ds.name,
|
||||
|
||||
Reference in New Issue
Block a user