feat(refactor): PEP8, type annotations, docstrings and PyJWT security fix

This commit is contained in:
kitos
2026-06-11 11:09:41 +02:00
161 changed files with 15318 additions and 811 deletions
+1
View File
@@ -0,0 +1 @@
"""Service layer — business logic orchestrating domain entities and persistence."""
@@ -1,19 +1,31 @@
"""Advanced metrics service — coverage by tactic, never-tested, avg validation time, detection trend."""
# Enable future language features for compatibility
from __future__ import annotations
# Import datetime, timedelta from datetime
from datetime import datetime, timedelta
# Import case, func from sqlalchemy
from sqlalchemy import case, func
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session
from app.models.technique import Technique
from app.models.test import Test
# Import TestResult from app.models.enums
from app.models.enums import TestResult
# Import Technique from app.models.technique
from app.models.technique import Technique
# Import Test from app.models.test
from app.models.test import Test
# Define function get_coverage_by_tactic
def get_coverage_by_tactic(db: Session) -> list[dict]:
"""Coverage percentage broken down by MITRE ATT&CK tactic."""
# Assign results = (
results = (
db.query(
Technique.tactic,
@@ -31,134 +43,211 @@ def get_coverage_by_tactic(db: Session) -> list[dict]:
case((Technique.status_global == "in_progress", 1), else_=0)
).label("in_progress"),
)
# Chain .group_by() call
.group_by(Technique.tactic)
# Chain .order_by() call
.order_by(Technique.tactic)
# Chain .all() call
.all()
)
# Return [
return [
{
# Literal argument value
"tactic": r[0] or "Unknown",
# Literal argument value
"total": r[1],
# Literal argument value
"validated": int(r[2]),
# Literal argument value
"partial": int(r[3]),
# Literal argument value
"not_covered": int(r[4]),
# Literal argument value
"in_progress": int(r[5]),
# Literal argument value
"coverage_pct": round((int(r[2]) / r[1]) * 100, 1) if r[1] > 0 else 0,
}
for r in results
]
# Define function get_never_tested_techniques
def get_never_tested_techniques(db: Session) -> list[dict]:
"""Techniques that have never had a test created."""
# Assign tested_ids = [
tested_ids = [
row[0]
for row in db.query(Test.technique_id)
# Chain .filter() call
.filter(Test.technique_id.isnot(None))
# Chain .distinct() call
.distinct()
# Chain .all() call
.all()
]
# Assign query = db.query(Technique)
query = db.query(Technique)
# Check: tested_ids
if tested_ids:
# Assign query = query.filter(~Technique.id.in_(tested_ids))
query = query.filter(~Technique.id.in_(tested_ids))
# Assign techniques = query.order_by(Technique.mitre_id).all()
techniques = query.order_by(Technique.mitre_id).all()
# Return [
return [
{
# Literal argument value
"mitre_id": t.mitre_id,
# Literal argument value
"name": t.name,
# Literal argument value
"tactic": t.tactic,
# Literal argument value
"is_subtechnique": t.is_subtechnique,
}
for t in techniques
]
# Define function get_avg_validation_time
def get_avg_validation_time(db: Session) -> dict:
"""Average time from test creation to validation, computed from validated tests.
Returns overall average and per-phase averages where data is available.
"""
# Assign validated_tests = (
validated_tests = (
db.query(Test)
# Chain .filter() call
.filter(Test.state == "validated")
# Chain .all() call
.all()
)
# Check: not validated_tests
if not validated_tests:
# Return {
return {
# Literal argument value
"total_validated": 0,
# Literal argument value
"avg_total_hours": 0,
# Literal argument value
"avg_red_phase_hours": 0,
# Literal argument value
"avg_blue_phase_hours": 0,
}
# Assign total_durations = []
total_durations = []
# Assign red_durations = []
red_durations = []
# Assign blue_durations = []
blue_durations = []
# Iterate over validated_tests
for test in validated_tests:
# Check: test.created_at and test.red_validated_at
if test.created_at and test.red_validated_at:
# Assign total_seconds = (test.red_validated_at - test.created_at).total_seconds()
total_seconds = (test.red_validated_at - test.created_at).total_seconds()
# Call total_durations.append()
total_durations.append(total_seconds)
# Check: test.red_started_at and test.blue_started_at
if test.red_started_at and test.blue_started_at:
# Assign red_sec = (test.blue_started_at - test.red_started_at).total_seconds()
red_sec = (test.blue_started_at - test.red_started_at).total_seconds()
# Assign red_paused = test.red_paused_seconds or 0
red_paused = test.red_paused_seconds or 0
# Call red_durations.append()
red_durations.append(max(red_sec - red_paused, 0))
# Check: test.blue_started_at and test.blue_validated_at
if test.blue_started_at and test.blue_validated_at:
# Assign blue_sec = (test.blue_validated_at - test.blue_started_at).total_seconds()
blue_sec = (test.blue_validated_at - test.blue_started_at).total_seconds()
# Assign blue_paused = test.blue_paused_seconds or 0
blue_paused = test.blue_paused_seconds or 0
# Call blue_durations.append()
blue_durations.append(max(blue_sec - blue_paused, 0))
# Define function avg_hours
def avg_hours(durations: list[float]) -> float:
# Check: not durations
if not durations:
# Return 0
return 0
# Return round(sum(durations) / len(durations) / 3600, 2)
return round(sum(durations) / len(durations) / 3600, 2)
# Return {
return {
# Literal argument value
"total_validated": len(validated_tests),
# Literal argument value
"avg_total_hours": avg_hours(total_durations),
# Literal argument value
"avg_red_phase_hours": avg_hours(red_durations),
# Literal argument value
"avg_blue_phase_hours": avg_hours(blue_durations),
}
# Define function get_detection_rate_trend
def get_detection_rate_trend(db: Session) -> list[dict]:
"""Monthly detection rate trend for the last 12 months."""
# Assign now = datetime.utcnow()
now = datetime.utcnow()
# Assign months = []
months = []
# Iterate over range(11, -1, -1)
for i in range(11, -1, -1):
# Assign month_start = datetime(now.year, now.month, 1) - timedelta(days=i * 30)
month_start = datetime(now.year, now.month, 1) - timedelta(days=i * 30)
# Assign month_end = month_start + timedelta(days=30)
month_end = month_start + timedelta(days=30)
# Assign validated = (
validated = (
db.query(func.count(Test.id))
# Chain .filter() call
.filter(
Test.state == "validated",
Test.created_at >= month_start,
Test.created_at < month_end,
)
# Chain .scalar() call
.scalar() or 0
)
# Assign detected = (
detected = (
db.query(func.count(Test.id))
# Chain .filter() call
.filter(
Test.state == "validated",
Test.detection_result == TestResult.detected,
Test.created_at >= month_start,
Test.created_at < month_end,
)
# Chain .scalar() call
.scalar() or 0
)
# Call months.append()
months.append({
# Literal argument value
"month": month_start.strftime("%Y-%m"),
# Literal argument value
"validated": validated,
# Literal argument value
"detected": detected,
# Literal argument value
"detection_rate": round((detected / validated) * 100, 1) if validated > 0 else 0,
})
# Return months
return months
+63
View File
@@ -1,28 +1,50 @@
"""Analytics service — flat JSON optimized for PowerBI / BI tools."""
# Enable future language features for compatibility
from __future__ import annotations
# Import func from sqlalchemy
from sqlalchemy import func
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session
# Import CoverageSnapshot from app.models.coverage_snapshot
from app.models.coverage_snapshot import CoverageSnapshot
# Import Technique from app.models.technique
from app.models.technique import Technique
# Import Test from app.models.test
from app.models.test import Test
# Import User from app.models.user
from app.models.user import User
# Define function get_coverage_analytics
def get_coverage_analytics(db: Session) -> list[dict]:
"""Coverage per technique — flat format for BI dashboards."""
# Assign techniques = db.query(Technique).all()
techniques = db.query(Technique).all()
# Return [
return [
{
# Literal argument value
"mitre_id": t.mitre_id,
# Literal argument value
"name": t.name,
# Literal argument value
"tactic": t.tactic,
# Literal argument value
"status": t.status_global.value if t.status_global else "not_evaluated",
# Literal argument value
"is_subtechnique": t.is_subtechnique,
# Literal argument value
"test_count": len(t.tests) if t.tests else 0,
# Literal argument value
"review_required": t.review_required,
# Literal argument value
"last_review_date": (
t.last_review_date.isoformat() if t.last_review_date else None
),
@@ -31,76 +53,117 @@ def get_coverage_analytics(db: Session) -> list[dict]:
]
# Define function get_tests_analytics
def get_tests_analytics(
# Entry: db
db: Session,
*,
# Entry: date_from
date_from: str | None = None,
# Entry: date_to
date_to: str | None = None,
) -> list[dict]:
"""All tests with timestamps — flat format for BI dashboards."""
# Assign query = db.query(Test)
query = db.query(Test)
# Check: date_from
if date_from:
# Assign query = query.filter(Test.created_at >= date_from)
query = query.filter(Test.created_at >= date_from)
# Check: date_to
if date_to:
# Assign query = query.filter(Test.created_at <= date_to)
query = query.filter(Test.created_at <= date_to)
# Assign tests = query.all()
tests = query.all()
# Return [
return [
{
# Literal argument value
"id": str(t.id),
# Literal argument value
"technique_id": str(t.technique_id),
# Literal argument value
"name": t.name,
# Literal argument value
"state": t.state.value if t.state else None,
# Literal argument value
"result": t.result.value if t.result else None,
# Literal argument value
"detection_result": (
t.detection_result.value if t.detection_result else None
),
# Literal argument value
"created_at": t.created_at.isoformat() if t.created_at else None,
# Literal argument value
"execution_date": (
t.execution_date.isoformat() if t.execution_date else None
),
# Literal argument value
"platform": t.platform,
# Literal argument value
"tool_used": t.tool_used,
# Literal argument value
"attack_success": t.attack_success,
# Literal argument value
"remediation_status": t.remediation_status,
}
for t in tests
]
# Define function get_trends_analytics
def get_trends_analytics(db: Session) -> list[dict]:
"""Historical coverage snapshots for trend visualization."""
# Assign snapshots = (
snapshots = (
db.query(CoverageSnapshot)
# Chain .order_by() call
.order_by(CoverageSnapshot.created_at)
# Chain .all() call
.all()
)
# Return [
return [
{
# Literal argument value
"date": s.created_at.isoformat() if s.created_at else None,
# Literal argument value
"name": s.name,
# Literal argument value
"total_techniques": s.total_techniques,
# Literal argument value
"validated_count": s.validated_count,
# Literal argument value
"partial_count": s.partial_count,
# Literal argument value
"not_covered_count": s.not_covered_count,
# Literal argument value
"organization_score": s.organization_score,
}
for s in snapshots
]
# Define function get_operators_analytics
def get_operators_analytics(db: Session) -> list[dict]:
"""Per-operator metrics — for workload management dashboards."""
# Assign results = (
results = (
db.query(
User.username,
User.role,
func.count(Test.id).label("test_count"),
)
# Chain .outerjoin() call
.outerjoin(Test, Test.created_by == User.id)
# Chain .group_by() call
.group_by(User.id, User.username, User.role)
# Chain .all() call
.all()
)
# Return [
return [
{"username": r[0], "role": r[1], "test_count": r[2]}
for r in results
+129 -7
View File
@@ -22,22 +22,39 @@ Running the import twice does **not** create duplicates. Existing
templates are identified by their ``atomic_test_id`` and simply skipped.
"""
# Import io
import io
# Import logging
import logging
import os
# Import shutil
import shutil
# Import tempfile
import tempfile
# Import zipfile
import zipfile
# Import Path from pathlib
from pathlib import Path
# Import requests
import requests as _requests
# Import yaml
import yaml
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session
# Import TestTemplate from app.models.test_template
from app.models.test_template import TestTemplate
from app.models.technique import Technique
from app.services.audit_service import log_action
# Assign logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
@@ -45,7 +62,9 @@ logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
ATOMIC_RT_ZIP_URL = (
# Literal argument value
"https://github.com/redcanaryco/atomic-red-team"
# Literal argument value
"/archive/refs/heads/master.zip"
)
@@ -55,6 +74,11 @@ _DOWNLOAD_TIMEOUT = 300
# Top-level directory name inside the ZIP
_ZIP_ROOT_PREFIX = "atomic-red-team-master"
# Safety limits for ZIP extraction — prevent zip-bomb DoS
_MAX_UNCOMPRESSED_SIZE = 500 * 1024 * 1024 # 500 MB
# Assign _MAX_ENTRIES = 50_000
_MAX_ENTRIES = 50_000
# ---------------------------------------------------------------------------
# Internal helpers
@@ -63,14 +87,21 @@ _ZIP_ROOT_PREFIX = "atomic-red-team-master"
def _download_zip(url: str = ATOMIC_RT_ZIP_URL) -> bytes:
"""Download the Atomic Red Team ZIP and return its raw bytes."""
# Log info: "Downloading Atomic Red Team ZIP from %s …", url
logger.info("Downloading Atomic Red Team ZIP from %s", url)
# Assign resp = _requests.get(url, timeout=_DOWNLOAD_TIMEOUT, stream=True)
resp = _requests.get(url, timeout=_DOWNLOAD_TIMEOUT, stream=True)
# Call resp.raise_for_status()
resp.raise_for_status()
# Assign content = resp.content
content = resp.content
# Log info: "Downloaded %.1f MB", len(content) / (1024 * 1024
logger.info("Downloaded %.1f MB", len(content) / (1024 * 1024))
# Return content
return content
# Define function _safe_extract_zip
def _safe_extract_zip(zip_bytes: bytes, dest: str) -> None:
"""Extract *zip_bytes* into *dest* with Zip Slip and Zip Bomb protection.
@@ -78,51 +109,66 @@ def _safe_extract_zip(zip_bytes: bytes, dest: str) -> None:
directory (path traversal / Zip Slip) or if the archive exceeds the
safety limits.
"""
# Maximum uncompressed size: 500 MB — prevents zip-bomb DoS
_MAX_UNCOMPRESSED_SIZE = 500 * 1024 * 1024
# Maximum number of entries
_MAX_ENTRIES = 50_000
# Assign dest_path = Path(dest).resolve()
dest_path = Path(dest).resolve()
# Open context manager
with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
# Assign entries = zf.infolist()
entries = zf.infolist()
# Check: len(entries) > _MAX_ENTRIES
if len(entries) > _MAX_ENTRIES:
# Raise ValueError
raise ValueError(
f"ZIP archive contains {len(entries)} entries "
f"(limit: {_MAX_ENTRIES}) — possible zip bomb"
)
# Assign total_size = sum(info.file_size for info in entries)
total_size = sum(info.file_size for info in entries)
# Check: total_size > _MAX_UNCOMPRESSED_SIZE
if total_size > _MAX_UNCOMPRESSED_SIZE:
# Raise ValueError
raise ValueError(
f"ZIP uncompressed size {total_size / (1024 * 1024):.0f} MB "
f"exceeds limit of {_MAX_UNCOMPRESSED_SIZE / (1024 * 1024):.0f} MB"
)
# Iterate over entries
for member in entries:
# Assign target = (dest_path / member.filename).resolve()
target = (dest_path / member.filename).resolve()
# Check: not target.is_relative_to(dest_path)
if not target.is_relative_to(dest_path):
# Raise ValueError
raise ValueError(
f"Zip Slip detected — member '{member.filename}' "
f"resolves outside target directory"
)
# Call zf.extractall()
zf.extractall(dest)
# Define function _extract_zip
def _extract_zip(zip_bytes: bytes, dest: str) -> Path:
"""Extract *zip_bytes* into *dest* and return the path to the atomics/ dir."""
# Call _safe_extract_zip()
_safe_extract_zip(zip_bytes, dest)
# Assign atomics_dir = Path(dest) / _ZIP_ROOT_PREFIX / "atomics"
atomics_dir = Path(dest) / _ZIP_ROOT_PREFIX / "atomics"
# Check: not atomics_dir.is_dir()
if not atomics_dir.is_dir():
# Raise FileNotFoundError
raise FileNotFoundError(
f"Expected atomics directory not found at {atomics_dir}"
)
# Return atomics_dir
return atomics_dir
# Define function _parse_yaml_files
def _parse_yaml_files(atomics_dir: Path) -> list[dict]:
"""Walk the atomics directory and parse all technique YAML files.
@@ -132,51 +178,84 @@ def _parse_yaml_files(atomics_dir: Path) -> list[dict]:
technique_id, index, name, description, platforms,
executor_type, command, source_url
"""
# Assign results = []
results: list[dict] = []
# Assign yaml_files = sorted(atomics_dir.glob("T*/T*.yaml"))
yaml_files = sorted(atomics_dir.glob("T*/T*.yaml"))
# Log info: "Found %d YAML files to parse", len(yaml_files
logger.info("Found %d YAML files to parse", len(yaml_files))
# Iterate over yaml_files
for yaml_path in yaml_files:
# Assign technique_id = yaml_path.stem # e.g. "T1059.001"
technique_id = yaml_path.stem # e.g. "T1059.001"
# Attempt the following; catch errors below
try:
# Open context manager
with open(yaml_path, "r", encoding="utf-8") as fh:
# Assign data = yaml.safe_load(fh)
data = yaml.safe_load(fh)
# Handle Exception
except Exception as exc:
# Log warning: "Failed to parse %s: %s", yaml_path, exc
logger.warning("Failed to parse %s: %s", yaml_path, exc)
# Skip to the next loop iteration
continue
# Check: not data or "atomic_tests" not in data
if not data or "atomic_tests" not in data:
# Skip to the next loop iteration
continue
# Iterate over enumerate(data["atomic_tests"])
for idx, test in enumerate(data["atomic_tests"]):
# Assign name = test.get("name", "").strip()
name = test.get("name", "").strip()
# Assign description = test.get("description", "").strip()
description = test.get("description", "").strip()
# Assign platforms = test.get("supported_platforms", [])
platforms = test.get("supported_platforms", [])
# Assign executor = test.get("executor", {})
executor = test.get("executor", {})
# Assign executor_type = executor.get("name", "") if isinstance(executor, dict) else ""
executor_type = executor.get("name", "") if isinstance(executor, dict) else ""
# Assign command = executor.get("command", "") if isinstance(executor, dict) else ""
command = executor.get("command", "") if isinstance(executor, dict) else ""
# Build an atomic_test_id in the format "T1059.001-0"
atomic_test_id = f"{technique_id}-{idx}"
# Assign source_url = (
source_url = (
f"https://github.com/redcanaryco/atomic-red-team/blob/master"
f"/atomics/{technique_id}/{technique_id}.yaml"
)
# Call results.append()
results.append({
# Literal argument value
"technique_id": technique_id,
# Literal argument value
"index": idx,
# Literal argument value
"atomic_test_id": atomic_test_id,
# Literal argument value
"name": name,
# Literal argument value
"description": description,
# Literal argument value
"platforms": ", ".join(platforms) if isinstance(platforms, list) else str(platforms),
# Literal argument value
"executor_type": executor_type,
# Literal argument value
"command": command[:4000] if command else None, # cap at 4k chars
# Literal argument value
"source_url": source_url,
})
# Log info: "Parsed %d atomic tests total", len(results
logger.info("Parsed %d atomic tests total", len(results))
# Return results
return results
@@ -193,52 +272,80 @@ def import_atomic_red_team(db: Session) -> dict:
db : Session
Active SQLAlchemy database session.
Returns
Returns:
-------
dict
Summary with keys ``created``, ``skipped_existing``,
``yaml_files_parsed``, ``total_tests_parsed``.
"""
# Assign tmp_dir = tempfile.mkdtemp(prefix="aegis_atomic_")
tmp_dir = tempfile.mkdtemp(prefix="aegis_atomic_")
# Attempt the following; catch errors below
try:
# Assign zip_bytes = _download_zip()
zip_bytes = _download_zip()
# Assign atomics_dir = _extract_zip(zip_bytes, tmp_dir)
atomics_dir = _extract_zip(zip_bytes, tmp_dir)
# Assign parsed_tests = _parse_yaml_files(atomics_dir)
parsed_tests = _parse_yaml_files(atomics_dir)
# Always execute this cleanup block
finally:
# Always clean up
shutil.rmtree(tmp_dir, ignore_errors=True)
# Log info: "Cleaned up temp directory %s", tmp_dir
logger.info("Cleaned up temp directory %s", tmp_dir)
# Pre-load existing atomic_test_ids for dedup
existing_ids: set[str] = {
row[0]
for row in db.query(TestTemplate.atomic_test_id)
# Chain .filter() call
.filter(TestTemplate.atomic_test_id.isnot(None))
# Chain .all() call
.all()
}
# Assign created = 0
created = 0
# Assign skipped = 0
skipped = 0
new_technique_ids: set[str] = set()
# Iterate over parsed_tests
for item in parsed_tests:
# Check: item["atomic_test_id"] in existing_ids
if item["atomic_test_id"] in existing_ids:
# Assign skipped = 1
skipped += 1
# Skip to the next loop iteration
continue
# Assign template = TestTemplate(
template = TestTemplate(
# Keyword argument: mitre_technique_id
mitre_technique_id=item["technique_id"],
# Keyword argument: name
name=item["name"][:500] if item["name"] else f"Atomic Test {item['atomic_test_id']}",
# Keyword argument: description
description=item["description"][:2000] if item["description"] else None,
# Keyword argument: source
source="atomic_red_team",
# Keyword argument: source_url
source_url=item["source_url"],
# Keyword argument: attack_procedure
attack_procedure=item["command"],
# Keyword argument: platform
platform=item["platforms"],
# Keyword argument: tool_suggested
tool_suggested=item["executor_type"] if item["executor_type"] else None,
# Keyword argument: atomic_test_id
atomic_test_id=item["atomic_test_id"],
# Keyword argument: is_active
is_active=True,
)
# Stage new record(s) for database insertion
db.add(template)
# Call existing_ids.add()
existing_ids.add(item["atomic_test_id"])
new_technique_ids.add(item["technique_id"])
created += 1
@@ -253,15 +360,23 @@ def import_atomic_red_team(db: Session) -> dict:
# Count distinct YAML files by technique_id
yaml_files_count = len({t["technique_id"] for t in parsed_tests})
# Assign summary = {
summary = {
# Literal argument value
"created": created,
# Literal argument value
"skipped_existing": skipped,
# Literal argument value
"yaml_files_parsed": yaml_files_count,
# Literal argument value
"total_tests_parsed": len(parsed_tests),
}
# Log info:
logger.info(
# Literal argument value
"Atomic Red Team import complete — created=%d, skipped=%d, "
# Literal argument value
"yaml_files=%d, total_tests=%d",
created, skipped, yaml_files_count, len(parsed_tests),
)
@@ -269,12 +384,19 @@ def import_atomic_red_team(db: Session) -> dict:
# Audit log (system action)
log_action(
db,
# Keyword argument: user_id
user_id=None,
# Keyword argument: action
action="import_atomic_red_team",
# Keyword argument: entity_type
entity_type="test_template",
# Keyword argument: entity_id
entity_id=None,
# Keyword argument: details
details=summary,
)
# Commit all pending changes to the database
db.commit()
# Return summary
return summary
@@ -4,24 +4,37 @@ Provides paginated logs and distinct action/entity-type lists.
No FastAPI imports.
"""
# Enable future language features for compatibility
from __future__ import annotations
# Import datetime from datetime
from datetime import datetime
# Import Session, joinedload from sqlalchemy.orm
from sqlalchemy.orm import Session, joinedload
# Import AuditLog from app.models.audit
from app.models.audit import AuditLog
# Define function list_logs
def list_logs(
# Entry: db
db: Session,
*,
# Entry: user_id
user_id: str | None = None,
# Entry: action
action: str | None = None,
# Entry: entity_type
entity_type: str | None = None,
# Entry: start_date
start_date: datetime | None = None,
# Entry: end_date
end_date: datetime | None = None,
# Entry: offset
offset: int = 0,
# Entry: limit
limit: int = 50,
) -> dict:
"""Return paginated audit logs with optional filters.
@@ -30,64 +43,104 @@ def list_logs(
Each item is a dict with: id, user_id, username, action, entity_type,
entity_id, timestamp, details.
"""
# Assign query = db.query(AuditLog).options(joinedload(AuditLog.user))
query = db.query(AuditLog).options(joinedload(AuditLog.user))
# Check: user_id
if user_id:
# Assign query = query.filter(AuditLog.user_id == user_id)
query = query.filter(AuditLog.user_id == user_id)
# Check: action
if action:
# Assign query = query.filter(AuditLog.action == action)
query = query.filter(AuditLog.action == action)
# Check: entity_type
if entity_type:
# Assign query = query.filter(AuditLog.entity_type == entity_type)
query = query.filter(AuditLog.entity_type == entity_type)
# Check: start_date
if start_date:
# Assign query = query.filter(AuditLog.timestamp >= start_date)
query = query.filter(AuditLog.timestamp >= start_date)
# Check: end_date
if end_date:
# Assign query = query.filter(AuditLog.timestamp <= end_date)
query = query.filter(AuditLog.timestamp <= end_date)
# Assign total = query.count()
total = query.count()
# Assign logs = (
logs = (
query
# Chain .order_by() call
.order_by(AuditLog.timestamp.desc())
# Chain .offset() call
.offset(offset)
# Chain .limit() call
.limit(limit)
# Chain .all() call
.all()
)
# Assign items = [
items = [
{
# Literal argument value
"id": log.id,
# Literal argument value
"user_id": log.user_id,
# Literal argument value
"username": log.user.username if log.user else None,
# Literal argument value
"action": log.action,
# Literal argument value
"entity_type": log.entity_type,
# Literal argument value
"entity_id": log.entity_id,
# Literal argument value
"timestamp": log.timestamp,
# Literal argument value
"details": log.details,
}
for log in logs
]
# Return {"items": items, "total": total, "offset": offset, "limit": limit}
return {"items": items, "total": total, "offset": offset, "limit": limit}
# Define function list_distinct_actions
def list_distinct_actions(db: Session) -> list[str]:
"""Return a list of distinct action types in the audit log."""
# Assign actions = (
actions = (
db.query(AuditLog.action)
# Chain .distinct() call
.distinct()
# Chain .order_by() call
.order_by(AuditLog.action)
# Chain .all() call
.all()
)
# Return [a[0] for a in actions]
return [a[0] for a in actions]
# Define function list_distinct_entity_types
def list_distinct_entity_types(db: Session) -> list[str]:
"""Return a list of distinct entity types in the audit log."""
# Assign types = (
types = (
db.query(AuditLog.entity_type)
# Chain .filter() call
.filter(AuditLog.entity_type.isnot(None))
# Chain .distinct() call
.distinct()
# Chain .order_by() call
.order_by(AuditLog.entity_type)
# Chain .all() call
.all()
)
# Return [t[0] for t in types]
return [t[0] for t in types]
+51 -1
View File
@@ -1,66 +1,116 @@
"""Audit logging with request context and integrity hashing."""
# Enable future language features for compatibility
from __future__ import annotations
# Import hashlib
import hashlib
# Import datetime, timezone from datetime
from datetime import datetime, timezone
# Import UUID from uuid
from uuid import UUID
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session
# Import request_ip, request_user_agent from app.middleware.request_context
from app.middleware.request_context import request_ip, request_user_agent
# Import AuditLog from app.models.audit
from app.models.audit import AuditLog
# Define function _integrity_payload
def _integrity_payload(entry: AuditLog) -> str:
# Assign ts = entry.timestamp
ts = entry.timestamp
# Check: ts is None
if ts is None:
# Assign ts = datetime.now(timezone.utc)
ts = datetime.now(timezone.utc)
# Assign user_part = str(entry.user_id) if entry.user_id else ""
user_part = str(entry.user_id) if entry.user_id else ""
# Assign entity_type = entry.entity_type or ""
entity_type = entry.entity_type or ""
# Assign entity_id = entry.entity_id or ""
entity_id = entry.entity_id or ""
# Return f"{user_part}:{entry.action}:{entity_type}:{entity_id}:{ts.isoforma...
return f"{user_part}:{entry.action}:{entity_type}:{entity_id}:{ts.isoformat()}"
# Define function compute_integrity_hash
def compute_integrity_hash(entry: AuditLog) -> str:
"""Return the SHA-256 hex digest for an audit log entry."""
# Return hashlib.sha256(_integrity_payload(entry).encode()).hexdigest()
return hashlib.sha256(_integrity_payload(entry).encode()).hexdigest()
# Define function verify_audit_integrity
def verify_audit_integrity(entry: AuditLog) -> bool:
"""Return whether the stored hash matches the entry's current fields."""
# Check: not entry.integrity_hash
if not entry.integrity_hash:
# Return False
return False
# Return entry.integrity_hash == compute_integrity_hash(entry)
return entry.integrity_hash == compute_integrity_hash(entry)
# Define function log_action
def log_action(
# Entry: db
db: Session,
user_id,
# Entry: user_id
user_id: UUID | None,
# Entry: action
action: str,
# Entry: entity_type
entity_type: str | None = None,
# Entry: entity_id
entity_id: str | None = None,
# Entry: details
details: dict | None = None,
*,
# Entry: ip_address
ip_address: str | None = None,
# Entry: user_agent
user_agent: str | None = None,
# Entry: session_id
session_id: str | None = None,
) -> AuditLog:
"""Record an audit event. Does not commit — the caller owns the transaction."""
# Assign ip = ip_address if ip_address is not None else request_ip.get("")
ip = ip_address if ip_address is not None else request_ip.get("")
# Assign ua = user_agent if user_agent is not None else request_user_agent.get("")
ua = user_agent if user_agent is not None else request_user_agent.get("")
# Assign entry = AuditLog(
entry = AuditLog(
# Keyword argument: user_id
user_id=user_id,
# Keyword argument: action
action=action,
# Keyword argument: entity_type
entity_type=entity_type,
# Keyword argument: entity_id
entity_id=str(entity_id) if entity_id else None,
# Keyword argument: details
details=details,
# Keyword argument: ip_address
ip_address=ip or None,
# Keyword argument: user_agent
user_agent=ua or None,
# Keyword argument: session_id
session_id=session_id,
timestamp=datetime.now(timezone.utc),
)
# Stage new record(s) for database insertion
db.add(entry)
# Flush changes to DB without committing the transaction
db.flush()
# Assign entry.integrity_hash = compute_integrity_hash(entry)
entry.integrity_hash = compute_integrity_hash(entry)
# Return entry
return entry
+25
View File
@@ -1,15 +1,24 @@
"""Authentication service — credential validation and password management."""
# Enable future language features for compatibility
from __future__ import annotations
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session
# Import hash_password, verify_password from app.auth
from app.auth import hash_password, verify_password
# Import BusinessRuleViolation, PermissionViolation from app.domain.errors
from app.domain.errors import BusinessRuleViolation, PermissionViolation
# Import User from app.models.user
from app.models.user import User
# Assign _DUMMY_HASH = "$2b$12$LJ3m4ys3Lg3dMO/NpNmOaeVwFpWJMxlB2FLmEAo9fZr.S8H1vC4Wy"
_DUMMY_HASH = "$2b$12$LJ3m4ys3Lg3dMO/NpNmOaeVwFpWJMxlB2FLmEAo9fZr.S8H1vC4Wy"
# Define function authenticate_user
def authenticate_user(db: Session, *, username: str, password: str) -> User:
"""Validate credentials and return the User.
@@ -17,33 +26,49 @@ def authenticate_user(db: Session, *, username: str, password: str) -> User:
Raises PermissionViolation for disabled account.
Uses constant-time comparison to prevent timing attacks.
"""
# Assign user = db.query(User).filter(User.username == username).first()
user = db.query(User).filter(User.username == username).first()
# Assign hashed = user.hashed_password if user else _DUMMY_HASH
hashed = user.hashed_password if user else _DUMMY_HASH
# Assign password_valid = verify_password(password, hashed)
password_valid = verify_password(password, hashed)
# Check: user is None or not password_valid
if user is None or not password_valid:
# Raise BusinessRuleViolation
raise BusinessRuleViolation("Incorrect username or password")
# Check: not user.is_active
if not user.is_active:
# Raise PermissionViolation
raise PermissionViolation("Account is disabled. Contact an administrator.")
# Return user
return user
# Define function change_password
def change_password(
# Entry: db
db: Session,
# Entry: user
user: User,
*,
# Entry: current_password
current_password: str,
# Entry: new_password
new_password: str,
) -> None:
"""Change a user's password. Does NOT commit.
Raises BusinessRuleViolation if current password is wrong.
"""
# Check: not verify_password(current_password, user.hashed_password)
if not verify_password(current_password, user.hashed_password):
# Raise BusinessRuleViolation
raise BusinessRuleViolation("Current password is incorrect")
if verify_password(new_password, user.hashed_password):
raise BusinessRuleViolation(
"New password must be different from the current password"
)
user.hashed_password = hash_password(new_password)
# Assign user.must_change_password = False
user.must_change_password = False
+158 -1
View File
@@ -21,23 +21,42 @@ templates are identified by ``source = "caldera"`` + ``atomic_test_id``
(the CALDERA ability ``id``).
"""
# Import io
import io
# Import logging
import logging
# Import shutil
import shutil
# Import tempfile
import tempfile
# Import zipfile
import zipfile
# Import datetime from datetime
from datetime import datetime
# Import Path from pathlib
from pathlib import Path
# Import requests
import requests as _requests
# Import yaml
import yaml
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session
from app.models.test_template import TestTemplate
# Import DataSource from app.models.data_source
from app.models.data_source import DataSource
from app.models.technique import Technique
from app.services.audit_service import log_action
# Assign logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
@@ -45,11 +64,15 @@ logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
CALDERA_ZIP_URL = (
# Literal argument value
"https://github.com/mitre/stockpile"
# Literal argument value
"/archive/refs/heads/master.zip"
)
# Assign _DOWNLOAD_TIMEOUT = 300
_DOWNLOAD_TIMEOUT = 300
# Assign _ZIP_ROOT_PREFIX = "stockpile-master"
_ZIP_ROOT_PREFIX = "stockpile-master"
@@ -60,26 +83,40 @@ _ZIP_ROOT_PREFIX = "stockpile-master"
def _download_zip(url: str = CALDERA_ZIP_URL) -> bytes:
"""Download the CALDERA ZIP and return raw bytes."""
# Log info: "Downloading CALDERA ZIP from %s …", url
logger.info("Downloading CALDERA ZIP from %s", url)
# Assign resp = _requests.get(url, timeout=_DOWNLOAD_TIMEOUT, stream=True)
resp = _requests.get(url, timeout=_DOWNLOAD_TIMEOUT, stream=True)
# Call resp.raise_for_status()
resp.raise_for_status()
# Assign content = resp.content
content = resp.content
# Log info: "Downloaded %.1f MB", len(content) / (1024 * 1024
logger.info("Downloaded %.1f MB", len(content) / (1024 * 1024))
# Return content
return content
# Define function _extract_zip
def _extract_zip(zip_bytes: bytes, dest: str) -> Path:
"""Extract *zip_bytes* into *dest* and return abilities dir."""
# Open context manager
with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
# Call zf.extractall()
zf.extractall(dest)
# Assign abilities_dir = Path(dest) / _ZIP_ROOT_PREFIX / "data" / "abilities"
abilities_dir = Path(dest) / _ZIP_ROOT_PREFIX / "data" / "abilities"
# Check: not abilities_dir.is_dir()
if not abilities_dir.is_dir():
# Raise FileNotFoundError
raise FileNotFoundError(
f"Expected abilities directory not found at {abilities_dir}"
)
# Return abilities_dir
return abilities_dir
# Define function _extract_commands
def _extract_commands(platforms_dict: dict) -> str:
"""Extract executor commands from CALDERA platforms dict.
@@ -95,116 +132,192 @@ def _extract_commands(platforms_dict: dict) -> str:
Returns a formatted string with all commands.
"""
# Assign lines = []
lines = []
# Check: not isinstance(platforms_dict, dict)
if not isinstance(platforms_dict, dict):
# Return ""
return ""
# Iterate over platforms_dict.items()
for os_name, executors in platforms_dict.items():
# Check: not isinstance(executors, dict)
if not isinstance(executors, dict):
# Skip to the next loop iteration
continue
# Iterate over executors.items()
for executor_name, executor_data in executors.items():
# Check: isinstance(executor_data, dict)
if isinstance(executor_data, dict):
# Assign cmd = executor_data.get("command", "")
cmd = executor_data.get("command", "")
# Check: cmd
if cmd:
# Call lines.append()
lines.append(f"[{os_name}/{executor_name}]\n{cmd}")
# Alternative: isinstance(executor_data, str)
elif isinstance(executor_data, str):
# Call lines.append()
lines.append(f"[{os_name}/{executor_name}]\n{executor_data}")
# Return "\n\n".join(lines)
return "\n\n".join(lines)
# Define function _extract_platforms
def _extract_platforms(platforms_dict: dict) -> str:
"""Extract platform names from CALDERA platforms dict."""
# Check: not isinstance(platforms_dict, dict)
if not isinstance(platforms_dict, dict):
# Return ""
return ""
# Assign platform_names = []
platform_names = []
# Iterate over platforms_dict
for os_name in platforms_dict:
# Assign normalized = str(os_name).lower().strip()
normalized = str(os_name).lower().strip()
# Check: normalized in ("windows", "linux", "darwin", "macos")
if normalized in ("windows", "linux", "darwin", "macos"):
# Check: normalized == "darwin"
if normalized == "darwin":
# Assign normalized = "macos"
normalized = "macos"
# Check: normalized not in platform_names
if normalized not in platform_names:
# Call platform_names.append()
platform_names.append(normalized)
# Return ", ".join(platform_names)
return ", ".join(platform_names)
# Define function _parse_abilities
def _parse_abilities(abilities_dir: Path) -> list[dict]:
"""Walk abilities directories and parse all YAML files.
Returns a flat list of dicts, each representing one ability.
"""
# Assign results = []
results: list[dict] = []
# Assign yaml_files = sorted(abilities_dir.rglob("*.yml"))
yaml_files = sorted(abilities_dir.rglob("*.yml"))
# Log info: "Found %d ability YAML files", len(yaml_files
logger.info("Found %d ability YAML files", len(yaml_files))
# Iterate over yaml_files
for yaml_path in yaml_files:
# Attempt the following; catch errors below
try:
# Open context manager
with open(yaml_path, "r", encoding="utf-8") as fh:
# Assign data_list = list(yaml.safe_load_all(fh))
data_list = list(yaml.safe_load_all(fh))
# Handle Exception
except Exception as exc:
# Log debug: "Failed to parse %s: %s", yaml_path, exc
logger.debug("Failed to parse %s: %s", yaml_path, exc)
# Skip to the next loop iteration
continue
# Stockpile YAML files may contain YAML lists of abilities
# (e.g. [- id: ..., - id: ...]) or single-document dicts.
# Flatten everything into individual ability dicts.
abilities: list[dict] = []
# Iterate over data_list
for data in data_list:
# Check: isinstance(data, dict)
if isinstance(data, dict):
# Call abilities.append()
abilities.append(data)
# Alternative: isinstance(data, list)
elif isinstance(data, list):
# Call abilities.extend()
abilities.extend(d for d in data if isinstance(d, dict))
# Iterate over abilities
for data in abilities:
# Assign ability_id = data.get("id", "")
ability_id = data.get("id", "")
# Check: not ability_id
if not ability_id:
# Skip to the next loop iteration
continue
# Assign name = data.get("name", "").strip()
name = data.get("name", "").strip()
# Assign description = data.get("description", "").strip()
description = data.get("description", "").strip()
# Assign tactic = data.get("tactic", "").strip()
tactic = data.get("tactic", "").strip()
# Extract technique info
technique = data.get("technique", {})
# Check: isinstance(technique, dict)
if isinstance(technique, dict):
# Assign attack_id = technique.get("attack_id", "")
attack_id = technique.get("attack_id", "")
# Fallback: handle remaining cases
else:
# Assign attack_id = ""
attack_id = ""
# Check: not attack_id
if not attack_id:
# Skip to the next loop iteration
continue
# Normalise technique ID
attack_id = str(attack_id).strip().upper()
# Check: not attack_id.startswith("T")
if not attack_id.startswith("T"):
# Skip to the next loop iteration
continue
# Extract platforms and commands
platforms_dict = data.get("platforms", {})
# Assign commands = _extract_commands(platforms_dict)
commands = _extract_commands(platforms_dict)
# Assign platform_str = _extract_platforms(platforms_dict)
platform_str = _extract_platforms(platforms_dict)
# Determine executor type
executors = set()
# Check: isinstance(platforms_dict, dict)
if isinstance(platforms_dict, dict):
# Iterate over platforms_dict.values()
for os_executors in platforms_dict.values():
# Check: isinstance(os_executors, dict)
if isinstance(os_executors, dict):
# Call executors.update()
executors.update(os_executors.keys())
# Assign executor_str = ", ".join(sorted(executors)) if executors else None
executor_str = ", ".join(sorted(executors)) if executors else None
# Call results.append()
results.append({
# Literal argument value
"mitre_technique_id": attack_id,
# Literal argument value
"name": f"CALDERA: {name}"[:500] if name else f"CALDERA ability {ability_id}"[:500],
# Literal argument value
"description": f"{description}\n\nTactic: {tactic}".strip()[:2000] if description else None,
# Literal argument value
"source": "caldera",
# Literal argument value
"platform": platform_str,
# Literal argument value
"tool_suggested": executor_str,
# Literal argument value
"attack_procedure": commands[:4000] if commands else None,
# Literal argument value
"atomic_test_id": f"caldera:{ability_id}",
# Literal argument value
"source_url": f"https://github.com/mitre/stockpile/tree/master/data/abilities/{tactic}",
})
# Log info: "Parsed %d CALDERA abilities total", len(results
logger.info("Parsed %d CALDERA abilities total", len(results))
# Return results
return results
@@ -218,46 +331,76 @@ def sync(db: Session) -> dict:
Returns a summary dict with ``created``, ``skipped_existing``, ``total_parsed``.
"""
# Assign tmp_dir = tempfile.mkdtemp(prefix="aegis_caldera_")
tmp_dir = tempfile.mkdtemp(prefix="aegis_caldera_")
# Attempt the following; catch errors below
try:
# Assign zip_bytes = _download_zip()
zip_bytes = _download_zip()
# Assign abilities_dir = _extract_zip(zip_bytes, tmp_dir)
abilities_dir = _extract_zip(zip_bytes, tmp_dir)
# Assign parsed = _parse_abilities(abilities_dir)
parsed = _parse_abilities(abilities_dir)
# Always execute this cleanup block
finally:
# Call shutil.rmtree()
shutil.rmtree(tmp_dir, ignore_errors=True)
# Log info: "Cleaned up temp directory %s", tmp_dir
logger.info("Cleaned up temp directory %s", tmp_dir)
# Pre-load existing for dedup
existing_ids: set[str] = {
row[0]
for row in db.query(TestTemplate.atomic_test_id)
# Chain .filter() call
.filter(TestTemplate.source == "caldera")
# Chain .filter() call
.filter(TestTemplate.atomic_test_id.isnot(None))
# Chain .all() call
.all()
}
# Assign created = 0
created = 0
# Assign skipped = 0
skipped = 0
new_technique_ids: set[str] = set()
# Iterate over parsed
for item in parsed:
# Check: item["atomic_test_id"] in existing_ids
if item["atomic_test_id"] in existing_ids:
# Assign skipped = 1
skipped += 1
# Skip to the next loop iteration
continue
# Assign template = TestTemplate(
template = TestTemplate(
# Keyword argument: mitre_technique_id
mitre_technique_id=item["mitre_technique_id"],
# Keyword argument: name
name=item["name"],
# Keyword argument: description
description=item["description"],
# Keyword argument: source
source=item["source"],
# Keyword argument: source_url
source_url=item["source_url"],
# Keyword argument: attack_procedure
attack_procedure=item["attack_procedure"],
# Keyword argument: platform
platform=item["platform"],
# Keyword argument: tool_suggested
tool_suggested=item["tool_suggested"],
# Keyword argument: atomic_test_id
atomic_test_id=item["atomic_test_id"],
# Keyword argument: is_active
is_active=True,
)
# Stage new record(s) for database insertion
db.add(template)
# Call existing_ids.add()
existing_ids.add(item["atomic_test_id"])
new_technique_ids.add(item["mitre_technique_id"])
created += 1
@@ -269,22 +412,36 @@ def sync(db: Session) -> dict:
db.commit()
# Assign summary = {
summary = {
# Literal argument value
"created": created,
# Literal argument value
"skipped_existing": skipped,
# Literal argument value
"total_parsed": len(parsed),
}
# Update DataSource record
ds = db.query(DataSource).filter(DataSource.name == "caldera").first()
# Check: ds
if ds:
# Assign ds.last_sync_at = datetime.utcnow()
ds.last_sync_at = datetime.utcnow()
# Assign ds.last_sync_status = "success"
ds.last_sync_status = "success"
# Assign ds.last_sync_stats = summary
ds.last_sync_stats = summary
# Commit all pending changes to the database
db.commit()
# Log info: "CALDERA import complete — %s", summary
logger.info("CALDERA import complete — %s", summary)
# Call log_action()
log_action(db, user_id=None, action="import_caldera",
# Keyword argument: entity_type
entity_type="test_template", entity_id=None, details=summary)
# Commit all pending changes to the database
db.commit()
# Return summary
return summary
+296 -8
View File
@@ -4,112 +4,191 @@ Framework-agnostic; uses domain exceptions from app.domain.errors.
The router is responsible for HTTP concerns, auth, audit logging, and commit.
"""
# Import uuid
import uuid
# Import datetime from datetime
from datetime import datetime
# Import Optional from typing
from typing import Optional
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session
# Import from app.domain.errors
from app.domain.errors import (
BusinessRuleViolation,
EntityNotFoundError,
PermissionViolation,
)
# Import Campaign, CampaignTest from app.models.campaign
from app.models.campaign import Campaign, CampaignTest
from app.models.test import Test
# Import Technique from app.models.technique
from app.models.technique import Technique
from app.utils import escape_like
from app.services.campaign_service import (
get_campaign_progress,
validate_no_circular_dependency,
TACTIC_TO_PHASE,
)
# Import Test from app.models.test
from app.models.test import Test
# Import calculate_next_run from app.services.campaign_scheduler_service
from app.services.campaign_scheduler_service import calculate_next_run
from app.services.status_service import recalculate_technique_status
# Import from app.services.campaign_service
from app.services.campaign_service import (
TACTIC_TO_PHASE,
get_campaign_progress,
validate_no_circular_dependency,
)
# Import escape_like from app.utils
from app.utils import escape_like
# ── Serialization helpers ────────────────────────────────────────────────
def serialize_campaign(db: Session, campaign: Campaign) -> dict:
"""Serialize a campaign with its tests and progress."""
# Assign progress = get_campaign_progress(db, campaign.id)
progress = get_campaign_progress(db, campaign.id)
# Assign campaign_tests = (
campaign_tests = (
db.query(CampaignTest)
# Chain .filter() call
.filter(CampaignTest.campaign_id == campaign.id)
# Chain .order_by() call
.order_by(CampaignTest.order_index)
# Chain .all() call
.all()
)
# Assign tests = []
tests = []
# Iterate over campaign_tests
for ct in campaign_tests:
# Assign test = ct.test
test = ct.test
# Assign technique = db.query(Technique).filter(Technique.id == test.technique_id).first...
technique = db.query(Technique).filter(Technique.id == test.technique_id).first() if test else None
# Call tests.append()
tests.append({
# Literal argument value
"id": str(ct.id),
# Literal argument value
"test_id": str(ct.test_id),
# Literal argument value
"order_index": ct.order_index,
# Literal argument value
"depends_on": str(ct.depends_on) if ct.depends_on else None,
# Literal argument value
"phase": ct.phase,
# Literal argument value
"test_name": test.name if test else None,
# Literal argument value
"test_state": test.state.value if test and test.state else None,
# Literal argument value
"test_result": test.result.value if test and test.result else None,
# Literal argument value
"technique_mitre_id": technique.mitre_id if technique else None,
# Literal argument value
"technique_name": technique.name if technique else None,
# Literal argument value
"platform": test.platform if test else None,
})
# Assign actor = campaign.threat_actor
actor = campaign.threat_actor
# Return {
return {
# Literal argument value
"id": str(campaign.id),
# Literal argument value
"name": campaign.name,
# Literal argument value
"description": campaign.description,
# Literal argument value
"type": campaign.type,
# Literal argument value
"status": campaign.status,
# Literal argument value
"threat_actor_id": str(campaign.threat_actor_id) if campaign.threat_actor_id else None,
# Literal argument value
"threat_actor_name": actor.name if actor else None,
# Literal argument value
"created_by": str(campaign.created_by) if campaign.created_by else None,
"start_date": campaign.start_date.isoformat() if campaign.start_date else None,
"scheduled_at": campaign.scheduled_at.isoformat() if campaign.scheduled_at else None,
# Literal argument value
"completed_at": campaign.completed_at.isoformat() if campaign.completed_at else None,
# Literal argument value
"target_platform": campaign.target_platform,
# Literal argument value
"tags": campaign.tags or [],
# Literal argument value
"created_at": campaign.created_at.isoformat() if campaign.created_at else None,
# Literal argument value
"is_recurring": campaign.is_recurring or False,
# Literal argument value
"recurrence_pattern": campaign.recurrence_pattern,
# Literal argument value
"next_run_at": campaign.next_run_at.isoformat() if campaign.next_run_at else None,
# Literal argument value
"last_run_at": campaign.last_run_at.isoformat() if campaign.last_run_at else None,
# Literal argument value
"parent_campaign_id": str(campaign.parent_campaign_id) if campaign.parent_campaign_id else None,
# Literal argument value
"tests": tests,
# Literal argument value
"progress": progress,
}
# Define function serialize_campaign_summary
def serialize_campaign_summary(db: Session, campaign: Campaign) -> dict:
"""Lightweight campaign serialization for list views."""
# Assign progress = get_campaign_progress(db, campaign.id)
progress = get_campaign_progress(db, campaign.id)
# Assign actor = campaign.threat_actor
actor = campaign.threat_actor
# Return {
return {
# Literal argument value
"id": str(campaign.id),
# Literal argument value
"name": campaign.name,
# Literal argument value
"description": campaign.description,
# Literal argument value
"type": campaign.type,
# Literal argument value
"status": campaign.status,
# Literal argument value
"threat_actor_id": str(campaign.threat_actor_id) if campaign.threat_actor_id else None,
# Literal argument value
"threat_actor_name": actor.name if actor else None,
"start_date": campaign.start_date.isoformat() if campaign.start_date else None,
"target_platform": campaign.target_platform,
# Literal argument value
"tags": campaign.tags or [],
# Literal argument value
"created_at": campaign.created_at.isoformat() if campaign.created_at else None,
# Literal argument value
"is_recurring": campaign.is_recurring or False,
# Literal argument value
"recurrence_pattern": campaign.recurrence_pattern,
# Literal argument value
"next_run_at": campaign.next_run_at.isoformat() if campaign.next_run_at else None,
# Literal argument value
"last_run_at": campaign.last_run_at.isoformat() if campaign.last_run_at else None,
# Literal argument value
"test_count": progress["total"],
# Literal argument value
"completion_pct": progress["completion_pct"],
}
@@ -118,122 +197,198 @@ def serialize_campaign_summary(db: Session, campaign: Campaign) -> dict:
def list_campaigns(
# Entry: db
db: Session,
*,
# Entry: type
type: Optional[str] = None,
# Entry: status
status: Optional[str] = None,
# Entry: threat_actor_id
threat_actor_id: Optional[str] = None,
# Entry: search
search: Optional[str] = None,
# Entry: offset
offset: int = 0,
# Entry: limit
limit: int = 50,
) -> dict:
"""Return a paginated list of campaigns with optional filters."""
# Assign query = db.query(Campaign)
query = db.query(Campaign)
# Check: type
if type:
# Assign query = query.filter(Campaign.type == type)
query = query.filter(Campaign.type == type)
# Check: status
if status:
# Assign query = query.filter(Campaign.status == status)
query = query.filter(Campaign.status == status)
# Check: threat_actor_id
if threat_actor_id:
# Assign query = query.filter(Campaign.threat_actor_id == threat_actor_id)
query = query.filter(Campaign.threat_actor_id == threat_actor_id)
# Check: search
if search:
# Assign pattern = f"%{escape_like(search)}%"
pattern = f"%{escape_like(search)}%"
# Assign query = query.filter(Campaign.name.ilike(pattern) | Campaign.description.il...
query = query.filter(Campaign.name.ilike(pattern) | Campaign.description.ilike(pattern))
# Assign total = query.count()
total = query.count()
# Assign campaigns = query.order_by(Campaign.created_at.desc()).offset(offset).limit(lim...
campaigns = query.order_by(Campaign.created_at.desc()).offset(offset).limit(limit).all()
# Return {
return {
# Literal argument value
"total": total,
# Literal argument value
"offset": offset,
# Literal argument value
"limit": limit,
# Literal argument value
"items": [serialize_campaign_summary(db, c) for c in campaigns],
}
# Define function create_campaign
def create_campaign(
# Entry: db
db: Session,
*,
# Entry: creator_id
creator_id: uuid.UUID,
# Entry: name
name: str,
# Entry: description
description: Optional[str] = None,
# Entry: type
type: str = "custom",
# Entry: threat_actor_id
threat_actor_id: Optional[str] = None,
# Entry: target_platform
target_platform: Optional[str] = None,
# Entry: tags
tags: Optional[list[str]] = None,
# Entry: scheduled_at
scheduled_at: Optional[str] = None,
start_date: Optional[str] = None,
) -> dict:
"""Create a new campaign. Does not commit; caller commits."""
# Assign campaign = Campaign(
campaign = Campaign(
# Keyword argument: name
name=name,
# Keyword argument: description
description=description,
# Keyword argument: type
type=type,
# Keyword argument: threat_actor_id
threat_actor_id=uuid.UUID(threat_actor_id) if threat_actor_id else None,
# Keyword argument: target_platform
target_platform=target_platform,
# Keyword argument: tags
tags=tags or [],
# Keyword argument: created_by
created_by=creator_id,
# Keyword argument: scheduled_at
scheduled_at=datetime.fromisoformat(scheduled_at) if scheduled_at else None,
start_date=datetime.fromisoformat(start_date) if start_date else None,
)
# Stage new record(s) for database insertion
db.add(campaign)
# Flush changes to DB without committing the transaction
db.flush()
# Return serialize_campaign(db, campaign)
return serialize_campaign(db, campaign)
# Define function get_campaign_detail
def get_campaign_detail(db: Session, campaign_id: str) -> dict:
"""Get detailed campaign info including tests and progress.
Raises EntityNotFoundError if campaign not found.
"""
# Assign campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
# Check: not campaign
if not campaign:
# Raise EntityNotFoundError
raise EntityNotFoundError("Campaign", campaign_id)
# Return serialize_campaign(db, campaign)
return serialize_campaign(db, campaign)
# Define function update_campaign
def update_campaign(
# Entry: db
db: Session,
# Entry: campaign_id
campaign_id: str,
*,
# Entry: updater_id
updater_id: uuid.UUID,
# Entry: updater_role
updater_role: str,
**fields,
**fields: object,
) -> dict:
"""Update a campaign. Only allowed in draft or active state.
Raises EntityNotFoundError, BusinessRuleViolation, or PermissionViolation.
Does not commit; caller commits.
"""
# Assign campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
# Check: not campaign
if not campaign:
# Raise EntityNotFoundError
raise EntityNotFoundError("Campaign", campaign_id)
# Check: campaign.status not in ("draft", "active")
if campaign.status not in ("draft", "active"):
# Raise BusinessRuleViolation
raise BusinessRuleViolation("Can only update draft or active campaigns")
# Check: str(campaign.created_by) != str(updater_id) and updater_role != "ad...
if str(campaign.created_by) != str(updater_id) and updater_role != "admin":
# Raise PermissionViolation
raise PermissionViolation("Only the creator or admin can update this campaign")
# Check: "scheduled_at" in fields and fields["scheduled_at"]
if "scheduled_at" in fields and fields["scheduled_at"]:
# Assign fields["scheduled_at"] = datetime.fromisoformat(fields["scheduled_at"])
fields["scheduled_at"] = datetime.fromisoformat(fields["scheduled_at"])
if "start_date" in fields and fields["start_date"]:
fields["start_date"] = datetime.fromisoformat(fields["start_date"])
# Iterate over fields.items()
for field, value in fields.items():
# Call setattr()
setattr(campaign, field, value)
# Flush changes to DB without committing the transaction
db.flush()
# Return serialize_campaign(db, campaign)
return serialize_campaign(db, campaign)
# Define function add_test_to_campaign
def add_test_to_campaign(
# Entry: db
db: Session,
# Entry: campaign_id
campaign_id: str,
*,
# Entry: test_id
test_id: str,
# Entry: order_index
order_index: Optional[int] = None,
# Entry: depends_on
depends_on: Optional[str] = None,
# Entry: phase
phase: Optional[str] = None,
) -> dict:
"""Add a test to a campaign with optional ordering and dependency.
@@ -242,60 +397,101 @@ def add_test_to_campaign(
Raises BusinessRuleViolation for invalid state or circular dependency.
Does not commit; caller commits.
"""
# Assign campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
# Check: not campaign
if not campaign:
# Raise EntityNotFoundError
raise EntityNotFoundError("Campaign", campaign_id)
# Check: campaign.status not in ("draft", "active")
if campaign.status not in ("draft", "active"):
# Raise BusinessRuleViolation
raise BusinessRuleViolation("Can only add tests to draft or active campaigns")
# Assign test = db.query(Test).filter(Test.id == test_id).first()
test = db.query(Test).filter(Test.id == test_id).first()
# Check: not test
if not test:
# Raise EntityNotFoundError
raise EntityNotFoundError("Test", test_id)
# Check: order_index is not None
if order_index is not None:
# Assign final_order_index = order_index
final_order_index = order_index
# Fallback: handle remaining cases
else:
# Assign max_order = (
max_order = (
db.query(CampaignTest.order_index)
# Chain .filter() call
.filter(CampaignTest.campaign_id == campaign_id)
# Chain .order_by() call
.order_by(CampaignTest.order_index.desc())
# Chain .first() call
.first()
)
# Assign final_order_index = (max_order[0] + 1) if max_order else 0
final_order_index = (max_order[0] + 1) if max_order else 0
# Assign depends_on_uuid = uuid.UUID(depends_on) if depends_on else None
depends_on_uuid = uuid.UUID(depends_on) if depends_on else None
# Assign ct_id = uuid.uuid4()
ct_id = uuid.uuid4()
# Check: depends_on_uuid
if depends_on_uuid:
# Call validate_no_circular_dependency()
validate_no_circular_dependency(db, uuid.UUID(campaign_id), ct_id, depends_on_uuid)
# Check: not phase and test.technique_id
if not phase and test.technique_id:
# Assign technique = db.query(Technique).filter(Technique.id == test.technique_id).first()
technique = db.query(Technique).filter(Technique.id == test.technique_id).first()
# Check: technique and technique.tactic
if technique and technique.tactic:
# Assign phase = TACTIC_TO_PHASE.get(technique.tactic, None)
phase = TACTIC_TO_PHASE.get(technique.tactic, None)
# Assign campaign_test = CampaignTest(
campaign_test = CampaignTest(
# Keyword argument: id
id=ct_id,
# Keyword argument: campaign_id
campaign_id=campaign_id,
# Keyword argument: test_id
test_id=test_id,
# Keyword argument: order_index
order_index=final_order_index,
# Keyword argument: depends_on
depends_on=depends_on_uuid,
# Keyword argument: phase
phase=phase,
)
# Stage new record(s) for database insertion
db.add(campaign_test)
# Flush changes to DB without committing the transaction
db.flush()
# Return {
return {
# Literal argument value
"id": str(campaign_test.id),
# Literal argument value
"campaign_id": str(campaign_test.campaign_id),
# Literal argument value
"test_id": str(campaign_test.test_id),
# Literal argument value
"order_index": campaign_test.order_index,
# Literal argument value
"depends_on": str(campaign_test.depends_on) if campaign_test.depends_on else None,
# Literal argument value
"phase": campaign_test.phase,
}
# Define function remove_test_from_campaign
def remove_test_from_campaign(db: Session, campaign_id: str, campaign_test_id: str) -> None:
"""Remove a test from a campaign.
@@ -303,27 +499,41 @@ def remove_test_from_campaign(db: Session, campaign_id: str, campaign_test_id: s
Raises BusinessRuleViolation for invalid state.
Does not commit; caller commits.
"""
# Assign campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
# Check: not campaign
if not campaign:
# Raise EntityNotFoundError
raise EntityNotFoundError("Campaign", campaign_id)
# Check: campaign.status not in ("draft", "active")
if campaign.status not in ("draft", "active"):
# Raise BusinessRuleViolation
raise BusinessRuleViolation("Can only modify draft or active campaigns")
# Assign ct = (
ct = (
db.query(CampaignTest)
# Chain .filter() call
.filter(
CampaignTest.id == campaign_test_id,
CampaignTest.campaign_id == campaign_id,
)
# Chain .first() call
.first()
)
# Check: not ct
if not ct:
# Raise EntityNotFoundError
raise EntityNotFoundError("CampaignTest", campaign_test_id)
# Assign dep_id = uuid.UUID(campaign_test_id)
dep_id = uuid.UUID(campaign_test_id)
# Assign dependents = db.query(CampaignTest).filter(CampaignTest.depends_on == dep_id).all()
dependents = db.query(CampaignTest).filter(CampaignTest.depends_on == dep_id).all()
# Iterate over dependents
for dep in dependents:
# Assign dep.depends_on = None
dep.depends_on = None
# Keep a reference to the underlying test before deleting the join record
@@ -334,6 +544,7 @@ def remove_test_from_campaign(db: Session, campaign_id: str, campaign_test_id: s
technique_id = test_obj.technique_id
db.delete(ct)
# Flush changes to DB without committing the transaction
db.flush()
# Also delete the actual test record (it was created for this campaign)
@@ -349,72 +560,110 @@ def remove_test_from_campaign(db: Session, campaign_id: str, campaign_test_id: s
db.flush()
# Define function activate_campaign
def activate_campaign(db: Session, campaign_id: str) -> Campaign:
"""Activate a campaign, moving it from draft to active.
Raises EntityNotFoundError, BusinessRuleViolation.
Does not commit; caller commits.
"""
# Assign campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
# Check: not campaign
if not campaign:
# Raise EntityNotFoundError
raise EntityNotFoundError("Campaign", campaign_id)
# Check: campaign.status != "draft"
if campaign.status != "draft":
# Raise BusinessRuleViolation
raise BusinessRuleViolation("Only draft campaigns can be activated")
# Assign test_count = db.query(CampaignTest).filter(CampaignTest.campaign_id == campaign_...
test_count = db.query(CampaignTest).filter(CampaignTest.campaign_id == campaign_id).count()
# Check: test_count == 0
if test_count == 0:
# Raise BusinessRuleViolation
raise BusinessRuleViolation("Campaign must have at least one test to activate")
# Assign campaign.status = "active"
campaign.status = "active"
# Flush changes to DB without committing the transaction
db.flush()
# Return campaign
return campaign
# Define function complete_campaign
def complete_campaign(db: Session, campaign_id: str) -> Campaign:
"""Mark a campaign as completed.
Raises EntityNotFoundError, BusinessRuleViolation.
Does not commit; caller commits.
"""
# Assign campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
# Check: not campaign
if not campaign:
# Raise EntityNotFoundError
raise EntityNotFoundError("Campaign", campaign_id)
# Check: campaign.status != "active"
if campaign.status != "active":
# Raise BusinessRuleViolation
raise BusinessRuleViolation("Only active campaigns can be completed")
# Assign campaign.status = "completed"
campaign.status = "completed"
# Assign campaign.completed_at = datetime.utcnow()
campaign.completed_at = datetime.utcnow()
# Flush changes to DB without committing the transaction
db.flush()
# Return campaign
return campaign
# Define function get_campaign_progress_data
def get_campaign_progress_data(db: Session, campaign_id: str) -> dict:
"""Get progress statistics for a campaign.
Raises EntityNotFoundError if campaign not found.
"""
# Assign campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
# Check: not campaign
if not campaign:
# Raise EntityNotFoundError
raise EntityNotFoundError("Campaign", campaign_id)
# Assign progress = get_campaign_progress(db, uuid.UUID(campaign_id))
progress = get_campaign_progress(db, uuid.UUID(campaign_id))
# Return {
return {
# Literal argument value
"campaign_id": str(campaign.id),
# Literal argument value
"campaign_name": campaign.name,
**progress,
}
# Define function schedule_campaign
def schedule_campaign(
# Entry: db
db: Session,
# Entry: campaign_id
campaign_id: str,
*,
# Entry: owner_id
owner_id: uuid.UUID,
# Entry: owner_role
owner_role: str,
# Entry: is_recurring
is_recurring: bool,
# Entry: recurrence_pattern
recurrence_pattern: Optional[str] = None,
# Entry: next_run_at
next_run_at: Optional[str] = None,
) -> Campaign:
"""Configure or update the recurrence schedule for a campaign.
@@ -422,32 +671,52 @@ def schedule_campaign(
Raises EntityNotFoundError, PermissionViolation, BusinessRuleViolation.
Does not commit; caller commits.
"""
# Assign campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
# Check: not campaign
if not campaign:
# Raise EntityNotFoundError
raise EntityNotFoundError("Campaign", campaign_id)
# Check: str(campaign.created_by) != str(owner_id) and owner_role != "admin"
if str(campaign.created_by) != str(owner_id) and owner_role != "admin":
# Raise PermissionViolation
raise PermissionViolation("Only the creator or admin can configure scheduling")
# Assign campaign.is_recurring = is_recurring
campaign.is_recurring = is_recurring
# Check: is_recurring
if is_recurring:
# Check: recurrence_pattern not in ("weekly", "monthly", "quarterly")
if recurrence_pattern not in ("weekly", "monthly", "quarterly"):
# Raise BusinessRuleViolation
raise BusinessRuleViolation(
# Literal argument value
"recurrence_pattern must be 'weekly', 'monthly', or 'quarterly'"
)
# Assign campaign.recurrence_pattern = recurrence_pattern
campaign.recurrence_pattern = recurrence_pattern
# Check: next_run_at
if next_run_at:
# Assign campaign.next_run_at = datetime.fromisoformat(
campaign.next_run_at = datetime.fromisoformat(
next_run_at.replace("Z", "+00:00").replace("+00:00", "")
)
# Alternative: not campaign.next_run_at
elif not campaign.next_run_at:
# Assign campaign.next_run_at = calculate_next_run(datetime.utcnow(), recurrence_pattern)
campaign.next_run_at = calculate_next_run(datetime.utcnow(), recurrence_pattern)
# Fallback: handle remaining cases
else:
# Assign campaign.recurrence_pattern = None
campaign.recurrence_pattern = None
# Assign campaign.next_run_at = None
campaign.next_run_at = None
# Flush changes to DB without committing the transaction
db.flush()
# Return campaign
return campaign
@@ -522,29 +791,48 @@ def get_campaign_history(db: Session, campaign_id: str) -> dict:
Raises EntityNotFoundError if campaign not found.
"""
# Assign campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
# Check: not campaign
if not campaign:
# Raise EntityNotFoundError
raise EntityNotFoundError("Campaign", campaign_id)
# Assign campaign_uuid = uuid.UUID(campaign_id)
campaign_uuid = uuid.UUID(campaign_id)
# Assign children = (
children = (
db.query(Campaign)
# Chain .filter() call
.filter(Campaign.parent_campaign_id == campaign_uuid)
# Chain .order_by() call
.order_by(Campaign.created_at.desc())
# Chain .all() call
.all()
)
# Return {
return {
# Literal argument value
"campaign_id": str(campaign.id),
# Literal argument value
"campaign_name": campaign.name,
# Literal argument value
"items": [
{
# Literal argument value
"id": str(child.id),
# Literal argument value
"name": child.name,
# Literal argument value
"status": child.status,
# Literal argument value
"test_count": db.query(CampaignTest).filter(CampaignTest.campaign_id == child.id).count(),
# Literal argument value
"completion_pct": get_campaign_progress(db, child.id)["completion_pct"],
# Literal argument value
"created_at": child.created_at.isoformat() if child.created_at else None,
# Literal argument value
"completed_at": child.completed_at.isoformat() if child.completed_at else None,
}
for child in children
@@ -4,19 +4,34 @@ Handles checking which recurring campaigns are due, cloning them with
fresh tests, and computing the next run date.
"""
# Import logging
import logging
import uuid
# Import datetime, timedelta from datetime
from datetime import datetime, timedelta
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session
# Import Campaign, CampaignTest from app.models.campaign
from app.models.campaign import Campaign, CampaignTest
from app.models.test import Test
# Import TestState from app.models.enums
from app.models.enums import TestState
from app.services.notification_service import create_notification
from app.services.audit_service import log_action
# Import Test from app.models.test
from app.models.test import Test
# Import User from app.models.user
from app.models.user import User
# Import log_action from app.services.audit_service
from app.services.audit_service import log_action
# Import create_notification from app.services.notification_service
from app.services.notification_service import create_notification
# Assign logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
@@ -33,11 +48,16 @@ def calculate_next_run(current_date: datetime, pattern: str) -> datetime:
- ``monthly`` : +30 days
- ``quarterly``: +90 days
"""
# Assign offsets = {
offsets = {
# Literal argument value
"weekly": timedelta(days=7),
# Literal argument value
"monthly": timedelta(days=30),
# Literal argument value
"quarterly": timedelta(days=90),
}
# Return current_date + offsets.get(pattern, timedelta(days=30))
return current_date + offsets.get(pattern, timedelta(days=30))
@@ -54,59 +74,99 @@ def _clone_campaign(db: Session, original: Campaign) -> Campaign:
with the same base data (in ``draft`` state) and link it.
3. Activate the new campaign.
"""
# Assign now = datetime.utcnow()
now = datetime.utcnow()
# Assign run_label = now.strftime("%Y-%m-%d")
run_label = now.strftime("%Y-%m-%d")
# Assign child = Campaign(
child = Campaign(
# Keyword argument: name
name=f"{original.name} (Run {run_label})",
# Keyword argument: description
description=original.description,
# Keyword argument: type
type=original.type,
# Keyword argument: threat_actor_id
threat_actor_id=original.threat_actor_id,
# Keyword argument: status
status="active",
# Keyword argument: created_by
created_by=original.created_by,
# Keyword argument: target_platform
target_platform=original.target_platform,
# Keyword argument: tags
tags=original.tags or [],
# Keyword argument: parent_campaign_id
parent_campaign_id=original.id,
)
# Stage new record(s) for database insertion
db.add(child)
# Flush changes to DB without committing the transaction
db.flush() # get child.id
# Clone each campaign_test with a fresh Test
original_cts = (
db.query(CampaignTest)
# Chain .filter() call
.filter(CampaignTest.campaign_id == original.id)
# Chain .order_by() call
.order_by(CampaignTest.order_index)
# Chain .all() call
.all()
)
# Iterate over original_cts
for ct in original_cts:
# Assign src_test = ct.test
src_test = ct.test
# Check: not src_test
if not src_test:
# Skip to the next loop iteration
continue
# Assign new_test = Test(
new_test = Test(
# Keyword argument: technique_id
technique_id=src_test.technique_id,
# Keyword argument: name
name=src_test.name,
# Keyword argument: description
description=src_test.description,
# Keyword argument: platform
platform=src_test.platform,
# Keyword argument: procedure_text
procedure_text=src_test.procedure_text,
# Keyword argument: tool_used
tool_used=src_test.tool_used,
# Keyword argument: created_by
created_by=original.created_by,
# Keyword argument: state
state=TestState.draft,
)
# Stage new record(s) for database insertion
db.add(new_test)
# Flush changes to DB without committing the transaction
db.flush() # get new_test.id
# Assign new_ct = CampaignTest(
new_ct = CampaignTest(
# Keyword argument: campaign_id
campaign_id=child.id,
# Keyword argument: test_id
test_id=new_test.id,
# Keyword argument: order_index
order_index=ct.order_index,
# Keyword argument: phase
phase=ct.phase,
# depends_on is not copied — would need ID remapping
)
# Stage new record(s) for database insertion
db.add(new_ct)
# Flush changes to DB without committing the transaction
db.flush()
# Return child
return child
@@ -120,75 +180,119 @@ def check_and_run_recurring_campaigns(db: Session) -> int:
Returns the number of campaigns spawned.
"""
# Assign now = datetime.utcnow()
now = datetime.utcnow()
# Assign due_campaigns = (
due_campaigns = (
db.query(Campaign)
# Chain .filter() call
.filter(
Campaign.is_recurring == True, # noqa: E712
Campaign.next_run_at <= now,
)
# Chain .all() call
.all()
)
# Assign spawned = 0
spawned = 0
# Iterate over due_campaigns
for campaign in due_campaigns:
# Attempt the following; catch errors below
try:
# Assign child = _clone_campaign(db, campaign)
child = _clone_campaign(db, campaign)
# Update the original's scheduling fields
campaign.last_run_at = now
# Assign campaign.next_run_at = calculate_next_run(now, campaign.recurrence_pattern or "monthly")
campaign.next_run_at = calculate_next_run(now, campaign.recurrence_pattern or "monthly")
# Commit all pending changes to the database
db.commit()
# Reload ORM object attributes from the database
db.refresh(child)
# Audit
log_action(
db,
# Keyword argument: user_id
user_id=campaign.created_by,
# Keyword argument: action
action="recurring_campaign_run",
# Keyword argument: entity_type
entity_type="campaign",
# Keyword argument: entity_id
entity_id=child.id,
# Keyword argument: details
details={
# Literal argument value
"parent_campaign_id": str(campaign.id),
# Literal argument value
"child_campaign_name": child.name,
# Literal argument value
"pattern": campaign.recurrence_pattern,
},
)
# Commit all pending changes to the database
db.commit()
# Notify
if campaign.created_by:
# Call create_notification()
create_notification(
db,
# Keyword argument: user_id
user_id=campaign.created_by,
# Keyword argument: type
type="recurring_campaign_run",
# Keyword argument: title
title="Recurring campaign executed",
message=f'Campaign "{child.name}" was automatically created from recurring template "{campaign.name}".',
# Keyword argument: message
message=(
f'Campaign "{child.name}" was automatically created '
f'from recurring template "{campaign.name}".'
),
# Keyword argument: entity_type
entity_type="campaign",
# Keyword argument: entity_id
entity_id=child.id,
)
# Notify red_tech users
red_techs = db.query(User).filter(User.role == "red_tech", User.is_active == True).all() # noqa: E712
# Iterate over red_techs
for user in red_techs:
# Call create_notification()
create_notification(
db,
# Keyword argument: user_id
user_id=user.id,
# Keyword argument: type
type="campaign_activated",
# Keyword argument: title
title="New recurring campaign active",
# Keyword argument: message
message=f'Campaign "{child.name}" is now active and ready for execution.',
# Keyword argument: entity_type
entity_type="campaign",
# Keyword argument: entity_id
entity_id=child.id,
)
# Assign spawned = 1
spawned += 1
# Log info: "Spawned child campaign '%s' from parent '%s'", ch
logger.info("Spawned child campaign '%s' from parent '%s'", child.name, campaign.name)
# Handle Exception
except Exception:
# Roll back all uncommitted changes
db.rollback()
# Log exception: "Failed to run recurring campaign '%s'", campaign.
logger.exception("Failed to run recurring campaign '%s'", campaign.name)
# Return spawned
return spawned
+128 -6
View File
@@ -4,108 +4,183 @@ Handles circular dependency validation, campaign generation from
threat actors, and progress calculation.
"""
# Import logging
import logging
import uuid
from datetime import datetime
from typing import Optional
# Import uuid
import uuid
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session
# Import EntityNotFoundError, InvalidOperationError from app.domain.exceptions
from app.domain.exceptions import EntityNotFoundError, InvalidOperationError
from app.models.campaign import Campaign, CampaignTest, KILL_CHAIN_PHASES
from app.models.test import Test
from app.models.test_template import TestTemplate
from app.models.technique import Technique
from app.models.threat_actor import ThreatActor, ThreatActorTechnique
# Import Campaign, CampaignTest from app.models.campaign
from app.models.campaign import Campaign, CampaignTest
# Import TechniqueStatus, TestState from app.models.enums
from app.models.enums import TechniqueStatus, TestState
from app.services.notification_service import create_notification
# Import Technique from app.models.technique
from app.models.technique import Technique
# Import Test from app.models.test
from app.models.test import Test
# Import TestTemplate from app.models.test_template
from app.models.test_template import TestTemplate
# Import ThreatActor, ThreatActorTechnique from app.models.threat_actor
from app.models.threat_actor import ThreatActor, ThreatActorTechnique
# Import User from app.models.user
from app.models.user import User
# Assign logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
# Mapping from ATT&CK tactics to kill chain phases
TACTIC_TO_PHASE: dict[str, str] = {
# Literal argument value
"reconnaissance": "reconnaissance",
# Literal argument value
"resource-development": "resource_development",
# Literal argument value
"initial-access": "initial_access",
# Literal argument value
"execution": "execution",
# Literal argument value
"persistence": "persistence",
# Literal argument value
"privilege-escalation": "privilege_escalation",
# Literal argument value
"defense-evasion": "defense_evasion",
# Literal argument value
"credential-access": "credential_access",
# Literal argument value
"discovery": "discovery",
# Literal argument value
"lateral-movement": "lateral_movement",
# Literal argument value
"collection": "collection",
# Literal argument value
"command-and-control": "command_and_control",
# Literal argument value
"exfiltration": "exfiltration",
# Literal argument value
"impact": "impact",
}
# Define function validate_no_circular_dependency
def validate_no_circular_dependency(
# Entry: db
db: Session,
# Entry: campaign_id
campaign_id: uuid.UUID,
# Entry: test_id
test_id: uuid.UUID,
# Entry: depends_on_id
depends_on_id: uuid.UUID | None,
) -> None:
"""Walk the depends_on chain and verify no cycle is formed.
Raises :class:`InvalidOperationError` if a circular dependency is detected.
"""
# Check: depends_on_id is None
if depends_on_id is None:
# Return control to caller
return
# Assign visited = set()
visited: set[uuid.UUID] = set()
# Assign current = depends_on_id
current = depends_on_id
# Loop while current is not None
while current is not None:
# Check: current in visited or current == test_id
if current in visited or current == test_id:
# Raise InvalidOperationError
raise InvalidOperationError(
# Literal argument value
"Circular dependency detected in campaign test chain"
)
# Call visited.add()
visited.add(current)
# Assign parent = db.query(CampaignTest).filter_by(id=current).first()
parent = db.query(CampaignTest).filter_by(id=current).first()
# Assign current = parent.depends_on if parent else None
current = parent.depends_on if parent else None
# Define function get_campaign_progress
def get_campaign_progress(db: Session, campaign_id: uuid.UUID) -> dict:
"""Calculate progress statistics for a campaign.
Returns counts of tests by state, plus total and completion percentage.
"""
# Assign campaign_tests = (
campaign_tests = (
db.query(CampaignTest)
# Chain .filter() call
.filter(CampaignTest.campaign_id == campaign_id)
# Chain .all() call
.all()
)
# Check: not campaign_tests
if not campaign_tests:
# Return {
return {
# Literal argument value
"total": 0,
# Literal argument value
"by_state": {},
# Literal argument value
"completion_pct": 0.0,
}
# Assign by_state = {}
by_state: dict[str, int] = {}
# Iterate over campaign_tests
for ct in campaign_tests:
# Assign test = ct.test
test = ct.test
# Assign state = test.state.value if test and test.state else "unknown"
state = test.state.value if test and test.state else "unknown"
# Assign by_state[state] = by_state.get(state, 0) + 1
by_state[state] = by_state.get(state, 0) + 1
# Assign total = len(campaign_tests)
total = len(campaign_tests)
# Assign completed = by_state.get("validated", 0)
completed = by_state.get("validated", 0)
# Assign completion_pct = round(completed / total * 100, 1) if total > 0 else 0.0
completion_pct = round(completed / total * 100, 1) if total > 0 else 0.0
# Return {
return {
# Literal argument value
"total": total,
# Literal argument value
"by_state": by_state,
# Literal argument value
"completion_pct": completion_pct,
}
# Define function generate_campaign_from_threat_actor
def generate_campaign_from_threat_actor(
# Entry: db
db: Session,
# Entry: actor_id
actor_id: uuid.UUID,
# Entry: user
user: User,
*,
start_date: Optional[datetime] = None,
@@ -119,75 +194,111 @@ def generate_campaign_from_threat_actor(
4. Create a campaign with tests ordered by kill chain phase
5. Return the campaign
"""
# Assign actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first()
actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first()
# Check: not actor
if not actor:
# Raise EntityNotFoundError
raise EntityNotFoundError("ThreatActor", str(actor_id))
# Get unvalidated techniques for this actor
gap_techniques = (
db.query(Technique, ThreatActorTechnique)
# Chain .join() call
.join(ThreatActorTechnique, ThreatActorTechnique.technique_id == Technique.id)
# Chain .filter() call
.filter(ThreatActorTechnique.threat_actor_id == actor_id)
# Chain .filter() call
.filter(Technique.status_global != TechniqueStatus.validated)
# Chain .order_by() call
.order_by(Technique.tactic, Technique.mitre_id)
# Chain .all() call
.all()
)
# Check: not gap_techniques
if not gap_techniques:
# Raise InvalidOperationError
raise InvalidOperationError(
f"No uncovered techniques found for {actor.name}"
)
# Create the campaign
campaign = Campaign(
# Keyword argument: name
name=f"APT Emulation: {actor.name}",
# Keyword argument: description
description=f"Auto-generated campaign to test coverage against {actor.name} "
f"({actor.mitre_id or 'unknown'}). "
f"Covers {len(gap_techniques)} uncovered technique(s).",
# Keyword argument: type
type="apt_emulation",
# Keyword argument: threat_actor_id
threat_actor_id=actor_id,
# Keyword argument: status
status="draft",
# Keyword argument: created_by
created_by=user.id,
# Keyword argument: tags
tags=[actor.name, "auto-generated"],
start_date=start_date,
)
# Stage new record(s) for database insertion
db.add(campaign)
# Flush changes to DB without committing the transaction
db.flush() # Get campaign.id
# Assign order_index = 0
order_index = 0
# Iterate over gap_techniques
for tech, _at in gap_techniques:
# Find best template for this technique
template = (
db.query(TestTemplate)
# Chain .filter() call
.filter(
TestTemplate.mitre_technique_id == tech.mitre_id,
TestTemplate.is_active == True, # noqa: E712
)
# Chain .order_by() call
.order_by(
# Prioritize by severity: critical > high > medium > low
TestTemplate.severity.desc(),
TestTemplate.name,
)
# Chain .first() call
.first()
)
# Check: not template
if not template:
# continue # Skip techniques without templates
continue # Skip techniques without templates
# Create a test from the template
test = Test(
# Keyword argument: technique_id
technique_id=tech.id,
# Keyword argument: name
name=f"[Campaign] {template.name}",
# Keyword argument: description
description=template.description,
# Keyword argument: platform
platform=template.platform,
# Keyword argument: procedure_text
procedure_text=template.attack_procedure,
# Keyword argument: tool_used
tool_used=template.tool_suggested,
# Keyword argument: created_by
created_by=user.id,
# Keyword argument: state
state=TestState.draft,
created_at=datetime.utcnow(),
)
# Stage new record(s) for database insertion
db.add(test)
# Flush changes to DB without committing the transaction
db.flush() # Get test.id
# Determine kill chain phase from the technique's tactic
@@ -195,22 +306,33 @@ def generate_campaign_from_threat_actor(
# Add to campaign
campaign_test = CampaignTest(
# Keyword argument: campaign_id
campaign_id=campaign.id,
# Keyword argument: test_id
test_id=test.id,
# Keyword argument: order_index
order_index=order_index,
# Keyword argument: phase
phase=phase,
)
# Stage new record(s) for database insertion
db.add(campaign_test)
# Assign order_index = 1
order_index += 1
# Commit all pending changes to the database
db.commit()
# Reload ORM object attributes from the database
db.refresh(campaign)
# Log info:
logger.info(
# Literal argument value
"Generated campaign '%s' with %d tests for actor %s",
campaign.name,
order_index,
actor.name,
)
# Return campaign
return campaign
File diff suppressed because it is too large Load Diff
+196 -5
View File
@@ -6,111 +6,184 @@ that the router remains a thin HTTP adapter.
This module is framework-agnostic: no FastAPI imports.
"""
# Enable future language features for compatibility
from __future__ import annotations
# Import csv
import csv
# Import io
import io
# Import Any from typing
from typing import Any
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session
# Import EntityNotFoundError from app.domain.errors
from app.domain.errors import EntityNotFoundError
# Import from app.models.compliance
from app.models.compliance import (
ComplianceFramework,
ComplianceControl,
ComplianceControlMapping,
ComplianceFramework,
)
from app.models.technique import Technique
from app.models.test_template import TestTemplate
from app.models.threat_actor import ThreatActorTechnique
from app.services.scoring_service import calculate_technique_score
# Import Technique from app.models.technique
from app.models.technique import Technique
# Import TestTemplate from app.models.test_template
from app.models.test_template import TestTemplate
# Import ThreatActorTechnique from app.models.threat_actor
from app.models.threat_actor import ThreatActorTechnique
# Import calculate_technique_score from app.services.scoring_service
from app.services.scoring_service import calculate_technique_score
# ── Helpers ───────────────────────────────────────────────────────────
def _classify_control(technique_scores: list[float]) -> str:
"""Classify a control status based on its technique scores."""
# Check: not technique_scores
if not technique_scores:
# Return "not_evaluated"
return "not_evaluated"
# Assign all_above_70 = all(s >= 70 for s in technique_scores)
all_above_70 = all(s >= 70 for s in technique_scores)
# Assign any_above_30 = any(s >= 30 for s in technique_scores)
any_above_30 = any(s >= 30 for s in technique_scores)
# Assign all_below_30 = all(s < 30 for s in technique_scores)
all_below_30 = all(s < 30 for s in technique_scores)
# Assign all_zero = all(s == 0 for s in technique_scores)
all_zero = all(s == 0 for s in technique_scores)
# Check: all_zero
if all_zero:
# Return "not_evaluated"
return "not_evaluated"
# Check: all_above_70
if all_above_70:
# Return "covered"
return "covered"
# Check: all_below_30
if all_below_30:
# Return "not_covered"
return "not_covered"
# Check: any_above_30
if any_above_30:
# Return "partially_covered"
return "partially_covered"
# Return "not_covered"
return "not_covered"
# Define function _get_control_status
def _get_control_status(control: ComplianceControl, db: Session) -> dict[str, Any]:
"""Compute the status and score for a single control."""
# Assign mappings = (
mappings = (
db.query(ComplianceControlMapping)
# Chain .filter() call
.filter(ComplianceControlMapping.compliance_control_id == control.id)
# Chain .all() call
.all()
)
# Check: not mappings
if not mappings:
# Return {
return {
# Literal argument value
"control_id": control.control_id,
# Literal argument value
"title": control.title,
"description": control.description,
"category": control.category,
# Literal argument value
"status": "not_evaluated",
# Literal argument value
"score": 0,
# Literal argument value
"techniques_count": 0,
# Literal argument value
"techniques_covered": 0,
# Literal argument value
"techniques": [],
}
# Assign technique_ids = [m.technique_id for m in mappings]
technique_ids = [m.technique_id for m in mappings]
# Assign techniques = (
techniques = (
db.query(Technique)
# Chain .filter() call
.filter(Technique.id.in_(technique_ids))
# Chain .all() call
.all()
)
# Assign tech_details = []
tech_details = []
# Assign scores = []
scores = []
# Assign covered_count = 0
covered_count = 0
# Iterate over techniques
for tech in techniques:
# Assign result = calculate_technique_score(tech, db)
result = calculate_technique_score(tech, db)
# Assign score = result["total_score"]
score = result["total_score"]
# Call scores.append()
scores.append(score)
# Check: score >= 50
if score >= 50:
# Assign covered_count = 1
covered_count += 1
# Call tech_details.append()
tech_details.append({
# Literal argument value
"mitre_id": tech.mitre_id,
# Literal argument value
"name": tech.name,
# Literal argument value
"score": score,
# Literal argument value
"status": tech.status_global.value if tech.status_global else "not_evaluated",
})
# Sort techniques by score ascending (worst first for priority)
tech_details.sort(key=lambda t: t["score"])
# Assign avg_score = round(sum(scores) / len(scores), 1) if scores else 0
avg_score = round(sum(scores) / len(scores), 1) if scores else 0
# Assign status = _classify_control(scores)
status = _classify_control(scores)
# Return {
return {
# Literal argument value
"control_id": control.control_id,
# Literal argument value
"title": control.title,
"description": control.description,
"category": control.category,
# Literal argument value
"status": status,
# Literal argument value
"score": avg_score,
# Literal argument value
"techniques_count": len(techniques),
# Literal argument value
"techniques_covered": covered_count,
# Literal argument value
"techniques": tech_details,
}
@@ -120,95 +193,150 @@ def _get_control_status(control: ComplianceControl, db: Session) -> dict[str, An
def list_frameworks(db: Session) -> list[dict[str, Any]]:
"""List all available compliance frameworks with control counts."""
# Assign frameworks = (
frameworks = (
db.query(ComplianceFramework)
# Chain .filter() call
.filter(ComplianceFramework.is_active == True)
# Chain .all() call
.all()
)
# Assign result = []
result = []
# Iterate over frameworks
for fw in frameworks:
# Assign control_count = (
control_count = (
db.query(ComplianceControl)
# Chain .filter() call
.filter(ComplianceControl.framework_id == fw.id)
# Chain .count() call
.count()
)
# Call result.append()
result.append({
# Literal argument value
"id": str(fw.id),
# Literal argument value
"name": fw.name,
# Literal argument value
"version": fw.version,
# Literal argument value
"description": fw.description,
# Literal argument value
"url": fw.url,
# Literal argument value
"is_active": fw.is_active,
# Literal argument value
"controls_count": control_count,
})
# Return result
return result
# Define function get_framework
def get_framework(db: Session, framework_id: str) -> ComplianceFramework | None:
"""Get a framework by ID, or None if not found."""
# Return (
return (
db.query(ComplianceFramework)
# Chain .filter() call
.filter(ComplianceFramework.id == framework_id)
# Chain .first() call
.first()
)
# Define function get_framework_status
def get_framework_status(db: Session, framework_id: str) -> dict[str, Any]:
"""Get compliance status for each control in a framework.
Raises EntityNotFoundError if the framework does not exist.
"""
# Assign framework = get_framework(db, framework_id)
framework = get_framework(db, framework_id)
# Check: not framework
if not framework:
# Raise EntityNotFoundError
raise EntityNotFoundError("Framework", framework_id)
# Assign controls = (
controls = (
db.query(ComplianceControl)
# Chain .filter() call
.filter(ComplianceControl.framework_id == framework.id)
# Chain .order_by() call
.order_by(ComplianceControl.control_id)
# Chain .all() call
.all()
)
# Assign control_statuses = []
control_statuses = []
# Assign summary = {
summary = {
# Literal argument value
"total_controls": len(controls),
# Literal argument value
"covered": 0,
# Literal argument value
"partially_covered": 0,
# Literal argument value
"not_covered": 0,
# Literal argument value
"not_evaluated": 0,
}
# Iterate over controls
for control in controls:
# Assign status_data = _get_control_status(control, db)
status_data = _get_control_status(control, db)
# Call control_statuses.append()
control_statuses.append(status_data)
# Assign status = status_data["status"]
status = status_data["status"]
# Check: status in summary
if status in summary:
# Assign summary[status] = 1
summary[status] += 1
# Compliance percentage: (covered + partially_covered*0.5) / total * 100
total = summary["total_controls"]
# Check: total > 0
if total > 0:
# Assign compliance_pct = round(
compliance_pct = round(
(summary["covered"] + summary["partially_covered"] * 0.5) / total * 100,
# Literal argument value
1,
)
# Fallback: handle remaining cases
else:
# Assign compliance_pct = 0
compliance_pct = 0
# Assign summary["compliance_percentage"] = compliance_pct
summary["compliance_percentage"] = compliance_pct
# Return {
return {
# Literal argument value
"framework": {"id": str(framework.id), "name": framework.name},
# Literal argument value
"summary": summary,
# Literal argument value
"controls": control_statuses,
}
# Define function build_framework_report_csv
def build_framework_report_csv(
# Entry: db
db: Session,
# Entry: framework_id
framework_id: str,
) -> tuple[bytes, str]:
"""Build the compliance report CSV content and filename.
@@ -217,33 +345,55 @@ def build_framework_report_csv(
Raises EntityNotFoundError if the framework does not exist.
"""
# Assign framework = get_framework(db, framework_id)
framework = get_framework(db, framework_id)
# Check: not framework
if not framework:
# Raise EntityNotFoundError
raise EntityNotFoundError("Framework", framework_id)
# Assign controls = (
controls = (
db.query(ComplianceControl)
# Chain .filter() call
.filter(ComplianceControl.framework_id == framework.id)
# Chain .order_by() call
.order_by(ComplianceControl.control_id)
# Chain .all() call
.all()
)
# Assign output = io.StringIO()
output = io.StringIO()
# Assign writer = csv.writer(output)
writer = csv.writer(output)
# Call writer.writerow()
writer.writerow([
# Literal argument value
"control_id",
# Literal argument value
"title",
# Literal argument value
"category",
# Literal argument value
"status",
# Literal argument value
"score",
# Literal argument value
"techniques_total",
# Literal argument value
"techniques_covered",
# Literal argument value
"technique_ids",
])
# Iterate over controls
for control in controls:
# Assign status_data = _get_control_status(control, db)
status_data = _get_control_status(control, db)
# Assign technique_ids = ",".join(t["mitre_id"] for t in status_data["techniques"])
technique_ids = ",".join(t["mitre_id"] for t in status_data["techniques"])
# Call writer.writerow()
writer.writerow([
status_data["control_id"],
status_data["title"],
@@ -255,75 +405,116 @@ def build_framework_report_csv(
technique_ids,
])
# Call output.seek()
output.seek(0)
# Assign filename = f"compliance_{framework.name.replace(' ', '_')}.csv"
filename = f"compliance_{framework.name.replace(' ', '_')}.csv"
# Return output.getvalue().encode("utf-8"), filename
return output.getvalue().encode("utf-8"), filename
# Define function get_framework_gaps
def get_framework_gaps(db: Session, framework_id: str) -> dict[str, Any]:
"""Get controls with techniques that are not adequately covered.
Raises EntityNotFoundError if the framework does not exist.
"""
# Assign framework = get_framework(db, framework_id)
framework = get_framework(db, framework_id)
# Check: not framework
if not framework:
# Raise EntityNotFoundError
raise EntityNotFoundError("Framework", framework_id)
# Assign controls = (
controls = (
db.query(ComplianceControl)
# Chain .filter() call
.filter(ComplianceControl.framework_id == framework.id)
# Chain .order_by() call
.order_by(ComplianceControl.control_id)
# Chain .all() call
.all()
)
# Assign gaps = []
gaps = []
# Iterate over controls
for control in controls:
# Assign status_data = _get_control_status(control, db)
status_data = _get_control_status(control, db)
# Check: status_data["status"] in ("not_covered", "partially_covered")
if status_data["status"] in ("not_covered", "partially_covered"):
# Find uncovered techniques
uncovered_techniques = []
# Iterate over status_data["techniques"]
for tech_info in status_data["techniques"]:
# Check: tech_info["score"] < 70
if tech_info["score"] < 70:
# Count available templates
template_count = (
db.query(TestTemplate)
# Chain .filter() call
.filter(TestTemplate.mitre_technique_id == tech_info["mitre_id"])
# Chain .count() call
.count()
)
# Count threat actors using this technique
technique = (
db.query(Technique)
# Chain .filter() call
.filter(Technique.mitre_id == tech_info["mitre_id"])
# Chain .first() call
.first()
)
# Assign actor_count = 0
actor_count = 0
# Check: technique
if technique:
# Assign actor_count = (
actor_count = (
db.query(ThreatActorTechnique)
# Chain .filter() call
.filter(ThreatActorTechnique.technique_id == technique.id)
# Chain .count() call
.count()
)
# Call uncovered_techniques.append()
uncovered_techniques.append({
**tech_info,
# Literal argument value
"templates_available": template_count,
# Literal argument value
"threat_actors_using": actor_count,
})
# Check: uncovered_techniques
if uncovered_techniques:
# Call gaps.append()
gaps.append({
# Literal argument value
"control_id": status_data["control_id"],
# Literal argument value
"title": status_data["title"],
# Literal argument value
"category": status_data["category"],
# Literal argument value
"status": status_data["status"],
# Literal argument value
"score": status_data["score"],
# Literal argument value
"uncovered_techniques": uncovered_techniques,
})
# Return {
return {
# Literal argument value
"framework": {"id": str(framework.id), "name": framework.name},
# Literal argument value
"total_gaps": len(gaps),
# Literal argument value
"gaps": gaps,
}
@@ -7,120 +7,202 @@ technique/test-count pattern by using a single grouped query.
This module is framework-agnostic: no FastAPI imports.
"""
# Enable future language features for compatibility
from __future__ import annotations
# Import datetime from datetime
from datetime import datetime
# Import func from sqlalchemy
from sqlalchemy import func
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session
# Import Technique from app.models.technique
from app.models.technique import Technique
# Import Test from app.models.test
from app.models.test import Test
# Import escape_like from app.utils
from app.utils import escape_like
# Define function _technique_test_counts
def _technique_test_counts(
# Entry: db
db: Session,
# Entry: technique_ids
technique_ids: list,
) -> dict:
"""Return ``{technique_id: {state_str: count}}`` in a single query."""
# Check: not technique_ids
if not technique_ids:
# Return {}
return {}
# Assign rows = (
rows = (
db.query(Test.technique_id, Test.state, func.count(Test.id))
# Chain .filter() call
.filter(Test.technique_id.in_(technique_ids))
# Chain .group_by() call
.group_by(Test.technique_id, Test.state)
# Chain .all() call
.all()
)
# Assign result = {}
result: dict = {}
# Iterate over rows
for tid, state, count in rows:
# Call result.setdefault()
result.setdefault(tid, {})[str(state)] = count
# Return result
return result
# Define function build_coverage_summary
def build_coverage_summary(
# Entry: db
db: Session,
*,
# Entry: tactic
tactic: str | None = None,
# Entry: platform
platform: str | None = None,
) -> dict:
"""Build the full coverage summary report as a dict."""
# Assign query = db.query(Technique)
query = db.query(Technique)
# Check: tactic
if tactic:
# Assign query = query.filter(Technique.tactic.ilike(f"%{escape_like(tactic)}%"))
query = query.filter(Technique.tactic.ilike(f"%{escape_like(tactic)}%"))
# Assign techniques = query.order_by(Technique.mitre_id).all()
techniques = query.order_by(Technique.mitre_id).all()
# Assign counts_map = _technique_test_counts(db, [t.id for t in techniques])
counts_map = _technique_test_counts(db, [t.id for t in techniques])
# Assign rows = []
rows = []
# Iterate over techniques
for t in techniques:
# Check: platform and platform.lower() not in [p.lower() for p in (t.platfor...
if platform and platform.lower() not in [p.lower() for p in (t.platforms or [])]:
# Skip to the next loop iteration
continue
# Assign counts = counts_map.get(t.id, {})
counts = counts_map.get(t.id, {})
# Call rows.append()
rows.append({
# Literal argument value
"mitre_id": t.mitre_id,
# Literal argument value
"name": t.name,
# Literal argument value
"tactic": t.tactic,
# Literal argument value
"platforms": t.platforms,
# Literal argument value
"status_global": t.status_global,
# Literal argument value
"total_tests": sum(counts.values()),
# Literal argument value
"tests_by_state": counts,
})
# Assign total = len(rows)
total = len(rows)
# Assign validated = sum(1 for r in rows if r["status_global"] == "validated")
validated = sum(1 for r in rows if r["status_global"] == "validated")
# Assign partial = sum(1 for r in rows if r["status_global"] == "partial")
partial = sum(1 for r in rows if r["status_global"] == "partial")
# Assign not_covered = sum(1 for r in rows if r["status_global"] == "not_covered")
not_covered = sum(1 for r in rows if r["status_global"] == "not_covered")
# Assign in_progress = sum(1 for r in rows if r["status_global"] == "in_progress")
in_progress = sum(1 for r in rows if r["status_global"] == "in_progress")
# Assign not_evaluated = sum(1 for r in rows if r["status_global"] == "not_evaluated")
not_evaluated = sum(1 for r in rows if r["status_global"] == "not_evaluated")
# Return {
return {
# Literal argument value
"generated_at": datetime.utcnow().isoformat(),
# Literal argument value
"summary": {
# Literal argument value
"total_techniques": total,
# Literal argument value
"validated": validated,
# Literal argument value
"partial": partial,
# Literal argument value
"not_covered": not_covered,
# Literal argument value
"in_progress": in_progress,
# Literal argument value
"not_evaluated": not_evaluated,
# Literal argument value
"coverage_percentage": round((validated / total * 100) if total > 0 else 0, 1),
},
# Literal argument value
"techniques": rows,
}
# Define function build_coverage_csv_rows
def build_coverage_csv_rows(
# Entry: db
db: Session,
*,
# Entry: tactic
tactic: str | None = None,
# Entry: platform
platform: str | None = None,
) -> list[list]:
"""Build rows for a CSV coverage export (header + data)."""
# Assign query = db.query(Technique)
query = db.query(Technique)
# Check: tactic
if tactic:
# Assign query = query.filter(Technique.tactic.ilike(f"%{escape_like(tactic)}%"))
query = query.filter(Technique.tactic.ilike(f"%{escape_like(tactic)}%"))
# Assign techniques = query.order_by(Technique.mitre_id).all()
techniques = query.order_by(Technique.mitre_id).all()
# Assign counts_map = _technique_test_counts(db, [t.id for t in techniques])
counts_map = _technique_test_counts(db, [t.id for t in techniques])
# Assign header = [
header = [
# Literal argument value
"MITRE ID", "Name", "Tactic", "Platforms", "Status",
# Literal argument value
"Total Tests", "Validated", "In Progress", "Not Covered",
]
# Assign rows = [header]
rows = [header]
# Assign in_progress_states = {"draft", "red_executing", "blue_evaluating", "in_review"}
in_progress_states = {"draft", "red_executing", "blue_evaluating", "in_review"}
# Iterate over techniques
for t in techniques:
# Check: platform and platform.lower() not in [p.lower() for p in (t.platfor...
if platform and platform.lower() not in [p.lower() for p in (t.platforms or [])]:
# Skip to the next loop iteration
continue
# Assign counts = counts_map.get(t.id, {})
counts = counts_map.get(t.id, {})
# Call rows.append()
rows.append([
t.mitre_id,
t.name,
t.tactic,
# Literal argument value
", ".join(t.platforms or []),
t.status_global,
sum(counts.values()),
@@ -129,65 +211,111 @@ def build_coverage_csv_rows(
counts.get("rejected", 0),
])
# Return rows
return rows
# Define function build_test_results_report
def build_test_results_report(
# Entry: db
db: Session,
*,
# Entry: state
state: str | None = None,
# Entry: date_from
date_from: str | None = None,
# Entry: date_to
date_to: str | None = None,
) -> dict:
"""Build a test results report with optional filters."""
# Assign query = db.query(Test)
query = db.query(Test)
# Check: state
if state:
# Assign query = query.filter(Test.state == state)
query = query.filter(Test.state == state)
# Check: date_from
if date_from:
# Attempt the following; catch errors below
try:
# Assign query = query.filter(Test.created_at >= datetime.fromisoformat(date_from))
query = query.filter(Test.created_at >= datetime.fromisoformat(date_from))
# Handle ValueError
except ValueError:
# Intentional no-op placeholder
pass
# Check: date_to
if date_to:
# Attempt the following; catch errors below
try:
# Assign query = query.filter(Test.created_at <= datetime.fromisoformat(date_to))
query = query.filter(Test.created_at <= datetime.fromisoformat(date_to))
# Handle ValueError
except ValueError:
# Intentional no-op placeholder
pass
# Assign tests = query.order_by(Test.created_at.desc()).all()
tests = query.order_by(Test.created_at.desc()).all()
# Assign by_state = {}
by_state: dict[str, int] = {}
# Assign by_result = {}
by_result: dict[str, int] = {}
# Iterate over tests
for t in tests:
# Assign s = t.state.value if hasattr(t.state, "value") else str(t.state)
s = t.state.value if hasattr(t.state, "value") else str(t.state)
# Assign by_state[s] = by_state.get(s, 0) + 1
by_state[s] = by_state.get(s, 0) + 1
# Check: t.detection_result
if t.detection_result:
# Assign r = t.detection_result.value if hasattr(t.detection_result, "value") el...
r = t.detection_result.value if hasattr(t.detection_result, "value") else str(t.detection_result)
# Assign by_result[r] = by_result.get(r, 0) + 1
by_result[r] = by_result.get(r, 0) + 1
# Return {
return {
# Literal argument value
"generated_at": datetime.utcnow().isoformat(),
# Literal argument value
"filters": {"state": state, "date_from": date_from, "date_to": date_to},
# Literal argument value
"summary": {
# Literal argument value
"total_tests": len(tests),
# Literal argument value
"by_state": by_state,
# Literal argument value
"by_detection_result": by_result,
},
# Literal argument value
"tests": [
{
# Literal argument value
"id": str(t.id),
# Literal argument value
"name": t.name,
# Literal argument value
"technique_id": str(t.technique_id),
# Literal argument value
"state": t.state.value if hasattr(t.state, "value") else str(t.state),
# Literal argument value
"platform": t.platform,
# Literal argument value
"attack_success": t.attack_success,
# Literal argument value
"detection_result": (
t.detection_result.value if t.detection_result and hasattr(t.detection_result, "value")
else str(t.detection_result) if t.detection_result else None
),
# Literal argument value
"red_validation_status": t.red_validation_status,
# Literal argument value
"blue_validation_status": t.blue_validation_status,
# Literal argument value
"created_at": t.created_at.isoformat() if t.created_at else None,
}
for t in tests
@@ -195,38 +323,62 @@ def build_test_results_report(
}
# Define function build_remediation_status_report
def build_remediation_status_report(
# Entry: db
db: Session,
*,
# Entry: status
status: str | None = None,
) -> dict:
"""Build a remediation status report."""
# Assign query = db.query(Test).filter(Test.remediation_steps.isnot(None))
query = db.query(Test).filter(Test.remediation_steps.isnot(None))
# Check: status
if status:
# Assign query = query.filter(Test.remediation_status == status)
query = query.filter(Test.remediation_status == status)
# Assign tests = query.order_by(Test.created_at.desc()).all()
tests = query.order_by(Test.created_at.desc()).all()
# Assign by_status = {}
by_status: dict[str, int] = {}
# Iterate over tests
for t in tests:
# Assign s = t.remediation_status or "unset"
s = t.remediation_status or "unset"
# Assign by_status[s] = by_status.get(s, 0) + 1
by_status[s] = by_status.get(s, 0) + 1
# Return {
return {
# Literal argument value
"generated_at": datetime.utcnow().isoformat(),
# Literal argument value
"summary": {
# Literal argument value
"total_with_remediation": len(tests),
# Literal argument value
"by_status": by_status,
},
# Literal argument value
"tests": [
{
# Literal argument value
"id": str(t.id),
# Literal argument value
"name": t.name,
# Literal argument value
"technique_id": str(t.technique_id),
# Literal argument value
"state": t.state.value if hasattr(t.state, "value") else str(t.state),
# Literal argument value
"remediation_status": t.remediation_status,
# Literal argument value
"remediation_steps": t.remediation_steps,
# Literal argument value
"remediation_assignee": str(t.remediation_assignee) if t.remediation_assignee else None,
}
for t in tests
+329 -14
View File
@@ -1,148 +1,270 @@
"""D3FEND import service — fetches MITRE D3FEND data and creates
DefensiveTechnique records plus ATT&CK → D3FEND mappings.
"""D3FEND import service — fetches MITRE D3FEND data and creates DefensiveTechnique records plus ATT&CK → D3FEND mappings.
Uses the D3FEND public API:
- https://d3fend.mitre.org/api/technique/api-all.json (all defensive techniques)
- https://d3fend.mitre.org/api/offensive-technique/{attack_id}.json (mappings per ATT&CK technique)
"""
# Import logging
import logging
import uuid
# Import Any from typing
from typing import Any
# Import UUID from uuid
from uuid import UUID
# Import httpx
import httpx
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session
from app.models.technique import Technique
# Import DefensiveTechnique, DefensiveTechniqueMapping from app.models.defensive_technique
from app.models.defensive_technique import DefensiveTechnique, DefensiveTechniqueMapping
# Import Technique from app.models.technique
from app.models.technique import Technique
# Assign logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
# Assign D3FEND_TACTIC_URL = "https://d3fend.mitre.org/api/tactic/d3f:{tactic}.json"
D3FEND_TACTIC_URL = "https://d3fend.mitre.org/api/tactic/d3f:{tactic}.json"
# Assign D3FEND_MAPPING_URL = "https://d3fend.mitre.org/api/offensive-technique/{attack_id}.json"
D3FEND_MAPPING_URL = "https://d3fend.mitre.org/api/offensive-technique/{attack_id}.json"
# Assign D3FEND_BASE_URL = "https://d3fend.mitre.org/technique/d3f:{iri}"
D3FEND_BASE_URL = "https://d3fend.mitre.org/technique/d3f:{iri}"
# Assign D3FEND_TACTICS = ["Detect", "Harden", "Isolate", "Deceive", "Evict", "Model"]
D3FEND_TACTICS = ["Detect", "Harden", "Isolate", "Deceive", "Evict", "Model"]
# ── Import all D3FEND techniques ─────────────────────────────────────
def _to_str(v: Any) -> str:
"""Coerce an RDF value (str, dict with @value, or list) to a plain string."""
def _to_str(v: Any) -> str: # noqa: ANN401
"""Coerce an RDF value (str, dict with @value, or list) to a plain string.
Args:
v (Any): RDF node value — may be a plain string, a dict containing
a ``@value`` key, or a list of such values.
Returns:
str: Plain string representation; ``"; "``-joined for list inputs.
"""
# Check: isinstance(v, dict)
if isinstance(v, dict):
# Return v.get("@value", str(v))
return v.get("@value", str(v))
# Check: isinstance(v, list)
if isinstance(v, list):
# Return "; ".join(_to_str(x) for x in v)
return "; ".join(_to_str(x) for x in v)
# Return str(v) if v else ""
return str(v) if v else ""
# Define function _fetch_techniques_from_tactic_apis
def _fetch_techniques_from_tactic_apis() -> list[dict[str, Any]]:
"""Fetch all defensive techniques via D3FEND tactic APIs.
Uses ``/api/tactic/d3f:{tactic}.json`` which is reliable and returns
full metadata including the ontology IRI for each technique.
Returns:
list[dict[str, Any]]: Deduplicated list of technique dicts, each
containing ``d3fend_id``, ``iri``, ``name``, ``description``,
and ``tactic``.
"""
# Assign all_techniques = []
all_techniques: list[dict[str, Any]] = []
# Assign seen = set()
seen: set[str] = set()
# Open context manager
with httpx.Client(timeout=60.0) as client:
# Iterate over D3FEND_TACTICS
for tactic in D3FEND_TACTICS:
# Assign url = D3FEND_TACTIC_URL.format(tactic=tactic)
url = D3FEND_TACTIC_URL.format(tactic=tactic)
# Attempt the following; catch errors below
try:
# Assign resp = client.get(url)
resp = client.get(url)
# Call resp.raise_for_status()
resp.raise_for_status()
# Assign data = resp.json()
data = resp.json()
# Handle Exception
except Exception as e:
# Log warning: "Failed to fetch D3FEND tactic %s: %s", tactic, e
logger.warning("Failed to fetch D3FEND tactic %s: %s", tactic, e)
# Skip to the next loop iteration
continue
# Assign graph = data.get("techniques", {}).get("@graph", [])
graph = data.get("techniques", {}).get("@graph", [])
# Iterate over graph
for node in graph:
# Assign nid = node.get("@id", "")
nid = node.get("@id", "")
# Assign d3id = _to_str(node.get("d3f:d3fend-id", ""))
d3id = _to_str(node.get("d3f:d3fend-id", ""))
# Assign label = _to_str(node.get("rdfs:label", ""))
label = _to_str(node.get("rdfs:label", ""))
# Assign defn = _to_str(node.get("d3f:definition", ""))
defn = _to_str(node.get("d3f:definition", ""))
# Check: not defn
if not defn:
# Assign defn = _to_str(node.get("rdfs:comment", ""))
defn = _to_str(node.get("rdfs:comment", ""))
# Assign iri = nid.replace("d3f:", "") if nid.startswith("d3f:") else nid
iri = nid.replace("d3f:", "") if nid.startswith("d3f:") else nid
# Check: d3id and label and d3id not in seen
if d3id and label and d3id not in seen:
# Call seen.add()
seen.add(d3id)
# Call all_techniques.append()
all_techniques.append({
# Literal argument value
"d3fend_id": d3id,
# Literal argument value
"iri": iri,
# Literal argument value
"name": label,
# Literal argument value
"description": defn[:500] if defn else None,
# Literal argument value
"tactic": tactic,
})
# Log info: "D3FEND tactic %s: %d techniques", tactic, len(gra
logger.info("D3FEND tactic %s: %d techniques", tactic, len(graph))
# Return all_techniques
return all_techniques
# Define function _upsert_techniques
def _upsert_techniques(db: Session, techniques: list[dict[str, Any]]) -> dict[str, int]:
"""Upsert a list of technique dicts into the DefensiveTechnique table."""
"""Upsert a list of technique dicts into the DefensiveTechnique table.
Args:
db (Session): Active SQLAlchemy database session.
techniques (list[dict[str, Any]]): List of technique data dicts, each
containing ``d3fend_id``, ``name``, and optionally ``description``,
``tactic``, and ``iri``.
Returns:
dict[str, int]: Contains ``created``, ``updated``, and ``total``
counts after the upsert.
"""
# Assign created = 0
created = 0
# Assign updated = 0
updated = 0
# Iterate over techniques
for tech_data in techniques:
# Assign existing = (
existing = (
db.query(DefensiveTechnique)
# Chain .filter() call
.filter(DefensiveTechnique.d3fend_id == tech_data["d3fend_id"])
# Chain .first() call
.first()
)
# Assign iri = tech_data.get("iri") or tech_data["name"].replace(" ", "")
iri = tech_data.get("iri") or tech_data["name"].replace(" ", "")
# Assign d3fend_url = D3FEND_BASE_URL.format(iri=iri)
d3fend_url = D3FEND_BASE_URL.format(iri=iri)
# Check: existing
if existing:
# Assign existing.name = tech_data["name"]
existing.name = tech_data["name"]
# Assign existing.description = tech_data.get("description")
existing.description = tech_data.get("description")
# Assign existing.tactic = tech_data.get("tactic")
existing.tactic = tech_data.get("tactic")
# Assign existing.d3fend_url = d3fend_url
existing.d3fend_url = d3fend_url
# Assign updated = 1
updated += 1
# Fallback: handle remaining cases
else:
# Assign new_tech = DefensiveTechnique(
new_tech = DefensiveTechnique(
# Keyword argument: d3fend_id
d3fend_id=tech_data["d3fend_id"],
# Keyword argument: name
name=tech_data["name"],
# Keyword argument: description
description=tech_data.get("description"),
# Keyword argument: tactic
tactic=tech_data.get("tactic"),
# Keyword argument: d3fend_url
d3fend_url=d3fend_url,
)
# Stage new record(s) for database insertion
db.add(new_tech)
# Assign created = 1
created += 1
# Commit all pending changes to the database
db.commit()
# Assign total = db.query(DefensiveTechnique).count()
total = db.query(DefensiveTechnique).count()
# Return {"created": created, "updated": updated, "total": total}
return {"created": created, "updated": updated, "total": total}
# Define function import_d3fend_techniques
def import_d3fend_techniques(db: Session) -> dict[str, int]:
"""Fetch all D3FEND defensive techniques and upsert into DB.
Uses the tactic-level APIs which are reliable and provide full metadata
including ontology IRIs for correct URL generation.
Returns a dict with counts: {created, updated, total}.
Args:
db (Session): Active SQLAlchemy database session.
Returns:
dict[str, int]: Contains ``created``, ``updated``, and ``total``
counts; falls back to curated list when the API returns fewer
than 50 techniques.
"""
# Log info: "Fetching D3FEND techniques from tactic APIs"
logger.info("Fetching D3FEND techniques from tactic APIs")
# Attempt the following; catch errors below
try:
# Assign techniques = _fetch_techniques_from_tactic_apis()
techniques = _fetch_techniques_from_tactic_apis()
# Handle Exception
except Exception as e:
# Log error: "Failed to fetch D3FEND techniques from tactic API
logger.error("Failed to fetch D3FEND techniques from tactic APIs: %s", e)
# Assign techniques = []
techniques = []
# Check: len(techniques) >= 50
if len(techniques) >= 50:
# Log info: "Fetched %d D3FEND techniques from tactic APIs", l
logger.info("Fetched %d D3FEND techniques from tactic APIs", len(techniques))
# Assign result = _upsert_techniques(db, techniques)
result = _upsert_techniques(db, techniques)
# Log info: "D3FEND import done: %d created, %d updated, %d to
logger.info("D3FEND import done: %d created, %d updated, %d total",
result["created"], result["updated"], result["total"])
# Return result
return result
# Fallback: use a curated list of well-known D3FEND techniques
logger.warning("Tactic APIs returned too few techniques (%d), using fallback", len(techniques))
# Return _import_d3fend_fallback(db)
return _import_d3fend_fallback(db)
@@ -228,9 +350,20 @@ _FALLBACK_TECHNIQUES: list[dict[str, str | None]] = [
]
# Define function _import_d3fend_fallback
def _import_d3fend_fallback(db: Session) -> dict[str, int]:
"""Import curated D3FEND techniques when the tactic APIs are unreachable."""
"""Import curated D3FEND techniques when the tactic APIs are unreachable.
Args:
db (Session): Active SQLAlchemy database session.
Returns:
dict[str, int]: Contains ``created``, ``updated``, and ``total``
counts from upserting the fallback technique list.
"""
# Log info: "Using fallback D3FEND technique list (%d entries
logger.info("Using fallback D3FEND technique list (%d entries)", len(_FALLBACK_TECHNIQUES))
# Return _upsert_techniques(db, _FALLBACK_TECHNIQUES) # type: ignore[arg-type]
return _upsert_techniques(db, _FALLBACK_TECHNIQUES) # type: ignore[arg-type]
@@ -239,217 +372,399 @@ def _import_d3fend_fallback(db: Session) -> dict[str, int]:
# Curated ATT&CK → D3FEND mapping for common techniques
_ATTACK_TO_D3FEND: dict[str, list[str]] = {
# Literal argument value
"T1059": ["D3-PSA", "D3-SCA", "D3-PA", "D3-EAW", "D3-EDL", "D3-PLA"],
# Literal argument value
"T1059.001": ["D3-PSA", "D3-SCA", "D3-PA", "D3-EAW", "D3-EDL"],
# Literal argument value
"T1059.003": ["D3-PSA", "D3-SCA", "D3-PA", "D3-EAW"],
# Literal argument value
"T1059.005": ["D3-PSA", "D3-SCA", "D3-EAW"],
# Literal argument value
"T1059.007": ["D3-PSA", "D3-SCA", "D3-EAW"],
# Literal argument value
"T1055": ["D3-PA", "D3-PSA", "D3-HBPI", "D3-PMAD", "D3-PLA"],
# Literal argument value
"T1055.001": ["D3-PA", "D3-PMAD", "D3-HBPI"],
# Literal argument value
"T1055.002": ["D3-PA", "D3-PMAD", "D3-HBPI"],
# Literal argument value
"T1003": ["D3-CH", "D3-CR", "D3-MFA", "D3-PMAD"],
# Literal argument value
"T1003.001": ["D3-CH", "D3-CR", "D3-PMAD"],
# Literal argument value
"T1078": ["D3-MFA", "D3-UBA", "D3-UGLPA", "D3-CH"],
# Literal argument value
"T1078.001": ["D3-MFA", "D3-UBA", "D3-CH"],
# Literal argument value
"T1566": ["D3-EAL", "D3-FA", "D3-FH", "D3-UA", "D3-EHR"],
# Literal argument value
"T1566.001": ["D3-EAL", "D3-FA", "D3-FH", "D3-EHR"],
# Literal argument value
"T1566.002": ["D3-UA", "D3-EAL", "D3-EHR"],
# Literal argument value
"T1071": ["D3-AL", "D3-NTA", "D3-PM", "D3-CT"],
# Literal argument value
"T1071.001": ["D3-AL", "D3-NTA", "D3-PM"],
# Literal argument value
"T1053": ["D3-PSA", "D3-PA", "D3-SCHE", "D3-SSA"],
# Literal argument value
"T1053.005": ["D3-PSA", "D3-SCHE", "D3-SSA"],
# Literal argument value
"T1543": ["D3-SMRA", "D3-SSA", "D3-SBAN"],
# Literal argument value
"T1543.003": ["D3-SMRA", "D3-SSA", "D3-SBAN"],
# Literal argument value
"T1547": ["D3-SICA", "D3-SSA", "D3-RRID"],
# Literal argument value
"T1547.001": ["D3-SICA", "D3-SSA", "D3-RRID"],
# Literal argument value
"T1021": ["D3-RTSD", "D3-RPA", "D3-NTA", "D3-MFA"],
# Literal argument value
"T1021.001": ["D3-RTSD", "D3-NTA", "D3-MFA"],
# Literal argument value
"T1021.002": ["D3-RTSD", "D3-NTA", "D3-NI"],
# Literal argument value
"T1560": ["D3-FA", "D3-FCA", "D3-ORA"],
# Literal argument value
"T1560.001": ["D3-FA", "D3-FCA"],
# Literal argument value
"T1048": ["D3-ORA", "D3-NTA", "D3-OTF"],
# Literal argument value
"T1048.003": ["D3-ORA", "D3-NTA", "D3-OTF"],
# Literal argument value
"T1105": ["D3-IRA", "D3-NTA", "D3-FA", "D3-FH"],
# Literal argument value
"T1036": ["D3-FCA", "D3-FH", "D3-FA", "D3-SWI"],
# Literal argument value
"T1036.005": ["D3-FCA", "D3-FH", "D3-FA"],
# Literal argument value
"T1140": ["D3-FA", "D3-DA", "D3-SCA"],
# Literal argument value
"T1070": ["D3-SSA", "D3-LOGA", "D3-SYSM"],
# Literal argument value
"T1070.004": ["D3-SSA", "D3-FAPA"],
# Literal argument value
"T1562": ["D3-SSA", "D3-SYSM", "D3-SMRA"],
# Literal argument value
"T1562.001": ["D3-SSA", "D3-SYSM", "D3-SMRA"],
# Literal argument value
"T1027": ["D3-DA", "D3-FA", "D3-RE"],
# Literal argument value
"T1027.002": ["D3-DA", "D3-FA"],
# Literal argument value
"T1110": ["D3-MFA", "D3-UBA", "D3-CH"],
# Literal argument value
"T1110.001": ["D3-MFA", "D3-UBA", "D3-CH"],
# Literal argument value
"T1082": ["D3-PSA", "D3-PA", "D3-SYSM"],
# Literal argument value
"T1083": ["D3-FAPA", "D3-PA"],
# Literal argument value
"T1497": ["D3-DA", "D3-SE"],
# Literal argument value
"T1218": ["D3-PSA", "D3-PLA", "D3-EAW"],
# Literal argument value
"T1218.011": ["D3-PSA", "D3-PLA", "D3-EAW"],
# Literal argument value
"T1569": ["D3-SMRA", "D3-PSA", "D3-PA"],
# Literal argument value
"T1569.002": ["D3-SMRA", "D3-PSA"],
# Literal argument value
"T1012": ["D3-RRID", "D3-PA"],
# Literal argument value
"T1112": ["D3-RRID", "D3-PA", "D3-REGG"],
# Literal argument value
"T1057": ["D3-PA", "D3-PSA"],
# Literal argument value
"T1518": ["D3-SYSM", "D3-PA"],
# Literal argument value
"T1049": ["D3-NTA", "D3-PA"],
# Literal argument value
"T1016": ["D3-NTA", "D3-PA", "D3-SYSM"],
# Literal argument value
"T1033": ["D3-PA", "D3-UBA"],
# Literal argument value
"T1087": ["D3-UBA", "D3-PA", "D3-SSA"],
# Literal argument value
"T1087.001": ["D3-UBA", "D3-PA"],
# Literal argument value
"T1087.002": ["D3-UBA", "D3-PA"],
# Literal argument value
"T1018": ["D3-NTA", "D3-PA"],
# Literal argument value
"T1047": ["D3-RPA", "D3-PSA", "D3-PA"],
# Literal argument value
"T1190": ["D3-ISVA", "D3-NTA", "D3-AL"],
# Literal argument value
"T1133": ["D3-NTA", "D3-MFA", "D3-RTSD"],
# Literal argument value
"T1486": ["D3-BKUP", "D3-FBKP", "D3-ANTR", "D3-FA"],
# Literal argument value
"T1490": ["D3-BKUP", "D3-FBKP", "D3-SSA"],
# Literal argument value
"T1489": ["D3-SMRA", "D3-SSA"],
# Literal argument value
"T1098": ["D3-UBA", "D3-SSA", "D3-PGOV"],
# Literal argument value
"T1136": ["D3-UBA", "D3-SSA", "D3-UACM"],
# Literal argument value
"T1136.001": ["D3-UBA", "D3-SSA", "D3-UACM"],
# Literal argument value
"T1068": ["D3-SU", "D3-VULM", "D3-HBPI"],
# Literal argument value
"T1548": ["D3-PSEP", "D3-PSA", "D3-PA"],
# Literal argument value
"T1548.002": ["D3-PSEP", "D3-PSA"],
# Literal argument value
"T1134": ["D3-PA", "D3-PSA", "D3-PSEP"],
# Literal argument value
"T1134.001": ["D3-PA", "D3-PSA"],
# Literal argument value
"T1574": ["D3-SWI", "D3-FCA", "D3-PLA"],
# Literal argument value
"T1574.001": ["D3-SWI", "D3-FCA"],
# Literal argument value
"T1204": ["D3-EAL", "D3-FA", "D3-UA"],
# Literal argument value
"T1204.001": ["D3-UA", "D3-EAL"],
# Literal argument value
"T1204.002": ["D3-FA", "D3-EAL", "D3-DA"],
# Literal argument value
"T1071.004": ["D3-DPM", "D3-DNSSM", "D3-NTA"],
# Literal argument value
"T1571": ["D3-NTA", "D3-PM", "D3-AL"],
# Literal argument value
"T1572": ["D3-NTA", "D3-AL", "D3-PM"],
# Literal argument value
"T1041": ["D3-ORA", "D3-NTA"],
# Literal argument value
"T1005": ["D3-FAPA", "D3-PA"],
# Literal argument value
"T1113": ["D3-PA", "D3-PSA"],
# Literal argument value
"T1056": ["D3-PA", "D3-PSA", "D3-HBPI"],
# Literal argument value
"T1056.001": ["D3-PA", "D3-PSA"],
# Literal argument value
"T1560.003": ["D3-FA", "D3-ORA"],
# Literal argument value
"T1583": ["D3-IPMR", "D3-DNSRA"],
# Literal argument value
"T1584": ["D3-IPMR", "D3-DNSRA"],
# Literal argument value
"T1595": ["D3-IRA", "D3-NTA"],
# Literal argument value
"T1589": ["D3-UBA", "D3-THRT"],
# Literal argument value
"T1590": ["D3-NTA", "D3-THRT"],
# Literal argument value
"T1591": ["D3-THRT"],
# Literal argument value
"T1592": ["D3-THRT"],
}
# Define function import_d3fend_mappings
def import_d3fend_mappings(db: Session) -> dict[str, int]:
"""Create ATT&CK → D3FEND mappings.
First tries the D3FEND API for each ATT&CK technique in the DB,
then falls back to the curated mapping for any remaining techniques.
Returns a dict with counts: {created, skipped, total}.
Args:
db (Session): Active SQLAlchemy database session.
Returns:
dict[str, int]: Contains ``created``, ``skipped``, and ``total``
mapping counts.
"""
# Assign created = 0
created = 0
# Assign skipped = 0
skipped = 0
# Get all ATT&CK techniques from the DB
attack_techniques = db.query(Technique).all()
# Assign technique_map = {t.mitre_id: t for t in attack_techniques}
technique_map = {t.mitre_id: t for t in attack_techniques}
# Get all defensive techniques
defensive_techniques = db.query(DefensiveTechnique).all()
# Assign d3fend_map = {dt.d3fend_id: dt for dt in defensive_techniques}
d3fend_map = {dt.d3fend_id: dt for dt in defensive_techniques}
# Check: not d3fend_map
if not d3fend_map:
# Log warning: "No D3FEND techniques in DB — run import_d3fend_te
logger.warning("No D3FEND techniques in DB — run import_d3fend_techniques first")
# Return {"created": 0, "skipped": 0, "total": 0}
return {"created": 0, "skipped": 0, "total": 0}
# Use the curated mapping for now (API per-technique is very slow for 700+ techniques)
for mitre_id, d3fend_ids in _ATTACK_TO_D3FEND.items():
# Assign attack_tech = technique_map.get(mitre_id)
attack_tech = technique_map.get(mitre_id)
# Check: not attack_tech
if not attack_tech:
# Skip to the next loop iteration
continue
# Iterate over d3fend_ids
for d3fend_id in d3fend_ids:
# Assign def_tech = d3fend_map.get(d3fend_id)
def_tech = d3fend_map.get(d3fend_id)
# Check: not def_tech
if not def_tech:
# Skip to the next loop iteration
continue
# Check if mapping already exists
existing = (
db.query(DefensiveTechniqueMapping)
# Chain .filter() call
.filter(
DefensiveTechniqueMapping.attack_technique_id == attack_tech.id,
DefensiveTechniqueMapping.defensive_technique_id == def_tech.id,
)
# Chain .first() call
.first()
)
# Check: existing
if existing:
# Assign skipped = 1
skipped += 1
# Skip to the next loop iteration
continue
# Assign mapping = DefensiveTechniqueMapping(
mapping = DefensiveTechniqueMapping(
# Keyword argument: attack_technique_id
attack_technique_id=attack_tech.id,
# Keyword argument: defensive_technique_id
defensive_technique_id=def_tech.id,
)
# Stage new record(s) for database insertion
db.add(mapping)
# Assign created = 1
created += 1
# Commit all pending changes to the database
db.commit()
# Assign total = db.query(DefensiveTechniqueMapping).count()
total = db.query(DefensiveTechniqueMapping).count()
# Log info: "D3FEND mappings: %d created, %d skipped, %d total
logger.info("D3FEND mappings: %d created, %d skipped, %d total", created, skipped, total)
# Return {"created": created, "skipped": skipped, "total": total}
return {"created": created, "skipped": skipped, "total": total}
# Define function sync
def sync(db: Session) -> dict:
"""Sync D3FEND techniques and ATT&CK mappings.
Called by the Data Sources router when the user clicks Sync for D3FEND.
Returns a flat summary dict suitable for ``last_sync_stats``.
Args:
db (Session): Active SQLAlchemy database session.
Returns:
dict: Flat summary dict suitable for ``last_sync_stats``, containing
``techniques_created``, ``techniques_updated``,
``techniques_total``, ``mappings_created``,
``mappings_skipped``, and ``mappings_total``.
"""
from app.models.data_source import DataSource
# Import datetime from datetime
from datetime import datetime
# Import DataSource from app.models.data_source
from app.models.data_source import DataSource
# Assign tech_result = import_d3fend_techniques(db)
tech_result = import_d3fend_techniques(db)
# Assign mapping_result = import_d3fend_mappings(db)
mapping_result = import_d3fend_mappings(db)
# Assign summary = {
summary = {
# Literal argument value
"techniques_created": tech_result.get("created", 0),
# Literal argument value
"techniques_updated": tech_result.get("updated", 0),
# Literal argument value
"techniques_total": tech_result.get("total", 0),
# Literal argument value
"mappings_created": mapping_result.get("created", 0),
# Literal argument value
"mappings_skipped": mapping_result.get("skipped", 0),
# Literal argument value
"mappings_total": mapping_result.get("total", 0),
}
# Update DataSource record
ds = db.query(DataSource).filter(DataSource.name == "d3fend").first()
# Check: ds
if ds:
# Assign ds.last_sync_at = datetime.utcnow()
ds.last_sync_at = datetime.utcnow()
# Assign ds.last_sync_status = "success"
ds.last_sync_status = "success"
# Assign ds.last_sync_stats = summary
ds.last_sync_stats = summary
# Commit all pending changes to the database
db.commit()
# Log info: "D3FEND sync complete — %s", summary
logger.info("D3FEND sync complete — %s", summary)
# Return summary
return summary
def get_defenses_for_technique(db: Session, technique_id) -> list[dict]:
"""Get all D3FEND defensive techniques mapped to a given ATT&CK technique."""
# Define function get_defenses_for_technique
def get_defenses_for_technique(db: Session, technique_id: UUID) -> list[dict]:
"""Return all D3FEND defensive techniques mapped to a given ATT&CK technique.
Args:
db (Session): Active SQLAlchemy database session.
technique_id (UUID): UUID of the ATT&CK technique to look up.
Returns:
list[dict]: List of defensive technique dicts, each containing
``id``, ``d3fend_id``, ``name``, ``description``, ``tactic``,
and ``d3fend_url``.
"""
# Assign mappings = (
mappings = (
db.query(DefensiveTechniqueMapping)
# Chain .filter() call
.filter(DefensiveTechniqueMapping.attack_technique_id == technique_id)
# Chain .all() call
.all()
)
# Assign results = []
results = []
# Iterate over mappings
for m in mappings:
# Assign dt = m.defensive_technique
dt = m.defensive_technique
# Call results.append()
results.append({
# Literal argument value
"id": str(dt.id),
# Literal argument value
"d3fend_id": dt.d3fend_id,
# Literal argument value
"name": dt.name,
# Literal argument value
"description": dt.description,
# Literal argument value
"tactic": dt.tactic,
# Literal argument value
"d3fend_url": dt.d3fend_url,
})
# Return results
return results
@@ -1,53 +1,92 @@
"""D3FEND query service — framework-agnostic queries for defensive techniques."""
# Enable future language features for compatibility
from __future__ import annotations
# Import Optional from typing
from typing import Optional
# Import func from sqlalchemy
from sqlalchemy import func
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session
# Import EntityNotFoundError from app.domain.errors
from app.domain.errors import EntityNotFoundError
# Import DefensiveTechnique from app.models.defensive_technique
from app.models.defensive_technique import DefensiveTechnique
# Import Technique from app.models.technique
from app.models.technique import Technique
# Import get_defenses_for_technique from app.services.d3fend_import_service
from app.services.d3fend_import_service import get_defenses_for_technique
# Import escape_like from app.utils
from app.utils import escape_like
# Define function list_defensive_techniques
def list_defensive_techniques(
# Entry: db
db: Session,
*,
# Entry: tactic
tactic: Optional[str] = None,
# Entry: search
search: Optional[str] = None,
# Entry: offset
offset: int = 0,
# Entry: limit
limit: int = 50,
) -> dict:
"""List D3FEND defensive techniques with optional filters."""
# Assign query = db.query(DefensiveTechnique)
query = db.query(DefensiveTechnique)
# Check: tactic
if tactic:
# Assign query = query.filter(DefensiveTechnique.tactic == tactic)
query = query.filter(DefensiveTechnique.tactic == tactic)
# Check: search
if search:
# Assign pattern = f"%{escape_like(search)}%"
pattern = f"%{escape_like(search)}%"
# Assign query = query.filter(
query = query.filter(
DefensiveTechnique.name.ilike(pattern)
| DefensiveTechnique.d3fend_id.ilike(pattern)
)
# Assign total = query.count()
total = query.count()
# Assign items = query.order_by(DefensiveTechnique.d3fend_id).offset(offset).limit(l...
items = query.order_by(DefensiveTechnique.d3fend_id).offset(offset).limit(limit).all()
# Return {
return {
# Literal argument value
"total": total,
# Literal argument value
"offset": offset,
# Literal argument value
"limit": limit,
# Literal argument value
"items": [
{
# Literal argument value
"id": str(dt.id),
# Literal argument value
"d3fend_id": dt.d3fend_id,
# Literal argument value
"name": dt.name,
# Literal argument value
"description": dt.description,
# Literal argument value
"tactic": dt.tactic,
# Literal argument value
"d3fend_url": dt.d3fend_url,
}
for dt in items
@@ -55,28 +94,44 @@ def list_defensive_techniques(
}
# Define function list_d3fend_tactics
def list_d3fend_tactics(db: Session) -> list[dict]:
"""Return a list of all D3FEND tactics with counts."""
# Assign rows = (
rows = (
db.query(DefensiveTechnique.tactic, func.count(DefensiveTechnique.id))
# Chain .group_by() call
.group_by(DefensiveTechnique.tactic)
# Chain .order_by() call
.order_by(DefensiveTechnique.tactic)
# Chain .all() call
.all()
)
# Return [{"tactic": tactic or "Unknown", "count": count} for tactic, count ...
return [{"tactic": tactic or "Unknown", "count": count} for tactic, count in rows]
# Define function get_defenses_for_attack_technique
def get_defenses_for_attack_technique(db: Session, mitre_id: str) -> dict:
"""Get all D3FEND defensive techniques mapped to a given ATT&CK technique."""
# Assign technique = db.query(Technique).filter(Technique.mitre_id == mitre_id).first()
technique = db.query(Technique).filter(Technique.mitre_id == mitre_id).first()
# Check: technique is None
if technique is None:
# Raise EntityNotFoundError
raise EntityNotFoundError("Technique", mitre_id)
# Assign defenses = get_defenses_for_technique(db, technique.id)
defenses = get_defenses_for_technique(db, technique.id)
# Return {
return {
# Literal argument value
"mitre_id": mitre_id,
# Literal argument value
"technique_name": technique.name,
# Literal argument value
"defenses": defenses,
# Literal argument value
"total": len(defenses),
}
+129 -1
View File
@@ -4,61 +4,99 @@ Provides list, update, sync, and stats. Sync operations commit internally
since they are long-running and self-contained.
"""
# Enable future language features for compatibility
from __future__ import annotations
# Import logging
import logging
# Import datetime from datetime
from datetime import datetime
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session
# Import BusinessRuleViolation, EntityNotFoundError from app.domain.errors
from app.domain.errors import BusinessRuleViolation, EntityNotFoundError
# Import get_import_handler from app.domain.ports.import_service
from app.domain.ports.import_service import get_import_handler
# Import DataSource from app.models.data_source
from app.models.data_source import DataSource
# Assign logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
# Define function list_sources
def list_sources(db: Session) -> list[dict]:
"""Return all registered data sources as a list of dicts."""
# Assign sources = db.query(DataSource).order_by(DataSource.name).all()
sources = db.query(DataSource).order_by(DataSource.name).all()
# Return [
return [
{
# Literal argument value
"id": str(s.id),
# Literal argument value
"name": s.name,
# Literal argument value
"display_name": s.display_name,
# Literal argument value
"type": s.type,
# Literal argument value
"url": s.url,
# Literal argument value
"description": s.description,
# Literal argument value
"is_enabled": s.is_enabled,
# Literal argument value
"last_sync_at": s.last_sync_at.isoformat() if s.last_sync_at else None,
# Literal argument value
"last_sync_status": s.last_sync_status,
# Literal argument value
"last_sync_stats": s.last_sync_stats,
# Literal argument value
"sync_frequency": s.sync_frequency,
# Literal argument value
"config": s.config,
# Literal argument value
"created_at": s.created_at.isoformat() if s.created_at else None,
}
for s in sources
]
# Define function update_source
def update_source(db: Session, source_id: str, **fields: object) -> None:
"""Update a data source's fields (is_enabled, sync_frequency, config).
Raises EntityNotFoundError if source does not exist.
Does not commit; the router handles that.
"""
# Assign ds = db.query(DataSource).filter(DataSource.id == source_id).first()
ds = db.query(DataSource).filter(DataSource.id == source_id).first()
# Check: not ds
if not ds:
# Raise EntityNotFoundError
raise EntityNotFoundError("Data source", source_id)
# Check: "is_enabled" in fields
if "is_enabled" in fields:
# Assign ds.is_enabled = fields["is_enabled"]
ds.is_enabled = fields["is_enabled"]
# Check: "sync_frequency" in fields
if "sync_frequency" in fields:
# Assign ds.sync_frequency = fields["sync_frequency"]
ds.sync_frequency = fields["sync_frequency"]
# Check: "config" in fields
if "config" in fields:
# Assign ds.config = fields["config"]
ds.config = fields["config"]
# Define function sync_source
def sync_source(db: Session, source_id: str) -> dict:
"""Trigger sync for a specific data source.
@@ -67,131 +105,221 @@ def sync_source(db: Session, source_id: str) -> dict:
Commits internally (long-running, self-contained operation).
Returns dict with message, source, stats.
"""
# Assign ds = db.query(DataSource).filter(DataSource.id == source_id).first()
ds = db.query(DataSource).filter(DataSource.id == source_id).first()
# Check: not ds
if not ds:
# Raise EntityNotFoundError
raise EntityNotFoundError("Data source", source_id)
# Assign handler = get_import_handler(ds.name)
handler = get_import_handler(ds.name)
# Check: handler is None
if handler is None:
# Raise BusinessRuleViolation
raise BusinessRuleViolation(f"No sync handler available for '{ds.name}'")
# Assign ds.last_sync_status = "in_progress"
ds.last_sync_status = "in_progress"
# Commit all pending changes to the database
db.commit()
# Attempt the following; catch errors below
try:
# Assign summary = handler(db)
summary = handler(db)
# Handle Exception
except Exception as exc:
# Log error: "Sync failed for %s: %s", ds.name, exc, exc_info=T
logger.error("Sync failed for %s: %s", ds.name, exc, exc_info=True)
# Assign ds.last_sync_status = "error"
ds.last_sync_status = "error"
# Assign ds.last_sync_at = datetime.utcnow()
ds.last_sync_at = datetime.utcnow()
# Assign ds.last_sync_stats = {"error": str(exc)}
ds.last_sync_stats = {"error": str(exc)}
# Commit all pending changes to the database
db.commit()
# Raise BusinessRuleViolation
raise BusinessRuleViolation(
f"Sync failed for '{ds.display_name}'. Check server logs for details."
)
# Assign ds.last_sync_at = datetime.utcnow()
ds.last_sync_at = datetime.utcnow()
# Assign ds.last_sync_status = "success"
ds.last_sync_status = "success"
# Assign ds.last_sync_stats = summary
ds.last_sync_stats = summary
# Commit all pending changes to the database
db.commit()
# Return {
return {
# Literal argument value
"message": f"Sync complete for {ds.display_name}",
# Literal argument value
"source": ds.name,
# Literal argument value
"stats": summary,
}
# Define function sync_all_sources
def sync_all_sources(db: Session) -> list[dict]:
"""Trigger sync for all enabled data sources (sequentially).
Commits internally (long-running, self-contained operation).
Returns list of result dicts with source, status, stats/detail.
"""
# Assign enabled_sources = (
enabled_sources = (
db.query(DataSource)
# Chain .filter() call
.filter(DataSource.is_enabled == True)
# Chain .order_by() call
.order_by(DataSource.name)
# Chain .all() call
.all()
)
# Assign results = []
results = []
# Iterate over enabled_sources
for ds in enabled_sources:
# Assign handler = get_import_handler(ds.name)
handler = get_import_handler(ds.name)
# Check: handler is None
if handler is None:
# Call results.append()
results.append({
# Literal argument value
"source": ds.name,
# Literal argument value
"status": "skipped",
# Literal argument value
"detail": "No sync handler available",
})
# Skip to the next loop iteration
continue
# Assign ds.last_sync_status = "in_progress"
ds.last_sync_status = "in_progress"
# Commit all pending changes to the database
db.commit()
# Attempt the following; catch errors below
try:
# Assign summary = handler(db)
summary = handler(db)
# Assign ds.last_sync_at = datetime.utcnow()
ds.last_sync_at = datetime.utcnow()
# Assign ds.last_sync_status = "success"
ds.last_sync_status = "success"
# Assign ds.last_sync_stats = summary
ds.last_sync_stats = summary
# Commit all pending changes to the database
db.commit()
# Call results.append()
results.append({
# Literal argument value
"source": ds.name,
# Literal argument value
"status": "success",
# Literal argument value
"stats": summary,
})
# Handle Exception
except Exception as exc:
# Log error: "Sync failed for %s: %s", ds.name, exc, exc_info=T
logger.error("Sync failed for %s: %s", ds.name, exc, exc_info=True)
# Assign ds.last_sync_status = "error"
ds.last_sync_status = "error"
# Assign ds.last_sync_at = datetime.utcnow()
ds.last_sync_at = datetime.utcnow()
# Assign ds.last_sync_stats = {"error": str(exc)}
ds.last_sync_stats = {"error": str(exc)}
# Commit all pending changes to the database
db.commit()
# Call results.append()
results.append({
# Literal argument value
"source": ds.name,
# Literal argument value
"status": "error",
# Literal argument value
"detail": "Sync failed. Check server logs for details.",
})
# Return results
return results
# Define function get_source_stats
def get_source_stats(db: Session, source_id: str) -> dict:
"""Return detailed statistics for a data source.
Raises EntityNotFoundError if source does not exist.
"""
# Assign ds = db.query(DataSource).filter(DataSource.id == source_id).first()
ds = db.query(DataSource).filter(DataSource.id == source_id).first()
# Check: not ds
if not ds:
# Raise EntityNotFoundError
raise EntityNotFoundError("Data source", source_id)
from app.models.test_template import TestTemplate
# Import DetectionRule from app.models.detection_rule
from app.models.detection_rule import DetectionRule
# Import TestTemplate from app.models.test_template
from app.models.test_template import TestTemplate
# Assign template_count = 0
template_count = 0
# Assign rule_count = 0
rule_count = 0
# Check: ds.type == "attack_procedure"
if ds.type == "attack_procedure":
# Assign template_count = (
template_count = (
db.query(TestTemplate)
# Chain .filter() call
.filter(TestTemplate.source == ds.name)
# Chain .count() call
.count()
)
# Alternative: ds.type == "detection_rule"
elif ds.type == "detection_rule":
# Assign rule_count = (
rule_count = (
db.query(DetectionRule)
# Chain .filter() call
.filter(DetectionRule.source == ds.name)
# Chain .count() call
.count()
)
# Return {
return {
# Literal argument value
"id": str(ds.id),
# Literal argument value
"name": ds.name,
# Literal argument value
"display_name": ds.display_name,
# Literal argument value
"type": ds.type,
# Literal argument value
"is_enabled": ds.is_enabled,
# Literal argument value
"last_sync_at": ds.last_sync_at.isoformat() if ds.last_sync_at else None,
# Literal argument value
"last_sync_status": ds.last_sync_status,
# Literal argument value
"last_sync_stats": ds.last_sync_stats,
# Literal argument value
"total_templates": template_count,
# Literal argument value
"total_rules": rule_count,
}
+228 -10
View File
@@ -6,76 +6,136 @@ that the router remains a thin HTTP adapter.
This module is framework-agnostic: no FastAPI imports.
"""
# Enable future language features for compatibility
from __future__ import annotations
# Import datetime from datetime
from datetime import datetime
# Import Any from typing
from typing import Any
# Import UUID from uuid
from uuid import UUID
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session
# Import EntityNotFoundError from app.domain.errors
from app.domain.errors import EntityNotFoundError
from app.models.detection_rule import DetectionRule
from app.models.test import Test
from app.models.test_template import TestTemplate
from app.models.test_template_detection_rule import TestTemplateDetectionRule
from app.models.test_detection_result import TestDetectionResult
from app.models.technique import Technique
from app.utils import escape_like
# Import DetectionRule from app.models.detection_rule
from app.models.detection_rule import DetectionRule
# Import Technique from app.models.technique
from app.models.technique import Technique
# Import Test from app.models.test
from app.models.test import Test
# Import TestDetectionResult from app.models.test_detection_result
from app.models.test_detection_result import TestDetectionResult
# Import TestTemplate from app.models.test_template
from app.models.test_template import TestTemplate
# Import TestTemplateDetectionRule from app.models.test_template_detection_rule
from app.models.test_template_detection_rule import TestTemplateDetectionRule
# Import escape_like from app.utils
from app.utils import escape_like
# ── Public service functions ──────────────────────────────────────────
def list_rules(
# Entry: db
db: Session,
*,
# Entry: technique
technique: str | None = None,
# Entry: source
source: str | None = None,
# Entry: severity
severity: str | None = None,
# Entry: search
search: str | None = None,
# Entry: offset
offset: int = 0,
# Entry: limit
limit: int = 50,
) -> dict[str, Any]:
"""List detection rules with optional filters and pagination."""
# Assign query = db.query(DetectionRule).filter(DetectionRule.is_active == True)
query = db.query(DetectionRule).filter(DetectionRule.is_active == True)
# Check: technique
if technique:
# Assign query = query.filter(DetectionRule.mitre_technique_id == technique)
query = query.filter(DetectionRule.mitre_technique_id == technique)
# Check: source
if source:
# Assign query = query.filter(DetectionRule.source == source)
query = query.filter(DetectionRule.source == source)
# Check: severity
if severity:
# Assign query = query.filter(DetectionRule.severity == severity)
query = query.filter(DetectionRule.severity == severity)
# Check: search
if search:
# Assign pattern = f"%{escape_like(search)}%"
pattern = f"%{escape_like(search)}%"
# Assign query = query.filter(
query = query.filter(
DetectionRule.title.ilike(pattern)
| DetectionRule.description.ilike(pattern)
)
# Assign total = query.count()
total = query.count()
# Assign items = (
items = (
query.order_by(DetectionRule.mitre_technique_id, DetectionRule.title)
# Chain .offset() call
.offset(offset)
# Chain .limit() call
.limit(limit)
# Chain .all() call
.all()
)
# Return {
return {
# Literal argument value
"total": total,
# Literal argument value
"offset": offset,
# Literal argument value
"limit": limit,
# Literal argument value
"items": [
{
# Literal argument value
"id": str(r.id),
# Literal argument value
"mitre_technique_id": r.mitre_technique_id,
# Literal argument value
"title": r.title,
# Literal argument value
"description": r.description,
# Literal argument value
"source": r.source,
# Literal argument value
"source_url": r.source_url,
# Literal argument value
"rule_format": r.rule_format,
# Literal argument value
"severity": r.severity,
# Literal argument value
"platforms": r.platforms or [],
# Literal argument value
"log_sources": r.log_sources,
# Literal argument value
"is_active": r.is_active,
}
for r in items
@@ -83,48 +143,78 @@ def list_rules(
}
# Define function get_rules_for_template
def get_rules_for_template(db: Session, template_id: str) -> dict[str, Any]:
"""Get detection rules associated with a test template.
Raises EntityNotFoundError if the template does not exist.
"""
# Assign template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first()
template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first()
# Check: not template
if not template:
# Raise EntityNotFoundError
raise EntityNotFoundError("Test template", template_id)
# Assign associations = (
associations = (
db.query(TestTemplateDetectionRule)
# Chain .filter() call
.filter(TestTemplateDetectionRule.test_template_id == template_id)
# Chain .all() call
.all()
)
# Assign rules = []
rules = []
# Iterate over associations
for assoc in associations:
# Assign r = assoc.detection_rule
r = assoc.detection_rule
# Call rules.append()
rules.append({
# Literal argument value
"id": str(r.id),
# Literal argument value
"mitre_technique_id": r.mitre_technique_id,
# Literal argument value
"title": r.title,
# Literal argument value
"description": r.description,
# Literal argument value
"source": r.source,
# Literal argument value
"source_url": r.source_url,
# Literal argument value
"rule_content": r.rule_content,
# Literal argument value
"rule_format": r.rule_format,
# Literal argument value
"severity": r.severity,
# Literal argument value
"platforms": r.platforms or [],
# Literal argument value
"log_sources": r.log_sources,
# Literal argument value
"is_primary": assoc.is_primary,
})
# Return {
return {
# Literal argument value
"template_id": str(template.id),
# Literal argument value
"template_name": template.name,
# Literal argument value
"mitre_technique_id": template.mitre_technique_id,
# Literal argument value
"rules": rules,
# Literal argument value
"total": len(rules),
}
# Define function auto_associate_rules
def auto_associate_rules(db: Session) -> dict[str, Any]:
"""Auto-associate test templates with detection rules by MITRE technique ID.
@@ -132,188 +222,316 @@ def auto_associate_rules(db: Session) -> dict[str, Any]:
technique and creates associations. Rules with severity high/critical
are marked as primary. Performs commit internally.
"""
# Assign templates = db.query(TestTemplate).filter(TestTemplate.is_active == True).all()
templates = db.query(TestTemplate).filter(TestTemplate.is_active == True).all()
# Assign rules = db.query(DetectionRule).filter(DetectionRule.is_active == True).all()
rules = db.query(DetectionRule).filter(DetectionRule.is_active == True).all()
# Assign rules_by_technique = {}
rules_by_technique: dict[str, list] = {}
# Iterate over rules
for rule in rules:
# Assign tid = rule.mitre_technique_id
tid = rule.mitre_technique_id
# Check: tid not in rules_by_technique
if tid not in rules_by_technique:
# Assign rules_by_technique[tid] = []
rules_by_technique[tid] = []
# rules_by_technique[tid].append(rule)
rules_by_technique[tid].append(rule)
# Assign created = 0
created = 0
# Assign skipped = 0
skipped = 0
# Assign high_severities = {"high", "critical"}
high_severities = {"high", "critical"}
# Iterate over templates
for template in templates:
# Assign matching_rules = rules_by_technique.get(template.mitre_technique_id, [])
matching_rules = rules_by_technique.get(template.mitre_technique_id, [])
# Iterate over matching_rules
for rule in matching_rules:
# Assign existing = (
existing = (
db.query(TestTemplateDetectionRule)
# Chain .filter() call
.filter(
TestTemplateDetectionRule.test_template_id == template.id,
TestTemplateDetectionRule.detection_rule_id == rule.id,
)
# Chain .first() call
.first()
)
# Check: existing
if existing:
# Assign skipped = 1
skipped += 1
# Skip to the next loop iteration
continue
# Assign is_primary = (rule.severity or "").lower() in high_severities
is_primary = (rule.severity or "").lower() in high_severities
# Assign assoc = TestTemplateDetectionRule(
assoc = TestTemplateDetectionRule(
# Keyword argument: test_template_id
test_template_id=template.id,
# Keyword argument: detection_rule_id
detection_rule_id=rule.id,
# Keyword argument: is_primary
is_primary=is_primary,
)
# Stage new record(s) for database insertion
db.add(assoc)
# Assign created = 1
created += 1
# Commit all pending changes to the database
db.commit()
# Assign total = db.query(TestTemplateDetectionRule).count()
total = db.query(TestTemplateDetectionRule).count()
# Return {
return {
# Literal argument value
"created": created,
# Literal argument value
"skipped": skipped,
# Literal argument value
"total_associations": total,
}
# Define function get_rules_for_test
def get_rules_for_test(db: Session, test_id: str) -> dict[str, Any]:
"""Get detection rules relevant to a test, along with their evaluation results.
Finds rules by matching the test's technique to detection rules.
Raises EntityNotFoundError if the test or its technique does not exist.
"""
# Assign test = db.query(Test).filter(Test.id == test_id).first()
test = db.query(Test).filter(Test.id == test_id).first()
# Check: not test
if not test:
# Raise EntityNotFoundError
raise EntityNotFoundError("Test", str(test_id))
# Assign technique = db.query(Technique).filter(Technique.id == test.technique_id).first()
technique = db.query(Technique).filter(Technique.id == test.technique_id).first()
# Check: not technique
if not technique:
# Raise EntityNotFoundError
raise EntityNotFoundError("Technique", str(test.technique_id))
# Assign rules = (
rules = (
db.query(DetectionRule)
# Chain .filter() call
.filter(
DetectionRule.mitre_technique_id == technique.mitre_id,
DetectionRule.is_active == True,
)
# Chain .order_by() call
.order_by(DetectionRule.severity.desc(), DetectionRule.title)
# Chain .all() call
.all()
)
# Assign existing_results = (
existing_results = (
db.query(TestDetectionResult)
# Chain .filter() call
.filter(TestDetectionResult.test_id == test_id)
# Chain .all() call
.all()
)
# Assign results_map = {str(r.detection_rule_id): r for r in existing_results}
results_map = {str(r.detection_rule_id): r for r in existing_results}
# Assign items = []
items = []
# Assign triggered_count = 0
triggered_count = 0
# Assign evaluated_count = 0
evaluated_count = 0
# Iterate over rules
for rule in rules:
# Assign result = results_map.get(str(rule.id))
result = results_map.get(str(rule.id))
# Assign triggered = result.triggered if result else None
triggered = result.triggered if result else None
# Assign notes = result.notes if result else None
notes = result.notes if result else None
# Assign evaluated_at = result.evaluated_at.isoformat() if result and result.evaluated_at e...
evaluated_at = result.evaluated_at.isoformat() if result and result.evaluated_at else None
# Check: triggered is not None
if triggered is not None:
# Assign evaluated_count = 1
evaluated_count += 1
# Check: triggered
if triggered:
# Assign triggered_count = 1
triggered_count += 1
# Call items.append()
items.append({
# Literal argument value
"id": str(rule.id),
# Literal argument value
"mitre_technique_id": rule.mitre_technique_id,
# Literal argument value
"title": rule.title,
# Literal argument value
"description": rule.description,
# Literal argument value
"source": rule.source,
# Literal argument value
"source_url": rule.source_url,
# Literal argument value
"rule_content": rule.rule_content,
# Literal argument value
"rule_format": rule.rule_format,
# Literal argument value
"severity": rule.severity,
# Literal argument value
"platforms": rule.platforms or [],
# Literal argument value
"log_sources": rule.log_sources,
# Literal argument value
"triggered": triggered,
# Literal argument value
"notes": notes,
# Literal argument value
"evaluated_at": evaluated_at,
# Literal argument value
"result_id": str(result.id) if result else None,
})
# Return {
return {
# Literal argument value
"test_id": str(test.id),
# Literal argument value
"mitre_technique_id": technique.mitre_id,
# Literal argument value
"rules": items,
# Literal argument value
"total": len(items),
# Literal argument value
"evaluated": evaluated_count,
# Literal argument value
"triggered": triggered_count,
# Literal argument value
"detection_rate": round(triggered_count / evaluated_count * 100, 1) if evaluated_count > 0 else 0,
}
# Define function evaluate_rule
def evaluate_rule(
# Entry: db
db: Session,
*,
test_id: Any,
detection_rule_id: Any,
# Entry: test_id
test_id: UUID,
# Entry: detection_rule_id
detection_rule_id: UUID,
# Entry: triggered
triggered: bool | None,
# Entry: notes
notes: str | None,
evaluator_id: Any,
# Entry: evaluator_id
evaluator_id: UUID,
) -> dict[str, Any]:
"""Save or update the evaluation result for a detection rule on a test.
Raises EntityNotFoundError if the test or detection rule does not exist.
"""
# Assign test = db.query(Test).filter(Test.id == test_id).first()
test = db.query(Test).filter(Test.id == test_id).first()
# Check: not test
if not test:
# Raise EntityNotFoundError
raise EntityNotFoundError("Test", str(test_id))
# Assign rule = db.query(DetectionRule).filter(DetectionRule.id == detection_rule_i...
rule = db.query(DetectionRule).filter(DetectionRule.id == detection_rule_id).first()
# Check: not rule
if not rule:
# Raise EntityNotFoundError
raise EntityNotFoundError("Detection rule", str(detection_rule_id))
# Assign existing = (
existing = (
db.query(TestDetectionResult)
# Chain .filter() call
.filter(
TestDetectionResult.test_id == test_id,
TestDetectionResult.detection_rule_id == detection_rule_id,
)
# Chain .first() call
.first()
)
# Check: existing
if existing:
# Assign existing.triggered = triggered
existing.triggered = triggered
# Assign existing.notes = notes
existing.notes = notes
# Assign existing.evaluated_by = evaluator_id
existing.evaluated_by = evaluator_id
# Assign existing.evaluated_at = datetime.utcnow()
existing.evaluated_at = datetime.utcnow()
# Commit all pending changes to the database
db.commit()
# Reload ORM object attributes from the database
db.refresh(existing)
# Return {
return {
# Literal argument value
"id": str(existing.id),
# Literal argument value
"triggered": existing.triggered,
# Literal argument value
"notes": existing.notes,
# Literal argument value
"evaluated_at": existing.evaluated_at.isoformat() if existing.evaluated_at else None,
}
# Fallback: handle remaining cases
else:
# Assign result = TestDetectionResult(
result = TestDetectionResult(
# Keyword argument: test_id
test_id=test_id,
# Keyword argument: detection_rule_id
detection_rule_id=detection_rule_id,
# Keyword argument: triggered
triggered=triggered,
# Keyword argument: notes
notes=notes,
# Keyword argument: evaluated_by
evaluated_by=evaluator_id,
# Keyword argument: evaluated_at
evaluated_at=datetime.utcnow(),
)
# Stage new record(s) for database insertion
db.add(result)
# Commit all pending changes to the database
db.commit()
# Reload ORM object attributes from the database
db.refresh(result)
# Return {
return {
# Literal argument value
"id": str(result.id),
# Literal argument value
"triggered": result.triggered,
# Literal argument value
"notes": result.notes,
# Literal argument value
"evaluated_at": result.evaluated_at.isoformat() if result.evaluated_at else None,
}
+190 -6
View File
@@ -21,22 +21,39 @@ rules are identified by ``source = "elastic"`` + ``source_id`` (the
TOML filename).
"""
# Import io
import io
# Import logging
import logging
# Import shutil
import shutil
# Import tempfile
import tempfile
# Import zipfile
import zipfile
# Import datetime from datetime
from datetime import datetime
# Import Path from pathlib
from pathlib import Path
# Import requests
import requests as _requests
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session
from app.models.detection_rule import DetectionRule
# Import DataSource from app.models.data_source
from app.models.data_source import DataSource
from app.models.technique import Technique
from app.services.audit_service import log_action
# Assign logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
@@ -44,19 +61,33 @@ logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
ELASTIC_ZIP_URL = (
# Literal argument value
"https://github.com/elastic/detection-rules"
# Literal argument value
"/archive/refs/heads/main.zip"
)
# Assign _DOWNLOAD_TIMEOUT = 300
_DOWNLOAD_TIMEOUT = 300
# Assign _ZIP_ROOT_PREFIX = "detection-rules-main"
_ZIP_ROOT_PREFIX = "detection-rules-main"
# Safety limits for ZIP extraction — prevent zip-bomb DoS
_MAX_UNCOMPRESSED_SIZE = 500 * 1024 * 1024 # 500 MB
# Assign _MAX_ENTRIES = 50_000
_MAX_ENTRIES = 50_000
# Severity normalisation
_SEVERITY_MAP = {
# Literal argument value
"informational": "informational",
# Literal argument value
"low": "low",
# Literal argument value
"medium": "medium",
# Literal argument value
"high": "high",
# Literal argument value
"critical": "critical",
}
@@ -68,14 +99,21 @@ _SEVERITY_MAP = {
def _download_zip(url: str = ELASTIC_ZIP_URL) -> bytes:
"""Download the Elastic Detection Rules ZIP and return raw bytes."""
# Log info: "Downloading Elastic Detection Rules ZIP from %s …
logger.info("Downloading Elastic Detection Rules ZIP from %s", url)
# Assign resp = _requests.get(url, timeout=_DOWNLOAD_TIMEOUT, stream=True)
resp = _requests.get(url, timeout=_DOWNLOAD_TIMEOUT, stream=True)
# Call resp.raise_for_status()
resp.raise_for_status()
# Assign content = resp.content
content = resp.content
# Log info: "Downloaded %.1f MB", len(content) / (1024 * 1024
logger.info("Downloaded %.1f MB", len(content) / (1024 * 1024))
# Return content
return content
# Define function _safe_extract_zip
def _safe_extract_zip(zip_bytes: bytes, dest: str) -> None:
"""Extract *zip_bytes* into *dest* with Zip Slip and Zip Bomb protection.
@@ -83,62 +121,85 @@ def _safe_extract_zip(zip_bytes: bytes, dest: str) -> None:
directory (path traversal / Zip Slip) or if the archive exceeds the
safety limits.
"""
# Maximum uncompressed size: 500 MB — prevents zip-bomb DoS
_MAX_UNCOMPRESSED_SIZE = 500 * 1024 * 1024
# Maximum number of entries
_MAX_ENTRIES = 50_000
# Assign dest_path = Path(dest).resolve()
dest_path = Path(dest).resolve()
# Open context manager
with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
# Assign entries = zf.infolist()
entries = zf.infolist()
# Check: len(entries) > _MAX_ENTRIES
if len(entries) > _MAX_ENTRIES:
# Raise ValueError
raise ValueError(
f"ZIP archive contains {len(entries)} entries "
f"(limit: {_MAX_ENTRIES}) — possible zip bomb"
)
# Assign total_size = sum(info.file_size for info in entries)
total_size = sum(info.file_size for info in entries)
# Check: total_size > _MAX_UNCOMPRESSED_SIZE
if total_size > _MAX_UNCOMPRESSED_SIZE:
# Raise ValueError
raise ValueError(
f"ZIP uncompressed size {total_size / (1024 * 1024):.0f} MB "
f"exceeds limit of {_MAX_UNCOMPRESSED_SIZE / (1024 * 1024):.0f} MB"
)
# Iterate over entries
for member in entries:
# Assign target = (dest_path / member.filename).resolve()
target = (dest_path / member.filename).resolve()
# Check: not target.is_relative_to(dest_path)
if not target.is_relative_to(dest_path):
# Raise ValueError
raise ValueError(
f"Zip Slip detected — member '{member.filename}' "
f"resolves outside target directory"
)
# Call zf.extractall()
zf.extractall(dest)
# Define function _extract_zip
def _extract_zip(zip_bytes: bytes, dest: str) -> Path:
"""Extract *zip_bytes* into *dest* and return rules/ dir."""
# Call _safe_extract_zip()
_safe_extract_zip(zip_bytes, dest)
# Assign rules_dir = Path(dest) / _ZIP_ROOT_PREFIX / "rules"
rules_dir = Path(dest) / _ZIP_ROOT_PREFIX / "rules"
# Check: not rules_dir.is_dir()
if not rules_dir.is_dir():
# Raise FileNotFoundError
raise FileNotFoundError(
f"Expected rules directory not found at {rules_dir}"
)
# Return rules_dir
return rules_dir
# Define function _parse_toml_safe
def _parse_toml_safe(path: Path) -> dict | None:
"""Parse a TOML file. Uses the ``toml`` library."""
# Attempt the following; catch errors below
try:
# Import toml
import toml
# Open context manager
with open(path, "r", encoding="utf-8") as fh:
# Return toml.load(fh)
return toml.load(fh)
# Handle Exception
except Exception as exc:
# Log debug: "Failed to parse %s: %s", path, exc
logger.debug("Failed to parse %s: %s", path, exc)
# Return None
return None
# Define function _extract_mitre_techniques
def _extract_mitre_techniques(threat_list: list) -> list[str]:
"""Extract MITRE technique IDs from Elastic's ``rule.threat`` array.
@@ -156,82 +217,132 @@ def _extract_mitre_techniques(threat_list: list) -> list[str]:
name = "LSASS Memory"
id = "T1003.001"
"""
# Assign technique_ids = []
technique_ids = []
# Check: not isinstance(threat_list, list)
if not isinstance(threat_list, list):
# Return technique_ids
return technique_ids
# Iterate over threat_list
for threat_entry in threat_list:
# Check: not isinstance(threat_entry, dict)
if not isinstance(threat_entry, dict):
# Skip to the next loop iteration
continue
# Skip non-MITRE frameworks
framework = threat_entry.get("framework", "")
# Check: "MITRE" not in str(framework).upper()
if "MITRE" not in str(framework).upper():
# Skip to the next loop iteration
continue
# Assign techniques = threat_entry.get("technique", [])
techniques = threat_entry.get("technique", [])
# Check: not isinstance(techniques, list)
if not isinstance(techniques, list):
# Skip to the next loop iteration
continue
# Iterate over techniques
for tech in techniques:
# Check: not isinstance(tech, dict)
if not isinstance(tech, dict):
# Skip to the next loop iteration
continue
# Assign tech_id = tech.get("id", "")
tech_id = tech.get("id", "")
# Check: tech_id and str(tech_id).upper().startswith("T")
if tech_id and str(tech_id).upper().startswith("T"):
# Call technique_ids.append()
technique_ids.append(str(tech_id).upper())
# Check subtechniques
subtechniques = tech.get("subtechnique", [])
# Check: isinstance(subtechniques, list)
if isinstance(subtechniques, list):
# Iterate over subtechniques
for subtech in subtechniques:
# Check: isinstance(subtech, dict)
if isinstance(subtech, dict):
# Assign sub_id = subtech.get("id", "")
sub_id = subtech.get("id", "")
# Check: sub_id and str(sub_id).upper().startswith("T")
if sub_id and str(sub_id).upper().startswith("T"):
# Call technique_ids.append()
technique_ids.append(str(sub_id).upper())
# Return list(set(technique_ids))
return list(set(technique_ids))
# Define function _parse_elastic_rules
def _parse_elastic_rules(rules_dir: Path) -> list[dict]:
"""Walk the rules directory and parse all TOML files.
Returns a flat list of dicts, one per (rule, technique) combination.
"""
# Assign results = []
results: list[dict] = []
# Assign toml_files = sorted(rules_dir.rglob("*.toml"))
toml_files = sorted(rules_dir.rglob("*.toml"))
# Log info: "Found %d TOML files to parse", len(toml_files
logger.info("Found %d TOML files to parse", len(toml_files))
# Iterate over toml_files
for toml_path in toml_files:
# Assign data = _parse_toml_safe(toml_path)
data = _parse_toml_safe(toml_path)
# Check: not data
if not data:
# Skip to the next loop iteration
continue
# Assign rule = data.get("rule", {})
rule = data.get("rule", {})
# Check: not isinstance(rule, dict)
if not isinstance(rule, dict):
# Skip to the next loop iteration
continue
# Assign name = rule.get("name", "").strip()
name = rule.get("name", "").strip()
# Check: not name
if not name:
# Skip to the next loop iteration
continue
# Extract MITRE technique IDs
threat_list = rule.get("threat", [])
# Assign technique_ids = _extract_mitre_techniques(threat_list)
technique_ids = _extract_mitre_techniques(threat_list)
# Check: not technique_ids
if not technique_ids:
# Skip to the next loop iteration
continue
# Assign description = rule.get("description", "")
description = rule.get("description", "")
# Assign query = rule.get("query", "")
query = rule.get("query", "")
# Assign severity = _SEVERITY_MAP.get(str(rule.get("severity", "")).lower())
severity = _SEVERITY_MAP.get(str(rule.get("severity", "")).lower())
# Assign rule_type = rule.get("type", "query") # query, eql, threshold, etc.
rule_type = rule.get("type", "query") # query, eql, threshold, etc.
# Determine rule format based on type
if rule_type == "eql":
# Assign rule_format = "eql"
rule_format = "eql"
# Alternative: rule_type == "esql"
elif rule_type == "esql":
# Assign rule_format = "esql"
rule_format = "esql"
# Fallback: handle remaining cases
else:
# Assign rule_format = "kql"
rule_format = "kql"
# Use filename as source_id
@@ -239,51 +350,79 @@ def _parse_elastic_rules(rules_dir: Path) -> list[dict]:
# Read raw content
try:
# Open context manager
with open(toml_path, "r", encoding="utf-8") as fh:
# Assign raw_content = fh.read()
raw_content = fh.read()
# Handle Exception
except Exception:
# Assign raw_content = query or str(data)
raw_content = query or str(data)
# Build source URL
relative = str(toml_path.relative_to(rules_dir.parent)).replace("\\", "/")
# Assign source_url = (
source_url = (
f"https://github.com/elastic/detection-rules/blob/main/{relative}"
)
# One entry per technique
for tech_id in technique_ids:
# Call results.append()
results.append({
# Literal argument value
"mitre_technique_id": tech_id,
# Literal argument value
"title": name[:500],
# Literal argument value
"description": str(description)[:2000] if description else None,
# Literal argument value
"source_id": source_id,
# Literal argument value
"source_url": source_url,
# Literal argument value
"rule_content": query[:50000] if query else raw_content[:50000],
# Literal argument value
"rule_format": rule_format,
# Literal argument value
"severity": severity,
# Literal argument value
"platforms": _infer_platforms(rules_dir, toml_path),
})
# Log info: "Parsed %d (rule, technique) pairs total", len(res
logger.info("Parsed %d (rule, technique) pairs total", len(results))
# Return results
return results
# Define function _infer_platforms
def _infer_platforms(rules_dir: Path, toml_path: Path) -> list[str] | None:
"""Infer platforms from the rule's directory structure.
Elastic organizes rules by OS: rules/windows/, rules/linux/, etc.
"""
# Assign relative = toml_path.relative_to(rules_dir)
relative = toml_path.relative_to(rules_dir)
# Assign parts = [p.lower() for p in relative.parts]
parts = [p.lower() for p in relative.parts]
# Assign platforms = []
platforms = []
# Check: "windows" in parts
if "windows" in parts:
# Call platforms.append()
platforms.append("windows")
# Check: "linux" in parts
if "linux" in parts:
# Call platforms.append()
platforms.append("linux")
# Check: "macos" in parts
if "macos" in parts:
# Call platforms.append()
platforms.append("macos")
# Return platforms if platforms else None
return platforms if platforms else None
@@ -297,47 +436,78 @@ def sync(db: Session) -> dict:
Returns a summary dict with ``created``, ``skipped_existing``, ``total_parsed``.
"""
# Assign tmp_dir = tempfile.mkdtemp(prefix="aegis_elastic_")
tmp_dir = tempfile.mkdtemp(prefix="aegis_elastic_")
# Attempt the following; catch errors below
try:
# Assign zip_bytes = _download_zip()
zip_bytes = _download_zip()
# Assign rules_dir = _extract_zip(zip_bytes, tmp_dir)
rules_dir = _extract_zip(zip_bytes, tmp_dir)
# Assign parsed_rules = _parse_elastic_rules(rules_dir)
parsed_rules = _parse_elastic_rules(rules_dir)
# Always execute this cleanup block
finally:
# Call shutil.rmtree()
shutil.rmtree(tmp_dir, ignore_errors=True)
# Log info: "Cleaned up temp directory %s", tmp_dir
logger.info("Cleaned up temp directory %s", tmp_dir)
# Pre-load existing source_ids for dedup
existing_ids: set[str] = {
row[0]
for row in db.query(DetectionRule.source_id)
# Chain .filter() call
.filter(DetectionRule.source == "elastic")
# Chain .filter() call
.filter(DetectionRule.source_id.isnot(None))
# Chain .all() call
.all()
}
# Assign created = 0
created = 0
# Assign skipped = 0
skipped = 0
new_technique_ids: set[str] = set()
# Iterate over parsed_rules
for item in parsed_rules:
# Check: item["source_id"] in existing_ids
if item["source_id"] in existing_ids:
# Assign skipped = 1
skipped += 1
# Skip to the next loop iteration
continue
# Assign rule = DetectionRule(
rule = DetectionRule(
# Keyword argument: mitre_technique_id
mitre_technique_id=item["mitre_technique_id"],
# Keyword argument: title
title=item["title"],
# Keyword argument: description
description=item["description"],
# Keyword argument: source
source="elastic",
# Keyword argument: source_id
source_id=item["source_id"],
# Keyword argument: source_url
source_url=item["source_url"],
# Keyword argument: rule_content
rule_content=item["rule_content"],
# Keyword argument: rule_format
rule_format=item["rule_format"],
# Keyword argument: severity
severity=item["severity"],
# Keyword argument: platforms
platforms=item["platforms"],
# Keyword argument: is_active
is_active=True,
)
# Stage new record(s) for database insertion
db.add(rule)
# Call existing_ids.add()
existing_ids.add(item["source_id"])
new_technique_ids.add(item["mitre_technique_id"])
created += 1
@@ -350,22 +520,36 @@ def sync(db: Session) -> dict:
db.commit()
# Assign summary = {
summary = {
# Literal argument value
"created": created,
# Literal argument value
"skipped_existing": skipped,
# Literal argument value
"total_parsed": len(parsed_rules),
}
# Update DataSource record
ds = db.query(DataSource).filter(DataSource.name == "elastic_rules").first()
# Check: ds
if ds:
# Assign ds.last_sync_at = datetime.utcnow()
ds.last_sync_at = datetime.utcnow()
# Assign ds.last_sync_status = "success"
ds.last_sync_status = "success"
# Assign ds.last_sync_stats = summary
ds.last_sync_stats = summary
# Commit all pending changes to the database
db.commit()
# Log info: "Elastic import complete — %s", summary
logger.info("Elastic import complete — %s", summary)
# Call log_action()
log_action(db, user_id=None, action="import_elastic_rules",
# Keyword argument: entity_type
entity_type="detection_rule", entity_id=None, details=summary)
# Commit all pending changes to the database
db.commit()
# Return summary
return summary
+89
View File
@@ -5,20 +5,32 @@ The router is responsible for HTTP concerns, file I/O, MinIO upload,
audit logging, and response formatting.
"""
# Enable future language features for compatibility
from __future__ import annotations
# Import os
import os
# Import uuid
import uuid
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session
# Import from app.domain.errors
from app.domain.errors import (
BusinessRuleViolation,
EntityNotFoundError,
PermissionViolation,
)
# Import TeamSide, TestState from app.models.enums
from app.models.enums import TeamSide, TestState
# Import Evidence from app.models.evidence
from app.models.evidence import Evidence
# Import Test from app.models.test
from app.models.test import Test
# States where red evidence can be uploaded / deleted
@@ -31,19 +43,30 @@ MAX_UPLOAD_SIZE = 50 * 1024 * 1024
# Allowed file extensions (lowercase, with leading dot)
ALLOWED_EXTENSIONS: frozenset[str] = frozenset({
# Literal argument value
".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp", ".svg",
# Literal argument value
".pdf", ".doc", ".docx", ".xls", ".xlsx", ".csv", ".txt",
# Literal argument value
".md", ".rtf", ".odt", ".ods",
# Literal argument value
".log", ".pcap", ".pcapng", ".evtx", ".json", ".xml",
# Literal argument value
".yaml", ".yml", ".toml",
# Literal argument value
".zip", ".tar", ".gz", ".7z",
# Literal argument value
".har", ".eml", ".msg",
})
# Define function validate_upload_permission
def validate_upload_permission(
# Entry: test
test: Test,
# Entry: team
team: TeamSide,
# Entry: user_role
user_role: str,
) -> None:
"""Validate that the user can upload evidence for the given team in the current state.
@@ -52,35 +75,56 @@ def validate_upload_permission(
PermissionViolation: If user lacks role to upload for this team.
BusinessRuleViolation: If test state does not allow uploading for this team.
"""
# Check: user_role == "admin"
if user_role == "admin":
# Return control to caller
return
# Check: team == TeamSide.red
if team == TeamSide.red:
# Check: user_role not in ("red_tech", "red_lead")
if user_role not in ("red_tech", "red_lead"):
# Raise PermissionViolation
raise PermissionViolation(
# Literal argument value
"Only red_tech, red_lead or admin can upload red evidence"
)
# Check: test.state not in RED_EDITABLE_STATES
if test.state not in RED_EDITABLE_STATES:
# Raise BusinessRuleViolation
raise BusinessRuleViolation(
f"Cannot upload red evidence in '{test.state.value}' state "
# Literal argument value
"(allowed in: draft, red_executing)"
)
# Alternative: team == TeamSide.blue
elif team == TeamSide.blue:
# Check: user_role not in ("blue_tech", "blue_lead")
if user_role not in ("blue_tech", "blue_lead"):
# Raise PermissionViolation
raise PermissionViolation(
# Literal argument value
"Only blue_tech, blue_lead or admin can upload blue evidence"
)
# Check: test.state not in BLUE_EDITABLE_STATES
if test.state not in BLUE_EDITABLE_STATES:
# Raise BusinessRuleViolation
raise BusinessRuleViolation(
f"Cannot upload blue evidence in '{test.state.value}' state "
# Literal argument value
"(allowed in: blue_evaluating)"
)
# Define function validate_delete_permission
def validate_delete_permission(
# Entry: test
test: Test,
# Entry: evidence
evidence: Evidence,
# Entry: user_role
user_role: str,
# Entry: user_id
user_id: uuid.UUID,
) -> None:
"""Validate that the user can delete this evidence in the current state.
@@ -88,80 +132,125 @@ def validate_delete_permission(
Raises:
PermissionViolation: If user cannot delete in this state or lacks permission.
"""
# Check: test.state in (TestState.in_review, TestState.validated, TestState....
if test.state in (TestState.in_review, TestState.validated, TestState.rejected):
# Raise PermissionViolation
raise PermissionViolation(
f"Cannot delete evidence when test is in '{test.state.value}' state"
)
# Check: user_role == "admin"
if user_role == "admin":
# Return control to caller
return
# Assign ev_team = evidence.team
ev_team = evidence.team
# Check: ev_team == TeamSide.red
if ev_team == TeamSide.red:
# Check: test.state not in RED_EDITABLE_STATES
if test.state not in RED_EDITABLE_STATES:
# Raise PermissionViolation
raise PermissionViolation(
# Literal argument value
"Cannot delete red evidence outside draft/red_executing"
)
# Check: user_role not in ("red_tech", "red_lead") and evidence.uploaded_by ...
if user_role not in ("red_tech", "red_lead") and evidence.uploaded_by != user_id:
# Raise PermissionViolation
raise PermissionViolation(
# Literal argument value
"Not enough permissions to delete this evidence"
)
# Alternative: ev_team == TeamSide.blue
elif ev_team == TeamSide.blue:
# Check: test.state not in BLUE_EDITABLE_STATES
if test.state not in BLUE_EDITABLE_STATES:
# Raise PermissionViolation
raise PermissionViolation(
# Literal argument value
"Cannot delete blue evidence outside blue_evaluating"
)
# Check: user_role not in ("blue_tech", "blue_lead") and evidence.uploaded_b...
if user_role not in ("blue_tech", "blue_lead") and evidence.uploaded_by != user_id:
# Raise PermissionViolation
raise PermissionViolation(
# Literal argument value
"Not enough permissions to delete this evidence"
)
# Define function validate_file
def validate_file(file_name: str, content_size: int) -> None:
"""Validate file extension and size.
Raises:
BusinessRuleViolation: If extension is not allowed or file exceeds size limit.
"""
# _, ext = os.path.splitext(file_name)
_, ext = os.path.splitext(file_name)
# Assign ext_lower = ext.lower() if ext else ""
ext_lower = ext.lower() if ext else ""
# Check: ext_lower not in ALLOWED_EXTENSIONS
if ext_lower not in ALLOWED_EXTENSIONS:
# Raise BusinessRuleViolation
raise BusinessRuleViolation(
f"File type '{ext}' is not allowed. "
f"Permitted types: {', '.join(sorted(ALLOWED_EXTENSIONS))}"
)
# Check: content_size > MAX_UPLOAD_SIZE
if content_size > MAX_UPLOAD_SIZE:
# Raise BusinessRuleViolation
raise BusinessRuleViolation(
f"File exceeds maximum upload size of {MAX_UPLOAD_SIZE // (1024 * 1024)} MB"
)
# Define function list_evidence_for_test
def list_evidence_for_test(
# Entry: db
db: Session,
# Entry: test_id
test_id: uuid.UUID,
*,
# Entry: team
team: TeamSide | str | None = None,
) -> list[Evidence]:
"""Return evidence for a test, optionally filtered by team."""
# Assign query = db.query(Evidence).filter(Evidence.test_id == test_id)
query = db.query(Evidence).filter(Evidence.test_id == test_id)
# Check: team is not None
if team is not None:
# Assign team_enum = TeamSide(team) if isinstance(team, str) else team
team_enum = TeamSide(team) if isinstance(team, str) else team
# Assign query = query.filter(Evidence.team == team_enum)
query = query.filter(Evidence.team == team_enum)
# Return query.order_by(Evidence.uploaded_at.desc()).all()
return query.order_by(Evidence.uploaded_at.desc()).all()
# Define function get_evidence_or_raise
def get_evidence_or_raise(db: Session, evidence_id: uuid.UUID) -> Evidence:
"""Fetch evidence by ID. Raises EntityNotFoundError if not found."""
# Assign evidence = db.query(Evidence).filter(Evidence.id == evidence_id).first()
evidence = db.query(Evidence).filter(Evidence.id == evidence_id).first()
# Check: evidence is None
if evidence is None:
# Raise EntityNotFoundError
raise EntityNotFoundError("Evidence", str(evidence_id))
# Return evidence
return evidence
# Define function get_test_or_raise
def get_test_or_raise(db: Session, test_id: uuid.UUID) -> Test:
"""Fetch test by ID. Raises EntityNotFoundError if not found."""
# Assign test = db.query(Test).filter(Test.id == test_id).first()
test = db.query(Test).filter(Test.id == test_id).first()
# Check: test is None
if test is None:
# Raise EntityNotFoundError
raise EntityNotFoundError("Test", str(test_id))
# Return test
return test
+424 -15
View File
@@ -7,32 +7,59 @@ This module is framework-agnostic: no FastAPI imports, no HTTPException,
no ``db.commit()``.
"""
# Enable future language features for compatibility
from __future__ import annotations
# Import json
import json
from typing import Optional
# Import Callable from collections.abc
from collections.abc import Callable
# Import func, or_ from sqlalchemy
from sqlalchemy import func, or_
from sqlalchemy.orm import Session
# Import Query, Session from sqlalchemy.orm
from sqlalchemy.orm import Query, Session
# Import BusinessRuleViolation, EntityNotFoundError from app.domain.errors
from app.domain.errors import BusinessRuleViolation, EntityNotFoundError
# Import Campaign, CampaignTest from app.models.campaign
from app.models.campaign import Campaign, CampaignTest
# Import DetectionRule from app.models.detection_rule
from app.models.detection_rule import DetectionRule
from app.models.defensive_technique import DefensiveTechniqueMapping
# Import TechniqueStatus, TestState from app.models.enums
from app.models.enums import TechniqueStatus, TestState
# Import Technique from app.models.technique
from app.models.technique import Technique
# Import Test from app.models.test
from app.models.test import Test
# Import TestDetectionResult from app.models.test_detection_result
from app.models.test_detection_result import TestDetectionResult
# Import ThreatActor, ThreatActorTechnique from app.models.threat_actor
from app.models.threat_actor import ThreatActor, ThreatActorTechnique
# Import escape_like from app.utils
from app.utils import escape_like
# ── Constants ─────────────────────────────────────────────────────────
ATTACK_VERSION = "15"
# Assign NAVIGATOR_VERSION = "5.0"
NAVIGATOR_VERSION = "5.0"
# Assign LAYER_VERSION = "4.5"
LAYER_VERSION = "4.5"
# Assign DOMAIN = "enterprise-attack"
DOMAIN = "enterprise-attack"
# Assign STATUS_SCORE_MAP = {
STATUS_SCORE_MAP: dict[TechniqueStatus, int] = {
TechniqueStatus.validated: 100,
TechniqueStatus.partial: 60,
@@ -42,6 +69,7 @@ STATUS_SCORE_MAP: dict[TechniqueStatus, int] = {
TechniqueStatus.review_required: 10,
}
# Assign TEST_STATE_SCORE = {
TEST_STATE_SCORE: dict[TestState, int] = {
TestState.validated: 100,
TestState.in_review: 70,
@@ -56,74 +84,169 @@ TEST_STATE_SCORE: dict[TestState, int] = {
def _score_to_color(score: int) -> str:
"""Map a 0-100 score to a red-yellow-green colour hex."""
"""Map a 0-100 score to a red-yellow-green colour hex.
Args:
score (int): Coverage score between 0 and 100 inclusive.
Returns:
str: Hex colour string representing the score tier.
"""
# Check: score <= 0
if score <= 0:
# Return "#d3d3d3"
return "#d3d3d3"
# Check: score <= 25
if score <= 25:
# Return "#ff6666"
return "#ff6666"
# Check: score <= 50
if score <= 50:
# Return "#ff9933"
return "#ff9933"
# Check: score <= 75
if score <= 75:
# Return "#ffff66"
return "#ffff66"
# Return "#66ff66"
return "#66ff66"
# Define function _build_layer_skeleton
def _build_layer_skeleton(
# Entry: name
name: str,
# Entry: description
description: str,
# Entry: gradient_colors
gradient_colors: list[str] | None = None,
) -> dict:
"""Return a base layer dict compatible with ATT&CK Navigator."""
"""Return a base layer dict compatible with ATT&CK Navigator.
Args:
name (str): Human-readable name for the layer.
description (str): Description text embedded in the layer metadata.
gradient_colors (list[str] | None): Optional list of hex colour stops
for the gradient; defaults to red-yellow-green if omitted.
Returns:
dict: Skeleton layer dictionary with versions, domain, and empty
techniques list.
"""
# Return {
return {
# Literal argument value
"name": name,
# Literal argument value
"versions": {
# Literal argument value
"attack": ATTACK_VERSION,
# Literal argument value
"navigator": NAVIGATOR_VERSION,
# Literal argument value
"layer": LAYER_VERSION,
},
# Literal argument value
"domain": DOMAIN,
# Literal argument value
"description": description,
# Literal argument value
"filters": {"platforms": ["windows", "linux", "macos"]},
# Literal argument value
"gradient": {
# Literal argument value
"colors": gradient_colors or ["#ff6666", "#ffff66", "#66ff66"],
# Literal argument value
"minValue": 0,
# Literal argument value
"maxValue": 100,
},
# Literal argument value
"techniques": [],
}
# Define function _apply_filters
def _apply_filters(
query,
model,
# Entry: query
query: Query, # type: ignore[type-arg]
# Entry: model
model: type,
# Entry: platforms
platforms: list[str] | None = None,
# Entry: tactics
tactics: list[str] | None = None,
):
"""Apply common platform and tactic filters to a technique query."""
) -> Query: # type: ignore[type-arg]
"""Apply common platform and tactic filters to a technique query.
Args:
query (Query): Base SQLAlchemy query targeting a technique-like model.
model (type): The SQLAlchemy model class that owns ``platforms`` and
``tactic`` columns.
platforms (list[str] | None): Optional list of platform names to
filter by (OR-joined).
tactics (list[str] | None): Optional list of tactic strings to
filter by (OR-joined, case-insensitive substring match).
Returns:
Query: The query with platform and tactic filters applied.
"""
# Check: platforms
if platforms:
# Assign platform_filters = [
platform_filters = [
model.platforms.op("@>")(json.dumps([p])) for p in platforms
]
# Assign query = query.filter(or_(*platform_filters))
query = query.filter(or_(*platform_filters))
# Check: tactics
if tactics:
# Assign tactic_filters = [
tactic_filters = [
model.tactic.ilike(f"%{escape_like(t)}%") for t in tactics
]
# Assign query = query.filter(or_(*tactic_filters))
query = query.filter(or_(*tactic_filters))
# Return query
return query
# Define function _format_tactic
def _format_tactic(tactic_str: str | None) -> str:
"""Normalize tactic string to ATT&CK Navigator format (kebab-case)."""
"""Normalize tactic string to ATT&CK Navigator format (kebab-case).
Args:
tactic_str (str | None): Raw tactic string, possibly comma-separated
or mixed-case.
Returns:
str: First tactic value lowercased and trimmed, or empty string if
the input is falsy.
"""
# Check: not tactic_str
if not tactic_str:
# Return ""
return ""
# Return tactic_str.split(",")[0].strip().lower()
return tactic_str.split(",")[0].strip().lower()
# Define function _parse_csv
def _parse_csv(value: str | None) -> list[str] | None:
"""Split a comma-separated string into a trimmed list, or ``None``."""
"""Split a comma-separated string into a trimmed list, or ``None``.
Args:
value (str | None): Comma-separated string to split, or ``None``.
Returns:
list[str] | None: Non-empty trimmed tokens, or ``None`` if the input
is falsy or produces no tokens.
"""
# Check: not value
if not value:
# Return None
return None
# Return [v.strip() for v in value.split(",") if v.strip()]
return [v.strip() for v in value.split(",") if v.strip()]
@@ -131,132 +254,224 @@ def _parse_csv(value: str | None) -> list[str] | None:
def build_coverage_layer(
# Entry: db
db: Session,
*,
# Entry: platforms
platforms: str | None = None,
# Entry: tactics
tactics: str | None = None,
# Entry: min_score
min_score: int = 0,
) -> dict:
"""Coverage layer -- score based on ``status_global`` of each technique."""
"""Coverage layer -- score based on ``status_global`` of each technique.
Args:
db (Session): Active SQLAlchemy database session.
platforms (str | None): Optional comma-separated platform names to
filter techniques.
tactics (str | None): Optional comma-separated tactic names to filter
techniques.
min_score (int): Minimum score threshold; techniques below this are
omitted from the layer.
Returns:
dict: ATT&CK Navigator-compatible layer dictionary.
"""
# Assign layer = _build_layer_skeleton("Aegis Coverage", "Coverage layer generated b...
layer = _build_layer_skeleton("Aegis Coverage", "Coverage layer generated by Aegis")
# Assign query = _apply_filters(
query = _apply_filters(
db.query(Technique), Technique,
_parse_csv(platforms), _parse_csv(tactics),
)
# Assign techniques = query.all()
techniques = query.all()
# Bulk-fetch test counts and rule counts to avoid N+1
tech_ids = [t.id for t in techniques]
# Assign mitre_ids = [t.mitre_id for t in techniques]
mitre_ids = [t.mitre_id for t in techniques]
# Assign test_counts = dict(
test_counts = dict(
db.query(Test.technique_id, func.count(Test.id))
# Chain .filter() call
.filter(Test.technique_id.in_(tech_ids), Test.state == TestState.validated)
# Chain .group_by() call
.group_by(Test.technique_id)
# Chain .all() call
.all()
) if tech_ids else {}
# Assign rule_counts = dict(
rule_counts = dict(
db.query(DetectionRule.mitre_technique_id, func.count(DetectionRule.id))
# Chain .filter() call
.filter(DetectionRule.mitre_technique_id.in_(mitre_ids))
# Chain .group_by() call
.group_by(DetectionRule.mitre_technique_id)
# Chain .all() call
.all()
) if mitre_ids else {}
# Iterate over techniques
for tech in techniques:
# Assign score = STATUS_SCORE_MAP.get(tech.status_global, 0)
score = STATUS_SCORE_MAP.get(tech.status_global, 0)
# Check: score < min_score
if score < min_score:
# Skip to the next loop iteration
continue
# Assign tc = test_counts.get(tech.id, 0)
tc = test_counts.get(tech.id, 0)
# Assign rc = rule_counts.get(tech.mitre_id, 0)
rc = rule_counts.get(tech.mitre_id, 0)
# Assign metadata = [
metadata = [
{"name": "tests_count", "value": str(tc)},
{"name": "detection_rules", "value": str(rc)},
]
# Check: tech.last_review_date
if tech.last_review_date:
# Call metadata.append()
metadata.append(
{"name": "last_validated", "value": tech.last_review_date.strftime("%Y-%m-%d")}
)
# Assign comment_parts = [
comment_parts = [
f"Status: {tech.status_global.value}",
f"{tc} tests validated",
f"{rc} detection rules",
]
# layer["techniques"].append({
layer["techniques"].append({
# Literal argument value
"techniqueID": tech.mitre_id,
# Literal argument value
"tactic": _format_tactic(tech.tactic),
# Literal argument value
"color": _score_to_color(score),
# Literal argument value
"score": score,
# Literal argument value
"comment": " - ".join(comment_parts),
# Literal argument value
"enabled": True,
# Literal argument value
"metadata": metadata,
})
# Return layer
return layer
# Define function build_threat_actor_layer
def build_threat_actor_layer(
# Entry: db
db: Session,
# Entry: actor_id
actor_id: str,
*,
# Entry: platforms
platforms: str | None = None,
# Entry: tactics
tactics: str | None = None,
# Entry: min_score
min_score: int = 0,
) -> dict:
"""Threat actor layer -- techniques used by an actor with coverage colour.
Raises :class:`EntityNotFoundError` if the actor does not exist.
Args:
db (Session): Active SQLAlchemy database session.
actor_id (str): UUID string identifying the threat actor.
platforms (str | None): Optional comma-separated platform names to
filter techniques.
tactics (str | None): Optional comma-separated tactic names to filter
techniques.
min_score (int): Minimum score threshold for actor techniques.
Returns:
dict: ATT&CK Navigator-compatible layer dictionary coloured by
coverage status for the specified actor.
"""
# Assign actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first()
actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first()
# Check: not actor
if not actor:
# Raise EntityNotFoundError
raise EntityNotFoundError("ThreatActor", actor_id)
# Assign layer = _build_layer_skeleton(
layer = _build_layer_skeleton(
f"Threat Actor: {actor.name}",
f"Techniques used by {actor.name} with coverage overlay",
# Keyword argument: gradient_colors
gradient_colors=["#808080", "#ff6666", "#66ff66"],
)
# Assign actor_technique_ids = {
actor_technique_ids = {
row.technique_id
for row in db.query(ThreatActorTechnique.technique_id)
# Chain .filter() call
.filter(ThreatActorTechnique.threat_actor_id == actor.id)
# Chain .all() call
.all()
}
# Check: not actor_technique_ids
if not actor_technique_ids:
# Return layer
return layer
# Assign query = _apply_filters(
query = _apply_filters(
db.query(Technique), Technique,
_parse_csv(platforms), _parse_csv(tactics),
)
# Assign techniques = query.all()
techniques = query.all()
# Bulk-fetch metadata for actor techniques only
test_counts = dict(
db.query(Test.technique_id, func.count(Test.id))
# Chain .filter() call
.filter(Test.technique_id.in_(actor_technique_ids), Test.state == TestState.validated)
# Chain .group_by() call
.group_by(Test.technique_id)
# Chain .all() call
.all()
)
# Assign actor_mitre_ids = [t.mitre_id for t in techniques if t.id in actor_technique_ids]
actor_mitre_ids = [t.mitre_id for t in techniques if t.id in actor_technique_ids]
# Assign rule_counts = dict(
rule_counts = dict(
db.query(DetectionRule.mitre_technique_id, func.count(DetectionRule.id))
# Chain .filter() call
.filter(DetectionRule.mitre_technique_id.in_(actor_mitre_ids))
# Chain .group_by() call
.group_by(DetectionRule.mitre_technique_id)
# Chain .all() call
.all()
) if actor_mitre_ids else {}
# Iterate over techniques
for tech in techniques:
# Assign is_actor_technique = tech.id in actor_technique_ids
is_actor_technique = tech.id in actor_technique_ids
# Assign score = STATUS_SCORE_MAP.get(tech.status_global, 0) if is_actor_technique e...
score = STATUS_SCORE_MAP.get(tech.status_global, 0) if is_actor_technique else 0
# Check: is_actor_technique and score < min_score
if is_actor_technique and score < min_score:
# Skip to the next loop iteration
continue
# Only include techniques actually used by this actor — skip the rest
@@ -284,14 +499,20 @@ def build_threat_actor_layer(
"metadata": metadata,
})
# Return layer
return layer
# Define function build_detection_rules_layer
def build_detection_rules_layer(
# Entry: db
db: Session,
*,
# Entry: platforms
platforms: str | None = None,
# Entry: tactics
tactics: str | None = None,
# Entry: min_score
min_score: int = 0,
) -> dict:
"""Detection rules layer -- score based on absolute rule count per technique.
@@ -305,28 +526,40 @@ def build_detection_rules_layer(
4+ rules → green (score 100)
"""
layer = _build_layer_skeleton(
# Literal argument value
"Detection Rules Coverage",
"Number of active detection rules per technique",
)
# Assign query = _apply_filters(
query = _apply_filters(
db.query(Technique), Technique,
_parse_csv(platforms), _parse_csv(tactics),
)
# Assign techniques = query.all()
techniques = query.all()
# Assign rule_counts = dict(
rule_counts = dict(
db.query(DetectionRule.mitre_technique_id, func.count(DetectionRule.id))
# Chain .filter() call
.filter(DetectionRule.is_active == True) # noqa: E712
# Chain .group_by() call
.group_by(DetectionRule.mitre_technique_id)
# Chain .all() call
.all()
)
# Assign evaluated_counts = dict(
evaluated_counts = dict(
db.query(DetectionRule.mitre_technique_id, func.count(TestDetectionResult.id))
# Chain .join() call
.join(TestDetectionResult, TestDetectionResult.detection_rule_id == DetectionRule.id)
# Chain .filter() call
.filter(TestDetectionResult.triggered.isnot(None))
# Chain .group_by() call
.group_by(DetectionRule.mitre_technique_id)
# Chain .all() call
.all()
)
@@ -334,12 +567,16 @@ def build_detection_rules_layer(
RULES_FOR_FULL_COVERAGE = 4
for tech in techniques:
# Assign total_rules = rule_counts.get(tech.mitre_id, 0)
total_rules = rule_counts.get(tech.mitre_id, 0)
# Assign evaluated_rules = evaluated_counts.get(tech.mitre_id, 0)
evaluated_rules = evaluated_counts.get(tech.mitre_id, 0)
score = min(int((total_rules / RULES_FOR_FULL_COVERAGE) * 100), 100)
# Check: score < min_score
if score < min_score:
# Skip to the next loop iteration
continue
rule_word = "rule" if total_rules == 1 else "rules"
@@ -347,113 +584,194 @@ def build_detection_rules_layer(
comment = f"{total_rules} active {rule_word}{eval_note}"
layer["techniques"].append({
# Literal argument value
"techniqueID": tech.mitre_id,
# Literal argument value
"tactic": _format_tactic(tech.tactic),
# Literal argument value
"color": _score_to_color(score),
# Literal argument value
"score": score,
"comment": comment,
"enabled": True,
# Literal argument value
"metadata": [
{"name": "total_rules", "value": str(total_rules)},
{"name": "evaluated_rules", "value": str(evaluated_rules)},
],
})
# Return layer
return layer
# Define function build_campaign_layer
def build_campaign_layer(
# Entry: db
db: Session,
# Entry: campaign_id
campaign_id: str,
*,
# Entry: platforms
platforms: str | None = None,
# Entry: tactics
tactics: str | None = None,
# Entry: min_score
min_score: int = 0,
) -> dict:
"""Campaign layer -- techniques in a campaign, coloured by test state.
Raises :class:`EntityNotFoundError` if the campaign does not exist.
Args:
db (Session): Active SQLAlchemy database session.
campaign_id (str): UUID string identifying the campaign.
platforms (str | None): Optional comma-separated platform names to
filter techniques.
tactics (str | None): Optional comma-separated tactic names to filter
techniques.
min_score (int): Minimum score threshold for techniques in the layer.
Returns:
dict: ATT&CK Navigator-compatible layer dictionary where each
technique colour reflects the best test state within the campaign.
"""
# Assign campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
# Check: not campaign
if not campaign:
# Raise EntityNotFoundError
raise EntityNotFoundError("Campaign", campaign_id)
# Assign layer = _build_layer_skeleton(
layer = _build_layer_skeleton(
f"Campaign: {campaign.name}",
f"Progress of campaign '{campaign.name}'",
)
# Assign campaign_tests = (
campaign_tests = (
db.query(CampaignTest)
# Chain .filter() call
.filter(CampaignTest.campaign_id == campaign.id)
# Chain .all() call
.all()
)
# Check: not campaign_tests
if not campaign_tests:
# Return layer
return layer
# Assign test_ids = [ct.test_id for ct in campaign_tests]
test_ids = [ct.test_id for ct in campaign_tests]
# Assign tests = db.query(Test).filter(Test.id.in_(test_ids)).all()
tests = db.query(Test).filter(Test.id.in_(test_ids)).all()
# Assign test_map = {t.id: t for t in tests}
test_map = {t.id: t for t in tests}
# Assign technique_ids = {t.technique_id for t in tests if t.technique_id}
technique_ids = {t.technique_id for t in tests if t.technique_id}
# Assign techniques = db.query(Technique).filter(Technique.id.in_(technique_ids)).all()
techniques = db.query(Technique).filter(Technique.id.in_(technique_ids)).all()
# Assign tech_map = {t.id: t for t in techniques}
tech_map = {t.id: t for t in techniques}
# Group tests by technique, keeping the best state score
tech_scores: dict = {}
# Iterate over campaign_tests
for ct in campaign_tests:
# Assign test = test_map.get(ct.test_id)
test = test_map.get(ct.test_id)
# Check: not test
if not test:
# Skip to the next loop iteration
continue
# Assign tech = tech_map.get(test.technique_id)
tech = tech_map.get(test.technique_id)
# Check: not tech
if not tech:
# Skip to the next loop iteration
continue
# Assign state_score = TEST_STATE_SCORE.get(test.state, 0)
state_score = TEST_STATE_SCORE.get(test.state, 0)
# Check: tech.mitre_id not in tech_scores
if tech.mitre_id not in tech_scores:
# Assign tech_scores[tech.mitre_id] = {
tech_scores[tech.mitre_id] = {
# Literal argument value
"technique": tech,
# Literal argument value
"max_score": state_score,
# Literal argument value
"tests": [],
}
# Fallback: handle remaining cases
else:
# Assign tech_scores[tech.mitre_id]["max_score"] = max(
tech_scores[tech.mitre_id]["max_score"] = max(
tech_scores[tech.mitre_id]["max_score"], state_score,
)
# tech_scores[tech.mitre_id]["tests"].append(test)
tech_scores[tech.mitre_id]["tests"].append(test)
# Assign platform_list = _parse_csv(platforms)
platform_list = _parse_csv(platforms)
# Assign tactic_list = _parse_csv(tactics)
tactic_list = _parse_csv(tactics)
# Iterate over tech_scores.items()
for mitre_id, info in tech_scores.items():
# Assign tech = info["technique"]
tech = info["technique"]
# Assign score = info["max_score"]
score = info["max_score"]
# Check: platform_list
if platform_list:
# Assign tech_platforms = tech.platforms or []
tech_platforms = tech.platforms or []
# Check: not any(p in tech_platforms for p in platform_list)
if not any(p in tech_platforms for p in platform_list):
# Skip to the next loop iteration
continue
# Check: tactic_list
if tactic_list:
# Assign tech_tactics = [t.strip() for t in (tech.tactic or "").lower().split(",")]
tech_tactics = [t.strip() for t in (tech.tactic or "").lower().split(",")]
# Check: not any(t in tech_tactics for t in tactic_list)
if not any(t in tech_tactics for t in tactic_list):
# Skip to the next loop iteration
continue
# Check: score < min_score
if score < min_score:
# Skip to the next loop iteration
continue
# Assign test_states = [t.state.value for t in info["tests"]]
test_states = [t.state.value for t in info["tests"]]
# layer["techniques"].append({
layer["techniques"].append({
# Literal argument value
"techniqueID": mitre_id,
# Literal argument value
"tactic": _format_tactic(tech.tactic),
# Literal argument value
"color": _score_to_color(score),
# Literal argument value
"score": score,
# Literal argument value
"comment": f"Campaign tests: {', '.join(test_states)}",
# Literal argument value
"enabled": True,
# Literal argument value
"metadata": [
{"name": "campaign_tests", "value": str(len(info["tests"]))},
{"name": "best_state", "value": max(test_states) if test_states else "none"},
],
})
# Return layer
return layer
@@ -470,67 +788,143 @@ def build_campaign_layer(
class _LayerRegistry:
"""Extensible registry that maps layer type names to builder functions."""
# Assign __slots__ = ("_simple", "_with_id")
__slots__ = ("_simple", "_with_id")
# Define function __init__
def __init__(self) -> None:
# Assign self._simple = {}
self._simple: dict[str, object] = {}
# Assign self._with_id = {}
self._with_id: dict[str, object] = {}
def register(self, name: str, builder, *, requires_id: bool = False) -> None:
# Define function register
def register(self, name: str, builder: Callable[..., dict], *, requires_id: bool = False) -> None:
"""Register a builder function under *name*.
Args:
name (str): Unique layer type identifier.
builder (Callable[..., dict]): Layer builder function.
requires_id (bool): Whether the builder needs a positional
``layer_id`` argument.
"""
# Assign target = self._with_id if requires_id else self._simple
target = self._with_id if requires_id else self._simple
# Assign target[name] = builder
target[name] = builder
# Apply the @property decorator
@property
# Define function supported_types
def supported_types(self) -> set[str]:
"""Return the set of all registered layer type names.
Returns:
set[str]: Union of simple and entity-bound layer type names.
"""
# Return set(self._simple) | set(self._with_id)
return set(self._simple) | set(self._with_id)
# Define function build
def build(
self,
# Entry: db
db: Session,
# Entry: layer_type
layer_type: str,
*,
# Entry: layer_id
layer_id: str | None = None,
# Entry: platforms
platforms: str | None = None,
# Entry: tactics
tactics: str | None = None,
# Entry: min_score
min_score: int = 0,
) -> dict:
"""Dispatch to the registered builder for *layer_type*.
Args:
db (Session): Active SQLAlchemy database session.
layer_type (str): Registered layer type name.
layer_id (str | None): Entity UUID for entity-bound layer types.
platforms (str | None): Optional comma-separated platform filter.
tactics (str | None): Optional comma-separated tactic filter.
min_score (int): Minimum score threshold.
Returns:
dict: ATT&CK Navigator-compatible layer dictionary.
"""
# Assign kwargs = dict(platforms=platforms, tactics=tactics, min_score=min_score)
kwargs = dict(platforms=platforms, tactics=tactics, min_score=min_score)
# Check: layer_type in self._simple
if layer_type in self._simple:
# Return self._simple[layer_type](db, **kwargs)
return self._simple[layer_type](db, **kwargs)
# Check: layer_type in self._with_id
if layer_type in self._with_id:
# Check: not layer_id
if not layer_id:
# Raise BusinessRuleViolation
raise BusinessRuleViolation(
f"layer_id is required for '{layer_type}' layer"
)
# Return self._with_id[layer_type](db, layer_id, **kwargs)
return self._with_id[layer_type](db, layer_id, **kwargs)
# Raise BusinessRuleViolation
raise BusinessRuleViolation(f"Unknown layer type: {layer_type}")
# Assign LAYER_REGISTRY = _LayerRegistry()
LAYER_REGISTRY = _LayerRegistry()
# Call LAYER_REGISTRY.register()
LAYER_REGISTRY.register("coverage", build_coverage_layer)
# Call LAYER_REGISTRY.register()
LAYER_REGISTRY.register("detection-rules", build_detection_rules_layer)
# Call LAYER_REGISTRY.register()
LAYER_REGISTRY.register("threat-actor", build_threat_actor_layer, requires_id=True)
# Call LAYER_REGISTRY.register()
LAYER_REGISTRY.register("campaign", build_campaign_layer, requires_id=True)
# Assign SUPPORTED_LAYER_TYPES = LAYER_REGISTRY.supported_types # snapshot of built-in types
SUPPORTED_LAYER_TYPES = LAYER_REGISTRY.supported_types # snapshot of built-in types
def register_layer(name: str, builder, *, requires_id: bool = False) -> None:
"""Public API to register a new heatmap layer type at import time."""
# Define function register_layer
def register_layer(name: str, builder: Callable[..., dict], *, requires_id: bool = False) -> None:
"""Register a new heatmap layer type at import time.
Args:
name (str): Unique identifier for the layer type used in API requests.
builder (Callable[..., dict]): Function that builds the layer dict;
must accept ``(db, *, platforms, tactics, min_score)`` and
optionally a positional ``layer_id`` when ``requires_id`` is
``True``.
requires_id (bool): Set to ``True`` when the builder needs a
``layer_id`` argument (e.g. threat-actor, campaign layers).
"""
# Call LAYER_REGISTRY.register()
LAYER_REGISTRY.register(name, builder, requires_id=requires_id)
# Define function build_navigator_export
def build_navigator_export(
# Entry: db
db: Session,
# Entry: layer_type
layer_type: str,
*,
# Entry: layer_id
layer_id: str | None = None,
# Entry: platforms
platforms: str | None = None,
# Entry: tactics
tactics: str | None = None,
# Entry: min_score
min_score: int = 0,
) -> dict:
"""Build a heatmap layer dict by type name.
@@ -539,8 +933,23 @@ def build_navigator_export(
missing ``layer_id``. Raises :class:`EntityNotFoundError` when
an entity-bound layer (threat-actor, campaign) references a
non-existent record.
Args:
db (Session): Active SQLAlchemy database session.
layer_type (str): Registered layer type name (e.g. ``"coverage"``,
``"threat-actor"``).
layer_id (str | None): Entity UUID required for entity-bound layer
types such as ``"threat-actor"`` and ``"campaign"``.
platforms (str | None): Optional comma-separated platform filter.
tactics (str | None): Optional comma-separated tactic filter.
min_score (int): Minimum score; techniques below this are excluded.
Returns:
dict: ATT&CK Navigator-compatible layer dictionary.
"""
# Return LAYER_REGISTRY.build(
return LAYER_REGISTRY.build(
db, layer_type,
# Keyword argument: layer_id
layer_id=layer_id, platforms=platforms, tactics=tactics, min_score=min_score,
)
+107 -2
View File
@@ -9,18 +9,34 @@ RSS feeds and parses them with the standard-library :mod:`xml.etree`
parser. No LLMs or paid APIs are used.
"""
# Import logging
import logging
# Import re
import re
import defusedxml.ElementTree as ET
# Import datetime from datetime
from datetime import datetime
# Import defusedxml.ElementTree
import defusedxml.ElementTree as ET # noqa: N817 — ET is the universal stdlib alias for ElementTree
# Import requests
import requests as _requests
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session
# Import IntelItem from app.models.intel
from app.models.intel import IntelItem
# Import Technique from app.models.technique
from app.models.technique import Technique
# Import log_action from app.services.audit_service
from app.services.audit_service import log_action
# Assign logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
@@ -29,7 +45,9 @@ logger = logging.getLogger(__name__)
RSS_FEEDS: list[dict[str, str]] = [
{
# Literal argument value
"name": "CISA Alerts",
# Literal argument value
"url": "https://www.cisa.gov/cybersecurity-advisories/all.xml",
},
{
@@ -37,19 +55,27 @@ RSS_FEEDS: list[dict[str, str]] = [
"url": "https://feeds.feedburner.com/Securityweek",
},
{
# Literal argument value
"name": "SANS ISC",
# Literal argument value
"url": "https://isc.sans.edu/rssfeed.xml",
},
{
# Literal argument value
"name": "BleepingComputer",
# Literal argument value
"url": "https://www.bleepingcomputer.com/feed/",
},
{
# Literal argument value
"name": "The Hacker News",
# Literal argument value
"url": "https://feeds.feedburner.com/TheHackersNews",
},
{
# Literal argument value
"name": "Krebs on Security",
# Literal argument value
"url": "https://krebsonsecurity.com/feed/",
},
]
@@ -73,49 +99,81 @@ def _fetch_feed(url: str) -> list[dict[str, str]]:
Each entry is a dict with keys ``title``, ``link``, and ``description``.
Returns an empty list on any error so the scan can continue.
"""
# Attempt the following; catch errors below
try:
# Assign resp = _requests.get(url, timeout=_FEED_TIMEOUT, headers={
resp = _requests.get(url, timeout=_FEED_TIMEOUT, headers={
# Literal argument value
"User-Agent": "AegisPlatform/1.0 IntelScan",
})
# Call resp.raise_for_status()
resp.raise_for_status()
# Handle Exception
except Exception as exc:
# Log warning: "Failed to fetch feed %s: %s", url, exc
logger.warning("Failed to fetch feed %s: %s", url, exc)
# Return []
return []
# Attempt the following; catch errors below
try:
# Assign root = ET.fromstring(resp.content)
root = ET.fromstring(resp.content)
# Handle ET.ParseError
except ET.ParseError as exc:
# Log warning: "Failed to parse feed %s: %s", url, exc
logger.warning("Failed to parse feed %s: %s", url, exc)
# Return []
return []
# Assign entries = []
entries: list[dict[str, str]] = []
# RSS 2.0 format: <channel><item>...
for item in root.iter("item"):
# Assign title_el = item.find("title")
title_el = item.find("title")
# Assign link_el = item.find("link")
link_el = item.find("link")
# Assign desc_el = item.find("description")
desc_el = item.find("description")
# Call entries.append()
entries.append({
# Literal argument value
"title": title_el.text.strip() if title_el is not None and title_el.text else "",
# Literal argument value
"link": link_el.text.strip() if link_el is not None and link_el.text else "",
# Literal argument value
"description": desc_el.text.strip() if desc_el is not None and desc_el.text else "",
})
# Atom format: <feed><entry>...
ns = {"atom": "http://www.w3.org/2005/Atom"}
# Iterate over root.iter("{http
for entry in root.iter("{http://www.w3.org/2005/Atom}entry"):
# Assign title_el = entry.find("atom:title", ns)
title_el = entry.find("atom:title", ns)
# Assign link_el = entry.find("atom:link", ns)
link_el = entry.find("atom:link", ns)
# Assign summary_el = entry.find("atom:summary", ns)
summary_el = entry.find("atom:summary", ns)
# Assign link_href = ""
link_href = ""
# Check: link_el is not None
if link_el is not None:
# Assign link_href = link_el.get("href", "")
link_href = link_el.get("href", "")
# Call entries.append()
entries.append({
# Literal argument value
"title": title_el.text.strip() if title_el is not None and title_el.text else "",
# Literal argument value
"link": link_href.strip(),
# Literal argument value
"description": summary_el.text.strip() if summary_el is not None and summary_el.text else "",
})
# Return entries
return entries
@@ -147,6 +205,7 @@ def _entry_matches(
name_patterns: list[re.Pattern],
) -> bool:
"""Return True if any pattern matches the entry's title or description."""
# Assign text = f"{entry.get('title', '')} {entry.get('description', '')}"
text = f"{entry.get('title', '')} {entry.get('description', '')}"
return any(p.search(text) for p in id_patterns + name_patterns)
@@ -164,20 +223,23 @@ def scan_intel(db: Session) -> dict:
db : Session
Active SQLAlchemy database session.
Returns
Returns:
-------
dict
Summary with keys ``new_items``, ``duplicates_skipped``,
``techniques_flagged``, ``feeds_checked``.
"""
# Log info: "Intel scan starting..."
logger.info("Intel scan starting...")
# 1. Load all active techniques
techniques = (
db.query(Technique)
# Chain .order_by() call
.order_by(Technique.mitre_id)
.all()
)
# Log info: "Scanning %d techniques against %d feeds", len(tec
logger.info("Scanning %d techniques against %d feeds", len(techniques), len(RSS_FEEDS))
# 2. Pre-load all existing intel URLs for dedup
@@ -187,24 +249,36 @@ def scan_intel(db: Session) -> dict:
# 3. Fetch all feeds once
all_entries: list[tuple[str, dict[str, str]]] = [] # (feed_name, entry)
# Assign feeds_ok = 0
feeds_ok = 0
# Iterate over RSS_FEEDS
for feed in RSS_FEEDS:
# Assign entries = _fetch_feed(feed["url"])
entries = _fetch_feed(feed["url"])
# Check: entries
if entries:
# Assign feeds_ok = 1
feeds_ok += 1
# Iterate over entries
for entry in entries:
# Call all_entries.append()
all_entries.append((feed["name"], entry))
# Log info: "Fetched %d entries from %d/%d feeds", len(all_ent
logger.info("Fetched %d entries from %d/%d feeds", len(all_entries), feeds_ok, len(RSS_FEEDS))
# 4. Match entries to techniques
new_items = 0
# Assign duplicates_skipped = 0
duplicates_skipped = 0
# Assign techniques_flagged = set()
techniques_flagged: set[str] = set()
# Iterate over techniques
for technique in techniques:
id_patterns, name_patterns = _build_patterns(technique)
# Iterate over all_entries
for feed_name, entry in all_entries:
if not _entry_matches(entry, id_patterns, name_patterns):
continue
@@ -213,45 +287,69 @@ def scan_intel(db: Session) -> dict:
if not entry.get("title", "").strip():
continue
# Assign url = entry.get("link", "").strip()
url = entry.get("link", "").strip()
# Check: not url
if not url:
# Skip to the next loop iteration
continue
# Dedup
if url in existing_urls:
# Assign duplicates_skipped = 1
duplicates_skipped += 1
# Skip to the next loop iteration
continue
# Create IntelItem
intel_item = IntelItem(
# Keyword argument: technique_id
technique_id=technique.id,
# Keyword argument: url
url=url,
# Keyword argument: title
title=entry.get("title", "")[:500],
# Keyword argument: source
source=feed_name,
# Keyword argument: detected_at
detected_at=datetime.utcnow(),
# Keyword argument: reviewed
reviewed=False,
)
# Stage new record(s) for database insertion
db.add(intel_item)
# Call existing_urls.add()
existing_urls.add(url)
# Assign new_items = 1
new_items += 1
# Flag technique for review
if not technique.review_required:
# Assign technique.review_required = True
technique.review_required = True
# Call techniques_flagged.add()
techniques_flagged.add(technique.mitre_id)
# 5. Single commit
db.commit()
# Assign summary = {
summary = {
# Literal argument value
"new_items": new_items,
# Literal argument value
"duplicates_skipped": duplicates_skipped,
# Literal argument value
"techniques_flagged": len(techniques_flagged),
# Literal argument value
"feeds_checked": feeds_ok,
}
# Log info:
logger.info(
# Literal argument value
"Intel scan complete — new=%d, duplicates_skipped=%d, "
# Literal argument value
"techniques_flagged=%d, feeds_checked=%d",
new_items, duplicates_skipped, len(techniques_flagged), feeds_ok,
)
@@ -259,12 +357,19 @@ def scan_intel(db: Session) -> dict:
# 6. Audit log
log_action(
db,
# Keyword argument: user_id
user_id=None,
# Keyword argument: action
action="intel_scan",
# Keyword argument: entity_type
entity_type="intel_item",
# Keyword argument: entity_id
entity_id=None,
# Keyword argument: details
details=summary,
)
# Commit all pending changes to the database
db.commit()
# Return summary
return summary
+134 -1
View File
@@ -28,22 +28,44 @@ creates the Jira ticket and stores the link.
from __future__ import annotations
# Import logging
import logging
# Import datetime from datetime
from datetime import datetime
from typing import Optional
# Import Any, Optional from typing
from typing import Any, Optional
# Import UUID from uuid
from uuid import UUID
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session
# Import settings from app.config
from app.config import settings
# Import EntityNotFoundError from app.domain.errors
from app.domain.errors import EntityNotFoundError
# Import InvalidOperationError from app.domain.exceptions
from app.domain.exceptions import InvalidOperationError
# Import Campaign from app.models.campaign
from app.models.campaign import Campaign
# Import JiraLink, JiraLinkEntityType, JiraSyncDirection from app.models.jira_link
from app.models.jira_link import JiraLink, JiraLinkEntityType, JiraSyncDirection
# Import Technique from app.models.technique
from app.models.technique import Technique
# Import Test from app.models.test
from app.models.test import Test
from app.models.user import User
# Assign logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
@@ -624,6 +646,7 @@ def get_jira_client():
Prefer ``get_user_jira_client()`` for new code.
"""
if not settings.JIRA_ENABLED:
# Raise InvalidOperationError
raise InvalidOperationError("Jira integration is not enabled")
if not settings.JIRA_URL or not settings.JIRA_USERNAME or not settings.JIRA_API_TOKEN:
raise InvalidOperationError(
@@ -639,75 +662,121 @@ def get_jira_client():
)
# Define function search_jira_issues
def search_jira_issues(query: str, max_results: int = 10) -> list[dict]:
"""Search Jira issues by JQL or free text (uses global credentials)."""
jira = get_jira_client()
# Assign jql = query if "=" in query or "~" in query else f'summary ~ "{query}"'
jql = query if "=" in query or "~" in query else f'summary ~ "{query}"'
# Assign results = jira.jql(jql, limit=max_results)
results = jira.jql(jql, limit=max_results)
# Return [
return [
{
# Literal argument value
"issue_key": issue["key"],
# Literal argument value
"summary": issue["fields"]["summary"],
# Literal argument value
"status": issue["fields"]["status"]["name"],
# Literal argument value
"assignee": (issue["fields"].get("assignee") or {}).get("displayName"),
# Literal argument value
"priority": (issue["fields"].get("priority") or {}).get("name"),
}
for issue in results.get("issues", [])
]
# Define function create_jira_issue
def create_jira_issue(
# Entry: project_key
project_key: str,
# Entry: summary
summary: str,
# Entry: description
description: str,
# Entry: issue_type
issue_type: str = "Task",
# Entry: labels
labels: Optional[list[str]] = None,
# Entry: custom_fields
custom_fields: Optional[dict] = None,
) -> dict:
"""Create a Jira issue and return its key + id (uses global credentials)."""
jira = get_jira_client()
# Assign fields = {
fields: dict = {
# Literal argument value
"project": {"key": project_key},
# Literal argument value
"summary": summary,
# Literal argument value
"description": description,
# Literal argument value
"issuetype": {"name": issue_type},
}
# Check: labels
if labels:
# Assign fields["labels"] = labels
fields["labels"] = labels
# Check: custom_fields
if custom_fields:
# Call fields.update()
fields.update(custom_fields)
# Assign result = jira.issue_create(fields=fields)
result = jira.issue_create(fields=fields)
# Return {"issue_key": result["key"], "issue_id": result["id"]}
return {"issue_key": result["key"], "issue_id": result["id"]}
# Define function sync_jira_to_aegis
def sync_jira_to_aegis(db: Session, link: JiraLink) -> None:
"""Pull current status from Jira into the local link record (global creds)."""
jira = get_jira_client()
# Assign issue = jira.issue(link.jira_issue_key)
issue = jira.issue(link.jira_issue_key)
# Assign fields = issue.get("fields", {})
fields = issue.get("fields", {})
# Assign link.jira_status = fields.get("status", {}).get("name")
link.jira_status = fields.get("status", {}).get("name")
# Assign link.jira_priority = (fields.get("priority") or {}).get("name")
link.jira_priority = (fields.get("priority") or {}).get("name")
# Assign link.jira_assignee = (fields.get("assignee") or {}).get("displayName")
link.jira_assignee = (fields.get("assignee") or {}).get("displayName")
# Assign link.jira_story_points = str(fields.get("customfield_10016", ""))
link.jira_story_points = str(fields.get("customfield_10016", ""))
# Assign link.last_synced_at = datetime.utcnow()
link.last_synced_at = datetime.utcnow()
# Flush changes to DB without committing the transaction
db.flush()
# Define function sync_aegis_to_jira
def sync_aegis_to_jira(db: Session, link: JiraLink, entity_data: dict) -> None:
"""Push an Aegis status update as a Jira comment (global creds)."""
jira = get_jira_client()
# Assign comment_body = _build_sync_comment(entity_data)
comment_body = _build_sync_comment(entity_data)
# Call jira.issue_add_comment()
jira.issue_add_comment(link.jira_issue_key, comment_body)
# Assign link.last_synced_at = datetime.utcnow()
link.last_synced_at = datetime.utcnow()
# Flush changes to DB without committing the transaction
db.flush()
# Define function _build_sync_comment
def _build_sync_comment(data: dict) -> str:
lines = ["h3. Aegis Sync Update", ""]
# Iterate over data.items()
for key, value in data.items():
# Call lines.append()
lines.append(f"*{key}:* {value}")
# Call lines.append()
lines.append(f"\n_Synced at {datetime.utcnow().isoformat()}_")
# Return "\n".join(lines)
return "\n".join(lines)
@@ -715,60 +784,94 @@ def _build_sync_comment(data: dict) -> str:
def create_link(
# Entry: db
db: Session,
*,
# Entry: entity_type
entity_type: JiraLinkEntityType,
# Entry: entity_id
entity_id: UUID,
# Entry: jira_issue_key
jira_issue_key: str,
# Entry: sync_direction
sync_direction: JiraSyncDirection,
# Entry: created_by
created_by: UUID,
) -> JiraLink:
link = JiraLink(
# Keyword argument: entity_type
entity_type=entity_type,
# Keyword argument: entity_id
entity_id=entity_id,
# Keyword argument: jira_issue_key
jira_issue_key=jira_issue_key,
# Keyword argument: sync_direction
sync_direction=sync_direction,
# Keyword argument: created_by
created_by=created_by,
)
# Stage new record(s) for database insertion
db.add(link)
# Flush changes to DB without committing the transaction
db.flush()
# Check: settings.JIRA_ENABLED
if settings.JIRA_ENABLED:
# Attempt the following; catch errors below
try:
# Call sync_jira_to_aegis()
sync_jira_to_aegis(db, link)
# Handle Exception
except Exception as e:
# Log warning: "Initial Jira sync failed for %s: %s", jira_issue_
logger.warning("Initial Jira sync failed for %s: %s", jira_issue_key, e)
# Return link
return link
# Define function list_links
def list_links(
# Entry: db
db: Session,
*,
# Entry: entity_type
entity_type: Optional[JiraLinkEntityType] = None,
# Entry: entity_id
entity_id: Optional[UUID] = None,
entity_ids: Optional[list[UUID]] = None,
) -> list[JiraLink]:
query = db.query(JiraLink)
# Check: entity_type
if entity_type:
# Assign query = query.filter(JiraLink.entity_type == entity_type)
query = query.filter(JiraLink.entity_type == entity_type)
# Check: entity_id
if entity_id:
# Assign query = query.filter(JiraLink.entity_id == entity_id)
query = query.filter(JiraLink.entity_id == entity_id)
elif entity_ids:
query = query.filter(JiraLink.entity_id.in_(entity_ids))
return query.order_by(JiraLink.created_at.desc()).all()
# Define function get_link_or_raise
def get_link_or_raise(db: Session, link_id: UUID) -> JiraLink:
link = db.query(JiraLink).filter(JiraLink.id == link_id).first()
# Check: not link
if not link:
# Raise EntityNotFoundError
raise EntityNotFoundError("JiraLink", str(link_id))
# Return link
return link
# Define function delete_link
def delete_link(db: Session, link_id: UUID) -> JiraLink:
link = get_link_or_raise(db, link_id)
# Mark record for deletion on next commit
db.delete(link)
# Return link
return link
@@ -776,43 +879,64 @@ def build_issue_data(
db: Session, entity_type: JiraLinkEntityType, entity_id: UUID
) -> tuple[str, str]:
"""Build Jira issue summary and description from an Aegis entity."""
# Check: entity_type == JiraLinkEntityType.test
if entity_type == JiraLinkEntityType.test:
# Assign entity = db.query(Test).filter(Test.id == entity_id).first()
entity = db.query(Test).filter(Test.id == entity_id).first()
# Check: not entity
if not entity:
# Raise EntityNotFoundError
raise EntityNotFoundError("Test", str(entity_id))
technique = db.query(Technique).filter(Technique.id == entity.technique_id).first()
return (
f"[Aegis] {technique.mitre_id if technique else 'N/A'}{entity.name}",
_build_test_description(entity, technique),
)
# Alternative: entity_type == JiraLinkEntityType.campaign
elif entity_type == JiraLinkEntityType.campaign:
# Assign entity = db.query(Campaign).filter(Campaign.id == entity_id).first()
entity = db.query(Campaign).filter(Campaign.id == entity_id).first()
# Check: not entity
if not entity:
# Raise EntityNotFoundError
raise EntityNotFoundError("Campaign", str(entity_id))
# Return (
return (
f"[Aegis Campaign] {entity.name}",
f"Campaign: {entity.name}\nType: {entity.type}\nStatus: {entity.status}\n"
f"Description: {entity.description or 'N/A'}",
)
# Alternative: entity_type == JiraLinkEntityType.technique
elif entity_type == JiraLinkEntityType.technique:
# Assign entity = db.query(Technique).filter(Technique.id == entity_id).first()
entity = db.query(Technique).filter(Technique.id == entity_id).first()
# Check: not entity
if not entity:
# Raise EntityNotFoundError
raise EntityNotFoundError("Technique", str(entity_id))
# Return (
return (
f"[Aegis Technique] {entity.mitre_id} - {entity.name}",
f"MITRE ID: {entity.mitre_id}\nName: {entity.name}\n"
f"Tactic: {entity.tactic or 'N/A'}\n"
f"Description: {entity.description or 'N/A'}",
)
# Fallback: handle remaining cases
else:
# Return f"[Aegis] Entity {entity_id}", f"Entity type: {entity_type.value}"
return f"[Aegis] Entity {entity_id}", f"Entity type: {entity_type.value}"
# Define function create_issue_and_link
def create_issue_and_link(
# Entry: db
db: Session,
*,
# Entry: entity_type
entity_type: JiraLinkEntityType,
# Entry: entity_id
entity_id: UUID,
# Entry: created_by
created_by: UUID,
) -> dict:
"""Create a Jira issue from an Aegis entity and link them (global creds)."""
@@ -821,16 +945,25 @@ def create_issue_and_link(
result = create_jira_issue(
project_key=project_key,
summary=summary,
# Keyword argument: description
description=description,
# Keyword argument: labels
labels=["aegis", entity_type.value],
)
# Assign link = JiraLink(
link = JiraLink(
# Keyword argument: entity_type
entity_type=entity_type,
# Keyword argument: entity_id
entity_id=entity_id,
# Keyword argument: jira_issue_key
jira_issue_key=result["issue_key"],
# Keyword argument: jira_issue_id
jira_issue_id=result["issue_id"],
jira_project_key=project_key,
created_by=created_by,
)
# Stage new record(s) for database insertion
db.add(link)
# Return {"issue_key": result["issue_key"], "link_id": str(link.id)}
return {"issue_key": result["issue_key"], "link_id": str(link.id)}
+228 -2
View File
@@ -24,24 +24,45 @@ Deduplication keys:
- GTFOBins: ``source + binary_name + function`` → stored in ``atomic_test_id``
"""
# Import io
import io
# Import logging
import logging
# Import re
import re
# Import shutil
import shutil
# Import tempfile
import tempfile
# Import zipfile
import zipfile
# Import datetime from datetime
from datetime import datetime
# Import Path from pathlib
from pathlib import Path
# Import requests
import requests as _requests
# Import yaml
import yaml
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session
from app.models.test_template import TestTemplate
# Import DataSource from app.models.data_source
from app.models.data_source import DataSource
from app.models.technique import Technique
from app.services.audit_service import log_action
# Assign logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
@@ -49,34 +70,57 @@ logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
LOLBAS_ZIP_URL = (
# Literal argument value
"https://github.com/LOLBAS-Project/LOLBAS"
# Literal argument value
"/archive/refs/heads/master.zip"
)
# Assign GTFOBINS_ZIP_URL = (
GTFOBINS_ZIP_URL = (
# Literal argument value
"https://github.com/GTFOBins/GTFOBins.github.io"
# Literal argument value
"/archive/refs/heads/master.zip"
)
# Assign _DOWNLOAD_TIMEOUT = 300
_DOWNLOAD_TIMEOUT = 300
# GTFOBins function → MITRE technique mapping
_GTFOBINS_FUNCTION_MAP: dict[str, str] = {
# Literal argument value
"shell": "T1059",
# Literal argument value
"command": "T1059",
# Literal argument value
"reverse-shell": "T1059",
# Literal argument value
"non-interactive-reverse-shell": "T1059",
# Literal argument value
"bind-shell": "T1059",
# Literal argument value
"non-interactive-bind-shell": "T1059",
# Literal argument value
"file-upload": "T1105",
# Literal argument value
"file-download": "T1105",
# Literal argument value
"upload": "T1105",
# Literal argument value
"download": "T1105",
# Literal argument value
"file-write": "T1105",
# Literal argument value
"file-read": "T1005",
# Literal argument value
"library-load": "T1129",
# Literal argument value
"sudo": "T1548.003",
# Literal argument value
"suid": "T1548.001",
# Literal argument value
"capabilities": "T1548",
# Literal argument value
"limited-suid": "T1548.001",
}
@@ -88,18 +132,28 @@ _GTFOBINS_FUNCTION_MAP: dict[str, str] = {
def _download_zip(url: str) -> bytes:
"""Download a ZIP from *url* and return raw bytes."""
# Log info: "Downloading ZIP from %s …", url
logger.info("Downloading ZIP from %s", url)
# Assign resp = _requests.get(url, timeout=_DOWNLOAD_TIMEOUT, stream=True)
resp = _requests.get(url, timeout=_DOWNLOAD_TIMEOUT, stream=True)
# Call resp.raise_for_status()
resp.raise_for_status()
# Assign content = resp.content
content = resp.content
# Log info: "Downloaded %.1f MB", len(content) / (1024 * 1024
logger.info("Downloaded %.1f MB", len(content) / (1024 * 1024))
# Return content
return content
# Define function _extract_zip
def _extract_zip(zip_bytes: bytes, dest: str) -> Path:
"""Extract *zip_bytes* into *dest* and return the root directory."""
# Open context manager
with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
# Call zf.extractall()
zf.extractall(dest)
# Return Path(dest)
return Path(dest)
@@ -110,83 +164,141 @@ def _extract_zip(zip_bytes: bytes, dest: str) -> Path:
def _parse_lolbas(root_dir: Path) -> list[dict]:
"""Parse LOLBAS YAML files and return template dicts."""
# Assign results = []
results: list[dict] = []
# Assign lolbas_root = root_dir / "LOLBAS-master"
lolbas_root = root_dir / "LOLBAS-master"
# Assign yaml_dirs = [
yaml_dirs = [
lolbas_root / "yml" / "OSBinaries",
lolbas_root / "yml" / "OSLibraries",
lolbas_root / "yml" / "OSScripts",
]
# Assign yaml_files = []
yaml_files = []
# Iterate over yaml_dirs
for d in yaml_dirs:
# Check: d.is_dir()
if d.is_dir():
# Call yaml_files.extend()
yaml_files.extend(sorted(d.rglob("*.yml")))
# Log info: "LOLBAS: Found %d YAML files", len(yaml_files
logger.info("LOLBAS: Found %d YAML files", len(yaml_files))
# Iterate over yaml_files
for yaml_path in yaml_files:
# Attempt the following; catch errors below
try:
# Open context manager
with open(yaml_path, "r", encoding="utf-8") as fh:
# Assign data = yaml.safe_load(fh)
data = yaml.safe_load(fh)
# Handle Exception
except Exception as exc:
# Log debug: "Failed to parse %s: %s", yaml_path, exc
logger.debug("Failed to parse %s: %s", yaml_path, exc)
# Skip to the next loop iteration
continue
# Check: not isinstance(data, dict)
if not isinstance(data, dict):
# Skip to the next loop iteration
continue
# Assign binary_name = data.get("Name", "").strip()
binary_name = data.get("Name", "").strip()
# Check: not binary_name
if not binary_name:
# Skip to the next loop iteration
continue
# Assign description = data.get("Description", "")
description = data.get("Description", "")
# Assign commands = data.get("Commands", [])
commands = data.get("Commands", [])
# Check: not isinstance(commands, list)
if not isinstance(commands, list):
# Skip to the next loop iteration
continue
# Iterate over commands
for cmd_entry in commands:
# Check: not isinstance(cmd_entry, dict)
if not isinstance(cmd_entry, dict):
# Skip to the next loop iteration
continue
# Assign mitre_id = cmd_entry.get("MitreID")
mitre_id = cmd_entry.get("MitreID")
# Check: not mitre_id
if not mitre_id:
# Skip to the next loop iteration
continue
# Normalise the MITRE ID
mitre_id = str(mitre_id).strip().upper()
# Check: not mitre_id.startswith("T")
if not mitre_id.startswith("T"):
# Skip to the next loop iteration
continue
# Assign command = cmd_entry.get("Command", "")
command = cmd_entry.get("Command", "")
# Assign usecase = cmd_entry.get("Usecase", "")
usecase = cmd_entry.get("Usecase", "")
# Assign cmd_description = cmd_entry.get("Description", "")
cmd_description = cmd_entry.get("Description", "")
# Dedup key
dedup_key = f"lolbas:{binary_name}:{mitre_id}"
# Assign procedure = []
procedure = []
# Check: cmd_description
if cmd_description:
# Call procedure.append()
procedure.append(f"Description: {cmd_description}")
# Check: usecase
if usecase:
# Call procedure.append()
procedure.append(f"Use case: {usecase}")
# Check: command
if command:
# Call procedure.append()
procedure.append(f"Command: {command}")
# Call results.append()
results.append({
# Literal argument value
"mitre_technique_id": mitre_id,
# Literal argument value
"name": f"LOLBAS: {binary_name}{usecase or cmd_description or mitre_id}"[:500],
"description": f"{description}\n\n{cmd_description}".strip()[:2000] if description else cmd_description[:2000] if cmd_description else None,
# Literal argument value
"description": (
f"{description}\n\n{cmd_description}".strip()[:2000]
if description
else cmd_description[:2000] if cmd_description else None
),
# Literal argument value
"source": "lolbas",
# Literal argument value
"platform": "windows",
# Literal argument value
"tool_suggested": binary_name,
# Literal argument value
"attack_procedure": "\n".join(procedure)[:4000] if procedure else None,
# Literal argument value
"atomic_test_id": dedup_key,
# Literal argument value
"source_url": f"https://lolbas-project.github.io/lolbas/Binaries/{binary_name}/",
})
# Log info: "LOLBAS: Parsed %d templates", len(results
logger.info("LOLBAS: Parsed %d templates", len(results))
# Return results
return results
@@ -197,85 +309,138 @@ def _parse_lolbas(root_dir: Path) -> list[dict]:
def _parse_gtfobins(root_dir: Path) -> list[dict]:
"""Parse GTFOBins markdown files and return template dicts."""
# Assign results = []
results: list[dict] = []
# Assign gtfobins_root = root_dir / "GTFOBins.github.io-master" / "_gtfobins"
gtfobins_root = root_dir / "GTFOBins.github.io-master" / "_gtfobins"
# Check: not gtfobins_root.is_dir()
if not gtfobins_root.is_dir():
# Log warning: "GTFOBins directory not found at %s", gtfobins_roo
logger.warning("GTFOBins directory not found at %s", gtfobins_root)
# Return results
return results
# Assign md_files = sorted(
md_files = sorted(
f for f in gtfobins_root.iterdir()
if f.is_file() and f.suffix in (".md", "")
)
# Log info: "GTFOBins: Found %d files", len(md_files
logger.info("GTFOBins: Found %d files", len(md_files))
# Iterate over md_files
for md_path in md_files:
# Assign binary_name = md_path.stem # e.g. "awk"
binary_name = md_path.stem # e.g. "awk"
# Attempt the following; catch errors below
try:
# Open context manager
with open(md_path, "r", encoding="utf-8") as fh:
# Assign content = fh.read()
content = fh.read()
# Handle Exception
except Exception as exc:
# Log debug: "Failed to read %s: %s", md_path, exc
logger.debug("Failed to read %s: %s", md_path, exc)
# Skip to the next loop iteration
continue
# Extract YAML front-matter
front_matter = _extract_front_matter(content)
# Check: not front_matter
if not front_matter:
# Skip to the next loop iteration
continue
# Assign functions = front_matter.get("functions", {})
functions = front_matter.get("functions", {})
# Check: not isinstance(functions, dict)
if not isinstance(functions, dict):
# Skip to the next loop iteration
continue
# Iterate over functions.items()
for func_name, func_data in functions.items():
# Map function to MITRE technique
mitre_id = _GTFOBINS_FUNCTION_MAP.get(func_name.lower())
# Check: not mitre_id
if not mitre_id:
# Skip to the next loop iteration
continue
# Extract code examples from function data
examples = []
# Check: isinstance(func_data, list)
if isinstance(func_data, list):
# Iterate over func_data
for entry in func_data:
# Check: isinstance(entry, dict)
if isinstance(entry, dict):
# Assign code = entry.get("code", "")
code = entry.get("code", "")
# Check: code
if code:
# Call examples.append()
examples.append(str(code))
# Alternative: isinstance(entry, str)
elif isinstance(entry, str):
# Call examples.append()
examples.append(entry)
# Assign procedure = "\n\n".join(examples) if examples else None
procedure = "\n\n".join(examples) if examples else None
# Assign dedup_key = f"gtfobins:{binary_name}:{func_name}"
dedup_key = f"gtfobins:{binary_name}:{func_name}"
# Call results.append()
results.append({
# Literal argument value
"mitre_technique_id": mitre_id,
# Literal argument value
"name": f"GTFOBins: {binary_name}{func_name}"[:500],
# Literal argument value
"description": f"Abuse {binary_name} binary for {func_name} on Linux/Unix."[:2000],
# Literal argument value
"source": "gtfobins",
# Literal argument value
"platform": "linux",
# Literal argument value
"tool_suggested": binary_name,
# Literal argument value
"attack_procedure": procedure[:4000] if procedure else None,
# Literal argument value
"atomic_test_id": dedup_key,
# Literal argument value
"source_url": f"https://gtfobins.github.io/gtfobins/{binary_name}/",
})
# Log info: "GTFOBins: Parsed %d templates", len(results
logger.info("GTFOBins: Parsed %d templates", len(results))
# Return results
return results
# Define function _extract_front_matter
def _extract_front_matter(content: str) -> dict | None:
"""Extract YAML front-matter from a markdown/GTFOBins file.
Supports both ``---/---`` (standard front-matter) and ``---/...``
(YAML document-end marker used by GTFOBins).
"""
# Assign match = re.match(r"^---\s*\n(.*?)\n(?:---|\.\.\.)", content, re.DOTALL)
match = re.match(r"^---\s*\n(.*?)\n(?:---|\.\.\.)", content, re.DOTALL)
# Check: not match
if not match:
# Return None
return None
# Attempt the following; catch errors below
try:
# Return yaml.safe_load(match.group(1))
return yaml.safe_load(match.group(1))
# Handle Exception
except Exception:
# Return None
return None
@@ -286,36 +451,59 @@ def _extract_front_matter(content: str) -> dict | None:
def _upsert_templates(db: Session, items: list[dict], source_name: str) -> dict:
"""Insert templates, skipping existing ones by atomic_test_id."""
# Assign existing_ids = {
existing_ids: set[str] = {
row[0]
for row in db.query(TestTemplate.atomic_test_id)
# Chain .filter() call
.filter(TestTemplate.source == source_name)
# Chain .filter() call
.filter(TestTemplate.atomic_test_id.isnot(None))
# Chain .all() call
.all()
}
# Assign created = 0
created = 0
# Assign skipped = 0
skipped = 0
new_technique_ids: set[str] = set()
# Iterate over items
for item in items:
# Check: item["atomic_test_id"] in existing_ids
if item["atomic_test_id"] in existing_ids:
# Assign skipped = 1
skipped += 1
# Skip to the next loop iteration
continue
# Assign template = TestTemplate(
template = TestTemplate(
# Keyword argument: mitre_technique_id
mitre_technique_id=item["mitre_technique_id"],
# Keyword argument: name
name=item["name"],
# Keyword argument: description
description=item["description"],
# Keyword argument: source
source=item["source"],
# Keyword argument: source_url
source_url=item.get("source_url"),
# Keyword argument: attack_procedure
attack_procedure=item.get("attack_procedure"),
# Keyword argument: platform
platform=item["platform"],
# Keyword argument: tool_suggested
tool_suggested=item.get("tool_suggested"),
# Keyword argument: atomic_test_id
atomic_test_id=item["atomic_test_id"],
# Keyword argument: is_active
is_active=True,
)
# Stage new record(s) for database insertion
db.add(template)
# Call existing_ids.add()
existing_ids.add(item["atomic_test_id"])
new_technique_ids.add(item["mitre_technique_id"])
created += 1
@@ -326,6 +514,7 @@ def _upsert_templates(db: Session, items: list[dict], source_name: str) -> dict:
).update({"review_required": True}, synchronize_session=False)
db.commit()
# Return {"created": created, "skipped_existing": skipped, "total_parsed": l...
return {"created": created, "skipped_existing": skipped, "total_parsed": len(items)}
@@ -339,56 +528,93 @@ def sync(db: Session) -> dict:
Returns a summary dict with ``created``, ``skipped_existing``, ``total_parsed``.
"""
# Assign tmp_dir = tempfile.mkdtemp(prefix="aegis_lolbas_")
tmp_dir = tempfile.mkdtemp(prefix="aegis_lolbas_")
# Attempt the following; catch errors below
try:
# Assign zip_bytes = _download_zip(LOLBAS_ZIP_URL)
zip_bytes = _download_zip(LOLBAS_ZIP_URL)
# Assign root_dir = _extract_zip(zip_bytes, tmp_dir)
root_dir = _extract_zip(zip_bytes, tmp_dir)
# Assign parsed = _parse_lolbas(root_dir)
parsed = _parse_lolbas(root_dir)
# Always execute this cleanup block
finally:
# Call shutil.rmtree()
shutil.rmtree(tmp_dir, ignore_errors=True)
# Assign summary = _upsert_templates(db, parsed, "lolbas")
summary = _upsert_templates(db, parsed, "lolbas")
# Update DataSource record
ds = db.query(DataSource).filter(DataSource.name == "lolbas").first()
# Check: ds
if ds:
# Assign ds.last_sync_at = datetime.utcnow()
ds.last_sync_at = datetime.utcnow()
# Assign ds.last_sync_status = "success"
ds.last_sync_status = "success"
# Assign ds.last_sync_stats = summary
ds.last_sync_stats = summary
# Commit all pending changes to the database
db.commit()
# Log info: "LOLBAS import complete — %s", summary
logger.info("LOLBAS import complete — %s", summary)
# Call log_action()
log_action(db, user_id=None, action="import_lolbas",
# Keyword argument: entity_type
entity_type="test_template", entity_id=None, details=summary)
# Commit all pending changes to the database
db.commit()
# Return summary
return summary
# Define function sync_gtfobins
def sync_gtfobins(db: Session) -> dict:
"""Import GTFOBins templates.
Returns a summary dict with ``created``, ``skipped_existing``, ``total_parsed``.
"""
# Assign tmp_dir = tempfile.mkdtemp(prefix="aegis_gtfobins_")
tmp_dir = tempfile.mkdtemp(prefix="aegis_gtfobins_")
# Attempt the following; catch errors below
try:
# Assign zip_bytes = _download_zip(GTFOBINS_ZIP_URL)
zip_bytes = _download_zip(GTFOBINS_ZIP_URL)
# Assign root_dir = _extract_zip(zip_bytes, tmp_dir)
root_dir = _extract_zip(zip_bytes, tmp_dir)
# Assign parsed = _parse_gtfobins(root_dir)
parsed = _parse_gtfobins(root_dir)
# Always execute this cleanup block
finally:
# Call shutil.rmtree()
shutil.rmtree(tmp_dir, ignore_errors=True)
# Assign summary = _upsert_templates(db, parsed, "gtfobins")
summary = _upsert_templates(db, parsed, "gtfobins")
# Update DataSource record
ds = db.query(DataSource).filter(DataSource.name == "gtfobins").first()
# Check: ds
if ds:
# Assign ds.last_sync_at = datetime.utcnow()
ds.last_sync_at = datetime.utcnow()
# Assign ds.last_sync_status = "success"
ds.last_sync_status = "success"
# Assign ds.last_sync_stats = summary
ds.last_sync_stats = summary
# Commit all pending changes to the database
db.commit()
# Log info: "GTFOBins import complete — %s", summary
logger.info("GTFOBins import complete — %s", summary)
# Call log_action()
log_action(db, user_id=None, action="import_gtfobins",
# Keyword argument: entity_type
entity_type="test_template", entity_id=None, details=summary)
# Commit all pending changes to the database
db.commit()
# Return summary
return summary
@@ -7,16 +7,28 @@ of MITRE ATT&CK technique coverage for dashboards and reporting.
This module is framework-agnostic: no FastAPI imports.
"""
# Enable future language features for compatibility
from __future__ import annotations
# Import defaultdict from collections
from collections import defaultdict
# Import func from sqlalchemy
from sqlalchemy import func
# Import Session, joinedload from sqlalchemy.orm
from sqlalchemy.orm import Session, joinedload
# Import TechniqueStatus, TestState from app.models.enums
from app.models.enums import TechniqueStatus, TestState
# Import Technique from app.models.technique
from app.models.technique import Technique
# Import Test from app.models.test
from app.models.test import Test
# Import from app.schemas.metrics
from app.schemas.metrics import (
CoverageSummary,
RecentTestItem,
@@ -27,40 +39,60 @@ from app.schemas.metrics import (
)
# Define function get_coverage_summary
def get_coverage_summary(db: Session) -> CoverageSummary:
"""Return a global coverage summary across all techniques."""
# Assign rows = (
rows = (
db.query(
Technique.status_global,
func.count(Technique.id).label("cnt"),
)
# Chain .group_by() call
.group_by(Technique.status_global)
# Chain .all() call
.all()
)
# Assign counts = {s.value: 0 for s in TechniqueStatus}
counts: dict[str, int] = {s.value: 0 for s in TechniqueStatus}
# Iterate over rows
for status, cnt in rows:
# Assign counts[status.value] = cnt
counts[status.value] = cnt
# Assign total = sum(counts.values())
total = sum(counts.values())
# Assign validated = counts["validated"]
validated = counts["validated"]
# Assign partial = counts["partial"]
partial = counts["partial"]
# Assign coverage_pct = (
coverage_pct = (
round((validated + partial) / total * 100, 2) if total > 0 else 0.0
)
# Return CoverageSummary(
return CoverageSummary(
# Keyword argument: total_techniques
total_techniques=total,
# Keyword argument: validated
validated=validated,
# Keyword argument: partial
partial=partial,
# Keyword argument: not_covered
not_covered=counts["not_covered"],
# Keyword argument: in_progress
in_progress=counts["in_progress"],
# Keyword argument: not_evaluated
not_evaluated=counts["not_evaluated"],
# Keyword argument: coverage_percentage
coverage_percentage=coverage_pct,
)
# Define function get_coverage_by_tactic
def get_coverage_by_tactic(db: Session) -> list[TacticCoverage]:
"""Return coverage breakdown grouped by tactic.
@@ -68,6 +100,7 @@ def get_coverage_by_tactic(db: Session) -> list[TacticCoverage]:
comma-separated string), the technique is counted once per tactic
it belongs to.
"""
# Assign techniques = db.query(
techniques = db.query(
Technique.tactic, Technique.status_global
).all()
@@ -75,179 +108,270 @@ def get_coverage_by_tactic(db: Session) -> list[TacticCoverage]:
# Accumulate per-tactic counters. A technique with tactic
# "persistence, privilege-escalation" is counted in both.
tactic_data: dict[str, dict[str, int]] = defaultdict(
# Entry: lambda
lambda: {s.value: 0 for s in TechniqueStatus}
)
# Iterate over techniques
for tactic_str, status in techniques:
# Check: not tactic_str
if not tactic_str:
# Assign tactics = ["unknown"]
tactics = ["unknown"]
# Fallback: handle remaining cases
else:
# Assign tactics = [t.strip() for t in tactic_str.split(",")]
tactics = [t.strip() for t in tactic_str.split(",")]
# Iterate over tactics
for tactic in tactics:
# Assign tactic_data[tactic][status.value] = 1
tactic_data[tactic][status.value] += 1
# Assign result = []
result = []
# Iterate over sorted(tactic_data)
for tactic in sorted(tactic_data):
# Assign counts = tactic_data[tactic]
counts = tactic_data[tactic]
# Assign total = sum(counts.values())
total = sum(counts.values())
# Call result.append()
result.append(
TacticCoverage(
# Keyword argument: tactic
tactic=tactic,
# Keyword argument: total
total=total,
# Keyword argument: validated
validated=counts["validated"],
# Keyword argument: partial
partial=counts["partial"],
# Keyword argument: not_covered
not_covered=counts["not_covered"],
# Keyword argument: not_evaluated
not_evaluated=counts["not_evaluated"],
# Keyword argument: in_progress
in_progress=counts["in_progress"],
)
)
# Return result
return result
# Define function get_test_pipeline_counts
def get_test_pipeline_counts(db: Session) -> TestPipelineCounts:
"""Return how many tests are in each pipeline state."""
# Assign rows = (
rows = (
db.query(Test.state, func.count(Test.id).label("cnt"))
# Chain .group_by() call
.group_by(Test.state)
# Chain .all() call
.all()
)
# Assign state_counts = {s.value: 0 for s in TestState}
state_counts: dict[str, int] = {s.value: 0 for s in TestState}
# Iterate over rows
for state, cnt in rows:
# Assign state_counts[state.value] = cnt
state_counts[state.value] = cnt
# Assign total = sum(state_counts.values())
total = sum(state_counts.values())
# Return TestPipelineCounts(
return TestPipelineCounts(
# Keyword argument: draft
draft=state_counts["draft"],
# Keyword argument: red_executing
red_executing=state_counts["red_executing"],
# Keyword argument: blue_evaluating
blue_evaluating=state_counts["blue_evaluating"],
# Keyword argument: in_review
in_review=state_counts["in_review"],
# Keyword argument: validated
validated=state_counts["validated"],
# Keyword argument: rejected
rejected=state_counts["rejected"],
# Keyword argument: total
total=total,
)
# Define function get_team_activity
def get_team_activity(db: Session) -> list[TeamActivity]:
"""Return activity summary for Red and Blue teams."""
# Red Team: completed = tests past red_executing; pending = draft + red_executing
red_completed = (
db.query(func.count(Test.id))
# Chain .filter() call
.filter(Test.state.in_([
TestState.blue_evaluating,
TestState.in_review,
TestState.validated,
TestState.rejected,
]))
# Chain .scalar() call
.scalar()
) or 0
# Assign red_pending = (
red_pending = (
db.query(func.count(Test.id))
# Chain .filter() call
.filter(Test.state.in_([TestState.draft, TestState.red_executing]))
# Chain .scalar() call
.scalar()
) or 0
# Blue Team: completed = tests past blue_evaluating; pending = blue_evaluating
blue_completed = (
db.query(func.count(Test.id))
# Chain .filter() call
.filter(Test.state.in_([
TestState.in_review,
TestState.validated,
TestState.rejected,
]))
# Chain .scalar() call
.scalar()
) or 0
# Assign blue_pending = (
blue_pending = (
db.query(func.count(Test.id))
# Chain .filter() call
.filter(Test.state == TestState.blue_evaluating)
# Chain .scalar() call
.scalar()
) or 0
# Return [
return [
TeamActivity(
# Keyword argument: team
team="Red Team",
# Keyword argument: tests_completed
tests_completed=red_completed,
# Keyword argument: tests_pending
tests_pending=red_pending,
),
TeamActivity(
# Keyword argument: team
team="Blue Team",
# Keyword argument: tests_completed
tests_completed=blue_completed,
# Keyword argument: tests_pending
tests_pending=blue_pending,
),
]
# Define function get_validation_rate
def get_validation_rate(db: Session) -> list[ValidationRate]:
"""Return approval and rejection rates for Red Lead and Blue Lead."""
# Red Lead validations
red_approved = (
db.query(func.count(Test.id))
# Chain .filter() call
.filter(Test.red_validation_status == "approved")
# Chain .scalar() call
.scalar()
) or 0
# Assign red_rejected = (
red_rejected = (
db.query(func.count(Test.id))
# Chain .filter() call
.filter(Test.red_validation_status == "rejected")
# Chain .scalar() call
.scalar()
) or 0
# Assign red_total = red_approved + red_rejected
red_total = red_approved + red_rejected
# Assign red_rate = round(red_approved / red_total * 100, 1) if red_total > 0 else 0.0
red_rate = round(red_approved / red_total * 100, 1) if red_total > 0 else 0.0
# Blue Lead validations
blue_approved = (
db.query(func.count(Test.id))
# Chain .filter() call
.filter(Test.blue_validation_status == "approved")
# Chain .scalar() call
.scalar()
) or 0
# Assign blue_rejected = (
blue_rejected = (
db.query(func.count(Test.id))
# Chain .filter() call
.filter(Test.blue_validation_status == "rejected")
# Chain .scalar() call
.scalar()
) or 0
# Assign blue_total = blue_approved + blue_rejected
blue_total = blue_approved + blue_rejected
# Assign blue_rate = round(blue_approved / blue_total * 100, 1) if blue_total > 0 else 0.0
blue_rate = round(blue_approved / blue_total * 100, 1) if blue_total > 0 else 0.0
# Return [
return [
ValidationRate(
# Keyword argument: role
role="red_lead",
# Keyword argument: total_reviewed
total_reviewed=red_total,
# Keyword argument: approved
approved=red_approved,
# Keyword argument: rejected
rejected=red_rejected,
# Keyword argument: approval_rate
approval_rate=red_rate,
),
ValidationRate(
# Keyword argument: role
role="blue_lead",
# Keyword argument: total_reviewed
total_reviewed=blue_total,
# Keyword argument: approved
approved=blue_approved,
# Keyword argument: rejected
rejected=blue_rejected,
# Keyword argument: approval_rate
approval_rate=blue_rate,
),
]
# Define function get_recent_tests
def get_recent_tests(db: Session, *, limit: int = 10) -> list[RecentTestItem]:
"""Return the most recently created tests."""
from sqlalchemy import nullslast
tests = (
db.query(Test)
# Chain .options() call
.options(joinedload(Test.technique))
.order_by(nullslast(Test.created_at.desc()))
.limit(limit)
# Chain .all() call
.all()
)
# Return [
return [
RecentTestItem(
# Keyword argument: id
id=str(t.id),
# Keyword argument: name
name=t.name,
# Keyword argument: state
state=t.state.value,
# Keyword argument: technique_mitre_id
technique_mitre_id=t.technique.mitre_id if t.technique else None,
# Keyword argument: technique_name
technique_name=t.technique.name if t.technique else None,
# Keyword argument: created_at
created_at=t.created_at,
)
for t in tests
+143 -2
View File
@@ -6,123 +6,197 @@ ATT&CK collection, and upserts attack-pattern objects into the local
when the TAXII server is unreachable.
"""
# Import logging
import logging
# Import datetime from datetime
from datetime import datetime
# Import requests
import requests as _requests
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session
# Import Server as TaxiiServer from taxii2client.v20
from taxii2client.v20 import Server as TaxiiServer
from app.models.technique import Technique
# Import TechniqueStatus from app.models.enums
from app.models.enums import TechniqueStatus
# Import Technique from app.models.technique
from app.models.technique import Technique
# Import log_action from app.services.audit_service
from app.services.audit_service import log_action
# Assign logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
# Assign TAXII_SERVER_URL = "https://cti-taxii.mitre.org/taxii/"
TAXII_SERVER_URL = "https://cti-taxii.mitre.org/taxii/"
# Assign MITRE_SOURCE_NAME = "mitre-attack"
MITRE_SOURCE_NAME = "mitre-attack"
# Assign GITHUB_ENTERPRISE_URL = (
GITHUB_ENTERPRISE_URL = (
# Literal argument value
"https://raw.githubusercontent.com/mitre/cti/master/"
# Literal argument value
"enterprise-attack/enterprise-attack.json"
)
# Define function _extract_mitre_id
def _extract_mitre_id(external_references: list) -> str | None:
"""Return the MITRE ATT&CK ID (e.g. ``T1059.001``) from external_references."""
# Check: not external_references
if not external_references:
# Return None
return None
# Iterate over external_references
for ref in external_references:
# Check: ref.get("source_name") == MITRE_SOURCE_NAME
if ref.get("source_name") == MITRE_SOURCE_NAME:
# Return ref.get("external_id")
return ref.get("external_id")
# Return None
return None
# Define function _extract_tactics
def _extract_tactics(kill_chain_phases: list) -> str | None:
"""Return a comma-separated string of tactic phase names."""
# Check: not kill_chain_phases
if not kill_chain_phases:
# Return None
return None
# Assign tactics = [
tactics = [
phase.get("phase_name")
for phase in kill_chain_phases
if phase.get("kill_chain_name") == "mitre-attack"
]
# Return ", ".join(tactics) if tactics else None
return ", ".join(tactics) if tactics else None
# Define function _extract_platforms
def _extract_platforms(stix_object: dict) -> list:
"""Return the list of platforms from the STIX object."""
# Return stix_object.get("x_mitre_platforms", [])
return stix_object.get("x_mitre_platforms", [])
# Define function _extract_version
def _extract_version(stix_object: dict) -> str | None:
"""Return the MITRE ATT&CK version string."""
# Return stix_object.get("x_mitre_version")
return stix_object.get("x_mitre_version")
# Define function _extract_last_modified
def _extract_last_modified(stix_object: dict) -> datetime | None:
"""Return the ``modified`` timestamp as a datetime, or None."""
# Assign modified = stix_object.get("modified")
modified = stix_object.get("modified")
# Check: modified is None
if modified is None:
# Return None
return None
# Check: isinstance(modified, datetime)
if isinstance(modified, datetime):
# Return modified
return modified
# Attempt the following; catch errors below
try:
# Return datetime.fromisoformat(modified.replace("Z", "+00:00"))
return datetime.fromisoformat(modified.replace("Z", "+00:00"))
# Handle (ValueError, AttributeError)
except (ValueError, AttributeError):
# Return None
return None
# Define function _fetch_attack_patterns_taxii
def _fetch_attack_patterns_taxii() -> list[dict]:
"""Connect to the MITRE TAXII server and return all attack-pattern objects."""
# Log info: "Connecting to MITRE TAXII server at %s", TAXII_SE
logger.info("Connecting to MITRE TAXII server at %s", TAXII_SERVER_URL)
# Assign server = TaxiiServer(TAXII_SERVER_URL)
server = TaxiiServer(TAXII_SERVER_URL)
# Assign api_root = server.api_roots[0]
api_root = server.api_roots[0]
# Assign collection = api_root.collections[0] # Enterprise ATT&CK
collection = api_root.collections[0] # Enterprise ATT&CK
# Log info:
logger.info(
# Literal argument value
"Fetching objects from collection '%s' (id=%s)",
collection.title,
collection.id,
)
# Assign bundle = collection.get_objects()
bundle = collection.get_objects()
# Assign objects = bundle.get("objects", [])
objects = bundle.get("objects", [])
# Assign attack_patterns = [
attack_patterns = [
obj for obj in objects if obj.get("type") == "attack-pattern"
]
# Log info: "Retrieved %d attack-pattern objects via TAXII", l
logger.info("Retrieved %d attack-pattern objects via TAXII", len(attack_patterns))
# Return attack_patterns
return attack_patterns
# Define function _fetch_attack_patterns_github
def _fetch_attack_patterns_github() -> list[dict]:
"""Fallback: fetch Enterprise ATT&CK bundle from the MITRE CTI GitHub repo."""
# Log info: "Fetching Enterprise ATT&CK bundle from GitHub (%s
logger.info("Fetching Enterprise ATT&CK bundle from GitHub (%s)", GITHUB_ENTERPRISE_URL)
# Assign resp = _requests.get(GITHUB_ENTERPRISE_URL, timeout=120)
resp = _requests.get(GITHUB_ENTERPRISE_URL, timeout=120)
# Call resp.raise_for_status()
resp.raise_for_status()
# Assign bundle = resp.json()
bundle = resp.json()
# Assign objects = bundle.get("objects", [])
objects = bundle.get("objects", [])
# Assign attack_patterns = [
attack_patterns = [
obj for obj in objects if obj.get("type") == "attack-pattern"
]
# Log info: "Retrieved %d attack-pattern objects via GitHub",
logger.info("Retrieved %d attack-pattern objects via GitHub", len(attack_patterns))
# Return attack_patterns
return attack_patterns
# Define function _fetch_attack_patterns
def _fetch_attack_patterns() -> list[dict]:
"""Return all attack-pattern objects, trying TAXII first then GitHub."""
# Attempt the following; catch errors below
try:
# Return _fetch_attack_patterns_taxii()
return _fetch_attack_patterns_taxii()
# Handle Exception
except Exception as exc:
# Log warning:
logger.warning(
# Literal argument value
"TAXII server unavailable (%s), falling back to GitHub mirror",
exc,
)
# Return _fetch_attack_patterns_github()
return _fetch_attack_patterns_github()
# Define function sync_mitre
def sync_mitre(db: Session) -> dict:
"""Synchronize MITRE ATT&CK techniques into the local database.
@@ -131,11 +205,12 @@ def sync_mitre(db: Session) -> dict:
db : Session
Active SQLAlchemy database session.
Returns
Returns:
-------
dict
Summary with keys ``created``, ``updated``, ``unchanged``, ``skipped``.
"""
# Assign attack_patterns = _fetch_attack_patterns()
attack_patterns = _fetch_attack_patterns()
# Pre-load existing techniques keyed by mitre_id for fast lookup
@@ -143,90 +218,149 @@ def sync_mitre(db: Session) -> dict:
t.mitre_id: t for t in db.query(Technique).all()
}
# Assign created = 0
created = 0
# Assign updated = 0
updated = 0
# Assign unchanged = 0
unchanged = 0
# Assign skipped = 0
skipped = 0
# Iterate over attack_patterns
for obj in attack_patterns:
# ------------------------------------------------------------------
# Skip revoked / deprecated objects
# ------------------------------------------------------------------
if obj.get("revoked", False) or obj.get("x_mitre_deprecated", False):
# Assign skipped = 1
skipped += 1
# Skip to the next loop iteration
continue
# Assign mitre_id = _extract_mitre_id(obj.get("external_references", []))
mitre_id = _extract_mitre_id(obj.get("external_references", []))
# Check: not mitre_id
if not mitre_id:
# Assign skipped = 1
skipped += 1
# Skip to the next loop iteration
continue
# Assign name = obj.get("name", "")
name = obj.get("name", "")
# Assign description = obj.get("description", "")
description = obj.get("description", "")
# Assign tactic = _extract_tactics(obj.get("kill_chain_phases", []))
tactic = _extract_tactics(obj.get("kill_chain_phases", []))
# Assign platforms = _extract_platforms(obj)
platforms = _extract_platforms(obj)
# Assign version = _extract_version(obj)
version = _extract_version(obj)
# Assign last_modified = _extract_last_modified(obj)
last_modified = _extract_last_modified(obj)
# Assign is_subtechnique = "." in mitre_id
is_subtechnique = "." in mitre_id
# Assign parent_mitre_id = mitre_id.split(".")[0] if is_subtechnique else None
parent_mitre_id = mitre_id.split(".")[0] if is_subtechnique else None
# Assign existing = existing_techniques.get(mitre_id)
existing = existing_techniques.get(mitre_id)
# Check: existing is None
if existing is None:
# ---- Create new technique ----
technique = Technique(
# Keyword argument: mitre_id
mitre_id=mitre_id,
# Keyword argument: name
name=name,
# Keyword argument: description
description=description,
# Keyword argument: tactic
tactic=tactic,
# Keyword argument: platforms
platforms=platforms,
# Keyword argument: mitre_version
mitre_version=version,
# Keyword argument: mitre_last_modified
mitre_last_modified=last_modified,
# Keyword argument: is_subtechnique
is_subtechnique=is_subtechnique,
# Keyword argument: parent_mitre_id
parent_mitre_id=parent_mitre_id,
# Keyword argument: status_global
status_global=TechniqueStatus.not_evaluated,
# Keyword argument: review_required
review_required=False,
)
# Stage new record(s) for database insertion
db.add(technique)
# Assign existing_techniques[mitre_id] = technique
existing_techniques[mitre_id] = technique
# Assign created = 1
created += 1
# Fallback: handle remaining cases
else:
# ---- Update if name or description changed ----
changes = False
# Check: existing.name != name
if existing.name != name:
# Assign existing.name = name
existing.name = name
# Assign changes = True
changes = True
# Check: (existing.description or "") != (description or "")
if (existing.description or "") != (description or ""):
# Assign existing.description = description
existing.description = description
# Assign changes = True
changes = True
# Always keep metadata up-to-date (does not trigger review)
existing.tactic = tactic
# Assign existing.platforms = platforms
existing.platforms = platforms
# Assign existing.mitre_version = version
existing.mitre_version = version
# Assign existing.mitre_last_modified = last_modified
existing.mitre_last_modified = last_modified
# Assign existing.is_subtechnique = is_subtechnique
existing.is_subtechnique = is_subtechnique
# Assign existing.parent_mitre_id = parent_mitre_id
existing.parent_mitre_id = parent_mitre_id
# Check: changes
if changes:
# Assign existing.review_required = True
existing.review_required = True
# Assign updated = 1
updated += 1
# Fallback: handle remaining cases
else:
# Assign unchanged = 1
unchanged += 1
# Single commit for the whole batch
db.commit()
# Assign summary = {
summary = {
# Literal argument value
"created": created,
# Literal argument value
"updated": updated,
# Literal argument value
"unchanged": unchanged,
# Literal argument value
"skipped": skipped,
}
# Log info:
logger.info(
# Literal argument value
"MITRE sync complete — created=%d, updated=%d, unchanged=%d, skipped=%d",
created,
updated,
@@ -237,12 +371,19 @@ def sync_mitre(db: Session) -> dict:
# Audit log (system action → user_id=None)
log_action(
db,
# Keyword argument: user_id
user_id=None,
# Keyword argument: action
action="mitre_sync",
# Keyword argument: entity_type
entity_type="technique",
# Keyword argument: entity_id
entity_id=None,
# Keyword argument: details
details=summary,
)
# Commit all pending changes to the database
db.commit()
# Return summary
return summary
+141 -4
View File
@@ -7,16 +7,29 @@ Functions in this module stage changes via ``db.add()`` / ``db.flush()``
but do **not** commit. The caller is responsible for committing.
"""
# Import uuid
import uuid
# Import datetime, timedelta from datetime
from datetime import datetime, timedelta
from sqlalchemy.orm import Session
# Import func from sqlalchemy
from sqlalchemy import func
from app.domain.errors import EntityNotFoundError
from app.models.notification import Notification
from app.models.user import User
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session
# Import EntityNotFoundError from app.domain.errors
from app.domain.errors import EntityNotFoundError
# Import Notification from app.models.notification
from app.models.notification import Notification
# Import Test from app.models.test
from app.models.test import Test
# Import User from app.models.user
from app.models.user import User
# ---------------------------------------------------------------------------
# Core CRUD
@@ -24,132 +37,209 @@ from app.models.user import User
def list_notifications(
# Entry: db
db: Session,
# Entry: user_id
user_id: uuid.UUID,
*,
# Entry: offset
offset: int = 0,
# Entry: limit
limit: int = 20,
) -> list[Notification]:
"""Return paginated notifications for a user, newest first."""
# Return (
return (
db.query(Notification)
# Chain .filter() call
.filter(Notification.user_id == user_id)
# Chain .order_by() call
.order_by(Notification.created_at.desc())
# Chain .offset() call
.offset(offset)
# Chain .limit() call
.limit(limit)
# Chain .all() call
.all()
)
# Define function get_notification_or_raise
def get_notification_or_raise(
# Entry: db
db: Session,
# Entry: notification_id
notification_id: uuid.UUID,
# Entry: user_id
user_id: uuid.UUID,
) -> Notification:
"""Fetch a notification by ID and user, or raise EntityNotFoundError."""
# Assign notif = (
notif = (
db.query(Notification)
# Chain .filter() call
.filter(
Notification.id == notification_id,
Notification.user_id == user_id,
)
# Chain .first() call
.first()
)
# Check: notif is None
if notif is None:
# Raise EntityNotFoundError
raise EntityNotFoundError("Notification", str(notification_id))
# Return notif
return notif
# Define function notify_role
def notify_role(
# Entry: db
db: Session,
*,
# Entry: role
role: str,
# Entry: type
type: str,
# Entry: title
title: str,
# Entry: message
message: str,
# Entry: entity_type
entity_type: str,
# Entry: entity_id
entity_id: uuid.UUID,
) -> None:
"""Send notifications to all active users with a given role."""
# Assign users = (
users = (
db.query(User)
# Chain .filter() call
.filter(User.role == role, User.is_active == True) # noqa: E712
# Chain .all() call
.all()
)
# Iterate over users
for user in users:
# Call create_notification()
create_notification(
db,
# Keyword argument: user_id
user_id=user.id,
# Keyword argument: type
type=type,
# Keyword argument: title
title=title,
# Keyword argument: message
message=message,
# Keyword argument: entity_type
entity_type=entity_type,
# Keyword argument: entity_id
entity_id=entity_id,
)
# Define function create_notification
def create_notification(
# Entry: db
db: Session,
# Entry: user_id
user_id: uuid.UUID,
# Entry: type
type: str,
# Entry: title
title: str,
# Entry: message
message: str | None = None,
# Entry: entity_type
entity_type: str | None = None,
# Entry: entity_id
entity_id: uuid.UUID | None = None,
) -> Notification:
"""Create a single notification for a user."""
# Assign notif = Notification(
notif = Notification(
# Keyword argument: user_id
user_id=user_id,
# Keyword argument: type
type=type,
# Keyword argument: title
title=title,
# Keyword argument: message
message=message,
# Keyword argument: entity_type
entity_type=entity_type,
# Keyword argument: entity_id
entity_id=entity_id,
)
# Stage new record(s) for database insertion
db.add(notif)
# Flush changes to DB without committing the transaction
db.flush()
# Return notif
return notif
# Define function mark_as_read
def mark_as_read(
# Entry: db
db: Session, notification_id: uuid.UUID, user_id: uuid.UUID
) -> Notification:
"""Mark a single notification as read. Returns the notification. Raises EntityNotFoundError if not found."""
# Assign notif = get_notification_or_raise(db, notification_id, user_id)
notif = get_notification_or_raise(db, notification_id, user_id)
# Assign notif.read = True
notif.read = True
# Return notif
return notif
# Define function mark_all_as_read
def mark_all_as_read(db: Session, user_id: uuid.UUID) -> int:
"""Mark all unread notifications for a user as read. Returns count updated."""
# Assign count = (
count = (
db.query(Notification)
# Chain .filter() call
.filter(Notification.user_id == user_id, Notification.read == False) # noqa: E712
# Chain .update() call
.update({"read": True})
)
# Return count
return count
# Define function get_unread_count
def get_unread_count(db: Session, user_id: uuid.UUID) -> int:
"""Return the number of unread notifications for a user."""
# Return (
return (
db.query(func.count(Notification.id))
# Chain .filter() call
.filter(Notification.user_id == user_id, Notification.read == False) # noqa: E712
# Chain .scalar() call
.scalar()
) or 0
# Define function cleanup_old_notifications
def cleanup_old_notifications(db: Session, days: int = 90) -> int:
"""Delete read notifications older than *days*. Returns count deleted."""
# Assign cutoff = datetime.utcnow() - timedelta(days=days)
cutoff = datetime.utcnow() - timedelta(days=days)
# Assign count = (
count = (
db.query(Notification)
# Chain .filter() call
.filter(
Notification.read == True, # noqa: E712
Notification.created_at < cutoff,
)
# Chain .delete() call
.delete()
)
# Return count
return count
@@ -204,71 +294,118 @@ def notify_test_state_change(db: Session, test, new_state: str) -> None:
- rejected -> notify creator
- validated -> notify creator
"""
# Assign test_name = test.name
test_name = test.name
# Assign test_id = test.id
test_id = test.id
# Assign creator_id = test.created_by
creator_id = test.created_by
# Check: new_state == "red_executing" and creator_id
if new_state == "red_executing" and creator_id:
# Call create_notification()
create_notification(
db,
# Keyword argument: user_id
user_id=creator_id,
# Keyword argument: type
type="test_state_changed",
# Keyword argument: title
title="Test execution started",
# Keyword argument: message
message=f'Your test "{test_name}" has moved to execution phase.',
# Keyword argument: entity_type
entity_type="test",
# Keyword argument: entity_id
entity_id=test_id,
)
# Alternative: new_state == "blue_evaluating"
elif new_state == "blue_evaluating":
# Notify all blue_tech users
blue_users = db.query(User).filter(User.role == "blue_tech", User.is_active == True).all() # noqa: E712
# Iterate over blue_users
for user in blue_users:
# Call create_notification()
create_notification(
db,
# Keyword argument: user_id
user_id=user.id,
# Keyword argument: type
type="test_assigned",
# Keyword argument: title
title="New test ready for blue evaluation",
# Keyword argument: message
message=f'Test "{test_name}" needs blue team evaluation.',
# Keyword argument: entity_type
entity_type="test",
# Keyword argument: entity_id
entity_id=test_id,
)
# Alternative: new_state == "in_review"
elif new_state == "in_review":
# Notify red_lead and blue_lead users
managers = (
db.query(User)
# Chain .filter() call
.filter(User.role.in_(["red_lead", "blue_lead"]), User.is_active == True) # noqa: E712
# Chain .all() call
.all()
)
# Iterate over managers
for user in managers:
# Call create_notification()
create_notification(
db,
# Keyword argument: user_id
user_id=user.id,
# Keyword argument: type
type="validation_needed",
# Keyword argument: title
title="Test ready for validation",
# Keyword argument: message
message=f'Test "{test_name}" is awaiting your review.',
# Keyword argument: entity_type
entity_type="test",
# Keyword argument: entity_id
entity_id=test_id,
)
# Alternative: new_state == "rejected" and creator_id
elif new_state == "rejected" and creator_id:
# Call create_notification()
create_notification(
db,
# Keyword argument: user_id
user_id=creator_id,
# Keyword argument: type
type="test_rejected",
# Keyword argument: title
title="Test rejected",
# Keyword argument: message
message=f'Your test "{test_name}" has been rejected. Please review and resubmit.',
# Keyword argument: entity_type
entity_type="test",
# Keyword argument: entity_id
entity_id=test_id,
)
# Alternative: new_state == "validated" and creator_id
elif new_state == "validated" and creator_id:
# Call create_notification()
create_notification(
db,
# Keyword argument: user_id
user_id=creator_id,
# Keyword argument: type
type="test_validated",
# Keyword argument: title
title="Test validated",
# Keyword argument: message
message=f'Your test "{test_name}" has been validated successfully.',
# Keyword argument: entity_type
entity_type="test",
# Keyword argument: entity_id
entity_id=test_id,
)
@@ -3,26 +3,45 @@
Calculates security operations KPIs from test data and audit logs.
"""
# Import datetime, timedelta from datetime
from datetime import datetime, timedelta
# Import Optional from typing
from typing import Optional
from sqlalchemy import func, case, and_, or_, extract
# Import func from sqlalchemy
from sqlalchemy import func
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session
from app.models.test import Test
from app.models.technique import Technique
from app.models.test_detection_result import TestDetectionResult
# Import AuditLog from app.models.audit
from app.models.audit import AuditLog
from app.models.enums import TestState, TestResult
# Import TestResult, TestState from app.models.enums
from app.models.enums import TestResult, TestState
# Import Technique from app.models.technique
from app.models.technique import Technique
# Import Test from app.models.test
from app.models.test import Test
# Import TestDetectionResult from app.models.test_detection_result
from app.models.test_detection_result import TestDetectionResult
# Define function _safe_stats
def _safe_stats(values: list[float]) -> dict:
"""Compute mean, median, min, max from a list of floats (in hours).
For sub-hour averages, mean_hours is stored as minutes to avoid
rounding to 0.0 which is falsy in JavaScript."""
if not values:
# Return None
return None
# Assign sorted_vals = sorted(values)
sorted_vals = sorted(values)
# Assign n = len(sorted_vals)
n = len(sorted_vals)
mean = sum(sorted_vals) / n
# Use minutes for sub-hour values to avoid JS falsy 0.0
@@ -31,8 +50,11 @@ def _safe_stats(values: list[float]) -> dict:
"mean_hours": mean_display,
"unit": "min" if mean < 1 else "hrs",
"median_hours": round(sorted_vals[n // 2], 1),
# Literal argument value
"min_hours": round(sorted_vals[0], 1),
# Literal argument value
"max_hours": round(sorted_vals[-1], 1),
# Literal argument value
"sample_size": n,
}
@@ -59,6 +81,7 @@ def calculate_mttd(db: Session) -> Optional[dict]:
.all()
)
# Assign detection_times = []
detection_times = []
for t in tests:
gross_secs = (t.blue_started_at - t.red_started_at).total_seconds()
@@ -66,6 +89,7 @@ def calculate_mttd(db: Session) -> Optional[dict]:
if net_secs > 0:
detection_times.append(net_secs / 3600)
# Return _safe_stats(detection_times)
return _safe_stats(detection_times)
@@ -83,14 +107,17 @@ def calculate_mttr(db: Session) -> Optional[dict]:
"""
tests = (
db.query(Test)
# Chain .filter() call
.filter(
Test.state == TestState.validated,
Test.red_started_at.isnot(None),
Test.blue_validated_at.isnot(None),
)
# Chain .all() call
.all()
)
# Assign response_times = []
response_times = []
for t in tests:
gross_secs = (t.blue_validated_at - t.red_started_at).total_seconds()
@@ -99,6 +126,7 @@ def calculate_mttr(db: Session) -> Optional[dict]:
if net_secs > 0:
response_times.append(net_secs / 3600)
# Return _safe_stats(response_times)
return _safe_stats(response_times)
@@ -106,34 +134,63 @@ def calculate_mttr(db: Session) -> Optional[dict]:
def calculate_detection_efficacy(db: Session) -> dict:
"""Calculate detection efficacy: detected / total validated tests."""
"""Calculate detection efficacy: detected / total validated tests.
Args:
db (Session): Active SQLAlchemy database session.
Returns:
dict: Contains ``percentage``, ``detected``, ``partially``,
``not_detected``, and ``total``.
"""
# Assign validated_tests = (
validated_tests = (
db.query(Test)
# Chain .filter() call
.filter(Test.state == TestState.validated)
# Chain .all() call
.all()
)
# Assign total = len(validated_tests)
total = len(validated_tests)
# Check: total == 0
if total == 0:
# Return {
return {
# Literal argument value
"percentage": 0,
# Literal argument value
"detected": 0,
# Literal argument value
"partially": 0,
# Literal argument value
"not_detected": 0,
# Literal argument value
"total": 0,
}
# Assign detected = len([t for t in validated_tests if t.detection_result == TestResult...
detected = len([t for t in validated_tests if t.detection_result == TestResult.detected])
# Assign partially = len([t for t in validated_tests if t.detection_result == TestResult...
partially = len([t for t in validated_tests if t.detection_result == TestResult.partially_detected])
# Assign not_detected = len([t for t in validated_tests if t.detection_result == TestResult...
not_detected = len([t for t in validated_tests if t.detection_result == TestResult.not_detected])
# Assign percentage = round((detected / total) * 100, 1) if total > 0 else 0
percentage = round((detected / total) * 100, 1) if total > 0 else 0
# Return {
return {
# Literal argument value
"percentage": percentage,
# Literal argument value
"detected": detected,
# Literal argument value
"partially": partially,
# Literal argument value
"not_detected": not_detected,
# Literal argument value
"total": total,
}
@@ -142,25 +199,45 @@ def calculate_detection_efficacy(db: Session) -> dict:
def calculate_alert_fidelity(db: Session) -> dict:
"""Calculate alert fidelity: ratio of triggered detection rules."""
"""Calculate alert fidelity: ratio of triggered detection rules.
Args:
db (Session): Active SQLAlchemy database session.
Returns:
dict: Contains ``percentage``, ``triggered``, ``not_triggered``,
and ``total_evaluated``.
"""
# Assign total_evaluated = (
total_evaluated = (
db.query(func.count(TestDetectionResult.id))
# Chain .filter() call
.filter(TestDetectionResult.triggered.isnot(None))
# Chain .scalar() call
.scalar()
) or 0
# Assign triggered = (
triggered = (
db.query(func.count(TestDetectionResult.id))
# Chain .filter() call
.filter(TestDetectionResult.triggered == True)
# Chain .scalar() call
.scalar()
) or 0
# Assign not_triggered = total_evaluated - triggered
not_triggered = total_evaluated - triggered
# Return {
return {
# Literal argument value
"percentage": round((triggered / total_evaluated) * 100, 1) if total_evaluated > 0 else 0,
# Literal argument value
"triggered": triggered,
# Literal argument value
"not_triggered": not_triggered,
# Literal argument value
"total_evaluated": total_evaluated,
}
@@ -169,46 +246,78 @@ def calculate_alert_fidelity(db: Session) -> dict:
def calculate_coverage_velocity(db: Session) -> dict:
"""Calculate techniques validated per week."""
"""Calculate techniques validated per week.
Args:
db (Session): Active SQLAlchemy database session.
Returns:
dict: Contains ``techniques_per_week`` (float average over the last
12 weeks) and ``trend`` (``"improving"``, ``"stable"``, or
``"declining"``).
"""
# Count techniques that changed to validated/partial in the last 12 weeks
twelve_weeks_ago = datetime.utcnow() - timedelta(weeks=12)
# Assign weekly_counts = (
weekly_counts = (
db.query(
func.date_trunc("week", Technique.last_review_date).label("week"),
func.count(Technique.id).label("count"),
)
# Chain .filter() call
.filter(
Technique.last_review_date >= twelve_weeks_ago,
Technique.last_review_date.isnot(None),
)
# Chain .group_by() call
.group_by(func.date_trunc("week", Technique.last_review_date))
# Chain .order_by() call
.order_by("week")
# Chain .all() call
.all()
)
# Check: weekly_counts
if weekly_counts:
# Assign counts = [row.count for row in weekly_counts]
counts = [row.count for row in weekly_counts]
# Assign avg_per_week = round(sum(counts) / len(counts), 1)
avg_per_week = round(sum(counts) / len(counts), 1)
# Trend: compare last 4 weeks vs previous 4 weeks
recent = counts[-4:] if len(counts) >= 4 else counts
# Assign earlier = counts[-8:-4] if len(counts) >= 8 else counts[:len(counts) // 2] if...
earlier = counts[-8:-4] if len(counts) >= 8 else counts[:len(counts) // 2] if counts else []
# Assign recent_avg = sum(recent) / len(recent) if recent else 0
recent_avg = sum(recent) / len(recent) if recent else 0
# Assign earlier_avg = sum(earlier) / len(earlier) if earlier else 0
earlier_avg = sum(earlier) / len(earlier) if earlier else 0
# Check: recent_avg > earlier_avg * 1.1
if recent_avg > earlier_avg * 1.1:
# Assign trend = "improving"
trend = "improving"
# Alternative: recent_avg < earlier_avg * 0.9
elif recent_avg < earlier_avg * 0.9:
# Assign trend = "declining"
trend = "declining"
# Fallback: handle remaining cases
else:
# Assign trend = "stable"
trend = "stable"
# Fallback: handle remaining cases
else:
# Assign avg_per_week = 0
avg_per_week = 0
# Assign trend = "stable"
trend = "stable"
# Return {
return {
# Literal argument value
"techniques_per_week": avg_per_week,
# Literal argument value
"trend": trend,
}
@@ -264,6 +373,7 @@ def calculate_validation_throughput(db: Session) -> dict:
else:
trend = "stable"
# Return {
return {
"tests_per_week": conversion_rate, # reuse key for API compat
"conversion_rate": conversion_rate,
@@ -278,51 +388,84 @@ def calculate_validation_throughput(db: Session) -> dict:
def calculate_rejection_rate(db: Session) -> dict:
"""Calculate rejection rate, broken down by red_lead and blue_lead."""
"""Calculate rejection rate, broken down by red_lead and blue_lead.
Args:
db (Session): Active SQLAlchemy database session.
Returns:
dict: Contains ``percentage`` (overall rejection rate), ``by_red_lead``
(red-lead rejection percentage), and ``by_blue_lead``
(blue-lead rejection percentage).
"""
# Assign validated_count = (
validated_count = (
db.query(func.count(Test.id))
# Chain .filter() call
.filter(Test.state == TestState.validated)
# Chain .scalar() call
.scalar()
) or 0
# Assign rejected_count = (
rejected_count = (
db.query(func.count(Test.id))
# Chain .filter() call
.filter(Test.state == TestState.rejected)
# Chain .scalar() call
.scalar()
) or 0
# Assign total = validated_count + rejected_count
total = validated_count + rejected_count
# Assign overall_pct = round((rejected_count / total) * 100, 1) if total > 0 else 0
overall_pct = round((rejected_count / total) * 100, 1) if total > 0 else 0
# By red_lead (red_validation_status == "rejected")
red_rejected = (
db.query(func.count(Test.id))
# Chain .filter() call
.filter(Test.red_validation_status == "rejected")
# Chain .scalar() call
.scalar()
) or 0
# Assign red_total = (
red_total = (
db.query(func.count(Test.id))
# Chain .filter() call
.filter(Test.red_validation_status.in_(["approved", "rejected"]))
# Chain .scalar() call
.scalar()
) or 0
# Assign red_pct = round((red_rejected / red_total) * 100, 1) if red_total > 0 else 0
red_pct = round((red_rejected / red_total) * 100, 1) if red_total > 0 else 0
# By blue_lead
blue_rejected = (
db.query(func.count(Test.id))
# Chain .filter() call
.filter(Test.blue_validation_status == "rejected")
# Chain .scalar() call
.scalar()
) or 0
# Assign blue_total = (
blue_total = (
db.query(func.count(Test.id))
# Chain .filter() call
.filter(Test.blue_validation_status.in_(["approved", "rejected"]))
# Chain .scalar() call
.scalar()
) or 0
# Assign blue_pct = round((blue_rejected / blue_total) * 100, 1) if blue_total > 0 else 0
blue_pct = round((blue_rejected / blue_total) * 100, 1) if blue_total > 0 else 0
# Return {
return {
# Literal argument value
"percentage": overall_pct,
# Literal argument value
"by_red_lead": red_pct,
# Literal argument value
"by_blue_lead": blue_pct,
}
@@ -331,14 +474,31 @@ def calculate_rejection_rate(db: Session) -> dict:
def get_all_operational_metrics(db: Session) -> dict:
"""Get all operational metrics in a single response."""
"""Return all operational metrics combined in a single response.
Args:
db (Session): Active SQLAlchemy database session.
Returns:
dict: Contains ``mttd``, ``mttr``, ``detection_efficacy``,
``alert_fidelity``, ``coverage_velocity``,
``validation_throughput``, and ``rejection_rate`` keys.
"""
# Return {
return {
# Literal argument value
"mttd": calculate_mttd(db),
# Literal argument value
"mttr": calculate_mttr(db),
# Literal argument value
"detection_efficacy": calculate_detection_efficacy(db),
# Literal argument value
"alert_fidelity": calculate_alert_fidelity(db),
# Literal argument value
"coverage_velocity": calculate_coverage_velocity(db),
# Literal argument value
"validation_throughput": calculate_validation_throughput(db),
# Literal argument value
"rejection_rate": calculate_rejection_rate(db),
}
@@ -347,44 +507,77 @@ def get_all_operational_metrics(db: Session) -> dict:
def get_operational_trend(db: Session, period: str = "90d") -> list:
"""Get weekly trend data for operational metrics."""
"""Return weekly trend data for operational metrics.
Args:
db (Session): Active SQLAlchemy database session.
period (str): Lookback period; one of ``"30d"``, ``"90d"``
(default), or ``"1y"``.
Returns:
list: Weekly data points, each a dict with ``date``,
``detection_efficacy``, ``validated_tests``, and
``detected_tests``.
"""
# Assign now = datetime.utcnow()
now = datetime.utcnow()
# Check: period == "30d"
if period == "30d":
# Assign start = now - timedelta(days=30)
start = now - timedelta(days=30)
# Alternative: period == "1y"
elif period == "1y":
# Assign start = now - timedelta(days=365)
start = now - timedelta(days=365)
# Fallback: handle remaining cases
else:
# Assign start = now - timedelta(days=90)
start = now - timedelta(days=90)
# Build weekly data points
data_points = []
# Assign current = start
current = start
# Loop while current < now
while current < now:
# Assign week_end = min(current + timedelta(days=7), now)
week_end = min(current + timedelta(days=7), now)
# Detection efficacy for tests validated up to this week
validated_up_to = (
db.query(Test)
# Chain .filter() call
.filter(
Test.state == TestState.validated,
Test.red_validated_at <= week_end,
)
# Chain .all() call
.all()
)
# Assign total = len(validated_up_to)
total = len(validated_up_to)
# Assign detected = len([t for t in validated_up_to if t.detection_result == TestResult...
detected = len([t for t in validated_up_to if t.detection_result == TestResult.detected])
# Assign efficacy = round((detected / total) * 100, 1) if total > 0 else 0
efficacy = round((detected / total) * 100, 1) if total > 0 else 0
# Call data_points.append()
data_points.append({
# Literal argument value
"date": current.strftime("%Y-%m-%d"),
# Literal argument value
"detection_efficacy": efficacy,
# Literal argument value
"validated_tests": total,
# Literal argument value
"detected_tests": detected,
})
# Assign current = week_end
current = week_end
# Return data_points
return data_points
@@ -392,20 +585,33 @@ def get_operational_trend(db: Session, period: str = "90d") -> list:
def get_metrics_by_team(db: Session) -> dict:
"""Get metrics broken down by Red vs Blue team."""
"""Return metrics broken down by Red vs Blue team.
Args:
db (Session): Active SQLAlchemy database session.
Returns:
dict: Contains ``red_team`` and ``blue_team`` sub-dicts, each with
``tests_completed``, ``avg_completion_hours``, and
``rejection_rate``.
"""
# Red team metrics
red_tests_completed = (
db.query(func.count(Test.id))
# Chain .filter() call
.filter(Test.state.in_([
TestState.blue_evaluating,
TestState.in_review,
TestState.validated,
TestState.rejected,
]))
# Chain .scalar() call
.scalar()
) or 0
# Assign red_avg_time = None
red_avg_time = None
# Assign red_times = []
red_times = []
# Red team avg execution time: red_started_at → blue_started_at (net of paused)
tests_with_red = (
@@ -416,6 +622,7 @@ def get_metrics_by_team(db: Session) -> dict:
)
.all()
)
# Iterate over tests_with_red
for t in tests_with_red:
gross = (t.blue_started_at - t.red_started_at).total_seconds()
net = gross - (t.red_paused_seconds or 0)
@@ -429,11 +636,13 @@ def get_metrics_by_team(db: Session) -> dict:
# Blue team: count tests that reached the blue evaluation phase
blue_tests_completed = (
db.query(func.count(Test.id))
# Chain .filter() call
.filter(Test.state.in_([
TestState.in_review,
TestState.validated,
TestState.rejected,
]))
# Chain .scalar() call
.scalar()
) or 0
@@ -441,15 +650,20 @@ def get_metrics_by_team(db: Session) -> dict:
# Prefer blue_work_started_at (actual pick-up) → blue_validated_at.
# Fall back to blue_started_at if blue_work_started_at is not set.
blue_avg_time = None
# Assign blue_times = []
blue_times = []
# Assign tests_with_blue = (
tests_with_blue = (
db.query(Test)
# Chain .filter() call
.filter(
Test.blue_started_at.isnot(None),
Test.blue_validated_at.isnot(None),
)
# Chain .all() call
.all()
)
# Iterate over tests_with_blue
for t in tests_with_blue:
phase_start = t.blue_work_started_at or t.blue_started_at
gross = (t.blue_validated_at - phase_start).total_seconds()
@@ -463,15 +677,22 @@ def get_metrics_by_team(db: Session) -> dict:
red_avg_raw = sum(red_times) / len(red_times) if red_times else None
blue_avg_raw = sum(blue_times) / len(blue_times) if blue_times else None
# Return {
return {
# Literal argument value
"red_team": {
# Literal argument value
"tests_completed": red_tests_completed,
# Literal argument value
"avg_completion_hours": red_avg_time,
"avg_unit": "min" if (red_avg_raw is not None and red_avg_raw < 1) else "hrs",
"rejection_rate": calculate_rejection_rate(db)["by_red_lead"],
},
# Literal argument value
"blue_team": {
# Literal argument value
"tests_completed": blue_tests_completed,
# Literal argument value
"avg_completion_hours": blue_avg_time,
"avg_unit": "min" if (blue_avg_raw is not None and blue_avg_raw < 1) else "hrs",
"rejection_rate": calculate_rejection_rate(db)["by_blue_lead"],
+243 -10
View File
@@ -1,237 +1,433 @@
"""OSINT enrichment service — automatically discovers CVEs, advisories, and
related intelligence for MITRE ATT&CK techniques using the NVD API.
"""OSINT enrichment service — discovers CVEs, advisories, and threat intel for ATT&CK techniques via the NVD API.
Designed to run as a weekly background job. Respects NVD rate limits
(5 requests per 30 seconds without an API key, 50/30s with a key).
"""
# Import logging
import logging
# Import time
import time
# Import Optional from typing
from typing import Optional
# Import UUID from uuid
from uuid import UUID
# Import requests
import requests
# Import func from sqlalchemy
from sqlalchemy import func
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session
# Import settings from app.config
from app.config import settings
# Import EntityNotFoundError from app.domain.errors
from app.domain.errors import EntityNotFoundError
# Import OsintItem from app.models.osint_item
from app.models.osint_item import OsintItem
# Import Technique from app.models.technique
from app.models.technique import Technique
# Assign logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
# Assign NVD_API_BASE = "https://services.nvd.nist.gov/rest/json/cves/2.0"
NVD_API_BASE = "https://services.nvd.nist.gov/rest/json/cves/2.0"
# Assign NVD_RATE_LIMIT_BATCH = 5
NVD_RATE_LIMIT_BATCH = 5
# Assign NVD_RATE_LIMIT_WAIT = 31 # seconds to wait after each batch
NVD_RATE_LIMIT_WAIT = 31 # seconds to wait after each batch
# Define function enrich_technique_with_cves
def enrich_technique_with_cves(db: Session, technique: Technique) -> int:
"""Search for CVEs related to a technique via the NVD API.
Uses the technique name as a keyword search. Deduplicates against
existing OsintItems so re-runs are safe.
Returns the number of new CVEs added.
Args:
db (Session): Active SQLAlchemy database session.
technique (Technique): The ATT&CK technique to enrich.
Returns:
int: Number of new CVE items added to the database.
"""
# Attempt the following; catch errors below
try:
# Assign headers = {}
headers = {}
# Check: getattr(settings, "NVD_API_KEY", "")
if getattr(settings, "NVD_API_KEY", ""):
# Assign headers["apiKey"] = settings.NVD_API_KEY
headers["apiKey"] = settings.NVD_API_KEY
# Assign params = {
params = {
# Literal argument value
"keywordSearch": technique.name,
# Literal argument value
"resultsPerPage": 10,
}
# Assign resp = requests.get(
resp = requests.get(
NVD_API_BASE,
# Keyword argument: params
params=params,
# Keyword argument: headers
headers=headers,
# Keyword argument: timeout
timeout=30,
)
# Check: resp.status_code != 200
if resp.status_code != 200:
# Log warning:
logger.warning(
# Literal argument value
"NVD API error for %s: HTTP %d",
technique.mitre_id,
resp.status_code,
)
# Return 0
return 0
# Assign data = resp.json()
data = resp.json()
# Assign count = 0
count = 0
# Iterate over data.get("vulnerabilities", [])
for vuln in data.get("vulnerabilities", []):
# Assign cve = vuln.get("cve", {})
cve = vuln.get("cve", {})
# Assign cve_id = cve.get("id")
cve_id = cve.get("id")
# Check: not cve_id
if not cve_id:
# Skip to the next loop iteration
continue
# Deduplicate
exists = (
db.query(OsintItem.id)
# Chain .filter() call
.filter(
OsintItem.technique_id == technique.id,
OsintItem.source_url.contains(cve_id),
)
# Chain .first() call
.first()
)
# Check: exists
if exists:
# Skip to the next loop iteration
continue
# Assign descriptions = cve.get("descriptions", [])
descriptions = cve.get("descriptions", [])
# Assign desc = next(
desc = next(
(d["value"] for d in descriptions if d["lang"] == "en"), ""
)
# Extract CVSS severity
metrics = cve.get("metrics", {})
# Assign cvss_v31 = metrics.get("cvssMetricV31", [])
cvss_v31 = metrics.get("cvssMetricV31", [])
# Assign cvss_v30 = metrics.get("cvssMetricV30", [])
cvss_v30 = metrics.get("cvssMetricV30", [])
# Assign cvss_entry = (cvss_v31[0] if cvss_v31 else cvss_v30[0]) if (cvss_v31 or cvss_v30...
cvss_entry = (cvss_v31[0] if cvss_v31 else cvss_v30[0]) if (cvss_v31 or cvss_v30) else {}
# Assign cvss_data = cvss_entry.get("cvssData", {}) if cvss_entry else {}
cvss_data = cvss_entry.get("cvssData", {}) if cvss_entry else {}
# Assign severity = cvss_data.get("baseSeverity", "UNKNOWN")
severity = cvss_data.get("baseSeverity", "UNKNOWN")
# Assign score = cvss_data.get("baseScore")
score = cvss_data.get("baseScore")
# Assign item = OsintItem(
item = OsintItem(
# Keyword argument: technique_id
technique_id=technique.id,
# Keyword argument: source_type
source_type="cve",
# Keyword argument: source_url
source_url=f"https://nvd.nist.gov/vuln/detail/{cve_id}",
# Keyword argument: title
title=cve_id,
# Keyword argument: description
description=desc[:500] if desc else None,
# Keyword argument: severity
severity=severity,
# Keyword argument: metadata_
metadata_={"cvss_score": score, "cve_id": cve_id},
)
# Stage new record(s) for database insertion
db.add(item)
# Assign count = 1
count += 1
# Check: count > 0
if count > 0:
# Assign technique.review_required = True
technique.review_required = True
# Commit all pending changes to the database
db.commit()
# Log info: "Added %d CVEs for %s", count, technique.mitre_id
logger.info("Added %d CVEs for %s", count, technique.mitre_id)
# Return count
return count
# Handle requests.RequestException
except requests.RequestException as e:
# Log error:
logger.error(
# Literal argument value
"HTTP error during OSINT enrichment for %s: %s",
technique.mitre_id,
e,
)
# Return 0
return 0
# Handle Exception
except Exception as e:
# Log error:
logger.error(
# Literal argument value
"OSINT enrichment failed for %s: %s",
technique.mitre_id,
e,
# Keyword argument: exc_info
exc_info=True,
)
# Return 0
return 0
# Define function enrich_all_techniques
def enrich_all_techniques(db: Session) -> int:
"""Enrich all techniques with CVE data from NVD.
Rate-limited: processes *NVD_RATE_LIMIT_BATCH* techniques, then
sleeps for *NVD_RATE_LIMIT_WAIT* seconds to stay under NVD limits.
Returns total number of new OSINT items added.
Args:
db (Session): Active SQLAlchemy database session.
Returns:
int: Total number of new OSINT items added across all techniques.
"""
# Assign techniques = db.query(Technique).order_by(Technique.mitre_id).all()
techniques = db.query(Technique).order_by(Technique.mitre_id).all()
# Assign total = 0
total = 0
# Log info:
logger.info(
# Literal argument value
"Starting OSINT enrichment for %d techniques...",
len(techniques),
)
# Iterate over enumerate(techniques)
for i, tech in enumerate(techniques):
# Assign total = enrich_technique_with_cves(db, tech)
total += enrich_technique_with_cves(db, tech)
# Rate limiting: wait after every batch
if (i + 1) % NVD_RATE_LIMIT_BATCH == 0 and (i + 1) < len(techniques):
# Log debug:
logger.debug(
# Literal argument value
"Rate limit pause after %d techniques (%ds)...",
i + 1,
NVD_RATE_LIMIT_WAIT,
)
# Call time.sleep()
time.sleep(NVD_RATE_LIMIT_WAIT)
# Log info:
logger.info(
# Literal argument value
"OSINT enrichment complete — %d new items across %d techniques",
total,
len(techniques),
)
# Return total
return total
# Define function get_osint_items_for_technique
def get_osint_items_for_technique(
# Entry: db
db: Session,
# Entry: technique_id
technique_id: str,
# Entry: source_type
source_type: str | None = None,
# Entry: reviewed
reviewed: bool | None = None,
) -> list[OsintItem]:
"""Retrieve OSINT items for a technique with optional filters."""
"""Retrieve OSINT items for a technique with optional filters.
Args:
db (Session): Active SQLAlchemy database session.
technique_id (str): UUID string of the technique to query.
source_type (str | None): Optional filter by source type (e.g.
``"cve"``).
reviewed (bool | None): Optional filter; ``True`` for reviewed items
only, ``False`` for unreviewed, ``None`` for all.
Returns:
list[OsintItem]: Matching OSINT items ordered by discovery date
descending.
"""
# Assign query = db.query(OsintItem).filter(OsintItem.technique_id == technique_id)
query = db.query(OsintItem).filter(OsintItem.technique_id == technique_id)
# Check: source_type
if source_type:
# Assign query = query.filter(OsintItem.source_type == source_type)
query = query.filter(OsintItem.source_type == source_type)
# Check: reviewed is not None
if reviewed is not None:
# Assign query = query.filter(OsintItem.reviewed == reviewed)
query = query.filter(OsintItem.reviewed == reviewed)
# Return query.order_by(OsintItem.discovered_at.desc()).all()
return query.order_by(OsintItem.discovered_at.desc()).all()
# Define function mark_osint_reviewed
def mark_osint_reviewed(db: Session, item_id: str) -> OsintItem | None:
"""Mark an OSINT item as reviewed. Does not commit; caller uses UnitOfWork."""
"""Mark an OSINT item as reviewed. Does not commit; caller uses UnitOfWork.
Args:
db (Session): Active SQLAlchemy database session.
item_id (str): UUID string of the OSINT item to mark.
Returns:
OsintItem | None: The updated item, or ``None`` if not found.
"""
# Assign item = db.query(OsintItem).filter(OsintItem.id == item_id).first()
item = db.query(OsintItem).filter(OsintItem.id == item_id).first()
# Check: item
if item:
# Assign item.reviewed = True
item.reviewed = True
# Return item
return item
# Define function get_unreviewed_count
def get_unreviewed_count(db: Session) -> int:
"""Return the total number of unreviewed OSINT items."""
"""Return the total number of unreviewed OSINT items.
Args:
db (Session): Active SQLAlchemy database session.
Returns:
int: Count of OSINT items where ``reviewed`` is ``False``.
"""
# Return db.query(OsintItem).filter(OsintItem.reviewed == False).count() # ...
return db.query(OsintItem).filter(OsintItem.reviewed == False).count() # noqa: E712
# Define function list_osint_items
def list_osint_items(
# Entry: db
db: Session,
*,
# Entry: technique_id
technique_id: Optional[UUID] = None,
# Entry: source_type
source_type: Optional[str] = None,
# Entry: reviewed
reviewed: Optional[bool] = None,
# Entry: offset
offset: int = 0,
# Entry: limit
limit: int = 50,
) -> dict:
"""List OSINT items with optional filters and pagination."""
"""List OSINT items with optional filters and pagination.
Args:
db (Session): Active SQLAlchemy database session.
technique_id (Optional[UUID]): Filter by technique UUID.
source_type (Optional[str]): Filter by source type string (e.g.
``"cve"``).
reviewed (Optional[bool]): Filter by reviewed status; ``None``
returns all.
offset (int): Number of records to skip for pagination.
limit (int): Maximum number of records to return.
Returns:
dict: Contains ``total`` count and ``items`` list of serialized OSINT
item dicts.
"""
# Assign query = db.query(OsintItem)
query = db.query(OsintItem)
# Check: technique_id
if technique_id:
# Assign query = query.filter(OsintItem.technique_id == technique_id)
query = query.filter(OsintItem.technique_id == technique_id)
# Check: source_type
if source_type:
# Assign query = query.filter(OsintItem.source_type == source_type)
query = query.filter(OsintItem.source_type == source_type)
# Check: reviewed is not None
if reviewed is not None:
# Assign query = query.filter(OsintItem.reviewed == reviewed)
query = query.filter(OsintItem.reviewed == reviewed)
# Assign total = query.count()
total = query.count()
# Assign items = (
items = (
query.order_by(OsintItem.discovered_at.desc())
# Chain .offset() call
.offset(offset)
# Chain .limit() call
.limit(limit)
# Chain .all() call
.all()
)
# Return {
return {
# Literal argument value
"total": total,
# Literal argument value
"items": [
{
# Literal argument value
"id": str(item.id),
# Literal argument value
"technique_id": str(item.technique_id),
# Literal argument value
"source_type": item.source_type,
# Literal argument value
"source_url": item.source_url,
# Literal argument value
"title": item.title,
# Literal argument value
"description": item.description,
# Literal argument value
"severity": item.severity,
# Literal argument value
"discovered_at": item.discovered_at.isoformat() if item.discovered_at else None,
# Literal argument value
"reviewed": item.reviewed,
# Literal argument value
"metadata": item.metadata_,
}
for item in items
@@ -239,39 +435,76 @@ def list_osint_items(
}
# Define function get_osint_summary
def get_osint_summary(db: Session) -> dict:
"""Summary statistics for OSINT items."""
"""Return summary statistics for OSINT items.
Args:
db (Session): Active SQLAlchemy database session.
Returns:
dict: Contains ``total_items``, ``unreviewed``,
``techniques_with_items``, ``by_severity``, and ``by_type``.
"""
# Assign total = db.query(func.count(OsintItem.id)).scalar() or 0
total = db.query(func.count(OsintItem.id)).scalar() or 0
# Assign unreviewed = get_unreviewed_count(db)
unreviewed = get_unreviewed_count(db)
# Assign by_severity = dict(
by_severity = dict(
db.query(OsintItem.severity, func.count(OsintItem.id))
# Chain .group_by() call
.group_by(OsintItem.severity)
# Chain .all() call
.all()
)
# Assign by_type = dict(
by_type = dict(
db.query(OsintItem.source_type, func.count(OsintItem.id))
# Chain .group_by() call
.group_by(OsintItem.source_type)
# Chain .all() call
.all()
)
# Assign techniques_with_items = (
techniques_with_items = (
db.query(func.count(func.distinct(OsintItem.technique_id))).scalar() or 0
)
# Return {
return {
# Literal argument value
"total_items": total,
# Literal argument value
"unreviewed": unreviewed,
# Literal argument value
"techniques_with_items": techniques_with_items,
# Literal argument value
"by_severity": by_severity,
# Literal argument value
"by_type": by_type,
}
# Define function get_technique_or_raise
def get_technique_or_raise(db: Session, technique_id: UUID) -> Technique:
"""Get a technique by ID or raise EntityNotFoundError."""
"""Return a technique by ID or raise EntityNotFoundError.
Args:
db (Session): Active SQLAlchemy database session.
technique_id (UUID): UUID of the technique to retrieve.
Returns:
Technique: The matching technique ORM object.
"""
# Assign technique = db.query(Technique).filter(Technique.id == technique_id).first()
technique = db.query(Technique).filter(Technique.id == technique_id).first()
# Check: not technique
if not technique:
# Raise EntityNotFoundError
raise EntityNotFoundError("Technique", str(technique_id))
# Return technique
return technique
+59 -3
View File
@@ -3,95 +3,151 @@
Uses WeasyPrint for PDF generation and docxtpl for DOCX.
"""
import os
import uuid
# Import logging
import logging
# Import os
import os
# Import uuid
import uuid
# Import datetime from datetime
from datetime import datetime
# Import Environment, FileSystemLoader from jinja2
from jinja2 import Environment, FileSystemLoader
# Import settings from app.config
from app.config import settings
# Assign logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
# Define class ReportEngine
class ReportEngine:
"""Template-based report generator supporting PDF, DOCX, and HTML output."""
# Define function __init__
def __init__(self) -> None:
"""Initialise the Jinja2 environment and ensure the output directory exists."""
# Assign self.jinja_env = Environment(
self.jinja_env = Environment(
# Keyword argument: loader
loader=FileSystemLoader(settings.REPORT_TEMPLATES_DIR),
# Keyword argument: autoescape
autoescape=True,
)
# Call os.makedirs()
os.makedirs(settings.REPORT_OUTPUT_DIR, exist_ok=True)
# Define function render_html
def render_html(self, template_name: str, context: dict) -> str:
"""Render a Jinja2 template to an HTML string."""
# Assign template = self.jinja_env.get_template(f"{template_name}.html")
template = self.jinja_env.get_template(f"{template_name}.html")
# Call context.setdefault()
context.setdefault("company_name", settings.COMPANY_NAME)
# Call context.setdefault()
context.setdefault("generated_at", datetime.utcnow().strftime("%B %d, %Y %H:%M UTC"))
# Return template.render(context)
return template.render(context)
# Define function generate_pdf
def generate_pdf(self, template_name: str, context: dict) -> str:
"""Render HTML and convert to PDF with WeasyPrint."""
from weasyprint import HTML, CSS
# Import CSS, HTML from weasyprint
from weasyprint import CSS, HTML
# Assign html_content = self.render_html(template_name, context)
html_content = self.render_html(template_name, context)
# Assign css_path = os.path.join(settings.REPORT_TEMPLATES_DIR, "styles", "report.css")
css_path = os.path.join(settings.REPORT_TEMPLATES_DIR, "styles", "report.css")
# Assign output_path = os.path.join(
output_path = os.path.join(
settings.REPORT_OUTPUT_DIR,
f"{template_name}_{uuid.uuid4().hex[:8]}.pdf",
)
# Assign stylesheets = []
stylesheets = []
# Check: os.path.exists(css_path)
if os.path.exists(css_path):
# Call stylesheets.append()
stylesheets.append(CSS(filename=css_path))
# Call HTML()
HTML(
# Keyword argument: string
string=html_content,
# Keyword argument: base_url
base_url=settings.REPORT_TEMPLATES_DIR,
).write_pdf(output_path, stylesheets=stylesheets)
# Log info: "PDF generated: %s", output_path
logger.info("PDF generated: %s", output_path)
# Return output_path
return output_path
# Define function generate_docx
def generate_docx(self, template_name: str, context: dict) -> str:
"""Render a .docx template with docxtpl."""
# Import DocxTemplate from docxtpl
from docxtpl import DocxTemplate
# Assign template_path = os.path.join(
template_path = os.path.join(
settings.REPORT_TEMPLATES_DIR, f"{template_name}.docx"
)
# Assign output_path = os.path.join(
output_path = os.path.join(
settings.REPORT_OUTPUT_DIR,
f"{template_name}_{uuid.uuid4().hex[:8]}.docx",
)
# Assign doc = DocxTemplate(template_path)
doc = DocxTemplate(template_path)
# Call context.setdefault()
context.setdefault("company_name", settings.COMPANY_NAME)
# Call context.setdefault()
context.setdefault("generated_at", datetime.utcnow().strftime("%B %d, %Y"))
# Call doc.render()
doc.render(context)
# Call doc.save()
doc.save(output_path)
# Log info: "DOCX generated: %s", output_path
logger.info("DOCX generated: %s", output_path)
# Return output_path
return output_path
# Define function generate_html
def generate_html(self, template_name: str, context: dict) -> str:
"""Render and save a standalone HTML report (alias for spec compatibility)."""
# Return self.generate_html_file(template_name, context)
return self.generate_html_file(template_name, context)
# Define function generate_html_file
def generate_html_file(self, template_name: str, context: dict) -> str:
"""Render and save a standalone HTML report."""
# Assign html_content = self.render_html(template_name, context)
html_content = self.render_html(template_name, context)
# Assign output_path = os.path.join(
output_path = os.path.join(
settings.REPORT_OUTPUT_DIR,
f"{template_name}_{uuid.uuid4().hex[:8]}.html",
)
# Open context manager
with open(output_path, "w", encoding="utf-8") as f:
# Call f.write()
f.write(html_content)
# Log info: "HTML report generated: %s", output_path
logger.info("HTML report generated: %s", output_path)
# Return output_path
return output_path
# Assign report_engine = ReportEngine()
report_engine = ReportEngine()
@@ -1,115 +1,195 @@
"""High-level report generation — collects domain data and delegates to ReportEngine."""
# Import logging
import logging
# Import datetime, timedelta from datetime
from datetime import datetime, timedelta
# Import UUID from uuid
from uuid import UUID
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session
# Import EntityNotFoundError from app.domain.exceptions
from app.domain.exceptions import EntityNotFoundError
# Import Campaign, CampaignTest from app.models.campaign
from app.models.campaign import Campaign, CampaignTest
# Import CoverageSnapshot from app.models.coverage_snapshot
from app.models.coverage_snapshot import CoverageSnapshot
# Import Technique from app.models.technique
from app.models.technique import Technique
# Import Test from app.models.test
from app.models.test import Test
# Import ThreatActor from app.models.threat_actor
from app.models.threat_actor import ThreatActor
# Import report_engine from app.services.report_engine
from app.services.report_engine import report_engine
# Assign logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
# Define function generate_purple_campaign_report
def generate_purple_campaign_report(
# Entry: db
db: Session,
# Entry: campaign_id
campaign_id: str,
# Entry: output_format
output_format: str = "pdf",
) -> str:
"""Generate the full Purple Team campaign report."""
# Assign cid = campaign_id if isinstance(campaign_id, UUID) else UUID(str(campaign...
cid = campaign_id if isinstance(campaign_id, UUID) else UUID(str(campaign_id))
# Assign campaign = db.query(Campaign).filter(Campaign.id == cid).first()
campaign = db.query(Campaign).filter(Campaign.id == cid).first()
# Check: not campaign
if not campaign:
# Raise EntityNotFoundError
raise EntityNotFoundError("Campaign", campaign_id)
# Assign campaign_tests = (
campaign_tests = (
db.query(Test)
# Chain .join() call
.join(CampaignTest, CampaignTest.test_id == Test.id)
# Chain .filter() call
.filter(CampaignTest.campaign_id == cid)
# Chain .all() call
.all()
)
# Assign tests_data = []
tests_data = []
# Iterate over campaign_tests
for test in campaign_tests:
# Assign technique = db.query(Technique).filter(Technique.id == test.technique_id).first()
technique = db.query(Technique).filter(Technique.id == test.technique_id).first()
# Call tests_data.append()
tests_data.append({
# Literal argument value
"technique_mitre_id": technique.mitre_id if technique else "N/A",
# Literal argument value
"name": test.name,
# Literal argument value
"tactic": technique.tactic if technique else "N/A",
# Literal argument value
"state": test.state.value if test.state else "draft",
# Literal argument value
"detection_result": (
test.detection_result.value if test.detection_result else "pending"
),
})
# Assign validated = [t for t in campaign_tests if t.state and t.state.value == "validat...
validated = [t for t in campaign_tests if t.state and t.state.value == "validated"]
# Assign detected = [
detected = [
t for t in validated
if t.detection_result and t.detection_result.value == "detected"
]
# Assign not_detected = [
not_detected = [
t for t in validated
if t.detection_result and t.detection_result.value == "not_detected"
]
# Assign critical_findings = [
critical_findings = [
{
# Literal argument value
"technique_id": t["technique_mitre_id"],
# Literal argument value
"name": t["name"],
# Literal argument value
"severity": "critical",
# Literal argument value
"description": "Technique was not detected during campaign execution.",
# Literal argument value
"recommendation": "Implement detection rule or review existing SIEM/EDR configuration.",
}
for t in tests_data
if t["detection_result"] == "not_detected"
]
# Assign org_score = _safe_org_score(db)
org_score = _safe_org_score(db)
# Assign threat_actors = []
threat_actors = []
# Check: campaign.threat_actor_id
if campaign.threat_actor_id:
# Assign actor = db.query(ThreatActor).filter(ThreatActor.id == campaign.threat_acto...
actor = db.query(ThreatActor).filter(ThreatActor.id == campaign.threat_actor_id).first()
# Check: actor
if actor:
# Assign threat_actors = [{"name": actor.name}]
threat_actors = [{"name": actor.name}]
# Assign context = {
context = {
# Literal argument value
"campaign": campaign,
# Literal argument value
"tests": tests_data,
# Literal argument value
"tests_validated": len(validated),
# Literal argument value
"tests_detected": len(detected),
# Literal argument value
"tests_not_detected": len(not_detected),
# Literal argument value
"critical_findings": critical_findings,
# Literal argument value
"org_score": org_score.get("overall", 0),
# Literal argument value
"tactics": list({t["tactic"] for t in tests_data}),
# Literal argument value
"threat_actors": threat_actors,
}
# Return _generate(output_format, "purple_campaign", context)
return _generate(output_format, "purple_campaign", context)
# Define function generate_coverage_report
def generate_coverage_report(
# Entry: db
db: Session,
# Entry: output_format
output_format: str = "pdf",
) -> str:
"""Generate an organization-wide MITRE ATT&CK coverage report."""
from sqlalchemy import func, case
# Import case, func from sqlalchemy
from sqlalchemy import case, func
# Assign org_score = _safe_org_score(db)
org_score = _safe_org_score(db)
# Assign techniques = db.query(Technique).all()
techniques = db.query(Technique).all()
# Assign status_counts = {"validated": 0, "partial": 0, "not_covered": 0, "in_progress": 0, ...
status_counts = {"validated": 0, "partial": 0, "not_covered": 0, "in_progress": 0, "not_evaluated": 0}
# Iterate over techniques
for t in techniques:
# Assign s = t.status_global.value if t.status_global else "not_evaluated"
s = t.status_global.value if t.status_global else "not_evaluated"
# Check: s in status_counts
if s in status_counts:
# Assign status_counts[s] = 1
status_counts[s] += 1
# Assign summary = {
summary = {
# Literal argument value
"total_techniques": len(techniques),
**status_counts,
}
@@ -121,14 +201,21 @@ def generate_coverage_report(
func.count(Technique.id).label("total"),
func.sum(case((Technique.status_global == "validated", 1), else_=0)).label("validated"),
)
# Chain .group_by() call
.group_by(Technique.tactic)
# Chain .all() call
.all()
)
# Assign tactics_coverage = [
tactics_coverage = [
{
# Literal argument value
"tactic": r[0] or "Unknown",
# Literal argument value
"total": r[1],
# Literal argument value
"validated": int(r[2]),
# Literal argument value
"coverage_pct": round((int(r[2]) / r[1]) * 100, 1) if r[1] > 0 else 0,
}
for r in tactic_rows
@@ -136,213 +223,326 @@ def generate_coverage_report(
# Never-tested techniques
tested_ids = {t.technique_id for t in db.query(Test.technique_id).distinct().all()}
# Assign never_tested = [
never_tested = [
{"mitre_id": t.mitre_id, "name": t.name, "tactic": t.tactic}
for t in techniques
if t.id not in tested_ids
]
# Assign context = {
context = {
# Literal argument value
"org_score": org_score,
# Literal argument value
"summary": summary,
# Literal argument value
"tactics_coverage": tactics_coverage,
# Literal argument value
"never_tested": never_tested[:50],
}
# Return _generate(output_format, "coverage_report", context)
return _generate(output_format, "coverage_report", context)
# Define function generate_executive_summary
def generate_executive_summary(
# Entry: db
db: Session,
# Entry: output_format
output_format: str = "pdf",
) -> str:
"""Generate an executive summary report."""
# Import func from sqlalchemy
from sqlalchemy import func
# Assign org_score = _safe_org_score(db)
org_score = _safe_org_score(db)
# Assign techniques = db.query(Technique).all()
techniques = db.query(Technique).all()
# Assign status_counts = {"validated": 0, "partial": 0, "not_covered": 0, "in_progress": 0, ...
status_counts = {"validated": 0, "partial": 0, "not_covered": 0, "in_progress": 0, "not_evaluated": 0}
# Iterate over techniques
for t in techniques:
# Assign s = t.status_global.value if t.status_global else "not_evaluated"
s = t.status_global.value if t.status_global else "not_evaluated"
# Check: s in status_counts
if s in status_counts:
# Assign status_counts[s] = 1
status_counts[s] += 1
# Assign summary = {"total_techniques": len(techniques), **status_counts}
summary = {"total_techniques": len(techniques), **status_counts}
# Assign total_tests = db.query(func.count(Test.id)).scalar() or 0
total_tests = db.query(func.count(Test.id)).scalar() or 0
# Assign active_campaigns = (
active_campaigns = (
db.query(func.count(Campaign.id)).filter(Campaign.status == "active").scalar() or 0
)
# Assign quarter_ago = datetime.utcnow() - timedelta(days=90)
quarter_ago = datetime.utcnow() - timedelta(days=90)
# Assign tests_this_quarter = (
tests_this_quarter = (
db.query(func.count(Test.id)).filter(Test.created_at >= quarter_ago).scalar() or 0
)
# Assign open_remediations = (
open_remediations = (
db.query(func.count(Test.id))
# Chain .filter() call
.filter(Test.remediation_status.in_(["pending", "in_progress"]))
# Chain .scalar() call
.scalar() or 0
)
# Detection rate among validated tests
validated_count = status_counts["validated"]
# Assign detected_count = (
detected_count = (
db.query(func.count(Test.id))
# Chain .filter() call
.filter(Test.state == "validated", Test.detection_result == "detected")
# Chain .scalar() call
.scalar() or 0
)
# Assign detection_rate = round((detected_count / validated_count) * 100, 1) if validated_cou...
detection_rate = round((detected_count / validated_count) * 100, 1) if validated_count > 0 else 0
# Top gaps — lowest coverage tactics
from sqlalchemy import case as sql_case
# Assign tactic_rows = (
tactic_rows = (
db.query(
Technique.tactic,
func.count(Technique.id).label("total"),
func.sum(sql_case((Technique.status_global == "validated", 1), else_=0)).label("validated"),
)
# Chain .group_by() call
.group_by(Technique.tactic)
# Chain .all() call
.all()
)
# Assign tactic_coverage = [
tactic_coverage = [
{
# Literal argument value
"tactic": r[0] or "Unknown",
# Literal argument value
"coverage_pct": round((int(r[2]) / r[1]) * 100, 1) if r[1] > 0 else 0,
}
for r in tactic_rows
]
# Assign top_gaps = sorted(tactic_coverage, key=lambda x: x["coverage_pct"])[:5]
top_gaps = sorted(tactic_coverage, key=lambda x: x["coverage_pct"])[:5]
# Assign context = {
context = {
# Literal argument value
"org_score": org_score,
# Literal argument value
"summary": summary,
# Literal argument value
"total_tests": total_tests,
# Literal argument value
"active_campaigns": active_campaigns,
# Literal argument value
"tests_this_quarter": tests_this_quarter,
# Literal argument value
"open_remediations": open_remediations,
# Literal argument value
"detection_rate": detection_rate,
# Literal argument value
"top_gaps": top_gaps,
}
# Return _generate(output_format, "executive_summary", context)
return _generate(output_format, "executive_summary", context)
# Define function generate_quarterly_summary
def generate_quarterly_summary(
# Entry: db
db: Session,
# Entry: output_format
output_format: str = "pdf",
) -> str:
"""Quarterly summary — reuses executive metrics plus snapshot trend rows."""
from sqlalchemy import case as sql_case, func
# Import case as sql_case from sqlalchemy
from sqlalchemy import case as sql_case
# Import func from sqlalchemy
from sqlalchemy import func
# Assign org_score = _safe_org_score(db)
org_score = _safe_org_score(db)
# Assign quarter_ago = datetime.utcnow() - timedelta(days=90)
quarter_ago = datetime.utcnow() - timedelta(days=90)
# Assign tests_this_quarter = (
tests_this_quarter = (
db.query(func.count(Test.id)).filter(Test.created_at >= quarter_ago).scalar() or 0
)
# Assign techniques = db.query(Technique).all()
techniques = db.query(Technique).all()
# Assign validated_count = sum(
validated_count = sum(
# Literal argument value
1 for t in techniques if t.status_global and t.status_global.value == "validated"
)
# Assign detected_count = (
detected_count = (
db.query(func.count(Test.id))
# Chain .filter() call
.filter(Test.state == "validated", Test.detection_result == "detected")
# Chain .scalar() call
.scalar() or 0
)
# Assign detection_rate = (
detection_rate = (
round((detected_count / validated_count) * 100, 1) if validated_count > 0 else 0
)
# Assign tactic_rows = (
tactic_rows = (
db.query(
Technique.tactic,
func.count(Technique.id).label("total"),
func.sum(sql_case((Technique.status_global == "validated", 1), else_=0)).label(
# Literal argument value
"validated",
),
)
# Chain .group_by() call
.group_by(Technique.tactic)
# Chain .all() call
.all()
)
# Assign top_gaps = sorted(
top_gaps = sorted(
[
{
# Literal argument value
"tactic": r[0] or "Unknown",
# Literal argument value
"coverage_pct": round((int(r[2]) / r[1]) * 100, 1) if r[1] > 0 else 0,
}
for r in tactic_rows
],
# Keyword argument: key
key=lambda x: x["coverage_pct"],
)[:5]
# Assign snapshots = (
snapshots = (
db.query(CoverageSnapshot)
# Chain .filter() call
.filter(CoverageSnapshot.created_at >= quarter_ago)
# Chain .order_by() call
.order_by(CoverageSnapshot.created_at)
# Chain .all() call
.all()
)
# Assign trend_rows = [
trend_rows = [
{
# Literal argument value
"date": s.created_at.strftime("%Y-%m-%d") if s.created_at else "",
# Literal argument value
"validated_count": s.validated_count,
# Literal argument value
"total_techniques": s.total_techniques,
# Literal argument value
"organization_score": round(s.organization_score, 1),
}
for s in snapshots
]
# Assign now = datetime.utcnow()
now = datetime.utcnow()
# Assign quarter_label = f"Q{((now.month - 1) // 3) + 1} {now.year}"
quarter_label = f"Q{((now.month - 1) // 3) + 1} {now.year}"
# Assign context = {
context = {
# Literal argument value
"quarter_label": quarter_label,
# Literal argument value
"org_score": org_score,
# Literal argument value
"tests_this_quarter": tests_this_quarter,
# Literal argument value
"detection_rate": detection_rate,
# Literal argument value
"trend_rows": trend_rows,
# Literal argument value
"top_gaps": top_gaps,
}
# Return _generate(output_format, "quarterly_summary", context)
return _generate(output_format, "quarterly_summary", context)
# Define function generate_technique_detail_report
def generate_technique_detail_report(
# Entry: db
db: Session,
# Entry: technique_id
technique_id: str,
# Entry: output_format
output_format: str = "pdf",
) -> str:
"""Detailed report for a single MITRE technique and its tests."""
# Assign tid = technique_id if isinstance(technique_id, UUID) else UUID(str(techni...
tid = technique_id if isinstance(technique_id, UUID) else UUID(str(technique_id))
# Assign technique = db.query(Technique).filter(Technique.id == tid).first()
technique = db.query(Technique).filter(Technique.id == tid).first()
# Check: not technique
if not technique:
# Raise EntityNotFoundError
raise EntityNotFoundError("Technique", str(technique_id))
# Assign related_tests = (
related_tests = (
db.query(Test)
# Chain .filter() call
.filter(Test.technique_id == tid)
# Chain .order_by() call
.order_by(Test.created_at.desc())
# Chain .all() call
.all()
)
# Assign tests_data = [
tests_data = [
{
# Literal argument value
"name": t.name,
# Literal argument value
"state": t.state.value if t.state else "draft",
# Literal argument value
"detection_result": (
t.detection_result.value if t.detection_result else "pending"
),
# Literal argument value
"created_at": t.created_at.strftime("%Y-%m-%d") if t.created_at else "",
}
for t in related_tests
]
# Assign context = {
context = {
# Literal argument value
"technique": technique,
# Literal argument value
"technique_status": (
technique.status_global.value if technique.status_global else "not_evaluated"
),
# Literal argument value
"tests": tests_data,
}
# Return _generate(output_format, "technique_detail", context)
return _generate(output_format, "technique_detail", context)
@@ -351,19 +551,32 @@ def generate_technique_detail_report(
def _safe_org_score(db: Session) -> dict:
"""Safely call the scoring service; return empty dict on failure."""
# Attempt the following; catch errors below
try:
# Import calculate_organization_score from app.services.scoring_service
from app.services.scoring_service import calculate_organization_score
# Return calculate_organization_score(db)
return calculate_organization_score(db)
# Handle Exception
except Exception as e:
# Log warning: "Scoring service unavailable: %s", e
logger.warning("Scoring service unavailable: %s", e)
# Return {"overall": 0, "coverage": 0, "detection_maturity": 0}
return {"overall": 0, "coverage": 0, "detection_maturity": 0}
# Define function _generate
def _generate(output_format: str, template_name: str, context: dict) -> str:
"""Dispatch to the correct ReportEngine method."""
# Check: output_format == "pdf"
if output_format == "pdf":
# Return report_engine.generate_pdf(template_name, context)
return report_engine.generate_pdf(template_name, context)
# Alternative: output_format == "docx"
elif output_format == "docx":
# Return report_engine.generate_docx(template_name, context)
return report_engine.generate_docx(template_name, context)
# Fallback: handle remaining cases
else:
# Return report_engine.generate_html_file(template_name, context)
return report_engine.generate_html_file(template_name, context)
+53 -8
View File
@@ -7,78 +7,123 @@ Thread-safe: each worker process has its own dict, and the TTL ensures
stale data does not persist longer than ``CACHE_TTL`` seconds.
"""
# Import time
import time
# Import Any, Optional from typing
from typing import Any, Optional
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session
# Assign CACHE_TTL = 300 # 5 minutes
CACHE_TTL = 300 # 5 minutes
# Assign _cache = {}
_cache: dict[str, dict[str, Any]] = {}
def get(key: str) -> Optional[Any]:
# Define function get
def get(key: str) -> Optional[Any]: # noqa: ANN401 # generic cache returns whatever was stored
"""Return cached value if present and not expired, else None."""
# Assign entry = _cache.get(key)
entry = _cache.get(key)
# Check: entry is None
if entry is None:
# Return None
return None
# Check: time.time() - entry["ts"] > CACHE_TTL
if time.time() - entry["ts"] > CACHE_TTL:
# Call _cache.pop()
_cache.pop(key, None)
# Return None
return None
# Return entry["data"]
return entry["data"]
def put(key: str, data: Any) -> None:
# Define function put
def put(key: str, data: Any) -> None: # noqa: ANN401 # generic cache accepts any serialisable value
"""Store *data* under *key* with the current timestamp."""
# Assign _cache[key] = {"data": data, "ts": time.time()}
_cache[key] = {"data": data, "ts": time.time()}
# Define function invalidate
def invalidate(key: Optional[str] = None) -> None:
"""Remove one key or clear the whole cache."""
# Check: key is None
if key is None:
# Call _cache.clear()
_cache.clear()
# Fallback: handle remaining cases
else:
# Call _cache.pop()
_cache.pop(key, None)
# ── High-level helpers ────────────────────────────────────────────────
def get_organization_score_cached(db):
def get_organization_score_cached(db: Session) -> dict:
"""Cached wrapper around ``calculate_organization_score``."""
# Import calculate_organization_score from app.services.scoring_service
from app.services.scoring_service import calculate_organization_score
# Assign cached = get("org_score")
cached = get("org_score")
# Check: cached is not None
if cached is not None:
# Return cached
return cached
# Assign result = calculate_organization_score(db)
result = calculate_organization_score(db)
# Call put()
put("org_score", result)
# Return result
return result
def get_operational_metrics_cached(db):
# Define function get_operational_metrics_cached
def get_operational_metrics_cached(db: Session) -> dict:
"""Cached wrapper around operational metrics (MTTD, MTTR, efficacy)."""
# Import from app.services.operational_metrics_service
from app.services.operational_metrics_service import (
calculate_mttd,
calculate_mttr,
calculate_detection_efficacy,
calculate_alert_fidelity,
calculate_coverage_velocity,
calculate_validation_throughput,
calculate_detection_efficacy,
calculate_mttd,
calculate_mttr,
calculate_rejection_rate,
calculate_validation_throughput,
)
# Assign cached = get("op_metrics")
cached = get("op_metrics")
# Check: cached is not None
if cached is not None:
# Return cached
return cached
# Assign result = {
result = {
# Literal argument value
"mttd": calculate_mttd(db),
# Literal argument value
"mttr": calculate_mttr(db),
# Literal argument value
"detection_efficacy": calculate_detection_efficacy(db),
# Literal argument value
"alert_fidelity": calculate_alert_fidelity(db),
# Literal argument value
"coverage_velocity": calculate_coverage_velocity(db),
# Literal argument value
"validation_throughput": calculate_validation_throughput(db),
# Literal argument value
"rejection_rate": calculate_rejection_rate(db),
}
# Call put()
put("op_metrics", result)
# Return result
return result
@@ -1,121 +1,202 @@
"""Scoring configuration persistence service."""
# Enable future language features for compatibility
from __future__ import annotations
# Import uuid
import uuid
# Import Any from typing
from typing import Any
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session
# Import settings from app.config
from app.config import settings
# Import ScoringWeights from app.domain.value_objects.scoring_weights
from app.domain.value_objects.scoring_weights import ScoringWeights
# Import ScoringConfig from app.models.scoring_config
from app.models.scoring_config import ScoringConfig
# Define function _row_recency
def _row_recency(row: ScoringConfig) -> float:
# Return float(getattr(row, "weight_recency", None) or getattr(row, "weight_...
return float(getattr(row, "weight_recency", None) or getattr(row, "weight_freshness", 10.0))
# Define function _row_severity
def _row_severity(row: ScoringConfig) -> float:
# Return float(
return float(
getattr(row, "weight_severity", None)
or getattr(row, "weight_platform_diversity", 10.0)
)
# Define function get_scoring_weights
def get_scoring_weights(db: Session) -> ScoringWeights:
"""Return the active scoring weights from the database or env defaults."""
# Assign row = db.query(ScoringConfig).first()
row = db.query(ScoringConfig).first()
# Check: row is not None
if row is not None:
# Return ScoringWeights(
return ScoringWeights(
# Keyword argument: tests
tests=row.weight_tests,
# Keyword argument: detection_rules
detection_rules=row.weight_detection_rules,
# Keyword argument: d3fend
d3fend=row.weight_d3fend,
# Keyword argument: recency
recency=_row_recency(row),
# Keyword argument: severity
severity=_row_severity(row),
)
# Return ScoringWeights(
return ScoringWeights(
# Keyword argument: tests
tests=float(settings.SCORING_WEIGHT_TESTS),
# Keyword argument: detection_rules
detection_rules=float(settings.SCORING_WEIGHT_DETECTION_RULES),
# Keyword argument: d3fend
d3fend=float(settings.SCORING_WEIGHT_D3FEND),
# Keyword argument: recency
recency=float(
getattr(settings, "SCORING_WEIGHT_RECENCY", settings.SCORING_WEIGHT_FRESHNESS)
),
# Keyword argument: severity
severity=float(
getattr(settings, "SCORING_WEIGHT_SEVERITY", settings.SCORING_WEIGHT_PLATFORM_DIVERSITY)
),
)
# Define function update_scoring_weights
def update_scoring_weights(
# Entry: db
db: Session,
*,
# Entry: tests
tests: float | None = None,
# Entry: detection_rules
detection_rules: float | None = None,
# Entry: d3fend
d3fend: float | None = None,
# Entry: recency
recency: float | None = None,
# Entry: severity
severity: float | None = None,
# Entry: freshness
freshness: float | None = None,
# Entry: platform_diversity
platform_diversity: float | None = None,
# Entry: updated_by
updated_by: uuid.UUID | None = None,
) -> dict[str, Any]:
"""Upsert scoring weights. Does not commit."""
# Check: freshness is not None and recency is None
if freshness is not None and recency is None:
# Assign recency = freshness
recency = freshness
# Check: platform_diversity is not None and severity is None
if platform_diversity is not None and severity is None:
# Assign severity = platform_diversity
severity = platform_diversity
# Assign current = get_scoring_weights(db)
current = get_scoring_weights(db)
# Assign new = ScoringWeights(
new = ScoringWeights(
# Keyword argument: tests
tests=tests if tests is not None else current.tests,
# Keyword argument: detection_rules
detection_rules=detection_rules if detection_rules is not None else current.detection_rules,
# Keyword argument: d3fend
d3fend=d3fend if d3fend is not None else current.d3fend,
# Keyword argument: recency
recency=recency if recency is not None else current.recency,
# Keyword argument: severity
severity=severity if severity is not None else current.severity,
)
# Assign row = db.query(ScoringConfig).first()
row = db.query(ScoringConfig).first()
# Check: row is None
if row is None:
# Assign row = ScoringConfig()
row = ScoringConfig()
# Stage new record(s) for database insertion
db.add(row)
# Assign row.weight_tests = new.tests
row.weight_tests = new.tests
# Assign row.weight_detection_rules = new.detection_rules
row.weight_detection_rules = new.detection_rules
# Assign row.weight_d3fend = new.d3fend
row.weight_d3fend = new.d3fend
# Check: hasattr(row, "weight_recency")
if hasattr(row, "weight_recency"):
# Assign row.weight_recency = new.recency
row.weight_recency = new.recency
# Alternative: hasattr(row, "weight_freshness")
elif hasattr(row, "weight_freshness"):
# Assign row.weight_freshness = new.recency
row.weight_freshness = new.recency
# Check: hasattr(row, "weight_severity")
if hasattr(row, "weight_severity"):
# Assign row.weight_severity = new.severity
row.weight_severity = new.severity
# Alternative: hasattr(row, "weight_platform_diversity")
elif hasattr(row, "weight_platform_diversity"):
# Assign row.weight_platform_diversity = new.severity
row.weight_platform_diversity = new.severity
# Check: updated_by is not None and hasattr(row, "updated_by")
if updated_by is not None and hasattr(row, "updated_by"):
# Assign row.updated_by = updated_by
row.updated_by = updated_by
# Return _weights_dict(new)
return _weights_dict(new)
# Define function get_weights_dict
def get_weights_dict(db: Session) -> dict[str, Any]:
"""Return current weights as a serialisable dict."""
# Return _weights_dict(get_scoring_weights(db))
return _weights_dict(get_scoring_weights(db))
# Define function _weights_dict
def _weights_dict(w: ScoringWeights) -> dict[str, Any]:
# Assign weights = {
weights = {
# Literal argument value
"tests": w.tests,
# Literal argument value
"detection_rules": w.detection_rules,
# Literal argument value
"d3fend": w.d3fend,
# Literal argument value
"recency": w.recency,
# Literal argument value
"severity": w.severity,
# Legacy keys for older clients
"freshness": w.recency,
# Literal argument value
"platform_diversity": w.severity,
}
# Return {
return {
# Literal argument value
"weights": weights,
# Literal argument value
"total": sum(
[w.tests, w.detection_rules, w.d3fend, w.recency, w.severity]
),
File diff suppressed because it is too large Load Diff
+184 -12
View File
@@ -22,24 +22,45 @@ rules are identified by ``source = "sigma"`` + ``source_id`` (relative
file path) and simply skipped.
"""
# Import io
import io
# Import logging
import logging
# Import re
import re
# Import shutil
import shutil
# Import tempfile
import tempfile
# Import zipfile
import zipfile
# Import datetime from datetime
from datetime import datetime
# Import Path from pathlib
from pathlib import Path
# Import requests
import requests as _requests
# Import yaml
import yaml
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session
from app.models.detection_rule import DetectionRule
# Import DataSource from app.models.data_source
from app.models.data_source import DataSource
from app.models.technique import Technique
from app.services.audit_service import log_action
# Assign logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
@@ -47,22 +68,35 @@ logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
SIGMA_ZIP_URL = (
# Literal argument value
"https://github.com/SigmaHQ/sigma/archive/refs/heads/master.zip"
)
# Assign _DOWNLOAD_TIMEOUT = 300
_DOWNLOAD_TIMEOUT = 300
# Assign _ZIP_ROOT_PREFIX = "sigma-master"
_ZIP_ROOT_PREFIX = "sigma-master"
# Safety limits for ZIP extraction — prevent zip-bomb DoS
_MAX_UNCOMPRESSED_SIZE = 500 * 1024 * 1024 # 500 MB
# Assign _MAX_ENTRIES = 50_000
_MAX_ENTRIES = 50_000
# Regex to extract MITRE ATT&CK technique IDs from Sigma tags
# e.g. "attack.t1059.001" → "T1059.001"
_ATTACK_TAG_RE = re.compile(r"attack\.(t\d{4}(?:\.\d{3})?)", re.IGNORECASE)
# Sigma severity levels
_SEVERITY_MAP = {
# Literal argument value
"informational": "informational",
# Literal argument value
"low": "low",
# Literal argument value
"medium": "medium",
# Literal argument value
"high": "high",
# Literal argument value
"critical": "critical",
}
@@ -74,14 +108,21 @@ _SEVERITY_MAP = {
def _download_zip(url: str = SIGMA_ZIP_URL) -> bytes:
"""Download the SigmaHQ ZIP and return raw bytes."""
# Log info: "Downloading SigmaHQ ZIP from %s …", url
logger.info("Downloading SigmaHQ ZIP from %s", url)
# Assign resp = _requests.get(url, timeout=_DOWNLOAD_TIMEOUT, stream=True)
resp = _requests.get(url, timeout=_DOWNLOAD_TIMEOUT, stream=True)
# Call resp.raise_for_status()
resp.raise_for_status()
# Assign content = resp.content
content = resp.content
# Log info: "Downloaded %.1f MB", len(content) / (1024 * 1024
logger.info("Downloaded %.1f MB", len(content) / (1024 * 1024))
# Return content
return content
# Define function _safe_extract_zip
def _safe_extract_zip(zip_bytes: bytes, dest: str) -> None:
"""Extract *zip_bytes* into *dest* with Zip Slip and Zip Bomb protection.
@@ -89,165 +130,249 @@ def _safe_extract_zip(zip_bytes: bytes, dest: str) -> None:
directory (path traversal / Zip Slip) or if the archive exceeds the
safety limits.
"""
# Maximum uncompressed size: 500 MB — prevents zip-bomb DoS
_MAX_UNCOMPRESSED_SIZE = 500 * 1024 * 1024
# Maximum number of entries
_MAX_ENTRIES = 50_000
# Assign dest_path = Path(dest).resolve()
dest_path = Path(dest).resolve()
# Open context manager
with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
# Assign entries = zf.infolist()
entries = zf.infolist()
# Check: len(entries) > _MAX_ENTRIES
if len(entries) > _MAX_ENTRIES:
# Raise ValueError
raise ValueError(
f"ZIP archive contains {len(entries)} entries "
f"(limit: {_MAX_ENTRIES}) — possible zip bomb"
)
# Assign total_size = sum(info.file_size for info in entries)
total_size = sum(info.file_size for info in entries)
# Check: total_size > _MAX_UNCOMPRESSED_SIZE
if total_size > _MAX_UNCOMPRESSED_SIZE:
# Raise ValueError
raise ValueError(
f"ZIP uncompressed size {total_size / (1024 * 1024):.0f} MB "
f"exceeds limit of {_MAX_UNCOMPRESSED_SIZE / (1024 * 1024):.0f} MB"
)
# Iterate over entries
for member in entries:
# Assign target = (dest_path / member.filename).resolve()
target = (dest_path / member.filename).resolve()
# Check: not target.is_relative_to(dest_path)
if not target.is_relative_to(dest_path):
# Raise ValueError
raise ValueError(
f"Zip Slip detected — member '{member.filename}' "
f"resolves outside target directory"
)
# Call zf.extractall()
zf.extractall(dest)
# Define function _extract_zip
def _extract_zip(zip_bytes: bytes, dest: str) -> Path:
"""Extract *zip_bytes* into *dest* and return the path to rules/ dir."""
# Call _safe_extract_zip()
_safe_extract_zip(zip_bytes, dest)
# Assign rules_dir = Path(dest) / _ZIP_ROOT_PREFIX / "rules"
rules_dir = Path(dest) / _ZIP_ROOT_PREFIX / "rules"
# Check: not rules_dir.is_dir()
if not rules_dir.is_dir():
# Raise FileNotFoundError
raise FileNotFoundError(
f"Expected rules directory not found at {rules_dir}"
)
# Return rules_dir
return rules_dir
# Define function _extract_attack_tags
def _extract_attack_tags(tags: list) -> list[str]:
"""Extract MITRE technique IDs from Sigma tag list.
Example input: ["attack.defense_evasion", "attack.t1059.001", "cve.2021.44228"]
Example output: ["T1059.001"]
"""
# Assign technique_ids = []
technique_ids = []
# Iterate over tags
for tag in tags:
# Assign m = _ATTACK_TAG_RE.match(str(tag).strip())
m = _ATTACK_TAG_RE.match(str(tag).strip())
# Check: m
if m:
# Call technique_ids.append()
technique_ids.append(m.group(1).upper())
# Return list(set(technique_ids))
return list(set(technique_ids))
# Define function _parse_sigma_rules
def _parse_sigma_rules(rules_dir: Path) -> list[dict]:
"""Walk the rules directory and parse all Sigma YAML files.
Returns a flat list of dicts, one per (rule, technique) combination.
A single Sigma rule tagged with N techniques produces N entries.
"""
# Assign results = []
results: list[dict] = []
# Assign yaml_files = sorted(rules_dir.rglob("*.yml"))
yaml_files = sorted(rules_dir.rglob("*.yml"))
# Log info: "Found %d YAML files to parse", len(yaml_files
logger.info("Found %d YAML files to parse", len(yaml_files))
# Iterate over yaml_files
for yaml_path in yaml_files:
# Assign relative_path = str(yaml_path.relative_to(rules_dir.parent))
relative_path = str(yaml_path.relative_to(rules_dir.parent))
# Attempt the following; catch errors below
try:
# Open context manager
with open(yaml_path, "r", encoding="utf-8") as fh:
# Assign data = yaml.safe_load(fh)
data = yaml.safe_load(fh)
# Handle Exception
except Exception as exc:
# Log debug: "Failed to parse %s: %s", yaml_path, exc
logger.debug("Failed to parse %s: %s", yaml_path, exc)
# Skip to the next loop iteration
continue
# Check: not isinstance(data, dict)
if not isinstance(data, dict):
# Skip to the next loop iteration
continue
# Assign title = data.get("title", "").strip()
title = data.get("title", "").strip()
# Check: not title
if not title:
# Skip to the next loop iteration
continue
# Extract ATT&CK technique IDs from tags
tags = data.get("tags", [])
# Check: not isinstance(tags, list)
if not isinstance(tags, list):
# Skip to the next loop iteration
continue
# Assign technique_ids = _extract_attack_tags(tags)
technique_ids = _extract_attack_tags(tags)
# Check: not technique_ids
if not technique_ids:
# continue # Skip rules without ATT&CK mapping
continue # Skip rules without ATT&CK mapping
# Assign description = data.get("description", "")
description = data.get("description", "")
# Assign level = str(data.get("level", "")).lower()
level = str(data.get("level", "")).lower()
# Assign severity = _SEVERITY_MAP.get(level)
severity = _SEVERITY_MAP.get(level)
# Extract logsource
logsource = data.get("logsource", {})
# Check: not isinstance(logsource, dict)
if not isinstance(logsource, dict):
# Assign logsource = {}
logsource = {}
# Read full YAML content for storage
try:
# Open context manager
with open(yaml_path, "r", encoding="utf-8") as fh:
# Assign raw_content = fh.read()
raw_content = fh.read()
# Handle Exception
except Exception:
# Assign raw_content = yaml.dump(data, default_flow_style=False)
raw_content = yaml.dump(data, default_flow_style=False)
# False positive assessment
falsepositives = data.get("falsepositives", [])
# Check: isinstance(falsepositives, list) and len(falsepositives) > 3
if isinstance(falsepositives, list) and len(falsepositives) > 3:
# Assign fp_rate = "high"
fp_rate = "high"
# Alternative: isinstance(falsepositives, list) and len(falsepositives) > 1
elif isinstance(falsepositives, list) and len(falsepositives) > 1:
# Assign fp_rate = "medium"
fp_rate = "medium"
# Fallback: handle remaining cases
else:
# Assign fp_rate = "low"
fp_rate = "low"
# Create one entry per technique
for tech_id in technique_ids:
# Assign source_url = (
source_url = (
f"https://github.com/SigmaHQ/sigma/blob/master/"
f"{relative_path.replace(chr(92), '/')}"
)
# Call results.append()
results.append({
# Literal argument value
"mitre_technique_id": tech_id,
# Literal argument value
"title": title[:500],
# Literal argument value
"description": str(description)[:2000] if description else None,
# Literal argument value
"source_id": relative_path,
# Literal argument value
"source_url": source_url,
# Literal argument value
"rule_content": raw_content,
# Literal argument value
"severity": severity,
# Literal argument value
"log_sources": logsource if logsource else None,
# Literal argument value
"false_positive_rate": fp_rate,
# Literal argument value
"platforms": _platforms_from_logsource(logsource),
})
# Log info: "Parsed %d (rule, technique) pairs total", len(res
logger.info("Parsed %d (rule, technique) pairs total", len(results))
# Return results
return results
# Define function _platforms_from_logsource
def _platforms_from_logsource(logsource: dict) -> list[str]:
"""Infer platform list from Sigma logsource."""
# Assign platforms = []
platforms = []
# Assign product = str(logsource.get("product", "")).lower()
product = str(logsource.get("product", "")).lower()
# Assign service = str(logsource.get("service", "")).lower()
service = str(logsource.get("service", "")).lower()
# Check: "windows" in product or "windows" in service
if "windows" in product or "windows" in service:
# Call platforms.append()
platforms.append("windows")
# Check: "linux" in product or "linux" in service
if "linux" in product or "linux" in service:
# Call platforms.append()
platforms.append("linux")
# Check: "macos" in product or "macos" in service
if "macos" in product or "macos" in service:
# Call platforms.append()
platforms.append("macos")
# Sysmon → Windows
if "sysmon" in service and "windows" not in platforms:
# Call platforms.append()
platforms.append("windows")
# Return platforms if platforms else None
return platforms if platforms else None
@@ -264,59 +389,88 @@ def sync(db: Session) -> dict:
db : Session
Active SQLAlchemy database session.
Returns
Returns:
-------
dict
Summary with ``created``, ``skipped_existing``, ``total_parsed``.
"""
# Assign tmp_dir = tempfile.mkdtemp(prefix="aegis_sigma_")
tmp_dir = tempfile.mkdtemp(prefix="aegis_sigma_")
# Attempt the following; catch errors below
try:
# Assign zip_bytes = _download_zip()
zip_bytes = _download_zip()
# Assign rules_dir = _extract_zip(zip_bytes, tmp_dir)
rules_dir = _extract_zip(zip_bytes, tmp_dir)
# Assign parsed_rules = _parse_sigma_rules(rules_dir)
parsed_rules = _parse_sigma_rules(rules_dir)
# Always execute this cleanup block
finally:
# Call shutil.rmtree()
shutil.rmtree(tmp_dir, ignore_errors=True)
# Log info: "Cleaned up temp directory %s", tmp_dir
logger.info("Cleaned up temp directory %s", tmp_dir)
# Pre-load existing source_ids for dedup
existing_ids: set[str] = {
row[0]
for row in db.query(DetectionRule.source_id)
# Chain .filter() call
.filter(DetectionRule.source == "sigma")
# Chain .filter() call
.filter(DetectionRule.source_id.isnot(None))
# Chain .all() call
.all()
}
# Assign created = 0
created = 0
# Assign skipped = 0
skipped = 0
new_technique_ids: set[str] = set()
# Iterate over parsed_rules
for item in parsed_rules:
# Dedup key: source_id (relative path). A rule file may produce
# multiple entries (one per technique), but we deduplicate by
# source_id so re-runs are safe. For multi-technique rules we
# only skip if the exact same source_id is already present.
dedup_key = f"{item['source_id']}::{item['mitre_technique_id']}"
# Deduplicate by source_id: one rule file may map to multiple techniques,
# but we skip insertion if this source_id was already imported.
if item["source_id"] in existing_ids:
# Assign skipped = 1
skipped += 1
# Skip to the next loop iteration
continue
# Assign rule = DetectionRule(
rule = DetectionRule(
# Keyword argument: mitre_technique_id
mitre_technique_id=item["mitre_technique_id"],
# Keyword argument: title
title=item["title"],
# Keyword argument: description
description=item["description"],
# Keyword argument: source
source="sigma",
# Keyword argument: source_id
source_id=item["source_id"],
# Keyword argument: source_url
source_url=item["source_url"],
# Keyword argument: rule_content
rule_content=item["rule_content"],
# Keyword argument: rule_format
rule_format="sigma_yaml",
# Keyword argument: severity
severity=item["severity"],
# Keyword argument: platforms
platforms=item["platforms"],
# Keyword argument: log_sources
log_sources=item["log_sources"],
# Keyword argument: false_positive_rate
false_positive_rate=item["false_positive_rate"],
# Keyword argument: is_active
is_active=True,
)
# Stage new record(s) for database insertion
db.add(rule)
# Call existing_ids.add()
existing_ids.add(item["source_id"])
new_technique_ids.add(item["mitre_technique_id"])
created += 1
@@ -329,30 +483,48 @@ def sync(db: Session) -> dict:
db.commit()
# Assign summary = {
summary = {
# Literal argument value
"created": created,
# Literal argument value
"skipped_existing": skipped,
# Literal argument value
"total_parsed": len(parsed_rules),
}
# Update DataSource record
ds = db.query(DataSource).filter(DataSource.name == "sigma").first()
# Check: ds
if ds:
# Assign ds.last_sync_at = datetime.utcnow()
ds.last_sync_at = datetime.utcnow()
# Assign ds.last_sync_status = "success"
ds.last_sync_status = "success"
# Assign ds.last_sync_stats = summary
ds.last_sync_stats = summary
# Commit all pending changes to the database
db.commit()
# Log info: "Sigma import complete — %s", summary
logger.info("Sigma import complete — %s", summary)
# Call log_action()
log_action(
db,
# Keyword argument: user_id
user_id=None,
# Keyword argument: action
action="import_sigma_rules",
# Keyword argument: entity_type
entity_type="detection_rule",
# Keyword argument: entity_id
entity_id=None,
# Keyword argument: details
details=summary,
)
# Commit all pending changes to the database
db.commit()
# Return summary
return summary
+371 -20
View File
@@ -7,25 +7,59 @@ Uses ``bulk_technique_scores`` so that snapshot creation runs in a fixed
number of SQL queries regardless of technique count.
"""
# Import logging
import logging
# Import uuid
import uuid
# Import defaultdict from collections
from collections import defaultdict
# Import datetime, timedelta, timezone from datetime
from datetime import datetime, timedelta, timezone
# Import func from sqlalchemy
from sqlalchemy import func
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session
# Import EntityNotFoundError from app.domain.errors
from app.domain.errors import EntityNotFoundError
from app.models.technique import Technique
# Import CoverageSnapshot, SnapshotTechniqueState from app.models.coverage_snapshot
from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState
# Import TechniqueStatus from app.models.enums
from app.models.enums import TechniqueStatus
# Import Technique from app.models.technique
from app.models.technique import Technique
# Import from app.services.scoring_service
from app.services.scoring_service import (
bulk_technique_scores,
calculate_organization_score,
)
# Assign logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
# Coverage status ordering for snapshot delta comparisons (higher = better coverage)
_STATUS_ORDER: dict[str, int] = {
# Literal argument value
"not_evaluated": 0,
# Literal argument value
"not_covered": 1,
# Literal argument value
"in_progress": 2,
# Literal argument value
"partial": 3,
# Literal argument value
"validated": 4,
}
# ---------------------------------------------------------------------------
# Serialization and queries
@@ -33,97 +67,207 @@ logger = logging.getLogger(__name__)
def serialize_snapshot_summary(snap: CoverageSnapshot) -> dict:
"""Lightweight serialization for list views."""
"""Return a lightweight serialization of a snapshot for list views.
Args:
snap (CoverageSnapshot): The snapshot ORM object to serialize.
Returns:
dict: Flat dictionary with summary fields (counts, scores, tactic
breakdown) suitable for paginated list responses.
"""
# Return {
return {
# Literal argument value
"id": str(snap.id),
# Literal argument value
"name": snap.name,
# Literal argument value
"organization_score": snap.organization_score,
# Literal argument value
"total_techniques": snap.total_techniques,
# Literal argument value
"validated_count": snap.validated_count,
# Literal argument value
"partial_count": snap.partial_count,
# Literal argument value
"not_covered_count": snap.not_covered_count,
# Literal argument value
"in_progress_count": snap.in_progress_count,
# Literal argument value
"not_evaluated_count": snap.not_evaluated_count,
# Literal argument value
"coverage_percentage": getattr(snap, "coverage_percentage", 0.0),
# Literal argument value
"by_tactic": getattr(snap, "by_tactic", None) or {},
# Literal argument value
"by_status": getattr(snap, "by_status", None) or {},
# Literal argument value
"stale_count": getattr(snap, "stale_count", 0),
# Literal argument value
"never_tested_count": getattr(snap, "never_tested_count", 0),
# Literal argument value
"created_by": str(snap.created_by) if snap.created_by else None,
# Literal argument value
"created_at": snap.created_at.isoformat() if snap.created_at else None,
}
# Define function serialize_snapshot_detail
def serialize_snapshot_detail(db: Session, snap: CoverageSnapshot) -> dict:
"""Full serialization including technique states."""
"""Return full serialization of a snapshot including per-technique states.
Args:
db (Session): Active SQLAlchemy database session.
snap (CoverageSnapshot): The snapshot ORM object to serialize.
Returns:
dict: Summary fields merged with a ``technique_states`` list, each
entry containing ``mitre_id``, ``technique_id``, ``status``,
and ``score``.
"""
# Assign base = serialize_snapshot_summary(snap)
base = serialize_snapshot_summary(snap)
# Assign technique_states = (
technique_states = (
db.query(SnapshotTechniqueState)
# Chain .filter() call
.filter(SnapshotTechniqueState.snapshot_id == snap.id)
# Chain .order_by() call
.order_by(SnapshotTechniqueState.mitre_id)
# Chain .all() call
.all()
)
# Assign base["technique_states"] = [
base["technique_states"] = [
{
# Literal argument value
"mitre_id": s.mitre_id,
# Literal argument value
"technique_id": str(s.technique_id),
# Literal argument value
"status": s.status,
# Literal argument value
"score": s.score,
}
for s in technique_states
]
# Return base
return base
# Define function list_snapshots
def list_snapshots(
# Entry: db
db: Session,
*,
# Entry: offset
offset: int = 0,
# Entry: limit
limit: int = 50,
) -> dict:
"""List coverage snapshots ordered by creation date (newest first)."""
"""List coverage snapshots ordered by creation date (newest first).
Args:
db (Session): Active SQLAlchemy database session.
offset (int): Number of records to skip for pagination.
limit (int): Maximum number of records to return.
Returns:
dict: Contains ``total``, ``offset``, ``limit``, and ``items`` (list
of serialized snapshot summaries).
"""
# Assign query = db.query(CoverageSnapshot)
query = db.query(CoverageSnapshot)
# Assign total = query.count()
total = query.count()
# Assign snapshots = (
snapshots = (
query
# Chain .order_by() call
.order_by(CoverageSnapshot.created_at.desc())
# Chain .offset() call
.offset(offset)
# Chain .limit() call
.limit(limit)
# Chain .all() call
.all()
)
# Return {
return {
# Literal argument value
"total": total,
# Literal argument value
"offset": offset,
# Literal argument value
"limit": limit,
# Literal argument value
"items": [serialize_snapshot_summary(s) for s in snapshots],
}
# Define function get_snapshot_or_raise
def get_snapshot_or_raise(db: Session, snapshot_id: str) -> CoverageSnapshot:
"""Fetch snapshot by ID or raise EntityNotFoundError."""
"""Fetch snapshot by ID or raise EntityNotFoundError.
Args:
db (Session): Active SQLAlchemy database session.
snapshot_id (str): UUID string of the snapshot to retrieve.
Returns:
CoverageSnapshot: The matching snapshot ORM object.
"""
# Attempt the following; catch errors below
try:
# Assign sid = uuid.UUID(snapshot_id)
sid = uuid.UUID(snapshot_id)
# Handle (ValueError, TypeError)
except (ValueError, TypeError):
# Raise EntityNotFoundError
raise EntityNotFoundError("Snapshot", snapshot_id)
# Assign snapshot = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == sid).first()
snapshot = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == sid).first()
# Check: snapshot is None
if snapshot is None:
# Raise EntityNotFoundError
raise EntityNotFoundError("Snapshot", snapshot_id)
# Return snapshot
return snapshot
# Define function get_snapshot_detail
def get_snapshot_detail(db: Session, snapshot_id: str) -> dict:
"""Get detailed snapshot including per-technique states."""
"""Return detailed snapshot data including per-technique states.
Args:
db (Session): Active SQLAlchemy database session.
snapshot_id (str): UUID string of the snapshot to retrieve.
Returns:
dict: Full snapshot serialization from
:func:`serialize_snapshot_detail`.
"""
# Assign snapshot = get_snapshot_or_raise(db, snapshot_id)
snapshot = get_snapshot_or_raise(db, snapshot_id)
# Return serialize_snapshot_detail(db, snapshot)
return serialize_snapshot_detail(db, snapshot)
# Define function delete_snapshot
def delete_snapshot(db: Session, snapshot_id: str) -> None:
"""Delete a snapshot. Does not commit — caller must commit."""
"""Delete a snapshot. Does not commit — caller must commit.
Args:
db (Session): Active SQLAlchemy database session.
snapshot_id (str): UUID string of the snapshot to delete.
"""
# Assign snapshot = get_snapshot_or_raise(db, snapshot_id)
snapshot = get_snapshot_or_raise(db, snapshot_id)
# Mark record for deletion on next commit
db.delete(snapshot)
@@ -133,8 +277,11 @@ def delete_snapshot(db: Session, snapshot_id: str) -> None:
def create_snapshot(
# Entry: db
db: Session,
# Entry: name
name: str | None = None,
# Entry: user_id
user_id: uuid.UUID | None = None,
) -> CoverageSnapshot:
"""Capture the current coverage state into a new snapshot.
@@ -144,121 +291,215 @@ def create_snapshot(
3. Compute the org score from the same bulk data.
4. Persist a ``CoverageSnapshot`` with normalised
``SnapshotTechniqueState`` rows.
Args:
db (Session): Active SQLAlchemy database session.
name (str | None): Optional human-readable label for the snapshot.
user_id (uuid.UUID | None): UUID of the user creating the snapshot,
stored for auditing.
Returns:
CoverageSnapshot: The newly created and committed snapshot ORM object.
"""
# Assign scores_map = bulk_technique_scores(db)
scores_map = bulk_technique_scores(db)
# Assign techniques = db.query(Technique).all()
techniques = db.query(Technique).all()
# Assign validated_count = 0
validated_count = 0
# Assign partial_count = 0
partial_count = 0
# Assign not_covered_count = 0
not_covered_count = 0
# Assign in_progress_count = 0
in_progress_count = 0
# Assign not_evaluated_count = 0
not_evaluated_count = 0
# Assign stale_count = 0
stale_count = 0
# Assign never_tested_count = 0
never_tested_count = 0
# Assign by_tactic = defaultdict(
by_tactic: dict[str, dict] = defaultdict(
# Entry: lambda
lambda: {"total": 0, "validated": 0, "partial": 0, "score_sum": 0.0}
)
# Assign by_status = defaultdict(int)
by_status: dict[str, int] = defaultdict(int)
# Assign technique_rows = []
technique_rows: list[dict] = []
# Iterate over techniques
for tech in techniques:
# Assign status_value = (
status_value = (
tech.status_global.value
if isinstance(tech.status_global, TechniqueStatus)
else (tech.status_global or "not_evaluated")
)
# Check: status_value == "validated"
if status_value == "validated":
# Assign validated_count = 1
validated_count += 1
# Alternative: status_value == "partial"
elif status_value == "partial":
# Assign partial_count = 1
partial_count += 1
# Alternative: status_value == "not_covered"
elif status_value == "not_covered":
# Assign not_covered_count = 1
not_covered_count += 1
# Alternative: status_value == "in_progress"
elif status_value == "in_progress":
# Assign in_progress_count = 1
in_progress_count += 1
# Fallback: handle remaining cases
else:
# Assign not_evaluated_count = 1
not_evaluated_count += 1
# Assign entry = scores_map.get(tech.id, {})
entry = scores_map.get(tech.id, {})
# Assign score = entry.get("total_score", 0)
score = entry.get("total_score", 0)
# Call technique_rows.append()
technique_rows.append({
# Literal argument value
"technique_id": tech.id,
# Literal argument value
"mitre_id": tech.mitre_id,
# Literal argument value
"status": status_value,
# Literal argument value
"score": score,
})
# Assign by_status[status_value] = 1
by_status[status_value] += 1
# Assign tactic_key = tech.tactic or "unknown"
tactic_key = tech.tactic or "unknown"
# Assign bucket = by_tactic[tactic_key]
bucket = by_tactic[tactic_key]
# Assign bucket["total"] = 1
bucket["total"] += 1
# Assign bucket["score_sum"] = score
bucket["score_sum"] += score
# Check: status_value == "validated"
if status_value == "validated":
# Assign bucket["validated"] = 1
bucket["validated"] += 1
# Alternative: status_value == "partial"
elif status_value == "partial":
# Assign bucket["partial"] = 1
bucket["partial"] += 1
# Check: status_value == "not_evaluated"
if status_value == "not_evaluated":
# Assign never_tested_count = 1
never_tested_count += 1
# Check: tech.review_required
if tech.review_required:
# Assign stale_count = 1
stale_count += 1
# Assign org_data = calculate_organization_score(db)
org_data = calculate_organization_score(db)
# Assign org_score = org_data.get("overall_score", 0)
org_score = org_data.get("overall_score", 0)
# Assign total_techniques = len(techniques) or 1
total_techniques = len(techniques) or 1
# Assign coverage_pct = round((validated_count / total_techniques) * 100, 1)
coverage_pct = round((validated_count / total_techniques) * 100, 1)
# Assign by_tactic_out = {
by_tactic_out = {
# Entry: tactic
tactic: {
# Literal argument value
"total": data["total"],
# Literal argument value
"validated": data["validated"],
# Literal argument value
"partial": data["partial"],
# Literal argument value
"average_score": round(data["score_sum"] / data["total"], 1) if data["total"] else 0,
}
for tactic, data in by_tactic.items()
}
# Assign snapshot = CoverageSnapshot(
snapshot = CoverageSnapshot(
# Keyword argument: name
name=name,
# Keyword argument: organization_score
organization_score=org_score,
# Keyword argument: total_techniques
total_techniques=len(techniques),
# Keyword argument: validated_count
validated_count=validated_count,
# Keyword argument: partial_count
partial_count=partial_count,
# Keyword argument: not_covered_count
not_covered_count=not_covered_count,
# Keyword argument: in_progress_count
in_progress_count=in_progress_count,
# Keyword argument: not_evaluated_count
not_evaluated_count=not_evaluated_count,
# Keyword argument: coverage_percentage
coverage_percentage=coverage_pct,
# Keyword argument: by_tactic
by_tactic=by_tactic_out,
# Keyword argument: by_status
by_status=dict(by_status),
# Keyword argument: stale_count
stale_count=stale_count,
# Keyword argument: never_tested_count
never_tested_count=never_tested_count,
# Keyword argument: created_by
created_by=user_id,
)
# Stage new record(s) for database insertion
db.add(snapshot)
# Flush changes to DB without committing the transaction
db.flush()
# Iterate over technique_rows
for row in technique_rows:
# Assign state = SnapshotTechniqueState(
state = SnapshotTechniqueState(
# Keyword argument: snapshot_id
snapshot_id=snapshot.id,
# Keyword argument: technique_id
technique_id=row["technique_id"],
# Keyword argument: mitre_id
mitre_id=row["mitre_id"],
# Keyword argument: status
status=row["status"],
# Keyword argument: score
score=row["score"],
)
# Stage new record(s) for database insertion
db.add(state)
# Commit all pending changes to the database
db.commit()
# Reload ORM object attributes from the database
db.refresh(snapshot)
# Log info:
logger.info(
# Literal argument value
"Snapshot '%s' created — %d techniques, org score %.1f",
snapshot.name or snapshot.id,
len(techniques),
org_score,
)
# Return snapshot
return snapshot
@@ -268,99 +509,160 @@ def create_snapshot(
def compare_snapshots(
# Entry: db
db: Session,
# Entry: snapshot_a_id
snapshot_a_id: uuid.UUID,
# Entry: snapshot_b_id
snapshot_b_id: uuid.UUID,
) -> dict:
"""Compare two snapshots and return deltas.
Returns improved/worsened technique lists plus aggregate statistics.
Args:
db (Session): Active SQLAlchemy database session.
snapshot_a_id (uuid.UUID): UUID of the baseline (older) snapshot.
snapshot_b_id (uuid.UUID): UUID of the comparison (newer) snapshot.
Returns:
dict: Contains ``snapshot_a``, ``snapshot_b``, ``score_delta``,
``improved``, ``worsened``, ``unchanged_count``, and ``summary``
keys.
"""
# Assign snap_a = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_a...
snap_a = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_a_id).first()
# Assign snap_b = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_b...
snap_b = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_b_id).first()
# Check: not snap_a or not snap_b
if not snap_a or not snap_b:
# Raise EntityNotFoundError
raise EntityNotFoundError("Snapshot", f"{snapshot_a_id} or {snapshot_b_id}")
# Build lookup dicts: mitre_id -> {status, score}
states_a = {
s.mitre_id: {"status": s.status, "score": s.score or 0}
for s in db.query(SnapshotTechniqueState)
# Chain .filter() call
.filter(SnapshotTechniqueState.snapshot_id == snapshot_a_id)
# Chain .all() call
.all()
}
# Assign states_b = {
states_b = {
s.mitre_id: {"status": s.status, "score": s.score or 0}
for s in db.query(SnapshotTechniqueState)
# Chain .filter() call
.filter(SnapshotTechniqueState.snapshot_id == snapshot_b_id)
# Chain .all() call
.all()
}
# Status priority for comparison
STATUS_ORDER = {
"not_evaluated": 0,
"not_covered": 1,
"in_progress": 2,
"partial": 3,
"validated": 4,
}
# Assign improved = []
improved = []
# Assign worsened = []
worsened = []
# Assign unchanged_count = 0
unchanged_count = 0
# Assign all_mitre_ids = set(states_a.keys()) | set(states_b.keys())
all_mitre_ids = set(states_a.keys()) | set(states_b.keys())
# Iterate over sorted(all_mitre_ids)
for mitre_id in sorted(all_mitre_ids):
# Assign a = states_a.get(mitre_id, {"status": "not_evaluated", "score": 0})
a = states_a.get(mitre_id, {"status": "not_evaluated", "score": 0})
# Assign b = states_b.get(mitre_id, {"status": "not_evaluated", "score": 0})
b = states_b.get(mitre_id, {"status": "not_evaluated", "score": 0})
a_order = STATUS_ORDER.get(a["status"], 0)
b_order = STATUS_ORDER.get(b["status"], 0)
# Assign a_order = _STATUS_ORDER.get(a["status"], 0)
a_order = _STATUS_ORDER.get(a["status"], 0)
# Assign b_order = _STATUS_ORDER.get(b["status"], 0)
b_order = _STATUS_ORDER.get(b["status"], 0)
# Check: b_order > a_order or (b_order == a_order and b["score"] > a["score"])
if b_order > a_order or (b_order == a_order and b["score"] > a["score"]):
# Call improved.append()
improved.append({
# Literal argument value
"mitre_id": mitre_id,
# Literal argument value
"old_status": a["status"],
# Literal argument value
"new_status": b["status"],
# Literal argument value
"old_score": a["score"],
# Literal argument value
"new_score": b["score"],
})
# Alternative: b_order < a_order or (b_order == a_order and b["score"] < a["score"])
elif b_order < a_order or (b_order == a_order and b["score"] < a["score"]):
# Call worsened.append()
worsened.append({
# Literal argument value
"mitre_id": mitre_id,
# Literal argument value
"old_status": a["status"],
# Literal argument value
"new_status": b["status"],
# Literal argument value
"old_score": a["score"],
# Literal argument value
"new_score": b["score"],
})
# Fallback: handle remaining cases
else:
# Assign unchanged_count = 1
unchanged_count += 1
# Define function _snap_summary
def _snap_summary(snap: CoverageSnapshot) -> dict:
# Return {
return {
# Literal argument value
"id": str(snap.id),
# Literal argument value
"name": snap.name,
# Literal argument value
"organization_score": snap.organization_score,
# Literal argument value
"total_techniques": snap.total_techniques,
# Literal argument value
"validated_count": snap.validated_count,
# Literal argument value
"partial_count": snap.partial_count,
# Literal argument value
"not_covered_count": snap.not_covered_count,
# Literal argument value
"in_progress_count": snap.in_progress_count,
# Literal argument value
"not_evaluated_count": snap.not_evaluated_count,
# Literal argument value
"created_at": snap.created_at.isoformat() if snap.created_at else None,
}
# Return {
return {
# Literal argument value
"snapshot_a": _snap_summary(snap_a),
# Literal argument value
"snapshot_b": _snap_summary(snap_b),
# Literal argument value
"score_delta": round(snap_b.organization_score - snap_a.organization_score, 1),
# Literal argument value
"improved": improved,
# Literal argument value
"worsened": worsened,
# Literal argument value
"unchanged_count": unchanged_count,
# Literal argument value
"summary": {
# Literal argument value
"improved_count": len(improved),
# Literal argument value
"worsened_count": len(worsened),
# Literal argument value
"new_count": len(states_b.keys() - states_a.keys()),
},
}
@@ -372,25 +674,53 @@ def compare_snapshots(
def get_coverage_evolution(db: Session, *, months: int = 12) -> list[dict]:
"""Return snapshot trend points for the last *months* months."""
"""Return snapshot trend points for the last *months* months.
Args:
db (Session): Active SQLAlchemy database session.
months (int): Number of months to look back; defaults to 12.
Returns:
list[dict]: Snapshot trend entries ordered by creation date ascending,
each containing ``date``, ``name``, ``org_score``,
``coverage_pct``, ``by_tactic``, ``by_status``,
``stale_count``, ``never_tested_count``, ``validated_count``,
and ``total_techniques``.
"""
# Assign cutoff = datetime.now(timezone.utc) - timedelta(days=months * 30)
cutoff = datetime.now(timezone.utc) - timedelta(days=months * 30)
# Assign snapshots = (
snapshots = (
db.query(CoverageSnapshot)
# Chain .filter() call
.filter(CoverageSnapshot.created_at >= cutoff)
# Chain .order_by() call
.order_by(CoverageSnapshot.created_at.asc())
# Chain .all() call
.all()
)
# Return [
return [
{
# Literal argument value
"date": snap.created_at.isoformat() if snap.created_at else None,
# Literal argument value
"name": snap.name,
# Literal argument value
"org_score": snap.organization_score,
# Literal argument value
"coverage_pct": getattr(snap, "coverage_percentage", 0.0),
# Literal argument value
"by_tactic": getattr(snap, "by_tactic", None) or {},
# Literal argument value
"by_status": getattr(snap, "by_status", None) or {},
# Literal argument value
"stale_count": getattr(snap, "stale_count", 0),
# Literal argument value
"never_tested_count": getattr(snap, "never_tested_count", 0),
# Literal argument value
"validated_count": snap.validated_count,
# Literal argument value
"total_techniques": snap.total_techniques,
}
for snap in snapshots
@@ -405,25 +735,46 @@ def get_coverage_evolution(db: Session, *, months: int = 12) -> list[dict]:
def cleanup_old_snapshots(db: Session, keep_last: int = 52) -> int:
"""Delete oldest snapshots, keeping the most recent *keep_last*.
Returns the number of snapshots deleted.
Args:
db (Session): Active SQLAlchemy database session.
keep_last (int): Number of most-recent snapshots to retain; defaults
to 52 (one year of weekly snapshots).
Returns:
int: Number of snapshots deleted.
"""
# Assign total = db.query(func.count(CoverageSnapshot.id)).scalar() or 0
total = db.query(func.count(CoverageSnapshot.id)).scalar() or 0
# Check: total <= keep_last
if total <= keep_last:
# Return 0
return 0
# Assign to_delete = total - keep_last
to_delete = total - keep_last
# Assign old_snapshots = (
old_snapshots = (
db.query(CoverageSnapshot)
# Chain .order_by() call
.order_by(CoverageSnapshot.created_at.asc())
# Chain .limit() call
.limit(to_delete)
# Chain .all() call
.all()
)
# Assign deleted = 0
deleted = 0
# Iterate over old_snapshots
for snap in old_snapshots:
# Mark record for deletion on next commit
db.delete(snap)
# Assign deleted = 1
deleted += 1
# Commit all pending changes to the database
db.commit()
# Log info: "Snapshot cleanup — deleted %d old snapshots (kept
logger.info("Snapshot cleanup — deleted %d old snapshots (kept %d)", deleted, keep_last)
# Return deleted
return deleted
@@ -1,26 +1,41 @@
"""Stale coverage detection — marks techniques whose last validated test
is older than a configurable threshold.
"""Stale coverage detection — marks techniques whose last validated test is older than a configurable threshold.
This is the simple version. The full Decay Engine (Fase 8) will replace
this with a multi-factor, configurable decay model with confidence scores.
"""
# Import logging
import logging
# Import datetime, timedelta, timezone from datetime
from datetime import datetime, timedelta, timezone
# Import func from sqlalchemy
from sqlalchemy import func
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session
# Import settings from app.config
from app.config import settings
# Import TechniqueStatus, TestState from app.models.enums
from app.models.enums import TechniqueStatus, TestState
# Import Technique from app.models.technique
from app.models.technique import Technique
# Import Test from app.models.test
from app.models.test import Test
# Assign logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
# Assign STALE_THRESHOLD_DAYS = settings.STALE_THRESHOLD_DAYS
STALE_THRESHOLD_DAYS = settings.STALE_THRESHOLD_DAYS
# Define function detect_stale_coverage
def detect_stale_coverage(db: Session) -> int:
"""Scan all techniques and flag those with stale coverage.
@@ -30,10 +45,17 @@ def detect_stale_coverage(db: Session) -> int:
- It has never had a validated test (but has been manually marked as
covered/partial).
Returns the number of newly-flagged techniques.
Args:
db (Session): Active SQLAlchemy database session.
Returns:
int: Number of techniques newly flagged as stale (``review_required``
set to ``True``) in this run.
"""
# Assign cutoff = datetime.now(timezone.utc) - timedelta(days=STALE_THRESHOLD_DAYS)
cutoff = datetime.now(timezone.utc) - timedelta(days=STALE_THRESHOLD_DAYS)
# Assign last_validated = func.coalesce(
last_validated = func.coalesce(
Test.blue_validated_at,
Test.red_validated_at,
@@ -46,40 +68,60 @@ def detect_stale_coverage(db: Session) -> int:
Test.technique_id,
func.max(last_validated).label("last_tested"),
)
# Chain .filter() call
.filter(Test.state == TestState.validated)
# Chain .group_by() call
.group_by(Test.technique_id)
# Chain .subquery() call
.subquery()
)
# Find techniques that are stale
stale_techniques = (
db.query(Technique)
# Chain .outerjoin() call
.outerjoin(latest_test, Technique.id == latest_test.c.technique_id)
# Chain .filter() call
.filter(
# Either tested before cutoff, or never tested at all
(latest_test.c.last_tested < cutoff)
| (latest_test.c.last_tested.is_(None))
)
# Chain .filter() call
.filter(
# Only flag techniques that have a real status (not never-evaluated ones)
Technique.status_global != TechniqueStatus.not_evaluated
)
# Chain .all() call
.all()
)
# Assign count = 0
count = 0
# Iterate over stale_techniques
for tech in stale_techniques:
# Check: not tech.review_required
if not tech.review_required:
# Assign tech.review_required = True
tech.review_required = True
# Assign count = 1
count += 1
# Log info: "Marked %s as stale coverage", tech.mitre_id
logger.info("Marked %s as stale coverage", tech.mitre_id)
# Check: count > 0
if count > 0:
# Commit all pending changes to the database
db.commit()
# Log info:
logger.info(
# Literal argument value
"Stale coverage detection complete — %d techniques flagged", count
)
# Fallback: handle remaining cases
else:
# Log info: "Stale coverage detection complete — no new stale
logger.info("Stale coverage detection complete — no new stale techniques")
# Return count
return count
+9
View File
@@ -10,21 +10,30 @@ The function mutates the technique but does **not** commit.
The caller is responsible for committing the session.
"""
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session
# Import TechniqueEntity from app.domain.entities.technique
from app.domain.entities.technique import TechniqueEntity
# Import Technique from app.models.technique
from app.models.technique import Technique
# Define function recalculate_technique_status
def recalculate_technique_status(db: Session, technique: Technique) -> None:
"""Recompute ``technique.status_global`` from its tests.
``db`` is accepted for backward compatibility but is not used
directly test data comes from the ORM relationship.
"""
# Assign entity = TechniqueEntity.from_orm(technique)
entity = TechniqueEntity.from_orm(technique)
# Assign test_snapshots = [
test_snapshots = [
(t.state, t.detection_result) for t in technique.tests
]
# Call entity.recalculate_status()
entity.recalculate_status(test_snapshots)
# Assign technique.status_global = entity.status_global
technique.status_global = entity.status_global
@@ -1,9 +1,14 @@
"""Technique query service — framework-agnostic queries for technique details."""
# Enable future language features for compatibility
from __future__ import annotations
# Import Session, joinedload from sqlalchemy.orm
from sqlalchemy.orm import Session, joinedload
# Import EntityNotFoundError from app.domain.errors
from app.domain.errors import EntityNotFoundError
# Import Technique from app.models.technique
from app.models.technique import Technique
from app.models.detection_rule import DetectionRule
from app.models.intel import IntelItem
@@ -13,15 +18,21 @@ from app.services.d3fend_import_service import get_defenses_for_technique
_SEVERITY_ORDER = {"critical": 0, "high": 1, "medium": 2, "low": 3, "informational": 4, None: 5}
# Define function get_technique_detail
def get_technique_detail(db: Session, mitre_id: str) -> dict:
"""Fetch full technique details including tests, detection rules, and D3FEND defenses."""
technique = (
db.query(Technique)
# Chain .options() call
.options(joinedload(Technique.tests))
# Chain .filter() call
.filter(Technique.mitre_id == mitre_id)
# Chain .first() call
.first()
)
# Check: technique is None
if technique is None:
# Raise EntityNotFoundError
raise EntityNotFoundError("Technique", mitre_id)
defenses = get_defenses_for_technique(db, technique.id)
@@ -49,26 +60,46 @@ def get_technique_detail(db: Session, mitre_id: str) -> dict:
)
return {
# Literal argument value
"id": str(technique.id),
# Literal argument value
"mitre_id": technique.mitre_id,
# Literal argument value
"name": technique.name,
# Literal argument value
"description": technique.description,
# Literal argument value
"tactic": technique.tactic,
# Literal argument value
"platforms": technique.platforms or [],
# Literal argument value
"mitre_version": technique.mitre_version,
# Literal argument value
"mitre_last_modified": technique.mitre_last_modified,
# Literal argument value
"is_subtechnique": technique.is_subtechnique,
# Literal argument value
"parent_mitre_id": technique.parent_mitre_id,
# Literal argument value
"status_global": technique.status_global.value if technique.status_global else "not_evaluated",
# Literal argument value
"review_required": technique.review_required,
# Literal argument value
"last_review_date": technique.last_review_date,
# Literal argument value
"tests": [
{
# Literal argument value
"id": str(t.id),
# Literal argument value
"name": t.name,
# Literal argument value
"state": t.state.value if t.state else None,
# Literal argument value
"result": t.result.value if t.result else None,
# Literal argument value
"platform": t.platform,
# Literal argument value
"created_at": t.created_at.isoformat() if t.created_at else None,
}
for t in technique.tests
+40 -3
View File
@@ -18,15 +18,31 @@ blue_work_started_at) to when they submit, so it reflects actual working time
rather than queue time.
"""
# Import logging
import logging
from typing import Optional
# Import Any, Optional from typing
from typing import Any, Optional
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session
# Import settings from app.config
from app.config import settings
# Import InvalidOperationError from app.domain.exceptions
from app.domain.exceptions import InvalidOperationError
# Import JiraLink, JiraLinkEntityType from app.models.jira_link
from app.models.jira_link import JiraLink, JiraLinkEntityType
# Import Test from app.models.test
from app.models.test import Test
# Import User from app.models.user
from app.models.user import User
# Assign logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
# Only red team execution time goes to Tempo.
@@ -85,23 +101,31 @@ def get_user_tempo_client(user, db=None):
"Add it in Settings → Profile → Tempo Integration."
)
try:
# Import client_v4 as tempo_client from tempoapiclient
from tempoapiclient import client_v4 as tempo_client
base_url = _get_tempo_base_url(db)
logger.debug("Using Tempo base URL: %s", base_url)
return tempo_client.Tempo(auth_token=token, base_url=base_url)
except ImportError:
# Raise InvalidOperationError
raise InvalidOperationError(
# Literal argument value
"tempo-api-python-client is not installed. "
"Run: pip install tempo-api-python-client"
)
# Define function log_worklog
def log_worklog(
user,
jira_issue_id: int,
# Entry: author_account_id
author_account_id: str,
# Entry: date
date: str,
# Entry: time_spent_seconds
time_spent_seconds: int,
# Entry: description
description: str,
db=None,
) -> dict:
@@ -128,10 +152,15 @@ def log_worklog(
raise RuntimeError(f"Tempo API error: {exc}") from exc
# Define function auto_log_test_worklog
def auto_log_test_worklog(
# Entry: db
db: Session,
test,
user,
# Entry: test
test: Test,
# Entry: user
user: User,
# Entry: activity_type
activity_type: str,
duration_seconds: int,
) -> Optional[dict]:
@@ -156,6 +185,7 @@ def auto_log_test_worklog(
# Global kill-switch
if not settings.TEMPO_ENABLED:
# Return None
return None
if duration_seconds <= 0:
@@ -183,15 +213,20 @@ def auto_log_test_worklog(
# Need a Jira link with a numeric issue ID
link = (
db.query(JiraLink)
# Chain .filter() call
.filter(
JiraLink.entity_id == test.id,
JiraLink.entity_type == JiraLinkEntityType.test,
)
# Chain .first() call
.first()
)
# Check: not link or not link.jira_issue_id
if not link or not link.jira_issue_id:
# Log debug: "No Jira link for test %s, skipping Tempo worklog"
logger.debug("No Jira link for test %s, skipping Tempo worklog", test.id)
# Return None
return None
jira_account_id = (getattr(user, "jira_account_id", "") or "").strip()
@@ -202,6 +237,7 @@ def auto_log_test_worklog(
)
return None
# Attempt the following; catch errors below
try:
# Use the phase start timestamp as the worklog date so it matches when
# work actually happened (not the submission timestamp).
@@ -231,6 +267,7 @@ def auto_log_test_worklog(
test.id, getattr(user, "username", user), duration_seconds, work_date,
)
return result
# Handle Exception
except Exception as e:
logger.warning(
"Tempo worklog failed for test %s (user %s): %s",
+228 -6
View File
@@ -4,12 +4,15 @@ Framework-agnostic; uses domain exceptions from app.domain.errors.
The router is responsible for HTTP concerns, auth, audit logging, and commit.
"""
# Import uuid
import uuid
from datetime import datetime
from typing import Any
# Import Session, joinedload from sqlalchemy.orm
from sqlalchemy.orm import Session, joinedload
# Import from app.domain.errors
from app.domain.errors import (
BusinessRuleViolation,
EntityNotFoundError,
@@ -21,19 +24,41 @@ from app.models.test import Test
from app.models.test_template import TestTemplate
from app.models.campaign import Campaign, CampaignTest
from app.models.audit import AuditLog
# Import TestState from app.models.enums
from app.models.enums import TestState
# Import Technique from app.models.technique
from app.models.technique import Technique
# Import Test from app.models.test
from app.models.test import Test
# Import TestTemplate from app.models.test_template
from app.models.test_template import TestTemplate
# Import escape_like from app.utils
from app.utils import escape_like
# Define function list_tests
def list_tests(
# Entry: db
db: Session,
*,
# Entry: state
state: str | None = None,
# Entry: technique_id
technique_id: uuid.UUID | None = None,
# Entry: platform
platform: str | None = None,
# Entry: created_by
created_by: uuid.UUID | None = None,
# Entry: pending_validation_side
pending_validation_side: str | None = None,
not_in_any_campaign: bool = False,
offset: int = 0,
# Entry: limit
limit: int = 50,
) -> list[Test]:
"""Return a paginated list of tests with optional filters.
@@ -44,20 +69,32 @@ def list_tests(
"""
query = db.query(Test).options(joinedload(Test.technique))
# Check: state
if state:
# Assign query = query.filter(Test.state == state)
query = query.filter(Test.state == state)
# Check: technique_id
if technique_id:
# Assign query = query.filter(Test.technique_id == technique_id)
query = query.filter(Test.technique_id == technique_id)
# Check: platform
if platform:
# Assign query = query.filter(Test.platform.ilike(f"%{escape_like(platform)}%"))
query = query.filter(Test.platform.ilike(f"%{escape_like(platform)}%"))
# Check: created_by
if created_by:
# Assign query = query.filter(Test.created_by == created_by)
query = query.filter(Test.created_by == created_by)
# Check: pending_validation_side == "red"
if pending_validation_side == "red":
# Assign query = query.filter(
query = query.filter(
Test.state == TestState.in_review,
Test.red_validation_status.in_(["pending", None]),
)
# Alternative: pending_validation_side == "blue"
elif pending_validation_side == "blue":
# Assign query = query.filter(
query = query.filter(
Test.state == TestState.in_review,
Test.blue_validation_status.in_(["pending", None]),
@@ -82,42 +119,72 @@ def list_tests(
)
query = query.filter(~Test.id.in_(future_draft_tests))
# Return query.order_by(Test.created_at.desc()).offset(offset).limit(limit)....
return query.order_by(Test.created_at.desc()).offset(offset).limit(limit).all()
# Define function create_test
def create_test(
# Entry: db
db: Session,
*,
# Entry: technique_id
technique_id: uuid.UUID,
# Entry: creator_id
creator_id: uuid.UUID,
**fields: Any,
**fields: object,
) -> Test:
"""Create a new test linked to an existing technique.
Raises EntityNotFoundError if the technique does not exist.
Does not commit; caller uses UnitOfWork.
Args:
db (Session): Active SQLAlchemy database session.
technique_id (uuid.UUID): UUID of the technique this test covers.
creator_id (uuid.UUID): UUID of the user creating the test.
**fields (object): Additional keyword arguments set as attributes on
the new test (e.g. ``name``, ``platform``, ``description``).
Returns:
Test: The newly created test ORM object, flushed but not committed.
"""
# Assign technique = db.query(Technique).filter(Technique.id == technique_id).first()
technique = db.query(Technique).filter(Technique.id == technique_id).first()
# Check: technique is None
if technique is None:
# Raise EntityNotFoundError
raise EntityNotFoundError("Technique", str(technique_id))
# Assign test = Test(
test = Test(
# Keyword argument: technique_id
technique_id=technique_id,
# Keyword argument: created_by
created_by=creator_id,
# Keyword argument: state
state=TestState.draft,
created_at=datetime.utcnow(), # explicit — DB column has no server default
**fields,
)
# Stage new record(s) for database insertion
db.add(test)
# Flush changes to DB without committing the transaction
db.flush()
# Return test
return test
# Define function create_test_from_template
def create_test_from_template(
# Entry: db
db: Session,
*,
# Entry: template_id
template_id: uuid.UUID,
# Entry: technique_id_or_mitre
technique_id_or_mitre: str,
# Entry: creator_id
creator_id: uuid.UUID,
# Optional user-edited overrides (take priority over template values)
name_override: str | None = None,
@@ -132,27 +199,53 @@ def create_test_from_template(
Override fields, when provided, take precedence over the template's values.
Raises EntityNotFoundError if template or technique not found.
Does not commit; caller uses UnitOfWork.
Args:
db (Session): Active SQLAlchemy database session.
template_id (uuid.UUID): UUID of the template to instantiate.
technique_id_or_mitre (str): UUID string or MITRE technique ID
(e.g. ``"T1059.001"``) identifying the target technique.
creator_id (uuid.UUID): UUID of the user creating the test.
Returns:
Test: The newly created test populated from template fields, flushed
but not committed.
"""
# Assign template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first()
template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first()
# Check: template is None
if template is None:
# Raise EntityNotFoundError
raise EntityNotFoundError("TestTemplate", str(template_id))
# Assign technique = None
technique = None
# Attempt the following; catch errors below
try:
# Assign technique_uuid = uuid.UUID(technique_id_or_mitre)
technique_uuid = uuid.UUID(technique_id_or_mitre)
# Assign technique = db.query(Technique).filter(Technique.id == technique_uuid).first()
technique = db.query(Technique).filter(Technique.id == technique_uuid).first()
# Handle ValueError
except ValueError:
# Intentional no-op placeholder
pass
# Check: technique is None
if technique is None:
# Assign technique = db.query(Technique).filter(
technique = db.query(Technique).filter(
Technique.mitre_id == technique_id_or_mitre
).first()
# Check: technique is None
if technique is None:
# Raise EntityNotFoundError
raise EntityNotFoundError("Technique", technique_id_or_mitre)
# Assign test = Test(
test = Test(
# Keyword argument: technique_id
technique_id=technique.id,
name=name_override if name_override is not None else template.name,
description=description_override if description_override is not None else template.description,
@@ -160,59 +253,111 @@ def create_test_from_template(
procedure_text=procedure_text_override if procedure_text_override is not None else template.attack_procedure,
tool_used=tool_used_override if tool_used_override is not None else template.tool_suggested,
remediation_steps=template.suggested_remediation,
# Keyword argument: created_by
created_by=creator_id,
# Keyword argument: state
state=TestState.draft,
created_at=datetime.utcnow(), # explicit — DB column has no server default
)
# Stage new record(s) for database insertion
db.add(test)
# Flush changes to DB without committing the transaction
db.flush()
# Return test
return test
# Define function get_test_detail
def get_test_detail(db: Session, test_id: uuid.UUID) -> Test:
"""Fetch a test with evidences and technique eager-loaded.
Raises EntityNotFoundError if the test does not exist.
Args:
db (Session): Active SQLAlchemy database session.
test_id (uuid.UUID): UUID of the test to retrieve.
Returns:
Test: The test ORM object with ``evidences`` relationship loaded.
"""
# Assign test = (
test = (
db.query(Test)
.options(joinedload(Test.evidences), joinedload(Test.technique))
.filter(Test.id == test_id)
# Chain .first() call
.first()
)
# Check: test is None
if test is None:
# Raise EntityNotFoundError
raise EntityNotFoundError("Test", str(test_id))
# Return test
return test
# Define function get_test_or_raise
def get_test_or_raise(db: Session, test_id: uuid.UUID) -> Test:
"""Fetch a test by ID. Raises EntityNotFoundError if not found."""
"""Fetch a test by ID. Raises EntityNotFoundError if not found.
Args:
db (Session): Active SQLAlchemy database session.
test_id (uuid.UUID): UUID of the test to retrieve.
Returns:
Test: The matching test ORM object.
"""
# Assign test = db.query(Test).filter(Test.id == test_id).first()
test = db.query(Test).filter(Test.id == test_id).first()
# Check: test is None
if test is None:
# Raise EntityNotFoundError
raise EntityNotFoundError("Test", str(test_id))
# Return test
return test
# Define function get_test_with_technique
def get_test_with_technique(db: Session, test_id: uuid.UUID) -> Test:
"""Fetch a test with technique joined. Raises EntityNotFoundError if not found."""
"""Fetch a test with technique joined. Raises EntityNotFoundError if not found.
Args:
db (Session): Active SQLAlchemy database session.
test_id (uuid.UUID): UUID of the test to retrieve.
Returns:
Test: The test ORM object with ``technique`` relationship loaded.
"""
# Assign test = (
test = (
db.query(Test)
# Chain .options() call
.options(joinedload(Test.technique))
# Chain .filter() call
.filter(Test.id == test_id)
# Chain .first() call
.first()
)
# Check: test is None
if test is None:
# Raise EntityNotFoundError
raise EntityNotFoundError("Test", str(test_id))
# Return test
return test
# Define function update_test
def update_test(
# Entry: db
db: Session,
# Entry: test_id
test_id: uuid.UUID,
*,
# Entry: updater_id
updater_id: uuid.UUID,
# Entry: updater_role
updater_role: str,
**fields: Any,
**fields: object,
) -> Test:
"""Update general test fields (draft or rejected only).
@@ -220,93 +365,170 @@ def update_test(
Raises BusinessRuleViolation if state is not draft or rejected.
Raises EntityNotFoundError if test not found.
Does not commit; caller uses UnitOfWork.
Args:
db (Session): Active SQLAlchemy database session.
test_id (uuid.UUID): UUID of the test to update.
updater_id (uuid.UUID): UUID of the user performing the update.
updater_role (str): Role of the updater; ``"admin"`` bypasses the
creator-only restriction.
**fields (object): Keyword arguments mapped directly onto test
attributes.
Returns:
Test: The updated test ORM object, flushed but not committed.
"""
# Assign test = get_test_or_raise(db, test_id)
test = get_test_or_raise(db, test_id)
# Check: updater_role != "admin" and test.created_by != updater_id
if updater_role != "admin" and test.created_by != updater_id:
# Raise PermissionViolation
raise PermissionViolation(
# Literal argument value
"Only the test creator or an admin can update this test"
)
# Check: test.state not in (TestState.draft, TestState.rejected)
if test.state not in (TestState.draft, TestState.rejected):
# Raise BusinessRuleViolation
raise BusinessRuleViolation(
f"Cannot update a test in '{test.state.value}' state (must be draft or rejected)"
)
# Iterate over fields.items()
for field, value in fields.items():
# Call setattr()
setattr(test, field, value)
# Flush changes to DB without committing the transaction
db.flush()
# Return test
return test
def update_test_red(db: Session, test_id: uuid.UUID, **fields: Any) -> Test:
# Define function update_test_red
def update_test_red(db: Session, test_id: uuid.UUID, **fields: object) -> Test:
"""Update Red Team fields (draft or red_executing only).
Raises BusinessRuleViolation if state not in (draft, red_executing).
Raises EntityNotFoundError if test not found.
Does not commit; caller uses UnitOfWork.
Args:
db (Session): Active SQLAlchemy database session.
test_id (uuid.UUID): UUID of the test to update.
**fields (object): Red-team field names and their new values.
Returns:
Test: The updated test ORM object, flushed but not committed.
"""
# Assign test = get_test_or_raise(db, test_id)
test = get_test_or_raise(db, test_id)
# Check: test.state not in (TestState.draft, TestState.red_executing)
if test.state not in (TestState.draft, TestState.red_executing):
# Raise BusinessRuleViolation
raise BusinessRuleViolation(
f"Cannot update red fields in '{test.state.value}' state "
# Literal argument value
"(must be draft or red_executing)"
)
# Iterate over fields.items()
for field, value in fields.items():
# Call setattr()
setattr(test, field, value)
# Flush changes to DB without committing the transaction
db.flush()
# Return test
return test
def update_test_blue(db: Session, test_id: uuid.UUID, **fields: Any) -> Test:
# Define function update_test_blue
def update_test_blue(db: Session, test_id: uuid.UUID, **fields: object) -> Test:
"""Update Blue Team fields (blue_evaluating only).
Raises BusinessRuleViolation if state is not blue_evaluating.
Raises EntityNotFoundError if test not found.
Does not commit; caller uses UnitOfWork.
Args:
db (Session): Active SQLAlchemy database session.
test_id (uuid.UUID): UUID of the test to update.
**fields (object): Blue-team field names and their new values.
Returns:
Test: The updated test ORM object, flushed but not committed.
"""
# Assign test = get_test_or_raise(db, test_id)
test = get_test_or_raise(db, test_id)
# Check: test.state != TestState.blue_evaluating
if test.state != TestState.blue_evaluating:
# Raise BusinessRuleViolation
raise BusinessRuleViolation(
f"Cannot update blue fields in '{test.state.value}' state "
# Literal argument value
"(must be blue_evaluating)"
)
# Iterate over fields.items()
for field, value in fields.items():
# Call setattr()
setattr(test, field, value)
# Flush changes to DB without committing the transaction
db.flush()
# Return test
return test
# Define function get_test_timeline
def get_test_timeline(db: Session, test_id: uuid.UUID) -> list[dict[str, Any]]:
"""Return chronological audit-log history for a test.
Raises EntityNotFoundError if the test does not exist.
Args:
db (Session): Active SQLAlchemy database session.
test_id (uuid.UUID): UUID of the test whose history is requested.
Returns:
list[dict[str, Any]]: Audit-log entries ordered by timestamp ascending,
each containing ``id``, ``action``, ``user_id``, ``timestamp``,
and ``details``.
"""
# Call get_test_or_raise()
get_test_or_raise(db, test_id)
# Assign logs = (
logs = (
db.query(AuditLog)
# Chain .filter() call
.filter(
AuditLog.entity_type == "test",
AuditLog.entity_id == str(test_id),
)
# Chain .order_by() call
.order_by(AuditLog.timestamp.asc())
# Chain .all() call
.all()
)
# Return [
return [
{
# Literal argument value
"id": str(log.id),
# Literal argument value
"action": log.action,
# Literal argument value
"user_id": str(log.user_id) if log.user_id else None,
# Literal argument value
"timestamp": log.timestamp.isoformat() if log.timestamp else None,
# Literal argument value
"details": log.details,
}
for log in logs
@@ -1,43 +1,77 @@
"""Test template service — framework-agnostic CRUD and queries."""
# Enable future language features for compatibility
from __future__ import annotations
# Import uuid
import uuid
# Import func, or_ from sqlalchemy
from sqlalchemy import func, or_
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session
# Import EntityNotFoundError from app.domain.errors
from app.domain.errors import EntityNotFoundError
# Import TestTemplate from app.models.test_template
from app.models.test_template import TestTemplate
# Import escape_like from app.utils
from app.utils import escape_like
# Define function list_templates
def list_templates(
# Entry: db
db: Session,
*,
# Entry: source
source: str | None = None,
# Entry: platform
platform: str | None = None,
# Entry: severity
severity: str | None = None,
# Entry: mitre_technique_id
mitre_technique_id: str | None = None,
# Entry: search
search: str | None = None,
# Entry: is_active
is_active: bool | None = None,
# Entry: offset
offset: int = 0,
# Entry: limit
limit: int = 50,
) -> list:
"""Return paginated, filterable list of test templates."""
# Assign query = db.query(TestTemplate)
query = db.query(TestTemplate)
# Check: is_active is not None
if is_active is not None:
# Assign query = query.filter(TestTemplate.is_active == is_active)
query = query.filter(TestTemplate.is_active == is_active)
# Check: source
if source:
# Assign query = query.filter(TestTemplate.source == source)
query = query.filter(TestTemplate.source == source)
# Check: platform
if platform:
# Assign query = query.filter(TestTemplate.platform.ilike(f"%{escape_like(platform)}...
query = query.filter(TestTemplate.platform.ilike(f"%{escape_like(platform)}%"))
# Check: severity
if severity:
# Assign query = query.filter(TestTemplate.severity == severity)
query = query.filter(TestTemplate.severity == severity)
# Check: mitre_technique_id
if mitre_technique_id:
# Assign query = query.filter(TestTemplate.mitre_technique_id == mitre_technique_id)
query = query.filter(TestTemplate.mitre_technique_id == mitre_technique_id)
# Check: search
if search:
# Assign pattern = f"%{escape_like(search)}%"
pattern = f"%{escape_like(search)}%"
# Assign query = query.filter(
query = query.filter(
or_(
TestTemplate.name.ilike(pattern),
@@ -45,106 +79,166 @@ def list_templates(
)
)
# Assign templates = (
templates = (
query
# Chain .order_by() call
.order_by(TestTemplate.mitre_technique_id, TestTemplate.name)
# Chain .offset() call
.offset(offset)
# Chain .limit() call
.limit(limit)
# Chain .all() call
.all()
)
# Return templates
return templates
# Define function get_template_stats
def get_template_stats(db: Session) -> dict:
"""Return catalog statistics: totals by source, platform, active/inactive."""
# Assign total = db.query(func.count(TestTemplate.id)).scalar() or 0
total = db.query(func.count(TestTemplate.id)).scalar() or 0
# Assign active = (
active = (
db.query(func.count(TestTemplate.id))
# Chain .filter() call
.filter(TestTemplate.is_active == True) # noqa: E712
# Chain .scalar() call
.scalar()
) or 0
# Assign inactive = total - active
inactive = total - active
# Assign source_rows = (
source_rows = (
db.query(TestTemplate.source, func.count(TestTemplate.id))
# Chain .filter() call
.filter(TestTemplate.is_active == True) # noqa: E712
# Chain .group_by() call
.group_by(TestTemplate.source)
# Chain .all() call
.all()
)
# Assign by_source = {source: cnt for source, cnt in source_rows}
by_source = {source: cnt for source, cnt in source_rows}
# Assign platform_rows = (
platform_rows = (
db.query(TestTemplate.platform, func.count(TestTemplate.id))
# Chain .filter() call
.filter(TestTemplate.is_active == True) # noqa: E712
# Chain .group_by() call
.group_by(TestTemplate.platform)
# Chain .all() call
.all()
)
# Assign by_platform = {(platform or "unspecified"): cnt for platform, cnt in platform_rows}
by_platform = {(platform or "unspecified"): cnt for platform, cnt in platform_rows}
# Return {
return {
# Literal argument value
"total": total,
# Literal argument value
"active": active,
# Literal argument value
"inactive": inactive,
# Literal argument value
"by_source": by_source,
# Literal argument value
"by_platform": by_platform,
}
# Define function bulk_activate
def bulk_activate(db: Session, *, activate: bool) -> int:
"""Set all templates to active or inactive. Returns count of affected. Does NOT commit."""
# Assign count = (
count = (
db.query(TestTemplate)
# Chain .filter() call
.filter(TestTemplate.is_active != activate)
# Chain .update() call
.update({TestTemplate.is_active: activate})
)
# Return count
return count
# Define function get_templates_by_technique
def get_templates_by_technique(db: Session, mitre_id: str) -> list:
"""Return all active templates mapped to a specific MITRE technique."""
# Return (
return (
db.query(TestTemplate)
# Chain .filter() call
.filter(
TestTemplate.mitre_technique_id == mitre_id,
TestTemplate.is_active == True, # noqa: E712
)
# Chain .order_by() call
.order_by(TestTemplate.name)
# Chain .all() call
.all()
)
# Define function get_template_or_raise
def get_template_or_raise(db: Session, template_id: uuid.UUID) -> TestTemplate:
"""Return a template by ID. Raises EntityNotFoundError if not found."""
# Assign template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first()
template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first()
# Check: template is None
if template is None:
# Raise EntityNotFoundError
raise EntityNotFoundError("Test template", str(template_id))
# Return template
return template
# Define function create_template
def create_template(db: Session, **fields: object) -> TestTemplate:
"""Create a test template from keyword args (e.g. payload.model_dump()). Does NOT commit."""
# Assign template = TestTemplate(**fields)
template = TestTemplate(**fields)
# Stage new record(s) for database insertion
db.add(template)
# Return template
return template
# Define function update_template
def update_template(db: Session, template_id: uuid.UUID, **fields: object) -> TestTemplate:
"""Update an existing template. Raises EntityNotFoundError if not found. Does NOT commit."""
# Assign template = get_template_or_raise(db, template_id)
template = get_template_or_raise(db, template_id)
# Iterate over fields.items()
for field, value in fields.items():
# Check: hasattr(template, field)
if hasattr(template, field):
# Call setattr()
setattr(template, field, value)
# Return template
return template
# Define function toggle_template_active
def toggle_template_active(db: Session, template_id: uuid.UUID) -> TestTemplate:
"""Toggle template active/inactive. Does NOT commit."""
# Assign template = get_template_or_raise(db, template_id)
template = get_template_or_raise(db, template_id)
# Assign template.is_active = not template.is_active
template.is_active = not template.is_active
# Return template
return template
# Define function soft_delete_template
def soft_delete_template(db: Session, template_id: uuid.UUID) -> None:
"""Soft-delete a template by setting is_active=False. Does NOT commit."""
# Assign template = get_template_or_raise(db, template_id)
template = get_template_or_raise(db, template_id)
# Assign template.is_active = False
template.is_active = False
+408 -3
View File
@@ -12,11 +12,19 @@ an audit-log entry. The caller (router) is responsible for committing the
session via the Unit of Work pattern.
"""
# Import logging
import logging
# Import uuid
import uuid
# Import datetime from datetime
from datetime import datetime
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session
# Import settings from app.config
from app.config import settings
from app.domain.exceptions import InvalidOperationError, InvalidTransitionError
from app.domain.test_entity import TestEntity
@@ -27,6 +35,31 @@ from app.models.user import User
from app.services.audit_service import log_action
from app.services.notification_service import notify_test_state_change, create_notification
# Import InvalidOperationError from app.domain.exceptions
from app.domain.exceptions import InvalidOperationError
# Import TestEntity from app.domain.test_entity
from app.domain.test_entity import TestEntity
# Import TestState from app.models.enums
from app.models.enums import TestState
# Import Test from app.models.test
from app.models.test import Test
# Import User from app.models.user
from app.models.user import User
# Import log_action from app.services.audit_service
from app.services.audit_service import log_action
# Import from app.services.notification_service
from app.services.notification_service import (
create_notification,
notify_test_state_change,
)
# Assign logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
@@ -50,18 +83,35 @@ VALID_TRANSITIONS: dict[TestState, list[TestState]] = {
def can_transition(test: Test, target_state: TestState) -> bool:
"""Return *True* if moving *test* to *target_state* is allowed."""
"""Return *True* if moving *test* to *target_state* is allowed.
Args:
test (Test): The test whose current state is being checked.
target_state (TestState): The state to transition to.
Returns:
bool: ``True`` if the transition is permitted by ``VALID_TRANSITIONS``.
"""
# Assign current = test.state if isinstance(test.state, TestState) else TestState(test...
current = test.state if isinstance(test.state, TestState) else TestState(test.state)
# Return target_state in VALID_TRANSITIONS.get(current, [])
return target_state in VALID_TRANSITIONS.get(current, [])
# Define function transition_state
def transition_state(
# Entry: db
db: Session,
# Entry: test
test: Test,
# Entry: target_state
target_state: TestState,
# Entry: user
user: User,
*,
# Entry: action_name
action_name: str = "transition_state",
# Entry: extra_details
extra_details: dict | None = None,
) -> Test:
"""Validate and perform a state transition, log it, and flush.
@@ -71,36 +121,71 @@ def transition_state(
when the transition is illegal. The entity is authoritative for which
transitions are valid; the module-level ``VALID_TRANSITIONS`` dict is
kept temporarily for backward compatibility of ``can_transition()``.
Args:
db (Session): Active SQLAlchemy database session.
test (Test): The test ORM object to transition.
target_state (TestState): Desired next state.
user (User): The user performing the transition (logged in audit).
action_name (str): Audit log action label; defaults to
``"transition_state"``.
extra_details (dict | None): Optional extra key-value pairs merged
into the audit log details.
Returns:
Test: The mutated test ORM object (state updated, flushed).
"""
# Assign entity = TestEntity.from_orm(test)
entity = TestEntity.from_orm(test)
# Assign previous_state = entity.transition_to(target_state)
previous_state = entity.transition_to(target_state)
# Assign test.state = entity.state
test.state = entity.state
# Flush changes to DB without committing the transaction
db.flush()
# Assign details = {
details: dict = {
# Literal argument value
"previous_state": previous_state,
# Literal argument value
"new_state": target_state.value,
# Literal argument value
"test_name": test.name,
# Literal argument value
"technique_id": str(test.technique_id),
}
# Check: extra_details
if extra_details:
# Call details.update()
details.update(extra_details)
# Call log_action()
log_action(
db,
# Keyword argument: user_id
user_id=user.id,
# Keyword argument: action
action=action_name,
# Keyword argument: entity_type
entity_type="test",
# Keyword argument: entity_id
entity_id=test.id,
# Keyword argument: details
details=details,
)
# Attempt the following; catch errors below
try:
# Call notify_test_state_change()
notify_test_state_change(db, test, target_state.value)
# Handle Exception
except Exception as e:
# Log warning: "Notification failed for test %s: %s", test.id, e
logger.warning("Notification failed for test %s: %s", test.id, e, exc_info=True)
# Return test
return test
@@ -112,27 +197,44 @@ def transition_state(
def start_execution(db: Session, test: Test, user: User) -> Test:
"""Move from ``draft`` → ``red_executing``."""
entity = TestEntity.from_orm(test)
# Call entity.start_execution()
entity.start_execution()
# Call entity.apply_to()
entity.apply_to(test)
# Flush changes to DB without committing the transaction
db.flush()
# Call log_action()
log_action(
db,
# Keyword argument: user_id
user_id=user.id,
# Keyword argument: action
action="start_execution",
# Keyword argument: entity_type
entity_type="test",
# Keyword argument: entity_id
entity_id=test.id,
# Keyword argument: details
details={
# Literal argument value
"previous_state": "draft",
# Literal argument value
"new_state": test.state.value,
# Literal argument value
"test_name": test.name,
# Literal argument value
"technique_id": str(test.technique_id),
},
)
# Attempt the following; catch errors below
try:
# Call notify_test_state_change()
notify_test_state_change(db, test, test.state.value)
# Handle Exception
except Exception as e:
# Log warning: "Notification failed for test %s: %s", test.id, e
logger.warning("Notification failed for test %s: %s", test.id, e, exc_info=True)
try:
@@ -144,6 +246,7 @@ def start_execution(db: Session, test: Test, user: User) -> Test:
return test
# Define function submit_red_evidence
def submit_red_evidence(db: Session, test: Test, user: User) -> Test:
"""Move from ``red_executing`` → ``blue_evaluating``.
@@ -151,6 +254,14 @@ def submit_red_evidence(db: Session, test: Test, user: User) -> Test:
Requires at least one Red Team evidence file to be uploaded.
Stops the Red Team timer and creates an automatic worklog.
Starts the Blue Team timer by recording ``blue_started_at``.
Args:
db (Session): Active SQLAlchemy database session.
test (Test): The test whose red-team evidence is being submitted.
user (User): The red-team user submitting the evidence.
Returns:
Test: The mutated test with state advanced and blue timer started.
"""
# Evidence is mandatory before submitting
red_evidence_count = (
@@ -167,29 +278,42 @@ def submit_red_evidence(db: Session, test: Test, user: User) -> Test:
# Auto-resume if paused
paused_extra = 0
# Check: test.paused_at is not None
if test.paused_at is not None:
# Assign paused_extra = max(int((now - test.paused_at).total_seconds()), 0)
paused_extra = max(int((now - test.paused_at).total_seconds()), 0)
# Assign test.paused_at = None
test.paused_at = None
# Assign test = transition_state(
test = transition_state(
db, test, TestState.blue_evaluating, user,
# Keyword argument: action_name
action_name="submit_red_evidence",
)
# Create automatic worklog for Red Team phase (subtract paused time)
_create_phase_worklog(
db,
# Keyword argument: test
test=test,
# Keyword argument: user
user=user,
# Keyword argument: phase_started_at
phase_started_at=test.red_started_at,
# Keyword argument: phase_ended_at
phase_ended_at=now,
# Keyword argument: paused_seconds
paused_seconds=(test.red_paused_seconds or 0) + paused_extra,
# Keyword argument: activity_type
activity_type="red_team_execution",
# Keyword argument: description
description=f"Red Team execution: {test.name}",
)
# Start Blue Team timer
test.blue_started_at = now
# Assign test.blue_paused_seconds = 0
test.blue_paused_seconds = 0
try:
@@ -234,6 +358,7 @@ def start_blue_work(db: Session, test: Test, user: User) -> Test:
return test
# Define function submit_blue_evidence
def submit_blue_evidence(db: Session, test: Test, user: User) -> Test:
"""Move from ``blue_evaluating`` → ``in_review``.
@@ -258,12 +383,17 @@ def submit_blue_evidence(db: Session, test: Test, user: User) -> Test:
# Auto-resume if paused
paused_extra = 0
# Check: test.paused_at is not None
if test.paused_at is not None:
# Assign paused_extra = max(int((now - test.paused_at).total_seconds()), 0)
paused_extra = max(int((now - test.paused_at).total_seconds()), 0)
# Assign test.paused_at = None
test.paused_at = None
# Assign test = transition_state(
test = transition_state(
db, test, TestState.in_review, user,
# Keyword argument: action_name
action_name="submit_blue_evidence",
)
@@ -272,12 +402,17 @@ def submit_blue_evidence(db: Session, test: Test, user: User) -> Test:
# Tempo worklog reflects real working time, not just queue time.
_create_phase_worklog(
db,
# Keyword argument: test
test=test,
# Keyword argument: user
user=user,
phase_started_at=test.blue_work_started_at or test.blue_started_at,
phase_ended_at=now,
# Keyword argument: paused_seconds
paused_seconds=(test.blue_paused_seconds or 0) + paused_extra,
# Keyword argument: activity_type
activity_type="blue_team_evaluation",
# Keyword argument: description
description=f"Blue Team evaluation: {test.name}",
)
@@ -290,69 +425,125 @@ def submit_blue_evidence(db: Session, test: Test, user: User) -> Test:
return test
# Define function pause_timer
def pause_timer(db: Session, test: Test, user: User) -> Test:
"""Pause the active phase timer.
Can only be called when the test is in ``red_executing`` or
``blue_evaluating`` and is not already paused.
Args:
db (Session): Active SQLAlchemy database session.
test (Test): The currently active test.
user (User): The user pausing the timer.
Returns:
Test: The mutated test with ``paused_at`` set to the current UTC time.
"""
# Check: test.state not in (TestState.red_executing, TestState.blue_evaluating)
if test.state not in (TestState.red_executing, TestState.blue_evaluating):
# Raise InvalidOperationError
raise InvalidOperationError(
f"Cannot pause timer in '{test.state.value}' state"
)
# Check: test.paused_at is not None
if test.paused_at is not None:
# Raise InvalidOperationError
raise InvalidOperationError("Timer is already paused")
# Assign test.paused_at = datetime.utcnow()
test.paused_at = datetime.utcnow()
# Call log_action()
log_action(
db,
# Keyword argument: user_id
user_id=user.id,
# Keyword argument: action
action="pause_timer",
# Keyword argument: entity_type
entity_type="test",
# Keyword argument: entity_id
entity_id=test.id,
# Keyword argument: details
details={"state": test.state.value},
)
# Return test
return test
# Define function resume_timer
def resume_timer(db: Session, test: Test, user: User) -> Test:
"""Resume a paused phase timer.
Accumulates the paused duration into the appropriate counter so
it is subtracted from the final worklog.
Args:
db (Session): Active SQLAlchemy database session.
test (Test): The paused test to resume.
user (User): The user resuming the timer.
Returns:
Test: The mutated test with ``paused_at`` cleared and accumulated
pause seconds updated.
"""
# Check: test.paused_at is None
if test.paused_at is None:
# Raise InvalidOperationError
raise InvalidOperationError("Timer is not paused")
# Assign now = datetime.utcnow()
now = datetime.utcnow()
# Assign paused_seconds = max(int((now - test.paused_at).total_seconds()), 0)
paused_seconds = max(int((now - test.paused_at).total_seconds()), 0)
# Check: test.state == TestState.red_executing
if test.state == TestState.red_executing:
# Assign test.red_paused_seconds = (test.red_paused_seconds or 0) + paused_seconds
test.red_paused_seconds = (test.red_paused_seconds or 0) + paused_seconds
# Alternative: test.state == TestState.blue_evaluating
elif test.state == TestState.blue_evaluating:
# Assign test.blue_paused_seconds = (test.blue_paused_seconds or 0) + paused_seconds
test.blue_paused_seconds = (test.blue_paused_seconds or 0) + paused_seconds
# Assign test.paused_at = None
test.paused_at = None
# Call log_action()
log_action(
db,
# Keyword argument: user_id
user_id=user.id,
# Keyword argument: action
action="resume_timer",
# Keyword argument: entity_type
entity_type="test",
# Keyword argument: entity_id
entity_id=test.id,
# Keyword argument: details
details={"paused_seconds": paused_seconds, "state": test.state.value},
)
# Return test
return test
# Define function _create_phase_worklog
def _create_phase_worklog(
# Entry: db
db: Session,
*,
# Entry: test
test: Test,
# Entry: user
user: User,
# Entry: phase_started_at
phase_started_at: datetime | None,
# Entry: phase_ended_at
phase_ended_at: datetime,
# Entry: paused_seconds
paused_seconds: int = 0,
# Entry: activity_type
activity_type: str,
# Entry: description
description: str,
) -> None:
"""Create an automatic, integrity-hashed worklog for a completed phase.
@@ -360,32 +551,64 @@ def _create_phase_worklog(
Subtracts accumulated *paused_seconds* from the gross elapsed time
so the worklog reflects only active working time.
Also triggers Tempo sync if the test has a Jira link.
Args:
db (Session): Active SQLAlchemy database session.
test (Test): The test for which the worklog is being created.
user (User): The user attributed to the worklog.
phase_started_at (datetime | None): Timestamp when the phase began;
if ``None`` the worklog is skipped with a warning.
phase_ended_at (datetime): Timestamp when the phase ended.
paused_seconds (int): Accumulated paused time in seconds to subtract
from gross elapsed time.
activity_type (str): Worklog activity type label (e.g.
``"red_team_execution"``).
description (str): Human-readable description for the worklog.
"""
# Check: not phase_started_at
if not phase_started_at:
# Log warning:
logger.warning(
# Literal argument value
"No phase start timestamp for test %s (%s), skipping worklog",
test.id, activity_type,
)
# Return control to caller
return
# Assign gross_seconds = int((phase_ended_at - phase_started_at).total_seconds())
gross_seconds = int((phase_ended_at - phase_started_at).total_seconds())
# Assign duration_seconds = max(gross_seconds - paused_seconds, 1)
duration_seconds = max(gross_seconds - paused_seconds, 1)
# Attempt the following; catch errors below
try:
# Import create_worklog from app.services.worklog_service
from app.services.worklog_service import create_worklog
# Assign wl = create_worklog(
wl = create_worklog(
db,
# Keyword argument: entity_type
entity_type="test",
# Keyword argument: entity_id
entity_id=test.id,
# Keyword argument: user_id
user_id=user.id,
# Keyword argument: activity_type
activity_type=activity_type,
# Keyword argument: started_at
started_at=phase_started_at,
# Keyword argument: ended_at
ended_at=phase_ended_at,
# Keyword argument: duration_seconds
duration_seconds=duration_seconds,
# Keyword argument: description
description=description,
)
# Log info:
logger.info(
# Literal argument value
"Auto-worklog created for test %s: %s, %ds (worklog %s)",
test.id, activity_type, duration_seconds, wl.id,
)
@@ -393,6 +616,7 @@ def _create_phase_worklog(
# Sync to Tempo: only red_team_execution, using the already-computed
# duration so the Tempo entry is identical to the Aegis worklog.
try:
# Import auto_log_test_worklog from app.services.tempo_service
from app.services.tempo_service import auto_log_test_worklog
tempo_result = auto_log_test_worklog(db, test, user, activity_type, duration_seconds)
if tempo_result and isinstance(tempo_result, dict):
@@ -400,17 +624,26 @@ def _create_phase_worklog(
wl.tempo_worklog_id = str(tempo_result.get("tempoWorklogId", ""))
db.flush()
except Exception as e:
# Log warning: "Tempo sync failed for worklog: %s", e, exc_info=T
logger.warning("Tempo sync failed for worklog: %s", e, exc_info=True)
# Handle Exception
except Exception as e:
# Log error: "Failed to create auto-worklog for test %s: %s", t
logger.error("Failed to create auto-worklog for test %s: %s", test.id, e, exc_info=True)
# Define function validate_as_red_lead
def validate_as_red_lead(
# Entry: db
db: Session,
# Entry: test
test: Test,
# Entry: user
user: User,
# Entry: validation_status
validation_status: str,
# Entry: notes
notes: str | None = None,
) -> Test:
"""Record Red Lead's validation decision.
@@ -418,21 +651,45 @@ def validate_as_red_lead(
Delegates validation rules and state mutation entirely to
:meth:`TestEntity.validate_red`. If both leads have voted the
entity will also advance the test to ``validated`` or ``rejected``.
Args:
db (Session): Active SQLAlchemy database session.
test (Test): The test being reviewed.
user (User): The red-lead user casting their vote.
validation_status (str): Validation decision, e.g. ``"approved"`` or
``"rejected"``.
notes (str | None): Optional freeform notes explaining the decision.
Returns:
Test: The mutated test with red-lead validation fields set.
"""
# Assign entity = TestEntity.from_orm(test)
entity = TestEntity.from_orm(test)
# Call entity.validate_red()
entity.validate_red(validation_status, by=user.id, notes=notes)
# Call entity.apply_to()
entity.apply_to(test)
# Flush changes to DB without committing the transaction
db.flush()
# Call log_action()
log_action(
db,
# Keyword argument: user_id
user_id=user.id,
# Keyword argument: action
action="validate_as_red_lead",
# Keyword argument: entity_type
entity_type="test",
# Keyword argument: entity_id
entity_id=test.id,
# Keyword argument: details
details={
# Literal argument value
"validation_status": validation_status,
# Literal argument value
"notes": notes,
# Literal argument value
"technique_id": str(test.technique_id),
},
)
@@ -441,11 +698,17 @@ def validate_as_red_lead(
return test
# Define function validate_as_blue_lead
def validate_as_blue_lead(
# Entry: db
db: Session,
# Entry: test
test: Test,
# Entry: user
user: User,
# Entry: validation_status
validation_status: str,
# Entry: notes
notes: str | None = None,
) -> Test:
"""Record Blue Lead's validation decision.
@@ -453,21 +716,45 @@ def validate_as_blue_lead(
Delegates validation rules and state mutation entirely to
:meth:`TestEntity.validate_blue`. If both leads have voted the
entity will also advance the test to ``validated`` or ``rejected``.
Args:
db (Session): Active SQLAlchemy database session.
test (Test): The test being reviewed.
user (User): The blue-lead user casting their vote.
validation_status (str): Validation decision, e.g. ``"approved"`` or
``"rejected"``.
notes (str | None): Optional freeform notes explaining the decision.
Returns:
Test: The mutated test with blue-lead validation fields set.
"""
# Assign entity = TestEntity.from_orm(test)
entity = TestEntity.from_orm(test)
# Call entity.validate_blue()
entity.validate_blue(validation_status, by=user.id, notes=notes)
# Call entity.apply_to()
entity.apply_to(test)
# Flush changes to DB without committing the transaction
db.flush()
# Call log_action()
log_action(
db,
# Keyword argument: user_id
user_id=user.id,
# Keyword argument: action
action="validate_as_blue_lead",
# Keyword argument: entity_type
entity_type="test",
# Keyword argument: entity_id
entity_id=test.id,
# Keyword argument: details
details={
# Literal argument value
"validation_status": validation_status,
# Literal argument value
"notes": notes,
# Literal argument value
"technique_id": str(test.technique_id),
},
)
@@ -476,35 +763,61 @@ def validate_as_blue_lead(
return test
# Define function check_dual_validation
def check_dual_validation(db: Session, test: Test) -> Test:
"""Evaluate both leads' decisions and advance the test if both have voted.
All state mutation is delegated to :meth:`TestEntity.check_dual_validation`.
This function never assigns ``test.state`` directly.
Args:
db (Session): Active SQLAlchemy database session.
test (Test): The test to evaluate.
Returns:
Test: The mutated test, potentially with state advanced to
``validated`` or ``rejected``.
"""
# Assign entity = TestEntity.from_orm(test)
entity = TestEntity.from_orm(test)
# Call entity.check_dual_validation()
entity.check_dual_validation()
# Call entity.apply_to()
entity.apply_to(test)
# Call _dispatch_dual_validation_effects()
_dispatch_dual_validation_effects(db, test, entity)
# Return test
return test
# Define function _dispatch_dual_validation_effects
def _dispatch_dual_validation_effects(
db: Session, test: Test, entity: TestEntity, actor: User | None = None
) -> None:
"""Dispatch side effects (notifications, cache, Jira) based on domain events."""
for event in entity.events:
# Check: event.name == "dual_validation_approved"
if event.name == "dual_validation_approved":
# Attempt the following; catch errors below
try:
# Import invalidate from app.services.score_cache
from app.services.score_cache import invalidate
# Call invalidate()
invalidate()
# Handle Exception
except Exception as e:
# Log warning: "Score cache invalidation failed: %s", e, exc_info
logger.warning("Score cache invalidation failed: %s", e, exc_info=True)
# Attempt the following; catch errors below
try:
# Call notify_test_state_change()
notify_test_state_change(db, test, "validated")
# Handle Exception
except Exception as e:
# Log warning:
logger.warning(
# Literal argument value
"Notification failed for test %s (validated): %s",
test.id, e, exc_info=True,
)
@@ -516,10 +829,15 @@ def _dispatch_dual_validation_effects(
logger.warning("Jira push failed for test %s: %s", test.id, e, exc_info=True)
elif event.name == "dual_validation_rejected":
# Attempt the following; catch errors below
try:
# Call notify_test_state_change()
notify_test_state_change(db, test, "rejected")
# Handle Exception
except Exception as e:
# Log warning:
logger.warning(
# Literal argument value
"Notification failed for test %s (rejected): %s",
test.id, e, exc_info=True,
)
@@ -585,6 +903,7 @@ def _notify_validation_conflict(db: Session, test: Test, actor: User | None) ->
)
# Define function handle_remediation_completed
def handle_remediation_completed(db: Session, test: Test, user: User) -> Test | None:
"""Create a re-test when remediation is completed.
@@ -594,121 +913,199 @@ def handle_remediation_completed(db: Session, test: Test, user: User) -> Test |
Prevents infinite loops by enforcing ``MAX_RETEST_COUNT``.
Returns the new retest or *None* if the limit was reached.
Args:
db (Session): Active SQLAlchemy database session.
test (Test): The test whose remediation was completed.
user (User): The user triggering the remediation completion.
Returns:
Test | None: The newly created retest, or ``None`` if the maximum
retest count has been reached.
"""
# Always reference the original test, not an intermediate retest
original_test_id = test.retest_of or test.id
# Check: test.retest_count >= settings.MAX_RETEST_COUNT
if test.retest_count >= settings.MAX_RETEST_COUNT:
# Max retests reached — notify and bail out
if test.created_by:
# Call create_notification()
create_notification(
db,
# Keyword argument: user_id
user_id=test.created_by,
# Keyword argument: type
type="max_retests_reached",
# Keyword argument: title
title="Maximum retests reached",
# Keyword argument: message
message=(
f'Test "{test.name}" has reached the maximum of '
f'{settings.MAX_RETEST_COUNT} retests. Manual review required.'
),
# Keyword argument: entity_type
entity_type="test",
# Keyword argument: entity_id
entity_id=test.id,
)
# Call log_action()
log_action(
db,
# Keyword argument: user_id
user_id=user.id,
# Keyword argument: action
action="max_retests_reached",
# Keyword argument: entity_type
entity_type="test",
# Keyword argument: entity_id
entity_id=test.id,
# Keyword argument: details
details={
# Literal argument value
"retest_count": test.retest_count,
# Literal argument value
"max_allowed": settings.MAX_RETEST_COUNT,
# Literal argument value
"original_test_id": str(original_test_id),
},
)
# Return None
return None
# Assign retest = Test(
retest = Test(
# Keyword argument: technique_id
technique_id=test.technique_id,
# Keyword argument: name
name=f"[Retest #{test.retest_count + 1}] {test.name.replace(f'[Retest #{test.retest_count}] ', '')}",
# Keyword argument: description
description=test.description,
# Keyword argument: platform
platform=test.platform,
# Keyword argument: procedure_text
procedure_text=test.procedure_text,
# Keyword argument: tool_used
tool_used=test.tool_used,
# Keyword argument: state
state=TestState.draft,
# Keyword argument: created_by
created_by=test.created_by,
# Keyword argument: retest_of
retest_of=original_test_id,
# Keyword argument: retest_count
retest_count=test.retest_count + 1,
)
# Stage new record(s) for database insertion
db.add(retest)
# Flush changes to DB without committing the transaction
db.flush()
# Call log_action()
log_action(
db,
# Keyword argument: user_id
user_id=user.id,
# Keyword argument: action
action="create_retest",
# Keyword argument: entity_type
entity_type="test",
# Keyword argument: entity_id
entity_id=retest.id,
# Keyword argument: details
details={
# Literal argument value
"original_test_id": str(original_test_id),
# Literal argument value
"retest_number": retest.retest_count,
# Literal argument value
"source_test_id": str(test.id),
},
)
# Notify the test creator and any red_tech users
if test.created_by:
# Call create_notification()
create_notification(
db,
# Keyword argument: user_id
user_id=test.created_by,
# Keyword argument: type
type="retest_created",
# Keyword argument: title
title="Re-test created",
# Keyword argument: message
message=(
f'A re-test has been automatically created for "{test.name}" '
f'after remediation was completed.'
),
# Keyword argument: entity_type
entity_type="test",
# Keyword argument: entity_id
entity_id=retest.id,
)
# Flush changes to DB without committing the transaction
db.flush()
# Return retest
return retest
def get_retest_chain(db: Session, test_id) -> list[Test]:
# Define function get_retest_chain
def get_retest_chain(db: Session, test_id: uuid.UUID) -> list[Test]:
"""Return the full chain of retests for a given test.
Includes the original test and all subsequent retests, ordered
by retest_count.
Args:
db (Session): Active SQLAlchemy database session.
test_id (uuid.UUID): UUID of any test in the retest chain.
Returns:
list[Test]: The original test followed by all its retests in
ascending retest-count order. Returns an empty list if the
test is not found.
"""
# Import uuid
import uuid as _uuid
# Assign tid = _uuid.UUID(str(test_id)) if not isinstance(test_id, _uuid.UUID) els...
tid = _uuid.UUID(str(test_id)) if not isinstance(test_id, _uuid.UUID) else test_id
# Find the original test first
test = db.query(Test).filter(Test.id == tid).first()
# Check: not test
if not test:
# Return []
return []
# Assign original_id = test.retest_of or test.id
original_id = test.retest_of or test.id
# Get original
original = db.query(Test).filter(Test.id == original_id).first()
# Check: not original
if not original:
# Return [test]
return [test]
# Get all retests of the original
retests = (
db.query(Test)
# Chain .filter() call
.filter(Test.retest_of == original_id)
# Chain .order_by() call
.order_by(Test.retest_count)
# Chain .all() call
.all()
)
# Return [original] + retests
return [original] + retests
# Define function reopen_test
def reopen_test(db: Session, test: Test, user: User) -> Test:
"""Move a ``rejected`` test back to ``draft`` for continued work.
@@ -719,20 +1116,27 @@ def reopen_test(db: Session, test: Test, user: User) -> Test:
re-validate the updated submission. Phase timing is reset so the timer
starts fresh for the new execution attempt.
"""
# Assign test = transition_state(
test = transition_state(
db, test, TestState.draft, user,
# Keyword argument: action_name
action_name="reopen_test",
)
# Clear validation DECISIONS — leads must re-validate the new attempt.
# Rejection NOTES are intentionally kept so teams see what needs fixing.
test.red_validation_status = None
# Assign test.red_validated_by = None
test.red_validated_by = None
# Assign test.red_validated_at = None
test.red_validated_at = None
# test.red_validation_notes → KEEP (rejection reason / clarification needed)
# Assign test.blue_validation_status = None
test.blue_validation_status = None
# Assign test.blue_validated_by = None
test.blue_validated_by = None
# Assign test.blue_validated_at = None
test.blue_validated_at = None
# test.blue_validation_notes → KEEP (rejection reason / clarification needed)
@@ -749,4 +1153,5 @@ def reopen_test(db: Session, test: Test, user: User) -> Test:
except Exception as e:
logger.warning("Jira push failed for test %s: %s", test.id, e, exc_info=True)
# Return test
return test
@@ -26,23 +26,49 @@ Deduplication by ``mitre_id`` for ThreatActor and by the unique
constraint ``(threat_actor_id, technique_id)`` for ThreatActorTechnique.
"""
# Import io
import io
# Import json
import json
# Import logging
import logging
# Import shutil
import shutil
# Import tempfile
import tempfile
# Import zipfile
import zipfile
# Import datetime from datetime
from datetime import datetime
# Import Path from pathlib
from pathlib import Path
# Import requests
import requests as _requests
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session
from app.models.threat_actor import ThreatActor, ThreatActorTechnique
from app.models.technique import Technique
# Import DataSource from app.models.data_source
from app.models.data_source import DataSource
# Import Technique from app.models.technique
from app.models.technique import Technique
# Import ThreatActor, ThreatActorTechnique from app.models.threat_actor
from app.models.threat_actor import ThreatActor, ThreatActorTechnique
# Import log_action from app.services.audit_service
from app.services.audit_service import log_action
# Assign logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
@@ -50,11 +76,15 @@ logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
MITRE_CTI_ZIP_URL = (
# Literal argument value
"https://github.com/mitre/cti"
# Literal argument value
"/archive/refs/heads/master.zip"
)
# Assign _DOWNLOAD_TIMEOUT = 300
_DOWNLOAD_TIMEOUT = 300
# Assign _ZIP_ROOT_PREFIX = "cti-master"
_ZIP_ROOT_PREFIX = "cti-master"
@@ -65,54 +95,86 @@ _ZIP_ROOT_PREFIX = "cti-master"
def _download_zip(url: str = MITRE_CTI_ZIP_URL) -> bytes:
"""Download the MITRE CTI ZIP and return raw bytes."""
# Log info: "Downloading MITRE CTI ZIP from %s …", url
logger.info("Downloading MITRE CTI ZIP from %s", url)
# Assign resp = _requests.get(url, timeout=_DOWNLOAD_TIMEOUT, stream=True)
resp = _requests.get(url, timeout=_DOWNLOAD_TIMEOUT, stream=True)
# Call resp.raise_for_status()
resp.raise_for_status()
# Assign content = resp.content
content = resp.content
# Log info: "Downloaded %.1f MB", len(content) / (1024 * 1024
logger.info("Downloaded %.1f MB", len(content) / (1024 * 1024))
# Return content
return content
# Define function _extract_zip_and_load_bundle
def _extract_zip_and_load_bundle(zip_bytes: bytes, dest: str) -> dict:
"""Extract ZIP and load the enterprise-attack STIX bundle."""
# Open context manager
with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
# Call zf.extractall()
zf.extractall(dest)
# Assign bundle_path = (
bundle_path = (
Path(dest) / _ZIP_ROOT_PREFIX
/ "enterprise-attack" / "enterprise-attack.json"
)
# Check: not bundle_path.is_file()
if not bundle_path.is_file():
# Raise FileNotFoundError
raise FileNotFoundError(
f"STIX bundle not found at {bundle_path}"
)
# Log info: "Loading STIX bundle from %s …", bundle_path
logger.info("Loading STIX bundle from %s", bundle_path)
# Open context manager
with open(bundle_path, "r", encoding="utf-8") as fh:
# Assign bundle = json.load(fh)
bundle = json.load(fh)
# Assign objects = bundle.get("objects", [])
objects = bundle.get("objects", [])
# Log info: "Loaded %d STIX objects", len(objects
logger.info("Loaded %d STIX objects", len(objects))
# Return bundle
return bundle
# Define function _extract_mitre_id
def _extract_mitre_id(external_references: list) -> str | None:
"""Extract the MITRE ATT&CK ID from external_references."""
# Check: not isinstance(external_references, list)
if not isinstance(external_references, list):
# Return None
return None
# Iterate over external_references
for ref in external_references:
# Check: isinstance(ref, dict) and ref.get("source_name") == "mitre-attack"
if isinstance(ref, dict) and ref.get("source_name") == "mitre-attack":
# Return ref.get("external_id")
return ref.get("external_id")
# Return None
return None
# Define function _extract_mitre_url
def _extract_mitre_url(external_references: list) -> str | None:
"""Extract the MITRE ATT&CK URL from external_references."""
# Check: not isinstance(external_references, list)
if not isinstance(external_references, list):
# Return None
return None
# Iterate over external_references
for ref in external_references:
# Check: isinstance(ref, dict) and ref.get("source_name") == "mitre-attack"
if isinstance(ref, dict) and ref.get("source_name") == "mitre-attack":
# Return ref.get("url")
return ref.get("url")
# Return None
return None
@@ -316,25 +378,41 @@ def _infer_motivation_from_description(description: str) -> str | None:
def _parse_intrusion_sets(objects: list) -> list[dict]:
"""Parse STIX intrusion-set objects into ThreatActor dicts."""
# Assign actors = []
actors = []
# Iterate over objects
for obj in objects:
# Check: obj.get("type") != "intrusion-set"
if obj.get("type") != "intrusion-set":
# Skip to the next loop iteration
continue
# Check: obj.get("revoked")
if obj.get("revoked"):
# Skip to the next loop iteration
continue
# Assign ext_refs = obj.get("external_references", [])
ext_refs = obj.get("external_references", [])
# Assign mitre_id = _extract_mitre_id(ext_refs)
mitre_id = _extract_mitre_id(ext_refs)
# Assign mitre_url = _extract_mitre_url(ext_refs)
mitre_url = _extract_mitre_url(ext_refs)
# Assign name = obj.get("name", "").strip()
name = obj.get("name", "").strip()
# Check: not name
if not name:
# Skip to the next loop iteration
continue
# Assign aliases = obj.get("aliases", [])
aliases = obj.get("aliases", [])
# Check: isinstance(aliases, list) and name in aliases
if isinstance(aliases, list) and name in aliases:
# Assign aliases = [a for a in aliases if a != name]
aliases = [a for a in aliases if a != name]
# Assign description = obj.get("description", "")
description = obj.get("description", "")
# Derive motivation: curated override > STIX field > description inference
@@ -348,80 +426,129 @@ def _parse_intrusion_sets(objects: list) -> list[dict]:
# Extract references (non-MITRE)
references = []
# Iterate over ext_refs
for ref in ext_refs:
# Check: isinstance(ref, dict) and ref.get("source_name") != "mitre-attack"
if isinstance(ref, dict) and ref.get("source_name") != "mitre-attack":
# Call references.append()
references.append({
# Literal argument value
"source": ref.get("source_name", ""),
# Literal argument value
"url": ref.get("url", ""),
# Literal argument value
"description": ref.get("description", ""),
})
# Call actors.append()
actors.append({
# Literal argument value
"stix_id": obj.get("id"), # e.g. "intrusion-set--abc123"
# Literal argument value
"mitre_id": mitre_id,
# Literal argument value
"name": name,
# Literal argument value
"aliases": aliases if aliases else [],
# Literal argument value
"description": description,
# Literal argument value
"mitre_url": mitre_url,
# Literal argument value
"references": references[:20], # cap to avoid bloat
# Literal argument value
"first_seen": obj.get("first_seen"),
# Literal argument value
"last_seen": obj.get("last_seen"),
"motivation": motivation,
"sophistication": sophistication,
})
# Log info: "Parsed %d intrusion-sets (threat actors)", len(ac
logger.info("Parsed %d intrusion-sets (threat actors)", len(actors))
# Return actors
return actors
# Define function _parse_relationships
def _parse_relationships(objects: list) -> list[dict]:
"""Parse STIX relationship objects (type=uses) linking
intrusion-sets to attack-patterns.
"""
"""Parse STIX relationship objects (type=uses) linking intrusion-sets to attack-patterns."""
# Assign relationships = []
relationships = []
# Iterate over objects
for obj in objects:
# Check: obj.get("type") != "relationship"
if obj.get("type") != "relationship":
# Skip to the next loop iteration
continue
# Check: obj.get("relationship_type") != "uses"
if obj.get("relationship_type") != "uses":
# Skip to the next loop iteration
continue
# Check: obj.get("revoked")
if obj.get("revoked"):
# Skip to the next loop iteration
continue
# Assign source_ref = obj.get("source_ref", "")
source_ref = obj.get("source_ref", "")
# Assign target_ref = obj.get("target_ref", "")
target_ref = obj.get("target_ref", "")
# We want intrusion-set → attack-pattern
if not source_ref.startswith("intrusion-set--"):
# Skip to the next loop iteration
continue
# Check: not target_ref.startswith("attack-pattern--")
if not target_ref.startswith("attack-pattern--"):
# Skip to the next loop iteration
continue
# Call relationships.append()
relationships.append({
# Literal argument value
"source_ref": source_ref,
# Literal argument value
"target_ref": target_ref,
# Literal argument value
"description": obj.get("description", ""),
})
# Log info: "Parsed %d uses-relationships (actor→technique)",
logger.info("Parsed %d uses-relationships (actor→technique)", len(relationships))
# Return relationships
return relationships
# Define function _build_attack_pattern_map
def _build_attack_pattern_map(objects: list) -> dict[str, str]:
"""Build a map from STIX attack-pattern ID → MITRE technique ID.
e.g. {"attack-pattern--abc123": "T1059.001"}
"""
# Assign mapping = {}
mapping = {}
# Iterate over objects
for obj in objects:
# Check: obj.get("type") != "attack-pattern"
if obj.get("type") != "attack-pattern":
# Skip to the next loop iteration
continue
# Check: obj.get("revoked")
if obj.get("revoked"):
# Skip to the next loop iteration
continue
# Assign stix_id = obj.get("id", "")
stix_id = obj.get("id", "")
# Assign mitre_id = _extract_mitre_id(obj.get("external_references", []))
mitre_id = _extract_mitre_id(obj.get("external_references", []))
# Check: stix_id and mitre_id
if stix_id and mitre_id:
# Assign mapping[stix_id] = mitre_id
mapping[stix_id] = mitre_id
# Log info: "Built attack-pattern map with %d entries", len(ma
logger.info("Built attack-pattern map with %d entries", len(mapping))
# Return mapping
return mapping
@@ -435,24 +562,31 @@ def sync(db: Session) -> dict:
Returns a summary dict.
"""
# Assign tmp_dir = tempfile.mkdtemp(prefix="aegis_mitre_cti_")
tmp_dir = tempfile.mkdtemp(prefix="aegis_mitre_cti_")
# Attempt the following; catch errors below
try:
# Assign zip_bytes = _download_zip()
zip_bytes = _download_zip()
# Assign bundle = _extract_zip_and_load_bundle(zip_bytes, tmp_dir)
bundle = _extract_zip_and_load_bundle(zip_bytes, tmp_dir)
# Always execute this cleanup block
finally:
# Call shutil.rmtree()
shutil.rmtree(tmp_dir, ignore_errors=True)
# Log info: "Cleaned up temp directory %s", tmp_dir
logger.info("Cleaned up temp directory %s", tmp_dir)
# Assign objects = bundle.get("objects", [])
objects = bundle.get("objects", [])
# Step 1: Parse data
actor_dicts = _parse_intrusion_sets(objects)
# Assign relationships = _parse_relationships(objects)
relationships = _parse_relationships(objects)
# Assign attack_pattern_map = _build_attack_pattern_map(objects)
attack_pattern_map = _build_attack_pattern_map(objects)
# Step 2: Build STIX-ID → actor dict map
stix_to_actor = {a["stix_id"]: a for a in actor_dicts}
# Step 3: Load existing actors and techniques from DB
existing_actors = {
row.mitre_id: row
@@ -460,6 +594,7 @@ def sync(db: Session) -> dict:
if row.mitre_id
}
# Assign technique_by_mitre_id = {
technique_by_mitre_id = {
row.mitre_id: row
for row in db.query(Technique).all()
@@ -467,22 +602,35 @@ def sync(db: Session) -> dict:
# Step 4: Upsert threat actors
actors_created = 0
# Assign actors_skipped = 0
actors_skipped = 0
# Assign stix_to_db_actor = {}
stix_to_db_actor: dict[str, ThreatActor] = {}
# Iterate over actor_dicts
for actor_dict in actor_dicts:
# Assign mitre_id = actor_dict["mitre_id"]
mitre_id = actor_dict["mitre_id"]
# Assign stix_id = actor_dict["stix_id"]
stix_id = actor_dict["stix_id"]
# Check: mitre_id and mitre_id in existing_actors
if mitre_id and mitre_id in existing_actors:
# Update existing actor
db_actor = existing_actors[mitre_id]
# Assign db_actor.name = actor_dict["name"]
db_actor.name = actor_dict["name"]
# Assign db_actor.aliases = actor_dict["aliases"]
db_actor.aliases = actor_dict["aliases"]
# Assign db_actor.description = actor_dict["description"]
db_actor.description = actor_dict["description"]
# Assign db_actor.mitre_url = actor_dict["mitre_url"]
db_actor.mitre_url = actor_dict["mitre_url"]
# Assign db_actor.references = actor_dict["references"]
db_actor.references = actor_dict["references"]
# Assign db_actor.first_seen = actor_dict.get("first_seen")
db_actor.first_seen = actor_dict.get("first_seen")
# Assign db_actor.last_seen = actor_dict.get("last_seen")
db_actor.last_seen = actor_dict.get("last_seen")
# Update enrichment fields if available
if actor_dict.get("motivation"):
@@ -490,101 +638,165 @@ def sync(db: Session) -> dict:
if actor_dict.get("sophistication"):
db_actor.sophistication = actor_dict["sophistication"]
stix_to_db_actor[stix_id] = db_actor
# Assign actors_skipped = 1
actors_skipped += 1
# Fallback: handle remaining cases
else:
# Create new actor
db_actor = ThreatActor(
# Keyword argument: mitre_id
mitre_id=mitre_id,
# Keyword argument: name
name=actor_dict["name"],
# Keyword argument: aliases
aliases=actor_dict["aliases"],
# Keyword argument: description
description=actor_dict["description"],
# Keyword argument: mitre_url
mitre_url=actor_dict["mitre_url"],
# Keyword argument: references
references=actor_dict["references"],
# Keyword argument: first_seen
first_seen=actor_dict.get("first_seen"),
# Keyword argument: last_seen
last_seen=actor_dict.get("last_seen"),
motivation=actor_dict.get("motivation"),
sophistication=actor_dict.get("sophistication"),
is_active=True,
)
# Stage new record(s) for database insertion
db.add(db_actor)
# Flush changes to DB without committing the transaction
db.flush() # get the ID
# Check: mitre_id
if mitre_id:
# Assign existing_actors[mitre_id] = db_actor
existing_actors[mitre_id] = db_actor
# Assign stix_to_db_actor[stix_id] = db_actor
stix_to_db_actor[stix_id] = db_actor
# Assign actors_created = 1
actors_created += 1
# Flush changes to DB without committing the transaction
db.flush()
# Step 5: Upsert actor-technique relationships
# Load existing relationships
existing_rels: set[tuple] = set()
# Iterate over db.query(ThreatActorTechnique).all()
for row in db.query(ThreatActorTechnique).all():
# Call existing_rels.add()
existing_rels.add((str(row.threat_actor_id), str(row.technique_id)))
# Assign rels_created = 0
rels_created = 0
# Assign rels_skipped = 0
rels_skipped = 0
# Iterate over relationships
for rel in relationships:
# Assign source_ref = rel["source_ref"]
source_ref = rel["source_ref"]
# Assign target_ref = rel["target_ref"]
target_ref = rel["target_ref"]
# Resolve actor
db_actor = stix_to_db_actor.get(source_ref)
# Check: not db_actor
if not db_actor:
# Skip to the next loop iteration
continue
# Resolve technique
mitre_technique_id = attack_pattern_map.get(target_ref)
# Check: not mitre_technique_id
if not mitre_technique_id:
# Skip to the next loop iteration
continue
# Assign db_technique = technique_by_mitre_id.get(mitre_technique_id)
db_technique = technique_by_mitre_id.get(mitre_technique_id)
# Check: not db_technique
if not db_technique:
# Skip to the next loop iteration
continue
# Assign rel_key = (str(db_actor.id), str(db_technique.id))
rel_key = (str(db_actor.id), str(db_technique.id))
# Check: rel_key in existing_rels
if rel_key in existing_rels:
# Assign rels_skipped = 1
rels_skipped += 1
# Skip to the next loop iteration
continue
# Assign actor_technique = ThreatActorTechnique(
actor_technique = ThreatActorTechnique(
# Keyword argument: threat_actor_id
threat_actor_id=db_actor.id,
# Keyword argument: technique_id
technique_id=db_technique.id,
# Keyword argument: usage_description
usage_description=rel["description"][:5000] if rel["description"] else None,
)
# Stage new record(s) for database insertion
db.add(actor_technique)
# Call existing_rels.add()
existing_rels.add(rel_key)
# Assign rels_created = 1
rels_created += 1
# Commit all pending changes to the database
db.commit()
# Assign summary = {
summary = {
# Literal argument value
"actors_created": actors_created,
# Literal argument value
"actors_updated": actors_skipped,
# Literal argument value
"relationships_created": rels_created,
# Literal argument value
"relationships_skipped": rels_skipped,
# Literal argument value
"total_actors_parsed": len(actor_dicts),
# Literal argument value
"total_relationships_parsed": len(relationships),
}
# Update DataSource record
ds = db.query(DataSource).filter(DataSource.name == "mitre_cti").first()
# Check: ds
if ds:
# Assign ds.last_sync_at = datetime.utcnow()
ds.last_sync_at = datetime.utcnow()
# Assign ds.last_sync_status = "success"
ds.last_sync_status = "success"
# Assign ds.last_sync_stats = summary
ds.last_sync_stats = summary
# Commit all pending changes to the database
db.commit()
# Log info: "MITRE CTI threat actor import complete — %s", sum
logger.info("MITRE CTI threat actor import complete — %s", summary)
# Call log_action()
log_action(
db,
# Keyword argument: user_id
user_id=None,
# Keyword argument: action
action="import_threat_actors",
# Keyword argument: entity_type
entity_type="threat_actor",
# Keyword argument: entity_id
entity_id=None,
# Keyword argument: details
details=summary,
)
# Commit all pending changes to the database
db.commit()
# Return summary
return summary
+189 -6
View File
@@ -6,34 +6,56 @@ that the router remains a thin HTTP adapter.
This module is framework-agnostic: no FastAPI imports.
"""
# Enable future language features for compatibility
from __future__ import annotations
# Import Any from typing
from typing import Any
from sqlalchemy import case, cast, func, or_, Text
from sqlalchemy.orm import Session
# Import EntityNotFoundError from app.domain.errors
from app.domain.errors import EntityNotFoundError
from app.models.enums import TechniqueStatus
from app.models.test import Test
from app.models.test_template import TestTemplate
from app.models.threat_actor import ThreatActor, ThreatActorTechnique
from app.models.technique import Technique
from app.utils import escape_like
# Import TechniqueStatus from app.models.enums
from app.models.enums import TechniqueStatus
# Import Technique from app.models.technique
from app.models.technique import Technique
# Import Test from app.models.test
from app.models.test import Test
# Import TestTemplate from app.models.test_template
from app.models.test_template import TestTemplate
# Import ThreatActor, ThreatActorTechnique from app.models.threat_actor
from app.models.threat_actor import ThreatActor, ThreatActorTechnique
# Import escape_like from app.utils
from app.utils import escape_like
# ── Public service functions ──────────────────────────────────────────
def list_actors(
# Entry: db
db: Session,
*,
# Entry: search
search: str | None = None,
# Entry: country
country: str | None = None,
# Entry: motivation
motivation: str | None = None,
# Entry: sophistication
sophistication: str | None = None,
# Entry: target_sectors
target_sectors: str | None = None,
# Entry: offset
offset: int = 0,
# Entry: limit
limit: int = 50,
) -> dict[str, Any]:
"""List threat actors with optional filters, pagination, and coverage stats.
@@ -41,10 +63,14 @@ def list_actors(
Uses grouped subqueries to avoid N+1: technique counts and coverage
counts are fetched in one query per page.
"""
# Assign query = db.query(ThreatActor)
query = db.query(ThreatActor)
# Check: search
if search:
# Assign pattern = f"%{escape_like(search)}%"
pattern = f"%{escape_like(search)}%"
# Assign query = query.filter(
query = query.filter(
or_(
ThreatActor.name.ilike(pattern),
@@ -53,35 +79,52 @@ def list_actors(
)
)
# Check: country
if country:
# Assign query = query.filter(ThreatActor.country == country)
query = query.filter(ThreatActor.country == country)
# Check: motivation
if motivation:
# Assign query = query.filter(ThreatActor.motivation == motivation)
query = query.filter(ThreatActor.motivation == motivation)
# Check: sophistication
if sophistication:
# Assign query = query.filter(ThreatActor.sophistication == sophistication)
query = query.filter(ThreatActor.sophistication == sophistication)
# Check: target_sectors
if target_sectors:
# Assign query = query.filter(
query = query.filter(
cast(ThreatActor.target_sectors, Text).ilike(
f"%{escape_like(target_sectors)}%"
)
)
# Assign total = query.count()
total = query.count()
# Assign actors = (
actors = (
query.order_by(ThreatActor.name).offset(offset).limit(limit).all()
)
# Check: not actors
if not actors:
# Return {
return {
# Literal argument value
"total": total,
# Literal argument value
"offset": offset,
# Literal argument value
"limit": limit,
# Literal argument value
"items": [],
}
# Assign actor_ids = [a.id for a in actors]
actor_ids = [a.id for a in actors]
# Single grouped query: tech_count and covered_count per actor
@@ -96,215 +139,355 @@ def list_actors(
TechniqueStatus.validated,
TechniqueStatus.partial,
]),
# Literal argument value
1,
),
# Keyword argument: else_
else_=0,
)
).label("covered_count"),
)
# Chain .join() call
.join(Technique, ThreatActorTechnique.technique_id == Technique.id)
# Chain .filter() call
.filter(ThreatActorTechnique.threat_actor_id.in_(actor_ids))
# Chain .group_by() call
.group_by(ThreatActorTechnique.threat_actor_id)
).all()
# Assign counts_map = {
counts_map = {
str(row.threat_actor_id): {
# Literal argument value
"tech_count": row.tech_count,
# Literal argument value
"covered_count": row.covered_count or 0,
}
for row in counts_rows
}
# Assign results = []
results = []
# Iterate over actors
for actor in actors:
# Assign cnt = counts_map.get(str(actor.id), {"tech_count": 0, "covered_count": 0})
cnt = counts_map.get(str(actor.id), {"tech_count": 0, "covered_count": 0})
# Assign tech_count = cnt["tech_count"]
tech_count = cnt["tech_count"]
# Assign covered = cnt["covered_count"]
covered = cnt["covered_count"]
# Assign coverage_pct = round((covered / tech_count * 100), 1) if tech_count > 0 else 0.0
coverage_pct = round((covered / tech_count * 100), 1) if tech_count > 0 else 0.0
# Call results.append()
results.append({
# Literal argument value
"id": str(actor.id),
# Literal argument value
"mitre_id": actor.mitre_id,
# Literal argument value
"name": actor.name,
# Literal argument value
"aliases": actor.aliases or [],
# Literal argument value
"country": actor.country,
# Literal argument value
"target_sectors": actor.target_sectors or [],
# Literal argument value
"target_regions": actor.target_regions or [],
# Literal argument value
"motivation": actor.motivation,
# Literal argument value
"sophistication": actor.sophistication,
# Literal argument value
"mitre_url": actor.mitre_url,
# Literal argument value
"technique_count": tech_count,
# Literal argument value
"coverage_pct": coverage_pct,
# Literal argument value
"is_active": actor.is_active,
})
# Return {
return {
# Literal argument value
"total": total,
# Literal argument value
"offset": offset,
# Literal argument value
"limit": limit,
# Literal argument value
"items": results,
}
# Define function get_actor_detail
def get_actor_detail(db: Session, actor_id: str) -> dict[str, Any]:
"""Get detailed threat actor with techniques.
Raises EntityNotFoundError if the actor does not exist.
"""
# Assign actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first()
actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first()
# Check: not actor
if not actor:
# Raise EntityNotFoundError
raise EntityNotFoundError("Threat actor", actor_id)
# Assign actor_techniques = (
actor_techniques = (
db.query(ThreatActorTechnique, Technique)
# Chain .join() call
.join(Technique, ThreatActorTechnique.technique_id == Technique.id)
# Chain .filter() call
.filter(ThreatActorTechnique.threat_actor_id == actor.id)
# Chain .order_by() call
.order_by(Technique.mitre_id)
# Chain .all() call
.all()
)
# Assign techniques_list = [
techniques_list = [
{
# Literal argument value
"technique_id": str(tech.id),
# Literal argument value
"mitre_id": tech.mitre_id,
# Literal argument value
"name": tech.name,
# Literal argument value
"tactic": tech.tactic,
# Literal argument value
"status_global": tech.status_global.value if tech.status_global else None,
# Literal argument value
"usage_description": at.usage_description,
# Literal argument value
"first_seen_using": at.first_seen_using,
}
for at, tech in actor_techniques
]
# Return {
return {
# Literal argument value
"id": str(actor.id),
# Literal argument value
"mitre_id": actor.mitre_id,
# Literal argument value
"name": actor.name,
# Literal argument value
"aliases": actor.aliases or [],
# Literal argument value
"description": actor.description,
# Literal argument value
"country": actor.country,
# Literal argument value
"target_sectors": actor.target_sectors or [],
# Literal argument value
"target_regions": actor.target_regions or [],
# Literal argument value
"motivation": actor.motivation,
# Literal argument value
"sophistication": actor.sophistication,
# Literal argument value
"first_seen": actor.first_seen,
# Literal argument value
"last_seen": actor.last_seen,
# Literal argument value
"references": actor.references or [],
# Literal argument value
"mitre_url": actor.mitre_url,
# Literal argument value
"is_active": actor.is_active,
# Literal argument value
"techniques": techniques_list,
}
# Define function get_actor_coverage
def get_actor_coverage(db: Session, actor_id: str) -> dict[str, Any]:
"""Calculate coverage percentage against a specific threat actor.
Raises EntityNotFoundError if the actor does not exist.
"""
# Assign actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first()
actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first()
# Check: not actor
if not actor:
# Raise EntityNotFoundError
raise EntityNotFoundError("Threat actor", actor_id)
# Assign actor_techniques = (
actor_techniques = (
db.query(Technique)
# Chain .join() call
.join(ThreatActorTechnique, ThreatActorTechnique.technique_id == Technique.id)
# Chain .filter() call
.filter(ThreatActorTechnique.threat_actor_id == actor.id)
# Chain .all() call
.all()
)
# Assign total = len(actor_techniques)
total = len(actor_techniques)
# Check: total == 0
if total == 0:
# Return {
return {
# Literal argument value
"actor_id": str(actor.id),
# Literal argument value
"actor_name": actor.name,
# Literal argument value
"total_techniques": 0,
# Literal argument value
"coverage_pct": 0.0,
# Literal argument value
"breakdown": {},
}
# Assign breakdown = {}
breakdown: dict[str, int] = {}
# Iterate over actor_techniques
for tech in actor_techniques:
# Assign status = tech.status_global.value if tech.status_global else "not_evaluated"
status = tech.status_global.value if tech.status_global else "not_evaluated"
# Assign breakdown[status] = breakdown.get(status, 0) + 1
breakdown[status] = breakdown.get(status, 0) + 1
# Assign covered = breakdown.get("validated", 0) + breakdown.get("partial", 0)
covered = breakdown.get("validated", 0) + breakdown.get("partial", 0)
# Assign coverage_pct = round((covered / total * 100), 1)
coverage_pct = round((covered / total * 100), 1)
# Return {
return {
# Literal argument value
"actor_id": str(actor.id),
# Literal argument value
"actor_name": actor.name,
# Literal argument value
"total_techniques": total,
# Literal argument value
"covered": covered,
# Literal argument value
"coverage_pct": coverage_pct,
# Literal argument value
"breakdown": breakdown,
}
# Define function get_actor_gaps
def get_actor_gaps(db: Session, actor_id: str) -> dict[str, Any]:
"""Identify techniques of this actor that are not fully validated.
Raises EntityNotFoundError if the actor does not exist.
"""
# Assign actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first()
actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first()
# Check: not actor
if not actor:
# Raise EntityNotFoundError
raise EntityNotFoundError("Threat actor", actor_id)
# Assign gap_techniques = (
gap_techniques = (
db.query(Technique, ThreatActorTechnique)
# Chain .join() call
.join(ThreatActorTechnique, ThreatActorTechnique.technique_id == Technique.id)
# Chain .filter() call
.filter(ThreatActorTechnique.threat_actor_id == actor.id)
# Chain .filter() call
.filter(Technique.status_global != TechniqueStatus.validated)
# Chain .order_by() call
.order_by(Technique.mitre_id)
# Chain .all() call
.all()
)
# Check: not gap_techniques
if not gap_techniques:
# Return {
return {
# Literal argument value
"actor_id": str(actor.id),
# Literal argument value
"actor_name": actor.name,
# Literal argument value
"total_gaps": 0,
# Literal argument value
"gaps": [],
}
# Assign technique_ids = [tech.id for tech, _ in gap_techniques]
technique_ids = [tech.id for tech, _ in gap_techniques]
# Assign mitre_ids = [tech.mitre_id for tech, _ in gap_techniques]
mitre_ids = [tech.mitre_id for tech, _ in gap_techniques]
# Batch template counts by mitre_technique_id
template_counts = (
db.query(TestTemplate.mitre_technique_id, func.count(TestTemplate.id).label("cnt"))
# Chain .filter() call
.filter(TestTemplate.mitre_technique_id.in_(mitre_ids))
# Chain .filter() call
.filter(TestTemplate.is_active == True)
# Chain .group_by() call
.group_by(TestTemplate.mitre_technique_id)
).all()
# Assign template_map = {row.mitre_technique_id: row.cnt for row in template_counts}
template_map = {row.mitre_technique_id: row.cnt for row in template_counts}
# Batch test counts by technique_id
test_counts = (
db.query(Test.technique_id, func.count(Test.id).label("cnt"))
# Chain .filter() call
.filter(Test.technique_id.in_(technique_ids))
# Chain .group_by() call
.group_by(Test.technique_id)
).all()
# Assign test_map = {str(row.technique_id): row.cnt for row in test_counts}
test_map = {str(row.technique_id): row.cnt for row in test_counts}
# Assign gaps = []
gaps = []
# Iterate over gap_techniques
for tech, at in gap_techniques:
# Assign template_count = template_map.get(tech.mitre_id, 0)
template_count = template_map.get(tech.mitre_id, 0)
# Assign test_count = test_map.get(str(tech.id), 0)
test_count = test_map.get(str(tech.id), 0)
# Call gaps.append()
gaps.append({
# Literal argument value
"technique_id": str(tech.id),
# Literal argument value
"mitre_id": tech.mitre_id,
# Literal argument value
"name": tech.name,
# Literal argument value
"tactic": tech.tactic,
# Literal argument value
"status_global": tech.status_global.value if tech.status_global else None,
# Literal argument value
"usage_description": at.usage_description,
# Literal argument value
"available_templates": template_count,
# Literal argument value
"existing_tests": test_count,
# Literal argument value
"has_templates": template_count > 0,
})
# Return {
return {
# Literal argument value
"actor_id": str(actor.id),
# Literal argument value
"actor_name": actor.name,
# Literal argument value
"total_gaps": len(gaps),
# Literal argument value
"gaps": gaps,
}
+49 -1
View File
@@ -4,30 +4,51 @@ Uses domain exceptions from app.domain.errors. The router handles
HTTP concerns, auth, audit logging, and commit.
"""
# Enable future language features for compatibility
from __future__ import annotations
# Import uuid
import uuid
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session
# Import hash_password from app.auth
from app.auth import hash_password
from app.domain.errors import BusinessRuleViolation, DuplicateEntityError, EntityNotFoundError
# Import from app.domain.errors
from app.domain.errors import (
BusinessRuleViolation,
DuplicateEntityError,
EntityNotFoundError,
)
# Import User from app.models.user
from app.models.user import User
# Assign VALID_ROLES = {"admin", "red_tech", "blue_tech", "red_lead", "blue_lead", "viewer"}
VALID_ROLES = {"admin", "red_tech", "blue_tech", "red_lead", "blue_lead", "viewer"}
# Define function list_users
def list_users(db: Session) -> list[User]:
"""Return a list of all users ordered by username."""
# Return db.query(User).order_by(User.username).all()
return db.query(User).order_by(User.username).all()
# Define function create_user
def create_user(
# Entry: db
db: Session,
*,
# Entry: username
username: str,
# Entry: email
email: str | None,
# Entry: password
password: str,
# Entry: role
role: str,
) -> User:
"""Create a new user.
@@ -36,33 +57,51 @@ def create_user(
Raises BusinessRuleViolation if role is invalid.
Does not commit; the router handles that.
"""
# Assign existing = db.query(User).filter(User.username == username).first()
existing = db.query(User).filter(User.username == username).first()
# Check: existing
if existing:
# Raise DuplicateEntityError
raise DuplicateEntityError("User", "username", username)
# Check: role not in VALID_ROLES
if role not in VALID_ROLES:
# Raise BusinessRuleViolation
raise BusinessRuleViolation(
f"Invalid role '{role}'. Must be one of: {', '.join(sorted(VALID_ROLES))}"
)
# Assign user = User(
user = User(
# Keyword argument: username
username=username,
# Keyword argument: email
email=email,
# Keyword argument: hashed_password
hashed_password=hash_password(password),
# Keyword argument: role
role=role,
)
# Stage new record(s) for database insertion
db.add(user)
# Return user
return user
# Define function get_user_or_raise
def get_user_or_raise(db: Session, user_id: uuid.UUID) -> User:
"""Return a user by ID or raise EntityNotFoundError."""
# Assign user = db.query(User).filter(User.id == user_id).first()
user = db.query(User).filter(User.id == user_id).first()
# Check: user is None
if user is None:
# Raise EntityNotFoundError
raise EntityNotFoundError("User", str(user_id))
# Return user
return user
# Define function update_user
def update_user(db: Session, user_id: uuid.UUID, **fields: object) -> User:
"""Update one or more fields of an existing user.
@@ -71,18 +110,27 @@ def update_user(db: Session, user_id: uuid.UUID, **fields: object) -> User:
Handles 'password' by hashing and storing as 'hashed_password'.
Does not commit; the router handles that.
"""
# Assign user = get_user_or_raise(db, user_id)
user = get_user_or_raise(db, user_id)
# Assign update_data = dict(fields)
update_data = dict(fields)
# Check: "role" in update_data and update_data["role"] not in VALID_ROLES
if "role" in update_data and update_data["role"] not in VALID_ROLES:
# Raise BusinessRuleViolation
raise BusinessRuleViolation(
f"Invalid role '{update_data['role']}'. Must be one of: {', '.join(sorted(VALID_ROLES))}"
)
# Check: "password" in update_data
if "password" in update_data:
# Assign update_data["hashed_password"] = hash_password(str(update_data.pop("password")))
update_data["hashed_password"] = hash_password(str(update_data.pop("password")))
# Iterate over update_data.items()
for field, value in update_data.items():
# Call setattr()
setattr(user, field, value)
# Return user
return user
+58
View File
@@ -1,83 +1,141 @@
"""Internal worklog service — CRUD with integrity hashing."""
# Import hashlib
import hashlib
# Import logging
import logging
# Import datetime from datetime
from datetime import datetime
# Import Optional from typing
from typing import Optional
# Import UUID from uuid
from uuid import UUID
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session
# Import EntityNotFoundError from app.domain.errors
from app.domain.errors import EntityNotFoundError
# Import Worklog from app.models.worklog
from app.models.worklog import Worklog
# Assign logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
# Define function create_worklog
def create_worklog(
# Entry: db
db: Session,
*,
# Entry: entity_type
entity_type: str,
# Entry: entity_id
entity_id: UUID,
# Entry: user_id
user_id: UUID,
# Entry: activity_type
activity_type: str,
# Entry: started_at
started_at: datetime,
# Entry: duration_seconds
duration_seconds: int,
# Entry: ended_at
ended_at: Optional[datetime] = None,
# Entry: description
description: Optional[str] = None,
) -> Worklog:
"""Create a worklog with an auto-computed integrity hash."""
# Assign wl = Worklog(
wl = Worklog(
# Keyword argument: entity_type
entity_type=entity_type,
# Keyword argument: entity_id
entity_id=entity_id,
# Keyword argument: user_id
user_id=user_id,
# Keyword argument: activity_type
activity_type=activity_type,
# Keyword argument: started_at
started_at=started_at,
# Keyword argument: ended_at
ended_at=ended_at,
# Keyword argument: duration_seconds
duration_seconds=duration_seconds,
# Keyword argument: description
description=description,
)
# Assign wl.integrity_hash = _compute_hash(wl)
wl.integrity_hash = _compute_hash(wl)
# Stage new record(s) for database insertion
db.add(wl)
# Does not commit; caller (router) uses UnitOfWork.
return wl
# Define function get_worklog_or_raise
def get_worklog_or_raise(db: Session, worklog_id: UUID) -> Worklog:
"""Get a worklog by ID or raise EntityNotFoundError."""
# Assign wl = db.query(Worklog).filter(Worklog.id == worklog_id).first()
wl = db.query(Worklog).filter(Worklog.id == worklog_id).first()
# Check: not wl
if not wl:
# Raise EntityNotFoundError
raise EntityNotFoundError("Worklog", str(worklog_id))
# Return wl
return wl
# Define function list_worklogs
def list_worklogs(
# Entry: db
db: Session,
*,
# Entry: entity_type
entity_type: Optional[str] = None,
# Entry: entity_id
entity_id: Optional[UUID] = None,
# Entry: user_id
user_id: Optional[UUID] = None,
) -> list[Worklog]:
"""List worklogs with optional filters."""
# Assign query = db.query(Worklog)
query = db.query(Worklog)
# Check: entity_type
if entity_type:
# Assign query = query.filter(Worklog.entity_type == entity_type)
query = query.filter(Worklog.entity_type == entity_type)
# Check: entity_id
if entity_id:
# Assign query = query.filter(Worklog.entity_id == entity_id)
query = query.filter(Worklog.entity_id == entity_id)
# Check: user_id
if user_id:
# Assign query = query.filter(Worklog.user_id == user_id)
query = query.filter(Worklog.user_id == user_id)
# Return query.order_by(Worklog.started_at.desc()).all()
return query.order_by(Worklog.started_at.desc()).all()
# Define function verify_worklog_integrity
def verify_worklog_integrity(wl: Worklog) -> bool:
"""Return True if the worklog has not been tampered with."""
# Return wl.integrity_hash == _compute_hash(wl)
return wl.integrity_hash == _compute_hash(wl)
# Define function _compute_hash
def _compute_hash(wl: Worklog) -> str:
"""SHA-256 of the immutable fields for audit integrity."""
# Assign data = (
data = (
f"{wl.entity_type}:{wl.entity_id}:{wl.user_id}:"
f"{wl.activity_type}:{wl.started_at}:{wl.duration_seconds}"
)
# Return hashlib.sha256(data.encode()).hexdigest()
return hashlib.sha256(data.encode()).hexdigest()