diff --git a/backend/app/routers/threat_actors.py b/backend/app/routers/threat_actors.py index b780ef1..733112a 100644 --- a/backend/app/routers/threat_actors.py +++ b/backend/app/routers/threat_actors.py @@ -7,28 +7,24 @@ threat actor profiles imported from MITRE CTI. import logging from typing import Optional -from fastapi import APIRouter, Depends, HTTPException, Query -from sqlalchemy import func, or_ -from sqlalchemy.orm import Session, joinedload +from fastapi import APIRouter, Depends, Query +from sqlalchemy.orm import Session from app.database import get_db from app.dependencies.auth import get_current_user from app.models.user import User -from app.models.threat_actor import ThreatActor, ThreatActorTechnique -from app.models.technique import Technique -from app.models.test import Test -from app.models.test_template import TestTemplate -from app.models.enums import TechniqueStatus +from app.services.threat_actor_service import ( + get_actor_coverage, + get_actor_detail, + get_actor_gaps, + list_actors, +) logger = logging.getLogger(__name__) router = APIRouter(prefix="/threat-actors", tags=["threat-actors"]) -# --------------------------------------------------------------------------- -# GET /threat-actors — Listado con filtros -# --------------------------------------------------------------------------- - @router.get("") def list_threat_actors( search: Optional[str] = Query(None), @@ -45,92 +41,17 @@ def list_threat_actors( **Requires** authentication (any role). """ - query = db.query(ThreatActor) + return list_actors( + db, + search=search, + country=country, + motivation=motivation, + sophistication=sophistication, + target_sectors=target_sectors, + offset=offset, + limit=limit, + ) - # Filters - if search: - from app.utils import escape_like - pattern = f"%{escape_like(search)}%" - query = query.filter( - or_( - ThreatActor.name.ilike(pattern), - ThreatActor.description.ilike(pattern), - func.cast(ThreatActor.aliases, func.text()).ilike(pattern), - ) - ) - - if country: - query = query.filter(ThreatActor.country == country) - - if motivation: - query = query.filter(ThreatActor.motivation == motivation) - - if sophistication: - query = query.filter(ThreatActor.sophistication == sophistication) - - if target_sectors: - from app.utils import escape_like - # JSONB contains check - query = query.filter( - func.cast(ThreatActor.target_sectors, func.text()).ilike(f"%{escape_like(target_sectors)}%") - ) - - # Total count - total = query.count() - - # Paginate - actors = query.order_by(ThreatActor.name).offset(offset).limit(limit).all() - - # For each actor, count techniques and calculate basic coverage - results = [] - for actor in actors: - tech_count = ( - db.query(ThreatActorTechnique) - .filter(ThreatActorTechnique.threat_actor_id == actor.id) - .count() - ) - - # Quick coverage calculation - covered = ( - db.query(ThreatActorTechnique) - .join(Technique, ThreatActorTechnique.technique_id == Technique.id) - .filter(ThreatActorTechnique.threat_actor_id == actor.id) - .filter(Technique.status_global.in_([ - TechniqueStatus.validated, - TechniqueStatus.partial, - ])) - .count() - ) - - coverage_pct = round((covered / tech_count * 100), 1) if tech_count > 0 else 0.0 - - results.append({ - "id": str(actor.id), - "mitre_id": actor.mitre_id, - "name": actor.name, - "aliases": actor.aliases or [], - "country": actor.country, - "target_sectors": actor.target_sectors or [], - "target_regions": actor.target_regions or [], - "motivation": actor.motivation, - "sophistication": actor.sophistication, - "mitre_url": actor.mitre_url, - "technique_count": tech_count, - "coverage_pct": coverage_pct, - "is_active": actor.is_active, - }) - - return { - "total": total, - "offset": offset, - "limit": limit, - "items": results, - } - - -# --------------------------------------------------------------------------- -# GET /threat-actors/{id} — Detalle -# --------------------------------------------------------------------------- @router.get("/{actor_id}") def get_threat_actor( @@ -142,54 +63,8 @@ def get_threat_actor( **Requires** authentication (any role). """ - actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first() - if not actor: - raise HTTPException(status_code=404, detail="Threat actor not found") + return get_actor_detail(db, actor_id) - # Get associated techniques with their coverage status - actor_techniques = ( - db.query(ThreatActorTechnique, Technique) - .join(Technique, ThreatActorTechnique.technique_id == Technique.id) - .filter(ThreatActorTechnique.threat_actor_id == actor.id) - .order_by(Technique.mitre_id) - .all() - ) - - techniques_list = [] - for at, tech in actor_techniques: - techniques_list.append({ - "technique_id": str(tech.id), - "mitre_id": tech.mitre_id, - "name": tech.name, - "tactic": tech.tactic, - "status_global": tech.status_global.value if tech.status_global else None, - "usage_description": at.usage_description, - "first_seen_using": at.first_seen_using, - }) - - return { - "id": str(actor.id), - "mitre_id": actor.mitre_id, - "name": actor.name, - "aliases": actor.aliases or [], - "description": actor.description, - "country": actor.country, - "target_sectors": actor.target_sectors or [], - "target_regions": actor.target_regions or [], - "motivation": actor.motivation, - "sophistication": actor.sophistication, - "first_seen": actor.first_seen, - "last_seen": actor.last_seen, - "references": actor.references or [], - "mitre_url": actor.mitre_url, - "is_active": actor.is_active, - "techniques": techniques_list, - } - - -# --------------------------------------------------------------------------- -# GET /threat-actors/{id}/coverage — Cobertura -# --------------------------------------------------------------------------- @router.get("/{actor_id}/coverage") def get_threat_actor_coverage( @@ -204,49 +79,8 @@ def get_threat_actor_coverage( Returns the percentage of the actor's techniques that have been validated or partially validated, along with a breakdown. """ - actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first() - if not actor: - raise HTTPException(status_code=404, detail="Threat actor not found") + return get_actor_coverage(db, actor_id) - # Get all techniques for this actor - actor_techniques = ( - db.query(Technique) - .join(ThreatActorTechnique, ThreatActorTechnique.technique_id == Technique.id) - .filter(ThreatActorTechnique.threat_actor_id == actor.id) - .all() - ) - - total = len(actor_techniques) - if total == 0: - return { - "actor_id": str(actor.id), - "actor_name": actor.name, - "total_techniques": 0, - "coverage_pct": 0.0, - "breakdown": {}, - } - - breakdown = {} - for tech in actor_techniques: - status = tech.status_global.value if tech.status_global else "not_evaluated" - breakdown[status] = breakdown.get(status, 0) + 1 - - covered = breakdown.get("validated", 0) + breakdown.get("partial", 0) - coverage_pct = round((covered / total * 100), 1) - - return { - "actor_id": str(actor.id), - "actor_name": actor.name, - "total_techniques": total, - "covered": covered, - "coverage_pct": coverage_pct, - "breakdown": breakdown, - } - - -# --------------------------------------------------------------------------- -# GET /threat-actors/{id}/gaps — Gap analysis -# --------------------------------------------------------------------------- @router.get("/{actor_id}/gaps") def get_threat_actor_gaps( @@ -260,52 +94,4 @@ def get_threat_actor_gaps( Returns list of gap techniques with available templates. """ - actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first() - if not actor: - raise HTTPException(status_code=404, detail="Threat actor not found") - - # Get techniques NOT validated - gap_techniques = ( - db.query(Technique, ThreatActorTechnique) - .join(ThreatActorTechnique, ThreatActorTechnique.technique_id == Technique.id) - .filter(ThreatActorTechnique.threat_actor_id == actor.id) - .filter(Technique.status_global != TechniqueStatus.validated) - .order_by(Technique.mitre_id) - .all() - ) - - gaps = [] - for tech, at in gap_techniques: - # Count available templates for this technique - template_count = ( - db.query(TestTemplate) - .filter(TestTemplate.mitre_technique_id == tech.mitre_id) - .filter(TestTemplate.is_active == True) - .count() - ) - - # Count existing tests - test_count = ( - db.query(Test) - .filter(Test.technique_id == tech.id) - .count() - ) - - gaps.append({ - "technique_id": str(tech.id), - "mitre_id": tech.mitre_id, - "name": tech.name, - "tactic": tech.tactic, - "status_global": tech.status_global.value if tech.status_global else None, - "usage_description": at.usage_description, - "available_templates": template_count, - "existing_tests": test_count, - "has_templates": template_count > 0, - }) - - return { - "actor_id": str(actor.id), - "actor_name": actor.name, - "total_gaps": len(gaps), - "gaps": gaps, - } + return get_actor_gaps(db, actor_id) diff --git a/backend/app/services/threat_actor_service.py b/backend/app/services/threat_actor_service.py new file mode 100644 index 0000000..03db9ab --- /dev/null +++ b/backend/app/services/threat_actor_service.py @@ -0,0 +1,310 @@ +"""Threat actor data service. + +Extracts query and business logic from the threat_actors router so +that the router remains a thin HTTP adapter. + +This module is framework-agnostic: no FastAPI imports. +""" + +from __future__ import annotations + +from typing import Any + +from sqlalchemy import case, func, or_ +from sqlalchemy.orm import Session + +from app.domain.errors import EntityNotFoundError +from app.models.enums import TechniqueStatus +from app.models.test import Test +from app.models.test_template import TestTemplate +from app.models.threat_actor import ThreatActor, ThreatActorTechnique +from app.models.technique import Technique +from app.utils import escape_like + + +# ── Public service functions ────────────────────────────────────────── + + +def list_actors( + db: Session, + *, + search: str | None = None, + country: str | None = None, + motivation: str | None = None, + sophistication: str | None = None, + target_sectors: str | None = None, + offset: int = 0, + limit: int = 50, +) -> dict[str, Any]: + """List threat actors with optional filters, pagination, and coverage stats. + + Uses grouped subqueries to avoid N+1: technique counts and coverage + counts are fetched in one query per page. + """ + query = db.query(ThreatActor) + + if search: + pattern = f"%{escape_like(search)}%" + query = query.filter( + or_( + ThreatActor.name.ilike(pattern), + ThreatActor.description.ilike(pattern), + func.cast(ThreatActor.aliases, func.text()).ilike(pattern), + ) + ) + + if country: + query = query.filter(ThreatActor.country == country) + + if motivation: + query = query.filter(ThreatActor.motivation == motivation) + + if sophistication: + query = query.filter(ThreatActor.sophistication == sophistication) + + if target_sectors: + query = query.filter( + func.cast(ThreatActor.target_sectors, func.text()).ilike( + f"%{escape_like(target_sectors)}%" + ) + ) + + total = query.count() + actors = ( + query.order_by(ThreatActor.name).offset(offset).limit(limit).all() + ) + + if not actors: + return { + "total": total, + "offset": offset, + "limit": limit, + "items": [], + } + + actor_ids = [a.id for a in actors] + + # Single grouped query: tech_count and covered_count per actor + counts_rows = ( + db.query( + ThreatActorTechnique.threat_actor_id, + func.count(ThreatActorTechnique.id).label("tech_count"), + func.sum( + case( + ( + Technique.status_global.in_([ + TechniqueStatus.validated, + TechniqueStatus.partial, + ]), + 1, + ), + else_=0, + ) + ).label("covered_count"), + ) + .join(Technique, ThreatActorTechnique.technique_id == Technique.id) + .filter(ThreatActorTechnique.threat_actor_id.in_(actor_ids)) + .group_by(ThreatActorTechnique.threat_actor_id) + ).all() + + counts_map = { + str(row.threat_actor_id): { + "tech_count": row.tech_count, + "covered_count": row.covered_count or 0, + } + for row in counts_rows + } + + results = [] + for actor in actors: + cnt = counts_map.get(str(actor.id), {"tech_count": 0, "covered_count": 0}) + tech_count = cnt["tech_count"] + covered = cnt["covered_count"] + coverage_pct = round((covered / tech_count * 100), 1) if tech_count > 0 else 0.0 + + results.append({ + "id": str(actor.id), + "mitre_id": actor.mitre_id, + "name": actor.name, + "aliases": actor.aliases or [], + "country": actor.country, + "target_sectors": actor.target_sectors or [], + "target_regions": actor.target_regions or [], + "motivation": actor.motivation, + "sophistication": actor.sophistication, + "mitre_url": actor.mitre_url, + "technique_count": tech_count, + "coverage_pct": coverage_pct, + "is_active": actor.is_active, + }) + + return { + "total": total, + "offset": offset, + "limit": limit, + "items": results, + } + + +def get_actor_detail(db: Session, actor_id: str) -> dict[str, Any]: + """Get detailed threat actor with techniques. + + Raises EntityNotFoundError if the actor does not exist. + """ + actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first() + if not actor: + raise EntityNotFoundError("Threat actor", actor_id) + + actor_techniques = ( + db.query(ThreatActorTechnique, Technique) + .join(Technique, ThreatActorTechnique.technique_id == Technique.id) + .filter(ThreatActorTechnique.threat_actor_id == actor.id) + .order_by(Technique.mitre_id) + .all() + ) + + techniques_list = [ + { + "technique_id": str(tech.id), + "mitre_id": tech.mitre_id, + "name": tech.name, + "tactic": tech.tactic, + "status_global": tech.status_global.value if tech.status_global else None, + "usage_description": at.usage_description, + "first_seen_using": at.first_seen_using, + } + for at, tech in actor_techniques + ] + + return { + "id": str(actor.id), + "mitre_id": actor.mitre_id, + "name": actor.name, + "aliases": actor.aliases or [], + "description": actor.description, + "country": actor.country, + "target_sectors": actor.target_sectors or [], + "target_regions": actor.target_regions or [], + "motivation": actor.motivation, + "sophistication": actor.sophistication, + "first_seen": actor.first_seen, + "last_seen": actor.last_seen, + "references": actor.references or [], + "mitre_url": actor.mitre_url, + "is_active": actor.is_active, + "techniques": techniques_list, + } + + +def get_actor_coverage(db: Session, actor_id: str) -> dict[str, Any]: + """Calculate coverage percentage against a specific threat actor. + + Raises EntityNotFoundError if the actor does not exist. + """ + actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first() + if not actor: + raise EntityNotFoundError("Threat actor", actor_id) + + actor_techniques = ( + db.query(Technique) + .join(ThreatActorTechnique, ThreatActorTechnique.technique_id == Technique.id) + .filter(ThreatActorTechnique.threat_actor_id == actor.id) + .all() + ) + + total = len(actor_techniques) + if total == 0: + return { + "actor_id": str(actor.id), + "actor_name": actor.name, + "total_techniques": 0, + "coverage_pct": 0.0, + "breakdown": {}, + } + + breakdown: dict[str, int] = {} + for tech in actor_techniques: + status = tech.status_global.value if tech.status_global else "not_evaluated" + breakdown[status] = breakdown.get(status, 0) + 1 + + covered = breakdown.get("validated", 0) + breakdown.get("partial", 0) + coverage_pct = round((covered / total * 100), 1) + + return { + "actor_id": str(actor.id), + "actor_name": actor.name, + "total_techniques": total, + "covered": covered, + "coverage_pct": coverage_pct, + "breakdown": breakdown, + } + + +def get_actor_gaps(db: Session, actor_id: str) -> dict[str, Any]: + """Identify techniques of this actor that are not fully validated. + + Raises EntityNotFoundError if the actor does not exist. + """ + actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first() + if not actor: + raise EntityNotFoundError("Threat actor", actor_id) + + gap_techniques = ( + db.query(Technique, ThreatActorTechnique) + .join(ThreatActorTechnique, ThreatActorTechnique.technique_id == Technique.id) + .filter(ThreatActorTechnique.threat_actor_id == actor.id) + .filter(Technique.status_global != TechniqueStatus.validated) + .order_by(Technique.mitre_id) + .all() + ) + + if not gap_techniques: + return { + "actor_id": str(actor.id), + "actor_name": actor.name, + "total_gaps": 0, + "gaps": [], + } + + technique_ids = [tech.id for tech, _ in gap_techniques] + mitre_ids = [tech.mitre_id for tech, _ in gap_techniques] + + # Batch template counts by mitre_technique_id + template_counts = ( + db.query(TestTemplate.mitre_technique_id, func.count(TestTemplate.id).label("cnt")) + .filter(TestTemplate.mitre_technique_id.in_(mitre_ids)) + .filter(TestTemplate.is_active == True) + .group_by(TestTemplate.mitre_technique_id) + ).all() + template_map = {row.mitre_technique_id: row.cnt for row in template_counts} + + # Batch test counts by technique_id + test_counts = ( + db.query(Test.technique_id, func.count(Test.id).label("cnt")) + .filter(Test.technique_id.in_(technique_ids)) + .group_by(Test.technique_id) + ).all() + test_map = {str(row.technique_id): row.cnt for row in test_counts} + + gaps = [] + for tech, at in gap_techniques: + template_count = template_map.get(tech.mitre_id, 0) + test_count = test_map.get(str(tech.id), 0) + gaps.append({ + "technique_id": str(tech.id), + "mitre_id": tech.mitre_id, + "name": tech.name, + "tactic": tech.tactic, + "status_global": tech.status_global.value if tech.status_global else None, + "usage_description": at.usage_description, + "available_templates": template_count, + "existing_tests": test_count, + "has_templates": template_count > 0, + }) + + return { + "actor_id": str(actor.id), + "actor_name": actor.name, + "total_gaps": len(gaps), + "gaps": gaps, + }