"""Atomic Red Team import service. Downloads the Atomic Red Team repository ZIP from GitHub, parses every ``atomics/T*/T*.yaml`` file, and upserts :class:`TestTemplate` records into the database. Strategy -------- The GitHub REST API without authentication only allows 60 req/hour. Since the Atomic Red Team repo contains 1 500+ YAML files we avoid per-file requests entirely. Instead we: 1. Download the full repo as a ZIP archive (~40 MB). 2. Extract in a temporary directory. 3. Walk ``atomics/T*/T*.yaml`` files parsing them with PyYAML. 4. Create / update ``TestTemplate`` rows keyed by ``atomic_test_id``. 5. Clean up the temporary directory. Idempotency ----------- Running the import twice does **not** create duplicates. Existing templates are identified by their ``atomic_test_id`` and simply skipped. """ # Import io import io # Import logging import logging # Import shutil import shutil # Import tempfile import tempfile # Import zipfile import zipfile # Import Path from pathlib from pathlib import Path # Import requests import requests as _requests # Import yaml import yaml # Import Session from sqlalchemy.orm from sqlalchemy.orm import Session # Import TestTemplate from app.models.test_template from app.models.test_template import TestTemplate from app.models.technique import Technique from app.services.audit_service import log_action # Assign logger = logging.getLogger(__name__) logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Constants # --------------------------------------------------------------------------- ATOMIC_RT_ZIP_URL = ( # Literal argument value "https://github.com/redcanaryco/atomic-red-team" # Literal argument value "/archive/refs/heads/master.zip" ) # Request timeout for the ZIP download (seconds) _DOWNLOAD_TIMEOUT = 300 # Top-level directory name inside the ZIP _ZIP_ROOT_PREFIX = "atomic-red-team-master" # Safety limits for ZIP extraction — prevent zip-bomb DoS _MAX_UNCOMPRESSED_SIZE = 500 * 1024 * 1024 # 500 MB # Assign _MAX_ENTRIES = 50_000 _MAX_ENTRIES = 50_000 # --------------------------------------------------------------------------- # Internal helpers # --------------------------------------------------------------------------- def _download_zip(url: str = ATOMIC_RT_ZIP_URL) -> bytes: """Download the Atomic Red Team ZIP and return its raw bytes.""" # Log info: "Downloading Atomic Red Team ZIP from %s …", url logger.info("Downloading Atomic Red Team ZIP from %s …", url) # Assign resp = _requests.get(url, timeout=_DOWNLOAD_TIMEOUT, stream=True) resp = _requests.get(url, timeout=_DOWNLOAD_TIMEOUT, stream=True) # Call resp.raise_for_status() resp.raise_for_status() # Assign content = resp.content content = resp.content # Log info: "Downloaded %.1f MB", len(content) / (1024 * 1024 logger.info("Downloaded %.1f MB", len(content) / (1024 * 1024)) # Return content return content # Define function _safe_extract_zip def _safe_extract_zip(zip_bytes: bytes, dest: str) -> None: """Extract *zip_bytes* into *dest* with Zip Slip and Zip Bomb protection. Raises :class:`ValueError` if any member tries to escape the target directory (path traversal / Zip Slip) or if the archive exceeds the safety limits. """ # Assign dest_path = Path(dest).resolve() dest_path = Path(dest).resolve() # Open context manager with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf: # Assign entries = zf.infolist() entries = zf.infolist() # Check: len(entries) > _MAX_ENTRIES if len(entries) > _MAX_ENTRIES: # Raise ValueError raise ValueError( f"ZIP archive contains {len(entries)} entries " f"(limit: {_MAX_ENTRIES}) — possible zip bomb" ) # Assign total_size = sum(info.file_size for info in entries) total_size = sum(info.file_size for info in entries) # Check: total_size > _MAX_UNCOMPRESSED_SIZE if total_size > _MAX_UNCOMPRESSED_SIZE: # Raise ValueError raise ValueError( f"ZIP uncompressed size {total_size / (1024 * 1024):.0f} MB " f"exceeds limit of {_MAX_UNCOMPRESSED_SIZE / (1024 * 1024):.0f} MB" ) # Iterate over entries — validate and extract each member individually for member in entries: # Assign target = (dest_path / member.filename).resolve() target = (dest_path / member.filename).resolve() # Check: not target.is_relative_to(dest_path) if not target.is_relative_to(dest_path): # Raise ValueError raise ValueError( f"Zip Slip detected — member '{member.filename}' " f"resolves outside target directory" ) zf.extract(member, dest) # Define function _extract_zip def _extract_zip(zip_bytes: bytes, dest: str) -> Path: """Extract *zip_bytes* into *dest* and return the path to the atomics/ dir.""" # Call _safe_extract_zip() _safe_extract_zip(zip_bytes, dest) # Assign atomics_dir = Path(dest) / _ZIP_ROOT_PREFIX / "atomics" atomics_dir = Path(dest) / _ZIP_ROOT_PREFIX / "atomics" # Check: not atomics_dir.is_dir() if not atomics_dir.is_dir(): # Raise FileNotFoundError raise FileNotFoundError( f"Expected atomics directory not found at {atomics_dir}" ) # Return atomics_dir return atomics_dir # Define function _parse_yaml_files def _parse_yaml_files(atomics_dir: Path) -> list[dict]: """Walk the atomics directory and parse all technique YAML files. Returns a flat list of dicts, each representing a single atomic test with the following keys:: technique_id, index, name, description, platforms, executor_type, command, source_url """ # Assign results = [] results: list[dict] = [] # Assign yaml_files = sorted(atomics_dir.glob("T*/T*.yaml")) yaml_files = sorted(atomics_dir.glob("T*/T*.yaml")) # Log info: "Found %d YAML files to parse", len(yaml_files logger.info("Found %d YAML files to parse", len(yaml_files)) # Iterate over yaml_files for yaml_path in yaml_files: # Assign technique_id = yaml_path.stem # e.g. "T1059.001" technique_id = yaml_path.stem # e.g. "T1059.001" # Attempt the following; catch errors below try: # Open context manager with open(yaml_path, "r", encoding="utf-8") as fh: # Assign data = yaml.safe_load(fh) data = yaml.safe_load(fh) # Handle Exception except Exception as exc: # Log warning: "Failed to parse %s: %s", yaml_path, exc logger.warning("Failed to parse %s: %s", yaml_path, exc) # Skip to the next loop iteration continue # Check: not data or "atomic_tests" not in data if not data or "atomic_tests" not in data: # Skip to the next loop iteration continue # Iterate over enumerate(data["atomic_tests"]) for idx, test in enumerate(data["atomic_tests"]): # Assign name = test.get("name", "").strip() name = test.get("name", "").strip() # Assign description = test.get("description", "").strip() description = test.get("description", "").strip() # Assign platforms = test.get("supported_platforms", []) platforms = test.get("supported_platforms", []) # Assign executor = test.get("executor", {}) executor = test.get("executor", {}) # Assign executor_type = executor.get("name", "") if isinstance(executor, dict) else "" executor_type = executor.get("name", "") if isinstance(executor, dict) else "" # Assign command = executor.get("command", "") if isinstance(executor, dict) else "" command = executor.get("command", "") if isinstance(executor, dict) else "" # Build an atomic_test_id in the format "T1059.001-0" atomic_test_id = f"{technique_id}-{idx}" # Assign source_url = ( source_url = ( f"https://github.com/redcanaryco/atomic-red-team/blob/master" f"/atomics/{technique_id}/{technique_id}.yaml" ) # Call results.append() results.append({ # Literal argument value "technique_id": technique_id, # Literal argument value "index": idx, # Literal argument value "atomic_test_id": atomic_test_id, # Literal argument value "name": name, # Literal argument value "description": description, # Literal argument value "platforms": ", ".join(platforms) if isinstance(platforms, list) else str(platforms), # Literal argument value "executor_type": executor_type, # Literal argument value "command": command[:4000] if command else None, # cap at 4k chars # Literal argument value "source_url": source_url, }) # Log info: "Parsed %d atomic tests total", len(results logger.info("Parsed %d atomic tests total", len(results)) # Return results return results # --------------------------------------------------------------------------- # Public API # --------------------------------------------------------------------------- def import_atomic_red_team(db: Session) -> dict: """Download and import Atomic Red Team tests as TestTemplates. Parameters ---------- db : Session Active SQLAlchemy database session. Returns: ------- dict Summary with keys ``created``, ``skipped_existing``, ``yaml_files_parsed``, ``total_tests_parsed``. """ # Assign tmp_dir = tempfile.mkdtemp(prefix="aegis_atomic_") tmp_dir = tempfile.mkdtemp(prefix="aegis_atomic_") # Attempt the following; catch errors below try: # Assign zip_bytes = _download_zip() zip_bytes = _download_zip() # Assign atomics_dir = _extract_zip(zip_bytes, tmp_dir) atomics_dir = _extract_zip(zip_bytes, tmp_dir) # Assign parsed_tests = _parse_yaml_files(atomics_dir) parsed_tests = _parse_yaml_files(atomics_dir) # Always execute this cleanup block finally: # Always clean up shutil.rmtree(tmp_dir, ignore_errors=True) # Log info: "Cleaned up temp directory %s", tmp_dir logger.info("Cleaned up temp directory %s", tmp_dir) # Pre-load existing atomic_test_ids for dedup existing_ids: set[str] = { row[0] for row in db.query(TestTemplate.atomic_test_id) # Chain .filter() call .filter(TestTemplate.atomic_test_id.isnot(None)) # Chain .all() call .all() } # Assign created = 0 created = 0 # Assign skipped = 0 skipped = 0 new_technique_ids: set[str] = set() # Iterate over parsed_tests for item in parsed_tests: # Check: item["atomic_test_id"] in existing_ids if item["atomic_test_id"] in existing_ids: # Assign skipped = 1 skipped += 1 # Skip to the next loop iteration continue # Assign template = TestTemplate( template = TestTemplate( # Keyword argument: mitre_technique_id mitre_technique_id=item["technique_id"], # Keyword argument: name name=item["name"][:500] if item["name"] else f"Atomic Test {item['atomic_test_id']}", # Keyword argument: description description=item["description"][:2000] if item["description"] else None, # Keyword argument: source source="atomic_red_team", # Keyword argument: source_url source_url=item["source_url"], # Keyword argument: attack_procedure attack_procedure=item["command"], # Keyword argument: platform platform=item["platforms"], # Keyword argument: tool_suggested tool_suggested=item["executor_type"] if item["executor_type"] else None, # Keyword argument: atomic_test_id atomic_test_id=item["atomic_test_id"], # Keyword argument: is_active is_active=True, ) # Stage new record(s) for database insertion db.add(template) # Call existing_ids.add() existing_ids.add(item["atomic_test_id"]) new_technique_ids.add(item["technique_id"]) created += 1 if new_technique_ids: db.query(Technique).filter( Technique.mitre_id.in_(new_technique_ids) ).update({"review_required": True}, synchronize_session=False) db.commit() # Count distinct YAML files by technique_id yaml_files_count = len({t["technique_id"] for t in parsed_tests}) # Assign summary = { summary = { # Literal argument value "created": created, # Literal argument value "skipped_existing": skipped, # Literal argument value "yaml_files_parsed": yaml_files_count, # Literal argument value "total_tests_parsed": len(parsed_tests), } # Log info: logger.info( # Literal argument value "Atomic Red Team import complete — created=%d, skipped=%d, " # Literal argument value "yaml_files=%d, total_tests=%d", created, skipped, yaml_files_count, len(parsed_tests), ) # Audit log (system action) log_action( db, # Keyword argument: user_id user_id=None, # Keyword argument: action action="import_atomic_red_team", # Keyword argument: entity_type entity_type="test_template", # Keyword argument: entity_id entity_id=None, # Keyword argument: details details=summary, ) # Commit all pending changes to the database db.commit() # Return summary return summary