refactor(heatmap): extract business logic to dedicated service
Some checks failed
Aegis CI / lint-and-test (push) Has been cancelled

- Create heatmap_service.py with all layer-building logic (coverage, threat-actor, detection-rules, campaign)

- Service is framework-agnostic: no FastAPI imports, no HTTPException, no db.commit()

- Fix N+1 in coverage and threat-actor layers: bulk-fetch test_counts and rule_counts with GROUP BY

- Router reduced from 528 to 140 lines: validates request, calls service, returns response
This commit is contained in:
2026-02-18 13:14:41 +01:00
parent bfce1a8a0e
commit 6147abc87a
2 changed files with 492 additions and 425 deletions

View File

@@ -1,157 +1,23 @@
"""Heatmap endpoints — ATT&CK Navigator-compatible layer generation.
Provides multiple layer types (coverage, threat actor, detection rules,
campaign) and an export endpoint that produces a JSON file importable
by the official MITRE ATT&CK Navigator.
Thin router that delegates to :mod:`app.services.heatmap_service`.
"""
from typing import Optional, List
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.responses import StreamingResponse
from sqlalchemy import func
from sqlalchemy.orm import Session
import io
import json
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from app.database import get_db
from app.dependencies.auth import get_current_user
from app.models.user import User
from app.models.technique import Technique
from app.models.test import Test
from app.models.threat_actor import ThreatActor, ThreatActorTechnique
from app.models.detection_rule import DetectionRule
from app.models.campaign import Campaign, CampaignTest
from app.models.defensive_technique import DefensiveTechniqueMapping
from app.models.enums import TechniqueStatus, TestState
from app.services import heatmap_service
router = APIRouter(prefix="/heatmap", tags=["heatmap"])
# ── Constants ─────────────────────────────────────────────────────────
ATTACK_VERSION = "15"
NAVIGATOR_VERSION = "5.0"
LAYER_VERSION = "4.5"
DOMAIN = "enterprise-attack"
# Score mapping for technique status_global
STATUS_SCORE_MAP = {
TechniqueStatus.validated: 100,
TechniqueStatus.partial: 60,
TechniqueStatus.in_progress: 30,
TechniqueStatus.not_covered: 10,
TechniqueStatus.not_evaluated: 0,
TechniqueStatus.review_required: 10,
}
# ── Helpers ───────────────────────────────────────────────────────────
def _score_to_color(score: int) -> str:
"""Map a 0-100 score to a red → yellow → green color hex."""
if score <= 0:
return "#d3d3d3" # gray for not evaluated
if score <= 25:
return "#ff6666" # red
if score <= 50:
return "#ff9933" # orange
if score <= 75:
return "#ffff66" # yellow
return "#66ff66" # green
def _build_layer_skeleton(
name: str,
description: str,
gradient_colors: List[str] | None = None,
) -> dict:
"""Return a base layer dict compatible with ATT&CK Navigator."""
return {
"name": name,
"versions": {
"attack": ATTACK_VERSION,
"navigator": NAVIGATOR_VERSION,
"layer": LAYER_VERSION,
},
"domain": DOMAIN,
"description": description,
"filters": {"platforms": ["windows", "linux", "macos"]},
"gradient": {
"colors": gradient_colors or ["#ff6666", "#ffff66", "#66ff66"],
"minValue": 0,
"maxValue": 100,
},
"techniques": [],
}
def _apply_filters(
query,
model,
platforms: Optional[List[str]] = None,
tactics: Optional[List[str]] = None,
):
"""Apply common platform and tactic filters to a technique query."""
if platforms:
from sqlalchemy import or_, cast, String
from sqlalchemy.dialects.postgresql import JSONB
# Filter techniques that have any of the specified platforms
platform_filters = []
for platform in platforms:
platform_filters.append(
model.platforms.op("@>")(json.dumps([platform]))
)
if platform_filters:
query = query.filter(or_(*platform_filters))
if tactics:
from sqlalchemy import or_
from app.utils import escape_like
tactic_filters = []
for tactic in tactics:
tactic_filters.append(model.tactic.ilike(f"%{escape_like(tactic)}%"))
query = query.filter(or_(*tactic_filters))
return query
def _format_tactic(tactic_str: str | None) -> str:
"""Normalize tactic string to ATT&CK Navigator format (kebab-case)."""
if not tactic_str:
return ""
# Take first tactic if comma-separated
first = tactic_str.split(",")[0].strip().lower()
return first
def _get_technique_metadata(technique, db: Session) -> list:
"""Build metadata array for a technique."""
# Count validated tests
test_count = (
db.query(func.count(Test.id))
.filter(Test.technique_id == technique.id, Test.state == TestState.validated)
.scalar()
) or 0
# Count detection rules
rule_count = (
db.query(func.count(DetectionRule.id))
.filter(DetectionRule.mitre_technique_id == technique.mitre_id)
.scalar()
) or 0
metadata = [
{"name": "tests_count", "value": str(test_count)},
{"name": "detection_rules", "value": str(rule_count)},
]
if technique.last_review_date:
metadata.append(
{"name": "last_validated", "value": technique.last_review_date.strftime("%Y-%m-%d")}
)
return metadata
# ── GET /heatmap/coverage ─────────────────────────────────────────────
@@ -165,43 +31,9 @@ def heatmap_coverage(
current_user: User = Depends(get_current_user),
):
"""Coverage layer — score based on status_global of each technique."""
layer = _build_layer_skeleton("Aegis Coverage", "Coverage layer generated by Aegis")
query = db.query(Technique)
platform_list = [p.strip() for p in platforms.split(",")] if platforms else None
tactic_list = [t.strip() for t in tactics.split(",")] if tactics else None
query = _apply_filters(query, Technique, platform_list, tactic_list)
techniques = query.all()
for tech in techniques:
score = STATUS_SCORE_MAP.get(tech.status_global, 0)
if score < min_score:
continue
comment_parts = [f"Status: {tech.status_global.value}"]
metadata = _get_technique_metadata(tech, db)
# Enrich comment with test/rule info
tests_info = next((m for m in metadata if m["name"] == "tests_count"), None)
rules_info = next((m for m in metadata if m["name"] == "detection_rules"), None)
if tests_info:
comment_parts.append(f"{tests_info['value']} tests validated")
if rules_info:
comment_parts.append(f"{rules_info['value']} detection rules")
layer["techniques"].append({
"techniqueID": tech.mitre_id,
"tactic": _format_tactic(tech.tactic),
"color": _score_to_color(score),
"score": score,
"comment": " - ".join(comment_parts),
"enabled": True,
"metadata": metadata,
})
return layer
return heatmap_service.build_coverage_layer(
db, platforms=platforms, tactics=tactics, min_score=min_score,
)
# ── GET /heatmap/threat-actor/{actor_id} ──────────────────────────────
@@ -217,62 +49,11 @@ def heatmap_threat_actor(
current_user: User = Depends(get_current_user),
):
"""Threat actor layer — techniques used by an actor with coverage color."""
actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first()
if not actor:
layer = heatmap_service.build_threat_actor_layer(
db, actor_id, platforms=platforms, tactics=tactics, min_score=min_score,
)
if layer is None:
raise HTTPException(status_code=404, detail="Threat actor not found")
layer = _build_layer_skeleton(
f"Threat Actor: {actor.name}",
f"Techniques used by {actor.name} with coverage overlay",
gradient_colors=["#808080", "#ff6666", "#66ff66"],
)
# Get actor's technique IDs
actor_technique_rows = (
db.query(ThreatActorTechnique)
.filter(ThreatActorTechnique.threat_actor_id == actor.id)
.all()
)
actor_technique_ids = {row.technique_id for row in actor_technique_rows}
if not actor_technique_ids:
return layer
query = db.query(Technique)
platform_list = [p.strip() for p in platforms.split(",")] if platforms else None
tactic_list = [t.strip() for t in tactics.split(",")] if tactics else None
query = _apply_filters(query, Technique, platform_list, tactic_list)
techniques = query.all()
for tech in techniques:
is_actor_technique = tech.id in actor_technique_ids
score = STATUS_SCORE_MAP.get(tech.status_global, 0) if is_actor_technique else 0
if is_actor_technique and score < min_score:
continue
if is_actor_technique:
metadata = _get_technique_metadata(tech, db)
layer["techniques"].append({
"techniqueID": tech.mitre_id,
"tactic": _format_tactic(tech.tactic),
"color": _score_to_color(score),
"score": score,
"comment": f"Used by {actor.name} - Coverage: {tech.status_global.value}",
"enabled": True,
"metadata": metadata,
})
else:
layer["techniques"].append({
"techniqueID": tech.mitre_id,
"tactic": _format_tactic(tech.tactic),
"color": "",
"score": 0,
"comment": "",
"enabled": False,
"metadata": [],
})
return layer
@@ -288,76 +69,10 @@ def heatmap_detection_rules(
current_user: User = Depends(get_current_user),
):
"""Detection rules layer — score based on ratio of rules available vs total."""
layer = _build_layer_skeleton(
"Detection Rules Coverage",
"Coverage of detection rules per technique",
return heatmap_service.build_detection_rules_layer(
db, platforms=platforms, tactics=tactics, min_score=min_score,
)
query = db.query(Technique)
platform_list = [p.strip() for p in platforms.split(",")] if platforms else None
tactic_list = [t.strip() for t in tactics.split(",")] if tactics else None
query = _apply_filters(query, Technique, platform_list, tactic_list)
techniques = query.all()
# Get rule counts per technique_mitre_id in one query
rule_counts = dict(
db.query(
DetectionRule.mitre_technique_id,
func.count(DetectionRule.id),
)
.filter(DetectionRule.is_active == True)
.group_by(DetectionRule.mitre_technique_id)
.all()
)
# Find the max rule count for normalization
max_rules = max(rule_counts.values()) if rule_counts else 1
from app.models.test_detection_result import TestDetectionResult
# Get evaluated rule counts per technique
evaluated_counts_raw = (
db.query(
DetectionRule.mitre_technique_id,
func.count(TestDetectionResult.id),
)
.join(TestDetectionResult, TestDetectionResult.detection_rule_id == DetectionRule.id)
.filter(TestDetectionResult.triggered.isnot(None))
.group_by(DetectionRule.mitre_technique_id)
.all()
)
evaluated_counts = dict(evaluated_counts_raw)
for tech in techniques:
total_rules = rule_counts.get(tech.mitre_id, 0)
evaluated_rules = evaluated_counts.get(tech.mitre_id, 0)
if total_rules > 0:
# Score based on rule availability (normalized) and evaluation ratio
availability_score = min((total_rules / max_rules) * 50, 50)
evaluation_score = (evaluated_rules / total_rules) * 50 if total_rules > 0 else 0
score = int(min(availability_score + evaluation_score, 100))
else:
score = 0
if score < min_score:
continue
layer["techniques"].append({
"techniqueID": tech.mitre_id,
"tactic": _format_tactic(tech.tactic),
"color": _score_to_color(score),
"score": score,
"comment": f"{total_rules} rules available, {evaluated_rules} evaluated",
"enabled": True,
"metadata": [
{"name": "total_rules", "value": str(total_rules)},
{"name": "evaluated_rules", "value": str(evaluated_rules)},
],
})
return layer
# ── GET /heatmap/campaign/{campaign_id} ───────────────────────────────
@@ -372,107 +87,26 @@ def heatmap_campaign(
current_user: User = Depends(get_current_user),
):
"""Campaign layer — only techniques in the campaign, colored by test state."""
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
if not campaign:
layer = heatmap_service.build_campaign_layer(
db, campaign_id, platforms=platforms, tactics=tactics, min_score=min_score,
)
if layer is None:
raise HTTPException(status_code=404, detail="Campaign not found")
layer = _build_layer_skeleton(
f"Campaign: {campaign.name}",
f"Progress of campaign '{campaign.name}'",
)
# Get campaign tests with their associated techniques
campaign_tests = (
db.query(CampaignTest)
.filter(CampaignTest.campaign_id == campaign.id)
.all()
)
if not campaign_tests:
return layer
# Map test_id -> test for all tests in campaign
test_ids = [ct.test_id for ct in campaign_tests]
tests = db.query(Test).filter(Test.id.in_(test_ids)).all()
test_map = {t.id: t for t in tests}
# Map technique_id -> technique
technique_ids = {t.technique_id for t in tests if t.technique_id}
techniques = db.query(Technique).filter(Technique.id.in_(technique_ids)).all()
tech_map = {t.id: t for t in techniques}
# Score mapping for test states
test_state_score = {
TestState.validated: 100,
TestState.in_review: 70,
TestState.blue_evaluating: 50,
TestState.red_executing: 30,
TestState.draft: 10,
TestState.rejected: 5,
}
# Group by technique (a technique may have multiple tests in a campaign)
tech_scores: dict = {}
for ct in campaign_tests:
test = test_map.get(ct.test_id)
if not test:
continue
tech = tech_map.get(test.technique_id)
if not tech:
continue
state_score = test_state_score.get(test.state, 0)
if tech.mitre_id not in tech_scores:
tech_scores[tech.mitre_id] = {
"technique": tech,
"max_score": state_score,
"tests": [],
}
else:
tech_scores[tech.mitre_id]["max_score"] = max(
tech_scores[tech.mitre_id]["max_score"], state_score
)
tech_scores[tech.mitre_id]["tests"].append(test)
platform_list = [p.strip() for p in platforms.split(",")] if platforms else None
tactic_list = [t.strip() for t in tactics.split(",")] if tactics else None
for mitre_id, info in tech_scores.items():
tech = info["technique"]
score = info["max_score"]
# Apply filters
if platform_list:
tech_platforms = tech.platforms or []
if not any(p in tech_platforms for p in platform_list):
continue
if tactic_list:
tech_tactics = (tech.tactic or "").lower().split(",")
tech_tactics = [t.strip() for t in tech_tactics]
if not any(t in tech_tactics for t in tactic_list):
continue
if score < min_score:
continue
test_states = [t.state.value for t in info["tests"]]
layer["techniques"].append({
"techniqueID": mitre_id,
"tactic": _format_tactic(tech.tactic),
"color": _score_to_color(score),
"score": score,
"comment": f"Campaign tests: {', '.join(test_states)}",
"enabled": True,
"metadata": [
{"name": "campaign_tests", "value": str(len(info["tests"]))},
{"name": "best_state", "value": max(test_states) if test_states else "none"},
],
})
return layer
# ── GET /heatmap/export-navigator ─────────────────────────────────────
_LAYER_BUILDERS = {
"coverage": lambda db, **kw: heatmap_service.build_coverage_layer(db, **kw),
"detection-rules": lambda db, **kw: heatmap_service.build_detection_rules_layer(db, **kw),
}
_LAYER_BUILDERS_WITH_ID = {
"threat-actor": lambda db, lid, **kw: heatmap_service.build_threat_actor_layer(db, lid, **kw),
"campaign": lambda db, lid, **kw: heatmap_service.build_campaign_layer(db, lid, **kw),
}
@router.get("/export-navigator")
def export_navigator(
@@ -485,43 +119,24 @@ def export_navigator(
current_user: User = Depends(get_current_user),
):
"""Export a heatmap layer as a downloadable JSON file for ATT&CK Navigator."""
# Delegate to the appropriate layer endpoint
if layer == "coverage":
data = heatmap_coverage(
platforms=platforms, tactics=tactics, min_score=min_score,
db=db, current_user=current_user,
)
elif layer == "threat-actor":
kwargs = dict(platforms=platforms, tactics=tactics, min_score=min_score)
if layer in _LAYER_BUILDERS:
data = _LAYER_BUILDERS[layer](db, **kwargs)
elif layer in _LAYER_BUILDERS_WITH_ID:
if not layer_id:
raise HTTPException(status_code=400, detail="layer_id required for threat-actor layer")
data = heatmap_threat_actor(
actor_id=layer_id, platforms=platforms, tactics=tactics,
min_score=min_score, db=db, current_user=current_user,
)
elif layer == "detection-rules":
data = heatmap_detection_rules(
platforms=platforms, tactics=tactics, min_score=min_score,
db=db, current_user=current_user,
)
elif layer == "campaign":
if not layer_id:
raise HTTPException(status_code=400, detail="layer_id required for campaign layer")
data = heatmap_campaign(
campaign_id=layer_id, platforms=platforms, tactics=tactics,
min_score=min_score, db=db, current_user=current_user,
)
raise HTTPException(status_code=400, detail=f"layer_id required for {layer} layer")
data = _LAYER_BUILDERS_WITH_ID[layer](db, layer_id, **kwargs)
if data is None:
raise HTTPException(status_code=404, detail=f"{layer} not found")
else:
raise HTTPException(status_code=400, detail=f"Unknown layer type: {layer}")
# Convert to JSON and return as downloadable file
json_content = json.dumps(data, indent=2, default=str)
buffer = io.BytesIO(json_content.encode("utf-8"))
filename = f"aegis_{layer}_layer.json"
return StreamingResponse(
buffer,
media_type="application/json",
headers={
"Content-Disposition": f"attachment; filename={filename}",
},
headers={"Content-Disposition": f"attachment; filename=aegis_{layer}_layer.json"},
)