feat(refactor): PEP8, type annotations, docstrings and PyJWT security fix
This commit is contained in:
@@ -0,0 +1 @@
|
||||
"""Service layer — business logic orchestrating domain entities and persistence."""
|
||||
|
||||
@@ -1,19 +1,31 @@
|
||||
"""Advanced metrics service — coverage by tactic, never-tested, avg validation time, detection trend."""
|
||||
|
||||
# Enable future language features for compatibility
|
||||
from __future__ import annotations
|
||||
|
||||
# Import datetime, timedelta from datetime
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
# Import case, func from sqlalchemy
|
||||
from sqlalchemy import case, func
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.technique import Technique
|
||||
from app.models.test import Test
|
||||
# Import TestResult from app.models.enums
|
||||
from app.models.enums import TestResult
|
||||
|
||||
# Import Technique from app.models.technique
|
||||
from app.models.technique import Technique
|
||||
|
||||
# Import Test from app.models.test
|
||||
from app.models.test import Test
|
||||
|
||||
|
||||
# Define function get_coverage_by_tactic
|
||||
def get_coverage_by_tactic(db: Session) -> list[dict]:
|
||||
"""Coverage percentage broken down by MITRE ATT&CK tactic."""
|
||||
# Assign results = (
|
||||
results = (
|
||||
db.query(
|
||||
Technique.tactic,
|
||||
@@ -31,134 +43,211 @@ def get_coverage_by_tactic(db: Session) -> list[dict]:
|
||||
case((Technique.status_global == "in_progress", 1), else_=0)
|
||||
).label("in_progress"),
|
||||
)
|
||||
# Chain .group_by() call
|
||||
.group_by(Technique.tactic)
|
||||
# Chain .order_by() call
|
||||
.order_by(Technique.tactic)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
# Return [
|
||||
return [
|
||||
{
|
||||
# Literal argument value
|
||||
"tactic": r[0] or "Unknown",
|
||||
# Literal argument value
|
||||
"total": r[1],
|
||||
# Literal argument value
|
||||
"validated": int(r[2]),
|
||||
# Literal argument value
|
||||
"partial": int(r[3]),
|
||||
# Literal argument value
|
||||
"not_covered": int(r[4]),
|
||||
# Literal argument value
|
||||
"in_progress": int(r[5]),
|
||||
# Literal argument value
|
||||
"coverage_pct": round((int(r[2]) / r[1]) * 100, 1) if r[1] > 0 else 0,
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
|
||||
|
||||
# Define function get_never_tested_techniques
|
||||
def get_never_tested_techniques(db: Session) -> list[dict]:
|
||||
"""Techniques that have never had a test created."""
|
||||
# Assign tested_ids = [
|
||||
tested_ids = [
|
||||
row[0]
|
||||
for row in db.query(Test.technique_id)
|
||||
# Chain .filter() call
|
||||
.filter(Test.technique_id.isnot(None))
|
||||
# Chain .distinct() call
|
||||
.distinct()
|
||||
# Chain .all() call
|
||||
.all()
|
||||
]
|
||||
# Assign query = db.query(Technique)
|
||||
query = db.query(Technique)
|
||||
# Check: tested_ids
|
||||
if tested_ids:
|
||||
# Assign query = query.filter(~Technique.id.in_(tested_ids))
|
||||
query = query.filter(~Technique.id.in_(tested_ids))
|
||||
# Assign techniques = query.order_by(Technique.mitre_id).all()
|
||||
techniques = query.order_by(Technique.mitre_id).all()
|
||||
# Return [
|
||||
return [
|
||||
{
|
||||
# Literal argument value
|
||||
"mitre_id": t.mitre_id,
|
||||
# Literal argument value
|
||||
"name": t.name,
|
||||
# Literal argument value
|
||||
"tactic": t.tactic,
|
||||
# Literal argument value
|
||||
"is_subtechnique": t.is_subtechnique,
|
||||
}
|
||||
for t in techniques
|
||||
]
|
||||
|
||||
|
||||
# Define function get_avg_validation_time
|
||||
def get_avg_validation_time(db: Session) -> dict:
|
||||
"""Average time from test creation to validation, computed from validated tests.
|
||||
|
||||
Returns overall average and per-phase averages where data is available.
|
||||
"""
|
||||
# Assign validated_tests = (
|
||||
validated_tests = (
|
||||
db.query(Test)
|
||||
# Chain .filter() call
|
||||
.filter(Test.state == "validated")
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
|
||||
# Check: not validated_tests
|
||||
if not validated_tests:
|
||||
# Return {
|
||||
return {
|
||||
# Literal argument value
|
||||
"total_validated": 0,
|
||||
# Literal argument value
|
||||
"avg_total_hours": 0,
|
||||
# Literal argument value
|
||||
"avg_red_phase_hours": 0,
|
||||
# Literal argument value
|
||||
"avg_blue_phase_hours": 0,
|
||||
}
|
||||
|
||||
# Assign total_durations = []
|
||||
total_durations = []
|
||||
# Assign red_durations = []
|
||||
red_durations = []
|
||||
# Assign blue_durations = []
|
||||
blue_durations = []
|
||||
|
||||
# Iterate over validated_tests
|
||||
for test in validated_tests:
|
||||
# Check: test.created_at and test.red_validated_at
|
||||
if test.created_at and test.red_validated_at:
|
||||
# Assign total_seconds = (test.red_validated_at - test.created_at).total_seconds()
|
||||
total_seconds = (test.red_validated_at - test.created_at).total_seconds()
|
||||
# Call total_durations.append()
|
||||
total_durations.append(total_seconds)
|
||||
|
||||
# Check: test.red_started_at and test.blue_started_at
|
||||
if test.red_started_at and test.blue_started_at:
|
||||
# Assign red_sec = (test.blue_started_at - test.red_started_at).total_seconds()
|
||||
red_sec = (test.blue_started_at - test.red_started_at).total_seconds()
|
||||
# Assign red_paused = test.red_paused_seconds or 0
|
||||
red_paused = test.red_paused_seconds or 0
|
||||
# Call red_durations.append()
|
||||
red_durations.append(max(red_sec - red_paused, 0))
|
||||
|
||||
# Check: test.blue_started_at and test.blue_validated_at
|
||||
if test.blue_started_at and test.blue_validated_at:
|
||||
# Assign blue_sec = (test.blue_validated_at - test.blue_started_at).total_seconds()
|
||||
blue_sec = (test.blue_validated_at - test.blue_started_at).total_seconds()
|
||||
# Assign blue_paused = test.blue_paused_seconds or 0
|
||||
blue_paused = test.blue_paused_seconds or 0
|
||||
# Call blue_durations.append()
|
||||
blue_durations.append(max(blue_sec - blue_paused, 0))
|
||||
|
||||
# Define function avg_hours
|
||||
def avg_hours(durations: list[float]) -> float:
|
||||
# Check: not durations
|
||||
if not durations:
|
||||
# Return 0
|
||||
return 0
|
||||
# Return round(sum(durations) / len(durations) / 3600, 2)
|
||||
return round(sum(durations) / len(durations) / 3600, 2)
|
||||
|
||||
# Return {
|
||||
return {
|
||||
# Literal argument value
|
||||
"total_validated": len(validated_tests),
|
||||
# Literal argument value
|
||||
"avg_total_hours": avg_hours(total_durations),
|
||||
# Literal argument value
|
||||
"avg_red_phase_hours": avg_hours(red_durations),
|
||||
# Literal argument value
|
||||
"avg_blue_phase_hours": avg_hours(blue_durations),
|
||||
}
|
||||
|
||||
|
||||
# Define function get_detection_rate_trend
|
||||
def get_detection_rate_trend(db: Session) -> list[dict]:
|
||||
"""Monthly detection rate trend for the last 12 months."""
|
||||
# Assign now = datetime.utcnow()
|
||||
now = datetime.utcnow()
|
||||
# Assign months = []
|
||||
months = []
|
||||
|
||||
# Iterate over range(11, -1, -1)
|
||||
for i in range(11, -1, -1):
|
||||
# Assign month_start = datetime(now.year, now.month, 1) - timedelta(days=i * 30)
|
||||
month_start = datetime(now.year, now.month, 1) - timedelta(days=i * 30)
|
||||
# Assign month_end = month_start + timedelta(days=30)
|
||||
month_end = month_start + timedelta(days=30)
|
||||
|
||||
# Assign validated = (
|
||||
validated = (
|
||||
db.query(func.count(Test.id))
|
||||
# Chain .filter() call
|
||||
.filter(
|
||||
Test.state == "validated",
|
||||
Test.created_at >= month_start,
|
||||
Test.created_at < month_end,
|
||||
)
|
||||
# Chain .scalar() call
|
||||
.scalar() or 0
|
||||
)
|
||||
|
||||
# Assign detected = (
|
||||
detected = (
|
||||
db.query(func.count(Test.id))
|
||||
# Chain .filter() call
|
||||
.filter(
|
||||
Test.state == "validated",
|
||||
Test.detection_result == TestResult.detected,
|
||||
Test.created_at >= month_start,
|
||||
Test.created_at < month_end,
|
||||
)
|
||||
# Chain .scalar() call
|
||||
.scalar() or 0
|
||||
)
|
||||
|
||||
# Call months.append()
|
||||
months.append({
|
||||
# Literal argument value
|
||||
"month": month_start.strftime("%Y-%m"),
|
||||
# Literal argument value
|
||||
"validated": validated,
|
||||
# Literal argument value
|
||||
"detected": detected,
|
||||
# Literal argument value
|
||||
"detection_rate": round((detected / validated) * 100, 1) if validated > 0 else 0,
|
||||
})
|
||||
|
||||
# Return months
|
||||
return months
|
||||
|
||||
@@ -1,28 +1,50 @@
|
||||
"""Analytics service — flat JSON optimized for PowerBI / BI tools."""
|
||||
|
||||
# Enable future language features for compatibility
|
||||
from __future__ import annotations
|
||||
|
||||
# Import func from sqlalchemy
|
||||
from sqlalchemy import func
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import CoverageSnapshot from app.models.coverage_snapshot
|
||||
from app.models.coverage_snapshot import CoverageSnapshot
|
||||
|
||||
# Import Technique from app.models.technique
|
||||
from app.models.technique import Technique
|
||||
|
||||
# Import Test from app.models.test
|
||||
from app.models.test import Test
|
||||
|
||||
# Import User from app.models.user
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
# Define function get_coverage_analytics
|
||||
def get_coverage_analytics(db: Session) -> list[dict]:
|
||||
"""Coverage per technique — flat format for BI dashboards."""
|
||||
# Assign techniques = db.query(Technique).all()
|
||||
techniques = db.query(Technique).all()
|
||||
# Return [
|
||||
return [
|
||||
{
|
||||
# Literal argument value
|
||||
"mitre_id": t.mitre_id,
|
||||
# Literal argument value
|
||||
"name": t.name,
|
||||
# Literal argument value
|
||||
"tactic": t.tactic,
|
||||
# Literal argument value
|
||||
"status": t.status_global.value if t.status_global else "not_evaluated",
|
||||
# Literal argument value
|
||||
"is_subtechnique": t.is_subtechnique,
|
||||
# Literal argument value
|
||||
"test_count": len(t.tests) if t.tests else 0,
|
||||
# Literal argument value
|
||||
"review_required": t.review_required,
|
||||
# Literal argument value
|
||||
"last_review_date": (
|
||||
t.last_review_date.isoformat() if t.last_review_date else None
|
||||
),
|
||||
@@ -31,76 +53,117 @@ def get_coverage_analytics(db: Session) -> list[dict]:
|
||||
]
|
||||
|
||||
|
||||
# Define function get_tests_analytics
|
||||
def get_tests_analytics(
|
||||
# Entry: db
|
||||
db: Session,
|
||||
*,
|
||||
# Entry: date_from
|
||||
date_from: str | None = None,
|
||||
# Entry: date_to
|
||||
date_to: str | None = None,
|
||||
) -> list[dict]:
|
||||
"""All tests with timestamps — flat format for BI dashboards."""
|
||||
# Assign query = db.query(Test)
|
||||
query = db.query(Test)
|
||||
# Check: date_from
|
||||
if date_from:
|
||||
# Assign query = query.filter(Test.created_at >= date_from)
|
||||
query = query.filter(Test.created_at >= date_from)
|
||||
# Check: date_to
|
||||
if date_to:
|
||||
# Assign query = query.filter(Test.created_at <= date_to)
|
||||
query = query.filter(Test.created_at <= date_to)
|
||||
# Assign tests = query.all()
|
||||
tests = query.all()
|
||||
|
||||
# Return [
|
||||
return [
|
||||
{
|
||||
# Literal argument value
|
||||
"id": str(t.id),
|
||||
# Literal argument value
|
||||
"technique_id": str(t.technique_id),
|
||||
# Literal argument value
|
||||
"name": t.name,
|
||||
# Literal argument value
|
||||
"state": t.state.value if t.state else None,
|
||||
# Literal argument value
|
||||
"result": t.result.value if t.result else None,
|
||||
# Literal argument value
|
||||
"detection_result": (
|
||||
t.detection_result.value if t.detection_result else None
|
||||
),
|
||||
# Literal argument value
|
||||
"created_at": t.created_at.isoformat() if t.created_at else None,
|
||||
# Literal argument value
|
||||
"execution_date": (
|
||||
t.execution_date.isoformat() if t.execution_date else None
|
||||
),
|
||||
# Literal argument value
|
||||
"platform": t.platform,
|
||||
# Literal argument value
|
||||
"tool_used": t.tool_used,
|
||||
# Literal argument value
|
||||
"attack_success": t.attack_success,
|
||||
# Literal argument value
|
||||
"remediation_status": t.remediation_status,
|
||||
}
|
||||
for t in tests
|
||||
]
|
||||
|
||||
|
||||
# Define function get_trends_analytics
|
||||
def get_trends_analytics(db: Session) -> list[dict]:
|
||||
"""Historical coverage snapshots for trend visualization."""
|
||||
# Assign snapshots = (
|
||||
snapshots = (
|
||||
db.query(CoverageSnapshot)
|
||||
# Chain .order_by() call
|
||||
.order_by(CoverageSnapshot.created_at)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
# Return [
|
||||
return [
|
||||
{
|
||||
# Literal argument value
|
||||
"date": s.created_at.isoformat() if s.created_at else None,
|
||||
# Literal argument value
|
||||
"name": s.name,
|
||||
# Literal argument value
|
||||
"total_techniques": s.total_techniques,
|
||||
# Literal argument value
|
||||
"validated_count": s.validated_count,
|
||||
# Literal argument value
|
||||
"partial_count": s.partial_count,
|
||||
# Literal argument value
|
||||
"not_covered_count": s.not_covered_count,
|
||||
# Literal argument value
|
||||
"organization_score": s.organization_score,
|
||||
}
|
||||
for s in snapshots
|
||||
]
|
||||
|
||||
|
||||
# Define function get_operators_analytics
|
||||
def get_operators_analytics(db: Session) -> list[dict]:
|
||||
"""Per-operator metrics — for workload management dashboards."""
|
||||
# Assign results = (
|
||||
results = (
|
||||
db.query(
|
||||
User.username,
|
||||
User.role,
|
||||
func.count(Test.id).label("test_count"),
|
||||
)
|
||||
# Chain .outerjoin() call
|
||||
.outerjoin(Test, Test.created_by == User.id)
|
||||
# Chain .group_by() call
|
||||
.group_by(User.id, User.username, User.role)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
# Return [
|
||||
return [
|
||||
{"username": r[0], "role": r[1], "test_count": r[2]}
|
||||
for r in results
|
||||
|
||||
@@ -22,22 +22,39 @@ Running the import twice does **not** create duplicates. Existing
|
||||
templates are identified by their ``atomic_test_id`` and simply skipped.
|
||||
"""
|
||||
|
||||
# Import io
|
||||
import io
|
||||
|
||||
# Import logging
|
||||
import logging
|
||||
import os
|
||||
|
||||
# Import shutil
|
||||
import shutil
|
||||
|
||||
# Import tempfile
|
||||
import tempfile
|
||||
|
||||
# Import zipfile
|
||||
import zipfile
|
||||
|
||||
# Import Path from pathlib
|
||||
from pathlib import Path
|
||||
|
||||
# Import requests
|
||||
import requests as _requests
|
||||
|
||||
# Import yaml
|
||||
import yaml
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import TestTemplate from app.models.test_template
|
||||
from app.models.test_template import TestTemplate
|
||||
from app.models.technique import Technique
|
||||
from app.services.audit_service import log_action
|
||||
|
||||
# Assign logger = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -45,7 +62,9 @@ logger = logging.getLogger(__name__)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
ATOMIC_RT_ZIP_URL = (
|
||||
# Literal argument value
|
||||
"https://github.com/redcanaryco/atomic-red-team"
|
||||
# Literal argument value
|
||||
"/archive/refs/heads/master.zip"
|
||||
)
|
||||
|
||||
@@ -55,6 +74,11 @@ _DOWNLOAD_TIMEOUT = 300
|
||||
# Top-level directory name inside the ZIP
|
||||
_ZIP_ROOT_PREFIX = "atomic-red-team-master"
|
||||
|
||||
# Safety limits for ZIP extraction — prevent zip-bomb DoS
|
||||
_MAX_UNCOMPRESSED_SIZE = 500 * 1024 * 1024 # 500 MB
|
||||
# Assign _MAX_ENTRIES = 50_000
|
||||
_MAX_ENTRIES = 50_000
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
@@ -63,14 +87,21 @@ _ZIP_ROOT_PREFIX = "atomic-red-team-master"
|
||||
|
||||
def _download_zip(url: str = ATOMIC_RT_ZIP_URL) -> bytes:
|
||||
"""Download the Atomic Red Team ZIP and return its raw bytes."""
|
||||
# Log info: "Downloading Atomic Red Team ZIP from %s …", url
|
||||
logger.info("Downloading Atomic Red Team ZIP from %s …", url)
|
||||
# Assign resp = _requests.get(url, timeout=_DOWNLOAD_TIMEOUT, stream=True)
|
||||
resp = _requests.get(url, timeout=_DOWNLOAD_TIMEOUT, stream=True)
|
||||
# Call resp.raise_for_status()
|
||||
resp.raise_for_status()
|
||||
# Assign content = resp.content
|
||||
content = resp.content
|
||||
# Log info: "Downloaded %.1f MB", len(content) / (1024 * 1024
|
||||
logger.info("Downloaded %.1f MB", len(content) / (1024 * 1024))
|
||||
# Return content
|
||||
return content
|
||||
|
||||
|
||||
# Define function _safe_extract_zip
|
||||
def _safe_extract_zip(zip_bytes: bytes, dest: str) -> None:
|
||||
"""Extract *zip_bytes* into *dest* with Zip Slip and Zip Bomb protection.
|
||||
|
||||
@@ -78,51 +109,66 @@ def _safe_extract_zip(zip_bytes: bytes, dest: str) -> None:
|
||||
directory (path traversal / Zip Slip) or if the archive exceeds the
|
||||
safety limits.
|
||||
"""
|
||||
# Maximum uncompressed size: 500 MB — prevents zip-bomb DoS
|
||||
_MAX_UNCOMPRESSED_SIZE = 500 * 1024 * 1024
|
||||
# Maximum number of entries
|
||||
_MAX_ENTRIES = 50_000
|
||||
|
||||
# Assign dest_path = Path(dest).resolve()
|
||||
dest_path = Path(dest).resolve()
|
||||
|
||||
# Open context manager
|
||||
with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
|
||||
# Assign entries = zf.infolist()
|
||||
entries = zf.infolist()
|
||||
|
||||
# Check: len(entries) > _MAX_ENTRIES
|
||||
if len(entries) > _MAX_ENTRIES:
|
||||
# Raise ValueError
|
||||
raise ValueError(
|
||||
f"ZIP archive contains {len(entries)} entries "
|
||||
f"(limit: {_MAX_ENTRIES}) — possible zip bomb"
|
||||
)
|
||||
|
||||
# Assign total_size = sum(info.file_size for info in entries)
|
||||
total_size = sum(info.file_size for info in entries)
|
||||
# Check: total_size > _MAX_UNCOMPRESSED_SIZE
|
||||
if total_size > _MAX_UNCOMPRESSED_SIZE:
|
||||
# Raise ValueError
|
||||
raise ValueError(
|
||||
f"ZIP uncompressed size {total_size / (1024 * 1024):.0f} MB "
|
||||
f"exceeds limit of {_MAX_UNCOMPRESSED_SIZE / (1024 * 1024):.0f} MB"
|
||||
)
|
||||
|
||||
# Iterate over entries
|
||||
for member in entries:
|
||||
# Assign target = (dest_path / member.filename).resolve()
|
||||
target = (dest_path / member.filename).resolve()
|
||||
# Check: not target.is_relative_to(dest_path)
|
||||
if not target.is_relative_to(dest_path):
|
||||
# Raise ValueError
|
||||
raise ValueError(
|
||||
f"Zip Slip detected — member '{member.filename}' "
|
||||
f"resolves outside target directory"
|
||||
)
|
||||
|
||||
# Call zf.extractall()
|
||||
zf.extractall(dest)
|
||||
|
||||
|
||||
# Define function _extract_zip
|
||||
def _extract_zip(zip_bytes: bytes, dest: str) -> Path:
|
||||
"""Extract *zip_bytes* into *dest* and return the path to the atomics/ dir."""
|
||||
# Call _safe_extract_zip()
|
||||
_safe_extract_zip(zip_bytes, dest)
|
||||
# Assign atomics_dir = Path(dest) / _ZIP_ROOT_PREFIX / "atomics"
|
||||
atomics_dir = Path(dest) / _ZIP_ROOT_PREFIX / "atomics"
|
||||
# Check: not atomics_dir.is_dir()
|
||||
if not atomics_dir.is_dir():
|
||||
# Raise FileNotFoundError
|
||||
raise FileNotFoundError(
|
||||
f"Expected atomics directory not found at {atomics_dir}"
|
||||
)
|
||||
# Return atomics_dir
|
||||
return atomics_dir
|
||||
|
||||
|
||||
# Define function _parse_yaml_files
|
||||
def _parse_yaml_files(atomics_dir: Path) -> list[dict]:
|
||||
"""Walk the atomics directory and parse all technique YAML files.
|
||||
|
||||
@@ -132,51 +178,84 @@ def _parse_yaml_files(atomics_dir: Path) -> list[dict]:
|
||||
technique_id, index, name, description, platforms,
|
||||
executor_type, command, source_url
|
||||
"""
|
||||
# Assign results = []
|
||||
results: list[dict] = []
|
||||
# Assign yaml_files = sorted(atomics_dir.glob("T*/T*.yaml"))
|
||||
yaml_files = sorted(atomics_dir.glob("T*/T*.yaml"))
|
||||
# Log info: "Found %d YAML files to parse", len(yaml_files
|
||||
logger.info("Found %d YAML files to parse", len(yaml_files))
|
||||
|
||||
# Iterate over yaml_files
|
||||
for yaml_path in yaml_files:
|
||||
# Assign technique_id = yaml_path.stem # e.g. "T1059.001"
|
||||
technique_id = yaml_path.stem # e.g. "T1059.001"
|
||||
# Attempt the following; catch errors below
|
||||
try:
|
||||
# Open context manager
|
||||
with open(yaml_path, "r", encoding="utf-8") as fh:
|
||||
# Assign data = yaml.safe_load(fh)
|
||||
data = yaml.safe_load(fh)
|
||||
# Handle Exception
|
||||
except Exception as exc:
|
||||
# Log warning: "Failed to parse %s: %s", yaml_path, exc
|
||||
logger.warning("Failed to parse %s: %s", yaml_path, exc)
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
|
||||
# Check: not data or "atomic_tests" not in data
|
||||
if not data or "atomic_tests" not in data:
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
|
||||
# Iterate over enumerate(data["atomic_tests"])
|
||||
for idx, test in enumerate(data["atomic_tests"]):
|
||||
# Assign name = test.get("name", "").strip()
|
||||
name = test.get("name", "").strip()
|
||||
# Assign description = test.get("description", "").strip()
|
||||
description = test.get("description", "").strip()
|
||||
# Assign platforms = test.get("supported_platforms", [])
|
||||
platforms = test.get("supported_platforms", [])
|
||||
# Assign executor = test.get("executor", {})
|
||||
executor = test.get("executor", {})
|
||||
# Assign executor_type = executor.get("name", "") if isinstance(executor, dict) else ""
|
||||
executor_type = executor.get("name", "") if isinstance(executor, dict) else ""
|
||||
# Assign command = executor.get("command", "") if isinstance(executor, dict) else ""
|
||||
command = executor.get("command", "") if isinstance(executor, dict) else ""
|
||||
|
||||
# Build an atomic_test_id in the format "T1059.001-0"
|
||||
atomic_test_id = f"{technique_id}-{idx}"
|
||||
|
||||
# Assign source_url = (
|
||||
source_url = (
|
||||
f"https://github.com/redcanaryco/atomic-red-team/blob/master"
|
||||
f"/atomics/{technique_id}/{technique_id}.yaml"
|
||||
)
|
||||
|
||||
# Call results.append()
|
||||
results.append({
|
||||
# Literal argument value
|
||||
"technique_id": technique_id,
|
||||
# Literal argument value
|
||||
"index": idx,
|
||||
# Literal argument value
|
||||
"atomic_test_id": atomic_test_id,
|
||||
# Literal argument value
|
||||
"name": name,
|
||||
# Literal argument value
|
||||
"description": description,
|
||||
# Literal argument value
|
||||
"platforms": ", ".join(platforms) if isinstance(platforms, list) else str(platforms),
|
||||
# Literal argument value
|
||||
"executor_type": executor_type,
|
||||
# Literal argument value
|
||||
"command": command[:4000] if command else None, # cap at 4k chars
|
||||
# Literal argument value
|
||||
"source_url": source_url,
|
||||
})
|
||||
|
||||
# Log info: "Parsed %d atomic tests total", len(results
|
||||
logger.info("Parsed %d atomic tests total", len(results))
|
||||
# Return results
|
||||
return results
|
||||
|
||||
|
||||
@@ -193,52 +272,80 @@ def import_atomic_red_team(db: Session) -> dict:
|
||||
db : Session
|
||||
Active SQLAlchemy database session.
|
||||
|
||||
Returns
|
||||
Returns:
|
||||
-------
|
||||
dict
|
||||
Summary with keys ``created``, ``skipped_existing``,
|
||||
``yaml_files_parsed``, ``total_tests_parsed``.
|
||||
"""
|
||||
# Assign tmp_dir = tempfile.mkdtemp(prefix="aegis_atomic_")
|
||||
tmp_dir = tempfile.mkdtemp(prefix="aegis_atomic_")
|
||||
# Attempt the following; catch errors below
|
||||
try:
|
||||
# Assign zip_bytes = _download_zip()
|
||||
zip_bytes = _download_zip()
|
||||
# Assign atomics_dir = _extract_zip(zip_bytes, tmp_dir)
|
||||
atomics_dir = _extract_zip(zip_bytes, tmp_dir)
|
||||
# Assign parsed_tests = _parse_yaml_files(atomics_dir)
|
||||
parsed_tests = _parse_yaml_files(atomics_dir)
|
||||
# Always execute this cleanup block
|
||||
finally:
|
||||
# Always clean up
|
||||
shutil.rmtree(tmp_dir, ignore_errors=True)
|
||||
# Log info: "Cleaned up temp directory %s", tmp_dir
|
||||
logger.info("Cleaned up temp directory %s", tmp_dir)
|
||||
|
||||
# Pre-load existing atomic_test_ids for dedup
|
||||
existing_ids: set[str] = {
|
||||
row[0]
|
||||
for row in db.query(TestTemplate.atomic_test_id)
|
||||
# Chain .filter() call
|
||||
.filter(TestTemplate.atomic_test_id.isnot(None))
|
||||
# Chain .all() call
|
||||
.all()
|
||||
}
|
||||
|
||||
# Assign created = 0
|
||||
created = 0
|
||||
# Assign skipped = 0
|
||||
skipped = 0
|
||||
new_technique_ids: set[str] = set()
|
||||
|
||||
# Iterate over parsed_tests
|
||||
for item in parsed_tests:
|
||||
# Check: item["atomic_test_id"] in existing_ids
|
||||
if item["atomic_test_id"] in existing_ids:
|
||||
# Assign skipped = 1
|
||||
skipped += 1
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
|
||||
# Assign template = TestTemplate(
|
||||
template = TestTemplate(
|
||||
# Keyword argument: mitre_technique_id
|
||||
mitre_technique_id=item["technique_id"],
|
||||
# Keyword argument: name
|
||||
name=item["name"][:500] if item["name"] else f"Atomic Test {item['atomic_test_id']}",
|
||||
# Keyword argument: description
|
||||
description=item["description"][:2000] if item["description"] else None,
|
||||
# Keyword argument: source
|
||||
source="atomic_red_team",
|
||||
# Keyword argument: source_url
|
||||
source_url=item["source_url"],
|
||||
# Keyword argument: attack_procedure
|
||||
attack_procedure=item["command"],
|
||||
# Keyword argument: platform
|
||||
platform=item["platforms"],
|
||||
# Keyword argument: tool_suggested
|
||||
tool_suggested=item["executor_type"] if item["executor_type"] else None,
|
||||
# Keyword argument: atomic_test_id
|
||||
atomic_test_id=item["atomic_test_id"],
|
||||
# Keyword argument: is_active
|
||||
is_active=True,
|
||||
)
|
||||
# Stage new record(s) for database insertion
|
||||
db.add(template)
|
||||
# Call existing_ids.add()
|
||||
existing_ids.add(item["atomic_test_id"])
|
||||
new_technique_ids.add(item["technique_id"])
|
||||
created += 1
|
||||
@@ -253,15 +360,23 @@ def import_atomic_red_team(db: Session) -> dict:
|
||||
# Count distinct YAML files by technique_id
|
||||
yaml_files_count = len({t["technique_id"] for t in parsed_tests})
|
||||
|
||||
# Assign summary = {
|
||||
summary = {
|
||||
# Literal argument value
|
||||
"created": created,
|
||||
# Literal argument value
|
||||
"skipped_existing": skipped,
|
||||
# Literal argument value
|
||||
"yaml_files_parsed": yaml_files_count,
|
||||
# Literal argument value
|
||||
"total_tests_parsed": len(parsed_tests),
|
||||
}
|
||||
|
||||
# Log info:
|
||||
logger.info(
|
||||
# Literal argument value
|
||||
"Atomic Red Team import complete — created=%d, skipped=%d, "
|
||||
# Literal argument value
|
||||
"yaml_files=%d, total_tests=%d",
|
||||
created, skipped, yaml_files_count, len(parsed_tests),
|
||||
)
|
||||
@@ -269,12 +384,19 @@ def import_atomic_red_team(db: Session) -> dict:
|
||||
# Audit log (system action)
|
||||
log_action(
|
||||
db,
|
||||
# Keyword argument: user_id
|
||||
user_id=None,
|
||||
# Keyword argument: action
|
||||
action="import_atomic_red_team",
|
||||
# Keyword argument: entity_type
|
||||
entity_type="test_template",
|
||||
# Keyword argument: entity_id
|
||||
entity_id=None,
|
||||
# Keyword argument: details
|
||||
details=summary,
|
||||
)
|
||||
# Commit all pending changes to the database
|
||||
db.commit()
|
||||
|
||||
# Return summary
|
||||
return summary
|
||||
|
||||
@@ -4,24 +4,37 @@ Provides paginated logs and distinct action/entity-type lists.
|
||||
No FastAPI imports.
|
||||
"""
|
||||
|
||||
# Enable future language features for compatibility
|
||||
from __future__ import annotations
|
||||
|
||||
# Import datetime from datetime
|
||||
from datetime import datetime
|
||||
|
||||
# Import Session, joinedload from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
# Import AuditLog from app.models.audit
|
||||
from app.models.audit import AuditLog
|
||||
|
||||
|
||||
# Define function list_logs
|
||||
def list_logs(
|
||||
# Entry: db
|
||||
db: Session,
|
||||
*,
|
||||
# Entry: user_id
|
||||
user_id: str | None = None,
|
||||
# Entry: action
|
||||
action: str | None = None,
|
||||
# Entry: entity_type
|
||||
entity_type: str | None = None,
|
||||
# Entry: start_date
|
||||
start_date: datetime | None = None,
|
||||
# Entry: end_date
|
||||
end_date: datetime | None = None,
|
||||
# Entry: offset
|
||||
offset: int = 0,
|
||||
# Entry: limit
|
||||
limit: int = 50,
|
||||
) -> dict:
|
||||
"""Return paginated audit logs with optional filters.
|
||||
@@ -30,64 +43,104 @@ def list_logs(
|
||||
Each item is a dict with: id, user_id, username, action, entity_type,
|
||||
entity_id, timestamp, details.
|
||||
"""
|
||||
# Assign query = db.query(AuditLog).options(joinedload(AuditLog.user))
|
||||
query = db.query(AuditLog).options(joinedload(AuditLog.user))
|
||||
|
||||
# Check: user_id
|
||||
if user_id:
|
||||
# Assign query = query.filter(AuditLog.user_id == user_id)
|
||||
query = query.filter(AuditLog.user_id == user_id)
|
||||
# Check: action
|
||||
if action:
|
||||
# Assign query = query.filter(AuditLog.action == action)
|
||||
query = query.filter(AuditLog.action == action)
|
||||
# Check: entity_type
|
||||
if entity_type:
|
||||
# Assign query = query.filter(AuditLog.entity_type == entity_type)
|
||||
query = query.filter(AuditLog.entity_type == entity_type)
|
||||
# Check: start_date
|
||||
if start_date:
|
||||
# Assign query = query.filter(AuditLog.timestamp >= start_date)
|
||||
query = query.filter(AuditLog.timestamp >= start_date)
|
||||
# Check: end_date
|
||||
if end_date:
|
||||
# Assign query = query.filter(AuditLog.timestamp <= end_date)
|
||||
query = query.filter(AuditLog.timestamp <= end_date)
|
||||
|
||||
# Assign total = query.count()
|
||||
total = query.count()
|
||||
|
||||
# Assign logs = (
|
||||
logs = (
|
||||
query
|
||||
# Chain .order_by() call
|
||||
.order_by(AuditLog.timestamp.desc())
|
||||
# Chain .offset() call
|
||||
.offset(offset)
|
||||
# Chain .limit() call
|
||||
.limit(limit)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
|
||||
# Assign items = [
|
||||
items = [
|
||||
{
|
||||
# Literal argument value
|
||||
"id": log.id,
|
||||
# Literal argument value
|
||||
"user_id": log.user_id,
|
||||
# Literal argument value
|
||||
"username": log.user.username if log.user else None,
|
||||
# Literal argument value
|
||||
"action": log.action,
|
||||
# Literal argument value
|
||||
"entity_type": log.entity_type,
|
||||
# Literal argument value
|
||||
"entity_id": log.entity_id,
|
||||
# Literal argument value
|
||||
"timestamp": log.timestamp,
|
||||
# Literal argument value
|
||||
"details": log.details,
|
||||
}
|
||||
for log in logs
|
||||
]
|
||||
|
||||
# Return {"items": items, "total": total, "offset": offset, "limit": limit}
|
||||
return {"items": items, "total": total, "offset": offset, "limit": limit}
|
||||
|
||||
|
||||
# Define function list_distinct_actions
|
||||
def list_distinct_actions(db: Session) -> list[str]:
|
||||
"""Return a list of distinct action types in the audit log."""
|
||||
# Assign actions = (
|
||||
actions = (
|
||||
db.query(AuditLog.action)
|
||||
# Chain .distinct() call
|
||||
.distinct()
|
||||
# Chain .order_by() call
|
||||
.order_by(AuditLog.action)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
# Return [a[0] for a in actions]
|
||||
return [a[0] for a in actions]
|
||||
|
||||
|
||||
# Define function list_distinct_entity_types
|
||||
def list_distinct_entity_types(db: Session) -> list[str]:
|
||||
"""Return a list of distinct entity types in the audit log."""
|
||||
# Assign types = (
|
||||
types = (
|
||||
db.query(AuditLog.entity_type)
|
||||
# Chain .filter() call
|
||||
.filter(AuditLog.entity_type.isnot(None))
|
||||
# Chain .distinct() call
|
||||
.distinct()
|
||||
# Chain .order_by() call
|
||||
.order_by(AuditLog.entity_type)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
# Return [t[0] for t in types]
|
||||
return [t[0] for t in types]
|
||||
|
||||
@@ -1,66 +1,116 @@
|
||||
"""Audit logging with request context and integrity hashing."""
|
||||
|
||||
# Enable future language features for compatibility
|
||||
from __future__ import annotations
|
||||
|
||||
# Import hashlib
|
||||
import hashlib
|
||||
|
||||
# Import datetime, timezone from datetime
|
||||
from datetime import datetime, timezone
|
||||
|
||||
# Import UUID from uuid
|
||||
from uuid import UUID
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import request_ip, request_user_agent from app.middleware.request_context
|
||||
from app.middleware.request_context import request_ip, request_user_agent
|
||||
|
||||
# Import AuditLog from app.models.audit
|
||||
from app.models.audit import AuditLog
|
||||
|
||||
|
||||
# Define function _integrity_payload
|
||||
def _integrity_payload(entry: AuditLog) -> str:
|
||||
# Assign ts = entry.timestamp
|
||||
ts = entry.timestamp
|
||||
# Check: ts is None
|
||||
if ts is None:
|
||||
# Assign ts = datetime.now(timezone.utc)
|
||||
ts = datetime.now(timezone.utc)
|
||||
# Assign user_part = str(entry.user_id) if entry.user_id else ""
|
||||
user_part = str(entry.user_id) if entry.user_id else ""
|
||||
# Assign entity_type = entry.entity_type or ""
|
||||
entity_type = entry.entity_type or ""
|
||||
# Assign entity_id = entry.entity_id or ""
|
||||
entity_id = entry.entity_id or ""
|
||||
# Return f"{user_part}:{entry.action}:{entity_type}:{entity_id}:{ts.isoforma...
|
||||
return f"{user_part}:{entry.action}:{entity_type}:{entity_id}:{ts.isoformat()}"
|
||||
|
||||
|
||||
# Define function compute_integrity_hash
|
||||
def compute_integrity_hash(entry: AuditLog) -> str:
|
||||
"""Return the SHA-256 hex digest for an audit log entry."""
|
||||
# Return hashlib.sha256(_integrity_payload(entry).encode()).hexdigest()
|
||||
return hashlib.sha256(_integrity_payload(entry).encode()).hexdigest()
|
||||
|
||||
|
||||
# Define function verify_audit_integrity
|
||||
def verify_audit_integrity(entry: AuditLog) -> bool:
|
||||
"""Return whether the stored hash matches the entry's current fields."""
|
||||
# Check: not entry.integrity_hash
|
||||
if not entry.integrity_hash:
|
||||
# Return False
|
||||
return False
|
||||
# Return entry.integrity_hash == compute_integrity_hash(entry)
|
||||
return entry.integrity_hash == compute_integrity_hash(entry)
|
||||
|
||||
|
||||
# Define function log_action
|
||||
def log_action(
|
||||
# Entry: db
|
||||
db: Session,
|
||||
user_id,
|
||||
# Entry: user_id
|
||||
user_id: UUID | None,
|
||||
# Entry: action
|
||||
action: str,
|
||||
# Entry: entity_type
|
||||
entity_type: str | None = None,
|
||||
# Entry: entity_id
|
||||
entity_id: str | None = None,
|
||||
# Entry: details
|
||||
details: dict | None = None,
|
||||
*,
|
||||
# Entry: ip_address
|
||||
ip_address: str | None = None,
|
||||
# Entry: user_agent
|
||||
user_agent: str | None = None,
|
||||
# Entry: session_id
|
||||
session_id: str | None = None,
|
||||
) -> AuditLog:
|
||||
"""Record an audit event. Does not commit — the caller owns the transaction."""
|
||||
# Assign ip = ip_address if ip_address is not None else request_ip.get("")
|
||||
ip = ip_address if ip_address is not None else request_ip.get("")
|
||||
# Assign ua = user_agent if user_agent is not None else request_user_agent.get("")
|
||||
ua = user_agent if user_agent is not None else request_user_agent.get("")
|
||||
|
||||
# Assign entry = AuditLog(
|
||||
entry = AuditLog(
|
||||
# Keyword argument: user_id
|
||||
user_id=user_id,
|
||||
# Keyword argument: action
|
||||
action=action,
|
||||
# Keyword argument: entity_type
|
||||
entity_type=entity_type,
|
||||
# Keyword argument: entity_id
|
||||
entity_id=str(entity_id) if entity_id else None,
|
||||
# Keyword argument: details
|
||||
details=details,
|
||||
# Keyword argument: ip_address
|
||||
ip_address=ip or None,
|
||||
# Keyword argument: user_agent
|
||||
user_agent=ua or None,
|
||||
# Keyword argument: session_id
|
||||
session_id=session_id,
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
# Stage new record(s) for database insertion
|
||||
db.add(entry)
|
||||
# Flush changes to DB without committing the transaction
|
||||
db.flush()
|
||||
# Assign entry.integrity_hash = compute_integrity_hash(entry)
|
||||
entry.integrity_hash = compute_integrity_hash(entry)
|
||||
# Return entry
|
||||
return entry
|
||||
|
||||
@@ -1,15 +1,24 @@
|
||||
"""Authentication service — credential validation and password management."""
|
||||
# Enable future language features for compatibility
|
||||
from __future__ import annotations
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import hash_password, verify_password from app.auth
|
||||
from app.auth import hash_password, verify_password
|
||||
|
||||
# Import BusinessRuleViolation, PermissionViolation from app.domain.errors
|
||||
from app.domain.errors import BusinessRuleViolation, PermissionViolation
|
||||
|
||||
# Import User from app.models.user
|
||||
from app.models.user import User
|
||||
|
||||
# Assign _DUMMY_HASH = "$2b$12$LJ3m4ys3Lg3dMO/NpNmOaeVwFpWJMxlB2FLmEAo9fZr.S8H1vC4Wy"
|
||||
_DUMMY_HASH = "$2b$12$LJ3m4ys3Lg3dMO/NpNmOaeVwFpWJMxlB2FLmEAo9fZr.S8H1vC4Wy"
|
||||
|
||||
|
||||
# Define function authenticate_user
|
||||
def authenticate_user(db: Session, *, username: str, password: str) -> User:
|
||||
"""Validate credentials and return the User.
|
||||
|
||||
@@ -17,33 +26,49 @@ def authenticate_user(db: Session, *, username: str, password: str) -> User:
|
||||
Raises PermissionViolation for disabled account.
|
||||
Uses constant-time comparison to prevent timing attacks.
|
||||
"""
|
||||
# Assign user = db.query(User).filter(User.username == username).first()
|
||||
user = db.query(User).filter(User.username == username).first()
|
||||
# Assign hashed = user.hashed_password if user else _DUMMY_HASH
|
||||
hashed = user.hashed_password if user else _DUMMY_HASH
|
||||
# Assign password_valid = verify_password(password, hashed)
|
||||
password_valid = verify_password(password, hashed)
|
||||
|
||||
# Check: user is None or not password_valid
|
||||
if user is None or not password_valid:
|
||||
# Raise BusinessRuleViolation
|
||||
raise BusinessRuleViolation("Incorrect username or password")
|
||||
# Check: not user.is_active
|
||||
if not user.is_active:
|
||||
# Raise PermissionViolation
|
||||
raise PermissionViolation("Account is disabled. Contact an administrator.")
|
||||
# Return user
|
||||
return user
|
||||
|
||||
|
||||
# Define function change_password
|
||||
def change_password(
|
||||
# Entry: db
|
||||
db: Session,
|
||||
# Entry: user
|
||||
user: User,
|
||||
*,
|
||||
# Entry: current_password
|
||||
current_password: str,
|
||||
# Entry: new_password
|
||||
new_password: str,
|
||||
) -> None:
|
||||
"""Change a user's password. Does NOT commit.
|
||||
|
||||
Raises BusinessRuleViolation if current password is wrong.
|
||||
"""
|
||||
# Check: not verify_password(current_password, user.hashed_password)
|
||||
if not verify_password(current_password, user.hashed_password):
|
||||
# Raise BusinessRuleViolation
|
||||
raise BusinessRuleViolation("Current password is incorrect")
|
||||
if verify_password(new_password, user.hashed_password):
|
||||
raise BusinessRuleViolation(
|
||||
"New password must be different from the current password"
|
||||
)
|
||||
user.hashed_password = hash_password(new_password)
|
||||
# Assign user.must_change_password = False
|
||||
user.must_change_password = False
|
||||
|
||||
@@ -21,23 +21,42 @@ templates are identified by ``source = "caldera"`` + ``atomic_test_id``
|
||||
(the CALDERA ability ``id``).
|
||||
"""
|
||||
|
||||
# Import io
|
||||
import io
|
||||
|
||||
# Import logging
|
||||
import logging
|
||||
|
||||
# Import shutil
|
||||
import shutil
|
||||
|
||||
# Import tempfile
|
||||
import tempfile
|
||||
|
||||
# Import zipfile
|
||||
import zipfile
|
||||
|
||||
# Import datetime from datetime
|
||||
from datetime import datetime
|
||||
|
||||
# Import Path from pathlib
|
||||
from pathlib import Path
|
||||
|
||||
# Import requests
|
||||
import requests as _requests
|
||||
|
||||
# Import yaml
|
||||
import yaml
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.test_template import TestTemplate
|
||||
# Import DataSource from app.models.data_source
|
||||
from app.models.data_source import DataSource
|
||||
from app.models.technique import Technique
|
||||
from app.services.audit_service import log_action
|
||||
|
||||
# Assign logger = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -45,11 +64,15 @@ logger = logging.getLogger(__name__)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
CALDERA_ZIP_URL = (
|
||||
# Literal argument value
|
||||
"https://github.com/mitre/stockpile"
|
||||
# Literal argument value
|
||||
"/archive/refs/heads/master.zip"
|
||||
)
|
||||
|
||||
# Assign _DOWNLOAD_TIMEOUT = 300
|
||||
_DOWNLOAD_TIMEOUT = 300
|
||||
# Assign _ZIP_ROOT_PREFIX = "stockpile-master"
|
||||
_ZIP_ROOT_PREFIX = "stockpile-master"
|
||||
|
||||
|
||||
@@ -60,26 +83,40 @@ _ZIP_ROOT_PREFIX = "stockpile-master"
|
||||
|
||||
def _download_zip(url: str = CALDERA_ZIP_URL) -> bytes:
|
||||
"""Download the CALDERA ZIP and return raw bytes."""
|
||||
# Log info: "Downloading CALDERA ZIP from %s …", url
|
||||
logger.info("Downloading CALDERA ZIP from %s …", url)
|
||||
# Assign resp = _requests.get(url, timeout=_DOWNLOAD_TIMEOUT, stream=True)
|
||||
resp = _requests.get(url, timeout=_DOWNLOAD_TIMEOUT, stream=True)
|
||||
# Call resp.raise_for_status()
|
||||
resp.raise_for_status()
|
||||
# Assign content = resp.content
|
||||
content = resp.content
|
||||
# Log info: "Downloaded %.1f MB", len(content) / (1024 * 1024
|
||||
logger.info("Downloaded %.1f MB", len(content) / (1024 * 1024))
|
||||
# Return content
|
||||
return content
|
||||
|
||||
|
||||
# Define function _extract_zip
|
||||
def _extract_zip(zip_bytes: bytes, dest: str) -> Path:
|
||||
"""Extract *zip_bytes* into *dest* and return abilities dir."""
|
||||
# Open context manager
|
||||
with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
|
||||
# Call zf.extractall()
|
||||
zf.extractall(dest)
|
||||
# Assign abilities_dir = Path(dest) / _ZIP_ROOT_PREFIX / "data" / "abilities"
|
||||
abilities_dir = Path(dest) / _ZIP_ROOT_PREFIX / "data" / "abilities"
|
||||
# Check: not abilities_dir.is_dir()
|
||||
if not abilities_dir.is_dir():
|
||||
# Raise FileNotFoundError
|
||||
raise FileNotFoundError(
|
||||
f"Expected abilities directory not found at {abilities_dir}"
|
||||
)
|
||||
# Return abilities_dir
|
||||
return abilities_dir
|
||||
|
||||
|
||||
# Define function _extract_commands
|
||||
def _extract_commands(platforms_dict: dict) -> str:
|
||||
"""Extract executor commands from CALDERA platforms dict.
|
||||
|
||||
@@ -95,116 +132,192 @@ def _extract_commands(platforms_dict: dict) -> str:
|
||||
|
||||
Returns a formatted string with all commands.
|
||||
"""
|
||||
# Assign lines = []
|
||||
lines = []
|
||||
# Check: not isinstance(platforms_dict, dict)
|
||||
if not isinstance(platforms_dict, dict):
|
||||
# Return ""
|
||||
return ""
|
||||
|
||||
# Iterate over platforms_dict.items()
|
||||
for os_name, executors in platforms_dict.items():
|
||||
# Check: not isinstance(executors, dict)
|
||||
if not isinstance(executors, dict):
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
# Iterate over executors.items()
|
||||
for executor_name, executor_data in executors.items():
|
||||
# Check: isinstance(executor_data, dict)
|
||||
if isinstance(executor_data, dict):
|
||||
# Assign cmd = executor_data.get("command", "")
|
||||
cmd = executor_data.get("command", "")
|
||||
# Check: cmd
|
||||
if cmd:
|
||||
# Call lines.append()
|
||||
lines.append(f"[{os_name}/{executor_name}]\n{cmd}")
|
||||
# Alternative: isinstance(executor_data, str)
|
||||
elif isinstance(executor_data, str):
|
||||
# Call lines.append()
|
||||
lines.append(f"[{os_name}/{executor_name}]\n{executor_data}")
|
||||
|
||||
# Return "\n\n".join(lines)
|
||||
return "\n\n".join(lines)
|
||||
|
||||
|
||||
# Define function _extract_platforms
|
||||
def _extract_platforms(platforms_dict: dict) -> str:
|
||||
"""Extract platform names from CALDERA platforms dict."""
|
||||
# Check: not isinstance(platforms_dict, dict)
|
||||
if not isinstance(platforms_dict, dict):
|
||||
# Return ""
|
||||
return ""
|
||||
# Assign platform_names = []
|
||||
platform_names = []
|
||||
# Iterate over platforms_dict
|
||||
for os_name in platforms_dict:
|
||||
# Assign normalized = str(os_name).lower().strip()
|
||||
normalized = str(os_name).lower().strip()
|
||||
# Check: normalized in ("windows", "linux", "darwin", "macos")
|
||||
if normalized in ("windows", "linux", "darwin", "macos"):
|
||||
# Check: normalized == "darwin"
|
||||
if normalized == "darwin":
|
||||
# Assign normalized = "macos"
|
||||
normalized = "macos"
|
||||
# Check: normalized not in platform_names
|
||||
if normalized not in platform_names:
|
||||
# Call platform_names.append()
|
||||
platform_names.append(normalized)
|
||||
# Return ", ".join(platform_names)
|
||||
return ", ".join(platform_names)
|
||||
|
||||
|
||||
# Define function _parse_abilities
|
||||
def _parse_abilities(abilities_dir: Path) -> list[dict]:
|
||||
"""Walk abilities directories and parse all YAML files.
|
||||
|
||||
Returns a flat list of dicts, each representing one ability.
|
||||
"""
|
||||
# Assign results = []
|
||||
results: list[dict] = []
|
||||
# Assign yaml_files = sorted(abilities_dir.rglob("*.yml"))
|
||||
yaml_files = sorted(abilities_dir.rglob("*.yml"))
|
||||
# Log info: "Found %d ability YAML files", len(yaml_files
|
||||
logger.info("Found %d ability YAML files", len(yaml_files))
|
||||
|
||||
# Iterate over yaml_files
|
||||
for yaml_path in yaml_files:
|
||||
# Attempt the following; catch errors below
|
||||
try:
|
||||
# Open context manager
|
||||
with open(yaml_path, "r", encoding="utf-8") as fh:
|
||||
# Assign data_list = list(yaml.safe_load_all(fh))
|
||||
data_list = list(yaml.safe_load_all(fh))
|
||||
# Handle Exception
|
||||
except Exception as exc:
|
||||
# Log debug: "Failed to parse %s: %s", yaml_path, exc
|
||||
logger.debug("Failed to parse %s: %s", yaml_path, exc)
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
|
||||
# Stockpile YAML files may contain YAML lists of abilities
|
||||
# (e.g. [- id: ..., - id: ...]) or single-document dicts.
|
||||
# Flatten everything into individual ability dicts.
|
||||
abilities: list[dict] = []
|
||||
# Iterate over data_list
|
||||
for data in data_list:
|
||||
# Check: isinstance(data, dict)
|
||||
if isinstance(data, dict):
|
||||
# Call abilities.append()
|
||||
abilities.append(data)
|
||||
# Alternative: isinstance(data, list)
|
||||
elif isinstance(data, list):
|
||||
# Call abilities.extend()
|
||||
abilities.extend(d for d in data if isinstance(d, dict))
|
||||
|
||||
# Iterate over abilities
|
||||
for data in abilities:
|
||||
# Assign ability_id = data.get("id", "")
|
||||
ability_id = data.get("id", "")
|
||||
# Check: not ability_id
|
||||
if not ability_id:
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
|
||||
# Assign name = data.get("name", "").strip()
|
||||
name = data.get("name", "").strip()
|
||||
# Assign description = data.get("description", "").strip()
|
||||
description = data.get("description", "").strip()
|
||||
# Assign tactic = data.get("tactic", "").strip()
|
||||
tactic = data.get("tactic", "").strip()
|
||||
|
||||
# Extract technique info
|
||||
technique = data.get("technique", {})
|
||||
# Check: isinstance(technique, dict)
|
||||
if isinstance(technique, dict):
|
||||
# Assign attack_id = technique.get("attack_id", "")
|
||||
attack_id = technique.get("attack_id", "")
|
||||
# Fallback: handle remaining cases
|
||||
else:
|
||||
# Assign attack_id = ""
|
||||
attack_id = ""
|
||||
|
||||
# Check: not attack_id
|
||||
if not attack_id:
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
|
||||
# Normalise technique ID
|
||||
attack_id = str(attack_id).strip().upper()
|
||||
# Check: not attack_id.startswith("T")
|
||||
if not attack_id.startswith("T"):
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
|
||||
# Extract platforms and commands
|
||||
platforms_dict = data.get("platforms", {})
|
||||
# Assign commands = _extract_commands(platforms_dict)
|
||||
commands = _extract_commands(platforms_dict)
|
||||
# Assign platform_str = _extract_platforms(platforms_dict)
|
||||
platform_str = _extract_platforms(platforms_dict)
|
||||
|
||||
# Determine executor type
|
||||
executors = set()
|
||||
# Check: isinstance(platforms_dict, dict)
|
||||
if isinstance(platforms_dict, dict):
|
||||
# Iterate over platforms_dict.values()
|
||||
for os_executors in platforms_dict.values():
|
||||
# Check: isinstance(os_executors, dict)
|
||||
if isinstance(os_executors, dict):
|
||||
# Call executors.update()
|
||||
executors.update(os_executors.keys())
|
||||
# Assign executor_str = ", ".join(sorted(executors)) if executors else None
|
||||
executor_str = ", ".join(sorted(executors)) if executors else None
|
||||
|
||||
# Call results.append()
|
||||
results.append({
|
||||
# Literal argument value
|
||||
"mitre_technique_id": attack_id,
|
||||
# Literal argument value
|
||||
"name": f"CALDERA: {name}"[:500] if name else f"CALDERA ability {ability_id}"[:500],
|
||||
# Literal argument value
|
||||
"description": f"{description}\n\nTactic: {tactic}".strip()[:2000] if description else None,
|
||||
# Literal argument value
|
||||
"source": "caldera",
|
||||
# Literal argument value
|
||||
"platform": platform_str,
|
||||
# Literal argument value
|
||||
"tool_suggested": executor_str,
|
||||
# Literal argument value
|
||||
"attack_procedure": commands[:4000] if commands else None,
|
||||
# Literal argument value
|
||||
"atomic_test_id": f"caldera:{ability_id}",
|
||||
# Literal argument value
|
||||
"source_url": f"https://github.com/mitre/stockpile/tree/master/data/abilities/{tactic}",
|
||||
})
|
||||
|
||||
# Log info: "Parsed %d CALDERA abilities total", len(results
|
||||
logger.info("Parsed %d CALDERA abilities total", len(results))
|
||||
# Return results
|
||||
return results
|
||||
|
||||
|
||||
@@ -218,46 +331,76 @@ def sync(db: Session) -> dict:
|
||||
|
||||
Returns a summary dict with ``created``, ``skipped_existing``, ``total_parsed``.
|
||||
"""
|
||||
# Assign tmp_dir = tempfile.mkdtemp(prefix="aegis_caldera_")
|
||||
tmp_dir = tempfile.mkdtemp(prefix="aegis_caldera_")
|
||||
# Attempt the following; catch errors below
|
||||
try:
|
||||
# Assign zip_bytes = _download_zip()
|
||||
zip_bytes = _download_zip()
|
||||
# Assign abilities_dir = _extract_zip(zip_bytes, tmp_dir)
|
||||
abilities_dir = _extract_zip(zip_bytes, tmp_dir)
|
||||
# Assign parsed = _parse_abilities(abilities_dir)
|
||||
parsed = _parse_abilities(abilities_dir)
|
||||
# Always execute this cleanup block
|
||||
finally:
|
||||
# Call shutil.rmtree()
|
||||
shutil.rmtree(tmp_dir, ignore_errors=True)
|
||||
# Log info: "Cleaned up temp directory %s", tmp_dir
|
||||
logger.info("Cleaned up temp directory %s", tmp_dir)
|
||||
|
||||
# Pre-load existing for dedup
|
||||
existing_ids: set[str] = {
|
||||
row[0]
|
||||
for row in db.query(TestTemplate.atomic_test_id)
|
||||
# Chain .filter() call
|
||||
.filter(TestTemplate.source == "caldera")
|
||||
# Chain .filter() call
|
||||
.filter(TestTemplate.atomic_test_id.isnot(None))
|
||||
# Chain .all() call
|
||||
.all()
|
||||
}
|
||||
|
||||
# Assign created = 0
|
||||
created = 0
|
||||
# Assign skipped = 0
|
||||
skipped = 0
|
||||
new_technique_ids: set[str] = set()
|
||||
|
||||
# Iterate over parsed
|
||||
for item in parsed:
|
||||
# Check: item["atomic_test_id"] in existing_ids
|
||||
if item["atomic_test_id"] in existing_ids:
|
||||
# Assign skipped = 1
|
||||
skipped += 1
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
|
||||
# Assign template = TestTemplate(
|
||||
template = TestTemplate(
|
||||
# Keyword argument: mitre_technique_id
|
||||
mitre_technique_id=item["mitre_technique_id"],
|
||||
# Keyword argument: name
|
||||
name=item["name"],
|
||||
# Keyword argument: description
|
||||
description=item["description"],
|
||||
# Keyword argument: source
|
||||
source=item["source"],
|
||||
# Keyword argument: source_url
|
||||
source_url=item["source_url"],
|
||||
# Keyword argument: attack_procedure
|
||||
attack_procedure=item["attack_procedure"],
|
||||
# Keyword argument: platform
|
||||
platform=item["platform"],
|
||||
# Keyword argument: tool_suggested
|
||||
tool_suggested=item["tool_suggested"],
|
||||
# Keyword argument: atomic_test_id
|
||||
atomic_test_id=item["atomic_test_id"],
|
||||
# Keyword argument: is_active
|
||||
is_active=True,
|
||||
)
|
||||
# Stage new record(s) for database insertion
|
||||
db.add(template)
|
||||
# Call existing_ids.add()
|
||||
existing_ids.add(item["atomic_test_id"])
|
||||
new_technique_ids.add(item["mitre_technique_id"])
|
||||
created += 1
|
||||
@@ -269,22 +412,36 @@ def sync(db: Session) -> dict:
|
||||
|
||||
db.commit()
|
||||
|
||||
# Assign summary = {
|
||||
summary = {
|
||||
# Literal argument value
|
||||
"created": created,
|
||||
# Literal argument value
|
||||
"skipped_existing": skipped,
|
||||
# Literal argument value
|
||||
"total_parsed": len(parsed),
|
||||
}
|
||||
|
||||
# Update DataSource record
|
||||
ds = db.query(DataSource).filter(DataSource.name == "caldera").first()
|
||||
# Check: ds
|
||||
if ds:
|
||||
# Assign ds.last_sync_at = datetime.utcnow()
|
||||
ds.last_sync_at = datetime.utcnow()
|
||||
# Assign ds.last_sync_status = "success"
|
||||
ds.last_sync_status = "success"
|
||||
# Assign ds.last_sync_stats = summary
|
||||
ds.last_sync_stats = summary
|
||||
# Commit all pending changes to the database
|
||||
db.commit()
|
||||
|
||||
# Log info: "CALDERA import complete — %s", summary
|
||||
logger.info("CALDERA import complete — %s", summary)
|
||||
# Call log_action()
|
||||
log_action(db, user_id=None, action="import_caldera",
|
||||
# Keyword argument: entity_type
|
||||
entity_type="test_template", entity_id=None, details=summary)
|
||||
# Commit all pending changes to the database
|
||||
db.commit()
|
||||
# Return summary
|
||||
return summary
|
||||
|
||||
@@ -4,112 +4,191 @@ Framework-agnostic; uses domain exceptions from app.domain.errors.
|
||||
The router is responsible for HTTP concerns, auth, audit logging, and commit.
|
||||
"""
|
||||
|
||||
# Import uuid
|
||||
import uuid
|
||||
|
||||
# Import datetime from datetime
|
||||
from datetime import datetime
|
||||
|
||||
# Import Optional from typing
|
||||
from typing import Optional
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import from app.domain.errors
|
||||
from app.domain.errors import (
|
||||
BusinessRuleViolation,
|
||||
EntityNotFoundError,
|
||||
PermissionViolation,
|
||||
)
|
||||
|
||||
# Import Campaign, CampaignTest from app.models.campaign
|
||||
from app.models.campaign import Campaign, CampaignTest
|
||||
from app.models.test import Test
|
||||
|
||||
# Import Technique from app.models.technique
|
||||
from app.models.technique import Technique
|
||||
from app.utils import escape_like
|
||||
from app.services.campaign_service import (
|
||||
get_campaign_progress,
|
||||
validate_no_circular_dependency,
|
||||
TACTIC_TO_PHASE,
|
||||
)
|
||||
|
||||
# Import Test from app.models.test
|
||||
from app.models.test import Test
|
||||
|
||||
# Import calculate_next_run from app.services.campaign_scheduler_service
|
||||
from app.services.campaign_scheduler_service import calculate_next_run
|
||||
from app.services.status_service import recalculate_technique_status
|
||||
|
||||
# Import from app.services.campaign_service
|
||||
from app.services.campaign_service import (
|
||||
TACTIC_TO_PHASE,
|
||||
get_campaign_progress,
|
||||
validate_no_circular_dependency,
|
||||
)
|
||||
|
||||
# Import escape_like from app.utils
|
||||
from app.utils import escape_like
|
||||
|
||||
# ── Serialization helpers ────────────────────────────────────────────────
|
||||
|
||||
|
||||
def serialize_campaign(db: Session, campaign: Campaign) -> dict:
|
||||
"""Serialize a campaign with its tests and progress."""
|
||||
# Assign progress = get_campaign_progress(db, campaign.id)
|
||||
progress = get_campaign_progress(db, campaign.id)
|
||||
|
||||
# Assign campaign_tests = (
|
||||
campaign_tests = (
|
||||
db.query(CampaignTest)
|
||||
# Chain .filter() call
|
||||
.filter(CampaignTest.campaign_id == campaign.id)
|
||||
# Chain .order_by() call
|
||||
.order_by(CampaignTest.order_index)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
|
||||
# Assign tests = []
|
||||
tests = []
|
||||
# Iterate over campaign_tests
|
||||
for ct in campaign_tests:
|
||||
# Assign test = ct.test
|
||||
test = ct.test
|
||||
# Assign technique = db.query(Technique).filter(Technique.id == test.technique_id).first...
|
||||
technique = db.query(Technique).filter(Technique.id == test.technique_id).first() if test else None
|
||||
|
||||
# Call tests.append()
|
||||
tests.append({
|
||||
# Literal argument value
|
||||
"id": str(ct.id),
|
||||
# Literal argument value
|
||||
"test_id": str(ct.test_id),
|
||||
# Literal argument value
|
||||
"order_index": ct.order_index,
|
||||
# Literal argument value
|
||||
"depends_on": str(ct.depends_on) if ct.depends_on else None,
|
||||
# Literal argument value
|
||||
"phase": ct.phase,
|
||||
# Literal argument value
|
||||
"test_name": test.name if test else None,
|
||||
# Literal argument value
|
||||
"test_state": test.state.value if test and test.state else None,
|
||||
# Literal argument value
|
||||
"test_result": test.result.value if test and test.result else None,
|
||||
# Literal argument value
|
||||
"technique_mitre_id": technique.mitre_id if technique else None,
|
||||
# Literal argument value
|
||||
"technique_name": technique.name if technique else None,
|
||||
# Literal argument value
|
||||
"platform": test.platform if test else None,
|
||||
})
|
||||
|
||||
# Assign actor = campaign.threat_actor
|
||||
actor = campaign.threat_actor
|
||||
|
||||
# Return {
|
||||
return {
|
||||
# Literal argument value
|
||||
"id": str(campaign.id),
|
||||
# Literal argument value
|
||||
"name": campaign.name,
|
||||
# Literal argument value
|
||||
"description": campaign.description,
|
||||
# Literal argument value
|
||||
"type": campaign.type,
|
||||
# Literal argument value
|
||||
"status": campaign.status,
|
||||
# Literal argument value
|
||||
"threat_actor_id": str(campaign.threat_actor_id) if campaign.threat_actor_id else None,
|
||||
# Literal argument value
|
||||
"threat_actor_name": actor.name if actor else None,
|
||||
# Literal argument value
|
||||
"created_by": str(campaign.created_by) if campaign.created_by else None,
|
||||
"start_date": campaign.start_date.isoformat() if campaign.start_date else None,
|
||||
"scheduled_at": campaign.scheduled_at.isoformat() if campaign.scheduled_at else None,
|
||||
# Literal argument value
|
||||
"completed_at": campaign.completed_at.isoformat() if campaign.completed_at else None,
|
||||
# Literal argument value
|
||||
"target_platform": campaign.target_platform,
|
||||
# Literal argument value
|
||||
"tags": campaign.tags or [],
|
||||
# Literal argument value
|
||||
"created_at": campaign.created_at.isoformat() if campaign.created_at else None,
|
||||
# Literal argument value
|
||||
"is_recurring": campaign.is_recurring or False,
|
||||
# Literal argument value
|
||||
"recurrence_pattern": campaign.recurrence_pattern,
|
||||
# Literal argument value
|
||||
"next_run_at": campaign.next_run_at.isoformat() if campaign.next_run_at else None,
|
||||
# Literal argument value
|
||||
"last_run_at": campaign.last_run_at.isoformat() if campaign.last_run_at else None,
|
||||
# Literal argument value
|
||||
"parent_campaign_id": str(campaign.parent_campaign_id) if campaign.parent_campaign_id else None,
|
||||
# Literal argument value
|
||||
"tests": tests,
|
||||
# Literal argument value
|
||||
"progress": progress,
|
||||
}
|
||||
|
||||
|
||||
# Define function serialize_campaign_summary
|
||||
def serialize_campaign_summary(db: Session, campaign: Campaign) -> dict:
|
||||
"""Lightweight campaign serialization for list views."""
|
||||
# Assign progress = get_campaign_progress(db, campaign.id)
|
||||
progress = get_campaign_progress(db, campaign.id)
|
||||
# Assign actor = campaign.threat_actor
|
||||
actor = campaign.threat_actor
|
||||
|
||||
# Return {
|
||||
return {
|
||||
# Literal argument value
|
||||
"id": str(campaign.id),
|
||||
# Literal argument value
|
||||
"name": campaign.name,
|
||||
# Literal argument value
|
||||
"description": campaign.description,
|
||||
# Literal argument value
|
||||
"type": campaign.type,
|
||||
# Literal argument value
|
||||
"status": campaign.status,
|
||||
# Literal argument value
|
||||
"threat_actor_id": str(campaign.threat_actor_id) if campaign.threat_actor_id else None,
|
||||
# Literal argument value
|
||||
"threat_actor_name": actor.name if actor else None,
|
||||
"start_date": campaign.start_date.isoformat() if campaign.start_date else None,
|
||||
"target_platform": campaign.target_platform,
|
||||
# Literal argument value
|
||||
"tags": campaign.tags or [],
|
||||
# Literal argument value
|
||||
"created_at": campaign.created_at.isoformat() if campaign.created_at else None,
|
||||
# Literal argument value
|
||||
"is_recurring": campaign.is_recurring or False,
|
||||
# Literal argument value
|
||||
"recurrence_pattern": campaign.recurrence_pattern,
|
||||
# Literal argument value
|
||||
"next_run_at": campaign.next_run_at.isoformat() if campaign.next_run_at else None,
|
||||
# Literal argument value
|
||||
"last_run_at": campaign.last_run_at.isoformat() if campaign.last_run_at else None,
|
||||
# Literal argument value
|
||||
"test_count": progress["total"],
|
||||
# Literal argument value
|
||||
"completion_pct": progress["completion_pct"],
|
||||
}
|
||||
|
||||
@@ -118,122 +197,198 @@ def serialize_campaign_summary(db: Session, campaign: Campaign) -> dict:
|
||||
|
||||
|
||||
def list_campaigns(
|
||||
# Entry: db
|
||||
db: Session,
|
||||
*,
|
||||
# Entry: type
|
||||
type: Optional[str] = None,
|
||||
# Entry: status
|
||||
status: Optional[str] = None,
|
||||
# Entry: threat_actor_id
|
||||
threat_actor_id: Optional[str] = None,
|
||||
# Entry: search
|
||||
search: Optional[str] = None,
|
||||
# Entry: offset
|
||||
offset: int = 0,
|
||||
# Entry: limit
|
||||
limit: int = 50,
|
||||
) -> dict:
|
||||
"""Return a paginated list of campaigns with optional filters."""
|
||||
# Assign query = db.query(Campaign)
|
||||
query = db.query(Campaign)
|
||||
|
||||
# Check: type
|
||||
if type:
|
||||
# Assign query = query.filter(Campaign.type == type)
|
||||
query = query.filter(Campaign.type == type)
|
||||
# Check: status
|
||||
if status:
|
||||
# Assign query = query.filter(Campaign.status == status)
|
||||
query = query.filter(Campaign.status == status)
|
||||
# Check: threat_actor_id
|
||||
if threat_actor_id:
|
||||
# Assign query = query.filter(Campaign.threat_actor_id == threat_actor_id)
|
||||
query = query.filter(Campaign.threat_actor_id == threat_actor_id)
|
||||
# Check: search
|
||||
if search:
|
||||
# Assign pattern = f"%{escape_like(search)}%"
|
||||
pattern = f"%{escape_like(search)}%"
|
||||
# Assign query = query.filter(Campaign.name.ilike(pattern) | Campaign.description.il...
|
||||
query = query.filter(Campaign.name.ilike(pattern) | Campaign.description.ilike(pattern))
|
||||
|
||||
# Assign total = query.count()
|
||||
total = query.count()
|
||||
# Assign campaigns = query.order_by(Campaign.created_at.desc()).offset(offset).limit(lim...
|
||||
campaigns = query.order_by(Campaign.created_at.desc()).offset(offset).limit(limit).all()
|
||||
|
||||
# Return {
|
||||
return {
|
||||
# Literal argument value
|
||||
"total": total,
|
||||
# Literal argument value
|
||||
"offset": offset,
|
||||
# Literal argument value
|
||||
"limit": limit,
|
||||
# Literal argument value
|
||||
"items": [serialize_campaign_summary(db, c) for c in campaigns],
|
||||
}
|
||||
|
||||
|
||||
# Define function create_campaign
|
||||
def create_campaign(
|
||||
# Entry: db
|
||||
db: Session,
|
||||
*,
|
||||
# Entry: creator_id
|
||||
creator_id: uuid.UUID,
|
||||
# Entry: name
|
||||
name: str,
|
||||
# Entry: description
|
||||
description: Optional[str] = None,
|
||||
# Entry: type
|
||||
type: str = "custom",
|
||||
# Entry: threat_actor_id
|
||||
threat_actor_id: Optional[str] = None,
|
||||
# Entry: target_platform
|
||||
target_platform: Optional[str] = None,
|
||||
# Entry: tags
|
||||
tags: Optional[list[str]] = None,
|
||||
# Entry: scheduled_at
|
||||
scheduled_at: Optional[str] = None,
|
||||
start_date: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Create a new campaign. Does not commit; caller commits."""
|
||||
# Assign campaign = Campaign(
|
||||
campaign = Campaign(
|
||||
# Keyword argument: name
|
||||
name=name,
|
||||
# Keyword argument: description
|
||||
description=description,
|
||||
# Keyword argument: type
|
||||
type=type,
|
||||
# Keyword argument: threat_actor_id
|
||||
threat_actor_id=uuid.UUID(threat_actor_id) if threat_actor_id else None,
|
||||
# Keyword argument: target_platform
|
||||
target_platform=target_platform,
|
||||
# Keyword argument: tags
|
||||
tags=tags or [],
|
||||
# Keyword argument: created_by
|
||||
created_by=creator_id,
|
||||
# Keyword argument: scheduled_at
|
||||
scheduled_at=datetime.fromisoformat(scheduled_at) if scheduled_at else None,
|
||||
start_date=datetime.fromisoformat(start_date) if start_date else None,
|
||||
)
|
||||
# Stage new record(s) for database insertion
|
||||
db.add(campaign)
|
||||
# Flush changes to DB without committing the transaction
|
||||
db.flush()
|
||||
# Return serialize_campaign(db, campaign)
|
||||
return serialize_campaign(db, campaign)
|
||||
|
||||
|
||||
# Define function get_campaign_detail
|
||||
def get_campaign_detail(db: Session, campaign_id: str) -> dict:
|
||||
"""Get detailed campaign info including tests and progress.
|
||||
|
||||
Raises EntityNotFoundError if campaign not found.
|
||||
"""
|
||||
# Assign campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||
# Check: not campaign
|
||||
if not campaign:
|
||||
# Raise EntityNotFoundError
|
||||
raise EntityNotFoundError("Campaign", campaign_id)
|
||||
# Return serialize_campaign(db, campaign)
|
||||
return serialize_campaign(db, campaign)
|
||||
|
||||
|
||||
# Define function update_campaign
|
||||
def update_campaign(
|
||||
# Entry: db
|
||||
db: Session,
|
||||
# Entry: campaign_id
|
||||
campaign_id: str,
|
||||
*,
|
||||
# Entry: updater_id
|
||||
updater_id: uuid.UUID,
|
||||
# Entry: updater_role
|
||||
updater_role: str,
|
||||
**fields,
|
||||
**fields: object,
|
||||
) -> dict:
|
||||
"""Update a campaign. Only allowed in draft or active state.
|
||||
|
||||
Raises EntityNotFoundError, BusinessRuleViolation, or PermissionViolation.
|
||||
Does not commit; caller commits.
|
||||
"""
|
||||
# Assign campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||
# Check: not campaign
|
||||
if not campaign:
|
||||
# Raise EntityNotFoundError
|
||||
raise EntityNotFoundError("Campaign", campaign_id)
|
||||
|
||||
# Check: campaign.status not in ("draft", "active")
|
||||
if campaign.status not in ("draft", "active"):
|
||||
# Raise BusinessRuleViolation
|
||||
raise BusinessRuleViolation("Can only update draft or active campaigns")
|
||||
|
||||
# Check: str(campaign.created_by) != str(updater_id) and updater_role != "ad...
|
||||
if str(campaign.created_by) != str(updater_id) and updater_role != "admin":
|
||||
# Raise PermissionViolation
|
||||
raise PermissionViolation("Only the creator or admin can update this campaign")
|
||||
|
||||
# Check: "scheduled_at" in fields and fields["scheduled_at"]
|
||||
if "scheduled_at" in fields and fields["scheduled_at"]:
|
||||
# Assign fields["scheduled_at"] = datetime.fromisoformat(fields["scheduled_at"])
|
||||
fields["scheduled_at"] = datetime.fromisoformat(fields["scheduled_at"])
|
||||
if "start_date" in fields and fields["start_date"]:
|
||||
fields["start_date"] = datetime.fromisoformat(fields["start_date"])
|
||||
|
||||
# Iterate over fields.items()
|
||||
for field, value in fields.items():
|
||||
# Call setattr()
|
||||
setattr(campaign, field, value)
|
||||
|
||||
# Flush changes to DB without committing the transaction
|
||||
db.flush()
|
||||
# Return serialize_campaign(db, campaign)
|
||||
return serialize_campaign(db, campaign)
|
||||
|
||||
|
||||
# Define function add_test_to_campaign
|
||||
def add_test_to_campaign(
|
||||
# Entry: db
|
||||
db: Session,
|
||||
# Entry: campaign_id
|
||||
campaign_id: str,
|
||||
*,
|
||||
# Entry: test_id
|
||||
test_id: str,
|
||||
# Entry: order_index
|
||||
order_index: Optional[int] = None,
|
||||
# Entry: depends_on
|
||||
depends_on: Optional[str] = None,
|
||||
# Entry: phase
|
||||
phase: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Add a test to a campaign with optional ordering and dependency.
|
||||
@@ -242,60 +397,101 @@ def add_test_to_campaign(
|
||||
Raises BusinessRuleViolation for invalid state or circular dependency.
|
||||
Does not commit; caller commits.
|
||||
"""
|
||||
# Assign campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||
# Check: not campaign
|
||||
if not campaign:
|
||||
# Raise EntityNotFoundError
|
||||
raise EntityNotFoundError("Campaign", campaign_id)
|
||||
|
||||
# Check: campaign.status not in ("draft", "active")
|
||||
if campaign.status not in ("draft", "active"):
|
||||
# Raise BusinessRuleViolation
|
||||
raise BusinessRuleViolation("Can only add tests to draft or active campaigns")
|
||||
|
||||
# Assign test = db.query(Test).filter(Test.id == test_id).first()
|
||||
test = db.query(Test).filter(Test.id == test_id).first()
|
||||
# Check: not test
|
||||
if not test:
|
||||
# Raise EntityNotFoundError
|
||||
raise EntityNotFoundError("Test", test_id)
|
||||
|
||||
# Check: order_index is not None
|
||||
if order_index is not None:
|
||||
# Assign final_order_index = order_index
|
||||
final_order_index = order_index
|
||||
# Fallback: handle remaining cases
|
||||
else:
|
||||
# Assign max_order = (
|
||||
max_order = (
|
||||
db.query(CampaignTest.order_index)
|
||||
# Chain .filter() call
|
||||
.filter(CampaignTest.campaign_id == campaign_id)
|
||||
# Chain .order_by() call
|
||||
.order_by(CampaignTest.order_index.desc())
|
||||
# Chain .first() call
|
||||
.first()
|
||||
)
|
||||
# Assign final_order_index = (max_order[0] + 1) if max_order else 0
|
||||
final_order_index = (max_order[0] + 1) if max_order else 0
|
||||
|
||||
# Assign depends_on_uuid = uuid.UUID(depends_on) if depends_on else None
|
||||
depends_on_uuid = uuid.UUID(depends_on) if depends_on else None
|
||||
|
||||
# Assign ct_id = uuid.uuid4()
|
||||
ct_id = uuid.uuid4()
|
||||
# Check: depends_on_uuid
|
||||
if depends_on_uuid:
|
||||
# Call validate_no_circular_dependency()
|
||||
validate_no_circular_dependency(db, uuid.UUID(campaign_id), ct_id, depends_on_uuid)
|
||||
|
||||
# Check: not phase and test.technique_id
|
||||
if not phase and test.technique_id:
|
||||
# Assign technique = db.query(Technique).filter(Technique.id == test.technique_id).first()
|
||||
technique = db.query(Technique).filter(Technique.id == test.technique_id).first()
|
||||
# Check: technique and technique.tactic
|
||||
if technique and technique.tactic:
|
||||
# Assign phase = TACTIC_TO_PHASE.get(technique.tactic, None)
|
||||
phase = TACTIC_TO_PHASE.get(technique.tactic, None)
|
||||
|
||||
# Assign campaign_test = CampaignTest(
|
||||
campaign_test = CampaignTest(
|
||||
# Keyword argument: id
|
||||
id=ct_id,
|
||||
# Keyword argument: campaign_id
|
||||
campaign_id=campaign_id,
|
||||
# Keyword argument: test_id
|
||||
test_id=test_id,
|
||||
# Keyword argument: order_index
|
||||
order_index=final_order_index,
|
||||
# Keyword argument: depends_on
|
||||
depends_on=depends_on_uuid,
|
||||
# Keyword argument: phase
|
||||
phase=phase,
|
||||
)
|
||||
# Stage new record(s) for database insertion
|
||||
db.add(campaign_test)
|
||||
# Flush changes to DB without committing the transaction
|
||||
db.flush()
|
||||
|
||||
# Return {
|
||||
return {
|
||||
# Literal argument value
|
||||
"id": str(campaign_test.id),
|
||||
# Literal argument value
|
||||
"campaign_id": str(campaign_test.campaign_id),
|
||||
# Literal argument value
|
||||
"test_id": str(campaign_test.test_id),
|
||||
# Literal argument value
|
||||
"order_index": campaign_test.order_index,
|
||||
# Literal argument value
|
||||
"depends_on": str(campaign_test.depends_on) if campaign_test.depends_on else None,
|
||||
# Literal argument value
|
||||
"phase": campaign_test.phase,
|
||||
}
|
||||
|
||||
|
||||
# Define function remove_test_from_campaign
|
||||
def remove_test_from_campaign(db: Session, campaign_id: str, campaign_test_id: str) -> None:
|
||||
"""Remove a test from a campaign.
|
||||
|
||||
@@ -303,27 +499,41 @@ def remove_test_from_campaign(db: Session, campaign_id: str, campaign_test_id: s
|
||||
Raises BusinessRuleViolation for invalid state.
|
||||
Does not commit; caller commits.
|
||||
"""
|
||||
# Assign campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||
# Check: not campaign
|
||||
if not campaign:
|
||||
# Raise EntityNotFoundError
|
||||
raise EntityNotFoundError("Campaign", campaign_id)
|
||||
|
||||
# Check: campaign.status not in ("draft", "active")
|
||||
if campaign.status not in ("draft", "active"):
|
||||
# Raise BusinessRuleViolation
|
||||
raise BusinessRuleViolation("Can only modify draft or active campaigns")
|
||||
|
||||
# Assign ct = (
|
||||
ct = (
|
||||
db.query(CampaignTest)
|
||||
# Chain .filter() call
|
||||
.filter(
|
||||
CampaignTest.id == campaign_test_id,
|
||||
CampaignTest.campaign_id == campaign_id,
|
||||
)
|
||||
# Chain .first() call
|
||||
.first()
|
||||
)
|
||||
# Check: not ct
|
||||
if not ct:
|
||||
# Raise EntityNotFoundError
|
||||
raise EntityNotFoundError("CampaignTest", campaign_test_id)
|
||||
|
||||
# Assign dep_id = uuid.UUID(campaign_test_id)
|
||||
dep_id = uuid.UUID(campaign_test_id)
|
||||
# Assign dependents = db.query(CampaignTest).filter(CampaignTest.depends_on == dep_id).all()
|
||||
dependents = db.query(CampaignTest).filter(CampaignTest.depends_on == dep_id).all()
|
||||
# Iterate over dependents
|
||||
for dep in dependents:
|
||||
# Assign dep.depends_on = None
|
||||
dep.depends_on = None
|
||||
|
||||
# Keep a reference to the underlying test before deleting the join record
|
||||
@@ -334,6 +544,7 @@ def remove_test_from_campaign(db: Session, campaign_id: str, campaign_test_id: s
|
||||
technique_id = test_obj.technique_id
|
||||
|
||||
db.delete(ct)
|
||||
# Flush changes to DB without committing the transaction
|
||||
db.flush()
|
||||
|
||||
# Also delete the actual test record (it was created for this campaign)
|
||||
@@ -349,72 +560,110 @@ def remove_test_from_campaign(db: Session, campaign_id: str, campaign_test_id: s
|
||||
db.flush()
|
||||
|
||||
|
||||
# Define function activate_campaign
|
||||
def activate_campaign(db: Session, campaign_id: str) -> Campaign:
|
||||
"""Activate a campaign, moving it from draft to active.
|
||||
|
||||
Raises EntityNotFoundError, BusinessRuleViolation.
|
||||
Does not commit; caller commits.
|
||||
"""
|
||||
# Assign campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||
# Check: not campaign
|
||||
if not campaign:
|
||||
# Raise EntityNotFoundError
|
||||
raise EntityNotFoundError("Campaign", campaign_id)
|
||||
|
||||
# Check: campaign.status != "draft"
|
||||
if campaign.status != "draft":
|
||||
# Raise BusinessRuleViolation
|
||||
raise BusinessRuleViolation("Only draft campaigns can be activated")
|
||||
|
||||
# Assign test_count = db.query(CampaignTest).filter(CampaignTest.campaign_id == campaign_...
|
||||
test_count = db.query(CampaignTest).filter(CampaignTest.campaign_id == campaign_id).count()
|
||||
# Check: test_count == 0
|
||||
if test_count == 0:
|
||||
# Raise BusinessRuleViolation
|
||||
raise BusinessRuleViolation("Campaign must have at least one test to activate")
|
||||
|
||||
# Assign campaign.status = "active"
|
||||
campaign.status = "active"
|
||||
# Flush changes to DB without committing the transaction
|
||||
db.flush()
|
||||
# Return campaign
|
||||
return campaign
|
||||
|
||||
|
||||
# Define function complete_campaign
|
||||
def complete_campaign(db: Session, campaign_id: str) -> Campaign:
|
||||
"""Mark a campaign as completed.
|
||||
|
||||
Raises EntityNotFoundError, BusinessRuleViolation.
|
||||
Does not commit; caller commits.
|
||||
"""
|
||||
# Assign campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||
# Check: not campaign
|
||||
if not campaign:
|
||||
# Raise EntityNotFoundError
|
||||
raise EntityNotFoundError("Campaign", campaign_id)
|
||||
|
||||
# Check: campaign.status != "active"
|
||||
if campaign.status != "active":
|
||||
# Raise BusinessRuleViolation
|
||||
raise BusinessRuleViolation("Only active campaigns can be completed")
|
||||
|
||||
# Assign campaign.status = "completed"
|
||||
campaign.status = "completed"
|
||||
# Assign campaign.completed_at = datetime.utcnow()
|
||||
campaign.completed_at = datetime.utcnow()
|
||||
# Flush changes to DB without committing the transaction
|
||||
db.flush()
|
||||
# Return campaign
|
||||
return campaign
|
||||
|
||||
|
||||
# Define function get_campaign_progress_data
|
||||
def get_campaign_progress_data(db: Session, campaign_id: str) -> dict:
|
||||
"""Get progress statistics for a campaign.
|
||||
|
||||
Raises EntityNotFoundError if campaign not found.
|
||||
"""
|
||||
# Assign campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||
# Check: not campaign
|
||||
if not campaign:
|
||||
# Raise EntityNotFoundError
|
||||
raise EntityNotFoundError("Campaign", campaign_id)
|
||||
|
||||
# Assign progress = get_campaign_progress(db, uuid.UUID(campaign_id))
|
||||
progress = get_campaign_progress(db, uuid.UUID(campaign_id))
|
||||
# Return {
|
||||
return {
|
||||
# Literal argument value
|
||||
"campaign_id": str(campaign.id),
|
||||
# Literal argument value
|
||||
"campaign_name": campaign.name,
|
||||
**progress,
|
||||
}
|
||||
|
||||
|
||||
# Define function schedule_campaign
|
||||
def schedule_campaign(
|
||||
# Entry: db
|
||||
db: Session,
|
||||
# Entry: campaign_id
|
||||
campaign_id: str,
|
||||
*,
|
||||
# Entry: owner_id
|
||||
owner_id: uuid.UUID,
|
||||
# Entry: owner_role
|
||||
owner_role: str,
|
||||
# Entry: is_recurring
|
||||
is_recurring: bool,
|
||||
# Entry: recurrence_pattern
|
||||
recurrence_pattern: Optional[str] = None,
|
||||
# Entry: next_run_at
|
||||
next_run_at: Optional[str] = None,
|
||||
) -> Campaign:
|
||||
"""Configure or update the recurrence schedule for a campaign.
|
||||
@@ -422,32 +671,52 @@ def schedule_campaign(
|
||||
Raises EntityNotFoundError, PermissionViolation, BusinessRuleViolation.
|
||||
Does not commit; caller commits.
|
||||
"""
|
||||
# Assign campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||
# Check: not campaign
|
||||
if not campaign:
|
||||
# Raise EntityNotFoundError
|
||||
raise EntityNotFoundError("Campaign", campaign_id)
|
||||
|
||||
# Check: str(campaign.created_by) != str(owner_id) and owner_role != "admin"
|
||||
if str(campaign.created_by) != str(owner_id) and owner_role != "admin":
|
||||
# Raise PermissionViolation
|
||||
raise PermissionViolation("Only the creator or admin can configure scheduling")
|
||||
|
||||
# Assign campaign.is_recurring = is_recurring
|
||||
campaign.is_recurring = is_recurring
|
||||
|
||||
# Check: is_recurring
|
||||
if is_recurring:
|
||||
# Check: recurrence_pattern not in ("weekly", "monthly", "quarterly")
|
||||
if recurrence_pattern not in ("weekly", "monthly", "quarterly"):
|
||||
# Raise BusinessRuleViolation
|
||||
raise BusinessRuleViolation(
|
||||
# Literal argument value
|
||||
"recurrence_pattern must be 'weekly', 'monthly', or 'quarterly'"
|
||||
)
|
||||
# Assign campaign.recurrence_pattern = recurrence_pattern
|
||||
campaign.recurrence_pattern = recurrence_pattern
|
||||
# Check: next_run_at
|
||||
if next_run_at:
|
||||
# Assign campaign.next_run_at = datetime.fromisoformat(
|
||||
campaign.next_run_at = datetime.fromisoformat(
|
||||
next_run_at.replace("Z", "+00:00").replace("+00:00", "")
|
||||
)
|
||||
# Alternative: not campaign.next_run_at
|
||||
elif not campaign.next_run_at:
|
||||
# Assign campaign.next_run_at = calculate_next_run(datetime.utcnow(), recurrence_pattern)
|
||||
campaign.next_run_at = calculate_next_run(datetime.utcnow(), recurrence_pattern)
|
||||
# Fallback: handle remaining cases
|
||||
else:
|
||||
# Assign campaign.recurrence_pattern = None
|
||||
campaign.recurrence_pattern = None
|
||||
# Assign campaign.next_run_at = None
|
||||
campaign.next_run_at = None
|
||||
|
||||
# Flush changes to DB without committing the transaction
|
||||
db.flush()
|
||||
# Return campaign
|
||||
return campaign
|
||||
|
||||
|
||||
@@ -522,29 +791,48 @@ def get_campaign_history(db: Session, campaign_id: str) -> dict:
|
||||
|
||||
Raises EntityNotFoundError if campaign not found.
|
||||
"""
|
||||
# Assign campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||
# Check: not campaign
|
||||
if not campaign:
|
||||
# Raise EntityNotFoundError
|
||||
raise EntityNotFoundError("Campaign", campaign_id)
|
||||
|
||||
# Assign campaign_uuid = uuid.UUID(campaign_id)
|
||||
campaign_uuid = uuid.UUID(campaign_id)
|
||||
# Assign children = (
|
||||
children = (
|
||||
db.query(Campaign)
|
||||
# Chain .filter() call
|
||||
.filter(Campaign.parent_campaign_id == campaign_uuid)
|
||||
# Chain .order_by() call
|
||||
.order_by(Campaign.created_at.desc())
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
|
||||
# Return {
|
||||
return {
|
||||
# Literal argument value
|
||||
"campaign_id": str(campaign.id),
|
||||
# Literal argument value
|
||||
"campaign_name": campaign.name,
|
||||
# Literal argument value
|
||||
"items": [
|
||||
{
|
||||
# Literal argument value
|
||||
"id": str(child.id),
|
||||
# Literal argument value
|
||||
"name": child.name,
|
||||
# Literal argument value
|
||||
"status": child.status,
|
||||
# Literal argument value
|
||||
"test_count": db.query(CampaignTest).filter(CampaignTest.campaign_id == child.id).count(),
|
||||
# Literal argument value
|
||||
"completion_pct": get_campaign_progress(db, child.id)["completion_pct"],
|
||||
# Literal argument value
|
||||
"created_at": child.created_at.isoformat() if child.created_at else None,
|
||||
# Literal argument value
|
||||
"completed_at": child.completed_at.isoformat() if child.completed_at else None,
|
||||
}
|
||||
for child in children
|
||||
|
||||
@@ -4,19 +4,34 @@ Handles checking which recurring campaigns are due, cloning them with
|
||||
fresh tests, and computing the next run date.
|
||||
"""
|
||||
|
||||
# Import logging
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
# Import datetime, timedelta from datetime
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import Campaign, CampaignTest from app.models.campaign
|
||||
from app.models.campaign import Campaign, CampaignTest
|
||||
from app.models.test import Test
|
||||
|
||||
# Import TestState from app.models.enums
|
||||
from app.models.enums import TestState
|
||||
from app.services.notification_service import create_notification
|
||||
from app.services.audit_service import log_action
|
||||
|
||||
# Import Test from app.models.test
|
||||
from app.models.test import Test
|
||||
|
||||
# Import User from app.models.user
|
||||
from app.models.user import User
|
||||
|
||||
# Import log_action from app.services.audit_service
|
||||
from app.services.audit_service import log_action
|
||||
|
||||
# Import create_notification from app.services.notification_service
|
||||
from app.services.notification_service import create_notification
|
||||
|
||||
# Assign logger = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -33,11 +48,16 @@ def calculate_next_run(current_date: datetime, pattern: str) -> datetime:
|
||||
- ``monthly`` : +30 days
|
||||
- ``quarterly``: +90 days
|
||||
"""
|
||||
# Assign offsets = {
|
||||
offsets = {
|
||||
# Literal argument value
|
||||
"weekly": timedelta(days=7),
|
||||
# Literal argument value
|
||||
"monthly": timedelta(days=30),
|
||||
# Literal argument value
|
||||
"quarterly": timedelta(days=90),
|
||||
}
|
||||
# Return current_date + offsets.get(pattern, timedelta(days=30))
|
||||
return current_date + offsets.get(pattern, timedelta(days=30))
|
||||
|
||||
|
||||
@@ -54,59 +74,99 @@ def _clone_campaign(db: Session, original: Campaign) -> Campaign:
|
||||
with the same base data (in ``draft`` state) and link it.
|
||||
3. Activate the new campaign.
|
||||
"""
|
||||
# Assign now = datetime.utcnow()
|
||||
now = datetime.utcnow()
|
||||
# Assign run_label = now.strftime("%Y-%m-%d")
|
||||
run_label = now.strftime("%Y-%m-%d")
|
||||
|
||||
# Assign child = Campaign(
|
||||
child = Campaign(
|
||||
# Keyword argument: name
|
||||
name=f"{original.name} (Run {run_label})",
|
||||
# Keyword argument: description
|
||||
description=original.description,
|
||||
# Keyword argument: type
|
||||
type=original.type,
|
||||
# Keyword argument: threat_actor_id
|
||||
threat_actor_id=original.threat_actor_id,
|
||||
# Keyword argument: status
|
||||
status="active",
|
||||
# Keyword argument: created_by
|
||||
created_by=original.created_by,
|
||||
# Keyword argument: target_platform
|
||||
target_platform=original.target_platform,
|
||||
# Keyword argument: tags
|
||||
tags=original.tags or [],
|
||||
# Keyword argument: parent_campaign_id
|
||||
parent_campaign_id=original.id,
|
||||
)
|
||||
# Stage new record(s) for database insertion
|
||||
db.add(child)
|
||||
# Flush changes to DB without committing the transaction
|
||||
db.flush() # get child.id
|
||||
|
||||
# Clone each campaign_test with a fresh Test
|
||||
original_cts = (
|
||||
db.query(CampaignTest)
|
||||
# Chain .filter() call
|
||||
.filter(CampaignTest.campaign_id == original.id)
|
||||
# Chain .order_by() call
|
||||
.order_by(CampaignTest.order_index)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
|
||||
# Iterate over original_cts
|
||||
for ct in original_cts:
|
||||
# Assign src_test = ct.test
|
||||
src_test = ct.test
|
||||
# Check: not src_test
|
||||
if not src_test:
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
|
||||
# Assign new_test = Test(
|
||||
new_test = Test(
|
||||
# Keyword argument: technique_id
|
||||
technique_id=src_test.technique_id,
|
||||
# Keyword argument: name
|
||||
name=src_test.name,
|
||||
# Keyword argument: description
|
||||
description=src_test.description,
|
||||
# Keyword argument: platform
|
||||
platform=src_test.platform,
|
||||
# Keyword argument: procedure_text
|
||||
procedure_text=src_test.procedure_text,
|
||||
# Keyword argument: tool_used
|
||||
tool_used=src_test.tool_used,
|
||||
# Keyword argument: created_by
|
||||
created_by=original.created_by,
|
||||
# Keyword argument: state
|
||||
state=TestState.draft,
|
||||
)
|
||||
# Stage new record(s) for database insertion
|
||||
db.add(new_test)
|
||||
# Flush changes to DB without committing the transaction
|
||||
db.flush() # get new_test.id
|
||||
|
||||
# Assign new_ct = CampaignTest(
|
||||
new_ct = CampaignTest(
|
||||
# Keyword argument: campaign_id
|
||||
campaign_id=child.id,
|
||||
# Keyword argument: test_id
|
||||
test_id=new_test.id,
|
||||
# Keyword argument: order_index
|
||||
order_index=ct.order_index,
|
||||
# Keyword argument: phase
|
||||
phase=ct.phase,
|
||||
# depends_on is not copied — would need ID remapping
|
||||
)
|
||||
# Stage new record(s) for database insertion
|
||||
db.add(new_ct)
|
||||
|
||||
# Flush changes to DB without committing the transaction
|
||||
db.flush()
|
||||
# Return child
|
||||
return child
|
||||
|
||||
|
||||
@@ -120,75 +180,119 @@ def check_and_run_recurring_campaigns(db: Session) -> int:
|
||||
|
||||
Returns the number of campaigns spawned.
|
||||
"""
|
||||
# Assign now = datetime.utcnow()
|
||||
now = datetime.utcnow()
|
||||
|
||||
# Assign due_campaigns = (
|
||||
due_campaigns = (
|
||||
db.query(Campaign)
|
||||
# Chain .filter() call
|
||||
.filter(
|
||||
Campaign.is_recurring == True, # noqa: E712
|
||||
Campaign.next_run_at <= now,
|
||||
)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
|
||||
# Assign spawned = 0
|
||||
spawned = 0
|
||||
|
||||
# Iterate over due_campaigns
|
||||
for campaign in due_campaigns:
|
||||
# Attempt the following; catch errors below
|
||||
try:
|
||||
# Assign child = _clone_campaign(db, campaign)
|
||||
child = _clone_campaign(db, campaign)
|
||||
|
||||
# Update the original's scheduling fields
|
||||
campaign.last_run_at = now
|
||||
# Assign campaign.next_run_at = calculate_next_run(now, campaign.recurrence_pattern or "monthly")
|
||||
campaign.next_run_at = calculate_next_run(now, campaign.recurrence_pattern or "monthly")
|
||||
|
||||
# Commit all pending changes to the database
|
||||
db.commit()
|
||||
# Reload ORM object attributes from the database
|
||||
db.refresh(child)
|
||||
|
||||
# Audit
|
||||
log_action(
|
||||
db,
|
||||
# Keyword argument: user_id
|
||||
user_id=campaign.created_by,
|
||||
# Keyword argument: action
|
||||
action="recurring_campaign_run",
|
||||
# Keyword argument: entity_type
|
||||
entity_type="campaign",
|
||||
# Keyword argument: entity_id
|
||||
entity_id=child.id,
|
||||
# Keyword argument: details
|
||||
details={
|
||||
# Literal argument value
|
||||
"parent_campaign_id": str(campaign.id),
|
||||
# Literal argument value
|
||||
"child_campaign_name": child.name,
|
||||
# Literal argument value
|
||||
"pattern": campaign.recurrence_pattern,
|
||||
},
|
||||
)
|
||||
# Commit all pending changes to the database
|
||||
db.commit()
|
||||
|
||||
# Notify
|
||||
if campaign.created_by:
|
||||
# Call create_notification()
|
||||
create_notification(
|
||||
db,
|
||||
# Keyword argument: user_id
|
||||
user_id=campaign.created_by,
|
||||
# Keyword argument: type
|
||||
type="recurring_campaign_run",
|
||||
# Keyword argument: title
|
||||
title="Recurring campaign executed",
|
||||
message=f'Campaign "{child.name}" was automatically created from recurring template "{campaign.name}".',
|
||||
# Keyword argument: message
|
||||
message=(
|
||||
f'Campaign "{child.name}" was automatically created '
|
||||
f'from recurring template "{campaign.name}".'
|
||||
),
|
||||
# Keyword argument: entity_type
|
||||
entity_type="campaign",
|
||||
# Keyword argument: entity_id
|
||||
entity_id=child.id,
|
||||
)
|
||||
|
||||
# Notify red_tech users
|
||||
red_techs = db.query(User).filter(User.role == "red_tech", User.is_active == True).all() # noqa: E712
|
||||
# Iterate over red_techs
|
||||
for user in red_techs:
|
||||
# Call create_notification()
|
||||
create_notification(
|
||||
db,
|
||||
# Keyword argument: user_id
|
||||
user_id=user.id,
|
||||
# Keyword argument: type
|
||||
type="campaign_activated",
|
||||
# Keyword argument: title
|
||||
title="New recurring campaign active",
|
||||
# Keyword argument: message
|
||||
message=f'Campaign "{child.name}" is now active and ready for execution.',
|
||||
# Keyword argument: entity_type
|
||||
entity_type="campaign",
|
||||
# Keyword argument: entity_id
|
||||
entity_id=child.id,
|
||||
)
|
||||
|
||||
# Assign spawned = 1
|
||||
spawned += 1
|
||||
# Log info: "Spawned child campaign '%s' from parent '%s'", ch
|
||||
logger.info("Spawned child campaign '%s' from parent '%s'", child.name, campaign.name)
|
||||
|
||||
# Handle Exception
|
||||
except Exception:
|
||||
# Roll back all uncommitted changes
|
||||
db.rollback()
|
||||
# Log exception: "Failed to run recurring campaign '%s'", campaign.
|
||||
logger.exception("Failed to run recurring campaign '%s'", campaign.name)
|
||||
|
||||
# Return spawned
|
||||
return spawned
|
||||
|
||||
@@ -4,108 +4,183 @@ Handles circular dependency validation, campaign generation from
|
||||
threat actors, and progress calculation.
|
||||
"""
|
||||
|
||||
# Import logging
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
# Import uuid
|
||||
import uuid
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import EntityNotFoundError, InvalidOperationError from app.domain.exceptions
|
||||
from app.domain.exceptions import EntityNotFoundError, InvalidOperationError
|
||||
from app.models.campaign import Campaign, CampaignTest, KILL_CHAIN_PHASES
|
||||
from app.models.test import Test
|
||||
from app.models.test_template import TestTemplate
|
||||
from app.models.technique import Technique
|
||||
from app.models.threat_actor import ThreatActor, ThreatActorTechnique
|
||||
|
||||
# Import Campaign, CampaignTest from app.models.campaign
|
||||
from app.models.campaign import Campaign, CampaignTest
|
||||
|
||||
# Import TechniqueStatus, TestState from app.models.enums
|
||||
from app.models.enums import TechniqueStatus, TestState
|
||||
from app.services.notification_service import create_notification
|
||||
|
||||
# Import Technique from app.models.technique
|
||||
from app.models.technique import Technique
|
||||
|
||||
# Import Test from app.models.test
|
||||
from app.models.test import Test
|
||||
|
||||
# Import TestTemplate from app.models.test_template
|
||||
from app.models.test_template import TestTemplate
|
||||
|
||||
# Import ThreatActor, ThreatActorTechnique from app.models.threat_actor
|
||||
from app.models.threat_actor import ThreatActor, ThreatActorTechnique
|
||||
|
||||
# Import User from app.models.user
|
||||
from app.models.user import User
|
||||
|
||||
# Assign logger = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Mapping from ATT&CK tactics to kill chain phases
|
||||
TACTIC_TO_PHASE: dict[str, str] = {
|
||||
# Literal argument value
|
||||
"reconnaissance": "reconnaissance",
|
||||
# Literal argument value
|
||||
"resource-development": "resource_development",
|
||||
# Literal argument value
|
||||
"initial-access": "initial_access",
|
||||
# Literal argument value
|
||||
"execution": "execution",
|
||||
# Literal argument value
|
||||
"persistence": "persistence",
|
||||
# Literal argument value
|
||||
"privilege-escalation": "privilege_escalation",
|
||||
# Literal argument value
|
||||
"defense-evasion": "defense_evasion",
|
||||
# Literal argument value
|
||||
"credential-access": "credential_access",
|
||||
# Literal argument value
|
||||
"discovery": "discovery",
|
||||
# Literal argument value
|
||||
"lateral-movement": "lateral_movement",
|
||||
# Literal argument value
|
||||
"collection": "collection",
|
||||
# Literal argument value
|
||||
"command-and-control": "command_and_control",
|
||||
# Literal argument value
|
||||
"exfiltration": "exfiltration",
|
||||
# Literal argument value
|
||||
"impact": "impact",
|
||||
}
|
||||
|
||||
|
||||
# Define function validate_no_circular_dependency
|
||||
def validate_no_circular_dependency(
|
||||
# Entry: db
|
||||
db: Session,
|
||||
# Entry: campaign_id
|
||||
campaign_id: uuid.UUID,
|
||||
# Entry: test_id
|
||||
test_id: uuid.UUID,
|
||||
# Entry: depends_on_id
|
||||
depends_on_id: uuid.UUID | None,
|
||||
) -> None:
|
||||
"""Walk the depends_on chain and verify no cycle is formed.
|
||||
|
||||
Raises :class:`InvalidOperationError` if a circular dependency is detected.
|
||||
"""
|
||||
# Check: depends_on_id is None
|
||||
if depends_on_id is None:
|
||||
# Return control to caller
|
||||
return
|
||||
|
||||
# Assign visited = set()
|
||||
visited: set[uuid.UUID] = set()
|
||||
# Assign current = depends_on_id
|
||||
current = depends_on_id
|
||||
|
||||
# Loop while current is not None
|
||||
while current is not None:
|
||||
# Check: current in visited or current == test_id
|
||||
if current in visited or current == test_id:
|
||||
# Raise InvalidOperationError
|
||||
raise InvalidOperationError(
|
||||
# Literal argument value
|
||||
"Circular dependency detected in campaign test chain"
|
||||
)
|
||||
# Call visited.add()
|
||||
visited.add(current)
|
||||
# Assign parent = db.query(CampaignTest).filter_by(id=current).first()
|
||||
parent = db.query(CampaignTest).filter_by(id=current).first()
|
||||
# Assign current = parent.depends_on if parent else None
|
||||
current = parent.depends_on if parent else None
|
||||
|
||||
|
||||
# Define function get_campaign_progress
|
||||
def get_campaign_progress(db: Session, campaign_id: uuid.UUID) -> dict:
|
||||
"""Calculate progress statistics for a campaign.
|
||||
|
||||
Returns counts of tests by state, plus total and completion percentage.
|
||||
"""
|
||||
# Assign campaign_tests = (
|
||||
campaign_tests = (
|
||||
db.query(CampaignTest)
|
||||
# Chain .filter() call
|
||||
.filter(CampaignTest.campaign_id == campaign_id)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
|
||||
# Check: not campaign_tests
|
||||
if not campaign_tests:
|
||||
# Return {
|
||||
return {
|
||||
# Literal argument value
|
||||
"total": 0,
|
||||
# Literal argument value
|
||||
"by_state": {},
|
||||
# Literal argument value
|
||||
"completion_pct": 0.0,
|
||||
}
|
||||
|
||||
# Assign by_state = {}
|
||||
by_state: dict[str, int] = {}
|
||||
# Iterate over campaign_tests
|
||||
for ct in campaign_tests:
|
||||
# Assign test = ct.test
|
||||
test = ct.test
|
||||
# Assign state = test.state.value if test and test.state else "unknown"
|
||||
state = test.state.value if test and test.state else "unknown"
|
||||
# Assign by_state[state] = by_state.get(state, 0) + 1
|
||||
by_state[state] = by_state.get(state, 0) + 1
|
||||
|
||||
# Assign total = len(campaign_tests)
|
||||
total = len(campaign_tests)
|
||||
# Assign completed = by_state.get("validated", 0)
|
||||
completed = by_state.get("validated", 0)
|
||||
# Assign completion_pct = round(completed / total * 100, 1) if total > 0 else 0.0
|
||||
completion_pct = round(completed / total * 100, 1) if total > 0 else 0.0
|
||||
|
||||
# Return {
|
||||
return {
|
||||
# Literal argument value
|
||||
"total": total,
|
||||
# Literal argument value
|
||||
"by_state": by_state,
|
||||
# Literal argument value
|
||||
"completion_pct": completion_pct,
|
||||
}
|
||||
|
||||
|
||||
# Define function generate_campaign_from_threat_actor
|
||||
def generate_campaign_from_threat_actor(
|
||||
# Entry: db
|
||||
db: Session,
|
||||
# Entry: actor_id
|
||||
actor_id: uuid.UUID,
|
||||
# Entry: user
|
||||
user: User,
|
||||
*,
|
||||
start_date: Optional[datetime] = None,
|
||||
@@ -119,75 +194,111 @@ def generate_campaign_from_threat_actor(
|
||||
4. Create a campaign with tests ordered by kill chain phase
|
||||
5. Return the campaign
|
||||
"""
|
||||
# Assign actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first()
|
||||
actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first()
|
||||
# Check: not actor
|
||||
if not actor:
|
||||
# Raise EntityNotFoundError
|
||||
raise EntityNotFoundError("ThreatActor", str(actor_id))
|
||||
|
||||
# Get unvalidated techniques for this actor
|
||||
gap_techniques = (
|
||||
db.query(Technique, ThreatActorTechnique)
|
||||
# Chain .join() call
|
||||
.join(ThreatActorTechnique, ThreatActorTechnique.technique_id == Technique.id)
|
||||
# Chain .filter() call
|
||||
.filter(ThreatActorTechnique.threat_actor_id == actor_id)
|
||||
# Chain .filter() call
|
||||
.filter(Technique.status_global != TechniqueStatus.validated)
|
||||
# Chain .order_by() call
|
||||
.order_by(Technique.tactic, Technique.mitre_id)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
|
||||
# Check: not gap_techniques
|
||||
if not gap_techniques:
|
||||
# Raise InvalidOperationError
|
||||
raise InvalidOperationError(
|
||||
f"No uncovered techniques found for {actor.name}"
|
||||
)
|
||||
|
||||
# Create the campaign
|
||||
campaign = Campaign(
|
||||
# Keyword argument: name
|
||||
name=f"APT Emulation: {actor.name}",
|
||||
# Keyword argument: description
|
||||
description=f"Auto-generated campaign to test coverage against {actor.name} "
|
||||
f"({actor.mitre_id or 'unknown'}). "
|
||||
f"Covers {len(gap_techniques)} uncovered technique(s).",
|
||||
# Keyword argument: type
|
||||
type="apt_emulation",
|
||||
# Keyword argument: threat_actor_id
|
||||
threat_actor_id=actor_id,
|
||||
# Keyword argument: status
|
||||
status="draft",
|
||||
# Keyword argument: created_by
|
||||
created_by=user.id,
|
||||
# Keyword argument: tags
|
||||
tags=[actor.name, "auto-generated"],
|
||||
start_date=start_date,
|
||||
)
|
||||
# Stage new record(s) for database insertion
|
||||
db.add(campaign)
|
||||
# Flush changes to DB without committing the transaction
|
||||
db.flush() # Get campaign.id
|
||||
|
||||
# Assign order_index = 0
|
||||
order_index = 0
|
||||
|
||||
# Iterate over gap_techniques
|
||||
for tech, _at in gap_techniques:
|
||||
# Find best template for this technique
|
||||
template = (
|
||||
db.query(TestTemplate)
|
||||
# Chain .filter() call
|
||||
.filter(
|
||||
TestTemplate.mitre_technique_id == tech.mitre_id,
|
||||
TestTemplate.is_active == True, # noqa: E712
|
||||
)
|
||||
# Chain .order_by() call
|
||||
.order_by(
|
||||
# Prioritize by severity: critical > high > medium > low
|
||||
TestTemplate.severity.desc(),
|
||||
TestTemplate.name,
|
||||
)
|
||||
# Chain .first() call
|
||||
.first()
|
||||
)
|
||||
|
||||
# Check: not template
|
||||
if not template:
|
||||
# continue # Skip techniques without templates
|
||||
continue # Skip techniques without templates
|
||||
|
||||
# Create a test from the template
|
||||
test = Test(
|
||||
# Keyword argument: technique_id
|
||||
technique_id=tech.id,
|
||||
# Keyword argument: name
|
||||
name=f"[Campaign] {template.name}",
|
||||
# Keyword argument: description
|
||||
description=template.description,
|
||||
# Keyword argument: platform
|
||||
platform=template.platform,
|
||||
# Keyword argument: procedure_text
|
||||
procedure_text=template.attack_procedure,
|
||||
# Keyword argument: tool_used
|
||||
tool_used=template.tool_suggested,
|
||||
# Keyword argument: created_by
|
||||
created_by=user.id,
|
||||
# Keyword argument: state
|
||||
state=TestState.draft,
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
# Stage new record(s) for database insertion
|
||||
db.add(test)
|
||||
# Flush changes to DB without committing the transaction
|
||||
db.flush() # Get test.id
|
||||
|
||||
# Determine kill chain phase from the technique's tactic
|
||||
@@ -195,22 +306,33 @@ def generate_campaign_from_threat_actor(
|
||||
|
||||
# Add to campaign
|
||||
campaign_test = CampaignTest(
|
||||
# Keyword argument: campaign_id
|
||||
campaign_id=campaign.id,
|
||||
# Keyword argument: test_id
|
||||
test_id=test.id,
|
||||
# Keyword argument: order_index
|
||||
order_index=order_index,
|
||||
# Keyword argument: phase
|
||||
phase=phase,
|
||||
)
|
||||
# Stage new record(s) for database insertion
|
||||
db.add(campaign_test)
|
||||
# Assign order_index = 1
|
||||
order_index += 1
|
||||
|
||||
# Commit all pending changes to the database
|
||||
db.commit()
|
||||
# Reload ORM object attributes from the database
|
||||
db.refresh(campaign)
|
||||
|
||||
# Log info:
|
||||
logger.info(
|
||||
# Literal argument value
|
||||
"Generated campaign '%s' with %d tests for actor %s",
|
||||
campaign.name,
|
||||
order_index,
|
||||
actor.name,
|
||||
)
|
||||
|
||||
# Return campaign
|
||||
return campaign
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -6,111 +6,184 @@ that the router remains a thin HTTP adapter.
|
||||
This module is framework-agnostic: no FastAPI imports.
|
||||
"""
|
||||
|
||||
# Enable future language features for compatibility
|
||||
from __future__ import annotations
|
||||
|
||||
# Import csv
|
||||
import csv
|
||||
|
||||
# Import io
|
||||
import io
|
||||
|
||||
# Import Any from typing
|
||||
from typing import Any
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import EntityNotFoundError from app.domain.errors
|
||||
from app.domain.errors import EntityNotFoundError
|
||||
|
||||
# Import from app.models.compliance
|
||||
from app.models.compliance import (
|
||||
ComplianceFramework,
|
||||
ComplianceControl,
|
||||
ComplianceControlMapping,
|
||||
ComplianceFramework,
|
||||
)
|
||||
from app.models.technique import Technique
|
||||
from app.models.test_template import TestTemplate
|
||||
from app.models.threat_actor import ThreatActorTechnique
|
||||
from app.services.scoring_service import calculate_technique_score
|
||||
|
||||
# Import Technique from app.models.technique
|
||||
from app.models.technique import Technique
|
||||
|
||||
# Import TestTemplate from app.models.test_template
|
||||
from app.models.test_template import TestTemplate
|
||||
|
||||
# Import ThreatActorTechnique from app.models.threat_actor
|
||||
from app.models.threat_actor import ThreatActorTechnique
|
||||
|
||||
# Import calculate_technique_score from app.services.scoring_service
|
||||
from app.services.scoring_service import calculate_technique_score
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _classify_control(technique_scores: list[float]) -> str:
|
||||
"""Classify a control status based on its technique scores."""
|
||||
# Check: not technique_scores
|
||||
if not technique_scores:
|
||||
# Return "not_evaluated"
|
||||
return "not_evaluated"
|
||||
|
||||
# Assign all_above_70 = all(s >= 70 for s in technique_scores)
|
||||
all_above_70 = all(s >= 70 for s in technique_scores)
|
||||
# Assign any_above_30 = any(s >= 30 for s in technique_scores)
|
||||
any_above_30 = any(s >= 30 for s in technique_scores)
|
||||
# Assign all_below_30 = all(s < 30 for s in technique_scores)
|
||||
all_below_30 = all(s < 30 for s in technique_scores)
|
||||
# Assign all_zero = all(s == 0 for s in technique_scores)
|
||||
all_zero = all(s == 0 for s in technique_scores)
|
||||
|
||||
# Check: all_zero
|
||||
if all_zero:
|
||||
# Return "not_evaluated"
|
||||
return "not_evaluated"
|
||||
# Check: all_above_70
|
||||
if all_above_70:
|
||||
# Return "covered"
|
||||
return "covered"
|
||||
# Check: all_below_30
|
||||
if all_below_30:
|
||||
# Return "not_covered"
|
||||
return "not_covered"
|
||||
# Check: any_above_30
|
||||
if any_above_30:
|
||||
# Return "partially_covered"
|
||||
return "partially_covered"
|
||||
# Return "not_covered"
|
||||
return "not_covered"
|
||||
|
||||
|
||||
# Define function _get_control_status
|
||||
def _get_control_status(control: ComplianceControl, db: Session) -> dict[str, Any]:
|
||||
"""Compute the status and score for a single control."""
|
||||
# Assign mappings = (
|
||||
mappings = (
|
||||
db.query(ComplianceControlMapping)
|
||||
# Chain .filter() call
|
||||
.filter(ComplianceControlMapping.compliance_control_id == control.id)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
|
||||
# Check: not mappings
|
||||
if not mappings:
|
||||
# Return {
|
||||
return {
|
||||
# Literal argument value
|
||||
"control_id": control.control_id,
|
||||
# Literal argument value
|
||||
"title": control.title,
|
||||
"description": control.description,
|
||||
"category": control.category,
|
||||
# Literal argument value
|
||||
"status": "not_evaluated",
|
||||
# Literal argument value
|
||||
"score": 0,
|
||||
# Literal argument value
|
||||
"techniques_count": 0,
|
||||
# Literal argument value
|
||||
"techniques_covered": 0,
|
||||
# Literal argument value
|
||||
"techniques": [],
|
||||
}
|
||||
|
||||
# Assign technique_ids = [m.technique_id for m in mappings]
|
||||
technique_ids = [m.technique_id for m in mappings]
|
||||
# Assign techniques = (
|
||||
techniques = (
|
||||
db.query(Technique)
|
||||
# Chain .filter() call
|
||||
.filter(Technique.id.in_(technique_ids))
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
|
||||
# Assign tech_details = []
|
||||
tech_details = []
|
||||
# Assign scores = []
|
||||
scores = []
|
||||
# Assign covered_count = 0
|
||||
covered_count = 0
|
||||
|
||||
# Iterate over techniques
|
||||
for tech in techniques:
|
||||
# Assign result = calculate_technique_score(tech, db)
|
||||
result = calculate_technique_score(tech, db)
|
||||
# Assign score = result["total_score"]
|
||||
score = result["total_score"]
|
||||
# Call scores.append()
|
||||
scores.append(score)
|
||||
# Check: score >= 50
|
||||
if score >= 50:
|
||||
# Assign covered_count = 1
|
||||
covered_count += 1
|
||||
|
||||
# Call tech_details.append()
|
||||
tech_details.append({
|
||||
# Literal argument value
|
||||
"mitre_id": tech.mitre_id,
|
||||
# Literal argument value
|
||||
"name": tech.name,
|
||||
# Literal argument value
|
||||
"score": score,
|
||||
# Literal argument value
|
||||
"status": tech.status_global.value if tech.status_global else "not_evaluated",
|
||||
})
|
||||
|
||||
# Sort techniques by score ascending (worst first for priority)
|
||||
tech_details.sort(key=lambda t: t["score"])
|
||||
|
||||
# Assign avg_score = round(sum(scores) / len(scores), 1) if scores else 0
|
||||
avg_score = round(sum(scores) / len(scores), 1) if scores else 0
|
||||
# Assign status = _classify_control(scores)
|
||||
status = _classify_control(scores)
|
||||
|
||||
# Return {
|
||||
return {
|
||||
# Literal argument value
|
||||
"control_id": control.control_id,
|
||||
# Literal argument value
|
||||
"title": control.title,
|
||||
"description": control.description,
|
||||
"category": control.category,
|
||||
# Literal argument value
|
||||
"status": status,
|
||||
# Literal argument value
|
||||
"score": avg_score,
|
||||
# Literal argument value
|
||||
"techniques_count": len(techniques),
|
||||
# Literal argument value
|
||||
"techniques_covered": covered_count,
|
||||
# Literal argument value
|
||||
"techniques": tech_details,
|
||||
}
|
||||
|
||||
@@ -120,95 +193,150 @@ def _get_control_status(control: ComplianceControl, db: Session) -> dict[str, An
|
||||
|
||||
def list_frameworks(db: Session) -> list[dict[str, Any]]:
|
||||
"""List all available compliance frameworks with control counts."""
|
||||
# Assign frameworks = (
|
||||
frameworks = (
|
||||
db.query(ComplianceFramework)
|
||||
# Chain .filter() call
|
||||
.filter(ComplianceFramework.is_active == True)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
|
||||
# Assign result = []
|
||||
result = []
|
||||
# Iterate over frameworks
|
||||
for fw in frameworks:
|
||||
# Assign control_count = (
|
||||
control_count = (
|
||||
db.query(ComplianceControl)
|
||||
# Chain .filter() call
|
||||
.filter(ComplianceControl.framework_id == fw.id)
|
||||
# Chain .count() call
|
||||
.count()
|
||||
)
|
||||
# Call result.append()
|
||||
result.append({
|
||||
# Literal argument value
|
||||
"id": str(fw.id),
|
||||
# Literal argument value
|
||||
"name": fw.name,
|
||||
# Literal argument value
|
||||
"version": fw.version,
|
||||
# Literal argument value
|
||||
"description": fw.description,
|
||||
# Literal argument value
|
||||
"url": fw.url,
|
||||
# Literal argument value
|
||||
"is_active": fw.is_active,
|
||||
# Literal argument value
|
||||
"controls_count": control_count,
|
||||
})
|
||||
|
||||
# Return result
|
||||
return result
|
||||
|
||||
|
||||
# Define function get_framework
|
||||
def get_framework(db: Session, framework_id: str) -> ComplianceFramework | None:
|
||||
"""Get a framework by ID, or None if not found."""
|
||||
# Return (
|
||||
return (
|
||||
db.query(ComplianceFramework)
|
||||
# Chain .filter() call
|
||||
.filter(ComplianceFramework.id == framework_id)
|
||||
# Chain .first() call
|
||||
.first()
|
||||
)
|
||||
|
||||
|
||||
# Define function get_framework_status
|
||||
def get_framework_status(db: Session, framework_id: str) -> dict[str, Any]:
|
||||
"""Get compliance status for each control in a framework.
|
||||
|
||||
Raises EntityNotFoundError if the framework does not exist.
|
||||
"""
|
||||
# Assign framework = get_framework(db, framework_id)
|
||||
framework = get_framework(db, framework_id)
|
||||
# Check: not framework
|
||||
if not framework:
|
||||
# Raise EntityNotFoundError
|
||||
raise EntityNotFoundError("Framework", framework_id)
|
||||
|
||||
# Assign controls = (
|
||||
controls = (
|
||||
db.query(ComplianceControl)
|
||||
# Chain .filter() call
|
||||
.filter(ComplianceControl.framework_id == framework.id)
|
||||
# Chain .order_by() call
|
||||
.order_by(ComplianceControl.control_id)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
|
||||
# Assign control_statuses = []
|
||||
control_statuses = []
|
||||
# Assign summary = {
|
||||
summary = {
|
||||
# Literal argument value
|
||||
"total_controls": len(controls),
|
||||
# Literal argument value
|
||||
"covered": 0,
|
||||
# Literal argument value
|
||||
"partially_covered": 0,
|
||||
# Literal argument value
|
||||
"not_covered": 0,
|
||||
# Literal argument value
|
||||
"not_evaluated": 0,
|
||||
}
|
||||
|
||||
# Iterate over controls
|
||||
for control in controls:
|
||||
# Assign status_data = _get_control_status(control, db)
|
||||
status_data = _get_control_status(control, db)
|
||||
# Call control_statuses.append()
|
||||
control_statuses.append(status_data)
|
||||
|
||||
# Assign status = status_data["status"]
|
||||
status = status_data["status"]
|
||||
# Check: status in summary
|
||||
if status in summary:
|
||||
# Assign summary[status] = 1
|
||||
summary[status] += 1
|
||||
|
||||
# Compliance percentage: (covered + partially_covered*0.5) / total * 100
|
||||
total = summary["total_controls"]
|
||||
# Check: total > 0
|
||||
if total > 0:
|
||||
# Assign compliance_pct = round(
|
||||
compliance_pct = round(
|
||||
(summary["covered"] + summary["partially_covered"] * 0.5) / total * 100,
|
||||
# Literal argument value
|
||||
1,
|
||||
)
|
||||
# Fallback: handle remaining cases
|
||||
else:
|
||||
# Assign compliance_pct = 0
|
||||
compliance_pct = 0
|
||||
|
||||
# Assign summary["compliance_percentage"] = compliance_pct
|
||||
summary["compliance_percentage"] = compliance_pct
|
||||
|
||||
# Return {
|
||||
return {
|
||||
# Literal argument value
|
||||
"framework": {"id": str(framework.id), "name": framework.name},
|
||||
# Literal argument value
|
||||
"summary": summary,
|
||||
# Literal argument value
|
||||
"controls": control_statuses,
|
||||
}
|
||||
|
||||
|
||||
# Define function build_framework_report_csv
|
||||
def build_framework_report_csv(
|
||||
# Entry: db
|
||||
db: Session,
|
||||
# Entry: framework_id
|
||||
framework_id: str,
|
||||
) -> tuple[bytes, str]:
|
||||
"""Build the compliance report CSV content and filename.
|
||||
@@ -217,33 +345,55 @@ def build_framework_report_csv(
|
||||
|
||||
Raises EntityNotFoundError if the framework does not exist.
|
||||
"""
|
||||
# Assign framework = get_framework(db, framework_id)
|
||||
framework = get_framework(db, framework_id)
|
||||
# Check: not framework
|
||||
if not framework:
|
||||
# Raise EntityNotFoundError
|
||||
raise EntityNotFoundError("Framework", framework_id)
|
||||
|
||||
# Assign controls = (
|
||||
controls = (
|
||||
db.query(ComplianceControl)
|
||||
# Chain .filter() call
|
||||
.filter(ComplianceControl.framework_id == framework.id)
|
||||
# Chain .order_by() call
|
||||
.order_by(ComplianceControl.control_id)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
|
||||
# Assign output = io.StringIO()
|
||||
output = io.StringIO()
|
||||
# Assign writer = csv.writer(output)
|
||||
writer = csv.writer(output)
|
||||
# Call writer.writerow()
|
||||
writer.writerow([
|
||||
# Literal argument value
|
||||
"control_id",
|
||||
# Literal argument value
|
||||
"title",
|
||||
# Literal argument value
|
||||
"category",
|
||||
# Literal argument value
|
||||
"status",
|
||||
# Literal argument value
|
||||
"score",
|
||||
# Literal argument value
|
||||
"techniques_total",
|
||||
# Literal argument value
|
||||
"techniques_covered",
|
||||
# Literal argument value
|
||||
"technique_ids",
|
||||
])
|
||||
|
||||
# Iterate over controls
|
||||
for control in controls:
|
||||
# Assign status_data = _get_control_status(control, db)
|
||||
status_data = _get_control_status(control, db)
|
||||
# Assign technique_ids = ",".join(t["mitre_id"] for t in status_data["techniques"])
|
||||
technique_ids = ",".join(t["mitre_id"] for t in status_data["techniques"])
|
||||
# Call writer.writerow()
|
||||
writer.writerow([
|
||||
status_data["control_id"],
|
||||
status_data["title"],
|
||||
@@ -255,75 +405,116 @@ def build_framework_report_csv(
|
||||
technique_ids,
|
||||
])
|
||||
|
||||
# Call output.seek()
|
||||
output.seek(0)
|
||||
# Assign filename = f"compliance_{framework.name.replace(' ', '_')}.csv"
|
||||
filename = f"compliance_{framework.name.replace(' ', '_')}.csv"
|
||||
# Return output.getvalue().encode("utf-8"), filename
|
||||
return output.getvalue().encode("utf-8"), filename
|
||||
|
||||
|
||||
# Define function get_framework_gaps
|
||||
def get_framework_gaps(db: Session, framework_id: str) -> dict[str, Any]:
|
||||
"""Get controls with techniques that are not adequately covered.
|
||||
|
||||
Raises EntityNotFoundError if the framework does not exist.
|
||||
"""
|
||||
# Assign framework = get_framework(db, framework_id)
|
||||
framework = get_framework(db, framework_id)
|
||||
# Check: not framework
|
||||
if not framework:
|
||||
# Raise EntityNotFoundError
|
||||
raise EntityNotFoundError("Framework", framework_id)
|
||||
|
||||
# Assign controls = (
|
||||
controls = (
|
||||
db.query(ComplianceControl)
|
||||
# Chain .filter() call
|
||||
.filter(ComplianceControl.framework_id == framework.id)
|
||||
# Chain .order_by() call
|
||||
.order_by(ComplianceControl.control_id)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
|
||||
# Assign gaps = []
|
||||
gaps = []
|
||||
# Iterate over controls
|
||||
for control in controls:
|
||||
# Assign status_data = _get_control_status(control, db)
|
||||
status_data = _get_control_status(control, db)
|
||||
|
||||
# Check: status_data["status"] in ("not_covered", "partially_covered")
|
||||
if status_data["status"] in ("not_covered", "partially_covered"):
|
||||
# Find uncovered techniques
|
||||
uncovered_techniques = []
|
||||
# Iterate over status_data["techniques"]
|
||||
for tech_info in status_data["techniques"]:
|
||||
# Check: tech_info["score"] < 70
|
||||
if tech_info["score"] < 70:
|
||||
# Count available templates
|
||||
template_count = (
|
||||
db.query(TestTemplate)
|
||||
# Chain .filter() call
|
||||
.filter(TestTemplate.mitre_technique_id == tech_info["mitre_id"])
|
||||
# Chain .count() call
|
||||
.count()
|
||||
)
|
||||
|
||||
# Count threat actors using this technique
|
||||
technique = (
|
||||
db.query(Technique)
|
||||
# Chain .filter() call
|
||||
.filter(Technique.mitre_id == tech_info["mitre_id"])
|
||||
# Chain .first() call
|
||||
.first()
|
||||
)
|
||||
# Assign actor_count = 0
|
||||
actor_count = 0
|
||||
# Check: technique
|
||||
if technique:
|
||||
# Assign actor_count = (
|
||||
actor_count = (
|
||||
db.query(ThreatActorTechnique)
|
||||
# Chain .filter() call
|
||||
.filter(ThreatActorTechnique.technique_id == technique.id)
|
||||
# Chain .count() call
|
||||
.count()
|
||||
)
|
||||
|
||||
# Call uncovered_techniques.append()
|
||||
uncovered_techniques.append({
|
||||
**tech_info,
|
||||
# Literal argument value
|
||||
"templates_available": template_count,
|
||||
# Literal argument value
|
||||
"threat_actors_using": actor_count,
|
||||
})
|
||||
|
||||
# Check: uncovered_techniques
|
||||
if uncovered_techniques:
|
||||
# Call gaps.append()
|
||||
gaps.append({
|
||||
# Literal argument value
|
||||
"control_id": status_data["control_id"],
|
||||
# Literal argument value
|
||||
"title": status_data["title"],
|
||||
# Literal argument value
|
||||
"category": status_data["category"],
|
||||
# Literal argument value
|
||||
"status": status_data["status"],
|
||||
# Literal argument value
|
||||
"score": status_data["score"],
|
||||
# Literal argument value
|
||||
"uncovered_techniques": uncovered_techniques,
|
||||
})
|
||||
|
||||
# Return {
|
||||
return {
|
||||
# Literal argument value
|
||||
"framework": {"id": str(framework.id), "name": framework.name},
|
||||
# Literal argument value
|
||||
"total_gaps": len(gaps),
|
||||
# Literal argument value
|
||||
"gaps": gaps,
|
||||
}
|
||||
|
||||
@@ -7,120 +7,202 @@ technique/test-count pattern by using a single grouped query.
|
||||
This module is framework-agnostic: no FastAPI imports.
|
||||
"""
|
||||
|
||||
# Enable future language features for compatibility
|
||||
from __future__ import annotations
|
||||
|
||||
# Import datetime from datetime
|
||||
from datetime import datetime
|
||||
|
||||
# Import func from sqlalchemy
|
||||
from sqlalchemy import func
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import Technique from app.models.technique
|
||||
from app.models.technique import Technique
|
||||
|
||||
# Import Test from app.models.test
|
||||
from app.models.test import Test
|
||||
|
||||
# Import escape_like from app.utils
|
||||
from app.utils import escape_like
|
||||
|
||||
|
||||
# Define function _technique_test_counts
|
||||
def _technique_test_counts(
|
||||
# Entry: db
|
||||
db: Session,
|
||||
# Entry: technique_ids
|
||||
technique_ids: list,
|
||||
) -> dict:
|
||||
"""Return ``{technique_id: {state_str: count}}`` in a single query."""
|
||||
# Check: not technique_ids
|
||||
if not technique_ids:
|
||||
# Return {}
|
||||
return {}
|
||||
# Assign rows = (
|
||||
rows = (
|
||||
db.query(Test.technique_id, Test.state, func.count(Test.id))
|
||||
# Chain .filter() call
|
||||
.filter(Test.technique_id.in_(technique_ids))
|
||||
# Chain .group_by() call
|
||||
.group_by(Test.technique_id, Test.state)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
# Assign result = {}
|
||||
result: dict = {}
|
||||
# Iterate over rows
|
||||
for tid, state, count in rows:
|
||||
# Call result.setdefault()
|
||||
result.setdefault(tid, {})[str(state)] = count
|
||||
# Return result
|
||||
return result
|
||||
|
||||
|
||||
# Define function build_coverage_summary
|
||||
def build_coverage_summary(
|
||||
# Entry: db
|
||||
db: Session,
|
||||
*,
|
||||
# Entry: tactic
|
||||
tactic: str | None = None,
|
||||
# Entry: platform
|
||||
platform: str | None = None,
|
||||
) -> dict:
|
||||
"""Build the full coverage summary report as a dict."""
|
||||
# Assign query = db.query(Technique)
|
||||
query = db.query(Technique)
|
||||
# Check: tactic
|
||||
if tactic:
|
||||
# Assign query = query.filter(Technique.tactic.ilike(f"%{escape_like(tactic)}%"))
|
||||
query = query.filter(Technique.tactic.ilike(f"%{escape_like(tactic)}%"))
|
||||
|
||||
# Assign techniques = query.order_by(Technique.mitre_id).all()
|
||||
techniques = query.order_by(Technique.mitre_id).all()
|
||||
# Assign counts_map = _technique_test_counts(db, [t.id for t in techniques])
|
||||
counts_map = _technique_test_counts(db, [t.id for t in techniques])
|
||||
|
||||
# Assign rows = []
|
||||
rows = []
|
||||
# Iterate over techniques
|
||||
for t in techniques:
|
||||
# Check: platform and platform.lower() not in [p.lower() for p in (t.platfor...
|
||||
if platform and platform.lower() not in [p.lower() for p in (t.platforms or [])]:
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
|
||||
# Assign counts = counts_map.get(t.id, {})
|
||||
counts = counts_map.get(t.id, {})
|
||||
# Call rows.append()
|
||||
rows.append({
|
||||
# Literal argument value
|
||||
"mitre_id": t.mitre_id,
|
||||
# Literal argument value
|
||||
"name": t.name,
|
||||
# Literal argument value
|
||||
"tactic": t.tactic,
|
||||
# Literal argument value
|
||||
"platforms": t.platforms,
|
||||
# Literal argument value
|
||||
"status_global": t.status_global,
|
||||
# Literal argument value
|
||||
"total_tests": sum(counts.values()),
|
||||
# Literal argument value
|
||||
"tests_by_state": counts,
|
||||
})
|
||||
|
||||
# Assign total = len(rows)
|
||||
total = len(rows)
|
||||
# Assign validated = sum(1 for r in rows if r["status_global"] == "validated")
|
||||
validated = sum(1 for r in rows if r["status_global"] == "validated")
|
||||
# Assign partial = sum(1 for r in rows if r["status_global"] == "partial")
|
||||
partial = sum(1 for r in rows if r["status_global"] == "partial")
|
||||
# Assign not_covered = sum(1 for r in rows if r["status_global"] == "not_covered")
|
||||
not_covered = sum(1 for r in rows if r["status_global"] == "not_covered")
|
||||
# Assign in_progress = sum(1 for r in rows if r["status_global"] == "in_progress")
|
||||
in_progress = sum(1 for r in rows if r["status_global"] == "in_progress")
|
||||
# Assign not_evaluated = sum(1 for r in rows if r["status_global"] == "not_evaluated")
|
||||
not_evaluated = sum(1 for r in rows if r["status_global"] == "not_evaluated")
|
||||
|
||||
# Return {
|
||||
return {
|
||||
# Literal argument value
|
||||
"generated_at": datetime.utcnow().isoformat(),
|
||||
# Literal argument value
|
||||
"summary": {
|
||||
# Literal argument value
|
||||
"total_techniques": total,
|
||||
# Literal argument value
|
||||
"validated": validated,
|
||||
# Literal argument value
|
||||
"partial": partial,
|
||||
# Literal argument value
|
||||
"not_covered": not_covered,
|
||||
# Literal argument value
|
||||
"in_progress": in_progress,
|
||||
# Literal argument value
|
||||
"not_evaluated": not_evaluated,
|
||||
# Literal argument value
|
||||
"coverage_percentage": round((validated / total * 100) if total > 0 else 0, 1),
|
||||
},
|
||||
# Literal argument value
|
||||
"techniques": rows,
|
||||
}
|
||||
|
||||
|
||||
# Define function build_coverage_csv_rows
|
||||
def build_coverage_csv_rows(
|
||||
# Entry: db
|
||||
db: Session,
|
||||
*,
|
||||
# Entry: tactic
|
||||
tactic: str | None = None,
|
||||
# Entry: platform
|
||||
platform: str | None = None,
|
||||
) -> list[list]:
|
||||
"""Build rows for a CSV coverage export (header + data)."""
|
||||
# Assign query = db.query(Technique)
|
||||
query = db.query(Technique)
|
||||
# Check: tactic
|
||||
if tactic:
|
||||
# Assign query = query.filter(Technique.tactic.ilike(f"%{escape_like(tactic)}%"))
|
||||
query = query.filter(Technique.tactic.ilike(f"%{escape_like(tactic)}%"))
|
||||
|
||||
# Assign techniques = query.order_by(Technique.mitre_id).all()
|
||||
techniques = query.order_by(Technique.mitre_id).all()
|
||||
# Assign counts_map = _technique_test_counts(db, [t.id for t in techniques])
|
||||
counts_map = _technique_test_counts(db, [t.id for t in techniques])
|
||||
|
||||
# Assign header = [
|
||||
header = [
|
||||
# Literal argument value
|
||||
"MITRE ID", "Name", "Tactic", "Platforms", "Status",
|
||||
# Literal argument value
|
||||
"Total Tests", "Validated", "In Progress", "Not Covered",
|
||||
]
|
||||
# Assign rows = [header]
|
||||
rows = [header]
|
||||
|
||||
# Assign in_progress_states = {"draft", "red_executing", "blue_evaluating", "in_review"}
|
||||
in_progress_states = {"draft", "red_executing", "blue_evaluating", "in_review"}
|
||||
|
||||
# Iterate over techniques
|
||||
for t in techniques:
|
||||
# Check: platform and platform.lower() not in [p.lower() for p in (t.platfor...
|
||||
if platform and platform.lower() not in [p.lower() for p in (t.platforms or [])]:
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
|
||||
# Assign counts = counts_map.get(t.id, {})
|
||||
counts = counts_map.get(t.id, {})
|
||||
# Call rows.append()
|
||||
rows.append([
|
||||
t.mitre_id,
|
||||
t.name,
|
||||
t.tactic,
|
||||
# Literal argument value
|
||||
", ".join(t.platforms or []),
|
||||
t.status_global,
|
||||
sum(counts.values()),
|
||||
@@ -129,65 +211,111 @@ def build_coverage_csv_rows(
|
||||
counts.get("rejected", 0),
|
||||
])
|
||||
|
||||
# Return rows
|
||||
return rows
|
||||
|
||||
|
||||
# Define function build_test_results_report
|
||||
def build_test_results_report(
|
||||
# Entry: db
|
||||
db: Session,
|
||||
*,
|
||||
# Entry: state
|
||||
state: str | None = None,
|
||||
# Entry: date_from
|
||||
date_from: str | None = None,
|
||||
# Entry: date_to
|
||||
date_to: str | None = None,
|
||||
) -> dict:
|
||||
"""Build a test results report with optional filters."""
|
||||
# Assign query = db.query(Test)
|
||||
query = db.query(Test)
|
||||
|
||||
# Check: state
|
||||
if state:
|
||||
# Assign query = query.filter(Test.state == state)
|
||||
query = query.filter(Test.state == state)
|
||||
# Check: date_from
|
||||
if date_from:
|
||||
# Attempt the following; catch errors below
|
||||
try:
|
||||
# Assign query = query.filter(Test.created_at >= datetime.fromisoformat(date_from))
|
||||
query = query.filter(Test.created_at >= datetime.fromisoformat(date_from))
|
||||
# Handle ValueError
|
||||
except ValueError:
|
||||
# Intentional no-op placeholder
|
||||
pass
|
||||
# Check: date_to
|
||||
if date_to:
|
||||
# Attempt the following; catch errors below
|
||||
try:
|
||||
# Assign query = query.filter(Test.created_at <= datetime.fromisoformat(date_to))
|
||||
query = query.filter(Test.created_at <= datetime.fromisoformat(date_to))
|
||||
# Handle ValueError
|
||||
except ValueError:
|
||||
# Intentional no-op placeholder
|
||||
pass
|
||||
|
||||
# Assign tests = query.order_by(Test.created_at.desc()).all()
|
||||
tests = query.order_by(Test.created_at.desc()).all()
|
||||
|
||||
# Assign by_state = {}
|
||||
by_state: dict[str, int] = {}
|
||||
# Assign by_result = {}
|
||||
by_result: dict[str, int] = {}
|
||||
# Iterate over tests
|
||||
for t in tests:
|
||||
# Assign s = t.state.value if hasattr(t.state, "value") else str(t.state)
|
||||
s = t.state.value if hasattr(t.state, "value") else str(t.state)
|
||||
# Assign by_state[s] = by_state.get(s, 0) + 1
|
||||
by_state[s] = by_state.get(s, 0) + 1
|
||||
# Check: t.detection_result
|
||||
if t.detection_result:
|
||||
# Assign r = t.detection_result.value if hasattr(t.detection_result, "value") el...
|
||||
r = t.detection_result.value if hasattr(t.detection_result, "value") else str(t.detection_result)
|
||||
# Assign by_result[r] = by_result.get(r, 0) + 1
|
||||
by_result[r] = by_result.get(r, 0) + 1
|
||||
|
||||
# Return {
|
||||
return {
|
||||
# Literal argument value
|
||||
"generated_at": datetime.utcnow().isoformat(),
|
||||
# Literal argument value
|
||||
"filters": {"state": state, "date_from": date_from, "date_to": date_to},
|
||||
# Literal argument value
|
||||
"summary": {
|
||||
# Literal argument value
|
||||
"total_tests": len(tests),
|
||||
# Literal argument value
|
||||
"by_state": by_state,
|
||||
# Literal argument value
|
||||
"by_detection_result": by_result,
|
||||
},
|
||||
# Literal argument value
|
||||
"tests": [
|
||||
{
|
||||
# Literal argument value
|
||||
"id": str(t.id),
|
||||
# Literal argument value
|
||||
"name": t.name,
|
||||
# Literal argument value
|
||||
"technique_id": str(t.technique_id),
|
||||
# Literal argument value
|
||||
"state": t.state.value if hasattr(t.state, "value") else str(t.state),
|
||||
# Literal argument value
|
||||
"platform": t.platform,
|
||||
# Literal argument value
|
||||
"attack_success": t.attack_success,
|
||||
# Literal argument value
|
||||
"detection_result": (
|
||||
t.detection_result.value if t.detection_result and hasattr(t.detection_result, "value")
|
||||
else str(t.detection_result) if t.detection_result else None
|
||||
),
|
||||
# Literal argument value
|
||||
"red_validation_status": t.red_validation_status,
|
||||
# Literal argument value
|
||||
"blue_validation_status": t.blue_validation_status,
|
||||
# Literal argument value
|
||||
"created_at": t.created_at.isoformat() if t.created_at else None,
|
||||
}
|
||||
for t in tests
|
||||
@@ -195,38 +323,62 @@ def build_test_results_report(
|
||||
}
|
||||
|
||||
|
||||
# Define function build_remediation_status_report
|
||||
def build_remediation_status_report(
|
||||
# Entry: db
|
||||
db: Session,
|
||||
*,
|
||||
# Entry: status
|
||||
status: str | None = None,
|
||||
) -> dict:
|
||||
"""Build a remediation status report."""
|
||||
# Assign query = db.query(Test).filter(Test.remediation_steps.isnot(None))
|
||||
query = db.query(Test).filter(Test.remediation_steps.isnot(None))
|
||||
|
||||
# Check: status
|
||||
if status:
|
||||
# Assign query = query.filter(Test.remediation_status == status)
|
||||
query = query.filter(Test.remediation_status == status)
|
||||
|
||||
# Assign tests = query.order_by(Test.created_at.desc()).all()
|
||||
tests = query.order_by(Test.created_at.desc()).all()
|
||||
|
||||
# Assign by_status = {}
|
||||
by_status: dict[str, int] = {}
|
||||
# Iterate over tests
|
||||
for t in tests:
|
||||
# Assign s = t.remediation_status or "unset"
|
||||
s = t.remediation_status or "unset"
|
||||
# Assign by_status[s] = by_status.get(s, 0) + 1
|
||||
by_status[s] = by_status.get(s, 0) + 1
|
||||
|
||||
# Return {
|
||||
return {
|
||||
# Literal argument value
|
||||
"generated_at": datetime.utcnow().isoformat(),
|
||||
# Literal argument value
|
||||
"summary": {
|
||||
# Literal argument value
|
||||
"total_with_remediation": len(tests),
|
||||
# Literal argument value
|
||||
"by_status": by_status,
|
||||
},
|
||||
# Literal argument value
|
||||
"tests": [
|
||||
{
|
||||
# Literal argument value
|
||||
"id": str(t.id),
|
||||
# Literal argument value
|
||||
"name": t.name,
|
||||
# Literal argument value
|
||||
"technique_id": str(t.technique_id),
|
||||
# Literal argument value
|
||||
"state": t.state.value if hasattr(t.state, "value") else str(t.state),
|
||||
# Literal argument value
|
||||
"remediation_status": t.remediation_status,
|
||||
# Literal argument value
|
||||
"remediation_steps": t.remediation_steps,
|
||||
# Literal argument value
|
||||
"remediation_assignee": str(t.remediation_assignee) if t.remediation_assignee else None,
|
||||
}
|
||||
for t in tests
|
||||
|
||||
@@ -1,148 +1,270 @@
|
||||
"""D3FEND import service — fetches MITRE D3FEND data and creates
|
||||
DefensiveTechnique records plus ATT&CK → D3FEND mappings.
|
||||
"""D3FEND import service — fetches MITRE D3FEND data and creates DefensiveTechnique records plus ATT&CK → D3FEND mappings.
|
||||
|
||||
Uses the D3FEND public API:
|
||||
- https://d3fend.mitre.org/api/technique/api-all.json (all defensive techniques)
|
||||
- https://d3fend.mitre.org/api/offensive-technique/{attack_id}.json (mappings per ATT&CK technique)
|
||||
"""
|
||||
|
||||
# Import logging
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
# Import Any from typing
|
||||
from typing import Any
|
||||
|
||||
# Import UUID from uuid
|
||||
from uuid import UUID
|
||||
|
||||
# Import httpx
|
||||
import httpx
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.technique import Technique
|
||||
# Import DefensiveTechnique, DefensiveTechniqueMapping from app.models.defensive_technique
|
||||
from app.models.defensive_technique import DefensiveTechnique, DefensiveTechniqueMapping
|
||||
|
||||
# Import Technique from app.models.technique
|
||||
from app.models.technique import Technique
|
||||
|
||||
# Assign logger = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Assign D3FEND_TACTIC_URL = "https://d3fend.mitre.org/api/tactic/d3f:{tactic}.json"
|
||||
D3FEND_TACTIC_URL = "https://d3fend.mitre.org/api/tactic/d3f:{tactic}.json"
|
||||
# Assign D3FEND_MAPPING_URL = "https://d3fend.mitre.org/api/offensive-technique/{attack_id}.json"
|
||||
D3FEND_MAPPING_URL = "https://d3fend.mitre.org/api/offensive-technique/{attack_id}.json"
|
||||
# Assign D3FEND_BASE_URL = "https://d3fend.mitre.org/technique/d3f:{iri}"
|
||||
D3FEND_BASE_URL = "https://d3fend.mitre.org/technique/d3f:{iri}"
|
||||
# Assign D3FEND_TACTICS = ["Detect", "Harden", "Isolate", "Deceive", "Evict", "Model"]
|
||||
D3FEND_TACTICS = ["Detect", "Harden", "Isolate", "Deceive", "Evict", "Model"]
|
||||
|
||||
|
||||
# ── Import all D3FEND techniques ─────────────────────────────────────
|
||||
|
||||
|
||||
def _to_str(v: Any) -> str:
|
||||
"""Coerce an RDF value (str, dict with @value, or list) to a plain string."""
|
||||
def _to_str(v: Any) -> str: # noqa: ANN401
|
||||
"""Coerce an RDF value (str, dict with @value, or list) to a plain string.
|
||||
|
||||
Args:
|
||||
v (Any): RDF node value — may be a plain string, a dict containing
|
||||
a ``@value`` key, or a list of such values.
|
||||
|
||||
Returns:
|
||||
str: Plain string representation; ``"; "``-joined for list inputs.
|
||||
"""
|
||||
# Check: isinstance(v, dict)
|
||||
if isinstance(v, dict):
|
||||
# Return v.get("@value", str(v))
|
||||
return v.get("@value", str(v))
|
||||
# Check: isinstance(v, list)
|
||||
if isinstance(v, list):
|
||||
# Return "; ".join(_to_str(x) for x in v)
|
||||
return "; ".join(_to_str(x) for x in v)
|
||||
# Return str(v) if v else ""
|
||||
return str(v) if v else ""
|
||||
|
||||
|
||||
# Define function _fetch_techniques_from_tactic_apis
|
||||
def _fetch_techniques_from_tactic_apis() -> list[dict[str, Any]]:
|
||||
"""Fetch all defensive techniques via D3FEND tactic APIs.
|
||||
|
||||
Uses ``/api/tactic/d3f:{tactic}.json`` which is reliable and returns
|
||||
full metadata including the ontology IRI for each technique.
|
||||
|
||||
Returns:
|
||||
list[dict[str, Any]]: Deduplicated list of technique dicts, each
|
||||
containing ``d3fend_id``, ``iri``, ``name``, ``description``,
|
||||
and ``tactic``.
|
||||
"""
|
||||
# Assign all_techniques = []
|
||||
all_techniques: list[dict[str, Any]] = []
|
||||
# Assign seen = set()
|
||||
seen: set[str] = set()
|
||||
|
||||
# Open context manager
|
||||
with httpx.Client(timeout=60.0) as client:
|
||||
# Iterate over D3FEND_TACTICS
|
||||
for tactic in D3FEND_TACTICS:
|
||||
# Assign url = D3FEND_TACTIC_URL.format(tactic=tactic)
|
||||
url = D3FEND_TACTIC_URL.format(tactic=tactic)
|
||||
# Attempt the following; catch errors below
|
||||
try:
|
||||
# Assign resp = client.get(url)
|
||||
resp = client.get(url)
|
||||
# Call resp.raise_for_status()
|
||||
resp.raise_for_status()
|
||||
# Assign data = resp.json()
|
||||
data = resp.json()
|
||||
# Handle Exception
|
||||
except Exception as e:
|
||||
# Log warning: "Failed to fetch D3FEND tactic %s: %s", tactic, e
|
||||
logger.warning("Failed to fetch D3FEND tactic %s: %s", tactic, e)
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
|
||||
# Assign graph = data.get("techniques", {}).get("@graph", [])
|
||||
graph = data.get("techniques", {}).get("@graph", [])
|
||||
# Iterate over graph
|
||||
for node in graph:
|
||||
# Assign nid = node.get("@id", "")
|
||||
nid = node.get("@id", "")
|
||||
# Assign d3id = _to_str(node.get("d3f:d3fend-id", ""))
|
||||
d3id = _to_str(node.get("d3f:d3fend-id", ""))
|
||||
# Assign label = _to_str(node.get("rdfs:label", ""))
|
||||
label = _to_str(node.get("rdfs:label", ""))
|
||||
# Assign defn = _to_str(node.get("d3f:definition", ""))
|
||||
defn = _to_str(node.get("d3f:definition", ""))
|
||||
# Check: not defn
|
||||
if not defn:
|
||||
# Assign defn = _to_str(node.get("rdfs:comment", ""))
|
||||
defn = _to_str(node.get("rdfs:comment", ""))
|
||||
|
||||
# Assign iri = nid.replace("d3f:", "") if nid.startswith("d3f:") else nid
|
||||
iri = nid.replace("d3f:", "") if nid.startswith("d3f:") else nid
|
||||
|
||||
# Check: d3id and label and d3id not in seen
|
||||
if d3id and label and d3id not in seen:
|
||||
# Call seen.add()
|
||||
seen.add(d3id)
|
||||
# Call all_techniques.append()
|
||||
all_techniques.append({
|
||||
# Literal argument value
|
||||
"d3fend_id": d3id,
|
||||
# Literal argument value
|
||||
"iri": iri,
|
||||
# Literal argument value
|
||||
"name": label,
|
||||
# Literal argument value
|
||||
"description": defn[:500] if defn else None,
|
||||
# Literal argument value
|
||||
"tactic": tactic,
|
||||
})
|
||||
|
||||
# Log info: "D3FEND tactic %s: %d techniques", tactic, len(gra
|
||||
logger.info("D3FEND tactic %s: %d techniques", tactic, len(graph))
|
||||
|
||||
# Return all_techniques
|
||||
return all_techniques
|
||||
|
||||
|
||||
# Define function _upsert_techniques
|
||||
def _upsert_techniques(db: Session, techniques: list[dict[str, Any]]) -> dict[str, int]:
|
||||
"""Upsert a list of technique dicts into the DefensiveTechnique table."""
|
||||
"""Upsert a list of technique dicts into the DefensiveTechnique table.
|
||||
|
||||
Args:
|
||||
db (Session): Active SQLAlchemy database session.
|
||||
techniques (list[dict[str, Any]]): List of technique data dicts, each
|
||||
containing ``d3fend_id``, ``name``, and optionally ``description``,
|
||||
``tactic``, and ``iri``.
|
||||
|
||||
Returns:
|
||||
dict[str, int]: Contains ``created``, ``updated``, and ``total``
|
||||
counts after the upsert.
|
||||
"""
|
||||
# Assign created = 0
|
||||
created = 0
|
||||
# Assign updated = 0
|
||||
updated = 0
|
||||
|
||||
# Iterate over techniques
|
||||
for tech_data in techniques:
|
||||
# Assign existing = (
|
||||
existing = (
|
||||
db.query(DefensiveTechnique)
|
||||
# Chain .filter() call
|
||||
.filter(DefensiveTechnique.d3fend_id == tech_data["d3fend_id"])
|
||||
# Chain .first() call
|
||||
.first()
|
||||
)
|
||||
# Assign iri = tech_data.get("iri") or tech_data["name"].replace(" ", "")
|
||||
iri = tech_data.get("iri") or tech_data["name"].replace(" ", "")
|
||||
# Assign d3fend_url = D3FEND_BASE_URL.format(iri=iri)
|
||||
d3fend_url = D3FEND_BASE_URL.format(iri=iri)
|
||||
|
||||
# Check: existing
|
||||
if existing:
|
||||
# Assign existing.name = tech_data["name"]
|
||||
existing.name = tech_data["name"]
|
||||
# Assign existing.description = tech_data.get("description")
|
||||
existing.description = tech_data.get("description")
|
||||
# Assign existing.tactic = tech_data.get("tactic")
|
||||
existing.tactic = tech_data.get("tactic")
|
||||
# Assign existing.d3fend_url = d3fend_url
|
||||
existing.d3fend_url = d3fend_url
|
||||
# Assign updated = 1
|
||||
updated += 1
|
||||
# Fallback: handle remaining cases
|
||||
else:
|
||||
# Assign new_tech = DefensiveTechnique(
|
||||
new_tech = DefensiveTechnique(
|
||||
# Keyword argument: d3fend_id
|
||||
d3fend_id=tech_data["d3fend_id"],
|
||||
# Keyword argument: name
|
||||
name=tech_data["name"],
|
||||
# Keyword argument: description
|
||||
description=tech_data.get("description"),
|
||||
# Keyword argument: tactic
|
||||
tactic=tech_data.get("tactic"),
|
||||
# Keyword argument: d3fend_url
|
||||
d3fend_url=d3fend_url,
|
||||
)
|
||||
# Stage new record(s) for database insertion
|
||||
db.add(new_tech)
|
||||
# Assign created = 1
|
||||
created += 1
|
||||
|
||||
# Commit all pending changes to the database
|
||||
db.commit()
|
||||
# Assign total = db.query(DefensiveTechnique).count()
|
||||
total = db.query(DefensiveTechnique).count()
|
||||
# Return {"created": created, "updated": updated, "total": total}
|
||||
return {"created": created, "updated": updated, "total": total}
|
||||
|
||||
|
||||
# Define function import_d3fend_techniques
|
||||
def import_d3fend_techniques(db: Session) -> dict[str, int]:
|
||||
"""Fetch all D3FEND defensive techniques and upsert into DB.
|
||||
|
||||
Uses the tactic-level APIs which are reliable and provide full metadata
|
||||
including ontology IRIs for correct URL generation.
|
||||
|
||||
Returns a dict with counts: {created, updated, total}.
|
||||
Args:
|
||||
db (Session): Active SQLAlchemy database session.
|
||||
|
||||
Returns:
|
||||
dict[str, int]: Contains ``created``, ``updated``, and ``total``
|
||||
counts; falls back to curated list when the API returns fewer
|
||||
than 50 techniques.
|
||||
"""
|
||||
# Log info: "Fetching D3FEND techniques from tactic APIs"
|
||||
logger.info("Fetching D3FEND techniques from tactic APIs")
|
||||
|
||||
# Attempt the following; catch errors below
|
||||
try:
|
||||
# Assign techniques = _fetch_techniques_from_tactic_apis()
|
||||
techniques = _fetch_techniques_from_tactic_apis()
|
||||
# Handle Exception
|
||||
except Exception as e:
|
||||
# Log error: "Failed to fetch D3FEND techniques from tactic API
|
||||
logger.error("Failed to fetch D3FEND techniques from tactic APIs: %s", e)
|
||||
# Assign techniques = []
|
||||
techniques = []
|
||||
|
||||
# Check: len(techniques) >= 50
|
||||
if len(techniques) >= 50:
|
||||
# Log info: "Fetched %d D3FEND techniques from tactic APIs", l
|
||||
logger.info("Fetched %d D3FEND techniques from tactic APIs", len(techniques))
|
||||
# Assign result = _upsert_techniques(db, techniques)
|
||||
result = _upsert_techniques(db, techniques)
|
||||
# Log info: "D3FEND import done: %d created, %d updated, %d to
|
||||
logger.info("D3FEND import done: %d created, %d updated, %d total",
|
||||
result["created"], result["updated"], result["total"])
|
||||
# Return result
|
||||
return result
|
||||
|
||||
# Fallback: use a curated list of well-known D3FEND techniques
|
||||
logger.warning("Tactic APIs returned too few techniques (%d), using fallback", len(techniques))
|
||||
# Return _import_d3fend_fallback(db)
|
||||
return _import_d3fend_fallback(db)
|
||||
|
||||
|
||||
@@ -228,9 +350,20 @@ _FALLBACK_TECHNIQUES: list[dict[str, str | None]] = [
|
||||
]
|
||||
|
||||
|
||||
# Define function _import_d3fend_fallback
|
||||
def _import_d3fend_fallback(db: Session) -> dict[str, int]:
|
||||
"""Import curated D3FEND techniques when the tactic APIs are unreachable."""
|
||||
"""Import curated D3FEND techniques when the tactic APIs are unreachable.
|
||||
|
||||
Args:
|
||||
db (Session): Active SQLAlchemy database session.
|
||||
|
||||
Returns:
|
||||
dict[str, int]: Contains ``created``, ``updated``, and ``total``
|
||||
counts from upserting the fallback technique list.
|
||||
"""
|
||||
# Log info: "Using fallback D3FEND technique list (%d entries
|
||||
logger.info("Using fallback D3FEND technique list (%d entries)", len(_FALLBACK_TECHNIQUES))
|
||||
# Return _upsert_techniques(db, _FALLBACK_TECHNIQUES) # type: ignore[arg-type]
|
||||
return _upsert_techniques(db, _FALLBACK_TECHNIQUES) # type: ignore[arg-type]
|
||||
|
||||
|
||||
@@ -239,217 +372,399 @@ def _import_d3fend_fallback(db: Session) -> dict[str, int]:
|
||||
|
||||
# Curated ATT&CK → D3FEND mapping for common techniques
|
||||
_ATTACK_TO_D3FEND: dict[str, list[str]] = {
|
||||
# Literal argument value
|
||||
"T1059": ["D3-PSA", "D3-SCA", "D3-PA", "D3-EAW", "D3-EDL", "D3-PLA"],
|
||||
# Literal argument value
|
||||
"T1059.001": ["D3-PSA", "D3-SCA", "D3-PA", "D3-EAW", "D3-EDL"],
|
||||
# Literal argument value
|
||||
"T1059.003": ["D3-PSA", "D3-SCA", "D3-PA", "D3-EAW"],
|
||||
# Literal argument value
|
||||
"T1059.005": ["D3-PSA", "D3-SCA", "D3-EAW"],
|
||||
# Literal argument value
|
||||
"T1059.007": ["D3-PSA", "D3-SCA", "D3-EAW"],
|
||||
# Literal argument value
|
||||
"T1055": ["D3-PA", "D3-PSA", "D3-HBPI", "D3-PMAD", "D3-PLA"],
|
||||
# Literal argument value
|
||||
"T1055.001": ["D3-PA", "D3-PMAD", "D3-HBPI"],
|
||||
# Literal argument value
|
||||
"T1055.002": ["D3-PA", "D3-PMAD", "D3-HBPI"],
|
||||
# Literal argument value
|
||||
"T1003": ["D3-CH", "D3-CR", "D3-MFA", "D3-PMAD"],
|
||||
# Literal argument value
|
||||
"T1003.001": ["D3-CH", "D3-CR", "D3-PMAD"],
|
||||
# Literal argument value
|
||||
"T1078": ["D3-MFA", "D3-UBA", "D3-UGLPA", "D3-CH"],
|
||||
# Literal argument value
|
||||
"T1078.001": ["D3-MFA", "D3-UBA", "D3-CH"],
|
||||
# Literal argument value
|
||||
"T1566": ["D3-EAL", "D3-FA", "D3-FH", "D3-UA", "D3-EHR"],
|
||||
# Literal argument value
|
||||
"T1566.001": ["D3-EAL", "D3-FA", "D3-FH", "D3-EHR"],
|
||||
# Literal argument value
|
||||
"T1566.002": ["D3-UA", "D3-EAL", "D3-EHR"],
|
||||
# Literal argument value
|
||||
"T1071": ["D3-AL", "D3-NTA", "D3-PM", "D3-CT"],
|
||||
# Literal argument value
|
||||
"T1071.001": ["D3-AL", "D3-NTA", "D3-PM"],
|
||||
# Literal argument value
|
||||
"T1053": ["D3-PSA", "D3-PA", "D3-SCHE", "D3-SSA"],
|
||||
# Literal argument value
|
||||
"T1053.005": ["D3-PSA", "D3-SCHE", "D3-SSA"],
|
||||
# Literal argument value
|
||||
"T1543": ["D3-SMRA", "D3-SSA", "D3-SBAN"],
|
||||
# Literal argument value
|
||||
"T1543.003": ["D3-SMRA", "D3-SSA", "D3-SBAN"],
|
||||
# Literal argument value
|
||||
"T1547": ["D3-SICA", "D3-SSA", "D3-RRID"],
|
||||
# Literal argument value
|
||||
"T1547.001": ["D3-SICA", "D3-SSA", "D3-RRID"],
|
||||
# Literal argument value
|
||||
"T1021": ["D3-RTSD", "D3-RPA", "D3-NTA", "D3-MFA"],
|
||||
# Literal argument value
|
||||
"T1021.001": ["D3-RTSD", "D3-NTA", "D3-MFA"],
|
||||
# Literal argument value
|
||||
"T1021.002": ["D3-RTSD", "D3-NTA", "D3-NI"],
|
||||
# Literal argument value
|
||||
"T1560": ["D3-FA", "D3-FCA", "D3-ORA"],
|
||||
# Literal argument value
|
||||
"T1560.001": ["D3-FA", "D3-FCA"],
|
||||
# Literal argument value
|
||||
"T1048": ["D3-ORA", "D3-NTA", "D3-OTF"],
|
||||
# Literal argument value
|
||||
"T1048.003": ["D3-ORA", "D3-NTA", "D3-OTF"],
|
||||
# Literal argument value
|
||||
"T1105": ["D3-IRA", "D3-NTA", "D3-FA", "D3-FH"],
|
||||
# Literal argument value
|
||||
"T1036": ["D3-FCA", "D3-FH", "D3-FA", "D3-SWI"],
|
||||
# Literal argument value
|
||||
"T1036.005": ["D3-FCA", "D3-FH", "D3-FA"],
|
||||
# Literal argument value
|
||||
"T1140": ["D3-FA", "D3-DA", "D3-SCA"],
|
||||
# Literal argument value
|
||||
"T1070": ["D3-SSA", "D3-LOGA", "D3-SYSM"],
|
||||
# Literal argument value
|
||||
"T1070.004": ["D3-SSA", "D3-FAPA"],
|
||||
# Literal argument value
|
||||
"T1562": ["D3-SSA", "D3-SYSM", "D3-SMRA"],
|
||||
# Literal argument value
|
||||
"T1562.001": ["D3-SSA", "D3-SYSM", "D3-SMRA"],
|
||||
# Literal argument value
|
||||
"T1027": ["D3-DA", "D3-FA", "D3-RE"],
|
||||
# Literal argument value
|
||||
"T1027.002": ["D3-DA", "D3-FA"],
|
||||
# Literal argument value
|
||||
"T1110": ["D3-MFA", "D3-UBA", "D3-CH"],
|
||||
# Literal argument value
|
||||
"T1110.001": ["D3-MFA", "D3-UBA", "D3-CH"],
|
||||
# Literal argument value
|
||||
"T1082": ["D3-PSA", "D3-PA", "D3-SYSM"],
|
||||
# Literal argument value
|
||||
"T1083": ["D3-FAPA", "D3-PA"],
|
||||
# Literal argument value
|
||||
"T1497": ["D3-DA", "D3-SE"],
|
||||
# Literal argument value
|
||||
"T1218": ["D3-PSA", "D3-PLA", "D3-EAW"],
|
||||
# Literal argument value
|
||||
"T1218.011": ["D3-PSA", "D3-PLA", "D3-EAW"],
|
||||
# Literal argument value
|
||||
"T1569": ["D3-SMRA", "D3-PSA", "D3-PA"],
|
||||
# Literal argument value
|
||||
"T1569.002": ["D3-SMRA", "D3-PSA"],
|
||||
# Literal argument value
|
||||
"T1012": ["D3-RRID", "D3-PA"],
|
||||
# Literal argument value
|
||||
"T1112": ["D3-RRID", "D3-PA", "D3-REGG"],
|
||||
# Literal argument value
|
||||
"T1057": ["D3-PA", "D3-PSA"],
|
||||
# Literal argument value
|
||||
"T1518": ["D3-SYSM", "D3-PA"],
|
||||
# Literal argument value
|
||||
"T1049": ["D3-NTA", "D3-PA"],
|
||||
# Literal argument value
|
||||
"T1016": ["D3-NTA", "D3-PA", "D3-SYSM"],
|
||||
# Literal argument value
|
||||
"T1033": ["D3-PA", "D3-UBA"],
|
||||
# Literal argument value
|
||||
"T1087": ["D3-UBA", "D3-PA", "D3-SSA"],
|
||||
# Literal argument value
|
||||
"T1087.001": ["D3-UBA", "D3-PA"],
|
||||
# Literal argument value
|
||||
"T1087.002": ["D3-UBA", "D3-PA"],
|
||||
# Literal argument value
|
||||
"T1018": ["D3-NTA", "D3-PA"],
|
||||
# Literal argument value
|
||||
"T1047": ["D3-RPA", "D3-PSA", "D3-PA"],
|
||||
# Literal argument value
|
||||
"T1190": ["D3-ISVA", "D3-NTA", "D3-AL"],
|
||||
# Literal argument value
|
||||
"T1133": ["D3-NTA", "D3-MFA", "D3-RTSD"],
|
||||
# Literal argument value
|
||||
"T1486": ["D3-BKUP", "D3-FBKP", "D3-ANTR", "D3-FA"],
|
||||
# Literal argument value
|
||||
"T1490": ["D3-BKUP", "D3-FBKP", "D3-SSA"],
|
||||
# Literal argument value
|
||||
"T1489": ["D3-SMRA", "D3-SSA"],
|
||||
# Literal argument value
|
||||
"T1098": ["D3-UBA", "D3-SSA", "D3-PGOV"],
|
||||
# Literal argument value
|
||||
"T1136": ["D3-UBA", "D3-SSA", "D3-UACM"],
|
||||
# Literal argument value
|
||||
"T1136.001": ["D3-UBA", "D3-SSA", "D3-UACM"],
|
||||
# Literal argument value
|
||||
"T1068": ["D3-SU", "D3-VULM", "D3-HBPI"],
|
||||
# Literal argument value
|
||||
"T1548": ["D3-PSEP", "D3-PSA", "D3-PA"],
|
||||
# Literal argument value
|
||||
"T1548.002": ["D3-PSEP", "D3-PSA"],
|
||||
# Literal argument value
|
||||
"T1134": ["D3-PA", "D3-PSA", "D3-PSEP"],
|
||||
# Literal argument value
|
||||
"T1134.001": ["D3-PA", "D3-PSA"],
|
||||
# Literal argument value
|
||||
"T1574": ["D3-SWI", "D3-FCA", "D3-PLA"],
|
||||
# Literal argument value
|
||||
"T1574.001": ["D3-SWI", "D3-FCA"],
|
||||
# Literal argument value
|
||||
"T1204": ["D3-EAL", "D3-FA", "D3-UA"],
|
||||
# Literal argument value
|
||||
"T1204.001": ["D3-UA", "D3-EAL"],
|
||||
# Literal argument value
|
||||
"T1204.002": ["D3-FA", "D3-EAL", "D3-DA"],
|
||||
# Literal argument value
|
||||
"T1071.004": ["D3-DPM", "D3-DNSSM", "D3-NTA"],
|
||||
# Literal argument value
|
||||
"T1571": ["D3-NTA", "D3-PM", "D3-AL"],
|
||||
# Literal argument value
|
||||
"T1572": ["D3-NTA", "D3-AL", "D3-PM"],
|
||||
# Literal argument value
|
||||
"T1041": ["D3-ORA", "D3-NTA"],
|
||||
# Literal argument value
|
||||
"T1005": ["D3-FAPA", "D3-PA"],
|
||||
# Literal argument value
|
||||
"T1113": ["D3-PA", "D3-PSA"],
|
||||
# Literal argument value
|
||||
"T1056": ["D3-PA", "D3-PSA", "D3-HBPI"],
|
||||
# Literal argument value
|
||||
"T1056.001": ["D3-PA", "D3-PSA"],
|
||||
# Literal argument value
|
||||
"T1560.003": ["D3-FA", "D3-ORA"],
|
||||
# Literal argument value
|
||||
"T1583": ["D3-IPMR", "D3-DNSRA"],
|
||||
# Literal argument value
|
||||
"T1584": ["D3-IPMR", "D3-DNSRA"],
|
||||
# Literal argument value
|
||||
"T1595": ["D3-IRA", "D3-NTA"],
|
||||
# Literal argument value
|
||||
"T1589": ["D3-UBA", "D3-THRT"],
|
||||
# Literal argument value
|
||||
"T1590": ["D3-NTA", "D3-THRT"],
|
||||
# Literal argument value
|
||||
"T1591": ["D3-THRT"],
|
||||
# Literal argument value
|
||||
"T1592": ["D3-THRT"],
|
||||
}
|
||||
|
||||
|
||||
# Define function import_d3fend_mappings
|
||||
def import_d3fend_mappings(db: Session) -> dict[str, int]:
|
||||
"""Create ATT&CK → D3FEND mappings.
|
||||
|
||||
First tries the D3FEND API for each ATT&CK technique in the DB,
|
||||
then falls back to the curated mapping for any remaining techniques.
|
||||
|
||||
Returns a dict with counts: {created, skipped, total}.
|
||||
Args:
|
||||
db (Session): Active SQLAlchemy database session.
|
||||
|
||||
Returns:
|
||||
dict[str, int]: Contains ``created``, ``skipped``, and ``total``
|
||||
mapping counts.
|
||||
"""
|
||||
# Assign created = 0
|
||||
created = 0
|
||||
# Assign skipped = 0
|
||||
skipped = 0
|
||||
|
||||
# Get all ATT&CK techniques from the DB
|
||||
attack_techniques = db.query(Technique).all()
|
||||
# Assign technique_map = {t.mitre_id: t for t in attack_techniques}
|
||||
technique_map = {t.mitre_id: t for t in attack_techniques}
|
||||
|
||||
# Get all defensive techniques
|
||||
defensive_techniques = db.query(DefensiveTechnique).all()
|
||||
# Assign d3fend_map = {dt.d3fend_id: dt for dt in defensive_techniques}
|
||||
d3fend_map = {dt.d3fend_id: dt for dt in defensive_techniques}
|
||||
|
||||
# Check: not d3fend_map
|
||||
if not d3fend_map:
|
||||
# Log warning: "No D3FEND techniques in DB — run import_d3fend_te
|
||||
logger.warning("No D3FEND techniques in DB — run import_d3fend_techniques first")
|
||||
# Return {"created": 0, "skipped": 0, "total": 0}
|
||||
return {"created": 0, "skipped": 0, "total": 0}
|
||||
|
||||
# Use the curated mapping for now (API per-technique is very slow for 700+ techniques)
|
||||
for mitre_id, d3fend_ids in _ATTACK_TO_D3FEND.items():
|
||||
# Assign attack_tech = technique_map.get(mitre_id)
|
||||
attack_tech = technique_map.get(mitre_id)
|
||||
# Check: not attack_tech
|
||||
if not attack_tech:
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
|
||||
# Iterate over d3fend_ids
|
||||
for d3fend_id in d3fend_ids:
|
||||
# Assign def_tech = d3fend_map.get(d3fend_id)
|
||||
def_tech = d3fend_map.get(d3fend_id)
|
||||
# Check: not def_tech
|
||||
if not def_tech:
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
|
||||
# Check if mapping already exists
|
||||
existing = (
|
||||
db.query(DefensiveTechniqueMapping)
|
||||
# Chain .filter() call
|
||||
.filter(
|
||||
DefensiveTechniqueMapping.attack_technique_id == attack_tech.id,
|
||||
DefensiveTechniqueMapping.defensive_technique_id == def_tech.id,
|
||||
)
|
||||
# Chain .first() call
|
||||
.first()
|
||||
)
|
||||
|
||||
# Check: existing
|
||||
if existing:
|
||||
# Assign skipped = 1
|
||||
skipped += 1
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
|
||||
# Assign mapping = DefensiveTechniqueMapping(
|
||||
mapping = DefensiveTechniqueMapping(
|
||||
# Keyword argument: attack_technique_id
|
||||
attack_technique_id=attack_tech.id,
|
||||
# Keyword argument: defensive_technique_id
|
||||
defensive_technique_id=def_tech.id,
|
||||
)
|
||||
# Stage new record(s) for database insertion
|
||||
db.add(mapping)
|
||||
# Assign created = 1
|
||||
created += 1
|
||||
|
||||
# Commit all pending changes to the database
|
||||
db.commit()
|
||||
|
||||
# Assign total = db.query(DefensiveTechniqueMapping).count()
|
||||
total = db.query(DefensiveTechniqueMapping).count()
|
||||
# Log info: "D3FEND mappings: %d created, %d skipped, %d total
|
||||
logger.info("D3FEND mappings: %d created, %d skipped, %d total", created, skipped, total)
|
||||
# Return {"created": created, "skipped": skipped, "total": total}
|
||||
return {"created": created, "skipped": skipped, "total": total}
|
||||
|
||||
|
||||
# Define function sync
|
||||
def sync(db: Session) -> dict:
|
||||
"""Sync D3FEND techniques and ATT&CK mappings.
|
||||
|
||||
Called by the Data Sources router when the user clicks Sync for D3FEND.
|
||||
Returns a flat summary dict suitable for ``last_sync_stats``.
|
||||
|
||||
Args:
|
||||
db (Session): Active SQLAlchemy database session.
|
||||
|
||||
Returns:
|
||||
dict: Flat summary dict suitable for ``last_sync_stats``, containing
|
||||
``techniques_created``, ``techniques_updated``,
|
||||
``techniques_total``, ``mappings_created``,
|
||||
``mappings_skipped``, and ``mappings_total``.
|
||||
"""
|
||||
from app.models.data_source import DataSource
|
||||
# Import datetime from datetime
|
||||
from datetime import datetime
|
||||
|
||||
# Import DataSource from app.models.data_source
|
||||
from app.models.data_source import DataSource
|
||||
|
||||
# Assign tech_result = import_d3fend_techniques(db)
|
||||
tech_result = import_d3fend_techniques(db)
|
||||
# Assign mapping_result = import_d3fend_mappings(db)
|
||||
mapping_result = import_d3fend_mappings(db)
|
||||
|
||||
# Assign summary = {
|
||||
summary = {
|
||||
# Literal argument value
|
||||
"techniques_created": tech_result.get("created", 0),
|
||||
# Literal argument value
|
||||
"techniques_updated": tech_result.get("updated", 0),
|
||||
# Literal argument value
|
||||
"techniques_total": tech_result.get("total", 0),
|
||||
# Literal argument value
|
||||
"mappings_created": mapping_result.get("created", 0),
|
||||
# Literal argument value
|
||||
"mappings_skipped": mapping_result.get("skipped", 0),
|
||||
# Literal argument value
|
||||
"mappings_total": mapping_result.get("total", 0),
|
||||
}
|
||||
|
||||
# Update DataSource record
|
||||
ds = db.query(DataSource).filter(DataSource.name == "d3fend").first()
|
||||
# Check: ds
|
||||
if ds:
|
||||
# Assign ds.last_sync_at = datetime.utcnow()
|
||||
ds.last_sync_at = datetime.utcnow()
|
||||
# Assign ds.last_sync_status = "success"
|
||||
ds.last_sync_status = "success"
|
||||
# Assign ds.last_sync_stats = summary
|
||||
ds.last_sync_stats = summary
|
||||
# Commit all pending changes to the database
|
||||
db.commit()
|
||||
|
||||
# Log info: "D3FEND sync complete — %s", summary
|
||||
logger.info("D3FEND sync complete — %s", summary)
|
||||
# Return summary
|
||||
return summary
|
||||
|
||||
|
||||
def get_defenses_for_technique(db: Session, technique_id) -> list[dict]:
|
||||
"""Get all D3FEND defensive techniques mapped to a given ATT&CK technique."""
|
||||
# Define function get_defenses_for_technique
|
||||
def get_defenses_for_technique(db: Session, technique_id: UUID) -> list[dict]:
|
||||
"""Return all D3FEND defensive techniques mapped to a given ATT&CK technique.
|
||||
|
||||
Args:
|
||||
db (Session): Active SQLAlchemy database session.
|
||||
technique_id (UUID): UUID of the ATT&CK technique to look up.
|
||||
|
||||
Returns:
|
||||
list[dict]: List of defensive technique dicts, each containing
|
||||
``id``, ``d3fend_id``, ``name``, ``description``, ``tactic``,
|
||||
and ``d3fend_url``.
|
||||
"""
|
||||
# Assign mappings = (
|
||||
mappings = (
|
||||
db.query(DefensiveTechniqueMapping)
|
||||
# Chain .filter() call
|
||||
.filter(DefensiveTechniqueMapping.attack_technique_id == technique_id)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
|
||||
# Assign results = []
|
||||
results = []
|
||||
# Iterate over mappings
|
||||
for m in mappings:
|
||||
# Assign dt = m.defensive_technique
|
||||
dt = m.defensive_technique
|
||||
# Call results.append()
|
||||
results.append({
|
||||
# Literal argument value
|
||||
"id": str(dt.id),
|
||||
# Literal argument value
|
||||
"d3fend_id": dt.d3fend_id,
|
||||
# Literal argument value
|
||||
"name": dt.name,
|
||||
# Literal argument value
|
||||
"description": dt.description,
|
||||
# Literal argument value
|
||||
"tactic": dt.tactic,
|
||||
# Literal argument value
|
||||
"d3fend_url": dt.d3fend_url,
|
||||
})
|
||||
|
||||
# Return results
|
||||
return results
|
||||
|
||||
@@ -1,53 +1,92 @@
|
||||
"""D3FEND query service — framework-agnostic queries for defensive techniques."""
|
||||
# Enable future language features for compatibility
|
||||
from __future__ import annotations
|
||||
|
||||
# Import Optional from typing
|
||||
from typing import Optional
|
||||
|
||||
# Import func from sqlalchemy
|
||||
from sqlalchemy import func
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import EntityNotFoundError from app.domain.errors
|
||||
from app.domain.errors import EntityNotFoundError
|
||||
|
||||
# Import DefensiveTechnique from app.models.defensive_technique
|
||||
from app.models.defensive_technique import DefensiveTechnique
|
||||
|
||||
# Import Technique from app.models.technique
|
||||
from app.models.technique import Technique
|
||||
|
||||
# Import get_defenses_for_technique from app.services.d3fend_import_service
|
||||
from app.services.d3fend_import_service import get_defenses_for_technique
|
||||
|
||||
# Import escape_like from app.utils
|
||||
from app.utils import escape_like
|
||||
|
||||
|
||||
# Define function list_defensive_techniques
|
||||
def list_defensive_techniques(
|
||||
# Entry: db
|
||||
db: Session,
|
||||
*,
|
||||
# Entry: tactic
|
||||
tactic: Optional[str] = None,
|
||||
# Entry: search
|
||||
search: Optional[str] = None,
|
||||
# Entry: offset
|
||||
offset: int = 0,
|
||||
# Entry: limit
|
||||
limit: int = 50,
|
||||
) -> dict:
|
||||
"""List D3FEND defensive techniques with optional filters."""
|
||||
# Assign query = db.query(DefensiveTechnique)
|
||||
query = db.query(DefensiveTechnique)
|
||||
|
||||
# Check: tactic
|
||||
if tactic:
|
||||
# Assign query = query.filter(DefensiveTechnique.tactic == tactic)
|
||||
query = query.filter(DefensiveTechnique.tactic == tactic)
|
||||
|
||||
# Check: search
|
||||
if search:
|
||||
# Assign pattern = f"%{escape_like(search)}%"
|
||||
pattern = f"%{escape_like(search)}%"
|
||||
# Assign query = query.filter(
|
||||
query = query.filter(
|
||||
DefensiveTechnique.name.ilike(pattern)
|
||||
| DefensiveTechnique.d3fend_id.ilike(pattern)
|
||||
)
|
||||
|
||||
# Assign total = query.count()
|
||||
total = query.count()
|
||||
# Assign items = query.order_by(DefensiveTechnique.d3fend_id).offset(offset).limit(l...
|
||||
items = query.order_by(DefensiveTechnique.d3fend_id).offset(offset).limit(limit).all()
|
||||
|
||||
# Return {
|
||||
return {
|
||||
# Literal argument value
|
||||
"total": total,
|
||||
# Literal argument value
|
||||
"offset": offset,
|
||||
# Literal argument value
|
||||
"limit": limit,
|
||||
# Literal argument value
|
||||
"items": [
|
||||
{
|
||||
# Literal argument value
|
||||
"id": str(dt.id),
|
||||
# Literal argument value
|
||||
"d3fend_id": dt.d3fend_id,
|
||||
# Literal argument value
|
||||
"name": dt.name,
|
||||
# Literal argument value
|
||||
"description": dt.description,
|
||||
# Literal argument value
|
||||
"tactic": dt.tactic,
|
||||
# Literal argument value
|
||||
"d3fend_url": dt.d3fend_url,
|
||||
}
|
||||
for dt in items
|
||||
@@ -55,28 +94,44 @@ def list_defensive_techniques(
|
||||
}
|
||||
|
||||
|
||||
# Define function list_d3fend_tactics
|
||||
def list_d3fend_tactics(db: Session) -> list[dict]:
|
||||
"""Return a list of all D3FEND tactics with counts."""
|
||||
# Assign rows = (
|
||||
rows = (
|
||||
db.query(DefensiveTechnique.tactic, func.count(DefensiveTechnique.id))
|
||||
# Chain .group_by() call
|
||||
.group_by(DefensiveTechnique.tactic)
|
||||
# Chain .order_by() call
|
||||
.order_by(DefensiveTechnique.tactic)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
# Return [{"tactic": tactic or "Unknown", "count": count} for tactic, count ...
|
||||
return [{"tactic": tactic or "Unknown", "count": count} for tactic, count in rows]
|
||||
|
||||
|
||||
# Define function get_defenses_for_attack_technique
|
||||
def get_defenses_for_attack_technique(db: Session, mitre_id: str) -> dict:
|
||||
"""Get all D3FEND defensive techniques mapped to a given ATT&CK technique."""
|
||||
# Assign technique = db.query(Technique).filter(Technique.mitre_id == mitre_id).first()
|
||||
technique = db.query(Technique).filter(Technique.mitre_id == mitre_id).first()
|
||||
# Check: technique is None
|
||||
if technique is None:
|
||||
# Raise EntityNotFoundError
|
||||
raise EntityNotFoundError("Technique", mitre_id)
|
||||
|
||||
# Assign defenses = get_defenses_for_technique(db, technique.id)
|
||||
defenses = get_defenses_for_technique(db, technique.id)
|
||||
|
||||
# Return {
|
||||
return {
|
||||
# Literal argument value
|
||||
"mitre_id": mitre_id,
|
||||
# Literal argument value
|
||||
"technique_name": technique.name,
|
||||
# Literal argument value
|
||||
"defenses": defenses,
|
||||
# Literal argument value
|
||||
"total": len(defenses),
|
||||
}
|
||||
|
||||
@@ -4,61 +4,99 @@ Provides list, update, sync, and stats. Sync operations commit internally
|
||||
since they are long-running and self-contained.
|
||||
"""
|
||||
|
||||
# Enable future language features for compatibility
|
||||
from __future__ import annotations
|
||||
|
||||
# Import logging
|
||||
import logging
|
||||
|
||||
# Import datetime from datetime
|
||||
from datetime import datetime
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import BusinessRuleViolation, EntityNotFoundError from app.domain.errors
|
||||
from app.domain.errors import BusinessRuleViolation, EntityNotFoundError
|
||||
|
||||
# Import get_import_handler from app.domain.ports.import_service
|
||||
from app.domain.ports.import_service import get_import_handler
|
||||
|
||||
# Import DataSource from app.models.data_source
|
||||
from app.models.data_source import DataSource
|
||||
|
||||
# Assign logger = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Define function list_sources
|
||||
def list_sources(db: Session) -> list[dict]:
|
||||
"""Return all registered data sources as a list of dicts."""
|
||||
# Assign sources = db.query(DataSource).order_by(DataSource.name).all()
|
||||
sources = db.query(DataSource).order_by(DataSource.name).all()
|
||||
# Return [
|
||||
return [
|
||||
{
|
||||
# Literal argument value
|
||||
"id": str(s.id),
|
||||
# Literal argument value
|
||||
"name": s.name,
|
||||
# Literal argument value
|
||||
"display_name": s.display_name,
|
||||
# Literal argument value
|
||||
"type": s.type,
|
||||
# Literal argument value
|
||||
"url": s.url,
|
||||
# Literal argument value
|
||||
"description": s.description,
|
||||
# Literal argument value
|
||||
"is_enabled": s.is_enabled,
|
||||
# Literal argument value
|
||||
"last_sync_at": s.last_sync_at.isoformat() if s.last_sync_at else None,
|
||||
# Literal argument value
|
||||
"last_sync_status": s.last_sync_status,
|
||||
# Literal argument value
|
||||
"last_sync_stats": s.last_sync_stats,
|
||||
# Literal argument value
|
||||
"sync_frequency": s.sync_frequency,
|
||||
# Literal argument value
|
||||
"config": s.config,
|
||||
# Literal argument value
|
||||
"created_at": s.created_at.isoformat() if s.created_at else None,
|
||||
}
|
||||
for s in sources
|
||||
]
|
||||
|
||||
|
||||
# Define function update_source
|
||||
def update_source(db: Session, source_id: str, **fields: object) -> None:
|
||||
"""Update a data source's fields (is_enabled, sync_frequency, config).
|
||||
|
||||
Raises EntityNotFoundError if source does not exist.
|
||||
Does not commit; the router handles that.
|
||||
"""
|
||||
# Assign ds = db.query(DataSource).filter(DataSource.id == source_id).first()
|
||||
ds = db.query(DataSource).filter(DataSource.id == source_id).first()
|
||||
# Check: not ds
|
||||
if not ds:
|
||||
# Raise EntityNotFoundError
|
||||
raise EntityNotFoundError("Data source", source_id)
|
||||
|
||||
# Check: "is_enabled" in fields
|
||||
if "is_enabled" in fields:
|
||||
# Assign ds.is_enabled = fields["is_enabled"]
|
||||
ds.is_enabled = fields["is_enabled"]
|
||||
# Check: "sync_frequency" in fields
|
||||
if "sync_frequency" in fields:
|
||||
# Assign ds.sync_frequency = fields["sync_frequency"]
|
||||
ds.sync_frequency = fields["sync_frequency"]
|
||||
# Check: "config" in fields
|
||||
if "config" in fields:
|
||||
# Assign ds.config = fields["config"]
|
||||
ds.config = fields["config"]
|
||||
|
||||
|
||||
# Define function sync_source
|
||||
def sync_source(db: Session, source_id: str) -> dict:
|
||||
"""Trigger sync for a specific data source.
|
||||
|
||||
@@ -67,131 +105,221 @@ def sync_source(db: Session, source_id: str) -> dict:
|
||||
Commits internally (long-running, self-contained operation).
|
||||
Returns dict with message, source, stats.
|
||||
"""
|
||||
# Assign ds = db.query(DataSource).filter(DataSource.id == source_id).first()
|
||||
ds = db.query(DataSource).filter(DataSource.id == source_id).first()
|
||||
# Check: not ds
|
||||
if not ds:
|
||||
# Raise EntityNotFoundError
|
||||
raise EntityNotFoundError("Data source", source_id)
|
||||
|
||||
# Assign handler = get_import_handler(ds.name)
|
||||
handler = get_import_handler(ds.name)
|
||||
# Check: handler is None
|
||||
if handler is None:
|
||||
# Raise BusinessRuleViolation
|
||||
raise BusinessRuleViolation(f"No sync handler available for '{ds.name}'")
|
||||
|
||||
# Assign ds.last_sync_status = "in_progress"
|
||||
ds.last_sync_status = "in_progress"
|
||||
# Commit all pending changes to the database
|
||||
db.commit()
|
||||
|
||||
# Attempt the following; catch errors below
|
||||
try:
|
||||
# Assign summary = handler(db)
|
||||
summary = handler(db)
|
||||
# Handle Exception
|
||||
except Exception as exc:
|
||||
# Log error: "Sync failed for %s: %s", ds.name, exc, exc_info=T
|
||||
logger.error("Sync failed for %s: %s", ds.name, exc, exc_info=True)
|
||||
# Assign ds.last_sync_status = "error"
|
||||
ds.last_sync_status = "error"
|
||||
# Assign ds.last_sync_at = datetime.utcnow()
|
||||
ds.last_sync_at = datetime.utcnow()
|
||||
# Assign ds.last_sync_stats = {"error": str(exc)}
|
||||
ds.last_sync_stats = {"error": str(exc)}
|
||||
# Commit all pending changes to the database
|
||||
db.commit()
|
||||
# Raise BusinessRuleViolation
|
||||
raise BusinessRuleViolation(
|
||||
f"Sync failed for '{ds.display_name}'. Check server logs for details."
|
||||
)
|
||||
|
||||
# Assign ds.last_sync_at = datetime.utcnow()
|
||||
ds.last_sync_at = datetime.utcnow()
|
||||
# Assign ds.last_sync_status = "success"
|
||||
ds.last_sync_status = "success"
|
||||
# Assign ds.last_sync_stats = summary
|
||||
ds.last_sync_stats = summary
|
||||
# Commit all pending changes to the database
|
||||
db.commit()
|
||||
|
||||
# Return {
|
||||
return {
|
||||
# Literal argument value
|
||||
"message": f"Sync complete for {ds.display_name}",
|
||||
# Literal argument value
|
||||
"source": ds.name,
|
||||
# Literal argument value
|
||||
"stats": summary,
|
||||
}
|
||||
|
||||
|
||||
# Define function sync_all_sources
|
||||
def sync_all_sources(db: Session) -> list[dict]:
|
||||
"""Trigger sync for all enabled data sources (sequentially).
|
||||
|
||||
Commits internally (long-running, self-contained operation).
|
||||
Returns list of result dicts with source, status, stats/detail.
|
||||
"""
|
||||
# Assign enabled_sources = (
|
||||
enabled_sources = (
|
||||
db.query(DataSource)
|
||||
# Chain .filter() call
|
||||
.filter(DataSource.is_enabled == True)
|
||||
# Chain .order_by() call
|
||||
.order_by(DataSource.name)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
|
||||
# Assign results = []
|
||||
results = []
|
||||
# Iterate over enabled_sources
|
||||
for ds in enabled_sources:
|
||||
# Assign handler = get_import_handler(ds.name)
|
||||
handler = get_import_handler(ds.name)
|
||||
# Check: handler is None
|
||||
if handler is None:
|
||||
# Call results.append()
|
||||
results.append({
|
||||
# Literal argument value
|
||||
"source": ds.name,
|
||||
# Literal argument value
|
||||
"status": "skipped",
|
||||
# Literal argument value
|
||||
"detail": "No sync handler available",
|
||||
})
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
|
||||
# Assign ds.last_sync_status = "in_progress"
|
||||
ds.last_sync_status = "in_progress"
|
||||
# Commit all pending changes to the database
|
||||
db.commit()
|
||||
|
||||
# Attempt the following; catch errors below
|
||||
try:
|
||||
# Assign summary = handler(db)
|
||||
summary = handler(db)
|
||||
# Assign ds.last_sync_at = datetime.utcnow()
|
||||
ds.last_sync_at = datetime.utcnow()
|
||||
# Assign ds.last_sync_status = "success"
|
||||
ds.last_sync_status = "success"
|
||||
# Assign ds.last_sync_stats = summary
|
||||
ds.last_sync_stats = summary
|
||||
# Commit all pending changes to the database
|
||||
db.commit()
|
||||
# Call results.append()
|
||||
results.append({
|
||||
# Literal argument value
|
||||
"source": ds.name,
|
||||
# Literal argument value
|
||||
"status": "success",
|
||||
# Literal argument value
|
||||
"stats": summary,
|
||||
})
|
||||
# Handle Exception
|
||||
except Exception as exc:
|
||||
# Log error: "Sync failed for %s: %s", ds.name, exc, exc_info=T
|
||||
logger.error("Sync failed for %s: %s", ds.name, exc, exc_info=True)
|
||||
# Assign ds.last_sync_status = "error"
|
||||
ds.last_sync_status = "error"
|
||||
# Assign ds.last_sync_at = datetime.utcnow()
|
||||
ds.last_sync_at = datetime.utcnow()
|
||||
# Assign ds.last_sync_stats = {"error": str(exc)}
|
||||
ds.last_sync_stats = {"error": str(exc)}
|
||||
# Commit all pending changes to the database
|
||||
db.commit()
|
||||
# Call results.append()
|
||||
results.append({
|
||||
# Literal argument value
|
||||
"source": ds.name,
|
||||
# Literal argument value
|
||||
"status": "error",
|
||||
# Literal argument value
|
||||
"detail": "Sync failed. Check server logs for details.",
|
||||
})
|
||||
|
||||
# Return results
|
||||
return results
|
||||
|
||||
|
||||
# Define function get_source_stats
|
||||
def get_source_stats(db: Session, source_id: str) -> dict:
|
||||
"""Return detailed statistics for a data source.
|
||||
|
||||
Raises EntityNotFoundError if source does not exist.
|
||||
"""
|
||||
# Assign ds = db.query(DataSource).filter(DataSource.id == source_id).first()
|
||||
ds = db.query(DataSource).filter(DataSource.id == source_id).first()
|
||||
# Check: not ds
|
||||
if not ds:
|
||||
# Raise EntityNotFoundError
|
||||
raise EntityNotFoundError("Data source", source_id)
|
||||
|
||||
from app.models.test_template import TestTemplate
|
||||
# Import DetectionRule from app.models.detection_rule
|
||||
from app.models.detection_rule import DetectionRule
|
||||
|
||||
# Import TestTemplate from app.models.test_template
|
||||
from app.models.test_template import TestTemplate
|
||||
|
||||
# Assign template_count = 0
|
||||
template_count = 0
|
||||
# Assign rule_count = 0
|
||||
rule_count = 0
|
||||
|
||||
# Check: ds.type == "attack_procedure"
|
||||
if ds.type == "attack_procedure":
|
||||
# Assign template_count = (
|
||||
template_count = (
|
||||
db.query(TestTemplate)
|
||||
# Chain .filter() call
|
||||
.filter(TestTemplate.source == ds.name)
|
||||
# Chain .count() call
|
||||
.count()
|
||||
)
|
||||
# Alternative: ds.type == "detection_rule"
|
||||
elif ds.type == "detection_rule":
|
||||
# Assign rule_count = (
|
||||
rule_count = (
|
||||
db.query(DetectionRule)
|
||||
# Chain .filter() call
|
||||
.filter(DetectionRule.source == ds.name)
|
||||
# Chain .count() call
|
||||
.count()
|
||||
)
|
||||
|
||||
# Return {
|
||||
return {
|
||||
# Literal argument value
|
||||
"id": str(ds.id),
|
||||
# Literal argument value
|
||||
"name": ds.name,
|
||||
# Literal argument value
|
||||
"display_name": ds.display_name,
|
||||
# Literal argument value
|
||||
"type": ds.type,
|
||||
# Literal argument value
|
||||
"is_enabled": ds.is_enabled,
|
||||
# Literal argument value
|
||||
"last_sync_at": ds.last_sync_at.isoformat() if ds.last_sync_at else None,
|
||||
# Literal argument value
|
||||
"last_sync_status": ds.last_sync_status,
|
||||
# Literal argument value
|
||||
"last_sync_stats": ds.last_sync_stats,
|
||||
# Literal argument value
|
||||
"total_templates": template_count,
|
||||
# Literal argument value
|
||||
"total_rules": rule_count,
|
||||
}
|
||||
|
||||
@@ -6,76 +6,136 @@ that the router remains a thin HTTP adapter.
|
||||
This module is framework-agnostic: no FastAPI imports.
|
||||
"""
|
||||
|
||||
# Enable future language features for compatibility
|
||||
from __future__ import annotations
|
||||
|
||||
# Import datetime from datetime
|
||||
from datetime import datetime
|
||||
|
||||
# Import Any from typing
|
||||
from typing import Any
|
||||
|
||||
# Import UUID from uuid
|
||||
from uuid import UUID
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import EntityNotFoundError from app.domain.errors
|
||||
from app.domain.errors import EntityNotFoundError
|
||||
from app.models.detection_rule import DetectionRule
|
||||
from app.models.test import Test
|
||||
from app.models.test_template import TestTemplate
|
||||
from app.models.test_template_detection_rule import TestTemplateDetectionRule
|
||||
from app.models.test_detection_result import TestDetectionResult
|
||||
from app.models.technique import Technique
|
||||
from app.utils import escape_like
|
||||
|
||||
# Import DetectionRule from app.models.detection_rule
|
||||
from app.models.detection_rule import DetectionRule
|
||||
|
||||
# Import Technique from app.models.technique
|
||||
from app.models.technique import Technique
|
||||
|
||||
# Import Test from app.models.test
|
||||
from app.models.test import Test
|
||||
|
||||
# Import TestDetectionResult from app.models.test_detection_result
|
||||
from app.models.test_detection_result import TestDetectionResult
|
||||
|
||||
# Import TestTemplate from app.models.test_template
|
||||
from app.models.test_template import TestTemplate
|
||||
|
||||
# Import TestTemplateDetectionRule from app.models.test_template_detection_rule
|
||||
from app.models.test_template_detection_rule import TestTemplateDetectionRule
|
||||
|
||||
# Import escape_like from app.utils
|
||||
from app.utils import escape_like
|
||||
|
||||
# ── Public service functions ──────────────────────────────────────────
|
||||
|
||||
|
||||
def list_rules(
|
||||
# Entry: db
|
||||
db: Session,
|
||||
*,
|
||||
# Entry: technique
|
||||
technique: str | None = None,
|
||||
# Entry: source
|
||||
source: str | None = None,
|
||||
# Entry: severity
|
||||
severity: str | None = None,
|
||||
# Entry: search
|
||||
search: str | None = None,
|
||||
# Entry: offset
|
||||
offset: int = 0,
|
||||
# Entry: limit
|
||||
limit: int = 50,
|
||||
) -> dict[str, Any]:
|
||||
"""List detection rules with optional filters and pagination."""
|
||||
# Assign query = db.query(DetectionRule).filter(DetectionRule.is_active == True)
|
||||
query = db.query(DetectionRule).filter(DetectionRule.is_active == True)
|
||||
|
||||
# Check: technique
|
||||
if technique:
|
||||
# Assign query = query.filter(DetectionRule.mitre_technique_id == technique)
|
||||
query = query.filter(DetectionRule.mitre_technique_id == technique)
|
||||
# Check: source
|
||||
if source:
|
||||
# Assign query = query.filter(DetectionRule.source == source)
|
||||
query = query.filter(DetectionRule.source == source)
|
||||
# Check: severity
|
||||
if severity:
|
||||
# Assign query = query.filter(DetectionRule.severity == severity)
|
||||
query = query.filter(DetectionRule.severity == severity)
|
||||
# Check: search
|
||||
if search:
|
||||
# Assign pattern = f"%{escape_like(search)}%"
|
||||
pattern = f"%{escape_like(search)}%"
|
||||
# Assign query = query.filter(
|
||||
query = query.filter(
|
||||
DetectionRule.title.ilike(pattern)
|
||||
| DetectionRule.description.ilike(pattern)
|
||||
)
|
||||
|
||||
# Assign total = query.count()
|
||||
total = query.count()
|
||||
# Assign items = (
|
||||
items = (
|
||||
query.order_by(DetectionRule.mitre_technique_id, DetectionRule.title)
|
||||
# Chain .offset() call
|
||||
.offset(offset)
|
||||
# Chain .limit() call
|
||||
.limit(limit)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
|
||||
# Return {
|
||||
return {
|
||||
# Literal argument value
|
||||
"total": total,
|
||||
# Literal argument value
|
||||
"offset": offset,
|
||||
# Literal argument value
|
||||
"limit": limit,
|
||||
# Literal argument value
|
||||
"items": [
|
||||
{
|
||||
# Literal argument value
|
||||
"id": str(r.id),
|
||||
# Literal argument value
|
||||
"mitre_technique_id": r.mitre_technique_id,
|
||||
# Literal argument value
|
||||
"title": r.title,
|
||||
# Literal argument value
|
||||
"description": r.description,
|
||||
# Literal argument value
|
||||
"source": r.source,
|
||||
# Literal argument value
|
||||
"source_url": r.source_url,
|
||||
# Literal argument value
|
||||
"rule_format": r.rule_format,
|
||||
# Literal argument value
|
||||
"severity": r.severity,
|
||||
# Literal argument value
|
||||
"platforms": r.platforms or [],
|
||||
# Literal argument value
|
||||
"log_sources": r.log_sources,
|
||||
# Literal argument value
|
||||
"is_active": r.is_active,
|
||||
}
|
||||
for r in items
|
||||
@@ -83,48 +143,78 @@ def list_rules(
|
||||
}
|
||||
|
||||
|
||||
# Define function get_rules_for_template
|
||||
def get_rules_for_template(db: Session, template_id: str) -> dict[str, Any]:
|
||||
"""Get detection rules associated with a test template.
|
||||
|
||||
Raises EntityNotFoundError if the template does not exist.
|
||||
"""
|
||||
# Assign template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first()
|
||||
template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first()
|
||||
# Check: not template
|
||||
if not template:
|
||||
# Raise EntityNotFoundError
|
||||
raise EntityNotFoundError("Test template", template_id)
|
||||
|
||||
# Assign associations = (
|
||||
associations = (
|
||||
db.query(TestTemplateDetectionRule)
|
||||
# Chain .filter() call
|
||||
.filter(TestTemplateDetectionRule.test_template_id == template_id)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
|
||||
# Assign rules = []
|
||||
rules = []
|
||||
# Iterate over associations
|
||||
for assoc in associations:
|
||||
# Assign r = assoc.detection_rule
|
||||
r = assoc.detection_rule
|
||||
# Call rules.append()
|
||||
rules.append({
|
||||
# Literal argument value
|
||||
"id": str(r.id),
|
||||
# Literal argument value
|
||||
"mitre_technique_id": r.mitre_technique_id,
|
||||
# Literal argument value
|
||||
"title": r.title,
|
||||
# Literal argument value
|
||||
"description": r.description,
|
||||
# Literal argument value
|
||||
"source": r.source,
|
||||
# Literal argument value
|
||||
"source_url": r.source_url,
|
||||
# Literal argument value
|
||||
"rule_content": r.rule_content,
|
||||
# Literal argument value
|
||||
"rule_format": r.rule_format,
|
||||
# Literal argument value
|
||||
"severity": r.severity,
|
||||
# Literal argument value
|
||||
"platforms": r.platforms or [],
|
||||
# Literal argument value
|
||||
"log_sources": r.log_sources,
|
||||
# Literal argument value
|
||||
"is_primary": assoc.is_primary,
|
||||
})
|
||||
|
||||
# Return {
|
||||
return {
|
||||
# Literal argument value
|
||||
"template_id": str(template.id),
|
||||
# Literal argument value
|
||||
"template_name": template.name,
|
||||
# Literal argument value
|
||||
"mitre_technique_id": template.mitre_technique_id,
|
||||
# Literal argument value
|
||||
"rules": rules,
|
||||
# Literal argument value
|
||||
"total": len(rules),
|
||||
}
|
||||
|
||||
|
||||
# Define function auto_associate_rules
|
||||
def auto_associate_rules(db: Session) -> dict[str, Any]:
|
||||
"""Auto-associate test templates with detection rules by MITRE technique ID.
|
||||
|
||||
@@ -132,188 +222,316 @@ def auto_associate_rules(db: Session) -> dict[str, Any]:
|
||||
technique and creates associations. Rules with severity high/critical
|
||||
are marked as primary. Performs commit internally.
|
||||
"""
|
||||
# Assign templates = db.query(TestTemplate).filter(TestTemplate.is_active == True).all()
|
||||
templates = db.query(TestTemplate).filter(TestTemplate.is_active == True).all()
|
||||
# Assign rules = db.query(DetectionRule).filter(DetectionRule.is_active == True).all()
|
||||
rules = db.query(DetectionRule).filter(DetectionRule.is_active == True).all()
|
||||
|
||||
# Assign rules_by_technique = {}
|
||||
rules_by_technique: dict[str, list] = {}
|
||||
# Iterate over rules
|
||||
for rule in rules:
|
||||
# Assign tid = rule.mitre_technique_id
|
||||
tid = rule.mitre_technique_id
|
||||
# Check: tid not in rules_by_technique
|
||||
if tid not in rules_by_technique:
|
||||
# Assign rules_by_technique[tid] = []
|
||||
rules_by_technique[tid] = []
|
||||
# rules_by_technique[tid].append(rule)
|
||||
rules_by_technique[tid].append(rule)
|
||||
|
||||
# Assign created = 0
|
||||
created = 0
|
||||
# Assign skipped = 0
|
||||
skipped = 0
|
||||
# Assign high_severities = {"high", "critical"}
|
||||
high_severities = {"high", "critical"}
|
||||
|
||||
# Iterate over templates
|
||||
for template in templates:
|
||||
# Assign matching_rules = rules_by_technique.get(template.mitre_technique_id, [])
|
||||
matching_rules = rules_by_technique.get(template.mitre_technique_id, [])
|
||||
# Iterate over matching_rules
|
||||
for rule in matching_rules:
|
||||
# Assign existing = (
|
||||
existing = (
|
||||
db.query(TestTemplateDetectionRule)
|
||||
# Chain .filter() call
|
||||
.filter(
|
||||
TestTemplateDetectionRule.test_template_id == template.id,
|
||||
TestTemplateDetectionRule.detection_rule_id == rule.id,
|
||||
)
|
||||
# Chain .first() call
|
||||
.first()
|
||||
)
|
||||
# Check: existing
|
||||
if existing:
|
||||
# Assign skipped = 1
|
||||
skipped += 1
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
|
||||
# Assign is_primary = (rule.severity or "").lower() in high_severities
|
||||
is_primary = (rule.severity or "").lower() in high_severities
|
||||
|
||||
# Assign assoc = TestTemplateDetectionRule(
|
||||
assoc = TestTemplateDetectionRule(
|
||||
# Keyword argument: test_template_id
|
||||
test_template_id=template.id,
|
||||
# Keyword argument: detection_rule_id
|
||||
detection_rule_id=rule.id,
|
||||
# Keyword argument: is_primary
|
||||
is_primary=is_primary,
|
||||
)
|
||||
# Stage new record(s) for database insertion
|
||||
db.add(assoc)
|
||||
# Assign created = 1
|
||||
created += 1
|
||||
|
||||
# Commit all pending changes to the database
|
||||
db.commit()
|
||||
|
||||
# Assign total = db.query(TestTemplateDetectionRule).count()
|
||||
total = db.query(TestTemplateDetectionRule).count()
|
||||
# Return {
|
||||
return {
|
||||
# Literal argument value
|
||||
"created": created,
|
||||
# Literal argument value
|
||||
"skipped": skipped,
|
||||
# Literal argument value
|
||||
"total_associations": total,
|
||||
}
|
||||
|
||||
|
||||
# Define function get_rules_for_test
|
||||
def get_rules_for_test(db: Session, test_id: str) -> dict[str, Any]:
|
||||
"""Get detection rules relevant to a test, along with their evaluation results.
|
||||
|
||||
Finds rules by matching the test's technique to detection rules.
|
||||
Raises EntityNotFoundError if the test or its technique does not exist.
|
||||
"""
|
||||
# Assign test = db.query(Test).filter(Test.id == test_id).first()
|
||||
test = db.query(Test).filter(Test.id == test_id).first()
|
||||
# Check: not test
|
||||
if not test:
|
||||
# Raise EntityNotFoundError
|
||||
raise EntityNotFoundError("Test", str(test_id))
|
||||
|
||||
# Assign technique = db.query(Technique).filter(Technique.id == test.technique_id).first()
|
||||
technique = db.query(Technique).filter(Technique.id == test.technique_id).first()
|
||||
# Check: not technique
|
||||
if not technique:
|
||||
# Raise EntityNotFoundError
|
||||
raise EntityNotFoundError("Technique", str(test.technique_id))
|
||||
|
||||
# Assign rules = (
|
||||
rules = (
|
||||
db.query(DetectionRule)
|
||||
# Chain .filter() call
|
||||
.filter(
|
||||
DetectionRule.mitre_technique_id == technique.mitre_id,
|
||||
DetectionRule.is_active == True,
|
||||
)
|
||||
# Chain .order_by() call
|
||||
.order_by(DetectionRule.severity.desc(), DetectionRule.title)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
|
||||
# Assign existing_results = (
|
||||
existing_results = (
|
||||
db.query(TestDetectionResult)
|
||||
# Chain .filter() call
|
||||
.filter(TestDetectionResult.test_id == test_id)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
# Assign results_map = {str(r.detection_rule_id): r for r in existing_results}
|
||||
results_map = {str(r.detection_rule_id): r for r in existing_results}
|
||||
|
||||
# Assign items = []
|
||||
items = []
|
||||
# Assign triggered_count = 0
|
||||
triggered_count = 0
|
||||
# Assign evaluated_count = 0
|
||||
evaluated_count = 0
|
||||
|
||||
# Iterate over rules
|
||||
for rule in rules:
|
||||
# Assign result = results_map.get(str(rule.id))
|
||||
result = results_map.get(str(rule.id))
|
||||
# Assign triggered = result.triggered if result else None
|
||||
triggered = result.triggered if result else None
|
||||
# Assign notes = result.notes if result else None
|
||||
notes = result.notes if result else None
|
||||
# Assign evaluated_at = result.evaluated_at.isoformat() if result and result.evaluated_at e...
|
||||
evaluated_at = result.evaluated_at.isoformat() if result and result.evaluated_at else None
|
||||
|
||||
# Check: triggered is not None
|
||||
if triggered is not None:
|
||||
# Assign evaluated_count = 1
|
||||
evaluated_count += 1
|
||||
# Check: triggered
|
||||
if triggered:
|
||||
# Assign triggered_count = 1
|
||||
triggered_count += 1
|
||||
|
||||
# Call items.append()
|
||||
items.append({
|
||||
# Literal argument value
|
||||
"id": str(rule.id),
|
||||
# Literal argument value
|
||||
"mitre_technique_id": rule.mitre_technique_id,
|
||||
# Literal argument value
|
||||
"title": rule.title,
|
||||
# Literal argument value
|
||||
"description": rule.description,
|
||||
# Literal argument value
|
||||
"source": rule.source,
|
||||
# Literal argument value
|
||||
"source_url": rule.source_url,
|
||||
# Literal argument value
|
||||
"rule_content": rule.rule_content,
|
||||
# Literal argument value
|
||||
"rule_format": rule.rule_format,
|
||||
# Literal argument value
|
||||
"severity": rule.severity,
|
||||
# Literal argument value
|
||||
"platforms": rule.platforms or [],
|
||||
# Literal argument value
|
||||
"log_sources": rule.log_sources,
|
||||
# Literal argument value
|
||||
"triggered": triggered,
|
||||
# Literal argument value
|
||||
"notes": notes,
|
||||
# Literal argument value
|
||||
"evaluated_at": evaluated_at,
|
||||
# Literal argument value
|
||||
"result_id": str(result.id) if result else None,
|
||||
})
|
||||
|
||||
# Return {
|
||||
return {
|
||||
# Literal argument value
|
||||
"test_id": str(test.id),
|
||||
# Literal argument value
|
||||
"mitre_technique_id": technique.mitre_id,
|
||||
# Literal argument value
|
||||
"rules": items,
|
||||
# Literal argument value
|
||||
"total": len(items),
|
||||
# Literal argument value
|
||||
"evaluated": evaluated_count,
|
||||
# Literal argument value
|
||||
"triggered": triggered_count,
|
||||
# Literal argument value
|
||||
"detection_rate": round(triggered_count / evaluated_count * 100, 1) if evaluated_count > 0 else 0,
|
||||
}
|
||||
|
||||
|
||||
# Define function evaluate_rule
|
||||
def evaluate_rule(
|
||||
# Entry: db
|
||||
db: Session,
|
||||
*,
|
||||
test_id: Any,
|
||||
detection_rule_id: Any,
|
||||
# Entry: test_id
|
||||
test_id: UUID,
|
||||
# Entry: detection_rule_id
|
||||
detection_rule_id: UUID,
|
||||
# Entry: triggered
|
||||
triggered: bool | None,
|
||||
# Entry: notes
|
||||
notes: str | None,
|
||||
evaluator_id: Any,
|
||||
# Entry: evaluator_id
|
||||
evaluator_id: UUID,
|
||||
) -> dict[str, Any]:
|
||||
"""Save or update the evaluation result for a detection rule on a test.
|
||||
|
||||
Raises EntityNotFoundError if the test or detection rule does not exist.
|
||||
"""
|
||||
# Assign test = db.query(Test).filter(Test.id == test_id).first()
|
||||
test = db.query(Test).filter(Test.id == test_id).first()
|
||||
# Check: not test
|
||||
if not test:
|
||||
# Raise EntityNotFoundError
|
||||
raise EntityNotFoundError("Test", str(test_id))
|
||||
|
||||
# Assign rule = db.query(DetectionRule).filter(DetectionRule.id == detection_rule_i...
|
||||
rule = db.query(DetectionRule).filter(DetectionRule.id == detection_rule_id).first()
|
||||
# Check: not rule
|
||||
if not rule:
|
||||
# Raise EntityNotFoundError
|
||||
raise EntityNotFoundError("Detection rule", str(detection_rule_id))
|
||||
|
||||
# Assign existing = (
|
||||
existing = (
|
||||
db.query(TestDetectionResult)
|
||||
# Chain .filter() call
|
||||
.filter(
|
||||
TestDetectionResult.test_id == test_id,
|
||||
TestDetectionResult.detection_rule_id == detection_rule_id,
|
||||
)
|
||||
# Chain .first() call
|
||||
.first()
|
||||
)
|
||||
|
||||
# Check: existing
|
||||
if existing:
|
||||
# Assign existing.triggered = triggered
|
||||
existing.triggered = triggered
|
||||
# Assign existing.notes = notes
|
||||
existing.notes = notes
|
||||
# Assign existing.evaluated_by = evaluator_id
|
||||
existing.evaluated_by = evaluator_id
|
||||
# Assign existing.evaluated_at = datetime.utcnow()
|
||||
existing.evaluated_at = datetime.utcnow()
|
||||
# Commit all pending changes to the database
|
||||
db.commit()
|
||||
# Reload ORM object attributes from the database
|
||||
db.refresh(existing)
|
||||
# Return {
|
||||
return {
|
||||
# Literal argument value
|
||||
"id": str(existing.id),
|
||||
# Literal argument value
|
||||
"triggered": existing.triggered,
|
||||
# Literal argument value
|
||||
"notes": existing.notes,
|
||||
# Literal argument value
|
||||
"evaluated_at": existing.evaluated_at.isoformat() if existing.evaluated_at else None,
|
||||
}
|
||||
# Fallback: handle remaining cases
|
||||
else:
|
||||
# Assign result = TestDetectionResult(
|
||||
result = TestDetectionResult(
|
||||
# Keyword argument: test_id
|
||||
test_id=test_id,
|
||||
# Keyword argument: detection_rule_id
|
||||
detection_rule_id=detection_rule_id,
|
||||
# Keyword argument: triggered
|
||||
triggered=triggered,
|
||||
# Keyword argument: notes
|
||||
notes=notes,
|
||||
# Keyword argument: evaluated_by
|
||||
evaluated_by=evaluator_id,
|
||||
# Keyword argument: evaluated_at
|
||||
evaluated_at=datetime.utcnow(),
|
||||
)
|
||||
# Stage new record(s) for database insertion
|
||||
db.add(result)
|
||||
# Commit all pending changes to the database
|
||||
db.commit()
|
||||
# Reload ORM object attributes from the database
|
||||
db.refresh(result)
|
||||
# Return {
|
||||
return {
|
||||
# Literal argument value
|
||||
"id": str(result.id),
|
||||
# Literal argument value
|
||||
"triggered": result.triggered,
|
||||
# Literal argument value
|
||||
"notes": result.notes,
|
||||
# Literal argument value
|
||||
"evaluated_at": result.evaluated_at.isoformat() if result.evaluated_at else None,
|
||||
}
|
||||
|
||||
@@ -21,22 +21,39 @@ rules are identified by ``source = "elastic"`` + ``source_id`` (the
|
||||
TOML filename).
|
||||
"""
|
||||
|
||||
# Import io
|
||||
import io
|
||||
|
||||
# Import logging
|
||||
import logging
|
||||
|
||||
# Import shutil
|
||||
import shutil
|
||||
|
||||
# Import tempfile
|
||||
import tempfile
|
||||
|
||||
# Import zipfile
|
||||
import zipfile
|
||||
|
||||
# Import datetime from datetime
|
||||
from datetime import datetime
|
||||
|
||||
# Import Path from pathlib
|
||||
from pathlib import Path
|
||||
|
||||
# Import requests
|
||||
import requests as _requests
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.detection_rule import DetectionRule
|
||||
# Import DataSource from app.models.data_source
|
||||
from app.models.data_source import DataSource
|
||||
from app.models.technique import Technique
|
||||
from app.services.audit_service import log_action
|
||||
|
||||
# Assign logger = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -44,19 +61,33 @@ logger = logging.getLogger(__name__)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
ELASTIC_ZIP_URL = (
|
||||
# Literal argument value
|
||||
"https://github.com/elastic/detection-rules"
|
||||
# Literal argument value
|
||||
"/archive/refs/heads/main.zip"
|
||||
)
|
||||
|
||||
# Assign _DOWNLOAD_TIMEOUT = 300
|
||||
_DOWNLOAD_TIMEOUT = 300
|
||||
# Assign _ZIP_ROOT_PREFIX = "detection-rules-main"
|
||||
_ZIP_ROOT_PREFIX = "detection-rules-main"
|
||||
|
||||
# Safety limits for ZIP extraction — prevent zip-bomb DoS
|
||||
_MAX_UNCOMPRESSED_SIZE = 500 * 1024 * 1024 # 500 MB
|
||||
# Assign _MAX_ENTRIES = 50_000
|
||||
_MAX_ENTRIES = 50_000
|
||||
|
||||
# Severity normalisation
|
||||
_SEVERITY_MAP = {
|
||||
# Literal argument value
|
||||
"informational": "informational",
|
||||
# Literal argument value
|
||||
"low": "low",
|
||||
# Literal argument value
|
||||
"medium": "medium",
|
||||
# Literal argument value
|
||||
"high": "high",
|
||||
# Literal argument value
|
||||
"critical": "critical",
|
||||
}
|
||||
|
||||
@@ -68,14 +99,21 @@ _SEVERITY_MAP = {
|
||||
|
||||
def _download_zip(url: str = ELASTIC_ZIP_URL) -> bytes:
|
||||
"""Download the Elastic Detection Rules ZIP and return raw bytes."""
|
||||
# Log info: "Downloading Elastic Detection Rules ZIP from %s …
|
||||
logger.info("Downloading Elastic Detection Rules ZIP from %s …", url)
|
||||
# Assign resp = _requests.get(url, timeout=_DOWNLOAD_TIMEOUT, stream=True)
|
||||
resp = _requests.get(url, timeout=_DOWNLOAD_TIMEOUT, stream=True)
|
||||
# Call resp.raise_for_status()
|
||||
resp.raise_for_status()
|
||||
# Assign content = resp.content
|
||||
content = resp.content
|
||||
# Log info: "Downloaded %.1f MB", len(content) / (1024 * 1024
|
||||
logger.info("Downloaded %.1f MB", len(content) / (1024 * 1024))
|
||||
# Return content
|
||||
return content
|
||||
|
||||
|
||||
# Define function _safe_extract_zip
|
||||
def _safe_extract_zip(zip_bytes: bytes, dest: str) -> None:
|
||||
"""Extract *zip_bytes* into *dest* with Zip Slip and Zip Bomb protection.
|
||||
|
||||
@@ -83,62 +121,85 @@ def _safe_extract_zip(zip_bytes: bytes, dest: str) -> None:
|
||||
directory (path traversal / Zip Slip) or if the archive exceeds the
|
||||
safety limits.
|
||||
"""
|
||||
# Maximum uncompressed size: 500 MB — prevents zip-bomb DoS
|
||||
_MAX_UNCOMPRESSED_SIZE = 500 * 1024 * 1024
|
||||
# Maximum number of entries
|
||||
_MAX_ENTRIES = 50_000
|
||||
|
||||
# Assign dest_path = Path(dest).resolve()
|
||||
dest_path = Path(dest).resolve()
|
||||
|
||||
# Open context manager
|
||||
with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
|
||||
# Assign entries = zf.infolist()
|
||||
entries = zf.infolist()
|
||||
|
||||
# Check: len(entries) > _MAX_ENTRIES
|
||||
if len(entries) > _MAX_ENTRIES:
|
||||
# Raise ValueError
|
||||
raise ValueError(
|
||||
f"ZIP archive contains {len(entries)} entries "
|
||||
f"(limit: {_MAX_ENTRIES}) — possible zip bomb"
|
||||
)
|
||||
|
||||
# Assign total_size = sum(info.file_size for info in entries)
|
||||
total_size = sum(info.file_size for info in entries)
|
||||
# Check: total_size > _MAX_UNCOMPRESSED_SIZE
|
||||
if total_size > _MAX_UNCOMPRESSED_SIZE:
|
||||
# Raise ValueError
|
||||
raise ValueError(
|
||||
f"ZIP uncompressed size {total_size / (1024 * 1024):.0f} MB "
|
||||
f"exceeds limit of {_MAX_UNCOMPRESSED_SIZE / (1024 * 1024):.0f} MB"
|
||||
)
|
||||
|
||||
# Iterate over entries
|
||||
for member in entries:
|
||||
# Assign target = (dest_path / member.filename).resolve()
|
||||
target = (dest_path / member.filename).resolve()
|
||||
# Check: not target.is_relative_to(dest_path)
|
||||
if not target.is_relative_to(dest_path):
|
||||
# Raise ValueError
|
||||
raise ValueError(
|
||||
f"Zip Slip detected — member '{member.filename}' "
|
||||
f"resolves outside target directory"
|
||||
)
|
||||
|
||||
# Call zf.extractall()
|
||||
zf.extractall(dest)
|
||||
|
||||
|
||||
# Define function _extract_zip
|
||||
def _extract_zip(zip_bytes: bytes, dest: str) -> Path:
|
||||
"""Extract *zip_bytes* into *dest* and return rules/ dir."""
|
||||
# Call _safe_extract_zip()
|
||||
_safe_extract_zip(zip_bytes, dest)
|
||||
# Assign rules_dir = Path(dest) / _ZIP_ROOT_PREFIX / "rules"
|
||||
rules_dir = Path(dest) / _ZIP_ROOT_PREFIX / "rules"
|
||||
# Check: not rules_dir.is_dir()
|
||||
if not rules_dir.is_dir():
|
||||
# Raise FileNotFoundError
|
||||
raise FileNotFoundError(
|
||||
f"Expected rules directory not found at {rules_dir}"
|
||||
)
|
||||
# Return rules_dir
|
||||
return rules_dir
|
||||
|
||||
|
||||
# Define function _parse_toml_safe
|
||||
def _parse_toml_safe(path: Path) -> dict | None:
|
||||
"""Parse a TOML file. Uses the ``toml`` library."""
|
||||
# Attempt the following; catch errors below
|
||||
try:
|
||||
# Import toml
|
||||
import toml
|
||||
# Open context manager
|
||||
with open(path, "r", encoding="utf-8") as fh:
|
||||
# Return toml.load(fh)
|
||||
return toml.load(fh)
|
||||
# Handle Exception
|
||||
except Exception as exc:
|
||||
# Log debug: "Failed to parse %s: %s", path, exc
|
||||
logger.debug("Failed to parse %s: %s", path, exc)
|
||||
# Return None
|
||||
return None
|
||||
|
||||
|
||||
# Define function _extract_mitre_techniques
|
||||
def _extract_mitre_techniques(threat_list: list) -> list[str]:
|
||||
"""Extract MITRE technique IDs from Elastic's ``rule.threat`` array.
|
||||
|
||||
@@ -156,82 +217,132 @@ def _extract_mitre_techniques(threat_list: list) -> list[str]:
|
||||
name = "LSASS Memory"
|
||||
id = "T1003.001"
|
||||
"""
|
||||
# Assign technique_ids = []
|
||||
technique_ids = []
|
||||
|
||||
# Check: not isinstance(threat_list, list)
|
||||
if not isinstance(threat_list, list):
|
||||
# Return technique_ids
|
||||
return technique_ids
|
||||
|
||||
# Iterate over threat_list
|
||||
for threat_entry in threat_list:
|
||||
# Check: not isinstance(threat_entry, dict)
|
||||
if not isinstance(threat_entry, dict):
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
|
||||
# Skip non-MITRE frameworks
|
||||
framework = threat_entry.get("framework", "")
|
||||
# Check: "MITRE" not in str(framework).upper()
|
||||
if "MITRE" not in str(framework).upper():
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
|
||||
# Assign techniques = threat_entry.get("technique", [])
|
||||
techniques = threat_entry.get("technique", [])
|
||||
# Check: not isinstance(techniques, list)
|
||||
if not isinstance(techniques, list):
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
|
||||
# Iterate over techniques
|
||||
for tech in techniques:
|
||||
# Check: not isinstance(tech, dict)
|
||||
if not isinstance(tech, dict):
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
# Assign tech_id = tech.get("id", "")
|
||||
tech_id = tech.get("id", "")
|
||||
# Check: tech_id and str(tech_id).upper().startswith("T")
|
||||
if tech_id and str(tech_id).upper().startswith("T"):
|
||||
# Call technique_ids.append()
|
||||
technique_ids.append(str(tech_id).upper())
|
||||
|
||||
# Check subtechniques
|
||||
subtechniques = tech.get("subtechnique", [])
|
||||
# Check: isinstance(subtechniques, list)
|
||||
if isinstance(subtechniques, list):
|
||||
# Iterate over subtechniques
|
||||
for subtech in subtechniques:
|
||||
# Check: isinstance(subtech, dict)
|
||||
if isinstance(subtech, dict):
|
||||
# Assign sub_id = subtech.get("id", "")
|
||||
sub_id = subtech.get("id", "")
|
||||
# Check: sub_id and str(sub_id).upper().startswith("T")
|
||||
if sub_id and str(sub_id).upper().startswith("T"):
|
||||
# Call technique_ids.append()
|
||||
technique_ids.append(str(sub_id).upper())
|
||||
|
||||
# Return list(set(technique_ids))
|
||||
return list(set(technique_ids))
|
||||
|
||||
|
||||
# Define function _parse_elastic_rules
|
||||
def _parse_elastic_rules(rules_dir: Path) -> list[dict]:
|
||||
"""Walk the rules directory and parse all TOML files.
|
||||
|
||||
Returns a flat list of dicts, one per (rule, technique) combination.
|
||||
"""
|
||||
# Assign results = []
|
||||
results: list[dict] = []
|
||||
# Assign toml_files = sorted(rules_dir.rglob("*.toml"))
|
||||
toml_files = sorted(rules_dir.rglob("*.toml"))
|
||||
# Log info: "Found %d TOML files to parse", len(toml_files
|
||||
logger.info("Found %d TOML files to parse", len(toml_files))
|
||||
|
||||
# Iterate over toml_files
|
||||
for toml_path in toml_files:
|
||||
# Assign data = _parse_toml_safe(toml_path)
|
||||
data = _parse_toml_safe(toml_path)
|
||||
# Check: not data
|
||||
if not data:
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
|
||||
# Assign rule = data.get("rule", {})
|
||||
rule = data.get("rule", {})
|
||||
# Check: not isinstance(rule, dict)
|
||||
if not isinstance(rule, dict):
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
|
||||
# Assign name = rule.get("name", "").strip()
|
||||
name = rule.get("name", "").strip()
|
||||
# Check: not name
|
||||
if not name:
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
|
||||
# Extract MITRE technique IDs
|
||||
threat_list = rule.get("threat", [])
|
||||
# Assign technique_ids = _extract_mitre_techniques(threat_list)
|
||||
technique_ids = _extract_mitre_techniques(threat_list)
|
||||
# Check: not technique_ids
|
||||
if not technique_ids:
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
|
||||
# Assign description = rule.get("description", "")
|
||||
description = rule.get("description", "")
|
||||
# Assign query = rule.get("query", "")
|
||||
query = rule.get("query", "")
|
||||
# Assign severity = _SEVERITY_MAP.get(str(rule.get("severity", "")).lower())
|
||||
severity = _SEVERITY_MAP.get(str(rule.get("severity", "")).lower())
|
||||
# Assign rule_type = rule.get("type", "query") # query, eql, threshold, etc.
|
||||
rule_type = rule.get("type", "query") # query, eql, threshold, etc.
|
||||
|
||||
# Determine rule format based on type
|
||||
if rule_type == "eql":
|
||||
# Assign rule_format = "eql"
|
||||
rule_format = "eql"
|
||||
# Alternative: rule_type == "esql"
|
||||
elif rule_type == "esql":
|
||||
# Assign rule_format = "esql"
|
||||
rule_format = "esql"
|
||||
# Fallback: handle remaining cases
|
||||
else:
|
||||
# Assign rule_format = "kql"
|
||||
rule_format = "kql"
|
||||
|
||||
# Use filename as source_id
|
||||
@@ -239,51 +350,79 @@ def _parse_elastic_rules(rules_dir: Path) -> list[dict]:
|
||||
|
||||
# Read raw content
|
||||
try:
|
||||
# Open context manager
|
||||
with open(toml_path, "r", encoding="utf-8") as fh:
|
||||
# Assign raw_content = fh.read()
|
||||
raw_content = fh.read()
|
||||
# Handle Exception
|
||||
except Exception:
|
||||
# Assign raw_content = query or str(data)
|
||||
raw_content = query or str(data)
|
||||
|
||||
# Build source URL
|
||||
relative = str(toml_path.relative_to(rules_dir.parent)).replace("\\", "/")
|
||||
# Assign source_url = (
|
||||
source_url = (
|
||||
f"https://github.com/elastic/detection-rules/blob/main/{relative}"
|
||||
)
|
||||
|
||||
# One entry per technique
|
||||
for tech_id in technique_ids:
|
||||
# Call results.append()
|
||||
results.append({
|
||||
# Literal argument value
|
||||
"mitre_technique_id": tech_id,
|
||||
# Literal argument value
|
||||
"title": name[:500],
|
||||
# Literal argument value
|
||||
"description": str(description)[:2000] if description else None,
|
||||
# Literal argument value
|
||||
"source_id": source_id,
|
||||
# Literal argument value
|
||||
"source_url": source_url,
|
||||
# Literal argument value
|
||||
"rule_content": query[:50000] if query else raw_content[:50000],
|
||||
# Literal argument value
|
||||
"rule_format": rule_format,
|
||||
# Literal argument value
|
||||
"severity": severity,
|
||||
# Literal argument value
|
||||
"platforms": _infer_platforms(rules_dir, toml_path),
|
||||
})
|
||||
|
||||
# Log info: "Parsed %d (rule, technique) pairs total", len(res
|
||||
logger.info("Parsed %d (rule, technique) pairs total", len(results))
|
||||
# Return results
|
||||
return results
|
||||
|
||||
|
||||
# Define function _infer_platforms
|
||||
def _infer_platforms(rules_dir: Path, toml_path: Path) -> list[str] | None:
|
||||
"""Infer platforms from the rule's directory structure.
|
||||
|
||||
Elastic organizes rules by OS: rules/windows/, rules/linux/, etc.
|
||||
"""
|
||||
# Assign relative = toml_path.relative_to(rules_dir)
|
||||
relative = toml_path.relative_to(rules_dir)
|
||||
# Assign parts = [p.lower() for p in relative.parts]
|
||||
parts = [p.lower() for p in relative.parts]
|
||||
|
||||
# Assign platforms = []
|
||||
platforms = []
|
||||
# Check: "windows" in parts
|
||||
if "windows" in parts:
|
||||
# Call platforms.append()
|
||||
platforms.append("windows")
|
||||
# Check: "linux" in parts
|
||||
if "linux" in parts:
|
||||
# Call platforms.append()
|
||||
platforms.append("linux")
|
||||
# Check: "macos" in parts
|
||||
if "macos" in parts:
|
||||
# Call platforms.append()
|
||||
platforms.append("macos")
|
||||
|
||||
# Return platforms if platforms else None
|
||||
return platforms if platforms else None
|
||||
|
||||
|
||||
@@ -297,47 +436,78 @@ def sync(db: Session) -> dict:
|
||||
|
||||
Returns a summary dict with ``created``, ``skipped_existing``, ``total_parsed``.
|
||||
"""
|
||||
# Assign tmp_dir = tempfile.mkdtemp(prefix="aegis_elastic_")
|
||||
tmp_dir = tempfile.mkdtemp(prefix="aegis_elastic_")
|
||||
# Attempt the following; catch errors below
|
||||
try:
|
||||
# Assign zip_bytes = _download_zip()
|
||||
zip_bytes = _download_zip()
|
||||
# Assign rules_dir = _extract_zip(zip_bytes, tmp_dir)
|
||||
rules_dir = _extract_zip(zip_bytes, tmp_dir)
|
||||
# Assign parsed_rules = _parse_elastic_rules(rules_dir)
|
||||
parsed_rules = _parse_elastic_rules(rules_dir)
|
||||
# Always execute this cleanup block
|
||||
finally:
|
||||
# Call shutil.rmtree()
|
||||
shutil.rmtree(tmp_dir, ignore_errors=True)
|
||||
# Log info: "Cleaned up temp directory %s", tmp_dir
|
||||
logger.info("Cleaned up temp directory %s", tmp_dir)
|
||||
|
||||
# Pre-load existing source_ids for dedup
|
||||
existing_ids: set[str] = {
|
||||
row[0]
|
||||
for row in db.query(DetectionRule.source_id)
|
||||
# Chain .filter() call
|
||||
.filter(DetectionRule.source == "elastic")
|
||||
# Chain .filter() call
|
||||
.filter(DetectionRule.source_id.isnot(None))
|
||||
# Chain .all() call
|
||||
.all()
|
||||
}
|
||||
|
||||
# Assign created = 0
|
||||
created = 0
|
||||
# Assign skipped = 0
|
||||
skipped = 0
|
||||
new_technique_ids: set[str] = set()
|
||||
|
||||
# Iterate over parsed_rules
|
||||
for item in parsed_rules:
|
||||
# Check: item["source_id"] in existing_ids
|
||||
if item["source_id"] in existing_ids:
|
||||
# Assign skipped = 1
|
||||
skipped += 1
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
|
||||
# Assign rule = DetectionRule(
|
||||
rule = DetectionRule(
|
||||
# Keyword argument: mitre_technique_id
|
||||
mitre_technique_id=item["mitre_technique_id"],
|
||||
# Keyword argument: title
|
||||
title=item["title"],
|
||||
# Keyword argument: description
|
||||
description=item["description"],
|
||||
# Keyword argument: source
|
||||
source="elastic",
|
||||
# Keyword argument: source_id
|
||||
source_id=item["source_id"],
|
||||
# Keyword argument: source_url
|
||||
source_url=item["source_url"],
|
||||
# Keyword argument: rule_content
|
||||
rule_content=item["rule_content"],
|
||||
# Keyword argument: rule_format
|
||||
rule_format=item["rule_format"],
|
||||
# Keyword argument: severity
|
||||
severity=item["severity"],
|
||||
# Keyword argument: platforms
|
||||
platforms=item["platforms"],
|
||||
# Keyword argument: is_active
|
||||
is_active=True,
|
||||
)
|
||||
# Stage new record(s) for database insertion
|
||||
db.add(rule)
|
||||
# Call existing_ids.add()
|
||||
existing_ids.add(item["source_id"])
|
||||
new_technique_ids.add(item["mitre_technique_id"])
|
||||
created += 1
|
||||
@@ -350,22 +520,36 @@ def sync(db: Session) -> dict:
|
||||
|
||||
db.commit()
|
||||
|
||||
# Assign summary = {
|
||||
summary = {
|
||||
# Literal argument value
|
||||
"created": created,
|
||||
# Literal argument value
|
||||
"skipped_existing": skipped,
|
||||
# Literal argument value
|
||||
"total_parsed": len(parsed_rules),
|
||||
}
|
||||
|
||||
# Update DataSource record
|
||||
ds = db.query(DataSource).filter(DataSource.name == "elastic_rules").first()
|
||||
# Check: ds
|
||||
if ds:
|
||||
# Assign ds.last_sync_at = datetime.utcnow()
|
||||
ds.last_sync_at = datetime.utcnow()
|
||||
# Assign ds.last_sync_status = "success"
|
||||
ds.last_sync_status = "success"
|
||||
# Assign ds.last_sync_stats = summary
|
||||
ds.last_sync_stats = summary
|
||||
# Commit all pending changes to the database
|
||||
db.commit()
|
||||
|
||||
# Log info: "Elastic import complete — %s", summary
|
||||
logger.info("Elastic import complete — %s", summary)
|
||||
# Call log_action()
|
||||
log_action(db, user_id=None, action="import_elastic_rules",
|
||||
# Keyword argument: entity_type
|
||||
entity_type="detection_rule", entity_id=None, details=summary)
|
||||
# Commit all pending changes to the database
|
||||
db.commit()
|
||||
# Return summary
|
||||
return summary
|
||||
|
||||
@@ -5,20 +5,32 @@ The router is responsible for HTTP concerns, file I/O, MinIO upload,
|
||||
audit logging, and response formatting.
|
||||
"""
|
||||
|
||||
# Enable future language features for compatibility
|
||||
from __future__ import annotations
|
||||
|
||||
# Import os
|
||||
import os
|
||||
|
||||
# Import uuid
|
||||
import uuid
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import from app.domain.errors
|
||||
from app.domain.errors import (
|
||||
BusinessRuleViolation,
|
||||
EntityNotFoundError,
|
||||
PermissionViolation,
|
||||
)
|
||||
|
||||
# Import TeamSide, TestState from app.models.enums
|
||||
from app.models.enums import TeamSide, TestState
|
||||
|
||||
# Import Evidence from app.models.evidence
|
||||
from app.models.evidence import Evidence
|
||||
|
||||
# Import Test from app.models.test
|
||||
from app.models.test import Test
|
||||
|
||||
# States where red evidence can be uploaded / deleted
|
||||
@@ -31,19 +43,30 @@ MAX_UPLOAD_SIZE = 50 * 1024 * 1024
|
||||
|
||||
# Allowed file extensions (lowercase, with leading dot)
|
||||
ALLOWED_EXTENSIONS: frozenset[str] = frozenset({
|
||||
# Literal argument value
|
||||
".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp", ".svg",
|
||||
# Literal argument value
|
||||
".pdf", ".doc", ".docx", ".xls", ".xlsx", ".csv", ".txt",
|
||||
# Literal argument value
|
||||
".md", ".rtf", ".odt", ".ods",
|
||||
# Literal argument value
|
||||
".log", ".pcap", ".pcapng", ".evtx", ".json", ".xml",
|
||||
# Literal argument value
|
||||
".yaml", ".yml", ".toml",
|
||||
# Literal argument value
|
||||
".zip", ".tar", ".gz", ".7z",
|
||||
# Literal argument value
|
||||
".har", ".eml", ".msg",
|
||||
})
|
||||
|
||||
|
||||
# Define function validate_upload_permission
|
||||
def validate_upload_permission(
|
||||
# Entry: test
|
||||
test: Test,
|
||||
# Entry: team
|
||||
team: TeamSide,
|
||||
# Entry: user_role
|
||||
user_role: str,
|
||||
) -> None:
|
||||
"""Validate that the user can upload evidence for the given team in the current state.
|
||||
@@ -52,35 +75,56 @@ def validate_upload_permission(
|
||||
PermissionViolation: If user lacks role to upload for this team.
|
||||
BusinessRuleViolation: If test state does not allow uploading for this team.
|
||||
"""
|
||||
# Check: user_role == "admin"
|
||||
if user_role == "admin":
|
||||
# Return control to caller
|
||||
return
|
||||
|
||||
# Check: team == TeamSide.red
|
||||
if team == TeamSide.red:
|
||||
# Check: user_role not in ("red_tech", "red_lead")
|
||||
if user_role not in ("red_tech", "red_lead"):
|
||||
# Raise PermissionViolation
|
||||
raise PermissionViolation(
|
||||
# Literal argument value
|
||||
"Only red_tech, red_lead or admin can upload red evidence"
|
||||
)
|
||||
# Check: test.state not in RED_EDITABLE_STATES
|
||||
if test.state not in RED_EDITABLE_STATES:
|
||||
# Raise BusinessRuleViolation
|
||||
raise BusinessRuleViolation(
|
||||
f"Cannot upload red evidence in '{test.state.value}' state "
|
||||
# Literal argument value
|
||||
"(allowed in: draft, red_executing)"
|
||||
)
|
||||
# Alternative: team == TeamSide.blue
|
||||
elif team == TeamSide.blue:
|
||||
# Check: user_role not in ("blue_tech", "blue_lead")
|
||||
if user_role not in ("blue_tech", "blue_lead"):
|
||||
# Raise PermissionViolation
|
||||
raise PermissionViolation(
|
||||
# Literal argument value
|
||||
"Only blue_tech, blue_lead or admin can upload blue evidence"
|
||||
)
|
||||
# Check: test.state not in BLUE_EDITABLE_STATES
|
||||
if test.state not in BLUE_EDITABLE_STATES:
|
||||
# Raise BusinessRuleViolation
|
||||
raise BusinessRuleViolation(
|
||||
f"Cannot upload blue evidence in '{test.state.value}' state "
|
||||
# Literal argument value
|
||||
"(allowed in: blue_evaluating)"
|
||||
)
|
||||
|
||||
|
||||
# Define function validate_delete_permission
|
||||
def validate_delete_permission(
|
||||
# Entry: test
|
||||
test: Test,
|
||||
# Entry: evidence
|
||||
evidence: Evidence,
|
||||
# Entry: user_role
|
||||
user_role: str,
|
||||
# Entry: user_id
|
||||
user_id: uuid.UUID,
|
||||
) -> None:
|
||||
"""Validate that the user can delete this evidence in the current state.
|
||||
@@ -88,80 +132,125 @@ def validate_delete_permission(
|
||||
Raises:
|
||||
PermissionViolation: If user cannot delete in this state or lacks permission.
|
||||
"""
|
||||
# Check: test.state in (TestState.in_review, TestState.validated, TestState....
|
||||
if test.state in (TestState.in_review, TestState.validated, TestState.rejected):
|
||||
# Raise PermissionViolation
|
||||
raise PermissionViolation(
|
||||
f"Cannot delete evidence when test is in '{test.state.value}' state"
|
||||
)
|
||||
|
||||
# Check: user_role == "admin"
|
||||
if user_role == "admin":
|
||||
# Return control to caller
|
||||
return
|
||||
|
||||
# Assign ev_team = evidence.team
|
||||
ev_team = evidence.team
|
||||
|
||||
# Check: ev_team == TeamSide.red
|
||||
if ev_team == TeamSide.red:
|
||||
# Check: test.state not in RED_EDITABLE_STATES
|
||||
if test.state not in RED_EDITABLE_STATES:
|
||||
# Raise PermissionViolation
|
||||
raise PermissionViolation(
|
||||
# Literal argument value
|
||||
"Cannot delete red evidence outside draft/red_executing"
|
||||
)
|
||||
# Check: user_role not in ("red_tech", "red_lead") and evidence.uploaded_by ...
|
||||
if user_role not in ("red_tech", "red_lead") and evidence.uploaded_by != user_id:
|
||||
# Raise PermissionViolation
|
||||
raise PermissionViolation(
|
||||
# Literal argument value
|
||||
"Not enough permissions to delete this evidence"
|
||||
)
|
||||
# Alternative: ev_team == TeamSide.blue
|
||||
elif ev_team == TeamSide.blue:
|
||||
# Check: test.state not in BLUE_EDITABLE_STATES
|
||||
if test.state not in BLUE_EDITABLE_STATES:
|
||||
# Raise PermissionViolation
|
||||
raise PermissionViolation(
|
||||
# Literal argument value
|
||||
"Cannot delete blue evidence outside blue_evaluating"
|
||||
)
|
||||
# Check: user_role not in ("blue_tech", "blue_lead") and evidence.uploaded_b...
|
||||
if user_role not in ("blue_tech", "blue_lead") and evidence.uploaded_by != user_id:
|
||||
# Raise PermissionViolation
|
||||
raise PermissionViolation(
|
||||
# Literal argument value
|
||||
"Not enough permissions to delete this evidence"
|
||||
)
|
||||
|
||||
|
||||
# Define function validate_file
|
||||
def validate_file(file_name: str, content_size: int) -> None:
|
||||
"""Validate file extension and size.
|
||||
|
||||
Raises:
|
||||
BusinessRuleViolation: If extension is not allowed or file exceeds size limit.
|
||||
"""
|
||||
# _, ext = os.path.splitext(file_name)
|
||||
_, ext = os.path.splitext(file_name)
|
||||
# Assign ext_lower = ext.lower() if ext else ""
|
||||
ext_lower = ext.lower() if ext else ""
|
||||
# Check: ext_lower not in ALLOWED_EXTENSIONS
|
||||
if ext_lower not in ALLOWED_EXTENSIONS:
|
||||
# Raise BusinessRuleViolation
|
||||
raise BusinessRuleViolation(
|
||||
f"File type '{ext}' is not allowed. "
|
||||
f"Permitted types: {', '.join(sorted(ALLOWED_EXTENSIONS))}"
|
||||
)
|
||||
# Check: content_size > MAX_UPLOAD_SIZE
|
||||
if content_size > MAX_UPLOAD_SIZE:
|
||||
# Raise BusinessRuleViolation
|
||||
raise BusinessRuleViolation(
|
||||
f"File exceeds maximum upload size of {MAX_UPLOAD_SIZE // (1024 * 1024)} MB"
|
||||
)
|
||||
|
||||
|
||||
# Define function list_evidence_for_test
|
||||
def list_evidence_for_test(
|
||||
# Entry: db
|
||||
db: Session,
|
||||
# Entry: test_id
|
||||
test_id: uuid.UUID,
|
||||
*,
|
||||
# Entry: team
|
||||
team: TeamSide | str | None = None,
|
||||
) -> list[Evidence]:
|
||||
"""Return evidence for a test, optionally filtered by team."""
|
||||
# Assign query = db.query(Evidence).filter(Evidence.test_id == test_id)
|
||||
query = db.query(Evidence).filter(Evidence.test_id == test_id)
|
||||
# Check: team is not None
|
||||
if team is not None:
|
||||
# Assign team_enum = TeamSide(team) if isinstance(team, str) else team
|
||||
team_enum = TeamSide(team) if isinstance(team, str) else team
|
||||
# Assign query = query.filter(Evidence.team == team_enum)
|
||||
query = query.filter(Evidence.team == team_enum)
|
||||
# Return query.order_by(Evidence.uploaded_at.desc()).all()
|
||||
return query.order_by(Evidence.uploaded_at.desc()).all()
|
||||
|
||||
|
||||
# Define function get_evidence_or_raise
|
||||
def get_evidence_or_raise(db: Session, evidence_id: uuid.UUID) -> Evidence:
|
||||
"""Fetch evidence by ID. Raises EntityNotFoundError if not found."""
|
||||
# Assign evidence = db.query(Evidence).filter(Evidence.id == evidence_id).first()
|
||||
evidence = db.query(Evidence).filter(Evidence.id == evidence_id).first()
|
||||
# Check: evidence is None
|
||||
if evidence is None:
|
||||
# Raise EntityNotFoundError
|
||||
raise EntityNotFoundError("Evidence", str(evidence_id))
|
||||
# Return evidence
|
||||
return evidence
|
||||
|
||||
|
||||
# Define function get_test_or_raise
|
||||
def get_test_or_raise(db: Session, test_id: uuid.UUID) -> Test:
|
||||
"""Fetch test by ID. Raises EntityNotFoundError if not found."""
|
||||
# Assign test = db.query(Test).filter(Test.id == test_id).first()
|
||||
test = db.query(Test).filter(Test.id == test_id).first()
|
||||
# Check: test is None
|
||||
if test is None:
|
||||
# Raise EntityNotFoundError
|
||||
raise EntityNotFoundError("Test", str(test_id))
|
||||
# Return test
|
||||
return test
|
||||
|
||||
@@ -7,32 +7,59 @@ This module is framework-agnostic: no FastAPI imports, no HTTPException,
|
||||
no ``db.commit()``.
|
||||
"""
|
||||
|
||||
# Enable future language features for compatibility
|
||||
from __future__ import annotations
|
||||
|
||||
# Import json
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
# Import Callable from collections.abc
|
||||
from collections.abc import Callable
|
||||
|
||||
# Import func, or_ from sqlalchemy
|
||||
from sqlalchemy import func, or_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import Query, Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Query, Session
|
||||
|
||||
# Import BusinessRuleViolation, EntityNotFoundError from app.domain.errors
|
||||
from app.domain.errors import BusinessRuleViolation, EntityNotFoundError
|
||||
|
||||
# Import Campaign, CampaignTest from app.models.campaign
|
||||
from app.models.campaign import Campaign, CampaignTest
|
||||
|
||||
# Import DetectionRule from app.models.detection_rule
|
||||
from app.models.detection_rule import DetectionRule
|
||||
from app.models.defensive_technique import DefensiveTechniqueMapping
|
||||
|
||||
# Import TechniqueStatus, TestState from app.models.enums
|
||||
from app.models.enums import TechniqueStatus, TestState
|
||||
|
||||
# Import Technique from app.models.technique
|
||||
from app.models.technique import Technique
|
||||
|
||||
# Import Test from app.models.test
|
||||
from app.models.test import Test
|
||||
|
||||
# Import TestDetectionResult from app.models.test_detection_result
|
||||
from app.models.test_detection_result import TestDetectionResult
|
||||
|
||||
# Import ThreatActor, ThreatActorTechnique from app.models.threat_actor
|
||||
from app.models.threat_actor import ThreatActor, ThreatActorTechnique
|
||||
|
||||
# Import escape_like from app.utils
|
||||
from app.utils import escape_like
|
||||
|
||||
# ── Constants ─────────────────────────────────────────────────────────
|
||||
|
||||
ATTACK_VERSION = "15"
|
||||
# Assign NAVIGATOR_VERSION = "5.0"
|
||||
NAVIGATOR_VERSION = "5.0"
|
||||
# Assign LAYER_VERSION = "4.5"
|
||||
LAYER_VERSION = "4.5"
|
||||
# Assign DOMAIN = "enterprise-attack"
|
||||
DOMAIN = "enterprise-attack"
|
||||
|
||||
# Assign STATUS_SCORE_MAP = {
|
||||
STATUS_SCORE_MAP: dict[TechniqueStatus, int] = {
|
||||
TechniqueStatus.validated: 100,
|
||||
TechniqueStatus.partial: 60,
|
||||
@@ -42,6 +69,7 @@ STATUS_SCORE_MAP: dict[TechniqueStatus, int] = {
|
||||
TechniqueStatus.review_required: 10,
|
||||
}
|
||||
|
||||
# Assign TEST_STATE_SCORE = {
|
||||
TEST_STATE_SCORE: dict[TestState, int] = {
|
||||
TestState.validated: 100,
|
||||
TestState.in_review: 70,
|
||||
@@ -56,74 +84,169 @@ TEST_STATE_SCORE: dict[TestState, int] = {
|
||||
|
||||
|
||||
def _score_to_color(score: int) -> str:
|
||||
"""Map a 0-100 score to a red-yellow-green colour hex."""
|
||||
"""Map a 0-100 score to a red-yellow-green colour hex.
|
||||
|
||||
Args:
|
||||
score (int): Coverage score between 0 and 100 inclusive.
|
||||
|
||||
Returns:
|
||||
str: Hex colour string representing the score tier.
|
||||
"""
|
||||
# Check: score <= 0
|
||||
if score <= 0:
|
||||
# Return "#d3d3d3"
|
||||
return "#d3d3d3"
|
||||
# Check: score <= 25
|
||||
if score <= 25:
|
||||
# Return "#ff6666"
|
||||
return "#ff6666"
|
||||
# Check: score <= 50
|
||||
if score <= 50:
|
||||
# Return "#ff9933"
|
||||
return "#ff9933"
|
||||
# Check: score <= 75
|
||||
if score <= 75:
|
||||
# Return "#ffff66"
|
||||
return "#ffff66"
|
||||
# Return "#66ff66"
|
||||
return "#66ff66"
|
||||
|
||||
|
||||
# Define function _build_layer_skeleton
|
||||
def _build_layer_skeleton(
|
||||
# Entry: name
|
||||
name: str,
|
||||
# Entry: description
|
||||
description: str,
|
||||
# Entry: gradient_colors
|
||||
gradient_colors: list[str] | None = None,
|
||||
) -> dict:
|
||||
"""Return a base layer dict compatible with ATT&CK Navigator."""
|
||||
"""Return a base layer dict compatible with ATT&CK Navigator.
|
||||
|
||||
Args:
|
||||
name (str): Human-readable name for the layer.
|
||||
description (str): Description text embedded in the layer metadata.
|
||||
gradient_colors (list[str] | None): Optional list of hex colour stops
|
||||
for the gradient; defaults to red-yellow-green if omitted.
|
||||
|
||||
Returns:
|
||||
dict: Skeleton layer dictionary with versions, domain, and empty
|
||||
techniques list.
|
||||
"""
|
||||
# Return {
|
||||
return {
|
||||
# Literal argument value
|
||||
"name": name,
|
||||
# Literal argument value
|
||||
"versions": {
|
||||
# Literal argument value
|
||||
"attack": ATTACK_VERSION,
|
||||
# Literal argument value
|
||||
"navigator": NAVIGATOR_VERSION,
|
||||
# Literal argument value
|
||||
"layer": LAYER_VERSION,
|
||||
},
|
||||
# Literal argument value
|
||||
"domain": DOMAIN,
|
||||
# Literal argument value
|
||||
"description": description,
|
||||
# Literal argument value
|
||||
"filters": {"platforms": ["windows", "linux", "macos"]},
|
||||
# Literal argument value
|
||||
"gradient": {
|
||||
# Literal argument value
|
||||
"colors": gradient_colors or ["#ff6666", "#ffff66", "#66ff66"],
|
||||
# Literal argument value
|
||||
"minValue": 0,
|
||||
# Literal argument value
|
||||
"maxValue": 100,
|
||||
},
|
||||
# Literal argument value
|
||||
"techniques": [],
|
||||
}
|
||||
|
||||
|
||||
# Define function _apply_filters
|
||||
def _apply_filters(
|
||||
query,
|
||||
model,
|
||||
# Entry: query
|
||||
query: Query, # type: ignore[type-arg]
|
||||
# Entry: model
|
||||
model: type,
|
||||
# Entry: platforms
|
||||
platforms: list[str] | None = None,
|
||||
# Entry: tactics
|
||||
tactics: list[str] | None = None,
|
||||
):
|
||||
"""Apply common platform and tactic filters to a technique query."""
|
||||
) -> Query: # type: ignore[type-arg]
|
||||
"""Apply common platform and tactic filters to a technique query.
|
||||
|
||||
Args:
|
||||
query (Query): Base SQLAlchemy query targeting a technique-like model.
|
||||
model (type): The SQLAlchemy model class that owns ``platforms`` and
|
||||
``tactic`` columns.
|
||||
platforms (list[str] | None): Optional list of platform names to
|
||||
filter by (OR-joined).
|
||||
tactics (list[str] | None): Optional list of tactic strings to
|
||||
filter by (OR-joined, case-insensitive substring match).
|
||||
|
||||
Returns:
|
||||
Query: The query with platform and tactic filters applied.
|
||||
"""
|
||||
# Check: platforms
|
||||
if platforms:
|
||||
# Assign platform_filters = [
|
||||
platform_filters = [
|
||||
model.platforms.op("@>")(json.dumps([p])) for p in platforms
|
||||
]
|
||||
# Assign query = query.filter(or_(*platform_filters))
|
||||
query = query.filter(or_(*platform_filters))
|
||||
# Check: tactics
|
||||
if tactics:
|
||||
# Assign tactic_filters = [
|
||||
tactic_filters = [
|
||||
model.tactic.ilike(f"%{escape_like(t)}%") for t in tactics
|
||||
]
|
||||
# Assign query = query.filter(or_(*tactic_filters))
|
||||
query = query.filter(or_(*tactic_filters))
|
||||
# Return query
|
||||
return query
|
||||
|
||||
|
||||
# Define function _format_tactic
|
||||
def _format_tactic(tactic_str: str | None) -> str:
|
||||
"""Normalize tactic string to ATT&CK Navigator format (kebab-case)."""
|
||||
"""Normalize tactic string to ATT&CK Navigator format (kebab-case).
|
||||
|
||||
Args:
|
||||
tactic_str (str | None): Raw tactic string, possibly comma-separated
|
||||
or mixed-case.
|
||||
|
||||
Returns:
|
||||
str: First tactic value lowercased and trimmed, or empty string if
|
||||
the input is falsy.
|
||||
"""
|
||||
# Check: not tactic_str
|
||||
if not tactic_str:
|
||||
# Return ""
|
||||
return ""
|
||||
# Return tactic_str.split(",")[0].strip().lower()
|
||||
return tactic_str.split(",")[0].strip().lower()
|
||||
|
||||
|
||||
# Define function _parse_csv
|
||||
def _parse_csv(value: str | None) -> list[str] | None:
|
||||
"""Split a comma-separated string into a trimmed list, or ``None``."""
|
||||
"""Split a comma-separated string into a trimmed list, or ``None``.
|
||||
|
||||
Args:
|
||||
value (str | None): Comma-separated string to split, or ``None``.
|
||||
|
||||
Returns:
|
||||
list[str] | None: Non-empty trimmed tokens, or ``None`` if the input
|
||||
is falsy or produces no tokens.
|
||||
"""
|
||||
# Check: not value
|
||||
if not value:
|
||||
# Return None
|
||||
return None
|
||||
# Return [v.strip() for v in value.split(",") if v.strip()]
|
||||
return [v.strip() for v in value.split(",") if v.strip()]
|
||||
|
||||
|
||||
@@ -131,132 +254,224 @@ def _parse_csv(value: str | None) -> list[str] | None:
|
||||
|
||||
|
||||
def build_coverage_layer(
|
||||
# Entry: db
|
||||
db: Session,
|
||||
*,
|
||||
# Entry: platforms
|
||||
platforms: str | None = None,
|
||||
# Entry: tactics
|
||||
tactics: str | None = None,
|
||||
# Entry: min_score
|
||||
min_score: int = 0,
|
||||
) -> dict:
|
||||
"""Coverage layer -- score based on ``status_global`` of each technique."""
|
||||
"""Coverage layer -- score based on ``status_global`` of each technique.
|
||||
|
||||
Args:
|
||||
db (Session): Active SQLAlchemy database session.
|
||||
platforms (str | None): Optional comma-separated platform names to
|
||||
filter techniques.
|
||||
tactics (str | None): Optional comma-separated tactic names to filter
|
||||
techniques.
|
||||
min_score (int): Minimum score threshold; techniques below this are
|
||||
omitted from the layer.
|
||||
|
||||
Returns:
|
||||
dict: ATT&CK Navigator-compatible layer dictionary.
|
||||
"""
|
||||
# Assign layer = _build_layer_skeleton("Aegis Coverage", "Coverage layer generated b...
|
||||
layer = _build_layer_skeleton("Aegis Coverage", "Coverage layer generated by Aegis")
|
||||
|
||||
# Assign query = _apply_filters(
|
||||
query = _apply_filters(
|
||||
db.query(Technique), Technique,
|
||||
_parse_csv(platforms), _parse_csv(tactics),
|
||||
)
|
||||
# Assign techniques = query.all()
|
||||
techniques = query.all()
|
||||
|
||||
# Bulk-fetch test counts and rule counts to avoid N+1
|
||||
tech_ids = [t.id for t in techniques]
|
||||
# Assign mitre_ids = [t.mitre_id for t in techniques]
|
||||
mitre_ids = [t.mitre_id for t in techniques]
|
||||
|
||||
# Assign test_counts = dict(
|
||||
test_counts = dict(
|
||||
db.query(Test.technique_id, func.count(Test.id))
|
||||
# Chain .filter() call
|
||||
.filter(Test.technique_id.in_(tech_ids), Test.state == TestState.validated)
|
||||
# Chain .group_by() call
|
||||
.group_by(Test.technique_id)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
) if tech_ids else {}
|
||||
|
||||
# Assign rule_counts = dict(
|
||||
rule_counts = dict(
|
||||
db.query(DetectionRule.mitre_technique_id, func.count(DetectionRule.id))
|
||||
# Chain .filter() call
|
||||
.filter(DetectionRule.mitre_technique_id.in_(mitre_ids))
|
||||
# Chain .group_by() call
|
||||
.group_by(DetectionRule.mitre_technique_id)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
) if mitre_ids else {}
|
||||
|
||||
# Iterate over techniques
|
||||
for tech in techniques:
|
||||
# Assign score = STATUS_SCORE_MAP.get(tech.status_global, 0)
|
||||
score = STATUS_SCORE_MAP.get(tech.status_global, 0)
|
||||
# Check: score < min_score
|
||||
if score < min_score:
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
|
||||
# Assign tc = test_counts.get(tech.id, 0)
|
||||
tc = test_counts.get(tech.id, 0)
|
||||
# Assign rc = rule_counts.get(tech.mitre_id, 0)
|
||||
rc = rule_counts.get(tech.mitre_id, 0)
|
||||
|
||||
# Assign metadata = [
|
||||
metadata = [
|
||||
{"name": "tests_count", "value": str(tc)},
|
||||
{"name": "detection_rules", "value": str(rc)},
|
||||
]
|
||||
# Check: tech.last_review_date
|
||||
if tech.last_review_date:
|
||||
# Call metadata.append()
|
||||
metadata.append(
|
||||
{"name": "last_validated", "value": tech.last_review_date.strftime("%Y-%m-%d")}
|
||||
)
|
||||
|
||||
# Assign comment_parts = [
|
||||
comment_parts = [
|
||||
f"Status: {tech.status_global.value}",
|
||||
f"{tc} tests validated",
|
||||
f"{rc} detection rules",
|
||||
]
|
||||
|
||||
# layer["techniques"].append({
|
||||
layer["techniques"].append({
|
||||
# Literal argument value
|
||||
"techniqueID": tech.mitre_id,
|
||||
# Literal argument value
|
||||
"tactic": _format_tactic(tech.tactic),
|
||||
# Literal argument value
|
||||
"color": _score_to_color(score),
|
||||
# Literal argument value
|
||||
"score": score,
|
||||
# Literal argument value
|
||||
"comment": " - ".join(comment_parts),
|
||||
# Literal argument value
|
||||
"enabled": True,
|
||||
# Literal argument value
|
||||
"metadata": metadata,
|
||||
})
|
||||
|
||||
# Return layer
|
||||
return layer
|
||||
|
||||
|
||||
# Define function build_threat_actor_layer
|
||||
def build_threat_actor_layer(
|
||||
# Entry: db
|
||||
db: Session,
|
||||
# Entry: actor_id
|
||||
actor_id: str,
|
||||
*,
|
||||
# Entry: platforms
|
||||
platforms: str | None = None,
|
||||
# Entry: tactics
|
||||
tactics: str | None = None,
|
||||
# Entry: min_score
|
||||
min_score: int = 0,
|
||||
) -> dict:
|
||||
"""Threat actor layer -- techniques used by an actor with coverage colour.
|
||||
|
||||
Raises :class:`EntityNotFoundError` if the actor does not exist.
|
||||
|
||||
Args:
|
||||
db (Session): Active SQLAlchemy database session.
|
||||
actor_id (str): UUID string identifying the threat actor.
|
||||
platforms (str | None): Optional comma-separated platform names to
|
||||
filter techniques.
|
||||
tactics (str | None): Optional comma-separated tactic names to filter
|
||||
techniques.
|
||||
min_score (int): Minimum score threshold for actor techniques.
|
||||
|
||||
Returns:
|
||||
dict: ATT&CK Navigator-compatible layer dictionary coloured by
|
||||
coverage status for the specified actor.
|
||||
"""
|
||||
# Assign actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first()
|
||||
actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first()
|
||||
# Check: not actor
|
||||
if not actor:
|
||||
# Raise EntityNotFoundError
|
||||
raise EntityNotFoundError("ThreatActor", actor_id)
|
||||
|
||||
# Assign layer = _build_layer_skeleton(
|
||||
layer = _build_layer_skeleton(
|
||||
f"Threat Actor: {actor.name}",
|
||||
f"Techniques used by {actor.name} with coverage overlay",
|
||||
# Keyword argument: gradient_colors
|
||||
gradient_colors=["#808080", "#ff6666", "#66ff66"],
|
||||
)
|
||||
|
||||
# Assign actor_technique_ids = {
|
||||
actor_technique_ids = {
|
||||
row.technique_id
|
||||
for row in db.query(ThreatActorTechnique.technique_id)
|
||||
# Chain .filter() call
|
||||
.filter(ThreatActorTechnique.threat_actor_id == actor.id)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
}
|
||||
# Check: not actor_technique_ids
|
||||
if not actor_technique_ids:
|
||||
# Return layer
|
||||
return layer
|
||||
|
||||
# Assign query = _apply_filters(
|
||||
query = _apply_filters(
|
||||
db.query(Technique), Technique,
|
||||
_parse_csv(platforms), _parse_csv(tactics),
|
||||
)
|
||||
# Assign techniques = query.all()
|
||||
techniques = query.all()
|
||||
|
||||
# Bulk-fetch metadata for actor techniques only
|
||||
test_counts = dict(
|
||||
db.query(Test.technique_id, func.count(Test.id))
|
||||
# Chain .filter() call
|
||||
.filter(Test.technique_id.in_(actor_technique_ids), Test.state == TestState.validated)
|
||||
# Chain .group_by() call
|
||||
.group_by(Test.technique_id)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
# Assign actor_mitre_ids = [t.mitre_id for t in techniques if t.id in actor_technique_ids]
|
||||
actor_mitre_ids = [t.mitre_id for t in techniques if t.id in actor_technique_ids]
|
||||
# Assign rule_counts = dict(
|
||||
rule_counts = dict(
|
||||
db.query(DetectionRule.mitre_technique_id, func.count(DetectionRule.id))
|
||||
# Chain .filter() call
|
||||
.filter(DetectionRule.mitre_technique_id.in_(actor_mitre_ids))
|
||||
# Chain .group_by() call
|
||||
.group_by(DetectionRule.mitre_technique_id)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
) if actor_mitre_ids else {}
|
||||
|
||||
# Iterate over techniques
|
||||
for tech in techniques:
|
||||
# Assign is_actor_technique = tech.id in actor_technique_ids
|
||||
is_actor_technique = tech.id in actor_technique_ids
|
||||
# Assign score = STATUS_SCORE_MAP.get(tech.status_global, 0) if is_actor_technique e...
|
||||
score = STATUS_SCORE_MAP.get(tech.status_global, 0) if is_actor_technique else 0
|
||||
|
||||
# Check: is_actor_technique and score < min_score
|
||||
if is_actor_technique and score < min_score:
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
|
||||
# Only include techniques actually used by this actor — skip the rest
|
||||
@@ -284,14 +499,20 @@ def build_threat_actor_layer(
|
||||
"metadata": metadata,
|
||||
})
|
||||
|
||||
# Return layer
|
||||
return layer
|
||||
|
||||
|
||||
# Define function build_detection_rules_layer
|
||||
def build_detection_rules_layer(
|
||||
# Entry: db
|
||||
db: Session,
|
||||
*,
|
||||
# Entry: platforms
|
||||
platforms: str | None = None,
|
||||
# Entry: tactics
|
||||
tactics: str | None = None,
|
||||
# Entry: min_score
|
||||
min_score: int = 0,
|
||||
) -> dict:
|
||||
"""Detection rules layer -- score based on absolute rule count per technique.
|
||||
@@ -305,28 +526,40 @@ def build_detection_rules_layer(
|
||||
4+ rules → green (score 100)
|
||||
"""
|
||||
layer = _build_layer_skeleton(
|
||||
# Literal argument value
|
||||
"Detection Rules Coverage",
|
||||
"Number of active detection rules per technique",
|
||||
)
|
||||
|
||||
# Assign query = _apply_filters(
|
||||
query = _apply_filters(
|
||||
db.query(Technique), Technique,
|
||||
_parse_csv(platforms), _parse_csv(tactics),
|
||||
)
|
||||
# Assign techniques = query.all()
|
||||
techniques = query.all()
|
||||
|
||||
# Assign rule_counts = dict(
|
||||
rule_counts = dict(
|
||||
db.query(DetectionRule.mitre_technique_id, func.count(DetectionRule.id))
|
||||
# Chain .filter() call
|
||||
.filter(DetectionRule.is_active == True) # noqa: E712
|
||||
# Chain .group_by() call
|
||||
.group_by(DetectionRule.mitre_technique_id)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
|
||||
# Assign evaluated_counts = dict(
|
||||
evaluated_counts = dict(
|
||||
db.query(DetectionRule.mitre_technique_id, func.count(TestDetectionResult.id))
|
||||
# Chain .join() call
|
||||
.join(TestDetectionResult, TestDetectionResult.detection_rule_id == DetectionRule.id)
|
||||
# Chain .filter() call
|
||||
.filter(TestDetectionResult.triggered.isnot(None))
|
||||
# Chain .group_by() call
|
||||
.group_by(DetectionRule.mitre_technique_id)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
|
||||
@@ -334,12 +567,16 @@ def build_detection_rules_layer(
|
||||
RULES_FOR_FULL_COVERAGE = 4
|
||||
|
||||
for tech in techniques:
|
||||
# Assign total_rules = rule_counts.get(tech.mitre_id, 0)
|
||||
total_rules = rule_counts.get(tech.mitre_id, 0)
|
||||
# Assign evaluated_rules = evaluated_counts.get(tech.mitre_id, 0)
|
||||
evaluated_rules = evaluated_counts.get(tech.mitre_id, 0)
|
||||
|
||||
score = min(int((total_rules / RULES_FOR_FULL_COVERAGE) * 100), 100)
|
||||
|
||||
# Check: score < min_score
|
||||
if score < min_score:
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
|
||||
rule_word = "rule" if total_rules == 1 else "rules"
|
||||
@@ -347,113 +584,194 @@ def build_detection_rules_layer(
|
||||
comment = f"{total_rules} active {rule_word}{eval_note}"
|
||||
|
||||
layer["techniques"].append({
|
||||
# Literal argument value
|
||||
"techniqueID": tech.mitre_id,
|
||||
# Literal argument value
|
||||
"tactic": _format_tactic(tech.tactic),
|
||||
# Literal argument value
|
||||
"color": _score_to_color(score),
|
||||
# Literal argument value
|
||||
"score": score,
|
||||
"comment": comment,
|
||||
"enabled": True,
|
||||
# Literal argument value
|
||||
"metadata": [
|
||||
{"name": "total_rules", "value": str(total_rules)},
|
||||
{"name": "evaluated_rules", "value": str(evaluated_rules)},
|
||||
],
|
||||
})
|
||||
|
||||
# Return layer
|
||||
return layer
|
||||
|
||||
|
||||
# Define function build_campaign_layer
|
||||
def build_campaign_layer(
|
||||
# Entry: db
|
||||
db: Session,
|
||||
# Entry: campaign_id
|
||||
campaign_id: str,
|
||||
*,
|
||||
# Entry: platforms
|
||||
platforms: str | None = None,
|
||||
# Entry: tactics
|
||||
tactics: str | None = None,
|
||||
# Entry: min_score
|
||||
min_score: int = 0,
|
||||
) -> dict:
|
||||
"""Campaign layer -- techniques in a campaign, coloured by test state.
|
||||
|
||||
Raises :class:`EntityNotFoundError` if the campaign does not exist.
|
||||
|
||||
Args:
|
||||
db (Session): Active SQLAlchemy database session.
|
||||
campaign_id (str): UUID string identifying the campaign.
|
||||
platforms (str | None): Optional comma-separated platform names to
|
||||
filter techniques.
|
||||
tactics (str | None): Optional comma-separated tactic names to filter
|
||||
techniques.
|
||||
min_score (int): Minimum score threshold for techniques in the layer.
|
||||
|
||||
Returns:
|
||||
dict: ATT&CK Navigator-compatible layer dictionary where each
|
||||
technique colour reflects the best test state within the campaign.
|
||||
"""
|
||||
# Assign campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||
# Check: not campaign
|
||||
if not campaign:
|
||||
# Raise EntityNotFoundError
|
||||
raise EntityNotFoundError("Campaign", campaign_id)
|
||||
|
||||
# Assign layer = _build_layer_skeleton(
|
||||
layer = _build_layer_skeleton(
|
||||
f"Campaign: {campaign.name}",
|
||||
f"Progress of campaign '{campaign.name}'",
|
||||
)
|
||||
|
||||
# Assign campaign_tests = (
|
||||
campaign_tests = (
|
||||
db.query(CampaignTest)
|
||||
# Chain .filter() call
|
||||
.filter(CampaignTest.campaign_id == campaign.id)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
# Check: not campaign_tests
|
||||
if not campaign_tests:
|
||||
# Return layer
|
||||
return layer
|
||||
|
||||
# Assign test_ids = [ct.test_id for ct in campaign_tests]
|
||||
test_ids = [ct.test_id for ct in campaign_tests]
|
||||
# Assign tests = db.query(Test).filter(Test.id.in_(test_ids)).all()
|
||||
tests = db.query(Test).filter(Test.id.in_(test_ids)).all()
|
||||
# Assign test_map = {t.id: t for t in tests}
|
||||
test_map = {t.id: t for t in tests}
|
||||
|
||||
# Assign technique_ids = {t.technique_id for t in tests if t.technique_id}
|
||||
technique_ids = {t.technique_id for t in tests if t.technique_id}
|
||||
# Assign techniques = db.query(Technique).filter(Technique.id.in_(technique_ids)).all()
|
||||
techniques = db.query(Technique).filter(Technique.id.in_(technique_ids)).all()
|
||||
# Assign tech_map = {t.id: t for t in techniques}
|
||||
tech_map = {t.id: t for t in techniques}
|
||||
|
||||
# Group tests by technique, keeping the best state score
|
||||
tech_scores: dict = {}
|
||||
# Iterate over campaign_tests
|
||||
for ct in campaign_tests:
|
||||
# Assign test = test_map.get(ct.test_id)
|
||||
test = test_map.get(ct.test_id)
|
||||
# Check: not test
|
||||
if not test:
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
# Assign tech = tech_map.get(test.technique_id)
|
||||
tech = tech_map.get(test.technique_id)
|
||||
# Check: not tech
|
||||
if not tech:
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
|
||||
# Assign state_score = TEST_STATE_SCORE.get(test.state, 0)
|
||||
state_score = TEST_STATE_SCORE.get(test.state, 0)
|
||||
# Check: tech.mitre_id not in tech_scores
|
||||
if tech.mitre_id not in tech_scores:
|
||||
# Assign tech_scores[tech.mitre_id] = {
|
||||
tech_scores[tech.mitre_id] = {
|
||||
# Literal argument value
|
||||
"technique": tech,
|
||||
# Literal argument value
|
||||
"max_score": state_score,
|
||||
# Literal argument value
|
||||
"tests": [],
|
||||
}
|
||||
# Fallback: handle remaining cases
|
||||
else:
|
||||
# Assign tech_scores[tech.mitre_id]["max_score"] = max(
|
||||
tech_scores[tech.mitre_id]["max_score"] = max(
|
||||
tech_scores[tech.mitre_id]["max_score"], state_score,
|
||||
)
|
||||
# tech_scores[tech.mitre_id]["tests"].append(test)
|
||||
tech_scores[tech.mitre_id]["tests"].append(test)
|
||||
|
||||
# Assign platform_list = _parse_csv(platforms)
|
||||
platform_list = _parse_csv(platforms)
|
||||
# Assign tactic_list = _parse_csv(tactics)
|
||||
tactic_list = _parse_csv(tactics)
|
||||
|
||||
# Iterate over tech_scores.items()
|
||||
for mitre_id, info in tech_scores.items():
|
||||
# Assign tech = info["technique"]
|
||||
tech = info["technique"]
|
||||
# Assign score = info["max_score"]
|
||||
score = info["max_score"]
|
||||
|
||||
# Check: platform_list
|
||||
if platform_list:
|
||||
# Assign tech_platforms = tech.platforms or []
|
||||
tech_platforms = tech.platforms or []
|
||||
# Check: not any(p in tech_platforms for p in platform_list)
|
||||
if not any(p in tech_platforms for p in platform_list):
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
# Check: tactic_list
|
||||
if tactic_list:
|
||||
# Assign tech_tactics = [t.strip() for t in (tech.tactic or "").lower().split(",")]
|
||||
tech_tactics = [t.strip() for t in (tech.tactic or "").lower().split(",")]
|
||||
# Check: not any(t in tech_tactics for t in tactic_list)
|
||||
if not any(t in tech_tactics for t in tactic_list):
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
# Check: score < min_score
|
||||
if score < min_score:
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
|
||||
# Assign test_states = [t.state.value for t in info["tests"]]
|
||||
test_states = [t.state.value for t in info["tests"]]
|
||||
# layer["techniques"].append({
|
||||
layer["techniques"].append({
|
||||
# Literal argument value
|
||||
"techniqueID": mitre_id,
|
||||
# Literal argument value
|
||||
"tactic": _format_tactic(tech.tactic),
|
||||
# Literal argument value
|
||||
"color": _score_to_color(score),
|
||||
# Literal argument value
|
||||
"score": score,
|
||||
# Literal argument value
|
||||
"comment": f"Campaign tests: {', '.join(test_states)}",
|
||||
# Literal argument value
|
||||
"enabled": True,
|
||||
# Literal argument value
|
||||
"metadata": [
|
||||
{"name": "campaign_tests", "value": str(len(info["tests"]))},
|
||||
{"name": "best_state", "value": max(test_states) if test_states else "none"},
|
||||
],
|
||||
})
|
||||
|
||||
# Return layer
|
||||
return layer
|
||||
|
||||
|
||||
@@ -470,67 +788,143 @@ def build_campaign_layer(
|
||||
class _LayerRegistry:
|
||||
"""Extensible registry that maps layer type names to builder functions."""
|
||||
|
||||
# Assign __slots__ = ("_simple", "_with_id")
|
||||
__slots__ = ("_simple", "_with_id")
|
||||
|
||||
# Define function __init__
|
||||
def __init__(self) -> None:
|
||||
# Assign self._simple = {}
|
||||
self._simple: dict[str, object] = {}
|
||||
# Assign self._with_id = {}
|
||||
self._with_id: dict[str, object] = {}
|
||||
|
||||
def register(self, name: str, builder, *, requires_id: bool = False) -> None:
|
||||
# Define function register
|
||||
def register(self, name: str, builder: Callable[..., dict], *, requires_id: bool = False) -> None:
|
||||
"""Register a builder function under *name*.
|
||||
|
||||
Args:
|
||||
name (str): Unique layer type identifier.
|
||||
builder (Callable[..., dict]): Layer builder function.
|
||||
requires_id (bool): Whether the builder needs a positional
|
||||
``layer_id`` argument.
|
||||
"""
|
||||
# Assign target = self._with_id if requires_id else self._simple
|
||||
target = self._with_id if requires_id else self._simple
|
||||
# Assign target[name] = builder
|
||||
target[name] = builder
|
||||
|
||||
# Apply the @property decorator
|
||||
@property
|
||||
# Define function supported_types
|
||||
def supported_types(self) -> set[str]:
|
||||
"""Return the set of all registered layer type names.
|
||||
|
||||
Returns:
|
||||
set[str]: Union of simple and entity-bound layer type names.
|
||||
"""
|
||||
# Return set(self._simple) | set(self._with_id)
|
||||
return set(self._simple) | set(self._with_id)
|
||||
|
||||
# Define function build
|
||||
def build(
|
||||
self,
|
||||
# Entry: db
|
||||
db: Session,
|
||||
# Entry: layer_type
|
||||
layer_type: str,
|
||||
*,
|
||||
# Entry: layer_id
|
||||
layer_id: str | None = None,
|
||||
# Entry: platforms
|
||||
platforms: str | None = None,
|
||||
# Entry: tactics
|
||||
tactics: str | None = None,
|
||||
# Entry: min_score
|
||||
min_score: int = 0,
|
||||
) -> dict:
|
||||
"""Dispatch to the registered builder for *layer_type*.
|
||||
|
||||
Args:
|
||||
db (Session): Active SQLAlchemy database session.
|
||||
layer_type (str): Registered layer type name.
|
||||
layer_id (str | None): Entity UUID for entity-bound layer types.
|
||||
platforms (str | None): Optional comma-separated platform filter.
|
||||
tactics (str | None): Optional comma-separated tactic filter.
|
||||
min_score (int): Minimum score threshold.
|
||||
|
||||
Returns:
|
||||
dict: ATT&CK Navigator-compatible layer dictionary.
|
||||
"""
|
||||
# Assign kwargs = dict(platforms=platforms, tactics=tactics, min_score=min_score)
|
||||
kwargs = dict(platforms=platforms, tactics=tactics, min_score=min_score)
|
||||
|
||||
# Check: layer_type in self._simple
|
||||
if layer_type in self._simple:
|
||||
# Return self._simple[layer_type](db, **kwargs)
|
||||
return self._simple[layer_type](db, **kwargs)
|
||||
|
||||
# Check: layer_type in self._with_id
|
||||
if layer_type in self._with_id:
|
||||
# Check: not layer_id
|
||||
if not layer_id:
|
||||
# Raise BusinessRuleViolation
|
||||
raise BusinessRuleViolation(
|
||||
f"layer_id is required for '{layer_type}' layer"
|
||||
)
|
||||
# Return self._with_id[layer_type](db, layer_id, **kwargs)
|
||||
return self._with_id[layer_type](db, layer_id, **kwargs)
|
||||
|
||||
# Raise BusinessRuleViolation
|
||||
raise BusinessRuleViolation(f"Unknown layer type: {layer_type}")
|
||||
|
||||
|
||||
# Assign LAYER_REGISTRY = _LayerRegistry()
|
||||
LAYER_REGISTRY = _LayerRegistry()
|
||||
|
||||
# Call LAYER_REGISTRY.register()
|
||||
LAYER_REGISTRY.register("coverage", build_coverage_layer)
|
||||
# Call LAYER_REGISTRY.register()
|
||||
LAYER_REGISTRY.register("detection-rules", build_detection_rules_layer)
|
||||
# Call LAYER_REGISTRY.register()
|
||||
LAYER_REGISTRY.register("threat-actor", build_threat_actor_layer, requires_id=True)
|
||||
# Call LAYER_REGISTRY.register()
|
||||
LAYER_REGISTRY.register("campaign", build_campaign_layer, requires_id=True)
|
||||
|
||||
# Assign SUPPORTED_LAYER_TYPES = LAYER_REGISTRY.supported_types # snapshot of built-in types
|
||||
SUPPORTED_LAYER_TYPES = LAYER_REGISTRY.supported_types # snapshot of built-in types
|
||||
|
||||
|
||||
def register_layer(name: str, builder, *, requires_id: bool = False) -> None:
|
||||
"""Public API to register a new heatmap layer type at import time."""
|
||||
# Define function register_layer
|
||||
def register_layer(name: str, builder: Callable[..., dict], *, requires_id: bool = False) -> None:
|
||||
"""Register a new heatmap layer type at import time.
|
||||
|
||||
Args:
|
||||
name (str): Unique identifier for the layer type used in API requests.
|
||||
builder (Callable[..., dict]): Function that builds the layer dict;
|
||||
must accept ``(db, *, platforms, tactics, min_score)`` and
|
||||
optionally a positional ``layer_id`` when ``requires_id`` is
|
||||
``True``.
|
||||
requires_id (bool): Set to ``True`` when the builder needs a
|
||||
``layer_id`` argument (e.g. threat-actor, campaign layers).
|
||||
"""
|
||||
# Call LAYER_REGISTRY.register()
|
||||
LAYER_REGISTRY.register(name, builder, requires_id=requires_id)
|
||||
|
||||
|
||||
# Define function build_navigator_export
|
||||
def build_navigator_export(
|
||||
# Entry: db
|
||||
db: Session,
|
||||
# Entry: layer_type
|
||||
layer_type: str,
|
||||
*,
|
||||
# Entry: layer_id
|
||||
layer_id: str | None = None,
|
||||
# Entry: platforms
|
||||
platforms: str | None = None,
|
||||
# Entry: tactics
|
||||
tactics: str | None = None,
|
||||
# Entry: min_score
|
||||
min_score: int = 0,
|
||||
) -> dict:
|
||||
"""Build a heatmap layer dict by type name.
|
||||
@@ -539,8 +933,23 @@ def build_navigator_export(
|
||||
missing ``layer_id``. Raises :class:`EntityNotFoundError` when
|
||||
an entity-bound layer (threat-actor, campaign) references a
|
||||
non-existent record.
|
||||
|
||||
Args:
|
||||
db (Session): Active SQLAlchemy database session.
|
||||
layer_type (str): Registered layer type name (e.g. ``"coverage"``,
|
||||
``"threat-actor"``).
|
||||
layer_id (str | None): Entity UUID required for entity-bound layer
|
||||
types such as ``"threat-actor"`` and ``"campaign"``.
|
||||
platforms (str | None): Optional comma-separated platform filter.
|
||||
tactics (str | None): Optional comma-separated tactic filter.
|
||||
min_score (int): Minimum score; techniques below this are excluded.
|
||||
|
||||
Returns:
|
||||
dict: ATT&CK Navigator-compatible layer dictionary.
|
||||
"""
|
||||
# Return LAYER_REGISTRY.build(
|
||||
return LAYER_REGISTRY.build(
|
||||
db, layer_type,
|
||||
# Keyword argument: layer_id
|
||||
layer_id=layer_id, platforms=platforms, tactics=tactics, min_score=min_score,
|
||||
)
|
||||
|
||||
@@ -9,18 +9,34 @@ RSS feeds and parses them with the standard-library :mod:`xml.etree`
|
||||
parser. No LLMs or paid APIs are used.
|
||||
"""
|
||||
|
||||
# Import logging
|
||||
import logging
|
||||
|
||||
# Import re
|
||||
import re
|
||||
import defusedxml.ElementTree as ET
|
||||
|
||||
# Import datetime from datetime
|
||||
from datetime import datetime
|
||||
|
||||
# Import defusedxml.ElementTree
|
||||
import defusedxml.ElementTree as ET # noqa: N817 — ET is the universal stdlib alias for ElementTree
|
||||
|
||||
# Import requests
|
||||
import requests as _requests
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import IntelItem from app.models.intel
|
||||
from app.models.intel import IntelItem
|
||||
|
||||
# Import Technique from app.models.technique
|
||||
from app.models.technique import Technique
|
||||
|
||||
# Import log_action from app.services.audit_service
|
||||
from app.services.audit_service import log_action
|
||||
|
||||
# Assign logger = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -29,7 +45,9 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
RSS_FEEDS: list[dict[str, str]] = [
|
||||
{
|
||||
# Literal argument value
|
||||
"name": "CISA Alerts",
|
||||
# Literal argument value
|
||||
"url": "https://www.cisa.gov/cybersecurity-advisories/all.xml",
|
||||
},
|
||||
{
|
||||
@@ -37,19 +55,27 @@ RSS_FEEDS: list[dict[str, str]] = [
|
||||
"url": "https://feeds.feedburner.com/Securityweek",
|
||||
},
|
||||
{
|
||||
# Literal argument value
|
||||
"name": "SANS ISC",
|
||||
# Literal argument value
|
||||
"url": "https://isc.sans.edu/rssfeed.xml",
|
||||
},
|
||||
{
|
||||
# Literal argument value
|
||||
"name": "BleepingComputer",
|
||||
# Literal argument value
|
||||
"url": "https://www.bleepingcomputer.com/feed/",
|
||||
},
|
||||
{
|
||||
# Literal argument value
|
||||
"name": "The Hacker News",
|
||||
# Literal argument value
|
||||
"url": "https://feeds.feedburner.com/TheHackersNews",
|
||||
},
|
||||
{
|
||||
# Literal argument value
|
||||
"name": "Krebs on Security",
|
||||
# Literal argument value
|
||||
"url": "https://krebsonsecurity.com/feed/",
|
||||
},
|
||||
]
|
||||
@@ -73,49 +99,81 @@ def _fetch_feed(url: str) -> list[dict[str, str]]:
|
||||
Each entry is a dict with keys ``title``, ``link``, and ``description``.
|
||||
Returns an empty list on any error so the scan can continue.
|
||||
"""
|
||||
# Attempt the following; catch errors below
|
||||
try:
|
||||
# Assign resp = _requests.get(url, timeout=_FEED_TIMEOUT, headers={
|
||||
resp = _requests.get(url, timeout=_FEED_TIMEOUT, headers={
|
||||
# Literal argument value
|
||||
"User-Agent": "AegisPlatform/1.0 IntelScan",
|
||||
})
|
||||
# Call resp.raise_for_status()
|
||||
resp.raise_for_status()
|
||||
# Handle Exception
|
||||
except Exception as exc:
|
||||
# Log warning: "Failed to fetch feed %s: %s", url, exc
|
||||
logger.warning("Failed to fetch feed %s: %s", url, exc)
|
||||
# Return []
|
||||
return []
|
||||
|
||||
# Attempt the following; catch errors below
|
||||
try:
|
||||
# Assign root = ET.fromstring(resp.content)
|
||||
root = ET.fromstring(resp.content)
|
||||
# Handle ET.ParseError
|
||||
except ET.ParseError as exc:
|
||||
# Log warning: "Failed to parse feed %s: %s", url, exc
|
||||
logger.warning("Failed to parse feed %s: %s", url, exc)
|
||||
# Return []
|
||||
return []
|
||||
|
||||
# Assign entries = []
|
||||
entries: list[dict[str, str]] = []
|
||||
|
||||
# RSS 2.0 format: <channel><item>...
|
||||
for item in root.iter("item"):
|
||||
# Assign title_el = item.find("title")
|
||||
title_el = item.find("title")
|
||||
# Assign link_el = item.find("link")
|
||||
link_el = item.find("link")
|
||||
# Assign desc_el = item.find("description")
|
||||
desc_el = item.find("description")
|
||||
# Call entries.append()
|
||||
entries.append({
|
||||
# Literal argument value
|
||||
"title": title_el.text.strip() if title_el is not None and title_el.text else "",
|
||||
# Literal argument value
|
||||
"link": link_el.text.strip() if link_el is not None and link_el.text else "",
|
||||
# Literal argument value
|
||||
"description": desc_el.text.strip() if desc_el is not None and desc_el.text else "",
|
||||
})
|
||||
|
||||
# Atom format: <feed><entry>...
|
||||
ns = {"atom": "http://www.w3.org/2005/Atom"}
|
||||
# Iterate over root.iter("{http
|
||||
for entry in root.iter("{http://www.w3.org/2005/Atom}entry"):
|
||||
# Assign title_el = entry.find("atom:title", ns)
|
||||
title_el = entry.find("atom:title", ns)
|
||||
# Assign link_el = entry.find("atom:link", ns)
|
||||
link_el = entry.find("atom:link", ns)
|
||||
# Assign summary_el = entry.find("atom:summary", ns)
|
||||
summary_el = entry.find("atom:summary", ns)
|
||||
# Assign link_href = ""
|
||||
link_href = ""
|
||||
# Check: link_el is not None
|
||||
if link_el is not None:
|
||||
# Assign link_href = link_el.get("href", "")
|
||||
link_href = link_el.get("href", "")
|
||||
# Call entries.append()
|
||||
entries.append({
|
||||
# Literal argument value
|
||||
"title": title_el.text.strip() if title_el is not None and title_el.text else "",
|
||||
# Literal argument value
|
||||
"link": link_href.strip(),
|
||||
# Literal argument value
|
||||
"description": summary_el.text.strip() if summary_el is not None and summary_el.text else "",
|
||||
})
|
||||
|
||||
# Return entries
|
||||
return entries
|
||||
|
||||
|
||||
@@ -147,6 +205,7 @@ def _entry_matches(
|
||||
name_patterns: list[re.Pattern],
|
||||
) -> bool:
|
||||
"""Return True if any pattern matches the entry's title or description."""
|
||||
# Assign text = f"{entry.get('title', '')} {entry.get('description', '')}"
|
||||
text = f"{entry.get('title', '')} {entry.get('description', '')}"
|
||||
return any(p.search(text) for p in id_patterns + name_patterns)
|
||||
|
||||
@@ -164,20 +223,23 @@ def scan_intel(db: Session) -> dict:
|
||||
db : Session
|
||||
Active SQLAlchemy database session.
|
||||
|
||||
Returns
|
||||
Returns:
|
||||
-------
|
||||
dict
|
||||
Summary with keys ``new_items``, ``duplicates_skipped``,
|
||||
``techniques_flagged``, ``feeds_checked``.
|
||||
"""
|
||||
# Log info: "Intel scan starting..."
|
||||
logger.info("Intel scan starting...")
|
||||
|
||||
# 1. Load all active techniques
|
||||
techniques = (
|
||||
db.query(Technique)
|
||||
# Chain .order_by() call
|
||||
.order_by(Technique.mitre_id)
|
||||
.all()
|
||||
)
|
||||
# Log info: "Scanning %d techniques against %d feeds", len(tec
|
||||
logger.info("Scanning %d techniques against %d feeds", len(techniques), len(RSS_FEEDS))
|
||||
|
||||
# 2. Pre-load all existing intel URLs for dedup
|
||||
@@ -187,24 +249,36 @@ def scan_intel(db: Session) -> dict:
|
||||
|
||||
# 3. Fetch all feeds once
|
||||
all_entries: list[tuple[str, dict[str, str]]] = [] # (feed_name, entry)
|
||||
# Assign feeds_ok = 0
|
||||
feeds_ok = 0
|
||||
# Iterate over RSS_FEEDS
|
||||
for feed in RSS_FEEDS:
|
||||
# Assign entries = _fetch_feed(feed["url"])
|
||||
entries = _fetch_feed(feed["url"])
|
||||
# Check: entries
|
||||
if entries:
|
||||
# Assign feeds_ok = 1
|
||||
feeds_ok += 1
|
||||
# Iterate over entries
|
||||
for entry in entries:
|
||||
# Call all_entries.append()
|
||||
all_entries.append((feed["name"], entry))
|
||||
|
||||
# Log info: "Fetched %d entries from %d/%d feeds", len(all_ent
|
||||
logger.info("Fetched %d entries from %d/%d feeds", len(all_entries), feeds_ok, len(RSS_FEEDS))
|
||||
|
||||
# 4. Match entries to techniques
|
||||
new_items = 0
|
||||
# Assign duplicates_skipped = 0
|
||||
duplicates_skipped = 0
|
||||
# Assign techniques_flagged = set()
|
||||
techniques_flagged: set[str] = set()
|
||||
|
||||
# Iterate over techniques
|
||||
for technique in techniques:
|
||||
id_patterns, name_patterns = _build_patterns(technique)
|
||||
|
||||
# Iterate over all_entries
|
||||
for feed_name, entry in all_entries:
|
||||
if not _entry_matches(entry, id_patterns, name_patterns):
|
||||
continue
|
||||
@@ -213,45 +287,69 @@ def scan_intel(db: Session) -> dict:
|
||||
if not entry.get("title", "").strip():
|
||||
continue
|
||||
|
||||
# Assign url = entry.get("link", "").strip()
|
||||
url = entry.get("link", "").strip()
|
||||
# Check: not url
|
||||
if not url:
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
|
||||
# Dedup
|
||||
if url in existing_urls:
|
||||
# Assign duplicates_skipped = 1
|
||||
duplicates_skipped += 1
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
|
||||
# Create IntelItem
|
||||
intel_item = IntelItem(
|
||||
# Keyword argument: technique_id
|
||||
technique_id=technique.id,
|
||||
# Keyword argument: url
|
||||
url=url,
|
||||
# Keyword argument: title
|
||||
title=entry.get("title", "")[:500],
|
||||
# Keyword argument: source
|
||||
source=feed_name,
|
||||
# Keyword argument: detected_at
|
||||
detected_at=datetime.utcnow(),
|
||||
# Keyword argument: reviewed
|
||||
reviewed=False,
|
||||
)
|
||||
# Stage new record(s) for database insertion
|
||||
db.add(intel_item)
|
||||
# Call existing_urls.add()
|
||||
existing_urls.add(url)
|
||||
# Assign new_items = 1
|
||||
new_items += 1
|
||||
|
||||
# Flag technique for review
|
||||
if not technique.review_required:
|
||||
# Assign technique.review_required = True
|
||||
technique.review_required = True
|
||||
# Call techniques_flagged.add()
|
||||
techniques_flagged.add(technique.mitre_id)
|
||||
|
||||
# 5. Single commit
|
||||
db.commit()
|
||||
|
||||
# Assign summary = {
|
||||
summary = {
|
||||
# Literal argument value
|
||||
"new_items": new_items,
|
||||
# Literal argument value
|
||||
"duplicates_skipped": duplicates_skipped,
|
||||
# Literal argument value
|
||||
"techniques_flagged": len(techniques_flagged),
|
||||
# Literal argument value
|
||||
"feeds_checked": feeds_ok,
|
||||
}
|
||||
|
||||
# Log info:
|
||||
logger.info(
|
||||
# Literal argument value
|
||||
"Intel scan complete — new=%d, duplicates_skipped=%d, "
|
||||
# Literal argument value
|
||||
"techniques_flagged=%d, feeds_checked=%d",
|
||||
new_items, duplicates_skipped, len(techniques_flagged), feeds_ok,
|
||||
)
|
||||
@@ -259,12 +357,19 @@ def scan_intel(db: Session) -> dict:
|
||||
# 6. Audit log
|
||||
log_action(
|
||||
db,
|
||||
# Keyword argument: user_id
|
||||
user_id=None,
|
||||
# Keyword argument: action
|
||||
action="intel_scan",
|
||||
# Keyword argument: entity_type
|
||||
entity_type="intel_item",
|
||||
# Keyword argument: entity_id
|
||||
entity_id=None,
|
||||
# Keyword argument: details
|
||||
details=summary,
|
||||
)
|
||||
# Commit all pending changes to the database
|
||||
db.commit()
|
||||
|
||||
# Return summary
|
||||
return summary
|
||||
|
||||
@@ -28,22 +28,44 @@ creates the Jira ticket and stores the link.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
# Import logging
|
||||
import logging
|
||||
|
||||
# Import datetime from datetime
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
# Import Any, Optional from typing
|
||||
from typing import Any, Optional
|
||||
|
||||
# Import UUID from uuid
|
||||
from uuid import UUID
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import settings from app.config
|
||||
from app.config import settings
|
||||
|
||||
# Import EntityNotFoundError from app.domain.errors
|
||||
from app.domain.errors import EntityNotFoundError
|
||||
|
||||
# Import InvalidOperationError from app.domain.exceptions
|
||||
from app.domain.exceptions import InvalidOperationError
|
||||
|
||||
# Import Campaign from app.models.campaign
|
||||
from app.models.campaign import Campaign
|
||||
|
||||
# Import JiraLink, JiraLinkEntityType, JiraSyncDirection from app.models.jira_link
|
||||
from app.models.jira_link import JiraLink, JiraLinkEntityType, JiraSyncDirection
|
||||
|
||||
# Import Technique from app.models.technique
|
||||
from app.models.technique import Technique
|
||||
|
||||
# Import Test from app.models.test
|
||||
from app.models.test import Test
|
||||
from app.models.user import User
|
||||
|
||||
# Assign logger = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -624,6 +646,7 @@ def get_jira_client():
|
||||
Prefer ``get_user_jira_client()`` for new code.
|
||||
"""
|
||||
if not settings.JIRA_ENABLED:
|
||||
# Raise InvalidOperationError
|
||||
raise InvalidOperationError("Jira integration is not enabled")
|
||||
if not settings.JIRA_URL or not settings.JIRA_USERNAME or not settings.JIRA_API_TOKEN:
|
||||
raise InvalidOperationError(
|
||||
@@ -639,75 +662,121 @@ def get_jira_client():
|
||||
)
|
||||
|
||||
|
||||
# Define function search_jira_issues
|
||||
def search_jira_issues(query: str, max_results: int = 10) -> list[dict]:
|
||||
"""Search Jira issues by JQL or free text (uses global credentials)."""
|
||||
jira = get_jira_client()
|
||||
# Assign jql = query if "=" in query or "~" in query else f'summary ~ "{query}"'
|
||||
jql = query if "=" in query or "~" in query else f'summary ~ "{query}"'
|
||||
# Assign results = jira.jql(jql, limit=max_results)
|
||||
results = jira.jql(jql, limit=max_results)
|
||||
# Return [
|
||||
return [
|
||||
{
|
||||
# Literal argument value
|
||||
"issue_key": issue["key"],
|
||||
# Literal argument value
|
||||
"summary": issue["fields"]["summary"],
|
||||
# Literal argument value
|
||||
"status": issue["fields"]["status"]["name"],
|
||||
# Literal argument value
|
||||
"assignee": (issue["fields"].get("assignee") or {}).get("displayName"),
|
||||
# Literal argument value
|
||||
"priority": (issue["fields"].get("priority") or {}).get("name"),
|
||||
}
|
||||
for issue in results.get("issues", [])
|
||||
]
|
||||
|
||||
|
||||
# Define function create_jira_issue
|
||||
def create_jira_issue(
|
||||
# Entry: project_key
|
||||
project_key: str,
|
||||
# Entry: summary
|
||||
summary: str,
|
||||
# Entry: description
|
||||
description: str,
|
||||
# Entry: issue_type
|
||||
issue_type: str = "Task",
|
||||
# Entry: labels
|
||||
labels: Optional[list[str]] = None,
|
||||
# Entry: custom_fields
|
||||
custom_fields: Optional[dict] = None,
|
||||
) -> dict:
|
||||
"""Create a Jira issue and return its key + id (uses global credentials)."""
|
||||
jira = get_jira_client()
|
||||
# Assign fields = {
|
||||
fields: dict = {
|
||||
# Literal argument value
|
||||
"project": {"key": project_key},
|
||||
# Literal argument value
|
||||
"summary": summary,
|
||||
# Literal argument value
|
||||
"description": description,
|
||||
# Literal argument value
|
||||
"issuetype": {"name": issue_type},
|
||||
}
|
||||
# Check: labels
|
||||
if labels:
|
||||
# Assign fields["labels"] = labels
|
||||
fields["labels"] = labels
|
||||
# Check: custom_fields
|
||||
if custom_fields:
|
||||
# Call fields.update()
|
||||
fields.update(custom_fields)
|
||||
|
||||
# Assign result = jira.issue_create(fields=fields)
|
||||
result = jira.issue_create(fields=fields)
|
||||
# Return {"issue_key": result["key"], "issue_id": result["id"]}
|
||||
return {"issue_key": result["key"], "issue_id": result["id"]}
|
||||
|
||||
|
||||
# Define function sync_jira_to_aegis
|
||||
def sync_jira_to_aegis(db: Session, link: JiraLink) -> None:
|
||||
"""Pull current status from Jira into the local link record (global creds)."""
|
||||
jira = get_jira_client()
|
||||
# Assign issue = jira.issue(link.jira_issue_key)
|
||||
issue = jira.issue(link.jira_issue_key)
|
||||
# Assign fields = issue.get("fields", {})
|
||||
fields = issue.get("fields", {})
|
||||
# Assign link.jira_status = fields.get("status", {}).get("name")
|
||||
link.jira_status = fields.get("status", {}).get("name")
|
||||
# Assign link.jira_priority = (fields.get("priority") or {}).get("name")
|
||||
link.jira_priority = (fields.get("priority") or {}).get("name")
|
||||
# Assign link.jira_assignee = (fields.get("assignee") or {}).get("displayName")
|
||||
link.jira_assignee = (fields.get("assignee") or {}).get("displayName")
|
||||
# Assign link.jira_story_points = str(fields.get("customfield_10016", ""))
|
||||
link.jira_story_points = str(fields.get("customfield_10016", ""))
|
||||
# Assign link.last_synced_at = datetime.utcnow()
|
||||
link.last_synced_at = datetime.utcnow()
|
||||
# Flush changes to DB without committing the transaction
|
||||
db.flush()
|
||||
|
||||
|
||||
# Define function sync_aegis_to_jira
|
||||
def sync_aegis_to_jira(db: Session, link: JiraLink, entity_data: dict) -> None:
|
||||
"""Push an Aegis status update as a Jira comment (global creds)."""
|
||||
jira = get_jira_client()
|
||||
# Assign comment_body = _build_sync_comment(entity_data)
|
||||
comment_body = _build_sync_comment(entity_data)
|
||||
# Call jira.issue_add_comment()
|
||||
jira.issue_add_comment(link.jira_issue_key, comment_body)
|
||||
# Assign link.last_synced_at = datetime.utcnow()
|
||||
link.last_synced_at = datetime.utcnow()
|
||||
# Flush changes to DB without committing the transaction
|
||||
db.flush()
|
||||
|
||||
|
||||
# Define function _build_sync_comment
|
||||
def _build_sync_comment(data: dict) -> str:
|
||||
lines = ["h3. Aegis Sync Update", ""]
|
||||
# Iterate over data.items()
|
||||
for key, value in data.items():
|
||||
# Call lines.append()
|
||||
lines.append(f"*{key}:* {value}")
|
||||
# Call lines.append()
|
||||
lines.append(f"\n_Synced at {datetime.utcnow().isoformat()}_")
|
||||
# Return "\n".join(lines)
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
@@ -715,60 +784,94 @@ def _build_sync_comment(data: dict) -> str:
|
||||
|
||||
|
||||
def create_link(
|
||||
# Entry: db
|
||||
db: Session,
|
||||
*,
|
||||
# Entry: entity_type
|
||||
entity_type: JiraLinkEntityType,
|
||||
# Entry: entity_id
|
||||
entity_id: UUID,
|
||||
# Entry: jira_issue_key
|
||||
jira_issue_key: str,
|
||||
# Entry: sync_direction
|
||||
sync_direction: JiraSyncDirection,
|
||||
# Entry: created_by
|
||||
created_by: UUID,
|
||||
) -> JiraLink:
|
||||
link = JiraLink(
|
||||
# Keyword argument: entity_type
|
||||
entity_type=entity_type,
|
||||
# Keyword argument: entity_id
|
||||
entity_id=entity_id,
|
||||
# Keyword argument: jira_issue_key
|
||||
jira_issue_key=jira_issue_key,
|
||||
# Keyword argument: sync_direction
|
||||
sync_direction=sync_direction,
|
||||
# Keyword argument: created_by
|
||||
created_by=created_by,
|
||||
)
|
||||
# Stage new record(s) for database insertion
|
||||
db.add(link)
|
||||
# Flush changes to DB without committing the transaction
|
||||
db.flush()
|
||||
|
||||
# Check: settings.JIRA_ENABLED
|
||||
if settings.JIRA_ENABLED:
|
||||
# Attempt the following; catch errors below
|
||||
try:
|
||||
# Call sync_jira_to_aegis()
|
||||
sync_jira_to_aegis(db, link)
|
||||
# Handle Exception
|
||||
except Exception as e:
|
||||
# Log warning: "Initial Jira sync failed for %s: %s", jira_issue_
|
||||
logger.warning("Initial Jira sync failed for %s: %s", jira_issue_key, e)
|
||||
|
||||
# Return link
|
||||
return link
|
||||
|
||||
|
||||
# Define function list_links
|
||||
def list_links(
|
||||
# Entry: db
|
||||
db: Session,
|
||||
*,
|
||||
# Entry: entity_type
|
||||
entity_type: Optional[JiraLinkEntityType] = None,
|
||||
# Entry: entity_id
|
||||
entity_id: Optional[UUID] = None,
|
||||
entity_ids: Optional[list[UUID]] = None,
|
||||
) -> list[JiraLink]:
|
||||
query = db.query(JiraLink)
|
||||
# Check: entity_type
|
||||
if entity_type:
|
||||
# Assign query = query.filter(JiraLink.entity_type == entity_type)
|
||||
query = query.filter(JiraLink.entity_type == entity_type)
|
||||
# Check: entity_id
|
||||
if entity_id:
|
||||
# Assign query = query.filter(JiraLink.entity_id == entity_id)
|
||||
query = query.filter(JiraLink.entity_id == entity_id)
|
||||
elif entity_ids:
|
||||
query = query.filter(JiraLink.entity_id.in_(entity_ids))
|
||||
return query.order_by(JiraLink.created_at.desc()).all()
|
||||
|
||||
|
||||
# Define function get_link_or_raise
|
||||
def get_link_or_raise(db: Session, link_id: UUID) -> JiraLink:
|
||||
link = db.query(JiraLink).filter(JiraLink.id == link_id).first()
|
||||
# Check: not link
|
||||
if not link:
|
||||
# Raise EntityNotFoundError
|
||||
raise EntityNotFoundError("JiraLink", str(link_id))
|
||||
# Return link
|
||||
return link
|
||||
|
||||
|
||||
# Define function delete_link
|
||||
def delete_link(db: Session, link_id: UUID) -> JiraLink:
|
||||
link = get_link_or_raise(db, link_id)
|
||||
# Mark record for deletion on next commit
|
||||
db.delete(link)
|
||||
# Return link
|
||||
return link
|
||||
|
||||
|
||||
@@ -776,43 +879,64 @@ def build_issue_data(
|
||||
db: Session, entity_type: JiraLinkEntityType, entity_id: UUID
|
||||
) -> tuple[str, str]:
|
||||
"""Build Jira issue summary and description from an Aegis entity."""
|
||||
# Check: entity_type == JiraLinkEntityType.test
|
||||
if entity_type == JiraLinkEntityType.test:
|
||||
# Assign entity = db.query(Test).filter(Test.id == entity_id).first()
|
||||
entity = db.query(Test).filter(Test.id == entity_id).first()
|
||||
# Check: not entity
|
||||
if not entity:
|
||||
# Raise EntityNotFoundError
|
||||
raise EntityNotFoundError("Test", str(entity_id))
|
||||
technique = db.query(Technique).filter(Technique.id == entity.technique_id).first()
|
||||
return (
|
||||
f"[Aegis] {technique.mitre_id if technique else 'N/A'} — {entity.name}",
|
||||
_build_test_description(entity, technique),
|
||||
)
|
||||
# Alternative: entity_type == JiraLinkEntityType.campaign
|
||||
elif entity_type == JiraLinkEntityType.campaign:
|
||||
# Assign entity = db.query(Campaign).filter(Campaign.id == entity_id).first()
|
||||
entity = db.query(Campaign).filter(Campaign.id == entity_id).first()
|
||||
# Check: not entity
|
||||
if not entity:
|
||||
# Raise EntityNotFoundError
|
||||
raise EntityNotFoundError("Campaign", str(entity_id))
|
||||
# Return (
|
||||
return (
|
||||
f"[Aegis Campaign] {entity.name}",
|
||||
f"Campaign: {entity.name}\nType: {entity.type}\nStatus: {entity.status}\n"
|
||||
f"Description: {entity.description or 'N/A'}",
|
||||
)
|
||||
# Alternative: entity_type == JiraLinkEntityType.technique
|
||||
elif entity_type == JiraLinkEntityType.technique:
|
||||
# Assign entity = db.query(Technique).filter(Technique.id == entity_id).first()
|
||||
entity = db.query(Technique).filter(Technique.id == entity_id).first()
|
||||
# Check: not entity
|
||||
if not entity:
|
||||
# Raise EntityNotFoundError
|
||||
raise EntityNotFoundError("Technique", str(entity_id))
|
||||
# Return (
|
||||
return (
|
||||
f"[Aegis Technique] {entity.mitre_id} - {entity.name}",
|
||||
f"MITRE ID: {entity.mitre_id}\nName: {entity.name}\n"
|
||||
f"Tactic: {entity.tactic or 'N/A'}\n"
|
||||
f"Description: {entity.description or 'N/A'}",
|
||||
)
|
||||
# Fallback: handle remaining cases
|
||||
else:
|
||||
# Return f"[Aegis] Entity {entity_id}", f"Entity type: {entity_type.value}"
|
||||
return f"[Aegis] Entity {entity_id}", f"Entity type: {entity_type.value}"
|
||||
|
||||
|
||||
# Define function create_issue_and_link
|
||||
def create_issue_and_link(
|
||||
# Entry: db
|
||||
db: Session,
|
||||
*,
|
||||
# Entry: entity_type
|
||||
entity_type: JiraLinkEntityType,
|
||||
# Entry: entity_id
|
||||
entity_id: UUID,
|
||||
# Entry: created_by
|
||||
created_by: UUID,
|
||||
) -> dict:
|
||||
"""Create a Jira issue from an Aegis entity and link them (global creds)."""
|
||||
@@ -821,16 +945,25 @@ def create_issue_and_link(
|
||||
result = create_jira_issue(
|
||||
project_key=project_key,
|
||||
summary=summary,
|
||||
# Keyword argument: description
|
||||
description=description,
|
||||
# Keyword argument: labels
|
||||
labels=["aegis", entity_type.value],
|
||||
)
|
||||
# Assign link = JiraLink(
|
||||
link = JiraLink(
|
||||
# Keyword argument: entity_type
|
||||
entity_type=entity_type,
|
||||
# Keyword argument: entity_id
|
||||
entity_id=entity_id,
|
||||
# Keyword argument: jira_issue_key
|
||||
jira_issue_key=result["issue_key"],
|
||||
# Keyword argument: jira_issue_id
|
||||
jira_issue_id=result["issue_id"],
|
||||
jira_project_key=project_key,
|
||||
created_by=created_by,
|
||||
)
|
||||
# Stage new record(s) for database insertion
|
||||
db.add(link)
|
||||
# Return {"issue_key": result["issue_key"], "link_id": str(link.id)}
|
||||
return {"issue_key": result["issue_key"], "link_id": str(link.id)}
|
||||
|
||||
@@ -24,24 +24,45 @@ Deduplication keys:
|
||||
- GTFOBins: ``source + binary_name + function`` → stored in ``atomic_test_id``
|
||||
"""
|
||||
|
||||
# Import io
|
||||
import io
|
||||
|
||||
# Import logging
|
||||
import logging
|
||||
|
||||
# Import re
|
||||
import re
|
||||
|
||||
# Import shutil
|
||||
import shutil
|
||||
|
||||
# Import tempfile
|
||||
import tempfile
|
||||
|
||||
# Import zipfile
|
||||
import zipfile
|
||||
|
||||
# Import datetime from datetime
|
||||
from datetime import datetime
|
||||
|
||||
# Import Path from pathlib
|
||||
from pathlib import Path
|
||||
|
||||
# Import requests
|
||||
import requests as _requests
|
||||
|
||||
# Import yaml
|
||||
import yaml
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.test_template import TestTemplate
|
||||
# Import DataSource from app.models.data_source
|
||||
from app.models.data_source import DataSource
|
||||
from app.models.technique import Technique
|
||||
from app.services.audit_service import log_action
|
||||
|
||||
# Assign logger = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -49,34 +70,57 @@ logger = logging.getLogger(__name__)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
LOLBAS_ZIP_URL = (
|
||||
# Literal argument value
|
||||
"https://github.com/LOLBAS-Project/LOLBAS"
|
||||
# Literal argument value
|
||||
"/archive/refs/heads/master.zip"
|
||||
)
|
||||
# Assign GTFOBINS_ZIP_URL = (
|
||||
GTFOBINS_ZIP_URL = (
|
||||
# Literal argument value
|
||||
"https://github.com/GTFOBins/GTFOBins.github.io"
|
||||
# Literal argument value
|
||||
"/archive/refs/heads/master.zip"
|
||||
)
|
||||
|
||||
# Assign _DOWNLOAD_TIMEOUT = 300
|
||||
_DOWNLOAD_TIMEOUT = 300
|
||||
|
||||
# GTFOBins function → MITRE technique mapping
|
||||
_GTFOBINS_FUNCTION_MAP: dict[str, str] = {
|
||||
# Literal argument value
|
||||
"shell": "T1059",
|
||||
# Literal argument value
|
||||
"command": "T1059",
|
||||
# Literal argument value
|
||||
"reverse-shell": "T1059",
|
||||
# Literal argument value
|
||||
"non-interactive-reverse-shell": "T1059",
|
||||
# Literal argument value
|
||||
"bind-shell": "T1059",
|
||||
# Literal argument value
|
||||
"non-interactive-bind-shell": "T1059",
|
||||
# Literal argument value
|
||||
"file-upload": "T1105",
|
||||
# Literal argument value
|
||||
"file-download": "T1105",
|
||||
# Literal argument value
|
||||
"upload": "T1105",
|
||||
# Literal argument value
|
||||
"download": "T1105",
|
||||
# Literal argument value
|
||||
"file-write": "T1105",
|
||||
# Literal argument value
|
||||
"file-read": "T1005",
|
||||
# Literal argument value
|
||||
"library-load": "T1129",
|
||||
# Literal argument value
|
||||
"sudo": "T1548.003",
|
||||
# Literal argument value
|
||||
"suid": "T1548.001",
|
||||
# Literal argument value
|
||||
"capabilities": "T1548",
|
||||
# Literal argument value
|
||||
"limited-suid": "T1548.001",
|
||||
}
|
||||
|
||||
@@ -88,18 +132,28 @@ _GTFOBINS_FUNCTION_MAP: dict[str, str] = {
|
||||
|
||||
def _download_zip(url: str) -> bytes:
|
||||
"""Download a ZIP from *url* and return raw bytes."""
|
||||
# Log info: "Downloading ZIP from %s …", url
|
||||
logger.info("Downloading ZIP from %s …", url)
|
||||
# Assign resp = _requests.get(url, timeout=_DOWNLOAD_TIMEOUT, stream=True)
|
||||
resp = _requests.get(url, timeout=_DOWNLOAD_TIMEOUT, stream=True)
|
||||
# Call resp.raise_for_status()
|
||||
resp.raise_for_status()
|
||||
# Assign content = resp.content
|
||||
content = resp.content
|
||||
# Log info: "Downloaded %.1f MB", len(content) / (1024 * 1024
|
||||
logger.info("Downloaded %.1f MB", len(content) / (1024 * 1024))
|
||||
# Return content
|
||||
return content
|
||||
|
||||
|
||||
# Define function _extract_zip
|
||||
def _extract_zip(zip_bytes: bytes, dest: str) -> Path:
|
||||
"""Extract *zip_bytes* into *dest* and return the root directory."""
|
||||
# Open context manager
|
||||
with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
|
||||
# Call zf.extractall()
|
||||
zf.extractall(dest)
|
||||
# Return Path(dest)
|
||||
return Path(dest)
|
||||
|
||||
|
||||
@@ -110,83 +164,141 @@ def _extract_zip(zip_bytes: bytes, dest: str) -> Path:
|
||||
|
||||
def _parse_lolbas(root_dir: Path) -> list[dict]:
|
||||
"""Parse LOLBAS YAML files and return template dicts."""
|
||||
# Assign results = []
|
||||
results: list[dict] = []
|
||||
|
||||
# Assign lolbas_root = root_dir / "LOLBAS-master"
|
||||
lolbas_root = root_dir / "LOLBAS-master"
|
||||
# Assign yaml_dirs = [
|
||||
yaml_dirs = [
|
||||
lolbas_root / "yml" / "OSBinaries",
|
||||
lolbas_root / "yml" / "OSLibraries",
|
||||
lolbas_root / "yml" / "OSScripts",
|
||||
]
|
||||
|
||||
# Assign yaml_files = []
|
||||
yaml_files = []
|
||||
# Iterate over yaml_dirs
|
||||
for d in yaml_dirs:
|
||||
# Check: d.is_dir()
|
||||
if d.is_dir():
|
||||
# Call yaml_files.extend()
|
||||
yaml_files.extend(sorted(d.rglob("*.yml")))
|
||||
|
||||
# Log info: "LOLBAS: Found %d YAML files", len(yaml_files
|
||||
logger.info("LOLBAS: Found %d YAML files", len(yaml_files))
|
||||
|
||||
# Iterate over yaml_files
|
||||
for yaml_path in yaml_files:
|
||||
# Attempt the following; catch errors below
|
||||
try:
|
||||
# Open context manager
|
||||
with open(yaml_path, "r", encoding="utf-8") as fh:
|
||||
# Assign data = yaml.safe_load(fh)
|
||||
data = yaml.safe_load(fh)
|
||||
# Handle Exception
|
||||
except Exception as exc:
|
||||
# Log debug: "Failed to parse %s: %s", yaml_path, exc
|
||||
logger.debug("Failed to parse %s: %s", yaml_path, exc)
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
|
||||
# Check: not isinstance(data, dict)
|
||||
if not isinstance(data, dict):
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
|
||||
# Assign binary_name = data.get("Name", "").strip()
|
||||
binary_name = data.get("Name", "").strip()
|
||||
# Check: not binary_name
|
||||
if not binary_name:
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
|
||||
# Assign description = data.get("Description", "")
|
||||
description = data.get("Description", "")
|
||||
# Assign commands = data.get("Commands", [])
|
||||
commands = data.get("Commands", [])
|
||||
# Check: not isinstance(commands, list)
|
||||
if not isinstance(commands, list):
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
|
||||
# Iterate over commands
|
||||
for cmd_entry in commands:
|
||||
# Check: not isinstance(cmd_entry, dict)
|
||||
if not isinstance(cmd_entry, dict):
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
|
||||
# Assign mitre_id = cmd_entry.get("MitreID")
|
||||
mitre_id = cmd_entry.get("MitreID")
|
||||
# Check: not mitre_id
|
||||
if not mitre_id:
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
|
||||
# Normalise the MITRE ID
|
||||
mitre_id = str(mitre_id).strip().upper()
|
||||
# Check: not mitre_id.startswith("T")
|
||||
if not mitre_id.startswith("T"):
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
|
||||
# Assign command = cmd_entry.get("Command", "")
|
||||
command = cmd_entry.get("Command", "")
|
||||
# Assign usecase = cmd_entry.get("Usecase", "")
|
||||
usecase = cmd_entry.get("Usecase", "")
|
||||
# Assign cmd_description = cmd_entry.get("Description", "")
|
||||
cmd_description = cmd_entry.get("Description", "")
|
||||
|
||||
# Dedup key
|
||||
dedup_key = f"lolbas:{binary_name}:{mitre_id}"
|
||||
|
||||
# Assign procedure = []
|
||||
procedure = []
|
||||
# Check: cmd_description
|
||||
if cmd_description:
|
||||
# Call procedure.append()
|
||||
procedure.append(f"Description: {cmd_description}")
|
||||
# Check: usecase
|
||||
if usecase:
|
||||
# Call procedure.append()
|
||||
procedure.append(f"Use case: {usecase}")
|
||||
# Check: command
|
||||
if command:
|
||||
# Call procedure.append()
|
||||
procedure.append(f"Command: {command}")
|
||||
|
||||
# Call results.append()
|
||||
results.append({
|
||||
# Literal argument value
|
||||
"mitre_technique_id": mitre_id,
|
||||
# Literal argument value
|
||||
"name": f"LOLBAS: {binary_name} — {usecase or cmd_description or mitre_id}"[:500],
|
||||
"description": f"{description}\n\n{cmd_description}".strip()[:2000] if description else cmd_description[:2000] if cmd_description else None,
|
||||
# Literal argument value
|
||||
"description": (
|
||||
f"{description}\n\n{cmd_description}".strip()[:2000]
|
||||
if description
|
||||
else cmd_description[:2000] if cmd_description else None
|
||||
),
|
||||
# Literal argument value
|
||||
"source": "lolbas",
|
||||
# Literal argument value
|
||||
"platform": "windows",
|
||||
# Literal argument value
|
||||
"tool_suggested": binary_name,
|
||||
# Literal argument value
|
||||
"attack_procedure": "\n".join(procedure)[:4000] if procedure else None,
|
||||
# Literal argument value
|
||||
"atomic_test_id": dedup_key,
|
||||
# Literal argument value
|
||||
"source_url": f"https://lolbas-project.github.io/lolbas/Binaries/{binary_name}/",
|
||||
})
|
||||
|
||||
# Log info: "LOLBAS: Parsed %d templates", len(results
|
||||
logger.info("LOLBAS: Parsed %d templates", len(results))
|
||||
# Return results
|
||||
return results
|
||||
|
||||
|
||||
@@ -197,85 +309,138 @@ def _parse_lolbas(root_dir: Path) -> list[dict]:
|
||||
|
||||
def _parse_gtfobins(root_dir: Path) -> list[dict]:
|
||||
"""Parse GTFOBins markdown files and return template dicts."""
|
||||
# Assign results = []
|
||||
results: list[dict] = []
|
||||
|
||||
# Assign gtfobins_root = root_dir / "GTFOBins.github.io-master" / "_gtfobins"
|
||||
gtfobins_root = root_dir / "GTFOBins.github.io-master" / "_gtfobins"
|
||||
# Check: not gtfobins_root.is_dir()
|
||||
if not gtfobins_root.is_dir():
|
||||
# Log warning: "GTFOBins directory not found at %s", gtfobins_roo
|
||||
logger.warning("GTFOBins directory not found at %s", gtfobins_root)
|
||||
# Return results
|
||||
return results
|
||||
|
||||
# Assign md_files = sorted(
|
||||
md_files = sorted(
|
||||
f for f in gtfobins_root.iterdir()
|
||||
if f.is_file() and f.suffix in (".md", "")
|
||||
)
|
||||
# Log info: "GTFOBins: Found %d files", len(md_files
|
||||
logger.info("GTFOBins: Found %d files", len(md_files))
|
||||
|
||||
# Iterate over md_files
|
||||
for md_path in md_files:
|
||||
# Assign binary_name = md_path.stem # e.g. "awk"
|
||||
binary_name = md_path.stem # e.g. "awk"
|
||||
# Attempt the following; catch errors below
|
||||
try:
|
||||
# Open context manager
|
||||
with open(md_path, "r", encoding="utf-8") as fh:
|
||||
# Assign content = fh.read()
|
||||
content = fh.read()
|
||||
# Handle Exception
|
||||
except Exception as exc:
|
||||
# Log debug: "Failed to read %s: %s", md_path, exc
|
||||
logger.debug("Failed to read %s: %s", md_path, exc)
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
|
||||
# Extract YAML front-matter
|
||||
front_matter = _extract_front_matter(content)
|
||||
# Check: not front_matter
|
||||
if not front_matter:
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
|
||||
# Assign functions = front_matter.get("functions", {})
|
||||
functions = front_matter.get("functions", {})
|
||||
# Check: not isinstance(functions, dict)
|
||||
if not isinstance(functions, dict):
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
|
||||
# Iterate over functions.items()
|
||||
for func_name, func_data in functions.items():
|
||||
# Map function to MITRE technique
|
||||
mitre_id = _GTFOBINS_FUNCTION_MAP.get(func_name.lower())
|
||||
# Check: not mitre_id
|
||||
if not mitre_id:
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
|
||||
# Extract code examples from function data
|
||||
examples = []
|
||||
# Check: isinstance(func_data, list)
|
||||
if isinstance(func_data, list):
|
||||
# Iterate over func_data
|
||||
for entry in func_data:
|
||||
# Check: isinstance(entry, dict)
|
||||
if isinstance(entry, dict):
|
||||
# Assign code = entry.get("code", "")
|
||||
code = entry.get("code", "")
|
||||
# Check: code
|
||||
if code:
|
||||
# Call examples.append()
|
||||
examples.append(str(code))
|
||||
# Alternative: isinstance(entry, str)
|
||||
elif isinstance(entry, str):
|
||||
# Call examples.append()
|
||||
examples.append(entry)
|
||||
|
||||
# Assign procedure = "\n\n".join(examples) if examples else None
|
||||
procedure = "\n\n".join(examples) if examples else None
|
||||
# Assign dedup_key = f"gtfobins:{binary_name}:{func_name}"
|
||||
dedup_key = f"gtfobins:{binary_name}:{func_name}"
|
||||
|
||||
# Call results.append()
|
||||
results.append({
|
||||
# Literal argument value
|
||||
"mitre_technique_id": mitre_id,
|
||||
# Literal argument value
|
||||
"name": f"GTFOBins: {binary_name} — {func_name}"[:500],
|
||||
# Literal argument value
|
||||
"description": f"Abuse {binary_name} binary for {func_name} on Linux/Unix."[:2000],
|
||||
# Literal argument value
|
||||
"source": "gtfobins",
|
||||
# Literal argument value
|
||||
"platform": "linux",
|
||||
# Literal argument value
|
||||
"tool_suggested": binary_name,
|
||||
# Literal argument value
|
||||
"attack_procedure": procedure[:4000] if procedure else None,
|
||||
# Literal argument value
|
||||
"atomic_test_id": dedup_key,
|
||||
# Literal argument value
|
||||
"source_url": f"https://gtfobins.github.io/gtfobins/{binary_name}/",
|
||||
})
|
||||
|
||||
# Log info: "GTFOBins: Parsed %d templates", len(results
|
||||
logger.info("GTFOBins: Parsed %d templates", len(results))
|
||||
# Return results
|
||||
return results
|
||||
|
||||
|
||||
# Define function _extract_front_matter
|
||||
def _extract_front_matter(content: str) -> dict | None:
|
||||
"""Extract YAML front-matter from a markdown/GTFOBins file.
|
||||
|
||||
Supports both ``---/---`` (standard front-matter) and ``---/...``
|
||||
(YAML document-end marker used by GTFOBins).
|
||||
"""
|
||||
# Assign match = re.match(r"^---\s*\n(.*?)\n(?:---|\.\.\.)", content, re.DOTALL)
|
||||
match = re.match(r"^---\s*\n(.*?)\n(?:---|\.\.\.)", content, re.DOTALL)
|
||||
# Check: not match
|
||||
if not match:
|
||||
# Return None
|
||||
return None
|
||||
# Attempt the following; catch errors below
|
||||
try:
|
||||
# Return yaml.safe_load(match.group(1))
|
||||
return yaml.safe_load(match.group(1))
|
||||
# Handle Exception
|
||||
except Exception:
|
||||
# Return None
|
||||
return None
|
||||
|
||||
|
||||
@@ -286,36 +451,59 @@ def _extract_front_matter(content: str) -> dict | None:
|
||||
|
||||
def _upsert_templates(db: Session, items: list[dict], source_name: str) -> dict:
|
||||
"""Insert templates, skipping existing ones by atomic_test_id."""
|
||||
# Assign existing_ids = {
|
||||
existing_ids: set[str] = {
|
||||
row[0]
|
||||
for row in db.query(TestTemplate.atomic_test_id)
|
||||
# Chain .filter() call
|
||||
.filter(TestTemplate.source == source_name)
|
||||
# Chain .filter() call
|
||||
.filter(TestTemplate.atomic_test_id.isnot(None))
|
||||
# Chain .all() call
|
||||
.all()
|
||||
}
|
||||
|
||||
# Assign created = 0
|
||||
created = 0
|
||||
# Assign skipped = 0
|
||||
skipped = 0
|
||||
new_technique_ids: set[str] = set()
|
||||
|
||||
# Iterate over items
|
||||
for item in items:
|
||||
# Check: item["atomic_test_id"] in existing_ids
|
||||
if item["atomic_test_id"] in existing_ids:
|
||||
# Assign skipped = 1
|
||||
skipped += 1
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
|
||||
# Assign template = TestTemplate(
|
||||
template = TestTemplate(
|
||||
# Keyword argument: mitre_technique_id
|
||||
mitre_technique_id=item["mitre_technique_id"],
|
||||
# Keyword argument: name
|
||||
name=item["name"],
|
||||
# Keyword argument: description
|
||||
description=item["description"],
|
||||
# Keyword argument: source
|
||||
source=item["source"],
|
||||
# Keyword argument: source_url
|
||||
source_url=item.get("source_url"),
|
||||
# Keyword argument: attack_procedure
|
||||
attack_procedure=item.get("attack_procedure"),
|
||||
# Keyword argument: platform
|
||||
platform=item["platform"],
|
||||
# Keyword argument: tool_suggested
|
||||
tool_suggested=item.get("tool_suggested"),
|
||||
# Keyword argument: atomic_test_id
|
||||
atomic_test_id=item["atomic_test_id"],
|
||||
# Keyword argument: is_active
|
||||
is_active=True,
|
||||
)
|
||||
# Stage new record(s) for database insertion
|
||||
db.add(template)
|
||||
# Call existing_ids.add()
|
||||
existing_ids.add(item["atomic_test_id"])
|
||||
new_technique_ids.add(item["mitre_technique_id"])
|
||||
created += 1
|
||||
@@ -326,6 +514,7 @@ def _upsert_templates(db: Session, items: list[dict], source_name: str) -> dict:
|
||||
).update({"review_required": True}, synchronize_session=False)
|
||||
|
||||
db.commit()
|
||||
# Return {"created": created, "skipped_existing": skipped, "total_parsed": l...
|
||||
return {"created": created, "skipped_existing": skipped, "total_parsed": len(items)}
|
||||
|
||||
|
||||
@@ -339,56 +528,93 @@ def sync(db: Session) -> dict:
|
||||
|
||||
Returns a summary dict with ``created``, ``skipped_existing``, ``total_parsed``.
|
||||
"""
|
||||
# Assign tmp_dir = tempfile.mkdtemp(prefix="aegis_lolbas_")
|
||||
tmp_dir = tempfile.mkdtemp(prefix="aegis_lolbas_")
|
||||
# Attempt the following; catch errors below
|
||||
try:
|
||||
# Assign zip_bytes = _download_zip(LOLBAS_ZIP_URL)
|
||||
zip_bytes = _download_zip(LOLBAS_ZIP_URL)
|
||||
# Assign root_dir = _extract_zip(zip_bytes, tmp_dir)
|
||||
root_dir = _extract_zip(zip_bytes, tmp_dir)
|
||||
# Assign parsed = _parse_lolbas(root_dir)
|
||||
parsed = _parse_lolbas(root_dir)
|
||||
# Always execute this cleanup block
|
||||
finally:
|
||||
# Call shutil.rmtree()
|
||||
shutil.rmtree(tmp_dir, ignore_errors=True)
|
||||
|
||||
# Assign summary = _upsert_templates(db, parsed, "lolbas")
|
||||
summary = _upsert_templates(db, parsed, "lolbas")
|
||||
|
||||
# Update DataSource record
|
||||
ds = db.query(DataSource).filter(DataSource.name == "lolbas").first()
|
||||
# Check: ds
|
||||
if ds:
|
||||
# Assign ds.last_sync_at = datetime.utcnow()
|
||||
ds.last_sync_at = datetime.utcnow()
|
||||
# Assign ds.last_sync_status = "success"
|
||||
ds.last_sync_status = "success"
|
||||
# Assign ds.last_sync_stats = summary
|
||||
ds.last_sync_stats = summary
|
||||
# Commit all pending changes to the database
|
||||
db.commit()
|
||||
|
||||
# Log info: "LOLBAS import complete — %s", summary
|
||||
logger.info("LOLBAS import complete — %s", summary)
|
||||
# Call log_action()
|
||||
log_action(db, user_id=None, action="import_lolbas",
|
||||
# Keyword argument: entity_type
|
||||
entity_type="test_template", entity_id=None, details=summary)
|
||||
# Commit all pending changes to the database
|
||||
db.commit()
|
||||
# Return summary
|
||||
return summary
|
||||
|
||||
|
||||
# Define function sync_gtfobins
|
||||
def sync_gtfobins(db: Session) -> dict:
|
||||
"""Import GTFOBins templates.
|
||||
|
||||
Returns a summary dict with ``created``, ``skipped_existing``, ``total_parsed``.
|
||||
"""
|
||||
# Assign tmp_dir = tempfile.mkdtemp(prefix="aegis_gtfobins_")
|
||||
tmp_dir = tempfile.mkdtemp(prefix="aegis_gtfobins_")
|
||||
# Attempt the following; catch errors below
|
||||
try:
|
||||
# Assign zip_bytes = _download_zip(GTFOBINS_ZIP_URL)
|
||||
zip_bytes = _download_zip(GTFOBINS_ZIP_URL)
|
||||
# Assign root_dir = _extract_zip(zip_bytes, tmp_dir)
|
||||
root_dir = _extract_zip(zip_bytes, tmp_dir)
|
||||
# Assign parsed = _parse_gtfobins(root_dir)
|
||||
parsed = _parse_gtfobins(root_dir)
|
||||
# Always execute this cleanup block
|
||||
finally:
|
||||
# Call shutil.rmtree()
|
||||
shutil.rmtree(tmp_dir, ignore_errors=True)
|
||||
|
||||
# Assign summary = _upsert_templates(db, parsed, "gtfobins")
|
||||
summary = _upsert_templates(db, parsed, "gtfobins")
|
||||
|
||||
# Update DataSource record
|
||||
ds = db.query(DataSource).filter(DataSource.name == "gtfobins").first()
|
||||
# Check: ds
|
||||
if ds:
|
||||
# Assign ds.last_sync_at = datetime.utcnow()
|
||||
ds.last_sync_at = datetime.utcnow()
|
||||
# Assign ds.last_sync_status = "success"
|
||||
ds.last_sync_status = "success"
|
||||
# Assign ds.last_sync_stats = summary
|
||||
ds.last_sync_stats = summary
|
||||
# Commit all pending changes to the database
|
||||
db.commit()
|
||||
|
||||
# Log info: "GTFOBins import complete — %s", summary
|
||||
logger.info("GTFOBins import complete — %s", summary)
|
||||
# Call log_action()
|
||||
log_action(db, user_id=None, action="import_gtfobins",
|
||||
# Keyword argument: entity_type
|
||||
entity_type="test_template", entity_id=None, details=summary)
|
||||
# Commit all pending changes to the database
|
||||
db.commit()
|
||||
# Return summary
|
||||
return summary
|
||||
|
||||
@@ -7,16 +7,28 @@ of MITRE ATT&CK technique coverage for dashboards and reporting.
|
||||
This module is framework-agnostic: no FastAPI imports.
|
||||
"""
|
||||
|
||||
# Enable future language features for compatibility
|
||||
from __future__ import annotations
|
||||
|
||||
# Import defaultdict from collections
|
||||
from collections import defaultdict
|
||||
|
||||
# Import func from sqlalchemy
|
||||
from sqlalchemy import func
|
||||
|
||||
# Import Session, joinedload from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
# Import TechniqueStatus, TestState from app.models.enums
|
||||
from app.models.enums import TechniqueStatus, TestState
|
||||
|
||||
# Import Technique from app.models.technique
|
||||
from app.models.technique import Technique
|
||||
|
||||
# Import Test from app.models.test
|
||||
from app.models.test import Test
|
||||
|
||||
# Import from app.schemas.metrics
|
||||
from app.schemas.metrics import (
|
||||
CoverageSummary,
|
||||
RecentTestItem,
|
||||
@@ -27,40 +39,60 @@ from app.schemas.metrics import (
|
||||
)
|
||||
|
||||
|
||||
# Define function get_coverage_summary
|
||||
def get_coverage_summary(db: Session) -> CoverageSummary:
|
||||
"""Return a global coverage summary across all techniques."""
|
||||
# Assign rows = (
|
||||
rows = (
|
||||
db.query(
|
||||
Technique.status_global,
|
||||
func.count(Technique.id).label("cnt"),
|
||||
)
|
||||
# Chain .group_by() call
|
||||
.group_by(Technique.status_global)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
|
||||
# Assign counts = {s.value: 0 for s in TechniqueStatus}
|
||||
counts: dict[str, int] = {s.value: 0 for s in TechniqueStatus}
|
||||
# Iterate over rows
|
||||
for status, cnt in rows:
|
||||
# Assign counts[status.value] = cnt
|
||||
counts[status.value] = cnt
|
||||
|
||||
# Assign total = sum(counts.values())
|
||||
total = sum(counts.values())
|
||||
# Assign validated = counts["validated"]
|
||||
validated = counts["validated"]
|
||||
# Assign partial = counts["partial"]
|
||||
partial = counts["partial"]
|
||||
|
||||
# Assign coverage_pct = (
|
||||
coverage_pct = (
|
||||
round((validated + partial) / total * 100, 2) if total > 0 else 0.0
|
||||
)
|
||||
|
||||
# Return CoverageSummary(
|
||||
return CoverageSummary(
|
||||
# Keyword argument: total_techniques
|
||||
total_techniques=total,
|
||||
# Keyword argument: validated
|
||||
validated=validated,
|
||||
# Keyword argument: partial
|
||||
partial=partial,
|
||||
# Keyword argument: not_covered
|
||||
not_covered=counts["not_covered"],
|
||||
# Keyword argument: in_progress
|
||||
in_progress=counts["in_progress"],
|
||||
# Keyword argument: not_evaluated
|
||||
not_evaluated=counts["not_evaluated"],
|
||||
# Keyword argument: coverage_percentage
|
||||
coverage_percentage=coverage_pct,
|
||||
)
|
||||
|
||||
|
||||
# Define function get_coverage_by_tactic
|
||||
def get_coverage_by_tactic(db: Session) -> list[TacticCoverage]:
|
||||
"""Return coverage breakdown grouped by tactic.
|
||||
|
||||
@@ -68,6 +100,7 @@ def get_coverage_by_tactic(db: Session) -> list[TacticCoverage]:
|
||||
comma-separated string), the technique is counted once per tactic
|
||||
it belongs to.
|
||||
"""
|
||||
# Assign techniques = db.query(
|
||||
techniques = db.query(
|
||||
Technique.tactic, Technique.status_global
|
||||
).all()
|
||||
@@ -75,179 +108,270 @@ def get_coverage_by_tactic(db: Session) -> list[TacticCoverage]:
|
||||
# Accumulate per-tactic counters. A technique with tactic
|
||||
# "persistence, privilege-escalation" is counted in both.
|
||||
tactic_data: dict[str, dict[str, int]] = defaultdict(
|
||||
# Entry: lambda
|
||||
lambda: {s.value: 0 for s in TechniqueStatus}
|
||||
)
|
||||
|
||||
# Iterate over techniques
|
||||
for tactic_str, status in techniques:
|
||||
# Check: not tactic_str
|
||||
if not tactic_str:
|
||||
# Assign tactics = ["unknown"]
|
||||
tactics = ["unknown"]
|
||||
# Fallback: handle remaining cases
|
||||
else:
|
||||
# Assign tactics = [t.strip() for t in tactic_str.split(",")]
|
||||
tactics = [t.strip() for t in tactic_str.split(",")]
|
||||
|
||||
# Iterate over tactics
|
||||
for tactic in tactics:
|
||||
# Assign tactic_data[tactic][status.value] = 1
|
||||
tactic_data[tactic][status.value] += 1
|
||||
|
||||
# Assign result = []
|
||||
result = []
|
||||
# Iterate over sorted(tactic_data)
|
||||
for tactic in sorted(tactic_data):
|
||||
# Assign counts = tactic_data[tactic]
|
||||
counts = tactic_data[tactic]
|
||||
# Assign total = sum(counts.values())
|
||||
total = sum(counts.values())
|
||||
# Call result.append()
|
||||
result.append(
|
||||
TacticCoverage(
|
||||
# Keyword argument: tactic
|
||||
tactic=tactic,
|
||||
# Keyword argument: total
|
||||
total=total,
|
||||
# Keyword argument: validated
|
||||
validated=counts["validated"],
|
||||
# Keyword argument: partial
|
||||
partial=counts["partial"],
|
||||
# Keyword argument: not_covered
|
||||
not_covered=counts["not_covered"],
|
||||
# Keyword argument: not_evaluated
|
||||
not_evaluated=counts["not_evaluated"],
|
||||
# Keyword argument: in_progress
|
||||
in_progress=counts["in_progress"],
|
||||
)
|
||||
)
|
||||
|
||||
# Return result
|
||||
return result
|
||||
|
||||
|
||||
# Define function get_test_pipeline_counts
|
||||
def get_test_pipeline_counts(db: Session) -> TestPipelineCounts:
|
||||
"""Return how many tests are in each pipeline state."""
|
||||
# Assign rows = (
|
||||
rows = (
|
||||
db.query(Test.state, func.count(Test.id).label("cnt"))
|
||||
# Chain .group_by() call
|
||||
.group_by(Test.state)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
|
||||
# Assign state_counts = {s.value: 0 for s in TestState}
|
||||
state_counts: dict[str, int] = {s.value: 0 for s in TestState}
|
||||
# Iterate over rows
|
||||
for state, cnt in rows:
|
||||
# Assign state_counts[state.value] = cnt
|
||||
state_counts[state.value] = cnt
|
||||
|
||||
# Assign total = sum(state_counts.values())
|
||||
total = sum(state_counts.values())
|
||||
|
||||
# Return TestPipelineCounts(
|
||||
return TestPipelineCounts(
|
||||
# Keyword argument: draft
|
||||
draft=state_counts["draft"],
|
||||
# Keyword argument: red_executing
|
||||
red_executing=state_counts["red_executing"],
|
||||
# Keyword argument: blue_evaluating
|
||||
blue_evaluating=state_counts["blue_evaluating"],
|
||||
# Keyword argument: in_review
|
||||
in_review=state_counts["in_review"],
|
||||
# Keyword argument: validated
|
||||
validated=state_counts["validated"],
|
||||
# Keyword argument: rejected
|
||||
rejected=state_counts["rejected"],
|
||||
# Keyword argument: total
|
||||
total=total,
|
||||
)
|
||||
|
||||
|
||||
# Define function get_team_activity
|
||||
def get_team_activity(db: Session) -> list[TeamActivity]:
|
||||
"""Return activity summary for Red and Blue teams."""
|
||||
# Red Team: completed = tests past red_executing; pending = draft + red_executing
|
||||
red_completed = (
|
||||
db.query(func.count(Test.id))
|
||||
# Chain .filter() call
|
||||
.filter(Test.state.in_([
|
||||
TestState.blue_evaluating,
|
||||
TestState.in_review,
|
||||
TestState.validated,
|
||||
TestState.rejected,
|
||||
]))
|
||||
# Chain .scalar() call
|
||||
.scalar()
|
||||
) or 0
|
||||
|
||||
# Assign red_pending = (
|
||||
red_pending = (
|
||||
db.query(func.count(Test.id))
|
||||
# Chain .filter() call
|
||||
.filter(Test.state.in_([TestState.draft, TestState.red_executing]))
|
||||
# Chain .scalar() call
|
||||
.scalar()
|
||||
) or 0
|
||||
|
||||
# Blue Team: completed = tests past blue_evaluating; pending = blue_evaluating
|
||||
blue_completed = (
|
||||
db.query(func.count(Test.id))
|
||||
# Chain .filter() call
|
||||
.filter(Test.state.in_([
|
||||
TestState.in_review,
|
||||
TestState.validated,
|
||||
TestState.rejected,
|
||||
]))
|
||||
# Chain .scalar() call
|
||||
.scalar()
|
||||
) or 0
|
||||
|
||||
# Assign blue_pending = (
|
||||
blue_pending = (
|
||||
db.query(func.count(Test.id))
|
||||
# Chain .filter() call
|
||||
.filter(Test.state == TestState.blue_evaluating)
|
||||
# Chain .scalar() call
|
||||
.scalar()
|
||||
) or 0
|
||||
|
||||
# Return [
|
||||
return [
|
||||
TeamActivity(
|
||||
# Keyword argument: team
|
||||
team="Red Team",
|
||||
# Keyword argument: tests_completed
|
||||
tests_completed=red_completed,
|
||||
# Keyword argument: tests_pending
|
||||
tests_pending=red_pending,
|
||||
),
|
||||
TeamActivity(
|
||||
# Keyword argument: team
|
||||
team="Blue Team",
|
||||
# Keyword argument: tests_completed
|
||||
tests_completed=blue_completed,
|
||||
# Keyword argument: tests_pending
|
||||
tests_pending=blue_pending,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# Define function get_validation_rate
|
||||
def get_validation_rate(db: Session) -> list[ValidationRate]:
|
||||
"""Return approval and rejection rates for Red Lead and Blue Lead."""
|
||||
# Red Lead validations
|
||||
red_approved = (
|
||||
db.query(func.count(Test.id))
|
||||
# Chain .filter() call
|
||||
.filter(Test.red_validation_status == "approved")
|
||||
# Chain .scalar() call
|
||||
.scalar()
|
||||
) or 0
|
||||
# Assign red_rejected = (
|
||||
red_rejected = (
|
||||
db.query(func.count(Test.id))
|
||||
# Chain .filter() call
|
||||
.filter(Test.red_validation_status == "rejected")
|
||||
# Chain .scalar() call
|
||||
.scalar()
|
||||
) or 0
|
||||
# Assign red_total = red_approved + red_rejected
|
||||
red_total = red_approved + red_rejected
|
||||
# Assign red_rate = round(red_approved / red_total * 100, 1) if red_total > 0 else 0.0
|
||||
red_rate = round(red_approved / red_total * 100, 1) if red_total > 0 else 0.0
|
||||
|
||||
# Blue Lead validations
|
||||
blue_approved = (
|
||||
db.query(func.count(Test.id))
|
||||
# Chain .filter() call
|
||||
.filter(Test.blue_validation_status == "approved")
|
||||
# Chain .scalar() call
|
||||
.scalar()
|
||||
) or 0
|
||||
# Assign blue_rejected = (
|
||||
blue_rejected = (
|
||||
db.query(func.count(Test.id))
|
||||
# Chain .filter() call
|
||||
.filter(Test.blue_validation_status == "rejected")
|
||||
# Chain .scalar() call
|
||||
.scalar()
|
||||
) or 0
|
||||
# Assign blue_total = blue_approved + blue_rejected
|
||||
blue_total = blue_approved + blue_rejected
|
||||
# Assign blue_rate = round(blue_approved / blue_total * 100, 1) if blue_total > 0 else 0.0
|
||||
blue_rate = round(blue_approved / blue_total * 100, 1) if blue_total > 0 else 0.0
|
||||
|
||||
# Return [
|
||||
return [
|
||||
ValidationRate(
|
||||
# Keyword argument: role
|
||||
role="red_lead",
|
||||
# Keyword argument: total_reviewed
|
||||
total_reviewed=red_total,
|
||||
# Keyword argument: approved
|
||||
approved=red_approved,
|
||||
# Keyword argument: rejected
|
||||
rejected=red_rejected,
|
||||
# Keyword argument: approval_rate
|
||||
approval_rate=red_rate,
|
||||
),
|
||||
ValidationRate(
|
||||
# Keyword argument: role
|
||||
role="blue_lead",
|
||||
# Keyword argument: total_reviewed
|
||||
total_reviewed=blue_total,
|
||||
# Keyword argument: approved
|
||||
approved=blue_approved,
|
||||
# Keyword argument: rejected
|
||||
rejected=blue_rejected,
|
||||
# Keyword argument: approval_rate
|
||||
approval_rate=blue_rate,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# Define function get_recent_tests
|
||||
def get_recent_tests(db: Session, *, limit: int = 10) -> list[RecentTestItem]:
|
||||
"""Return the most recently created tests."""
|
||||
from sqlalchemy import nullslast
|
||||
tests = (
|
||||
db.query(Test)
|
||||
# Chain .options() call
|
||||
.options(joinedload(Test.technique))
|
||||
.order_by(nullslast(Test.created_at.desc()))
|
||||
.limit(limit)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
|
||||
# Return [
|
||||
return [
|
||||
RecentTestItem(
|
||||
# Keyword argument: id
|
||||
id=str(t.id),
|
||||
# Keyword argument: name
|
||||
name=t.name,
|
||||
# Keyword argument: state
|
||||
state=t.state.value,
|
||||
# Keyword argument: technique_mitre_id
|
||||
technique_mitre_id=t.technique.mitre_id if t.technique else None,
|
||||
# Keyword argument: technique_name
|
||||
technique_name=t.technique.name if t.technique else None,
|
||||
# Keyword argument: created_at
|
||||
created_at=t.created_at,
|
||||
)
|
||||
for t in tests
|
||||
|
||||
@@ -6,123 +6,197 @@ ATT&CK collection, and upserts attack-pattern objects into the local
|
||||
when the TAXII server is unreachable.
|
||||
"""
|
||||
|
||||
# Import logging
|
||||
import logging
|
||||
|
||||
# Import datetime from datetime
|
||||
from datetime import datetime
|
||||
|
||||
# Import requests
|
||||
import requests as _requests
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import Server as TaxiiServer from taxii2client.v20
|
||||
from taxii2client.v20 import Server as TaxiiServer
|
||||
|
||||
from app.models.technique import Technique
|
||||
# Import TechniqueStatus from app.models.enums
|
||||
from app.models.enums import TechniqueStatus
|
||||
|
||||
# Import Technique from app.models.technique
|
||||
from app.models.technique import Technique
|
||||
|
||||
# Import log_action from app.services.audit_service
|
||||
from app.services.audit_service import log_action
|
||||
|
||||
# Assign logger = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Assign TAXII_SERVER_URL = "https://cti-taxii.mitre.org/taxii/"
|
||||
TAXII_SERVER_URL = "https://cti-taxii.mitre.org/taxii/"
|
||||
# Assign MITRE_SOURCE_NAME = "mitre-attack"
|
||||
MITRE_SOURCE_NAME = "mitre-attack"
|
||||
# Assign GITHUB_ENTERPRISE_URL = (
|
||||
GITHUB_ENTERPRISE_URL = (
|
||||
# Literal argument value
|
||||
"https://raw.githubusercontent.com/mitre/cti/master/"
|
||||
# Literal argument value
|
||||
"enterprise-attack/enterprise-attack.json"
|
||||
)
|
||||
|
||||
|
||||
# Define function _extract_mitre_id
|
||||
def _extract_mitre_id(external_references: list) -> str | None:
|
||||
"""Return the MITRE ATT&CK ID (e.g. ``T1059.001``) from external_references."""
|
||||
# Check: not external_references
|
||||
if not external_references:
|
||||
# Return None
|
||||
return None
|
||||
# Iterate over external_references
|
||||
for ref in external_references:
|
||||
# Check: ref.get("source_name") == MITRE_SOURCE_NAME
|
||||
if ref.get("source_name") == MITRE_SOURCE_NAME:
|
||||
# Return ref.get("external_id")
|
||||
return ref.get("external_id")
|
||||
# Return None
|
||||
return None
|
||||
|
||||
|
||||
# Define function _extract_tactics
|
||||
def _extract_tactics(kill_chain_phases: list) -> str | None:
|
||||
"""Return a comma-separated string of tactic phase names."""
|
||||
# Check: not kill_chain_phases
|
||||
if not kill_chain_phases:
|
||||
# Return None
|
||||
return None
|
||||
# Assign tactics = [
|
||||
tactics = [
|
||||
phase.get("phase_name")
|
||||
for phase in kill_chain_phases
|
||||
if phase.get("kill_chain_name") == "mitre-attack"
|
||||
]
|
||||
# Return ", ".join(tactics) if tactics else None
|
||||
return ", ".join(tactics) if tactics else None
|
||||
|
||||
|
||||
# Define function _extract_platforms
|
||||
def _extract_platforms(stix_object: dict) -> list:
|
||||
"""Return the list of platforms from the STIX object."""
|
||||
# Return stix_object.get("x_mitre_platforms", [])
|
||||
return stix_object.get("x_mitre_platforms", [])
|
||||
|
||||
|
||||
# Define function _extract_version
|
||||
def _extract_version(stix_object: dict) -> str | None:
|
||||
"""Return the MITRE ATT&CK version string."""
|
||||
# Return stix_object.get("x_mitre_version")
|
||||
return stix_object.get("x_mitre_version")
|
||||
|
||||
|
||||
# Define function _extract_last_modified
|
||||
def _extract_last_modified(stix_object: dict) -> datetime | None:
|
||||
"""Return the ``modified`` timestamp as a datetime, or None."""
|
||||
# Assign modified = stix_object.get("modified")
|
||||
modified = stix_object.get("modified")
|
||||
# Check: modified is None
|
||||
if modified is None:
|
||||
# Return None
|
||||
return None
|
||||
# Check: isinstance(modified, datetime)
|
||||
if isinstance(modified, datetime):
|
||||
# Return modified
|
||||
return modified
|
||||
# Attempt the following; catch errors below
|
||||
try:
|
||||
# Return datetime.fromisoformat(modified.replace("Z", "+00:00"))
|
||||
return datetime.fromisoformat(modified.replace("Z", "+00:00"))
|
||||
# Handle (ValueError, AttributeError)
|
||||
except (ValueError, AttributeError):
|
||||
# Return None
|
||||
return None
|
||||
|
||||
|
||||
# Define function _fetch_attack_patterns_taxii
|
||||
def _fetch_attack_patterns_taxii() -> list[dict]:
|
||||
"""Connect to the MITRE TAXII server and return all attack-pattern objects."""
|
||||
# Log info: "Connecting to MITRE TAXII server at %s", TAXII_SE
|
||||
logger.info("Connecting to MITRE TAXII server at %s", TAXII_SERVER_URL)
|
||||
# Assign server = TaxiiServer(TAXII_SERVER_URL)
|
||||
server = TaxiiServer(TAXII_SERVER_URL)
|
||||
|
||||
# Assign api_root = server.api_roots[0]
|
||||
api_root = server.api_roots[0]
|
||||
# Assign collection = api_root.collections[0] # Enterprise ATT&CK
|
||||
collection = api_root.collections[0] # Enterprise ATT&CK
|
||||
|
||||
# Log info:
|
||||
logger.info(
|
||||
# Literal argument value
|
||||
"Fetching objects from collection '%s' (id=%s)",
|
||||
collection.title,
|
||||
collection.id,
|
||||
)
|
||||
|
||||
# Assign bundle = collection.get_objects()
|
||||
bundle = collection.get_objects()
|
||||
# Assign objects = bundle.get("objects", [])
|
||||
objects = bundle.get("objects", [])
|
||||
|
||||
# Assign attack_patterns = [
|
||||
attack_patterns = [
|
||||
obj for obj in objects if obj.get("type") == "attack-pattern"
|
||||
]
|
||||
# Log info: "Retrieved %d attack-pattern objects via TAXII", l
|
||||
logger.info("Retrieved %d attack-pattern objects via TAXII", len(attack_patterns))
|
||||
# Return attack_patterns
|
||||
return attack_patterns
|
||||
|
||||
|
||||
# Define function _fetch_attack_patterns_github
|
||||
def _fetch_attack_patterns_github() -> list[dict]:
|
||||
"""Fallback: fetch Enterprise ATT&CK bundle from the MITRE CTI GitHub repo."""
|
||||
# Log info: "Fetching Enterprise ATT&CK bundle from GitHub (%s
|
||||
logger.info("Fetching Enterprise ATT&CK bundle from GitHub (%s)", GITHUB_ENTERPRISE_URL)
|
||||
# Assign resp = _requests.get(GITHUB_ENTERPRISE_URL, timeout=120)
|
||||
resp = _requests.get(GITHUB_ENTERPRISE_URL, timeout=120)
|
||||
# Call resp.raise_for_status()
|
||||
resp.raise_for_status()
|
||||
# Assign bundle = resp.json()
|
||||
bundle = resp.json()
|
||||
# Assign objects = bundle.get("objects", [])
|
||||
objects = bundle.get("objects", [])
|
||||
|
||||
# Assign attack_patterns = [
|
||||
attack_patterns = [
|
||||
obj for obj in objects if obj.get("type") == "attack-pattern"
|
||||
]
|
||||
# Log info: "Retrieved %d attack-pattern objects via GitHub",
|
||||
logger.info("Retrieved %d attack-pattern objects via GitHub", len(attack_patterns))
|
||||
# Return attack_patterns
|
||||
return attack_patterns
|
||||
|
||||
|
||||
# Define function _fetch_attack_patterns
|
||||
def _fetch_attack_patterns() -> list[dict]:
|
||||
"""Return all attack-pattern objects, trying TAXII first then GitHub."""
|
||||
# Attempt the following; catch errors below
|
||||
try:
|
||||
# Return _fetch_attack_patterns_taxii()
|
||||
return _fetch_attack_patterns_taxii()
|
||||
# Handle Exception
|
||||
except Exception as exc:
|
||||
# Log warning:
|
||||
logger.warning(
|
||||
# Literal argument value
|
||||
"TAXII server unavailable (%s), falling back to GitHub mirror",
|
||||
exc,
|
||||
)
|
||||
# Return _fetch_attack_patterns_github()
|
||||
return _fetch_attack_patterns_github()
|
||||
|
||||
|
||||
# Define function sync_mitre
|
||||
def sync_mitre(db: Session) -> dict:
|
||||
"""Synchronize MITRE ATT&CK techniques into the local database.
|
||||
|
||||
@@ -131,11 +205,12 @@ def sync_mitre(db: Session) -> dict:
|
||||
db : Session
|
||||
Active SQLAlchemy database session.
|
||||
|
||||
Returns
|
||||
Returns:
|
||||
-------
|
||||
dict
|
||||
Summary with keys ``created``, ``updated``, ``unchanged``, ``skipped``.
|
||||
"""
|
||||
# Assign attack_patterns = _fetch_attack_patterns()
|
||||
attack_patterns = _fetch_attack_patterns()
|
||||
|
||||
# Pre-load existing techniques keyed by mitre_id for fast lookup
|
||||
@@ -143,90 +218,149 @@ def sync_mitre(db: Session) -> dict:
|
||||
t.mitre_id: t for t in db.query(Technique).all()
|
||||
}
|
||||
|
||||
# Assign created = 0
|
||||
created = 0
|
||||
# Assign updated = 0
|
||||
updated = 0
|
||||
# Assign unchanged = 0
|
||||
unchanged = 0
|
||||
# Assign skipped = 0
|
||||
skipped = 0
|
||||
|
||||
# Iterate over attack_patterns
|
||||
for obj in attack_patterns:
|
||||
# ------------------------------------------------------------------
|
||||
# Skip revoked / deprecated objects
|
||||
# ------------------------------------------------------------------
|
||||
if obj.get("revoked", False) or obj.get("x_mitre_deprecated", False):
|
||||
# Assign skipped = 1
|
||||
skipped += 1
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
|
||||
# Assign mitre_id = _extract_mitre_id(obj.get("external_references", []))
|
||||
mitre_id = _extract_mitre_id(obj.get("external_references", []))
|
||||
# Check: not mitre_id
|
||||
if not mitre_id:
|
||||
# Assign skipped = 1
|
||||
skipped += 1
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
|
||||
# Assign name = obj.get("name", "")
|
||||
name = obj.get("name", "")
|
||||
# Assign description = obj.get("description", "")
|
||||
description = obj.get("description", "")
|
||||
# Assign tactic = _extract_tactics(obj.get("kill_chain_phases", []))
|
||||
tactic = _extract_tactics(obj.get("kill_chain_phases", []))
|
||||
# Assign platforms = _extract_platforms(obj)
|
||||
platforms = _extract_platforms(obj)
|
||||
# Assign version = _extract_version(obj)
|
||||
version = _extract_version(obj)
|
||||
# Assign last_modified = _extract_last_modified(obj)
|
||||
last_modified = _extract_last_modified(obj)
|
||||
# Assign is_subtechnique = "." in mitre_id
|
||||
is_subtechnique = "." in mitre_id
|
||||
# Assign parent_mitre_id = mitre_id.split(".")[0] if is_subtechnique else None
|
||||
parent_mitre_id = mitre_id.split(".")[0] if is_subtechnique else None
|
||||
|
||||
# Assign existing = existing_techniques.get(mitre_id)
|
||||
existing = existing_techniques.get(mitre_id)
|
||||
|
||||
# Check: existing is None
|
||||
if existing is None:
|
||||
# ---- Create new technique ----
|
||||
technique = Technique(
|
||||
# Keyword argument: mitre_id
|
||||
mitre_id=mitre_id,
|
||||
# Keyword argument: name
|
||||
name=name,
|
||||
# Keyword argument: description
|
||||
description=description,
|
||||
# Keyword argument: tactic
|
||||
tactic=tactic,
|
||||
# Keyword argument: platforms
|
||||
platforms=platforms,
|
||||
# Keyword argument: mitre_version
|
||||
mitre_version=version,
|
||||
# Keyword argument: mitre_last_modified
|
||||
mitre_last_modified=last_modified,
|
||||
# Keyword argument: is_subtechnique
|
||||
is_subtechnique=is_subtechnique,
|
||||
# Keyword argument: parent_mitre_id
|
||||
parent_mitre_id=parent_mitre_id,
|
||||
# Keyword argument: status_global
|
||||
status_global=TechniqueStatus.not_evaluated,
|
||||
# Keyword argument: review_required
|
||||
review_required=False,
|
||||
)
|
||||
# Stage new record(s) for database insertion
|
||||
db.add(technique)
|
||||
# Assign existing_techniques[mitre_id] = technique
|
||||
existing_techniques[mitre_id] = technique
|
||||
# Assign created = 1
|
||||
created += 1
|
||||
# Fallback: handle remaining cases
|
||||
else:
|
||||
# ---- Update if name or description changed ----
|
||||
changes = False
|
||||
|
||||
# Check: existing.name != name
|
||||
if existing.name != name:
|
||||
# Assign existing.name = name
|
||||
existing.name = name
|
||||
# Assign changes = True
|
||||
changes = True
|
||||
|
||||
# Check: (existing.description or "") != (description or "")
|
||||
if (existing.description or "") != (description or ""):
|
||||
# Assign existing.description = description
|
||||
existing.description = description
|
||||
# Assign changes = True
|
||||
changes = True
|
||||
|
||||
# Always keep metadata up-to-date (does not trigger review)
|
||||
existing.tactic = tactic
|
||||
# Assign existing.platforms = platforms
|
||||
existing.platforms = platforms
|
||||
# Assign existing.mitre_version = version
|
||||
existing.mitre_version = version
|
||||
# Assign existing.mitre_last_modified = last_modified
|
||||
existing.mitre_last_modified = last_modified
|
||||
# Assign existing.is_subtechnique = is_subtechnique
|
||||
existing.is_subtechnique = is_subtechnique
|
||||
# Assign existing.parent_mitre_id = parent_mitre_id
|
||||
existing.parent_mitre_id = parent_mitre_id
|
||||
|
||||
# Check: changes
|
||||
if changes:
|
||||
# Assign existing.review_required = True
|
||||
existing.review_required = True
|
||||
# Assign updated = 1
|
||||
updated += 1
|
||||
# Fallback: handle remaining cases
|
||||
else:
|
||||
# Assign unchanged = 1
|
||||
unchanged += 1
|
||||
|
||||
# Single commit for the whole batch
|
||||
db.commit()
|
||||
|
||||
# Assign summary = {
|
||||
summary = {
|
||||
# Literal argument value
|
||||
"created": created,
|
||||
# Literal argument value
|
||||
"updated": updated,
|
||||
# Literal argument value
|
||||
"unchanged": unchanged,
|
||||
# Literal argument value
|
||||
"skipped": skipped,
|
||||
}
|
||||
|
||||
# Log info:
|
||||
logger.info(
|
||||
# Literal argument value
|
||||
"MITRE sync complete — created=%d, updated=%d, unchanged=%d, skipped=%d",
|
||||
created,
|
||||
updated,
|
||||
@@ -237,12 +371,19 @@ def sync_mitre(db: Session) -> dict:
|
||||
# Audit log (system action → user_id=None)
|
||||
log_action(
|
||||
db,
|
||||
# Keyword argument: user_id
|
||||
user_id=None,
|
||||
# Keyword argument: action
|
||||
action="mitre_sync",
|
||||
# Keyword argument: entity_type
|
||||
entity_type="technique",
|
||||
# Keyword argument: entity_id
|
||||
entity_id=None,
|
||||
# Keyword argument: details
|
||||
details=summary,
|
||||
)
|
||||
# Commit all pending changes to the database
|
||||
db.commit()
|
||||
|
||||
# Return summary
|
||||
return summary
|
||||
|
||||
@@ -7,16 +7,29 @@ Functions in this module stage changes via ``db.add()`` / ``db.flush()``
|
||||
but do **not** commit. The caller is responsible for committing.
|
||||
"""
|
||||
|
||||
# Import uuid
|
||||
import uuid
|
||||
|
||||
# Import datetime, timedelta from datetime
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
# Import func from sqlalchemy
|
||||
from sqlalchemy import func
|
||||
|
||||
from app.domain.errors import EntityNotFoundError
|
||||
from app.models.notification import Notification
|
||||
from app.models.user import User
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import EntityNotFoundError from app.domain.errors
|
||||
from app.domain.errors import EntityNotFoundError
|
||||
|
||||
# Import Notification from app.models.notification
|
||||
from app.models.notification import Notification
|
||||
|
||||
# Import Test from app.models.test
|
||||
from app.models.test import Test
|
||||
|
||||
# Import User from app.models.user
|
||||
from app.models.user import User
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Core CRUD
|
||||
@@ -24,132 +37,209 @@ from app.models.user import User
|
||||
|
||||
|
||||
def list_notifications(
|
||||
# Entry: db
|
||||
db: Session,
|
||||
# Entry: user_id
|
||||
user_id: uuid.UUID,
|
||||
*,
|
||||
# Entry: offset
|
||||
offset: int = 0,
|
||||
# Entry: limit
|
||||
limit: int = 20,
|
||||
) -> list[Notification]:
|
||||
"""Return paginated notifications for a user, newest first."""
|
||||
# Return (
|
||||
return (
|
||||
db.query(Notification)
|
||||
# Chain .filter() call
|
||||
.filter(Notification.user_id == user_id)
|
||||
# Chain .order_by() call
|
||||
.order_by(Notification.created_at.desc())
|
||||
# Chain .offset() call
|
||||
.offset(offset)
|
||||
# Chain .limit() call
|
||||
.limit(limit)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
# Define function get_notification_or_raise
|
||||
def get_notification_or_raise(
|
||||
# Entry: db
|
||||
db: Session,
|
||||
# Entry: notification_id
|
||||
notification_id: uuid.UUID,
|
||||
# Entry: user_id
|
||||
user_id: uuid.UUID,
|
||||
) -> Notification:
|
||||
"""Fetch a notification by ID and user, or raise EntityNotFoundError."""
|
||||
# Assign notif = (
|
||||
notif = (
|
||||
db.query(Notification)
|
||||
# Chain .filter() call
|
||||
.filter(
|
||||
Notification.id == notification_id,
|
||||
Notification.user_id == user_id,
|
||||
)
|
||||
# Chain .first() call
|
||||
.first()
|
||||
)
|
||||
# Check: notif is None
|
||||
if notif is None:
|
||||
# Raise EntityNotFoundError
|
||||
raise EntityNotFoundError("Notification", str(notification_id))
|
||||
# Return notif
|
||||
return notif
|
||||
|
||||
|
||||
# Define function notify_role
|
||||
def notify_role(
|
||||
# Entry: db
|
||||
db: Session,
|
||||
*,
|
||||
# Entry: role
|
||||
role: str,
|
||||
# Entry: type
|
||||
type: str,
|
||||
# Entry: title
|
||||
title: str,
|
||||
# Entry: message
|
||||
message: str,
|
||||
# Entry: entity_type
|
||||
entity_type: str,
|
||||
# Entry: entity_id
|
||||
entity_id: uuid.UUID,
|
||||
) -> None:
|
||||
"""Send notifications to all active users with a given role."""
|
||||
# Assign users = (
|
||||
users = (
|
||||
db.query(User)
|
||||
# Chain .filter() call
|
||||
.filter(User.role == role, User.is_active == True) # noqa: E712
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
# Iterate over users
|
||||
for user in users:
|
||||
# Call create_notification()
|
||||
create_notification(
|
||||
db,
|
||||
# Keyword argument: user_id
|
||||
user_id=user.id,
|
||||
# Keyword argument: type
|
||||
type=type,
|
||||
# Keyword argument: title
|
||||
title=title,
|
||||
# Keyword argument: message
|
||||
message=message,
|
||||
# Keyword argument: entity_type
|
||||
entity_type=entity_type,
|
||||
# Keyword argument: entity_id
|
||||
entity_id=entity_id,
|
||||
)
|
||||
|
||||
|
||||
# Define function create_notification
|
||||
def create_notification(
|
||||
# Entry: db
|
||||
db: Session,
|
||||
# Entry: user_id
|
||||
user_id: uuid.UUID,
|
||||
# Entry: type
|
||||
type: str,
|
||||
# Entry: title
|
||||
title: str,
|
||||
# Entry: message
|
||||
message: str | None = None,
|
||||
# Entry: entity_type
|
||||
entity_type: str | None = None,
|
||||
# Entry: entity_id
|
||||
entity_id: uuid.UUID | None = None,
|
||||
) -> Notification:
|
||||
"""Create a single notification for a user."""
|
||||
# Assign notif = Notification(
|
||||
notif = Notification(
|
||||
# Keyword argument: user_id
|
||||
user_id=user_id,
|
||||
# Keyword argument: type
|
||||
type=type,
|
||||
# Keyword argument: title
|
||||
title=title,
|
||||
# Keyword argument: message
|
||||
message=message,
|
||||
# Keyword argument: entity_type
|
||||
entity_type=entity_type,
|
||||
# Keyword argument: entity_id
|
||||
entity_id=entity_id,
|
||||
)
|
||||
# Stage new record(s) for database insertion
|
||||
db.add(notif)
|
||||
# Flush changes to DB without committing the transaction
|
||||
db.flush()
|
||||
# Return notif
|
||||
return notif
|
||||
|
||||
|
||||
# Define function mark_as_read
|
||||
def mark_as_read(
|
||||
# Entry: db
|
||||
db: Session, notification_id: uuid.UUID, user_id: uuid.UUID
|
||||
) -> Notification:
|
||||
"""Mark a single notification as read. Returns the notification. Raises EntityNotFoundError if not found."""
|
||||
# Assign notif = get_notification_or_raise(db, notification_id, user_id)
|
||||
notif = get_notification_or_raise(db, notification_id, user_id)
|
||||
# Assign notif.read = True
|
||||
notif.read = True
|
||||
# Return notif
|
||||
return notif
|
||||
|
||||
|
||||
# Define function mark_all_as_read
|
||||
def mark_all_as_read(db: Session, user_id: uuid.UUID) -> int:
|
||||
"""Mark all unread notifications for a user as read. Returns count updated."""
|
||||
# Assign count = (
|
||||
count = (
|
||||
db.query(Notification)
|
||||
# Chain .filter() call
|
||||
.filter(Notification.user_id == user_id, Notification.read == False) # noqa: E712
|
||||
# Chain .update() call
|
||||
.update({"read": True})
|
||||
)
|
||||
# Return count
|
||||
return count
|
||||
|
||||
|
||||
# Define function get_unread_count
|
||||
def get_unread_count(db: Session, user_id: uuid.UUID) -> int:
|
||||
"""Return the number of unread notifications for a user."""
|
||||
# Return (
|
||||
return (
|
||||
db.query(func.count(Notification.id))
|
||||
# Chain .filter() call
|
||||
.filter(Notification.user_id == user_id, Notification.read == False) # noqa: E712
|
||||
# Chain .scalar() call
|
||||
.scalar()
|
||||
) or 0
|
||||
|
||||
|
||||
# Define function cleanup_old_notifications
|
||||
def cleanup_old_notifications(db: Session, days: int = 90) -> int:
|
||||
"""Delete read notifications older than *days*. Returns count deleted."""
|
||||
# Assign cutoff = datetime.utcnow() - timedelta(days=days)
|
||||
cutoff = datetime.utcnow() - timedelta(days=days)
|
||||
# Assign count = (
|
||||
count = (
|
||||
db.query(Notification)
|
||||
# Chain .filter() call
|
||||
.filter(
|
||||
Notification.read == True, # noqa: E712
|
||||
Notification.created_at < cutoff,
|
||||
)
|
||||
# Chain .delete() call
|
||||
.delete()
|
||||
)
|
||||
# Return count
|
||||
return count
|
||||
|
||||
|
||||
@@ -204,71 +294,118 @@ def notify_test_state_change(db: Session, test, new_state: str) -> None:
|
||||
- rejected -> notify creator
|
||||
- validated -> notify creator
|
||||
"""
|
||||
# Assign test_name = test.name
|
||||
test_name = test.name
|
||||
# Assign test_id = test.id
|
||||
test_id = test.id
|
||||
# Assign creator_id = test.created_by
|
||||
creator_id = test.created_by
|
||||
|
||||
# Check: new_state == "red_executing" and creator_id
|
||||
if new_state == "red_executing" and creator_id:
|
||||
# Call create_notification()
|
||||
create_notification(
|
||||
db,
|
||||
# Keyword argument: user_id
|
||||
user_id=creator_id,
|
||||
# Keyword argument: type
|
||||
type="test_state_changed",
|
||||
# Keyword argument: title
|
||||
title="Test execution started",
|
||||
# Keyword argument: message
|
||||
message=f'Your test "{test_name}" has moved to execution phase.',
|
||||
# Keyword argument: entity_type
|
||||
entity_type="test",
|
||||
# Keyword argument: entity_id
|
||||
entity_id=test_id,
|
||||
)
|
||||
|
||||
# Alternative: new_state == "blue_evaluating"
|
||||
elif new_state == "blue_evaluating":
|
||||
# Notify all blue_tech users
|
||||
blue_users = db.query(User).filter(User.role == "blue_tech", User.is_active == True).all() # noqa: E712
|
||||
# Iterate over blue_users
|
||||
for user in blue_users:
|
||||
# Call create_notification()
|
||||
create_notification(
|
||||
db,
|
||||
# Keyword argument: user_id
|
||||
user_id=user.id,
|
||||
# Keyword argument: type
|
||||
type="test_assigned",
|
||||
# Keyword argument: title
|
||||
title="New test ready for blue evaluation",
|
||||
# Keyword argument: message
|
||||
message=f'Test "{test_name}" needs blue team evaluation.',
|
||||
# Keyword argument: entity_type
|
||||
entity_type="test",
|
||||
# Keyword argument: entity_id
|
||||
entity_id=test_id,
|
||||
)
|
||||
|
||||
# Alternative: new_state == "in_review"
|
||||
elif new_state == "in_review":
|
||||
# Notify red_lead and blue_lead users
|
||||
managers = (
|
||||
db.query(User)
|
||||
# Chain .filter() call
|
||||
.filter(User.role.in_(["red_lead", "blue_lead"]), User.is_active == True) # noqa: E712
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
# Iterate over managers
|
||||
for user in managers:
|
||||
# Call create_notification()
|
||||
create_notification(
|
||||
db,
|
||||
# Keyword argument: user_id
|
||||
user_id=user.id,
|
||||
# Keyword argument: type
|
||||
type="validation_needed",
|
||||
# Keyword argument: title
|
||||
title="Test ready for validation",
|
||||
# Keyword argument: message
|
||||
message=f'Test "{test_name}" is awaiting your review.',
|
||||
# Keyword argument: entity_type
|
||||
entity_type="test",
|
||||
# Keyword argument: entity_id
|
||||
entity_id=test_id,
|
||||
)
|
||||
|
||||
# Alternative: new_state == "rejected" and creator_id
|
||||
elif new_state == "rejected" and creator_id:
|
||||
# Call create_notification()
|
||||
create_notification(
|
||||
db,
|
||||
# Keyword argument: user_id
|
||||
user_id=creator_id,
|
||||
# Keyword argument: type
|
||||
type="test_rejected",
|
||||
# Keyword argument: title
|
||||
title="Test rejected",
|
||||
# Keyword argument: message
|
||||
message=f'Your test "{test_name}" has been rejected. Please review and resubmit.',
|
||||
# Keyword argument: entity_type
|
||||
entity_type="test",
|
||||
# Keyword argument: entity_id
|
||||
entity_id=test_id,
|
||||
)
|
||||
|
||||
# Alternative: new_state == "validated" and creator_id
|
||||
elif new_state == "validated" and creator_id:
|
||||
# Call create_notification()
|
||||
create_notification(
|
||||
db,
|
||||
# Keyword argument: user_id
|
||||
user_id=creator_id,
|
||||
# Keyword argument: type
|
||||
type="test_validated",
|
||||
# Keyword argument: title
|
||||
title="Test validated",
|
||||
# Keyword argument: message
|
||||
message=f'Your test "{test_name}" has been validated successfully.',
|
||||
# Keyword argument: entity_type
|
||||
entity_type="test",
|
||||
# Keyword argument: entity_id
|
||||
entity_id=test_id,
|
||||
)
|
||||
|
||||
@@ -3,26 +3,45 @@
|
||||
Calculates security operations KPIs from test data and audit logs.
|
||||
"""
|
||||
|
||||
# Import datetime, timedelta from datetime
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
# Import Optional from typing
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import func, case, and_, or_, extract
|
||||
# Import func from sqlalchemy
|
||||
from sqlalchemy import func
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.test import Test
|
||||
from app.models.technique import Technique
|
||||
from app.models.test_detection_result import TestDetectionResult
|
||||
# Import AuditLog from app.models.audit
|
||||
from app.models.audit import AuditLog
|
||||
from app.models.enums import TestState, TestResult
|
||||
|
||||
# Import TestResult, TestState from app.models.enums
|
||||
from app.models.enums import TestResult, TestState
|
||||
|
||||
# Import Technique from app.models.technique
|
||||
from app.models.technique import Technique
|
||||
|
||||
# Import Test from app.models.test
|
||||
from app.models.test import Test
|
||||
|
||||
# Import TestDetectionResult from app.models.test_detection_result
|
||||
from app.models.test_detection_result import TestDetectionResult
|
||||
|
||||
|
||||
# Define function _safe_stats
|
||||
def _safe_stats(values: list[float]) -> dict:
|
||||
"""Compute mean, median, min, max from a list of floats (in hours).
|
||||
For sub-hour averages, mean_hours is stored as minutes to avoid
|
||||
rounding to 0.0 which is falsy in JavaScript."""
|
||||
if not values:
|
||||
# Return None
|
||||
return None
|
||||
# Assign sorted_vals = sorted(values)
|
||||
sorted_vals = sorted(values)
|
||||
# Assign n = len(sorted_vals)
|
||||
n = len(sorted_vals)
|
||||
mean = sum(sorted_vals) / n
|
||||
# Use minutes for sub-hour values to avoid JS falsy 0.0
|
||||
@@ -31,8 +50,11 @@ def _safe_stats(values: list[float]) -> dict:
|
||||
"mean_hours": mean_display,
|
||||
"unit": "min" if mean < 1 else "hrs",
|
||||
"median_hours": round(sorted_vals[n // 2], 1),
|
||||
# Literal argument value
|
||||
"min_hours": round(sorted_vals[0], 1),
|
||||
# Literal argument value
|
||||
"max_hours": round(sorted_vals[-1], 1),
|
||||
# Literal argument value
|
||||
"sample_size": n,
|
||||
}
|
||||
|
||||
@@ -59,6 +81,7 @@ def calculate_mttd(db: Session) -> Optional[dict]:
|
||||
.all()
|
||||
)
|
||||
|
||||
# Assign detection_times = []
|
||||
detection_times = []
|
||||
for t in tests:
|
||||
gross_secs = (t.blue_started_at - t.red_started_at).total_seconds()
|
||||
@@ -66,6 +89,7 @@ def calculate_mttd(db: Session) -> Optional[dict]:
|
||||
if net_secs > 0:
|
||||
detection_times.append(net_secs / 3600)
|
||||
|
||||
# Return _safe_stats(detection_times)
|
||||
return _safe_stats(detection_times)
|
||||
|
||||
|
||||
@@ -83,14 +107,17 @@ def calculate_mttr(db: Session) -> Optional[dict]:
|
||||
"""
|
||||
tests = (
|
||||
db.query(Test)
|
||||
# Chain .filter() call
|
||||
.filter(
|
||||
Test.state == TestState.validated,
|
||||
Test.red_started_at.isnot(None),
|
||||
Test.blue_validated_at.isnot(None),
|
||||
)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
|
||||
# Assign response_times = []
|
||||
response_times = []
|
||||
for t in tests:
|
||||
gross_secs = (t.blue_validated_at - t.red_started_at).total_seconds()
|
||||
@@ -99,6 +126,7 @@ def calculate_mttr(db: Session) -> Optional[dict]:
|
||||
if net_secs > 0:
|
||||
response_times.append(net_secs / 3600)
|
||||
|
||||
# Return _safe_stats(response_times)
|
||||
return _safe_stats(response_times)
|
||||
|
||||
|
||||
@@ -106,34 +134,63 @@ def calculate_mttr(db: Session) -> Optional[dict]:
|
||||
|
||||
|
||||
def calculate_detection_efficacy(db: Session) -> dict:
|
||||
"""Calculate detection efficacy: detected / total validated tests."""
|
||||
"""Calculate detection efficacy: detected / total validated tests.
|
||||
|
||||
Args:
|
||||
db (Session): Active SQLAlchemy database session.
|
||||
|
||||
Returns:
|
||||
dict: Contains ``percentage``, ``detected``, ``partially``,
|
||||
``not_detected``, and ``total``.
|
||||
"""
|
||||
# Assign validated_tests = (
|
||||
validated_tests = (
|
||||
db.query(Test)
|
||||
# Chain .filter() call
|
||||
.filter(Test.state == TestState.validated)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
|
||||
# Assign total = len(validated_tests)
|
||||
total = len(validated_tests)
|
||||
# Check: total == 0
|
||||
if total == 0:
|
||||
# Return {
|
||||
return {
|
||||
# Literal argument value
|
||||
"percentage": 0,
|
||||
# Literal argument value
|
||||
"detected": 0,
|
||||
# Literal argument value
|
||||
"partially": 0,
|
||||
# Literal argument value
|
||||
"not_detected": 0,
|
||||
# Literal argument value
|
||||
"total": 0,
|
||||
}
|
||||
|
||||
# Assign detected = len([t for t in validated_tests if t.detection_result == TestResult...
|
||||
detected = len([t for t in validated_tests if t.detection_result == TestResult.detected])
|
||||
# Assign partially = len([t for t in validated_tests if t.detection_result == TestResult...
|
||||
partially = len([t for t in validated_tests if t.detection_result == TestResult.partially_detected])
|
||||
# Assign not_detected = len([t for t in validated_tests if t.detection_result == TestResult...
|
||||
not_detected = len([t for t in validated_tests if t.detection_result == TestResult.not_detected])
|
||||
|
||||
# Assign percentage = round((detected / total) * 100, 1) if total > 0 else 0
|
||||
percentage = round((detected / total) * 100, 1) if total > 0 else 0
|
||||
|
||||
# Return {
|
||||
return {
|
||||
# Literal argument value
|
||||
"percentage": percentage,
|
||||
# Literal argument value
|
||||
"detected": detected,
|
||||
# Literal argument value
|
||||
"partially": partially,
|
||||
# Literal argument value
|
||||
"not_detected": not_detected,
|
||||
# Literal argument value
|
||||
"total": total,
|
||||
}
|
||||
|
||||
@@ -142,25 +199,45 @@ def calculate_detection_efficacy(db: Session) -> dict:
|
||||
|
||||
|
||||
def calculate_alert_fidelity(db: Session) -> dict:
|
||||
"""Calculate alert fidelity: ratio of triggered detection rules."""
|
||||
"""Calculate alert fidelity: ratio of triggered detection rules.
|
||||
|
||||
Args:
|
||||
db (Session): Active SQLAlchemy database session.
|
||||
|
||||
Returns:
|
||||
dict: Contains ``percentage``, ``triggered``, ``not_triggered``,
|
||||
and ``total_evaluated``.
|
||||
"""
|
||||
# Assign total_evaluated = (
|
||||
total_evaluated = (
|
||||
db.query(func.count(TestDetectionResult.id))
|
||||
# Chain .filter() call
|
||||
.filter(TestDetectionResult.triggered.isnot(None))
|
||||
# Chain .scalar() call
|
||||
.scalar()
|
||||
) or 0
|
||||
|
||||
# Assign triggered = (
|
||||
triggered = (
|
||||
db.query(func.count(TestDetectionResult.id))
|
||||
# Chain .filter() call
|
||||
.filter(TestDetectionResult.triggered == True)
|
||||
# Chain .scalar() call
|
||||
.scalar()
|
||||
) or 0
|
||||
|
||||
# Assign not_triggered = total_evaluated - triggered
|
||||
not_triggered = total_evaluated - triggered
|
||||
|
||||
# Return {
|
||||
return {
|
||||
# Literal argument value
|
||||
"percentage": round((triggered / total_evaluated) * 100, 1) if total_evaluated > 0 else 0,
|
||||
# Literal argument value
|
||||
"triggered": triggered,
|
||||
# Literal argument value
|
||||
"not_triggered": not_triggered,
|
||||
# Literal argument value
|
||||
"total_evaluated": total_evaluated,
|
||||
}
|
||||
|
||||
@@ -169,46 +246,78 @@ def calculate_alert_fidelity(db: Session) -> dict:
|
||||
|
||||
|
||||
def calculate_coverage_velocity(db: Session) -> dict:
|
||||
"""Calculate techniques validated per week."""
|
||||
"""Calculate techniques validated per week.
|
||||
|
||||
Args:
|
||||
db (Session): Active SQLAlchemy database session.
|
||||
|
||||
Returns:
|
||||
dict: Contains ``techniques_per_week`` (float average over the last
|
||||
12 weeks) and ``trend`` (``"improving"``, ``"stable"``, or
|
||||
``"declining"``).
|
||||
"""
|
||||
# Count techniques that changed to validated/partial in the last 12 weeks
|
||||
twelve_weeks_ago = datetime.utcnow() - timedelta(weeks=12)
|
||||
|
||||
# Assign weekly_counts = (
|
||||
weekly_counts = (
|
||||
db.query(
|
||||
func.date_trunc("week", Technique.last_review_date).label("week"),
|
||||
func.count(Technique.id).label("count"),
|
||||
)
|
||||
# Chain .filter() call
|
||||
.filter(
|
||||
Technique.last_review_date >= twelve_weeks_ago,
|
||||
Technique.last_review_date.isnot(None),
|
||||
)
|
||||
# Chain .group_by() call
|
||||
.group_by(func.date_trunc("week", Technique.last_review_date))
|
||||
# Chain .order_by() call
|
||||
.order_by("week")
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
|
||||
# Check: weekly_counts
|
||||
if weekly_counts:
|
||||
# Assign counts = [row.count for row in weekly_counts]
|
||||
counts = [row.count for row in weekly_counts]
|
||||
# Assign avg_per_week = round(sum(counts) / len(counts), 1)
|
||||
avg_per_week = round(sum(counts) / len(counts), 1)
|
||||
# Trend: compare last 4 weeks vs previous 4 weeks
|
||||
recent = counts[-4:] if len(counts) >= 4 else counts
|
||||
# Assign earlier = counts[-8:-4] if len(counts) >= 8 else counts[:len(counts) // 2] if...
|
||||
earlier = counts[-8:-4] if len(counts) >= 8 else counts[:len(counts) // 2] if counts else []
|
||||
|
||||
# Assign recent_avg = sum(recent) / len(recent) if recent else 0
|
||||
recent_avg = sum(recent) / len(recent) if recent else 0
|
||||
# Assign earlier_avg = sum(earlier) / len(earlier) if earlier else 0
|
||||
earlier_avg = sum(earlier) / len(earlier) if earlier else 0
|
||||
|
||||
# Check: recent_avg > earlier_avg * 1.1
|
||||
if recent_avg > earlier_avg * 1.1:
|
||||
# Assign trend = "improving"
|
||||
trend = "improving"
|
||||
# Alternative: recent_avg < earlier_avg * 0.9
|
||||
elif recent_avg < earlier_avg * 0.9:
|
||||
# Assign trend = "declining"
|
||||
trend = "declining"
|
||||
# Fallback: handle remaining cases
|
||||
else:
|
||||
# Assign trend = "stable"
|
||||
trend = "stable"
|
||||
# Fallback: handle remaining cases
|
||||
else:
|
||||
# Assign avg_per_week = 0
|
||||
avg_per_week = 0
|
||||
# Assign trend = "stable"
|
||||
trend = "stable"
|
||||
|
||||
# Return {
|
||||
return {
|
||||
# Literal argument value
|
||||
"techniques_per_week": avg_per_week,
|
||||
# Literal argument value
|
||||
"trend": trend,
|
||||
}
|
||||
|
||||
@@ -264,6 +373,7 @@ def calculate_validation_throughput(db: Session) -> dict:
|
||||
else:
|
||||
trend = "stable"
|
||||
|
||||
# Return {
|
||||
return {
|
||||
"tests_per_week": conversion_rate, # reuse key for API compat
|
||||
"conversion_rate": conversion_rate,
|
||||
@@ -278,51 +388,84 @@ def calculate_validation_throughput(db: Session) -> dict:
|
||||
|
||||
|
||||
def calculate_rejection_rate(db: Session) -> dict:
|
||||
"""Calculate rejection rate, broken down by red_lead and blue_lead."""
|
||||
"""Calculate rejection rate, broken down by red_lead and blue_lead.
|
||||
|
||||
Args:
|
||||
db (Session): Active SQLAlchemy database session.
|
||||
|
||||
Returns:
|
||||
dict: Contains ``percentage`` (overall rejection rate), ``by_red_lead``
|
||||
(red-lead rejection percentage), and ``by_blue_lead``
|
||||
(blue-lead rejection percentage).
|
||||
"""
|
||||
# Assign validated_count = (
|
||||
validated_count = (
|
||||
db.query(func.count(Test.id))
|
||||
# Chain .filter() call
|
||||
.filter(Test.state == TestState.validated)
|
||||
# Chain .scalar() call
|
||||
.scalar()
|
||||
) or 0
|
||||
|
||||
# Assign rejected_count = (
|
||||
rejected_count = (
|
||||
db.query(func.count(Test.id))
|
||||
# Chain .filter() call
|
||||
.filter(Test.state == TestState.rejected)
|
||||
# Chain .scalar() call
|
||||
.scalar()
|
||||
) or 0
|
||||
|
||||
# Assign total = validated_count + rejected_count
|
||||
total = validated_count + rejected_count
|
||||
# Assign overall_pct = round((rejected_count / total) * 100, 1) if total > 0 else 0
|
||||
overall_pct = round((rejected_count / total) * 100, 1) if total > 0 else 0
|
||||
|
||||
# By red_lead (red_validation_status == "rejected")
|
||||
red_rejected = (
|
||||
db.query(func.count(Test.id))
|
||||
# Chain .filter() call
|
||||
.filter(Test.red_validation_status == "rejected")
|
||||
# Chain .scalar() call
|
||||
.scalar()
|
||||
) or 0
|
||||
# Assign red_total = (
|
||||
red_total = (
|
||||
db.query(func.count(Test.id))
|
||||
# Chain .filter() call
|
||||
.filter(Test.red_validation_status.in_(["approved", "rejected"]))
|
||||
# Chain .scalar() call
|
||||
.scalar()
|
||||
) or 0
|
||||
# Assign red_pct = round((red_rejected / red_total) * 100, 1) if red_total > 0 else 0
|
||||
red_pct = round((red_rejected / red_total) * 100, 1) if red_total > 0 else 0
|
||||
|
||||
# By blue_lead
|
||||
blue_rejected = (
|
||||
db.query(func.count(Test.id))
|
||||
# Chain .filter() call
|
||||
.filter(Test.blue_validation_status == "rejected")
|
||||
# Chain .scalar() call
|
||||
.scalar()
|
||||
) or 0
|
||||
# Assign blue_total = (
|
||||
blue_total = (
|
||||
db.query(func.count(Test.id))
|
||||
# Chain .filter() call
|
||||
.filter(Test.blue_validation_status.in_(["approved", "rejected"]))
|
||||
# Chain .scalar() call
|
||||
.scalar()
|
||||
) or 0
|
||||
# Assign blue_pct = round((blue_rejected / blue_total) * 100, 1) if blue_total > 0 else 0
|
||||
blue_pct = round((blue_rejected / blue_total) * 100, 1) if blue_total > 0 else 0
|
||||
|
||||
# Return {
|
||||
return {
|
||||
# Literal argument value
|
||||
"percentage": overall_pct,
|
||||
# Literal argument value
|
||||
"by_red_lead": red_pct,
|
||||
# Literal argument value
|
||||
"by_blue_lead": blue_pct,
|
||||
}
|
||||
|
||||
@@ -331,14 +474,31 @@ def calculate_rejection_rate(db: Session) -> dict:
|
||||
|
||||
|
||||
def get_all_operational_metrics(db: Session) -> dict:
|
||||
"""Get all operational metrics in a single response."""
|
||||
"""Return all operational metrics combined in a single response.
|
||||
|
||||
Args:
|
||||
db (Session): Active SQLAlchemy database session.
|
||||
|
||||
Returns:
|
||||
dict: Contains ``mttd``, ``mttr``, ``detection_efficacy``,
|
||||
``alert_fidelity``, ``coverage_velocity``,
|
||||
``validation_throughput``, and ``rejection_rate`` keys.
|
||||
"""
|
||||
# Return {
|
||||
return {
|
||||
# Literal argument value
|
||||
"mttd": calculate_mttd(db),
|
||||
# Literal argument value
|
||||
"mttr": calculate_mttr(db),
|
||||
# Literal argument value
|
||||
"detection_efficacy": calculate_detection_efficacy(db),
|
||||
# Literal argument value
|
||||
"alert_fidelity": calculate_alert_fidelity(db),
|
||||
# Literal argument value
|
||||
"coverage_velocity": calculate_coverage_velocity(db),
|
||||
# Literal argument value
|
||||
"validation_throughput": calculate_validation_throughput(db),
|
||||
# Literal argument value
|
||||
"rejection_rate": calculate_rejection_rate(db),
|
||||
}
|
||||
|
||||
@@ -347,44 +507,77 @@ def get_all_operational_metrics(db: Session) -> dict:
|
||||
|
||||
|
||||
def get_operational_trend(db: Session, period: str = "90d") -> list:
|
||||
"""Get weekly trend data for operational metrics."""
|
||||
"""Return weekly trend data for operational metrics.
|
||||
|
||||
Args:
|
||||
db (Session): Active SQLAlchemy database session.
|
||||
period (str): Lookback period; one of ``"30d"``, ``"90d"``
|
||||
(default), or ``"1y"``.
|
||||
|
||||
Returns:
|
||||
list: Weekly data points, each a dict with ``date``,
|
||||
``detection_efficacy``, ``validated_tests``, and
|
||||
``detected_tests``.
|
||||
"""
|
||||
# Assign now = datetime.utcnow()
|
||||
now = datetime.utcnow()
|
||||
# Check: period == "30d"
|
||||
if period == "30d":
|
||||
# Assign start = now - timedelta(days=30)
|
||||
start = now - timedelta(days=30)
|
||||
# Alternative: period == "1y"
|
||||
elif period == "1y":
|
||||
# Assign start = now - timedelta(days=365)
|
||||
start = now - timedelta(days=365)
|
||||
# Fallback: handle remaining cases
|
||||
else:
|
||||
# Assign start = now - timedelta(days=90)
|
||||
start = now - timedelta(days=90)
|
||||
|
||||
# Build weekly data points
|
||||
data_points = []
|
||||
# Assign current = start
|
||||
current = start
|
||||
# Loop while current < now
|
||||
while current < now:
|
||||
# Assign week_end = min(current + timedelta(days=7), now)
|
||||
week_end = min(current + timedelta(days=7), now)
|
||||
|
||||
# Detection efficacy for tests validated up to this week
|
||||
validated_up_to = (
|
||||
db.query(Test)
|
||||
# Chain .filter() call
|
||||
.filter(
|
||||
Test.state == TestState.validated,
|
||||
Test.red_validated_at <= week_end,
|
||||
)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
|
||||
# Assign total = len(validated_up_to)
|
||||
total = len(validated_up_to)
|
||||
# Assign detected = len([t for t in validated_up_to if t.detection_result == TestResult...
|
||||
detected = len([t for t in validated_up_to if t.detection_result == TestResult.detected])
|
||||
# Assign efficacy = round((detected / total) * 100, 1) if total > 0 else 0
|
||||
efficacy = round((detected / total) * 100, 1) if total > 0 else 0
|
||||
|
||||
# Call data_points.append()
|
||||
data_points.append({
|
||||
# Literal argument value
|
||||
"date": current.strftime("%Y-%m-%d"),
|
||||
# Literal argument value
|
||||
"detection_efficacy": efficacy,
|
||||
# Literal argument value
|
||||
"validated_tests": total,
|
||||
# Literal argument value
|
||||
"detected_tests": detected,
|
||||
})
|
||||
|
||||
# Assign current = week_end
|
||||
current = week_end
|
||||
|
||||
# Return data_points
|
||||
return data_points
|
||||
|
||||
|
||||
@@ -392,20 +585,33 @@ def get_operational_trend(db: Session, period: str = "90d") -> list:
|
||||
|
||||
|
||||
def get_metrics_by_team(db: Session) -> dict:
|
||||
"""Get metrics broken down by Red vs Blue team."""
|
||||
"""Return metrics broken down by Red vs Blue team.
|
||||
|
||||
Args:
|
||||
db (Session): Active SQLAlchemy database session.
|
||||
|
||||
Returns:
|
||||
dict: Contains ``red_team`` and ``blue_team`` sub-dicts, each with
|
||||
``tests_completed``, ``avg_completion_hours``, and
|
||||
``rejection_rate``.
|
||||
"""
|
||||
# Red team metrics
|
||||
red_tests_completed = (
|
||||
db.query(func.count(Test.id))
|
||||
# Chain .filter() call
|
||||
.filter(Test.state.in_([
|
||||
TestState.blue_evaluating,
|
||||
TestState.in_review,
|
||||
TestState.validated,
|
||||
TestState.rejected,
|
||||
]))
|
||||
# Chain .scalar() call
|
||||
.scalar()
|
||||
) or 0
|
||||
|
||||
# Assign red_avg_time = None
|
||||
red_avg_time = None
|
||||
# Assign red_times = []
|
||||
red_times = []
|
||||
# Red team avg execution time: red_started_at → blue_started_at (net of paused)
|
||||
tests_with_red = (
|
||||
@@ -416,6 +622,7 @@ def get_metrics_by_team(db: Session) -> dict:
|
||||
)
|
||||
.all()
|
||||
)
|
||||
# Iterate over tests_with_red
|
||||
for t in tests_with_red:
|
||||
gross = (t.blue_started_at - t.red_started_at).total_seconds()
|
||||
net = gross - (t.red_paused_seconds or 0)
|
||||
@@ -429,11 +636,13 @@ def get_metrics_by_team(db: Session) -> dict:
|
||||
# Blue team: count tests that reached the blue evaluation phase
|
||||
blue_tests_completed = (
|
||||
db.query(func.count(Test.id))
|
||||
# Chain .filter() call
|
||||
.filter(Test.state.in_([
|
||||
TestState.in_review,
|
||||
TestState.validated,
|
||||
TestState.rejected,
|
||||
]))
|
||||
# Chain .scalar() call
|
||||
.scalar()
|
||||
) or 0
|
||||
|
||||
@@ -441,15 +650,20 @@ def get_metrics_by_team(db: Session) -> dict:
|
||||
# Prefer blue_work_started_at (actual pick-up) → blue_validated_at.
|
||||
# Fall back to blue_started_at if blue_work_started_at is not set.
|
||||
blue_avg_time = None
|
||||
# Assign blue_times = []
|
||||
blue_times = []
|
||||
# Assign tests_with_blue = (
|
||||
tests_with_blue = (
|
||||
db.query(Test)
|
||||
# Chain .filter() call
|
||||
.filter(
|
||||
Test.blue_started_at.isnot(None),
|
||||
Test.blue_validated_at.isnot(None),
|
||||
)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
# Iterate over tests_with_blue
|
||||
for t in tests_with_blue:
|
||||
phase_start = t.blue_work_started_at or t.blue_started_at
|
||||
gross = (t.blue_validated_at - phase_start).total_seconds()
|
||||
@@ -463,15 +677,22 @@ def get_metrics_by_team(db: Session) -> dict:
|
||||
red_avg_raw = sum(red_times) / len(red_times) if red_times else None
|
||||
blue_avg_raw = sum(blue_times) / len(blue_times) if blue_times else None
|
||||
|
||||
# Return {
|
||||
return {
|
||||
# Literal argument value
|
||||
"red_team": {
|
||||
# Literal argument value
|
||||
"tests_completed": red_tests_completed,
|
||||
# Literal argument value
|
||||
"avg_completion_hours": red_avg_time,
|
||||
"avg_unit": "min" if (red_avg_raw is not None and red_avg_raw < 1) else "hrs",
|
||||
"rejection_rate": calculate_rejection_rate(db)["by_red_lead"],
|
||||
},
|
||||
# Literal argument value
|
||||
"blue_team": {
|
||||
# Literal argument value
|
||||
"tests_completed": blue_tests_completed,
|
||||
# Literal argument value
|
||||
"avg_completion_hours": blue_avg_time,
|
||||
"avg_unit": "min" if (blue_avg_raw is not None and blue_avg_raw < 1) else "hrs",
|
||||
"rejection_rate": calculate_rejection_rate(db)["by_blue_lead"],
|
||||
|
||||
@@ -1,237 +1,433 @@
|
||||
"""OSINT enrichment service — automatically discovers CVEs, advisories, and
|
||||
related intelligence for MITRE ATT&CK techniques using the NVD API.
|
||||
"""OSINT enrichment service — discovers CVEs, advisories, and threat intel for ATT&CK techniques via the NVD API.
|
||||
|
||||
Designed to run as a weekly background job. Respects NVD rate limits
|
||||
(5 requests per 30 seconds without an API key, 50/30s with a key).
|
||||
"""
|
||||
|
||||
# Import logging
|
||||
import logging
|
||||
|
||||
# Import time
|
||||
import time
|
||||
|
||||
# Import Optional from typing
|
||||
from typing import Optional
|
||||
|
||||
# Import UUID from uuid
|
||||
from uuid import UUID
|
||||
|
||||
# Import requests
|
||||
import requests
|
||||
|
||||
# Import func from sqlalchemy
|
||||
from sqlalchemy import func
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import settings from app.config
|
||||
from app.config import settings
|
||||
|
||||
# Import EntityNotFoundError from app.domain.errors
|
||||
from app.domain.errors import EntityNotFoundError
|
||||
|
||||
# Import OsintItem from app.models.osint_item
|
||||
from app.models.osint_item import OsintItem
|
||||
|
||||
# Import Technique from app.models.technique
|
||||
from app.models.technique import Technique
|
||||
|
||||
# Assign logger = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Assign NVD_API_BASE = "https://services.nvd.nist.gov/rest/json/cves/2.0"
|
||||
NVD_API_BASE = "https://services.nvd.nist.gov/rest/json/cves/2.0"
|
||||
# Assign NVD_RATE_LIMIT_BATCH = 5
|
||||
NVD_RATE_LIMIT_BATCH = 5
|
||||
# Assign NVD_RATE_LIMIT_WAIT = 31 # seconds to wait after each batch
|
||||
NVD_RATE_LIMIT_WAIT = 31 # seconds to wait after each batch
|
||||
|
||||
|
||||
# Define function enrich_technique_with_cves
|
||||
def enrich_technique_with_cves(db: Session, technique: Technique) -> int:
|
||||
"""Search for CVEs related to a technique via the NVD API.
|
||||
|
||||
Uses the technique name as a keyword search. Deduplicates against
|
||||
existing OsintItems so re-runs are safe.
|
||||
|
||||
Returns the number of new CVEs added.
|
||||
Args:
|
||||
db (Session): Active SQLAlchemy database session.
|
||||
technique (Technique): The ATT&CK technique to enrich.
|
||||
|
||||
Returns:
|
||||
int: Number of new CVE items added to the database.
|
||||
"""
|
||||
# Attempt the following; catch errors below
|
||||
try:
|
||||
# Assign headers = {}
|
||||
headers = {}
|
||||
# Check: getattr(settings, "NVD_API_KEY", "")
|
||||
if getattr(settings, "NVD_API_KEY", ""):
|
||||
# Assign headers["apiKey"] = settings.NVD_API_KEY
|
||||
headers["apiKey"] = settings.NVD_API_KEY
|
||||
|
||||
# Assign params = {
|
||||
params = {
|
||||
# Literal argument value
|
||||
"keywordSearch": technique.name,
|
||||
# Literal argument value
|
||||
"resultsPerPage": 10,
|
||||
}
|
||||
|
||||
# Assign resp = requests.get(
|
||||
resp = requests.get(
|
||||
NVD_API_BASE,
|
||||
# Keyword argument: params
|
||||
params=params,
|
||||
# Keyword argument: headers
|
||||
headers=headers,
|
||||
# Keyword argument: timeout
|
||||
timeout=30,
|
||||
)
|
||||
# Check: resp.status_code != 200
|
||||
if resp.status_code != 200:
|
||||
# Log warning:
|
||||
logger.warning(
|
||||
# Literal argument value
|
||||
"NVD API error for %s: HTTP %d",
|
||||
technique.mitre_id,
|
||||
resp.status_code,
|
||||
)
|
||||
# Return 0
|
||||
return 0
|
||||
|
||||
# Assign data = resp.json()
|
||||
data = resp.json()
|
||||
# Assign count = 0
|
||||
count = 0
|
||||
|
||||
# Iterate over data.get("vulnerabilities", [])
|
||||
for vuln in data.get("vulnerabilities", []):
|
||||
# Assign cve = vuln.get("cve", {})
|
||||
cve = vuln.get("cve", {})
|
||||
# Assign cve_id = cve.get("id")
|
||||
cve_id = cve.get("id")
|
||||
# Check: not cve_id
|
||||
if not cve_id:
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
|
||||
# Deduplicate
|
||||
exists = (
|
||||
db.query(OsintItem.id)
|
||||
# Chain .filter() call
|
||||
.filter(
|
||||
OsintItem.technique_id == technique.id,
|
||||
OsintItem.source_url.contains(cve_id),
|
||||
)
|
||||
# Chain .first() call
|
||||
.first()
|
||||
)
|
||||
# Check: exists
|
||||
if exists:
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
|
||||
# Assign descriptions = cve.get("descriptions", [])
|
||||
descriptions = cve.get("descriptions", [])
|
||||
# Assign desc = next(
|
||||
desc = next(
|
||||
(d["value"] for d in descriptions if d["lang"] == "en"), ""
|
||||
)
|
||||
|
||||
# Extract CVSS severity
|
||||
metrics = cve.get("metrics", {})
|
||||
# Assign cvss_v31 = metrics.get("cvssMetricV31", [])
|
||||
cvss_v31 = metrics.get("cvssMetricV31", [])
|
||||
# Assign cvss_v30 = metrics.get("cvssMetricV30", [])
|
||||
cvss_v30 = metrics.get("cvssMetricV30", [])
|
||||
# Assign cvss_entry = (cvss_v31[0] if cvss_v31 else cvss_v30[0]) if (cvss_v31 or cvss_v30...
|
||||
cvss_entry = (cvss_v31[0] if cvss_v31 else cvss_v30[0]) if (cvss_v31 or cvss_v30) else {}
|
||||
# Assign cvss_data = cvss_entry.get("cvssData", {}) if cvss_entry else {}
|
||||
cvss_data = cvss_entry.get("cvssData", {}) if cvss_entry else {}
|
||||
# Assign severity = cvss_data.get("baseSeverity", "UNKNOWN")
|
||||
severity = cvss_data.get("baseSeverity", "UNKNOWN")
|
||||
# Assign score = cvss_data.get("baseScore")
|
||||
score = cvss_data.get("baseScore")
|
||||
|
||||
# Assign item = OsintItem(
|
||||
item = OsintItem(
|
||||
# Keyword argument: technique_id
|
||||
technique_id=technique.id,
|
||||
# Keyword argument: source_type
|
||||
source_type="cve",
|
||||
# Keyword argument: source_url
|
||||
source_url=f"https://nvd.nist.gov/vuln/detail/{cve_id}",
|
||||
# Keyword argument: title
|
||||
title=cve_id,
|
||||
# Keyword argument: description
|
||||
description=desc[:500] if desc else None,
|
||||
# Keyword argument: severity
|
||||
severity=severity,
|
||||
# Keyword argument: metadata_
|
||||
metadata_={"cvss_score": score, "cve_id": cve_id},
|
||||
)
|
||||
# Stage new record(s) for database insertion
|
||||
db.add(item)
|
||||
# Assign count = 1
|
||||
count += 1
|
||||
|
||||
# Check: count > 0
|
||||
if count > 0:
|
||||
# Assign technique.review_required = True
|
||||
technique.review_required = True
|
||||
# Commit all pending changes to the database
|
||||
db.commit()
|
||||
# Log info: "Added %d CVEs for %s", count, technique.mitre_id
|
||||
logger.info("Added %d CVEs for %s", count, technique.mitre_id)
|
||||
|
||||
# Return count
|
||||
return count
|
||||
|
||||
# Handle requests.RequestException
|
||||
except requests.RequestException as e:
|
||||
# Log error:
|
||||
logger.error(
|
||||
# Literal argument value
|
||||
"HTTP error during OSINT enrichment for %s: %s",
|
||||
technique.mitre_id,
|
||||
e,
|
||||
)
|
||||
# Return 0
|
||||
return 0
|
||||
# Handle Exception
|
||||
except Exception as e:
|
||||
# Log error:
|
||||
logger.error(
|
||||
# Literal argument value
|
||||
"OSINT enrichment failed for %s: %s",
|
||||
technique.mitre_id,
|
||||
e,
|
||||
# Keyword argument: exc_info
|
||||
exc_info=True,
|
||||
)
|
||||
# Return 0
|
||||
return 0
|
||||
|
||||
|
||||
# Define function enrich_all_techniques
|
||||
def enrich_all_techniques(db: Session) -> int:
|
||||
"""Enrich all techniques with CVE data from NVD.
|
||||
|
||||
Rate-limited: processes *NVD_RATE_LIMIT_BATCH* techniques, then
|
||||
sleeps for *NVD_RATE_LIMIT_WAIT* seconds to stay under NVD limits.
|
||||
|
||||
Returns total number of new OSINT items added.
|
||||
Args:
|
||||
db (Session): Active SQLAlchemy database session.
|
||||
|
||||
Returns:
|
||||
int: Total number of new OSINT items added across all techniques.
|
||||
"""
|
||||
# Assign techniques = db.query(Technique).order_by(Technique.mitre_id).all()
|
||||
techniques = db.query(Technique).order_by(Technique.mitre_id).all()
|
||||
# Assign total = 0
|
||||
total = 0
|
||||
|
||||
# Log info:
|
||||
logger.info(
|
||||
# Literal argument value
|
||||
"Starting OSINT enrichment for %d techniques...",
|
||||
len(techniques),
|
||||
)
|
||||
|
||||
# Iterate over enumerate(techniques)
|
||||
for i, tech in enumerate(techniques):
|
||||
# Assign total = enrich_technique_with_cves(db, tech)
|
||||
total += enrich_technique_with_cves(db, tech)
|
||||
|
||||
# Rate limiting: wait after every batch
|
||||
if (i + 1) % NVD_RATE_LIMIT_BATCH == 0 and (i + 1) < len(techniques):
|
||||
# Log debug:
|
||||
logger.debug(
|
||||
# Literal argument value
|
||||
"Rate limit pause after %d techniques (%ds)...",
|
||||
i + 1,
|
||||
NVD_RATE_LIMIT_WAIT,
|
||||
)
|
||||
# Call time.sleep()
|
||||
time.sleep(NVD_RATE_LIMIT_WAIT)
|
||||
|
||||
# Log info:
|
||||
logger.info(
|
||||
# Literal argument value
|
||||
"OSINT enrichment complete — %d new items across %d techniques",
|
||||
total,
|
||||
len(techniques),
|
||||
)
|
||||
# Return total
|
||||
return total
|
||||
|
||||
|
||||
# Define function get_osint_items_for_technique
|
||||
def get_osint_items_for_technique(
|
||||
# Entry: db
|
||||
db: Session,
|
||||
# Entry: technique_id
|
||||
technique_id: str,
|
||||
# Entry: source_type
|
||||
source_type: str | None = None,
|
||||
# Entry: reviewed
|
||||
reviewed: bool | None = None,
|
||||
) -> list[OsintItem]:
|
||||
"""Retrieve OSINT items for a technique with optional filters."""
|
||||
"""Retrieve OSINT items for a technique with optional filters.
|
||||
|
||||
Args:
|
||||
db (Session): Active SQLAlchemy database session.
|
||||
technique_id (str): UUID string of the technique to query.
|
||||
source_type (str | None): Optional filter by source type (e.g.
|
||||
``"cve"``).
|
||||
reviewed (bool | None): Optional filter; ``True`` for reviewed items
|
||||
only, ``False`` for unreviewed, ``None`` for all.
|
||||
|
||||
Returns:
|
||||
list[OsintItem]: Matching OSINT items ordered by discovery date
|
||||
descending.
|
||||
"""
|
||||
# Assign query = db.query(OsintItem).filter(OsintItem.technique_id == technique_id)
|
||||
query = db.query(OsintItem).filter(OsintItem.technique_id == technique_id)
|
||||
# Check: source_type
|
||||
if source_type:
|
||||
# Assign query = query.filter(OsintItem.source_type == source_type)
|
||||
query = query.filter(OsintItem.source_type == source_type)
|
||||
# Check: reviewed is not None
|
||||
if reviewed is not None:
|
||||
# Assign query = query.filter(OsintItem.reviewed == reviewed)
|
||||
query = query.filter(OsintItem.reviewed == reviewed)
|
||||
# Return query.order_by(OsintItem.discovered_at.desc()).all()
|
||||
return query.order_by(OsintItem.discovered_at.desc()).all()
|
||||
|
||||
|
||||
# Define function mark_osint_reviewed
|
||||
def mark_osint_reviewed(db: Session, item_id: str) -> OsintItem | None:
|
||||
"""Mark an OSINT item as reviewed. Does not commit; caller uses UnitOfWork."""
|
||||
"""Mark an OSINT item as reviewed. Does not commit; caller uses UnitOfWork.
|
||||
|
||||
Args:
|
||||
db (Session): Active SQLAlchemy database session.
|
||||
item_id (str): UUID string of the OSINT item to mark.
|
||||
|
||||
Returns:
|
||||
OsintItem | None: The updated item, or ``None`` if not found.
|
||||
"""
|
||||
# Assign item = db.query(OsintItem).filter(OsintItem.id == item_id).first()
|
||||
item = db.query(OsintItem).filter(OsintItem.id == item_id).first()
|
||||
# Check: item
|
||||
if item:
|
||||
# Assign item.reviewed = True
|
||||
item.reviewed = True
|
||||
# Return item
|
||||
return item
|
||||
|
||||
|
||||
# Define function get_unreviewed_count
|
||||
def get_unreviewed_count(db: Session) -> int:
|
||||
"""Return the total number of unreviewed OSINT items."""
|
||||
"""Return the total number of unreviewed OSINT items.
|
||||
|
||||
Args:
|
||||
db (Session): Active SQLAlchemy database session.
|
||||
|
||||
Returns:
|
||||
int: Count of OSINT items where ``reviewed`` is ``False``.
|
||||
"""
|
||||
# Return db.query(OsintItem).filter(OsintItem.reviewed == False).count() # ...
|
||||
return db.query(OsintItem).filter(OsintItem.reviewed == False).count() # noqa: E712
|
||||
|
||||
|
||||
# Define function list_osint_items
|
||||
def list_osint_items(
|
||||
# Entry: db
|
||||
db: Session,
|
||||
*,
|
||||
# Entry: technique_id
|
||||
technique_id: Optional[UUID] = None,
|
||||
# Entry: source_type
|
||||
source_type: Optional[str] = None,
|
||||
# Entry: reviewed
|
||||
reviewed: Optional[bool] = None,
|
||||
# Entry: offset
|
||||
offset: int = 0,
|
||||
# Entry: limit
|
||||
limit: int = 50,
|
||||
) -> dict:
|
||||
"""List OSINT items with optional filters and pagination."""
|
||||
"""List OSINT items with optional filters and pagination.
|
||||
|
||||
Args:
|
||||
db (Session): Active SQLAlchemy database session.
|
||||
technique_id (Optional[UUID]): Filter by technique UUID.
|
||||
source_type (Optional[str]): Filter by source type string (e.g.
|
||||
``"cve"``).
|
||||
reviewed (Optional[bool]): Filter by reviewed status; ``None``
|
||||
returns all.
|
||||
offset (int): Number of records to skip for pagination.
|
||||
limit (int): Maximum number of records to return.
|
||||
|
||||
Returns:
|
||||
dict: Contains ``total`` count and ``items`` list of serialized OSINT
|
||||
item dicts.
|
||||
"""
|
||||
# Assign query = db.query(OsintItem)
|
||||
query = db.query(OsintItem)
|
||||
# Check: technique_id
|
||||
if technique_id:
|
||||
# Assign query = query.filter(OsintItem.technique_id == technique_id)
|
||||
query = query.filter(OsintItem.technique_id == technique_id)
|
||||
# Check: source_type
|
||||
if source_type:
|
||||
# Assign query = query.filter(OsintItem.source_type == source_type)
|
||||
query = query.filter(OsintItem.source_type == source_type)
|
||||
# Check: reviewed is not None
|
||||
if reviewed is not None:
|
||||
# Assign query = query.filter(OsintItem.reviewed == reviewed)
|
||||
query = query.filter(OsintItem.reviewed == reviewed)
|
||||
|
||||
# Assign total = query.count()
|
||||
total = query.count()
|
||||
# Assign items = (
|
||||
items = (
|
||||
query.order_by(OsintItem.discovered_at.desc())
|
||||
# Chain .offset() call
|
||||
.offset(offset)
|
||||
# Chain .limit() call
|
||||
.limit(limit)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
|
||||
# Return {
|
||||
return {
|
||||
# Literal argument value
|
||||
"total": total,
|
||||
# Literal argument value
|
||||
"items": [
|
||||
{
|
||||
# Literal argument value
|
||||
"id": str(item.id),
|
||||
# Literal argument value
|
||||
"technique_id": str(item.technique_id),
|
||||
# Literal argument value
|
||||
"source_type": item.source_type,
|
||||
# Literal argument value
|
||||
"source_url": item.source_url,
|
||||
# Literal argument value
|
||||
"title": item.title,
|
||||
# Literal argument value
|
||||
"description": item.description,
|
||||
# Literal argument value
|
||||
"severity": item.severity,
|
||||
# Literal argument value
|
||||
"discovered_at": item.discovered_at.isoformat() if item.discovered_at else None,
|
||||
# Literal argument value
|
||||
"reviewed": item.reviewed,
|
||||
# Literal argument value
|
||||
"metadata": item.metadata_,
|
||||
}
|
||||
for item in items
|
||||
@@ -239,39 +435,76 @@ def list_osint_items(
|
||||
}
|
||||
|
||||
|
||||
# Define function get_osint_summary
|
||||
def get_osint_summary(db: Session) -> dict:
|
||||
"""Summary statistics for OSINT items."""
|
||||
"""Return summary statistics for OSINT items.
|
||||
|
||||
Args:
|
||||
db (Session): Active SQLAlchemy database session.
|
||||
|
||||
Returns:
|
||||
dict: Contains ``total_items``, ``unreviewed``,
|
||||
``techniques_with_items``, ``by_severity``, and ``by_type``.
|
||||
"""
|
||||
# Assign total = db.query(func.count(OsintItem.id)).scalar() or 0
|
||||
total = db.query(func.count(OsintItem.id)).scalar() or 0
|
||||
# Assign unreviewed = get_unreviewed_count(db)
|
||||
unreviewed = get_unreviewed_count(db)
|
||||
|
||||
# Assign by_severity = dict(
|
||||
by_severity = dict(
|
||||
db.query(OsintItem.severity, func.count(OsintItem.id))
|
||||
# Chain .group_by() call
|
||||
.group_by(OsintItem.severity)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
|
||||
# Assign by_type = dict(
|
||||
by_type = dict(
|
||||
db.query(OsintItem.source_type, func.count(OsintItem.id))
|
||||
# Chain .group_by() call
|
||||
.group_by(OsintItem.source_type)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
|
||||
# Assign techniques_with_items = (
|
||||
techniques_with_items = (
|
||||
db.query(func.count(func.distinct(OsintItem.technique_id))).scalar() or 0
|
||||
)
|
||||
|
||||
# Return {
|
||||
return {
|
||||
# Literal argument value
|
||||
"total_items": total,
|
||||
# Literal argument value
|
||||
"unreviewed": unreviewed,
|
||||
# Literal argument value
|
||||
"techniques_with_items": techniques_with_items,
|
||||
# Literal argument value
|
||||
"by_severity": by_severity,
|
||||
# Literal argument value
|
||||
"by_type": by_type,
|
||||
}
|
||||
|
||||
|
||||
# Define function get_technique_or_raise
|
||||
def get_technique_or_raise(db: Session, technique_id: UUID) -> Technique:
|
||||
"""Get a technique by ID or raise EntityNotFoundError."""
|
||||
"""Return a technique by ID or raise EntityNotFoundError.
|
||||
|
||||
Args:
|
||||
db (Session): Active SQLAlchemy database session.
|
||||
technique_id (UUID): UUID of the technique to retrieve.
|
||||
|
||||
Returns:
|
||||
Technique: The matching technique ORM object.
|
||||
"""
|
||||
# Assign technique = db.query(Technique).filter(Technique.id == technique_id).first()
|
||||
technique = db.query(Technique).filter(Technique.id == technique_id).first()
|
||||
# Check: not technique
|
||||
if not technique:
|
||||
# Raise EntityNotFoundError
|
||||
raise EntityNotFoundError("Technique", str(technique_id))
|
||||
# Return technique
|
||||
return technique
|
||||
|
||||
@@ -3,95 +3,151 @@
|
||||
Uses WeasyPrint for PDF generation and docxtpl for DOCX.
|
||||
"""
|
||||
|
||||
import os
|
||||
import uuid
|
||||
# Import logging
|
||||
import logging
|
||||
|
||||
# Import os
|
||||
import os
|
||||
|
||||
# Import uuid
|
||||
import uuid
|
||||
|
||||
# Import datetime from datetime
|
||||
from datetime import datetime
|
||||
|
||||
# Import Environment, FileSystemLoader from jinja2
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
|
||||
# Import settings from app.config
|
||||
from app.config import settings
|
||||
|
||||
# Assign logger = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Define class ReportEngine
|
||||
class ReportEngine:
|
||||
"""Template-based report generator supporting PDF, DOCX, and HTML output."""
|
||||
|
||||
# Define function __init__
|
||||
def __init__(self) -> None:
|
||||
"""Initialise the Jinja2 environment and ensure the output directory exists."""
|
||||
# Assign self.jinja_env = Environment(
|
||||
self.jinja_env = Environment(
|
||||
# Keyword argument: loader
|
||||
loader=FileSystemLoader(settings.REPORT_TEMPLATES_DIR),
|
||||
# Keyword argument: autoescape
|
||||
autoescape=True,
|
||||
)
|
||||
# Call os.makedirs()
|
||||
os.makedirs(settings.REPORT_OUTPUT_DIR, exist_ok=True)
|
||||
|
||||
# Define function render_html
|
||||
def render_html(self, template_name: str, context: dict) -> str:
|
||||
"""Render a Jinja2 template to an HTML string."""
|
||||
# Assign template = self.jinja_env.get_template(f"{template_name}.html")
|
||||
template = self.jinja_env.get_template(f"{template_name}.html")
|
||||
# Call context.setdefault()
|
||||
context.setdefault("company_name", settings.COMPANY_NAME)
|
||||
# Call context.setdefault()
|
||||
context.setdefault("generated_at", datetime.utcnow().strftime("%B %d, %Y %H:%M UTC"))
|
||||
# Return template.render(context)
|
||||
return template.render(context)
|
||||
|
||||
# Define function generate_pdf
|
||||
def generate_pdf(self, template_name: str, context: dict) -> str:
|
||||
"""Render HTML and convert to PDF with WeasyPrint."""
|
||||
from weasyprint import HTML, CSS
|
||||
# Import CSS, HTML from weasyprint
|
||||
from weasyprint import CSS, HTML
|
||||
|
||||
# Assign html_content = self.render_html(template_name, context)
|
||||
html_content = self.render_html(template_name, context)
|
||||
# Assign css_path = os.path.join(settings.REPORT_TEMPLATES_DIR, "styles", "report.css")
|
||||
css_path = os.path.join(settings.REPORT_TEMPLATES_DIR, "styles", "report.css")
|
||||
# Assign output_path = os.path.join(
|
||||
output_path = os.path.join(
|
||||
settings.REPORT_OUTPUT_DIR,
|
||||
f"{template_name}_{uuid.uuid4().hex[:8]}.pdf",
|
||||
)
|
||||
|
||||
# Assign stylesheets = []
|
||||
stylesheets = []
|
||||
# Check: os.path.exists(css_path)
|
||||
if os.path.exists(css_path):
|
||||
# Call stylesheets.append()
|
||||
stylesheets.append(CSS(filename=css_path))
|
||||
|
||||
# Call HTML()
|
||||
HTML(
|
||||
# Keyword argument: string
|
||||
string=html_content,
|
||||
# Keyword argument: base_url
|
||||
base_url=settings.REPORT_TEMPLATES_DIR,
|
||||
).write_pdf(output_path, stylesheets=stylesheets)
|
||||
|
||||
# Log info: "PDF generated: %s", output_path
|
||||
logger.info("PDF generated: %s", output_path)
|
||||
# Return output_path
|
||||
return output_path
|
||||
|
||||
# Define function generate_docx
|
||||
def generate_docx(self, template_name: str, context: dict) -> str:
|
||||
"""Render a .docx template with docxtpl."""
|
||||
# Import DocxTemplate from docxtpl
|
||||
from docxtpl import DocxTemplate
|
||||
|
||||
# Assign template_path = os.path.join(
|
||||
template_path = os.path.join(
|
||||
settings.REPORT_TEMPLATES_DIR, f"{template_name}.docx"
|
||||
)
|
||||
# Assign output_path = os.path.join(
|
||||
output_path = os.path.join(
|
||||
settings.REPORT_OUTPUT_DIR,
|
||||
f"{template_name}_{uuid.uuid4().hex[:8]}.docx",
|
||||
)
|
||||
|
||||
# Assign doc = DocxTemplate(template_path)
|
||||
doc = DocxTemplate(template_path)
|
||||
# Call context.setdefault()
|
||||
context.setdefault("company_name", settings.COMPANY_NAME)
|
||||
# Call context.setdefault()
|
||||
context.setdefault("generated_at", datetime.utcnow().strftime("%B %d, %Y"))
|
||||
# Call doc.render()
|
||||
doc.render(context)
|
||||
# Call doc.save()
|
||||
doc.save(output_path)
|
||||
|
||||
# Log info: "DOCX generated: %s", output_path
|
||||
logger.info("DOCX generated: %s", output_path)
|
||||
# Return output_path
|
||||
return output_path
|
||||
|
||||
# Define function generate_html
|
||||
def generate_html(self, template_name: str, context: dict) -> str:
|
||||
"""Render and save a standalone HTML report (alias for spec compatibility)."""
|
||||
# Return self.generate_html_file(template_name, context)
|
||||
return self.generate_html_file(template_name, context)
|
||||
|
||||
# Define function generate_html_file
|
||||
def generate_html_file(self, template_name: str, context: dict) -> str:
|
||||
"""Render and save a standalone HTML report."""
|
||||
# Assign html_content = self.render_html(template_name, context)
|
||||
html_content = self.render_html(template_name, context)
|
||||
# Assign output_path = os.path.join(
|
||||
output_path = os.path.join(
|
||||
settings.REPORT_OUTPUT_DIR,
|
||||
f"{template_name}_{uuid.uuid4().hex[:8]}.html",
|
||||
)
|
||||
# Open context manager
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
# Call f.write()
|
||||
f.write(html_content)
|
||||
|
||||
# Log info: "HTML report generated: %s", output_path
|
||||
logger.info("HTML report generated: %s", output_path)
|
||||
# Return output_path
|
||||
return output_path
|
||||
|
||||
|
||||
# Assign report_engine = ReportEngine()
|
||||
report_engine = ReportEngine()
|
||||
|
||||
@@ -1,115 +1,195 @@
|
||||
"""High-level report generation — collects domain data and delegates to ReportEngine."""
|
||||
|
||||
# Import logging
|
||||
import logging
|
||||
|
||||
# Import datetime, timedelta from datetime
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
# Import UUID from uuid
|
||||
from uuid import UUID
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import EntityNotFoundError from app.domain.exceptions
|
||||
from app.domain.exceptions import EntityNotFoundError
|
||||
|
||||
# Import Campaign, CampaignTest from app.models.campaign
|
||||
from app.models.campaign import Campaign, CampaignTest
|
||||
|
||||
# Import CoverageSnapshot from app.models.coverage_snapshot
|
||||
from app.models.coverage_snapshot import CoverageSnapshot
|
||||
|
||||
# Import Technique from app.models.technique
|
||||
from app.models.technique import Technique
|
||||
|
||||
# Import Test from app.models.test
|
||||
from app.models.test import Test
|
||||
|
||||
# Import ThreatActor from app.models.threat_actor
|
||||
from app.models.threat_actor import ThreatActor
|
||||
|
||||
# Import report_engine from app.services.report_engine
|
||||
from app.services.report_engine import report_engine
|
||||
|
||||
# Assign logger = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Define function generate_purple_campaign_report
|
||||
def generate_purple_campaign_report(
|
||||
# Entry: db
|
||||
db: Session,
|
||||
# Entry: campaign_id
|
||||
campaign_id: str,
|
||||
# Entry: output_format
|
||||
output_format: str = "pdf",
|
||||
) -> str:
|
||||
"""Generate the full Purple Team campaign report."""
|
||||
# Assign cid = campaign_id if isinstance(campaign_id, UUID) else UUID(str(campaign...
|
||||
cid = campaign_id if isinstance(campaign_id, UUID) else UUID(str(campaign_id))
|
||||
# Assign campaign = db.query(Campaign).filter(Campaign.id == cid).first()
|
||||
campaign = db.query(Campaign).filter(Campaign.id == cid).first()
|
||||
# Check: not campaign
|
||||
if not campaign:
|
||||
# Raise EntityNotFoundError
|
||||
raise EntityNotFoundError("Campaign", campaign_id)
|
||||
|
||||
# Assign campaign_tests = (
|
||||
campaign_tests = (
|
||||
db.query(Test)
|
||||
# Chain .join() call
|
||||
.join(CampaignTest, CampaignTest.test_id == Test.id)
|
||||
# Chain .filter() call
|
||||
.filter(CampaignTest.campaign_id == cid)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
|
||||
# Assign tests_data = []
|
||||
tests_data = []
|
||||
# Iterate over campaign_tests
|
||||
for test in campaign_tests:
|
||||
# Assign technique = db.query(Technique).filter(Technique.id == test.technique_id).first()
|
||||
technique = db.query(Technique).filter(Technique.id == test.technique_id).first()
|
||||
# Call tests_data.append()
|
||||
tests_data.append({
|
||||
# Literal argument value
|
||||
"technique_mitre_id": technique.mitre_id if technique else "N/A",
|
||||
# Literal argument value
|
||||
"name": test.name,
|
||||
# Literal argument value
|
||||
"tactic": technique.tactic if technique else "N/A",
|
||||
# Literal argument value
|
||||
"state": test.state.value if test.state else "draft",
|
||||
# Literal argument value
|
||||
"detection_result": (
|
||||
test.detection_result.value if test.detection_result else "pending"
|
||||
),
|
||||
})
|
||||
|
||||
# Assign validated = [t for t in campaign_tests if t.state and t.state.value == "validat...
|
||||
validated = [t for t in campaign_tests if t.state and t.state.value == "validated"]
|
||||
# Assign detected = [
|
||||
detected = [
|
||||
t for t in validated
|
||||
if t.detection_result and t.detection_result.value == "detected"
|
||||
]
|
||||
# Assign not_detected = [
|
||||
not_detected = [
|
||||
t for t in validated
|
||||
if t.detection_result and t.detection_result.value == "not_detected"
|
||||
]
|
||||
|
||||
# Assign critical_findings = [
|
||||
critical_findings = [
|
||||
{
|
||||
# Literal argument value
|
||||
"technique_id": t["technique_mitre_id"],
|
||||
# Literal argument value
|
||||
"name": t["name"],
|
||||
# Literal argument value
|
||||
"severity": "critical",
|
||||
# Literal argument value
|
||||
"description": "Technique was not detected during campaign execution.",
|
||||
# Literal argument value
|
||||
"recommendation": "Implement detection rule or review existing SIEM/EDR configuration.",
|
||||
}
|
||||
for t in tests_data
|
||||
if t["detection_result"] == "not_detected"
|
||||
]
|
||||
|
||||
# Assign org_score = _safe_org_score(db)
|
||||
org_score = _safe_org_score(db)
|
||||
|
||||
# Assign threat_actors = []
|
||||
threat_actors = []
|
||||
# Check: campaign.threat_actor_id
|
||||
if campaign.threat_actor_id:
|
||||
# Assign actor = db.query(ThreatActor).filter(ThreatActor.id == campaign.threat_acto...
|
||||
actor = db.query(ThreatActor).filter(ThreatActor.id == campaign.threat_actor_id).first()
|
||||
# Check: actor
|
||||
if actor:
|
||||
# Assign threat_actors = [{"name": actor.name}]
|
||||
threat_actors = [{"name": actor.name}]
|
||||
|
||||
# Assign context = {
|
||||
context = {
|
||||
# Literal argument value
|
||||
"campaign": campaign,
|
||||
# Literal argument value
|
||||
"tests": tests_data,
|
||||
# Literal argument value
|
||||
"tests_validated": len(validated),
|
||||
# Literal argument value
|
||||
"tests_detected": len(detected),
|
||||
# Literal argument value
|
||||
"tests_not_detected": len(not_detected),
|
||||
# Literal argument value
|
||||
"critical_findings": critical_findings,
|
||||
# Literal argument value
|
||||
"org_score": org_score.get("overall", 0),
|
||||
# Literal argument value
|
||||
"tactics": list({t["tactic"] for t in tests_data}),
|
||||
# Literal argument value
|
||||
"threat_actors": threat_actors,
|
||||
}
|
||||
|
||||
# Return _generate(output_format, "purple_campaign", context)
|
||||
return _generate(output_format, "purple_campaign", context)
|
||||
|
||||
|
||||
# Define function generate_coverage_report
|
||||
def generate_coverage_report(
|
||||
# Entry: db
|
||||
db: Session,
|
||||
# Entry: output_format
|
||||
output_format: str = "pdf",
|
||||
) -> str:
|
||||
"""Generate an organization-wide MITRE ATT&CK coverage report."""
|
||||
from sqlalchemy import func, case
|
||||
# Import case, func from sqlalchemy
|
||||
from sqlalchemy import case, func
|
||||
|
||||
# Assign org_score = _safe_org_score(db)
|
||||
org_score = _safe_org_score(db)
|
||||
|
||||
# Assign techniques = db.query(Technique).all()
|
||||
techniques = db.query(Technique).all()
|
||||
# Assign status_counts = {"validated": 0, "partial": 0, "not_covered": 0, "in_progress": 0, ...
|
||||
status_counts = {"validated": 0, "partial": 0, "not_covered": 0, "in_progress": 0, "not_evaluated": 0}
|
||||
# Iterate over techniques
|
||||
for t in techniques:
|
||||
# Assign s = t.status_global.value if t.status_global else "not_evaluated"
|
||||
s = t.status_global.value if t.status_global else "not_evaluated"
|
||||
# Check: s in status_counts
|
||||
if s in status_counts:
|
||||
# Assign status_counts[s] = 1
|
||||
status_counts[s] += 1
|
||||
|
||||
# Assign summary = {
|
||||
summary = {
|
||||
# Literal argument value
|
||||
"total_techniques": len(techniques),
|
||||
**status_counts,
|
||||
}
|
||||
@@ -121,14 +201,21 @@ def generate_coverage_report(
|
||||
func.count(Technique.id).label("total"),
|
||||
func.sum(case((Technique.status_global == "validated", 1), else_=0)).label("validated"),
|
||||
)
|
||||
# Chain .group_by() call
|
||||
.group_by(Technique.tactic)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
# Assign tactics_coverage = [
|
||||
tactics_coverage = [
|
||||
{
|
||||
# Literal argument value
|
||||
"tactic": r[0] or "Unknown",
|
||||
# Literal argument value
|
||||
"total": r[1],
|
||||
# Literal argument value
|
||||
"validated": int(r[2]),
|
||||
# Literal argument value
|
||||
"coverage_pct": round((int(r[2]) / r[1]) * 100, 1) if r[1] > 0 else 0,
|
||||
}
|
||||
for r in tactic_rows
|
||||
@@ -136,213 +223,326 @@ def generate_coverage_report(
|
||||
|
||||
# Never-tested techniques
|
||||
tested_ids = {t.technique_id for t in db.query(Test.technique_id).distinct().all()}
|
||||
# Assign never_tested = [
|
||||
never_tested = [
|
||||
{"mitre_id": t.mitre_id, "name": t.name, "tactic": t.tactic}
|
||||
for t in techniques
|
||||
if t.id not in tested_ids
|
||||
]
|
||||
|
||||
# Assign context = {
|
||||
context = {
|
||||
# Literal argument value
|
||||
"org_score": org_score,
|
||||
# Literal argument value
|
||||
"summary": summary,
|
||||
# Literal argument value
|
||||
"tactics_coverage": tactics_coverage,
|
||||
# Literal argument value
|
||||
"never_tested": never_tested[:50],
|
||||
}
|
||||
|
||||
# Return _generate(output_format, "coverage_report", context)
|
||||
return _generate(output_format, "coverage_report", context)
|
||||
|
||||
|
||||
# Define function generate_executive_summary
|
||||
def generate_executive_summary(
|
||||
# Entry: db
|
||||
db: Session,
|
||||
# Entry: output_format
|
||||
output_format: str = "pdf",
|
||||
) -> str:
|
||||
"""Generate an executive summary report."""
|
||||
# Import func from sqlalchemy
|
||||
from sqlalchemy import func
|
||||
|
||||
# Assign org_score = _safe_org_score(db)
|
||||
org_score = _safe_org_score(db)
|
||||
# Assign techniques = db.query(Technique).all()
|
||||
techniques = db.query(Technique).all()
|
||||
|
||||
# Assign status_counts = {"validated": 0, "partial": 0, "not_covered": 0, "in_progress": 0, ...
|
||||
status_counts = {"validated": 0, "partial": 0, "not_covered": 0, "in_progress": 0, "not_evaluated": 0}
|
||||
# Iterate over techniques
|
||||
for t in techniques:
|
||||
# Assign s = t.status_global.value if t.status_global else "not_evaluated"
|
||||
s = t.status_global.value if t.status_global else "not_evaluated"
|
||||
# Check: s in status_counts
|
||||
if s in status_counts:
|
||||
# Assign status_counts[s] = 1
|
||||
status_counts[s] += 1
|
||||
|
||||
# Assign summary = {"total_techniques": len(techniques), **status_counts}
|
||||
summary = {"total_techniques": len(techniques), **status_counts}
|
||||
|
||||
# Assign total_tests = db.query(func.count(Test.id)).scalar() or 0
|
||||
total_tests = db.query(func.count(Test.id)).scalar() or 0
|
||||
# Assign active_campaigns = (
|
||||
active_campaigns = (
|
||||
db.query(func.count(Campaign.id)).filter(Campaign.status == "active").scalar() or 0
|
||||
)
|
||||
|
||||
# Assign quarter_ago = datetime.utcnow() - timedelta(days=90)
|
||||
quarter_ago = datetime.utcnow() - timedelta(days=90)
|
||||
# Assign tests_this_quarter = (
|
||||
tests_this_quarter = (
|
||||
db.query(func.count(Test.id)).filter(Test.created_at >= quarter_ago).scalar() or 0
|
||||
)
|
||||
|
||||
# Assign open_remediations = (
|
||||
open_remediations = (
|
||||
db.query(func.count(Test.id))
|
||||
# Chain .filter() call
|
||||
.filter(Test.remediation_status.in_(["pending", "in_progress"]))
|
||||
# Chain .scalar() call
|
||||
.scalar() or 0
|
||||
)
|
||||
|
||||
# Detection rate among validated tests
|
||||
validated_count = status_counts["validated"]
|
||||
# Assign detected_count = (
|
||||
detected_count = (
|
||||
db.query(func.count(Test.id))
|
||||
# Chain .filter() call
|
||||
.filter(Test.state == "validated", Test.detection_result == "detected")
|
||||
# Chain .scalar() call
|
||||
.scalar() or 0
|
||||
)
|
||||
# Assign detection_rate = round((detected_count / validated_count) * 100, 1) if validated_cou...
|
||||
detection_rate = round((detected_count / validated_count) * 100, 1) if validated_count > 0 else 0
|
||||
|
||||
# Top gaps — lowest coverage tactics
|
||||
from sqlalchemy import case as sql_case
|
||||
# Assign tactic_rows = (
|
||||
tactic_rows = (
|
||||
db.query(
|
||||
Technique.tactic,
|
||||
func.count(Technique.id).label("total"),
|
||||
func.sum(sql_case((Technique.status_global == "validated", 1), else_=0)).label("validated"),
|
||||
)
|
||||
# Chain .group_by() call
|
||||
.group_by(Technique.tactic)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
# Assign tactic_coverage = [
|
||||
tactic_coverage = [
|
||||
{
|
||||
# Literal argument value
|
||||
"tactic": r[0] or "Unknown",
|
||||
# Literal argument value
|
||||
"coverage_pct": round((int(r[2]) / r[1]) * 100, 1) if r[1] > 0 else 0,
|
||||
}
|
||||
for r in tactic_rows
|
||||
]
|
||||
# Assign top_gaps = sorted(tactic_coverage, key=lambda x: x["coverage_pct"])[:5]
|
||||
top_gaps = sorted(tactic_coverage, key=lambda x: x["coverage_pct"])[:5]
|
||||
|
||||
# Assign context = {
|
||||
context = {
|
||||
# Literal argument value
|
||||
"org_score": org_score,
|
||||
# Literal argument value
|
||||
"summary": summary,
|
||||
# Literal argument value
|
||||
"total_tests": total_tests,
|
||||
# Literal argument value
|
||||
"active_campaigns": active_campaigns,
|
||||
# Literal argument value
|
||||
"tests_this_quarter": tests_this_quarter,
|
||||
# Literal argument value
|
||||
"open_remediations": open_remediations,
|
||||
# Literal argument value
|
||||
"detection_rate": detection_rate,
|
||||
# Literal argument value
|
||||
"top_gaps": top_gaps,
|
||||
}
|
||||
|
||||
# Return _generate(output_format, "executive_summary", context)
|
||||
return _generate(output_format, "executive_summary", context)
|
||||
|
||||
|
||||
# Define function generate_quarterly_summary
|
||||
def generate_quarterly_summary(
|
||||
# Entry: db
|
||||
db: Session,
|
||||
# Entry: output_format
|
||||
output_format: str = "pdf",
|
||||
) -> str:
|
||||
"""Quarterly summary — reuses executive metrics plus snapshot trend rows."""
|
||||
from sqlalchemy import case as sql_case, func
|
||||
# Import case as sql_case from sqlalchemy
|
||||
from sqlalchemy import case as sql_case
|
||||
|
||||
# Import func from sqlalchemy
|
||||
from sqlalchemy import func
|
||||
|
||||
# Assign org_score = _safe_org_score(db)
|
||||
org_score = _safe_org_score(db)
|
||||
# Assign quarter_ago = datetime.utcnow() - timedelta(days=90)
|
||||
quarter_ago = datetime.utcnow() - timedelta(days=90)
|
||||
# Assign tests_this_quarter = (
|
||||
tests_this_quarter = (
|
||||
db.query(func.count(Test.id)).filter(Test.created_at >= quarter_ago).scalar() or 0
|
||||
)
|
||||
|
||||
# Assign techniques = db.query(Technique).all()
|
||||
techniques = db.query(Technique).all()
|
||||
# Assign validated_count = sum(
|
||||
validated_count = sum(
|
||||
# Literal argument value
|
||||
1 for t in techniques if t.status_global and t.status_global.value == "validated"
|
||||
)
|
||||
# Assign detected_count = (
|
||||
detected_count = (
|
||||
db.query(func.count(Test.id))
|
||||
# Chain .filter() call
|
||||
.filter(Test.state == "validated", Test.detection_result == "detected")
|
||||
# Chain .scalar() call
|
||||
.scalar() or 0
|
||||
)
|
||||
# Assign detection_rate = (
|
||||
detection_rate = (
|
||||
round((detected_count / validated_count) * 100, 1) if validated_count > 0 else 0
|
||||
)
|
||||
|
||||
# Assign tactic_rows = (
|
||||
tactic_rows = (
|
||||
db.query(
|
||||
Technique.tactic,
|
||||
func.count(Technique.id).label("total"),
|
||||
func.sum(sql_case((Technique.status_global == "validated", 1), else_=0)).label(
|
||||
# Literal argument value
|
||||
"validated",
|
||||
),
|
||||
)
|
||||
# Chain .group_by() call
|
||||
.group_by(Technique.tactic)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
# Assign top_gaps = sorted(
|
||||
top_gaps = sorted(
|
||||
[
|
||||
{
|
||||
# Literal argument value
|
||||
"tactic": r[0] or "Unknown",
|
||||
# Literal argument value
|
||||
"coverage_pct": round((int(r[2]) / r[1]) * 100, 1) if r[1] > 0 else 0,
|
||||
}
|
||||
for r in tactic_rows
|
||||
],
|
||||
# Keyword argument: key
|
||||
key=lambda x: x["coverage_pct"],
|
||||
)[:5]
|
||||
|
||||
# Assign snapshots = (
|
||||
snapshots = (
|
||||
db.query(CoverageSnapshot)
|
||||
# Chain .filter() call
|
||||
.filter(CoverageSnapshot.created_at >= quarter_ago)
|
||||
# Chain .order_by() call
|
||||
.order_by(CoverageSnapshot.created_at)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
# Assign trend_rows = [
|
||||
trend_rows = [
|
||||
{
|
||||
# Literal argument value
|
||||
"date": s.created_at.strftime("%Y-%m-%d") if s.created_at else "",
|
||||
# Literal argument value
|
||||
"validated_count": s.validated_count,
|
||||
# Literal argument value
|
||||
"total_techniques": s.total_techniques,
|
||||
# Literal argument value
|
||||
"organization_score": round(s.organization_score, 1),
|
||||
}
|
||||
for s in snapshots
|
||||
]
|
||||
|
||||
# Assign now = datetime.utcnow()
|
||||
now = datetime.utcnow()
|
||||
# Assign quarter_label = f"Q{((now.month - 1) // 3) + 1} {now.year}"
|
||||
quarter_label = f"Q{((now.month - 1) // 3) + 1} {now.year}"
|
||||
|
||||
# Assign context = {
|
||||
context = {
|
||||
# Literal argument value
|
||||
"quarter_label": quarter_label,
|
||||
# Literal argument value
|
||||
"org_score": org_score,
|
||||
# Literal argument value
|
||||
"tests_this_quarter": tests_this_quarter,
|
||||
# Literal argument value
|
||||
"detection_rate": detection_rate,
|
||||
# Literal argument value
|
||||
"trend_rows": trend_rows,
|
||||
# Literal argument value
|
||||
"top_gaps": top_gaps,
|
||||
}
|
||||
# Return _generate(output_format, "quarterly_summary", context)
|
||||
return _generate(output_format, "quarterly_summary", context)
|
||||
|
||||
|
||||
# Define function generate_technique_detail_report
|
||||
def generate_technique_detail_report(
|
||||
# Entry: db
|
||||
db: Session,
|
||||
# Entry: technique_id
|
||||
technique_id: str,
|
||||
# Entry: output_format
|
||||
output_format: str = "pdf",
|
||||
) -> str:
|
||||
"""Detailed report for a single MITRE technique and its tests."""
|
||||
# Assign tid = technique_id if isinstance(technique_id, UUID) else UUID(str(techni...
|
||||
tid = technique_id if isinstance(technique_id, UUID) else UUID(str(technique_id))
|
||||
# Assign technique = db.query(Technique).filter(Technique.id == tid).first()
|
||||
technique = db.query(Technique).filter(Technique.id == tid).first()
|
||||
# Check: not technique
|
||||
if not technique:
|
||||
# Raise EntityNotFoundError
|
||||
raise EntityNotFoundError("Technique", str(technique_id))
|
||||
|
||||
# Assign related_tests = (
|
||||
related_tests = (
|
||||
db.query(Test)
|
||||
# Chain .filter() call
|
||||
.filter(Test.technique_id == tid)
|
||||
# Chain .order_by() call
|
||||
.order_by(Test.created_at.desc())
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
# Assign tests_data = [
|
||||
tests_data = [
|
||||
{
|
||||
# Literal argument value
|
||||
"name": t.name,
|
||||
# Literal argument value
|
||||
"state": t.state.value if t.state else "draft",
|
||||
# Literal argument value
|
||||
"detection_result": (
|
||||
t.detection_result.value if t.detection_result else "pending"
|
||||
),
|
||||
# Literal argument value
|
||||
"created_at": t.created_at.strftime("%Y-%m-%d") if t.created_at else "",
|
||||
}
|
||||
for t in related_tests
|
||||
]
|
||||
|
||||
# Assign context = {
|
||||
context = {
|
||||
# Literal argument value
|
||||
"technique": technique,
|
||||
# Literal argument value
|
||||
"technique_status": (
|
||||
technique.status_global.value if technique.status_global else "not_evaluated"
|
||||
),
|
||||
# Literal argument value
|
||||
"tests": tests_data,
|
||||
}
|
||||
# Return _generate(output_format, "technique_detail", context)
|
||||
return _generate(output_format, "technique_detail", context)
|
||||
|
||||
|
||||
@@ -351,19 +551,32 @@ def generate_technique_detail_report(
|
||||
|
||||
def _safe_org_score(db: Session) -> dict:
|
||||
"""Safely call the scoring service; return empty dict on failure."""
|
||||
# Attempt the following; catch errors below
|
||||
try:
|
||||
# Import calculate_organization_score from app.services.scoring_service
|
||||
from app.services.scoring_service import calculate_organization_score
|
||||
# Return calculate_organization_score(db)
|
||||
return calculate_organization_score(db)
|
||||
# Handle Exception
|
||||
except Exception as e:
|
||||
# Log warning: "Scoring service unavailable: %s", e
|
||||
logger.warning("Scoring service unavailable: %s", e)
|
||||
# Return {"overall": 0, "coverage": 0, "detection_maturity": 0}
|
||||
return {"overall": 0, "coverage": 0, "detection_maturity": 0}
|
||||
|
||||
|
||||
# Define function _generate
|
||||
def _generate(output_format: str, template_name: str, context: dict) -> str:
|
||||
"""Dispatch to the correct ReportEngine method."""
|
||||
# Check: output_format == "pdf"
|
||||
if output_format == "pdf":
|
||||
# Return report_engine.generate_pdf(template_name, context)
|
||||
return report_engine.generate_pdf(template_name, context)
|
||||
# Alternative: output_format == "docx"
|
||||
elif output_format == "docx":
|
||||
# Return report_engine.generate_docx(template_name, context)
|
||||
return report_engine.generate_docx(template_name, context)
|
||||
# Fallback: handle remaining cases
|
||||
else:
|
||||
# Return report_engine.generate_html_file(template_name, context)
|
||||
return report_engine.generate_html_file(template_name, context)
|
||||
|
||||
@@ -7,78 +7,123 @@ Thread-safe: each worker process has its own dict, and the TTL ensures
|
||||
stale data does not persist longer than ``CACHE_TTL`` seconds.
|
||||
"""
|
||||
|
||||
# Import time
|
||||
import time
|
||||
|
||||
# Import Any, Optional from typing
|
||||
from typing import Any, Optional
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Assign CACHE_TTL = 300 # 5 minutes
|
||||
CACHE_TTL = 300 # 5 minutes
|
||||
|
||||
# Assign _cache = {}
|
||||
_cache: dict[str, dict[str, Any]] = {}
|
||||
|
||||
|
||||
def get(key: str) -> Optional[Any]:
|
||||
# Define function get
|
||||
def get(key: str) -> Optional[Any]: # noqa: ANN401 # generic cache returns whatever was stored
|
||||
"""Return cached value if present and not expired, else None."""
|
||||
# Assign entry = _cache.get(key)
|
||||
entry = _cache.get(key)
|
||||
# Check: entry is None
|
||||
if entry is None:
|
||||
# Return None
|
||||
return None
|
||||
# Check: time.time() - entry["ts"] > CACHE_TTL
|
||||
if time.time() - entry["ts"] > CACHE_TTL:
|
||||
# Call _cache.pop()
|
||||
_cache.pop(key, None)
|
||||
# Return None
|
||||
return None
|
||||
# Return entry["data"]
|
||||
return entry["data"]
|
||||
|
||||
|
||||
def put(key: str, data: Any) -> None:
|
||||
# Define function put
|
||||
def put(key: str, data: Any) -> None: # noqa: ANN401 # generic cache accepts any serialisable value
|
||||
"""Store *data* under *key* with the current timestamp."""
|
||||
# Assign _cache[key] = {"data": data, "ts": time.time()}
|
||||
_cache[key] = {"data": data, "ts": time.time()}
|
||||
|
||||
|
||||
# Define function invalidate
|
||||
def invalidate(key: Optional[str] = None) -> None:
|
||||
"""Remove one key or clear the whole cache."""
|
||||
# Check: key is None
|
||||
if key is None:
|
||||
# Call _cache.clear()
|
||||
_cache.clear()
|
||||
# Fallback: handle remaining cases
|
||||
else:
|
||||
# Call _cache.pop()
|
||||
_cache.pop(key, None)
|
||||
|
||||
|
||||
# ── High-level helpers ────────────────────────────────────────────────
|
||||
|
||||
|
||||
def get_organization_score_cached(db):
|
||||
def get_organization_score_cached(db: Session) -> dict:
|
||||
"""Cached wrapper around ``calculate_organization_score``."""
|
||||
# Import calculate_organization_score from app.services.scoring_service
|
||||
from app.services.scoring_service import calculate_organization_score
|
||||
|
||||
# Assign cached = get("org_score")
|
||||
cached = get("org_score")
|
||||
# Check: cached is not None
|
||||
if cached is not None:
|
||||
# Return cached
|
||||
return cached
|
||||
|
||||
# Assign result = calculate_organization_score(db)
|
||||
result = calculate_organization_score(db)
|
||||
# Call put()
|
||||
put("org_score", result)
|
||||
# Return result
|
||||
return result
|
||||
|
||||
|
||||
def get_operational_metrics_cached(db):
|
||||
# Define function get_operational_metrics_cached
|
||||
def get_operational_metrics_cached(db: Session) -> dict:
|
||||
"""Cached wrapper around operational metrics (MTTD, MTTR, efficacy)."""
|
||||
# Import from app.services.operational_metrics_service
|
||||
from app.services.operational_metrics_service import (
|
||||
calculate_mttd,
|
||||
calculate_mttr,
|
||||
calculate_detection_efficacy,
|
||||
calculate_alert_fidelity,
|
||||
calculate_coverage_velocity,
|
||||
calculate_validation_throughput,
|
||||
calculate_detection_efficacy,
|
||||
calculate_mttd,
|
||||
calculate_mttr,
|
||||
calculate_rejection_rate,
|
||||
calculate_validation_throughput,
|
||||
)
|
||||
|
||||
# Assign cached = get("op_metrics")
|
||||
cached = get("op_metrics")
|
||||
# Check: cached is not None
|
||||
if cached is not None:
|
||||
# Return cached
|
||||
return cached
|
||||
|
||||
# Assign result = {
|
||||
result = {
|
||||
# Literal argument value
|
||||
"mttd": calculate_mttd(db),
|
||||
# Literal argument value
|
||||
"mttr": calculate_mttr(db),
|
||||
# Literal argument value
|
||||
"detection_efficacy": calculate_detection_efficacy(db),
|
||||
# Literal argument value
|
||||
"alert_fidelity": calculate_alert_fidelity(db),
|
||||
# Literal argument value
|
||||
"coverage_velocity": calculate_coverage_velocity(db),
|
||||
# Literal argument value
|
||||
"validation_throughput": calculate_validation_throughput(db),
|
||||
# Literal argument value
|
||||
"rejection_rate": calculate_rejection_rate(db),
|
||||
}
|
||||
# Call put()
|
||||
put("op_metrics", result)
|
||||
# Return result
|
||||
return result
|
||||
|
||||
@@ -1,121 +1,202 @@
|
||||
"""Scoring configuration persistence service."""
|
||||
|
||||
# Enable future language features for compatibility
|
||||
from __future__ import annotations
|
||||
|
||||
# Import uuid
|
||||
import uuid
|
||||
|
||||
# Import Any from typing
|
||||
from typing import Any
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import settings from app.config
|
||||
from app.config import settings
|
||||
|
||||
# Import ScoringWeights from app.domain.value_objects.scoring_weights
|
||||
from app.domain.value_objects.scoring_weights import ScoringWeights
|
||||
|
||||
# Import ScoringConfig from app.models.scoring_config
|
||||
from app.models.scoring_config import ScoringConfig
|
||||
|
||||
|
||||
# Define function _row_recency
|
||||
def _row_recency(row: ScoringConfig) -> float:
|
||||
# Return float(getattr(row, "weight_recency", None) or getattr(row, "weight_...
|
||||
return float(getattr(row, "weight_recency", None) or getattr(row, "weight_freshness", 10.0))
|
||||
|
||||
|
||||
# Define function _row_severity
|
||||
def _row_severity(row: ScoringConfig) -> float:
|
||||
# Return float(
|
||||
return float(
|
||||
getattr(row, "weight_severity", None)
|
||||
or getattr(row, "weight_platform_diversity", 10.0)
|
||||
)
|
||||
|
||||
|
||||
# Define function get_scoring_weights
|
||||
def get_scoring_weights(db: Session) -> ScoringWeights:
|
||||
"""Return the active scoring weights from the database or env defaults."""
|
||||
# Assign row = db.query(ScoringConfig).first()
|
||||
row = db.query(ScoringConfig).first()
|
||||
# Check: row is not None
|
||||
if row is not None:
|
||||
# Return ScoringWeights(
|
||||
return ScoringWeights(
|
||||
# Keyword argument: tests
|
||||
tests=row.weight_tests,
|
||||
# Keyword argument: detection_rules
|
||||
detection_rules=row.weight_detection_rules,
|
||||
# Keyword argument: d3fend
|
||||
d3fend=row.weight_d3fend,
|
||||
# Keyword argument: recency
|
||||
recency=_row_recency(row),
|
||||
# Keyword argument: severity
|
||||
severity=_row_severity(row),
|
||||
)
|
||||
|
||||
# Return ScoringWeights(
|
||||
return ScoringWeights(
|
||||
# Keyword argument: tests
|
||||
tests=float(settings.SCORING_WEIGHT_TESTS),
|
||||
# Keyword argument: detection_rules
|
||||
detection_rules=float(settings.SCORING_WEIGHT_DETECTION_RULES),
|
||||
# Keyword argument: d3fend
|
||||
d3fend=float(settings.SCORING_WEIGHT_D3FEND),
|
||||
# Keyword argument: recency
|
||||
recency=float(
|
||||
getattr(settings, "SCORING_WEIGHT_RECENCY", settings.SCORING_WEIGHT_FRESHNESS)
|
||||
),
|
||||
# Keyword argument: severity
|
||||
severity=float(
|
||||
getattr(settings, "SCORING_WEIGHT_SEVERITY", settings.SCORING_WEIGHT_PLATFORM_DIVERSITY)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# Define function update_scoring_weights
|
||||
def update_scoring_weights(
|
||||
# Entry: db
|
||||
db: Session,
|
||||
*,
|
||||
# Entry: tests
|
||||
tests: float | None = None,
|
||||
# Entry: detection_rules
|
||||
detection_rules: float | None = None,
|
||||
# Entry: d3fend
|
||||
d3fend: float | None = None,
|
||||
# Entry: recency
|
||||
recency: float | None = None,
|
||||
# Entry: severity
|
||||
severity: float | None = None,
|
||||
# Entry: freshness
|
||||
freshness: float | None = None,
|
||||
# Entry: platform_diversity
|
||||
platform_diversity: float | None = None,
|
||||
# Entry: updated_by
|
||||
updated_by: uuid.UUID | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Upsert scoring weights. Does not commit."""
|
||||
# Check: freshness is not None and recency is None
|
||||
if freshness is not None and recency is None:
|
||||
# Assign recency = freshness
|
||||
recency = freshness
|
||||
# Check: platform_diversity is not None and severity is None
|
||||
if platform_diversity is not None and severity is None:
|
||||
# Assign severity = platform_diversity
|
||||
severity = platform_diversity
|
||||
|
||||
# Assign current = get_scoring_weights(db)
|
||||
current = get_scoring_weights(db)
|
||||
|
||||
# Assign new = ScoringWeights(
|
||||
new = ScoringWeights(
|
||||
# Keyword argument: tests
|
||||
tests=tests if tests is not None else current.tests,
|
||||
# Keyword argument: detection_rules
|
||||
detection_rules=detection_rules if detection_rules is not None else current.detection_rules,
|
||||
# Keyword argument: d3fend
|
||||
d3fend=d3fend if d3fend is not None else current.d3fend,
|
||||
# Keyword argument: recency
|
||||
recency=recency if recency is not None else current.recency,
|
||||
# Keyword argument: severity
|
||||
severity=severity if severity is not None else current.severity,
|
||||
)
|
||||
|
||||
# Assign row = db.query(ScoringConfig).first()
|
||||
row = db.query(ScoringConfig).first()
|
||||
# Check: row is None
|
||||
if row is None:
|
||||
# Assign row = ScoringConfig()
|
||||
row = ScoringConfig()
|
||||
# Stage new record(s) for database insertion
|
||||
db.add(row)
|
||||
|
||||
# Assign row.weight_tests = new.tests
|
||||
row.weight_tests = new.tests
|
||||
# Assign row.weight_detection_rules = new.detection_rules
|
||||
row.weight_detection_rules = new.detection_rules
|
||||
# Assign row.weight_d3fend = new.d3fend
|
||||
row.weight_d3fend = new.d3fend
|
||||
# Check: hasattr(row, "weight_recency")
|
||||
if hasattr(row, "weight_recency"):
|
||||
# Assign row.weight_recency = new.recency
|
||||
row.weight_recency = new.recency
|
||||
# Alternative: hasattr(row, "weight_freshness")
|
||||
elif hasattr(row, "weight_freshness"):
|
||||
# Assign row.weight_freshness = new.recency
|
||||
row.weight_freshness = new.recency
|
||||
# Check: hasattr(row, "weight_severity")
|
||||
if hasattr(row, "weight_severity"):
|
||||
# Assign row.weight_severity = new.severity
|
||||
row.weight_severity = new.severity
|
||||
# Alternative: hasattr(row, "weight_platform_diversity")
|
||||
elif hasattr(row, "weight_platform_diversity"):
|
||||
# Assign row.weight_platform_diversity = new.severity
|
||||
row.weight_platform_diversity = new.severity
|
||||
# Check: updated_by is not None and hasattr(row, "updated_by")
|
||||
if updated_by is not None and hasattr(row, "updated_by"):
|
||||
# Assign row.updated_by = updated_by
|
||||
row.updated_by = updated_by
|
||||
|
||||
# Return _weights_dict(new)
|
||||
return _weights_dict(new)
|
||||
|
||||
|
||||
# Define function get_weights_dict
|
||||
def get_weights_dict(db: Session) -> dict[str, Any]:
|
||||
"""Return current weights as a serialisable dict."""
|
||||
# Return _weights_dict(get_scoring_weights(db))
|
||||
return _weights_dict(get_scoring_weights(db))
|
||||
|
||||
|
||||
# Define function _weights_dict
|
||||
def _weights_dict(w: ScoringWeights) -> dict[str, Any]:
|
||||
# Assign weights = {
|
||||
weights = {
|
||||
# Literal argument value
|
||||
"tests": w.tests,
|
||||
# Literal argument value
|
||||
"detection_rules": w.detection_rules,
|
||||
# Literal argument value
|
||||
"d3fend": w.d3fend,
|
||||
# Literal argument value
|
||||
"recency": w.recency,
|
||||
# Literal argument value
|
||||
"severity": w.severity,
|
||||
# Legacy keys for older clients
|
||||
"freshness": w.recency,
|
||||
# Literal argument value
|
||||
"platform_diversity": w.severity,
|
||||
}
|
||||
# Return {
|
||||
return {
|
||||
# Literal argument value
|
||||
"weights": weights,
|
||||
# Literal argument value
|
||||
"total": sum(
|
||||
[w.tests, w.detection_rules, w.d3fend, w.recency, w.severity]
|
||||
),
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -22,24 +22,45 @@ rules are identified by ``source = "sigma"`` + ``source_id`` (relative
|
||||
file path) and simply skipped.
|
||||
"""
|
||||
|
||||
# Import io
|
||||
import io
|
||||
|
||||
# Import logging
|
||||
import logging
|
||||
|
||||
# Import re
|
||||
import re
|
||||
|
||||
# Import shutil
|
||||
import shutil
|
||||
|
||||
# Import tempfile
|
||||
import tempfile
|
||||
|
||||
# Import zipfile
|
||||
import zipfile
|
||||
|
||||
# Import datetime from datetime
|
||||
from datetime import datetime
|
||||
|
||||
# Import Path from pathlib
|
||||
from pathlib import Path
|
||||
|
||||
# Import requests
|
||||
import requests as _requests
|
||||
|
||||
# Import yaml
|
||||
import yaml
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.detection_rule import DetectionRule
|
||||
# Import DataSource from app.models.data_source
|
||||
from app.models.data_source import DataSource
|
||||
from app.models.technique import Technique
|
||||
from app.services.audit_service import log_action
|
||||
|
||||
# Assign logger = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -47,22 +68,35 @@ logger = logging.getLogger(__name__)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
SIGMA_ZIP_URL = (
|
||||
# Literal argument value
|
||||
"https://github.com/SigmaHQ/sigma/archive/refs/heads/master.zip"
|
||||
)
|
||||
|
||||
# Assign _DOWNLOAD_TIMEOUT = 300
|
||||
_DOWNLOAD_TIMEOUT = 300
|
||||
# Assign _ZIP_ROOT_PREFIX = "sigma-master"
|
||||
_ZIP_ROOT_PREFIX = "sigma-master"
|
||||
|
||||
# Safety limits for ZIP extraction — prevent zip-bomb DoS
|
||||
_MAX_UNCOMPRESSED_SIZE = 500 * 1024 * 1024 # 500 MB
|
||||
# Assign _MAX_ENTRIES = 50_000
|
||||
_MAX_ENTRIES = 50_000
|
||||
|
||||
# Regex to extract MITRE ATT&CK technique IDs from Sigma tags
|
||||
# e.g. "attack.t1059.001" → "T1059.001"
|
||||
_ATTACK_TAG_RE = re.compile(r"attack\.(t\d{4}(?:\.\d{3})?)", re.IGNORECASE)
|
||||
|
||||
# Sigma severity levels
|
||||
_SEVERITY_MAP = {
|
||||
# Literal argument value
|
||||
"informational": "informational",
|
||||
# Literal argument value
|
||||
"low": "low",
|
||||
# Literal argument value
|
||||
"medium": "medium",
|
||||
# Literal argument value
|
||||
"high": "high",
|
||||
# Literal argument value
|
||||
"critical": "critical",
|
||||
}
|
||||
|
||||
@@ -74,14 +108,21 @@ _SEVERITY_MAP = {
|
||||
|
||||
def _download_zip(url: str = SIGMA_ZIP_URL) -> bytes:
|
||||
"""Download the SigmaHQ ZIP and return raw bytes."""
|
||||
# Log info: "Downloading SigmaHQ ZIP from %s …", url
|
||||
logger.info("Downloading SigmaHQ ZIP from %s …", url)
|
||||
# Assign resp = _requests.get(url, timeout=_DOWNLOAD_TIMEOUT, stream=True)
|
||||
resp = _requests.get(url, timeout=_DOWNLOAD_TIMEOUT, stream=True)
|
||||
# Call resp.raise_for_status()
|
||||
resp.raise_for_status()
|
||||
# Assign content = resp.content
|
||||
content = resp.content
|
||||
# Log info: "Downloaded %.1f MB", len(content) / (1024 * 1024
|
||||
logger.info("Downloaded %.1f MB", len(content) / (1024 * 1024))
|
||||
# Return content
|
||||
return content
|
||||
|
||||
|
||||
# Define function _safe_extract_zip
|
||||
def _safe_extract_zip(zip_bytes: bytes, dest: str) -> None:
|
||||
"""Extract *zip_bytes* into *dest* with Zip Slip and Zip Bomb protection.
|
||||
|
||||
@@ -89,165 +130,249 @@ def _safe_extract_zip(zip_bytes: bytes, dest: str) -> None:
|
||||
directory (path traversal / Zip Slip) or if the archive exceeds the
|
||||
safety limits.
|
||||
"""
|
||||
# Maximum uncompressed size: 500 MB — prevents zip-bomb DoS
|
||||
_MAX_UNCOMPRESSED_SIZE = 500 * 1024 * 1024
|
||||
# Maximum number of entries
|
||||
_MAX_ENTRIES = 50_000
|
||||
|
||||
# Assign dest_path = Path(dest).resolve()
|
||||
dest_path = Path(dest).resolve()
|
||||
|
||||
# Open context manager
|
||||
with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
|
||||
# Assign entries = zf.infolist()
|
||||
entries = zf.infolist()
|
||||
|
||||
# Check: len(entries) > _MAX_ENTRIES
|
||||
if len(entries) > _MAX_ENTRIES:
|
||||
# Raise ValueError
|
||||
raise ValueError(
|
||||
f"ZIP archive contains {len(entries)} entries "
|
||||
f"(limit: {_MAX_ENTRIES}) — possible zip bomb"
|
||||
)
|
||||
|
||||
# Assign total_size = sum(info.file_size for info in entries)
|
||||
total_size = sum(info.file_size for info in entries)
|
||||
# Check: total_size > _MAX_UNCOMPRESSED_SIZE
|
||||
if total_size > _MAX_UNCOMPRESSED_SIZE:
|
||||
# Raise ValueError
|
||||
raise ValueError(
|
||||
f"ZIP uncompressed size {total_size / (1024 * 1024):.0f} MB "
|
||||
f"exceeds limit of {_MAX_UNCOMPRESSED_SIZE / (1024 * 1024):.0f} MB"
|
||||
)
|
||||
|
||||
# Iterate over entries
|
||||
for member in entries:
|
||||
# Assign target = (dest_path / member.filename).resolve()
|
||||
target = (dest_path / member.filename).resolve()
|
||||
# Check: not target.is_relative_to(dest_path)
|
||||
if not target.is_relative_to(dest_path):
|
||||
# Raise ValueError
|
||||
raise ValueError(
|
||||
f"Zip Slip detected — member '{member.filename}' "
|
||||
f"resolves outside target directory"
|
||||
)
|
||||
|
||||
# Call zf.extractall()
|
||||
zf.extractall(dest)
|
||||
|
||||
|
||||
# Define function _extract_zip
|
||||
def _extract_zip(zip_bytes: bytes, dest: str) -> Path:
|
||||
"""Extract *zip_bytes* into *dest* and return the path to rules/ dir."""
|
||||
# Call _safe_extract_zip()
|
||||
_safe_extract_zip(zip_bytes, dest)
|
||||
# Assign rules_dir = Path(dest) / _ZIP_ROOT_PREFIX / "rules"
|
||||
rules_dir = Path(dest) / _ZIP_ROOT_PREFIX / "rules"
|
||||
# Check: not rules_dir.is_dir()
|
||||
if not rules_dir.is_dir():
|
||||
# Raise FileNotFoundError
|
||||
raise FileNotFoundError(
|
||||
f"Expected rules directory not found at {rules_dir}"
|
||||
)
|
||||
# Return rules_dir
|
||||
return rules_dir
|
||||
|
||||
|
||||
# Define function _extract_attack_tags
|
||||
def _extract_attack_tags(tags: list) -> list[str]:
|
||||
"""Extract MITRE technique IDs from Sigma tag list.
|
||||
|
||||
Example input: ["attack.defense_evasion", "attack.t1059.001", "cve.2021.44228"]
|
||||
Example output: ["T1059.001"]
|
||||
"""
|
||||
# Assign technique_ids = []
|
||||
technique_ids = []
|
||||
# Iterate over tags
|
||||
for tag in tags:
|
||||
# Assign m = _ATTACK_TAG_RE.match(str(tag).strip())
|
||||
m = _ATTACK_TAG_RE.match(str(tag).strip())
|
||||
# Check: m
|
||||
if m:
|
||||
# Call technique_ids.append()
|
||||
technique_ids.append(m.group(1).upper())
|
||||
# Return list(set(technique_ids))
|
||||
return list(set(technique_ids))
|
||||
|
||||
|
||||
# Define function _parse_sigma_rules
|
||||
def _parse_sigma_rules(rules_dir: Path) -> list[dict]:
|
||||
"""Walk the rules directory and parse all Sigma YAML files.
|
||||
|
||||
Returns a flat list of dicts, one per (rule, technique) combination.
|
||||
A single Sigma rule tagged with N techniques produces N entries.
|
||||
"""
|
||||
# Assign results = []
|
||||
results: list[dict] = []
|
||||
# Assign yaml_files = sorted(rules_dir.rglob("*.yml"))
|
||||
yaml_files = sorted(rules_dir.rglob("*.yml"))
|
||||
# Log info: "Found %d YAML files to parse", len(yaml_files
|
||||
logger.info("Found %d YAML files to parse", len(yaml_files))
|
||||
|
||||
# Iterate over yaml_files
|
||||
for yaml_path in yaml_files:
|
||||
# Assign relative_path = str(yaml_path.relative_to(rules_dir.parent))
|
||||
relative_path = str(yaml_path.relative_to(rules_dir.parent))
|
||||
# Attempt the following; catch errors below
|
||||
try:
|
||||
# Open context manager
|
||||
with open(yaml_path, "r", encoding="utf-8") as fh:
|
||||
# Assign data = yaml.safe_load(fh)
|
||||
data = yaml.safe_load(fh)
|
||||
# Handle Exception
|
||||
except Exception as exc:
|
||||
# Log debug: "Failed to parse %s: %s", yaml_path, exc
|
||||
logger.debug("Failed to parse %s: %s", yaml_path, exc)
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
|
||||
# Check: not isinstance(data, dict)
|
||||
if not isinstance(data, dict):
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
|
||||
# Assign title = data.get("title", "").strip()
|
||||
title = data.get("title", "").strip()
|
||||
# Check: not title
|
||||
if not title:
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
|
||||
# Extract ATT&CK technique IDs from tags
|
||||
tags = data.get("tags", [])
|
||||
# Check: not isinstance(tags, list)
|
||||
if not isinstance(tags, list):
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
|
||||
# Assign technique_ids = _extract_attack_tags(tags)
|
||||
technique_ids = _extract_attack_tags(tags)
|
||||
# Check: not technique_ids
|
||||
if not technique_ids:
|
||||
# continue # Skip rules without ATT&CK mapping
|
||||
continue # Skip rules without ATT&CK mapping
|
||||
|
||||
# Assign description = data.get("description", "")
|
||||
description = data.get("description", "")
|
||||
# Assign level = str(data.get("level", "")).lower()
|
||||
level = str(data.get("level", "")).lower()
|
||||
# Assign severity = _SEVERITY_MAP.get(level)
|
||||
severity = _SEVERITY_MAP.get(level)
|
||||
|
||||
# Extract logsource
|
||||
logsource = data.get("logsource", {})
|
||||
# Check: not isinstance(logsource, dict)
|
||||
if not isinstance(logsource, dict):
|
||||
# Assign logsource = {}
|
||||
logsource = {}
|
||||
|
||||
# Read full YAML content for storage
|
||||
try:
|
||||
# Open context manager
|
||||
with open(yaml_path, "r", encoding="utf-8") as fh:
|
||||
# Assign raw_content = fh.read()
|
||||
raw_content = fh.read()
|
||||
# Handle Exception
|
||||
except Exception:
|
||||
# Assign raw_content = yaml.dump(data, default_flow_style=False)
|
||||
raw_content = yaml.dump(data, default_flow_style=False)
|
||||
|
||||
# False positive assessment
|
||||
falsepositives = data.get("falsepositives", [])
|
||||
# Check: isinstance(falsepositives, list) and len(falsepositives) > 3
|
||||
if isinstance(falsepositives, list) and len(falsepositives) > 3:
|
||||
# Assign fp_rate = "high"
|
||||
fp_rate = "high"
|
||||
# Alternative: isinstance(falsepositives, list) and len(falsepositives) > 1
|
||||
elif isinstance(falsepositives, list) and len(falsepositives) > 1:
|
||||
# Assign fp_rate = "medium"
|
||||
fp_rate = "medium"
|
||||
# Fallback: handle remaining cases
|
||||
else:
|
||||
# Assign fp_rate = "low"
|
||||
fp_rate = "low"
|
||||
|
||||
# Create one entry per technique
|
||||
for tech_id in technique_ids:
|
||||
# Assign source_url = (
|
||||
source_url = (
|
||||
f"https://github.com/SigmaHQ/sigma/blob/master/"
|
||||
f"{relative_path.replace(chr(92), '/')}"
|
||||
)
|
||||
# Call results.append()
|
||||
results.append({
|
||||
# Literal argument value
|
||||
"mitre_technique_id": tech_id,
|
||||
# Literal argument value
|
||||
"title": title[:500],
|
||||
# Literal argument value
|
||||
"description": str(description)[:2000] if description else None,
|
||||
# Literal argument value
|
||||
"source_id": relative_path,
|
||||
# Literal argument value
|
||||
"source_url": source_url,
|
||||
# Literal argument value
|
||||
"rule_content": raw_content,
|
||||
# Literal argument value
|
||||
"severity": severity,
|
||||
# Literal argument value
|
||||
"log_sources": logsource if logsource else None,
|
||||
# Literal argument value
|
||||
"false_positive_rate": fp_rate,
|
||||
# Literal argument value
|
||||
"platforms": _platforms_from_logsource(logsource),
|
||||
})
|
||||
|
||||
# Log info: "Parsed %d (rule, technique) pairs total", len(res
|
||||
logger.info("Parsed %d (rule, technique) pairs total", len(results))
|
||||
# Return results
|
||||
return results
|
||||
|
||||
|
||||
# Define function _platforms_from_logsource
|
||||
def _platforms_from_logsource(logsource: dict) -> list[str]:
|
||||
"""Infer platform list from Sigma logsource."""
|
||||
# Assign platforms = []
|
||||
platforms = []
|
||||
# Assign product = str(logsource.get("product", "")).lower()
|
||||
product = str(logsource.get("product", "")).lower()
|
||||
# Assign service = str(logsource.get("service", "")).lower()
|
||||
service = str(logsource.get("service", "")).lower()
|
||||
|
||||
# Check: "windows" in product or "windows" in service
|
||||
if "windows" in product or "windows" in service:
|
||||
# Call platforms.append()
|
||||
platforms.append("windows")
|
||||
# Check: "linux" in product or "linux" in service
|
||||
if "linux" in product or "linux" in service:
|
||||
# Call platforms.append()
|
||||
platforms.append("linux")
|
||||
# Check: "macos" in product or "macos" in service
|
||||
if "macos" in product or "macos" in service:
|
||||
# Call platforms.append()
|
||||
platforms.append("macos")
|
||||
|
||||
# Sysmon → Windows
|
||||
if "sysmon" in service and "windows" not in platforms:
|
||||
# Call platforms.append()
|
||||
platforms.append("windows")
|
||||
|
||||
# Return platforms if platforms else None
|
||||
return platforms if platforms else None
|
||||
|
||||
|
||||
@@ -264,59 +389,88 @@ def sync(db: Session) -> dict:
|
||||
db : Session
|
||||
Active SQLAlchemy database session.
|
||||
|
||||
Returns
|
||||
Returns:
|
||||
-------
|
||||
dict
|
||||
Summary with ``created``, ``skipped_existing``, ``total_parsed``.
|
||||
"""
|
||||
# Assign tmp_dir = tempfile.mkdtemp(prefix="aegis_sigma_")
|
||||
tmp_dir = tempfile.mkdtemp(prefix="aegis_sigma_")
|
||||
# Attempt the following; catch errors below
|
||||
try:
|
||||
# Assign zip_bytes = _download_zip()
|
||||
zip_bytes = _download_zip()
|
||||
# Assign rules_dir = _extract_zip(zip_bytes, tmp_dir)
|
||||
rules_dir = _extract_zip(zip_bytes, tmp_dir)
|
||||
# Assign parsed_rules = _parse_sigma_rules(rules_dir)
|
||||
parsed_rules = _parse_sigma_rules(rules_dir)
|
||||
# Always execute this cleanup block
|
||||
finally:
|
||||
# Call shutil.rmtree()
|
||||
shutil.rmtree(tmp_dir, ignore_errors=True)
|
||||
# Log info: "Cleaned up temp directory %s", tmp_dir
|
||||
logger.info("Cleaned up temp directory %s", tmp_dir)
|
||||
|
||||
# Pre-load existing source_ids for dedup
|
||||
existing_ids: set[str] = {
|
||||
row[0]
|
||||
for row in db.query(DetectionRule.source_id)
|
||||
# Chain .filter() call
|
||||
.filter(DetectionRule.source == "sigma")
|
||||
# Chain .filter() call
|
||||
.filter(DetectionRule.source_id.isnot(None))
|
||||
# Chain .all() call
|
||||
.all()
|
||||
}
|
||||
|
||||
# Assign created = 0
|
||||
created = 0
|
||||
# Assign skipped = 0
|
||||
skipped = 0
|
||||
new_technique_ids: set[str] = set()
|
||||
|
||||
# Iterate over parsed_rules
|
||||
for item in parsed_rules:
|
||||
# Dedup key: source_id (relative path). A rule file may produce
|
||||
# multiple entries (one per technique), but we deduplicate by
|
||||
# source_id so re-runs are safe. For multi-technique rules we
|
||||
# only skip if the exact same source_id is already present.
|
||||
dedup_key = f"{item['source_id']}::{item['mitre_technique_id']}"
|
||||
# Deduplicate by source_id: one rule file may map to multiple techniques,
|
||||
# but we skip insertion if this source_id was already imported.
|
||||
if item["source_id"] in existing_ids:
|
||||
# Assign skipped = 1
|
||||
skipped += 1
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
|
||||
# Assign rule = DetectionRule(
|
||||
rule = DetectionRule(
|
||||
# Keyword argument: mitre_technique_id
|
||||
mitre_technique_id=item["mitre_technique_id"],
|
||||
# Keyword argument: title
|
||||
title=item["title"],
|
||||
# Keyword argument: description
|
||||
description=item["description"],
|
||||
# Keyword argument: source
|
||||
source="sigma",
|
||||
# Keyword argument: source_id
|
||||
source_id=item["source_id"],
|
||||
# Keyword argument: source_url
|
||||
source_url=item["source_url"],
|
||||
# Keyword argument: rule_content
|
||||
rule_content=item["rule_content"],
|
||||
# Keyword argument: rule_format
|
||||
rule_format="sigma_yaml",
|
||||
# Keyword argument: severity
|
||||
severity=item["severity"],
|
||||
# Keyword argument: platforms
|
||||
platforms=item["platforms"],
|
||||
# Keyword argument: log_sources
|
||||
log_sources=item["log_sources"],
|
||||
# Keyword argument: false_positive_rate
|
||||
false_positive_rate=item["false_positive_rate"],
|
||||
# Keyword argument: is_active
|
||||
is_active=True,
|
||||
)
|
||||
# Stage new record(s) for database insertion
|
||||
db.add(rule)
|
||||
# Call existing_ids.add()
|
||||
existing_ids.add(item["source_id"])
|
||||
new_technique_ids.add(item["mitre_technique_id"])
|
||||
created += 1
|
||||
@@ -329,30 +483,48 @@ def sync(db: Session) -> dict:
|
||||
|
||||
db.commit()
|
||||
|
||||
# Assign summary = {
|
||||
summary = {
|
||||
# Literal argument value
|
||||
"created": created,
|
||||
# Literal argument value
|
||||
"skipped_existing": skipped,
|
||||
# Literal argument value
|
||||
"total_parsed": len(parsed_rules),
|
||||
}
|
||||
|
||||
# Update DataSource record
|
||||
ds = db.query(DataSource).filter(DataSource.name == "sigma").first()
|
||||
# Check: ds
|
||||
if ds:
|
||||
# Assign ds.last_sync_at = datetime.utcnow()
|
||||
ds.last_sync_at = datetime.utcnow()
|
||||
# Assign ds.last_sync_status = "success"
|
||||
ds.last_sync_status = "success"
|
||||
# Assign ds.last_sync_stats = summary
|
||||
ds.last_sync_stats = summary
|
||||
# Commit all pending changes to the database
|
||||
db.commit()
|
||||
|
||||
# Log info: "Sigma import complete — %s", summary
|
||||
logger.info("Sigma import complete — %s", summary)
|
||||
|
||||
# Call log_action()
|
||||
log_action(
|
||||
db,
|
||||
# Keyword argument: user_id
|
||||
user_id=None,
|
||||
# Keyword argument: action
|
||||
action="import_sigma_rules",
|
||||
# Keyword argument: entity_type
|
||||
entity_type="detection_rule",
|
||||
# Keyword argument: entity_id
|
||||
entity_id=None,
|
||||
# Keyword argument: details
|
||||
details=summary,
|
||||
)
|
||||
# Commit all pending changes to the database
|
||||
db.commit()
|
||||
|
||||
# Return summary
|
||||
return summary
|
||||
|
||||
@@ -7,25 +7,59 @@ Uses ``bulk_technique_scores`` so that snapshot creation runs in a fixed
|
||||
number of SQL queries regardless of technique count.
|
||||
"""
|
||||
|
||||
# Import logging
|
||||
import logging
|
||||
|
||||
# Import uuid
|
||||
import uuid
|
||||
|
||||
# Import defaultdict from collections
|
||||
from collections import defaultdict
|
||||
|
||||
# Import datetime, timedelta, timezone from datetime
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
# Import func from sqlalchemy
|
||||
from sqlalchemy import func
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import EntityNotFoundError from app.domain.errors
|
||||
from app.domain.errors import EntityNotFoundError
|
||||
from app.models.technique import Technique
|
||||
|
||||
# Import CoverageSnapshot, SnapshotTechniqueState from app.models.coverage_snapshot
|
||||
from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState
|
||||
|
||||
# Import TechniqueStatus from app.models.enums
|
||||
from app.models.enums import TechniqueStatus
|
||||
|
||||
# Import Technique from app.models.technique
|
||||
from app.models.technique import Technique
|
||||
|
||||
# Import from app.services.scoring_service
|
||||
from app.services.scoring_service import (
|
||||
bulk_technique_scores,
|
||||
calculate_organization_score,
|
||||
)
|
||||
|
||||
# Assign logger = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Coverage status ordering for snapshot delta comparisons (higher = better coverage)
|
||||
_STATUS_ORDER: dict[str, int] = {
|
||||
# Literal argument value
|
||||
"not_evaluated": 0,
|
||||
# Literal argument value
|
||||
"not_covered": 1,
|
||||
# Literal argument value
|
||||
"in_progress": 2,
|
||||
# Literal argument value
|
||||
"partial": 3,
|
||||
# Literal argument value
|
||||
"validated": 4,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Serialization and queries
|
||||
@@ -33,97 +67,207 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def serialize_snapshot_summary(snap: CoverageSnapshot) -> dict:
|
||||
"""Lightweight serialization for list views."""
|
||||
"""Return a lightweight serialization of a snapshot for list views.
|
||||
|
||||
Args:
|
||||
snap (CoverageSnapshot): The snapshot ORM object to serialize.
|
||||
|
||||
Returns:
|
||||
dict: Flat dictionary with summary fields (counts, scores, tactic
|
||||
breakdown) suitable for paginated list responses.
|
||||
"""
|
||||
# Return {
|
||||
return {
|
||||
# Literal argument value
|
||||
"id": str(snap.id),
|
||||
# Literal argument value
|
||||
"name": snap.name,
|
||||
# Literal argument value
|
||||
"organization_score": snap.organization_score,
|
||||
# Literal argument value
|
||||
"total_techniques": snap.total_techniques,
|
||||
# Literal argument value
|
||||
"validated_count": snap.validated_count,
|
||||
# Literal argument value
|
||||
"partial_count": snap.partial_count,
|
||||
# Literal argument value
|
||||
"not_covered_count": snap.not_covered_count,
|
||||
# Literal argument value
|
||||
"in_progress_count": snap.in_progress_count,
|
||||
# Literal argument value
|
||||
"not_evaluated_count": snap.not_evaluated_count,
|
||||
# Literal argument value
|
||||
"coverage_percentage": getattr(snap, "coverage_percentage", 0.0),
|
||||
# Literal argument value
|
||||
"by_tactic": getattr(snap, "by_tactic", None) or {},
|
||||
# Literal argument value
|
||||
"by_status": getattr(snap, "by_status", None) or {},
|
||||
# Literal argument value
|
||||
"stale_count": getattr(snap, "stale_count", 0),
|
||||
# Literal argument value
|
||||
"never_tested_count": getattr(snap, "never_tested_count", 0),
|
||||
# Literal argument value
|
||||
"created_by": str(snap.created_by) if snap.created_by else None,
|
||||
# Literal argument value
|
||||
"created_at": snap.created_at.isoformat() if snap.created_at else None,
|
||||
}
|
||||
|
||||
|
||||
# Define function serialize_snapshot_detail
|
||||
def serialize_snapshot_detail(db: Session, snap: CoverageSnapshot) -> dict:
|
||||
"""Full serialization including technique states."""
|
||||
"""Return full serialization of a snapshot including per-technique states.
|
||||
|
||||
Args:
|
||||
db (Session): Active SQLAlchemy database session.
|
||||
snap (CoverageSnapshot): The snapshot ORM object to serialize.
|
||||
|
||||
Returns:
|
||||
dict: Summary fields merged with a ``technique_states`` list, each
|
||||
entry containing ``mitre_id``, ``technique_id``, ``status``,
|
||||
and ``score``.
|
||||
"""
|
||||
# Assign base = serialize_snapshot_summary(snap)
|
||||
base = serialize_snapshot_summary(snap)
|
||||
|
||||
# Assign technique_states = (
|
||||
technique_states = (
|
||||
db.query(SnapshotTechniqueState)
|
||||
# Chain .filter() call
|
||||
.filter(SnapshotTechniqueState.snapshot_id == snap.id)
|
||||
# Chain .order_by() call
|
||||
.order_by(SnapshotTechniqueState.mitre_id)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
|
||||
# Assign base["technique_states"] = [
|
||||
base["technique_states"] = [
|
||||
{
|
||||
# Literal argument value
|
||||
"mitre_id": s.mitre_id,
|
||||
# Literal argument value
|
||||
"technique_id": str(s.technique_id),
|
||||
# Literal argument value
|
||||
"status": s.status,
|
||||
# Literal argument value
|
||||
"score": s.score,
|
||||
}
|
||||
for s in technique_states
|
||||
]
|
||||
# Return base
|
||||
return base
|
||||
|
||||
|
||||
# Define function list_snapshots
|
||||
def list_snapshots(
|
||||
# Entry: db
|
||||
db: Session,
|
||||
*,
|
||||
# Entry: offset
|
||||
offset: int = 0,
|
||||
# Entry: limit
|
||||
limit: int = 50,
|
||||
) -> dict:
|
||||
"""List coverage snapshots ordered by creation date (newest first)."""
|
||||
"""List coverage snapshots ordered by creation date (newest first).
|
||||
|
||||
Args:
|
||||
db (Session): Active SQLAlchemy database session.
|
||||
offset (int): Number of records to skip for pagination.
|
||||
limit (int): Maximum number of records to return.
|
||||
|
||||
Returns:
|
||||
dict: Contains ``total``, ``offset``, ``limit``, and ``items`` (list
|
||||
of serialized snapshot summaries).
|
||||
"""
|
||||
# Assign query = db.query(CoverageSnapshot)
|
||||
query = db.query(CoverageSnapshot)
|
||||
# Assign total = query.count()
|
||||
total = query.count()
|
||||
|
||||
# Assign snapshots = (
|
||||
snapshots = (
|
||||
query
|
||||
# Chain .order_by() call
|
||||
.order_by(CoverageSnapshot.created_at.desc())
|
||||
# Chain .offset() call
|
||||
.offset(offset)
|
||||
# Chain .limit() call
|
||||
.limit(limit)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
|
||||
# Return {
|
||||
return {
|
||||
# Literal argument value
|
||||
"total": total,
|
||||
# Literal argument value
|
||||
"offset": offset,
|
||||
# Literal argument value
|
||||
"limit": limit,
|
||||
# Literal argument value
|
||||
"items": [serialize_snapshot_summary(s) for s in snapshots],
|
||||
}
|
||||
|
||||
|
||||
# Define function get_snapshot_or_raise
|
||||
def get_snapshot_or_raise(db: Session, snapshot_id: str) -> CoverageSnapshot:
|
||||
"""Fetch snapshot by ID or raise EntityNotFoundError."""
|
||||
"""Fetch snapshot by ID or raise EntityNotFoundError.
|
||||
|
||||
Args:
|
||||
db (Session): Active SQLAlchemy database session.
|
||||
snapshot_id (str): UUID string of the snapshot to retrieve.
|
||||
|
||||
Returns:
|
||||
CoverageSnapshot: The matching snapshot ORM object.
|
||||
"""
|
||||
# Attempt the following; catch errors below
|
||||
try:
|
||||
# Assign sid = uuid.UUID(snapshot_id)
|
||||
sid = uuid.UUID(snapshot_id)
|
||||
# Handle (ValueError, TypeError)
|
||||
except (ValueError, TypeError):
|
||||
# Raise EntityNotFoundError
|
||||
raise EntityNotFoundError("Snapshot", snapshot_id)
|
||||
# Assign snapshot = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == sid).first()
|
||||
snapshot = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == sid).first()
|
||||
# Check: snapshot is None
|
||||
if snapshot is None:
|
||||
# Raise EntityNotFoundError
|
||||
raise EntityNotFoundError("Snapshot", snapshot_id)
|
||||
# Return snapshot
|
||||
return snapshot
|
||||
|
||||
|
||||
# Define function get_snapshot_detail
|
||||
def get_snapshot_detail(db: Session, snapshot_id: str) -> dict:
|
||||
"""Get detailed snapshot including per-technique states."""
|
||||
"""Return detailed snapshot data including per-technique states.
|
||||
|
||||
Args:
|
||||
db (Session): Active SQLAlchemy database session.
|
||||
snapshot_id (str): UUID string of the snapshot to retrieve.
|
||||
|
||||
Returns:
|
||||
dict: Full snapshot serialization from
|
||||
:func:`serialize_snapshot_detail`.
|
||||
"""
|
||||
# Assign snapshot = get_snapshot_or_raise(db, snapshot_id)
|
||||
snapshot = get_snapshot_or_raise(db, snapshot_id)
|
||||
# Return serialize_snapshot_detail(db, snapshot)
|
||||
return serialize_snapshot_detail(db, snapshot)
|
||||
|
||||
|
||||
# Define function delete_snapshot
|
||||
def delete_snapshot(db: Session, snapshot_id: str) -> None:
|
||||
"""Delete a snapshot. Does not commit — caller must commit."""
|
||||
"""Delete a snapshot. Does not commit — caller must commit.
|
||||
|
||||
Args:
|
||||
db (Session): Active SQLAlchemy database session.
|
||||
snapshot_id (str): UUID string of the snapshot to delete.
|
||||
"""
|
||||
# Assign snapshot = get_snapshot_or_raise(db, snapshot_id)
|
||||
snapshot = get_snapshot_or_raise(db, snapshot_id)
|
||||
# Mark record for deletion on next commit
|
||||
db.delete(snapshot)
|
||||
|
||||
|
||||
@@ -133,8 +277,11 @@ def delete_snapshot(db: Session, snapshot_id: str) -> None:
|
||||
|
||||
|
||||
def create_snapshot(
|
||||
# Entry: db
|
||||
db: Session,
|
||||
# Entry: name
|
||||
name: str | None = None,
|
||||
# Entry: user_id
|
||||
user_id: uuid.UUID | None = None,
|
||||
) -> CoverageSnapshot:
|
||||
"""Capture the current coverage state into a new snapshot.
|
||||
@@ -144,121 +291,215 @@ def create_snapshot(
|
||||
3. Compute the org score from the same bulk data.
|
||||
4. Persist a ``CoverageSnapshot`` with normalised
|
||||
``SnapshotTechniqueState`` rows.
|
||||
|
||||
Args:
|
||||
db (Session): Active SQLAlchemy database session.
|
||||
name (str | None): Optional human-readable label for the snapshot.
|
||||
user_id (uuid.UUID | None): UUID of the user creating the snapshot,
|
||||
stored for auditing.
|
||||
|
||||
Returns:
|
||||
CoverageSnapshot: The newly created and committed snapshot ORM object.
|
||||
"""
|
||||
# Assign scores_map = bulk_technique_scores(db)
|
||||
scores_map = bulk_technique_scores(db)
|
||||
|
||||
# Assign techniques = db.query(Technique).all()
|
||||
techniques = db.query(Technique).all()
|
||||
|
||||
# Assign validated_count = 0
|
||||
validated_count = 0
|
||||
# Assign partial_count = 0
|
||||
partial_count = 0
|
||||
# Assign not_covered_count = 0
|
||||
not_covered_count = 0
|
||||
# Assign in_progress_count = 0
|
||||
in_progress_count = 0
|
||||
# Assign not_evaluated_count = 0
|
||||
not_evaluated_count = 0
|
||||
# Assign stale_count = 0
|
||||
stale_count = 0
|
||||
# Assign never_tested_count = 0
|
||||
never_tested_count = 0
|
||||
|
||||
# Assign by_tactic = defaultdict(
|
||||
by_tactic: dict[str, dict] = defaultdict(
|
||||
# Entry: lambda
|
||||
lambda: {"total": 0, "validated": 0, "partial": 0, "score_sum": 0.0}
|
||||
)
|
||||
# Assign by_status = defaultdict(int)
|
||||
by_status: dict[str, int] = defaultdict(int)
|
||||
|
||||
# Assign technique_rows = []
|
||||
technique_rows: list[dict] = []
|
||||
|
||||
# Iterate over techniques
|
||||
for tech in techniques:
|
||||
# Assign status_value = (
|
||||
status_value = (
|
||||
tech.status_global.value
|
||||
if isinstance(tech.status_global, TechniqueStatus)
|
||||
else (tech.status_global or "not_evaluated")
|
||||
)
|
||||
|
||||
# Check: status_value == "validated"
|
||||
if status_value == "validated":
|
||||
# Assign validated_count = 1
|
||||
validated_count += 1
|
||||
# Alternative: status_value == "partial"
|
||||
elif status_value == "partial":
|
||||
# Assign partial_count = 1
|
||||
partial_count += 1
|
||||
# Alternative: status_value == "not_covered"
|
||||
elif status_value == "not_covered":
|
||||
# Assign not_covered_count = 1
|
||||
not_covered_count += 1
|
||||
# Alternative: status_value == "in_progress"
|
||||
elif status_value == "in_progress":
|
||||
# Assign in_progress_count = 1
|
||||
in_progress_count += 1
|
||||
# Fallback: handle remaining cases
|
||||
else:
|
||||
# Assign not_evaluated_count = 1
|
||||
not_evaluated_count += 1
|
||||
|
||||
# Assign entry = scores_map.get(tech.id, {})
|
||||
entry = scores_map.get(tech.id, {})
|
||||
# Assign score = entry.get("total_score", 0)
|
||||
score = entry.get("total_score", 0)
|
||||
# Call technique_rows.append()
|
||||
technique_rows.append({
|
||||
# Literal argument value
|
||||
"technique_id": tech.id,
|
||||
# Literal argument value
|
||||
"mitre_id": tech.mitre_id,
|
||||
# Literal argument value
|
||||
"status": status_value,
|
||||
# Literal argument value
|
||||
"score": score,
|
||||
})
|
||||
|
||||
# Assign by_status[status_value] = 1
|
||||
by_status[status_value] += 1
|
||||
# Assign tactic_key = tech.tactic or "unknown"
|
||||
tactic_key = tech.tactic or "unknown"
|
||||
# Assign bucket = by_tactic[tactic_key]
|
||||
bucket = by_tactic[tactic_key]
|
||||
# Assign bucket["total"] = 1
|
||||
bucket["total"] += 1
|
||||
# Assign bucket["score_sum"] = score
|
||||
bucket["score_sum"] += score
|
||||
# Check: status_value == "validated"
|
||||
if status_value == "validated":
|
||||
# Assign bucket["validated"] = 1
|
||||
bucket["validated"] += 1
|
||||
# Alternative: status_value == "partial"
|
||||
elif status_value == "partial":
|
||||
# Assign bucket["partial"] = 1
|
||||
bucket["partial"] += 1
|
||||
|
||||
# Check: status_value == "not_evaluated"
|
||||
if status_value == "not_evaluated":
|
||||
# Assign never_tested_count = 1
|
||||
never_tested_count += 1
|
||||
# Check: tech.review_required
|
||||
if tech.review_required:
|
||||
# Assign stale_count = 1
|
||||
stale_count += 1
|
||||
|
||||
# Assign org_data = calculate_organization_score(db)
|
||||
org_data = calculate_organization_score(db)
|
||||
# Assign org_score = org_data.get("overall_score", 0)
|
||||
org_score = org_data.get("overall_score", 0)
|
||||
# Assign total_techniques = len(techniques) or 1
|
||||
total_techniques = len(techniques) or 1
|
||||
# Assign coverage_pct = round((validated_count / total_techniques) * 100, 1)
|
||||
coverage_pct = round((validated_count / total_techniques) * 100, 1)
|
||||
|
||||
# Assign by_tactic_out = {
|
||||
by_tactic_out = {
|
||||
# Entry: tactic
|
||||
tactic: {
|
||||
# Literal argument value
|
||||
"total": data["total"],
|
||||
# Literal argument value
|
||||
"validated": data["validated"],
|
||||
# Literal argument value
|
||||
"partial": data["partial"],
|
||||
# Literal argument value
|
||||
"average_score": round(data["score_sum"] / data["total"], 1) if data["total"] else 0,
|
||||
}
|
||||
for tactic, data in by_tactic.items()
|
||||
}
|
||||
|
||||
# Assign snapshot = CoverageSnapshot(
|
||||
snapshot = CoverageSnapshot(
|
||||
# Keyword argument: name
|
||||
name=name,
|
||||
# Keyword argument: organization_score
|
||||
organization_score=org_score,
|
||||
# Keyword argument: total_techniques
|
||||
total_techniques=len(techniques),
|
||||
# Keyword argument: validated_count
|
||||
validated_count=validated_count,
|
||||
# Keyword argument: partial_count
|
||||
partial_count=partial_count,
|
||||
# Keyword argument: not_covered_count
|
||||
not_covered_count=not_covered_count,
|
||||
# Keyword argument: in_progress_count
|
||||
in_progress_count=in_progress_count,
|
||||
# Keyword argument: not_evaluated_count
|
||||
not_evaluated_count=not_evaluated_count,
|
||||
# Keyword argument: coverage_percentage
|
||||
coverage_percentage=coverage_pct,
|
||||
# Keyword argument: by_tactic
|
||||
by_tactic=by_tactic_out,
|
||||
# Keyword argument: by_status
|
||||
by_status=dict(by_status),
|
||||
# Keyword argument: stale_count
|
||||
stale_count=stale_count,
|
||||
# Keyword argument: never_tested_count
|
||||
never_tested_count=never_tested_count,
|
||||
# Keyword argument: created_by
|
||||
created_by=user_id,
|
||||
)
|
||||
# Stage new record(s) for database insertion
|
||||
db.add(snapshot)
|
||||
# Flush changes to DB without committing the transaction
|
||||
db.flush()
|
||||
|
||||
# Iterate over technique_rows
|
||||
for row in technique_rows:
|
||||
# Assign state = SnapshotTechniqueState(
|
||||
state = SnapshotTechniqueState(
|
||||
# Keyword argument: snapshot_id
|
||||
snapshot_id=snapshot.id,
|
||||
# Keyword argument: technique_id
|
||||
technique_id=row["technique_id"],
|
||||
# Keyword argument: mitre_id
|
||||
mitre_id=row["mitre_id"],
|
||||
# Keyword argument: status
|
||||
status=row["status"],
|
||||
# Keyword argument: score
|
||||
score=row["score"],
|
||||
)
|
||||
# Stage new record(s) for database insertion
|
||||
db.add(state)
|
||||
|
||||
# Commit all pending changes to the database
|
||||
db.commit()
|
||||
# Reload ORM object attributes from the database
|
||||
db.refresh(snapshot)
|
||||
|
||||
# Log info:
|
||||
logger.info(
|
||||
# Literal argument value
|
||||
"Snapshot '%s' created — %d techniques, org score %.1f",
|
||||
snapshot.name or snapshot.id,
|
||||
len(techniques),
|
||||
org_score,
|
||||
)
|
||||
# Return snapshot
|
||||
return snapshot
|
||||
|
||||
|
||||
@@ -268,99 +509,160 @@ def create_snapshot(
|
||||
|
||||
|
||||
def compare_snapshots(
|
||||
# Entry: db
|
||||
db: Session,
|
||||
# Entry: snapshot_a_id
|
||||
snapshot_a_id: uuid.UUID,
|
||||
# Entry: snapshot_b_id
|
||||
snapshot_b_id: uuid.UUID,
|
||||
) -> dict:
|
||||
"""Compare two snapshots and return deltas.
|
||||
|
||||
Returns improved/worsened technique lists plus aggregate statistics.
|
||||
|
||||
Args:
|
||||
db (Session): Active SQLAlchemy database session.
|
||||
snapshot_a_id (uuid.UUID): UUID of the baseline (older) snapshot.
|
||||
snapshot_b_id (uuid.UUID): UUID of the comparison (newer) snapshot.
|
||||
|
||||
Returns:
|
||||
dict: Contains ``snapshot_a``, ``snapshot_b``, ``score_delta``,
|
||||
``improved``, ``worsened``, ``unchanged_count``, and ``summary``
|
||||
keys.
|
||||
"""
|
||||
# Assign snap_a = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_a...
|
||||
snap_a = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_a_id).first()
|
||||
# Assign snap_b = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_b...
|
||||
snap_b = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_b_id).first()
|
||||
|
||||
# Check: not snap_a or not snap_b
|
||||
if not snap_a or not snap_b:
|
||||
# Raise EntityNotFoundError
|
||||
raise EntityNotFoundError("Snapshot", f"{snapshot_a_id} or {snapshot_b_id}")
|
||||
|
||||
# Build lookup dicts: mitre_id -> {status, score}
|
||||
states_a = {
|
||||
s.mitre_id: {"status": s.status, "score": s.score or 0}
|
||||
for s in db.query(SnapshotTechniqueState)
|
||||
# Chain .filter() call
|
||||
.filter(SnapshotTechniqueState.snapshot_id == snapshot_a_id)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
}
|
||||
# Assign states_b = {
|
||||
states_b = {
|
||||
s.mitre_id: {"status": s.status, "score": s.score or 0}
|
||||
for s in db.query(SnapshotTechniqueState)
|
||||
# Chain .filter() call
|
||||
.filter(SnapshotTechniqueState.snapshot_id == snapshot_b_id)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
}
|
||||
|
||||
# Status priority for comparison
|
||||
STATUS_ORDER = {
|
||||
"not_evaluated": 0,
|
||||
"not_covered": 1,
|
||||
"in_progress": 2,
|
||||
"partial": 3,
|
||||
"validated": 4,
|
||||
}
|
||||
|
||||
# Assign improved = []
|
||||
improved = []
|
||||
# Assign worsened = []
|
||||
worsened = []
|
||||
# Assign unchanged_count = 0
|
||||
unchanged_count = 0
|
||||
|
||||
# Assign all_mitre_ids = set(states_a.keys()) | set(states_b.keys())
|
||||
all_mitre_ids = set(states_a.keys()) | set(states_b.keys())
|
||||
|
||||
# Iterate over sorted(all_mitre_ids)
|
||||
for mitre_id in sorted(all_mitre_ids):
|
||||
# Assign a = states_a.get(mitre_id, {"status": "not_evaluated", "score": 0})
|
||||
a = states_a.get(mitre_id, {"status": "not_evaluated", "score": 0})
|
||||
# Assign b = states_b.get(mitre_id, {"status": "not_evaluated", "score": 0})
|
||||
b = states_b.get(mitre_id, {"status": "not_evaluated", "score": 0})
|
||||
|
||||
a_order = STATUS_ORDER.get(a["status"], 0)
|
||||
b_order = STATUS_ORDER.get(b["status"], 0)
|
||||
# Assign a_order = _STATUS_ORDER.get(a["status"], 0)
|
||||
a_order = _STATUS_ORDER.get(a["status"], 0)
|
||||
# Assign b_order = _STATUS_ORDER.get(b["status"], 0)
|
||||
b_order = _STATUS_ORDER.get(b["status"], 0)
|
||||
|
||||
# Check: b_order > a_order or (b_order == a_order and b["score"] > a["score"])
|
||||
if b_order > a_order or (b_order == a_order and b["score"] > a["score"]):
|
||||
# Call improved.append()
|
||||
improved.append({
|
||||
# Literal argument value
|
||||
"mitre_id": mitre_id,
|
||||
# Literal argument value
|
||||
"old_status": a["status"],
|
||||
# Literal argument value
|
||||
"new_status": b["status"],
|
||||
# Literal argument value
|
||||
"old_score": a["score"],
|
||||
# Literal argument value
|
||||
"new_score": b["score"],
|
||||
})
|
||||
# Alternative: b_order < a_order or (b_order == a_order and b["score"] < a["score"])
|
||||
elif b_order < a_order or (b_order == a_order and b["score"] < a["score"]):
|
||||
# Call worsened.append()
|
||||
worsened.append({
|
||||
# Literal argument value
|
||||
"mitre_id": mitre_id,
|
||||
# Literal argument value
|
||||
"old_status": a["status"],
|
||||
# Literal argument value
|
||||
"new_status": b["status"],
|
||||
# Literal argument value
|
||||
"old_score": a["score"],
|
||||
# Literal argument value
|
||||
"new_score": b["score"],
|
||||
})
|
||||
# Fallback: handle remaining cases
|
||||
else:
|
||||
# Assign unchanged_count = 1
|
||||
unchanged_count += 1
|
||||
|
||||
# Define function _snap_summary
|
||||
def _snap_summary(snap: CoverageSnapshot) -> dict:
|
||||
# Return {
|
||||
return {
|
||||
# Literal argument value
|
||||
"id": str(snap.id),
|
||||
# Literal argument value
|
||||
"name": snap.name,
|
||||
# Literal argument value
|
||||
"organization_score": snap.organization_score,
|
||||
# Literal argument value
|
||||
"total_techniques": snap.total_techniques,
|
||||
# Literal argument value
|
||||
"validated_count": snap.validated_count,
|
||||
# Literal argument value
|
||||
"partial_count": snap.partial_count,
|
||||
# Literal argument value
|
||||
"not_covered_count": snap.not_covered_count,
|
||||
# Literal argument value
|
||||
"in_progress_count": snap.in_progress_count,
|
||||
# Literal argument value
|
||||
"not_evaluated_count": snap.not_evaluated_count,
|
||||
# Literal argument value
|
||||
"created_at": snap.created_at.isoformat() if snap.created_at else None,
|
||||
}
|
||||
|
||||
# Return {
|
||||
return {
|
||||
# Literal argument value
|
||||
"snapshot_a": _snap_summary(snap_a),
|
||||
# Literal argument value
|
||||
"snapshot_b": _snap_summary(snap_b),
|
||||
# Literal argument value
|
||||
"score_delta": round(snap_b.organization_score - snap_a.organization_score, 1),
|
||||
# Literal argument value
|
||||
"improved": improved,
|
||||
# Literal argument value
|
||||
"worsened": worsened,
|
||||
# Literal argument value
|
||||
"unchanged_count": unchanged_count,
|
||||
# Literal argument value
|
||||
"summary": {
|
||||
# Literal argument value
|
||||
"improved_count": len(improved),
|
||||
# Literal argument value
|
||||
"worsened_count": len(worsened),
|
||||
# Literal argument value
|
||||
"new_count": len(states_b.keys() - states_a.keys()),
|
||||
},
|
||||
}
|
||||
@@ -372,25 +674,53 @@ def compare_snapshots(
|
||||
|
||||
|
||||
def get_coverage_evolution(db: Session, *, months: int = 12) -> list[dict]:
|
||||
"""Return snapshot trend points for the last *months* months."""
|
||||
"""Return snapshot trend points for the last *months* months.
|
||||
|
||||
Args:
|
||||
db (Session): Active SQLAlchemy database session.
|
||||
months (int): Number of months to look back; defaults to 12.
|
||||
|
||||
Returns:
|
||||
list[dict]: Snapshot trend entries ordered by creation date ascending,
|
||||
each containing ``date``, ``name``, ``org_score``,
|
||||
``coverage_pct``, ``by_tactic``, ``by_status``,
|
||||
``stale_count``, ``never_tested_count``, ``validated_count``,
|
||||
and ``total_techniques``.
|
||||
"""
|
||||
# Assign cutoff = datetime.now(timezone.utc) - timedelta(days=months * 30)
|
||||
cutoff = datetime.now(timezone.utc) - timedelta(days=months * 30)
|
||||
# Assign snapshots = (
|
||||
snapshots = (
|
||||
db.query(CoverageSnapshot)
|
||||
# Chain .filter() call
|
||||
.filter(CoverageSnapshot.created_at >= cutoff)
|
||||
# Chain .order_by() call
|
||||
.order_by(CoverageSnapshot.created_at.asc())
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
# Return [
|
||||
return [
|
||||
{
|
||||
# Literal argument value
|
||||
"date": snap.created_at.isoformat() if snap.created_at else None,
|
||||
# Literal argument value
|
||||
"name": snap.name,
|
||||
# Literal argument value
|
||||
"org_score": snap.organization_score,
|
||||
# Literal argument value
|
||||
"coverage_pct": getattr(snap, "coverage_percentage", 0.0),
|
||||
# Literal argument value
|
||||
"by_tactic": getattr(snap, "by_tactic", None) or {},
|
||||
# Literal argument value
|
||||
"by_status": getattr(snap, "by_status", None) or {},
|
||||
# Literal argument value
|
||||
"stale_count": getattr(snap, "stale_count", 0),
|
||||
# Literal argument value
|
||||
"never_tested_count": getattr(snap, "never_tested_count", 0),
|
||||
# Literal argument value
|
||||
"validated_count": snap.validated_count,
|
||||
# Literal argument value
|
||||
"total_techniques": snap.total_techniques,
|
||||
}
|
||||
for snap in snapshots
|
||||
@@ -405,25 +735,46 @@ def get_coverage_evolution(db: Session, *, months: int = 12) -> list[dict]:
|
||||
def cleanup_old_snapshots(db: Session, keep_last: int = 52) -> int:
|
||||
"""Delete oldest snapshots, keeping the most recent *keep_last*.
|
||||
|
||||
Returns the number of snapshots deleted.
|
||||
Args:
|
||||
db (Session): Active SQLAlchemy database session.
|
||||
keep_last (int): Number of most-recent snapshots to retain; defaults
|
||||
to 52 (one year of weekly snapshots).
|
||||
|
||||
Returns:
|
||||
int: Number of snapshots deleted.
|
||||
"""
|
||||
# Assign total = db.query(func.count(CoverageSnapshot.id)).scalar() or 0
|
||||
total = db.query(func.count(CoverageSnapshot.id)).scalar() or 0
|
||||
# Check: total <= keep_last
|
||||
if total <= keep_last:
|
||||
# Return 0
|
||||
return 0
|
||||
|
||||
# Assign to_delete = total - keep_last
|
||||
to_delete = total - keep_last
|
||||
# Assign old_snapshots = (
|
||||
old_snapshots = (
|
||||
db.query(CoverageSnapshot)
|
||||
# Chain .order_by() call
|
||||
.order_by(CoverageSnapshot.created_at.asc())
|
||||
# Chain .limit() call
|
||||
.limit(to_delete)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
|
||||
# Assign deleted = 0
|
||||
deleted = 0
|
||||
# Iterate over old_snapshots
|
||||
for snap in old_snapshots:
|
||||
# Mark record for deletion on next commit
|
||||
db.delete(snap)
|
||||
# Assign deleted = 1
|
||||
deleted += 1
|
||||
|
||||
# Commit all pending changes to the database
|
||||
db.commit()
|
||||
# Log info: "Snapshot cleanup — deleted %d old snapshots (kept
|
||||
logger.info("Snapshot cleanup — deleted %d old snapshots (kept %d)", deleted, keep_last)
|
||||
# Return deleted
|
||||
return deleted
|
||||
|
||||
@@ -1,26 +1,41 @@
|
||||
"""Stale coverage detection — marks techniques whose last validated test
|
||||
is older than a configurable threshold.
|
||||
"""Stale coverage detection — marks techniques whose last validated test is older than a configurable threshold.
|
||||
|
||||
This is the simple version. The full Decay Engine (Fase 8) will replace
|
||||
this with a multi-factor, configurable decay model with confidence scores.
|
||||
"""
|
||||
|
||||
# Import logging
|
||||
import logging
|
||||
|
||||
# Import datetime, timedelta, timezone from datetime
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
# Import func from sqlalchemy
|
||||
from sqlalchemy import func
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import settings from app.config
|
||||
from app.config import settings
|
||||
|
||||
# Import TechniqueStatus, TestState from app.models.enums
|
||||
from app.models.enums import TechniqueStatus, TestState
|
||||
|
||||
# Import Technique from app.models.technique
|
||||
from app.models.technique import Technique
|
||||
|
||||
# Import Test from app.models.test
|
||||
from app.models.test import Test
|
||||
|
||||
# Assign logger = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Assign STALE_THRESHOLD_DAYS = settings.STALE_THRESHOLD_DAYS
|
||||
STALE_THRESHOLD_DAYS = settings.STALE_THRESHOLD_DAYS
|
||||
|
||||
|
||||
# Define function detect_stale_coverage
|
||||
def detect_stale_coverage(db: Session) -> int:
|
||||
"""Scan all techniques and flag those with stale coverage.
|
||||
|
||||
@@ -30,10 +45,17 @@ def detect_stale_coverage(db: Session) -> int:
|
||||
- It has never had a validated test (but has been manually marked as
|
||||
covered/partial).
|
||||
|
||||
Returns the number of newly-flagged techniques.
|
||||
Args:
|
||||
db (Session): Active SQLAlchemy database session.
|
||||
|
||||
Returns:
|
||||
int: Number of techniques newly flagged as stale (``review_required``
|
||||
set to ``True``) in this run.
|
||||
"""
|
||||
# Assign cutoff = datetime.now(timezone.utc) - timedelta(days=STALE_THRESHOLD_DAYS)
|
||||
cutoff = datetime.now(timezone.utc) - timedelta(days=STALE_THRESHOLD_DAYS)
|
||||
|
||||
# Assign last_validated = func.coalesce(
|
||||
last_validated = func.coalesce(
|
||||
Test.blue_validated_at,
|
||||
Test.red_validated_at,
|
||||
@@ -46,40 +68,60 @@ def detect_stale_coverage(db: Session) -> int:
|
||||
Test.technique_id,
|
||||
func.max(last_validated).label("last_tested"),
|
||||
)
|
||||
# Chain .filter() call
|
||||
.filter(Test.state == TestState.validated)
|
||||
# Chain .group_by() call
|
||||
.group_by(Test.technique_id)
|
||||
# Chain .subquery() call
|
||||
.subquery()
|
||||
)
|
||||
|
||||
# Find techniques that are stale
|
||||
stale_techniques = (
|
||||
db.query(Technique)
|
||||
# Chain .outerjoin() call
|
||||
.outerjoin(latest_test, Technique.id == latest_test.c.technique_id)
|
||||
# Chain .filter() call
|
||||
.filter(
|
||||
# Either tested before cutoff, or never tested at all
|
||||
(latest_test.c.last_tested < cutoff)
|
||||
| (latest_test.c.last_tested.is_(None))
|
||||
)
|
||||
# Chain .filter() call
|
||||
.filter(
|
||||
# Only flag techniques that have a real status (not never-evaluated ones)
|
||||
Technique.status_global != TechniqueStatus.not_evaluated
|
||||
)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
|
||||
# Assign count = 0
|
||||
count = 0
|
||||
# Iterate over stale_techniques
|
||||
for tech in stale_techniques:
|
||||
# Check: not tech.review_required
|
||||
if not tech.review_required:
|
||||
# Assign tech.review_required = True
|
||||
tech.review_required = True
|
||||
# Assign count = 1
|
||||
count += 1
|
||||
# Log info: "Marked %s as stale coverage", tech.mitre_id
|
||||
logger.info("Marked %s as stale coverage", tech.mitre_id)
|
||||
|
||||
# Check: count > 0
|
||||
if count > 0:
|
||||
# Commit all pending changes to the database
|
||||
db.commit()
|
||||
# Log info:
|
||||
logger.info(
|
||||
# Literal argument value
|
||||
"Stale coverage detection complete — %d techniques flagged", count
|
||||
)
|
||||
# Fallback: handle remaining cases
|
||||
else:
|
||||
# Log info: "Stale coverage detection complete — no new stale
|
||||
logger.info("Stale coverage detection complete — no new stale techniques")
|
||||
|
||||
# Return count
|
||||
return count
|
||||
|
||||
@@ -10,21 +10,30 @@ The function mutates the technique but does **not** commit.
|
||||
The caller is responsible for committing the session.
|
||||
"""
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import TechniqueEntity from app.domain.entities.technique
|
||||
from app.domain.entities.technique import TechniqueEntity
|
||||
|
||||
# Import Technique from app.models.technique
|
||||
from app.models.technique import Technique
|
||||
|
||||
|
||||
# Define function recalculate_technique_status
|
||||
def recalculate_technique_status(db: Session, technique: Technique) -> None:
|
||||
"""Recompute ``technique.status_global`` from its tests.
|
||||
|
||||
``db`` is accepted for backward compatibility but is not used
|
||||
directly — test data comes from the ORM relationship.
|
||||
"""
|
||||
# Assign entity = TechniqueEntity.from_orm(technique)
|
||||
entity = TechniqueEntity.from_orm(technique)
|
||||
# Assign test_snapshots = [
|
||||
test_snapshots = [
|
||||
(t.state, t.detection_result) for t in technique.tests
|
||||
]
|
||||
# Call entity.recalculate_status()
|
||||
entity.recalculate_status(test_snapshots)
|
||||
# Assign technique.status_global = entity.status_global
|
||||
technique.status_global = entity.status_global
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
"""Technique query service — framework-agnostic queries for technique details."""
|
||||
# Enable future language features for compatibility
|
||||
from __future__ import annotations
|
||||
|
||||
# Import Session, joinedload from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
# Import EntityNotFoundError from app.domain.errors
|
||||
from app.domain.errors import EntityNotFoundError
|
||||
|
||||
# Import Technique from app.models.technique
|
||||
from app.models.technique import Technique
|
||||
from app.models.detection_rule import DetectionRule
|
||||
from app.models.intel import IntelItem
|
||||
@@ -13,15 +18,21 @@ from app.services.d3fend_import_service import get_defenses_for_technique
|
||||
_SEVERITY_ORDER = {"critical": 0, "high": 1, "medium": 2, "low": 3, "informational": 4, None: 5}
|
||||
|
||||
|
||||
# Define function get_technique_detail
|
||||
def get_technique_detail(db: Session, mitre_id: str) -> dict:
|
||||
"""Fetch full technique details including tests, detection rules, and D3FEND defenses."""
|
||||
technique = (
|
||||
db.query(Technique)
|
||||
# Chain .options() call
|
||||
.options(joinedload(Technique.tests))
|
||||
# Chain .filter() call
|
||||
.filter(Technique.mitre_id == mitre_id)
|
||||
# Chain .first() call
|
||||
.first()
|
||||
)
|
||||
# Check: technique is None
|
||||
if technique is None:
|
||||
# Raise EntityNotFoundError
|
||||
raise EntityNotFoundError("Technique", mitre_id)
|
||||
|
||||
defenses = get_defenses_for_technique(db, technique.id)
|
||||
@@ -49,26 +60,46 @@ def get_technique_detail(db: Session, mitre_id: str) -> dict:
|
||||
)
|
||||
|
||||
return {
|
||||
# Literal argument value
|
||||
"id": str(technique.id),
|
||||
# Literal argument value
|
||||
"mitre_id": technique.mitre_id,
|
||||
# Literal argument value
|
||||
"name": technique.name,
|
||||
# Literal argument value
|
||||
"description": technique.description,
|
||||
# Literal argument value
|
||||
"tactic": technique.tactic,
|
||||
# Literal argument value
|
||||
"platforms": technique.platforms or [],
|
||||
# Literal argument value
|
||||
"mitre_version": technique.mitre_version,
|
||||
# Literal argument value
|
||||
"mitre_last_modified": technique.mitre_last_modified,
|
||||
# Literal argument value
|
||||
"is_subtechnique": technique.is_subtechnique,
|
||||
# Literal argument value
|
||||
"parent_mitre_id": technique.parent_mitre_id,
|
||||
# Literal argument value
|
||||
"status_global": technique.status_global.value if technique.status_global else "not_evaluated",
|
||||
# Literal argument value
|
||||
"review_required": technique.review_required,
|
||||
# Literal argument value
|
||||
"last_review_date": technique.last_review_date,
|
||||
# Literal argument value
|
||||
"tests": [
|
||||
{
|
||||
# Literal argument value
|
||||
"id": str(t.id),
|
||||
# Literal argument value
|
||||
"name": t.name,
|
||||
# Literal argument value
|
||||
"state": t.state.value if t.state else None,
|
||||
# Literal argument value
|
||||
"result": t.result.value if t.result else None,
|
||||
# Literal argument value
|
||||
"platform": t.platform,
|
||||
# Literal argument value
|
||||
"created_at": t.created_at.isoformat() if t.created_at else None,
|
||||
}
|
||||
for t in technique.tests
|
||||
|
||||
@@ -18,15 +18,31 @@ blue_work_started_at) to when they submit, so it reflects actual working time
|
||||
rather than queue time.
|
||||
"""
|
||||
|
||||
# Import logging
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
# Import Any, Optional from typing
|
||||
from typing import Any, Optional
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import settings from app.config
|
||||
from app.config import settings
|
||||
|
||||
# Import InvalidOperationError from app.domain.exceptions
|
||||
from app.domain.exceptions import InvalidOperationError
|
||||
|
||||
# Import JiraLink, JiraLinkEntityType from app.models.jira_link
|
||||
from app.models.jira_link import JiraLink, JiraLinkEntityType
|
||||
|
||||
# Import Test from app.models.test
|
||||
from app.models.test import Test
|
||||
|
||||
# Import User from app.models.user
|
||||
from app.models.user import User
|
||||
|
||||
# Assign logger = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Only red team execution time goes to Tempo.
|
||||
@@ -85,23 +101,31 @@ def get_user_tempo_client(user, db=None):
|
||||
"Add it in Settings → Profile → Tempo Integration."
|
||||
)
|
||||
try:
|
||||
# Import client_v4 as tempo_client from tempoapiclient
|
||||
from tempoapiclient import client_v4 as tempo_client
|
||||
base_url = _get_tempo_base_url(db)
|
||||
logger.debug("Using Tempo base URL: %s", base_url)
|
||||
return tempo_client.Tempo(auth_token=token, base_url=base_url)
|
||||
except ImportError:
|
||||
# Raise InvalidOperationError
|
||||
raise InvalidOperationError(
|
||||
# Literal argument value
|
||||
"tempo-api-python-client is not installed. "
|
||||
"Run: pip install tempo-api-python-client"
|
||||
)
|
||||
|
||||
|
||||
# Define function log_worklog
|
||||
def log_worklog(
|
||||
user,
|
||||
jira_issue_id: int,
|
||||
# Entry: author_account_id
|
||||
author_account_id: str,
|
||||
# Entry: date
|
||||
date: str,
|
||||
# Entry: time_spent_seconds
|
||||
time_spent_seconds: int,
|
||||
# Entry: description
|
||||
description: str,
|
||||
db=None,
|
||||
) -> dict:
|
||||
@@ -128,10 +152,15 @@ def log_worklog(
|
||||
raise RuntimeError(f"Tempo API error: {exc}") from exc
|
||||
|
||||
|
||||
# Define function auto_log_test_worklog
|
||||
def auto_log_test_worklog(
|
||||
# Entry: db
|
||||
db: Session,
|
||||
test,
|
||||
user,
|
||||
# Entry: test
|
||||
test: Test,
|
||||
# Entry: user
|
||||
user: User,
|
||||
# Entry: activity_type
|
||||
activity_type: str,
|
||||
duration_seconds: int,
|
||||
) -> Optional[dict]:
|
||||
@@ -156,6 +185,7 @@ def auto_log_test_worklog(
|
||||
|
||||
# Global kill-switch
|
||||
if not settings.TEMPO_ENABLED:
|
||||
# Return None
|
||||
return None
|
||||
|
||||
if duration_seconds <= 0:
|
||||
@@ -183,15 +213,20 @@ def auto_log_test_worklog(
|
||||
# Need a Jira link with a numeric issue ID
|
||||
link = (
|
||||
db.query(JiraLink)
|
||||
# Chain .filter() call
|
||||
.filter(
|
||||
JiraLink.entity_id == test.id,
|
||||
JiraLink.entity_type == JiraLinkEntityType.test,
|
||||
)
|
||||
# Chain .first() call
|
||||
.first()
|
||||
)
|
||||
|
||||
# Check: not link or not link.jira_issue_id
|
||||
if not link or not link.jira_issue_id:
|
||||
# Log debug: "No Jira link for test %s, skipping Tempo worklog"
|
||||
logger.debug("No Jira link for test %s, skipping Tempo worklog", test.id)
|
||||
# Return None
|
||||
return None
|
||||
|
||||
jira_account_id = (getattr(user, "jira_account_id", "") or "").strip()
|
||||
@@ -202,6 +237,7 @@ def auto_log_test_worklog(
|
||||
)
|
||||
return None
|
||||
|
||||
# Attempt the following; catch errors below
|
||||
try:
|
||||
# Use the phase start timestamp as the worklog date so it matches when
|
||||
# work actually happened (not the submission timestamp).
|
||||
@@ -231,6 +267,7 @@ def auto_log_test_worklog(
|
||||
test.id, getattr(user, "username", user), duration_seconds, work_date,
|
||||
)
|
||||
return result
|
||||
# Handle Exception
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Tempo worklog failed for test %s (user %s): %s",
|
||||
|
||||
@@ -4,12 +4,15 @@ Framework-agnostic; uses domain exceptions from app.domain.errors.
|
||||
The router is responsible for HTTP concerns, auth, audit logging, and commit.
|
||||
"""
|
||||
|
||||
# Import uuid
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
# Import Session, joinedload from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
# Import from app.domain.errors
|
||||
from app.domain.errors import (
|
||||
BusinessRuleViolation,
|
||||
EntityNotFoundError,
|
||||
@@ -21,19 +24,41 @@ from app.models.test import Test
|
||||
from app.models.test_template import TestTemplate
|
||||
from app.models.campaign import Campaign, CampaignTest
|
||||
from app.models.audit import AuditLog
|
||||
|
||||
# Import TestState from app.models.enums
|
||||
from app.models.enums import TestState
|
||||
|
||||
# Import Technique from app.models.technique
|
||||
from app.models.technique import Technique
|
||||
|
||||
# Import Test from app.models.test
|
||||
from app.models.test import Test
|
||||
|
||||
# Import TestTemplate from app.models.test_template
|
||||
from app.models.test_template import TestTemplate
|
||||
|
||||
# Import escape_like from app.utils
|
||||
from app.utils import escape_like
|
||||
|
||||
|
||||
# Define function list_tests
|
||||
def list_tests(
|
||||
# Entry: db
|
||||
db: Session,
|
||||
*,
|
||||
# Entry: state
|
||||
state: str | None = None,
|
||||
# Entry: technique_id
|
||||
technique_id: uuid.UUID | None = None,
|
||||
# Entry: platform
|
||||
platform: str | None = None,
|
||||
# Entry: created_by
|
||||
created_by: uuid.UUID | None = None,
|
||||
# Entry: pending_validation_side
|
||||
pending_validation_side: str | None = None,
|
||||
not_in_any_campaign: bool = False,
|
||||
offset: int = 0,
|
||||
# Entry: limit
|
||||
limit: int = 50,
|
||||
) -> list[Test]:
|
||||
"""Return a paginated list of tests with optional filters.
|
||||
@@ -44,20 +69,32 @@ def list_tests(
|
||||
"""
|
||||
query = db.query(Test).options(joinedload(Test.technique))
|
||||
|
||||
# Check: state
|
||||
if state:
|
||||
# Assign query = query.filter(Test.state == state)
|
||||
query = query.filter(Test.state == state)
|
||||
# Check: technique_id
|
||||
if technique_id:
|
||||
# Assign query = query.filter(Test.technique_id == technique_id)
|
||||
query = query.filter(Test.technique_id == technique_id)
|
||||
# Check: platform
|
||||
if platform:
|
||||
# Assign query = query.filter(Test.platform.ilike(f"%{escape_like(platform)}%"))
|
||||
query = query.filter(Test.platform.ilike(f"%{escape_like(platform)}%"))
|
||||
# Check: created_by
|
||||
if created_by:
|
||||
# Assign query = query.filter(Test.created_by == created_by)
|
||||
query = query.filter(Test.created_by == created_by)
|
||||
# Check: pending_validation_side == "red"
|
||||
if pending_validation_side == "red":
|
||||
# Assign query = query.filter(
|
||||
query = query.filter(
|
||||
Test.state == TestState.in_review,
|
||||
Test.red_validation_status.in_(["pending", None]),
|
||||
)
|
||||
# Alternative: pending_validation_side == "blue"
|
||||
elif pending_validation_side == "blue":
|
||||
# Assign query = query.filter(
|
||||
query = query.filter(
|
||||
Test.state == TestState.in_review,
|
||||
Test.blue_validation_status.in_(["pending", None]),
|
||||
@@ -82,42 +119,72 @@ def list_tests(
|
||||
)
|
||||
query = query.filter(~Test.id.in_(future_draft_tests))
|
||||
|
||||
# Return query.order_by(Test.created_at.desc()).offset(offset).limit(limit)....
|
||||
return query.order_by(Test.created_at.desc()).offset(offset).limit(limit).all()
|
||||
|
||||
|
||||
# Define function create_test
|
||||
def create_test(
|
||||
# Entry: db
|
||||
db: Session,
|
||||
*,
|
||||
# Entry: technique_id
|
||||
technique_id: uuid.UUID,
|
||||
# Entry: creator_id
|
||||
creator_id: uuid.UUID,
|
||||
**fields: Any,
|
||||
**fields: object,
|
||||
) -> Test:
|
||||
"""Create a new test linked to an existing technique.
|
||||
|
||||
Raises EntityNotFoundError if the technique does not exist.
|
||||
Does not commit; caller uses UnitOfWork.
|
||||
|
||||
Args:
|
||||
db (Session): Active SQLAlchemy database session.
|
||||
technique_id (uuid.UUID): UUID of the technique this test covers.
|
||||
creator_id (uuid.UUID): UUID of the user creating the test.
|
||||
**fields (object): Additional keyword arguments set as attributes on
|
||||
the new test (e.g. ``name``, ``platform``, ``description``).
|
||||
|
||||
Returns:
|
||||
Test: The newly created test ORM object, flushed but not committed.
|
||||
"""
|
||||
# Assign technique = db.query(Technique).filter(Technique.id == technique_id).first()
|
||||
technique = db.query(Technique).filter(Technique.id == technique_id).first()
|
||||
# Check: technique is None
|
||||
if technique is None:
|
||||
# Raise EntityNotFoundError
|
||||
raise EntityNotFoundError("Technique", str(technique_id))
|
||||
|
||||
# Assign test = Test(
|
||||
test = Test(
|
||||
# Keyword argument: technique_id
|
||||
technique_id=technique_id,
|
||||
# Keyword argument: created_by
|
||||
created_by=creator_id,
|
||||
# Keyword argument: state
|
||||
state=TestState.draft,
|
||||
created_at=datetime.utcnow(), # explicit — DB column has no server default
|
||||
**fields,
|
||||
)
|
||||
# Stage new record(s) for database insertion
|
||||
db.add(test)
|
||||
# Flush changes to DB without committing the transaction
|
||||
db.flush()
|
||||
# Return test
|
||||
return test
|
||||
|
||||
|
||||
# Define function create_test_from_template
|
||||
def create_test_from_template(
|
||||
# Entry: db
|
||||
db: Session,
|
||||
*,
|
||||
# Entry: template_id
|
||||
template_id: uuid.UUID,
|
||||
# Entry: technique_id_or_mitre
|
||||
technique_id_or_mitre: str,
|
||||
# Entry: creator_id
|
||||
creator_id: uuid.UUID,
|
||||
# Optional user-edited overrides (take priority over template values)
|
||||
name_override: str | None = None,
|
||||
@@ -132,27 +199,53 @@ def create_test_from_template(
|
||||
Override fields, when provided, take precedence over the template's values.
|
||||
Raises EntityNotFoundError if template or technique not found.
|
||||
Does not commit; caller uses UnitOfWork.
|
||||
|
||||
Args:
|
||||
db (Session): Active SQLAlchemy database session.
|
||||
template_id (uuid.UUID): UUID of the template to instantiate.
|
||||
technique_id_or_mitre (str): UUID string or MITRE technique ID
|
||||
(e.g. ``"T1059.001"``) identifying the target technique.
|
||||
creator_id (uuid.UUID): UUID of the user creating the test.
|
||||
|
||||
Returns:
|
||||
Test: The newly created test populated from template fields, flushed
|
||||
but not committed.
|
||||
"""
|
||||
# Assign template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first()
|
||||
template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first()
|
||||
# Check: template is None
|
||||
if template is None:
|
||||
# Raise EntityNotFoundError
|
||||
raise EntityNotFoundError("TestTemplate", str(template_id))
|
||||
|
||||
# Assign technique = None
|
||||
technique = None
|
||||
# Attempt the following; catch errors below
|
||||
try:
|
||||
# Assign technique_uuid = uuid.UUID(technique_id_or_mitre)
|
||||
technique_uuid = uuid.UUID(technique_id_or_mitre)
|
||||
# Assign technique = db.query(Technique).filter(Technique.id == technique_uuid).first()
|
||||
technique = db.query(Technique).filter(Technique.id == technique_uuid).first()
|
||||
# Handle ValueError
|
||||
except ValueError:
|
||||
# Intentional no-op placeholder
|
||||
pass
|
||||
|
||||
# Check: technique is None
|
||||
if technique is None:
|
||||
# Assign technique = db.query(Technique).filter(
|
||||
technique = db.query(Technique).filter(
|
||||
Technique.mitre_id == technique_id_or_mitre
|
||||
).first()
|
||||
|
||||
# Check: technique is None
|
||||
if technique is None:
|
||||
# Raise EntityNotFoundError
|
||||
raise EntityNotFoundError("Technique", technique_id_or_mitre)
|
||||
|
||||
# Assign test = Test(
|
||||
test = Test(
|
||||
# Keyword argument: technique_id
|
||||
technique_id=technique.id,
|
||||
name=name_override if name_override is not None else template.name,
|
||||
description=description_override if description_override is not None else template.description,
|
||||
@@ -160,59 +253,111 @@ def create_test_from_template(
|
||||
procedure_text=procedure_text_override if procedure_text_override is not None else template.attack_procedure,
|
||||
tool_used=tool_used_override if tool_used_override is not None else template.tool_suggested,
|
||||
remediation_steps=template.suggested_remediation,
|
||||
# Keyword argument: created_by
|
||||
created_by=creator_id,
|
||||
# Keyword argument: state
|
||||
state=TestState.draft,
|
||||
created_at=datetime.utcnow(), # explicit — DB column has no server default
|
||||
)
|
||||
# Stage new record(s) for database insertion
|
||||
db.add(test)
|
||||
# Flush changes to DB without committing the transaction
|
||||
db.flush()
|
||||
# Return test
|
||||
return test
|
||||
|
||||
|
||||
# Define function get_test_detail
|
||||
def get_test_detail(db: Session, test_id: uuid.UUID) -> Test:
|
||||
"""Fetch a test with evidences and technique eager-loaded.
|
||||
|
||||
Raises EntityNotFoundError if the test does not exist.
|
||||
|
||||
Args:
|
||||
db (Session): Active SQLAlchemy database session.
|
||||
test_id (uuid.UUID): UUID of the test to retrieve.
|
||||
|
||||
Returns:
|
||||
Test: The test ORM object with ``evidences`` relationship loaded.
|
||||
"""
|
||||
# Assign test = (
|
||||
test = (
|
||||
db.query(Test)
|
||||
.options(joinedload(Test.evidences), joinedload(Test.technique))
|
||||
.filter(Test.id == test_id)
|
||||
# Chain .first() call
|
||||
.first()
|
||||
)
|
||||
# Check: test is None
|
||||
if test is None:
|
||||
# Raise EntityNotFoundError
|
||||
raise EntityNotFoundError("Test", str(test_id))
|
||||
# Return test
|
||||
return test
|
||||
|
||||
|
||||
# Define function get_test_or_raise
|
||||
def get_test_or_raise(db: Session, test_id: uuid.UUID) -> Test:
|
||||
"""Fetch a test by ID. Raises EntityNotFoundError if not found."""
|
||||
"""Fetch a test by ID. Raises EntityNotFoundError if not found.
|
||||
|
||||
Args:
|
||||
db (Session): Active SQLAlchemy database session.
|
||||
test_id (uuid.UUID): UUID of the test to retrieve.
|
||||
|
||||
Returns:
|
||||
Test: The matching test ORM object.
|
||||
"""
|
||||
# Assign test = db.query(Test).filter(Test.id == test_id).first()
|
||||
test = db.query(Test).filter(Test.id == test_id).first()
|
||||
# Check: test is None
|
||||
if test is None:
|
||||
# Raise EntityNotFoundError
|
||||
raise EntityNotFoundError("Test", str(test_id))
|
||||
# Return test
|
||||
return test
|
||||
|
||||
|
||||
# Define function get_test_with_technique
|
||||
def get_test_with_technique(db: Session, test_id: uuid.UUID) -> Test:
|
||||
"""Fetch a test with technique joined. Raises EntityNotFoundError if not found."""
|
||||
"""Fetch a test with technique joined. Raises EntityNotFoundError if not found.
|
||||
|
||||
Args:
|
||||
db (Session): Active SQLAlchemy database session.
|
||||
test_id (uuid.UUID): UUID of the test to retrieve.
|
||||
|
||||
Returns:
|
||||
Test: The test ORM object with ``technique`` relationship loaded.
|
||||
"""
|
||||
# Assign test = (
|
||||
test = (
|
||||
db.query(Test)
|
||||
# Chain .options() call
|
||||
.options(joinedload(Test.technique))
|
||||
# Chain .filter() call
|
||||
.filter(Test.id == test_id)
|
||||
# Chain .first() call
|
||||
.first()
|
||||
)
|
||||
# Check: test is None
|
||||
if test is None:
|
||||
# Raise EntityNotFoundError
|
||||
raise EntityNotFoundError("Test", str(test_id))
|
||||
# Return test
|
||||
return test
|
||||
|
||||
|
||||
# Define function update_test
|
||||
def update_test(
|
||||
# Entry: db
|
||||
db: Session,
|
||||
# Entry: test_id
|
||||
test_id: uuid.UUID,
|
||||
*,
|
||||
# Entry: updater_id
|
||||
updater_id: uuid.UUID,
|
||||
# Entry: updater_role
|
||||
updater_role: str,
|
||||
**fields: Any,
|
||||
**fields: object,
|
||||
) -> Test:
|
||||
"""Update general test fields (draft or rejected only).
|
||||
|
||||
@@ -220,93 +365,170 @@ def update_test(
|
||||
Raises BusinessRuleViolation if state is not draft or rejected.
|
||||
Raises EntityNotFoundError if test not found.
|
||||
Does not commit; caller uses UnitOfWork.
|
||||
|
||||
Args:
|
||||
db (Session): Active SQLAlchemy database session.
|
||||
test_id (uuid.UUID): UUID of the test to update.
|
||||
updater_id (uuid.UUID): UUID of the user performing the update.
|
||||
updater_role (str): Role of the updater; ``"admin"`` bypasses the
|
||||
creator-only restriction.
|
||||
**fields (object): Keyword arguments mapped directly onto test
|
||||
attributes.
|
||||
|
||||
Returns:
|
||||
Test: The updated test ORM object, flushed but not committed.
|
||||
"""
|
||||
# Assign test = get_test_or_raise(db, test_id)
|
||||
test = get_test_or_raise(db, test_id)
|
||||
|
||||
# Check: updater_role != "admin" and test.created_by != updater_id
|
||||
if updater_role != "admin" and test.created_by != updater_id:
|
||||
# Raise PermissionViolation
|
||||
raise PermissionViolation(
|
||||
# Literal argument value
|
||||
"Only the test creator or an admin can update this test"
|
||||
)
|
||||
|
||||
# Check: test.state not in (TestState.draft, TestState.rejected)
|
||||
if test.state not in (TestState.draft, TestState.rejected):
|
||||
# Raise BusinessRuleViolation
|
||||
raise BusinessRuleViolation(
|
||||
f"Cannot update a test in '{test.state.value}' state (must be draft or rejected)"
|
||||
)
|
||||
|
||||
# Iterate over fields.items()
|
||||
for field, value in fields.items():
|
||||
# Call setattr()
|
||||
setattr(test, field, value)
|
||||
|
||||
# Flush changes to DB without committing the transaction
|
||||
db.flush()
|
||||
# Return test
|
||||
return test
|
||||
|
||||
|
||||
def update_test_red(db: Session, test_id: uuid.UUID, **fields: Any) -> Test:
|
||||
# Define function update_test_red
|
||||
def update_test_red(db: Session, test_id: uuid.UUID, **fields: object) -> Test:
|
||||
"""Update Red Team fields (draft or red_executing only).
|
||||
|
||||
Raises BusinessRuleViolation if state not in (draft, red_executing).
|
||||
Raises EntityNotFoundError if test not found.
|
||||
Does not commit; caller uses UnitOfWork.
|
||||
|
||||
Args:
|
||||
db (Session): Active SQLAlchemy database session.
|
||||
test_id (uuid.UUID): UUID of the test to update.
|
||||
**fields (object): Red-team field names and their new values.
|
||||
|
||||
Returns:
|
||||
Test: The updated test ORM object, flushed but not committed.
|
||||
"""
|
||||
# Assign test = get_test_or_raise(db, test_id)
|
||||
test = get_test_or_raise(db, test_id)
|
||||
|
||||
# Check: test.state not in (TestState.draft, TestState.red_executing)
|
||||
if test.state not in (TestState.draft, TestState.red_executing):
|
||||
# Raise BusinessRuleViolation
|
||||
raise BusinessRuleViolation(
|
||||
f"Cannot update red fields in '{test.state.value}' state "
|
||||
# Literal argument value
|
||||
"(must be draft or red_executing)"
|
||||
)
|
||||
|
||||
# Iterate over fields.items()
|
||||
for field, value in fields.items():
|
||||
# Call setattr()
|
||||
setattr(test, field, value)
|
||||
|
||||
# Flush changes to DB without committing the transaction
|
||||
db.flush()
|
||||
# Return test
|
||||
return test
|
||||
|
||||
|
||||
def update_test_blue(db: Session, test_id: uuid.UUID, **fields: Any) -> Test:
|
||||
# Define function update_test_blue
|
||||
def update_test_blue(db: Session, test_id: uuid.UUID, **fields: object) -> Test:
|
||||
"""Update Blue Team fields (blue_evaluating only).
|
||||
|
||||
Raises BusinessRuleViolation if state is not blue_evaluating.
|
||||
Raises EntityNotFoundError if test not found.
|
||||
Does not commit; caller uses UnitOfWork.
|
||||
|
||||
Args:
|
||||
db (Session): Active SQLAlchemy database session.
|
||||
test_id (uuid.UUID): UUID of the test to update.
|
||||
**fields (object): Blue-team field names and their new values.
|
||||
|
||||
Returns:
|
||||
Test: The updated test ORM object, flushed but not committed.
|
||||
"""
|
||||
# Assign test = get_test_or_raise(db, test_id)
|
||||
test = get_test_or_raise(db, test_id)
|
||||
|
||||
# Check: test.state != TestState.blue_evaluating
|
||||
if test.state != TestState.blue_evaluating:
|
||||
# Raise BusinessRuleViolation
|
||||
raise BusinessRuleViolation(
|
||||
f"Cannot update blue fields in '{test.state.value}' state "
|
||||
# Literal argument value
|
||||
"(must be blue_evaluating)"
|
||||
)
|
||||
|
||||
# Iterate over fields.items()
|
||||
for field, value in fields.items():
|
||||
# Call setattr()
|
||||
setattr(test, field, value)
|
||||
|
||||
# Flush changes to DB without committing the transaction
|
||||
db.flush()
|
||||
# Return test
|
||||
return test
|
||||
|
||||
|
||||
# Define function get_test_timeline
|
||||
def get_test_timeline(db: Session, test_id: uuid.UUID) -> list[dict[str, Any]]:
|
||||
"""Return chronological audit-log history for a test.
|
||||
|
||||
Raises EntityNotFoundError if the test does not exist.
|
||||
|
||||
Args:
|
||||
db (Session): Active SQLAlchemy database session.
|
||||
test_id (uuid.UUID): UUID of the test whose history is requested.
|
||||
|
||||
Returns:
|
||||
list[dict[str, Any]]: Audit-log entries ordered by timestamp ascending,
|
||||
each containing ``id``, ``action``, ``user_id``, ``timestamp``,
|
||||
and ``details``.
|
||||
"""
|
||||
# Call get_test_or_raise()
|
||||
get_test_or_raise(db, test_id)
|
||||
|
||||
# Assign logs = (
|
||||
logs = (
|
||||
db.query(AuditLog)
|
||||
# Chain .filter() call
|
||||
.filter(
|
||||
AuditLog.entity_type == "test",
|
||||
AuditLog.entity_id == str(test_id),
|
||||
)
|
||||
# Chain .order_by() call
|
||||
.order_by(AuditLog.timestamp.asc())
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
|
||||
# Return [
|
||||
return [
|
||||
{
|
||||
# Literal argument value
|
||||
"id": str(log.id),
|
||||
# Literal argument value
|
||||
"action": log.action,
|
||||
# Literal argument value
|
||||
"user_id": str(log.user_id) if log.user_id else None,
|
||||
# Literal argument value
|
||||
"timestamp": log.timestamp.isoformat() if log.timestamp else None,
|
||||
# Literal argument value
|
||||
"details": log.details,
|
||||
}
|
||||
for log in logs
|
||||
|
||||
@@ -1,43 +1,77 @@
|
||||
"""Test template service — framework-agnostic CRUD and queries."""
|
||||
# Enable future language features for compatibility
|
||||
from __future__ import annotations
|
||||
|
||||
# Import uuid
|
||||
import uuid
|
||||
|
||||
# Import func, or_ from sqlalchemy
|
||||
from sqlalchemy import func, or_
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import EntityNotFoundError from app.domain.errors
|
||||
from app.domain.errors import EntityNotFoundError
|
||||
|
||||
# Import TestTemplate from app.models.test_template
|
||||
from app.models.test_template import TestTemplate
|
||||
|
||||
# Import escape_like from app.utils
|
||||
from app.utils import escape_like
|
||||
|
||||
|
||||
# Define function list_templates
|
||||
def list_templates(
|
||||
# Entry: db
|
||||
db: Session,
|
||||
*,
|
||||
# Entry: source
|
||||
source: str | None = None,
|
||||
# Entry: platform
|
||||
platform: str | None = None,
|
||||
# Entry: severity
|
||||
severity: str | None = None,
|
||||
# Entry: mitre_technique_id
|
||||
mitre_technique_id: str | None = None,
|
||||
# Entry: search
|
||||
search: str | None = None,
|
||||
# Entry: is_active
|
||||
is_active: bool | None = None,
|
||||
# Entry: offset
|
||||
offset: int = 0,
|
||||
# Entry: limit
|
||||
limit: int = 50,
|
||||
) -> list:
|
||||
"""Return paginated, filterable list of test templates."""
|
||||
# Assign query = db.query(TestTemplate)
|
||||
query = db.query(TestTemplate)
|
||||
# Check: is_active is not None
|
||||
if is_active is not None:
|
||||
# Assign query = query.filter(TestTemplate.is_active == is_active)
|
||||
query = query.filter(TestTemplate.is_active == is_active)
|
||||
|
||||
# Check: source
|
||||
if source:
|
||||
# Assign query = query.filter(TestTemplate.source == source)
|
||||
query = query.filter(TestTemplate.source == source)
|
||||
# Check: platform
|
||||
if platform:
|
||||
# Assign query = query.filter(TestTemplate.platform.ilike(f"%{escape_like(platform)}...
|
||||
query = query.filter(TestTemplate.platform.ilike(f"%{escape_like(platform)}%"))
|
||||
# Check: severity
|
||||
if severity:
|
||||
# Assign query = query.filter(TestTemplate.severity == severity)
|
||||
query = query.filter(TestTemplate.severity == severity)
|
||||
# Check: mitre_technique_id
|
||||
if mitre_technique_id:
|
||||
# Assign query = query.filter(TestTemplate.mitre_technique_id == mitre_technique_id)
|
||||
query = query.filter(TestTemplate.mitre_technique_id == mitre_technique_id)
|
||||
# Check: search
|
||||
if search:
|
||||
# Assign pattern = f"%{escape_like(search)}%"
|
||||
pattern = f"%{escape_like(search)}%"
|
||||
# Assign query = query.filter(
|
||||
query = query.filter(
|
||||
or_(
|
||||
TestTemplate.name.ilike(pattern),
|
||||
@@ -45,106 +79,166 @@ def list_templates(
|
||||
)
|
||||
)
|
||||
|
||||
# Assign templates = (
|
||||
templates = (
|
||||
query
|
||||
# Chain .order_by() call
|
||||
.order_by(TestTemplate.mitre_technique_id, TestTemplate.name)
|
||||
# Chain .offset() call
|
||||
.offset(offset)
|
||||
# Chain .limit() call
|
||||
.limit(limit)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
# Return templates
|
||||
return templates
|
||||
|
||||
|
||||
# Define function get_template_stats
|
||||
def get_template_stats(db: Session) -> dict:
|
||||
"""Return catalog statistics: totals by source, platform, active/inactive."""
|
||||
# Assign total = db.query(func.count(TestTemplate.id)).scalar() or 0
|
||||
total = db.query(func.count(TestTemplate.id)).scalar() or 0
|
||||
# Assign active = (
|
||||
active = (
|
||||
db.query(func.count(TestTemplate.id))
|
||||
# Chain .filter() call
|
||||
.filter(TestTemplate.is_active == True) # noqa: E712
|
||||
# Chain .scalar() call
|
||||
.scalar()
|
||||
) or 0
|
||||
# Assign inactive = total - active
|
||||
inactive = total - active
|
||||
|
||||
# Assign source_rows = (
|
||||
source_rows = (
|
||||
db.query(TestTemplate.source, func.count(TestTemplate.id))
|
||||
# Chain .filter() call
|
||||
.filter(TestTemplate.is_active == True) # noqa: E712
|
||||
# Chain .group_by() call
|
||||
.group_by(TestTemplate.source)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
# Assign by_source = {source: cnt for source, cnt in source_rows}
|
||||
by_source = {source: cnt for source, cnt in source_rows}
|
||||
|
||||
# Assign platform_rows = (
|
||||
platform_rows = (
|
||||
db.query(TestTemplate.platform, func.count(TestTemplate.id))
|
||||
# Chain .filter() call
|
||||
.filter(TestTemplate.is_active == True) # noqa: E712
|
||||
# Chain .group_by() call
|
||||
.group_by(TestTemplate.platform)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
# Assign by_platform = {(platform or "unspecified"): cnt for platform, cnt in platform_rows}
|
||||
by_platform = {(platform or "unspecified"): cnt for platform, cnt in platform_rows}
|
||||
|
||||
# Return {
|
||||
return {
|
||||
# Literal argument value
|
||||
"total": total,
|
||||
# Literal argument value
|
||||
"active": active,
|
||||
# Literal argument value
|
||||
"inactive": inactive,
|
||||
# Literal argument value
|
||||
"by_source": by_source,
|
||||
# Literal argument value
|
||||
"by_platform": by_platform,
|
||||
}
|
||||
|
||||
|
||||
# Define function bulk_activate
|
||||
def bulk_activate(db: Session, *, activate: bool) -> int:
|
||||
"""Set all templates to active or inactive. Returns count of affected. Does NOT commit."""
|
||||
# Assign count = (
|
||||
count = (
|
||||
db.query(TestTemplate)
|
||||
# Chain .filter() call
|
||||
.filter(TestTemplate.is_active != activate)
|
||||
# Chain .update() call
|
||||
.update({TestTemplate.is_active: activate})
|
||||
)
|
||||
# Return count
|
||||
return count
|
||||
|
||||
|
||||
# Define function get_templates_by_technique
|
||||
def get_templates_by_technique(db: Session, mitre_id: str) -> list:
|
||||
"""Return all active templates mapped to a specific MITRE technique."""
|
||||
# Return (
|
||||
return (
|
||||
db.query(TestTemplate)
|
||||
# Chain .filter() call
|
||||
.filter(
|
||||
TestTemplate.mitre_technique_id == mitre_id,
|
||||
TestTemplate.is_active == True, # noqa: E712
|
||||
)
|
||||
# Chain .order_by() call
|
||||
.order_by(TestTemplate.name)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
# Define function get_template_or_raise
|
||||
def get_template_or_raise(db: Session, template_id: uuid.UUID) -> TestTemplate:
|
||||
"""Return a template by ID. Raises EntityNotFoundError if not found."""
|
||||
# Assign template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first()
|
||||
template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first()
|
||||
# Check: template is None
|
||||
if template is None:
|
||||
# Raise EntityNotFoundError
|
||||
raise EntityNotFoundError("Test template", str(template_id))
|
||||
# Return template
|
||||
return template
|
||||
|
||||
|
||||
# Define function create_template
|
||||
def create_template(db: Session, **fields: object) -> TestTemplate:
|
||||
"""Create a test template from keyword args (e.g. payload.model_dump()). Does NOT commit."""
|
||||
# Assign template = TestTemplate(**fields)
|
||||
template = TestTemplate(**fields)
|
||||
# Stage new record(s) for database insertion
|
||||
db.add(template)
|
||||
# Return template
|
||||
return template
|
||||
|
||||
|
||||
# Define function update_template
|
||||
def update_template(db: Session, template_id: uuid.UUID, **fields: object) -> TestTemplate:
|
||||
"""Update an existing template. Raises EntityNotFoundError if not found. Does NOT commit."""
|
||||
# Assign template = get_template_or_raise(db, template_id)
|
||||
template = get_template_or_raise(db, template_id)
|
||||
# Iterate over fields.items()
|
||||
for field, value in fields.items():
|
||||
# Check: hasattr(template, field)
|
||||
if hasattr(template, field):
|
||||
# Call setattr()
|
||||
setattr(template, field, value)
|
||||
# Return template
|
||||
return template
|
||||
|
||||
|
||||
# Define function toggle_template_active
|
||||
def toggle_template_active(db: Session, template_id: uuid.UUID) -> TestTemplate:
|
||||
"""Toggle template active/inactive. Does NOT commit."""
|
||||
# Assign template = get_template_or_raise(db, template_id)
|
||||
template = get_template_or_raise(db, template_id)
|
||||
# Assign template.is_active = not template.is_active
|
||||
template.is_active = not template.is_active
|
||||
# Return template
|
||||
return template
|
||||
|
||||
|
||||
# Define function soft_delete_template
|
||||
def soft_delete_template(db: Session, template_id: uuid.UUID) -> None:
|
||||
"""Soft-delete a template by setting is_active=False. Does NOT commit."""
|
||||
# Assign template = get_template_or_raise(db, template_id)
|
||||
template = get_template_or_raise(db, template_id)
|
||||
# Assign template.is_active = False
|
||||
template.is_active = False
|
||||
|
||||
@@ -12,11 +12,19 @@ an audit-log entry. The caller (router) is responsible for committing the
|
||||
session via the Unit of Work pattern.
|
||||
"""
|
||||
|
||||
# Import logging
|
||||
import logging
|
||||
|
||||
# Import uuid
|
||||
import uuid
|
||||
|
||||
# Import datetime from datetime
|
||||
from datetime import datetime
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import settings from app.config
|
||||
from app.config import settings
|
||||
from app.domain.exceptions import InvalidOperationError, InvalidTransitionError
|
||||
from app.domain.test_entity import TestEntity
|
||||
@@ -27,6 +35,31 @@ from app.models.user import User
|
||||
from app.services.audit_service import log_action
|
||||
from app.services.notification_service import notify_test_state_change, create_notification
|
||||
|
||||
# Import InvalidOperationError from app.domain.exceptions
|
||||
from app.domain.exceptions import InvalidOperationError
|
||||
|
||||
# Import TestEntity from app.domain.test_entity
|
||||
from app.domain.test_entity import TestEntity
|
||||
|
||||
# Import TestState from app.models.enums
|
||||
from app.models.enums import TestState
|
||||
|
||||
# Import Test from app.models.test
|
||||
from app.models.test import Test
|
||||
|
||||
# Import User from app.models.user
|
||||
from app.models.user import User
|
||||
|
||||
# Import log_action from app.services.audit_service
|
||||
from app.services.audit_service import log_action
|
||||
|
||||
# Import from app.services.notification_service
|
||||
from app.services.notification_service import (
|
||||
create_notification,
|
||||
notify_test_state_change,
|
||||
)
|
||||
|
||||
# Assign logger = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -50,18 +83,35 @@ VALID_TRANSITIONS: dict[TestState, list[TestState]] = {
|
||||
|
||||
|
||||
def can_transition(test: Test, target_state: TestState) -> bool:
|
||||
"""Return *True* if moving *test* to *target_state* is allowed."""
|
||||
"""Return *True* if moving *test* to *target_state* is allowed.
|
||||
|
||||
Args:
|
||||
test (Test): The test whose current state is being checked.
|
||||
target_state (TestState): The state to transition to.
|
||||
|
||||
Returns:
|
||||
bool: ``True`` if the transition is permitted by ``VALID_TRANSITIONS``.
|
||||
"""
|
||||
# Assign current = test.state if isinstance(test.state, TestState) else TestState(test...
|
||||
current = test.state if isinstance(test.state, TestState) else TestState(test.state)
|
||||
# Return target_state in VALID_TRANSITIONS.get(current, [])
|
||||
return target_state in VALID_TRANSITIONS.get(current, [])
|
||||
|
||||
|
||||
# Define function transition_state
|
||||
def transition_state(
|
||||
# Entry: db
|
||||
db: Session,
|
||||
# Entry: test
|
||||
test: Test,
|
||||
# Entry: target_state
|
||||
target_state: TestState,
|
||||
# Entry: user
|
||||
user: User,
|
||||
*,
|
||||
# Entry: action_name
|
||||
action_name: str = "transition_state",
|
||||
# Entry: extra_details
|
||||
extra_details: dict | None = None,
|
||||
) -> Test:
|
||||
"""Validate and perform a state transition, log it, and flush.
|
||||
@@ -71,36 +121,71 @@ def transition_state(
|
||||
when the transition is illegal. The entity is authoritative for which
|
||||
transitions are valid; the module-level ``VALID_TRANSITIONS`` dict is
|
||||
kept temporarily for backward compatibility of ``can_transition()``.
|
||||
|
||||
Args:
|
||||
db (Session): Active SQLAlchemy database session.
|
||||
test (Test): The test ORM object to transition.
|
||||
target_state (TestState): Desired next state.
|
||||
user (User): The user performing the transition (logged in audit).
|
||||
action_name (str): Audit log action label; defaults to
|
||||
``"transition_state"``.
|
||||
extra_details (dict | None): Optional extra key-value pairs merged
|
||||
into the audit log details.
|
||||
|
||||
Returns:
|
||||
Test: The mutated test ORM object (state updated, flushed).
|
||||
"""
|
||||
# Assign entity = TestEntity.from_orm(test)
|
||||
entity = TestEntity.from_orm(test)
|
||||
# Assign previous_state = entity.transition_to(target_state)
|
||||
previous_state = entity.transition_to(target_state)
|
||||
|
||||
# Assign test.state = entity.state
|
||||
test.state = entity.state
|
||||
# Flush changes to DB without committing the transaction
|
||||
db.flush()
|
||||
|
||||
# Assign details = {
|
||||
details: dict = {
|
||||
# Literal argument value
|
||||
"previous_state": previous_state,
|
||||
# Literal argument value
|
||||
"new_state": target_state.value,
|
||||
# Literal argument value
|
||||
"test_name": test.name,
|
||||
# Literal argument value
|
||||
"technique_id": str(test.technique_id),
|
||||
}
|
||||
# Check: extra_details
|
||||
if extra_details:
|
||||
# Call details.update()
|
||||
details.update(extra_details)
|
||||
|
||||
# Call log_action()
|
||||
log_action(
|
||||
db,
|
||||
# Keyword argument: user_id
|
||||
user_id=user.id,
|
||||
# Keyword argument: action
|
||||
action=action_name,
|
||||
# Keyword argument: entity_type
|
||||
entity_type="test",
|
||||
# Keyword argument: entity_id
|
||||
entity_id=test.id,
|
||||
# Keyword argument: details
|
||||
details=details,
|
||||
)
|
||||
|
||||
# Attempt the following; catch errors below
|
||||
try:
|
||||
# Call notify_test_state_change()
|
||||
notify_test_state_change(db, test, target_state.value)
|
||||
# Handle Exception
|
||||
except Exception as e:
|
||||
# Log warning: "Notification failed for test %s: %s", test.id, e
|
||||
logger.warning("Notification failed for test %s: %s", test.id, e, exc_info=True)
|
||||
|
||||
# Return test
|
||||
return test
|
||||
|
||||
|
||||
@@ -112,27 +197,44 @@ def transition_state(
|
||||
def start_execution(db: Session, test: Test, user: User) -> Test:
|
||||
"""Move from ``draft`` → ``red_executing``."""
|
||||
entity = TestEntity.from_orm(test)
|
||||
# Call entity.start_execution()
|
||||
entity.start_execution()
|
||||
# Call entity.apply_to()
|
||||
entity.apply_to(test)
|
||||
# Flush changes to DB without committing the transaction
|
||||
db.flush()
|
||||
|
||||
# Call log_action()
|
||||
log_action(
|
||||
db,
|
||||
# Keyword argument: user_id
|
||||
user_id=user.id,
|
||||
# Keyword argument: action
|
||||
action="start_execution",
|
||||
# Keyword argument: entity_type
|
||||
entity_type="test",
|
||||
# Keyword argument: entity_id
|
||||
entity_id=test.id,
|
||||
# Keyword argument: details
|
||||
details={
|
||||
# Literal argument value
|
||||
"previous_state": "draft",
|
||||
# Literal argument value
|
||||
"new_state": test.state.value,
|
||||
# Literal argument value
|
||||
"test_name": test.name,
|
||||
# Literal argument value
|
||||
"technique_id": str(test.technique_id),
|
||||
},
|
||||
)
|
||||
|
||||
# Attempt the following; catch errors below
|
||||
try:
|
||||
# Call notify_test_state_change()
|
||||
notify_test_state_change(db, test, test.state.value)
|
||||
# Handle Exception
|
||||
except Exception as e:
|
||||
# Log warning: "Notification failed for test %s: %s", test.id, e
|
||||
logger.warning("Notification failed for test %s: %s", test.id, e, exc_info=True)
|
||||
|
||||
try:
|
||||
@@ -144,6 +246,7 @@ def start_execution(db: Session, test: Test, user: User) -> Test:
|
||||
return test
|
||||
|
||||
|
||||
# Define function submit_red_evidence
|
||||
def submit_red_evidence(db: Session, test: Test, user: User) -> Test:
|
||||
"""Move from ``red_executing`` → ``blue_evaluating``.
|
||||
|
||||
@@ -151,6 +254,14 @@ def submit_red_evidence(db: Session, test: Test, user: User) -> Test:
|
||||
Requires at least one Red Team evidence file to be uploaded.
|
||||
Stops the Red Team timer and creates an automatic worklog.
|
||||
Starts the Blue Team timer by recording ``blue_started_at``.
|
||||
|
||||
Args:
|
||||
db (Session): Active SQLAlchemy database session.
|
||||
test (Test): The test whose red-team evidence is being submitted.
|
||||
user (User): The red-team user submitting the evidence.
|
||||
|
||||
Returns:
|
||||
Test: The mutated test with state advanced and blue timer started.
|
||||
"""
|
||||
# Evidence is mandatory before submitting
|
||||
red_evidence_count = (
|
||||
@@ -167,29 +278,42 @@ def submit_red_evidence(db: Session, test: Test, user: User) -> Test:
|
||||
|
||||
# Auto-resume if paused
|
||||
paused_extra = 0
|
||||
# Check: test.paused_at is not None
|
||||
if test.paused_at is not None:
|
||||
# Assign paused_extra = max(int((now - test.paused_at).total_seconds()), 0)
|
||||
paused_extra = max(int((now - test.paused_at).total_seconds()), 0)
|
||||
# Assign test.paused_at = None
|
||||
test.paused_at = None
|
||||
|
||||
# Assign test = transition_state(
|
||||
test = transition_state(
|
||||
db, test, TestState.blue_evaluating, user,
|
||||
# Keyword argument: action_name
|
||||
action_name="submit_red_evidence",
|
||||
)
|
||||
|
||||
# Create automatic worklog for Red Team phase (subtract paused time)
|
||||
_create_phase_worklog(
|
||||
db,
|
||||
# Keyword argument: test
|
||||
test=test,
|
||||
# Keyword argument: user
|
||||
user=user,
|
||||
# Keyword argument: phase_started_at
|
||||
phase_started_at=test.red_started_at,
|
||||
# Keyword argument: phase_ended_at
|
||||
phase_ended_at=now,
|
||||
# Keyword argument: paused_seconds
|
||||
paused_seconds=(test.red_paused_seconds or 0) + paused_extra,
|
||||
# Keyword argument: activity_type
|
||||
activity_type="red_team_execution",
|
||||
# Keyword argument: description
|
||||
description=f"Red Team execution: {test.name}",
|
||||
)
|
||||
|
||||
# Start Blue Team timer
|
||||
test.blue_started_at = now
|
||||
# Assign test.blue_paused_seconds = 0
|
||||
test.blue_paused_seconds = 0
|
||||
|
||||
try:
|
||||
@@ -234,6 +358,7 @@ def start_blue_work(db: Session, test: Test, user: User) -> Test:
|
||||
return test
|
||||
|
||||
|
||||
# Define function submit_blue_evidence
|
||||
def submit_blue_evidence(db: Session, test: Test, user: User) -> Test:
|
||||
"""Move from ``blue_evaluating`` → ``in_review``.
|
||||
|
||||
@@ -258,12 +383,17 @@ def submit_blue_evidence(db: Session, test: Test, user: User) -> Test:
|
||||
|
||||
# Auto-resume if paused
|
||||
paused_extra = 0
|
||||
# Check: test.paused_at is not None
|
||||
if test.paused_at is not None:
|
||||
# Assign paused_extra = max(int((now - test.paused_at).total_seconds()), 0)
|
||||
paused_extra = max(int((now - test.paused_at).total_seconds()), 0)
|
||||
# Assign test.paused_at = None
|
||||
test.paused_at = None
|
||||
|
||||
# Assign test = transition_state(
|
||||
test = transition_state(
|
||||
db, test, TestState.in_review, user,
|
||||
# Keyword argument: action_name
|
||||
action_name="submit_blue_evidence",
|
||||
)
|
||||
|
||||
@@ -272,12 +402,17 @@ def submit_blue_evidence(db: Session, test: Test, user: User) -> Test:
|
||||
# Tempo worklog reflects real working time, not just queue time.
|
||||
_create_phase_worklog(
|
||||
db,
|
||||
# Keyword argument: test
|
||||
test=test,
|
||||
# Keyword argument: user
|
||||
user=user,
|
||||
phase_started_at=test.blue_work_started_at or test.blue_started_at,
|
||||
phase_ended_at=now,
|
||||
# Keyword argument: paused_seconds
|
||||
paused_seconds=(test.blue_paused_seconds or 0) + paused_extra,
|
||||
# Keyword argument: activity_type
|
||||
activity_type="blue_team_evaluation",
|
||||
# Keyword argument: description
|
||||
description=f"Blue Team evaluation: {test.name}",
|
||||
)
|
||||
|
||||
@@ -290,69 +425,125 @@ def submit_blue_evidence(db: Session, test: Test, user: User) -> Test:
|
||||
return test
|
||||
|
||||
|
||||
# Define function pause_timer
|
||||
def pause_timer(db: Session, test: Test, user: User) -> Test:
|
||||
"""Pause the active phase timer.
|
||||
|
||||
Can only be called when the test is in ``red_executing`` or
|
||||
``blue_evaluating`` and is not already paused.
|
||||
|
||||
Args:
|
||||
db (Session): Active SQLAlchemy database session.
|
||||
test (Test): The currently active test.
|
||||
user (User): The user pausing the timer.
|
||||
|
||||
Returns:
|
||||
Test: The mutated test with ``paused_at`` set to the current UTC time.
|
||||
"""
|
||||
# Check: test.state not in (TestState.red_executing, TestState.blue_evaluating)
|
||||
if test.state not in (TestState.red_executing, TestState.blue_evaluating):
|
||||
# Raise InvalidOperationError
|
||||
raise InvalidOperationError(
|
||||
f"Cannot pause timer in '{test.state.value}' state"
|
||||
)
|
||||
# Check: test.paused_at is not None
|
||||
if test.paused_at is not None:
|
||||
# Raise InvalidOperationError
|
||||
raise InvalidOperationError("Timer is already paused")
|
||||
|
||||
# Assign test.paused_at = datetime.utcnow()
|
||||
test.paused_at = datetime.utcnow()
|
||||
# Call log_action()
|
||||
log_action(
|
||||
db,
|
||||
# Keyword argument: user_id
|
||||
user_id=user.id,
|
||||
# Keyword argument: action
|
||||
action="pause_timer",
|
||||
# Keyword argument: entity_type
|
||||
entity_type="test",
|
||||
# Keyword argument: entity_id
|
||||
entity_id=test.id,
|
||||
# Keyword argument: details
|
||||
details={"state": test.state.value},
|
||||
)
|
||||
# Return test
|
||||
return test
|
||||
|
||||
|
||||
# Define function resume_timer
|
||||
def resume_timer(db: Session, test: Test, user: User) -> Test:
|
||||
"""Resume a paused phase timer.
|
||||
|
||||
Accumulates the paused duration into the appropriate counter so
|
||||
it is subtracted from the final worklog.
|
||||
|
||||
Args:
|
||||
db (Session): Active SQLAlchemy database session.
|
||||
test (Test): The paused test to resume.
|
||||
user (User): The user resuming the timer.
|
||||
|
||||
Returns:
|
||||
Test: The mutated test with ``paused_at`` cleared and accumulated
|
||||
pause seconds updated.
|
||||
"""
|
||||
# Check: test.paused_at is None
|
||||
if test.paused_at is None:
|
||||
# Raise InvalidOperationError
|
||||
raise InvalidOperationError("Timer is not paused")
|
||||
|
||||
# Assign now = datetime.utcnow()
|
||||
now = datetime.utcnow()
|
||||
# Assign paused_seconds = max(int((now - test.paused_at).total_seconds()), 0)
|
||||
paused_seconds = max(int((now - test.paused_at).total_seconds()), 0)
|
||||
|
||||
# Check: test.state == TestState.red_executing
|
||||
if test.state == TestState.red_executing:
|
||||
# Assign test.red_paused_seconds = (test.red_paused_seconds or 0) + paused_seconds
|
||||
test.red_paused_seconds = (test.red_paused_seconds or 0) + paused_seconds
|
||||
# Alternative: test.state == TestState.blue_evaluating
|
||||
elif test.state == TestState.blue_evaluating:
|
||||
# Assign test.blue_paused_seconds = (test.blue_paused_seconds or 0) + paused_seconds
|
||||
test.blue_paused_seconds = (test.blue_paused_seconds or 0) + paused_seconds
|
||||
|
||||
# Assign test.paused_at = None
|
||||
test.paused_at = None
|
||||
# Call log_action()
|
||||
log_action(
|
||||
db,
|
||||
# Keyword argument: user_id
|
||||
user_id=user.id,
|
||||
# Keyword argument: action
|
||||
action="resume_timer",
|
||||
# Keyword argument: entity_type
|
||||
entity_type="test",
|
||||
# Keyword argument: entity_id
|
||||
entity_id=test.id,
|
||||
# Keyword argument: details
|
||||
details={"paused_seconds": paused_seconds, "state": test.state.value},
|
||||
)
|
||||
# Return test
|
||||
return test
|
||||
|
||||
|
||||
# Define function _create_phase_worklog
|
||||
def _create_phase_worklog(
|
||||
# Entry: db
|
||||
db: Session,
|
||||
*,
|
||||
# Entry: test
|
||||
test: Test,
|
||||
# Entry: user
|
||||
user: User,
|
||||
# Entry: phase_started_at
|
||||
phase_started_at: datetime | None,
|
||||
# Entry: phase_ended_at
|
||||
phase_ended_at: datetime,
|
||||
# Entry: paused_seconds
|
||||
paused_seconds: int = 0,
|
||||
# Entry: activity_type
|
||||
activity_type: str,
|
||||
# Entry: description
|
||||
description: str,
|
||||
) -> None:
|
||||
"""Create an automatic, integrity-hashed worklog for a completed phase.
|
||||
@@ -360,32 +551,64 @@ def _create_phase_worklog(
|
||||
Subtracts accumulated *paused_seconds* from the gross elapsed time
|
||||
so the worklog reflects only active working time.
|
||||
Also triggers Tempo sync if the test has a Jira link.
|
||||
|
||||
Args:
|
||||
db (Session): Active SQLAlchemy database session.
|
||||
test (Test): The test for which the worklog is being created.
|
||||
user (User): The user attributed to the worklog.
|
||||
phase_started_at (datetime | None): Timestamp when the phase began;
|
||||
if ``None`` the worklog is skipped with a warning.
|
||||
phase_ended_at (datetime): Timestamp when the phase ended.
|
||||
paused_seconds (int): Accumulated paused time in seconds to subtract
|
||||
from gross elapsed time.
|
||||
activity_type (str): Worklog activity type label (e.g.
|
||||
``"red_team_execution"``).
|
||||
description (str): Human-readable description for the worklog.
|
||||
"""
|
||||
# Check: not phase_started_at
|
||||
if not phase_started_at:
|
||||
# Log warning:
|
||||
logger.warning(
|
||||
# Literal argument value
|
||||
"No phase start timestamp for test %s (%s), skipping worklog",
|
||||
test.id, activity_type,
|
||||
)
|
||||
# Return control to caller
|
||||
return
|
||||
|
||||
# Assign gross_seconds = int((phase_ended_at - phase_started_at).total_seconds())
|
||||
gross_seconds = int((phase_ended_at - phase_started_at).total_seconds())
|
||||
# Assign duration_seconds = max(gross_seconds - paused_seconds, 1)
|
||||
duration_seconds = max(gross_seconds - paused_seconds, 1)
|
||||
|
||||
# Attempt the following; catch errors below
|
||||
try:
|
||||
# Import create_worklog from app.services.worklog_service
|
||||
from app.services.worklog_service import create_worklog
|
||||
|
||||
# Assign wl = create_worklog(
|
||||
wl = create_worklog(
|
||||
db,
|
||||
# Keyword argument: entity_type
|
||||
entity_type="test",
|
||||
# Keyword argument: entity_id
|
||||
entity_id=test.id,
|
||||
# Keyword argument: user_id
|
||||
user_id=user.id,
|
||||
# Keyword argument: activity_type
|
||||
activity_type=activity_type,
|
||||
# Keyword argument: started_at
|
||||
started_at=phase_started_at,
|
||||
# Keyword argument: ended_at
|
||||
ended_at=phase_ended_at,
|
||||
# Keyword argument: duration_seconds
|
||||
duration_seconds=duration_seconds,
|
||||
# Keyword argument: description
|
||||
description=description,
|
||||
)
|
||||
# Log info:
|
||||
logger.info(
|
||||
# Literal argument value
|
||||
"Auto-worklog created for test %s: %s, %ds (worklog %s)",
|
||||
test.id, activity_type, duration_seconds, wl.id,
|
||||
)
|
||||
@@ -393,6 +616,7 @@ def _create_phase_worklog(
|
||||
# Sync to Tempo: only red_team_execution, using the already-computed
|
||||
# duration so the Tempo entry is identical to the Aegis worklog.
|
||||
try:
|
||||
# Import auto_log_test_worklog from app.services.tempo_service
|
||||
from app.services.tempo_service import auto_log_test_worklog
|
||||
tempo_result = auto_log_test_worklog(db, test, user, activity_type, duration_seconds)
|
||||
if tempo_result and isinstance(tempo_result, dict):
|
||||
@@ -400,17 +624,26 @@ def _create_phase_worklog(
|
||||
wl.tempo_worklog_id = str(tempo_result.get("tempoWorklogId", ""))
|
||||
db.flush()
|
||||
except Exception as e:
|
||||
# Log warning: "Tempo sync failed for worklog: %s", e, exc_info=T
|
||||
logger.warning("Tempo sync failed for worklog: %s", e, exc_info=True)
|
||||
|
||||
# Handle Exception
|
||||
except Exception as e:
|
||||
# Log error: "Failed to create auto-worklog for test %s: %s", t
|
||||
logger.error("Failed to create auto-worklog for test %s: %s", test.id, e, exc_info=True)
|
||||
|
||||
|
||||
# Define function validate_as_red_lead
|
||||
def validate_as_red_lead(
|
||||
# Entry: db
|
||||
db: Session,
|
||||
# Entry: test
|
||||
test: Test,
|
||||
# Entry: user
|
||||
user: User,
|
||||
# Entry: validation_status
|
||||
validation_status: str,
|
||||
# Entry: notes
|
||||
notes: str | None = None,
|
||||
) -> Test:
|
||||
"""Record Red Lead's validation decision.
|
||||
@@ -418,21 +651,45 @@ def validate_as_red_lead(
|
||||
Delegates validation rules and state mutation entirely to
|
||||
:meth:`TestEntity.validate_red`. If both leads have voted the
|
||||
entity will also advance the test to ``validated`` or ``rejected``.
|
||||
|
||||
Args:
|
||||
db (Session): Active SQLAlchemy database session.
|
||||
test (Test): The test being reviewed.
|
||||
user (User): The red-lead user casting their vote.
|
||||
validation_status (str): Validation decision, e.g. ``"approved"`` or
|
||||
``"rejected"``.
|
||||
notes (str | None): Optional freeform notes explaining the decision.
|
||||
|
||||
Returns:
|
||||
Test: The mutated test with red-lead validation fields set.
|
||||
"""
|
||||
# Assign entity = TestEntity.from_orm(test)
|
||||
entity = TestEntity.from_orm(test)
|
||||
# Call entity.validate_red()
|
||||
entity.validate_red(validation_status, by=user.id, notes=notes)
|
||||
# Call entity.apply_to()
|
||||
entity.apply_to(test)
|
||||
# Flush changes to DB without committing the transaction
|
||||
db.flush()
|
||||
|
||||
# Call log_action()
|
||||
log_action(
|
||||
db,
|
||||
# Keyword argument: user_id
|
||||
user_id=user.id,
|
||||
# Keyword argument: action
|
||||
action="validate_as_red_lead",
|
||||
# Keyword argument: entity_type
|
||||
entity_type="test",
|
||||
# Keyword argument: entity_id
|
||||
entity_id=test.id,
|
||||
# Keyword argument: details
|
||||
details={
|
||||
# Literal argument value
|
||||
"validation_status": validation_status,
|
||||
# Literal argument value
|
||||
"notes": notes,
|
||||
# Literal argument value
|
||||
"technique_id": str(test.technique_id),
|
||||
},
|
||||
)
|
||||
@@ -441,11 +698,17 @@ def validate_as_red_lead(
|
||||
return test
|
||||
|
||||
|
||||
# Define function validate_as_blue_lead
|
||||
def validate_as_blue_lead(
|
||||
# Entry: db
|
||||
db: Session,
|
||||
# Entry: test
|
||||
test: Test,
|
||||
# Entry: user
|
||||
user: User,
|
||||
# Entry: validation_status
|
||||
validation_status: str,
|
||||
# Entry: notes
|
||||
notes: str | None = None,
|
||||
) -> Test:
|
||||
"""Record Blue Lead's validation decision.
|
||||
@@ -453,21 +716,45 @@ def validate_as_blue_lead(
|
||||
Delegates validation rules and state mutation entirely to
|
||||
:meth:`TestEntity.validate_blue`. If both leads have voted the
|
||||
entity will also advance the test to ``validated`` or ``rejected``.
|
||||
|
||||
Args:
|
||||
db (Session): Active SQLAlchemy database session.
|
||||
test (Test): The test being reviewed.
|
||||
user (User): The blue-lead user casting their vote.
|
||||
validation_status (str): Validation decision, e.g. ``"approved"`` or
|
||||
``"rejected"``.
|
||||
notes (str | None): Optional freeform notes explaining the decision.
|
||||
|
||||
Returns:
|
||||
Test: The mutated test with blue-lead validation fields set.
|
||||
"""
|
||||
# Assign entity = TestEntity.from_orm(test)
|
||||
entity = TestEntity.from_orm(test)
|
||||
# Call entity.validate_blue()
|
||||
entity.validate_blue(validation_status, by=user.id, notes=notes)
|
||||
# Call entity.apply_to()
|
||||
entity.apply_to(test)
|
||||
# Flush changes to DB without committing the transaction
|
||||
db.flush()
|
||||
|
||||
# Call log_action()
|
||||
log_action(
|
||||
db,
|
||||
# Keyword argument: user_id
|
||||
user_id=user.id,
|
||||
# Keyword argument: action
|
||||
action="validate_as_blue_lead",
|
||||
# Keyword argument: entity_type
|
||||
entity_type="test",
|
||||
# Keyword argument: entity_id
|
||||
entity_id=test.id,
|
||||
# Keyword argument: details
|
||||
details={
|
||||
# Literal argument value
|
||||
"validation_status": validation_status,
|
||||
# Literal argument value
|
||||
"notes": notes,
|
||||
# Literal argument value
|
||||
"technique_id": str(test.technique_id),
|
||||
},
|
||||
)
|
||||
@@ -476,35 +763,61 @@ def validate_as_blue_lead(
|
||||
return test
|
||||
|
||||
|
||||
# Define function check_dual_validation
|
||||
def check_dual_validation(db: Session, test: Test) -> Test:
|
||||
"""Evaluate both leads' decisions and advance the test if both have voted.
|
||||
|
||||
All state mutation is delegated to :meth:`TestEntity.check_dual_validation`.
|
||||
This function never assigns ``test.state`` directly.
|
||||
|
||||
Args:
|
||||
db (Session): Active SQLAlchemy database session.
|
||||
test (Test): The test to evaluate.
|
||||
|
||||
Returns:
|
||||
Test: The mutated test, potentially with state advanced to
|
||||
``validated`` or ``rejected``.
|
||||
"""
|
||||
# Assign entity = TestEntity.from_orm(test)
|
||||
entity = TestEntity.from_orm(test)
|
||||
# Call entity.check_dual_validation()
|
||||
entity.check_dual_validation()
|
||||
# Call entity.apply_to()
|
||||
entity.apply_to(test)
|
||||
|
||||
# Call _dispatch_dual_validation_effects()
|
||||
_dispatch_dual_validation_effects(db, test, entity)
|
||||
# Return test
|
||||
return test
|
||||
|
||||
|
||||
# Define function _dispatch_dual_validation_effects
|
||||
def _dispatch_dual_validation_effects(
|
||||
db: Session, test: Test, entity: TestEntity, actor: User | None = None
|
||||
) -> None:
|
||||
"""Dispatch side effects (notifications, cache, Jira) based on domain events."""
|
||||
for event in entity.events:
|
||||
# Check: event.name == "dual_validation_approved"
|
||||
if event.name == "dual_validation_approved":
|
||||
# Attempt the following; catch errors below
|
||||
try:
|
||||
# Import invalidate from app.services.score_cache
|
||||
from app.services.score_cache import invalidate
|
||||
# Call invalidate()
|
||||
invalidate()
|
||||
# Handle Exception
|
||||
except Exception as e:
|
||||
# Log warning: "Score cache invalidation failed: %s", e, exc_info
|
||||
logger.warning("Score cache invalidation failed: %s", e, exc_info=True)
|
||||
# Attempt the following; catch errors below
|
||||
try:
|
||||
# Call notify_test_state_change()
|
||||
notify_test_state_change(db, test, "validated")
|
||||
# Handle Exception
|
||||
except Exception as e:
|
||||
# Log warning:
|
||||
logger.warning(
|
||||
# Literal argument value
|
||||
"Notification failed for test %s (validated): %s",
|
||||
test.id, e, exc_info=True,
|
||||
)
|
||||
@@ -516,10 +829,15 @@ def _dispatch_dual_validation_effects(
|
||||
logger.warning("Jira push failed for test %s: %s", test.id, e, exc_info=True)
|
||||
|
||||
elif event.name == "dual_validation_rejected":
|
||||
# Attempt the following; catch errors below
|
||||
try:
|
||||
# Call notify_test_state_change()
|
||||
notify_test_state_change(db, test, "rejected")
|
||||
# Handle Exception
|
||||
except Exception as e:
|
||||
# Log warning:
|
||||
logger.warning(
|
||||
# Literal argument value
|
||||
"Notification failed for test %s (rejected): %s",
|
||||
test.id, e, exc_info=True,
|
||||
)
|
||||
@@ -585,6 +903,7 @@ def _notify_validation_conflict(db: Session, test: Test, actor: User | None) ->
|
||||
)
|
||||
|
||||
|
||||
# Define function handle_remediation_completed
|
||||
def handle_remediation_completed(db: Session, test: Test, user: User) -> Test | None:
|
||||
"""Create a re-test when remediation is completed.
|
||||
|
||||
@@ -594,121 +913,199 @@ def handle_remediation_completed(db: Session, test: Test, user: User) -> Test |
|
||||
|
||||
Prevents infinite loops by enforcing ``MAX_RETEST_COUNT``.
|
||||
|
||||
Returns the new retest or *None* if the limit was reached.
|
||||
Args:
|
||||
db (Session): Active SQLAlchemy database session.
|
||||
test (Test): The test whose remediation was completed.
|
||||
user (User): The user triggering the remediation completion.
|
||||
|
||||
Returns:
|
||||
Test | None: The newly created retest, or ``None`` if the maximum
|
||||
retest count has been reached.
|
||||
"""
|
||||
# Always reference the original test, not an intermediate retest
|
||||
original_test_id = test.retest_of or test.id
|
||||
|
||||
# Check: test.retest_count >= settings.MAX_RETEST_COUNT
|
||||
if test.retest_count >= settings.MAX_RETEST_COUNT:
|
||||
# Max retests reached — notify and bail out
|
||||
if test.created_by:
|
||||
# Call create_notification()
|
||||
create_notification(
|
||||
db,
|
||||
# Keyword argument: user_id
|
||||
user_id=test.created_by,
|
||||
# Keyword argument: type
|
||||
type="max_retests_reached",
|
||||
# Keyword argument: title
|
||||
title="Maximum retests reached",
|
||||
# Keyword argument: message
|
||||
message=(
|
||||
f'Test "{test.name}" has reached the maximum of '
|
||||
f'{settings.MAX_RETEST_COUNT} retests. Manual review required.'
|
||||
),
|
||||
# Keyword argument: entity_type
|
||||
entity_type="test",
|
||||
# Keyword argument: entity_id
|
||||
entity_id=test.id,
|
||||
)
|
||||
|
||||
# Call log_action()
|
||||
log_action(
|
||||
db,
|
||||
# Keyword argument: user_id
|
||||
user_id=user.id,
|
||||
# Keyword argument: action
|
||||
action="max_retests_reached",
|
||||
# Keyword argument: entity_type
|
||||
entity_type="test",
|
||||
# Keyword argument: entity_id
|
||||
entity_id=test.id,
|
||||
# Keyword argument: details
|
||||
details={
|
||||
# Literal argument value
|
||||
"retest_count": test.retest_count,
|
||||
# Literal argument value
|
||||
"max_allowed": settings.MAX_RETEST_COUNT,
|
||||
# Literal argument value
|
||||
"original_test_id": str(original_test_id),
|
||||
},
|
||||
)
|
||||
# Return None
|
||||
return None
|
||||
|
||||
# Assign retest = Test(
|
||||
retest = Test(
|
||||
# Keyword argument: technique_id
|
||||
technique_id=test.technique_id,
|
||||
# Keyword argument: name
|
||||
name=f"[Retest #{test.retest_count + 1}] {test.name.replace(f'[Retest #{test.retest_count}] ', '')}",
|
||||
# Keyword argument: description
|
||||
description=test.description,
|
||||
# Keyword argument: platform
|
||||
platform=test.platform,
|
||||
# Keyword argument: procedure_text
|
||||
procedure_text=test.procedure_text,
|
||||
# Keyword argument: tool_used
|
||||
tool_used=test.tool_used,
|
||||
# Keyword argument: state
|
||||
state=TestState.draft,
|
||||
# Keyword argument: created_by
|
||||
created_by=test.created_by,
|
||||
# Keyword argument: retest_of
|
||||
retest_of=original_test_id,
|
||||
# Keyword argument: retest_count
|
||||
retest_count=test.retest_count + 1,
|
||||
)
|
||||
# Stage new record(s) for database insertion
|
||||
db.add(retest)
|
||||
# Flush changes to DB without committing the transaction
|
||||
db.flush()
|
||||
|
||||
# Call log_action()
|
||||
log_action(
|
||||
db,
|
||||
# Keyword argument: user_id
|
||||
user_id=user.id,
|
||||
# Keyword argument: action
|
||||
action="create_retest",
|
||||
# Keyword argument: entity_type
|
||||
entity_type="test",
|
||||
# Keyword argument: entity_id
|
||||
entity_id=retest.id,
|
||||
# Keyword argument: details
|
||||
details={
|
||||
# Literal argument value
|
||||
"original_test_id": str(original_test_id),
|
||||
# Literal argument value
|
||||
"retest_number": retest.retest_count,
|
||||
# Literal argument value
|
||||
"source_test_id": str(test.id),
|
||||
},
|
||||
)
|
||||
|
||||
# Notify the test creator and any red_tech users
|
||||
if test.created_by:
|
||||
# Call create_notification()
|
||||
create_notification(
|
||||
db,
|
||||
# Keyword argument: user_id
|
||||
user_id=test.created_by,
|
||||
# Keyword argument: type
|
||||
type="retest_created",
|
||||
# Keyword argument: title
|
||||
title="Re-test created",
|
||||
# Keyword argument: message
|
||||
message=(
|
||||
f'A re-test has been automatically created for "{test.name}" '
|
||||
f'after remediation was completed.'
|
||||
),
|
||||
# Keyword argument: entity_type
|
||||
entity_type="test",
|
||||
# Keyword argument: entity_id
|
||||
entity_id=retest.id,
|
||||
)
|
||||
|
||||
# Flush changes to DB without committing the transaction
|
||||
db.flush()
|
||||
# Return retest
|
||||
return retest
|
||||
|
||||
|
||||
def get_retest_chain(db: Session, test_id) -> list[Test]:
|
||||
# Define function get_retest_chain
|
||||
def get_retest_chain(db: Session, test_id: uuid.UUID) -> list[Test]:
|
||||
"""Return the full chain of retests for a given test.
|
||||
|
||||
Includes the original test and all subsequent retests, ordered
|
||||
by retest_count.
|
||||
|
||||
Args:
|
||||
db (Session): Active SQLAlchemy database session.
|
||||
test_id (uuid.UUID): UUID of any test in the retest chain.
|
||||
|
||||
Returns:
|
||||
list[Test]: The original test followed by all its retests in
|
||||
ascending retest-count order. Returns an empty list if the
|
||||
test is not found.
|
||||
"""
|
||||
# Import uuid
|
||||
import uuid as _uuid
|
||||
|
||||
# Assign tid = _uuid.UUID(str(test_id)) if not isinstance(test_id, _uuid.UUID) els...
|
||||
tid = _uuid.UUID(str(test_id)) if not isinstance(test_id, _uuid.UUID) else test_id
|
||||
|
||||
# Find the original test first
|
||||
test = db.query(Test).filter(Test.id == tid).first()
|
||||
# Check: not test
|
||||
if not test:
|
||||
# Return []
|
||||
return []
|
||||
|
||||
# Assign original_id = test.retest_of or test.id
|
||||
original_id = test.retest_of or test.id
|
||||
|
||||
# Get original
|
||||
original = db.query(Test).filter(Test.id == original_id).first()
|
||||
# Check: not original
|
||||
if not original:
|
||||
# Return [test]
|
||||
return [test]
|
||||
|
||||
# Get all retests of the original
|
||||
retests = (
|
||||
db.query(Test)
|
||||
# Chain .filter() call
|
||||
.filter(Test.retest_of == original_id)
|
||||
# Chain .order_by() call
|
||||
.order_by(Test.retest_count)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
|
||||
# Return [original] + retests
|
||||
return [original] + retests
|
||||
|
||||
|
||||
# Define function reopen_test
|
||||
def reopen_test(db: Session, test: Test, user: User) -> Test:
|
||||
"""Move a ``rejected`` test back to ``draft`` for continued work.
|
||||
|
||||
@@ -719,20 +1116,27 @@ def reopen_test(db: Session, test: Test, user: User) -> Test:
|
||||
re-validate the updated submission. Phase timing is reset so the timer
|
||||
starts fresh for the new execution attempt.
|
||||
"""
|
||||
# Assign test = transition_state(
|
||||
test = transition_state(
|
||||
db, test, TestState.draft, user,
|
||||
# Keyword argument: action_name
|
||||
action_name="reopen_test",
|
||||
)
|
||||
|
||||
# Clear validation DECISIONS — leads must re-validate the new attempt.
|
||||
# Rejection NOTES are intentionally kept so teams see what needs fixing.
|
||||
test.red_validation_status = None
|
||||
# Assign test.red_validated_by = None
|
||||
test.red_validated_by = None
|
||||
# Assign test.red_validated_at = None
|
||||
test.red_validated_at = None
|
||||
# test.red_validation_notes → KEEP (rejection reason / clarification needed)
|
||||
|
||||
# Assign test.blue_validation_status = None
|
||||
test.blue_validation_status = None
|
||||
# Assign test.blue_validated_by = None
|
||||
test.blue_validated_by = None
|
||||
# Assign test.blue_validated_at = None
|
||||
test.blue_validated_at = None
|
||||
# test.blue_validation_notes → KEEP (rejection reason / clarification needed)
|
||||
|
||||
@@ -749,4 +1153,5 @@ def reopen_test(db: Session, test: Test, user: User) -> Test:
|
||||
except Exception as e:
|
||||
logger.warning("Jira push failed for test %s: %s", test.id, e, exc_info=True)
|
||||
|
||||
# Return test
|
||||
return test
|
||||
|
||||
@@ -26,23 +26,49 @@ Deduplication by ``mitre_id`` for ThreatActor and by the unique
|
||||
constraint ``(threat_actor_id, technique_id)`` for ThreatActorTechnique.
|
||||
"""
|
||||
|
||||
# Import io
|
||||
import io
|
||||
|
||||
# Import json
|
||||
import json
|
||||
|
||||
# Import logging
|
||||
import logging
|
||||
|
||||
# Import shutil
|
||||
import shutil
|
||||
|
||||
# Import tempfile
|
||||
import tempfile
|
||||
|
||||
# Import zipfile
|
||||
import zipfile
|
||||
|
||||
# Import datetime from datetime
|
||||
from datetime import datetime
|
||||
|
||||
# Import Path from pathlib
|
||||
from pathlib import Path
|
||||
|
||||
# Import requests
|
||||
import requests as _requests
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.threat_actor import ThreatActor, ThreatActorTechnique
|
||||
from app.models.technique import Technique
|
||||
# Import DataSource from app.models.data_source
|
||||
from app.models.data_source import DataSource
|
||||
|
||||
# Import Technique from app.models.technique
|
||||
from app.models.technique import Technique
|
||||
|
||||
# Import ThreatActor, ThreatActorTechnique from app.models.threat_actor
|
||||
from app.models.threat_actor import ThreatActor, ThreatActorTechnique
|
||||
|
||||
# Import log_action from app.services.audit_service
|
||||
from app.services.audit_service import log_action
|
||||
|
||||
# Assign logger = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -50,11 +76,15 @@ logger = logging.getLogger(__name__)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
MITRE_CTI_ZIP_URL = (
|
||||
# Literal argument value
|
||||
"https://github.com/mitre/cti"
|
||||
# Literal argument value
|
||||
"/archive/refs/heads/master.zip"
|
||||
)
|
||||
|
||||
# Assign _DOWNLOAD_TIMEOUT = 300
|
||||
_DOWNLOAD_TIMEOUT = 300
|
||||
# Assign _ZIP_ROOT_PREFIX = "cti-master"
|
||||
_ZIP_ROOT_PREFIX = "cti-master"
|
||||
|
||||
|
||||
@@ -65,54 +95,86 @@ _ZIP_ROOT_PREFIX = "cti-master"
|
||||
|
||||
def _download_zip(url: str = MITRE_CTI_ZIP_URL) -> bytes:
|
||||
"""Download the MITRE CTI ZIP and return raw bytes."""
|
||||
# Log info: "Downloading MITRE CTI ZIP from %s …", url
|
||||
logger.info("Downloading MITRE CTI ZIP from %s …", url)
|
||||
# Assign resp = _requests.get(url, timeout=_DOWNLOAD_TIMEOUT, stream=True)
|
||||
resp = _requests.get(url, timeout=_DOWNLOAD_TIMEOUT, stream=True)
|
||||
# Call resp.raise_for_status()
|
||||
resp.raise_for_status()
|
||||
# Assign content = resp.content
|
||||
content = resp.content
|
||||
# Log info: "Downloaded %.1f MB", len(content) / (1024 * 1024
|
||||
logger.info("Downloaded %.1f MB", len(content) / (1024 * 1024))
|
||||
# Return content
|
||||
return content
|
||||
|
||||
|
||||
# Define function _extract_zip_and_load_bundle
|
||||
def _extract_zip_and_load_bundle(zip_bytes: bytes, dest: str) -> dict:
|
||||
"""Extract ZIP and load the enterprise-attack STIX bundle."""
|
||||
# Open context manager
|
||||
with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
|
||||
# Call zf.extractall()
|
||||
zf.extractall(dest)
|
||||
|
||||
# Assign bundle_path = (
|
||||
bundle_path = (
|
||||
Path(dest) / _ZIP_ROOT_PREFIX
|
||||
/ "enterprise-attack" / "enterprise-attack.json"
|
||||
)
|
||||
# Check: not bundle_path.is_file()
|
||||
if not bundle_path.is_file():
|
||||
# Raise FileNotFoundError
|
||||
raise FileNotFoundError(
|
||||
f"STIX bundle not found at {bundle_path}"
|
||||
)
|
||||
|
||||
# Log info: "Loading STIX bundle from %s …", bundle_path
|
||||
logger.info("Loading STIX bundle from %s …", bundle_path)
|
||||
# Open context manager
|
||||
with open(bundle_path, "r", encoding="utf-8") as fh:
|
||||
# Assign bundle = json.load(fh)
|
||||
bundle = json.load(fh)
|
||||
|
||||
# Assign objects = bundle.get("objects", [])
|
||||
objects = bundle.get("objects", [])
|
||||
# Log info: "Loaded %d STIX objects", len(objects
|
||||
logger.info("Loaded %d STIX objects", len(objects))
|
||||
# Return bundle
|
||||
return bundle
|
||||
|
||||
|
||||
# Define function _extract_mitre_id
|
||||
def _extract_mitre_id(external_references: list) -> str | None:
|
||||
"""Extract the MITRE ATT&CK ID from external_references."""
|
||||
# Check: not isinstance(external_references, list)
|
||||
if not isinstance(external_references, list):
|
||||
# Return None
|
||||
return None
|
||||
# Iterate over external_references
|
||||
for ref in external_references:
|
||||
# Check: isinstance(ref, dict) and ref.get("source_name") == "mitre-attack"
|
||||
if isinstance(ref, dict) and ref.get("source_name") == "mitre-attack":
|
||||
# Return ref.get("external_id")
|
||||
return ref.get("external_id")
|
||||
# Return None
|
||||
return None
|
||||
|
||||
|
||||
# Define function _extract_mitre_url
|
||||
def _extract_mitre_url(external_references: list) -> str | None:
|
||||
"""Extract the MITRE ATT&CK URL from external_references."""
|
||||
# Check: not isinstance(external_references, list)
|
||||
if not isinstance(external_references, list):
|
||||
# Return None
|
||||
return None
|
||||
# Iterate over external_references
|
||||
for ref in external_references:
|
||||
# Check: isinstance(ref, dict) and ref.get("source_name") == "mitre-attack"
|
||||
if isinstance(ref, dict) and ref.get("source_name") == "mitre-attack":
|
||||
# Return ref.get("url")
|
||||
return ref.get("url")
|
||||
# Return None
|
||||
return None
|
||||
|
||||
|
||||
@@ -316,25 +378,41 @@ def _infer_motivation_from_description(description: str) -> str | None:
|
||||
|
||||
def _parse_intrusion_sets(objects: list) -> list[dict]:
|
||||
"""Parse STIX intrusion-set objects into ThreatActor dicts."""
|
||||
# Assign actors = []
|
||||
actors = []
|
||||
# Iterate over objects
|
||||
for obj in objects:
|
||||
# Check: obj.get("type") != "intrusion-set"
|
||||
if obj.get("type") != "intrusion-set":
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
# Check: obj.get("revoked")
|
||||
if obj.get("revoked"):
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
|
||||
# Assign ext_refs = obj.get("external_references", [])
|
||||
ext_refs = obj.get("external_references", [])
|
||||
# Assign mitre_id = _extract_mitre_id(ext_refs)
|
||||
mitre_id = _extract_mitre_id(ext_refs)
|
||||
# Assign mitre_url = _extract_mitre_url(ext_refs)
|
||||
mitre_url = _extract_mitre_url(ext_refs)
|
||||
|
||||
# Assign name = obj.get("name", "").strip()
|
||||
name = obj.get("name", "").strip()
|
||||
# Check: not name
|
||||
if not name:
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
|
||||
# Assign aliases = obj.get("aliases", [])
|
||||
aliases = obj.get("aliases", [])
|
||||
# Check: isinstance(aliases, list) and name in aliases
|
||||
if isinstance(aliases, list) and name in aliases:
|
||||
# Assign aliases = [a for a in aliases if a != name]
|
||||
aliases = [a for a in aliases if a != name]
|
||||
|
||||
# Assign description = obj.get("description", "")
|
||||
description = obj.get("description", "")
|
||||
|
||||
# Derive motivation: curated override > STIX field > description inference
|
||||
@@ -348,80 +426,129 @@ def _parse_intrusion_sets(objects: list) -> list[dict]:
|
||||
|
||||
# Extract references (non-MITRE)
|
||||
references = []
|
||||
# Iterate over ext_refs
|
||||
for ref in ext_refs:
|
||||
# Check: isinstance(ref, dict) and ref.get("source_name") != "mitre-attack"
|
||||
if isinstance(ref, dict) and ref.get("source_name") != "mitre-attack":
|
||||
# Call references.append()
|
||||
references.append({
|
||||
# Literal argument value
|
||||
"source": ref.get("source_name", ""),
|
||||
# Literal argument value
|
||||
"url": ref.get("url", ""),
|
||||
# Literal argument value
|
||||
"description": ref.get("description", ""),
|
||||
})
|
||||
|
||||
# Call actors.append()
|
||||
actors.append({
|
||||
# Literal argument value
|
||||
"stix_id": obj.get("id"), # e.g. "intrusion-set--abc123"
|
||||
# Literal argument value
|
||||
"mitre_id": mitre_id,
|
||||
# Literal argument value
|
||||
"name": name,
|
||||
# Literal argument value
|
||||
"aliases": aliases if aliases else [],
|
||||
# Literal argument value
|
||||
"description": description,
|
||||
# Literal argument value
|
||||
"mitre_url": mitre_url,
|
||||
# Literal argument value
|
||||
"references": references[:20], # cap to avoid bloat
|
||||
# Literal argument value
|
||||
"first_seen": obj.get("first_seen"),
|
||||
# Literal argument value
|
||||
"last_seen": obj.get("last_seen"),
|
||||
"motivation": motivation,
|
||||
"sophistication": sophistication,
|
||||
})
|
||||
|
||||
# Log info: "Parsed %d intrusion-sets (threat actors)", len(ac
|
||||
logger.info("Parsed %d intrusion-sets (threat actors)", len(actors))
|
||||
# Return actors
|
||||
return actors
|
||||
|
||||
|
||||
# Define function _parse_relationships
|
||||
def _parse_relationships(objects: list) -> list[dict]:
|
||||
"""Parse STIX relationship objects (type=uses) linking
|
||||
intrusion-sets to attack-patterns.
|
||||
"""
|
||||
"""Parse STIX relationship objects (type=uses) linking intrusion-sets to attack-patterns."""
|
||||
# Assign relationships = []
|
||||
relationships = []
|
||||
# Iterate over objects
|
||||
for obj in objects:
|
||||
# Check: obj.get("type") != "relationship"
|
||||
if obj.get("type") != "relationship":
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
# Check: obj.get("relationship_type") != "uses"
|
||||
if obj.get("relationship_type") != "uses":
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
# Check: obj.get("revoked")
|
||||
if obj.get("revoked"):
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
|
||||
# Assign source_ref = obj.get("source_ref", "")
|
||||
source_ref = obj.get("source_ref", "")
|
||||
# Assign target_ref = obj.get("target_ref", "")
|
||||
target_ref = obj.get("target_ref", "")
|
||||
|
||||
# We want intrusion-set → attack-pattern
|
||||
if not source_ref.startswith("intrusion-set--"):
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
# Check: not target_ref.startswith("attack-pattern--")
|
||||
if not target_ref.startswith("attack-pattern--"):
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
|
||||
# Call relationships.append()
|
||||
relationships.append({
|
||||
# Literal argument value
|
||||
"source_ref": source_ref,
|
||||
# Literal argument value
|
||||
"target_ref": target_ref,
|
||||
# Literal argument value
|
||||
"description": obj.get("description", ""),
|
||||
})
|
||||
|
||||
# Log info: "Parsed %d uses-relationships (actor→technique)",
|
||||
logger.info("Parsed %d uses-relationships (actor→technique)", len(relationships))
|
||||
# Return relationships
|
||||
return relationships
|
||||
|
||||
|
||||
# Define function _build_attack_pattern_map
|
||||
def _build_attack_pattern_map(objects: list) -> dict[str, str]:
|
||||
"""Build a map from STIX attack-pattern ID → MITRE technique ID.
|
||||
|
||||
e.g. {"attack-pattern--abc123": "T1059.001"}
|
||||
"""
|
||||
# Assign mapping = {}
|
||||
mapping = {}
|
||||
# Iterate over objects
|
||||
for obj in objects:
|
||||
# Check: obj.get("type") != "attack-pattern"
|
||||
if obj.get("type") != "attack-pattern":
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
# Check: obj.get("revoked")
|
||||
if obj.get("revoked"):
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
# Assign stix_id = obj.get("id", "")
|
||||
stix_id = obj.get("id", "")
|
||||
# Assign mitre_id = _extract_mitre_id(obj.get("external_references", []))
|
||||
mitre_id = _extract_mitre_id(obj.get("external_references", []))
|
||||
# Check: stix_id and mitre_id
|
||||
if stix_id and mitre_id:
|
||||
# Assign mapping[stix_id] = mitre_id
|
||||
mapping[stix_id] = mitre_id
|
||||
# Log info: "Built attack-pattern map with %d entries", len(ma
|
||||
logger.info("Built attack-pattern map with %d entries", len(mapping))
|
||||
# Return mapping
|
||||
return mapping
|
||||
|
||||
|
||||
@@ -435,24 +562,31 @@ def sync(db: Session) -> dict:
|
||||
|
||||
Returns a summary dict.
|
||||
"""
|
||||
# Assign tmp_dir = tempfile.mkdtemp(prefix="aegis_mitre_cti_")
|
||||
tmp_dir = tempfile.mkdtemp(prefix="aegis_mitre_cti_")
|
||||
# Attempt the following; catch errors below
|
||||
try:
|
||||
# Assign zip_bytes = _download_zip()
|
||||
zip_bytes = _download_zip()
|
||||
# Assign bundle = _extract_zip_and_load_bundle(zip_bytes, tmp_dir)
|
||||
bundle = _extract_zip_and_load_bundle(zip_bytes, tmp_dir)
|
||||
# Always execute this cleanup block
|
||||
finally:
|
||||
# Call shutil.rmtree()
|
||||
shutil.rmtree(tmp_dir, ignore_errors=True)
|
||||
# Log info: "Cleaned up temp directory %s", tmp_dir
|
||||
logger.info("Cleaned up temp directory %s", tmp_dir)
|
||||
|
||||
# Assign objects = bundle.get("objects", [])
|
||||
objects = bundle.get("objects", [])
|
||||
|
||||
# Step 1: Parse data
|
||||
actor_dicts = _parse_intrusion_sets(objects)
|
||||
# Assign relationships = _parse_relationships(objects)
|
||||
relationships = _parse_relationships(objects)
|
||||
# Assign attack_pattern_map = _build_attack_pattern_map(objects)
|
||||
attack_pattern_map = _build_attack_pattern_map(objects)
|
||||
|
||||
# Step 2: Build STIX-ID → actor dict map
|
||||
stix_to_actor = {a["stix_id"]: a for a in actor_dicts}
|
||||
|
||||
# Step 3: Load existing actors and techniques from DB
|
||||
existing_actors = {
|
||||
row.mitre_id: row
|
||||
@@ -460,6 +594,7 @@ def sync(db: Session) -> dict:
|
||||
if row.mitre_id
|
||||
}
|
||||
|
||||
# Assign technique_by_mitre_id = {
|
||||
technique_by_mitre_id = {
|
||||
row.mitre_id: row
|
||||
for row in db.query(Technique).all()
|
||||
@@ -467,22 +602,35 @@ def sync(db: Session) -> dict:
|
||||
|
||||
# Step 4: Upsert threat actors
|
||||
actors_created = 0
|
||||
# Assign actors_skipped = 0
|
||||
actors_skipped = 0
|
||||
# Assign stix_to_db_actor = {}
|
||||
stix_to_db_actor: dict[str, ThreatActor] = {}
|
||||
|
||||
# Iterate over actor_dicts
|
||||
for actor_dict in actor_dicts:
|
||||
# Assign mitre_id = actor_dict["mitre_id"]
|
||||
mitre_id = actor_dict["mitre_id"]
|
||||
# Assign stix_id = actor_dict["stix_id"]
|
||||
stix_id = actor_dict["stix_id"]
|
||||
|
||||
# Check: mitre_id and mitre_id in existing_actors
|
||||
if mitre_id and mitre_id in existing_actors:
|
||||
# Update existing actor
|
||||
db_actor = existing_actors[mitre_id]
|
||||
# Assign db_actor.name = actor_dict["name"]
|
||||
db_actor.name = actor_dict["name"]
|
||||
# Assign db_actor.aliases = actor_dict["aliases"]
|
||||
db_actor.aliases = actor_dict["aliases"]
|
||||
# Assign db_actor.description = actor_dict["description"]
|
||||
db_actor.description = actor_dict["description"]
|
||||
# Assign db_actor.mitre_url = actor_dict["mitre_url"]
|
||||
db_actor.mitre_url = actor_dict["mitre_url"]
|
||||
# Assign db_actor.references = actor_dict["references"]
|
||||
db_actor.references = actor_dict["references"]
|
||||
# Assign db_actor.first_seen = actor_dict.get("first_seen")
|
||||
db_actor.first_seen = actor_dict.get("first_seen")
|
||||
# Assign db_actor.last_seen = actor_dict.get("last_seen")
|
||||
db_actor.last_seen = actor_dict.get("last_seen")
|
||||
# Update enrichment fields if available
|
||||
if actor_dict.get("motivation"):
|
||||
@@ -490,101 +638,165 @@ def sync(db: Session) -> dict:
|
||||
if actor_dict.get("sophistication"):
|
||||
db_actor.sophistication = actor_dict["sophistication"]
|
||||
stix_to_db_actor[stix_id] = db_actor
|
||||
# Assign actors_skipped = 1
|
||||
actors_skipped += 1
|
||||
# Fallback: handle remaining cases
|
||||
else:
|
||||
# Create new actor
|
||||
db_actor = ThreatActor(
|
||||
# Keyword argument: mitre_id
|
||||
mitre_id=mitre_id,
|
||||
# Keyword argument: name
|
||||
name=actor_dict["name"],
|
||||
# Keyword argument: aliases
|
||||
aliases=actor_dict["aliases"],
|
||||
# Keyword argument: description
|
||||
description=actor_dict["description"],
|
||||
# Keyword argument: mitre_url
|
||||
mitre_url=actor_dict["mitre_url"],
|
||||
# Keyword argument: references
|
||||
references=actor_dict["references"],
|
||||
# Keyword argument: first_seen
|
||||
first_seen=actor_dict.get("first_seen"),
|
||||
# Keyword argument: last_seen
|
||||
last_seen=actor_dict.get("last_seen"),
|
||||
motivation=actor_dict.get("motivation"),
|
||||
sophistication=actor_dict.get("sophistication"),
|
||||
is_active=True,
|
||||
)
|
||||
# Stage new record(s) for database insertion
|
||||
db.add(db_actor)
|
||||
# Flush changes to DB without committing the transaction
|
||||
db.flush() # get the ID
|
||||
# Check: mitre_id
|
||||
if mitre_id:
|
||||
# Assign existing_actors[mitre_id] = db_actor
|
||||
existing_actors[mitre_id] = db_actor
|
||||
# Assign stix_to_db_actor[stix_id] = db_actor
|
||||
stix_to_db_actor[stix_id] = db_actor
|
||||
# Assign actors_created = 1
|
||||
actors_created += 1
|
||||
|
||||
# Flush changes to DB without committing the transaction
|
||||
db.flush()
|
||||
|
||||
# Step 5: Upsert actor-technique relationships
|
||||
# Load existing relationships
|
||||
existing_rels: set[tuple] = set()
|
||||
# Iterate over db.query(ThreatActorTechnique).all()
|
||||
for row in db.query(ThreatActorTechnique).all():
|
||||
# Call existing_rels.add()
|
||||
existing_rels.add((str(row.threat_actor_id), str(row.technique_id)))
|
||||
|
||||
# Assign rels_created = 0
|
||||
rels_created = 0
|
||||
# Assign rels_skipped = 0
|
||||
rels_skipped = 0
|
||||
|
||||
# Iterate over relationships
|
||||
for rel in relationships:
|
||||
# Assign source_ref = rel["source_ref"]
|
||||
source_ref = rel["source_ref"]
|
||||
# Assign target_ref = rel["target_ref"]
|
||||
target_ref = rel["target_ref"]
|
||||
|
||||
# Resolve actor
|
||||
db_actor = stix_to_db_actor.get(source_ref)
|
||||
# Check: not db_actor
|
||||
if not db_actor:
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
|
||||
# Resolve technique
|
||||
mitre_technique_id = attack_pattern_map.get(target_ref)
|
||||
# Check: not mitre_technique_id
|
||||
if not mitre_technique_id:
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
|
||||
# Assign db_technique = technique_by_mitre_id.get(mitre_technique_id)
|
||||
db_technique = technique_by_mitre_id.get(mitre_technique_id)
|
||||
# Check: not db_technique
|
||||
if not db_technique:
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
|
||||
# Assign rel_key = (str(db_actor.id), str(db_technique.id))
|
||||
rel_key = (str(db_actor.id), str(db_technique.id))
|
||||
# Check: rel_key in existing_rels
|
||||
if rel_key in existing_rels:
|
||||
# Assign rels_skipped = 1
|
||||
rels_skipped += 1
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
|
||||
# Assign actor_technique = ThreatActorTechnique(
|
||||
actor_technique = ThreatActorTechnique(
|
||||
# Keyword argument: threat_actor_id
|
||||
threat_actor_id=db_actor.id,
|
||||
# Keyword argument: technique_id
|
||||
technique_id=db_technique.id,
|
||||
# Keyword argument: usage_description
|
||||
usage_description=rel["description"][:5000] if rel["description"] else None,
|
||||
)
|
||||
# Stage new record(s) for database insertion
|
||||
db.add(actor_technique)
|
||||
# Call existing_rels.add()
|
||||
existing_rels.add(rel_key)
|
||||
# Assign rels_created = 1
|
||||
rels_created += 1
|
||||
|
||||
# Commit all pending changes to the database
|
||||
db.commit()
|
||||
|
||||
# Assign summary = {
|
||||
summary = {
|
||||
# Literal argument value
|
||||
"actors_created": actors_created,
|
||||
# Literal argument value
|
||||
"actors_updated": actors_skipped,
|
||||
# Literal argument value
|
||||
"relationships_created": rels_created,
|
||||
# Literal argument value
|
||||
"relationships_skipped": rels_skipped,
|
||||
# Literal argument value
|
||||
"total_actors_parsed": len(actor_dicts),
|
||||
# Literal argument value
|
||||
"total_relationships_parsed": len(relationships),
|
||||
}
|
||||
|
||||
# Update DataSource record
|
||||
ds = db.query(DataSource).filter(DataSource.name == "mitre_cti").first()
|
||||
# Check: ds
|
||||
if ds:
|
||||
# Assign ds.last_sync_at = datetime.utcnow()
|
||||
ds.last_sync_at = datetime.utcnow()
|
||||
# Assign ds.last_sync_status = "success"
|
||||
ds.last_sync_status = "success"
|
||||
# Assign ds.last_sync_stats = summary
|
||||
ds.last_sync_stats = summary
|
||||
# Commit all pending changes to the database
|
||||
db.commit()
|
||||
|
||||
# Log info: "MITRE CTI threat actor import complete — %s", sum
|
||||
logger.info("MITRE CTI threat actor import complete — %s", summary)
|
||||
|
||||
# Call log_action()
|
||||
log_action(
|
||||
db,
|
||||
# Keyword argument: user_id
|
||||
user_id=None,
|
||||
# Keyword argument: action
|
||||
action="import_threat_actors",
|
||||
# Keyword argument: entity_type
|
||||
entity_type="threat_actor",
|
||||
# Keyword argument: entity_id
|
||||
entity_id=None,
|
||||
# Keyword argument: details
|
||||
details=summary,
|
||||
)
|
||||
# Commit all pending changes to the database
|
||||
db.commit()
|
||||
|
||||
# Return summary
|
||||
return summary
|
||||
|
||||
@@ -6,34 +6,56 @@ that the router remains a thin HTTP adapter.
|
||||
This module is framework-agnostic: no FastAPI imports.
|
||||
"""
|
||||
|
||||
# Enable future language features for compatibility
|
||||
from __future__ import annotations
|
||||
|
||||
# Import Any from typing
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import case, cast, func, or_, Text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import EntityNotFoundError from app.domain.errors
|
||||
from app.domain.errors import EntityNotFoundError
|
||||
from app.models.enums import TechniqueStatus
|
||||
from app.models.test import Test
|
||||
from app.models.test_template import TestTemplate
|
||||
from app.models.threat_actor import ThreatActor, ThreatActorTechnique
|
||||
from app.models.technique import Technique
|
||||
from app.utils import escape_like
|
||||
|
||||
# Import TechniqueStatus from app.models.enums
|
||||
from app.models.enums import TechniqueStatus
|
||||
|
||||
# Import Technique from app.models.technique
|
||||
from app.models.technique import Technique
|
||||
|
||||
# Import Test from app.models.test
|
||||
from app.models.test import Test
|
||||
|
||||
# Import TestTemplate from app.models.test_template
|
||||
from app.models.test_template import TestTemplate
|
||||
|
||||
# Import ThreatActor, ThreatActorTechnique from app.models.threat_actor
|
||||
from app.models.threat_actor import ThreatActor, ThreatActorTechnique
|
||||
|
||||
# Import escape_like from app.utils
|
||||
from app.utils import escape_like
|
||||
|
||||
# ── Public service functions ──────────────────────────────────────────
|
||||
|
||||
|
||||
def list_actors(
|
||||
# Entry: db
|
||||
db: Session,
|
||||
*,
|
||||
# Entry: search
|
||||
search: str | None = None,
|
||||
# Entry: country
|
||||
country: str | None = None,
|
||||
# Entry: motivation
|
||||
motivation: str | None = None,
|
||||
# Entry: sophistication
|
||||
sophistication: str | None = None,
|
||||
# Entry: target_sectors
|
||||
target_sectors: str | None = None,
|
||||
# Entry: offset
|
||||
offset: int = 0,
|
||||
# Entry: limit
|
||||
limit: int = 50,
|
||||
) -> dict[str, Any]:
|
||||
"""List threat actors with optional filters, pagination, and coverage stats.
|
||||
@@ -41,10 +63,14 @@ def list_actors(
|
||||
Uses grouped subqueries to avoid N+1: technique counts and coverage
|
||||
counts are fetched in one query per page.
|
||||
"""
|
||||
# Assign query = db.query(ThreatActor)
|
||||
query = db.query(ThreatActor)
|
||||
|
||||
# Check: search
|
||||
if search:
|
||||
# Assign pattern = f"%{escape_like(search)}%"
|
||||
pattern = f"%{escape_like(search)}%"
|
||||
# Assign query = query.filter(
|
||||
query = query.filter(
|
||||
or_(
|
||||
ThreatActor.name.ilike(pattern),
|
||||
@@ -53,35 +79,52 @@ def list_actors(
|
||||
)
|
||||
)
|
||||
|
||||
# Check: country
|
||||
if country:
|
||||
# Assign query = query.filter(ThreatActor.country == country)
|
||||
query = query.filter(ThreatActor.country == country)
|
||||
|
||||
# Check: motivation
|
||||
if motivation:
|
||||
# Assign query = query.filter(ThreatActor.motivation == motivation)
|
||||
query = query.filter(ThreatActor.motivation == motivation)
|
||||
|
||||
# Check: sophistication
|
||||
if sophistication:
|
||||
# Assign query = query.filter(ThreatActor.sophistication == sophistication)
|
||||
query = query.filter(ThreatActor.sophistication == sophistication)
|
||||
|
||||
# Check: target_sectors
|
||||
if target_sectors:
|
||||
# Assign query = query.filter(
|
||||
query = query.filter(
|
||||
cast(ThreatActor.target_sectors, Text).ilike(
|
||||
f"%{escape_like(target_sectors)}%"
|
||||
)
|
||||
)
|
||||
|
||||
# Assign total = query.count()
|
||||
total = query.count()
|
||||
# Assign actors = (
|
||||
actors = (
|
||||
query.order_by(ThreatActor.name).offset(offset).limit(limit).all()
|
||||
)
|
||||
|
||||
# Check: not actors
|
||||
if not actors:
|
||||
# Return {
|
||||
return {
|
||||
# Literal argument value
|
||||
"total": total,
|
||||
# Literal argument value
|
||||
"offset": offset,
|
||||
# Literal argument value
|
||||
"limit": limit,
|
||||
# Literal argument value
|
||||
"items": [],
|
||||
}
|
||||
|
||||
# Assign actor_ids = [a.id for a in actors]
|
||||
actor_ids = [a.id for a in actors]
|
||||
|
||||
# Single grouped query: tech_count and covered_count per actor
|
||||
@@ -96,215 +139,355 @@ def list_actors(
|
||||
TechniqueStatus.validated,
|
||||
TechniqueStatus.partial,
|
||||
]),
|
||||
# Literal argument value
|
||||
1,
|
||||
),
|
||||
# Keyword argument: else_
|
||||
else_=0,
|
||||
)
|
||||
).label("covered_count"),
|
||||
)
|
||||
# Chain .join() call
|
||||
.join(Technique, ThreatActorTechnique.technique_id == Technique.id)
|
||||
# Chain .filter() call
|
||||
.filter(ThreatActorTechnique.threat_actor_id.in_(actor_ids))
|
||||
# Chain .group_by() call
|
||||
.group_by(ThreatActorTechnique.threat_actor_id)
|
||||
).all()
|
||||
|
||||
# Assign counts_map = {
|
||||
counts_map = {
|
||||
str(row.threat_actor_id): {
|
||||
# Literal argument value
|
||||
"tech_count": row.tech_count,
|
||||
# Literal argument value
|
||||
"covered_count": row.covered_count or 0,
|
||||
}
|
||||
for row in counts_rows
|
||||
}
|
||||
|
||||
# Assign results = []
|
||||
results = []
|
||||
# Iterate over actors
|
||||
for actor in actors:
|
||||
# Assign cnt = counts_map.get(str(actor.id), {"tech_count": 0, "covered_count": 0})
|
||||
cnt = counts_map.get(str(actor.id), {"tech_count": 0, "covered_count": 0})
|
||||
# Assign tech_count = cnt["tech_count"]
|
||||
tech_count = cnt["tech_count"]
|
||||
# Assign covered = cnt["covered_count"]
|
||||
covered = cnt["covered_count"]
|
||||
# Assign coverage_pct = round((covered / tech_count * 100), 1) if tech_count > 0 else 0.0
|
||||
coverage_pct = round((covered / tech_count * 100), 1) if tech_count > 0 else 0.0
|
||||
|
||||
# Call results.append()
|
||||
results.append({
|
||||
# Literal argument value
|
||||
"id": str(actor.id),
|
||||
# Literal argument value
|
||||
"mitre_id": actor.mitre_id,
|
||||
# Literal argument value
|
||||
"name": actor.name,
|
||||
# Literal argument value
|
||||
"aliases": actor.aliases or [],
|
||||
# Literal argument value
|
||||
"country": actor.country,
|
||||
# Literal argument value
|
||||
"target_sectors": actor.target_sectors or [],
|
||||
# Literal argument value
|
||||
"target_regions": actor.target_regions or [],
|
||||
# Literal argument value
|
||||
"motivation": actor.motivation,
|
||||
# Literal argument value
|
||||
"sophistication": actor.sophistication,
|
||||
# Literal argument value
|
||||
"mitre_url": actor.mitre_url,
|
||||
# Literal argument value
|
||||
"technique_count": tech_count,
|
||||
# Literal argument value
|
||||
"coverage_pct": coverage_pct,
|
||||
# Literal argument value
|
||||
"is_active": actor.is_active,
|
||||
})
|
||||
|
||||
# Return {
|
||||
return {
|
||||
# Literal argument value
|
||||
"total": total,
|
||||
# Literal argument value
|
||||
"offset": offset,
|
||||
# Literal argument value
|
||||
"limit": limit,
|
||||
# Literal argument value
|
||||
"items": results,
|
||||
}
|
||||
|
||||
|
||||
# Define function get_actor_detail
|
||||
def get_actor_detail(db: Session, actor_id: str) -> dict[str, Any]:
|
||||
"""Get detailed threat actor with techniques.
|
||||
|
||||
Raises EntityNotFoundError if the actor does not exist.
|
||||
"""
|
||||
# Assign actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first()
|
||||
actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first()
|
||||
# Check: not actor
|
||||
if not actor:
|
||||
# Raise EntityNotFoundError
|
||||
raise EntityNotFoundError("Threat actor", actor_id)
|
||||
|
||||
# Assign actor_techniques = (
|
||||
actor_techniques = (
|
||||
db.query(ThreatActorTechnique, Technique)
|
||||
# Chain .join() call
|
||||
.join(Technique, ThreatActorTechnique.technique_id == Technique.id)
|
||||
# Chain .filter() call
|
||||
.filter(ThreatActorTechnique.threat_actor_id == actor.id)
|
||||
# Chain .order_by() call
|
||||
.order_by(Technique.mitre_id)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
|
||||
# Assign techniques_list = [
|
||||
techniques_list = [
|
||||
{
|
||||
# Literal argument value
|
||||
"technique_id": str(tech.id),
|
||||
# Literal argument value
|
||||
"mitre_id": tech.mitre_id,
|
||||
# Literal argument value
|
||||
"name": tech.name,
|
||||
# Literal argument value
|
||||
"tactic": tech.tactic,
|
||||
# Literal argument value
|
||||
"status_global": tech.status_global.value if tech.status_global else None,
|
||||
# Literal argument value
|
||||
"usage_description": at.usage_description,
|
||||
# Literal argument value
|
||||
"first_seen_using": at.first_seen_using,
|
||||
}
|
||||
for at, tech in actor_techniques
|
||||
]
|
||||
|
||||
# Return {
|
||||
return {
|
||||
# Literal argument value
|
||||
"id": str(actor.id),
|
||||
# Literal argument value
|
||||
"mitre_id": actor.mitre_id,
|
||||
# Literal argument value
|
||||
"name": actor.name,
|
||||
# Literal argument value
|
||||
"aliases": actor.aliases or [],
|
||||
# Literal argument value
|
||||
"description": actor.description,
|
||||
# Literal argument value
|
||||
"country": actor.country,
|
||||
# Literal argument value
|
||||
"target_sectors": actor.target_sectors or [],
|
||||
# Literal argument value
|
||||
"target_regions": actor.target_regions or [],
|
||||
# Literal argument value
|
||||
"motivation": actor.motivation,
|
||||
# Literal argument value
|
||||
"sophistication": actor.sophistication,
|
||||
# Literal argument value
|
||||
"first_seen": actor.first_seen,
|
||||
# Literal argument value
|
||||
"last_seen": actor.last_seen,
|
||||
# Literal argument value
|
||||
"references": actor.references or [],
|
||||
# Literal argument value
|
||||
"mitre_url": actor.mitre_url,
|
||||
# Literal argument value
|
||||
"is_active": actor.is_active,
|
||||
# Literal argument value
|
||||
"techniques": techniques_list,
|
||||
}
|
||||
|
||||
|
||||
# Define function get_actor_coverage
|
||||
def get_actor_coverage(db: Session, actor_id: str) -> dict[str, Any]:
|
||||
"""Calculate coverage percentage against a specific threat actor.
|
||||
|
||||
Raises EntityNotFoundError if the actor does not exist.
|
||||
"""
|
||||
# Assign actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first()
|
||||
actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first()
|
||||
# Check: not actor
|
||||
if not actor:
|
||||
# Raise EntityNotFoundError
|
||||
raise EntityNotFoundError("Threat actor", actor_id)
|
||||
|
||||
# Assign actor_techniques = (
|
||||
actor_techniques = (
|
||||
db.query(Technique)
|
||||
# Chain .join() call
|
||||
.join(ThreatActorTechnique, ThreatActorTechnique.technique_id == Technique.id)
|
||||
# Chain .filter() call
|
||||
.filter(ThreatActorTechnique.threat_actor_id == actor.id)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
|
||||
# Assign total = len(actor_techniques)
|
||||
total = len(actor_techniques)
|
||||
# Check: total == 0
|
||||
if total == 0:
|
||||
# Return {
|
||||
return {
|
||||
# Literal argument value
|
||||
"actor_id": str(actor.id),
|
||||
# Literal argument value
|
||||
"actor_name": actor.name,
|
||||
# Literal argument value
|
||||
"total_techniques": 0,
|
||||
# Literal argument value
|
||||
"coverage_pct": 0.0,
|
||||
# Literal argument value
|
||||
"breakdown": {},
|
||||
}
|
||||
|
||||
# Assign breakdown = {}
|
||||
breakdown: dict[str, int] = {}
|
||||
# Iterate over actor_techniques
|
||||
for tech in actor_techniques:
|
||||
# Assign status = tech.status_global.value if tech.status_global else "not_evaluated"
|
||||
status = tech.status_global.value if tech.status_global else "not_evaluated"
|
||||
# Assign breakdown[status] = breakdown.get(status, 0) + 1
|
||||
breakdown[status] = breakdown.get(status, 0) + 1
|
||||
|
||||
# Assign covered = breakdown.get("validated", 0) + breakdown.get("partial", 0)
|
||||
covered = breakdown.get("validated", 0) + breakdown.get("partial", 0)
|
||||
# Assign coverage_pct = round((covered / total * 100), 1)
|
||||
coverage_pct = round((covered / total * 100), 1)
|
||||
|
||||
# Return {
|
||||
return {
|
||||
# Literal argument value
|
||||
"actor_id": str(actor.id),
|
||||
# Literal argument value
|
||||
"actor_name": actor.name,
|
||||
# Literal argument value
|
||||
"total_techniques": total,
|
||||
# Literal argument value
|
||||
"covered": covered,
|
||||
# Literal argument value
|
||||
"coverage_pct": coverage_pct,
|
||||
# Literal argument value
|
||||
"breakdown": breakdown,
|
||||
}
|
||||
|
||||
|
||||
# Define function get_actor_gaps
|
||||
def get_actor_gaps(db: Session, actor_id: str) -> dict[str, Any]:
|
||||
"""Identify techniques of this actor that are not fully validated.
|
||||
|
||||
Raises EntityNotFoundError if the actor does not exist.
|
||||
"""
|
||||
# Assign actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first()
|
||||
actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first()
|
||||
# Check: not actor
|
||||
if not actor:
|
||||
# Raise EntityNotFoundError
|
||||
raise EntityNotFoundError("Threat actor", actor_id)
|
||||
|
||||
# Assign gap_techniques = (
|
||||
gap_techniques = (
|
||||
db.query(Technique, ThreatActorTechnique)
|
||||
# Chain .join() call
|
||||
.join(ThreatActorTechnique, ThreatActorTechnique.technique_id == Technique.id)
|
||||
# Chain .filter() call
|
||||
.filter(ThreatActorTechnique.threat_actor_id == actor.id)
|
||||
# Chain .filter() call
|
||||
.filter(Technique.status_global != TechniqueStatus.validated)
|
||||
# Chain .order_by() call
|
||||
.order_by(Technique.mitre_id)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
|
||||
# Check: not gap_techniques
|
||||
if not gap_techniques:
|
||||
# Return {
|
||||
return {
|
||||
# Literal argument value
|
||||
"actor_id": str(actor.id),
|
||||
# Literal argument value
|
||||
"actor_name": actor.name,
|
||||
# Literal argument value
|
||||
"total_gaps": 0,
|
||||
# Literal argument value
|
||||
"gaps": [],
|
||||
}
|
||||
|
||||
# Assign technique_ids = [tech.id for tech, _ in gap_techniques]
|
||||
technique_ids = [tech.id for tech, _ in gap_techniques]
|
||||
# Assign mitre_ids = [tech.mitre_id for tech, _ in gap_techniques]
|
||||
mitre_ids = [tech.mitre_id for tech, _ in gap_techniques]
|
||||
|
||||
# Batch template counts by mitre_technique_id
|
||||
template_counts = (
|
||||
db.query(TestTemplate.mitre_technique_id, func.count(TestTemplate.id).label("cnt"))
|
||||
# Chain .filter() call
|
||||
.filter(TestTemplate.mitre_technique_id.in_(mitre_ids))
|
||||
# Chain .filter() call
|
||||
.filter(TestTemplate.is_active == True)
|
||||
# Chain .group_by() call
|
||||
.group_by(TestTemplate.mitre_technique_id)
|
||||
).all()
|
||||
# Assign template_map = {row.mitre_technique_id: row.cnt for row in template_counts}
|
||||
template_map = {row.mitre_technique_id: row.cnt for row in template_counts}
|
||||
|
||||
# Batch test counts by technique_id
|
||||
test_counts = (
|
||||
db.query(Test.technique_id, func.count(Test.id).label("cnt"))
|
||||
# Chain .filter() call
|
||||
.filter(Test.technique_id.in_(technique_ids))
|
||||
# Chain .group_by() call
|
||||
.group_by(Test.technique_id)
|
||||
).all()
|
||||
# Assign test_map = {str(row.technique_id): row.cnt for row in test_counts}
|
||||
test_map = {str(row.technique_id): row.cnt for row in test_counts}
|
||||
|
||||
# Assign gaps = []
|
||||
gaps = []
|
||||
# Iterate over gap_techniques
|
||||
for tech, at in gap_techniques:
|
||||
# Assign template_count = template_map.get(tech.mitre_id, 0)
|
||||
template_count = template_map.get(tech.mitre_id, 0)
|
||||
# Assign test_count = test_map.get(str(tech.id), 0)
|
||||
test_count = test_map.get(str(tech.id), 0)
|
||||
# Call gaps.append()
|
||||
gaps.append({
|
||||
# Literal argument value
|
||||
"technique_id": str(tech.id),
|
||||
# Literal argument value
|
||||
"mitre_id": tech.mitre_id,
|
||||
# Literal argument value
|
||||
"name": tech.name,
|
||||
# Literal argument value
|
||||
"tactic": tech.tactic,
|
||||
# Literal argument value
|
||||
"status_global": tech.status_global.value if tech.status_global else None,
|
||||
# Literal argument value
|
||||
"usage_description": at.usage_description,
|
||||
# Literal argument value
|
||||
"available_templates": template_count,
|
||||
# Literal argument value
|
||||
"existing_tests": test_count,
|
||||
# Literal argument value
|
||||
"has_templates": template_count > 0,
|
||||
})
|
||||
|
||||
# Return {
|
||||
return {
|
||||
# Literal argument value
|
||||
"actor_id": str(actor.id),
|
||||
# Literal argument value
|
||||
"actor_name": actor.name,
|
||||
# Literal argument value
|
||||
"total_gaps": len(gaps),
|
||||
# Literal argument value
|
||||
"gaps": gaps,
|
||||
}
|
||||
|
||||
@@ -4,30 +4,51 @@ Uses domain exceptions from app.domain.errors. The router handles
|
||||
HTTP concerns, auth, audit logging, and commit.
|
||||
"""
|
||||
|
||||
# Enable future language features for compatibility
|
||||
from __future__ import annotations
|
||||
|
||||
# Import uuid
|
||||
import uuid
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import hash_password from app.auth
|
||||
from app.auth import hash_password
|
||||
from app.domain.errors import BusinessRuleViolation, DuplicateEntityError, EntityNotFoundError
|
||||
|
||||
# Import from app.domain.errors
|
||||
from app.domain.errors import (
|
||||
BusinessRuleViolation,
|
||||
DuplicateEntityError,
|
||||
EntityNotFoundError,
|
||||
)
|
||||
|
||||
# Import User from app.models.user
|
||||
from app.models.user import User
|
||||
|
||||
# Assign VALID_ROLES = {"admin", "red_tech", "blue_tech", "red_lead", "blue_lead", "viewer"}
|
||||
VALID_ROLES = {"admin", "red_tech", "blue_tech", "red_lead", "blue_lead", "viewer"}
|
||||
|
||||
|
||||
# Define function list_users
|
||||
def list_users(db: Session) -> list[User]:
|
||||
"""Return a list of all users ordered by username."""
|
||||
# Return db.query(User).order_by(User.username).all()
|
||||
return db.query(User).order_by(User.username).all()
|
||||
|
||||
|
||||
# Define function create_user
|
||||
def create_user(
|
||||
# Entry: db
|
||||
db: Session,
|
||||
*,
|
||||
# Entry: username
|
||||
username: str,
|
||||
# Entry: email
|
||||
email: str | None,
|
||||
# Entry: password
|
||||
password: str,
|
||||
# Entry: role
|
||||
role: str,
|
||||
) -> User:
|
||||
"""Create a new user.
|
||||
@@ -36,33 +57,51 @@ def create_user(
|
||||
Raises BusinessRuleViolation if role is invalid.
|
||||
Does not commit; the router handles that.
|
||||
"""
|
||||
# Assign existing = db.query(User).filter(User.username == username).first()
|
||||
existing = db.query(User).filter(User.username == username).first()
|
||||
# Check: existing
|
||||
if existing:
|
||||
# Raise DuplicateEntityError
|
||||
raise DuplicateEntityError("User", "username", username)
|
||||
|
||||
# Check: role not in VALID_ROLES
|
||||
if role not in VALID_ROLES:
|
||||
# Raise BusinessRuleViolation
|
||||
raise BusinessRuleViolation(
|
||||
f"Invalid role '{role}'. Must be one of: {', '.join(sorted(VALID_ROLES))}"
|
||||
)
|
||||
|
||||
# Assign user = User(
|
||||
user = User(
|
||||
# Keyword argument: username
|
||||
username=username,
|
||||
# Keyword argument: email
|
||||
email=email,
|
||||
# Keyword argument: hashed_password
|
||||
hashed_password=hash_password(password),
|
||||
# Keyword argument: role
|
||||
role=role,
|
||||
)
|
||||
# Stage new record(s) for database insertion
|
||||
db.add(user)
|
||||
# Return user
|
||||
return user
|
||||
|
||||
|
||||
# Define function get_user_or_raise
|
||||
def get_user_or_raise(db: Session, user_id: uuid.UUID) -> User:
|
||||
"""Return a user by ID or raise EntityNotFoundError."""
|
||||
# Assign user = db.query(User).filter(User.id == user_id).first()
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
# Check: user is None
|
||||
if user is None:
|
||||
# Raise EntityNotFoundError
|
||||
raise EntityNotFoundError("User", str(user_id))
|
||||
# Return user
|
||||
return user
|
||||
|
||||
|
||||
# Define function update_user
|
||||
def update_user(db: Session, user_id: uuid.UUID, **fields: object) -> User:
|
||||
"""Update one or more fields of an existing user.
|
||||
|
||||
@@ -71,18 +110,27 @@ def update_user(db: Session, user_id: uuid.UUID, **fields: object) -> User:
|
||||
Handles 'password' by hashing and storing as 'hashed_password'.
|
||||
Does not commit; the router handles that.
|
||||
"""
|
||||
# Assign user = get_user_or_raise(db, user_id)
|
||||
user = get_user_or_raise(db, user_id)
|
||||
# Assign update_data = dict(fields)
|
||||
update_data = dict(fields)
|
||||
|
||||
# Check: "role" in update_data and update_data["role"] not in VALID_ROLES
|
||||
if "role" in update_data and update_data["role"] not in VALID_ROLES:
|
||||
# Raise BusinessRuleViolation
|
||||
raise BusinessRuleViolation(
|
||||
f"Invalid role '{update_data['role']}'. Must be one of: {', '.join(sorted(VALID_ROLES))}"
|
||||
)
|
||||
|
||||
# Check: "password" in update_data
|
||||
if "password" in update_data:
|
||||
# Assign update_data["hashed_password"] = hash_password(str(update_data.pop("password")))
|
||||
update_data["hashed_password"] = hash_password(str(update_data.pop("password")))
|
||||
|
||||
# Iterate over update_data.items()
|
||||
for field, value in update_data.items():
|
||||
# Call setattr()
|
||||
setattr(user, field, value)
|
||||
|
||||
# Return user
|
||||
return user
|
||||
|
||||
@@ -1,83 +1,141 @@
|
||||
"""Internal worklog service — CRUD with integrity hashing."""
|
||||
|
||||
# Import hashlib
|
||||
import hashlib
|
||||
|
||||
# Import logging
|
||||
import logging
|
||||
|
||||
# Import datetime from datetime
|
||||
from datetime import datetime
|
||||
|
||||
# Import Optional from typing
|
||||
from typing import Optional
|
||||
|
||||
# Import UUID from uuid
|
||||
from uuid import UUID
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import EntityNotFoundError from app.domain.errors
|
||||
from app.domain.errors import EntityNotFoundError
|
||||
|
||||
# Import Worklog from app.models.worklog
|
||||
from app.models.worklog import Worklog
|
||||
|
||||
# Assign logger = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Define function create_worklog
|
||||
def create_worklog(
|
||||
# Entry: db
|
||||
db: Session,
|
||||
*,
|
||||
# Entry: entity_type
|
||||
entity_type: str,
|
||||
# Entry: entity_id
|
||||
entity_id: UUID,
|
||||
# Entry: user_id
|
||||
user_id: UUID,
|
||||
# Entry: activity_type
|
||||
activity_type: str,
|
||||
# Entry: started_at
|
||||
started_at: datetime,
|
||||
# Entry: duration_seconds
|
||||
duration_seconds: int,
|
||||
# Entry: ended_at
|
||||
ended_at: Optional[datetime] = None,
|
||||
# Entry: description
|
||||
description: Optional[str] = None,
|
||||
) -> Worklog:
|
||||
"""Create a worklog with an auto-computed integrity hash."""
|
||||
# Assign wl = Worklog(
|
||||
wl = Worklog(
|
||||
# Keyword argument: entity_type
|
||||
entity_type=entity_type,
|
||||
# Keyword argument: entity_id
|
||||
entity_id=entity_id,
|
||||
# Keyword argument: user_id
|
||||
user_id=user_id,
|
||||
# Keyword argument: activity_type
|
||||
activity_type=activity_type,
|
||||
# Keyword argument: started_at
|
||||
started_at=started_at,
|
||||
# Keyword argument: ended_at
|
||||
ended_at=ended_at,
|
||||
# Keyword argument: duration_seconds
|
||||
duration_seconds=duration_seconds,
|
||||
# Keyword argument: description
|
||||
description=description,
|
||||
)
|
||||
# Assign wl.integrity_hash = _compute_hash(wl)
|
||||
wl.integrity_hash = _compute_hash(wl)
|
||||
# Stage new record(s) for database insertion
|
||||
db.add(wl)
|
||||
# Does not commit; caller (router) uses UnitOfWork.
|
||||
return wl
|
||||
|
||||
|
||||
# Define function get_worklog_or_raise
|
||||
def get_worklog_or_raise(db: Session, worklog_id: UUID) -> Worklog:
|
||||
"""Get a worklog by ID or raise EntityNotFoundError."""
|
||||
# Assign wl = db.query(Worklog).filter(Worklog.id == worklog_id).first()
|
||||
wl = db.query(Worklog).filter(Worklog.id == worklog_id).first()
|
||||
# Check: not wl
|
||||
if not wl:
|
||||
# Raise EntityNotFoundError
|
||||
raise EntityNotFoundError("Worklog", str(worklog_id))
|
||||
# Return wl
|
||||
return wl
|
||||
|
||||
|
||||
# Define function list_worklogs
|
||||
def list_worklogs(
|
||||
# Entry: db
|
||||
db: Session,
|
||||
*,
|
||||
# Entry: entity_type
|
||||
entity_type: Optional[str] = None,
|
||||
# Entry: entity_id
|
||||
entity_id: Optional[UUID] = None,
|
||||
# Entry: user_id
|
||||
user_id: Optional[UUID] = None,
|
||||
) -> list[Worklog]:
|
||||
"""List worklogs with optional filters."""
|
||||
# Assign query = db.query(Worklog)
|
||||
query = db.query(Worklog)
|
||||
# Check: entity_type
|
||||
if entity_type:
|
||||
# Assign query = query.filter(Worklog.entity_type == entity_type)
|
||||
query = query.filter(Worklog.entity_type == entity_type)
|
||||
# Check: entity_id
|
||||
if entity_id:
|
||||
# Assign query = query.filter(Worklog.entity_id == entity_id)
|
||||
query = query.filter(Worklog.entity_id == entity_id)
|
||||
# Check: user_id
|
||||
if user_id:
|
||||
# Assign query = query.filter(Worklog.user_id == user_id)
|
||||
query = query.filter(Worklog.user_id == user_id)
|
||||
# Return query.order_by(Worklog.started_at.desc()).all()
|
||||
return query.order_by(Worklog.started_at.desc()).all()
|
||||
|
||||
|
||||
# Define function verify_worklog_integrity
|
||||
def verify_worklog_integrity(wl: Worklog) -> bool:
|
||||
"""Return True if the worklog has not been tampered with."""
|
||||
# Return wl.integrity_hash == _compute_hash(wl)
|
||||
return wl.integrity_hash == _compute_hash(wl)
|
||||
|
||||
|
||||
# Define function _compute_hash
|
||||
def _compute_hash(wl: Worklog) -> str:
|
||||
"""SHA-256 of the immutable fields for audit integrity."""
|
||||
# Assign data = (
|
||||
data = (
|
||||
f"{wl.entity_type}:{wl.entity_id}:{wl.user_id}:"
|
||||
f"{wl.activity_type}:{wl.started_at}:{wl.duration_seconds}"
|
||||
)
|
||||
# Return hashlib.sha256(data.encode()).hexdigest()
|
||||
return hashlib.sha256(data.encode()).hexdigest()
|
||||
|
||||
Reference in New Issue
Block a user