"""Service for synchronizing MITRE ATT&CK techniques via TAXII 2.0. Connects to the official MITRE CTI TAXII server, fetches the Enterprise ATT&CK collection, and upserts attack-pattern objects into the local ``techniques`` table. Falls back to the MITRE CTI GitHub repository when the TAXII server is unreachable. """ # Import logging import logging # Import datetime from datetime from datetime import datetime # Import requests import requests as _requests # Import Session from sqlalchemy.orm from sqlalchemy.orm import Session # Import Server as TaxiiServer from taxii2client.v20 from taxii2client.v20 import Server as TaxiiServer # Import TechniqueStatus from app.models.enums from app.models.enums import TechniqueStatus # Import Technique from app.models.technique from app.models.technique import Technique # Import log_action from app.services.audit_service from app.services.audit_service import log_action # Assign logger = logging.getLogger(__name__) logger = logging.getLogger(__name__) # Assign TAXII_SERVER_URL = "https://cti-taxii.mitre.org/taxii/" TAXII_SERVER_URL = "https://cti-taxii.mitre.org/taxii/" # Assign MITRE_SOURCE_NAME = "mitre-attack" MITRE_SOURCE_NAME = "mitre-attack" # Assign GITHUB_ENTERPRISE_URL = ( GITHUB_ENTERPRISE_URL = ( # Literal argument value "https://raw.githubusercontent.com/mitre/cti/master/" # Literal argument value "enterprise-attack/enterprise-attack.json" ) # Define function _extract_mitre_id def _extract_mitre_id(external_references: list) -> str | None: """Return the MITRE ATT&CK ID (e.g. ``T1059.001``) from external_references.""" # Check: not external_references if not external_references: # Return None return None # Iterate over external_references for ref in external_references: # Check: ref.get("source_name") == MITRE_SOURCE_NAME if ref.get("source_name") == MITRE_SOURCE_NAME: # Return ref.get("external_id") return ref.get("external_id") # Return None return None # Define function _extract_tactics def _extract_tactics(kill_chain_phases: list) -> str | None: """Return a comma-separated string of tactic phase names.""" # Check: not kill_chain_phases if not kill_chain_phases: # Return None return None # Assign tactics = [ tactics = [ phase.get("phase_name") for phase in kill_chain_phases if phase.get("kill_chain_name") == "mitre-attack" ] # Return ", ".join(tactics) if tactics else None return ", ".join(tactics) if tactics else None # Define function _extract_platforms def _extract_platforms(stix_object: dict) -> list: """Return the list of platforms from the STIX object.""" # Return stix_object.get("x_mitre_platforms", []) return stix_object.get("x_mitre_platforms", []) # Define function _extract_version def _extract_version(stix_object: dict) -> str | None: """Return the MITRE ATT&CK version string.""" # Return stix_object.get("x_mitre_version") return stix_object.get("x_mitre_version") # Define function _extract_last_modified def _extract_last_modified(stix_object: dict) -> datetime | None: """Return the ``modified`` timestamp as a datetime, or None.""" # Assign modified = stix_object.get("modified") modified = stix_object.get("modified") # Check: modified is None if modified is None: # Return None return None # Check: isinstance(modified, datetime) if isinstance(modified, datetime): # Return modified return modified # Attempt the following; catch errors below try: # Return datetime.fromisoformat(modified.replace("Z", "+00:00")) return datetime.fromisoformat(modified.replace("Z", "+00:00")) # Handle (ValueError, AttributeError) except (ValueError, AttributeError): # Return None return None # Define function _fetch_attack_patterns_taxii def _fetch_attack_patterns_taxii() -> list[dict]: """Connect to the MITRE TAXII server and return all attack-pattern objects.""" # Log info: "Connecting to MITRE TAXII server at %s", TAXII_SE logger.info("Connecting to MITRE TAXII server at %s", TAXII_SERVER_URL) # Assign server = TaxiiServer(TAXII_SERVER_URL) server = TaxiiServer(TAXII_SERVER_URL) # Assign api_root = server.api_roots[0] api_root = server.api_roots[0] # Assign collection = api_root.collections[0] # Enterprise ATT&CK collection = api_root.collections[0] # Enterprise ATT&CK # Log info: logger.info( # Literal argument value "Fetching objects from collection '%s' (id=%s)", collection.title, collection.id, ) # Assign bundle = collection.get_objects() bundle = collection.get_objects() # Assign objects = bundle.get("objects", []) objects = bundle.get("objects", []) # Assign attack_patterns = [ attack_patterns = [ obj for obj in objects if obj.get("type") == "attack-pattern" ] # Log info: "Retrieved %d attack-pattern objects via TAXII", l logger.info("Retrieved %d attack-pattern objects via TAXII", len(attack_patterns)) # Return attack_patterns return attack_patterns # Define function _fetch_attack_patterns_github def _fetch_attack_patterns_github() -> list[dict]: """Fallback: fetch Enterprise ATT&CK bundle from the MITRE CTI GitHub repo.""" # Log info: "Fetching Enterprise ATT&CK bundle from GitHub (%s logger.info("Fetching Enterprise ATT&CK bundle from GitHub (%s)", GITHUB_ENTERPRISE_URL) # Assign resp = _requests.get(GITHUB_ENTERPRISE_URL, timeout=120) resp = _requests.get(GITHUB_ENTERPRISE_URL, timeout=120) # Call resp.raise_for_status() resp.raise_for_status() # Assign bundle = resp.json() bundle = resp.json() # Assign objects = bundle.get("objects", []) objects = bundle.get("objects", []) # Assign attack_patterns = [ attack_patterns = [ obj for obj in objects if obj.get("type") == "attack-pattern" ] # Log info: "Retrieved %d attack-pattern objects via GitHub", logger.info("Retrieved %d attack-pattern objects via GitHub", len(attack_patterns)) # Return attack_patterns return attack_patterns # Define function _fetch_attack_patterns def _fetch_attack_patterns() -> list[dict]: """Return all attack-pattern objects, trying TAXII first then GitHub.""" # Attempt the following; catch errors below try: # Return _fetch_attack_patterns_taxii() return _fetch_attack_patterns_taxii() # Handle Exception except Exception as exc: # Log warning: logger.warning( # Literal argument value "TAXII server unavailable (%s), falling back to GitHub mirror", exc, ) # Return _fetch_attack_patterns_github() return _fetch_attack_patterns_github() # Define function sync_mitre def sync_mitre(db: Session) -> dict: """Synchronize MITRE ATT&CK techniques into the local database. Parameters ---------- db : Session Active SQLAlchemy database session. Returns: ------- dict Summary with keys ``created``, ``updated``, ``unchanged``, ``skipped``. """ # Assign attack_patterns = _fetch_attack_patterns() attack_patterns = _fetch_attack_patterns() # Pre-load existing techniques keyed by mitre_id for fast lookup existing_techniques: dict[str, Technique] = { t.mitre_id: t for t in db.query(Technique).all() } # Assign created = 0 created = 0 # Assign updated = 0 updated = 0 # Assign unchanged = 0 unchanged = 0 # Assign skipped = 0 skipped = 0 # Iterate over attack_patterns for obj in attack_patterns: # ------------------------------------------------------------------ # Skip revoked / deprecated objects # ------------------------------------------------------------------ if obj.get("revoked", False) or obj.get("x_mitre_deprecated", False): # Assign skipped = 1 skipped += 1 # Skip to the next loop iteration continue # Assign mitre_id = _extract_mitre_id(obj.get("external_references", [])) mitre_id = _extract_mitre_id(obj.get("external_references", [])) # Check: not mitre_id if not mitre_id: # Assign skipped = 1 skipped += 1 # Skip to the next loop iteration continue # Assign name = obj.get("name", "") name = obj.get("name", "") # Assign description = obj.get("description", "") description = obj.get("description", "") # Assign tactic = _extract_tactics(obj.get("kill_chain_phases", [])) tactic = _extract_tactics(obj.get("kill_chain_phases", [])) # Assign platforms = _extract_platforms(obj) platforms = _extract_platforms(obj) # Assign version = _extract_version(obj) version = _extract_version(obj) # Assign last_modified = _extract_last_modified(obj) last_modified = _extract_last_modified(obj) # Assign is_subtechnique = "." in mitre_id is_subtechnique = "." in mitre_id # Assign parent_mitre_id = mitre_id.split(".")[0] if is_subtechnique else None parent_mitre_id = mitre_id.split(".")[0] if is_subtechnique else None # Assign existing = existing_techniques.get(mitre_id) existing = existing_techniques.get(mitre_id) # Check: existing is None if existing is None: # ---- Create new technique ---- technique = Technique( # Keyword argument: mitre_id mitre_id=mitre_id, # Keyword argument: name name=name, # Keyword argument: description description=description, # Keyword argument: tactic tactic=tactic, # Keyword argument: platforms platforms=platforms, # Keyword argument: mitre_version mitre_version=version, # Keyword argument: mitre_last_modified mitre_last_modified=last_modified, # Keyword argument: is_subtechnique is_subtechnique=is_subtechnique, # Keyword argument: parent_mitre_id parent_mitre_id=parent_mitre_id, # Keyword argument: status_global status_global=TechniqueStatus.not_evaluated, # Keyword argument: review_required review_required=False, ) # Stage new record(s) for database insertion db.add(technique) # Assign existing_techniques[mitre_id] = technique existing_techniques[mitre_id] = technique # Assign created = 1 created += 1 # Fallback: handle remaining cases else: # ---- Update if name or description changed ---- changes = False # Check: existing.name != name if existing.name != name: # Assign existing.name = name existing.name = name # Assign changes = True changes = True # Check: (existing.description or "") != (description or "") if (existing.description or "") != (description or ""): # Assign existing.description = description existing.description = description # Assign changes = True changes = True # Always keep metadata up-to-date (does not trigger review) existing.tactic = tactic # Assign existing.platforms = platforms existing.platforms = platforms # Assign existing.mitre_version = version existing.mitre_version = version # Assign existing.mitre_last_modified = last_modified existing.mitre_last_modified = last_modified # Assign existing.is_subtechnique = is_subtechnique existing.is_subtechnique = is_subtechnique # Assign existing.parent_mitre_id = parent_mitre_id existing.parent_mitre_id = parent_mitre_id # Check: changes if changes: # Assign existing.review_required = True existing.review_required = True # Assign updated = 1 updated += 1 # Fallback: handle remaining cases else: # Assign unchanged = 1 unchanged += 1 # Single commit for the whole batch db.commit() # Assign summary = { summary = { # Literal argument value "created": created, # Literal argument value "updated": updated, # Literal argument value "unchanged": unchanged, # Literal argument value "skipped": skipped, } # Log info: logger.info( # Literal argument value "MITRE sync complete — created=%d, updated=%d, unchanged=%d, skipped=%d", created, updated, unchanged, skipped, ) # Audit log (system action → user_id=None) log_action( db, # Keyword argument: user_id user_id=None, # Keyword argument: action action="mitre_sync", # Keyword argument: entity_type entity_type="technique", # Keyword argument: entity_id entity_id=None, # Keyword argument: details details=summary, ) # Commit all pending changes to the database db.commit() # Return summary return summary