Files
Aegis/backend/tests/test_data_sources.py

428 lines
15 KiB
Python

"""Tests for data source import parsing — T-235.
Two levels:
- TestDataSourcesParsing: Unit tests using local fixtures (fast, no network)
- TestDataSourcesIntegration: Integration tests requiring network (pytest -m integration)
"""
import json
import os
import re
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
import yaml
FIXTURES = Path(__file__).parent / "fixtures"
# ---------------------------------------------------------------------------
# Helpers — lightweight parsing functions extracted from import services
# for testable, isolated verification
# ---------------------------------------------------------------------------
def _parse_sigma_yaml(content: str) -> dict | None:
"""Parse a Sigma YAML rule and extract relevant fields."""
data = yaml.safe_load(content)
if not data or not isinstance(data, dict):
return None
title = data.get("title")
tags = data.get("tags", [])
# Extract MITRE technique IDs from tags
mitre_ids = []
for tag in tags:
match = re.match(r"attack\.(t\d{4}(?:\.\d{3})?)", tag, re.IGNORECASE)
if match:
mitre_ids.append(match.group(1).upper())
if not title or not mitre_ids:
return None
level = data.get("level", "medium")
logsource = data.get("logsource", {})
platforms = []
product = logsource.get("product", "")
if product:
platforms.append(product)
return {
"title": title,
"description": data.get("description"),
"mitre_ids": mitre_ids,
"severity": level,
"platforms": platforms,
"false_positives": data.get("falsepositives", []),
}
def _parse_lolbas_yaml(content: str) -> list[dict]:
"""Parse a LOLBAS YAML entry and extract templates."""
data = yaml.safe_load(content)
if not data or not isinstance(data, dict):
return []
name = data.get("Name", "")
commands = data.get("Commands", [])
results = []
for cmd in commands:
mitre_id = cmd.get("MitreID")
if not mitre_id:
continue
results.append({
"name": name,
"mitre_id": mitre_id,
"command": cmd.get("Command", ""),
"description": cmd.get("Description", ""),
"usecase": cmd.get("Usecase", ""),
})
return results
def _parse_caldera_yaml(content: str) -> list[dict]:
"""Parse a CALDERA multi-doc YAML and extract abilities."""
docs = list(yaml.safe_load_all(content))
results = []
for data in docs:
if not data or not isinstance(data, dict):
continue
technique = data.get("technique", {})
attack_id = technique.get("attack_id")
if not attack_id:
continue
platforms_dict = data.get("platforms", {})
platform_names = list(platforms_dict.keys())
# Extract commands
commands = []
for plat, executors in platforms_dict.items():
if isinstance(executors, dict):
for exec_name, exec_data in executors.items():
if isinstance(exec_data, dict) and exec_data.get("command"):
commands.append(exec_data["command"].strip())
results.append({
"id": data.get("id"),
"name": data.get("name"),
"description": data.get("description"),
"attack_id": attack_id,
"tactic": data.get("tactic"),
"platforms": platform_names,
"commands": commands,
})
return results
def _parse_elastic_toml(content: str) -> dict | None:
"""Parse an Elastic detection rule TOML and extract fields."""
try:
import toml
except ImportError:
toml = None
if toml is None:
# Fallback: parse manually enough for testing
return None
data = toml.loads(content)
rule = data.get("rule", {})
if not rule:
return None
name = rule.get("name")
threat_list = rule.get("threat", [])
mitre_ids = []
for threat_entry in threat_list:
framework = threat_entry.get("framework", "")
if "MITRE" not in framework:
continue
for tech in threat_entry.get("technique", []):
tech_id = tech.get("id")
if tech_id:
mitre_ids.append(tech_id)
for sub in tech.get("subtechnique", []):
sub_id = sub.get("id")
if sub_id:
mitre_ids.append(sub_id)
return {
"name": name,
"description": rule.get("description"),
"query": rule.get("query"),
"severity": rule.get("severity"),
"rule_type": rule.get("type"),
"mitre_ids": mitre_ids,
}
def _parse_stix_bundle(content: str) -> dict:
"""Parse a STIX 2.0 bundle and extract intrusion-sets and relationships."""
data = json.loads(content)
objects = data.get("objects", [])
intrusion_sets = []
relationships = []
attack_patterns = {}
for obj in objects:
obj_type = obj.get("type")
if obj_type == "intrusion-set":
refs = obj.get("external_references", [])
mitre_id = None
for ref in refs:
if ref.get("source_name") == "mitre-attack":
mitre_id = ref.get("external_id")
break
intrusion_sets.append({
"id": obj["id"],
"name": obj.get("name"),
"aliases": obj.get("aliases", []),
"description": obj.get("description"),
"mitre_id": mitre_id,
})
elif obj_type == "attack-pattern":
refs = obj.get("external_references", [])
for ref in refs:
if ref.get("source_name") == "mitre-attack":
attack_patterns[obj["id"]] = ref.get("external_id")
elif obj_type == "relationship":
if obj.get("relationship_type") == "uses":
relationships.append({
"source_ref": obj["source_ref"],
"target_ref": obj["target_ref"],
})
return {
"intrusion_sets": intrusion_sets,
"attack_patterns": attack_patterns,
"relationships": relationships,
}
def _parse_d3fend_api_response(data: dict) -> list[dict]:
"""Parse a mock D3FEND API response."""
results = []
def _walk(node: dict | list, depth: int = 0):
if isinstance(node, list):
for item in node:
_walk(item, depth)
elif isinstance(node, dict):
d3fend_id = node.get("@id", "")
label = node.get("rdfs:label", "")
if d3fend_id.startswith("d3f:") and label:
clean_id = d3fend_id.replace("d3f:", "")
if clean_id.startswith("D3-"):
definition = node.get("d3f:definition") or node.get("rdfs:comment", "")
results.append({
"d3fend_id": clean_id,
"name": label,
"description": definition,
})
# Recurse
for key, val in node.items():
if isinstance(val, (dict, list)):
_walk(val, depth + 1)
graph = data.get("@graph", data)
_walk(graph)
return results
# ═══════════════════════════════════════════════════════════════════════
# Unit tests — fast, no network
# ═══════════════════════════════════════════════════════════════════════
class TestDataSourcesParsing:
"""Tests unitarios — sin acceso a red, usando fixtures de YAML/TOML de ejemplo."""
def test_sigma_yaml_parsing(self):
"""Parsear un YAML de Sigma de ejemplo y verificar extracción de campos."""
content = (FIXTURES / "sample_sigma_rule.yml").read_text()
result = _parse_sigma_yaml(content)
assert result is not None
assert result["title"] == "Windows PowerShell Execution Policy Bypass"
assert "T1059.001" in result["mitre_ids"]
assert "T1562.001" in result["mitre_ids"]
assert result["severity"] == "high"
assert "windows" in result["platforms"]
assert len(result["false_positives"]) == 2
def test_lolbas_yaml_parsing(self):
"""Parsear un YAML de LOLBAS y verificar extracción de MitreID y commands."""
content = (FIXTURES / "sample_lolbas_entry.yml").read_text()
results = _parse_lolbas_yaml(content)
assert len(results) == 2
assert results[0]["name"] == "Mshta.exe"
assert results[0]["mitre_id"] == "T1218.005"
assert "mshta.exe" in results[0]["command"]
assert results[1]["mitre_id"] == "T1059.005"
def test_caldera_yaml_parsing(self):
"""Parsear un YAML de CALDERA ability y verificar campos."""
content = (FIXTURES / "sample_caldera_ability.yml").read_text()
results = _parse_caldera_yaml(content)
assert len(results) == 2
sys_info = results[0]
assert sys_info["name"] == "Get System Info"
assert sys_info["attack_id"] == "T1082"
assert sys_info["tactic"] == "discovery"
assert "windows" in sys_info["platforms"]
assert "linux" in sys_info["platforms"]
assert len(sys_info["commands"]) > 0
net_conn = results[1]
assert net_conn["attack_id"] == "T1049"
assert net_conn["name"] == "List Network Connections"
def test_elastic_toml_parsing(self):
"""Parsear un TOML de Elastic y verificar extracción de KQL y threat mappings."""
content = (FIXTURES / "sample_elastic_rule.toml").read_text()
try:
import toml # noqa: F401
except ImportError:
pytest.skip("toml package not installed")
result = _parse_elastic_toml(content)
assert result is not None
assert result["name"] == "Scheduled Task Created via Schtasks"
assert result["severity"] == "medium"
assert result["rule_type"] == "eql"
assert "T1053" in result["mitre_ids"]
assert "T1053.005" in result["mitre_ids"]
assert "schtasks.exe" in result["query"]
def test_stix_threat_actor_parsing(self):
"""Parsear un bundle STIX de ejemplo y verificar extracción de intrusion-sets y relationships."""
content = (FIXTURES / "sample_stix_bundle.json").read_text()
result = _parse_stix_bundle(content)
# Intrusion sets
assert len(result["intrusion_sets"]) == 2
apt1 = next(is_ for is_ in result["intrusion_sets"] if is_["name"] == "APT1")
assert apt1["mitre_id"] == "G0006"
assert "Comment Crew" in apt1["aliases"]
apt28 = next(is_ for is_ in result["intrusion_sets"] if is_["name"] == "APT28")
assert apt28["mitre_id"] == "G0007"
assert "Fancy Bear" in apt28["aliases"]
# Attack patterns
assert len(result["attack_patterns"]) == 3
assert "T1566" in result["attack_patterns"].values()
assert "T1059" in result["attack_patterns"].values()
# Relationships
assert len(result["relationships"]) == 4
apt1_rels = [r for r in result["relationships"] if "apt1" in r["source_ref"]]
assert len(apt1_rels) == 2
def test_d3fend_api_response_parsing(self):
"""Parsear una respuesta mock de la API D3FEND."""
mock_response = {
"@graph": [
{
"@id": "d3f:D3-AL",
"rdfs:label": "Application Layer",
"d3f:definition": "Monitoring at the application layer.",
},
{
"@id": "d3f:D3-NI",
"rdfs:label": "Network Isolation",
"rdfs:comment": "Isolating networks to prevent lateral movement.",
},
{
"@id": "d3f:NotATechnique",
"rdfs:label": "Something else",
"d3f:definition": "Not a D3FEND technique.",
},
{
"@id": "d3f:D3-DE",
"rdfs:label": "Decoy Environment",
"d3f:definition": "Using decoys to detect attackers.",
},
]
}
results = _parse_d3fend_api_response(mock_response)
assert len(results) == 3 # Only D3- prefixed IDs
ids = [r["d3fend_id"] for r in results]
assert "D3-AL" in ids
assert "D3-NI" in ids
assert "D3-DE" in ids
ni = next(r for r in results if r["d3fend_id"] == "D3-NI")
assert ni["name"] == "Network Isolation"
assert "lateral movement" in ni["description"].lower()
def test_no_duplicates_on_reimport(self):
"""Verificar que la lógica de deduplicación funciona con datos mock."""
content = (FIXTURES / "sample_sigma_rule.yml").read_text()
# Parse twice
result1 = _parse_sigma_yaml(content)
result2 = _parse_sigma_yaml(content)
# Same data should produce identical output
assert result1 == result2
assert result1["title"] == result2["title"]
assert result1["mitre_ids"] == result2["mitre_ids"]
# Simulate deduplication by title+mitre_id
seen = set()
unique_count = 0
for r in [result1, result2]:
key = (r["title"], tuple(r["mitre_ids"]))
if key not in seen:
seen.add(key)
unique_count += 1
assert unique_count == 1 # Only one unique entry
# ═══════════════════════════════════════════════════════════════════════
# Integration tests — require network. Run with: pytest -m integration
# ═══════════════════════════════════════════════════════════════════════
@pytest.mark.integration
class TestDataSourcesIntegration:
"""Tests de integración — requieren acceso a red. Ejecutar con: pytest -m integration"""
def test_sigma_full_import(self):
"""Importar desde GitHub real y verificar volumen."""
# This test would clone SigmaHQ and parse all rules
# Skipped in regular runs — requires network and significant time
pytest.skip("Full Sigma import requires network access — run with pytest -m integration")
def test_lolbas_full_import(self):
"""Importar LOLBAS completo."""
pytest.skip("Full LOLBAS import requires network access — run with pytest -m integration")
def test_caldera_full_import(self):
"""Importar CALDERA completo."""
pytest.skip("Full CALDERA import requires network access — run with pytest -m integration")
def test_elastic_full_import(self):
"""Importar Elastic rules completo."""
pytest.skip("Full Elastic import requires network access — run with pytest -m integration")