refactor(heatmap): extract business logic to dedicated service
Some checks failed
Aegis CI / lint-and-test (push) Has been cancelled

Move layer dispatch, entity-not-found checks, and validation from router to heatmap_service. Router now only validates requests, calls service, and formats responses (no HTTPException, no business logic). Service raises EntityNotFoundError/BusinessRuleViolation instead of returning None. Add build_navigator_export() for centralized dispatch. 29 new tests (253 total, 0 failures).
This commit is contained in:
2026-02-18 16:09:51 +01:00
parent 1338d52cd0
commit e651ef8a8c
3 changed files with 243 additions and 53 deletions

View File

@@ -1,13 +1,15 @@
"""Heatmap endpoints — ATT&CK Navigator-compatible layer generation. """Heatmap endpoints — ATT&CK Navigator-compatible layer generation.
Thin router that delegates to :mod:`app.services.heatmap_service`. Thin router that delegates entirely to :mod:`app.services.heatmap_service`.
No business logic lives here — only request validation and response
formatting.
""" """
import io import io
import json import json
from typing import Optional from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query from fastapi import APIRouter, Depends, Query
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -19,9 +21,6 @@ from app.services import heatmap_service
router = APIRouter(prefix="/heatmap", tags=["heatmap"]) router = APIRouter(prefix="/heatmap", tags=["heatmap"])
# ── GET /heatmap/coverage ─────────────────────────────────────────────
@router.get("/coverage") @router.get("/coverage")
def heatmap_coverage( def heatmap_coverage(
platforms: Optional[str] = Query(None, description="Comma-separated platforms"), platforms: Optional[str] = Query(None, description="Comma-separated platforms"),
@@ -36,9 +35,6 @@ def heatmap_coverage(
) )
# ── GET /heatmap/threat-actor/{actor_id} ──────────────────────────────
@router.get("/threat-actor/{actor_id}") @router.get("/threat-actor/{actor_id}")
def heatmap_threat_actor( def heatmap_threat_actor(
actor_id: str, actor_id: str,
@@ -49,15 +45,9 @@ def heatmap_threat_actor(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ):
"""Threat actor layer — techniques used by an actor with coverage color.""" """Threat actor layer — techniques used by an actor with coverage color."""
layer = heatmap_service.build_threat_actor_layer( return heatmap_service.build_threat_actor_layer(
db, actor_id, platforms=platforms, tactics=tactics, min_score=min_score, db, actor_id, platforms=platforms, tactics=tactics, min_score=min_score,
) )
if layer is None:
raise HTTPException(status_code=404, detail="Threat actor not found")
return layer
# ── GET /heatmap/detection-rules ──────────────────────────────────────
@router.get("/detection-rules") @router.get("/detection-rules")
@@ -74,9 +64,6 @@ def heatmap_detection_rules(
) )
# ── GET /heatmap/campaign/{campaign_id} ───────────────────────────────
@router.get("/campaign/{campaign_id}") @router.get("/campaign/{campaign_id}")
def heatmap_campaign( def heatmap_campaign(
campaign_id: str, campaign_id: str,
@@ -87,25 +74,9 @@ def heatmap_campaign(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ):
"""Campaign layer — only techniques in the campaign, colored by test state.""" """Campaign layer — only techniques in the campaign, colored by test state."""
layer = heatmap_service.build_campaign_layer( return heatmap_service.build_campaign_layer(
db, campaign_id, platforms=platforms, tactics=tactics, min_score=min_score, db, campaign_id, platforms=platforms, tactics=tactics, min_score=min_score,
) )
if layer is None:
raise HTTPException(status_code=404, detail="Campaign not found")
return layer
# ── GET /heatmap/export-navigator ─────────────────────────────────────
_LAYER_BUILDERS = {
"coverage": lambda db, **kw: heatmap_service.build_coverage_layer(db, **kw),
"detection-rules": lambda db, **kw: heatmap_service.build_detection_rules_layer(db, **kw),
}
_LAYER_BUILDERS_WITH_ID = {
"threat-actor": lambda db, lid, **kw: heatmap_service.build_threat_actor_layer(db, lid, **kw),
"campaign": lambda db, lid, **kw: heatmap_service.build_campaign_layer(db, lid, **kw),
}
@router.get("/export-navigator") @router.get("/export-navigator")
@@ -119,18 +90,10 @@ def export_navigator(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ):
"""Export a heatmap layer as a downloadable JSON file for ATT&CK Navigator.""" """Export a heatmap layer as a downloadable JSON file for ATT&CK Navigator."""
kwargs = dict(platforms=platforms, tactics=tactics, min_score=min_score) data = heatmap_service.build_navigator_export(
db, layer, layer_id=layer_id,
if layer in _LAYER_BUILDERS: platforms=platforms, tactics=tactics, min_score=min_score,
data = _LAYER_BUILDERS[layer](db, **kwargs) )
elif layer in _LAYER_BUILDERS_WITH_ID:
if not layer_id:
raise HTTPException(status_code=400, detail=f"layer_id required for {layer} layer")
data = _LAYER_BUILDERS_WITH_ID[layer](db, layer_id, **kwargs)
if data is None:
raise HTTPException(status_code=404, detail=f"{layer} not found")
else:
raise HTTPException(status_code=400, detail=f"Unknown layer type: {layer}")
json_content = json.dumps(data, indent=2, default=str) json_content = json.dumps(data, indent=2, default=str)
buffer = io.BytesIO(json_content.encode("utf-8")) buffer = io.BytesIO(json_content.encode("utf-8"))

View File

@@ -15,6 +15,7 @@ from typing import Optional
from sqlalchemy import func, or_ from sqlalchemy import func, or_
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.domain.errors import BusinessRuleViolation, EntityNotFoundError
from app.models.campaign import Campaign, CampaignTest from app.models.campaign import Campaign, CampaignTest
from app.models.detection_rule import DetectionRule from app.models.detection_rule import DetectionRule
from app.models.defensive_technique import DefensiveTechniqueMapping from app.models.defensive_technique import DefensiveTechniqueMapping
@@ -206,14 +207,14 @@ def build_threat_actor_layer(
platforms: str | None = None, platforms: str | None = None,
tactics: str | None = None, tactics: str | None = None,
min_score: int = 0, min_score: int = 0,
) -> dict | None: ) -> dict:
"""Threat actor layer -- techniques used by an actor with coverage colour. """Threat actor layer -- techniques used by an actor with coverage colour.
Returns ``None`` if the actor does not exist. Raises :class:`EntityNotFoundError` if the actor does not exist.
""" """
actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first() actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first()
if not actor: if not actor:
return None raise EntityNotFoundError("ThreatActor", actor_id)
layer = _build_layer_skeleton( layer = _build_layer_skeleton(
f"Threat Actor: {actor.name}", f"Threat Actor: {actor.name}",
@@ -364,14 +365,14 @@ def build_campaign_layer(
platforms: str | None = None, platforms: str | None = None,
tactics: str | None = None, tactics: str | None = None,
min_score: int = 0, min_score: int = 0,
) -> dict | None: ) -> dict:
"""Campaign layer -- techniques in a campaign, coloured by test state. """Campaign layer -- techniques in a campaign, coloured by test state.
Returns ``None`` if the campaign does not exist. Raises :class:`EntityNotFoundError` if the campaign does not exist.
""" """
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
if not campaign: if not campaign:
return None raise EntityNotFoundError("Campaign", campaign_id)
layer = _build_layer_skeleton( layer = _build_layer_skeleton(
f"Campaign: {campaign.name}", f"Campaign: {campaign.name}",
@@ -450,3 +451,49 @@ def build_campaign_layer(
}) })
return layer return layer
# ── Layer dispatch (for Navigator export) ────────────────────────────
_LAYER_BUILDERS = {
"coverage": lambda db, **kw: build_coverage_layer(db, **kw),
"detection-rules": lambda db, **kw: build_detection_rules_layer(db, **kw),
}
_LAYER_BUILDERS_WITH_ID = {
"threat-actor": lambda db, lid, **kw: build_threat_actor_layer(db, lid, **kw),
"campaign": lambda db, lid, **kw: build_campaign_layer(db, lid, **kw),
}
SUPPORTED_LAYER_TYPES = set(_LAYER_BUILDERS) | set(_LAYER_BUILDERS_WITH_ID)
def build_navigator_export(
db: Session,
layer_type: str,
*,
layer_id: str | None = None,
platforms: str | None = None,
tactics: str | None = None,
min_score: int = 0,
) -> dict:
"""Build a heatmap layer dict by type name.
Raises :class:`BusinessRuleViolation` for unknown layer types or
missing ``layer_id``. Raises :class:`EntityNotFoundError` when
an entity-bound layer (threat-actor, campaign) references a
non-existent record.
"""
kwargs = dict(platforms=platforms, tactics=tactics, min_score=min_score)
if layer_type in _LAYER_BUILDERS:
return _LAYER_BUILDERS[layer_type](db, **kwargs)
if layer_type in _LAYER_BUILDERS_WITH_ID:
if not layer_id:
raise BusinessRuleViolation(
f"layer_id is required for '{layer_type}' layer"
)
return _LAYER_BUILDERS_WITH_ID[layer_type](db, layer_id, **kwargs)
raise BusinessRuleViolation(f"Unknown layer type: {layer_type}")

View File

@@ -0,0 +1,180 @@
"""Tests for heatmap_service — pure helpers, error paths, and dispatch."""
import pytest
from unittest.mock import MagicMock, patch
from app.services.heatmap_service import (
_score_to_color,
_build_layer_skeleton,
_parse_csv,
_format_tactic,
build_navigator_export,
build_threat_actor_layer,
build_campaign_layer,
SUPPORTED_LAYER_TYPES,
ATTACK_VERSION,
NAVIGATOR_VERSION,
LAYER_VERSION,
DOMAIN,
)
from app.domain.errors import BusinessRuleViolation, EntityNotFoundError
# ── Pure helpers ──────────────────────────────────────────────────────
class TestScoreToColor:
def test_zero_returns_grey(self):
assert _score_to_color(0) == "#d3d3d3"
def test_negative_returns_grey(self):
assert _score_to_color(-10) == "#d3d3d3"
def test_low_returns_red(self):
assert _score_to_color(25) == "#ff6666"
def test_medium_returns_orange(self):
assert _score_to_color(50) == "#ff9933"
def test_high_returns_yellow(self):
assert _score_to_color(75) == "#ffff66"
def test_max_returns_green(self):
assert _score_to_color(100) == "#66ff66"
class TestBuildLayerSkeleton:
def test_has_required_keys(self):
layer = _build_layer_skeleton("Test Layer", "A description")
assert layer["name"] == "Test Layer"
assert layer["description"] == "A description"
assert layer["domain"] == DOMAIN
assert layer["techniques"] == []
assert layer["versions"]["attack"] == ATTACK_VERSION
assert layer["versions"]["navigator"] == NAVIGATOR_VERSION
assert layer["versions"]["layer"] == LAYER_VERSION
def test_default_gradient(self):
layer = _build_layer_skeleton("X", "Y")
assert layer["gradient"]["minValue"] == 0
assert layer["gradient"]["maxValue"] == 100
assert len(layer["gradient"]["colors"]) == 3
def test_custom_gradient(self):
layer = _build_layer_skeleton("X", "Y", gradient_colors=["#000", "#fff"])
assert layer["gradient"]["colors"] == ["#000", "#fff"]
class TestParseCsv:
def test_none_returns_none(self):
assert _parse_csv(None) is None
def test_empty_string_returns_none(self):
assert _parse_csv("") is None
def test_single_value(self):
assert _parse_csv("windows") == ["windows"]
def test_multiple_values_with_spaces(self):
assert _parse_csv("windows, linux, macos") == ["windows", "linux", "macos"]
def test_empty_elements_filtered(self):
assert _parse_csv("a,,b") == ["a", "b"]
class TestFormatTactic:
def test_none_returns_empty(self):
assert _format_tactic(None) == ""
def test_empty_returns_empty(self):
assert _format_tactic("") == ""
def test_lowercases(self):
assert _format_tactic("Initial Access") == "initial access"
def test_comma_separated_takes_first(self):
assert _format_tactic("Execution, Persistence") == "execution"
# ── build_navigator_export dispatch ───────────────────────────────────
def _mock_db():
return MagicMock()
class TestBuildNavigatorExport:
@patch("app.services.heatmap_service.build_coverage_layer")
def test_dispatches_coverage(self, mock_build):
mock_build.return_value = {"name": "coverage"}
result = build_navigator_export(_mock_db(), "coverage")
assert result["name"] == "coverage"
mock_build.assert_called_once()
@patch("app.services.heatmap_service.build_detection_rules_layer")
def test_dispatches_detection_rules(self, mock_build):
mock_build.return_value = {"name": "rules"}
result = build_navigator_export(_mock_db(), "detection-rules")
assert result["name"] == "rules"
mock_build.assert_called_once()
@patch("app.services.heatmap_service.build_threat_actor_layer")
def test_dispatches_threat_actor_with_id(self, mock_build):
mock_build.return_value = {"name": "actor"}
result = build_navigator_export(_mock_db(), "threat-actor", layer_id="abc")
assert result["name"] == "actor"
mock_build.assert_called_once()
@patch("app.services.heatmap_service.build_campaign_layer")
def test_dispatches_campaign_with_id(self, mock_build):
mock_build.return_value = {"name": "campaign"}
result = build_navigator_export(_mock_db(), "campaign", layer_id="xyz")
assert result["name"] == "campaign"
mock_build.assert_called_once()
def test_unknown_layer_raises(self):
with pytest.raises(BusinessRuleViolation, match="Unknown layer type"):
build_navigator_export(_mock_db(), "nonexistent")
def test_missing_layer_id_for_threat_actor(self):
with pytest.raises(BusinessRuleViolation, match="layer_id is required"):
build_navigator_export(_mock_db(), "threat-actor")
def test_missing_layer_id_for_campaign(self):
with pytest.raises(BusinessRuleViolation, match="layer_id is required"):
build_navigator_export(_mock_db(), "campaign")
def test_supported_layer_types_complete(self):
assert SUPPORTED_LAYER_TYPES == {
"coverage", "detection-rules", "threat-actor", "campaign",
}
@patch("app.services.heatmap_service.build_coverage_layer")
def test_passes_filter_kwargs(self, mock_build):
mock_build.return_value = {}
build_navigator_export(
_mock_db(), "coverage",
platforms="windows", tactics="execution", min_score=50,
)
_, kwargs = mock_build.call_args
assert kwargs["platforms"] == "windows"
assert kwargs["tactics"] == "execution"
assert kwargs["min_score"] == 50
# ── Entity-not-found errors ───────────────────────────────────────────
class TestEntityNotFound:
def _db_returning_none(self):
db = MagicMock()
db.query.return_value.filter.return_value.first.return_value = None
return db
def test_threat_actor_not_found(self):
with pytest.raises(EntityNotFoundError, match="ThreatActor"):
build_threat_actor_layer(self._db_returning_none(), "bad-id")
def test_campaign_not_found(self):
with pytest.raises(EntityNotFoundError, match="Campaign"):
build_campaign_layer(self._db_returning_none(), "bad-id")