"""Atomic Red Team import service. Downloads the Atomic Red Team repository ZIP from GitHub, parses every ``atomics/T*/T*.yaml`` file, and upserts :class:`TestTemplate` records into the database. Strategy -------- The GitHub REST API without authentication only allows 60 req/hour. Since the Atomic Red Team repo contains 1 500+ YAML files we avoid per-file requests entirely. Instead we: 1. Download the full repo as a ZIP archive (~40 MB). 2. Extract in a temporary directory. 3. Walk ``atomics/T*/T*.yaml`` files parsing them with PyYAML. 4. Create / update ``TestTemplate`` rows keyed by ``atomic_test_id``. 5. Clean up the temporary directory. Idempotency ----------- Running the import twice does **not** create duplicates. Existing templates are identified by their ``atomic_test_id`` and simply skipped. """ import io import logging import os import shutil import tempfile import zipfile from pathlib import Path import requests as _requests import yaml from sqlalchemy.orm import Session from app.models.test_template import TestTemplate from app.services.audit_service import log_action logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Constants # --------------------------------------------------------------------------- ATOMIC_RT_ZIP_URL = ( "https://github.com/redcanaryco/atomic-red-team" "/archive/refs/heads/master.zip" ) # Request timeout for the ZIP download (seconds) _DOWNLOAD_TIMEOUT = 300 # Top-level directory name inside the ZIP _ZIP_ROOT_PREFIX = "atomic-red-team-master" # --------------------------------------------------------------------------- # Internal helpers # --------------------------------------------------------------------------- def _download_zip(url: str = ATOMIC_RT_ZIP_URL) -> bytes: """Download the Atomic Red Team ZIP and return its raw bytes.""" logger.info("Downloading Atomic Red Team ZIP from %s …", url) resp = _requests.get(url, timeout=_DOWNLOAD_TIMEOUT, stream=True) resp.raise_for_status() content = resp.content logger.info("Downloaded %.1f MB", len(content) / (1024 * 1024)) return content def _safe_extract_zip(zip_bytes: bytes, dest: str) -> None: """Extract *zip_bytes* into *dest* with Zip Slip and Zip Bomb protection. Raises :class:`ValueError` if any member tries to escape the target directory (path traversal / Zip Slip) or if the archive exceeds the safety limits. """ # Maximum uncompressed size: 500 MB — prevents zip-bomb DoS _MAX_UNCOMPRESSED_SIZE = 500 * 1024 * 1024 # Maximum number of entries _MAX_ENTRIES = 50_000 dest_path = Path(dest).resolve() with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf: entries = zf.infolist() if len(entries) > _MAX_ENTRIES: raise ValueError( f"ZIP archive contains {len(entries)} entries " f"(limit: {_MAX_ENTRIES}) — possible zip bomb" ) total_size = sum(info.file_size for info in entries) if total_size > _MAX_UNCOMPRESSED_SIZE: raise ValueError( f"ZIP uncompressed size {total_size / (1024 * 1024):.0f} MB " f"exceeds limit of {_MAX_UNCOMPRESSED_SIZE / (1024 * 1024):.0f} MB" ) for member in entries: target = (dest_path / member.filename).resolve() if not target.is_relative_to(dest_path): raise ValueError( f"Zip Slip detected — member '{member.filename}' " f"resolves outside target directory" ) zf.extractall(dest) def _extract_zip(zip_bytes: bytes, dest: str) -> Path: """Extract *zip_bytes* into *dest* and return the path to the atomics/ dir.""" _safe_extract_zip(zip_bytes, dest) atomics_dir = Path(dest) / _ZIP_ROOT_PREFIX / "atomics" if not atomics_dir.is_dir(): raise FileNotFoundError( f"Expected atomics directory not found at {atomics_dir}" ) return atomics_dir def _parse_yaml_files(atomics_dir: Path) -> list[dict]: """Walk the atomics directory and parse all technique YAML files. Returns a flat list of dicts, each representing a single atomic test with the following keys:: technique_id, index, name, description, platforms, executor_type, command, source_url """ results: list[dict] = [] yaml_files = sorted(atomics_dir.glob("T*/T*.yaml")) logger.info("Found %d YAML files to parse", len(yaml_files)) for yaml_path in yaml_files: technique_id = yaml_path.stem # e.g. "T1059.001" try: with open(yaml_path, "r", encoding="utf-8") as fh: data = yaml.safe_load(fh) except Exception as exc: logger.warning("Failed to parse %s: %s", yaml_path, exc) continue if not data or "atomic_tests" not in data: continue for idx, test in enumerate(data["atomic_tests"]): name = test.get("name", "").strip() description = test.get("description", "").strip() platforms = test.get("supported_platforms", []) executor = test.get("executor", {}) executor_type = executor.get("name", "") if isinstance(executor, dict) else "" command = executor.get("command", "") if isinstance(executor, dict) else "" # Build an atomic_test_id in the format "T1059.001-0" atomic_test_id = f"{technique_id}-{idx}" source_url = ( f"https://github.com/redcanaryco/atomic-red-team/blob/master" f"/atomics/{technique_id}/{technique_id}.yaml" ) results.append({ "technique_id": technique_id, "index": idx, "atomic_test_id": atomic_test_id, "name": name, "description": description, "platforms": ", ".join(platforms) if isinstance(platforms, list) else str(platforms), "executor_type": executor_type, "command": command[:4000] if command else None, # cap at 4k chars "source_url": source_url, }) logger.info("Parsed %d atomic tests total", len(results)) return results # --------------------------------------------------------------------------- # Public API # --------------------------------------------------------------------------- def import_atomic_red_team(db: Session) -> dict: """Download and import Atomic Red Team tests as TestTemplates. Parameters ---------- db : Session Active SQLAlchemy database session. Returns ------- dict Summary with keys ``created``, ``skipped_existing``, ``yaml_files_parsed``, ``total_tests_parsed``. """ tmp_dir = tempfile.mkdtemp(prefix="aegis_atomic_") try: zip_bytes = _download_zip() atomics_dir = _extract_zip(zip_bytes, tmp_dir) parsed_tests = _parse_yaml_files(atomics_dir) finally: # Always clean up shutil.rmtree(tmp_dir, ignore_errors=True) logger.info("Cleaned up temp directory %s", tmp_dir) # Pre-load existing atomic_test_ids for dedup existing_ids: set[str] = { row[0] for row in db.query(TestTemplate.atomic_test_id) .filter(TestTemplate.atomic_test_id.isnot(None)) .all() } created = 0 skipped = 0 for item in parsed_tests: if item["atomic_test_id"] in existing_ids: skipped += 1 continue template = TestTemplate( mitre_technique_id=item["technique_id"], name=item["name"][:500] if item["name"] else f"Atomic Test {item['atomic_test_id']}", description=item["description"][:2000] if item["description"] else None, source="atomic_red_team", source_url=item["source_url"], attack_procedure=item["command"], platform=item["platforms"], tool_suggested=item["executor_type"] if item["executor_type"] else None, atomic_test_id=item["atomic_test_id"], is_active=True, ) db.add(template) existing_ids.add(item["atomic_test_id"]) created += 1 db.commit() # Count distinct YAML files by technique_id yaml_files_count = len({t["technique_id"] for t in parsed_tests}) summary = { "created": created, "skipped_existing": skipped, "yaml_files_parsed": yaml_files_count, "total_tests_parsed": len(parsed_tests), } logger.info( "Atomic Red Team import complete — created=%d, skipped=%d, " "yaml_files=%d, total_tests=%d", created, skipped, yaml_files_count, len(parsed_tests), ) # Audit log (system action) log_action( db, user_id=None, action="import_atomic_red_team", entity_type="test_template", entity_id=None, details=summary, ) return summary