"""Threat actor data service. Extracts query and business logic from the threat_actors router so that the router remains a thin HTTP adapter. This module is framework-agnostic: no FastAPI imports. """ # Enable future language features for compatibility from __future__ import annotations # Import Any from typing from typing import Any # Import case, func, or_ from sqlalchemy from sqlalchemy import case, func, or_ # Import Session from sqlalchemy.orm from sqlalchemy.orm import Session # Import EntityNotFoundError from app.domain.errors from app.domain.errors import EntityNotFoundError # Import TechniqueStatus from app.models.enums from app.models.enums import TechniqueStatus # Import Technique from app.models.technique from app.models.technique import Technique # Import Test from app.models.test from app.models.test import Test # Import TestTemplate from app.models.test_template from app.models.test_template import TestTemplate # Import ThreatActor, ThreatActorTechnique from app.models.threat_actor from app.models.threat_actor import ThreatActor, ThreatActorTechnique # Import escape_like from app.utils from app.utils import escape_like # ── Public service functions ────────────────────────────────────────── def list_actors( # Entry: db db: Session, *, # Entry: search search: str | None = None, # Entry: country country: str | None = None, # Entry: motivation motivation: str | None = None, # Entry: sophistication sophistication: str | None = None, # Entry: target_sectors target_sectors: str | None = None, # Entry: offset offset: int = 0, # Entry: limit limit: int = 50, ) -> dict[str, Any]: """List threat actors with optional filters, pagination, and coverage stats. Uses grouped subqueries to avoid N+1: technique counts and coverage counts are fetched in one query per page. """ # Assign query = db.query(ThreatActor) query = db.query(ThreatActor) # Check: search if search: # Assign pattern = f"%{escape_like(search)}%" pattern = f"%{escape_like(search)}%" # Assign query = query.filter( query = query.filter( or_( ThreatActor.name.ilike(pattern), ThreatActor.description.ilike(pattern), func.cast(ThreatActor.aliases, func.text()).ilike(pattern), ) ) # Check: country if country: # Assign query = query.filter(ThreatActor.country == country) query = query.filter(ThreatActor.country == country) # Check: motivation if motivation: # Assign query = query.filter(ThreatActor.motivation == motivation) query = query.filter(ThreatActor.motivation == motivation) # Check: sophistication if sophistication: # Assign query = query.filter(ThreatActor.sophistication == sophistication) query = query.filter(ThreatActor.sophistication == sophistication) # Check: target_sectors if target_sectors: # Assign query = query.filter( query = query.filter( func.cast(ThreatActor.target_sectors, func.text()).ilike( f"%{escape_like(target_sectors)}%" ) ) # Assign total = query.count() total = query.count() # Assign actors = ( actors = ( query.order_by(ThreatActor.name).offset(offset).limit(limit).all() ) # Check: not actors if not actors: # Return { return { # Literal argument value "total": total, # Literal argument value "offset": offset, # Literal argument value "limit": limit, # Literal argument value "items": [], } # Assign actor_ids = [a.id for a in actors] actor_ids = [a.id for a in actors] # Single grouped query: tech_count and covered_count per actor counts_rows = ( db.query( ThreatActorTechnique.threat_actor_id, func.count(ThreatActorTechnique.id).label("tech_count"), func.sum( case( ( Technique.status_global.in_([ TechniqueStatus.validated, TechniqueStatus.partial, ]), # Literal argument value 1, ), # Keyword argument: else_ else_=0, ) ).label("covered_count"), ) # Chain .join() call .join(Technique, ThreatActorTechnique.technique_id == Technique.id) # Chain .filter() call .filter(ThreatActorTechnique.threat_actor_id.in_(actor_ids)) # Chain .group_by() call .group_by(ThreatActorTechnique.threat_actor_id) ).all() # Assign counts_map = { counts_map = { str(row.threat_actor_id): { # Literal argument value "tech_count": row.tech_count, # Literal argument value "covered_count": row.covered_count or 0, } for row in counts_rows } # Assign results = [] results = [] # Iterate over actors for actor in actors: # Assign cnt = counts_map.get(str(actor.id), {"tech_count": 0, "covered_count": 0}) cnt = counts_map.get(str(actor.id), {"tech_count": 0, "covered_count": 0}) # Assign tech_count = cnt["tech_count"] tech_count = cnt["tech_count"] # Assign covered = cnt["covered_count"] covered = cnt["covered_count"] # Assign coverage_pct = round((covered / tech_count * 100), 1) if tech_count > 0 else 0.0 coverage_pct = round((covered / tech_count * 100), 1) if tech_count > 0 else 0.0 # Call results.append() results.append({ # Literal argument value "id": str(actor.id), # Literal argument value "mitre_id": actor.mitre_id, # Literal argument value "name": actor.name, # Literal argument value "aliases": actor.aliases or [], # Literal argument value "country": actor.country, # Literal argument value "target_sectors": actor.target_sectors or [], # Literal argument value "target_regions": actor.target_regions or [], # Literal argument value "motivation": actor.motivation, # Literal argument value "sophistication": actor.sophistication, # Literal argument value "mitre_url": actor.mitre_url, # Literal argument value "technique_count": tech_count, # Literal argument value "coverage_pct": coverage_pct, # Literal argument value "is_active": actor.is_active, }) # Return { return { # Literal argument value "total": total, # Literal argument value "offset": offset, # Literal argument value "limit": limit, # Literal argument value "items": results, } # Define function get_actor_detail def get_actor_detail(db: Session, actor_id: str) -> dict[str, Any]: """Get detailed threat actor with techniques. Raises EntityNotFoundError if the actor does not exist. """ # Assign actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first() actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first() # Check: not actor if not actor: # Raise EntityNotFoundError raise EntityNotFoundError("Threat actor", actor_id) # Assign actor_techniques = ( actor_techniques = ( db.query(ThreatActorTechnique, Technique) # Chain .join() call .join(Technique, ThreatActorTechnique.technique_id == Technique.id) # Chain .filter() call .filter(ThreatActorTechnique.threat_actor_id == actor.id) # Chain .order_by() call .order_by(Technique.mitre_id) # Chain .all() call .all() ) # Assign techniques_list = [ techniques_list = [ { # Literal argument value "technique_id": str(tech.id), # Literal argument value "mitre_id": tech.mitre_id, # Literal argument value "name": tech.name, # Literal argument value "tactic": tech.tactic, # Literal argument value "status_global": tech.status_global.value if tech.status_global else None, # Literal argument value "usage_description": at.usage_description, # Literal argument value "first_seen_using": at.first_seen_using, } for at, tech in actor_techniques ] # Return { return { # Literal argument value "id": str(actor.id), # Literal argument value "mitre_id": actor.mitre_id, # Literal argument value "name": actor.name, # Literal argument value "aliases": actor.aliases or [], # Literal argument value "description": actor.description, # Literal argument value "country": actor.country, # Literal argument value "target_sectors": actor.target_sectors or [], # Literal argument value "target_regions": actor.target_regions or [], # Literal argument value "motivation": actor.motivation, # Literal argument value "sophistication": actor.sophistication, # Literal argument value "first_seen": actor.first_seen, # Literal argument value "last_seen": actor.last_seen, # Literal argument value "references": actor.references or [], # Literal argument value "mitre_url": actor.mitre_url, # Literal argument value "is_active": actor.is_active, # Literal argument value "techniques": techniques_list, } # Define function get_actor_coverage def get_actor_coverage(db: Session, actor_id: str) -> dict[str, Any]: """Calculate coverage percentage against a specific threat actor. Raises EntityNotFoundError if the actor does not exist. """ # Assign actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first() actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first() # Check: not actor if not actor: # Raise EntityNotFoundError raise EntityNotFoundError("Threat actor", actor_id) # Assign actor_techniques = ( actor_techniques = ( db.query(Technique) # Chain .join() call .join(ThreatActorTechnique, ThreatActorTechnique.technique_id == Technique.id) # Chain .filter() call .filter(ThreatActorTechnique.threat_actor_id == actor.id) # Chain .all() call .all() ) # Assign total = len(actor_techniques) total = len(actor_techniques) # Check: total == 0 if total == 0: # Return { return { # Literal argument value "actor_id": str(actor.id), # Literal argument value "actor_name": actor.name, # Literal argument value "total_techniques": 0, # Literal argument value "coverage_pct": 0.0, # Literal argument value "breakdown": {}, } # Assign breakdown = {} breakdown: dict[str, int] = {} # Iterate over actor_techniques for tech in actor_techniques: # Assign status = tech.status_global.value if tech.status_global else "not_evaluated" status = tech.status_global.value if tech.status_global else "not_evaluated" # Assign breakdown[status] = breakdown.get(status, 0) + 1 breakdown[status] = breakdown.get(status, 0) + 1 # Assign covered = breakdown.get("validated", 0) + breakdown.get("partial", 0) covered = breakdown.get("validated", 0) + breakdown.get("partial", 0) # Assign coverage_pct = round((covered / total * 100), 1) coverage_pct = round((covered / total * 100), 1) # Return { return { # Literal argument value "actor_id": str(actor.id), # Literal argument value "actor_name": actor.name, # Literal argument value "total_techniques": total, # Literal argument value "covered": covered, # Literal argument value "coverage_pct": coverage_pct, # Literal argument value "breakdown": breakdown, } # Define function get_actor_gaps def get_actor_gaps(db: Session, actor_id: str) -> dict[str, Any]: """Identify techniques of this actor that are not fully validated. Raises EntityNotFoundError if the actor does not exist. """ # Assign actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first() actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first() # Check: not actor if not actor: # Raise EntityNotFoundError raise EntityNotFoundError("Threat actor", actor_id) # Assign gap_techniques = ( gap_techniques = ( db.query(Technique, ThreatActorTechnique) # Chain .join() call .join(ThreatActorTechnique, ThreatActorTechnique.technique_id == Technique.id) # Chain .filter() call .filter(ThreatActorTechnique.threat_actor_id == actor.id) # Chain .filter() call .filter(Technique.status_global != TechniqueStatus.validated) # Chain .order_by() call .order_by(Technique.mitre_id) # Chain .all() call .all() ) # Check: not gap_techniques if not gap_techniques: # Return { return { # Literal argument value "actor_id": str(actor.id), # Literal argument value "actor_name": actor.name, # Literal argument value "total_gaps": 0, # Literal argument value "gaps": [], } # Assign technique_ids = [tech.id for tech, _ in gap_techniques] technique_ids = [tech.id for tech, _ in gap_techniques] # Assign mitre_ids = [tech.mitre_id for tech, _ in gap_techniques] mitre_ids = [tech.mitre_id for tech, _ in gap_techniques] # Batch template counts by mitre_technique_id template_counts = ( db.query(TestTemplate.mitre_technique_id, func.count(TestTemplate.id).label("cnt")) # Chain .filter() call .filter(TestTemplate.mitre_technique_id.in_(mitre_ids)) # Chain .filter() call .filter(TestTemplate.is_active == True) # Chain .group_by() call .group_by(TestTemplate.mitre_technique_id) ).all() # Assign template_map = {row.mitre_technique_id: row.cnt for row in template_counts} template_map = {row.mitre_technique_id: row.cnt for row in template_counts} # Batch test counts by technique_id test_counts = ( db.query(Test.technique_id, func.count(Test.id).label("cnt")) # Chain .filter() call .filter(Test.technique_id.in_(technique_ids)) # Chain .group_by() call .group_by(Test.technique_id) ).all() # Assign test_map = {str(row.technique_id): row.cnt for row in test_counts} test_map = {str(row.technique_id): row.cnt for row in test_counts} # Assign gaps = [] gaps = [] # Iterate over gap_techniques for tech, at in gap_techniques: # Assign template_count = template_map.get(tech.mitre_id, 0) template_count = template_map.get(tech.mitre_id, 0) # Assign test_count = test_map.get(str(tech.id), 0) test_count = test_map.get(str(tech.id), 0) # Call gaps.append() gaps.append({ # Literal argument value "technique_id": str(tech.id), # Literal argument value "mitre_id": tech.mitre_id, # Literal argument value "name": tech.name, # Literal argument value "tactic": tech.tactic, # Literal argument value "status_global": tech.status_global.value if tech.status_global else None, # Literal argument value "usage_description": at.usage_description, # Literal argument value "available_templates": template_count, # Literal argument value "existing_tests": test_count, # Literal argument value "has_templates": template_count > 0, }) # Return { return { # Literal argument value "actor_id": str(actor.id), # Literal argument value "actor_name": actor.name, # Literal argument value "total_gaps": len(gaps), # Literal argument value "gaps": gaps, }