"""Threat Actor import service (MITRE CTI / STIX 2.0). Downloads the MITRE CTI repository, parses the STIX 2.0 bundle for ``intrusion-set`` objects (APT groups) and ``relationship`` objects linking them to ``attack-pattern`` (techniques), then creates :class:`ThreatActor` and :class:`ThreatActorTechnique` records. STIX 2.0 structure ------------------ The enterprise-attack bundle contains: - ``intrusion-set`` objects → our ThreatActor rows - ``attack-pattern`` objects → already in our Technique table - ``relationship`` objects (type=uses) → connects intrusion-set → attack-pattern Strategy -------- 1. Download ZIP of ``github.com/mitre/cti``. 2. Load ``enterprise-attack/enterprise-attack.json`` (single STIX bundle). 3. Build lookup maps for intrusion-sets and attack-patterns. 4. Parse relationships to connect actors → techniques. 5. Upsert into database. Idempotency ----------- Deduplication by ``mitre_id`` for ThreatActor and by the unique constraint ``(threat_actor_id, technique_id)`` for ThreatActorTechnique. """ # Import io import io # Import json import json # Import logging import logging # Import shutil import shutil # Import tempfile import tempfile # Import zipfile import zipfile # Import datetime from datetime from datetime import datetime # Import Path from pathlib from pathlib import Path # Import requests import requests as _requests # Import Session from sqlalchemy.orm from sqlalchemy.orm import Session # Import DataSource from app.models.data_source from app.models.data_source import DataSource # Import Technique from app.models.technique from app.models.technique import Technique # Import ThreatActor, ThreatActorTechnique from app.models.threat_actor from app.models.threat_actor import ThreatActor, ThreatActorTechnique # Import log_action from app.services.audit_service from app.services.audit_service import log_action # Assign logger = logging.getLogger(__name__) logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Constants # --------------------------------------------------------------------------- MITRE_CTI_ZIP_URL = ( # Literal argument value "https://github.com/mitre/cti" # Literal argument value "/archive/refs/heads/master.zip" ) # Assign _DOWNLOAD_TIMEOUT = 300 _DOWNLOAD_TIMEOUT = 300 # Assign _ZIP_ROOT_PREFIX = "cti-master" _ZIP_ROOT_PREFIX = "cti-master" # --------------------------------------------------------------------------- # Internal helpers # --------------------------------------------------------------------------- def _download_zip(url: str = MITRE_CTI_ZIP_URL) -> bytes: """Download the MITRE CTI ZIP and return raw bytes.""" # Log info: "Downloading MITRE CTI ZIP from %s …", url logger.info("Downloading MITRE CTI ZIP from %s …", url) # Assign resp = _requests.get(url, timeout=_DOWNLOAD_TIMEOUT, stream=True) resp = _requests.get(url, timeout=_DOWNLOAD_TIMEOUT, stream=True) # Call resp.raise_for_status() resp.raise_for_status() # Assign content = resp.content content = resp.content # Log info: "Downloaded %.1f MB", len(content) / (1024 * 1024 logger.info("Downloaded %.1f MB", len(content) / (1024 * 1024)) # Return content return content # Define function _extract_zip_and_load_bundle def _extract_zip_and_load_bundle(zip_bytes: bytes, dest: str) -> dict: """Extract ZIP and load the enterprise-attack STIX bundle.""" # Open context manager with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf: # Call zf.extractall() zf.extractall(dest) # Assign bundle_path = ( bundle_path = ( Path(dest) / _ZIP_ROOT_PREFIX / "enterprise-attack" / "enterprise-attack.json" ) # Check: not bundle_path.is_file() if not bundle_path.is_file(): # Raise FileNotFoundError raise FileNotFoundError( f"STIX bundle not found at {bundle_path}" ) # Log info: "Loading STIX bundle from %s …", bundle_path logger.info("Loading STIX bundle from %s …", bundle_path) # Open context manager with open(bundle_path, "r", encoding="utf-8") as fh: # Assign bundle = json.load(fh) bundle = json.load(fh) # Assign objects = bundle.get("objects", []) objects = bundle.get("objects", []) # Log info: "Loaded %d STIX objects", len(objects logger.info("Loaded %d STIX objects", len(objects)) # Return bundle return bundle # Define function _extract_mitre_id def _extract_mitre_id(external_references: list) -> str | None: """Extract the MITRE ATT&CK ID from external_references.""" # Check: not isinstance(external_references, list) if not isinstance(external_references, list): # Return None return None # Iterate over external_references for ref in external_references: # Check: isinstance(ref, dict) and ref.get("source_name") == "mitre-attack" if isinstance(ref, dict) and ref.get("source_name") == "mitre-attack": # Return ref.get("external_id") return ref.get("external_id") # Return None return None # Define function _extract_mitre_url def _extract_mitre_url(external_references: list) -> str | None: """Extract the MITRE ATT&CK URL from external_references.""" # Check: not isinstance(external_references, list) if not isinstance(external_references, list): # Return None return None # Iterate over external_references for ref in external_references: # Check: isinstance(ref, dict) and ref.get("source_name") == "mitre-attack" if isinstance(ref, dict) and ref.get("source_name") == "mitre-attack": # Return ref.get("url") return ref.get("url") # Return None return None # Define function _parse_intrusion_sets def _parse_intrusion_sets(objects: list) -> list[dict]: """Parse STIX intrusion-set objects into ThreatActor dicts.""" # Assign actors = [] actors = [] # Iterate over objects for obj in objects: # Check: obj.get("type") != "intrusion-set" if obj.get("type") != "intrusion-set": # Skip to the next loop iteration continue # Check: obj.get("revoked") if obj.get("revoked"): # Skip to the next loop iteration continue # Assign ext_refs = obj.get("external_references", []) ext_refs = obj.get("external_references", []) # Assign mitre_id = _extract_mitre_id(ext_refs) mitre_id = _extract_mitre_id(ext_refs) # Assign mitre_url = _extract_mitre_url(ext_refs) mitre_url = _extract_mitre_url(ext_refs) # Assign name = obj.get("name", "").strip() name = obj.get("name", "").strip() # Check: not name if not name: # Skip to the next loop iteration continue # Assign aliases = obj.get("aliases", []) aliases = obj.get("aliases", []) # Check: isinstance(aliases, list) and name in aliases if isinstance(aliases, list) and name in aliases: # Assign aliases = [a for a in aliases if a != name] aliases = [a for a in aliases if a != name] # Assign description = obj.get("description", "") description = obj.get("description", "") # Extract references (non-MITRE) references = [] # Iterate over ext_refs for ref in ext_refs: # Check: isinstance(ref, dict) and ref.get("source_name") != "mitre-attack" if isinstance(ref, dict) and ref.get("source_name") != "mitre-attack": # Call references.append() references.append({ # Literal argument value "source": ref.get("source_name", ""), # Literal argument value "url": ref.get("url", ""), # Literal argument value "description": ref.get("description", ""), }) # Call actors.append() actors.append({ # Literal argument value "stix_id": obj.get("id"), # e.g. "intrusion-set--abc123" # Literal argument value "mitre_id": mitre_id, # Literal argument value "name": name, # Literal argument value "aliases": aliases if aliases else [], # Literal argument value "description": description, # Literal argument value "mitre_url": mitre_url, # Literal argument value "references": references[:20], # cap to avoid bloat # Literal argument value "first_seen": obj.get("first_seen"), # Literal argument value "last_seen": obj.get("last_seen"), }) # Log info: "Parsed %d intrusion-sets (threat actors)", len(ac logger.info("Parsed %d intrusion-sets (threat actors)", len(actors)) # Return actors return actors # Define function _parse_relationships def _parse_relationships(objects: list) -> list[dict]: """Parse STIX relationship objects (type=uses) linking intrusion-sets to attack-patterns.""" # Assign relationships = [] relationships = [] # Iterate over objects for obj in objects: # Check: obj.get("type") != "relationship" if obj.get("type") != "relationship": # Skip to the next loop iteration continue # Check: obj.get("relationship_type") != "uses" if obj.get("relationship_type") != "uses": # Skip to the next loop iteration continue # Check: obj.get("revoked") if obj.get("revoked"): # Skip to the next loop iteration continue # Assign source_ref = obj.get("source_ref", "") source_ref = obj.get("source_ref", "") # Assign target_ref = obj.get("target_ref", "") target_ref = obj.get("target_ref", "") # We want intrusion-set → attack-pattern if not source_ref.startswith("intrusion-set--"): # Skip to the next loop iteration continue # Check: not target_ref.startswith("attack-pattern--") if not target_ref.startswith("attack-pattern--"): # Skip to the next loop iteration continue # Call relationships.append() relationships.append({ # Literal argument value "source_ref": source_ref, # Literal argument value "target_ref": target_ref, # Literal argument value "description": obj.get("description", ""), }) # Log info: "Parsed %d uses-relationships (actor→technique)", logger.info("Parsed %d uses-relationships (actor→technique)", len(relationships)) # Return relationships return relationships # Define function _build_attack_pattern_map def _build_attack_pattern_map(objects: list) -> dict[str, str]: """Build a map from STIX attack-pattern ID → MITRE technique ID. e.g. {"attack-pattern--abc123": "T1059.001"} """ # Assign mapping = {} mapping = {} # Iterate over objects for obj in objects: # Check: obj.get("type") != "attack-pattern" if obj.get("type") != "attack-pattern": # Skip to the next loop iteration continue # Check: obj.get("revoked") if obj.get("revoked"): # Skip to the next loop iteration continue # Assign stix_id = obj.get("id", "") stix_id = obj.get("id", "") # Assign mitre_id = _extract_mitre_id(obj.get("external_references", [])) mitre_id = _extract_mitre_id(obj.get("external_references", [])) # Check: stix_id and mitre_id if stix_id and mitre_id: # Assign mapping[stix_id] = mitre_id mapping[stix_id] = mitre_id # Log info: "Built attack-pattern map with %d entries", len(ma logger.info("Built attack-pattern map with %d entries", len(mapping)) # Return mapping return mapping # --------------------------------------------------------------------------- # Public API # --------------------------------------------------------------------------- def sync(db: Session) -> dict: """Download and import threat actors from MITRE CTI. Returns a summary dict. """ # Assign tmp_dir = tempfile.mkdtemp(prefix="aegis_mitre_cti_") tmp_dir = tempfile.mkdtemp(prefix="aegis_mitre_cti_") # Attempt the following; catch errors below try: # Assign zip_bytes = _download_zip() zip_bytes = _download_zip() # Assign bundle = _extract_zip_and_load_bundle(zip_bytes, tmp_dir) bundle = _extract_zip_and_load_bundle(zip_bytes, tmp_dir) # Always execute this cleanup block finally: # Call shutil.rmtree() shutil.rmtree(tmp_dir, ignore_errors=True) # Log info: "Cleaned up temp directory %s", tmp_dir logger.info("Cleaned up temp directory %s", tmp_dir) # Assign objects = bundle.get("objects", []) objects = bundle.get("objects", []) # Step 1: Parse data actor_dicts = _parse_intrusion_sets(objects) # Assign relationships = _parse_relationships(objects) relationships = _parse_relationships(objects) # Assign attack_pattern_map = _build_attack_pattern_map(objects) attack_pattern_map = _build_attack_pattern_map(objects) # Step 3: Load existing actors and techniques from DB existing_actors = { row.mitre_id: row for row in db.query(ThreatActor).all() if row.mitre_id } # Assign technique_by_mitre_id = { technique_by_mitre_id = { row.mitre_id: row for row in db.query(Technique).all() } # Step 4: Upsert threat actors actors_created = 0 # Assign actors_skipped = 0 actors_skipped = 0 # Assign stix_to_db_actor = {} stix_to_db_actor: dict[str, ThreatActor] = {} # Iterate over actor_dicts for actor_dict in actor_dicts: # Assign mitre_id = actor_dict["mitre_id"] mitre_id = actor_dict["mitre_id"] # Assign stix_id = actor_dict["stix_id"] stix_id = actor_dict["stix_id"] # Check: mitre_id and mitre_id in existing_actors if mitre_id and mitre_id in existing_actors: # Update existing actor db_actor = existing_actors[mitre_id] # Assign db_actor.name = actor_dict["name"] db_actor.name = actor_dict["name"] # Assign db_actor.aliases = actor_dict["aliases"] db_actor.aliases = actor_dict["aliases"] # Assign db_actor.description = actor_dict["description"] db_actor.description = actor_dict["description"] # Assign db_actor.mitre_url = actor_dict["mitre_url"] db_actor.mitre_url = actor_dict["mitre_url"] # Assign db_actor.references = actor_dict["references"] db_actor.references = actor_dict["references"] # Assign db_actor.first_seen = actor_dict.get("first_seen") db_actor.first_seen = actor_dict.get("first_seen") # Assign db_actor.last_seen = actor_dict.get("last_seen") db_actor.last_seen = actor_dict.get("last_seen") # Assign stix_to_db_actor[stix_id] = db_actor stix_to_db_actor[stix_id] = db_actor # Assign actors_skipped = 1 actors_skipped += 1 # Fallback: handle remaining cases else: # Create new actor db_actor = ThreatActor( # Keyword argument: mitre_id mitre_id=mitre_id, # Keyword argument: name name=actor_dict["name"], # Keyword argument: aliases aliases=actor_dict["aliases"], # Keyword argument: description description=actor_dict["description"], # Keyword argument: mitre_url mitre_url=actor_dict["mitre_url"], # Keyword argument: references references=actor_dict["references"], # Keyword argument: first_seen first_seen=actor_dict.get("first_seen"), # Keyword argument: last_seen last_seen=actor_dict.get("last_seen"), # Keyword argument: is_active is_active=True, ) # Stage new record(s) for database insertion db.add(db_actor) # Flush changes to DB without committing the transaction db.flush() # get the ID # Check: mitre_id if mitre_id: # Assign existing_actors[mitre_id] = db_actor existing_actors[mitre_id] = db_actor # Assign stix_to_db_actor[stix_id] = db_actor stix_to_db_actor[stix_id] = db_actor # Assign actors_created = 1 actors_created += 1 # Flush changes to DB without committing the transaction db.flush() # Step 5: Upsert actor-technique relationships # Load existing relationships existing_rels: set[tuple] = set() # Iterate over db.query(ThreatActorTechnique).all() for row in db.query(ThreatActorTechnique).all(): # Call existing_rels.add() existing_rels.add((str(row.threat_actor_id), str(row.technique_id))) # Assign rels_created = 0 rels_created = 0 # Assign rels_skipped = 0 rels_skipped = 0 # Iterate over relationships for rel in relationships: # Assign source_ref = rel["source_ref"] source_ref = rel["source_ref"] # Assign target_ref = rel["target_ref"] target_ref = rel["target_ref"] # Resolve actor db_actor = stix_to_db_actor.get(source_ref) # Check: not db_actor if not db_actor: # Skip to the next loop iteration continue # Resolve technique mitre_technique_id = attack_pattern_map.get(target_ref) # Check: not mitre_technique_id if not mitre_technique_id: # Skip to the next loop iteration continue # Assign db_technique = technique_by_mitre_id.get(mitre_technique_id) db_technique = technique_by_mitre_id.get(mitre_technique_id) # Check: not db_technique if not db_technique: # Skip to the next loop iteration continue # Assign rel_key = (str(db_actor.id), str(db_technique.id)) rel_key = (str(db_actor.id), str(db_technique.id)) # Check: rel_key in existing_rels if rel_key in existing_rels: # Assign rels_skipped = 1 rels_skipped += 1 # Skip to the next loop iteration continue # Assign actor_technique = ThreatActorTechnique( actor_technique = ThreatActorTechnique( # Keyword argument: threat_actor_id threat_actor_id=db_actor.id, # Keyword argument: technique_id technique_id=db_technique.id, # Keyword argument: usage_description usage_description=rel["description"][:5000] if rel["description"] else None, ) # Stage new record(s) for database insertion db.add(actor_technique) # Call existing_rels.add() existing_rels.add(rel_key) # Assign rels_created = 1 rels_created += 1 # Commit all pending changes to the database db.commit() # Assign summary = { summary = { # Literal argument value "actors_created": actors_created, # Literal argument value "actors_updated": actors_skipped, # Literal argument value "relationships_created": rels_created, # Literal argument value "relationships_skipped": rels_skipped, # Literal argument value "total_actors_parsed": len(actor_dicts), # Literal argument value "total_relationships_parsed": len(relationships), } # Update DataSource record ds = db.query(DataSource).filter(DataSource.name == "mitre_cti").first() # Check: ds if ds: # Assign ds.last_sync_at = datetime.utcnow() ds.last_sync_at = datetime.utcnow() # Assign ds.last_sync_status = "success" ds.last_sync_status = "success" # Assign ds.last_sync_stats = summary ds.last_sync_stats = summary # Commit all pending changes to the database db.commit() # Log info: "MITRE CTI threat actor import complete — %s", sum logger.info("MITRE CTI threat actor import complete — %s", summary) # Call log_action() log_action( db, # Keyword argument: user_id user_id=None, # Keyword argument: action action="import_threat_actors", # Keyword argument: entity_type entity_type="threat_actor", # Keyword argument: entity_id entity_id=None, # Keyword argument: details details=summary, ) # Commit all pending changes to the database db.commit() # Return summary return summary