refactor(threat-actors): extract query/business logic to threat_actor_service, fix N+1 with grouped subqueries
Some checks failed
Aegis CI / lint-and-test (push) Has been cancelled

This commit is contained in:
2026-02-19 17:40:00 +01:00
parent 560fc0c9f0
commit 93fde55389
2 changed files with 331 additions and 235 deletions

View File

@@ -7,28 +7,24 @@ threat actor profiles imported from MITRE CTI.
import logging import logging
from typing import Optional from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query from fastapi import APIRouter, Depends, Query
from sqlalchemy import func, or_ from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, joinedload
from app.database import get_db from app.database import get_db
from app.dependencies.auth import get_current_user from app.dependencies.auth import get_current_user
from app.models.user import User from app.models.user import User
from app.models.threat_actor import ThreatActor, ThreatActorTechnique from app.services.threat_actor_service import (
from app.models.technique import Technique get_actor_coverage,
from app.models.test import Test get_actor_detail,
from app.models.test_template import TestTemplate get_actor_gaps,
from app.models.enums import TechniqueStatus list_actors,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
router = APIRouter(prefix="/threat-actors", tags=["threat-actors"]) router = APIRouter(prefix="/threat-actors", tags=["threat-actors"])
# ---------------------------------------------------------------------------
# GET /threat-actors — Listado con filtros
# ---------------------------------------------------------------------------
@router.get("") @router.get("")
def list_threat_actors( def list_threat_actors(
search: Optional[str] = Query(None), search: Optional[str] = Query(None),
@@ -45,92 +41,17 @@ def list_threat_actors(
**Requires** authentication (any role). **Requires** authentication (any role).
""" """
query = db.query(ThreatActor) return list_actors(
db,
search=search,
country=country,
motivation=motivation,
sophistication=sophistication,
target_sectors=target_sectors,
offset=offset,
limit=limit,
)
# Filters
if search:
from app.utils import escape_like
pattern = f"%{escape_like(search)}%"
query = query.filter(
or_(
ThreatActor.name.ilike(pattern),
ThreatActor.description.ilike(pattern),
func.cast(ThreatActor.aliases, func.text()).ilike(pattern),
)
)
if country:
query = query.filter(ThreatActor.country == country)
if motivation:
query = query.filter(ThreatActor.motivation == motivation)
if sophistication:
query = query.filter(ThreatActor.sophistication == sophistication)
if target_sectors:
from app.utils import escape_like
# JSONB contains check
query = query.filter(
func.cast(ThreatActor.target_sectors, func.text()).ilike(f"%{escape_like(target_sectors)}%")
)
# Total count
total = query.count()
# Paginate
actors = query.order_by(ThreatActor.name).offset(offset).limit(limit).all()
# For each actor, count techniques and calculate basic coverage
results = []
for actor in actors:
tech_count = (
db.query(ThreatActorTechnique)
.filter(ThreatActorTechnique.threat_actor_id == actor.id)
.count()
)
# Quick coverage calculation
covered = (
db.query(ThreatActorTechnique)
.join(Technique, ThreatActorTechnique.technique_id == Technique.id)
.filter(ThreatActorTechnique.threat_actor_id == actor.id)
.filter(Technique.status_global.in_([
TechniqueStatus.validated,
TechniqueStatus.partial,
]))
.count()
)
coverage_pct = round((covered / tech_count * 100), 1) if tech_count > 0 else 0.0
results.append({
"id": str(actor.id),
"mitre_id": actor.mitre_id,
"name": actor.name,
"aliases": actor.aliases or [],
"country": actor.country,
"target_sectors": actor.target_sectors or [],
"target_regions": actor.target_regions or [],
"motivation": actor.motivation,
"sophistication": actor.sophistication,
"mitre_url": actor.mitre_url,
"technique_count": tech_count,
"coverage_pct": coverage_pct,
"is_active": actor.is_active,
})
return {
"total": total,
"offset": offset,
"limit": limit,
"items": results,
}
# ---------------------------------------------------------------------------
# GET /threat-actors/{id} — Detalle
# ---------------------------------------------------------------------------
@router.get("/{actor_id}") @router.get("/{actor_id}")
def get_threat_actor( def get_threat_actor(
@@ -142,54 +63,8 @@ def get_threat_actor(
**Requires** authentication (any role). **Requires** authentication (any role).
""" """
actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first() return get_actor_detail(db, actor_id)
if not actor:
raise HTTPException(status_code=404, detail="Threat actor not found")
# Get associated techniques with their coverage status
actor_techniques = (
db.query(ThreatActorTechnique, Technique)
.join(Technique, ThreatActorTechnique.technique_id == Technique.id)
.filter(ThreatActorTechnique.threat_actor_id == actor.id)
.order_by(Technique.mitre_id)
.all()
)
techniques_list = []
for at, tech in actor_techniques:
techniques_list.append({
"technique_id": str(tech.id),
"mitre_id": tech.mitre_id,
"name": tech.name,
"tactic": tech.tactic,
"status_global": tech.status_global.value if tech.status_global else None,
"usage_description": at.usage_description,
"first_seen_using": at.first_seen_using,
})
return {
"id": str(actor.id),
"mitre_id": actor.mitre_id,
"name": actor.name,
"aliases": actor.aliases or [],
"description": actor.description,
"country": actor.country,
"target_sectors": actor.target_sectors or [],
"target_regions": actor.target_regions or [],
"motivation": actor.motivation,
"sophistication": actor.sophistication,
"first_seen": actor.first_seen,
"last_seen": actor.last_seen,
"references": actor.references or [],
"mitre_url": actor.mitre_url,
"is_active": actor.is_active,
"techniques": techniques_list,
}
# ---------------------------------------------------------------------------
# GET /threat-actors/{id}/coverage — Cobertura
# ---------------------------------------------------------------------------
@router.get("/{actor_id}/coverage") @router.get("/{actor_id}/coverage")
def get_threat_actor_coverage( def get_threat_actor_coverage(
@@ -204,49 +79,8 @@ def get_threat_actor_coverage(
Returns the percentage of the actor's techniques that have been Returns the percentage of the actor's techniques that have been
validated or partially validated, along with a breakdown. validated or partially validated, along with a breakdown.
""" """
actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first() return get_actor_coverage(db, actor_id)
if not actor:
raise HTTPException(status_code=404, detail="Threat actor not found")
# Get all techniques for this actor
actor_techniques = (
db.query(Technique)
.join(ThreatActorTechnique, ThreatActorTechnique.technique_id == Technique.id)
.filter(ThreatActorTechnique.threat_actor_id == actor.id)
.all()
)
total = len(actor_techniques)
if total == 0:
return {
"actor_id": str(actor.id),
"actor_name": actor.name,
"total_techniques": 0,
"coverage_pct": 0.0,
"breakdown": {},
}
breakdown = {}
for tech in actor_techniques:
status = tech.status_global.value if tech.status_global else "not_evaluated"
breakdown[status] = breakdown.get(status, 0) + 1
covered = breakdown.get("validated", 0) + breakdown.get("partial", 0)
coverage_pct = round((covered / total * 100), 1)
return {
"actor_id": str(actor.id),
"actor_name": actor.name,
"total_techniques": total,
"covered": covered,
"coverage_pct": coverage_pct,
"breakdown": breakdown,
}
# ---------------------------------------------------------------------------
# GET /threat-actors/{id}/gaps — Gap analysis
# ---------------------------------------------------------------------------
@router.get("/{actor_id}/gaps") @router.get("/{actor_id}/gaps")
def get_threat_actor_gaps( def get_threat_actor_gaps(
@@ -260,52 +94,4 @@ def get_threat_actor_gaps(
Returns list of gap techniques with available templates. Returns list of gap techniques with available templates.
""" """
actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first() return get_actor_gaps(db, actor_id)
if not actor:
raise HTTPException(status_code=404, detail="Threat actor not found")
# Get techniques NOT validated
gap_techniques = (
db.query(Technique, ThreatActorTechnique)
.join(ThreatActorTechnique, ThreatActorTechnique.technique_id == Technique.id)
.filter(ThreatActorTechnique.threat_actor_id == actor.id)
.filter(Technique.status_global != TechniqueStatus.validated)
.order_by(Technique.mitre_id)
.all()
)
gaps = []
for tech, at in gap_techniques:
# Count available templates for this technique
template_count = (
db.query(TestTemplate)
.filter(TestTemplate.mitre_technique_id == tech.mitre_id)
.filter(TestTemplate.is_active == True)
.count()
)
# Count existing tests
test_count = (
db.query(Test)
.filter(Test.technique_id == tech.id)
.count()
)
gaps.append({
"technique_id": str(tech.id),
"mitre_id": tech.mitre_id,
"name": tech.name,
"tactic": tech.tactic,
"status_global": tech.status_global.value if tech.status_global else None,
"usage_description": at.usage_description,
"available_templates": template_count,
"existing_tests": test_count,
"has_templates": template_count > 0,
})
return {
"actor_id": str(actor.id),
"actor_name": actor.name,
"total_gaps": len(gaps),
"gaps": gaps,
}

View File

@@ -0,0 +1,310 @@
"""Threat actor data service.
Extracts query and business logic from the threat_actors router so
that the router remains a thin HTTP adapter.
This module is framework-agnostic: no FastAPI imports.
"""
from __future__ import annotations
from typing import Any
from sqlalchemy import case, func, or_
from sqlalchemy.orm import Session
from app.domain.errors import EntityNotFoundError
from app.models.enums import TechniqueStatus
from app.models.test import Test
from app.models.test_template import TestTemplate
from app.models.threat_actor import ThreatActor, ThreatActorTechnique
from app.models.technique import Technique
from app.utils import escape_like
# ── Public service functions ──────────────────────────────────────────
def list_actors(
db: Session,
*,
search: str | None = None,
country: str | None = None,
motivation: str | None = None,
sophistication: str | None = None,
target_sectors: str | None = None,
offset: int = 0,
limit: int = 50,
) -> dict[str, Any]:
"""List threat actors with optional filters, pagination, and coverage stats.
Uses grouped subqueries to avoid N+1: technique counts and coverage
counts are fetched in one query per page.
"""
query = db.query(ThreatActor)
if search:
pattern = f"%{escape_like(search)}%"
query = query.filter(
or_(
ThreatActor.name.ilike(pattern),
ThreatActor.description.ilike(pattern),
func.cast(ThreatActor.aliases, func.text()).ilike(pattern),
)
)
if country:
query = query.filter(ThreatActor.country == country)
if motivation:
query = query.filter(ThreatActor.motivation == motivation)
if sophistication:
query = query.filter(ThreatActor.sophistication == sophistication)
if target_sectors:
query = query.filter(
func.cast(ThreatActor.target_sectors, func.text()).ilike(
f"%{escape_like(target_sectors)}%"
)
)
total = query.count()
actors = (
query.order_by(ThreatActor.name).offset(offset).limit(limit).all()
)
if not actors:
return {
"total": total,
"offset": offset,
"limit": limit,
"items": [],
}
actor_ids = [a.id for a in actors]
# Single grouped query: tech_count and covered_count per actor
counts_rows = (
db.query(
ThreatActorTechnique.threat_actor_id,
func.count(ThreatActorTechnique.id).label("tech_count"),
func.sum(
case(
(
Technique.status_global.in_([
TechniqueStatus.validated,
TechniqueStatus.partial,
]),
1,
),
else_=0,
)
).label("covered_count"),
)
.join(Technique, ThreatActorTechnique.technique_id == Technique.id)
.filter(ThreatActorTechnique.threat_actor_id.in_(actor_ids))
.group_by(ThreatActorTechnique.threat_actor_id)
).all()
counts_map = {
str(row.threat_actor_id): {
"tech_count": row.tech_count,
"covered_count": row.covered_count or 0,
}
for row in counts_rows
}
results = []
for actor in actors:
cnt = counts_map.get(str(actor.id), {"tech_count": 0, "covered_count": 0})
tech_count = cnt["tech_count"]
covered = cnt["covered_count"]
coverage_pct = round((covered / tech_count * 100), 1) if tech_count > 0 else 0.0
results.append({
"id": str(actor.id),
"mitre_id": actor.mitre_id,
"name": actor.name,
"aliases": actor.aliases or [],
"country": actor.country,
"target_sectors": actor.target_sectors or [],
"target_regions": actor.target_regions or [],
"motivation": actor.motivation,
"sophistication": actor.sophistication,
"mitre_url": actor.mitre_url,
"technique_count": tech_count,
"coverage_pct": coverage_pct,
"is_active": actor.is_active,
})
return {
"total": total,
"offset": offset,
"limit": limit,
"items": results,
}
def get_actor_detail(db: Session, actor_id: str) -> dict[str, Any]:
"""Get detailed threat actor with techniques.
Raises EntityNotFoundError if the actor does not exist.
"""
actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first()
if not actor:
raise EntityNotFoundError("Threat actor", actor_id)
actor_techniques = (
db.query(ThreatActorTechnique, Technique)
.join(Technique, ThreatActorTechnique.technique_id == Technique.id)
.filter(ThreatActorTechnique.threat_actor_id == actor.id)
.order_by(Technique.mitre_id)
.all()
)
techniques_list = [
{
"technique_id": str(tech.id),
"mitre_id": tech.mitre_id,
"name": tech.name,
"tactic": tech.tactic,
"status_global": tech.status_global.value if tech.status_global else None,
"usage_description": at.usage_description,
"first_seen_using": at.first_seen_using,
}
for at, tech in actor_techniques
]
return {
"id": str(actor.id),
"mitre_id": actor.mitre_id,
"name": actor.name,
"aliases": actor.aliases or [],
"description": actor.description,
"country": actor.country,
"target_sectors": actor.target_sectors or [],
"target_regions": actor.target_regions or [],
"motivation": actor.motivation,
"sophistication": actor.sophistication,
"first_seen": actor.first_seen,
"last_seen": actor.last_seen,
"references": actor.references or [],
"mitre_url": actor.mitre_url,
"is_active": actor.is_active,
"techniques": techniques_list,
}
def get_actor_coverage(db: Session, actor_id: str) -> dict[str, Any]:
"""Calculate coverage percentage against a specific threat actor.
Raises EntityNotFoundError if the actor does not exist.
"""
actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first()
if not actor:
raise EntityNotFoundError("Threat actor", actor_id)
actor_techniques = (
db.query(Technique)
.join(ThreatActorTechnique, ThreatActorTechnique.technique_id == Technique.id)
.filter(ThreatActorTechnique.threat_actor_id == actor.id)
.all()
)
total = len(actor_techniques)
if total == 0:
return {
"actor_id": str(actor.id),
"actor_name": actor.name,
"total_techniques": 0,
"coverage_pct": 0.0,
"breakdown": {},
}
breakdown: dict[str, int] = {}
for tech in actor_techniques:
status = tech.status_global.value if tech.status_global else "not_evaluated"
breakdown[status] = breakdown.get(status, 0) + 1
covered = breakdown.get("validated", 0) + breakdown.get("partial", 0)
coverage_pct = round((covered / total * 100), 1)
return {
"actor_id": str(actor.id),
"actor_name": actor.name,
"total_techniques": total,
"covered": covered,
"coverage_pct": coverage_pct,
"breakdown": breakdown,
}
def get_actor_gaps(db: Session, actor_id: str) -> dict[str, Any]:
"""Identify techniques of this actor that are not fully validated.
Raises EntityNotFoundError if the actor does not exist.
"""
actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first()
if not actor:
raise EntityNotFoundError("Threat actor", actor_id)
gap_techniques = (
db.query(Technique, ThreatActorTechnique)
.join(ThreatActorTechnique, ThreatActorTechnique.technique_id == Technique.id)
.filter(ThreatActorTechnique.threat_actor_id == actor.id)
.filter(Technique.status_global != TechniqueStatus.validated)
.order_by(Technique.mitre_id)
.all()
)
if not gap_techniques:
return {
"actor_id": str(actor.id),
"actor_name": actor.name,
"total_gaps": 0,
"gaps": [],
}
technique_ids = [tech.id for tech, _ in gap_techniques]
mitre_ids = [tech.mitre_id for tech, _ in gap_techniques]
# Batch template counts by mitre_technique_id
template_counts = (
db.query(TestTemplate.mitre_technique_id, func.count(TestTemplate.id).label("cnt"))
.filter(TestTemplate.mitre_technique_id.in_(mitre_ids))
.filter(TestTemplate.is_active == True)
.group_by(TestTemplate.mitre_technique_id)
).all()
template_map = {row.mitre_technique_id: row.cnt for row in template_counts}
# Batch test counts by technique_id
test_counts = (
db.query(Test.technique_id, func.count(Test.id).label("cnt"))
.filter(Test.technique_id.in_(technique_ids))
.group_by(Test.technique_id)
).all()
test_map = {str(row.technique_id): row.cnt for row in test_counts}
gaps = []
for tech, at in gap_techniques:
template_count = template_map.get(tech.mitre_id, 0)
test_count = test_map.get(str(tech.id), 0)
gaps.append({
"technique_id": str(tech.id),
"mitre_id": tech.mitre_id,
"name": tech.name,
"tactic": tech.tactic,
"status_global": tech.status_global.value if tech.status_global else None,
"usage_description": at.usage_description,
"available_templates": template_count,
"existing_tests": test_count,
"has_templates": template_count > 0,
})
return {
"actor_id": str(actor.id),
"actor_name": actor.name,
"total_gaps": len(gaps),
"gaps": gaps,
}