feat(refactor): PEP8, type annotations, docstrings and PyJWT security fix
This commit is contained in:
@@ -22,22 +22,39 @@ Running the import twice does **not** create duplicates. Existing
|
||||
templates are identified by their ``atomic_test_id`` and simply skipped.
|
||||
"""
|
||||
|
||||
# Import io
|
||||
import io
|
||||
|
||||
# Import logging
|
||||
import logging
|
||||
import os
|
||||
|
||||
# Import shutil
|
||||
import shutil
|
||||
|
||||
# Import tempfile
|
||||
import tempfile
|
||||
|
||||
# Import zipfile
|
||||
import zipfile
|
||||
|
||||
# Import Path from pathlib
|
||||
from pathlib import Path
|
||||
|
||||
# Import requests
|
||||
import requests as _requests
|
||||
|
||||
# Import yaml
|
||||
import yaml
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import TestTemplate from app.models.test_template
|
||||
from app.models.test_template import TestTemplate
|
||||
from app.models.technique import Technique
|
||||
from app.services.audit_service import log_action
|
||||
|
||||
# Assign logger = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -45,7 +62,9 @@ logger = logging.getLogger(__name__)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
ATOMIC_RT_ZIP_URL = (
|
||||
# Literal argument value
|
||||
"https://github.com/redcanaryco/atomic-red-team"
|
||||
# Literal argument value
|
||||
"/archive/refs/heads/master.zip"
|
||||
)
|
||||
|
||||
@@ -55,6 +74,11 @@ _DOWNLOAD_TIMEOUT = 300
|
||||
# Top-level directory name inside the ZIP
|
||||
_ZIP_ROOT_PREFIX = "atomic-red-team-master"
|
||||
|
||||
# Safety limits for ZIP extraction — prevent zip-bomb DoS
|
||||
_MAX_UNCOMPRESSED_SIZE = 500 * 1024 * 1024 # 500 MB
|
||||
# Assign _MAX_ENTRIES = 50_000
|
||||
_MAX_ENTRIES = 50_000
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
@@ -63,14 +87,21 @@ _ZIP_ROOT_PREFIX = "atomic-red-team-master"
|
||||
|
||||
def _download_zip(url: str = ATOMIC_RT_ZIP_URL) -> bytes:
|
||||
"""Download the Atomic Red Team ZIP and return its raw bytes."""
|
||||
# Log info: "Downloading Atomic Red Team ZIP from %s …", url
|
||||
logger.info("Downloading Atomic Red Team ZIP from %s …", url)
|
||||
# Assign resp = _requests.get(url, timeout=_DOWNLOAD_TIMEOUT, stream=True)
|
||||
resp = _requests.get(url, timeout=_DOWNLOAD_TIMEOUT, stream=True)
|
||||
# Call resp.raise_for_status()
|
||||
resp.raise_for_status()
|
||||
# Assign content = resp.content
|
||||
content = resp.content
|
||||
# Log info: "Downloaded %.1f MB", len(content) / (1024 * 1024
|
||||
logger.info("Downloaded %.1f MB", len(content) / (1024 * 1024))
|
||||
# Return content
|
||||
return content
|
||||
|
||||
|
||||
# Define function _safe_extract_zip
|
||||
def _safe_extract_zip(zip_bytes: bytes, dest: str) -> None:
|
||||
"""Extract *zip_bytes* into *dest* with Zip Slip and Zip Bomb protection.
|
||||
|
||||
@@ -78,51 +109,66 @@ def _safe_extract_zip(zip_bytes: bytes, dest: str) -> None:
|
||||
directory (path traversal / Zip Slip) or if the archive exceeds the
|
||||
safety limits.
|
||||
"""
|
||||
# Maximum uncompressed size: 500 MB — prevents zip-bomb DoS
|
||||
_MAX_UNCOMPRESSED_SIZE = 500 * 1024 * 1024
|
||||
# Maximum number of entries
|
||||
_MAX_ENTRIES = 50_000
|
||||
|
||||
# Assign dest_path = Path(dest).resolve()
|
||||
dest_path = Path(dest).resolve()
|
||||
|
||||
# Open context manager
|
||||
with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
|
||||
# Assign entries = zf.infolist()
|
||||
entries = zf.infolist()
|
||||
|
||||
# Check: len(entries) > _MAX_ENTRIES
|
||||
if len(entries) > _MAX_ENTRIES:
|
||||
# Raise ValueError
|
||||
raise ValueError(
|
||||
f"ZIP archive contains {len(entries)} entries "
|
||||
f"(limit: {_MAX_ENTRIES}) — possible zip bomb"
|
||||
)
|
||||
|
||||
# Assign total_size = sum(info.file_size for info in entries)
|
||||
total_size = sum(info.file_size for info in entries)
|
||||
# Check: total_size > _MAX_UNCOMPRESSED_SIZE
|
||||
if total_size > _MAX_UNCOMPRESSED_SIZE:
|
||||
# Raise ValueError
|
||||
raise ValueError(
|
||||
f"ZIP uncompressed size {total_size / (1024 * 1024):.0f} MB "
|
||||
f"exceeds limit of {_MAX_UNCOMPRESSED_SIZE / (1024 * 1024):.0f} MB"
|
||||
)
|
||||
|
||||
# Iterate over entries
|
||||
for member in entries:
|
||||
# Assign target = (dest_path / member.filename).resolve()
|
||||
target = (dest_path / member.filename).resolve()
|
||||
# Check: not target.is_relative_to(dest_path)
|
||||
if not target.is_relative_to(dest_path):
|
||||
# Raise ValueError
|
||||
raise ValueError(
|
||||
f"Zip Slip detected — member '{member.filename}' "
|
||||
f"resolves outside target directory"
|
||||
)
|
||||
|
||||
# Call zf.extractall()
|
||||
zf.extractall(dest)
|
||||
|
||||
|
||||
# Define function _extract_zip
|
||||
def _extract_zip(zip_bytes: bytes, dest: str) -> Path:
|
||||
"""Extract *zip_bytes* into *dest* and return the path to the atomics/ dir."""
|
||||
# Call _safe_extract_zip()
|
||||
_safe_extract_zip(zip_bytes, dest)
|
||||
# Assign atomics_dir = Path(dest) / _ZIP_ROOT_PREFIX / "atomics"
|
||||
atomics_dir = Path(dest) / _ZIP_ROOT_PREFIX / "atomics"
|
||||
# Check: not atomics_dir.is_dir()
|
||||
if not atomics_dir.is_dir():
|
||||
# Raise FileNotFoundError
|
||||
raise FileNotFoundError(
|
||||
f"Expected atomics directory not found at {atomics_dir}"
|
||||
)
|
||||
# Return atomics_dir
|
||||
return atomics_dir
|
||||
|
||||
|
||||
# Define function _parse_yaml_files
|
||||
def _parse_yaml_files(atomics_dir: Path) -> list[dict]:
|
||||
"""Walk the atomics directory and parse all technique YAML files.
|
||||
|
||||
@@ -132,51 +178,84 @@ def _parse_yaml_files(atomics_dir: Path) -> list[dict]:
|
||||
technique_id, index, name, description, platforms,
|
||||
executor_type, command, source_url
|
||||
"""
|
||||
# Assign results = []
|
||||
results: list[dict] = []
|
||||
# Assign yaml_files = sorted(atomics_dir.glob("T*/T*.yaml"))
|
||||
yaml_files = sorted(atomics_dir.glob("T*/T*.yaml"))
|
||||
# Log info: "Found %d YAML files to parse", len(yaml_files
|
||||
logger.info("Found %d YAML files to parse", len(yaml_files))
|
||||
|
||||
# Iterate over yaml_files
|
||||
for yaml_path in yaml_files:
|
||||
# Assign technique_id = yaml_path.stem # e.g. "T1059.001"
|
||||
technique_id = yaml_path.stem # e.g. "T1059.001"
|
||||
# Attempt the following; catch errors below
|
||||
try:
|
||||
# Open context manager
|
||||
with open(yaml_path, "r", encoding="utf-8") as fh:
|
||||
# Assign data = yaml.safe_load(fh)
|
||||
data = yaml.safe_load(fh)
|
||||
# Handle Exception
|
||||
except Exception as exc:
|
||||
# Log warning: "Failed to parse %s: %s", yaml_path, exc
|
||||
logger.warning("Failed to parse %s: %s", yaml_path, exc)
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
|
||||
# Check: not data or "atomic_tests" not in data
|
||||
if not data or "atomic_tests" not in data:
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
|
||||
# Iterate over enumerate(data["atomic_tests"])
|
||||
for idx, test in enumerate(data["atomic_tests"]):
|
||||
# Assign name = test.get("name", "").strip()
|
||||
name = test.get("name", "").strip()
|
||||
# Assign description = test.get("description", "").strip()
|
||||
description = test.get("description", "").strip()
|
||||
# Assign platforms = test.get("supported_platforms", [])
|
||||
platforms = test.get("supported_platforms", [])
|
||||
# Assign executor = test.get("executor", {})
|
||||
executor = test.get("executor", {})
|
||||
# Assign executor_type = executor.get("name", "") if isinstance(executor, dict) else ""
|
||||
executor_type = executor.get("name", "") if isinstance(executor, dict) else ""
|
||||
# Assign command = executor.get("command", "") if isinstance(executor, dict) else ""
|
||||
command = executor.get("command", "") if isinstance(executor, dict) else ""
|
||||
|
||||
# Build an atomic_test_id in the format "T1059.001-0"
|
||||
atomic_test_id = f"{technique_id}-{idx}"
|
||||
|
||||
# Assign source_url = (
|
||||
source_url = (
|
||||
f"https://github.com/redcanaryco/atomic-red-team/blob/master"
|
||||
f"/atomics/{technique_id}/{technique_id}.yaml"
|
||||
)
|
||||
|
||||
# Call results.append()
|
||||
results.append({
|
||||
# Literal argument value
|
||||
"technique_id": technique_id,
|
||||
# Literal argument value
|
||||
"index": idx,
|
||||
# Literal argument value
|
||||
"atomic_test_id": atomic_test_id,
|
||||
# Literal argument value
|
||||
"name": name,
|
||||
# Literal argument value
|
||||
"description": description,
|
||||
# Literal argument value
|
||||
"platforms": ", ".join(platforms) if isinstance(platforms, list) else str(platforms),
|
||||
# Literal argument value
|
||||
"executor_type": executor_type,
|
||||
# Literal argument value
|
||||
"command": command[:4000] if command else None, # cap at 4k chars
|
||||
# Literal argument value
|
||||
"source_url": source_url,
|
||||
})
|
||||
|
||||
# Log info: "Parsed %d atomic tests total", len(results
|
||||
logger.info("Parsed %d atomic tests total", len(results))
|
||||
# Return results
|
||||
return results
|
||||
|
||||
|
||||
@@ -193,52 +272,80 @@ def import_atomic_red_team(db: Session) -> dict:
|
||||
db : Session
|
||||
Active SQLAlchemy database session.
|
||||
|
||||
Returns
|
||||
Returns:
|
||||
-------
|
||||
dict
|
||||
Summary with keys ``created``, ``skipped_existing``,
|
||||
``yaml_files_parsed``, ``total_tests_parsed``.
|
||||
"""
|
||||
# Assign tmp_dir = tempfile.mkdtemp(prefix="aegis_atomic_")
|
||||
tmp_dir = tempfile.mkdtemp(prefix="aegis_atomic_")
|
||||
# Attempt the following; catch errors below
|
||||
try:
|
||||
# Assign zip_bytes = _download_zip()
|
||||
zip_bytes = _download_zip()
|
||||
# Assign atomics_dir = _extract_zip(zip_bytes, tmp_dir)
|
||||
atomics_dir = _extract_zip(zip_bytes, tmp_dir)
|
||||
# Assign parsed_tests = _parse_yaml_files(atomics_dir)
|
||||
parsed_tests = _parse_yaml_files(atomics_dir)
|
||||
# Always execute this cleanup block
|
||||
finally:
|
||||
# Always clean up
|
||||
shutil.rmtree(tmp_dir, ignore_errors=True)
|
||||
# Log info: "Cleaned up temp directory %s", tmp_dir
|
||||
logger.info("Cleaned up temp directory %s", tmp_dir)
|
||||
|
||||
# Pre-load existing atomic_test_ids for dedup
|
||||
existing_ids: set[str] = {
|
||||
row[0]
|
||||
for row in db.query(TestTemplate.atomic_test_id)
|
||||
# Chain .filter() call
|
||||
.filter(TestTemplate.atomic_test_id.isnot(None))
|
||||
# Chain .all() call
|
||||
.all()
|
||||
}
|
||||
|
||||
# Assign created = 0
|
||||
created = 0
|
||||
# Assign skipped = 0
|
||||
skipped = 0
|
||||
new_technique_ids: set[str] = set()
|
||||
|
||||
# Iterate over parsed_tests
|
||||
for item in parsed_tests:
|
||||
# Check: item["atomic_test_id"] in existing_ids
|
||||
if item["atomic_test_id"] in existing_ids:
|
||||
# Assign skipped = 1
|
||||
skipped += 1
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
|
||||
# Assign template = TestTemplate(
|
||||
template = TestTemplate(
|
||||
# Keyword argument: mitre_technique_id
|
||||
mitre_technique_id=item["technique_id"],
|
||||
# Keyword argument: name
|
||||
name=item["name"][:500] if item["name"] else f"Atomic Test {item['atomic_test_id']}",
|
||||
# Keyword argument: description
|
||||
description=item["description"][:2000] if item["description"] else None,
|
||||
# Keyword argument: source
|
||||
source="atomic_red_team",
|
||||
# Keyword argument: source_url
|
||||
source_url=item["source_url"],
|
||||
# Keyword argument: attack_procedure
|
||||
attack_procedure=item["command"],
|
||||
# Keyword argument: platform
|
||||
platform=item["platforms"],
|
||||
# Keyword argument: tool_suggested
|
||||
tool_suggested=item["executor_type"] if item["executor_type"] else None,
|
||||
# Keyword argument: atomic_test_id
|
||||
atomic_test_id=item["atomic_test_id"],
|
||||
# Keyword argument: is_active
|
||||
is_active=True,
|
||||
)
|
||||
# Stage new record(s) for database insertion
|
||||
db.add(template)
|
||||
# Call existing_ids.add()
|
||||
existing_ids.add(item["atomic_test_id"])
|
||||
new_technique_ids.add(item["technique_id"])
|
||||
created += 1
|
||||
@@ -253,15 +360,23 @@ def import_atomic_red_team(db: Session) -> dict:
|
||||
# Count distinct YAML files by technique_id
|
||||
yaml_files_count = len({t["technique_id"] for t in parsed_tests})
|
||||
|
||||
# Assign summary = {
|
||||
summary = {
|
||||
# Literal argument value
|
||||
"created": created,
|
||||
# Literal argument value
|
||||
"skipped_existing": skipped,
|
||||
# Literal argument value
|
||||
"yaml_files_parsed": yaml_files_count,
|
||||
# Literal argument value
|
||||
"total_tests_parsed": len(parsed_tests),
|
||||
}
|
||||
|
||||
# Log info:
|
||||
logger.info(
|
||||
# Literal argument value
|
||||
"Atomic Red Team import complete — created=%d, skipped=%d, "
|
||||
# Literal argument value
|
||||
"yaml_files=%d, total_tests=%d",
|
||||
created, skipped, yaml_files_count, len(parsed_tests),
|
||||
)
|
||||
@@ -269,12 +384,19 @@ def import_atomic_red_team(db: Session) -> dict:
|
||||
# Audit log (system action)
|
||||
log_action(
|
||||
db,
|
||||
# Keyword argument: user_id
|
||||
user_id=None,
|
||||
# Keyword argument: action
|
||||
action="import_atomic_red_team",
|
||||
# Keyword argument: entity_type
|
||||
entity_type="test_template",
|
||||
# Keyword argument: entity_id
|
||||
entity_id=None,
|
||||
# Keyword argument: details
|
||||
details=summary,
|
||||
)
|
||||
# Commit all pending changes to the database
|
||||
db.commit()
|
||||
|
||||
# Return summary
|
||||
return summary
|
||||
|
||||
Reference in New Issue
Block a user