"""Elastic Detection Rules import service. Downloads the Elastic detection-rules repository ZIP from GitHub, parses every ``.toml`` rule file under ``rules/``, extracts MITRE ATT&CK mappings, and creates :class:`DetectionRule` records in the database. Strategy -------- 1. Download the full repo as a ZIP archive. 2. Extract into a temporary directory. 3. Walk all ``.toml`` files under ``rules/``. 4. Parse each TOML file — extract rule name, description, query (KQL), severity, and MITRE ATT&CK threat mappings. 5. Create / skip ``DetectionRule`` rows keyed by ``(source, source_id)``. 6. Clean up. Idempotency ----------- Running the import twice does **not** create duplicates. Existing rules are identified by ``source = "elastic"`` + ``source_id`` (the TOML filename). """ # Import io import io # Import logging import logging # Import shutil import shutil # Import tempfile import tempfile # Import zipfile import zipfile # Import datetime from datetime from datetime import datetime # Import Path from pathlib from pathlib import Path # Import requests import requests as _requests # Import Session from sqlalchemy.orm from sqlalchemy.orm import Session # Import DataSource from app.models.data_source from app.models.data_source import DataSource # Import DetectionRule from app.models.detection_rule from app.models.detection_rule import DetectionRule from app.models.technique import Technique from app.services.audit_service import log_action # Assign logger = logging.getLogger(__name__) logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Constants # --------------------------------------------------------------------------- ELASTIC_ZIP_URL = ( # Literal argument value "https://github.com/elastic/detection-rules" # Literal argument value "/archive/refs/heads/main.zip" ) # Assign _DOWNLOAD_TIMEOUT = 300 _DOWNLOAD_TIMEOUT = 300 # Assign _ZIP_ROOT_PREFIX = "detection-rules-main" _ZIP_ROOT_PREFIX = "detection-rules-main" # Safety limits for ZIP extraction — prevent zip-bomb DoS _MAX_UNCOMPRESSED_SIZE = 500 * 1024 * 1024 # 500 MB # Assign _MAX_ENTRIES = 50_000 _MAX_ENTRIES = 50_000 # Severity normalisation _SEVERITY_MAP = { # Literal argument value "informational": "informational", # Literal argument value "low": "low", # Literal argument value "medium": "medium", # Literal argument value "high": "high", # Literal argument value "critical": "critical", } # --------------------------------------------------------------------------- # Internal helpers # --------------------------------------------------------------------------- def _download_zip(url: str = ELASTIC_ZIP_URL) -> bytes: """Download the Elastic Detection Rules ZIP and return raw bytes.""" # Log info: "Downloading Elastic Detection Rules ZIP from %s … logger.info("Downloading Elastic Detection Rules ZIP from %s …", url) # Assign resp = _requests.get(url, timeout=_DOWNLOAD_TIMEOUT, stream=True) resp = _requests.get(url, timeout=_DOWNLOAD_TIMEOUT, stream=True) # Call resp.raise_for_status() resp.raise_for_status() # Assign content = resp.content content = resp.content # Log info: "Downloaded %.1f MB", len(content) / (1024 * 1024 logger.info("Downloaded %.1f MB", len(content) / (1024 * 1024)) # Return content return content # Define function _safe_extract_zip def _safe_extract_zip(zip_bytes: bytes, dest: str) -> None: """Extract *zip_bytes* into *dest* with Zip Slip and Zip Bomb protection. Raises :class:`ValueError` if any member tries to escape the target directory (path traversal / Zip Slip) or if the archive exceeds the safety limits. """ # Assign dest_path = Path(dest).resolve() dest_path = Path(dest).resolve() # Open context manager with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf: # Assign entries = zf.infolist() entries = zf.infolist() # Check: len(entries) > _MAX_ENTRIES if len(entries) > _MAX_ENTRIES: # Raise ValueError raise ValueError( f"ZIP archive contains {len(entries)} entries " f"(limit: {_MAX_ENTRIES}) — possible zip bomb" ) # Assign total_size = sum(info.file_size for info in entries) total_size = sum(info.file_size for info in entries) # Check: total_size > _MAX_UNCOMPRESSED_SIZE if total_size > _MAX_UNCOMPRESSED_SIZE: # Raise ValueError raise ValueError( f"ZIP uncompressed size {total_size / (1024 * 1024):.0f} MB " f"exceeds limit of {_MAX_UNCOMPRESSED_SIZE / (1024 * 1024):.0f} MB" ) # Iterate over entries — validate and extract each member individually for member in entries: # Assign target = (dest_path / member.filename).resolve() target = (dest_path / member.filename).resolve() # Check: not target.is_relative_to(dest_path) if not target.is_relative_to(dest_path): # Raise ValueError raise ValueError( f"Zip Slip detected — member '{member.filename}' " f"resolves outside target directory" ) zf.extract(member, dest) # Define function _extract_zip def _extract_zip(zip_bytes: bytes, dest: str) -> Path: """Extract *zip_bytes* into *dest* and return rules/ dir.""" # Call _safe_extract_zip() _safe_extract_zip(zip_bytes, dest) # Assign rules_dir = Path(dest) / _ZIP_ROOT_PREFIX / "rules" rules_dir = Path(dest) / _ZIP_ROOT_PREFIX / "rules" # Check: not rules_dir.is_dir() if not rules_dir.is_dir(): # Raise FileNotFoundError raise FileNotFoundError( f"Expected rules directory not found at {rules_dir}" ) # Return rules_dir return rules_dir # Define function _parse_toml_safe def _parse_toml_safe(path: Path) -> dict | None: """Parse a TOML file. Uses the ``toml`` library.""" # Attempt the following; catch errors below try: # Import toml import toml # Open context manager with open(path, "r", encoding="utf-8") as fh: # Return toml.load(fh) return toml.load(fh) # Handle Exception except Exception as exc: # Log debug: "Failed to parse %s: %s", path, exc logger.debug("Failed to parse %s: %s", path, exc) # Return None return None # Define function _extract_mitre_techniques def _extract_mitre_techniques(threat_list: list) -> list[str]: """Extract MITRE technique IDs from Elastic's ``rule.threat`` array. Each entry looks like:: [[rule.threat]] framework = "MITRE ATT&CK" [rule.threat.tactic] name = "Credential Access" id = "TA0006" [[rule.threat.technique]] name = "OS Credential Dumping" id = "T1003" [[rule.threat.technique.subtechnique]] name = "LSASS Memory" id = "T1003.001" """ # Assign technique_ids = [] technique_ids = [] # Check: not isinstance(threat_list, list) if not isinstance(threat_list, list): # Return technique_ids return technique_ids # Iterate over threat_list for threat_entry in threat_list: # Check: not isinstance(threat_entry, dict) if not isinstance(threat_entry, dict): # Skip to the next loop iteration continue # Skip non-MITRE frameworks framework = threat_entry.get("framework", "") # Check: "MITRE" not in str(framework).upper() if "MITRE" not in str(framework).upper(): # Skip to the next loop iteration continue # Assign techniques = threat_entry.get("technique", []) techniques = threat_entry.get("technique", []) # Check: not isinstance(techniques, list) if not isinstance(techniques, list): # Skip to the next loop iteration continue # Iterate over techniques for tech in techniques: # Check: not isinstance(tech, dict) if not isinstance(tech, dict): # Skip to the next loop iteration continue # Assign tech_id = tech.get("id", "") tech_id = tech.get("id", "") # Check: tech_id and str(tech_id).upper().startswith("T") if tech_id and str(tech_id).upper().startswith("T"): # Call technique_ids.append() technique_ids.append(str(tech_id).upper()) # Check subtechniques subtechniques = tech.get("subtechnique", []) # Check: isinstance(subtechniques, list) if isinstance(subtechniques, list): # Iterate over subtechniques for subtech in subtechniques: # Check: isinstance(subtech, dict) if isinstance(subtech, dict): # Assign sub_id = subtech.get("id", "") sub_id = subtech.get("id", "") # Check: sub_id and str(sub_id).upper().startswith("T") if sub_id and str(sub_id).upper().startswith("T"): # Call technique_ids.append() technique_ids.append(str(sub_id).upper()) # Return list(set(technique_ids)) return list(set(technique_ids)) # Define function _parse_elastic_rules def _parse_elastic_rules(rules_dir: Path) -> list[dict]: """Walk the rules directory and parse all TOML files. Returns a flat list of dicts, one per (rule, technique) combination. """ # Assign results = [] results: list[dict] = [] # Assign toml_files = sorted(rules_dir.rglob("*.toml")) toml_files = sorted(rules_dir.rglob("*.toml")) # Log info: "Found %d TOML files to parse", len(toml_files logger.info("Found %d TOML files to parse", len(toml_files)) # Iterate over toml_files for toml_path in toml_files: # Assign data = _parse_toml_safe(toml_path) data = _parse_toml_safe(toml_path) # Check: not data if not data: # Skip to the next loop iteration continue # Assign rule = data.get("rule", {}) rule = data.get("rule", {}) # Check: not isinstance(rule, dict) if not isinstance(rule, dict): # Skip to the next loop iteration continue # Assign name = rule.get("name", "").strip() name = rule.get("name", "").strip() # Check: not name if not name: # Skip to the next loop iteration continue # Extract MITRE technique IDs threat_list = rule.get("threat", []) # Assign technique_ids = _extract_mitre_techniques(threat_list) technique_ids = _extract_mitre_techniques(threat_list) # Check: not technique_ids if not technique_ids: # Skip to the next loop iteration continue # Assign description = rule.get("description", "") description = rule.get("description", "") # Assign query = rule.get("query", "") query = rule.get("query", "") # Assign severity = _SEVERITY_MAP.get(str(rule.get("severity", "")).lower()) severity = _SEVERITY_MAP.get(str(rule.get("severity", "")).lower()) # Assign rule_type = rule.get("type", "query") # query, eql, threshold, etc. rule_type = rule.get("type", "query") # query, eql, threshold, etc. # Determine rule format based on type if rule_type == "eql": # Assign rule_format = "eql" rule_format = "eql" # Alternative: rule_type == "esql" elif rule_type == "esql": # Assign rule_format = "esql" rule_format = "esql" # Fallback: handle remaining cases else: # Assign rule_format = "kql" rule_format = "kql" # Use filename as source_id source_id = toml_path.name # Read raw content try: # Open context manager with open(toml_path, "r", encoding="utf-8") as fh: # Assign raw_content = fh.read() raw_content = fh.read() # Handle Exception except Exception: # Assign raw_content = query or str(data) raw_content = query or str(data) # Build source URL relative = str(toml_path.relative_to(rules_dir.parent)).replace("\\", "/") # Assign source_url = ( source_url = ( f"https://github.com/elastic/detection-rules/blob/main/{relative}" ) # One entry per technique for tech_id in technique_ids: # Call results.append() results.append({ # Literal argument value "mitre_technique_id": tech_id, # Literal argument value "title": name[:500], # Literal argument value "description": str(description)[:2000] if description else None, # Literal argument value "source_id": source_id, # Literal argument value "source_url": source_url, # Literal argument value "rule_content": query[:50000] if query else raw_content[:50000], # Literal argument value "rule_format": rule_format, # Literal argument value "severity": severity, # Literal argument value "platforms": _infer_platforms(rules_dir, toml_path), }) # Log info: "Parsed %d (rule, technique) pairs total", len(res logger.info("Parsed %d (rule, technique) pairs total", len(results)) # Return results return results # Define function _infer_platforms def _infer_platforms(rules_dir: Path, toml_path: Path) -> list[str] | None: """Infer platforms from the rule's directory structure. Elastic organizes rules by OS: rules/windows/, rules/linux/, etc. """ # Assign relative = toml_path.relative_to(rules_dir) relative = toml_path.relative_to(rules_dir) # Assign parts = [p.lower() for p in relative.parts] parts = [p.lower() for p in relative.parts] # Assign platforms = [] platforms = [] # Check: "windows" in parts if "windows" in parts: # Call platforms.append() platforms.append("windows") # Check: "linux" in parts if "linux" in parts: # Call platforms.append() platforms.append("linux") # Check: "macos" in parts if "macos" in parts: # Call platforms.append() platforms.append("macos") # Return platforms if platforms else None return platforms if platforms else None # --------------------------------------------------------------------------- # Public API # --------------------------------------------------------------------------- def sync(db: Session) -> dict: """Download and import Elastic detection rules. Returns a summary dict with ``created``, ``skipped_existing``, ``total_parsed``. """ # Assign tmp_dir = tempfile.mkdtemp(prefix="aegis_elastic_") tmp_dir = tempfile.mkdtemp(prefix="aegis_elastic_") # Attempt the following; catch errors below try: # Assign zip_bytes = _download_zip() zip_bytes = _download_zip() # Assign rules_dir = _extract_zip(zip_bytes, tmp_dir) rules_dir = _extract_zip(zip_bytes, tmp_dir) # Assign parsed_rules = _parse_elastic_rules(rules_dir) parsed_rules = _parse_elastic_rules(rules_dir) # Always execute this cleanup block finally: # Call shutil.rmtree() shutil.rmtree(tmp_dir, ignore_errors=True) # Log info: "Cleaned up temp directory %s", tmp_dir logger.info("Cleaned up temp directory %s", tmp_dir) # Pre-load existing source_ids for dedup existing_ids: set[str] = { row[0] for row in db.query(DetectionRule.source_id) # Chain .filter() call .filter(DetectionRule.source == "elastic") # Chain .filter() call .filter(DetectionRule.source_id.isnot(None)) # Chain .all() call .all() } # Assign created = 0 created = 0 # Assign skipped = 0 skipped = 0 new_technique_ids: set[str] = set() # Iterate over parsed_rules for item in parsed_rules: # Check: item["source_id"] in existing_ids if item["source_id"] in existing_ids: # Assign skipped = 1 skipped += 1 # Skip to the next loop iteration continue # Assign rule = DetectionRule( rule = DetectionRule( # Keyword argument: mitre_technique_id mitre_technique_id=item["mitre_technique_id"], # Keyword argument: title title=item["title"], # Keyword argument: description description=item["description"], # Keyword argument: source source="elastic", # Keyword argument: source_id source_id=item["source_id"], # Keyword argument: source_url source_url=item["source_url"], # Keyword argument: rule_content rule_content=item["rule_content"], # Keyword argument: rule_format rule_format=item["rule_format"], # Keyword argument: severity severity=item["severity"], # Keyword argument: platforms platforms=item["platforms"], # Keyword argument: is_active is_active=True, ) # Stage new record(s) for database insertion db.add(rule) # Call existing_ids.add() existing_ids.add(item["source_id"]) new_technique_ids.add(item["mitre_technique_id"]) created += 1 # Flag techniques that received new rules for review if new_technique_ids: db.query(Technique).filter( Technique.mitre_id.in_(new_technique_ids) ).update({"review_required": True}, synchronize_session=False) db.commit() # Assign summary = { summary = { # Literal argument value "created": created, # Literal argument value "skipped_existing": skipped, # Literal argument value "total_parsed": len(parsed_rules), } # Update DataSource record ds = db.query(DataSource).filter(DataSource.name == "elastic_rules").first() # Check: ds if ds: # Assign ds.last_sync_at = datetime.utcnow() ds.last_sync_at = datetime.utcnow() # Assign ds.last_sync_status = "success" ds.last_sync_status = "success" # Assign ds.last_sync_stats = summary ds.last_sync_stats = summary # Commit all pending changes to the database db.commit() # Log info: "Elastic import complete — %s", summary logger.info("Elastic import complete — %s", summary) # Call log_action() log_action(db, user_id=None, action="import_elastic_rules", # Keyword argument: entity_type entity_type="detection_rule", entity_id=None, details=summary) # Commit all pending changes to the database db.commit() # Return summary return summary