refactor(heatmap): extract business logic to dedicated service
Some checks failed
Aegis CI / lint-and-test (push) Has been cancelled
Some checks failed
Aegis CI / lint-and-test (push) Has been cancelled
Move layer dispatch, entity-not-found checks, and validation from router to heatmap_service. Router now only validates requests, calls service, and formats responses (no HTTPException, no business logic). Service raises EntityNotFoundError/BusinessRuleViolation instead of returning None. Add build_navigator_export() for centralized dispatch. 29 new tests (253 total, 0 failures).
This commit is contained in:
@@ -1,13 +1,15 @@
|
|||||||
"""Heatmap endpoints — ATT&CK Navigator-compatible layer generation.
|
"""Heatmap endpoints — ATT&CK Navigator-compatible layer generation.
|
||||||
|
|
||||||
Thin router that delegates to :mod:`app.services.heatmap_service`.
|
Thin router that delegates entirely to :mod:`app.services.heatmap_service`.
|
||||||
|
No business logic lives here — only request validation and response
|
||||||
|
formatting.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
from fastapi import APIRouter, Depends, Query
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
@@ -19,9 +21,6 @@ from app.services import heatmap_service
|
|||||||
router = APIRouter(prefix="/heatmap", tags=["heatmap"])
|
router = APIRouter(prefix="/heatmap", tags=["heatmap"])
|
||||||
|
|
||||||
|
|
||||||
# ── GET /heatmap/coverage ─────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/coverage")
|
@router.get("/coverage")
|
||||||
def heatmap_coverage(
|
def heatmap_coverage(
|
||||||
platforms: Optional[str] = Query(None, description="Comma-separated platforms"),
|
platforms: Optional[str] = Query(None, description="Comma-separated platforms"),
|
||||||
@@ -36,9 +35,6 @@ def heatmap_coverage(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# ── GET /heatmap/threat-actor/{actor_id} ──────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/threat-actor/{actor_id}")
|
@router.get("/threat-actor/{actor_id}")
|
||||||
def heatmap_threat_actor(
|
def heatmap_threat_actor(
|
||||||
actor_id: str,
|
actor_id: str,
|
||||||
@@ -49,15 +45,9 @@ def heatmap_threat_actor(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Threat actor layer — techniques used by an actor with coverage color."""
|
"""Threat actor layer — techniques used by an actor with coverage color."""
|
||||||
layer = heatmap_service.build_threat_actor_layer(
|
return heatmap_service.build_threat_actor_layer(
|
||||||
db, actor_id, platforms=platforms, tactics=tactics, min_score=min_score,
|
db, actor_id, platforms=platforms, tactics=tactics, min_score=min_score,
|
||||||
)
|
)
|
||||||
if layer is None:
|
|
||||||
raise HTTPException(status_code=404, detail="Threat actor not found")
|
|
||||||
return layer
|
|
||||||
|
|
||||||
|
|
||||||
# ── GET /heatmap/detection-rules ──────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/detection-rules")
|
@router.get("/detection-rules")
|
||||||
@@ -74,9 +64,6 @@ def heatmap_detection_rules(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# ── GET /heatmap/campaign/{campaign_id} ───────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/campaign/{campaign_id}")
|
@router.get("/campaign/{campaign_id}")
|
||||||
def heatmap_campaign(
|
def heatmap_campaign(
|
||||||
campaign_id: str,
|
campaign_id: str,
|
||||||
@@ -87,25 +74,9 @@ def heatmap_campaign(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Campaign layer — only techniques in the campaign, colored by test state."""
|
"""Campaign layer — only techniques in the campaign, colored by test state."""
|
||||||
layer = heatmap_service.build_campaign_layer(
|
return heatmap_service.build_campaign_layer(
|
||||||
db, campaign_id, platforms=platforms, tactics=tactics, min_score=min_score,
|
db, campaign_id, platforms=platforms, tactics=tactics, min_score=min_score,
|
||||||
)
|
)
|
||||||
if layer is None:
|
|
||||||
raise HTTPException(status_code=404, detail="Campaign not found")
|
|
||||||
return layer
|
|
||||||
|
|
||||||
|
|
||||||
# ── GET /heatmap/export-navigator ─────────────────────────────────────
|
|
||||||
|
|
||||||
_LAYER_BUILDERS = {
|
|
||||||
"coverage": lambda db, **kw: heatmap_service.build_coverage_layer(db, **kw),
|
|
||||||
"detection-rules": lambda db, **kw: heatmap_service.build_detection_rules_layer(db, **kw),
|
|
||||||
}
|
|
||||||
|
|
||||||
_LAYER_BUILDERS_WITH_ID = {
|
|
||||||
"threat-actor": lambda db, lid, **kw: heatmap_service.build_threat_actor_layer(db, lid, **kw),
|
|
||||||
"campaign": lambda db, lid, **kw: heatmap_service.build_campaign_layer(db, lid, **kw),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/export-navigator")
|
@router.get("/export-navigator")
|
||||||
@@ -119,18 +90,10 @@ def export_navigator(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Export a heatmap layer as a downloadable JSON file for ATT&CK Navigator."""
|
"""Export a heatmap layer as a downloadable JSON file for ATT&CK Navigator."""
|
||||||
kwargs = dict(platforms=platforms, tactics=tactics, min_score=min_score)
|
data = heatmap_service.build_navigator_export(
|
||||||
|
db, layer, layer_id=layer_id,
|
||||||
if layer in _LAYER_BUILDERS:
|
platforms=platforms, tactics=tactics, min_score=min_score,
|
||||||
data = _LAYER_BUILDERS[layer](db, **kwargs)
|
)
|
||||||
elif layer in _LAYER_BUILDERS_WITH_ID:
|
|
||||||
if not layer_id:
|
|
||||||
raise HTTPException(status_code=400, detail=f"layer_id required for {layer} layer")
|
|
||||||
data = _LAYER_BUILDERS_WITH_ID[layer](db, layer_id, **kwargs)
|
|
||||||
if data is None:
|
|
||||||
raise HTTPException(status_code=404, detail=f"{layer} not found")
|
|
||||||
else:
|
|
||||||
raise HTTPException(status_code=400, detail=f"Unknown layer type: {layer}")
|
|
||||||
|
|
||||||
json_content = json.dumps(data, indent=2, default=str)
|
json_content = json.dumps(data, indent=2, default=str)
|
||||||
buffer = io.BytesIO(json_content.encode("utf-8"))
|
buffer = io.BytesIO(json_content.encode("utf-8"))
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ from typing import Optional
|
|||||||
from sqlalchemy import func, or_
|
from sqlalchemy import func, or_
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.domain.errors import BusinessRuleViolation, EntityNotFoundError
|
||||||
from app.models.campaign import Campaign, CampaignTest
|
from app.models.campaign import Campaign, CampaignTest
|
||||||
from app.models.detection_rule import DetectionRule
|
from app.models.detection_rule import DetectionRule
|
||||||
from app.models.defensive_technique import DefensiveTechniqueMapping
|
from app.models.defensive_technique import DefensiveTechniqueMapping
|
||||||
@@ -206,14 +207,14 @@ def build_threat_actor_layer(
|
|||||||
platforms: str | None = None,
|
platforms: str | None = None,
|
||||||
tactics: str | None = None,
|
tactics: str | None = None,
|
||||||
min_score: int = 0,
|
min_score: int = 0,
|
||||||
) -> dict | None:
|
) -> dict:
|
||||||
"""Threat actor layer -- techniques used by an actor with coverage colour.
|
"""Threat actor layer -- techniques used by an actor with coverage colour.
|
||||||
|
|
||||||
Returns ``None`` if the actor does not exist.
|
Raises :class:`EntityNotFoundError` if the actor does not exist.
|
||||||
"""
|
"""
|
||||||
actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first()
|
actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first()
|
||||||
if not actor:
|
if not actor:
|
||||||
return None
|
raise EntityNotFoundError("ThreatActor", actor_id)
|
||||||
|
|
||||||
layer = _build_layer_skeleton(
|
layer = _build_layer_skeleton(
|
||||||
f"Threat Actor: {actor.name}",
|
f"Threat Actor: {actor.name}",
|
||||||
@@ -364,14 +365,14 @@ def build_campaign_layer(
|
|||||||
platforms: str | None = None,
|
platforms: str | None = None,
|
||||||
tactics: str | None = None,
|
tactics: str | None = None,
|
||||||
min_score: int = 0,
|
min_score: int = 0,
|
||||||
) -> dict | None:
|
) -> dict:
|
||||||
"""Campaign layer -- techniques in a campaign, coloured by test state.
|
"""Campaign layer -- techniques in a campaign, coloured by test state.
|
||||||
|
|
||||||
Returns ``None`` if the campaign does not exist.
|
Raises :class:`EntityNotFoundError` if the campaign does not exist.
|
||||||
"""
|
"""
|
||||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||||
if not campaign:
|
if not campaign:
|
||||||
return None
|
raise EntityNotFoundError("Campaign", campaign_id)
|
||||||
|
|
||||||
layer = _build_layer_skeleton(
|
layer = _build_layer_skeleton(
|
||||||
f"Campaign: {campaign.name}",
|
f"Campaign: {campaign.name}",
|
||||||
@@ -450,3 +451,49 @@ def build_campaign_layer(
|
|||||||
})
|
})
|
||||||
|
|
||||||
return layer
|
return layer
|
||||||
|
|
||||||
|
|
||||||
|
# ── Layer dispatch (for Navigator export) ────────────────────────────
|
||||||
|
|
||||||
|
_LAYER_BUILDERS = {
|
||||||
|
"coverage": lambda db, **kw: build_coverage_layer(db, **kw),
|
||||||
|
"detection-rules": lambda db, **kw: build_detection_rules_layer(db, **kw),
|
||||||
|
}
|
||||||
|
|
||||||
|
_LAYER_BUILDERS_WITH_ID = {
|
||||||
|
"threat-actor": lambda db, lid, **kw: build_threat_actor_layer(db, lid, **kw),
|
||||||
|
"campaign": lambda db, lid, **kw: build_campaign_layer(db, lid, **kw),
|
||||||
|
}
|
||||||
|
|
||||||
|
SUPPORTED_LAYER_TYPES = set(_LAYER_BUILDERS) | set(_LAYER_BUILDERS_WITH_ID)
|
||||||
|
|
||||||
|
|
||||||
|
def build_navigator_export(
|
||||||
|
db: Session,
|
||||||
|
layer_type: str,
|
||||||
|
*,
|
||||||
|
layer_id: str | None = None,
|
||||||
|
platforms: str | None = None,
|
||||||
|
tactics: str | None = None,
|
||||||
|
min_score: int = 0,
|
||||||
|
) -> dict:
|
||||||
|
"""Build a heatmap layer dict by type name.
|
||||||
|
|
||||||
|
Raises :class:`BusinessRuleViolation` for unknown layer types or
|
||||||
|
missing ``layer_id``. Raises :class:`EntityNotFoundError` when
|
||||||
|
an entity-bound layer (threat-actor, campaign) references a
|
||||||
|
non-existent record.
|
||||||
|
"""
|
||||||
|
kwargs = dict(platforms=platforms, tactics=tactics, min_score=min_score)
|
||||||
|
|
||||||
|
if layer_type in _LAYER_BUILDERS:
|
||||||
|
return _LAYER_BUILDERS[layer_type](db, **kwargs)
|
||||||
|
|
||||||
|
if layer_type in _LAYER_BUILDERS_WITH_ID:
|
||||||
|
if not layer_id:
|
||||||
|
raise BusinessRuleViolation(
|
||||||
|
f"layer_id is required for '{layer_type}' layer"
|
||||||
|
)
|
||||||
|
return _LAYER_BUILDERS_WITH_ID[layer_type](db, layer_id, **kwargs)
|
||||||
|
|
||||||
|
raise BusinessRuleViolation(f"Unknown layer type: {layer_type}")
|
||||||
|
|||||||
180
backend/tests/test_heatmap_service.py
Normal file
180
backend/tests/test_heatmap_service.py
Normal file
@@ -0,0 +1,180 @@
|
|||||||
|
"""Tests for heatmap_service — pure helpers, error paths, and dispatch."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from app.services.heatmap_service import (
|
||||||
|
_score_to_color,
|
||||||
|
_build_layer_skeleton,
|
||||||
|
_parse_csv,
|
||||||
|
_format_tactic,
|
||||||
|
build_navigator_export,
|
||||||
|
build_threat_actor_layer,
|
||||||
|
build_campaign_layer,
|
||||||
|
SUPPORTED_LAYER_TYPES,
|
||||||
|
ATTACK_VERSION,
|
||||||
|
NAVIGATOR_VERSION,
|
||||||
|
LAYER_VERSION,
|
||||||
|
DOMAIN,
|
||||||
|
)
|
||||||
|
from app.domain.errors import BusinessRuleViolation, EntityNotFoundError
|
||||||
|
|
||||||
|
|
||||||
|
# ── Pure helpers ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestScoreToColor:
|
||||||
|
def test_zero_returns_grey(self):
|
||||||
|
assert _score_to_color(0) == "#d3d3d3"
|
||||||
|
|
||||||
|
def test_negative_returns_grey(self):
|
||||||
|
assert _score_to_color(-10) == "#d3d3d3"
|
||||||
|
|
||||||
|
def test_low_returns_red(self):
|
||||||
|
assert _score_to_color(25) == "#ff6666"
|
||||||
|
|
||||||
|
def test_medium_returns_orange(self):
|
||||||
|
assert _score_to_color(50) == "#ff9933"
|
||||||
|
|
||||||
|
def test_high_returns_yellow(self):
|
||||||
|
assert _score_to_color(75) == "#ffff66"
|
||||||
|
|
||||||
|
def test_max_returns_green(self):
|
||||||
|
assert _score_to_color(100) == "#66ff66"
|
||||||
|
|
||||||
|
|
||||||
|
class TestBuildLayerSkeleton:
|
||||||
|
def test_has_required_keys(self):
|
||||||
|
layer = _build_layer_skeleton("Test Layer", "A description")
|
||||||
|
assert layer["name"] == "Test Layer"
|
||||||
|
assert layer["description"] == "A description"
|
||||||
|
assert layer["domain"] == DOMAIN
|
||||||
|
assert layer["techniques"] == []
|
||||||
|
assert layer["versions"]["attack"] == ATTACK_VERSION
|
||||||
|
assert layer["versions"]["navigator"] == NAVIGATOR_VERSION
|
||||||
|
assert layer["versions"]["layer"] == LAYER_VERSION
|
||||||
|
|
||||||
|
def test_default_gradient(self):
|
||||||
|
layer = _build_layer_skeleton("X", "Y")
|
||||||
|
assert layer["gradient"]["minValue"] == 0
|
||||||
|
assert layer["gradient"]["maxValue"] == 100
|
||||||
|
assert len(layer["gradient"]["colors"]) == 3
|
||||||
|
|
||||||
|
def test_custom_gradient(self):
|
||||||
|
layer = _build_layer_skeleton("X", "Y", gradient_colors=["#000", "#fff"])
|
||||||
|
assert layer["gradient"]["colors"] == ["#000", "#fff"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestParseCsv:
|
||||||
|
def test_none_returns_none(self):
|
||||||
|
assert _parse_csv(None) is None
|
||||||
|
|
||||||
|
def test_empty_string_returns_none(self):
|
||||||
|
assert _parse_csv("") is None
|
||||||
|
|
||||||
|
def test_single_value(self):
|
||||||
|
assert _parse_csv("windows") == ["windows"]
|
||||||
|
|
||||||
|
def test_multiple_values_with_spaces(self):
|
||||||
|
assert _parse_csv("windows, linux, macos") == ["windows", "linux", "macos"]
|
||||||
|
|
||||||
|
def test_empty_elements_filtered(self):
|
||||||
|
assert _parse_csv("a,,b") == ["a", "b"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestFormatTactic:
|
||||||
|
def test_none_returns_empty(self):
|
||||||
|
assert _format_tactic(None) == ""
|
||||||
|
|
||||||
|
def test_empty_returns_empty(self):
|
||||||
|
assert _format_tactic("") == ""
|
||||||
|
|
||||||
|
def test_lowercases(self):
|
||||||
|
assert _format_tactic("Initial Access") == "initial access"
|
||||||
|
|
||||||
|
def test_comma_separated_takes_first(self):
|
||||||
|
assert _format_tactic("Execution, Persistence") == "execution"
|
||||||
|
|
||||||
|
|
||||||
|
# ── build_navigator_export dispatch ───────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_db():
|
||||||
|
return MagicMock()
|
||||||
|
|
||||||
|
|
||||||
|
class TestBuildNavigatorExport:
|
||||||
|
@patch("app.services.heatmap_service.build_coverage_layer")
|
||||||
|
def test_dispatches_coverage(self, mock_build):
|
||||||
|
mock_build.return_value = {"name": "coverage"}
|
||||||
|
result = build_navigator_export(_mock_db(), "coverage")
|
||||||
|
assert result["name"] == "coverage"
|
||||||
|
mock_build.assert_called_once()
|
||||||
|
|
||||||
|
@patch("app.services.heatmap_service.build_detection_rules_layer")
|
||||||
|
def test_dispatches_detection_rules(self, mock_build):
|
||||||
|
mock_build.return_value = {"name": "rules"}
|
||||||
|
result = build_navigator_export(_mock_db(), "detection-rules")
|
||||||
|
assert result["name"] == "rules"
|
||||||
|
mock_build.assert_called_once()
|
||||||
|
|
||||||
|
@patch("app.services.heatmap_service.build_threat_actor_layer")
|
||||||
|
def test_dispatches_threat_actor_with_id(self, mock_build):
|
||||||
|
mock_build.return_value = {"name": "actor"}
|
||||||
|
result = build_navigator_export(_mock_db(), "threat-actor", layer_id="abc")
|
||||||
|
assert result["name"] == "actor"
|
||||||
|
mock_build.assert_called_once()
|
||||||
|
|
||||||
|
@patch("app.services.heatmap_service.build_campaign_layer")
|
||||||
|
def test_dispatches_campaign_with_id(self, mock_build):
|
||||||
|
mock_build.return_value = {"name": "campaign"}
|
||||||
|
result = build_navigator_export(_mock_db(), "campaign", layer_id="xyz")
|
||||||
|
assert result["name"] == "campaign"
|
||||||
|
mock_build.assert_called_once()
|
||||||
|
|
||||||
|
def test_unknown_layer_raises(self):
|
||||||
|
with pytest.raises(BusinessRuleViolation, match="Unknown layer type"):
|
||||||
|
build_navigator_export(_mock_db(), "nonexistent")
|
||||||
|
|
||||||
|
def test_missing_layer_id_for_threat_actor(self):
|
||||||
|
with pytest.raises(BusinessRuleViolation, match="layer_id is required"):
|
||||||
|
build_navigator_export(_mock_db(), "threat-actor")
|
||||||
|
|
||||||
|
def test_missing_layer_id_for_campaign(self):
|
||||||
|
with pytest.raises(BusinessRuleViolation, match="layer_id is required"):
|
||||||
|
build_navigator_export(_mock_db(), "campaign")
|
||||||
|
|
||||||
|
def test_supported_layer_types_complete(self):
|
||||||
|
assert SUPPORTED_LAYER_TYPES == {
|
||||||
|
"coverage", "detection-rules", "threat-actor", "campaign",
|
||||||
|
}
|
||||||
|
|
||||||
|
@patch("app.services.heatmap_service.build_coverage_layer")
|
||||||
|
def test_passes_filter_kwargs(self, mock_build):
|
||||||
|
mock_build.return_value = {}
|
||||||
|
build_navigator_export(
|
||||||
|
_mock_db(), "coverage",
|
||||||
|
platforms="windows", tactics="execution", min_score=50,
|
||||||
|
)
|
||||||
|
_, kwargs = mock_build.call_args
|
||||||
|
assert kwargs["platforms"] == "windows"
|
||||||
|
assert kwargs["tactics"] == "execution"
|
||||||
|
assert kwargs["min_score"] == 50
|
||||||
|
|
||||||
|
|
||||||
|
# ── Entity-not-found errors ───────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestEntityNotFound:
|
||||||
|
def _db_returning_none(self):
|
||||||
|
db = MagicMock()
|
||||||
|
db.query.return_value.filter.return_value.first.return_value = None
|
||||||
|
return db
|
||||||
|
|
||||||
|
def test_threat_actor_not_found(self):
|
||||||
|
with pytest.raises(EntityNotFoundError, match="ThreatActor"):
|
||||||
|
build_threat_actor_layer(self._db_returning_none(), "bad-id")
|
||||||
|
|
||||||
|
def test_campaign_not_found(self):
|
||||||
|
with pytest.raises(EntityNotFoundError, match="Campaign"):
|
||||||
|
build_campaign_layer(self._db_returning_none(), "bad-id")
|
||||||
Reference in New Issue
Block a user