From e651ef8a8c951b840049ea326081510534fcc809 Mon Sep 17 00:00:00 2001 From: Kitos Date: Wed, 18 Feb 2026 16:09:51 +0100 Subject: [PATCH] refactor(heatmap): extract business logic to dedicated service Move layer dispatch, entity-not-found checks, and validation from router to heatmap_service. Router now only validates requests, calls service, and formats responses (no HTTPException, no business logic). Service raises EntityNotFoundError/BusinessRuleViolation instead of returning None. Add build_navigator_export() for centralized dispatch. 29 new tests (253 total, 0 failures). --- backend/app/routers/heatmap.py | 57 ++------ backend/app/services/heatmap_service.py | 59 +++++++- backend/tests/test_heatmap_service.py | 180 ++++++++++++++++++++++++ 3 files changed, 243 insertions(+), 53 deletions(-) create mode 100644 backend/tests/test_heatmap_service.py diff --git a/backend/app/routers/heatmap.py b/backend/app/routers/heatmap.py index 102e0d0..18ec7f4 100644 --- a/backend/app/routers/heatmap.py +++ b/backend/app/routers/heatmap.py @@ -1,13 +1,15 @@ """Heatmap endpoints — ATT&CK Navigator-compatible layer generation. -Thin router that delegates to :mod:`app.services.heatmap_service`. +Thin router that delegates entirely to :mod:`app.services.heatmap_service`. +No business logic lives here — only request validation and response +formatting. """ import io import json from typing import Optional -from fastapi import APIRouter, Depends, HTTPException, Query +from fastapi import APIRouter, Depends, Query from fastapi.responses import StreamingResponse from sqlalchemy.orm import Session @@ -19,9 +21,6 @@ from app.services import heatmap_service router = APIRouter(prefix="/heatmap", tags=["heatmap"]) -# ── GET /heatmap/coverage ───────────────────────────────────────────── - - @router.get("/coverage") def heatmap_coverage( platforms: Optional[str] = Query(None, description="Comma-separated platforms"), @@ -36,9 +35,6 @@ def heatmap_coverage( ) -# ── GET /heatmap/threat-actor/{actor_id} ────────────────────────────── - - @router.get("/threat-actor/{actor_id}") def heatmap_threat_actor( actor_id: str, @@ -49,15 +45,9 @@ def heatmap_threat_actor( current_user: User = Depends(get_current_user), ): """Threat actor layer — techniques used by an actor with coverage color.""" - layer = heatmap_service.build_threat_actor_layer( + return heatmap_service.build_threat_actor_layer( db, actor_id, platforms=platforms, tactics=tactics, min_score=min_score, ) - if layer is None: - raise HTTPException(status_code=404, detail="Threat actor not found") - return layer - - -# ── GET /heatmap/detection-rules ────────────────────────────────────── @router.get("/detection-rules") @@ -74,9 +64,6 @@ def heatmap_detection_rules( ) -# ── GET /heatmap/campaign/{campaign_id} ─────────────────────────────── - - @router.get("/campaign/{campaign_id}") def heatmap_campaign( campaign_id: str, @@ -87,25 +74,9 @@ def heatmap_campaign( current_user: User = Depends(get_current_user), ): """Campaign layer — only techniques in the campaign, colored by test state.""" - layer = heatmap_service.build_campaign_layer( + return heatmap_service.build_campaign_layer( db, campaign_id, platforms=platforms, tactics=tactics, min_score=min_score, ) - if layer is None: - raise HTTPException(status_code=404, detail="Campaign not found") - return layer - - -# ── GET /heatmap/export-navigator ───────────────────────────────────── - -_LAYER_BUILDERS = { - "coverage": lambda db, **kw: heatmap_service.build_coverage_layer(db, **kw), - "detection-rules": lambda db, **kw: heatmap_service.build_detection_rules_layer(db, **kw), -} - -_LAYER_BUILDERS_WITH_ID = { - "threat-actor": lambda db, lid, **kw: heatmap_service.build_threat_actor_layer(db, lid, **kw), - "campaign": lambda db, lid, **kw: heatmap_service.build_campaign_layer(db, lid, **kw), -} @router.get("/export-navigator") @@ -119,18 +90,10 @@ def export_navigator( current_user: User = Depends(get_current_user), ): """Export a heatmap layer as a downloadable JSON file for ATT&CK Navigator.""" - kwargs = dict(platforms=platforms, tactics=tactics, min_score=min_score) - - if layer in _LAYER_BUILDERS: - data = _LAYER_BUILDERS[layer](db, **kwargs) - elif layer in _LAYER_BUILDERS_WITH_ID: - if not layer_id: - raise HTTPException(status_code=400, detail=f"layer_id required for {layer} layer") - data = _LAYER_BUILDERS_WITH_ID[layer](db, layer_id, **kwargs) - if data is None: - raise HTTPException(status_code=404, detail=f"{layer} not found") - else: - raise HTTPException(status_code=400, detail=f"Unknown layer type: {layer}") + data = heatmap_service.build_navigator_export( + db, layer, layer_id=layer_id, + platforms=platforms, tactics=tactics, min_score=min_score, + ) json_content = json.dumps(data, indent=2, default=str) buffer = io.BytesIO(json_content.encode("utf-8")) diff --git a/backend/app/services/heatmap_service.py b/backend/app/services/heatmap_service.py index 838341e..725f638 100644 --- a/backend/app/services/heatmap_service.py +++ b/backend/app/services/heatmap_service.py @@ -15,6 +15,7 @@ from typing import Optional from sqlalchemy import func, or_ from sqlalchemy.orm import Session +from app.domain.errors import BusinessRuleViolation, EntityNotFoundError from app.models.campaign import Campaign, CampaignTest from app.models.detection_rule import DetectionRule from app.models.defensive_technique import DefensiveTechniqueMapping @@ -206,14 +207,14 @@ def build_threat_actor_layer( platforms: str | None = None, tactics: str | None = None, min_score: int = 0, -) -> dict | None: +) -> dict: """Threat actor layer -- techniques used by an actor with coverage colour. - Returns ``None`` if the actor does not exist. + Raises :class:`EntityNotFoundError` if the actor does not exist. """ actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first() if not actor: - return None + raise EntityNotFoundError("ThreatActor", actor_id) layer = _build_layer_skeleton( f"Threat Actor: {actor.name}", @@ -364,14 +365,14 @@ def build_campaign_layer( platforms: str | None = None, tactics: str | None = None, min_score: int = 0, -) -> dict | None: +) -> dict: """Campaign layer -- techniques in a campaign, coloured by test state. - Returns ``None`` if the campaign does not exist. + Raises :class:`EntityNotFoundError` if the campaign does not exist. """ campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() if not campaign: - return None + raise EntityNotFoundError("Campaign", campaign_id) layer = _build_layer_skeleton( f"Campaign: {campaign.name}", @@ -450,3 +451,49 @@ def build_campaign_layer( }) return layer + + +# ── Layer dispatch (for Navigator export) ──────────────────────────── + +_LAYER_BUILDERS = { + "coverage": lambda db, **kw: build_coverage_layer(db, **kw), + "detection-rules": lambda db, **kw: build_detection_rules_layer(db, **kw), +} + +_LAYER_BUILDERS_WITH_ID = { + "threat-actor": lambda db, lid, **kw: build_threat_actor_layer(db, lid, **kw), + "campaign": lambda db, lid, **kw: build_campaign_layer(db, lid, **kw), +} + +SUPPORTED_LAYER_TYPES = set(_LAYER_BUILDERS) | set(_LAYER_BUILDERS_WITH_ID) + + +def build_navigator_export( + db: Session, + layer_type: str, + *, + layer_id: str | None = None, + platforms: str | None = None, + tactics: str | None = None, + min_score: int = 0, +) -> dict: + """Build a heatmap layer dict by type name. + + Raises :class:`BusinessRuleViolation` for unknown layer types or + missing ``layer_id``. Raises :class:`EntityNotFoundError` when + an entity-bound layer (threat-actor, campaign) references a + non-existent record. + """ + kwargs = dict(platforms=platforms, tactics=tactics, min_score=min_score) + + if layer_type in _LAYER_BUILDERS: + return _LAYER_BUILDERS[layer_type](db, **kwargs) + + if layer_type in _LAYER_BUILDERS_WITH_ID: + if not layer_id: + raise BusinessRuleViolation( + f"layer_id is required for '{layer_type}' layer" + ) + return _LAYER_BUILDERS_WITH_ID[layer_type](db, layer_id, **kwargs) + + raise BusinessRuleViolation(f"Unknown layer type: {layer_type}") diff --git a/backend/tests/test_heatmap_service.py b/backend/tests/test_heatmap_service.py new file mode 100644 index 0000000..fbc163d --- /dev/null +++ b/backend/tests/test_heatmap_service.py @@ -0,0 +1,180 @@ +"""Tests for heatmap_service — pure helpers, error paths, and dispatch.""" + +import pytest +from unittest.mock import MagicMock, patch + +from app.services.heatmap_service import ( + _score_to_color, + _build_layer_skeleton, + _parse_csv, + _format_tactic, + build_navigator_export, + build_threat_actor_layer, + build_campaign_layer, + SUPPORTED_LAYER_TYPES, + ATTACK_VERSION, + NAVIGATOR_VERSION, + LAYER_VERSION, + DOMAIN, +) +from app.domain.errors import BusinessRuleViolation, EntityNotFoundError + + +# ── Pure helpers ────────────────────────────────────────────────────── + + +class TestScoreToColor: + def test_zero_returns_grey(self): + assert _score_to_color(0) == "#d3d3d3" + + def test_negative_returns_grey(self): + assert _score_to_color(-10) == "#d3d3d3" + + def test_low_returns_red(self): + assert _score_to_color(25) == "#ff6666" + + def test_medium_returns_orange(self): + assert _score_to_color(50) == "#ff9933" + + def test_high_returns_yellow(self): + assert _score_to_color(75) == "#ffff66" + + def test_max_returns_green(self): + assert _score_to_color(100) == "#66ff66" + + +class TestBuildLayerSkeleton: + def test_has_required_keys(self): + layer = _build_layer_skeleton("Test Layer", "A description") + assert layer["name"] == "Test Layer" + assert layer["description"] == "A description" + assert layer["domain"] == DOMAIN + assert layer["techniques"] == [] + assert layer["versions"]["attack"] == ATTACK_VERSION + assert layer["versions"]["navigator"] == NAVIGATOR_VERSION + assert layer["versions"]["layer"] == LAYER_VERSION + + def test_default_gradient(self): + layer = _build_layer_skeleton("X", "Y") + assert layer["gradient"]["minValue"] == 0 + assert layer["gradient"]["maxValue"] == 100 + assert len(layer["gradient"]["colors"]) == 3 + + def test_custom_gradient(self): + layer = _build_layer_skeleton("X", "Y", gradient_colors=["#000", "#fff"]) + assert layer["gradient"]["colors"] == ["#000", "#fff"] + + +class TestParseCsv: + def test_none_returns_none(self): + assert _parse_csv(None) is None + + def test_empty_string_returns_none(self): + assert _parse_csv("") is None + + def test_single_value(self): + assert _parse_csv("windows") == ["windows"] + + def test_multiple_values_with_spaces(self): + assert _parse_csv("windows, linux, macos") == ["windows", "linux", "macos"] + + def test_empty_elements_filtered(self): + assert _parse_csv("a,,b") == ["a", "b"] + + +class TestFormatTactic: + def test_none_returns_empty(self): + assert _format_tactic(None) == "" + + def test_empty_returns_empty(self): + assert _format_tactic("") == "" + + def test_lowercases(self): + assert _format_tactic("Initial Access") == "initial access" + + def test_comma_separated_takes_first(self): + assert _format_tactic("Execution, Persistence") == "execution" + + +# ── build_navigator_export dispatch ─────────────────────────────────── + + +def _mock_db(): + return MagicMock() + + +class TestBuildNavigatorExport: + @patch("app.services.heatmap_service.build_coverage_layer") + def test_dispatches_coverage(self, mock_build): + mock_build.return_value = {"name": "coverage"} + result = build_navigator_export(_mock_db(), "coverage") + assert result["name"] == "coverage" + mock_build.assert_called_once() + + @patch("app.services.heatmap_service.build_detection_rules_layer") + def test_dispatches_detection_rules(self, mock_build): + mock_build.return_value = {"name": "rules"} + result = build_navigator_export(_mock_db(), "detection-rules") + assert result["name"] == "rules" + mock_build.assert_called_once() + + @patch("app.services.heatmap_service.build_threat_actor_layer") + def test_dispatches_threat_actor_with_id(self, mock_build): + mock_build.return_value = {"name": "actor"} + result = build_navigator_export(_mock_db(), "threat-actor", layer_id="abc") + assert result["name"] == "actor" + mock_build.assert_called_once() + + @patch("app.services.heatmap_service.build_campaign_layer") + def test_dispatches_campaign_with_id(self, mock_build): + mock_build.return_value = {"name": "campaign"} + result = build_navigator_export(_mock_db(), "campaign", layer_id="xyz") + assert result["name"] == "campaign" + mock_build.assert_called_once() + + def test_unknown_layer_raises(self): + with pytest.raises(BusinessRuleViolation, match="Unknown layer type"): + build_navigator_export(_mock_db(), "nonexistent") + + def test_missing_layer_id_for_threat_actor(self): + with pytest.raises(BusinessRuleViolation, match="layer_id is required"): + build_navigator_export(_mock_db(), "threat-actor") + + def test_missing_layer_id_for_campaign(self): + with pytest.raises(BusinessRuleViolation, match="layer_id is required"): + build_navigator_export(_mock_db(), "campaign") + + def test_supported_layer_types_complete(self): + assert SUPPORTED_LAYER_TYPES == { + "coverage", "detection-rules", "threat-actor", "campaign", + } + + @patch("app.services.heatmap_service.build_coverage_layer") + def test_passes_filter_kwargs(self, mock_build): + mock_build.return_value = {} + build_navigator_export( + _mock_db(), "coverage", + platforms="windows", tactics="execution", min_score=50, + ) + _, kwargs = mock_build.call_args + assert kwargs["platforms"] == "windows" + assert kwargs["tactics"] == "execution" + assert kwargs["min_score"] == 50 + + +# ── Entity-not-found errors ─────────────────────────────────────────── + + +class TestEntityNotFound: + def _db_returning_none(self): + db = MagicMock() + db.query.return_value.filter.return_value.first.return_value = None + return db + + def test_threat_actor_not_found(self): + with pytest.raises(EntityNotFoundError, match="ThreatActor"): + build_threat_actor_layer(self._db_returning_none(), "bad-id") + + def test_campaign_not_found(self): + with pytest.raises(EntityNotFoundError, match="Campaign"): + build_campaign_layer(self._db_returning_none(), "bad-id")