diff --git a/backend/app/services/heatmap_service.py b/backend/app/services/heatmap_service.py index 725f638..1ac34cf 100644 --- a/backend/app/services/heatmap_service.py +++ b/backend/app/services/heatmap_service.py @@ -453,19 +453,71 @@ def build_campaign_layer( return layer -# ── Layer dispatch (for Navigator export) ──────────────────────────── +# ── Layer registry (OCP-compliant dispatch) ────────────────────────── +# +# To add a new layer type: +# 1. Write a builder function: ``def build_X_layer(db, *, platforms, tactics, min_score) -> dict`` +# 2. Call ``register_layer("x", build_X_layer)`` (or ``register_layer("x", fn, requires_id=True)``) +# 3. Optionally add a convenience endpoint in the router +# +# The ``/export-navigator?layer=x`` endpoint picks up new layers automatically. -_LAYER_BUILDERS = { - "coverage": lambda db, **kw: build_coverage_layer(db, **kw), - "detection-rules": lambda db, **kw: build_detection_rules_layer(db, **kw), -} -_LAYER_BUILDERS_WITH_ID = { - "threat-actor": lambda db, lid, **kw: build_threat_actor_layer(db, lid, **kw), - "campaign": lambda db, lid, **kw: build_campaign_layer(db, lid, **kw), -} +class _LayerRegistry: + """Extensible registry that maps layer type names to builder functions.""" -SUPPORTED_LAYER_TYPES = set(_LAYER_BUILDERS) | set(_LAYER_BUILDERS_WITH_ID) + __slots__ = ("_simple", "_with_id") + + def __init__(self) -> None: + self._simple: dict[str, object] = {} + self._with_id: dict[str, object] = {} + + def register(self, name: str, builder, *, requires_id: bool = False) -> None: + target = self._with_id if requires_id else self._simple + target[name] = builder + + @property + def supported_types(self) -> set[str]: + return set(self._simple) | set(self._with_id) + + def build( + self, + db: Session, + layer_type: str, + *, + layer_id: str | None = None, + platforms: str | None = None, + tactics: str | None = None, + min_score: int = 0, + ) -> dict: + kwargs = dict(platforms=platforms, tactics=tactics, min_score=min_score) + + if layer_type in self._simple: + return self._simple[layer_type](db, **kwargs) + + if layer_type in self._with_id: + if not layer_id: + raise BusinessRuleViolation( + f"layer_id is required for '{layer_type}' layer" + ) + return self._with_id[layer_type](db, layer_id, **kwargs) + + raise BusinessRuleViolation(f"Unknown layer type: {layer_type}") + + +LAYER_REGISTRY = _LayerRegistry() + +LAYER_REGISTRY.register("coverage", build_coverage_layer) +LAYER_REGISTRY.register("detection-rules", build_detection_rules_layer) +LAYER_REGISTRY.register("threat-actor", build_threat_actor_layer, requires_id=True) +LAYER_REGISTRY.register("campaign", build_campaign_layer, requires_id=True) + +SUPPORTED_LAYER_TYPES = LAYER_REGISTRY.supported_types # snapshot of built-in types + + +def register_layer(name: str, builder, *, requires_id: bool = False) -> None: + """Public API to register a new heatmap layer type at import time.""" + LAYER_REGISTRY.register(name, builder, requires_id=requires_id) def build_navigator_export( @@ -484,16 +536,7 @@ def build_navigator_export( an entity-bound layer (threat-actor, campaign) references a non-existent record. """ - kwargs = dict(platforms=platforms, tactics=tactics, min_score=min_score) - - if layer_type in _LAYER_BUILDERS: - return _LAYER_BUILDERS[layer_type](db, **kwargs) - - if layer_type in _LAYER_BUILDERS_WITH_ID: - if not layer_id: - raise BusinessRuleViolation( - f"layer_id is required for '{layer_type}' layer" - ) - return _LAYER_BUILDERS_WITH_ID[layer_type](db, layer_id, **kwargs) - - raise BusinessRuleViolation(f"Unknown layer type: {layer_type}") + return LAYER_REGISTRY.build( + db, layer_type, + layer_id=layer_id, platforms=platforms, tactics=tactics, min_score=min_score, + ) diff --git a/backend/tests/test_heatmap_layer_registry.py b/backend/tests/test_heatmap_layer_registry.py new file mode 100644 index 0000000..fc4e7f3 --- /dev/null +++ b/backend/tests/test_heatmap_layer_registry.py @@ -0,0 +1,58 @@ +"""Tests for the heatmap layer registry (OCP extensibility).""" + +import pytest + +from app.services.heatmap_service import ( + LAYER_REGISTRY, + SUPPORTED_LAYER_TYPES, + register_layer, +) + + +def test_builtin_layer_types_registered(): + assert "coverage" in SUPPORTED_LAYER_TYPES + assert "detection-rules" in SUPPORTED_LAYER_TYPES + assert "threat-actor" in SUPPORTED_LAYER_TYPES + assert "campaign" in SUPPORTED_LAYER_TYPES + + +def test_register_custom_simple_layer(): + def my_layer(db, *, platforms=None, tactics=None, min_score=0): + return {"name": "custom", "techniques": []} + + register_layer("custom-test-layer", my_layer) + assert "custom-test-layer" in LAYER_REGISTRY.supported_types + + result = LAYER_REGISTRY.build( + None, "custom-test-layer", + platforms=None, tactics=None, min_score=0, + ) + assert result["name"] == "custom" + + +def test_register_custom_id_layer(): + def my_id_layer(db, layer_id, *, platforms=None, tactics=None, min_score=0): + return {"name": f"entity-{layer_id}", "techniques": []} + + register_layer("custom-id-layer", my_id_layer, requires_id=True) + assert "custom-id-layer" in LAYER_REGISTRY.supported_types + + result = LAYER_REGISTRY.build( + None, "custom-id-layer", + layer_id="abc-123", platforms=None, tactics=None, min_score=0, + ) + assert result["name"] == "entity-abc-123" + + +def test_unknown_layer_raises(): + from app.domain.errors import BusinessRuleViolation + + with pytest.raises(BusinessRuleViolation, match="Unknown layer type"): + LAYER_REGISTRY.build(None, "nonexistent-layer") + + +def test_id_layer_without_id_raises(): + from app.domain.errors import BusinessRuleViolation + + with pytest.raises(BusinessRuleViolation, match="layer_id is required"): + LAYER_REGISTRY.build(None, "threat-actor", layer_id=None) diff --git a/backend/tests/test_heatmap_service.py b/backend/tests/test_heatmap_service.py index fbc163d..1062351 100644 --- a/backend/tests/test_heatmap_service.py +++ b/backend/tests/test_heatmap_service.py @@ -104,33 +104,53 @@ def _mock_db(): class TestBuildNavigatorExport: - @patch("app.services.heatmap_service.build_coverage_layer") - def test_dispatches_coverage(self, mock_build): - mock_build.return_value = {"name": "coverage"} - result = build_navigator_export(_mock_db(), "coverage") - assert result["name"] == "coverage" - mock_build.assert_called_once() + def test_dispatches_coverage(self): + from app.services.heatmap_service import LAYER_REGISTRY + mock_build = MagicMock(return_value={"name": "coverage"}) + orig = LAYER_REGISTRY._simple["coverage"] + LAYER_REGISTRY._simple["coverage"] = mock_build + try: + result = build_navigator_export(_mock_db(), "coverage") + assert result["name"] == "coverage" + mock_build.assert_called_once() + finally: + LAYER_REGISTRY._simple["coverage"] = orig - @patch("app.services.heatmap_service.build_detection_rules_layer") - def test_dispatches_detection_rules(self, mock_build): - mock_build.return_value = {"name": "rules"} - result = build_navigator_export(_mock_db(), "detection-rules") - assert result["name"] == "rules" - mock_build.assert_called_once() + def test_dispatches_detection_rules(self): + from app.services.heatmap_service import LAYER_REGISTRY + mock_build = MagicMock(return_value={"name": "rules"}) + orig = LAYER_REGISTRY._simple["detection-rules"] + LAYER_REGISTRY._simple["detection-rules"] = mock_build + try: + result = build_navigator_export(_mock_db(), "detection-rules") + assert result["name"] == "rules" + mock_build.assert_called_once() + finally: + LAYER_REGISTRY._simple["detection-rules"] = orig - @patch("app.services.heatmap_service.build_threat_actor_layer") - def test_dispatches_threat_actor_with_id(self, mock_build): - mock_build.return_value = {"name": "actor"} - result = build_navigator_export(_mock_db(), "threat-actor", layer_id="abc") - assert result["name"] == "actor" - mock_build.assert_called_once() + def test_dispatches_threat_actor_with_id(self): + from app.services.heatmap_service import LAYER_REGISTRY + mock_build = MagicMock(return_value={"name": "actor"}) + orig = LAYER_REGISTRY._with_id["threat-actor"] + LAYER_REGISTRY._with_id["threat-actor"] = mock_build + try: + result = build_navigator_export(_mock_db(), "threat-actor", layer_id="abc") + assert result["name"] == "actor" + mock_build.assert_called_once() + finally: + LAYER_REGISTRY._with_id["threat-actor"] = orig - @patch("app.services.heatmap_service.build_campaign_layer") - def test_dispatches_campaign_with_id(self, mock_build): - mock_build.return_value = {"name": "campaign"} - result = build_navigator_export(_mock_db(), "campaign", layer_id="xyz") - assert result["name"] == "campaign" - mock_build.assert_called_once() + def test_dispatches_campaign_with_id(self): + from app.services.heatmap_service import LAYER_REGISTRY + mock_build = MagicMock(return_value={"name": "campaign"}) + orig = LAYER_REGISTRY._with_id["campaign"] + LAYER_REGISTRY._with_id["campaign"] = mock_build + try: + result = build_navigator_export(_mock_db(), "campaign", layer_id="xyz") + assert result["name"] == "campaign" + mock_build.assert_called_once() + finally: + LAYER_REGISTRY._with_id["campaign"] = orig def test_unknown_layer_raises(self): with pytest.raises(BusinessRuleViolation, match="Unknown layer type"): @@ -149,17 +169,22 @@ class TestBuildNavigatorExport: "coverage", "detection-rules", "threat-actor", "campaign", } - @patch("app.services.heatmap_service.build_coverage_layer") - def test_passes_filter_kwargs(self, mock_build): - mock_build.return_value = {} - build_navigator_export( - _mock_db(), "coverage", - platforms="windows", tactics="execution", min_score=50, - ) - _, kwargs = mock_build.call_args - assert kwargs["platforms"] == "windows" - assert kwargs["tactics"] == "execution" - assert kwargs["min_score"] == 50 + def test_passes_filter_kwargs(self): + from app.services.heatmap_service import LAYER_REGISTRY + mock_build = MagicMock(return_value={}) + orig = LAYER_REGISTRY._simple["coverage"] + LAYER_REGISTRY._simple["coverage"] = mock_build + try: + build_navigator_export( + _mock_db(), "coverage", + platforms="windows", tactics="execution", min_score=50, + ) + _, kwargs = mock_build.call_args + assert kwargs["platforms"] == "windows" + assert kwargs["tactics"] == "execution" + assert kwargs["min_score"] == 50 + finally: + LAYER_REGISTRY._simple["coverage"] = orig # ── Entity-not-found errors ───────────────────────────────────────────