"""Threat actor data service. Extracts query and business logic from the threat_actors router so that the router remains a thin HTTP adapter. This module is framework-agnostic: no FastAPI imports. """ from __future__ import annotations from typing import Any from sqlalchemy import case, cast, func, or_, Text from sqlalchemy.orm import Session from app.domain.errors import EntityNotFoundError from app.models.enums import TechniqueStatus from app.models.test import Test from app.models.test_template import TestTemplate from app.models.threat_actor import ThreatActor, ThreatActorTechnique from app.models.technique import Technique from app.utils import escape_like # ── Public service functions ────────────────────────────────────────── def list_actors( db: Session, *, search: str | None = None, country: str | None = None, motivation: str | None = None, sophistication: str | None = None, target_sectors: str | None = None, offset: int = 0, limit: int = 50, ) -> dict[str, Any]: """List threat actors with optional filters, pagination, and coverage stats. Uses grouped subqueries to avoid N+1: technique counts and coverage counts are fetched in one query per page. """ query = db.query(ThreatActor) if search: pattern = f"%{escape_like(search)}%" query = query.filter( or_( ThreatActor.name.ilike(pattern), ThreatActor.description.ilike(pattern), cast(ThreatActor.aliases, Text).ilike(pattern), ) ) if country: query = query.filter(ThreatActor.country == country) if motivation: query = query.filter(ThreatActor.motivation == motivation) if sophistication: query = query.filter(ThreatActor.sophistication == sophistication) if target_sectors: query = query.filter( cast(ThreatActor.target_sectors, Text).ilike( f"%{escape_like(target_sectors)}%" ) ) total = query.count() actors = ( query.order_by(ThreatActor.name).offset(offset).limit(limit).all() ) if not actors: return { "total": total, "offset": offset, "limit": limit, "items": [], } actor_ids = [a.id for a in actors] # Single grouped query: tech_count and covered_count per actor counts_rows = ( db.query( ThreatActorTechnique.threat_actor_id, func.count(ThreatActorTechnique.id).label("tech_count"), func.sum( case( ( Technique.status_global.in_([ TechniqueStatus.validated, TechniqueStatus.partial, ]), 1, ), else_=0, ) ).label("covered_count"), ) .join(Technique, ThreatActorTechnique.technique_id == Technique.id) .filter(ThreatActorTechnique.threat_actor_id.in_(actor_ids)) .group_by(ThreatActorTechnique.threat_actor_id) ).all() counts_map = { str(row.threat_actor_id): { "tech_count": row.tech_count, "covered_count": row.covered_count or 0, } for row in counts_rows } results = [] for actor in actors: cnt = counts_map.get(str(actor.id), {"tech_count": 0, "covered_count": 0}) tech_count = cnt["tech_count"] covered = cnt["covered_count"] coverage_pct = round((covered / tech_count * 100), 1) if tech_count > 0 else 0.0 results.append({ "id": str(actor.id), "mitre_id": actor.mitre_id, "name": actor.name, "aliases": actor.aliases or [], "country": actor.country, "target_sectors": actor.target_sectors or [], "target_regions": actor.target_regions or [], "motivation": actor.motivation, "sophistication": actor.sophistication, "mitre_url": actor.mitre_url, "technique_count": tech_count, "coverage_pct": coverage_pct, "is_active": actor.is_active, }) return { "total": total, "offset": offset, "limit": limit, "items": results, } def get_actor_detail(db: Session, actor_id: str) -> dict[str, Any]: """Get detailed threat actor with techniques. Raises EntityNotFoundError if the actor does not exist. """ actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first() if not actor: raise EntityNotFoundError("Threat actor", actor_id) actor_techniques = ( db.query(ThreatActorTechnique, Technique) .join(Technique, ThreatActorTechnique.technique_id == Technique.id) .filter(ThreatActorTechnique.threat_actor_id == actor.id) .order_by(Technique.mitre_id) .all() ) techniques_list = [ { "technique_id": str(tech.id), "mitre_id": tech.mitre_id, "name": tech.name, "tactic": tech.tactic, "status_global": tech.status_global.value if tech.status_global else None, "usage_description": at.usage_description, "first_seen_using": at.first_seen_using, } for at, tech in actor_techniques ] return { "id": str(actor.id), "mitre_id": actor.mitre_id, "name": actor.name, "aliases": actor.aliases or [], "description": actor.description, "country": actor.country, "target_sectors": actor.target_sectors or [], "target_regions": actor.target_regions or [], "motivation": actor.motivation, "sophistication": actor.sophistication, "first_seen": actor.first_seen, "last_seen": actor.last_seen, "references": actor.references or [], "mitre_url": actor.mitre_url, "is_active": actor.is_active, "techniques": techniques_list, } def get_actor_coverage(db: Session, actor_id: str) -> dict[str, Any]: """Calculate coverage percentage against a specific threat actor. Raises EntityNotFoundError if the actor does not exist. """ actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first() if not actor: raise EntityNotFoundError("Threat actor", actor_id) actor_techniques = ( db.query(Technique) .join(ThreatActorTechnique, ThreatActorTechnique.technique_id == Technique.id) .filter(ThreatActorTechnique.threat_actor_id == actor.id) .all() ) total = len(actor_techniques) if total == 0: return { "actor_id": str(actor.id), "actor_name": actor.name, "total_techniques": 0, "coverage_pct": 0.0, "breakdown": {}, } breakdown: dict[str, int] = {} for tech in actor_techniques: status = tech.status_global.value if tech.status_global else "not_evaluated" breakdown[status] = breakdown.get(status, 0) + 1 covered = breakdown.get("validated", 0) + breakdown.get("partial", 0) coverage_pct = round((covered / total * 100), 1) return { "actor_id": str(actor.id), "actor_name": actor.name, "total_techniques": total, "covered": covered, "coverage_pct": coverage_pct, "breakdown": breakdown, } def get_actor_gaps(db: Session, actor_id: str) -> dict[str, Any]: """Identify techniques of this actor that are not fully validated. Raises EntityNotFoundError if the actor does not exist. """ actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first() if not actor: raise EntityNotFoundError("Threat actor", actor_id) gap_techniques = ( db.query(Technique, ThreatActorTechnique) .join(ThreatActorTechnique, ThreatActorTechnique.technique_id == Technique.id) .filter(ThreatActorTechnique.threat_actor_id == actor.id) .filter(Technique.status_global != TechniqueStatus.validated) .order_by(Technique.mitre_id) .all() ) if not gap_techniques: return { "actor_id": str(actor.id), "actor_name": actor.name, "total_gaps": 0, "gaps": [], } technique_ids = [tech.id for tech, _ in gap_techniques] mitre_ids = [tech.mitre_id for tech, _ in gap_techniques] # Batch template counts by mitre_technique_id template_counts = ( db.query(TestTemplate.mitre_technique_id, func.count(TestTemplate.id).label("cnt")) .filter(TestTemplate.mitre_technique_id.in_(mitre_ids)) .filter(TestTemplate.is_active == True) .group_by(TestTemplate.mitre_technique_id) ).all() template_map = {row.mitre_technique_id: row.cnt for row in template_counts} # Batch test counts by technique_id test_counts = ( db.query(Test.technique_id, func.count(Test.id).label("cnt")) .filter(Test.technique_id.in_(technique_ids)) .group_by(Test.technique_id) ).all() test_map = {str(row.technique_id): row.cnt for row in test_counts} gaps = [] for tech, at in gap_techniques: template_count = template_map.get(tech.mitre_id, 0) test_count = test_map.get(str(tech.id), 0) gaps.append({ "technique_id": str(tech.id), "mitre_id": tech.mitre_id, "name": tech.name, "tactic": tech.tactic, "status_global": tech.status_global.value if tech.status_global else None, "usage_description": at.usage_description, "available_templates": template_count, "existing_tests": test_count, "has_templates": template_count > 0, }) return { "actor_id": str(actor.id), "actor_name": actor.name, "total_gaps": len(gaps), "gaps": gaps, }