"""Elastic Detection Rules import service. Downloads the Elastic detection-rules repository ZIP from GitHub, parses every ``.toml`` rule file under ``rules/``, extracts MITRE ATT&CK mappings, and creates :class:`DetectionRule` records in the database. Strategy -------- 1. Download the full repo as a ZIP archive. 2. Extract into a temporary directory. 3. Walk all ``.toml`` files under ``rules/``. 4. Parse each TOML file — extract rule name, description, query (KQL), severity, and MITRE ATT&CK threat mappings. 5. Create / skip ``DetectionRule`` rows keyed by ``(source, source_id)``. 6. Clean up. Idempotency ----------- Running the import twice does **not** create duplicates. Existing rules are identified by ``source = "elastic"`` + ``source_id`` (the TOML filename). """ import io import logging import shutil import tempfile import zipfile from datetime import datetime from pathlib import Path import requests as _requests from sqlalchemy.orm import Session from app.models.detection_rule import DetectionRule from app.models.data_source import DataSource from app.services.audit_service import log_action logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Constants # --------------------------------------------------------------------------- ELASTIC_ZIP_URL = ( "https://github.com/elastic/detection-rules" "/archive/refs/heads/main.zip" ) _DOWNLOAD_TIMEOUT = 300 _ZIP_ROOT_PREFIX = "detection-rules-main" # Severity normalisation _SEVERITY_MAP = { "informational": "informational", "low": "low", "medium": "medium", "high": "high", "critical": "critical", } # --------------------------------------------------------------------------- # Internal helpers # --------------------------------------------------------------------------- def _download_zip(url: str = ELASTIC_ZIP_URL) -> bytes: """Download the Elastic Detection Rules ZIP and return raw bytes.""" logger.info("Downloading Elastic Detection Rules ZIP from %s …", url) resp = _requests.get(url, timeout=_DOWNLOAD_TIMEOUT, stream=True) resp.raise_for_status() content = resp.content logger.info("Downloaded %.1f MB", len(content) / (1024 * 1024)) return content def _safe_extract_zip(zip_bytes: bytes, dest: str) -> None: """Extract *zip_bytes* into *dest* with Zip Slip and Zip Bomb protection. Raises :class:`ValueError` if any member tries to escape the target directory (path traversal / Zip Slip) or if the archive exceeds the safety limits. """ # Maximum uncompressed size: 500 MB — prevents zip-bomb DoS _MAX_UNCOMPRESSED_SIZE = 500 * 1024 * 1024 # Maximum number of entries _MAX_ENTRIES = 50_000 dest_path = Path(dest).resolve() with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf: entries = zf.infolist() if len(entries) > _MAX_ENTRIES: raise ValueError( f"ZIP archive contains {len(entries)} entries " f"(limit: {_MAX_ENTRIES}) — possible zip bomb" ) total_size = sum(info.file_size for info in entries) if total_size > _MAX_UNCOMPRESSED_SIZE: raise ValueError( f"ZIP uncompressed size {total_size / (1024 * 1024):.0f} MB " f"exceeds limit of {_MAX_UNCOMPRESSED_SIZE / (1024 * 1024):.0f} MB" ) for member in entries: target = (dest_path / member.filename).resolve() if not target.is_relative_to(dest_path): raise ValueError( f"Zip Slip detected — member '{member.filename}' " f"resolves outside target directory" ) zf.extractall(dest) def _extract_zip(zip_bytes: bytes, dest: str) -> Path: """Extract *zip_bytes* into *dest* and return rules/ dir.""" _safe_extract_zip(zip_bytes, dest) rules_dir = Path(dest) / _ZIP_ROOT_PREFIX / "rules" if not rules_dir.is_dir(): raise FileNotFoundError( f"Expected rules directory not found at {rules_dir}" ) return rules_dir def _parse_toml_safe(path: Path) -> dict | None: """Parse a TOML file. Uses the ``toml`` library.""" try: import toml with open(path, "r", encoding="utf-8") as fh: return toml.load(fh) except Exception as exc: logger.debug("Failed to parse %s: %s", path, exc) return None def _extract_mitre_techniques(threat_list: list) -> list[str]: """Extract MITRE technique IDs from Elastic's ``rule.threat`` array. Each entry looks like:: [[rule.threat]] framework = "MITRE ATT&CK" [rule.threat.tactic] name = "Credential Access" id = "TA0006" [[rule.threat.technique]] name = "OS Credential Dumping" id = "T1003" [[rule.threat.technique.subtechnique]] name = "LSASS Memory" id = "T1003.001" """ technique_ids = [] if not isinstance(threat_list, list): return technique_ids for threat_entry in threat_list: if not isinstance(threat_entry, dict): continue # Skip non-MITRE frameworks framework = threat_entry.get("framework", "") if "MITRE" not in str(framework).upper(): continue techniques = threat_entry.get("technique", []) if not isinstance(techniques, list): continue for tech in techniques: if not isinstance(tech, dict): continue tech_id = tech.get("id", "") if tech_id and str(tech_id).upper().startswith("T"): technique_ids.append(str(tech_id).upper()) # Check subtechniques subtechniques = tech.get("subtechnique", []) if isinstance(subtechniques, list): for subtech in subtechniques: if isinstance(subtech, dict): sub_id = subtech.get("id", "") if sub_id and str(sub_id).upper().startswith("T"): technique_ids.append(str(sub_id).upper()) return list(set(technique_ids)) def _parse_elastic_rules(rules_dir: Path) -> list[dict]: """Walk the rules directory and parse all TOML files. Returns a flat list of dicts, one per (rule, technique) combination. """ results: list[dict] = [] toml_files = sorted(rules_dir.rglob("*.toml")) logger.info("Found %d TOML files to parse", len(toml_files)) for toml_path in toml_files: data = _parse_toml_safe(toml_path) if not data: continue rule = data.get("rule", {}) if not isinstance(rule, dict): continue name = rule.get("name", "").strip() if not name: continue # Extract MITRE technique IDs threat_list = rule.get("threat", []) technique_ids = _extract_mitre_techniques(threat_list) if not technique_ids: continue description = rule.get("description", "") query = rule.get("query", "") severity = _SEVERITY_MAP.get(str(rule.get("severity", "")).lower()) rule_type = rule.get("type", "query") # query, eql, threshold, etc. # Determine rule format based on type if rule_type == "eql": rule_format = "eql" elif rule_type == "esql": rule_format = "esql" else: rule_format = "kql" # Use filename as source_id source_id = toml_path.name # Read raw content try: with open(toml_path, "r", encoding="utf-8") as fh: raw_content = fh.read() except Exception: raw_content = query or str(data) # Build source URL relative = str(toml_path.relative_to(rules_dir.parent)).replace("\\", "/") source_url = ( f"https://github.com/elastic/detection-rules/blob/main/{relative}" ) # One entry per technique for tech_id in technique_ids: results.append({ "mitre_technique_id": tech_id, "title": name[:500], "description": str(description)[:2000] if description else None, "source_id": source_id, "source_url": source_url, "rule_content": query[:50000] if query else raw_content[:50000], "rule_format": rule_format, "severity": severity, "platforms": _infer_platforms(rules_dir, toml_path), }) logger.info("Parsed %d (rule, technique) pairs total", len(results)) return results def _infer_platforms(rules_dir: Path, toml_path: Path) -> list[str] | None: """Infer platforms from the rule's directory structure. Elastic organizes rules by OS: rules/windows/, rules/linux/, etc. """ relative = toml_path.relative_to(rules_dir) parts = [p.lower() for p in relative.parts] platforms = [] if "windows" in parts: platforms.append("windows") if "linux" in parts: platforms.append("linux") if "macos" in parts: platforms.append("macos") return platforms if platforms else None # --------------------------------------------------------------------------- # Public API # --------------------------------------------------------------------------- def sync(db: Session) -> dict: """Download and import Elastic detection rules. Returns a summary dict with ``created``, ``skipped_existing``, ``total_parsed``. """ tmp_dir = tempfile.mkdtemp(prefix="aegis_elastic_") try: zip_bytes = _download_zip() rules_dir = _extract_zip(zip_bytes, tmp_dir) parsed_rules = _parse_elastic_rules(rules_dir) finally: shutil.rmtree(tmp_dir, ignore_errors=True) logger.info("Cleaned up temp directory %s", tmp_dir) # Pre-load existing source_ids for dedup existing_ids: set[str] = { row[0] for row in db.query(DetectionRule.source_id) .filter(DetectionRule.source == "elastic") .filter(DetectionRule.source_id.isnot(None)) .all() } created = 0 skipped = 0 for item in parsed_rules: if item["source_id"] in existing_ids: skipped += 1 continue rule = DetectionRule( mitre_technique_id=item["mitre_technique_id"], title=item["title"], description=item["description"], source="elastic", source_id=item["source_id"], source_url=item["source_url"], rule_content=item["rule_content"], rule_format=item["rule_format"], severity=item["severity"], platforms=item["platforms"], is_active=True, ) db.add(rule) existing_ids.add(item["source_id"]) created += 1 db.commit() summary = { "created": created, "skipped_existing": skipped, "total_parsed": len(parsed_rules), } # Update DataSource record ds = db.query(DataSource).filter(DataSource.name == "elastic_rules").first() if ds: ds.last_sync_at = datetime.utcnow() ds.last_sync_status = "success" ds.last_sync_stats = summary db.commit() logger.info("Elastic import complete — %s", summary) log_action(db, user_id=None, action="import_elastic_rules", entity_type="detection_rule", entity_id=None, details=summary) db.commit() return summary