refactor(threat-actors): extract query/business logic to threat_actor_service, fix N+1 with grouped subqueries
Some checks failed
Aegis CI / lint-and-test (push) Has been cancelled
Some checks failed
Aegis CI / lint-and-test (push) Has been cancelled
This commit is contained in:
@@ -7,28 +7,24 @@ threat actor profiles imported from MITRE CTI.
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy import func, or_
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import get_current_user
|
||||
from app.models.user import User
|
||||
from app.models.threat_actor import ThreatActor, ThreatActorTechnique
|
||||
from app.models.technique import Technique
|
||||
from app.models.test import Test
|
||||
from app.models.test_template import TestTemplate
|
||||
from app.models.enums import TechniqueStatus
|
||||
from app.services.threat_actor_service import (
|
||||
get_actor_coverage,
|
||||
get_actor_detail,
|
||||
get_actor_gaps,
|
||||
list_actors,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/threat-actors", tags=["threat-actors"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /threat-actors — Listado con filtros
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("")
|
||||
def list_threat_actors(
|
||||
search: Optional[str] = Query(None),
|
||||
@@ -45,92 +41,17 @@ def list_threat_actors(
|
||||
|
||||
**Requires** authentication (any role).
|
||||
"""
|
||||
query = db.query(ThreatActor)
|
||||
|
||||
# Filters
|
||||
if search:
|
||||
from app.utils import escape_like
|
||||
pattern = f"%{escape_like(search)}%"
|
||||
query = query.filter(
|
||||
or_(
|
||||
ThreatActor.name.ilike(pattern),
|
||||
ThreatActor.description.ilike(pattern),
|
||||
func.cast(ThreatActor.aliases, func.text()).ilike(pattern),
|
||||
)
|
||||
return list_actors(
|
||||
db,
|
||||
search=search,
|
||||
country=country,
|
||||
motivation=motivation,
|
||||
sophistication=sophistication,
|
||||
target_sectors=target_sectors,
|
||||
offset=offset,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
if country:
|
||||
query = query.filter(ThreatActor.country == country)
|
||||
|
||||
if motivation:
|
||||
query = query.filter(ThreatActor.motivation == motivation)
|
||||
|
||||
if sophistication:
|
||||
query = query.filter(ThreatActor.sophistication == sophistication)
|
||||
|
||||
if target_sectors:
|
||||
from app.utils import escape_like
|
||||
# JSONB contains check
|
||||
query = query.filter(
|
||||
func.cast(ThreatActor.target_sectors, func.text()).ilike(f"%{escape_like(target_sectors)}%")
|
||||
)
|
||||
|
||||
# Total count
|
||||
total = query.count()
|
||||
|
||||
# Paginate
|
||||
actors = query.order_by(ThreatActor.name).offset(offset).limit(limit).all()
|
||||
|
||||
# For each actor, count techniques and calculate basic coverage
|
||||
results = []
|
||||
for actor in actors:
|
||||
tech_count = (
|
||||
db.query(ThreatActorTechnique)
|
||||
.filter(ThreatActorTechnique.threat_actor_id == actor.id)
|
||||
.count()
|
||||
)
|
||||
|
||||
# Quick coverage calculation
|
||||
covered = (
|
||||
db.query(ThreatActorTechnique)
|
||||
.join(Technique, ThreatActorTechnique.technique_id == Technique.id)
|
||||
.filter(ThreatActorTechnique.threat_actor_id == actor.id)
|
||||
.filter(Technique.status_global.in_([
|
||||
TechniqueStatus.validated,
|
||||
TechniqueStatus.partial,
|
||||
]))
|
||||
.count()
|
||||
)
|
||||
|
||||
coverage_pct = round((covered / tech_count * 100), 1) if tech_count > 0 else 0.0
|
||||
|
||||
results.append({
|
||||
"id": str(actor.id),
|
||||
"mitre_id": actor.mitre_id,
|
||||
"name": actor.name,
|
||||
"aliases": actor.aliases or [],
|
||||
"country": actor.country,
|
||||
"target_sectors": actor.target_sectors or [],
|
||||
"target_regions": actor.target_regions or [],
|
||||
"motivation": actor.motivation,
|
||||
"sophistication": actor.sophistication,
|
||||
"mitre_url": actor.mitre_url,
|
||||
"technique_count": tech_count,
|
||||
"coverage_pct": coverage_pct,
|
||||
"is_active": actor.is_active,
|
||||
})
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"offset": offset,
|
||||
"limit": limit,
|
||||
"items": results,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /threat-actors/{id} — Detalle
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/{actor_id}")
|
||||
def get_threat_actor(
|
||||
@@ -142,54 +63,8 @@ def get_threat_actor(
|
||||
|
||||
**Requires** authentication (any role).
|
||||
"""
|
||||
actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first()
|
||||
if not actor:
|
||||
raise HTTPException(status_code=404, detail="Threat actor not found")
|
||||
return get_actor_detail(db, actor_id)
|
||||
|
||||
# Get associated techniques with their coverage status
|
||||
actor_techniques = (
|
||||
db.query(ThreatActorTechnique, Technique)
|
||||
.join(Technique, ThreatActorTechnique.technique_id == Technique.id)
|
||||
.filter(ThreatActorTechnique.threat_actor_id == actor.id)
|
||||
.order_by(Technique.mitre_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
techniques_list = []
|
||||
for at, tech in actor_techniques:
|
||||
techniques_list.append({
|
||||
"technique_id": str(tech.id),
|
||||
"mitre_id": tech.mitre_id,
|
||||
"name": tech.name,
|
||||
"tactic": tech.tactic,
|
||||
"status_global": tech.status_global.value if tech.status_global else None,
|
||||
"usage_description": at.usage_description,
|
||||
"first_seen_using": at.first_seen_using,
|
||||
})
|
||||
|
||||
return {
|
||||
"id": str(actor.id),
|
||||
"mitre_id": actor.mitre_id,
|
||||
"name": actor.name,
|
||||
"aliases": actor.aliases or [],
|
||||
"description": actor.description,
|
||||
"country": actor.country,
|
||||
"target_sectors": actor.target_sectors or [],
|
||||
"target_regions": actor.target_regions or [],
|
||||
"motivation": actor.motivation,
|
||||
"sophistication": actor.sophistication,
|
||||
"first_seen": actor.first_seen,
|
||||
"last_seen": actor.last_seen,
|
||||
"references": actor.references or [],
|
||||
"mitre_url": actor.mitre_url,
|
||||
"is_active": actor.is_active,
|
||||
"techniques": techniques_list,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /threat-actors/{id}/coverage — Cobertura
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/{actor_id}/coverage")
|
||||
def get_threat_actor_coverage(
|
||||
@@ -204,49 +79,8 @@ def get_threat_actor_coverage(
|
||||
Returns the percentage of the actor's techniques that have been
|
||||
validated or partially validated, along with a breakdown.
|
||||
"""
|
||||
actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first()
|
||||
if not actor:
|
||||
raise HTTPException(status_code=404, detail="Threat actor not found")
|
||||
return get_actor_coverage(db, actor_id)
|
||||
|
||||
# Get all techniques for this actor
|
||||
actor_techniques = (
|
||||
db.query(Technique)
|
||||
.join(ThreatActorTechnique, ThreatActorTechnique.technique_id == Technique.id)
|
||||
.filter(ThreatActorTechnique.threat_actor_id == actor.id)
|
||||
.all()
|
||||
)
|
||||
|
||||
total = len(actor_techniques)
|
||||
if total == 0:
|
||||
return {
|
||||
"actor_id": str(actor.id),
|
||||
"actor_name": actor.name,
|
||||
"total_techniques": 0,
|
||||
"coverage_pct": 0.0,
|
||||
"breakdown": {},
|
||||
}
|
||||
|
||||
breakdown = {}
|
||||
for tech in actor_techniques:
|
||||
status = tech.status_global.value if tech.status_global else "not_evaluated"
|
||||
breakdown[status] = breakdown.get(status, 0) + 1
|
||||
|
||||
covered = breakdown.get("validated", 0) + breakdown.get("partial", 0)
|
||||
coverage_pct = round((covered / total * 100), 1)
|
||||
|
||||
return {
|
||||
"actor_id": str(actor.id),
|
||||
"actor_name": actor.name,
|
||||
"total_techniques": total,
|
||||
"covered": covered,
|
||||
"coverage_pct": coverage_pct,
|
||||
"breakdown": breakdown,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /threat-actors/{id}/gaps — Gap analysis
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/{actor_id}/gaps")
|
||||
def get_threat_actor_gaps(
|
||||
@@ -260,52 +94,4 @@ def get_threat_actor_gaps(
|
||||
|
||||
Returns list of gap techniques with available templates.
|
||||
"""
|
||||
actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first()
|
||||
if not actor:
|
||||
raise HTTPException(status_code=404, detail="Threat actor not found")
|
||||
|
||||
# Get techniques NOT validated
|
||||
gap_techniques = (
|
||||
db.query(Technique, ThreatActorTechnique)
|
||||
.join(ThreatActorTechnique, ThreatActorTechnique.technique_id == Technique.id)
|
||||
.filter(ThreatActorTechnique.threat_actor_id == actor.id)
|
||||
.filter(Technique.status_global != TechniqueStatus.validated)
|
||||
.order_by(Technique.mitre_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
gaps = []
|
||||
for tech, at in gap_techniques:
|
||||
# Count available templates for this technique
|
||||
template_count = (
|
||||
db.query(TestTemplate)
|
||||
.filter(TestTemplate.mitre_technique_id == tech.mitre_id)
|
||||
.filter(TestTemplate.is_active == True)
|
||||
.count()
|
||||
)
|
||||
|
||||
# Count existing tests
|
||||
test_count = (
|
||||
db.query(Test)
|
||||
.filter(Test.technique_id == tech.id)
|
||||
.count()
|
||||
)
|
||||
|
||||
gaps.append({
|
||||
"technique_id": str(tech.id),
|
||||
"mitre_id": tech.mitre_id,
|
||||
"name": tech.name,
|
||||
"tactic": tech.tactic,
|
||||
"status_global": tech.status_global.value if tech.status_global else None,
|
||||
"usage_description": at.usage_description,
|
||||
"available_templates": template_count,
|
||||
"existing_tests": test_count,
|
||||
"has_templates": template_count > 0,
|
||||
})
|
||||
|
||||
return {
|
||||
"actor_id": str(actor.id),
|
||||
"actor_name": actor.name,
|
||||
"total_gaps": len(gaps),
|
||||
"gaps": gaps,
|
||||
}
|
||||
return get_actor_gaps(db, actor_id)
|
||||
|
||||
310
backend/app/services/threat_actor_service.py
Normal file
310
backend/app/services/threat_actor_service.py
Normal file
@@ -0,0 +1,310 @@
|
||||
"""Threat actor data service.
|
||||
|
||||
Extracts query and business logic from the threat_actors router so
|
||||
that the router remains a thin HTTP adapter.
|
||||
|
||||
This module is framework-agnostic: no FastAPI imports.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import case, func, or_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.domain.errors import EntityNotFoundError
|
||||
from app.models.enums import TechniqueStatus
|
||||
from app.models.test import Test
|
||||
from app.models.test_template import TestTemplate
|
||||
from app.models.threat_actor import ThreatActor, ThreatActorTechnique
|
||||
from app.models.technique import Technique
|
||||
from app.utils import escape_like
|
||||
|
||||
|
||||
# ── Public service functions ──────────────────────────────────────────
|
||||
|
||||
|
||||
def list_actors(
|
||||
db: Session,
|
||||
*,
|
||||
search: str | None = None,
|
||||
country: str | None = None,
|
||||
motivation: str | None = None,
|
||||
sophistication: str | None = None,
|
||||
target_sectors: str | None = None,
|
||||
offset: int = 0,
|
||||
limit: int = 50,
|
||||
) -> dict[str, Any]:
|
||||
"""List threat actors with optional filters, pagination, and coverage stats.
|
||||
|
||||
Uses grouped subqueries to avoid N+1: technique counts and coverage
|
||||
counts are fetched in one query per page.
|
||||
"""
|
||||
query = db.query(ThreatActor)
|
||||
|
||||
if search:
|
||||
pattern = f"%{escape_like(search)}%"
|
||||
query = query.filter(
|
||||
or_(
|
||||
ThreatActor.name.ilike(pattern),
|
||||
ThreatActor.description.ilike(pattern),
|
||||
func.cast(ThreatActor.aliases, func.text()).ilike(pattern),
|
||||
)
|
||||
)
|
||||
|
||||
if country:
|
||||
query = query.filter(ThreatActor.country == country)
|
||||
|
||||
if motivation:
|
||||
query = query.filter(ThreatActor.motivation == motivation)
|
||||
|
||||
if sophistication:
|
||||
query = query.filter(ThreatActor.sophistication == sophistication)
|
||||
|
||||
if target_sectors:
|
||||
query = query.filter(
|
||||
func.cast(ThreatActor.target_sectors, func.text()).ilike(
|
||||
f"%{escape_like(target_sectors)}%"
|
||||
)
|
||||
)
|
||||
|
||||
total = query.count()
|
||||
actors = (
|
||||
query.order_by(ThreatActor.name).offset(offset).limit(limit).all()
|
||||
)
|
||||
|
||||
if not actors:
|
||||
return {
|
||||
"total": total,
|
||||
"offset": offset,
|
||||
"limit": limit,
|
||||
"items": [],
|
||||
}
|
||||
|
||||
actor_ids = [a.id for a in actors]
|
||||
|
||||
# Single grouped query: tech_count and covered_count per actor
|
||||
counts_rows = (
|
||||
db.query(
|
||||
ThreatActorTechnique.threat_actor_id,
|
||||
func.count(ThreatActorTechnique.id).label("tech_count"),
|
||||
func.sum(
|
||||
case(
|
||||
(
|
||||
Technique.status_global.in_([
|
||||
TechniqueStatus.validated,
|
||||
TechniqueStatus.partial,
|
||||
]),
|
||||
1,
|
||||
),
|
||||
else_=0,
|
||||
)
|
||||
).label("covered_count"),
|
||||
)
|
||||
.join(Technique, ThreatActorTechnique.technique_id == Technique.id)
|
||||
.filter(ThreatActorTechnique.threat_actor_id.in_(actor_ids))
|
||||
.group_by(ThreatActorTechnique.threat_actor_id)
|
||||
).all()
|
||||
|
||||
counts_map = {
|
||||
str(row.threat_actor_id): {
|
||||
"tech_count": row.tech_count,
|
||||
"covered_count": row.covered_count or 0,
|
||||
}
|
||||
for row in counts_rows
|
||||
}
|
||||
|
||||
results = []
|
||||
for actor in actors:
|
||||
cnt = counts_map.get(str(actor.id), {"tech_count": 0, "covered_count": 0})
|
||||
tech_count = cnt["tech_count"]
|
||||
covered = cnt["covered_count"]
|
||||
coverage_pct = round((covered / tech_count * 100), 1) if tech_count > 0 else 0.0
|
||||
|
||||
results.append({
|
||||
"id": str(actor.id),
|
||||
"mitre_id": actor.mitre_id,
|
||||
"name": actor.name,
|
||||
"aliases": actor.aliases or [],
|
||||
"country": actor.country,
|
||||
"target_sectors": actor.target_sectors or [],
|
||||
"target_regions": actor.target_regions or [],
|
||||
"motivation": actor.motivation,
|
||||
"sophistication": actor.sophistication,
|
||||
"mitre_url": actor.mitre_url,
|
||||
"technique_count": tech_count,
|
||||
"coverage_pct": coverage_pct,
|
||||
"is_active": actor.is_active,
|
||||
})
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"offset": offset,
|
||||
"limit": limit,
|
||||
"items": results,
|
||||
}
|
||||
|
||||
|
||||
def get_actor_detail(db: Session, actor_id: str) -> dict[str, Any]:
|
||||
"""Get detailed threat actor with techniques.
|
||||
|
||||
Raises EntityNotFoundError if the actor does not exist.
|
||||
"""
|
||||
actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first()
|
||||
if not actor:
|
||||
raise EntityNotFoundError("Threat actor", actor_id)
|
||||
|
||||
actor_techniques = (
|
||||
db.query(ThreatActorTechnique, Technique)
|
||||
.join(Technique, ThreatActorTechnique.technique_id == Technique.id)
|
||||
.filter(ThreatActorTechnique.threat_actor_id == actor.id)
|
||||
.order_by(Technique.mitre_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
techniques_list = [
|
||||
{
|
||||
"technique_id": str(tech.id),
|
||||
"mitre_id": tech.mitre_id,
|
||||
"name": tech.name,
|
||||
"tactic": tech.tactic,
|
||||
"status_global": tech.status_global.value if tech.status_global else None,
|
||||
"usage_description": at.usage_description,
|
||||
"first_seen_using": at.first_seen_using,
|
||||
}
|
||||
for at, tech in actor_techniques
|
||||
]
|
||||
|
||||
return {
|
||||
"id": str(actor.id),
|
||||
"mitre_id": actor.mitre_id,
|
||||
"name": actor.name,
|
||||
"aliases": actor.aliases or [],
|
||||
"description": actor.description,
|
||||
"country": actor.country,
|
||||
"target_sectors": actor.target_sectors or [],
|
||||
"target_regions": actor.target_regions or [],
|
||||
"motivation": actor.motivation,
|
||||
"sophistication": actor.sophistication,
|
||||
"first_seen": actor.first_seen,
|
||||
"last_seen": actor.last_seen,
|
||||
"references": actor.references or [],
|
||||
"mitre_url": actor.mitre_url,
|
||||
"is_active": actor.is_active,
|
||||
"techniques": techniques_list,
|
||||
}
|
||||
|
||||
|
||||
def get_actor_coverage(db: Session, actor_id: str) -> dict[str, Any]:
|
||||
"""Calculate coverage percentage against a specific threat actor.
|
||||
|
||||
Raises EntityNotFoundError if the actor does not exist.
|
||||
"""
|
||||
actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first()
|
||||
if not actor:
|
||||
raise EntityNotFoundError("Threat actor", actor_id)
|
||||
|
||||
actor_techniques = (
|
||||
db.query(Technique)
|
||||
.join(ThreatActorTechnique, ThreatActorTechnique.technique_id == Technique.id)
|
||||
.filter(ThreatActorTechnique.threat_actor_id == actor.id)
|
||||
.all()
|
||||
)
|
||||
|
||||
total = len(actor_techniques)
|
||||
if total == 0:
|
||||
return {
|
||||
"actor_id": str(actor.id),
|
||||
"actor_name": actor.name,
|
||||
"total_techniques": 0,
|
||||
"coverage_pct": 0.0,
|
||||
"breakdown": {},
|
||||
}
|
||||
|
||||
breakdown: dict[str, int] = {}
|
||||
for tech in actor_techniques:
|
||||
status = tech.status_global.value if tech.status_global else "not_evaluated"
|
||||
breakdown[status] = breakdown.get(status, 0) + 1
|
||||
|
||||
covered = breakdown.get("validated", 0) + breakdown.get("partial", 0)
|
||||
coverage_pct = round((covered / total * 100), 1)
|
||||
|
||||
return {
|
||||
"actor_id": str(actor.id),
|
||||
"actor_name": actor.name,
|
||||
"total_techniques": total,
|
||||
"covered": covered,
|
||||
"coverage_pct": coverage_pct,
|
||||
"breakdown": breakdown,
|
||||
}
|
||||
|
||||
|
||||
def get_actor_gaps(db: Session, actor_id: str) -> dict[str, Any]:
|
||||
"""Identify techniques of this actor that are not fully validated.
|
||||
|
||||
Raises EntityNotFoundError if the actor does not exist.
|
||||
"""
|
||||
actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first()
|
||||
if not actor:
|
||||
raise EntityNotFoundError("Threat actor", actor_id)
|
||||
|
||||
gap_techniques = (
|
||||
db.query(Technique, ThreatActorTechnique)
|
||||
.join(ThreatActorTechnique, ThreatActorTechnique.technique_id == Technique.id)
|
||||
.filter(ThreatActorTechnique.threat_actor_id == actor.id)
|
||||
.filter(Technique.status_global != TechniqueStatus.validated)
|
||||
.order_by(Technique.mitre_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
if not gap_techniques:
|
||||
return {
|
||||
"actor_id": str(actor.id),
|
||||
"actor_name": actor.name,
|
||||
"total_gaps": 0,
|
||||
"gaps": [],
|
||||
}
|
||||
|
||||
technique_ids = [tech.id for tech, _ in gap_techniques]
|
||||
mitre_ids = [tech.mitre_id for tech, _ in gap_techniques]
|
||||
|
||||
# Batch template counts by mitre_technique_id
|
||||
template_counts = (
|
||||
db.query(TestTemplate.mitre_technique_id, func.count(TestTemplate.id).label("cnt"))
|
||||
.filter(TestTemplate.mitre_technique_id.in_(mitre_ids))
|
||||
.filter(TestTemplate.is_active == True)
|
||||
.group_by(TestTemplate.mitre_technique_id)
|
||||
).all()
|
||||
template_map = {row.mitre_technique_id: row.cnt for row in template_counts}
|
||||
|
||||
# Batch test counts by technique_id
|
||||
test_counts = (
|
||||
db.query(Test.technique_id, func.count(Test.id).label("cnt"))
|
||||
.filter(Test.technique_id.in_(technique_ids))
|
||||
.group_by(Test.technique_id)
|
||||
).all()
|
||||
test_map = {str(row.technique_id): row.cnt for row in test_counts}
|
||||
|
||||
gaps = []
|
||||
for tech, at in gap_techniques:
|
||||
template_count = template_map.get(tech.mitre_id, 0)
|
||||
test_count = test_map.get(str(tech.id), 0)
|
||||
gaps.append({
|
||||
"technique_id": str(tech.id),
|
||||
"mitre_id": tech.mitre_id,
|
||||
"name": tech.name,
|
||||
"tactic": tech.tactic,
|
||||
"status_global": tech.status_global.value if tech.status_global else None,
|
||||
"usage_description": at.usage_description,
|
||||
"available_templates": template_count,
|
||||
"existing_tests": test_count,
|
||||
"has_templates": template_count > 0,
|
||||
})
|
||||
|
||||
return {
|
||||
"actor_id": str(actor.id),
|
||||
"actor_name": actor.name,
|
||||
"total_gaps": len(gaps),
|
||||
"gaps": gaps,
|
||||
}
|
||||
Reference in New Issue
Block a user