Files
Aegis/backend/app/services/threat_actor_service.py
T

494 lines
16 KiB
Python

"""Threat actor data service.
Extracts query and business logic from the threat_actors router so
that the router remains a thin HTTP adapter.
This module is framework-agnostic: no FastAPI imports.
"""
# Enable future language features for compatibility
from __future__ import annotations
# Import Any from typing
from typing import Any
from sqlalchemy import case, cast, func, or_, Text
from sqlalchemy.orm import Session
# Import EntityNotFoundError from app.domain.errors
from app.domain.errors import EntityNotFoundError
# Import TechniqueStatus from app.models.enums
from app.models.enums import TechniqueStatus
# Import Technique from app.models.technique
from app.models.technique import Technique
# Import Test from app.models.test
from app.models.test import Test
# Import TestTemplate from app.models.test_template
from app.models.test_template import TestTemplate
# Import ThreatActor, ThreatActorTechnique from app.models.threat_actor
from app.models.threat_actor import ThreatActor, ThreatActorTechnique
# Import escape_like from app.utils
from app.utils import escape_like
# ── Public service functions ──────────────────────────────────────────
def list_actors(
# Entry: db
db: Session,
*,
# Entry: search
search: str | None = None,
# Entry: country
country: str | None = None,
# Entry: motivation
motivation: str | None = None,
# Entry: sophistication
sophistication: str | None = None,
# Entry: target_sectors
target_sectors: str | None = None,
# Entry: offset
offset: int = 0,
# Entry: limit
limit: int = 50,
) -> dict[str, Any]:
"""List threat actors with optional filters, pagination, and coverage stats.
Uses grouped subqueries to avoid N+1: technique counts and coverage
counts are fetched in one query per page.
"""
# Assign query = db.query(ThreatActor)
query = db.query(ThreatActor)
# Check: search
if search:
# Assign pattern = f"%{escape_like(search)}%"
pattern = f"%{escape_like(search)}%"
# Assign query = query.filter(
query = query.filter(
or_(
ThreatActor.name.ilike(pattern),
ThreatActor.description.ilike(pattern),
cast(ThreatActor.aliases, Text).ilike(pattern),
)
)
# Check: country
if country:
# Assign query = query.filter(ThreatActor.country == country)
query = query.filter(ThreatActor.country == country)
# Check: motivation
if motivation:
# Assign query = query.filter(ThreatActor.motivation == motivation)
query = query.filter(ThreatActor.motivation == motivation)
# Check: sophistication
if sophistication:
# Assign query = query.filter(ThreatActor.sophistication == sophistication)
query = query.filter(ThreatActor.sophistication == sophistication)
# Check: target_sectors
if target_sectors:
# Assign query = query.filter(
query = query.filter(
cast(ThreatActor.target_sectors, Text).ilike(
f"%{escape_like(target_sectors)}%"
)
)
# Assign total = query.count()
total = query.count()
# Assign actors = (
actors = (
query.order_by(ThreatActor.name).offset(offset).limit(limit).all()
)
# Check: not actors
if not actors:
# Return {
return {
# Literal argument value
"total": total,
# Literal argument value
"offset": offset,
# Literal argument value
"limit": limit,
# Literal argument value
"items": [],
}
# Assign actor_ids = [a.id for a in actors]
actor_ids = [a.id for a in actors]
# Single grouped query: tech_count and covered_count per actor
counts_rows = (
db.query(
ThreatActorTechnique.threat_actor_id,
func.count(ThreatActorTechnique.id).label("tech_count"),
func.sum(
case(
(
Technique.status_global.in_([
TechniqueStatus.validated,
TechniqueStatus.partial,
]),
# Literal argument value
1,
),
# Keyword argument: else_
else_=0,
)
).label("covered_count"),
)
# Chain .join() call
.join(Technique, ThreatActorTechnique.technique_id == Technique.id)
# Chain .filter() call
.filter(ThreatActorTechnique.threat_actor_id.in_(actor_ids))
# Chain .group_by() call
.group_by(ThreatActorTechnique.threat_actor_id)
).all()
# Assign counts_map = {
counts_map = {
str(row.threat_actor_id): {
# Literal argument value
"tech_count": row.tech_count,
# Literal argument value
"covered_count": row.covered_count or 0,
}
for row in counts_rows
}
# Assign results = []
results = []
# Iterate over actors
for actor in actors:
# Assign cnt = counts_map.get(str(actor.id), {"tech_count": 0, "covered_count": 0})
cnt = counts_map.get(str(actor.id), {"tech_count": 0, "covered_count": 0})
# Assign tech_count = cnt["tech_count"]
tech_count = cnt["tech_count"]
# Assign covered = cnt["covered_count"]
covered = cnt["covered_count"]
# Assign coverage_pct = round((covered / tech_count * 100), 1) if tech_count > 0 else 0.0
coverage_pct = round((covered / tech_count * 100), 1) if tech_count > 0 else 0.0
# Call results.append()
results.append({
# Literal argument value
"id": str(actor.id),
# Literal argument value
"mitre_id": actor.mitre_id,
# Literal argument value
"name": actor.name,
# Literal argument value
"aliases": actor.aliases or [],
# Literal argument value
"country": actor.country,
# Literal argument value
"target_sectors": actor.target_sectors or [],
# Literal argument value
"target_regions": actor.target_regions or [],
# Literal argument value
"motivation": actor.motivation,
# Literal argument value
"sophistication": actor.sophistication,
# Literal argument value
"mitre_url": actor.mitre_url,
# Literal argument value
"technique_count": tech_count,
# Literal argument value
"coverage_pct": coverage_pct,
# Literal argument value
"is_active": actor.is_active,
})
# Return {
return {
# Literal argument value
"total": total,
# Literal argument value
"offset": offset,
# Literal argument value
"limit": limit,
# Literal argument value
"items": results,
}
# Define function get_actor_detail
def get_actor_detail(db: Session, actor_id: str) -> dict[str, Any]:
"""Get detailed threat actor with techniques.
Raises EntityNotFoundError if the actor does not exist.
"""
# Assign actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first()
actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first()
# Check: not actor
if not actor:
# Raise EntityNotFoundError
raise EntityNotFoundError("Threat actor", actor_id)
# Assign actor_techniques = (
actor_techniques = (
db.query(ThreatActorTechnique, Technique)
# Chain .join() call
.join(Technique, ThreatActorTechnique.technique_id == Technique.id)
# Chain .filter() call
.filter(ThreatActorTechnique.threat_actor_id == actor.id)
# Chain .order_by() call
.order_by(Technique.mitre_id)
# Chain .all() call
.all()
)
# Assign techniques_list = [
techniques_list = [
{
# Literal argument value
"technique_id": str(tech.id),
# Literal argument value
"mitre_id": tech.mitre_id,
# Literal argument value
"name": tech.name,
# Literal argument value
"tactic": tech.tactic,
# Literal argument value
"status_global": tech.status_global.value if tech.status_global else None,
# Literal argument value
"usage_description": at.usage_description,
# Literal argument value
"first_seen_using": at.first_seen_using,
}
for at, tech in actor_techniques
]
# Return {
return {
# Literal argument value
"id": str(actor.id),
# Literal argument value
"mitre_id": actor.mitre_id,
# Literal argument value
"name": actor.name,
# Literal argument value
"aliases": actor.aliases or [],
# Literal argument value
"description": actor.description,
# Literal argument value
"country": actor.country,
# Literal argument value
"target_sectors": actor.target_sectors or [],
# Literal argument value
"target_regions": actor.target_regions or [],
# Literal argument value
"motivation": actor.motivation,
# Literal argument value
"sophistication": actor.sophistication,
# Literal argument value
"first_seen": actor.first_seen,
# Literal argument value
"last_seen": actor.last_seen,
# Literal argument value
"references": actor.references or [],
# Literal argument value
"mitre_url": actor.mitre_url,
# Literal argument value
"is_active": actor.is_active,
# Literal argument value
"techniques": techniques_list,
}
# Define function get_actor_coverage
def get_actor_coverage(db: Session, actor_id: str) -> dict[str, Any]:
"""Calculate coverage percentage against a specific threat actor.
Raises EntityNotFoundError if the actor does not exist.
"""
# Assign actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first()
actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first()
# Check: not actor
if not actor:
# Raise EntityNotFoundError
raise EntityNotFoundError("Threat actor", actor_id)
# Assign actor_techniques = (
actor_techniques = (
db.query(Technique)
# Chain .join() call
.join(ThreatActorTechnique, ThreatActorTechnique.technique_id == Technique.id)
# Chain .filter() call
.filter(ThreatActorTechnique.threat_actor_id == actor.id)
# Chain .all() call
.all()
)
# Assign total = len(actor_techniques)
total = len(actor_techniques)
# Check: total == 0
if total == 0:
# Return {
return {
# Literal argument value
"actor_id": str(actor.id),
# Literal argument value
"actor_name": actor.name,
# Literal argument value
"total_techniques": 0,
# Literal argument value
"coverage_pct": 0.0,
# Literal argument value
"breakdown": {},
}
# Assign breakdown = {}
breakdown: dict[str, int] = {}
# Iterate over actor_techniques
for tech in actor_techniques:
# Assign status = tech.status_global.value if tech.status_global else "not_evaluated"
status = tech.status_global.value if tech.status_global else "not_evaluated"
# Assign breakdown[status] = breakdown.get(status, 0) + 1
breakdown[status] = breakdown.get(status, 0) + 1
# Assign covered = breakdown.get("validated", 0) + breakdown.get("partial", 0)
covered = breakdown.get("validated", 0) + breakdown.get("partial", 0)
# Assign coverage_pct = round((covered / total * 100), 1)
coverage_pct = round((covered / total * 100), 1)
# Return {
return {
# Literal argument value
"actor_id": str(actor.id),
# Literal argument value
"actor_name": actor.name,
# Literal argument value
"total_techniques": total,
# Literal argument value
"covered": covered,
# Literal argument value
"coverage_pct": coverage_pct,
# Literal argument value
"breakdown": breakdown,
}
# Define function get_actor_gaps
def get_actor_gaps(db: Session, actor_id: str) -> dict[str, Any]:
"""Identify techniques of this actor that are not fully validated.
Raises EntityNotFoundError if the actor does not exist.
"""
# Assign actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first()
actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first()
# Check: not actor
if not actor:
# Raise EntityNotFoundError
raise EntityNotFoundError("Threat actor", actor_id)
# Assign gap_techniques = (
gap_techniques = (
db.query(Technique, ThreatActorTechnique)
# Chain .join() call
.join(ThreatActorTechnique, ThreatActorTechnique.technique_id == Technique.id)
# Chain .filter() call
.filter(ThreatActorTechnique.threat_actor_id == actor.id)
# Chain .filter() call
.filter(Technique.status_global != TechniqueStatus.validated)
# Chain .order_by() call
.order_by(Technique.mitre_id)
# Chain .all() call
.all()
)
# Check: not gap_techniques
if not gap_techniques:
# Return {
return {
# Literal argument value
"actor_id": str(actor.id),
# Literal argument value
"actor_name": actor.name,
# Literal argument value
"total_gaps": 0,
# Literal argument value
"gaps": [],
}
# Assign technique_ids = [tech.id for tech, _ in gap_techniques]
technique_ids = [tech.id for tech, _ in gap_techniques]
# Assign mitre_ids = [tech.mitre_id for tech, _ in gap_techniques]
mitre_ids = [tech.mitre_id for tech, _ in gap_techniques]
# Batch template counts by mitre_technique_id
template_counts = (
db.query(TestTemplate.mitre_technique_id, func.count(TestTemplate.id).label("cnt"))
# Chain .filter() call
.filter(TestTemplate.mitre_technique_id.in_(mitre_ids))
# Chain .filter() call
.filter(TestTemplate.is_active == True)
# Chain .group_by() call
.group_by(TestTemplate.mitre_technique_id)
).all()
# Assign template_map = {row.mitre_technique_id: row.cnt for row in template_counts}
template_map = {row.mitre_technique_id: row.cnt for row in template_counts}
# Batch test counts by technique_id
test_counts = (
db.query(Test.technique_id, func.count(Test.id).label("cnt"))
# Chain .filter() call
.filter(Test.technique_id.in_(technique_ids))
# Chain .group_by() call
.group_by(Test.technique_id)
).all()
# Assign test_map = {str(row.technique_id): row.cnt for row in test_counts}
test_map = {str(row.technique_id): row.cnt for row in test_counts}
# Assign gaps = []
gaps = []
# Iterate over gap_techniques
for tech, at in gap_techniques:
# Assign template_count = template_map.get(tech.mitre_id, 0)
template_count = template_map.get(tech.mitre_id, 0)
# Assign test_count = test_map.get(str(tech.id), 0)
test_count = test_map.get(str(tech.id), 0)
# Call gaps.append()
gaps.append({
# Literal argument value
"technique_id": str(tech.id),
# Literal argument value
"mitre_id": tech.mitre_id,
# Literal argument value
"name": tech.name,
# Literal argument value
"tactic": tech.tactic,
# Literal argument value
"status_global": tech.status_global.value if tech.status_global else None,
# Literal argument value
"usage_description": at.usage_description,
# Literal argument value
"available_templates": template_count,
# Literal argument value
"existing_tests": test_count,
# Literal argument value
"has_templates": template_count > 0,
})
# Return {
return {
# Literal argument value
"actor_id": str(actor.id),
# Literal argument value
"actor_name": actor.name,
# Literal argument value
"total_gaps": len(gaps),
# Literal argument value
"gaps": gaps,
}