Files
Aegis/backend/app/services/elastic_import_service.py

322 lines
9.9 KiB
Python

"""Elastic Detection Rules import service.
Downloads the Elastic detection-rules repository ZIP from GitHub, parses
every ``.toml`` rule file under ``rules/``, extracts MITRE ATT&CK
mappings, and creates :class:`DetectionRule` records in the database.
Strategy
--------
1. Download the full repo as a ZIP archive.
2. Extract into a temporary directory.
3. Walk all ``.toml`` files under ``rules/``.
4. Parse each TOML file — extract rule name, description, query (KQL),
severity, and MITRE ATT&CK threat mappings.
5. Create / skip ``DetectionRule`` rows keyed by ``(source, source_id)``.
6. Clean up.
Idempotency
-----------
Running the import twice does **not** create duplicates. Existing
rules are identified by ``source = "elastic"`` + ``source_id`` (the
TOML filename).
"""
import io
import logging
import shutil
import tempfile
import zipfile
from datetime import datetime
from pathlib import Path
import requests as _requests
from sqlalchemy.orm import Session
from app.models.detection_rule import DetectionRule
from app.models.data_source import DataSource
from app.services.audit_service import log_action
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
ELASTIC_ZIP_URL = (
"https://github.com/elastic/detection-rules"
"/archive/refs/heads/main.zip"
)
_DOWNLOAD_TIMEOUT = 300
_ZIP_ROOT_PREFIX = "detection-rules-main"
# Severity normalisation
_SEVERITY_MAP = {
"informational": "informational",
"low": "low",
"medium": "medium",
"high": "high",
"critical": "critical",
}
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
def _download_zip(url: str = ELASTIC_ZIP_URL) -> bytes:
"""Download the Elastic Detection Rules ZIP and return raw bytes."""
logger.info("Downloading Elastic Detection Rules ZIP from %s", url)
resp = _requests.get(url, timeout=_DOWNLOAD_TIMEOUT, stream=True)
resp.raise_for_status()
content = resp.content
logger.info("Downloaded %.1f MB", len(content) / (1024 * 1024))
return content
def _extract_zip(zip_bytes: bytes, dest: str) -> Path:
"""Extract *zip_bytes* into *dest* and return rules/ dir."""
with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
zf.extractall(dest)
rules_dir = Path(dest) / _ZIP_ROOT_PREFIX / "rules"
if not rules_dir.is_dir():
raise FileNotFoundError(
f"Expected rules directory not found at {rules_dir}"
)
return rules_dir
def _parse_toml_safe(path: Path) -> dict | None:
"""Parse a TOML file. Uses the ``toml`` library."""
try:
import toml
with open(path, "r", encoding="utf-8") as fh:
return toml.load(fh)
except Exception as exc:
logger.debug("Failed to parse %s: %s", path, exc)
return None
def _extract_mitre_techniques(threat_list: list) -> list[str]:
"""Extract MITRE technique IDs from Elastic's ``rule.threat`` array.
Each entry looks like::
[[rule.threat]]
framework = "MITRE ATT&CK"
[rule.threat.tactic]
name = "Credential Access"
id = "TA0006"
[[rule.threat.technique]]
name = "OS Credential Dumping"
id = "T1003"
[[rule.threat.technique.subtechnique]]
name = "LSASS Memory"
id = "T1003.001"
"""
technique_ids = []
if not isinstance(threat_list, list):
return technique_ids
for threat_entry in threat_list:
if not isinstance(threat_entry, dict):
continue
# Skip non-MITRE frameworks
framework = threat_entry.get("framework", "")
if "MITRE" not in str(framework).upper():
continue
techniques = threat_entry.get("technique", [])
if not isinstance(techniques, list):
continue
for tech in techniques:
if not isinstance(tech, dict):
continue
tech_id = tech.get("id", "")
if tech_id and str(tech_id).upper().startswith("T"):
technique_ids.append(str(tech_id).upper())
# Check subtechniques
subtechniques = tech.get("subtechnique", [])
if isinstance(subtechniques, list):
for subtech in subtechniques:
if isinstance(subtech, dict):
sub_id = subtech.get("id", "")
if sub_id and str(sub_id).upper().startswith("T"):
technique_ids.append(str(sub_id).upper())
return list(set(technique_ids))
def _parse_elastic_rules(rules_dir: Path) -> list[dict]:
"""Walk the rules directory and parse all TOML files.
Returns a flat list of dicts, one per (rule, technique) combination.
"""
results: list[dict] = []
toml_files = sorted(rules_dir.rglob("*.toml"))
logger.info("Found %d TOML files to parse", len(toml_files))
for toml_path in toml_files:
data = _parse_toml_safe(toml_path)
if not data:
continue
rule = data.get("rule", {})
if not isinstance(rule, dict):
continue
name = rule.get("name", "").strip()
if not name:
continue
# Extract MITRE technique IDs
threat_list = rule.get("threat", [])
technique_ids = _extract_mitre_techniques(threat_list)
if not technique_ids:
continue
description = rule.get("description", "")
query = rule.get("query", "")
severity = _SEVERITY_MAP.get(str(rule.get("severity", "")).lower())
rule_type = rule.get("type", "query") # query, eql, threshold, etc.
# Determine rule format based on type
if rule_type == "eql":
rule_format = "eql"
elif rule_type == "esql":
rule_format = "esql"
else:
rule_format = "kql"
# Use filename as source_id
source_id = toml_path.name
# Read raw content
try:
with open(toml_path, "r", encoding="utf-8") as fh:
raw_content = fh.read()
except Exception:
raw_content = query or str(data)
# Build source URL
relative = str(toml_path.relative_to(rules_dir.parent)).replace("\\", "/")
source_url = (
f"https://github.com/elastic/detection-rules/blob/main/{relative}"
)
# One entry per technique
for tech_id in technique_ids:
results.append({
"mitre_technique_id": tech_id,
"title": name[:500],
"description": str(description)[:2000] if description else None,
"source_id": source_id,
"source_url": source_url,
"rule_content": query[:50000] if query else raw_content[:50000],
"rule_format": rule_format,
"severity": severity,
"platforms": _infer_platforms(rules_dir, toml_path),
})
logger.info("Parsed %d (rule, technique) pairs total", len(results))
return results
def _infer_platforms(rules_dir: Path, toml_path: Path) -> list[str] | None:
"""Infer platforms from the rule's directory structure.
Elastic organizes rules by OS: rules/windows/, rules/linux/, etc.
"""
relative = toml_path.relative_to(rules_dir)
parts = [p.lower() for p in relative.parts]
platforms = []
if "windows" in parts:
platforms.append("windows")
if "linux" in parts:
platforms.append("linux")
if "macos" in parts:
platforms.append("macos")
return platforms if platforms else None
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def sync(db: Session) -> dict:
"""Download and import Elastic detection rules.
Returns a summary dict with ``created``, ``skipped_existing``, ``total_parsed``.
"""
tmp_dir = tempfile.mkdtemp(prefix="aegis_elastic_")
try:
zip_bytes = _download_zip()
rules_dir = _extract_zip(zip_bytes, tmp_dir)
parsed_rules = _parse_elastic_rules(rules_dir)
finally:
shutil.rmtree(tmp_dir, ignore_errors=True)
logger.info("Cleaned up temp directory %s", tmp_dir)
# Pre-load existing source_ids for dedup
existing_ids: set[str] = {
row[0]
for row in db.query(DetectionRule.source_id)
.filter(DetectionRule.source == "elastic")
.filter(DetectionRule.source_id.isnot(None))
.all()
}
created = 0
skipped = 0
for item in parsed_rules:
if item["source_id"] in existing_ids:
skipped += 1
continue
rule = DetectionRule(
mitre_technique_id=item["mitre_technique_id"],
title=item["title"],
description=item["description"],
source="elastic",
source_id=item["source_id"],
source_url=item["source_url"],
rule_content=item["rule_content"],
rule_format=item["rule_format"],
severity=item["severity"],
platforms=item["platforms"],
is_active=True,
)
db.add(rule)
existing_ids.add(item["source_id"])
created += 1
db.commit()
summary = {
"created": created,
"skipped_existing": skipped,
"total_parsed": len(parsed_rules),
}
# Update DataSource record
ds = db.query(DataSource).filter(DataSource.name == "elastic_rules").first()
if ds:
ds.last_sync_at = datetime.utcnow()
ds.last_sync_status = "success"
ds.last_sync_stats = summary
db.commit()
logger.info("Elastic import complete — %s", summary)
log_action(db, user_id=None, action="import_elastic_rules",
entity_type="detection_rule", entity_id=None, details=summary)
return summary