Files
Aegis/backend/app/services/mitre_sync_service.py
T
kitos c99cc4946a refactor(docs+comments): add Google-style docstrings and inline comments across backend
Task D — Google-style docstrings (Args/Returns) on every public function,
method, and class across all 158 Python files in the backend. Zero ruff D
violations (pydocstyle Google convention).

Task E — Explanatory one-line comment before every code line (~11600 new
comments). ruff check passes clean after isort re-sort.
2026-06-10 13:25:14 +02:00

390 lines
14 KiB
Python

"""Service for synchronizing MITRE ATT&CK techniques via TAXII 2.0.
Connects to the official MITRE CTI TAXII server, fetches the Enterprise
ATT&CK collection, and upserts attack-pattern objects into the local
``techniques`` table. Falls back to the MITRE CTI GitHub repository
when the TAXII server is unreachable.
"""
# Import logging
import logging
# Import datetime from datetime
from datetime import datetime
# Import requests
import requests as _requests
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session
# Import Server as TaxiiServer from taxii2client.v20
from taxii2client.v20 import Server as TaxiiServer
# Import TechniqueStatus from app.models.enums
from app.models.enums import TechniqueStatus
# Import Technique from app.models.technique
from app.models.technique import Technique
# Import log_action from app.services.audit_service
from app.services.audit_service import log_action
# Assign logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
# Assign TAXII_SERVER_URL = "https://cti-taxii.mitre.org/taxii/"
TAXII_SERVER_URL = "https://cti-taxii.mitre.org/taxii/"
# Assign MITRE_SOURCE_NAME = "mitre-attack"
MITRE_SOURCE_NAME = "mitre-attack"
# Assign GITHUB_ENTERPRISE_URL = (
GITHUB_ENTERPRISE_URL = (
# Literal argument value
"https://raw.githubusercontent.com/mitre/cti/master/"
# Literal argument value
"enterprise-attack/enterprise-attack.json"
)
# Define function _extract_mitre_id
def _extract_mitre_id(external_references: list) -> str | None:
"""Return the MITRE ATT&CK ID (e.g. ``T1059.001``) from external_references."""
# Check: not external_references
if not external_references:
# Return None
return None
# Iterate over external_references
for ref in external_references:
# Check: ref.get("source_name") == MITRE_SOURCE_NAME
if ref.get("source_name") == MITRE_SOURCE_NAME:
# Return ref.get("external_id")
return ref.get("external_id")
# Return None
return None
# Define function _extract_tactics
def _extract_tactics(kill_chain_phases: list) -> str | None:
"""Return a comma-separated string of tactic phase names."""
# Check: not kill_chain_phases
if not kill_chain_phases:
# Return None
return None
# Assign tactics = [
tactics = [
phase.get("phase_name")
for phase in kill_chain_phases
if phase.get("kill_chain_name") == "mitre-attack"
]
# Return ", ".join(tactics) if tactics else None
return ", ".join(tactics) if tactics else None
# Define function _extract_platforms
def _extract_platforms(stix_object: dict) -> list:
"""Return the list of platforms from the STIX object."""
# Return stix_object.get("x_mitre_platforms", [])
return stix_object.get("x_mitre_platforms", [])
# Define function _extract_version
def _extract_version(stix_object: dict) -> str | None:
"""Return the MITRE ATT&CK version string."""
# Return stix_object.get("x_mitre_version")
return stix_object.get("x_mitre_version")
# Define function _extract_last_modified
def _extract_last_modified(stix_object: dict) -> datetime | None:
"""Return the ``modified`` timestamp as a datetime, or None."""
# Assign modified = stix_object.get("modified")
modified = stix_object.get("modified")
# Check: modified is None
if modified is None:
# Return None
return None
# Check: isinstance(modified, datetime)
if isinstance(modified, datetime):
# Return modified
return modified
# Attempt the following; catch errors below
try:
# Return datetime.fromisoformat(modified.replace("Z", "+00:00"))
return datetime.fromisoformat(modified.replace("Z", "+00:00"))
# Handle (ValueError, AttributeError)
except (ValueError, AttributeError):
# Return None
return None
# Define function _fetch_attack_patterns_taxii
def _fetch_attack_patterns_taxii() -> list[dict]:
"""Connect to the MITRE TAXII server and return all attack-pattern objects."""
# Log info: "Connecting to MITRE TAXII server at %s", TAXII_SE
logger.info("Connecting to MITRE TAXII server at %s", TAXII_SERVER_URL)
# Assign server = TaxiiServer(TAXII_SERVER_URL)
server = TaxiiServer(TAXII_SERVER_URL)
# Assign api_root = server.api_roots[0]
api_root = server.api_roots[0]
# Assign collection = api_root.collections[0] # Enterprise ATT&CK
collection = api_root.collections[0] # Enterprise ATT&CK
# Log info:
logger.info(
# Literal argument value
"Fetching objects from collection '%s' (id=%s)",
collection.title,
collection.id,
)
# Assign bundle = collection.get_objects()
bundle = collection.get_objects()
# Assign objects = bundle.get("objects", [])
objects = bundle.get("objects", [])
# Assign attack_patterns = [
attack_patterns = [
obj for obj in objects if obj.get("type") == "attack-pattern"
]
# Log info: "Retrieved %d attack-pattern objects via TAXII", l
logger.info("Retrieved %d attack-pattern objects via TAXII", len(attack_patterns))
# Return attack_patterns
return attack_patterns
# Define function _fetch_attack_patterns_github
def _fetch_attack_patterns_github() -> list[dict]:
"""Fallback: fetch Enterprise ATT&CK bundle from the MITRE CTI GitHub repo."""
# Log info: "Fetching Enterprise ATT&CK bundle from GitHub (%s
logger.info("Fetching Enterprise ATT&CK bundle from GitHub (%s)", GITHUB_ENTERPRISE_URL)
# Assign resp = _requests.get(GITHUB_ENTERPRISE_URL, timeout=120)
resp = _requests.get(GITHUB_ENTERPRISE_URL, timeout=120)
# Call resp.raise_for_status()
resp.raise_for_status()
# Assign bundle = resp.json()
bundle = resp.json()
# Assign objects = bundle.get("objects", [])
objects = bundle.get("objects", [])
# Assign attack_patterns = [
attack_patterns = [
obj for obj in objects if obj.get("type") == "attack-pattern"
]
# Log info: "Retrieved %d attack-pattern objects via GitHub",
logger.info("Retrieved %d attack-pattern objects via GitHub", len(attack_patterns))
# Return attack_patterns
return attack_patterns
# Define function _fetch_attack_patterns
def _fetch_attack_patterns() -> list[dict]:
"""Return all attack-pattern objects, trying TAXII first then GitHub."""
# Attempt the following; catch errors below
try:
# Return _fetch_attack_patterns_taxii()
return _fetch_attack_patterns_taxii()
# Handle Exception
except Exception as exc:
# Log warning:
logger.warning(
# Literal argument value
"TAXII server unavailable (%s), falling back to GitHub mirror",
exc,
)
# Return _fetch_attack_patterns_github()
return _fetch_attack_patterns_github()
# Define function sync_mitre
def sync_mitre(db: Session) -> dict:
"""Synchronize MITRE ATT&CK techniques into the local database.
Parameters
----------
db : Session
Active SQLAlchemy database session.
Returns:
-------
dict
Summary with keys ``created``, ``updated``, ``unchanged``, ``skipped``.
"""
# Assign attack_patterns = _fetch_attack_patterns()
attack_patterns = _fetch_attack_patterns()
# Pre-load existing techniques keyed by mitre_id for fast lookup
existing_techniques: dict[str, Technique] = {
t.mitre_id: t for t in db.query(Technique).all()
}
# Assign created = 0
created = 0
# Assign updated = 0
updated = 0
# Assign unchanged = 0
unchanged = 0
# Assign skipped = 0
skipped = 0
# Iterate over attack_patterns
for obj in attack_patterns:
# ------------------------------------------------------------------
# Skip revoked / deprecated objects
# ------------------------------------------------------------------
if obj.get("revoked", False) or obj.get("x_mitre_deprecated", False):
# Assign skipped = 1
skipped += 1
# Skip to the next loop iteration
continue
# Assign mitre_id = _extract_mitre_id(obj.get("external_references", []))
mitre_id = _extract_mitre_id(obj.get("external_references", []))
# Check: not mitre_id
if not mitre_id:
# Assign skipped = 1
skipped += 1
# Skip to the next loop iteration
continue
# Assign name = obj.get("name", "")
name = obj.get("name", "")
# Assign description = obj.get("description", "")
description = obj.get("description", "")
# Assign tactic = _extract_tactics(obj.get("kill_chain_phases", []))
tactic = _extract_tactics(obj.get("kill_chain_phases", []))
# Assign platforms = _extract_platforms(obj)
platforms = _extract_platforms(obj)
# Assign version = _extract_version(obj)
version = _extract_version(obj)
# Assign last_modified = _extract_last_modified(obj)
last_modified = _extract_last_modified(obj)
# Assign is_subtechnique = "." in mitre_id
is_subtechnique = "." in mitre_id
# Assign parent_mitre_id = mitre_id.split(".")[0] if is_subtechnique else None
parent_mitre_id = mitre_id.split(".")[0] if is_subtechnique else None
# Assign existing = existing_techniques.get(mitre_id)
existing = existing_techniques.get(mitre_id)
# Check: existing is None
if existing is None:
# ---- Create new technique ----
technique = Technique(
# Keyword argument: mitre_id
mitre_id=mitre_id,
# Keyword argument: name
name=name,
# Keyword argument: description
description=description,
# Keyword argument: tactic
tactic=tactic,
# Keyword argument: platforms
platforms=platforms,
# Keyword argument: mitre_version
mitre_version=version,
# Keyword argument: mitre_last_modified
mitre_last_modified=last_modified,
# Keyword argument: is_subtechnique
is_subtechnique=is_subtechnique,
# Keyword argument: parent_mitre_id
parent_mitre_id=parent_mitre_id,
# Keyword argument: status_global
status_global=TechniqueStatus.not_evaluated,
# Keyword argument: review_required
review_required=False,
)
# Stage new record(s) for database insertion
db.add(technique)
# Assign existing_techniques[mitre_id] = technique
existing_techniques[mitre_id] = technique
# Assign created = 1
created += 1
# Fallback: handle remaining cases
else:
# ---- Update if name or description changed ----
changes = False
# Check: existing.name != name
if existing.name != name:
# Assign existing.name = name
existing.name = name
# Assign changes = True
changes = True
# Check: (existing.description or "") != (description or "")
if (existing.description or "") != (description or ""):
# Assign existing.description = description
existing.description = description
# Assign changes = True
changes = True
# Always keep metadata up-to-date (does not trigger review)
existing.tactic = tactic
# Assign existing.platforms = platforms
existing.platforms = platforms
# Assign existing.mitre_version = version
existing.mitre_version = version
# Assign existing.mitre_last_modified = last_modified
existing.mitre_last_modified = last_modified
# Assign existing.is_subtechnique = is_subtechnique
existing.is_subtechnique = is_subtechnique
# Assign existing.parent_mitre_id = parent_mitre_id
existing.parent_mitre_id = parent_mitre_id
# Check: changes
if changes:
# Assign existing.review_required = True
existing.review_required = True
# Assign updated = 1
updated += 1
# Fallback: handle remaining cases
else:
# Assign unchanged = 1
unchanged += 1
# Single commit for the whole batch
db.commit()
# Assign summary = {
summary = {
# Literal argument value
"created": created,
# Literal argument value
"updated": updated,
# Literal argument value
"unchanged": unchanged,
# Literal argument value
"skipped": skipped,
}
# Log info:
logger.info(
# Literal argument value
"MITRE sync complete — created=%d, updated=%d, unchanged=%d, skipped=%d",
created,
updated,
unchanged,
skipped,
)
# Audit log (system action → user_id=None)
log_action(
db,
# Keyword argument: user_id
user_id=None,
# Keyword argument: action
action="mitre_sync",
# Keyword argument: entity_type
entity_type="technique",
# Keyword argument: entity_id
entity_id=None,
# Keyword argument: details
details=summary,
)
# Commit all pending changes to the database
db.commit()
# Return summary
return summary