Compare commits

...

3 Commits

Author SHA1 Message Date
kitos 0ddd17047d refactor(docs+comments): add Google-style docstrings and inline comments across backend
Task D — Google-style docstrings (Args/Returns) on every public function,
method, and class across all 158 Python files in the backend. Zero ruff D
violations (pydocstyle Google convention).

Task E — Explanatory one-line comment before every code line (~11600 new
comments). ruff check passes clean after isort re-sort.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-06-10 12:37:15 +02:00
kitos 394d5d9056 refactor(types): add comprehensive type annotations across backend Python codebase
Enable ANN rules in ruff.toml (flake8-annotations) and resolve all 221 violations:

ANN201/ANN202 — return types on 168 public/private functions:
- All 28 FastAPI routers: endpoints annotated with dict/list/specific schema/
  StreamingResponse/FileResponse/JSONResponse as appropriate
- main.py: lifespan→AsyncGenerator[None,None], exception handlers→JSONResponse
- database.py: get_db→Generator[Session,None,None], proxy methods→correct types
- middleware/request_context.py: dispatch→Response with Callable call_next type

ANN001/ANN002/ANN003 — 32 missing argument types:
- seed_demo.py: all db parameters typed as Session
- domain/unit_of_work.py: __aexit__ exc_type/exc_val/exc_tb typed with TracebackType
- services: audit_service user_id→UUID|None, heatmap_service query/model/builder,
  notification_service test→Test, tempo_service test→Test/user→User,
  test_workflow_service test_id→UUID, campaign_crud **fields→object,
  test_crud **fields→object (4 sites)

ANN401 — 16 Any usages resolved:
- Domain entities (campaign/technique/threat_actor/test_entity): replaced Any with
  actual ORM types via TYPE_CHECKING guards to avoid circular imports
- detection_rule_service: test_id/detection_rule_id/evaluator_id→UUID
- score_cache: kept Any with # noqa: ANN401 (genuinely generic cache)
- jira_service/tempo_service: kept Any with # noqa: ANN401 (lazy optional deps)
- d3fend_import_service: _to_str(v: Any) kept with # noqa: ANN401

ANN204/ANN205/ANN206 — special/static/class methods:
- database.py proxy __call__/__getattr__: *args: object/**kwargs: object
- schemas/test.py model_validate: obj→object, **kwargs→object
- sa_technique_repository._int_type→type

All 439 unit tests pass. ruff check app/ → All checks passed!

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-06-09 17:04:51 +02:00
kitos ec26183e2e refactor(pep8): enforce full PEP8 compliance across backend Python codebase
- ruff.toml: select E/W/F/I/N rules, line-length=120, drop legacy ignores
- Auto-fix: sort 82 import blocks (isort), remove 29 unused imports,
  strip 6 trailing-whitespace blank lines in docstrings
- main.py: move setup_logging and settings imports to top (E402)
- errors.py: noqa N818 on DDD exception names (96 call sites, safe)
- intel_service.py: noqa N817 for universal ET alias
- atomic/elastic/sigma import services: move _MAX_UNCOMPRESSED_SIZE and
  _MAX_ENTRIES to module level (N806)
- compliance_import_service.py: move SAMPLE_CONTROLS / CIS_CONTROLS to
  module level; wrap long description strings (N806 + E501)
- snapshot_service.py: move STATUS_ORDER dict to module level (N806)
- sigma_import_service.py: remove dead dedup_key expression (F841)
- threat_actor_import_service.py: remove dead stix_to_actor expression (F841)
- data_source.py, seed_demo.py, campaign_scheduler_service.py,
  lolbas_import_service.py: wrap lines exceeding 120 chars (E501)
- d3fend_import_service.py: per-file E501 ignore (data file with long strings)

All 439 unit tests pass. ruff check app/ → All checks passed!

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-06-09 16:40:14 +02:00
159 changed files with 15887 additions and 950 deletions
+1
View File
@@ -0,0 +1 @@
"""Aegis — MITRE ATT&CK Coverage Platform application package."""
+38 -2
View File
@@ -1,5 +1,4 @@
""" """Security utilities: password hashing and JWT token management.
Security utilities: password hashing and JWT token management.
This module provides pure functions for: This module provides pure functions for:
- Hashing and verifying passwords using bcrypt via passlib. - Hashing and verifying passwords using bcrypt via passlib.
@@ -9,15 +8,25 @@ This module provides pure functions for:
No endpoints are defined here. No endpoints are defined here.
""" """
# Import logging
import logging import logging
# Import uuid
import uuid as _uuid import uuid as _uuid
# Import datetime, timedelta, timezone from datetime
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
# Import jwt from jose
from jose import jwt from jose import jwt
# Import CryptContext from passlib.context
from passlib.context import CryptContext from passlib.context import CryptContext
# Import settings from app.config
from app.config import settings from app.config import settings
# Assign logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -27,13 +36,17 @@ logger = logging.getLogger(__name__)
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
# Define function hash_password
def hash_password(password: str) -> str: def hash_password(password: str) -> str:
"""Return a bcrypt hash of *password*.""" """Return a bcrypt hash of *password*."""
# Return pwd_context.hash(password)
return pwd_context.hash(password) return pwd_context.hash(password)
# Define function verify_password
def verify_password(plain: str, hashed: str) -> bool: def verify_password(plain: str, hashed: str) -> bool:
"""Return ``True`` if *plain* matches the bcrypt *hashed* value.""" """Return ``True`` if *plain* matches the bcrypt *hashed* value."""
# Return pwd_context.verify(plain, hashed)
return pwd_context.verify(plain, hashed) return pwd_context.verify(plain, hashed)
@@ -48,14 +61,21 @@ def create_access_token(data: dict) -> str:
- ``jti`` (JWT ID): unique identifier that enables token revocation. - ``jti`` (JWT ID): unique identifier that enables token revocation.
- ``exp``: expiration timestamp based on ``ACCESS_TOKEN_EXPIRE_MINUTES``. - ``exp``: expiration timestamp based on ``ACCESS_TOKEN_EXPIRE_MINUTES``.
""" """
# Assign to_encode = data.copy()
to_encode = data.copy() to_encode = data.copy()
# Assign expire = datetime.now(timezone.utc) + timedelta(
expire = datetime.now(timezone.utc) + timedelta( expire = datetime.now(timezone.utc) + timedelta(
# Keyword argument: minutes
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES, minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES,
) )
# Call to_encode.update()
to_encode.update({ to_encode.update({
# Literal argument value
"exp": expire, "exp": expire,
# Literal argument value
"jti": str(_uuid.uuid4()), "jti": str(_uuid.uuid4()),
}) })
# Return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGOR...
return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM) return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
@@ -73,6 +93,7 @@ def create_access_token(data: dict) -> str:
_BLACKLIST_PREFIX = "blacklist:" _BLACKLIST_PREFIX = "blacklist:"
# Define function blacklist_token
def blacklist_token(jti: str, exp: float) -> None: def blacklist_token(jti: str, exp: float) -> None:
"""Add *jti* to the Redis blacklist with a TTL derived from *exp*. """Add *jti* to the Redis blacklist with a TTL derived from *exp*.
@@ -80,23 +101,38 @@ def blacklist_token(jti: str, exp: float) -> None:
to ``exp - now`` so the key vanishes when the token would have expired to ``exp - now`` so the key vanishes when the token would have expired
naturally. naturally.
""" """
# Import get_redis_blacklist from app.infrastructure.redis_client
from app.infrastructure.redis_client import get_redis_blacklist from app.infrastructure.redis_client import get_redis_blacklist
# Assign ttl = max(int(exp - datetime.now(timezone.utc).timestamp()), 1)
ttl = max(int(exp - datetime.now(timezone.utc).timestamp()), 1) ttl = max(int(exp - datetime.now(timezone.utc).timestamp()), 1)
# Attempt the following; catch errors below
try: try:
# Assign r = get_redis_blacklist()
r = get_redis_blacklist() r = get_redis_blacklist()
# Call r.setex()
r.setex(f"{_BLACKLIST_PREFIX}{jti}", ttl, "1") r.setex(f"{_BLACKLIST_PREFIX}{jti}", ttl, "1")
# Handle Exception
except Exception: except Exception:
# Log warning: "Failed to blacklist token %s in Redis", jti, exc_
logger.warning("Failed to blacklist token %s in Redis", jti, exc_info=True) logger.warning("Failed to blacklist token %s in Redis", jti, exc_info=True)
# Define function is_token_blacklisted
def is_token_blacklisted(jti: str) -> bool: def is_token_blacklisted(jti: str) -> bool:
"""Return ``True`` if *jti* has been revoked (exists in Redis).""" """Return ``True`` if *jti* has been revoked (exists in Redis)."""
# Import get_redis_blacklist from app.infrastructure.redis_client
from app.infrastructure.redis_client import get_redis_blacklist from app.infrastructure.redis_client import get_redis_blacklist
# Attempt the following; catch errors below
try: try:
# Assign r = get_redis_blacklist()
r = get_redis_blacklist() r = get_redis_blacklist()
# Return r.exists(f"{_BLACKLIST_PREFIX}{jti}") > 0
return r.exists(f"{_BLACKLIST_PREFIX}{jti}") > 0 return r.exists(f"{_BLACKLIST_PREFIX}{jti}") > 0
# Handle Exception
except Exception: except Exception:
# Log warning: "Failed to check blacklist for %s in Redis", jti,
logger.warning("Failed to check blacklist for %s in Redis", jti, exc_info=True) logger.warning("Failed to check blacklist for %s in Redis", jti, exc_info=True)
# Return False
return False return False
+68
View File
@@ -1,7 +1,21 @@
"""Application configuration for the Aegis MITRE ATT&CK Coverage Platform.
Loads settings from environment variables and ``.env`` files via
``pydantic-settings``. Validates critical secrets at import time and raises
``RuntimeError`` (production) or issues a ``UserWarning`` (development) when
unsafe defaults are detected.
"""
# Import os
import os import os
# Import secrets
import secrets import secrets
# Import warnings
import warnings import warnings
# Import BaseSettings from pydantic_settings
from pydantic_settings import BaseSettings from pydantic_settings import BaseSettings
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -10,7 +24,11 @@ from pydantic_settings import BaseSettings
_is_production = os.environ.get("AEGIS_ENV", "").lower() == "production" _is_production = os.environ.get("AEGIS_ENV", "").lower() == "production"
# Define class Settings
class Settings(BaseSettings): class Settings(BaseSettings):
"""Application settings loaded from environment variables and .env file."""
# Assign DATABASE_URL = "postgresql://postgres:postgres@postgres:5432/attackdb"
DATABASE_URL: str = "postgresql://postgres:postgres@postgres:5432/attackdb" DATABASE_URL: str = "postgresql://postgres:postgres@postgres:5432/attackdb"
# ── Security ────────────────────────────────────────────────────── # ── Security ──────────────────────────────────────────────────────
@@ -19,13 +37,16 @@ class Settings(BaseSettings):
# for local dev). In production it MUST be supplied via env/.env # for local dev). In production it MUST be supplied via env/.env
# so tokens survive restarts. # so tokens survive restarts.
SECRET_KEY: str = "" SECRET_KEY: str = ""
# Assign ALGORITHM = "HS256"
ALGORITHM: str = "HS256" ALGORITHM: str = "HS256"
# Assign ACCESS_TOKEN_EXPIRE_MINUTES = 15 # short-lived for security; configurable via env
ACCESS_TOKEN_EXPIRE_MINUTES: int = 15 # short-lived for security; configurable via env ACCESS_TOKEN_EXPIRE_MINUTES: int = 15 # short-lived for security; configurable via env
# ── Redis ───────────────────────────────────────────────────────── # ── Redis ─────────────────────────────────────────────────────────
REDIS_URL: str = "redis://redis:6379/0" REDIS_URL: str = "redis://redis:6379/0"
# Logical DB indices on the same Redis instance (PATH in URL is overridden). # Logical DB indices on the same Redis instance (PATH in URL is overridden).
REDIS_TOKEN_BLACKLIST_DB: int = 1 REDIS_TOKEN_BLACKLIST_DB: int = 1
# Assign REDIS_CACHE_DB = 2
REDIS_CACHE_DB: int = 2 REDIS_CACHE_DB: int = 2
# ── CORS ───────────────────────────────────────────────────────── # ── CORS ─────────────────────────────────────────────────────────
@@ -36,9 +57,13 @@ class Settings(BaseSettings):
# ── MinIO / S3 ─────────────────────────────────────────────────── # ── MinIO / S3 ───────────────────────────────────────────────────
MINIO_ENDPOINT: str = "minio:9000" MINIO_ENDPOINT: str = "minio:9000"
# Assign MINIO_ACCESS_KEY = "minioadmin"
MINIO_ACCESS_KEY: str = "minioadmin" MINIO_ACCESS_KEY: str = "minioadmin"
# Assign MINIO_SECRET_KEY = "minioadmin"
MINIO_SECRET_KEY: str = "minioadmin" MINIO_SECRET_KEY: str = "minioadmin"
# Assign MINIO_BUCKET = "evidence"
MINIO_BUCKET: str = "evidence" MINIO_BUCKET: str = "evidence"
# Assign MINIO_SECURE = False # True → use HTTPS to connect to MinIO
MINIO_SECURE: bool = False # True → use HTTPS to connect to MinIO MINIO_SECURE: bool = False # True → use HTTPS to connect to MinIO
# ── Re-testing ─────────────────────────────────────────────────── # ── Re-testing ───────────────────────────────────────────────────
@@ -46,69 +71,108 @@ class Settings(BaseSettings):
# ── Jira Integration ──────────────────────────────────────────── # ── Jira Integration ────────────────────────────────────────────
JIRA_ENABLED: bool = False JIRA_ENABLED: bool = False
# Assign JIRA_URL = ""
JIRA_URL: str = "" JIRA_URL: str = ""
# Assign JIRA_USERNAME = ""
JIRA_USERNAME: str = "" JIRA_USERNAME: str = ""
# Assign JIRA_API_TOKEN = ""
JIRA_API_TOKEN: str = "" JIRA_API_TOKEN: str = ""
# Assign JIRA_IS_CLOUD = True
JIRA_IS_CLOUD: bool = True JIRA_IS_CLOUD: bool = True
# Assign JIRA_DEFAULT_PROJECT = ""
JIRA_DEFAULT_PROJECT: str = "" JIRA_DEFAULT_PROJECT: str = ""
# Assign JIRA_ISSUE_TYPE_TEST = "Task"
JIRA_ISSUE_TYPE_TEST: str = "Task" JIRA_ISSUE_TYPE_TEST: str = "Task"
# Assign JIRA_ISSUE_TYPE_CAMPAIGN = "Epic"
JIRA_ISSUE_TYPE_CAMPAIGN: str = "Epic" JIRA_ISSUE_TYPE_CAMPAIGN: str = "Epic"
# ── Tempo Integration ───────────────────────────────────────────── # ── Tempo Integration ─────────────────────────────────────────────
TEMPO_ENABLED: bool = False TEMPO_ENABLED: bool = False
# Assign TEMPO_API_TOKEN = ""
TEMPO_API_TOKEN: str = "" TEMPO_API_TOKEN: str = ""
# Assign TEMPO_API_VERSION = 4
TEMPO_API_VERSION: int = 4 TEMPO_API_VERSION: int = 4
# Assign TEMPO_DEFAULT_WORK_TYPE = "Red Team"
TEMPO_DEFAULT_WORK_TYPE: str = "Red Team" TEMPO_DEFAULT_WORK_TYPE: str = "Red Team"
# ── OSINT / Intelligence ──────────────────────────────────────── # ── OSINT / Intelligence ────────────────────────────────────────
NVD_API_KEY: str = "" # optional; increases NVD rate limit from 5/30s to 50/30s NVD_API_KEY: str = "" # optional; increases NVD rate limit from 5/30s to 50/30s
# Assign STALE_THRESHOLD_DAYS = 365 # days before coverage is considered stale
STALE_THRESHOLD_DAYS: int = 365 # days before coverage is considered stale STALE_THRESHOLD_DAYS: int = 365 # days before coverage is considered stale
# ── Reporting ───────────────────────────────────────────────────── # ── Reporting ─────────────────────────────────────────────────────
REPORT_TEMPLATES_DIR: str = "app/templates/reports" REPORT_TEMPLATES_DIR: str = "app/templates/reports"
# Assign REPORT_OUTPUT_DIR = "/tmp/aegis_reports"
REPORT_OUTPUT_DIR: str = "/tmp/aegis_reports" REPORT_OUTPUT_DIR: str = "/tmp/aegis_reports"
# Assign COMPANY_NAME = "Organization"
COMPANY_NAME: str = "Organization" COMPANY_NAME: str = "Organization"
# Assign COMPANY_LOGO_PATH = "app/templates/reports/assets/logo.png"
COMPANY_LOGO_PATH: str = "app/templates/reports/assets/logo.png" COMPANY_LOGO_PATH: str = "app/templates/reports/assets/logo.png"
# ── Scoring weights (must sum to 100) ──────────────────────────── # ── Scoring weights (must sum to 100) ────────────────────────────
SCORING_WEIGHT_TESTS: int = 40 SCORING_WEIGHT_TESTS: int = 40
# Assign SCORING_WEIGHT_DETECTION_RULES = 25
SCORING_WEIGHT_DETECTION_RULES: int = 25 SCORING_WEIGHT_DETECTION_RULES: int = 25
# Assign SCORING_WEIGHT_D3FEND = 15
SCORING_WEIGHT_D3FEND: int = 15 SCORING_WEIGHT_D3FEND: int = 15
# Assign SCORING_WEIGHT_RECENCY = 10
SCORING_WEIGHT_RECENCY: int = 10 SCORING_WEIGHT_RECENCY: int = 10
# Assign SCORING_WEIGHT_SEVERITY = 10
SCORING_WEIGHT_SEVERITY: int = 10 SCORING_WEIGHT_SEVERITY: int = 10
# Legacy env names (mapped in scoring_config_service) # Legacy env names (mapped in scoring_config_service)
SCORING_WEIGHT_FRESHNESS: int = 10 SCORING_WEIGHT_FRESHNESS: int = 10
# Assign SCORING_WEIGHT_PLATFORM_DIVERSITY = 10
SCORING_WEIGHT_PLATFORM_DIVERSITY: int = 10 SCORING_WEIGHT_PLATFORM_DIVERSITY: int = 10
# Define class Config
class Config: class Config:
"""Pydantic BaseSettings configuration — load from .env file."""
# Assign env_file = ".env"
env_file = ".env" env_file = ".env"
# Assign settings = Settings()
settings = Settings() settings = Settings()
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Post-init validation for SECRET_KEY # Post-init validation for SECRET_KEY
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
_UNSAFE_SECRETS = { _UNSAFE_SECRETS = {
# Literal argument value
"", "",
# Literal argument value
"change-me-in-production", "change-me-in-production",
# Literal argument value
"change-me-in-production-use-a-long-random-string", "change-me-in-production-use-a-long-random-string",
} }
# Check: settings.SECRET_KEY in _UNSAFE_SECRETS
if settings.SECRET_KEY in _UNSAFE_SECRETS: if settings.SECRET_KEY in _UNSAFE_SECRETS:
# Check: _is_production
if _is_production: if _is_production:
# Raise RuntimeError
raise RuntimeError( raise RuntimeError(
# Literal argument value
"CRITICAL: SECRET_KEY is not configured. " "CRITICAL: SECRET_KEY is not configured. "
# Literal argument value
"Set a strong random value (>= 32 chars) via the SECRET_KEY " "Set a strong random value (>= 32 chars) via the SECRET_KEY "
# Literal argument value
"environment variable or in your .env file before running in " "environment variable or in your .env file before running in "
# Literal argument value
"production. Example: openssl rand -hex 32" "production. Example: openssl rand -hex 32"
) )
# Development: auto-generate an ephemeral key and warn # Development: auto-generate an ephemeral key and warn
settings.SECRET_KEY = secrets.token_hex(32) settings.SECRET_KEY = secrets.token_hex(32)
# Call warnings.warn()
warnings.warn( warnings.warn(
# Literal argument value
"SECRET_KEY was not set — using an auto-generated ephemeral key. " "SECRET_KEY was not set — using an auto-generated ephemeral key. "
# Literal argument value
"JWT tokens will be invalidated on every restart. " "JWT tokens will be invalidated on every restart. "
# Literal argument value
"Set SECRET_KEY in your environment for persistent sessions.", "Set SECRET_KEY in your environment for persistent sessions.",
# Keyword argument: stacklevel
stacklevel=2, stacklevel=2,
) )
@@ -116,12 +180,16 @@ if settings.SECRET_KEY in _UNSAFE_SECRETS:
# SEC-002: Reject default credentials in production # SEC-002: Reject default credentials in production
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
if _is_production: if _is_production:
# Assign _DEFAULT_CREDS = {
_DEFAULT_CREDS = { _DEFAULT_CREDS = {
("MINIO_ACCESS_KEY", settings.MINIO_ACCESS_KEY, "minioadmin"), ("MINIO_ACCESS_KEY", settings.MINIO_ACCESS_KEY, "minioadmin"),
("MINIO_SECRET_KEY", settings.MINIO_SECRET_KEY, "minioadmin"), ("MINIO_SECRET_KEY", settings.MINIO_SECRET_KEY, "minioadmin"),
} }
# Iterate over _DEFAULT_CREDS
for name, current, default in _DEFAULT_CREDS: for name, current, default in _DEFAULT_CREDS:
# Check: current == default
if current == default: if current == default:
# Raise RuntimeError
raise RuntimeError( raise RuntimeError(
f"CRITICAL: {name} is using the default value '{default}'. " f"CRITICAL: {name} is using the default value '{default}'. "
f"Set a strong value via the {name} environment variable " f"Set a strong value via the {name} environment variable "
+106 -10
View File
@@ -1,68 +1,164 @@
from sqlalchemy import create_engine """Database engine and session management for the Aegis platform.
from sqlalchemy.orm import sessionmaker, declarative_base
The engine and session factory are created lazily so that tests can override
``DATABASE_URL`` via environment variables before any import triggers real
PostgreSQL engine creation (which requires psycopg2).
"""
# Import Generator from collections.abc
from collections.abc import Generator
# Import create_engine from sqlalchemy
from sqlalchemy import create_engine
# Import Engine from sqlalchemy.engine
from sqlalchemy.engine import Engine
# Import Session, declarative_base, sessionmaker from sqlalchemy.orm
from sqlalchemy.orm import Session, declarative_base, sessionmaker
# Assign Base = declarative_base()
Base = declarative_base() Base = declarative_base()
# Engine and session factory are created lazily so that tests can # Engine and session factory are created lazily so that tests can
# override DATABASE_URL via environment *before* any import triggers # override DATABASE_URL via environment *before* any import triggers
# the real PostgreSQL engine creation (which requires psycopg2). # the real PostgreSQL engine creation (which requires psycopg2).
_engine = None _engine = None
# Assign _SessionLocal = None
_SessionLocal = None _SessionLocal = None
def _get_engine(): # Define function _get_engine
def _get_engine() -> Engine:
"""Return the shared SQLAlchemy engine, creating it on first call.
Returns:
Engine: Configured SQLAlchemy engine for the application database.
"""
# Declare global variable
global _engine global _engine
# Check: _engine is None
if _engine is None: if _engine is None:
# Import settings from app.config
from app.config import settings from app.config import settings
# Assign url = settings.DATABASE_URL
url = settings.DATABASE_URL url = settings.DATABASE_URL
# Assign kwargs = {}
kwargs: dict = {} kwargs: dict = {}
# Check: url.startswith("postgresql")
if url.startswith("postgresql"): if url.startswith("postgresql"):
# Call kwargs.update()
kwargs.update( kwargs.update(
# Keyword argument: pool_size
pool_size=20, pool_size=20,
# Keyword argument: max_overflow
max_overflow=10, max_overflow=10,
# Keyword argument: pool_recycle
pool_recycle=3600, pool_recycle=3600,
# Keyword argument: pool_pre_ping
pool_pre_ping=True, pool_pre_ping=True,
) )
# Assign _engine = create_engine(url, **kwargs)
_engine = create_engine(url, **kwargs) _engine = create_engine(url, **kwargs)
# Return _engine
return _engine return _engine
def _get_session_factory(): # Define function _get_session_factory
def _get_session_factory() -> sessionmaker:
"""Return the shared sessionmaker, creating it on first call.
Returns:
sessionmaker: Configured sessionmaker bound to the application engine.
"""
# Declare global variable
global _SessionLocal global _SessionLocal
# Check: _SessionLocal is None
if _SessionLocal is None: if _SessionLocal is None:
# Assign _SessionLocal = sessionmaker(
_SessionLocal = sessionmaker( _SessionLocal = sessionmaker(
# Keyword argument: autocommit
autocommit=False, autoflush=False, bind=_get_engine() autocommit=False, autoflush=False, bind=_get_engine()
) )
# Return _SessionLocal
return _SessionLocal return _SessionLocal
# Define class _LazySessionLocal
class _LazySessionLocal: class _LazySessionLocal:
"""Proxy so ``SessionLocal()`` keeps working as before but the real """Proxy so ``SessionLocal()`` keeps working as before but the real sessionmaker is only created on first call."""
sessionmaker is only created on first call."""
def __call__(self, *args, **kwargs): # Define function __call__
def __call__(self, *args: object, **kwargs: object) -> Session:
"""Create and return a new database session.
Args:
*args (object): Positional arguments forwarded to the sessionmaker.
**kwargs (object): Keyword arguments forwarded to the sessionmaker.
Returns:
Session: A new SQLAlchemy database session.
"""
# Return _get_session_factory()(*args, **kwargs)
return _get_session_factory()(*args, **kwargs) return _get_session_factory()(*args, **kwargs)
def __getattr__(self, name): # Define function __getattr__
def __getattr__(self, name: str) -> object:
"""Delegate attribute access to the underlying sessionmaker.
Args:
name (str): Attribute name to look up on the sessionmaker.
Returns:
object: The attribute value from the underlying sessionmaker.
"""
# Return getattr(_get_session_factory(), name)
return getattr(_get_session_factory(), name) return getattr(_get_session_factory(), name)
# Assign SessionLocal = _LazySessionLocal()
SessionLocal = _LazySessionLocal() SessionLocal = _LazySessionLocal()
# Define class _EngineProxy
class _EngineProxy: class _EngineProxy:
"""Thin proxy so ``from app.database import engine`` still works.""" """Thin proxy so ``from app.database import engine`` still works."""
def __getattr__(self, name):
# Define function __getattr__
def __getattr__(self, name: str) -> object:
"""Delegate attribute access to the lazily-created engine.
Args:
name (str): Attribute name to look up on the real engine.
Returns:
object: The attribute value from the underlying SQLAlchemy engine.
"""
# Return getattr(_get_engine(), name)
return getattr(_get_engine(), name) return getattr(_get_engine(), name)
# Assign engine = _EngineProxy() # type: ignore[assignment]
engine = _EngineProxy() # type: ignore[assignment] engine = _EngineProxy() # type: ignore[assignment]
def get_db(): # Define function get_db
def get_db() -> Generator[Session, None, None]:
"""Yield a database session and close it when the request is done.
Intended for use as a FastAPI dependency.
Yields:
Session: An active SQLAlchemy session for the current request.
"""
# Assign db = SessionLocal()
db = SessionLocal() db = SessionLocal()
# Attempt the following; catch errors below
try: try:
# Yield db
yield db yield db
# Always execute this cleanup block
finally: finally:
# Close the database session
db.close() db.close()
+1
View File
@@ -0,0 +1 @@
"""FastAPI dependency injection helpers for auth, DB, and shared state."""
+72 -4
View File
@@ -1,5 +1,4 @@
""" """Authentication and RBAC dependencies for FastAPI.
Authentication and RBAC dependencies for FastAPI.
Provides: Provides:
- ``get_current_user``: decodes JWT from HttpOnly cookie (preferred) or - ``get_current_user``: decodes JWT from HttpOnly cookie (preferred) or
@@ -8,16 +7,34 @@ Provides:
(admins always pass). (admins always pass).
""" """
# Import Callable from collections.abc
from collections.abc import Callable
# Import Optional from typing
from typing import Optional from typing import Optional
# Import Cookie, Depends, HTTPException, status from fastapi
from fastapi import Cookie, Depends, HTTPException, status from fastapi import Cookie, Depends, HTTPException, status
# Import OAuth2PasswordBearer from fastapi.security
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
# Import JWTError, jwt from jose
from jose import JWTError, jwt from jose import JWTError, jwt
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
# Import auth as auth_lib from app
from app import auth as auth_lib from app import auth as auth_lib
# Import settings from app.config
from app.config import settings from app.config import settings
# Import get_db from app.database
from app.database import get_db from app.database import get_db
# Import User from app.models.user
from app.models.user import User from app.models.user import User
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -35,8 +52,11 @@ _COOKIE_NAME = "aegis_token"
async def get_current_user( async def get_current_user(
# Entry: aegis_token
aegis_token: Optional[str] = Cookie(None), aegis_token: Optional[str] = Cookie(None),
# Entry: bearer_token
bearer_token: Optional[str] = Depends(oauth2_scheme), bearer_token: Optional[str] = Depends(oauth2_scheme),
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> User: ) -> User:
"""Decode the JWT, look up the user in *db*, and return it. """Decode the JWT, look up the user in *db*, and return it.
@@ -52,42 +72,66 @@ async def get_current_user(
- the ``sub`` claim is missing, or - the ``sub`` claim is missing, or
- no matching active user exists in the database. - no matching active user exists in the database.
""" """
# Assign credentials_exception = HTTPException(
credentials_exception = HTTPException( credentials_exception = HTTPException(
# Keyword argument: status_code
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
# Keyword argument: detail
detail="Could not validate credentials", detail="Could not validate credentials",
# Keyword argument: headers
headers={"WWW-Authenticate": "Bearer"}, headers={"WWW-Authenticate": "Bearer"},
) )
# Assign revoked_exception = HTTPException(
revoked_exception = HTTPException( revoked_exception = HTTPException(
# Keyword argument: status_code
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
# Keyword argument: detail
detail="Token has been revoked", detail="Token has been revoked",
# Keyword argument: headers
headers={"WWW-Authenticate": "Bearer"}, headers={"WWW-Authenticate": "Bearer"},
) )
# Prefer cookie, fall back to header # Prefer cookie, fall back to header
token = aegis_token or bearer_token token = aegis_token or bearer_token
# Check: token is None
if token is None: if token is None:
# Raise credentials_exception
raise credentials_exception raise credentials_exception
# Attempt the following; catch errors below
try: try:
# Assign payload = jwt.decode(
payload = jwt.decode( payload = jwt.decode(
token, token,
settings.SECRET_KEY, settings.SECRET_KEY,
# Keyword argument: algorithms
algorithms=[settings.ALGORITHM], algorithms=[settings.ALGORITHM],
) )
# Assign username = payload.get("sub")
username: str | None = payload.get("sub") username: str | None = payload.get("sub")
# Check: username is None
if username is None: if username is None:
# Raise credentials_exception
raise credentials_exception raise credentials_exception
# Check token blacklist (revoked tokens) # Check token blacklist (revoked tokens)
jti: str | None = payload.get("jti") jti: str | None = payload.get("jti")
# Check: jti and auth_lib.is_token_blacklisted(jti)
if jti and auth_lib.is_token_blacklisted(jti): if jti and auth_lib.is_token_blacklisted(jti):
# Raise revoked_exception
raise revoked_exception raise revoked_exception
# Handle JWTError
except JWTError: except JWTError:
# Raise credentials_exception
raise credentials_exception raise credentials_exception
# Assign user = db.query(User).filter(User.username == username).first()
user = db.query(User).filter(User.username == username).first() user = db.query(User).filter(User.username == username).first()
# Check: user is None or not user.is_active
if user is None or not user.is_active: if user is None or not user.is_active:
# Raise credentials_exception
raise credentials_exception raise credentials_exception
# Return user
return user return user
@@ -97,6 +141,7 @@ async def get_current_user(
async def require_password_changed( async def require_password_changed(
# Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> User: ) -> User:
"""Block all requests when the user still needs to change their password. """Block all requests when the user still needs to change their password.
@@ -104,15 +149,21 @@ async def require_password_changed(
Only ``/auth/change-password`` and ``/auth/me`` are exempt — those Only ``/auth/change-password`` and ``/auth/me`` are exempt — those
endpoints do **not** depend on this function. endpoints do **not** depend on this function.
""" """
# Check: getattr(current_user, "must_change_password", False)
if getattr(current_user, "must_change_password", False): if getattr(current_user, "must_change_password", False):
# Raise HTTPException
raise HTTPException( raise HTTPException(
# Keyword argument: status_code
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
# Keyword argument: detail
detail="PASSWORD_CHANGE_REQUIRED", detail="PASSWORD_CHANGE_REQUIRED",
) )
# Return current_user
return current_user return current_user
def require_role(required_role: str): # Define function require_role
def require_role(required_role: str) -> Callable[..., object]:
"""Return a FastAPI dependency that enforces *required_role*. """Return a FastAPI dependency that enforces *required_role*.
The dependency allows the request to proceed when The dependency allows the request to proceed when
@@ -120,20 +171,29 @@ def require_role(required_role: str):
Otherwise it raises :class:`~fastapi.HTTPException` **403**. Otherwise it raises :class:`~fastapi.HTTPException` **403**.
""" """
# Define async function role_checker
async def role_checker( async def role_checker(
# Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> User: ) -> User:
# Check: current_user.role != required_role and current_user.role != "admin"
if current_user.role != required_role and current_user.role != "admin": if current_user.role != required_role and current_user.role != "admin":
# Raise HTTPException
raise HTTPException( raise HTTPException(
# Keyword argument: status_code
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
# Keyword argument: detail
detail="Not enough permissions", detail="Not enough permissions",
) )
# Return current_user
return current_user return current_user
# Return role_checker
return role_checker return role_checker
def require_any_role(*roles: str): # Define function require_any_role
def require_any_role(*roles: str) -> Callable[..., object]:
"""Return a FastAPI dependency that enforces **any** of the given *roles*. """Return a FastAPI dependency that enforces **any** of the given *roles*.
Admins always pass. Usage example:: Admins always pass. Usage example::
@@ -141,14 +201,22 @@ def require_any_role(*roles: str):
@router.patch("/resource", dependencies=[Depends(require_any_role("red_lead", "blue_lead"))]) @router.patch("/resource", dependencies=[Depends(require_any_role("red_lead", "blue_lead"))])
""" """
# Define async function role_checker
async def role_checker( async def role_checker(
# Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> User: ) -> User:
# Check: current_user.role != "admin" and current_user.role not in roles
if current_user.role != "admin" and current_user.role not in roles: if current_user.role != "admin" and current_user.role not in roles:
# Raise HTTPException
raise HTTPException( raise HTTPException(
# Keyword argument: status_code
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
# Keyword argument: detail
detail="Not enough permissions", detail="Not enough permissions",
) )
# Return current_user
return current_user return current_user
# Return role_checker
return role_checker return role_checker
+14
View File
@@ -4,27 +4,41 @@ Wiring lives ONLY in the presentation layer — use cases and services
never know which concrete repository implementation they receive. never know which concrete repository implementation they receive.
""" """
# Import Depends from fastapi
from fastapi import Depends from fastapi import Depends
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
# Import get_db from app.database
from app.database import get_db from app.database import get_db
# Import from app.infrastructure.persistence.repositories.sa_technique_repository
from app.infrastructure.persistence.repositories.sa_technique_repository import ( from app.infrastructure.persistence.repositories.sa_technique_repository import (
SATechniqueRepository, SATechniqueRepository,
) )
# Import from app.infrastructure.persistence.repositories.sa_test_repository
from app.infrastructure.persistence.repositories.sa_test_repository import ( from app.infrastructure.persistence.repositories.sa_test_repository import (
SATestRepository, SATestRepository,
) )
# Define function get_technique_repository
def get_technique_repository( def get_technique_repository(
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> SATechniqueRepository: ) -> SATechniqueRepository:
"""Provide a TechniqueRepository backed by the current DB session.""" """Provide a TechniqueRepository backed by the current DB session."""
# Return SATechniqueRepository(db)
return SATechniqueRepository(db) return SATechniqueRepository(db)
# Define function get_test_repository
def get_test_repository( def get_test_repository(
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> SATestRepository: ) -> SATestRepository:
"""Provide a TestRepository backed by the current DB session.""" """Provide a TestRepository backed by the current DB session."""
# Return SATestRepository(db)
return SATestRepository(db) return SATestRepository(db)
+1
View File
@@ -0,0 +1 @@
"""Domain layer — entities, value objects, errors, and repository ports."""
+16
View File
@@ -1,18 +1,34 @@
"""Domain entity classes representing core business objects."""
# Import CampaignEntity from app.domain.entities.campaign
from app.domain.entities.campaign import CampaignEntity from app.domain.entities.campaign import CampaignEntity
# Import from app.domain.entities.compliance
from app.domain.entities.compliance import ( from app.domain.entities.compliance import (
ComplianceControlEntity, ComplianceControlEntity,
ComplianceFrameworkEntity, ComplianceFrameworkEntity,
ControlCoverageStatus, ControlCoverageStatus,
) )
# Import TechniqueEntity from app.domain.entities.technique
from app.domain.entities.technique import TechniqueEntity from app.domain.entities.technique import TechniqueEntity
# Import ThreatActorEntity, ThreatActorTechniqueRef from app.domain.entities.threat_actor
from app.domain.entities.threat_actor import ThreatActorEntity, ThreatActorTechniqueRef from app.domain.entities.threat_actor import ThreatActorEntity, ThreatActorTechniqueRef
# Assign __all__ = [
__all__ = [ __all__ = [
# Literal argument value
"CampaignEntity", "CampaignEntity",
# Literal argument value
"ComplianceControlEntity", "ComplianceControlEntity",
# Literal argument value
"ComplianceFrameworkEntity", "ComplianceFrameworkEntity",
# Literal argument value
"ControlCoverageStatus", "ControlCoverageStatus",
# Literal argument value
"TechniqueEntity", "TechniqueEntity",
# Literal argument value
"ThreatActorEntity", "ThreatActorEntity",
# Literal argument value
"ThreatActorTechniqueRef", "ThreatActorTechniqueRef",
] ]
+121 -5
View File
@@ -3,30 +3,59 @@
Pure domain logic — no framework imports. Pure domain logic — no framework imports.
""" """
# Enable future language features for compatibility
from __future__ import annotations from __future__ import annotations
# Import enum
import enum import enum
import uuid
from dataclasses import dataclass, field
from typing import Any
# Import uuid
import uuid
# Import dataclass, field from dataclasses
from dataclasses import dataclass, field
# Import TYPE_CHECKING from typing
from typing import TYPE_CHECKING
# Import BusinessRuleViolation, InvalidStateTransition from app.domain.errors
from app.domain.errors import BusinessRuleViolation, InvalidStateTransition from app.domain.errors import BusinessRuleViolation, InvalidStateTransition
# Check: TYPE_CHECKING
if TYPE_CHECKING:
# Import Campaign as CampaignORM from app.models.campaign
from app.models.campaign import Campaign as CampaignORM
# Define class CampaignStatus
class CampaignStatus(str, enum.Enum): class CampaignStatus(str, enum.Enum):
"""Lifecycle states for a campaign."""
# Assign draft = "draft"
draft = "draft" draft = "draft"
# Assign active = "active"
active = "active" active = "active"
# Assign completed = "completed"
completed = "completed" completed = "completed"
# Assign archived = "archived"
archived = "archived" archived = "archived"
# Define class CampaignType
class CampaignType(str, enum.Enum): class CampaignType(str, enum.Enum):
"""Classification of the campaign's testing methodology."""
# Assign custom = "custom"
custom = "custom" custom = "custom"
# Assign apt_emulation = "apt_emulation"
apt_emulation = "apt_emulation" apt_emulation = "apt_emulation"
# Assign kill_chain = "kill_chain"
kill_chain = "kill_chain" kill_chain = "kill_chain"
# Assign compliance = "compliance"
compliance = "compliance" compliance = "compliance"
# Assign VALID_TRANSITIONS = {
VALID_TRANSITIONS: dict[CampaignStatus, list[CampaignStatus]] = { VALID_TRANSITIONS: dict[CampaignStatus, list[CampaignStatus]] = {
CampaignStatus.draft: [CampaignStatus.active], CampaignStatus.draft: [CampaignStatus.active],
CampaignStatus.active: [CampaignStatus.completed], CampaignStatus.active: [CampaignStatus.completed],
@@ -35,69 +64,156 @@ VALID_TRANSITIONS: dict[CampaignStatus, list[CampaignStatus]] = {
} }
# Apply the @dataclass decorator
@dataclass @dataclass
# Define class CampaignEntity
class CampaignEntity: class CampaignEntity:
"""Pure domain representation of a security testing campaign.
Owns all lifecycle state-machine logic for campaign activation,
completion, and archival.
"""
# name: str
name: str name: str
# Assign type = CampaignType.custom
type: CampaignType = CampaignType.custom type: CampaignType = CampaignType.custom
# Assign status = CampaignStatus.draft
status: CampaignStatus = CampaignStatus.draft status: CampaignStatus = CampaignStatus.draft
# Assign id = None
id: uuid.UUID | None = None id: uuid.UUID | None = None
# Assign description = None
description: str | None = None description: str | None = None
# Assign threat_actor_id = None
threat_actor_id: uuid.UUID | None = None threat_actor_id: uuid.UUID | None = None
# Assign created_by = None
created_by: uuid.UUID | None = None created_by: uuid.UUID | None = None
# Assign target_platform = None
target_platform: str | None = None target_platform: str | None = None
# Assign tags = field(default_factory=list)
tags: list[str] = field(default_factory=list) tags: list[str] = field(default_factory=list)
# Assign test_count = 0
test_count: int = 0 test_count: int = 0
# Define function can_transition_to
def can_transition_to(self, target: CampaignStatus) -> bool: def can_transition_to(self, target: CampaignStatus) -> bool:
"""Check whether transitioning from the current status to *target* is valid.
Args:
target (CampaignStatus): The desired next status.
Returns:
bool: True if the transition is allowed, False otherwise.
"""
# Return target in VALID_TRANSITIONS.get(self.status, [])
return target in VALID_TRANSITIONS.get(self.status, []) return target in VALID_TRANSITIONS.get(self.status, [])
# Define function activate
def activate(self) -> None: def activate(self) -> None:
"""Transition the campaign from ``draft`` to ``active``.
Returns:
None
"""
# Check: not self.can_transition_to(CampaignStatus.active)
if not self.can_transition_to(CampaignStatus.active): if not self.can_transition_to(CampaignStatus.active):
# Raise InvalidStateTransition
raise InvalidStateTransition( raise InvalidStateTransition(
self.status.value, CampaignStatus.active.value, self.status.value, CampaignStatus.active.value,
[s.value for s in VALID_TRANSITIONS[self.status]], [s.value for s in VALID_TRANSITIONS[self.status]],
) )
# Check: self.test_count == 0
if self.test_count == 0: if self.test_count == 0:
# Raise BusinessRuleViolation
raise BusinessRuleViolation( raise BusinessRuleViolation(
# Literal argument value
"Campaign must have at least one test to activate" "Campaign must have at least one test to activate"
) )
# Assign self.status = CampaignStatus.active
self.status = CampaignStatus.active self.status = CampaignStatus.active
# Define function complete
def complete(self) -> None: def complete(self) -> None:
"""Transition the campaign from ``active`` to ``completed``.
Returns:
None
"""
# Check: not self.can_transition_to(CampaignStatus.completed)
if not self.can_transition_to(CampaignStatus.completed): if not self.can_transition_to(CampaignStatus.completed):
# Raise InvalidStateTransition
raise InvalidStateTransition( raise InvalidStateTransition(
self.status.value, CampaignStatus.completed.value, self.status.value, CampaignStatus.completed.value,
[s.value for s in VALID_TRANSITIONS[self.status]], [s.value for s in VALID_TRANSITIONS[self.status]],
) )
# Assign self.status = CampaignStatus.completed
self.status = CampaignStatus.completed self.status = CampaignStatus.completed
# Define function archive
def archive(self) -> None: def archive(self) -> None:
"""Transition the campaign from ``completed`` to ``archived``.
Returns:
None
"""
# Check: not self.can_transition_to(CampaignStatus.archived)
if not self.can_transition_to(CampaignStatus.archived): if not self.can_transition_to(CampaignStatus.archived):
# Raise InvalidStateTransition
raise InvalidStateTransition( raise InvalidStateTransition(
self.status.value, CampaignStatus.archived.value, self.status.value, CampaignStatus.archived.value,
[s.value for s in VALID_TRANSITIONS[self.status]], [s.value for s in VALID_TRANSITIONS[self.status]],
) )
# Assign self.status = CampaignStatus.archived
self.status = CampaignStatus.archived self.status = CampaignStatus.archived
# Define function ensure_modifiable
def ensure_modifiable(self) -> None: def ensure_modifiable(self) -> None:
"""Raise BusinessRuleViolation if the campaign is not in a modifiable state.
Returns:
None
"""
# Check: self.status not in (CampaignStatus.draft, CampaignStatus.active)
if self.status not in (CampaignStatus.draft, CampaignStatus.active): if self.status not in (CampaignStatus.draft, CampaignStatus.active):
# Raise BusinessRuleViolation
raise BusinessRuleViolation( raise BusinessRuleViolation(
f"Cannot modify campaign in '{self.status.value}' state" f"Cannot modify campaign in '{self.status.value}' state"
) )
# Apply the @classmethod decorator
@classmethod @classmethod
def from_orm(cls, orm: Any) -> CampaignEntity: # Define function from_orm
"""Build a CampaignEntity from a SQLAlchemy Campaign model.""" def from_orm(cls, orm: CampaignORM) -> CampaignEntity:
"""Build a CampaignEntity from a SQLAlchemy Campaign model.
Args:
orm (CampaignORM): The SQLAlchemy Campaign ORM model instance.
Returns:
CampaignEntity: A fully populated domain entity reflecting the ORM state.
"""
# Assign test_count = len(getattr(orm, "campaign_tests", None) or [])
test_count = len(getattr(orm, "campaign_tests", None) or []) test_count = len(getattr(orm, "campaign_tests", None) or [])
# Return cls(
return cls( return cls(
# Keyword argument: id
id=orm.id, id=orm.id,
# Keyword argument: name
name=orm.name, name=orm.name,
# Keyword argument: type
type=CampaignType(orm.type) if orm.type else CampaignType.custom, type=CampaignType(orm.type) if orm.type else CampaignType.custom,
# Keyword argument: status
status=CampaignStatus(orm.status) if orm.status else CampaignStatus.draft, status=CampaignStatus(orm.status) if orm.status else CampaignStatus.draft,
# Keyword argument: description
description=orm.description, description=orm.description,
# Keyword argument: threat_actor_id
threat_actor_id=orm.threat_actor_id, threat_actor_id=orm.threat_actor_id,
# Keyword argument: created_by
created_by=orm.created_by, created_by=orm.created_by,
# Keyword argument: target_platform
target_platform=orm.target_platform, target_platform=orm.target_platform,
# Keyword argument: tags
tags=orm.tags or [], tags=orm.tags or [],
# Keyword argument: test_count
test_count=test_count, test_count=test_count,
) )
+93
View File
@@ -3,68 +3,161 @@
Pure domain logic — no framework imports. Pure domain logic — no framework imports.
""" """
# Enable future language features for compatibility
from __future__ import annotations from __future__ import annotations
# Import enum
import enum import enum
# Import uuid
import uuid import uuid
# Import dataclass, field from dataclasses
from dataclasses import dataclass, field from dataclasses import dataclass, field
# Define class ControlCoverageStatus
class ControlCoverageStatus(str, enum.Enum): class ControlCoverageStatus(str, enum.Enum):
"""Computed coverage level for a single compliance control."""
# Assign covered = "covered"
covered = "covered" covered = "covered"
# Assign partially_covered = "partially_covered"
partially_covered = "partially_covered" partially_covered = "partially_covered"
# Assign not_covered = "not_covered"
not_covered = "not_covered" not_covered = "not_covered"
# Apply the @dataclass decorator
@dataclass @dataclass
# Define class ComplianceControlEntity
class ComplianceControlEntity: class ComplianceControlEntity:
"""Pure domain representation of a single compliance framework control.
Derives its coverage status from the technique statuses associated
with it via the ``technique_statuses`` list.
"""
# control_id: str
control_id: str control_id: str
# title: str
title: str title: str
# Assign id = None
id: uuid.UUID | None = None id: uuid.UUID | None = None
# Assign description = None
description: str | None = None description: str | None = None
# Assign category = None
category: str | None = None category: str | None = None
# Assign technique_statuses = field(default_factory=list)
technique_statuses: list[str] = field(default_factory=list) technique_statuses: list[str] = field(default_factory=list)
# Apply the @property decorator
@property @property
# Define function coverage_status
def coverage_status(self) -> ControlCoverageStatus: def coverage_status(self) -> ControlCoverageStatus:
"""Compute the coverage status for this control based on linked technique statuses.
Returns:
ControlCoverageStatus: ``covered`` when all techniques are covered,
``partially_covered`` when at least one is covered, and
``not_covered`` when none are covered or the control has no techniques.
"""
# Check: not self.technique_statuses
if not self.technique_statuses: if not self.technique_statuses:
# Return ControlCoverageStatus.not_covered
return ControlCoverageStatus.not_covered return ControlCoverageStatus.not_covered
# Assign covered_statuses = {"validated", "partial"}
covered_statuses = {"validated", "partial"} covered_statuses = {"validated", "partial"}
# Assign covered = [s for s in self.technique_statuses if s in covered_statuses]
covered = [s for s in self.technique_statuses if s in covered_statuses] covered = [s for s in self.technique_statuses if s in covered_statuses]
# Check: len(covered) == len(self.technique_statuses)
if len(covered) == len(self.technique_statuses): if len(covered) == len(self.technique_statuses):
# Return ControlCoverageStatus.covered
return ControlCoverageStatus.covered return ControlCoverageStatus.covered
# Alternative: len(covered) > 0
elif len(covered) > 0: elif len(covered) > 0:
# Return ControlCoverageStatus.partially_covered
return ControlCoverageStatus.partially_covered return ControlCoverageStatus.partially_covered
# Return ControlCoverageStatus.not_covered
return ControlCoverageStatus.not_covered return ControlCoverageStatus.not_covered
# Apply the @dataclass decorator
@dataclass @dataclass
# Define class ComplianceFrameworkEntity
class ComplianceFrameworkEntity: class ComplianceFrameworkEntity:
"""Pure domain representation of a compliance framework (e.g. NIST 800-53, PCI-DSS).
Aggregates a collection of controls and provides aggregate coverage statistics.
"""
# name: str
name: str name: str
# Assign id = None
id: uuid.UUID | None = None id: uuid.UUID | None = None
# Assign version = None
version: str | None = None version: str | None = None
# Assign description = None
description: str | None = None description: str | None = None
# Assign is_active = True
is_active: bool = True is_active: bool = True
# Assign controls = field(default_factory=list)
controls: list[ComplianceControlEntity] = field(default_factory=list) controls: list[ComplianceControlEntity] = field(default_factory=list)
# Apply the @property decorator
@property @property
# Define function total_controls
def total_controls(self) -> int: def total_controls(self) -> int:
"""Return the total number of controls in this framework.
Returns:
int: Count of all controls regardless of coverage status.
"""
# Return len(self.controls)
return len(self.controls) return len(self.controls)
# Apply the @property decorator
@property @property
# Define function covered_controls
def covered_controls(self) -> int: def covered_controls(self) -> int:
"""Return the number of fully covered controls in this framework.
Returns:
int: Count of controls with ``ControlCoverageStatus.covered`` status.
"""
# Return sum(
return sum( return sum(
# Literal argument value
1 for c in self.controls 1 for c in self.controls
if c.coverage_status == ControlCoverageStatus.covered if c.coverage_status == ControlCoverageStatus.covered
) )
# Apply the @property decorator
@property @property
# Define function coverage_pct
def coverage_pct(self) -> float: def coverage_pct(self) -> float:
"""Return the percentage of controls that are fully covered.
Returns:
float: A value from 0.0 to 100.0, rounded to one decimal place.
Returns 0.0 when the framework has no controls.
"""
# Check: self.total_controls == 0
if self.total_controls == 0: if self.total_controls == 0:
# Return 0.0
return 0.0 return 0.0
# Return round(self.covered_controls / self.total_controls * 100, 1)
return round(self.covered_controls / self.total_controls * 100, 1) return round(self.covered_controls / self.total_controls * 100, 1)
# Define function get_gap_controls
def get_gap_controls(self) -> list[ComplianceControlEntity]: def get_gap_controls(self) -> list[ComplianceControlEntity]:
"""Return controls that are not fully covered.
Returns:
list[ComplianceControlEntity]: Controls with ``partially_covered`` or
``not_covered`` status.
"""
# Return [
return [ return [
c for c in self.controls c for c in self.controls
if c.coverage_status != ControlCoverageStatus.covered if c.coverage_status != ControlCoverageStatus.covered
+157 -11
View File
@@ -12,105 +12,211 @@ Usage::
entity.apply_to(technique_orm_model) entity.apply_to(technique_orm_model)
""" """
# Enable future language features for compatibility
from __future__ import annotations from __future__ import annotations
# Import uuid
import uuid import uuid
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any
# Import dataclass, field from dataclasses
from dataclasses import dataclass, field
# Import datetime from datetime
from datetime import datetime
# Import TYPE_CHECKING from typing
from typing import TYPE_CHECKING
# Import TechniqueStatus, TestResult, TestState from app.domain.enums
from app.domain.enums import TechniqueStatus, TestResult, TestState from app.domain.enums import TechniqueStatus, TestResult, TestState
# Import MitreId from app.domain.value_objects.mitre_id
from app.domain.value_objects.mitre_id import MitreId from app.domain.value_objects.mitre_id import MitreId
# Check: TYPE_CHECKING
if TYPE_CHECKING:
# Import Technique as TechniqueORM from app.models.technique
from app.models.technique import Technique as TechniqueORM
# Apply the @dataclass decorator
@dataclass(frozen=True) @dataclass(frozen=True)
# Define class _TestSnapshot
class _TestSnapshot: class _TestSnapshot:
"""Minimal read-only view of a test for status calculation.""" """Minimal read-only view of a test for status calculation."""
# state: TestState
state: TestState state: TestState
# detection_result: str | None
detection_result: str | None detection_result: str | None
# Apply the @dataclass decorator
@dataclass @dataclass
# Define class TechniqueEntity
class TechniqueEntity: class TechniqueEntity:
"""Pure domain representation of a MITRE ATT&CK technique.""" """Pure domain representation of a MITRE ATT&CK technique."""
# id: uuid.UUID
id: uuid.UUID id: uuid.UUID
# mitre_id: str
mitre_id: str mitre_id: str
# name: str
name: str name: str
# Assign tactic = None
tactic: str | None = None tactic: str | None = None
# Assign description = None
description: str | None = None description: str | None = None
# Assign platforms = field(default_factory=list)
platforms: list[str] = field(default_factory=list) platforms: list[str] = field(default_factory=list)
# Assign is_subtechnique = False
is_subtechnique: bool = False is_subtechnique: bool = False
# Assign parent_mitre_id = None
parent_mitre_id: str | None = None parent_mitre_id: str | None = None
# Assign status_global = TechniqueStatus.not_evaluated
status_global: TechniqueStatus = TechniqueStatus.not_evaluated status_global: TechniqueStatus = TechniqueStatus.not_evaluated
# Assign review_required = False
review_required: bool = False review_required: bool = False
# Assign last_review_date = None
last_review_date: datetime | None = None last_review_date: datetime | None = None
# Assign mitre_version = None
mitre_version: str | None = None mitre_version: str | None = None
# Assign mitre_last_modified = None
mitre_last_modified: datetime | None = None mitre_last_modified: datetime | None = None
# -- Factory ----------------------------------------------------------- # -- Factory -----------------------------------------------------------
@classmethod @classmethod
# Define function create
def create( def create(
cls, cls,
*, *,
# Entry: mitre_id
mitre_id: str, mitre_id: str,
# Entry: name
name: str, name: str,
# Entry: tactic
tactic: str | None = None, tactic: str | None = None,
# Entry: description
description: str | None = None, description: str | None = None,
# Entry: platforms
platforms: list[str] | None = None, platforms: list[str] | None = None,
) -> TechniqueEntity: ) -> TechniqueEntity:
"""Create a new technique, validating the MITRE ID format.""" """Create a new technique, validating the MITRE ID format.
Args:
mitre_id (str): MITRE ATT&CK identifier (e.g. ``"T1059"`` or ``"T1059.001"``).
name (str): Human-readable name of the technique.
tactic (str | None): MITRE tactic category the technique belongs to.
description (str | None): Optional free-text description.
platforms (list[str] | None): List of platform strings the technique applies to.
Returns:
TechniqueEntity: A new entity with a freshly generated UUID and
``status_global`` set to ``not_evaluated``.
"""
# Assign validated_id = MitreId(mitre_id)
validated_id = MitreId(mitre_id) validated_id = MitreId(mitre_id)
# Return cls(
return cls( return cls(
# Keyword argument: id
id=uuid.uuid4(), id=uuid.uuid4(),
# Keyword argument: mitre_id
mitre_id=validated_id.value, mitre_id=validated_id.value,
# Keyword argument: name
name=name, name=name,
# Keyword argument: tactic
tactic=tactic, tactic=tactic,
# Keyword argument: description
description=description, description=description,
# Keyword argument: platforms
platforms=platforms or [], platforms=platforms or [],
# Keyword argument: is_subtechnique
is_subtechnique=validated_id.is_subtechnique, is_subtechnique=validated_id.is_subtechnique,
# Keyword argument: parent_mitre_id
parent_mitre_id=validated_id.parent_id, parent_mitre_id=validated_id.parent_id,
# Keyword argument: status_global
status_global=TechniqueStatus.not_evaluated, status_global=TechniqueStatus.not_evaluated,
) )
# Apply the @classmethod decorator
@classmethod @classmethod
def from_orm(cls, model: Any) -> TechniqueEntity: # Define function from_orm
"""Build a TechniqueEntity from a SQLAlchemy Technique model.""" def from_orm(cls, model: TechniqueORM) -> TechniqueEntity:
"""Build a TechniqueEntity from a SQLAlchemy Technique model.
Args:
model (TechniqueORM): The ORM model instance to convert.
Returns:
TechniqueEntity: A fully populated domain entity reflecting the ORM state.
"""
# Assign raw_status = model.status_global
raw_status = model.status_global raw_status = model.status_global
# Check: raw_status is None
if raw_status is None: if raw_status is None:
# Assign status = TechniqueStatus.not_evaluated
status = TechniqueStatus.not_evaluated status = TechniqueStatus.not_evaluated
# Alternative: isinstance(raw_status, TechniqueStatus)
elif isinstance(raw_status, TechniqueStatus): elif isinstance(raw_status, TechniqueStatus):
# Assign status = raw_status
status = raw_status status = raw_status
# Fallback: handle remaining cases
else: else:
# Assign status = TechniqueStatus(raw_status)
status = TechniqueStatus(raw_status) status = TechniqueStatus(raw_status)
# Return cls(
return cls( return cls(
# Keyword argument: id
id=model.id, id=model.id,
# Keyword argument: mitre_id
mitre_id=model.mitre_id, mitre_id=model.mitre_id,
# Keyword argument: name
name=model.name, name=model.name,
# Keyword argument: tactic
tactic=model.tactic, tactic=model.tactic,
# Keyword argument: description
description=model.description, description=model.description,
# Keyword argument: platforms
platforms=model.platforms or [], platforms=model.platforms or [],
# Keyword argument: is_subtechnique
is_subtechnique=model.is_subtechnique or False, is_subtechnique=model.is_subtechnique or False,
# Keyword argument: parent_mitre_id
parent_mitre_id=model.parent_mitre_id, parent_mitre_id=model.parent_mitre_id,
# Keyword argument: status_global
status_global=status, status_global=status,
# Keyword argument: review_required
review_required=model.review_required or False, review_required=model.review_required or False,
# Keyword argument: last_review_date
last_review_date=model.last_review_date, last_review_date=model.last_review_date,
# Keyword argument: mitre_version
mitre_version=getattr(model, "mitre_version", None), mitre_version=getattr(model, "mitre_version", None),
# Keyword argument: mitre_last_modified
mitre_last_modified=getattr(model, "mitre_last_modified", None), mitre_last_modified=getattr(model, "mitre_last_modified", None),
) )
def apply_to(self, model: Any) -> None: # Define function apply_to
"""Copy mutable fields back onto the ORM model.""" def apply_to(self, model: TechniqueORM) -> None:
"""Copy mutable fields back onto the ORM model.
Args:
model (TechniqueORM): The ORM model to update in-place.
Returns:
None
"""
# Assign model.status_global = self.status_global
model.status_global = self.status_global model.status_global = self.status_global
# Assign model.review_required = self.review_required
model.review_required = self.review_required model.review_required = self.review_required
# Assign model.last_review_date = self.last_review_date
model.last_review_date = self.last_review_date model.last_review_date = self.last_review_date
# -- Business logic ---------------------------------------------------- # -- Business logic ----------------------------------------------------
def recalculate_status( def recalculate_status(
self, self,
# Entry: test_snapshots
test_snapshots: list[tuple[str, str | None]], test_snapshots: list[tuple[str, str | None]],
) -> TechniqueStatus: ) -> TechniqueStatus:
"""Recompute ``status_global`` from a list of (state, detection_result) pairs. """Recompute ``status_global`` from a list of (state, detection_result) pairs.
@@ -124,41 +230,81 @@ class TechniqueEntity:
3. Some validated, others in progress -> partial 3. Some validated, others in progress -> partial
4. All in intermediate states -> in_progress 4. All in intermediate states -> in_progress
Returns the new status (also set on the entity). Args:
test_snapshots (list[tuple[str, str | None]]): Each element is a
``(state, detection_result)`` pair where *state* is a
:class:`TestState` value string and *detection_result* is a
:class:`TestResult` value string or ``None``.
Returns:
TechniqueStatus: The newly computed status, which is also stored on
the entity's ``status_global`` field.
""" """
# Assign tests = [
tests = [ tests = [
_TestSnapshot( _TestSnapshot(
# Keyword argument: state
state=s if isinstance(s, TestState) else TestState(s), state=s if isinstance(s, TestState) else TestState(s),
# Keyword argument: detection_result
detection_result=dr, detection_result=dr,
) )
for s, dr in test_snapshots for s, dr in test_snapshots
] ]
# Check: not tests
if not tests: if not tests:
# Assign self.status_global = TechniqueStatus.not_evaluated
self.status_global = TechniqueStatus.not_evaluated self.status_global = TechniqueStatus.not_evaluated
# Alternative: all(t.state == TestState.validated for t in tests)
elif all(t.state == TestState.validated for t in tests): elif all(t.state == TestState.validated for t in tests):
# Assign results = [t.detection_result for t in tests if t.detection_result]
results = [t.detection_result for t in tests if t.detection_result] results = [t.detection_result for t in tests if t.detection_result]
# Check: results and all(r == TestResult.detected or r == "detected" for r i...
if results and all(r == TestResult.detected or r == "detected" for r in results): if results and all(r == TestResult.detected or r == "detected" for r in results):
# Assign self.status_global = TechniqueStatus.validated
self.status_global = TechniqueStatus.validated self.status_global = TechniqueStatus.validated
# elif any(
elif any( elif any(
# Keyword argument: r
r == TestResult.partially_detected or r == "partially_detected" r == TestResult.partially_detected or r == "partially_detected"
for r in results for r in results
): ):
# Assign self.status_global = TechniqueStatus.partial
self.status_global = TechniqueStatus.partial self.status_global = TechniqueStatus.partial
# Fallback: handle remaining cases
else: else:
# Assign self.status_global = TechniqueStatus.not_covered
self.status_global = TechniqueStatus.not_covered self.status_global = TechniqueStatus.not_covered
# Alternative: any(t.state == TestState.validated for t in tests)
elif any(t.state == TestState.validated for t in tests): elif any(t.state == TestState.validated for t in tests):
# Assign self.status_global = TechniqueStatus.partial
self.status_global = TechniqueStatus.partial self.status_global = TechniqueStatus.partial
# Fallback: handle remaining cases
else: else:
# Assign self.status_global = TechniqueStatus.in_progress
self.status_global = TechniqueStatus.in_progress self.status_global = TechniqueStatus.in_progress
# Return self.status_global
return self.status_global return self.status_global
# Define function mark_reviewed
def mark_reviewed(self) -> None: def mark_reviewed(self) -> None:
"""Mark the technique as reviewed, clearing the review flag.""" """Mark the technique as reviewed, clearing the review flag.
Returns:
None
"""
# Assign self.review_required = False
self.review_required = False self.review_required = False
# Assign self.last_review_date = datetime.utcnow()
self.last_review_date = datetime.utcnow() self.last_review_date = datetime.utcnow()
# Define function flag_for_review
def flag_for_review(self) -> None: def flag_for_review(self) -> None:
"""Flag the technique as needing review.""" """Flag the technique as needing review.
Returns:
None
"""
# Assign self.review_required = True
self.review_required = True self.review_required = True
+112 -2
View File
@@ -3,94 +3,204 @@
Pure domain logic — no framework imports. Pure domain logic — no framework imports.
""" """
# Enable future language features for compatibility
from __future__ import annotations from __future__ import annotations
# Import uuid
import uuid import uuid
# Import dataclass, field from dataclasses
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any
# Import TYPE_CHECKING from typing
from typing import TYPE_CHECKING
# Check: TYPE_CHECKING
if TYPE_CHECKING:
# Import ThreatActor as ThreatActorORM from app.models.threat_actor
from app.models.threat_actor import ThreatActor as ThreatActorORM
# Apply the @dataclass decorator
@dataclass @dataclass
# Define class ThreatActorTechniqueRef
class ThreatActorTechniqueRef: class ThreatActorTechniqueRef:
"""Lightweight reference to a technique used by an actor.""" """Lightweight reference to a technique used by an actor."""
# technique_id: uuid.UUID
technique_id: uuid.UUID technique_id: uuid.UUID
# Assign mitre_id = None
mitre_id: str | None = None mitre_id: str | None = None
# Assign name = None
name: str | None = None name: str | None = None
# Assign status = None
status: str | None = None status: str | None = None
# Assign usage_description = None
usage_description: str | None = None usage_description: str | None = None
# Apply the @dataclass decorator
@dataclass @dataclass
# Define class ThreatActorEntity
class ThreatActorEntity: class ThreatActorEntity:
"""Pure domain representation of a MITRE ATT&CK threat actor (group).
Aggregates references to the techniques the actor is known to use and
provides coverage analysis properties.
"""
# name: str
name: str name: str
# Assign id = None
id: uuid.UUID | None = None id: uuid.UUID | None = None
# Assign mitre_id = None
mitre_id: str | None = None mitre_id: str | None = None
# Assign aliases = field(default_factory=list)
aliases: list[str] = field(default_factory=list) aliases: list[str] = field(default_factory=list)
# Assign description = None
description: str | None = None description: str | None = None
# Assign country = None
country: str | None = None country: str | None = None
# Assign target_sectors = field(default_factory=list)
target_sectors: list[str] = field(default_factory=list) target_sectors: list[str] = field(default_factory=list)
# Assign target_regions = field(default_factory=list)
target_regions: list[str] = field(default_factory=list) target_regions: list[str] = field(default_factory=list)
# Assign motivation = None
motivation: str | None = None motivation: str | None = None
# Assign sophistication = None
sophistication: str | None = None sophistication: str | None = None
# Assign first_seen = None
first_seen: str | None = None first_seen: str | None = None
# Assign last_seen = None
last_seen: str | None = None last_seen: str | None = None
# Assign is_active = True
is_active: bool = True is_active: bool = True
# Assign techniques = field(default_factory=list)
techniques: list[ThreatActorTechniqueRef] = field(default_factory=list) techniques: list[ThreatActorTechniqueRef] = field(default_factory=list)
# Apply the @property decorator
@property @property
# Define function technique_count
def technique_count(self) -> int: def technique_count(self) -> int:
"""Return the total number of techniques associated with this actor.
Returns:
int: Count of technique references.
"""
# Return len(self.techniques)
return len(self.techniques) return len(self.techniques)
# Apply the @property decorator
@property @property
# Define function covered_techniques
def covered_techniques(self) -> list[ThreatActorTechniqueRef]: def covered_techniques(self) -> list[ThreatActorTechniqueRef]:
"""Return technique references whose coverage status is ``validated`` or ``partial``.
Returns:
list[ThreatActorTechniqueRef]: Subset of techniques considered covered.
"""
# Return [
return [ return [
t for t in self.techniques t for t in self.techniques
if t.status in ("validated", "partial") if t.status in ("validated", "partial")
] ]
# Apply the @property decorator
@property @property
# Define function uncovered_techniques
def uncovered_techniques(self) -> list[ThreatActorTechniqueRef]: def uncovered_techniques(self) -> list[ThreatActorTechniqueRef]:
"""Return technique references whose coverage status is neither ``validated`` nor ``partial``.
Returns:
list[ThreatActorTechniqueRef]: Subset of techniques not yet covered.
"""
# Return [
return [ return [
t for t in self.techniques t for t in self.techniques
if t.status not in ("validated", "partial") if t.status not in ("validated", "partial")
] ]
# Apply the @property decorator
@property @property
# Define function coverage_pct
def coverage_pct(self) -> float: def coverage_pct(self) -> float:
"""Return the percentage of the actor's techniques that are covered.
Returns:
float: A value from 0.0 to 100.0, rounded to one decimal place.
Returns 0.0 when the actor has no associated techniques.
"""
# Check: not self.techniques
if not self.techniques: if not self.techniques:
# Return 0.0
return 0.0 return 0.0
# Return round(len(self.covered_techniques) / len(self.techniques) * 100, 1)
return round(len(self.covered_techniques) / len(self.techniques) * 100, 1) return round(len(self.covered_techniques) / len(self.techniques) * 100, 1)
# Apply the @classmethod decorator
@classmethod @classmethod
def from_orm(cls, orm: Any) -> ThreatActorEntity: # Define function from_orm
def from_orm(cls, orm: ThreatActorORM) -> ThreatActorEntity:
"""Build a ThreatActorEntity from a SQLAlchemy ThreatActor model.
Args:
orm (ThreatActorORM): The ORM model instance to convert.
Returns:
ThreatActorEntity: A fully populated domain entity including
technique references resolved from the ORM relationship.
"""
# Assign techs = []
techs: list[ThreatActorTechniqueRef] = [] techs: list[ThreatActorTechniqueRef] = []
# Iterate over getattr(orm, "techniques", None) or []
for tat in getattr(orm, "techniques", None) or []: for tat in getattr(orm, "techniques", None) or []:
# Assign technique = getattr(tat, "technique", None)
technique = getattr(tat, "technique", None) technique = getattr(tat, "technique", None)
# Call techs.append()
techs.append(ThreatActorTechniqueRef( techs.append(ThreatActorTechniqueRef(
# Keyword argument: technique_id
technique_id=tat.technique_id, technique_id=tat.technique_id,
# Keyword argument: mitre_id
mitre_id=getattr(technique, "mitre_id", None) if technique else None, mitre_id=getattr(technique, "mitre_id", None) if technique else None,
# Keyword argument: name
name=getattr(technique, "name", None) if technique else None, name=getattr(technique, "name", None) if technique else None,
# Keyword argument: status
status=( status=(
technique.status_global.value technique.status_global.value
if technique and hasattr(technique.status_global, "value") if technique and hasattr(technique.status_global, "value")
else getattr(technique, "status_global", None) if technique else None else getattr(technique, "status_global", None) if technique else None
), ),
# Keyword argument: usage_description
usage_description=tat.usage_description, usage_description=tat.usage_description,
)) ))
# Return cls(
return cls( return cls(
# Keyword argument: id
id=orm.id, id=orm.id,
# Keyword argument: name
name=orm.name, name=orm.name,
# Keyword argument: mitre_id
mitre_id=orm.mitre_id, mitre_id=orm.mitre_id,
# Keyword argument: aliases
aliases=orm.aliases or [], aliases=orm.aliases or [],
# Keyword argument: description
description=orm.description, description=orm.description,
# Keyword argument: country
country=orm.country, country=orm.country,
# Keyword argument: target_sectors
target_sectors=orm.target_sectors or [], target_sectors=orm.target_sectors or [],
# Keyword argument: target_regions
target_regions=orm.target_regions or [], target_regions=orm.target_regions or [],
# Keyword argument: motivation
motivation=orm.motivation, motivation=orm.motivation,
# Keyword argument: sophistication
sophistication=orm.sophistication, sophistication=orm.sophistication,
# Keyword argument: first_seen
first_seen=orm.first_seen, first_seen=orm.first_seen,
# Keyword argument: last_seen
last_seen=orm.last_seen, last_seen=orm.last_seen,
# Keyword argument: is_active
is_active=orm.is_active if orm.is_active is not None else True, is_active=orm.is_active if orm.is_active is not None else True,
# Keyword argument: techniques
techniques=techs, techniques=techs,
) )
+37
View File
@@ -5,40 +5,77 @@ truth. ``models/enums.py`` re-exports them so that existing ORM code
continues to work without changes. continues to work without changes.
""" """
# Import enum
import enum import enum
# Define class TechniqueStatus
class TechniqueStatus(str, enum.Enum): class TechniqueStatus(str, enum.Enum):
"""Coverage and evaluation status for a MITRE ATT&CK technique."""
# Assign not_evaluated = "not_evaluated"
not_evaluated = "not_evaluated" not_evaluated = "not_evaluated"
# Assign in_progress = "in_progress"
in_progress = "in_progress" in_progress = "in_progress"
# Assign validated = "validated"
validated = "validated" validated = "validated"
# Assign partial = "partial"
partial = "partial" partial = "partial"
# Assign not_covered = "not_covered"
not_covered = "not_covered" not_covered = "not_covered"
# Assign review_required = "review_required"
review_required = "review_required" review_required = "review_required"
# Define class TestState
class TestState(str, enum.Enum): class TestState(str, enum.Enum):
"""Lifecycle states in the security test state machine."""
# Assign draft = "draft"
draft = "draft" draft = "draft"
# Assign red_executing = "red_executing"
red_executing = "red_executing" red_executing = "red_executing"
# Assign blue_evaluating = "blue_evaluating"
blue_evaluating = "blue_evaluating" blue_evaluating = "blue_evaluating"
# Assign in_review = "in_review"
in_review = "in_review" in_review = "in_review"
# Assign validated = "validated"
validated = "validated" validated = "validated"
# Assign rejected = "rejected"
rejected = "rejected" rejected = "rejected"
# Define class TeamSide
class TeamSide(str, enum.Enum): class TeamSide(str, enum.Enum):
"""Identifies which team (red or blue) an action belongs to."""
# Assign red = "red"
red = "red" red = "red"
# Assign blue = "blue"
blue = "blue" blue = "blue"
# Define class TestResult
class TestResult(str, enum.Enum): class TestResult(str, enum.Enum):
"""Outcome of a red-team test from a detection perspective."""
# Assign detected = "detected"
detected = "detected" detected = "detected"
# Assign not_detected = "not_detected"
not_detected = "not_detected" not_detected = "not_detected"
# Assign partially_detected = "partially_detected"
partially_detected = "partially_detected" partially_detected = "partially_detected"
# Define class DataClassification
class DataClassification(str, enum.Enum): class DataClassification(str, enum.Enum):
"""Data sensitivity classification levels for compliance and retention policies."""
# Assign public = "public"
public = "public" public = "public"
# Assign internal = "internal"
internal = "internal" internal = "internal"
# Assign sensitive = "sensitive"
sensitive = "sensitive" sensitive = "sensitive"
# Assign restricted = "restricted"
restricted = "restricted" restricted = "restricted"
+99 -3
View File
@@ -9,15 +9,30 @@ Existing code that imports from ``app.domain.exceptions`` continues to
work — that module re-exports everything defined here. work — that module re-exports everything defined here.
""" """
# Enable future language features for compatibility
from __future__ import annotations from __future__ import annotations
# Define class DomainError
class DomainError(Exception): class DomainError(Exception):
"""Base for all domain errors.""" """Base for all domain errors."""
# Define function __init__
def __init__(self, message: str, *, code: str = "DOMAIN_ERROR") -> None: def __init__(self, message: str, *, code: str = "DOMAIN_ERROR") -> None:
"""Initialise the domain error with a human-readable message and error code.
Args:
message (str): Human-readable description of the error.
code (str): Machine-readable error code used by the HTTP error handler.
Returns:
None
"""
# Assign self.message = message
self.message = message self.message = message
# Assign self.code = code
self.code = code self.code = code
# Call super()
super().__init__(message) super().__init__(message)
@@ -27,18 +42,45 @@ class DomainError(Exception):
class EntityNotFoundError(DomainError): class EntityNotFoundError(DomainError):
"""A requested entity does not exist.""" """A requested entity does not exist."""
# Define function __init__
def __init__(self, entity: str, identifier: str) -> None: def __init__(self, entity: str, identifier: str) -> None:
"""Initialise an entity-not-found error.
Args:
entity (str): Name of the entity type that was not found (e.g. "Technique").
identifier (str): The ID or key used in the failed lookup.
Returns:
None
"""
# Call super()
super().__init__(f"{entity} not found: {identifier}", code="NOT_FOUND") super().__init__(f"{entity} not found: {identifier}", code="NOT_FOUND")
# Assign self.entity = entity
self.entity = entity self.entity = entity
# Assign self.identifier = identifier
self.identifier = identifier self.identifier = identifier
# Define class DuplicateEntityError
class DuplicateEntityError(DomainError): class DuplicateEntityError(DomainError):
"""Creating an entity that already exists.""" """Creating an entity that already exists."""
# Define function __init__
def __init__(self, entity: str, field: str, value: str) -> None: def __init__(self, entity: str, field: str, value: str) -> None:
"""Initialise a duplicate-entity error.
Args:
entity (str): Name of the entity type that already exists (e.g. "Campaign").
field (str): Name of the field whose value conflicts (e.g. "name").
value (str): The conflicting value that is already in use.
Returns:
None
"""
# Call super()
super().__init__( super().__init__(
f"{entity} with {field}='{value}' already exists", f"{entity} with {field}='{value}' already exists",
# Keyword argument: code
code="DUPLICATE", code="DUPLICATE",
) )
@@ -46,34 +88,67 @@ class DuplicateEntityError(DomainError):
# ── State machine ──────────────────────────────────────────────────── # ── State machine ────────────────────────────────────────────────────
class InvalidStateTransition(DomainError): class InvalidStateTransition(DomainError): # noqa: N818 — DDD term, renaming would break 96 call sites
"""A state-machine transition is not allowed.""" """A state-machine transition is not allowed."""
# Define function __init__
def __init__( def __init__(
self, self,
# Entry: current_state
current_state: str, current_state: str,
# Entry: target_state
target_state: str, target_state: str,
# Entry: valid_transitions
valid_transitions: list[str] | None = None, valid_transitions: list[str] | None = None,
) -> None: ) -> None:
"""Initialise an invalid state-transition error.
Args:
current_state (str): The entity's present state (e.g. "draft").
target_state (str): The state that was illegally requested.
valid_transitions (list[str] | None): Allowed target states from the
current state; included in the error message when provided.
Returns:
None
"""
# Assign msg = f"Cannot transition from '{current_state}' to '{target_state}'"
msg = f"Cannot transition from '{current_state}' to '{target_state}'" msg = f"Cannot transition from '{current_state}' to '{target_state}'"
# Check: valid_transitions
if valid_transitions: if valid_transitions:
# Assign msg = f". Valid transitions: {valid_transitions}"
msg += f". Valid transitions: {valid_transitions}" msg += f". Valid transitions: {valid_transitions}"
# Call super()
super().__init__(msg, code="INVALID_TRANSITION") super().__init__(msg, code="INVALID_TRANSITION")
# Assign self.current_state = current_state
self.current_state = current_state self.current_state = current_state
# Assign self.target_state = target_state
self.target_state = target_state self.target_state = target_state
# Assign self.valid_transitions = valid_transitions or []
self.valid_transitions = valid_transitions or [] self.valid_transitions = valid_transitions or []
# ── Business rules ──────────────────────────────────────────────────── # ── Business rules ────────────────────────────────────────────────────
class BusinessRuleViolation(DomainError): class BusinessRuleViolation(DomainError): # noqa: N818 — DDD term, renaming would break 96 call sites
"""An operation violates a business invariant.""" """An operation violates a business invariant."""
# Define function __init__
def __init__(self, message: str) -> None: def __init__(self, message: str) -> None:
"""Initialise a business-rule violation error.
Args:
message (str): Human-readable description of the violated rule.
Returns:
None
"""
# Call super()
super().__init__(message, code="BUSINESS_RULE_VIOLATION") super().__init__(message, code="BUSINESS_RULE_VIOLATION")
# Define class InvalidOperationError
class InvalidOperationError(BusinessRuleViolation): class InvalidOperationError(BusinessRuleViolation):
"""An operation is invalid in the current context. """An operation is invalid in the current context.
@@ -81,16 +156,37 @@ class InvalidOperationError(BusinessRuleViolation):
:class:`BusinessRuleViolation` directly. :class:`BusinessRuleViolation` directly.
""" """
# Define function __init__
def __init__(self, message: str) -> None: def __init__(self, message: str) -> None:
"""Initialise an invalid-operation error.
Args:
message (str): Human-readable description of why the operation is invalid.
Returns:
None
"""
# Call super()
super().__init__(message) super().__init__(message)
# Assign self.code = "INVALID_OPERATION"
self.code = "INVALID_OPERATION" self.code = "INVALID_OPERATION"
# ── Authorization ──────────────────────────────────────────────────── # ── Authorization ────────────────────────────────────────────────────
class PermissionViolation(DomainError): class PermissionViolation(DomainError): # noqa: N818 — DDD term, renaming would break 96 call sites
"""The user lacks permissions for an action.""" """The user lacks permissions for an action."""
# Define function __init__
def __init__(self, message: str = "Insufficient permissions") -> None: def __init__(self, message: str = "Insufficient permissions") -> None:
"""Initialise a permission-violation error.
Args:
message (str): Human-readable description of the access denial.
Returns:
None
"""
# Call super()
super().__init__(message, code="FORBIDDEN") super().__init__(message, code="FORBIDDEN")
+3
View File
@@ -6,6 +6,7 @@ old import paths so that existing code keeps working without changes::
from app.domain.exceptions import InvalidTransitionError # still works from app.domain.exceptions import InvalidTransitionError # still works
""" """
# Import # noqa: F401 from app.domain.errors
from app.domain.errors import ( # noqa: F401 from app.domain.errors import ( # noqa: F401
BusinessRuleViolation, BusinessRuleViolation,
DomainError, DomainError,
@@ -18,5 +19,7 @@ from app.domain.errors import ( # noqa: F401
# Legacy aliases — old name → new name # Legacy aliases — old name → new name
DomainException = DomainError DomainException = DomainError
# Assign InvalidTransitionError = InvalidStateTransition
InvalidTransitionError = InvalidStateTransition InvalidTransitionError = InvalidStateTransition
# Assign AuthorizationError = PermissionViolation
AuthorizationError = PermissionViolation AuthorizationError = PermissionViolation
+1
View File
@@ -0,0 +1 @@
"""Abstract port interfaces that infrastructure adapters must implement."""
+78 -1
View File
@@ -12,14 +12,19 @@ This satisfies the Open/Closed Principle — the system is open for new
import sources without modifying existing code. import sources without modifying existing code.
""" """
# Enable future language features for compatibility
from __future__ import annotations from __future__ import annotations
# Import Any, Protocol, runtime_checkable from typing
from typing import Any, Protocol, runtime_checkable from typing import Any, Protocol, runtime_checkable
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
# Apply the @runtime_checkable decorator
@runtime_checkable @runtime_checkable
# Define class ImportService
class ImportService(Protocol): class ImportService(Protocol):
"""Contract for any data-import operation. """Contract for any data-import operation.
@@ -27,62 +32,134 @@ class ImportService(Protocol):
downloads, parses, and upserts records from an external source. downloads, parses, and upserts records from an external source.
""" """
def __call__(self, db: Session) -> dict[str, Any]: ... # Define function __call__
def __call__(self, db: Session) -> dict[str, Any]:
"""Execute the import operation against the given database session.
Args:
db (Session): Active SQLAlchemy session to use for all DB operations.
Returns:
dict[str, Any]: Summary statistics for the import run (e.g. created,
updated, skipped counts).
"""
# ...
...
# Define class ImportServiceEntry
class ImportServiceEntry: class ImportServiceEntry:
"""Lazy-loading wrapper that resolves a module-level function on first call.""" """Lazy-loading wrapper that resolves a module-level function on first call."""
# Assign __slots__ = ("_module_path", "_func_name", "_resolved")
__slots__ = ("_module_path", "_func_name", "_resolved") __slots__ = ("_module_path", "_func_name", "_resolved")
# Define function __init__
def __init__(self, module_path: str, func_name: str) -> None: def __init__(self, module_path: str, func_name: str) -> None:
"""Initialise the lazy entry with the module path and function name to resolve later.
Args:
module_path (str): Dotted Python module path, e.g.
``"app.services.atomic_import_service"``.
func_name (str): Name of the callable to import from *module_path*.
Returns:
None
"""
# Assign self._module_path = module_path
self._module_path = module_path self._module_path = module_path
# Assign self._func_name = func_name
self._func_name = func_name self._func_name = func_name
# Assign self._resolved = None
self._resolved: ImportService | None = None self._resolved: ImportService | None = None
# Define function __call__
def __call__(self, db: Session) -> dict[str, Any]: def __call__(self, db: Session) -> dict[str, Any]:
"""Resolve the import function on first call and invoke it with *db*.
Args:
db (Session): SQLAlchemy session passed through to the underlying
import function.
Returns:
dict[str, Any]: Import statistics returned by the underlying function
(e.g. counts of created/updated/skipped records).
"""
# Check: self._resolved is None
if self._resolved is None: if self._resolved is None:
# Import importlib
import importlib import importlib
# Assign mod = importlib.import_module(self._module_path)
mod = importlib.import_module(self._module_path) mod = importlib.import_module(self._module_path)
# Assign self._resolved = getattr(mod, self._func_name)
self._resolved = getattr(mod, self._func_name) self._resolved = getattr(mod, self._func_name)
# Return self._resolved(db)
return self._resolved(db) return self._resolved(db)
# Apply the @property decorator
@property @property
# Define function source_info
def source_info(self) -> str: def source_info(self) -> str:
"""Return a human-readable identifier for this import entry.
Returns:
str: The fully qualified function reference as
``"<module_path>.<func_name>"``.
"""
# Return f"{self._module_path}.{self._func_name}"
return f"{self._module_path}.{self._func_name}" return f"{self._module_path}.{self._func_name}"
# Assign IMPORT_REGISTRY = {
IMPORT_REGISTRY: dict[str, ImportServiceEntry] = { IMPORT_REGISTRY: dict[str, ImportServiceEntry] = {
# Literal argument value
"atomic_red_team": ImportServiceEntry( "atomic_red_team": ImportServiceEntry(
# Literal argument value
"app.services.atomic_import_service", "import_atomic_red_team", "app.services.atomic_import_service", "import_atomic_red_team",
), ),
# Literal argument value
"sigma": ImportServiceEntry( "sigma": ImportServiceEntry(
# Literal argument value
"app.services.sigma_import_service", "sync", "app.services.sigma_import_service", "sync",
), ),
# Literal argument value
"lolbas": ImportServiceEntry( "lolbas": ImportServiceEntry(
# Literal argument value
"app.services.lolbas_import_service", "sync", "app.services.lolbas_import_service", "sync",
), ),
# Literal argument value
"gtfobins": ImportServiceEntry( "gtfobins": ImportServiceEntry(
# Literal argument value
"app.services.lolbas_import_service", "sync_gtfobins", "app.services.lolbas_import_service", "sync_gtfobins",
), ),
# Literal argument value
"caldera": ImportServiceEntry( "caldera": ImportServiceEntry(
# Literal argument value
"app.services.caldera_import_service", "sync", "app.services.caldera_import_service", "sync",
), ),
# Literal argument value
"elastic_rules": ImportServiceEntry( "elastic_rules": ImportServiceEntry(
# Literal argument value
"app.services.elastic_import_service", "sync", "app.services.elastic_import_service", "sync",
), ),
# Literal argument value
"mitre_cti": ImportServiceEntry( "mitre_cti": ImportServiceEntry(
# Literal argument value
"app.services.threat_actor_import_service", "sync", "app.services.threat_actor_import_service", "sync",
), ),
# Literal argument value
"d3fend": ImportServiceEntry( "d3fend": ImportServiceEntry(
# Literal argument value
"app.services.d3fend_import_service", "sync", "app.services.d3fend_import_service", "sync",
), ),
} }
# Define function get_import_handler
def get_import_handler(source_name: str) -> ImportServiceEntry | None: def get_import_handler(source_name: str) -> ImportServiceEntry | None:
"""Look up the import handler for *source_name*. """Look up the import handler for *source_name*.
Returns ``None`` when no handler is registered. Returns ``None`` when no handler is registered.
""" """
# Return IMPORT_REGISTRY.get(source_name)
return IMPORT_REGISTRY.get(source_name) return IMPORT_REGISTRY.get(source_name)
@@ -1,4 +1,9 @@
"""Abstract repository port interfaces for domain entity persistence."""
# Import TechniqueRepository from app.domain.ports.repositories.technique_repository
from app.domain.ports.repositories.technique_repository import TechniqueRepository from app.domain.ports.repositories.technique_repository import TechniqueRepository
# Import TestRepository from app.domain.ports.repositories.test_repository
from app.domain.ports.repositories.test_repository import TestRepository from app.domain.ports.repositories.test_repository import TestRepository
# Assign __all__ = ["TechniqueRepository", "TestRepository"]
__all__ = ["TechniqueRepository", "TestRepository"] __all__ = ["TechniqueRepository", "TestRepository"]
@@ -4,54 +4,157 @@ This is a domain contract — implementations live in infrastructure/.
The domain layer NEVER imports the implementation. The domain layer NEVER imports the implementation.
""" """
# Enable future language features for compatibility
from __future__ import annotations from __future__ import annotations
# Import uuid
import uuid import uuid
# Import NamedTuple, Protocol, runtime_checkable from typing
from typing import NamedTuple, Protocol, runtime_checkable from typing import NamedTuple, Protocol, runtime_checkable
# Import TechniqueEntity from app.domain.entities.technique
from app.domain.entities.technique import TechniqueEntity from app.domain.entities.technique import TechniqueEntity
# Import TechniqueStatus from app.domain.enums
from app.domain.enums import TechniqueStatus from app.domain.enums import TechniqueStatus
# Define class TechniqueWithCounts
class TechniqueWithCounts(NamedTuple): class TechniqueWithCounts(NamedTuple):
"""Pre-aggregated technique data for heatmap/scoring.""" """Pre-aggregated technique data for heatmap/scoring."""
# entity: TechniqueEntity
entity: TechniqueEntity entity: TechniqueEntity
# test_count: int
test_count: int test_count: int
# validated_test_count: int
validated_test_count: int validated_test_count: int
# detection_rule_count: int
detection_rule_count: int detection_rule_count: int
# Apply the @runtime_checkable decorator
@runtime_checkable @runtime_checkable
# Define class TechniqueRepository
class TechniqueRepository(Protocol): class TechniqueRepository(Protocol):
"""Data access contract for techniques (one per aggregate root).""" """Data access contract for techniques (one per aggregate root)."""
# -- Single-entity access ---------------------------------------------- # -- Single-entity access ----------------------------------------------
def find_by_id(self, technique_id: uuid.UUID) -> TechniqueEntity | None: ... def find_by_id(self, technique_id: uuid.UUID) -> TechniqueEntity | None:
"""Return the technique with the given primary key, or None if absent.
def find_by_mitre_id(self, mitre_id: str) -> TechniqueEntity | None: ... Args:
technique_id (uuid.UUID): Primary key of the technique to look up.
Returns:
TechniqueEntity | None: The matching entity, or None if not found.
"""
# ...
...
# Define function find_by_mitre_id
def find_by_mitre_id(self, mitre_id: str) -> TechniqueEntity | None:
"""Return the technique matching the given MITRE ATT&CK identifier, or None.
Args:
mitre_id (str): MITRE ATT&CK ID (e.g. ``"T1059"`` or ``"T1059.001"``).
Returns:
TechniqueEntity | None: The matching entity, or None if not found.
"""
# ...
...
# -- List access ------------------------------------------------------- # -- List access -------------------------------------------------------
def list_all( def list_all(
self, self,
*, *,
# Entry: tactic
tactic: str | None = None, tactic: str | None = None,
# Entry: status
status: TechniqueStatus | None = None, status: TechniqueStatus | None = None,
# Entry: review_required
review_required: bool | None = None, review_required: bool | None = None,
) -> list[TechniqueEntity]: ... ) -> list[TechniqueEntity]:
"""Return all techniques, optionally filtered by tactic, status, or review flag.
def list_by_ids(self, ids: list[uuid.UUID]) -> list[TechniqueEntity]: ... Args:
tactic (str | None): When provided, restrict results to this tactic category.
status (TechniqueStatus | None): When provided, restrict results to this status.
review_required (bool | None): When provided, restrict results to techniques
whose ``review_required`` flag matches this value.
Returns:
list[TechniqueEntity]: Matching technique entities; may be empty.
"""
# ...
...
# Define function list_by_ids
def list_by_ids(self, ids: list[uuid.UUID]) -> list[TechniqueEntity]:
"""Return all techniques whose primary keys are in *ids*.
Args:
ids (list[uuid.UUID]): List of technique UUIDs to retrieve.
Returns:
list[TechniqueEntity]: Entities found for the supplied IDs; order
is not guaranteed and missing IDs are silently omitted.
"""
# ...
...
# -- Batch queries (scoring/heatmap performance) ----------------------- # -- Batch queries (scoring/heatmap performance) -----------------------
def count_by_status(self) -> dict[TechniqueStatus, int]: ... def count_by_status(self) -> dict[TechniqueStatus, int]:
"""Return a count of techniques grouped by their global status.
def find_all_with_test_counts(self) -> list[TechniqueWithCounts]: ... Returns:
dict[TechniqueStatus, int]: Mapping from each status value to the
number of techniques in that state.
"""
# ...
...
# Define function find_all_with_test_counts
def find_all_with_test_counts(self) -> list[TechniqueWithCounts]:
"""Return all techniques together with pre-aggregated test and rule counts.
Returns:
list[TechniqueWithCounts]: Each element bundles a TechniqueEntity
with its total, validated, and detection-rule counts for use
in heatmap and scoring calculations.
"""
# ...
...
# -- Mutations --------------------------------------------------------- # -- Mutations ---------------------------------------------------------
def save(self, technique: TechniqueEntity) -> TechniqueEntity: ... def save(self, technique: TechniqueEntity) -> TechniqueEntity:
"""Persist a technique entity and return the saved state.
def exists_by_mitre_id(self, mitre_id: str) -> bool: ... Args:
technique (TechniqueEntity): The entity to create or update.
Returns:
TechniqueEntity: The persisted entity, potentially with updated
fields (e.g. server-side timestamps).
"""
# ...
...
# Define function exists_by_mitre_id
def exists_by_mitre_id(self, mitre_id: str) -> bool:
"""Return True if a technique with the given MITRE ID exists in the repository.
Args:
mitre_id (str): MITRE ATT&CK ID to check (e.g. ``"T1059"``).
Returns:
bool: True if a matching technique is found, False otherwise.
"""
# ...
...
@@ -3,14 +3,20 @@
This is a domain contract — implementations live in infrastructure/. This is a domain contract — implementations live in infrastructure/.
""" """
# Enable future language features for compatibility
from __future__ import annotations from __future__ import annotations
# Import uuid
import uuid import uuid
from typing import Protocol, runtime_checkable
# Import Protocol from typing
from typing import Protocol
# Import TestState from app.domain.enums
from app.domain.enums import TestState from app.domain.enums import TestState
# Define class TestRepository
class TestRepository(Protocol): class TestRepository(Protocol):
"""Data access contract for tests.""" """Data access contract for tests."""
@@ -22,31 +28,81 @@ class TestRepository(Protocol):
Returns the ORM model directly (not a domain entity) because Returns the ORM model directly (not a domain entity) because
the TestEntity is constructed at the service layer via the TestEntity is constructed at the service layer via
``TestEntity.from_orm()``. ``TestEntity.from_orm()``.
Args:
test_id (uuid.UUID): Primary key of the test to look up.
Returns:
object | None: The ORM model instance, or None if not found.
""" """
# ...
... ...
# -- List access ------------------------------------------------------- # -- List access -------------------------------------------------------
def list_by_technique(self, technique_id: uuid.UUID) -> list[object]: ... def list_by_technique(self, technique_id: uuid.UUID) -> list[object]:
"""Return all test ORM models associated with the given technique.
def list_by_state(self, state: TestState) -> list[object]: ... Args:
technique_id (uuid.UUID): Primary key of the technique whose tests to retrieve.
Returns:
list[object]: ORM model instances for all tests linked to this technique.
"""
# ...
...
# Define function list_by_state
def list_by_state(self, state: TestState) -> list[object]:
"""Return all test ORM models in the given state.
Args:
state (TestState): The state to filter tests by.
Returns:
list[object]: ORM model instances for all tests currently in *state*.
"""
# ...
...
# Define function count_by_technique_and_state
def count_by_technique_and_state( def count_by_technique_and_state(
self, self,
# Entry: technique_id
technique_id: uuid.UUID, technique_id: uuid.UUID,
) -> dict[TestState, int]: ) -> dict[TestState, int]:
"""Return test counts grouped by state for a single technique.""" """Return test counts grouped by state for a single technique.
Args:
technique_id (uuid.UUID): Primary key of the technique whose test
counts to aggregate.
Returns:
dict[TestState, int]: Mapping from each test state to the number of
tests in that state for the given technique.
"""
# ...
... ...
# -- Batch queries ----------------------------------------------------- # -- Batch queries -----------------------------------------------------
def get_states_and_results_for_technique( def get_states_and_results_for_technique(
self, self,
# Entry: technique_id
technique_id: uuid.UUID, technique_id: uuid.UUID,
) -> list[tuple[str, str | None]]: ) -> list[tuple[str, str | None]]:
"""Return (state, detection_result) pairs for all tests of a technique. """Return (state, detection_result) pairs for all tests of a technique.
Used by TechniqueEntity.recalculate_status() without loading full Used by TechniqueEntity.recalculate_status() without loading full
test models. test models.
Args:
technique_id (uuid.UUID): Primary key of the technique whose test
data to retrieve.
Returns:
list[tuple[str, str | None]]: Each tuple contains the test state
string and the detection result string (or None if not yet set).
""" """
# ...
... ...
+337 -19
View File
@@ -20,33 +20,57 @@ After mutations, the service layer copies ``entity.changes`` back onto
the ORM model and persists via Unit of Work. the ORM model and persists via Unit of Work.
""" """
# Enable future language features for compatibility
from __future__ import annotations from __future__ import annotations
# Import enum
import enum import enum
import uuid
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any
# Import uuid
import uuid
# Import dataclass, field from dataclasses
from dataclasses import dataclass, field
# Import datetime from datetime
from datetime import datetime
# Import TYPE_CHECKING, Any from typing
from typing import TYPE_CHECKING, Any
# Import from app.domain.errors
from app.domain.errors import ( from app.domain.errors import (
BusinessRuleViolation, BusinessRuleViolation,
InvalidOperationError, InvalidOperationError,
InvalidStateTransition, InvalidStateTransition,
) )
# Check: TYPE_CHECKING
if TYPE_CHECKING:
# Import Test as TestORM from app.models.test
from app.models.test import Test as TestORM
# ── Value objects ──────────────────────────────────────────────────── # ── Value objects ────────────────────────────────────────────────────
class TestState(str, enum.Enum): class TestState(str, enum.Enum):
"""Ordered lifecycle states for a security test."""
# Assign draft = "draft"
draft = "draft" draft = "draft"
# Assign red_executing = "red_executing"
red_executing = "red_executing" red_executing = "red_executing"
# Assign blue_evaluating = "blue_evaluating"
blue_evaluating = "blue_evaluating" blue_evaluating = "blue_evaluating"
# Assign in_review = "in_review"
in_review = "in_review" in_review = "in_review"
# Assign validated = "validated"
validated = "validated" validated = "validated"
# Assign rejected = "rejected"
rejected = "rejected" rejected = "rejected"
# Assign VALID_TRANSITIONS = {
VALID_TRANSITIONS: dict[TestState, list[TestState]] = { VALID_TRANSITIONS: dict[TestState, list[TestState]] = {
TestState.draft: [TestState.red_executing], TestState.draft: [TestState.red_executing],
TestState.red_executing: [TestState.blue_evaluating], TestState.red_executing: [TestState.blue_evaluating],
@@ -56,6 +80,7 @@ VALID_TRANSITIONS: dict[TestState, list[TestState]] = {
TestState.validated: [], TestState.validated: [],
} }
# Assign _PAUSABLE_STATES = frozenset({TestState.red_executing, TestState.blue_evaluating})
_PAUSABLE_STATES = frozenset({TestState.red_executing, TestState.blue_evaluating}) _PAUSABLE_STATES = frozenset({TestState.red_executing, TestState.blue_evaluating})
@@ -63,8 +88,13 @@ _PAUSABLE_STATES = frozenset({TestState.red_executing, TestState.blue_evaluating
@dataclass(frozen=True) @dataclass(frozen=True)
# Define class DomainEvent
class DomainEvent: class DomainEvent:
"""Immutable record of a domain-level event emitted by the test entity."""
# name: str
name: str name: str
# Assign payload = field(default_factory=dict)
payload: dict[str, Any] = field(default_factory=dict) payload: dict[str, Any] = field(default_factory=dict)
@@ -72,30 +102,44 @@ class DomainEvent:
@dataclass @dataclass
# Define class TestEntity
class TestEntity: class TestEntity:
"""Pure domain representation of a security test.""" """Pure domain representation of a security test."""
# id: uuid.UUID
id: uuid.UUID id: uuid.UUID
# state: TestState
state: TestState state: TestState
# Red validation # Red validation
red_validation_status: str | None = None red_validation_status: str | None = None
# Assign red_validated_by = None
red_validated_by: uuid.UUID | None = None red_validated_by: uuid.UUID | None = None
# Assign red_validated_at = None
red_validated_at: datetime | None = None red_validated_at: datetime | None = None
# Assign red_validation_notes = None
red_validation_notes: str | None = None red_validation_notes: str | None = None
# Blue validation # Blue validation
blue_validation_status: str | None = None blue_validation_status: str | None = None
# Assign blue_validated_by = None
blue_validated_by: uuid.UUID | None = None blue_validated_by: uuid.UUID | None = None
# Assign blue_validated_at = None
blue_validated_at: datetime | None = None blue_validated_at: datetime | None = None
# Assign blue_validation_notes = None
blue_validation_notes: str | None = None blue_validation_notes: str | None = None
# Phase timing # Phase timing
execution_date: datetime | None = None execution_date: datetime | None = None
# Assign red_started_at = None
red_started_at: datetime | None = None red_started_at: datetime | None = None
# Assign blue_started_at = None
blue_started_at: datetime | None = None blue_started_at: datetime | None = None
# Assign paused_at = None
paused_at: datetime | None = None paused_at: datetime | None = None
# Assign red_paused_seconds = 0
red_paused_seconds: int = 0 red_paused_seconds: int = 0
# Assign blue_paused_seconds = 0
blue_paused_seconds: int = 0 blue_paused_seconds: int = 0
# Internal bookkeeping (not persisted as-is) # Internal bookkeeping (not persisted as-is)
@@ -104,58 +148,134 @@ class TestEntity:
# -- Factory -------------------------------------------------------- # -- Factory --------------------------------------------------------
@classmethod @classmethod
def from_orm(cls, model: Any) -> TestEntity: # Define function from_orm
"""Build a TestEntity from a SQLAlchemy ``Test`` model instance.""" def from_orm(cls, model: TestORM) -> TestEntity:
"""Build a TestEntity from a SQLAlchemy ``Test`` model instance.
Args:
model (TestORM): The ORM model whose fields will be copied into the entity.
Returns:
TestEntity: A fully populated domain entity reflecting the ORM state.
"""
# Assign raw_state = model.state
raw_state = model.state raw_state = model.state
# Assign state = raw_state if isinstance(raw_state, TestState) else TestState(raw_st...
state = raw_state if isinstance(raw_state, TestState) else TestState(raw_state) state = raw_state if isinstance(raw_state, TestState) else TestState(raw_state)
# Return cls(
return cls( return cls(
# Keyword argument: id
id=model.id, id=model.id,
# Keyword argument: state
state=state, state=state,
# Keyword argument: red_validation_status
red_validation_status=model.red_validation_status, red_validation_status=model.red_validation_status,
# Keyword argument: red_validated_by
red_validated_by=model.red_validated_by, red_validated_by=model.red_validated_by,
# Keyword argument: red_validated_at
red_validated_at=model.red_validated_at, red_validated_at=model.red_validated_at,
# Keyword argument: red_validation_notes
red_validation_notes=model.red_validation_notes, red_validation_notes=model.red_validation_notes,
# Keyword argument: blue_validation_status
blue_validation_status=model.blue_validation_status, blue_validation_status=model.blue_validation_status,
# Keyword argument: blue_validated_by
blue_validated_by=model.blue_validated_by, blue_validated_by=model.blue_validated_by,
# Keyword argument: blue_validated_at
blue_validated_at=model.blue_validated_at, blue_validated_at=model.blue_validated_at,
# Keyword argument: blue_validation_notes
blue_validation_notes=model.blue_validation_notes, blue_validation_notes=model.blue_validation_notes,
# Keyword argument: execution_date
execution_date=model.execution_date, execution_date=model.execution_date,
# Keyword argument: red_started_at
red_started_at=model.red_started_at, red_started_at=model.red_started_at,
# Keyword argument: blue_started_at
blue_started_at=model.blue_started_at, blue_started_at=model.blue_started_at,
# Keyword argument: paused_at
paused_at=model.paused_at, paused_at=model.paused_at,
# Keyword argument: red_paused_seconds
red_paused_seconds=model.red_paused_seconds or 0, red_paused_seconds=model.red_paused_seconds or 0,
# Keyword argument: blue_paused_seconds
blue_paused_seconds=model.blue_paused_seconds or 0, blue_paused_seconds=model.blue_paused_seconds or 0,
) )
def apply_to(self, model: Any) -> None: # Define function apply_to
"""Copy the entity's mutable fields back onto the ORM model.""" def apply_to(self, model: TestORM) -> None:
"""Copy the entity's mutable fields back onto the ORM model.
Args:
model (TestORM): The ORM model to update in-place.
Returns:
None
"""
# Assign model.state = self.state
model.state = self.state model.state = self.state
# Assign model.red_validation_status = self.red_validation_status
model.red_validation_status = self.red_validation_status model.red_validation_status = self.red_validation_status
# Assign model.red_validated_by = self.red_validated_by
model.red_validated_by = self.red_validated_by model.red_validated_by = self.red_validated_by
# Assign model.red_validated_at = self.red_validated_at
model.red_validated_at = self.red_validated_at model.red_validated_at = self.red_validated_at
# Assign model.red_validation_notes = self.red_validation_notes
model.red_validation_notes = self.red_validation_notes model.red_validation_notes = self.red_validation_notes
# Assign model.blue_validation_status = self.blue_validation_status
model.blue_validation_status = self.blue_validation_status model.blue_validation_status = self.blue_validation_status
# Assign model.blue_validated_by = self.blue_validated_by
model.blue_validated_by = self.blue_validated_by model.blue_validated_by = self.blue_validated_by
# Assign model.blue_validated_at = self.blue_validated_at
model.blue_validated_at = self.blue_validated_at model.blue_validated_at = self.blue_validated_at
# Assign model.blue_validation_notes = self.blue_validation_notes
model.blue_validation_notes = self.blue_validation_notes model.blue_validation_notes = self.blue_validation_notes
# Assign model.execution_date = self.execution_date
model.execution_date = self.execution_date model.execution_date = self.execution_date
# Assign model.red_started_at = self.red_started_at
model.red_started_at = self.red_started_at model.red_started_at = self.red_started_at
# Assign model.blue_started_at = self.blue_started_at
model.blue_started_at = self.blue_started_at model.blue_started_at = self.blue_started_at
# Assign model.paused_at = self.paused_at
model.paused_at = self.paused_at model.paused_at = self.paused_at
# Assign model.red_paused_seconds = self.red_paused_seconds
model.red_paused_seconds = self.red_paused_seconds model.red_paused_seconds = self.red_paused_seconds
# Assign model.blue_paused_seconds = self.blue_paused_seconds
model.blue_paused_seconds = self.blue_paused_seconds model.blue_paused_seconds = self.blue_paused_seconds
# -- Query helpers -------------------------------------------------- # -- Query helpers --------------------------------------------------
@property @property
# Define function events
def events(self) -> list[DomainEvent]: def events(self) -> list[DomainEvent]:
"""Return a snapshot of all domain events raised on this entity.
Returns:
list[DomainEvent]: Ordered list of events emitted since the entity
was constructed or last cleared.
"""
# Return list(self._events)
return list(self._events) return list(self._events)
# Define function can_transition
def can_transition(self, target: TestState) -> bool: def can_transition(self, target: TestState) -> bool:
"""Check whether a transition from the current state to *target* is valid.
Args:
target (TestState): The desired next state.
Returns:
bool: True if the transition is allowed, False otherwise.
"""
# Return target in VALID_TRANSITIONS.get(self.state, [])
return target in VALID_TRANSITIONS.get(self.state, []) return target in VALID_TRANSITIONS.get(self.state, [])
# Apply the @property decorator
@property @property
# Define function is_terminal
def is_terminal(self) -> bool: def is_terminal(self) -> bool:
"""Return True if the test has reached its final (validated) state.
Returns:
bool: True when state is ``validated``, False for all other states.
"""
# Return self.state == TestState.validated
return self.state == TestState.validated return self.state == TestState.validated
# -- Core transition ------------------------------------------------ # -- Core transition ------------------------------------------------
@@ -169,148 +289,305 @@ class TestEntity:
Returns the *previous* state value as a plain string. Returns the *previous* state value as a plain string.
Raises :class:`InvalidStateTransition` when the move is illegal. Raises :class:`InvalidStateTransition` when the move is illegal.
Args:
target (TestState | str): The desired next state, as an enum member
or its string equivalent.
Returns:
str: The previous state value before the transition.
""" """
# Assign value = target.value if hasattr(target, "value") else str(target)
value = target.value if hasattr(target, "value") else str(target) value = target.value if hasattr(target, "value") else str(target)
# Assign resolved = target if isinstance(target, TestState) else TestState(value)
resolved = target if isinstance(target, TestState) else TestState(value) resolved = target if isinstance(target, TestState) else TestState(value)
# Return self._transition(resolved)
return self._transition(resolved) return self._transition(resolved)
# Define function _transition
def _transition(self, target: TestState) -> str: def _transition(self, target: TestState) -> str:
"""Internal: validate and apply; return previous state value.""" """Validate and apply a state transition, returning the previous state value.
Args:
target (TestState): The desired next state enum member.
Returns:
str: The previous state value before the transition was applied.
"""
# Check: not self.can_transition(target)
if not self.can_transition(target): if not self.can_transition(target):
# Assign valid = [s.value for s in VALID_TRANSITIONS.get(self.state, [])]
valid = [s.value for s in VALID_TRANSITIONS.get(self.state, [])] valid = [s.value for s in VALID_TRANSITIONS.get(self.state, [])]
# Raise InvalidStateTransition
raise InvalidStateTransition( raise InvalidStateTransition(
# Keyword argument: current_state
current_state=self.state.value, current_state=self.state.value,
# Keyword argument: target_state
target_state=target.value, target_state=target.value,
# Keyword argument: valid_transitions
valid_transitions=valid, valid_transitions=valid,
) )
# Assign previous = self.state.value
previous = self.state.value previous = self.state.value
# Assign self.state = target
self.state = target self.state = target
# Call self._events.append()
self._events.append(DomainEvent( self._events.append(DomainEvent(
# Literal argument value
"state_changed", "state_changed",
{"previous": previous, "new": target.value}, {"previous": previous, "new": target.value},
)) ))
# Return previous
return previous return previous
# -- Lifecycle commands -------------------------------------------- # -- Lifecycle commands --------------------------------------------
def start_execution(self) -> None: def start_execution(self) -> None:
"""``draft`` -> ``red_executing``.""" """Transition the test from ``draft`` to ``red_executing``.
Returns:
None
"""
# Call self._transition()
self._transition(TestState.red_executing) self._transition(TestState.red_executing)
# Assign now = datetime.utcnow()
now = datetime.utcnow() now = datetime.utcnow()
# Assign self.execution_date = now
self.execution_date = now self.execution_date = now
# Assign self.red_started_at = now
self.red_started_at = now self.red_started_at = now
# Call self._events.append()
self._events.append(DomainEvent("execution_started")) self._events.append(DomainEvent("execution_started"))
# Define function submit_red_evidence
def submit_red_evidence(self) -> int: def submit_red_evidence(self) -> int:
"""``red_executing`` -> ``blue_evaluating``. """Transition the test from ``red_executing`` to ``blue_evaluating``.
Auto-resumes if paused. Returns paused seconds accumulated Auto-resumes if paused. Returns paused seconds accumulated
during this phase (for worklog calculation). during this phase (for worklog calculation).
Returns:
int: Total seconds the red phase was paused.
""" """
# Assign paused_extra = self._auto_resume()
paused_extra = self._auto_resume() paused_extra = self._auto_resume()
# Call self._transition()
self._transition(TestState.blue_evaluating) self._transition(TestState.blue_evaluating)
# Assign total_paused = self.red_paused_seconds + paused_extra
total_paused = self.red_paused_seconds + paused_extra total_paused = self.red_paused_seconds + paused_extra
# Assign self.blue_started_at = datetime.utcnow()
self.blue_started_at = datetime.utcnow() self.blue_started_at = datetime.utcnow()
# Assign self.blue_paused_seconds = 0
self.blue_paused_seconds = 0 self.blue_paused_seconds = 0
# Call self._events.append()
self._events.append(DomainEvent( self._events.append(DomainEvent(
# Literal argument value
"red_evidence_submitted", "red_evidence_submitted",
{"red_paused_seconds": total_paused}, {"red_paused_seconds": total_paused},
)) ))
# Return total_paused
return total_paused return total_paused
# Define function submit_blue_evidence
def submit_blue_evidence(self) -> int: def submit_blue_evidence(self) -> int:
"""``blue_evaluating`` -> ``in_review``. """Transition the test from ``blue_evaluating`` to ``in_review``.
Auto-resumes if paused. Returns paused seconds accumulated Auto-resumes if paused. Returns paused seconds accumulated
during this phase (for worklog calculation). during this phase (for worklog calculation).
Returns:
int: Total seconds the blue phase was paused.
""" """
# Assign paused_extra = self._auto_resume()
paused_extra = self._auto_resume() paused_extra = self._auto_resume()
# Call self._transition()
self._transition(TestState.in_review) self._transition(TestState.in_review)
# Assign total_paused = self.blue_paused_seconds + paused_extra
total_paused = self.blue_paused_seconds + paused_extra total_paused = self.blue_paused_seconds + paused_extra
# Call self._events.append()
self._events.append(DomainEvent( self._events.append(DomainEvent(
# Literal argument value
"blue_evidence_submitted", "blue_evidence_submitted",
{"blue_paused_seconds": total_paused}, {"blue_paused_seconds": total_paused},
)) ))
# Return total_paused
return total_paused return total_paused
# Define function pause_timer
def pause_timer(self) -> None: def pause_timer(self) -> None:
"""Pause the active phase timer.""" """Pause the active phase timer.
Returns:
None
"""
# Check: self.state not in _PAUSABLE_STATES
if self.state not in _PAUSABLE_STATES: if self.state not in _PAUSABLE_STATES:
# Raise BusinessRuleViolation
raise BusinessRuleViolation( raise BusinessRuleViolation(
f"Cannot pause timer in '{self.state.value}' state" f"Cannot pause timer in '{self.state.value}' state"
) )
# Check: self.paused_at is not None
if self.paused_at is not None: if self.paused_at is not None:
# Raise BusinessRuleViolation
raise BusinessRuleViolation("Timer is already paused") raise BusinessRuleViolation("Timer is already paused")
# Assign self.paused_at = datetime.utcnow()
self.paused_at = datetime.utcnow() self.paused_at = datetime.utcnow()
# Call self._events.append()
self._events.append(DomainEvent("timer_paused")) self._events.append(DomainEvent("timer_paused"))
# Define function resume_timer
def resume_timer(self) -> int: def resume_timer(self) -> int:
"""Resume a paused timer. Returns seconds that were paused.""" """Resume a paused timer.
Returns:
int: Number of seconds the timer was paused for.
"""
# Check: self.paused_at is None
if self.paused_at is None: if self.paused_at is None:
# Raise BusinessRuleViolation
raise BusinessRuleViolation("Timer is not paused") raise BusinessRuleViolation("Timer is not paused")
# Assign now = datetime.utcnow()
now = datetime.utcnow() now = datetime.utcnow()
# Assign paused_seconds = max(int((now - self.paused_at).total_seconds()), 0)
paused_seconds = max(int((now - self.paused_at).total_seconds()), 0) paused_seconds = max(int((now - self.paused_at).total_seconds()), 0)
# Check: self.state == TestState.red_executing
if self.state == TestState.red_executing: if self.state == TestState.red_executing:
# Assign self.red_paused_seconds = paused_seconds
self.red_paused_seconds += paused_seconds self.red_paused_seconds += paused_seconds
# Alternative: self.state == TestState.blue_evaluating
elif self.state == TestState.blue_evaluating: elif self.state == TestState.blue_evaluating:
# Assign self.blue_paused_seconds = paused_seconds
self.blue_paused_seconds += paused_seconds self.blue_paused_seconds += paused_seconds
# Assign self.paused_at = None
self.paused_at = None self.paused_at = None
# Call self._events.append()
self._events.append(DomainEvent("timer_resumed", {"paused_seconds": paused_seconds})) self._events.append(DomainEvent("timer_resumed", {"paused_seconds": paused_seconds}))
# Return paused_seconds
return paused_seconds return paused_seconds
# Define function validate_red
def validate_red(self, status: str, *, by: uuid.UUID, notes: str | None = None) -> None: def validate_red(self, status: str, *, by: uuid.UUID, notes: str | None = None) -> None:
"""Record Red Lead's validation decision.""" """Record Red Lead's validation decision.
Args:
status (str): Validation outcome; must be ``"approved"`` or ``"rejected"``.
by (uuid.UUID): UUID of the Red Lead recording the decision.
notes (str | None): Optional free-text notes about the decision.
Returns:
None
"""
# Call self._assert_in_review()
self._assert_in_review("red") self._assert_in_review("red")
# Call self._assert_valid_vote()
self._assert_valid_vote(status) self._assert_valid_vote(status)
# Assign now = datetime.utcnow()
now = datetime.utcnow() now = datetime.utcnow()
# Assign self.red_validation_status = status
self.red_validation_status = status self.red_validation_status = status
# Assign self.red_validated_by = by
self.red_validated_by = by self.red_validated_by = by
# Assign self.red_validated_at = now
self.red_validated_at = now self.red_validated_at = now
# Assign self.red_validation_notes = notes
self.red_validation_notes = notes self.red_validation_notes = notes
# Call self._events.append()
self._events.append(DomainEvent("red_validated", {"status": status})) self._events.append(DomainEvent("red_validated", {"status": status}))
# Call self._check_dual_validation()
self._check_dual_validation() self._check_dual_validation()
# Define function validate_blue
def validate_blue(self, status: str, *, by: uuid.UUID, notes: str | None = None) -> None: def validate_blue(self, status: str, *, by: uuid.UUID, notes: str | None = None) -> None:
"""Record Blue Lead's validation decision.""" """Record Blue Lead's validation decision.
Args:
status (str): Validation outcome; must be ``"approved"`` or ``"rejected"``.
by (uuid.UUID): UUID of the Blue Lead recording the decision.
notes (str | None): Optional free-text notes about the decision.
Returns:
None
"""
# Call self._assert_in_review()
self._assert_in_review("blue") self._assert_in_review("blue")
# Call self._assert_valid_vote()
self._assert_valid_vote(status) self._assert_valid_vote(status)
# Assign now = datetime.utcnow()
now = datetime.utcnow() now = datetime.utcnow()
# Assign self.blue_validation_status = status
self.blue_validation_status = status self.blue_validation_status = status
# Assign self.blue_validated_by = by
self.blue_validated_by = by self.blue_validated_by = by
# Assign self.blue_validated_at = now
self.blue_validated_at = now self.blue_validated_at = now
# Assign self.blue_validation_notes = notes
self.blue_validation_notes = notes self.blue_validation_notes = notes
# Call self._events.append()
self._events.append(DomainEvent("blue_validated", {"status": status})) self._events.append(DomainEvent("blue_validated", {"status": status}))
# Call self._check_dual_validation()
self._check_dual_validation() self._check_dual_validation()
# Define function reopen
def reopen(self) -> None: def reopen(self) -> None:
"""``rejected`` -> ``draft``, clearing all validation/timing fields.""" """Transition the test from ``rejected`` back to ``draft``, clearing all validation and timing fields.
Returns:
None
"""
# Call self._transition()
self._transition(TestState.draft) self._transition(TestState.draft)
# Assign self.red_validation_status = None
self.red_validation_status = None self.red_validation_status = None
# Assign self.red_validated_by = None
self.red_validated_by = None self.red_validated_by = None
# Assign self.red_validated_at = None
self.red_validated_at = None self.red_validated_at = None
# Assign self.red_validation_notes = None
self.red_validation_notes = None self.red_validation_notes = None
# Assign self.blue_validation_status = None
self.blue_validation_status = None self.blue_validation_status = None
# Assign self.blue_validated_by = None
self.blue_validated_by = None self.blue_validated_by = None
# Assign self.blue_validated_at = None
self.blue_validated_at = None self.blue_validated_at = None
# Assign self.blue_validation_notes = None
self.blue_validation_notes = None self.blue_validation_notes = None
# Assign self.red_started_at = None
self.red_started_at = None self.red_started_at = None
# Assign self.blue_started_at = None
self.blue_started_at = None self.blue_started_at = None
# Assign self.paused_at = None
self.paused_at = None self.paused_at = None
# Assign self.red_paused_seconds = 0
self.red_paused_seconds = 0 self.red_paused_seconds = 0
# Assign self.blue_paused_seconds = 0
self.blue_paused_seconds = 0 self.blue_paused_seconds = 0
# Call self._events.append()
self._events.append(DomainEvent("test_reopened")) self._events.append(DomainEvent("test_reopened"))
# -- Private ------------------------------------------------------- # -- Private -------------------------------------------------------
def _auto_resume(self) -> int: def _auto_resume(self) -> int:
"""If paused, accumulate pause time and clear. Returns extra seconds.""" """Accumulate pause time and clear the paused timestamp if currently paused.
Returns:
int: Extra seconds that were accumulated from the current pause, or 0
if the timer was not paused.
"""
# Check: self.paused_at is None
if self.paused_at is None: if self.paused_at is None:
# Return 0
return 0 return 0
# Assign now = datetime.utcnow()
now = datetime.utcnow() now = datetime.utcnow()
# Assign extra = max(int((now - self.paused_at).total_seconds()), 0)
extra = max(int((now - self.paused_at).total_seconds()), 0) extra = max(int((now - self.paused_at).total_seconds()), 0)
# Assign self.paused_at = None
self.paused_at = None self.paused_at = None
# Return extra
return extra return extra
# Define function check_dual_validation
def check_dual_validation(self) -> None: def check_dual_validation(self) -> None:
"""Evaluate both leads' votes and advance state if appropriate. """Evaluate both leads' votes and advance state if appropriate.
@@ -321,29 +598,70 @@ class TestEntity:
Called automatically by :meth:`validate_red` and :meth:`validate_blue`. Called automatically by :meth:`validate_red` and :meth:`validate_blue`.
Also available as a standalone entry point for backward compatibility Also available as a standalone entry point for backward compatibility
when validation fields are set externally. when validation fields are set externally.
Returns:
None
""" """
# Call self._check_dual_validation()
self._check_dual_validation() self._check_dual_validation()
# Define function _assert_in_review
def _assert_in_review(self, side: str) -> None: def _assert_in_review(self, side: str) -> None:
"""Raise InvalidOperationError unless the test is in ``in_review`` state.
Args:
side (str): The team side being validated (``"red"`` or ``"blue"``),
used in the error message.
Returns:
None
"""
# Check: self.state != TestState.in_review
if self.state != TestState.in_review: if self.state != TestState.in_review:
# Raise InvalidOperationError
raise InvalidOperationError( raise InvalidOperationError(
f"Cannot validate {side} side while test is in " f"Cannot validate {side} side while test is in "
f"'{self.state.value}' state (must be in_review)" f"'{self.state.value}' state (must be in_review)"
) )
# Apply the @staticmethod decorator
@staticmethod @staticmethod
# Define function _assert_valid_vote
def _assert_valid_vote(status: str) -> None: def _assert_valid_vote(status: str) -> None:
"""Raise InvalidOperationError if *status* is not a valid vote value.
Args:
status (str): The vote value to validate; must be ``"approved"`` or ``"rejected"``.
Returns:
None
"""
# Check: status not in ("approved", "rejected")
if status not in ("approved", "rejected"): if status not in ("approved", "rejected"):
# Raise InvalidOperationError
raise InvalidOperationError( raise InvalidOperationError(
# Literal argument value
"validation_status must be 'approved' or 'rejected'" "validation_status must be 'approved' or 'rejected'"
) )
# Define function _check_dual_validation
def _check_dual_validation(self) -> None: def _check_dual_validation(self) -> None:
"""If both leads have voted, advance to validated or rejected.""" """Advance to ``validated`` or ``rejected`` once both leads have voted.
Returns:
None
"""
# r, b = self.red_validation_status, self.blue_validation_status
r, b = self.red_validation_status, self.blue_validation_status r, b = self.red_validation_status, self.blue_validation_status
# Check: r == "rejected" or b == "rejected"
if r == "rejected" or b == "rejected": if r == "rejected" or b == "rejected":
# Assign self.state = TestState.rejected
self.state = TestState.rejected self.state = TestState.rejected
# Call self._events.append()
self._events.append(DomainEvent("dual_validation_rejected")) self._events.append(DomainEvent("dual_validation_rejected"))
# Alternative: r == "approved" and b == "approved"
elif r == "approved" and b == "approved": elif r == "approved" and b == "approved":
# Assign self.state = TestState.validated
self.state = TestState.validated self.state = TestState.validated
# Call self._events.append()
self._events.append(DomainEvent("dual_validation_approved")) self._events.append(DomainEvent("dual_validation_approved"))
+49 -1
View File
@@ -20,36 +20,84 @@ Services should **never** call ``db.commit()``; they use ``db.add()`` /
osint_enrichment_service.enrich_technique_with_cves). osint_enrichment_service.enrich_technique_with_cves).
""" """
# Enable future language features for compatibility
from __future__ import annotations from __future__ import annotations
# Import TracebackType from types
from types import TracebackType
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
# Define class UnitOfWork
class UnitOfWork: class UnitOfWork:
"""Lightweight transaction wrapper around an existing SQLAlchemy session.""" """Lightweight transaction wrapper around an existing SQLAlchemy session."""
# Define function __init__
def __init__(self, session: Session) -> None: def __init__(self, session: Session) -> None:
"""Wrap an existing SQLAlchemy session in a Unit of Work.
Args:
session (Session): The active SQLAlchemy session to manage.
Returns:
None
"""
# Assign self._session = session
self._session = session self._session = session
# -- context manager ----------------------------------------------------- # -- context manager -----------------------------------------------------
def __enter__(self) -> "UnitOfWork": def __enter__(self) -> "UnitOfWork":
"""Enter the runtime context, returning this UnitOfWork instance.
Returns:
UnitOfWork: The UnitOfWork itself, for use in ``with`` statements.
"""
# Return self
return self return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None: # Define function __exit__
def __exit__(
self,
# Entry: exc_type
exc_type: type[BaseException] | None,
# Entry: exc_val
exc_val: BaseException | None,
# Entry: exc_tb
exc_tb: TracebackType | None,
) -> None:
"""Exit the runtime context, rolling back if an exception propagated.
Args:
exc_type (type[BaseException] | None): Exception class, if raised.
exc_val (BaseException | None): Exception instance, if raised.
exc_tb (TracebackType | None): Traceback object, if an exception was raised.
Returns:
None
"""
# Check: exc_type is not None
if exc_type is not None: if exc_type is not None:
# Call self.rollback()
self.rollback() self.rollback()
# -- public API ---------------------------------------------------------- # -- public API ----------------------------------------------------------
def commit(self) -> None: def commit(self) -> None:
"""Flush pending changes and commit the transaction.""" """Flush pending changes and commit the transaction."""
# Call self._session.commit()
self._session.commit() self._session.commit()
# Define function rollback
def rollback(self) -> None: def rollback(self) -> None:
"""Roll back the current transaction.""" """Roll back the current transaction."""
# Call self._session.rollback()
self._session.rollback() self._session.rollback()
# Define function flush
def flush(self) -> None: def flush(self) -> None:
"""Flush pending changes without committing (useful for getting IDs).""" """Flush pending changes without committing (useful for getting IDs)."""
# Call self._session.flush()
self._session.flush() self._session.flush()
@@ -1,4 +1,9 @@
"""Immutable domain value objects."""
# Import MitreId from app.domain.value_objects.mitre_id
from app.domain.value_objects.mitre_id import MitreId from app.domain.value_objects.mitre_id import MitreId
# Import ScoringWeights from app.domain.value_objects.scoring_weights
from app.domain.value_objects.scoring_weights import ScoringWeights from app.domain.value_objects.scoring_weights import ScoringWeights
# Assign __all__ = ["MitreId", "ScoringWeights"]
__all__ = ["MitreId", "ScoringWeights"] __all__ = ["MitreId", "ScoringWeights"]
+65 -1
View File
@@ -5,47 +5,111 @@ format: ``T`` followed by 4 digits, optionally a dot and 3 more digits
for sub-techniques (e.g. ``T1059``, ``T1059.001``). for sub-techniques (e.g. ``T1059``, ``T1059.001``).
""" """
# Enable future language features for compatibility
from __future__ import annotations from __future__ import annotations
# Import re
import re import re
# Import dataclass from dataclasses
from dataclasses import dataclass from dataclasses import dataclass
# Assign _MITRE_ID_RE = re.compile(r"^T\d{4}(\.\d{3})?$")
_MITRE_ID_RE = re.compile(r"^T\d{4}(\.\d{3})?$") _MITRE_ID_RE = re.compile(r"^T\d{4}(\.\d{3})?$")
# Apply the @dataclass decorator
@dataclass(frozen=True, slots=True) @dataclass(frozen=True, slots=True)
# Define class MitreId
class MitreId: class MitreId:
"""Validated MITRE ATT&CK technique identifier.""" """Validated MITRE ATT&CK technique identifier."""
# value: str
value: str value: str
# Define function __post_init__
def __post_init__(self) -> None: def __post_init__(self) -> None:
"""Validate that *value* matches the expected MITRE ATT&CK ID format.
Returns:
None
"""
# Check: not _MITRE_ID_RE.match(self.value)
if not _MITRE_ID_RE.match(self.value): if not _MITRE_ID_RE.match(self.value):
# Raise ValueError
raise ValueError( raise ValueError(
f"Invalid MITRE ATT&CK ID '{self.value}'. " f"Invalid MITRE ATT&CK ID '{self.value}'. "
# Literal argument value
"Expected format: T1234 or T1234.001" "Expected format: T1234 or T1234.001"
) )
# Apply the @property decorator
@property @property
# Define function is_subtechnique
def is_subtechnique(self) -> bool: def is_subtechnique(self) -> bool:
"""Return True if this identifier represents a sub-technique.
Returns:
bool: True when the ID contains a dot (e.g. ``T1059.001``).
"""
# Return "." in self.value
return "." in self.value return "." in self.value
# Apply the @property decorator
@property @property
# Define function parent_id
def parent_id(self) -> str | None: def parent_id(self) -> str | None:
"""Return the parent technique ID (e.g. T1059 for T1059.001).""" """Return the parent technique ID (e.g. ``T1059`` for ``T1059.001``).
Returns:
str | None: The parent ID string, or None if this is not a sub-technique.
"""
# Check: not self.is_subtechnique
if not self.is_subtechnique: if not self.is_subtechnique:
# Return None
return None return None
# Return self.value.split(".")[0]
return self.value.split(".")[0] return self.value.split(".")[0]
# Define function __str__
def __str__(self) -> str: def __str__(self) -> str:
"""Return the string representation of the MITRE ID.
Returns:
str: The raw identifier string (e.g. ``"T1059.001"``).
"""
# Return self.value
return self.value return self.value
# Define function __eq__
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool:
"""Compare this MitreId to another MitreId or a plain string.
Args:
other (object): The value to compare against; may be a
:class:`MitreId` instance or a plain ``str``.
Returns:
bool: True if the identifiers are equal, NotImplemented for
unsupported types.
"""
# Check: isinstance(other, MitreId)
if isinstance(other, MitreId): if isinstance(other, MitreId):
# Return self.value == other.value
return self.value == other.value return self.value == other.value
# Check: isinstance(other, str)
if isinstance(other, str): if isinstance(other, str):
# Return self.value == other
return self.value == other return self.value == other
# Return NotImplemented
return NotImplemented return NotImplemented
# Define function __hash__
def __hash__(self) -> int: def __hash__(self) -> int:
"""Return the hash of the identifier string.
Returns:
int: Hash value derived from the raw identifier string.
"""
# Return hash(self.value)
return hash(self.value) return hash(self.value)
@@ -3,22 +3,38 @@
Enforces that all five weights are non-negative and sum to exactly 100. Enforces that all five weights are non-negative and sum to exactly 100.
""" """
# Enable future language features for compatibility
from __future__ import annotations from __future__ import annotations
# Import dataclass from dataclasses
from dataclasses import dataclass from dataclasses import dataclass
# Apply the @dataclass decorator
@dataclass(frozen=True, slots=True) @dataclass(frozen=True, slots=True)
# Define class ScoringWeights
class ScoringWeights: class ScoringWeights:
"""Five scoring dimension weights that must sum to 100.""" """Five scoring dimension weights that must sum to 100."""
# tests: float
tests: float tests: float
# detection_rules: float
detection_rules: float detection_rules: float
# d3fend: float
d3fend: float d3fend: float
# recency: float
recency: float recency: float
# severity: float
severity: float severity: float
# Define function __post_init__
def __post_init__(self) -> None: def __post_init__(self) -> None:
"""Validate that all weights are non-negative and sum to exactly 100.
Returns:
None
"""
# Assign fields = [
fields = [ fields = [
self.tests, self.tests,
self.detection_rules, self.detection_rules,
@@ -26,32 +42,66 @@ class ScoringWeights:
self.recency, self.recency,
self.severity, self.severity,
] ]
# Iterate over fields
for f in fields: for f in fields:
# Check: f < 0
if f < 0: if f < 0:
# Raise ValueError
raise ValueError("Scoring weights must be non-negative") raise ValueError("Scoring weights must be non-negative")
# Assign total = sum(fields)
total = sum(fields) total = sum(fields)
# Check: abs(total - 100) > 0.01
if abs(total - 100) > 0.01: if abs(total - 100) > 0.01:
# Raise ValueError
raise ValueError( raise ValueError(
f"Scoring weights must sum to 100, got {total}" f"Scoring weights must sum to 100, got {total}"
) )
# Apply the @classmethod decorator
@classmethod @classmethod
# Define function default
def default(cls) -> ScoringWeights: def default(cls) -> ScoringWeights:
"""Return the default weight distribution.""" """Return the default weight distribution.
Returns:
ScoringWeights: A weight set with tests=40, detection_rules=25,
d3fend=15, recency=10, severity=10.
"""
# Return cls(
return cls( return cls(
# Keyword argument: tests
tests=40.0, tests=40.0,
# Keyword argument: detection_rules
detection_rules=25.0, detection_rules=25.0,
# Keyword argument: d3fend
d3fend=15.0, d3fend=15.0,
# Keyword argument: recency
recency=10.0, recency=10.0,
# Keyword argument: severity
severity=10.0, severity=10.0,
) )
# Backward-compatible aliases for older API payloads # Backward-compatible aliases for older API payloads
@property @property
# Define function freshness
def freshness(self) -> float: def freshness(self) -> float:
"""Return the recency weight (backward-compatible alias).
Returns:
float: The value of the ``recency`` weight.
"""
# Return self.recency
return self.recency return self.recency
# Apply the @property decorator
@property @property
# Define function platform_diversity
def platform_diversity(self) -> float: def platform_diversity(self) -> float:
"""Return the severity weight (backward-compatible alias).
Returns:
float: The value of the ``severity`` weight.
"""
# Return self.severity
return self.severity return self.severity
+1
View File
@@ -0,0 +1 @@
"""Infrastructure adapters — persistence, caching, and external services."""
@@ -0,0 +1 @@
"""SQLAlchemy-based persistence adapters for the domain repository ports."""
@@ -0,0 +1 @@
"""ORM-to-domain entity mapper functions."""
@@ -1,20 +1,28 @@
"""Technique ORM model <-> domain entity mapper.""" """Technique ORM model <-> domain entity mapper."""
# Enable future language features for compatibility
from __future__ import annotations from __future__ import annotations
# Import TechniqueEntity from app.domain.entities.technique
from app.domain.entities.technique import TechniqueEntity from app.domain.entities.technique import TechniqueEntity
from app.domain.enums import TechniqueStatus
# Define class TechniqueMapper
class TechniqueMapper: class TechniqueMapper:
"""Converts between SQLAlchemy Technique model and TechniqueEntity.""" """Converts between SQLAlchemy Technique model and TechniqueEntity."""
# Apply the @staticmethod decorator
@staticmethod @staticmethod
# Define function to_entity
def to_entity(model: object) -> TechniqueEntity: def to_entity(model: object) -> TechniqueEntity:
"""Convert an ORM Technique model to a domain TechniqueEntity.""" """Convert an ORM Technique model to a domain TechniqueEntity."""
# Return TechniqueEntity.from_orm(model)
return TechniqueEntity.from_orm(model) return TechniqueEntity.from_orm(model)
# Apply the @staticmethod decorator
@staticmethod @staticmethod
# Define function to_model_updates
def to_model_updates(entity: TechniqueEntity, model: object) -> None: def to_model_updates(entity: TechniqueEntity, model: object) -> None:
"""Apply entity changes back onto an existing ORM model.""" """Apply entity changes back onto an existing ORM model."""
# Call entity.apply_to()
entity.apply_to(model) entity.apply_to(model)
@@ -1,8 +1,13 @@
"""Concrete SQLAlchemy repository implementations."""
# Import from app.infrastructure.persistence.repositories.sa_technique_repository
from app.infrastructure.persistence.repositories.sa_technique_repository import ( from app.infrastructure.persistence.repositories.sa_technique_repository import (
SATechniqueRepository, SATechniqueRepository,
) )
# Import from app.infrastructure.persistence.repositories.sa_test_repository
from app.infrastructure.persistence.repositories.sa_test_repository import ( from app.infrastructure.persistence.repositories.sa_test_repository import (
SATestRepository, SATestRepository,
) )
# Assign __all__ = ["SATechniqueRepository", "SATestRepository"]
__all__ = ["SATechniqueRepository", "SATestRepository"] __all__ = ["SATechniqueRepository", "SATestRepository"]
@@ -4,44 +4,95 @@ Receives a Session from the caller — does NOT create its own.
Does NOT call commit() — the Unit of Work owns that. Does NOT call commit() — the Unit of Work owns that.
""" """
# Enable future language features for compatibility
from __future__ import annotations from __future__ import annotations
# Import uuid
import uuid import uuid
# Import func from sqlalchemy
from sqlalchemy import func from sqlalchemy import func
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
# Import TechniqueEntity from app.domain.entities.technique
from app.domain.entities.technique import TechniqueEntity from app.domain.entities.technique import TechniqueEntity
# Import TechniqueStatus, TestState from app.domain.enums
from app.domain.enums import TechniqueStatus, TestState from app.domain.enums import TechniqueStatus, TestState
# Import TechniqueWithCounts from app.domain.ports.repositories.technique_repository
from app.domain.ports.repositories.technique_repository import TechniqueWithCounts from app.domain.ports.repositories.technique_repository import TechniqueWithCounts
# Import TechniqueMapper from app.infrastructure.persistence.mappers.technique_mapper
from app.infrastructure.persistence.mappers.technique_mapper import TechniqueMapper from app.infrastructure.persistence.mappers.technique_mapper import TechniqueMapper
# Import DetectionRule from app.models.detection_rule
from app.models.detection_rule import DetectionRule from app.models.detection_rule import DetectionRule
# Import Technique from app.models.technique
from app.models.technique import Technique from app.models.technique import Technique
# Import Test from app.models.test
from app.models.test import Test from app.models.test import Test
# Define class SATechniqueRepository
class SATechniqueRepository: class SATechniqueRepository:
"""Concrete repository backed by SQLAlchemy.""" """Concrete repository backed by SQLAlchemy."""
# Define function __init__
def __init__(self, session: Session) -> None: def __init__(self, session: Session) -> None:
"""Initialise the repository with a caller-provided session.
Args:
session (Session): The SQLAlchemy session to use for all queries.
"""
# Assign self._session = session
self._session = session self._session = session
# -- Single-entity access ---------------------------------------------- # -- Single-entity access ----------------------------------------------
def find_by_id(self, technique_id: uuid.UUID) -> TechniqueEntity | None: def find_by_id(self, technique_id: uuid.UUID) -> TechniqueEntity | None:
"""Return a single technique by its primary key.
Args:
technique_id (uuid.UUID): The UUID primary key of the technique.
Returns:
TechniqueEntity | None: The matching entity, or ``None`` if not found.
"""
# Assign model = (
model = ( model = (
self._session.query(Technique) self._session.query(Technique)
# Chain .filter() call
.filter(Technique.id == technique_id) .filter(Technique.id == technique_id)
# Chain .first() call
.first() .first()
) )
# Return TechniqueMapper.to_entity(model) if model else None
return TechniqueMapper.to_entity(model) if model else None return TechniqueMapper.to_entity(model) if model else None
# Define function find_by_mitre_id
def find_by_mitre_id(self, mitre_id: str) -> TechniqueEntity | None: def find_by_mitre_id(self, mitre_id: str) -> TechniqueEntity | None:
"""Return a single technique by its MITRE ATT&CK ID (e.g. ``T1059.001``).
Args:
mitre_id (str): The MITRE ATT&CK identifier string.
Returns:
TechniqueEntity | None: The matching entity, or ``None`` if not found.
"""
# Assign model = (
model = ( model = (
self._session.query(Technique) self._session.query(Technique)
# Chain .filter() call
.filter(Technique.mitre_id == mitre_id) .filter(Technique.mitre_id == mitre_id)
# Chain .first() call
.first() .first()
) )
# Return TechniqueMapper.to_entity(model) if model else None
return TechniqueMapper.to_entity(model) if model else None return TechniqueMapper.to_entity(model) if model else None
# -- List access ------------------------------------------------------- # -- List access -------------------------------------------------------
@@ -49,57 +100,111 @@ class SATechniqueRepository:
def list_all( def list_all(
self, self,
*, *,
# Entry: tactic
tactic: str | None = None, tactic: str | None = None,
# Entry: status
status: TechniqueStatus | None = None, status: TechniqueStatus | None = None,
# Entry: review_required
review_required: bool | None = None, review_required: bool | None = None,
) -> list[TechniqueEntity]: ) -> list[TechniqueEntity]:
"""Return all techniques, optionally filtered by tactic, status, or review flag.
Args:
tactic (str | None): Filter to techniques belonging to this tactic name.
status (TechniqueStatus | None): Filter to techniques with this coverage status.
review_required (bool | None): Filter to techniques where ``review_required`` matches.
Returns:
list[TechniqueEntity]: Ordered list of matching technique entities.
"""
# Assign query = self._session.query(Technique)
query = self._session.query(Technique) query = self._session.query(Technique)
# Check: tactic is not None
if tactic is not None: if tactic is not None:
# Assign query = query.filter(Technique.tactic == tactic)
query = query.filter(Technique.tactic == tactic) query = query.filter(Technique.tactic == tactic)
# Check: status is not None
if status is not None: if status is not None:
# Assign query = query.filter(Technique.status_global == status)
query = query.filter(Technique.status_global == status) query = query.filter(Technique.status_global == status)
# Check: review_required is not None
if review_required is not None: if review_required is not None:
# Assign query = query.filter(Technique.review_required == review_required)
query = query.filter(Technique.review_required == review_required) query = query.filter(Technique.review_required == review_required)
# Assign models = query.order_by(Technique.mitre_id).all()
models = query.order_by(Technique.mitre_id).all() models = query.order_by(Technique.mitre_id).all()
# Return [TechniqueMapper.to_entity(m) for m in models]
return [TechniqueMapper.to_entity(m) for m in models] return [TechniqueMapper.to_entity(m) for m in models]
# Define function list_by_ids
def list_by_ids(self, ids: list[uuid.UUID]) -> list[TechniqueEntity]: def list_by_ids(self, ids: list[uuid.UUID]) -> list[TechniqueEntity]:
"""Return techniques matching the provided list of UUIDs.
Args:
ids (list[uuid.UUID]): UUIDs of the techniques to retrieve.
Returns:
list[TechniqueEntity]: Technique entities corresponding to the given IDs.
"""
# Check: not ids
if not ids: if not ids:
# Return []
return [] return []
# Assign models = (
models = ( models = (
self._session.query(Technique) self._session.query(Technique)
# Chain .filter() call
.filter(Technique.id.in_(ids)) .filter(Technique.id.in_(ids))
# Chain .all() call
.all() .all()
) )
# Return [TechniqueMapper.to_entity(m) for m in models]
return [TechniqueMapper.to_entity(m) for m in models] return [TechniqueMapper.to_entity(m) for m in models]
# -- Batch queries (for scoring/heatmap) ------------------------------- # -- Batch queries (for scoring/heatmap) -------------------------------
def count_by_status(self) -> dict[TechniqueStatus, int]: def count_by_status(self) -> dict[TechniqueStatus, int]:
"""Return a count of techniques grouped by their coverage status.
Returns:
dict[TechniqueStatus, int]: Mapping of each status value to its technique count.
"""
# Assign rows = (
rows = ( rows = (
self._session.query( self._session.query(
Technique.status_global, Technique.status_global,
func.count(Technique.id), func.count(Technique.id),
) )
# Chain .group_by() call
.group_by(Technique.status_global) .group_by(Technique.status_global)
# Chain .all() call
.all() .all()
) )
# Assign result = {s: 0 for s in TechniqueStatus}
result = {s: 0 for s in TechniqueStatus} result = {s: 0 for s in TechniqueStatus}
# Iterate over rows
for status_val, count in rows: for status_val, count in rows:
# Assign key = (
key = ( key = (
status_val status_val
if isinstance(status_val, TechniqueStatus) if isinstance(status_val, TechniqueStatus)
else TechniqueStatus(status_val) else TechniqueStatus(status_val)
) )
# Assign result[key] = count
result[key] = count result[key] = count
# Return result
return result return result
# Define function find_all_with_test_counts
def find_all_with_test_counts(self) -> list[TechniqueWithCounts]: def find_all_with_test_counts(self) -> list[TechniqueWithCounts]:
"""Single query replacing the N+1 pattern. """Return all techniques with pre-aggregated test and detection rule counts.
Returns all techniques with pre-aggregated test and detection Uses a single query with subqueries to avoid the N+1 pattern.
rule counts via subqueries.
Returns:
list[TechniqueWithCounts]: All techniques with their associated counts.
""" """
# Assign test_count_sq = (
test_count_sq = ( test_count_sq = (
self._session.query( self._session.query(
Test.technique_id, Test.technique_id,
@@ -108,18 +213,24 @@ class SATechniqueRepository:
func.cast(Test.state == TestState.validated, self._int_type()) func.cast(Test.state == TestState.validated, self._int_type())
).label("validated_count"), ).label("validated_count"),
) )
# Chain .group_by() call
.group_by(Test.technique_id) .group_by(Test.technique_id)
# Chain .subquery() call
.subquery() .subquery()
) )
# Assign rule_count_sq = (
rule_count_sq = ( rule_count_sq = (
self._session.query( self._session.query(
DetectionRule.mitre_technique_id, DetectionRule.mitre_technique_id,
func.count(DetectionRule.id).label("rule_count"), func.count(DetectionRule.id).label("rule_count"),
) )
# Chain .group_by() call
.group_by(DetectionRule.mitre_technique_id) .group_by(DetectionRule.mitre_technique_id)
# Chain .subquery() call
.subquery() .subquery()
) )
# Assign rows = (
rows = ( rows = (
self._session.query( self._session.query(
Technique, Technique,
@@ -127,20 +238,29 @@ class SATechniqueRepository:
func.coalesce(test_count_sq.c.validated_count, 0), func.coalesce(test_count_sq.c.validated_count, 0),
func.coalesce(rule_count_sq.c.rule_count, 0), func.coalesce(rule_count_sq.c.rule_count, 0),
) )
# Chain .outerjoin() call
.outerjoin(test_count_sq, Technique.id == test_count_sq.c.technique_id) .outerjoin(test_count_sq, Technique.id == test_count_sq.c.technique_id)
# Chain .outerjoin() call
.outerjoin( .outerjoin(
rule_count_sq, rule_count_sq,
Technique.mitre_id == rule_count_sq.c.mitre_technique_id, Technique.mitre_id == rule_count_sq.c.mitre_technique_id,
) )
# Chain .order_by() call
.order_by(Technique.mitre_id) .order_by(Technique.mitre_id)
# Chain .all() call
.all() .all()
) )
# Return [
return [ return [
TechniqueWithCounts( TechniqueWithCounts(
# Keyword argument: entity
entity=TechniqueMapper.to_entity(tech), entity=TechniqueMapper.to_entity(tech),
# Keyword argument: test_count
test_count=int(tc), test_count=int(tc),
# Keyword argument: validated_test_count
validated_test_count=int(vtc), validated_test_count=int(vtc),
# Keyword argument: detection_rule_count
detection_rule_count=int(rc), detection_rule_count=int(rc),
) )
for tech, tc, vtc, rc in rows for tech, tc, vtc, rc in rows
@@ -149,55 +269,112 @@ class SATechniqueRepository:
# -- Mutations --------------------------------------------------------- # -- Mutations ---------------------------------------------------------
def save(self, technique: TechniqueEntity) -> TechniqueEntity: def save(self, technique: TechniqueEntity) -> TechniqueEntity:
"""Persist a technique entity, inserting or updating as needed.
Args:
technique (TechniqueEntity): The domain entity to persist.
Returns:
TechniqueEntity: The persisted entity reflecting the current DB state.
"""
# Assign existing = (
existing = ( existing = (
self._session.query(Technique) self._session.query(Technique)
# Chain .filter() call
.filter(Technique.id == technique.id) .filter(Technique.id == technique.id)
# Chain .first() call
.first() .first()
) )
# Check: existing
if existing: if existing:
# Call technique.apply_to()
technique.apply_to(existing) technique.apply_to(existing)
# Assign existing.mitre_id = technique.mitre_id
existing.mitre_id = technique.mitre_id existing.mitre_id = technique.mitre_id
# Assign existing.name = technique.name
existing.name = technique.name existing.name = technique.name
# Assign existing.tactic = technique.tactic
existing.tactic = technique.tactic existing.tactic = technique.tactic
# Assign existing.description = technique.description
existing.description = technique.description existing.description = technique.description
# Assign existing.platforms = technique.platforms
existing.platforms = technique.platforms existing.platforms = technique.platforms
# Assign existing.is_subtechnique = technique.is_subtechnique
existing.is_subtechnique = technique.is_subtechnique existing.is_subtechnique = technique.is_subtechnique
# Assign existing.parent_mitre_id = technique.parent_mitre_id
existing.parent_mitre_id = technique.parent_mitre_id existing.parent_mitre_id = technique.parent_mitre_id
# Assign existing.mitre_version = technique.mitre_version
existing.mitre_version = technique.mitre_version existing.mitre_version = technique.mitre_version
# Assign existing.mitre_last_modified = technique.mitre_last_modified
existing.mitre_last_modified = technique.mitre_last_modified existing.mitre_last_modified = technique.mitre_last_modified
# Call self._session.flush()
self._session.flush() self._session.flush()
# Return TechniqueMapper.to_entity(existing)
return TechniqueMapper.to_entity(existing) return TechniqueMapper.to_entity(existing)
# Fallback: handle remaining cases
else: else:
# Assign model = Technique(
model = Technique( model = Technique(
# Keyword argument: id
id=technique.id, id=technique.id,
# Keyword argument: mitre_id
mitre_id=technique.mitre_id, mitre_id=technique.mitre_id,
# Keyword argument: name
name=technique.name, name=technique.name,
# Keyword argument: tactic
tactic=technique.tactic, tactic=technique.tactic,
# Keyword argument: description
description=technique.description, description=technique.description,
# Keyword argument: platforms
platforms=technique.platforms, platforms=technique.platforms,
# Keyword argument: is_subtechnique
is_subtechnique=technique.is_subtechnique, is_subtechnique=technique.is_subtechnique,
# Keyword argument: parent_mitre_id
parent_mitre_id=technique.parent_mitre_id, parent_mitre_id=technique.parent_mitre_id,
# Keyword argument: status_global
status_global=technique.status_global, status_global=technique.status_global,
# Keyword argument: review_required
review_required=technique.review_required, review_required=technique.review_required,
# Keyword argument: last_review_date
last_review_date=technique.last_review_date, last_review_date=technique.last_review_date,
# Keyword argument: mitre_version
mitre_version=technique.mitre_version, mitre_version=technique.mitre_version,
# Keyword argument: mitre_last_modified
mitre_last_modified=technique.mitre_last_modified, mitre_last_modified=technique.mitre_last_modified,
) )
# Call self._session.add()
self._session.add(model) self._session.add(model)
# Call self._session.flush()
self._session.flush() self._session.flush()
# Return TechniqueMapper.to_entity(model)
return TechniqueMapper.to_entity(model) return TechniqueMapper.to_entity(model)
# Define function exists_by_mitre_id
def exists_by_mitre_id(self, mitre_id: str) -> bool: def exists_by_mitre_id(self, mitre_id: str) -> bool:
"""Check whether a technique with the given MITRE ID already exists.
Args:
mitre_id (str): The MITRE ATT&CK identifier to look up.
Returns:
bool: ``True`` if the technique exists, ``False`` otherwise.
"""
# Return (
return ( return (
self._session.query(Technique.id) self._session.query(Technique.id)
# Chain .filter() call
.filter(Technique.mitre_id == mitre_id) .filter(Technique.mitre_id == mitre_id)
# Chain .first() call
.first() .first()
) is not None ) is not None
# -- Internal ---------------------------------------------------------- # -- Internal ----------------------------------------------------------
@staticmethod @staticmethod
def _int_type(): # Define function _int_type
def _int_type() -> type:
"""Return an Integer type for CAST expressions (SQLite-compatible).""" """Return an Integer type for CAST expressions (SQLite-compatible)."""
# Import Integer from sqlalchemy
from sqlalchemy import Integer from sqlalchemy import Integer
# Return Integer
return Integer return Integer
@@ -1,78 +1,163 @@
"""SQLAlchemy implementation of TestRepository.""" """SQLAlchemy implementation of TestRepository."""
# Enable future language features for compatibility
from __future__ import annotations from __future__ import annotations
# Import uuid
import uuid import uuid
# Import func from sqlalchemy
from sqlalchemy import func from sqlalchemy import func
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
# Import TestState from app.domain.enums
from app.domain.enums import TestState from app.domain.enums import TestState
# Import Test from app.models.test
from app.models.test import Test from app.models.test import Test
# Define class SATestRepository
class SATestRepository: class SATestRepository:
"""Concrete test repository backed by SQLAlchemy.""" """Concrete test repository backed by SQLAlchemy."""
# Define function __init__
def __init__(self, session: Session) -> None: def __init__(self, session: Session) -> None:
"""Initialise the repository with a caller-provided session.
Args:
session (Session): The SQLAlchemy session to use for all queries.
"""
# Assign self._session = session
self._session = session self._session = session
# Define function find_by_id
def find_by_id(self, test_id: uuid.UUID) -> Test | None: def find_by_id(self, test_id: uuid.UUID) -> Test | None:
"""Return a single test by its primary key.
Args:
test_id (uuid.UUID): The UUID primary key of the test.
Returns:
Test | None: The ORM model instance, or ``None`` if not found.
"""
# Return (
return ( return (
self._session.query(Test) self._session.query(Test)
# Chain .filter() call
.filter(Test.id == test_id) .filter(Test.id == test_id)
# Chain .first() call
.first() .first()
) )
# Define function list_by_technique
def list_by_technique(self, technique_id: uuid.UUID) -> list[Test]: def list_by_technique(self, technique_id: uuid.UUID) -> list[Test]:
"""Return all tests for a given technique, ordered by creation date.
Args:
technique_id (uuid.UUID): The UUID of the parent technique.
Returns:
list[Test]: ORM model instances ordered by ``created_at`` ascending.
"""
# Return (
return ( return (
self._session.query(Test) self._session.query(Test)
# Chain .filter() call
.filter(Test.technique_id == technique_id) .filter(Test.technique_id == technique_id)
# Chain .order_by() call
.order_by(Test.created_at) .order_by(Test.created_at)
# Chain .all() call
.all() .all()
) )
# Define function list_by_state
def list_by_state(self, state: TestState) -> list[Test]: def list_by_state(self, state: TestState) -> list[Test]:
"""Return all tests that are currently in the given workflow state.
Args:
state (TestState): The workflow state to filter on.
Returns:
list[Test]: All ORM model instances with the specified state.
"""
# Return (
return ( return (
self._session.query(Test) self._session.query(Test)
# Chain .filter() call
.filter(Test.state == state) .filter(Test.state == state)
# Chain .all() call
.all() .all()
) )
# Define function count_by_technique_and_state
def count_by_technique_and_state( def count_by_technique_and_state(
self, self,
# Entry: technique_id
technique_id: uuid.UUID, technique_id: uuid.UUID,
) -> dict[TestState, int]: ) -> dict[TestState, int]:
"""Return per-state test counts for a specific technique.
Args:
technique_id (uuid.UUID): The UUID of the technique to aggregate for.
Returns:
dict[TestState, int]: Mapping of each state to the number of tests in that state.
"""
# Assign rows = (
rows = ( rows = (
self._session.query(Test.state, func.count(Test.id)) self._session.query(Test.state, func.count(Test.id))
# Chain .filter() call
.filter(Test.technique_id == technique_id) .filter(Test.technique_id == technique_id)
# Chain .group_by() call
.group_by(Test.state) .group_by(Test.state)
# Chain .all() call
.all() .all()
) )
# Assign result = {}
result: dict[TestState, int] = {} result: dict[TestState, int] = {}
# Iterate over rows
for state_val, count in rows: for state_val, count in rows:
# Assign key = (
key = ( key = (
state_val state_val
if isinstance(state_val, TestState) if isinstance(state_val, TestState)
else TestState(state_val) else TestState(state_val)
) )
# Assign result[key] = count
result[key] = count result[key] = count
# Return result
return result return result
# Define function get_states_and_results_for_technique
def get_states_and_results_for_technique( def get_states_and_results_for_technique(
self, self,
# Entry: technique_id
technique_id: uuid.UUID, technique_id: uuid.UUID,
) -> list[tuple[str, str | None]]: ) -> list[tuple[str, str | None]]:
"""Return lightweight (state, detection_result) pairs. """Return lightweight ``(state, detection_result)`` pairs for a technique.
Used by TechniqueEntity.recalculate_status() without loading Used by ``TechniqueEntity.recalculate_status()`` to avoid loading full
full Test models. ``Test`` models.
Args:
technique_id (uuid.UUID): The UUID of the technique to query.
Returns:
list[tuple[str, str | None]]: Each tuple contains the state string
and the detection result string (or ``None``).
""" """
# Assign rows = (
rows = ( rows = (
self._session.query(Test.state, Test.detection_result) self._session.query(Test.state, Test.detection_result)
# Chain .filter() call
.filter(Test.technique_id == technique_id) .filter(Test.technique_id == technique_id)
# Chain .all() call
.all() .all()
) )
# Return [
return [ return [
( (
r.state.value if hasattr(r.state, "value") else str(r.state), r.state.value if hasattr(r.state, "value") else str(r.state),
@@ -13,54 +13,79 @@ Usage::
get_redis_blacklist().setex("blacklist:…", ttl, "1") get_redis_blacklist().setex("blacklist:…", ttl, "1")
""" """
# Enable future language features for compatibility
from __future__ import annotations from __future__ import annotations
# Import logging
import logging import logging
# Import urlparse, urlunparse from urllib.parse
from urllib.parse import urlparse, urlunparse from urllib.parse import urlparse, urlunparse
# Import redis
import redis import redis
# Import settings from app.config
from app.config import settings from app.config import settings
# Assign logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Assign _clients = {}
_clients: dict[str, redis.Redis] = {} _clients: dict[str, redis.Redis] = {}
# Define function _redis_url_with_db
def _redis_url_with_db(base_url: str, db_index: int) -> str: def _redis_url_with_db(base_url: str, db_index: int) -> str:
"""Return *base_url* with its path replaced by ``/{db_index}``.""" """Return *base_url* with its path replaced by ``/{db_index}``."""
# Assign parsed = urlparse(base_url)
parsed = urlparse(base_url) parsed = urlparse(base_url)
# Assign path = f"/{db_index}"
path = f"/{db_index}" path = f"/{db_index}"
# Return urlunparse(
return urlunparse( return urlunparse(
(parsed.scheme, parsed.netloc, path, "", "", ""), (parsed.scheme, parsed.netloc, path, "", "", ""),
) )
# Define function _get_client
def _get_client(url: str) -> redis.Redis: def _get_client(url: str) -> redis.Redis:
# Check: url not in _clients
if url not in _clients: if url not in _clients:
# Assign _clients[url] = redis.from_url(url, decode_responses=True)
_clients[url] = redis.from_url(url, decode_responses=True) _clients[url] = redis.from_url(url, decode_responses=True)
# Log info: "Redis client connected to %s", url
logger.info("Redis client connected to %s", url) logger.info("Redis client connected to %s", url)
# Return _clients[url]
return _clients[url] return _clients[url]
# Define function get_redis
def get_redis() -> redis.Redis: def get_redis() -> redis.Redis:
"""Default Redis connection (URL from ``settings.REDIS_URL``).""" """Default Redis connection (URL from ``settings.REDIS_URL``)."""
# Return _get_client(settings.REDIS_URL)
return _get_client(settings.REDIS_URL) return _get_client(settings.REDIS_URL)
# Define function get_redis_blacklist
def get_redis_blacklist() -> redis.Redis: def get_redis_blacklist() -> redis.Redis:
"""Redis DB used for JWT revocation (``jti`` keys with TTL).""" """Redis DB used for JWT revocation (``jti`` keys with TTL)."""
# Assign url = _redis_url_with_db(
url = _redis_url_with_db( url = _redis_url_with_db(
settings.REDIS_URL, settings.REDIS_URL,
settings.REDIS_TOKEN_BLACKLIST_DB, settings.REDIS_TOKEN_BLACKLIST_DB,
) )
# Return _get_client(url)
return _get_client(url) return _get_client(url)
# Define function get_redis_cache
def get_redis_cache() -> redis.Redis: def get_redis_cache() -> redis.Redis:
"""Redis DB reserved for shared cache (scores, queues, etc.).""" """Redis DB reserved for shared cache (scores, queues, etc.)."""
# Assign url = _redis_url_with_db(
url = _redis_url_with_db( url = _redis_url_with_db(
settings.REDIS_URL, settings.REDIS_URL,
settings.REDIS_CACHE_DB, settings.REDIS_CACHE_DB,
) )
# Return _get_client(url)
return _get_client(url) return _get_client(url)
+1
View File
@@ -0,0 +1 @@
"""Background scheduler jobs (MITRE sync, Jira sync, data retention)."""
+28
View File
@@ -1,37 +1,65 @@
"""Scheduled job — syncs all Jira links hourly.""" """Scheduled job — syncs all Jira links hourly."""
# Import logging
import logging import logging
# Import settings from app.config
from app.config import settings from app.config import settings
# Import SessionLocal from app.database
from app.database import SessionLocal from app.database import SessionLocal
# Import JiraLink from app.models.jira_link
from app.models.jira_link import JiraLink from app.models.jira_link import JiraLink
# Import jira_service from app.services
from app.services import jira_service from app.services import jira_service
# Assign logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Define function sync_all_jira_links
def sync_all_jira_links() -> None: def sync_all_jira_links() -> None:
"""Pull latest status from Jira for every stored link. """Pull latest status from Jira for every stored link.
Silently skips if ``JIRA_ENABLED`` is ``False``. Individual link Silently skips if ``JIRA_ENABLED`` is ``False``. Individual link
failures are logged but do not abort the rest of the batch. failures are logged but do not abort the rest of the batch.
""" """
# Check: not settings.JIRA_ENABLED
if not settings.JIRA_ENABLED: if not settings.JIRA_ENABLED:
# Return control to caller
return return
# Assign db = SessionLocal()
db = SessionLocal() db = SessionLocal()
# Attempt the following; catch errors below
try: try:
# Assign links = db.query(JiraLink).all()
links = db.query(JiraLink).all() links = db.query(JiraLink).all()
# Assign synced = 0
synced = 0 synced = 0
# Iterate over links
for link in links: for link in links:
# Attempt the following; catch errors below
try: try:
# Call jira_service.sync_jira_to_aegis()
jira_service.sync_jira_to_aegis(db, link) jira_service.sync_jira_to_aegis(db, link)
# Assign synced = 1
synced += 1 synced += 1
# Handle Exception
except Exception as e: except Exception as e:
# Log warning: "Jira sync failed for link %s: %s", link.id, e
logger.warning("Jira sync failed for link %s: %s", link.id, e) logger.warning("Jira sync failed for link %s: %s", link.id, e)
# Commit all pending changes to the database
db.commit() db.commit()
# Log info: "Jira sync completed: %d/%d links updated", synced
logger.info("Jira sync completed: %d/%d links updated", synced, len(links)) logger.info("Jira sync completed: %d/%d links updated", synced, len(links))
# Handle Exception
except Exception: except Exception:
# Log exception: "Jira sync batch job failed"
logger.exception("Jira sync batch job failed") logger.exception("Jira sync batch job failed")
# Always execute this cleanup block
finally: finally:
# Close the database session
db.close() db.close()
+165 -7
View File
@@ -10,21 +10,43 @@ Each job manages its own database session (created on entry, closed in
sessions. sessions.
""" """
# Import logging
import logging import logging
# Import BackgroundScheduler from apscheduler.schedulers.background
from apscheduler.schedulers.background import BackgroundScheduler from apscheduler.schedulers.background import BackgroundScheduler
# Import SessionLocal from app.database
from app.database import SessionLocal from app.database import SessionLocal
from app.services.mitre_sync_service import sync_mitre
from app.services.intel_service import scan_intel # Import sync_all_jira_links from app.jobs.jira_sync_job
from app.services.notification_service import cleanup_old_notifications
from app.services.snapshot_service import create_snapshot, cleanup_old_snapshots
from app.services.campaign_scheduler_service import check_and_run_recurring_campaigns
from app.jobs.jira_sync_job import sync_all_jira_links from app.jobs.jira_sync_job import sync_all_jira_links
from app.services.osint_enrichment_service import enrich_all_techniques
from app.services.stale_detection_service import detect_stale_coverage # Import run_retention_job from app.jobs.retention_job
from app.jobs.retention_job import run_retention_job from app.jobs.retention_job import run_retention_job
# Import check_and_run_recurring_campaigns from app.services.campaign_scheduler_service
from app.services.campaign_scheduler_service import check_and_run_recurring_campaigns
# Import scan_intel from app.services.intel_service
from app.services.intel_service import scan_intel
# Import sync_mitre from app.services.mitre_sync_service
from app.services.mitre_sync_service import sync_mitre
# Import cleanup_old_notifications from app.services.notification_service
from app.services.notification_service import cleanup_old_notifications
# Import enrich_all_techniques from app.services.osint_enrichment_service
from app.services.osint_enrichment_service import enrich_all_techniques
# Import cleanup_old_snapshots, create_snapshot from app.services.snapshot_service
from app.services.snapshot_service import cleanup_old_snapshots, create_snapshot
# Import detect_stale_coverage from app.services.stale_detection_service
from app.services.stale_detection_service import detect_stale_coverage
# Assign logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -41,99 +63,172 @@ scheduler = BackgroundScheduler()
def _run_mitre_sync() -> None: def _run_mitre_sync() -> None:
"""Execute a MITRE sync inside its own DB session.""" """Execute a MITRE sync inside its own DB session."""
# Log info: "Scheduled MITRE sync job starting..."
logger.info("Scheduled MITRE sync job starting...") logger.info("Scheduled MITRE sync job starting...")
# Assign db = SessionLocal()
db = SessionLocal() db = SessionLocal()
# Attempt the following; catch errors below
try: try:
# Assign summary = sync_mitre(db)
summary = sync_mitre(db) summary = sync_mitre(db)
# Log info: "Scheduled MITRE sync job finished — %s", summary
logger.info("Scheduled MITRE sync job finished — %s", summary) logger.info("Scheduled MITRE sync job finished — %s", summary)
# Handle Exception
except Exception: except Exception:
# Log exception: "Scheduled MITRE sync job failed"
logger.exception("Scheduled MITRE sync job failed") logger.exception("Scheduled MITRE sync job failed")
# Always execute this cleanup block
finally: finally:
# Close the database session
db.close() db.close()
# Define function _run_notification_cleanup
def _run_notification_cleanup() -> None: def _run_notification_cleanup() -> None:
"""Clean up old read notifications.""" """Clean up old read notifications."""
# Log info: "Scheduled notification cleanup job starting..."
logger.info("Scheduled notification cleanup job starting...") logger.info("Scheduled notification cleanup job starting...")
# Assign db = SessionLocal()
db = SessionLocal() db = SessionLocal()
# Attempt the following; catch errors below
try: try:
# Assign deleted = cleanup_old_notifications(db, days=90)
deleted = cleanup_old_notifications(db, days=90) deleted = cleanup_old_notifications(db, days=90)
# Log info: "Notification cleanup finished — deleted %d old no
logger.info("Notification cleanup finished — deleted %d old notifications", deleted) logger.info("Notification cleanup finished — deleted %d old notifications", deleted)
# Handle Exception
except Exception: except Exception:
# Log exception: "Notification cleanup job failed"
logger.exception("Notification cleanup job failed") logger.exception("Notification cleanup job failed")
# Always execute this cleanup block
finally: finally:
# Close the database session
db.close() db.close()
# Define function _run_weekly_snapshot
def _run_weekly_snapshot() -> None: def _run_weekly_snapshot() -> None:
"""Create a weekly coverage snapshot and clean up old ones.""" """Create a weekly coverage snapshot and clean up old ones."""
# Log info: "Scheduled weekly snapshot job starting..."
logger.info("Scheduled weekly snapshot job starting...") logger.info("Scheduled weekly snapshot job starting...")
# Assign db = SessionLocal()
db = SessionLocal() db = SessionLocal()
# Attempt the following; catch errors below
try: try:
# Assign snapshot = create_snapshot(db, name="Auto-weekly")
snapshot = create_snapshot(db, name="Auto-weekly") snapshot = create_snapshot(db, name="Auto-weekly")
# Log info:
logger.info( logger.info(
# Literal argument value
"Weekly snapshot created — score %.1f, %d techniques", "Weekly snapshot created — score %.1f, %d techniques",
snapshot.organization_score, snapshot.organization_score,
snapshot.total_techniques, snapshot.total_techniques,
) )
# Assign deleted = cleanup_old_snapshots(db, keep_last=52)
deleted = cleanup_old_snapshots(db, keep_last=52) deleted = cleanup_old_snapshots(db, keep_last=52)
# Check: deleted
if deleted: if deleted:
# Log info: "Cleaned up %d old snapshots", deleted
logger.info("Cleaned up %d old snapshots", deleted) logger.info("Cleaned up %d old snapshots", deleted)
# Handle Exception
except Exception: except Exception:
# Log exception: "Weekly snapshot job failed"
logger.exception("Weekly snapshot job failed") logger.exception("Weekly snapshot job failed")
# Always execute this cleanup block
finally: finally:
# Close the database session
db.close() db.close()
# Define function _run_recurring_campaigns
def _run_recurring_campaigns() -> None: def _run_recurring_campaigns() -> None:
"""Check and run any due recurring campaigns.""" """Check and run any due recurring campaigns."""
# Log info: "Scheduled recurring campaigns check starting..."
logger.info("Scheduled recurring campaigns check starting...") logger.info("Scheduled recurring campaigns check starting...")
# Assign db = SessionLocal()
db = SessionLocal() db = SessionLocal()
# Attempt the following; catch errors below
try: try:
# Assign spawned = check_and_run_recurring_campaigns(db)
spawned = check_and_run_recurring_campaigns(db) spawned = check_and_run_recurring_campaigns(db)
# Log info: "Recurring campaigns check finished — spawned %d c
logger.info("Recurring campaigns check finished — spawned %d campaigns", spawned) logger.info("Recurring campaigns check finished — spawned %d campaigns", spawned)
# Handle Exception
except Exception: except Exception:
# Log exception: "Recurring campaigns check failed"
logger.exception("Recurring campaigns check failed") logger.exception("Recurring campaigns check failed")
# Always execute this cleanup block
finally: finally:
# Close the database session
db.close() db.close()
# Define function _run_intel_scan
def _run_intel_scan() -> None: def _run_intel_scan() -> None:
"""Execute an intel scan inside its own DB session.""" """Execute an intel scan inside its own DB session."""
# Log info: "Scheduled intel scan job starting..."
logger.info("Scheduled intel scan job starting...") logger.info("Scheduled intel scan job starting...")
# Assign db = SessionLocal()
db = SessionLocal() db = SessionLocal()
# Attempt the following; catch errors below
try: try:
# Assign summary = scan_intel(db)
summary = scan_intel(db) summary = scan_intel(db)
# Log info: "Scheduled intel scan job finished — %s", summary
logger.info("Scheduled intel scan job finished — %s", summary) logger.info("Scheduled intel scan job finished — %s", summary)
# Handle Exception
except Exception: except Exception:
# Log exception: "Scheduled intel scan job failed"
logger.exception("Scheduled intel scan job failed") logger.exception("Scheduled intel scan job failed")
# Always execute this cleanup block
finally: finally:
# Close the database session
db.close() db.close()
# Define function _run_osint_enrichment
def _run_osint_enrichment() -> None: def _run_osint_enrichment() -> None:
"""Execute weekly OSINT enrichment inside its own DB session.""" """Execute weekly OSINT enrichment inside its own DB session."""
# Log info: "Scheduled OSINT enrichment job starting..."
logger.info("Scheduled OSINT enrichment job starting...") logger.info("Scheduled OSINT enrichment job starting...")
# Assign db = SessionLocal()
db = SessionLocal() db = SessionLocal()
# Attempt the following; catch errors below
try: try:
# Assign total = enrich_all_techniques(db)
total = enrich_all_techniques(db) total = enrich_all_techniques(db)
# Log info: "OSINT enrichment finished — %d new items", total
logger.info("OSINT enrichment finished — %d new items", total) logger.info("OSINT enrichment finished — %d new items", total)
# Handle Exception
except Exception: except Exception:
# Log exception: "OSINT enrichment job failed"
logger.exception("OSINT enrichment job failed") logger.exception("OSINT enrichment job failed")
# Always execute this cleanup block
finally: finally:
# Close the database session
db.close() db.close()
# Define function _run_stale_detection
def _run_stale_detection() -> None: def _run_stale_detection() -> None:
"""Execute daily stale coverage detection inside its own DB session.""" """Execute daily stale coverage detection inside its own DB session."""
# Log info: "Scheduled stale coverage detection starting..."
logger.info("Scheduled stale coverage detection starting...") logger.info("Scheduled stale coverage detection starting...")
# Assign db = SessionLocal()
db = SessionLocal() db = SessionLocal()
# Attempt the following; catch errors below
try: try:
# Assign count = detect_stale_coverage(db)
count = detect_stale_coverage(db) count = detect_stale_coverage(db)
# Log info: "Stale detection finished — %d techniques flagged"
logger.info("Stale detection finished — %d techniques flagged", count) logger.info("Stale detection finished — %d techniques flagged", count)
# Handle Exception
except Exception: except Exception:
# Log exception: "Stale coverage detection job failed"
logger.exception("Stale coverage detection job failed") logger.exception("Stale coverage detection job failed")
# Always execute this cleanup block
finally: finally:
# Close the database session
db.close() db.close()
@@ -152,85 +247,148 @@ def start_scheduler() -> None:
Neither job fires immediately on startup. Neither job fires immediately on startup.
""" """
# Call scheduler.add_job()
scheduler.add_job( scheduler.add_job(
_run_mitre_sync, _run_mitre_sync,
# Keyword argument: trigger
trigger="interval", trigger="interval",
# Keyword argument: hours
hours=24, hours=24,
# Keyword argument: id
id="mitre_sync", id="mitre_sync",
# Keyword argument: name
name="MITRE ATT&CK sync (every 24h)", name="MITRE ATT&CK sync (every 24h)",
# Keyword argument: replace_existing
replace_existing=True, replace_existing=True,
) )
# Call scheduler.add_job()
scheduler.add_job( scheduler.add_job(
_run_intel_scan, _run_intel_scan,
# Keyword argument: trigger
trigger="interval", trigger="interval",
# Keyword argument: weeks
weeks=1, weeks=1,
# Keyword argument: id
id="intel_scan", id="intel_scan",
# Keyword argument: name
name="Intel scan (every 7d)", name="Intel scan (every 7d)",
# Keyword argument: replace_existing
replace_existing=True, replace_existing=True,
) )
# Call scheduler.add_job()
scheduler.add_job( scheduler.add_job(
_run_notification_cleanup, _run_notification_cleanup,
# Keyword argument: trigger
trigger="interval", trigger="interval",
# Keyword argument: hours
hours=24, hours=24,
# Keyword argument: id
id="notification_cleanup", id="notification_cleanup",
# Keyword argument: name
name="Notification cleanup (daily)", name="Notification cleanup (daily)",
# Keyword argument: replace_existing
replace_existing=True, replace_existing=True,
) )
# Call scheduler.add_job()
scheduler.add_job( scheduler.add_job(
_run_weekly_snapshot, _run_weekly_snapshot,
# Keyword argument: trigger
trigger="cron", trigger="cron",
# Keyword argument: day_of_week
day_of_week="sun", day_of_week="sun",
# Keyword argument: hour
hour=0, hour=0,
# Keyword argument: minute
minute=0, minute=0,
# Keyword argument: id
id="weekly_snapshot", id="weekly_snapshot",
# Keyword argument: name
name="Weekly coverage snapshot (Sundays 00:00)", name="Weekly coverage snapshot (Sundays 00:00)",
# Keyword argument: replace_existing
replace_existing=True, replace_existing=True,
) )
# Call scheduler.add_job()
scheduler.add_job( scheduler.add_job(
_run_recurring_campaigns, _run_recurring_campaigns,
# Keyword argument: trigger
trigger="interval", trigger="interval",
# Keyword argument: hours
hours=24, hours=24,
# Keyword argument: id
id="recurring_campaigns", id="recurring_campaigns",
# Keyword argument: name
name="Recurring campaigns check (daily)", name="Recurring campaigns check (daily)",
# Keyword argument: replace_existing
replace_existing=True, replace_existing=True,
) )
# Call scheduler.add_job()
scheduler.add_job( scheduler.add_job(
sync_all_jira_links, sync_all_jira_links,
# Keyword argument: trigger
trigger="interval", trigger="interval",
# Keyword argument: hours
hours=1, hours=1,
# Keyword argument: id
id="jira_sync", id="jira_sync",
# Keyword argument: name
name="Jira link sync (hourly)", name="Jira link sync (hourly)",
# Keyword argument: replace_existing
replace_existing=True, replace_existing=True,
) )
# Call scheduler.add_job()
scheduler.add_job( scheduler.add_job(
_run_osint_enrichment, _run_osint_enrichment,
# Keyword argument: trigger
trigger="interval", trigger="interval",
# Keyword argument: weeks
weeks=1, weeks=1,
# Keyword argument: id
id="osint_enrichment", id="osint_enrichment",
# Keyword argument: name
name="OSINT enrichment (weekly)", name="OSINT enrichment (weekly)",
# Keyword argument: replace_existing
replace_existing=True, replace_existing=True,
) )
# Call scheduler.add_job()
scheduler.add_job( scheduler.add_job(
_run_stale_detection, _run_stale_detection,
# Keyword argument: trigger
trigger="interval", trigger="interval",
# Keyword argument: hours
hours=24, hours=24,
# Keyword argument: id
id="stale_detection", id="stale_detection",
# Keyword argument: name
name="Stale coverage detection (daily)", name="Stale coverage detection (daily)",
# Keyword argument: replace_existing
replace_existing=True, replace_existing=True,
) )
# Call scheduler.add_job()
scheduler.add_job( scheduler.add_job(
run_retention_job, run_retention_job,
# Keyword argument: trigger
trigger="interval", trigger="interval",
# Keyword argument: hours
hours=24, hours=24,
# Keyword argument: id
id="retention_policies", id="retention_policies",
# Keyword argument: name
name="Data retention policies (daily)", name="Data retention policies (daily)",
# Keyword argument: replace_existing
replace_existing=True, replace_existing=True,
) )
# Call scheduler.start()
scheduler.start() scheduler.start()
# Log info:
logger.info( logger.info(
# Literal argument value
"Background scheduler started — mitre_sync (24h), intel_scan (7d), " "Background scheduler started — mitre_sync (24h), intel_scan (7d), "
# Literal argument value
"notification_cleanup (24h), weekly_snapshot (Sundays 00:00), " "notification_cleanup (24h), weekly_snapshot (Sundays 00:00), "
# Literal argument value
"recurring_campaigns (daily), jira_sync (1h), " "recurring_campaigns (daily), jira_sync (1h), "
# Literal argument value
"osint_enrichment (weekly), stale_detection (daily), " "osint_enrichment (weekly), stale_detection (daily), "
# Literal argument value
"retention_policies (daily)" "retention_policies (daily)"
) )
+36
View File
@@ -1,53 +1,89 @@
"""Data retention policies — scheduled cleanup of aged records.""" """Data retention policies — scheduled cleanup of aged records."""
# Enable future language features for compatibility
from __future__ import annotations from __future__ import annotations
# Import logging
import logging import logging
# Import datetime, timedelta, timezone from datetime
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
# Import SessionLocal from app.database
from app.database import SessionLocal from app.database import SessionLocal
# Import AuditLog from app.models.audit
from app.models.audit import AuditLog from app.models.audit import AuditLog
# Import cleanup_old_notifications from app.services.notification_service
from app.services.notification_service import cleanup_old_notifications from app.services.notification_service import cleanup_old_notifications
# Assign logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Assign AUDIT_LOG_RETENTION_DAYS = 730
AUDIT_LOG_RETENTION_DAYS = 730 AUDIT_LOG_RETENTION_DAYS = 730
# Define function apply_retention_policies
def apply_retention_policies(db: Session) -> dict[str, int]: def apply_retention_policies(db: Session) -> dict[str, int]:
"""Apply retention rules. Commits the session before returning.""" """Apply retention rules. Commits the session before returning."""
# Assign cutoff = datetime.now(timezone.utc) - timedelta(days=AUDIT_LOG_RETENTION_DAYS)
cutoff = datetime.now(timezone.utc) - timedelta(days=AUDIT_LOG_RETENTION_DAYS) cutoff = datetime.now(timezone.utc) - timedelta(days=AUDIT_LOG_RETENTION_DAYS)
# Assign deleted_audit = (
deleted_audit = ( deleted_audit = (
db.query(AuditLog) db.query(AuditLog)
# Chain .filter() call
.filter(AuditLog.timestamp < cutoff) .filter(AuditLog.timestamp < cutoff)
# Chain .delete() call
.delete(synchronize_session=False) .delete(synchronize_session=False)
) )
# Check: deleted_audit
if deleted_audit: if deleted_audit:
# Log info:
logger.info( logger.info(
# Literal argument value
"Retention: deleted %d audit logs older than %d days", "Retention: deleted %d audit logs older than %d days",
deleted_audit, deleted_audit,
AUDIT_LOG_RETENTION_DAYS, AUDIT_LOG_RETENTION_DAYS,
) )
# Assign deleted_notifications = cleanup_old_notifications(db, days=90)
deleted_notifications = cleanup_old_notifications(db, days=90) deleted_notifications = cleanup_old_notifications(db, days=90)
# Commit all pending changes to the database
db.commit() db.commit()
# Return {
return { return {
# Literal argument value
"audit_logs_deleted": deleted_audit, "audit_logs_deleted": deleted_audit,
# Literal argument value
"notifications_deleted": deleted_notifications, "notifications_deleted": deleted_notifications,
} }
# Define function run_retention_job
def run_retention_job() -> None: def run_retention_job() -> None:
"""Entry point for the daily retention scheduler job.""" """Entry point for the daily retention scheduler job."""
# Log info: "Scheduled retention job starting..."
logger.info("Scheduled retention job starting...") logger.info("Scheduled retention job starting...")
# Assign db = SessionLocal()
db = SessionLocal() db = SessionLocal()
# Attempt the following; catch errors below
try: try:
# Assign summary = apply_retention_policies(db)
summary = apply_retention_policies(db) summary = apply_retention_policies(db)
# Log info: "Retention job finished — %s", summary
logger.info("Retention job finished — %s", summary) logger.info("Retention job finished — %s", summary)
# Handle Exception
except Exception: except Exception:
# Log exception: "Retention job failed"
logger.exception("Retention job failed") logger.exception("Retention job failed")
# Roll back all uncommitted changes
db.rollback() db.rollback()
# Always execute this cleanup block
finally: finally:
# Close the database session
db.close() db.close()
+4
View File
@@ -1,6 +1,10 @@
"""Shared SlowAPI rate limiter for all routers.""" """Shared SlowAPI rate limiter for all routers."""
# Import Limiter from slowapi
from slowapi import Limiter from slowapi import Limiter
# Import get_remote_address from slowapi.util
from slowapi.util import get_remote_address from slowapi.util import get_remote_address
# Assign limiter = Limiter(key_func=get_remote_address)
limiter = Limiter(key_func=get_remote_address) limiter = Limiter(key_func=get_remote_address)
+41
View File
@@ -8,60 +8,101 @@ In **development** (default), uses a human-readable text format for
comfortable local work. comfortable local work.
""" """
# Enable future language features for compatibility
from __future__ import annotations from __future__ import annotations
# Import json
import json import json
# Import logging
import logging import logging
# Import os
import os import os
# Import sys
import sys import sys
# Import datetime, timezone from datetime
from datetime import datetime, timezone from datetime import datetime, timezone
# Define class _JSONFormatter
class _JSONFormatter(logging.Formatter): class _JSONFormatter(logging.Formatter):
"""Emit each log record as a single-line JSON object.""" """Emit each log record as a single-line JSON object."""
# Define function format
def format(self, record: logging.LogRecord) -> str: def format(self, record: logging.LogRecord) -> str:
# Assign payload = {
payload: dict = { payload: dict = {
# Literal argument value
"timestamp": datetime.fromtimestamp(record.created, tz=timezone.utc).isoformat(), "timestamp": datetime.fromtimestamp(record.created, tz=timezone.utc).isoformat(),
# Literal argument value
"level": record.levelname, "level": record.levelname,
# Literal argument value
"logger": record.name, "logger": record.name,
# Literal argument value
"message": record.getMessage(), "message": record.getMessage(),
} }
# Check: record.exc_info and record.exc_info[1] is not None
if record.exc_info and record.exc_info[1] is not None: if record.exc_info and record.exc_info[1] is not None:
# Assign payload["exception"] = self.formatException(record.exc_info)
payload["exception"] = self.formatException(record.exc_info) payload["exception"] = self.formatException(record.exc_info)
# Assign extra = getattr(record, "_extra", None)
extra = getattr(record, "_extra", None) extra = getattr(record, "_extra", None)
# Check: extra
if extra: if extra:
# Call payload.update()
payload.update(extra) payload.update(extra)
# Return json.dumps(payload, default=str)
return json.dumps(payload, default=str) return json.dumps(payload, default=str)
# Assign _DEV_FORMAT = "%(asctime)s %(levelname)-8s %(name)s — %(message)s"
_DEV_FORMAT = "%(asctime)s %(levelname)-8s %(name)s%(message)s" _DEV_FORMAT = "%(asctime)s %(levelname)-8s %(name)s%(message)s"
# Define function setup_logging
def setup_logging() -> None: def setup_logging() -> None:
"""Configure the root logger based on the environment.""" """Configure the root logger based on the environment."""
# Assign is_production = os.environ.get("AEGIS_ENV", "").lower() == "production"
is_production = os.environ.get("AEGIS_ENV", "").lower() == "production" is_production = os.environ.get("AEGIS_ENV", "").lower() == "production"
# Assign level_name = os.environ.get("LOG_LEVEL", "INFO").upper()
level_name = os.environ.get("LOG_LEVEL", "INFO").upper() level_name = os.environ.get("LOG_LEVEL", "INFO").upper()
# Assign level = getattr(logging, level_name, logging.INFO)
level = getattr(logging, level_name, logging.INFO) level = getattr(logging, level_name, logging.INFO)
# Assign root = logging.getLogger()
root = logging.getLogger() root = logging.getLogger()
# Call root.setLevel()
root.setLevel(level) root.setLevel(level)
# Check: root.handlers
if root.handlers: if root.handlers:
# Call root.handlers.clear()
root.handlers.clear() root.handlers.clear()
# Assign handler = logging.StreamHandler(sys.stdout)
handler = logging.StreamHandler(sys.stdout) handler = logging.StreamHandler(sys.stdout)
# Call handler.setLevel()
handler.setLevel(level) handler.setLevel(level)
# Check: is_production
if is_production: if is_production:
# Call handler.setFormatter()
handler.setFormatter(_JSONFormatter()) handler.setFormatter(_JSONFormatter())
# Fallback: handle remaining cases
else: else:
# Call handler.setFormatter()
handler.setFormatter(logging.Formatter(_DEV_FORMAT)) handler.setFormatter(logging.Formatter(_DEV_FORMAT))
# Call root.addHandler()
root.addHandler(handler) root.addHandler(handler)
# Call logging.getLogger()
logging.getLogger("uvicorn.access").setLevel(logging.WARNING) logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
# Call logging.getLogger()
logging.getLogger("sqlalchemy.engine").setLevel(logging.WARNING) logging.getLogger("sqlalchemy.engine").setLevel(logging.WARNING)
+267 -50
View File
@@ -1,62 +1,171 @@
"""FastAPI application factory and global middleware/exception configuration.
Builds the ``app`` instance, wires up CORS, rate limiting, domain-error
mapping, all API routers, and async lifespan hooks (MinIO bucket creation,
APScheduler startup/shutdown).
"""
# Import logging
import logging import logging
# Import os
import os import os
# Import AsyncGenerator from collections.abc
from collections.abc import AsyncGenerator
# Import asynccontextmanager from contextlib
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
# Import FastAPI, Request, status from fastapi
from fastapi import FastAPI, Request, status from fastapi import FastAPI, Request, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse # Import RequestValidationError from fastapi.exceptions
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
# Import CORSMiddleware from fastapi.middleware.cors
from fastapi.middleware.cors import CORSMiddleware
# Import JSONResponse from fastapi.responses
from fastapi.responses import JSONResponse
# Import _rate_limit_exceeded_handler from slowapi
from slowapi import _rate_limit_exceeded_handler from slowapi import _rate_limit_exceeded_handler
# Import RateLimitExceeded from slowapi.errors
from slowapi.errors import RateLimitExceeded from slowapi.errors import RateLimitExceeded
# Import SQLAlchemyError from sqlalchemy.exc
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from app.routers import auth as auth_router # Import settings as _settings from app.config
from app.routers import techniques as techniques_router from app.config import settings as _settings
from app.routers import tests as tests_router
from app.routers import evidence as evidence_router # Import DomainError from app.domain.errors
from app.routers import test_templates as test_templates_router
from app.routers import system as system_router
from app.routers import metrics as metrics_router
from app.routers import users as users_router
from app.routers import audit as audit_router
from app.routers import notifications as notifications_router
from app.routers import reports as reports_router
from app.routers import data_sources as data_sources_router
from app.routers import threat_actors as threat_actors_router
from app.routers import d3fend as d3fend_router
from app.routers import detection_rules as detection_rules_router
from app.routers import campaigns as campaigns_router
from app.routers import heatmap as heatmap_router
from app.routers import scores as scores_router
from app.routers import operational_metrics as operational_metrics_router
from app.routers import compliance as compliance_router
from app.routers import snapshots as snapshots_router
from app.routers import jira as jira_router
from app.routers import worklogs as worklogs_router
from app.routers import professional_reports as professional_reports_router
from app.routers import analytics as analytics_router
from app.routers import advanced_metrics as advanced_metrics_router
from app.routers import osint as osint_router
from app.domain.errors import DomainError from app.domain.errors import DomainError
from app.middleware.error_handler import domain_exception_handler
from app.middleware.request_context import RequestContextMiddleware # Import scheduler, start_scheduler from app.jobs.mitre_sync_job
from app.jobs.mitre_sync_job import scheduler, start_scheduler
# Import limiter from app.limiter
from app.limiter import limiter from app.limiter import limiter
# Import setup_logging from app.logging_config
from app.logging_config import setup_logging
# Import domain_exception_handler from app.middleware.error_handler
from app.middleware.error_handler import domain_exception_handler
# Import RequestContextMiddleware from app.middleware.request_context
from app.middleware.request_context import RequestContextMiddleware
# Import advanced_metrics as advanced_metrics_router from app.routers
from app.routers import advanced_metrics as advanced_metrics_router
# Import analytics as analytics_router from app.routers
from app.routers import analytics as analytics_router
# Import audit as audit_router from app.routers
from app.routers import audit as audit_router
# Import auth as auth_router from app.routers
from app.routers import auth as auth_router
# Import campaigns as campaigns_router from app.routers
from app.routers import campaigns as campaigns_router
# Import compliance as compliance_router from app.routers
from app.routers import compliance as compliance_router
# Import d3fend as d3fend_router from app.routers
from app.routers import d3fend as d3fend_router
# Import data_sources as data_sources_router from app.routers
from app.routers import data_sources as data_sources_router
# Import detection_rules as detection_rules_router from app.routers
from app.routers import detection_rules as detection_rules_router
# Import evidence as evidence_router from app.routers
from app.routers import evidence as evidence_router
# Import heatmap as heatmap_router from app.routers
from app.routers import heatmap as heatmap_router
# Import jira as jira_router from app.routers
from app.routers import jira as jira_router
# Import metrics as metrics_router from app.routers
from app.routers import metrics as metrics_router
# Import notifications as notifications_router from app.routers
from app.routers import notifications as notifications_router
# Import operational_metrics as operational_metrics_router from app.routers
from app.routers import operational_metrics as operational_metrics_router
# Import osint as osint_router from app.routers
from app.routers import osint as osint_router
# Import professional_reports as professional_reports_ro... from app.routers
from app.routers import professional_reports as professional_reports_router
# Import reports as reports_router from app.routers
from app.routers import reports as reports_router
# Import scores as scores_router from app.routers
from app.routers import scores as scores_router
# Import snapshots as snapshots_router from app.routers
from app.routers import snapshots as snapshots_router
# Import system as system_router from app.routers
from app.routers import system as system_router
# Import techniques as techniques_router from app.routers
from app.routers import techniques as techniques_router
# Import test_templates as test_templates_router from app.routers
from app.routers import test_templates as test_templates_router
# Import tests as tests_router from app.routers
from app.routers import tests as tests_router
# Import threat_actors as threat_actors_router from app.routers
from app.routers import threat_actors as threat_actors_router
# Import users as users_router from app.routers
from app.routers import users as users_router
# Import worklogs as worklogs_router from app.routers
from app.routers import worklogs as worklogs_router
# Import ensure_bucket_exists from app.storage
from app.storage import ensure_bucket_exists from app.storage import ensure_bucket_exists
from app.jobs.mitre_sync_job import start_scheduler, scheduler
# Configure structured logging before any module initialises its own logger
setup_logging()
# ── Environment detection ───────────────────────────────────────────────── # ── Environment detection ─────────────────────────────────────────────────
_IS_PRODUCTION = os.environ.get("AEGIS_ENV", "").lower() == "production" _IS_PRODUCTION = os.environ.get("AEGIS_ENV", "").lower() == "production"
# ── Logging ─────────────────────────────────────────────────────────────── # Apply the @asynccontextmanager decorator
from app.logging_config import setup_logging
setup_logging()
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): # Define async function lifespan
"""Startup / shutdown logic.""" async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
"""Manage application startup and shutdown lifecycle.
Args:
app (FastAPI): The FastAPI application instance.
Yields:
None: Control is yielded to the running application.
"""
# Call ensure_bucket_exists()
ensure_bucket_exists() ensure_bucket_exists()
# Call start_scheduler()
start_scheduler() start_scheduler()
# Yield value
yield yield
# Graceful shutdown of the background scheduler # Graceful shutdown of the background scheduler
scheduler.shutdown(wait=False) scheduler.shutdown(wait=False)
@@ -64,74 +173,116 @@ async def lifespan(app: FastAPI):
# ── In production, disable Swagger UI and ReDoc to hide API surface ────── # ── In production, disable Swagger UI and ReDoc to hide API surface ──────
app = FastAPI( app = FastAPI(
# Keyword argument: title
title="Attack Coverage Platform", title="Attack Coverage Platform",
# Keyword argument: lifespan
lifespan=lifespan, lifespan=lifespan,
# Keyword argument: docs_url
docs_url=None if _IS_PRODUCTION else "/docs", docs_url=None if _IS_PRODUCTION else "/docs",
# Keyword argument: redoc_url
redoc_url=None if _IS_PRODUCTION else "/redoc", redoc_url=None if _IS_PRODUCTION else "/redoc",
# Keyword argument: openapi_url
openapi_url=None if _IS_PRODUCTION else "/openapi.json", openapi_url=None if _IS_PRODUCTION else "/openapi.json",
) )
# ── Rate Limiter ────────────────────────────────────────────────────────── # ── Rate Limiter ──────────────────────────────────────────────────────────
app.state.limiter = limiter app.state.limiter = limiter
# Call app.add_exception_handler()
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
# Call app.add_middleware()
app.add_middleware(RequestContextMiddleware) app.add_middleware(RequestContextMiddleware)
# ── Domain exception → HTTP mapping ────────────────────────────────────── # ── Domain exception → HTTP mapping ──────────────────────────────────────
app.add_exception_handler(DomainError, domain_exception_handler) app.add_exception_handler(DomainError, domain_exception_handler)
# ── CORS ────────────────────────────────────────────────────────────────── # ── CORS ──────────────────────────────────────────────────────────────────
from app.config import settings as _settings
_cors_origins: list[str] = [ _cors_origins: list[str] = [
o.strip() for o in _settings.CORS_ORIGINS.split(",") if o.strip() o.strip() for o in _settings.CORS_ORIGINS.split(",") if o.strip()
] ]
# Call app.add_middleware()
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
# Keyword argument: allow_origins
allow_origins=_cors_origins, allow_origins=_cors_origins,
# Keyword argument: allow_credentials
allow_credentials=True, allow_credentials=True,
# Keyword argument: allow_methods
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"], allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
# Keyword argument: allow_headers
allow_headers=["Authorization", "Content-Type"], allow_headers=["Authorization", "Content-Type"],
) )
# ── Routers ────────────────────────────────────────────────────────────── # ── Routers ──────────────────────────────────────────────────────────────
app.include_router(auth_router.router, prefix="/api/v1") app.include_router(auth_router.router, prefix="/api/v1")
# Call app.include_router()
app.include_router(techniques_router.router, prefix="/api/v1") app.include_router(techniques_router.router, prefix="/api/v1")
# Call app.include_router()
app.include_router(tests_router.router, prefix="/api/v1") app.include_router(tests_router.router, prefix="/api/v1")
# Call app.include_router()
app.include_router(evidence_router.router, prefix="/api/v1") app.include_router(evidence_router.router, prefix="/api/v1")
# Call app.include_router()
app.include_router(test_templates_router.router, prefix="/api/v1") app.include_router(test_templates_router.router, prefix="/api/v1")
# Call app.include_router()
app.include_router(system_router.router, prefix="/api/v1") app.include_router(system_router.router, prefix="/api/v1")
# Call app.include_router()
app.include_router(metrics_router.router, prefix="/api/v1") app.include_router(metrics_router.router, prefix="/api/v1")
# Call app.include_router()
app.include_router(users_router.router, prefix="/api/v1") app.include_router(users_router.router, prefix="/api/v1")
# Call app.include_router()
app.include_router(audit_router.router, prefix="/api/v1") app.include_router(audit_router.router, prefix="/api/v1")
# Call app.include_router()
app.include_router(notifications_router.router, prefix="/api/v1") app.include_router(notifications_router.router, prefix="/api/v1")
# Call app.include_router()
app.include_router(reports_router.router, prefix="/api/v1") app.include_router(reports_router.router, prefix="/api/v1")
# Call app.include_router()
app.include_router(data_sources_router.router, prefix="/api/v1") app.include_router(data_sources_router.router, prefix="/api/v1")
# Call app.include_router()
app.include_router(threat_actors_router.router, prefix="/api/v1") app.include_router(threat_actors_router.router, prefix="/api/v1")
# Call app.include_router()
app.include_router(d3fend_router.router, prefix="/api/v1") app.include_router(d3fend_router.router, prefix="/api/v1")
# Call app.include_router()
app.include_router(detection_rules_router.router, prefix="/api/v1") app.include_router(detection_rules_router.router, prefix="/api/v1")
# Call app.include_router()
app.include_router(campaigns_router.router, prefix="/api/v1") app.include_router(campaigns_router.router, prefix="/api/v1")
# Call app.include_router()
app.include_router(heatmap_router.router, prefix="/api/v1") app.include_router(heatmap_router.router, prefix="/api/v1")
# Call app.include_router()
app.include_router(scores_router.router, prefix="/api/v1") app.include_router(scores_router.router, prefix="/api/v1")
# Call app.include_router()
app.include_router(operational_metrics_router.router, prefix="/api/v1") app.include_router(operational_metrics_router.router, prefix="/api/v1")
# Call app.include_router()
app.include_router(compliance_router.router, prefix="/api/v1") app.include_router(compliance_router.router, prefix="/api/v1")
# Call app.include_router()
app.include_router(snapshots_router.router, prefix="/api/v1") app.include_router(snapshots_router.router, prefix="/api/v1")
# Call app.include_router()
app.include_router(jira_router.router, prefix="/api/v1") app.include_router(jira_router.router, prefix="/api/v1")
# Call app.include_router()
app.include_router(worklogs_router.router, prefix="/api/v1") app.include_router(worklogs_router.router, prefix="/api/v1")
# Call app.include_router()
app.include_router(professional_reports_router.router, prefix="/api/v1") app.include_router(professional_reports_router.router, prefix="/api/v1")
# Call app.include_router()
app.include_router(analytics_router.router, prefix="/api/v1") app.include_router(analytics_router.router, prefix="/api/v1")
# Call app.include_router()
app.include_router(advanced_metrics_router.router, prefix="/api/v1") app.include_router(advanced_metrics_router.router, prefix="/api/v1")
# Call app.include_router()
app.include_router(osint_router.router, prefix="/api/v1") app.include_router(osint_router.router, prefix="/api/v1")
# Apply the @app.get decorator
@app.get("/health", include_in_schema=False) @app.get("/health", include_in_schema=False)
def health(): # Define function health
"""Minimal health check — returns only an HTTP 200 with no service metadata. def health() -> dict[str, str]:
"""Return a minimal liveness probe response.
Access is restricted to internal networks at the Nginx level Access is restricted to internal networks at the Nginx level
(see ``frontend/nginx.conf``). (see ``frontend/nginx.conf``).
Returns:
dict[str, str]: A dict with ``{"status": "ok"}``.
""" """
# Return {"status": "ok"}
return {"status": "ok"} return {"status": "ok"}
@@ -139,51 +290,117 @@ def health():
def _serialize_validation_errors(exc: RequestValidationError) -> list[dict]: def _serialize_validation_errors(exc: RequestValidationError) -> list[dict]:
"""Return validation errors safe for JSON (no raw exception objects).""" """Return validation errors safe for JSON serialization.
Converts non-serializable values inside ``ctx`` dictionaries to strings
so the response body can be safely encoded.
Args:
exc (RequestValidationError): The Pydantic validation exception.
Returns:
list[dict]: A list of sanitised error detail dictionaries.
"""
# Assign serialized = []
serialized: list[dict] = [] serialized: list[dict] = []
# Iterate over exc.errors()
for err in exc.errors(): for err in exc.errors():
# Assign item = dict(err)
item = dict(err) item = dict(err)
# Assign ctx = item.get("ctx")
ctx = item.get("ctx") ctx = item.get("ctx")
# Check: isinstance(ctx, dict)
if isinstance(ctx, dict): if isinstance(ctx, dict):
# Assign item["ctx"] = {key: str(value) for key, value in ctx.items()}
item["ctx"] = {key: str(value) for key, value in ctx.items()} item["ctx"] = {key: str(value) for key, value in ctx.items()}
# Call serialized.append()
serialized.append(item) serialized.append(item)
# Return serialized
return serialized return serialized
# Apply the @app.exception_handler decorator
@app.exception_handler(RequestValidationError) @app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError): # Define async function validation_exception_handler
"""Handle validation errors with consistent format.""" async def validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse:
"""Handle Pydantic validation errors and return a structured 422 response.
Args:
request (Request): The incoming HTTP request.
exc (RequestValidationError): The caught validation exception.
Returns:
JSONResponse: A 422 response with a ``VALIDATION_ERROR`` code and error details.
"""
# Return JSONResponse(
return JSONResponse( return JSONResponse(
# Keyword argument: status_code
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
# Keyword argument: content
content={ content={
# Literal argument value
"detail": "Validation error", "detail": "Validation error",
# Literal argument value
"code": "VALIDATION_ERROR", "code": "VALIDATION_ERROR",
# Literal argument value
"errors": _serialize_validation_errors(exc), "errors": _serialize_validation_errors(exc),
}, },
) )
# Apply the @app.exception_handler decorator
@app.exception_handler(SQLAlchemyError) @app.exception_handler(SQLAlchemyError)
async def sqlalchemy_exception_handler(request: Request, exc: SQLAlchemyError): # Define async function sqlalchemy_exception_handler
"""Handle database errors.""" async def sqlalchemy_exception_handler(request: Request, exc: SQLAlchemyError) -> JSONResponse:
"""Handle SQLAlchemy database errors and return a structured 500 response.
Args:
request (Request): The incoming HTTP request.
exc (SQLAlchemyError): The caught SQLAlchemy exception.
Returns:
JSONResponse: A 500 response with a ``DATABASE_ERROR`` code.
"""
# Log error: f"Database error: {exc}"
logging.error(f"Database error: {exc}") logging.error(f"Database error: {exc}")
# Return JSONResponse(
return JSONResponse( return JSONResponse(
# Keyword argument: status_code
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
# Keyword argument: content
content={ content={
# Literal argument value
"detail": "Database error occurred", "detail": "Database error occurred",
# Literal argument value
"code": "DATABASE_ERROR", "code": "DATABASE_ERROR",
}, },
) )
# Apply the @app.exception_handler decorator
@app.exception_handler(Exception) @app.exception_handler(Exception)
async def general_exception_handler(request: Request, exc: Exception): # Define async function general_exception_handler
"""Handle all unhandled exceptions.""" async def general_exception_handler(request: Request, exc: Exception) -> JSONResponse:
"""Handle all otherwise-unhandled exceptions and return a structured 500 response.
Args:
request (Request): The incoming HTTP request.
exc (Exception): The unhandled exception.
Returns:
JSONResponse: A 500 response with an ``INTERNAL_ERROR`` code.
"""
# Log error: f"Unhandled exception: {exc}"
logging.error(f"Unhandled exception: {exc}") logging.error(f"Unhandled exception: {exc}")
# Return JSONResponse(
return JSONResponse( return JSONResponse(
# Keyword argument: status_code
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
# Keyword argument: content
content={ content={
# Literal argument value
"detail": "An internal server error occurred", "detail": "An internal server error occurred",
# Literal argument value
"code": "INTERNAL_ERROR", "code": "INTERNAL_ERROR",
}, },
) )
+1
View File
@@ -0,0 +1 @@
"""ASGI middleware components for request context, error handling, and rate limiting."""
+21
View File
@@ -5,9 +5,13 @@ domain-layer errors into structured JSON responses, keeping
the service layer free from FastAPI's ``HTTPException``. the service layer free from FastAPI's ``HTTPException``.
""" """
# Import Request from fastapi
from fastapi import Request from fastapi import Request
# Import JSONResponse from fastapi.responses
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
# Import from app.domain.errors
from app.domain.errors import ( from app.domain.errors import (
BusinessRuleViolation, BusinessRuleViolation,
DomainError, DomainError,
@@ -18,28 +22,45 @@ from app.domain.errors import (
PermissionViolation, PermissionViolation,
) )
# Assign EXCEPTION_STATUS_MAP = {
EXCEPTION_STATUS_MAP: dict[type[DomainError], int] = { EXCEPTION_STATUS_MAP: dict[type[DomainError], int] = {
# Entry: EntityNotFoundError
EntityNotFoundError: 404, EntityNotFoundError: 404,
# Entry: DuplicateEntityError
DuplicateEntityError: 409, DuplicateEntityError: 409,
# Entry: InvalidStateTransition
InvalidStateTransition: 400, InvalidStateTransition: 400,
# Entry: InvalidOperationError
InvalidOperationError: 400, InvalidOperationError: 400,
# Entry: BusinessRuleViolation
BusinessRuleViolation: 400, BusinessRuleViolation: 400,
# Entry: PermissionViolation
PermissionViolation: 403, PermissionViolation: 403,
} }
# Define async function domain_exception_handler
async def domain_exception_handler( async def domain_exception_handler(
# Entry: request
request: Request, request: Request,
# Entry: exc
exc: DomainError, exc: DomainError,
) -> JSONResponse: ) -> JSONResponse:
"""Convert a :class:`DomainError` into a JSON error response.""" """Convert a :class:`DomainError` into a JSON error response."""
# Assign status_code = EXCEPTION_STATUS_MAP.get(type(exc), 400)
status_code = EXCEPTION_STATUS_MAP.get(type(exc), 400) status_code = EXCEPTION_STATUS_MAP.get(type(exc), 400)
# Assign content = {"detail": exc.message, "code": exc.code}
content: dict = {"detail": exc.message, "code": exc.code} content: dict = {"detail": exc.message, "code": exc.code}
# Check: isinstance(exc, InvalidStateTransition)
if isinstance(exc, InvalidStateTransition): if isinstance(exc, InvalidStateTransition):
# Assign content["current_state"] = exc.current_state
content["current_state"] = exc.current_state content["current_state"] = exc.current_state
# Assign content["target_state"] = exc.target_state
content["target_state"] = exc.target_state content["target_state"] = exc.target_state
# Assign content["valid_transitions"] = exc.valid_transitions
content["valid_transitions"] = exc.valid_transitions content["valid_transitions"] = exc.valid_transitions
# Return JSONResponse(status_code=status_code, content=content)
return JSONResponse(status_code=status_code, content=content) return JSONResponse(status_code=status_code, content=content)
+50 -2
View File
@@ -1,26 +1,74 @@
"""Request context middleware — captures client IP and User-Agent per request.""" """Request context middleware — captures client IP and User-Agent per request."""
# Import Awaitable, Callable from collections.abc
from collections.abc import Awaitable, Callable
# Import ContextVar from contextvars
from contextvars import ContextVar from contextvars import ContextVar
# Import Request from fastapi
from fastapi import Request from fastapi import Request
# Import BaseHTTPMiddleware from starlette.middleware.base
from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import BaseHTTPMiddleware
# Import Response from starlette.responses
from starlette.responses import Response
# Assign request_ip = ContextVar("request_ip", default="")
request_ip: ContextVar[str] = ContextVar("request_ip", default="") request_ip: ContextVar[str] = ContextVar("request_ip", default="")
# Assign request_user_agent = ContextVar("request_user_agent", default="")
request_user_agent: ContextVar[str] = ContextVar("request_user_agent", default="") request_user_agent: ContextVar[str] = ContextVar("request_user_agent", default="")
# Define function resolve_client_ip
def resolve_client_ip(request: Request) -> str: def resolve_client_ip(request: Request) -> str:
"""Extract the client IP, honouring ``X-Forwarded-For`` when present.""" """Extract the real client IP, honouring ``X-Forwarded-For`` when present.
Args:
request (Request): The incoming Starlette/FastAPI request.
Returns:
str: The resolved client IP address, or ``"unknown"`` when unavailable.
"""
# Assign forwarded = request.headers.get("X-Forwarded-For")
forwarded = request.headers.get("X-Forwarded-For") forwarded = request.headers.get("X-Forwarded-For")
# Check: forwarded
if forwarded: if forwarded:
# Return forwarded.split(",")[0].strip()
return forwarded.split(",")[0].strip() return forwarded.split(",")[0].strip()
# Check: request.client
if request.client: if request.client:
# Return request.client.host
return request.client.host return request.client.host
# Return "unknown"
return "unknown" return "unknown"
# Define class RequestContextMiddleware
class RequestContextMiddleware(BaseHTTPMiddleware): class RequestContextMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next): """Middleware that captures client IP and User-Agent into context variables."""
# Define async function dispatch
async def dispatch(
self,
# Entry: request
request: Request,
# Entry: call_next
call_next: Callable[[Request], Awaitable[Response]],
) -> Response:
"""Store client IP and User-Agent in context vars for the current request.
Args:
request (Request): The incoming HTTP request.
call_next (Callable[[Request], Awaitable[Response]]): The next middleware or route handler.
Returns:
Response: The HTTP response produced by the downstream handler.
"""
# Call request_ip.set()
request_ip.set(resolve_client_ip(request)) request_ip.set(resolve_client_ip(request))
# Call request_user_agent.set()
request_user_agent.set(request.headers.get("User-Agent", "")) request_user_agent.set(request.headers.get("User-Agent", ""))
# Return await call_next(request)
return await call_next(request) return await call_next(request)
+80 -21
View File
@@ -1,37 +1,96 @@
"""SQLAlchemy ORM model definitions for all database tables."""
# Import all models here so Alembic can detect them # Import all models here so Alembic can detect them
from app.models.user import User
from app.models.technique import Technique
from app.models.test import Test
from app.models.test_template import TestTemplate
from app.models.evidence import Evidence
from app.models.intel import IntelItem
from app.models.audit import AuditLog from app.models.audit import AuditLog
from app.models.notification import Notification
from app.models.data_source import DataSource
from app.models.detection_rule import DetectionRule
from app.models.threat_actor import ThreatActor, ThreatActorTechnique
from app.models.defensive_technique import DefensiveTechnique, DefensiveTechniqueMapping
from app.models.test_template_detection_rule import TestTemplateDetectionRule
from app.models.test_detection_result import TestDetectionResult
from app.models.campaign import Campaign, CampaignTest
from app.models.compliance import ComplianceFramework, ComplianceControl, ComplianceControlMapping
from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState
from app.models.jira_link import JiraLink, JiraLinkEntityType, JiraSyncDirection
from app.models.worklog import Worklog
from app.models.osint_item import OsintItem
from app.models.scoring_config import ScoringConfig
from app.models.enums import TechniqueStatus, TestState, TestResult, TeamSide
# Import Campaign, CampaignTest from app.models.campaign
from app.models.campaign import Campaign, CampaignTest
# Import from app.models.compliance
from app.models.compliance import (
ComplianceControl,
ComplianceControlMapping,
ComplianceFramework,
)
# Import CoverageSnapshot, SnapshotTechniqueState from app.models.coverage_snapshot
from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState
# Import DataSource from app.models.data_source
from app.models.data_source import DataSource
# Import DefensiveTechnique, DefensiveTechniqueMapping from app.models.defensive_technique
from app.models.defensive_technique import DefensiveTechnique, DefensiveTechniqueMapping
# Import DetectionRule from app.models.detection_rule
from app.models.detection_rule import DetectionRule
# Import TeamSide, TechniqueStatus, TestResult, TestState from app.models.enums
from app.models.enums import TeamSide, TechniqueStatus, TestResult, TestState
# Import Evidence from app.models.evidence
from app.models.evidence import Evidence
# Import IntelItem from app.models.intel
from app.models.intel import IntelItem
# Import JiraLink, JiraLinkEntityType, JiraSyncDirection from app.models.jira_link
from app.models.jira_link import JiraLink, JiraLinkEntityType, JiraSyncDirection
# Import Notification from app.models.notification
from app.models.notification import Notification
# Import OsintItem from app.models.osint_item
from app.models.osint_item import OsintItem
# Import ScoringConfig from app.models.scoring_config
from app.models.scoring_config import ScoringConfig
# Import Technique from app.models.technique
from app.models.technique import Technique
# Import Test from app.models.test
from app.models.test import Test
# Import TestDetectionResult from app.models.test_detection_result
from app.models.test_detection_result import TestDetectionResult
# Import TestTemplate from app.models.test_template
from app.models.test_template import TestTemplate
# Import TestTemplateDetectionRule from app.models.test_template_detection_rule
from app.models.test_template_detection_rule import TestTemplateDetectionRule
# Import ThreatActor, ThreatActorTechnique from app.models.threat_actor
from app.models.threat_actor import ThreatActor, ThreatActorTechnique
# Import User from app.models.user
from app.models.user import User
# Import Worklog from app.models.worklog
from app.models.worklog import Worklog
# Assign __all__ = [
__all__ = [ __all__ = [
# Literal argument value
"User", "Technique", "Test", "TestTemplate", "Evidence", "User", "Technique", "Test", "TestTemplate", "Evidence",
# Literal argument value
"IntelItem", "AuditLog", "Notification", "DataSource", "IntelItem", "AuditLog", "Notification", "DataSource",
# Literal argument value
"DetectionRule", "ThreatActor", "ThreatActorTechnique", "DetectionRule", "ThreatActor", "ThreatActorTechnique",
# Literal argument value
"DefensiveTechnique", "DefensiveTechniqueMapping", "DefensiveTechnique", "DefensiveTechniqueMapping",
# Literal argument value
"TestTemplateDetectionRule", "TestDetectionResult", "TestTemplateDetectionRule", "TestDetectionResult",
# Literal argument value
"Campaign", "CampaignTest", "Campaign", "CampaignTest",
# Literal argument value
"ComplianceFramework", "ComplianceControl", "ComplianceControlMapping", "ComplianceFramework", "ComplianceControl", "ComplianceControlMapping",
# Literal argument value
"CoverageSnapshot", "SnapshotTechniqueState", "CoverageSnapshot", "SnapshotTechniqueState",
# Literal argument value
"JiraLink", "JiraLinkEntityType", "JiraSyncDirection", "JiraLink", "JiraLinkEntityType", "JiraSyncDirection",
# Literal argument value
"Worklog", "OsintItem", "ScoringConfig", "Worklog", "OsintItem", "ScoringConfig",
# Literal argument value
"TechniqueStatus", "TestState", "TestResult", "TeamSide", "TechniqueStatus", "TestState", "TestResult", "TeamSide",
] ]
+28 -5
View File
@@ -1,35 +1,58 @@
"""SQLAlchemy model for the audit log table."""
# Import uuid
import uuid import uuid
from sqlalchemy import Column, String, DateTime, ForeignKey, Index, func
from sqlalchemy.dialects.postgresql import UUID, JSONB # Import Column, DateTime, ForeignKey, Index, String, func from sqlalchemy
from sqlalchemy import Column, DateTime, ForeignKey, Index, String, func
# Import JSONB, UUID from sqlalchemy.dialects.postgresql
from sqlalchemy.dialects.postgresql import JSONB, UUID
# Import relationship from sqlalchemy.orm
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
# Import Base from app.database
from app.database import Base from app.database import Base
# Define class AuditLog
class AuditLog(Base): class AuditLog(Base):
""" """Audit log model for tracking all system actions.
Audit log model for tracking all system actions.
Records user actions, entity changes, and system events Records user actions, entity changes, and system events
for security auditing and compliance purposes. for security auditing and compliance purposes.
""" """
# Assign __tablename__ = "audit_logs"
__tablename__ = "audit_logs" __tablename__ = "audit_logs"
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
# Assign user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True) user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
# Assign action = Column(String, nullable=False)
action = Column(String, nullable=False) action = Column(String, nullable=False)
# Assign entity_type = Column(String, nullable=True)
entity_type = Column(String, nullable=True) entity_type = Column(String, nullable=True)
# Assign entity_id = Column(String, nullable=True)
entity_id = Column(String, nullable=True) entity_id = Column(String, nullable=True)
# Assign timestamp = Column(DateTime(timezone=True), server_default=func.now())
timestamp = Column(DateTime(timezone=True), server_default=func.now()) timestamp = Column(DateTime(timezone=True), server_default=func.now())
# Assign details = Column(JSONB, nullable=True)
details = Column(JSONB, nullable=True) details = Column(JSONB, nullable=True)
# Assign ip_address = Column(String(45), nullable=True)
ip_address = Column(String(45), nullable=True) ip_address = Column(String(45), nullable=True)
# Assign user_agent = Column(String(500), nullable=True)
user_agent = Column(String(500), nullable=True) user_agent = Column(String(500), nullable=True)
# Assign integrity_hash = Column(String(64), nullable=True)
integrity_hash = Column(String(64), nullable=True) integrity_hash = Column(String(64), nullable=True)
# Assign session_id = Column(String(100), nullable=True)
session_id = Column(String(100), nullable=True) session_id = Column(String(100), nullable=True)
# Relationships # Relationships
user = relationship("User") user = relationship("User")
# Assign __table_args__ = (
__table_args__ = ( __table_args__ = (
Index("ix_audit_logs_entity", "entity_type", "entity_id"), Index("ix_audit_logs_entity", "entity_type", "entity_id"),
Index("ix_audit_logs_timestamp", "timestamp"), Index("ix_audit_logs_timestamp", "timestamp"),
+85 -7
View File
@@ -4,20 +4,35 @@ Campaigns group multiple tests into a kill chain sequence,
enabling simulation of complete attack chains and APT emulations. enabling simulation of complete attack chains and APT emulations.
""" """
# Import uuid
import uuid import uuid
# Import from sqlalchemy
from sqlalchemy import ( from sqlalchemy import (
Column, String, Text, Integer, Boolean, DateTime, Boolean,
ForeignKey, Index, func, Column,
DateTime,
ForeignKey,
Index,
Integer,
String,
Text,
func,
) )
from sqlalchemy.dialects.postgresql import UUID, JSONB
# Import JSONB, UUID from sqlalchemy.dialects.postgresql
from sqlalchemy.dialects.postgresql import JSONB, UUID
# Import relationship from sqlalchemy.orm
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
# Import Base from app.database
from app.database import Base from app.database import Base
# Define class Campaign
class Campaign(Base): class Campaign(Base):
""" """A campaign groups multiple tests into a sequenced attack chain.
A campaign groups multiple tests into a sequenced attack chain.
Types: Types:
- custom: manually created campaign - custom: manually created campaign
@@ -31,61 +46,97 @@ class Campaign(Base):
- completed: all tests done - completed: all tests done
- archived: historical record - archived: historical record
""" """
# Assign __tablename__ = "campaigns"
__tablename__ = "campaigns" __tablename__ = "campaigns"
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
# Assign name = Column(String, nullable=False)
name = Column(String, nullable=False) name = Column(String, nullable=False)
# Assign description = Column(Text, nullable=True)
description = Column(Text, nullable=True) description = Column(Text, nullable=True)
# Assign type = Column(String, nullable=False, default="custom") # custom, ap...
type = Column(String, nullable=False, default="custom") # custom, apt_emulation, kill_chain, compliance type = Column(String, nullable=False, default="custom") # custom, apt_emulation, kill_chain, compliance
# Assign threat_actor_id = Column(
threat_actor_id = Column( threat_actor_id = Column(
UUID(as_uuid=True), UUID(as_uuid=True),
ForeignKey("threat_actors.id", ondelete="SET NULL"), ForeignKey("threat_actors.id", ondelete="SET NULL"),
# Keyword argument: nullable
nullable=True, nullable=True,
) )
# Assign status = Column(String, nullable=False, default="draft") # draft, activ...
status = Column(String, nullable=False, default="draft") # draft, active, completed, archived status = Column(String, nullable=False, default="draft") # draft, active, completed, archived
# Assign created_by = Column(
created_by = Column( created_by = Column(
UUID(as_uuid=True), UUID(as_uuid=True),
ForeignKey("users.id", ondelete="SET NULL"), ForeignKey("users.id", ondelete="SET NULL"),
# Keyword argument: nullable
nullable=True, nullable=True,
) )
# Assign scheduled_at = Column(DateTime, nullable=True)
scheduled_at = Column(DateTime, nullable=True) scheduled_at = Column(DateTime, nullable=True)
# Assign completed_at = Column(DateTime, nullable=True)
completed_at = Column(DateTime, nullable=True) completed_at = Column(DateTime, nullable=True)
# Assign target_platform = Column(String, nullable=True)
target_platform = Column(String, nullable=True) target_platform = Column(String, nullable=True)
# Assign tags = Column(JSONB, nullable=True, default=[])
tags = Column(JSONB, nullable=True, default=[]) tags = Column(JSONB, nullable=True, default=[])
# Assign created_at = Column(DateTime(timezone=True), server_default=func.now())
created_at = Column(DateTime(timezone=True), server_default=func.now()) created_at = Column(DateTime(timezone=True), server_default=func.now())
# Assign data_classification = Column(String(20), nullable=False, server_default="internal")
data_classification = Column(String(20), nullable=False, server_default="internal") data_classification = Column(String(20), nullable=False, server_default="internal")
# Recurring scheduling fields # Recurring scheduling fields
is_recurring = Column(Boolean, default=False) is_recurring = Column(Boolean, default=False)
# Assign recurrence_pattern = Column(String, nullable=True) # weekly, monthly, quarterly
recurrence_pattern = Column(String, nullable=True) # weekly, monthly, quarterly recurrence_pattern = Column(String, nullable=True) # weekly, monthly, quarterly
# Assign next_run_at = Column(DateTime, nullable=True)
next_run_at = Column(DateTime, nullable=True) next_run_at = Column(DateTime, nullable=True)
# Assign last_run_at = Column(DateTime, nullable=True)
last_run_at = Column(DateTime, nullable=True) last_run_at = Column(DateTime, nullable=True)
# Assign parent_campaign_id = Column(
parent_campaign_id = Column( parent_campaign_id = Column(
UUID(as_uuid=True), UUID(as_uuid=True),
ForeignKey("campaigns.id", ondelete="SET NULL"), ForeignKey("campaigns.id", ondelete="SET NULL"),
# Keyword argument: nullable
nullable=True, nullable=True,
) )
# Relationships # Relationships
threat_actor = relationship("ThreatActor") threat_actor = relationship("ThreatActor")
# Assign creator = relationship("User", foreign_keys=[created_by])
creator = relationship("User", foreign_keys=[created_by]) creator = relationship("User", foreign_keys=[created_by])
# Assign campaign_tests = relationship(
campaign_tests = relationship( campaign_tests = relationship(
# Literal argument value
"CampaignTest", "CampaignTest",
# Keyword argument: back_populates
back_populates="campaign", back_populates="campaign",
# Keyword argument: cascade
cascade="all, delete-orphan", cascade="all, delete-orphan",
# Keyword argument: order_by
order_by="CampaignTest.order_index", order_by="CampaignTest.order_index",
) )
# Assign parent_campaign = relationship(
parent_campaign = relationship( parent_campaign = relationship(
# Literal argument value
"Campaign", "Campaign",
# Keyword argument: remote_side
remote_side="Campaign.id", remote_side="Campaign.id",
# Keyword argument: foreign_keys
foreign_keys=[parent_campaign_id], foreign_keys=[parent_campaign_id],
) )
# Assign child_campaigns = relationship(
child_campaigns = relationship( child_campaigns = relationship(
# Literal argument value
"Campaign", "Campaign",
# Keyword argument: foreign_keys
foreign_keys=[parent_campaign_id], foreign_keys=[parent_campaign_id],
# Keyword argument: back_populates
back_populates="parent_campaign", back_populates="parent_campaign",
) )
# Assign __table_args__ = (
__table_args__ = ( __table_args__ = (
Index('ix_campaigns_status', 'status'), Index('ix_campaigns_status', 'status'),
Index('ix_campaigns_type', 'type'), Index('ix_campaigns_type', 'type'),
@@ -97,56 +148,83 @@ class Campaign(Base):
# Kill chain phases in order (for sorting and validation) # Kill chain phases in order (for sorting and validation)
KILL_CHAIN_PHASES = [ KILL_CHAIN_PHASES = [
# Literal argument value
"reconnaissance", "reconnaissance",
# Literal argument value
"resource_development", "resource_development",
# Literal argument value
"initial_access", "initial_access",
# Literal argument value
"execution", "execution",
# Literal argument value
"persistence", "persistence",
# Literal argument value
"privilege_escalation", "privilege_escalation",
# Literal argument value
"defense_evasion", "defense_evasion",
# Literal argument value
"credential_access", "credential_access",
# Literal argument value
"discovery", "discovery",
# Literal argument value
"lateral_movement", "lateral_movement",
# Literal argument value
"collection", "collection",
# Literal argument value
"command_and_control", "command_and_control",
# Literal argument value
"exfiltration", "exfiltration",
# Literal argument value
"impact", "impact",
] ]
# Define class CampaignTest
class CampaignTest(Base): class CampaignTest(Base):
""" """A test within a campaign, with ordering and dependency information.
A test within a campaign, with ordering and dependency information.
``depends_on`` creates a self-referential chain (A -> B -> C). ``depends_on`` creates a self-referential chain (A -> B -> C).
Circular dependencies are validated at the service layer. Circular dependencies are validated at the service layer.
""" """
# Assign __tablename__ = "campaign_tests"
__tablename__ = "campaign_tests" __tablename__ = "campaign_tests"
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
# Assign campaign_id = Column(
campaign_id = Column( campaign_id = Column(
UUID(as_uuid=True), UUID(as_uuid=True),
ForeignKey("campaigns.id", ondelete="CASCADE"), ForeignKey("campaigns.id", ondelete="CASCADE"),
# Keyword argument: nullable
nullable=False, nullable=False,
) )
# Assign test_id = Column(
test_id = Column( test_id = Column(
UUID(as_uuid=True), UUID(as_uuid=True),
ForeignKey("tests.id", ondelete="CASCADE"), ForeignKey("tests.id", ondelete="CASCADE"),
# Keyword argument: nullable
nullable=False, nullable=False,
) )
# Assign order_index = Column(Integer, nullable=False, default=0)
order_index = Column(Integer, nullable=False, default=0) order_index = Column(Integer, nullable=False, default=0)
# Assign depends_on = Column(
depends_on = Column( depends_on = Column(
UUID(as_uuid=True), UUID(as_uuid=True),
ForeignKey("campaign_tests.id", ondelete="SET NULL"), ForeignKey("campaign_tests.id", ondelete="SET NULL"),
# Keyword argument: nullable
nullable=True, nullable=True,
) )
# Assign phase = Column(String, nullable=True) # kill chain phase
phase = Column(String, nullable=True) # kill chain phase phase = Column(String, nullable=True) # kill chain phase
# Relationships # Relationships
campaign = relationship("Campaign", back_populates="campaign_tests") campaign = relationship("Campaign", back_populates="campaign_tests")
# Assign test = relationship("Test")
test = relationship("Test") test = relationship("Test")
# Assign dependency = relationship("CampaignTest", remote_side="CampaignTest.id")
dependency = relationship("CampaignTest", remote_side="CampaignTest.id") dependency = relationship("CampaignTest", remote_side="CampaignTest.id")
# Assign __table_args__ = (
__table_args__ = ( __table_args__ = (
Index('ix_campaign_tests_campaign', 'campaign_id'), Index('ix_campaign_tests_campaign', 'campaign_id'),
Index('ix_campaign_tests_test', 'test_id'), Index('ix_campaign_tests_test', 'test_id'),
+55 -2
View File
@@ -4,92 +4,145 @@ Maps compliance frameworks (NIST 800-53, DORA, NIS2, ISO 27001) to
MITRE ATT&CK techniques, enabling compliance gap analysis. MITRE ATT&CK techniques, enabling compliance gap analysis.
""" """
# Import uuid
import uuid import uuid
# Import from sqlalchemy
from sqlalchemy import ( from sqlalchemy import (
Column, String, Text, Boolean, DateTime, Boolean,
ForeignKey, Index, UniqueConstraint, func, Column,
DateTime,
ForeignKey,
Index,
String,
Text,
UniqueConstraint,
func,
) )
# Import UUID from sqlalchemy.dialects.postgresql
from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.dialects.postgresql import UUID
# Import relationship from sqlalchemy.orm
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
# Import Base from app.database
from app.database import Base from app.database import Base
# Define class ComplianceFramework
class ComplianceFramework(Base): class ComplianceFramework(Base):
"""A compliance framework (e.g. NIST 800-53, ISO 27001).""" """A compliance framework (e.g. NIST 800-53, ISO 27001)."""
# Assign __tablename__ = "compliance_frameworks"
__tablename__ = "compliance_frameworks" __tablename__ = "compliance_frameworks"
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
# Assign name = Column(String, unique=True, nullable=False)
name = Column(String, unique=True, nullable=False) name = Column(String, unique=True, nullable=False)
# Assign version = Column(String, nullable=True)
version = Column(String, nullable=True) version = Column(String, nullable=True)
# Assign description = Column(Text, nullable=True)
description = Column(Text, nullable=True) description = Column(Text, nullable=True)
# Assign url = Column(String, nullable=True)
url = Column(String, nullable=True) url = Column(String, nullable=True)
# Assign is_active = Column(Boolean, default=True)
is_active = Column(Boolean, default=True) is_active = Column(Boolean, default=True)
# Assign created_at = Column(DateTime(timezone=True), server_default=func.now())
created_at = Column(DateTime(timezone=True), server_default=func.now()) created_at = Column(DateTime(timezone=True), server_default=func.now())
# Relationships # Relationships
controls = relationship( controls = relationship(
# Literal argument value
"ComplianceControl", "ComplianceControl",
# Keyword argument: back_populates
back_populates="framework", back_populates="framework",
# Keyword argument: cascade
cascade="all, delete-orphan", cascade="all, delete-orphan",
) )
# Define class ComplianceControl
class ComplianceControl(Base): class ComplianceControl(Base):
"""A control within a compliance framework (e.g. AC-2, PR.AC-1).""" """A control within a compliance framework (e.g. AC-2, PR.AC-1)."""
# Assign __tablename__ = "compliance_controls"
__tablename__ = "compliance_controls" __tablename__ = "compliance_controls"
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
# Assign framework_id = Column(
framework_id = Column( framework_id = Column(
UUID(as_uuid=True), UUID(as_uuid=True),
ForeignKey("compliance_frameworks.id", ondelete="CASCADE"), ForeignKey("compliance_frameworks.id", ondelete="CASCADE"),
# Keyword argument: nullable
nullable=False, nullable=False,
) )
# Assign control_id = Column(String, nullable=False) # e.g. "AC-2"
control_id = Column(String, nullable=False) # e.g. "AC-2" control_id = Column(String, nullable=False) # e.g. "AC-2"
# Assign title = Column(String, nullable=False)
title = Column(String, nullable=False) title = Column(String, nullable=False)
# Assign description = Column(Text, nullable=True)
description = Column(Text, nullable=True) description = Column(Text, nullable=True)
# Assign category = Column(String, nullable=True)
category = Column(String, nullable=True) category = Column(String, nullable=True)
# Relationships # Relationships
framework = relationship("ComplianceFramework", back_populates="controls") framework = relationship("ComplianceFramework", back_populates="controls")
# Assign technique_mappings = relationship(
technique_mappings = relationship( technique_mappings = relationship(
# Literal argument value
"ComplianceControlMapping", "ComplianceControlMapping",
# Keyword argument: back_populates
back_populates="compliance_control", back_populates="compliance_control",
# Keyword argument: cascade
cascade="all, delete-orphan", cascade="all, delete-orphan",
) )
# Assign __table_args__ = (
__table_args__ = ( __table_args__ = (
Index('ix_compliance_controls_framework', 'framework_id'), Index('ix_compliance_controls_framework', 'framework_id'),
) )
# Define class ComplianceControlMapping
class ComplianceControlMapping(Base): class ComplianceControlMapping(Base):
"""Maps a compliance control to a MITRE ATT&CK technique.""" """Maps a compliance control to a MITRE ATT&CK technique."""
# Assign __tablename__ = "compliance_control_mappings"
__tablename__ = "compliance_control_mappings" __tablename__ = "compliance_control_mappings"
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
# Assign compliance_control_id = Column(
compliance_control_id = Column( compliance_control_id = Column(
UUID(as_uuid=True), UUID(as_uuid=True),
ForeignKey("compliance_controls.id", ondelete="CASCADE"), ForeignKey("compliance_controls.id", ondelete="CASCADE"),
# Keyword argument: nullable
nullable=False, nullable=False,
) )
# Assign technique_id = Column(
technique_id = Column( technique_id = Column(
UUID(as_uuid=True), UUID(as_uuid=True),
ForeignKey("techniques.id", ondelete="CASCADE"), ForeignKey("techniques.id", ondelete="CASCADE"),
# Keyword argument: nullable
nullable=False, nullable=False,
) )
# Relationships # Relationships
compliance_control = relationship( compliance_control = relationship(
# Literal argument value
"ComplianceControl", back_populates="technique_mappings" "ComplianceControl", back_populates="technique_mappings"
) )
# Assign technique = relationship("Technique")
technique = relationship("Technique") technique = relationship("Technique")
# Assign __table_args__ = (
__table_args__ = ( __table_args__ = (
Index('ix_compliance_mappings_control', 'compliance_control_id'), Index('ix_compliance_mappings_control', 'compliance_control_id'),
Index('ix_compliance_mappings_technique', 'technique_id'), Index('ix_compliance_mappings_technique', 'technique_id'),
UniqueConstraint( UniqueConstraint(
# Literal argument value
'compliance_control_id', 'technique_id', 'compliance_control_id', 'technique_id',
# Keyword argument: name
name='uq_control_technique', name='uq_control_technique',
), ),
) )
+51 -2
View File
@@ -5,76 +5,125 @@ SnapshotTechniqueState stores per-technique state (normalized, one row
per technique per snapshot) to avoid bloated JSONB fields. per technique per snapshot) to avoid bloated JSONB fields.
""" """
# Import uuid
import uuid import uuid
# Import from sqlalchemy
from sqlalchemy import ( from sqlalchemy import (
Column, String, Float, Integer, DateTime, Column,
ForeignKey, Index, func, DateTime,
Float,
ForeignKey,
Index,
Integer,
String,
func,
) )
# Import JSONB, UUID from sqlalchemy.dialects.postgresql
from sqlalchemy.dialects.postgresql import JSONB, UUID from sqlalchemy.dialects.postgresql import JSONB, UUID
# Import relationship from sqlalchemy.orm
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
# Import Base from app.database
from app.database import Base from app.database import Base
# Define class CoverageSnapshot
class CoverageSnapshot(Base): class CoverageSnapshot(Base):
"""A point-in-time snapshot of the organisation's overall coverage.""" """A point-in-time snapshot of the organisation's overall coverage."""
# Assign __tablename__ = "coverage_snapshots"
__tablename__ = "coverage_snapshots" __tablename__ = "coverage_snapshots"
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
# Assign name = Column(String, nullable=True) # e.g. "Pre-remediación Q1"
name = Column(String, nullable=True) # e.g. "Pre-remediación Q1" name = Column(String, nullable=True) # e.g. "Pre-remediación Q1"
# Assign organization_score = Column(Float, nullable=False)
organization_score = Column(Float, nullable=False) organization_score = Column(Float, nullable=False)
# Assign total_techniques = Column(Integer, nullable=False)
total_techniques = Column(Integer, nullable=False) total_techniques = Column(Integer, nullable=False)
# Assign validated_count = Column(Integer, nullable=False)
validated_count = Column(Integer, nullable=False) validated_count = Column(Integer, nullable=False)
# Assign partial_count = Column(Integer, nullable=False)
partial_count = Column(Integer, nullable=False) partial_count = Column(Integer, nullable=False)
# Assign not_covered_count = Column(Integer, nullable=False)
not_covered_count = Column(Integer, nullable=False) not_covered_count = Column(Integer, nullable=False)
# Assign in_progress_count = Column(Integer, nullable=False)
in_progress_count = Column(Integer, nullable=False) in_progress_count = Column(Integer, nullable=False)
# Assign not_evaluated_count = Column(Integer, nullable=False)
not_evaluated_count = Column(Integer, nullable=False) not_evaluated_count = Column(Integer, nullable=False)
# Assign coverage_percentage = Column(Float, nullable=False, default=0.0)
coverage_percentage = Column(Float, nullable=False, default=0.0) coverage_percentage = Column(Float, nullable=False, default=0.0)
# Assign by_tactic = Column(JSONB, nullable=False, default=dict)
by_tactic = Column(JSONB, nullable=False, default=dict) by_tactic = Column(JSONB, nullable=False, default=dict)
# Assign by_status = Column(JSONB, nullable=False, default=dict)
by_status = Column(JSONB, nullable=False, default=dict) by_status = Column(JSONB, nullable=False, default=dict)
# Assign stale_count = Column(Integer, nullable=False, default=0)
stale_count = Column(Integer, nullable=False, default=0) stale_count = Column(Integer, nullable=False, default=0)
# Assign never_tested_count = Column(Integer, nullable=False, default=0)
never_tested_count = Column(Integer, nullable=False, default=0) never_tested_count = Column(Integer, nullable=False, default=0)
# Assign created_by = Column(
created_by = Column( created_by = Column(
UUID(as_uuid=True), UUID(as_uuid=True),
ForeignKey("users.id", ondelete="SET NULL"), ForeignKey("users.id", ondelete="SET NULL"),
# Keyword argument: nullable
nullable=True, nullable=True,
) )
# Assign created_at = Column(DateTime(timezone=True), server_default=func.now())
created_at = Column(DateTime(timezone=True), server_default=func.now()) created_at = Column(DateTime(timezone=True), server_default=func.now())
# Relationships # Relationships
creator = relationship("User", foreign_keys=[created_by]) creator = relationship("User", foreign_keys=[created_by])
# Assign technique_states = relationship(
technique_states = relationship( technique_states = relationship(
# Literal argument value
"SnapshotTechniqueState", "SnapshotTechniqueState",
# Keyword argument: back_populates
back_populates="snapshot", back_populates="snapshot",
# Keyword argument: cascade
cascade="all, delete-orphan", cascade="all, delete-orphan",
) )
# Define class SnapshotTechniqueState
class SnapshotTechniqueState(Base): class SnapshotTechniqueState(Base):
"""Per-technique state within a snapshot (normalised storage).""" """Per-technique state within a snapshot (normalised storage)."""
# Assign __tablename__ = "snapshot_technique_states"
__tablename__ = "snapshot_technique_states" __tablename__ = "snapshot_technique_states"
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
# Assign snapshot_id = Column(
snapshot_id = Column( snapshot_id = Column(
UUID(as_uuid=True), UUID(as_uuid=True),
ForeignKey("coverage_snapshots.id", ondelete="CASCADE"), ForeignKey("coverage_snapshots.id", ondelete="CASCADE"),
# Keyword argument: nullable
nullable=False, nullable=False,
) )
# Assign technique_id = Column(
technique_id = Column( technique_id = Column(
UUID(as_uuid=True), UUID(as_uuid=True),
ForeignKey("techniques.id", ondelete="CASCADE"), ForeignKey("techniques.id", ondelete="CASCADE"),
# Keyword argument: nullable
nullable=False, nullable=False,
) )
# Assign mitre_id = Column(String, nullable=False) # denormalised for fast queries
mitre_id = Column(String, nullable=False) # denormalised for fast queries mitre_id = Column(String, nullable=False) # denormalised for fast queries
# Assign status = Column(String, nullable=False)
status = Column(String, nullable=False) status = Column(String, nullable=False)
# Assign score = Column(Float, nullable=True)
score = Column(Float, nullable=True) score = Column(Float, nullable=True)
# Relationships # Relationships
snapshot = relationship("CoverageSnapshot", back_populates="technique_states") snapshot = relationship("CoverageSnapshot", back_populates="technique_states")
# Assign technique = relationship("Technique")
technique = relationship("Technique") technique = relationship("Technique")
# Assign __table_args__ = (
__table_args__ = ( __table_args__ = (
Index("ix_snapshot_technique_states_snapshot", "snapshot_id"), Index("ix_snapshot_technique_states_snapshot", "snapshot_id"),
Index("ix_snapshot_technique_states_technique", "technique_id"), Index("ix_snapshot_technique_states_technique", "technique_id"),
+28 -8
View File
@@ -1,36 +1,56 @@
"""DataSource model — registry of external data sources for import.""" """DataSource model — registry of external data sources for import."""
# Import uuid
import uuid import uuid
from sqlalchemy import Column, String, Text, Boolean, DateTime, Index, func
from sqlalchemy.dialects.postgresql import UUID, JSONB
# Import Boolean, Column, DateTime, Index, String, Text,... from sqlalchemy
from sqlalchemy import Boolean, Column, DateTime, Index, String, Text, func
# Import JSONB, UUID from sqlalchemy.dialects.postgresql
from sqlalchemy.dialects.postgresql import JSONB, UUID
# Import Base from app.database
from app.database import Base from app.database import Base
# Define class DataSource
class DataSource(Base): class DataSource(Base):
""" """Unified registry of all external data sources.
Unified registry of all external data sources (attack procedures,
detection rules, threat intel, defensive techniques).
Each source can be independently enabled/disabled and tracks its own Covers attack procedures, detection rules, threat intel, and defensive techniques.
synchronisation state. Each source can be independently enabled/disabled and tracks its own synchronisation state.
""" """
# Assign __tablename__ = "data_sources"
__tablename__ = "data_sources" __tablename__ = "data_sources"
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
# Assign name = Column(String, unique=True, nullable=False) # e.g. "atom...
name = Column(String, unique=True, nullable=False) # e.g. "atomic_red_team" name = Column(String, unique=True, nullable=False) # e.g. "atomic_red_team"
# Assign display_name = Column(String, nullable=False) # e.g. "Atomic Red ...
display_name = Column(String, nullable=False) # e.g. "Atomic Red Team" display_name = Column(String, nullable=False) # e.g. "Atomic Red Team"
type = Column(String, nullable=False) # attack_procedure / detection_rule / threat_intel / defensive_technique # Values: attack_procedure / detection_rule / threat_intel / defensive_technique
type = Column(String, nullable=False)
# Assign url = Column(String, nullable=True) # URL base...
url = Column(String, nullable=True) # URL base of repo/API url = Column(String, nullable=True) # URL base of repo/API
# Assign description = Column(Text, nullable=True)
description = Column(Text, nullable=True) description = Column(Text, nullable=True)
# Assign is_enabled = Column(Boolean, default=True)
is_enabled = Column(Boolean, default=True) is_enabled = Column(Boolean, default=True)
# Assign last_sync_at = Column(DateTime, nullable=True)
last_sync_at = Column(DateTime, nullable=True) last_sync_at = Column(DateTime, nullable=True)
# Assign last_sync_status = Column(String, nullable=True) # success / error / in_...
last_sync_status = Column(String, nullable=True) # success / error / in_progress last_sync_status = Column(String, nullable=True) # success / error / in_progress
# Assign last_sync_stats = Column(JSONB, nullable=True) # {"imported": X, "upd...
last_sync_stats = Column(JSONB, nullable=True) # {"imported": X, "updated": Y, ...} last_sync_stats = Column(JSONB, nullable=True) # {"imported": X, "updated": Y, ...}
# Assign sync_frequency = Column(String, nullable=True) # daily / weekly / mo...
sync_frequency = Column(String, nullable=True) # daily / weekly / monthly / manual sync_frequency = Column(String, nullable=True) # daily / weekly / monthly / manual
# Assign config = Column(JSONB, nullable=True) # source-spec...
config = Column(JSONB, nullable=True) # source-specific configuration config = Column(JSONB, nullable=True) # source-specific configuration
# Assign created_at = Column(DateTime(timezone=True), server_default=func.now())
created_at = Column(DateTime(timezone=True), server_default=func.now()) created_at = Column(DateTime(timezone=True), server_default=func.now())
# Assign __table_args__ = (
__table_args__ = ( __table_args__ = (
Index('ix_data_sources_type', 'type'), Index('ix_data_sources_type', 'type'),
Index('ix_data_sources_is_enabled', 'is_enabled'), Index('ix_data_sources_is_enabled', 'is_enabled'),
+42 -8
View File
@@ -4,74 +4,108 @@ Stores MITRE D3FEND defensive techniques and their mappings to
ATT&CK techniques, enabling recommended countermeasure lookups. ATT&CK techniques, enabling recommended countermeasure lookups.
""" """
# Import uuid
import uuid import uuid
# Import from sqlalchemy
from sqlalchemy import ( from sqlalchemy import (
Column, String, Text, DateTime, Column,
ForeignKey, Index, UniqueConstraint, func, DateTime,
ForeignKey,
Index,
String,
Text,
UniqueConstraint,
func,
) )
# Import UUID from sqlalchemy.dialects.postgresql
from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.dialects.postgresql import UUID
# Import relationship from sqlalchemy.orm
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
# Import Base from app.database
from app.database import Base from app.database import Base
# Define class DefensiveTechnique
class DefensiveTechnique(Base): class DefensiveTechnique(Base):
""" """MITRE D3FEND defensive technique.
MITRE D3FEND defensive technique.
Represents a countermeasure from the D3FEND framework that can be Represents a countermeasure from the D3FEND framework that can be
mapped to one or more ATT&CK techniques via DefensiveTechniqueMapping. mapped to one or more ATT&CK techniques via DefensiveTechniqueMapping.
""" """
# Assign __tablename__ = "defensive_techniques"
__tablename__ = "defensive_techniques" __tablename__ = "defensive_techniques"
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
# Assign d3fend_id = Column(String, unique=True, nullable=False) # e.g. "D3-AL"
d3fend_id = Column(String, unique=True, nullable=False) # e.g. "D3-AL" d3fend_id = Column(String, unique=True, nullable=False) # e.g. "D3-AL"
# Assign name = Column(String, nullable=False)
name = Column(String, nullable=False) name = Column(String, nullable=False)
# Assign description = Column(Text, nullable=True)
description = Column(Text, nullable=True) description = Column(Text, nullable=True)
# Assign tactic = Column(String, nullable=True) # Detect, ...
tactic = Column(String, nullable=True) # Detect, Isolate, Deceive, Evict, etc. tactic = Column(String, nullable=True) # Detect, Isolate, Deceive, Evict, etc.
# Assign d3fend_url = Column(String, nullable=True)
d3fend_url = Column(String, nullable=True) d3fend_url = Column(String, nullable=True)
# Assign created_at = Column(DateTime(timezone=True), server_default=func.now())
created_at = Column(DateTime(timezone=True), server_default=func.now()) created_at = Column(DateTime(timezone=True), server_default=func.now())
# Relationships # Relationships
attack_mappings = relationship( attack_mappings = relationship(
# Literal argument value
"DefensiveTechniqueMapping", "DefensiveTechniqueMapping",
# Keyword argument: back_populates
back_populates="defensive_technique", back_populates="defensive_technique",
# Keyword argument: cascade
cascade="all, delete-orphan", cascade="all, delete-orphan",
) )
# Assign __table_args__ = (
__table_args__ = ( __table_args__ = (
Index('ix_defensive_techniques_tactic', 'tactic'), Index('ix_defensive_techniques_tactic', 'tactic'),
) )
# Define class DefensiveTechniqueMapping
class DefensiveTechniqueMapping(Base): class DefensiveTechniqueMapping(Base):
""" """Association between a MITRE ATT&CK technique and a D3FEND defensive technique."""
Association between a MITRE ATT&CK technique and a D3FEND # Assign __tablename__ = "defensive_technique_mappings"
defensive technique.
"""
__tablename__ = "defensive_technique_mappings" __tablename__ = "defensive_technique_mappings"
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
# Assign attack_technique_id = Column(
attack_technique_id = Column( attack_technique_id = Column(
UUID(as_uuid=True), UUID(as_uuid=True),
ForeignKey("techniques.id", ondelete="CASCADE"), ForeignKey("techniques.id", ondelete="CASCADE"),
# Keyword argument: nullable
nullable=False, nullable=False,
) )
# Assign defensive_technique_id = Column(
defensive_technique_id = Column( defensive_technique_id = Column(
UUID(as_uuid=True), UUID(as_uuid=True),
ForeignKey("defensive_techniques.id", ondelete="CASCADE"), ForeignKey("defensive_techniques.id", ondelete="CASCADE"),
# Keyword argument: nullable
nullable=False, nullable=False,
) )
# Relationships # Relationships
attack_technique = relationship("Technique") attack_technique = relationship("Technique")
# Assign defensive_technique = relationship("DefensiveTechnique", back_populates="attack_mappings")
defensive_technique = relationship("DefensiveTechnique", back_populates="attack_mappings") defensive_technique = relationship("DefensiveTechnique", back_populates="attack_mappings")
# Assign __table_args__ = (
__table_args__ = ( __table_args__ = (
Index('ix_dtm_attack_technique', 'attack_technique_id'), Index('ix_dtm_attack_technique', 'attack_technique_id'),
Index('ix_dtm_defensive_technique', 'defensive_technique_id'), Index('ix_dtm_defensive_technique', 'defensive_technique_id'),
UniqueConstraint( UniqueConstraint(
# Literal argument value
'attack_technique_id', 'defensive_technique_id', 'attack_technique_id', 'defensive_technique_id',
# Keyword argument: name
name='uq_attack_defensive_technique', name='uq_attack_defensive_technique',
), ),
) )
+27 -4
View File
@@ -1,38 +1,61 @@
"""DetectionRule model — detection rules from multiple sources.""" """DetectionRule model — detection rules from multiple sources."""
# Import uuid
import uuid import uuid
from sqlalchemy import Column, String, Text, Boolean, DateTime, Index, func
from sqlalchemy.dialects.postgresql import UUID, JSONB
# Import Boolean, Column, DateTime, Index, String, Text,... from sqlalchemy
from sqlalchemy import Boolean, Column, DateTime, Index, String, Text, func
# Import JSONB, UUID from sqlalchemy.dialects.postgresql
from sqlalchemy.dialects.postgresql import JSONB, UUID
# Import Base from app.database
from app.database import Base from app.database import Base
# Define class DetectionRule
class DetectionRule(Base): class DetectionRule(Base):
""" """Detection rule from an external source (Sigma, Elastic, Splunk, custom).
Detection rule from an external source (Sigma, Elastic, Splunk, custom).
Each rule is mapped to one MITRE ATT&CK technique via Each rule is mapped to one MITRE ATT&CK technique via
``mitre_technique_id`` and stores the complete rule content in ``mitre_technique_id`` and stores the complete rule content in
``rule_content``. ``rule_content``.
""" """
# Assign __tablename__ = "detection_rules"
__tablename__ = "detection_rules" __tablename__ = "detection_rules"
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
# Assign mitre_technique_id = Column(String, nullable=False) # e.g. "T1059.001"
mitre_technique_id = Column(String, nullable=False) # e.g. "T1059.001" mitre_technique_id = Column(String, nullable=False) # e.g. "T1059.001"
# Assign title = Column(String, nullable=False)
title = Column(String, nullable=False) title = Column(String, nullable=False)
# Assign description = Column(Text, nullable=True)
description = Column(Text, nullable=True) description = Column(Text, nullable=True)
# Assign source = Column(String, nullable=False) # sigma / ela...
source = Column(String, nullable=False) # sigma / elastic / splunk / custom source = Column(String, nullable=False) # sigma / elastic / splunk / custom
# Assign source_id = Column(String, nullable=True) # ID in the sour...
source_id = Column(String, nullable=True) # ID in the source repo (for dedup) source_id = Column(String, nullable=True) # ID in the source repo (for dedup)
# Assign source_url = Column(String, nullable=True)
source_url = Column(String, nullable=True) source_url = Column(String, nullable=True)
# Assign rule_content = Column(Text, nullable=False) # YAML / KQL / SPL ...
rule_content = Column(Text, nullable=False) # YAML / KQL / SPL content rule_content = Column(Text, nullable=False) # YAML / KQL / SPL content
# Assign rule_format = Column(String, nullable=False) # sigma_yaml / kql...
rule_format = Column(String, nullable=False) # sigma_yaml / kql / spl / custom rule_format = Column(String, nullable=False) # sigma_yaml / kql / spl / custom
# Assign severity = Column(String, nullable=True) # informational...
severity = Column(String, nullable=True) # informational / low / medium / high / critical severity = Column(String, nullable=True) # informational / low / medium / high / critical
# Assign platforms = Column(JSONB, nullable=True, default=[])
platforms = Column(JSONB, nullable=True, default=[]) platforms = Column(JSONB, nullable=True, default=[])
# Assign log_sources = Column(JSONB, nullable=True) # e.g. {"product":...
log_sources = Column(JSONB, nullable=True) # e.g. {"product": "windows", "service": "sysmon"} log_sources = Column(JSONB, nullable=True) # e.g. {"product": "windows", "service": "sysmon"}
# Assign false_positive_rate = Column(String, nullable=True) # low / medium / high
false_positive_rate = Column(String, nullable=True) # low / medium / high false_positive_rate = Column(String, nullable=True) # low / medium / high
# Assign is_active = Column(Boolean, default=True)
is_active = Column(Boolean, default=True) is_active = Column(Boolean, default=True)
# Assign created_at = Column(DateTime(timezone=True), server_default=func.now())
created_at = Column(DateTime(timezone=True), server_default=func.now()) created_at = Column(DateTime(timezone=True), server_default=func.now())
# Assign __table_args__ = (
__table_args__ = ( __table_args__ = (
Index('ix_detection_rules_mitre_technique_id', 'mitre_technique_id'), Index('ix_detection_rules_mitre_technique_id', 'mitre_technique_id'),
Index('ix_detection_rules_source', 'source'), Index('ix_detection_rules_source', 'source'),
+1
View File
@@ -5,6 +5,7 @@ re-exports every enum so that existing model and router code keeps
working with ``from app.models.enums import ...``. working with ``from app.models.enums import ...``.
""" """
# Import # noqa: F401 from app.domain.enums
from app.domain.enums import ( # noqa: F401 from app.domain.enums import ( # noqa: F401
DataClassification, DataClassification,
TeamSide, TeamSide,
+29 -5
View File
@@ -1,35 +1,59 @@
"""SQLAlchemy model for the evidence table."""
# Import uuid
import uuid import uuid
from sqlalchemy import Column, String, Text, DateTime, ForeignKey, Enum, func
# Import Column, DateTime, Enum, ForeignKey, String, Tex... from sqlalchemy
from sqlalchemy import Column, DateTime, Enum, ForeignKey, String, Text, func
# Import UUID from sqlalchemy.dialects.postgresql
from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.dialects.postgresql import UUID
# Import relationship from sqlalchemy.orm
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
# Import Base from app.database
from app.database import Base from app.database import Base
# Import TeamSide from app.models.enums
from app.models.enums import TeamSide from app.models.enums import TeamSide
# Define class Evidence
class Evidence(Base): class Evidence(Base):
""" """Evidence model for storing file metadata associated with tests.
Evidence model for storing file metadata associated with tests.
Files are stored in MinIO, and this model tracks the file location, Files are stored in MinIO, and this model tracks the file location,
integrity hash, and upload metadata. integrity hash, and upload metadata.
The ``team`` field distinguishes whether this evidence was uploaded by The ``team`` field distinguishes whether this evidence was uploaded by
Red Team (attack evidence) or Blue Team (detection evidence). Red Team (attack evidence) or Blue Team (detection evidence).
""" """
# Assign __tablename__ = "evidences"
__tablename__ = "evidences" __tablename__ = "evidences"
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
# Assign test_id = Column(UUID(as_uuid=True), ForeignKey("tests.id"), nullable=False)
test_id = Column(UUID(as_uuid=True), ForeignKey("tests.id"), nullable=False) test_id = Column(UUID(as_uuid=True), ForeignKey("tests.id"), nullable=False)
# Assign file_name = Column(String, nullable=False)
file_name = Column(String, nullable=False) file_name = Column(String, nullable=False)
# Assign file_path = Column(String, nullable=False) # Path in MinIO
file_path = Column(String, nullable=False) # Path in MinIO file_path = Column(String, nullable=False) # Path in MinIO
# Assign sha256_hash = Column(String, nullable=False)
sha256_hash = Column(String, nullable=False) sha256_hash = Column(String, nullable=False)
# Assign uploaded_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
uploaded_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True) uploaded_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
# Assign uploaded_at = Column(DateTime(timezone=True), server_default=func.now())
uploaded_at = Column(DateTime(timezone=True), server_default=func.now()) uploaded_at = Column(DateTime(timezone=True), server_default=func.now())
# Assign team = Column(Enum(TeamSide, name="teamside"), nullable=False, default=Tea...
team = Column(Enum(TeamSide, name="teamside"), nullable=False, default=TeamSide.red) team = Column(Enum(TeamSide, name="teamside"), nullable=False, default=TeamSide.red)
# Assign notes = Column(Text, nullable=True)
notes = Column(Text, nullable=True) notes = Column(Text, nullable=True)
# Assign data_classification = Column(String(20), nullable=False, server_default="internal")
data_classification = Column(String(20), nullable=False, server_default="internal") data_classification = Column(String(20), nullable=False, server_default="internal")
# Relationships # Relationships
test = relationship("Test", back_populates="evidences") test = relationship("Test", back_populates="evidences")
# Assign uploader = relationship("User", foreign_keys=[uploaded_by])
uploader = relationship("User", foreign_keys=[uploaded_by]) uploader = relationship("User", foreign_keys=[uploaded_by])
+22 -4
View File
@@ -1,26 +1,44 @@
"""SQLAlchemy model for the intel_items table."""
# Import uuid
import uuid import uuid
from sqlalchemy import Column, String, Boolean, DateTime, ForeignKey, func
# Import Boolean, Column, DateTime, ForeignKey, String, ... from sqlalchemy
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, String, func
# Import UUID from sqlalchemy.dialects.postgresql
from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.dialects.postgresql import UUID
# Import relationship from sqlalchemy.orm
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
# Import Base from app.database
from app.database import Base from app.database import Base
# Define class IntelItem
class IntelItem(Base): class IntelItem(Base):
""" """Intelligence item model for tracking threat intelligence related to techniques.
Intelligence item model for tracking threat intelligence related to techniques.
Stores URLs and metadata from automated intel scans that may indicate Stores URLs and metadata from automated intel scans that may indicate
new attack variations or detection bypasses for specific techniques. new attack variations or detection bypasses for specific techniques.
""" """
# Assign __tablename__ = "intel_items"
__tablename__ = "intel_items" __tablename__ = "intel_items"
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
# Assign technique_id = Column(UUID(as_uuid=True), ForeignKey("techniques.id"), nullable=True)
technique_id = Column(UUID(as_uuid=True), ForeignKey("techniques.id"), nullable=True) technique_id = Column(UUID(as_uuid=True), ForeignKey("techniques.id"), nullable=True)
# Assign url = Column(String, nullable=False)
url = Column(String, nullable=False) url = Column(String, nullable=False)
# Assign title = Column(String, nullable=True)
title = Column(String, nullable=True) title = Column(String, nullable=True)
# Assign source = Column(String, nullable=True)
source = Column(String, nullable=True) source = Column(String, nullable=True)
# Assign detected_at = Column(DateTime(timezone=True), server_default=func.now())
detected_at = Column(DateTime(timezone=True), server_default=func.now()) detected_at = Column(DateTime(timezone=True), server_default=func.now())
# Assign reviewed = Column(Boolean, default=False)
reviewed = Column(Boolean, default=False) reviewed = Column(Boolean, default=False)
# Relationships # Relationships
+48 -2
View File
@@ -1,53 +1,99 @@
"""Jira integration models — link Aegis entities to Jira issues.""" """Jira integration models — link Aegis entities to Jira issues."""
# Import enum
import enum import enum
# Import uuid
import uuid import uuid
from sqlalchemy import Column, String, DateTime, ForeignKey, Enum as SQLEnum, Index, func
from sqlalchemy.dialects.postgresql import UUID, JSONB # Import Column, DateTime, ForeignKey, Index, String, func from sqlalchemy
from sqlalchemy import Column, DateTime, ForeignKey, Index, String, func
# Import Enum as SQLEnum from sqlalchemy
from sqlalchemy import Enum as SQLEnum
# Import JSONB, UUID from sqlalchemy.dialects.postgresql
from sqlalchemy.dialects.postgresql import JSONB, UUID
# Import relationship from sqlalchemy.orm
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
# Import Base from app.database
from app.database import Base from app.database import Base
# Define class JiraLinkEntityType
class JiraLinkEntityType(str, enum.Enum): class JiraLinkEntityType(str, enum.Enum):
"""Aegis entity types that can be linked to a Jira issue."""
# Assign test = "test"
test = "test" test = "test"
# Assign technique = "technique"
technique = "technique" technique = "technique"
# Assign campaign = "campaign"
campaign = "campaign" campaign = "campaign"
# Assign evidence = "evidence"
evidence = "evidence" evidence = "evidence"
# Define class JiraSyncDirection
class JiraSyncDirection(str, enum.Enum): class JiraSyncDirection(str, enum.Enum):
"""Direction of synchronisation between Aegis and Jira."""
# Assign aegis_to_jira = "aegis_to_jira"
aegis_to_jira = "aegis_to_jira" aegis_to_jira = "aegis_to_jira"
# Assign jira_to_aegis = "jira_to_aegis"
jira_to_aegis = "jira_to_aegis" jira_to_aegis = "jira_to_aegis"
# Assign bidirectional = "bidirectional"
bidirectional = "bidirectional" bidirectional = "bidirectional"
# Define class JiraLink
class JiraLink(Base): class JiraLink(Base):
"""Associates an Aegis entity with a Jira issue for bidirectional sync.""" """Associates an Aegis entity with a Jira issue for bidirectional sync."""
# Assign __tablename__ = "jira_links"
__tablename__ = "jira_links" __tablename__ = "jira_links"
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
# Assign entity_type = Column(SQLEnum(JiraLinkEntityType), nullable=False)
entity_type = Column(SQLEnum(JiraLinkEntityType), nullable=False) entity_type = Column(SQLEnum(JiraLinkEntityType), nullable=False)
# Assign entity_id = Column(UUID(as_uuid=True), nullable=False)
entity_id = Column(UUID(as_uuid=True), nullable=False) entity_id = Column(UUID(as_uuid=True), nullable=False)
# Assign jira_issue_key = Column(String(50), nullable=False)
jira_issue_key = Column(String(50), nullable=False) jira_issue_key = Column(String(50), nullable=False)
# Assign jira_issue_id = Column(String(50))
jira_issue_id = Column(String(50)) jira_issue_id = Column(String(50))
# Assign jira_project_key = Column(String(20))
jira_project_key = Column(String(20)) jira_project_key = Column(String(20))
# Assign jira_status = Column(String(100))
jira_status = Column(String(100)) jira_status = Column(String(100))
# Assign jira_priority = Column(String(50))
jira_priority = Column(String(50)) jira_priority = Column(String(50))
# Assign jira_assignee = Column(String(255))
jira_assignee = Column(String(255)) jira_assignee = Column(String(255))
# Assign jira_story_points = Column(String(10))
jira_story_points = Column(String(10)) jira_story_points = Column(String(10))
# Assign sync_direction = Column(
sync_direction = Column( sync_direction = Column(
SQLEnum(JiraSyncDirection), default=JiraSyncDirection.bidirectional SQLEnum(JiraSyncDirection), default=JiraSyncDirection.bidirectional
) )
# Assign last_synced_at = Column(DateTime)
last_synced_at = Column(DateTime) last_synced_at = Column(DateTime)
# Assign sync_metadata = Column(JSONB, default={})
sync_metadata = Column(JSONB, default={}) sync_metadata = Column(JSONB, default={})
# Assign created_by = Column(UUID(as_uuid=True), ForeignKey("users.id"))
created_by = Column(UUID(as_uuid=True), ForeignKey("users.id")) created_by = Column(UUID(as_uuid=True), ForeignKey("users.id"))
# Assign created_at = Column(DateTime(timezone=True), server_default=func.now())
created_at = Column(DateTime(timezone=True), server_default=func.now()) created_at = Column(DateTime(timezone=True), server_default=func.now())
# Assign updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate...
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now()) updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
# Assign creator = relationship("User", foreign_keys=[created_by])
creator = relationship("User", foreign_keys=[created_by]) creator = relationship("User", foreign_keys=[created_by])
# Assign __table_args__ = (
__table_args__ = ( __table_args__ = (
Index("ix_jira_links_entity_id", "entity_id"), Index("ix_jira_links_entity_id", "entity_id"),
Index("ix_jira_links_issue_key", "jira_issue_key"), Index("ix_jira_links_issue_key", "jira_issue_key"),
+22 -3
View File
@@ -1,35 +1,54 @@
"""Notification model — in-app notifications for user actions.""" """Notification model — in-app notifications for user actions."""
# Import uuid
import uuid import uuid
from sqlalchemy import Column, String, Text, Boolean, DateTime, ForeignKey, Index, func
# Import Boolean, Column, DateTime, ForeignKey, Index, S... from sqlalchemy
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Index, String, Text, func
# Import UUID from sqlalchemy.dialects.postgresql
from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.dialects.postgresql import UUID
# Import relationship from sqlalchemy.orm
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
# Import Base from app.database
from app.database import Base from app.database import Base
# Define class Notification
class Notification(Base): class Notification(Base):
""" """In-app notification for alerting users when they need to act.
In-app notification for alerting users when they need to act.
Types include: test_assigned, validation_needed, test_rejected, Types include: test_assigned, validation_needed, test_rejected,
test_validated, test_state_changed, etc. test_validated, test_state_changed, etc.
""" """
# Assign __tablename__ = "notifications"
__tablename__ = "notifications" __tablename__ = "notifications"
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
# Assign user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=False)
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=False) user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=False)
# Assign type = Column(String, nullable=False)
type = Column(String, nullable=False) type = Column(String, nullable=False)
# Assign title = Column(String, nullable=False)
title = Column(String, nullable=False) title = Column(String, nullable=False)
# Assign message = Column(Text, nullable=True)
message = Column(Text, nullable=True) message = Column(Text, nullable=True)
# Assign entity_type = Column(String, nullable=True)
entity_type = Column(String, nullable=True) entity_type = Column(String, nullable=True)
# Assign entity_id = Column(UUID(as_uuid=True), nullable=True)
entity_id = Column(UUID(as_uuid=True), nullable=True) entity_id = Column(UUID(as_uuid=True), nullable=True)
# Assign read = Column(Boolean, default=False)
read = Column(Boolean, default=False) read = Column(Boolean, default=False)
# Assign created_at = Column(DateTime(timezone=True), server_default=func.now())
created_at = Column(DateTime(timezone=True), server_default=func.now()) created_at = Column(DateTime(timezone=True), server_default=func.now())
# Relationships # Relationships
user = relationship("User") user = relationship("User")
# Assign __table_args__ = (
__table_args__ = ( __table_args__ = (
Index("ix_notifications_user_id", "user_id"), Index("ix_notifications_user_id", "user_id"),
Index("ix_notifications_read", "read"), Index("ix_notifications_read", "read"),
+25 -4
View File
@@ -1,37 +1,58 @@
"""OSINT enrichment items — CVEs, blogs, PoCs, and advisories linked to techniques.""" """OSINT enrichment items — CVEs, blogs, PoCs, and advisories linked to techniques."""
# Import uuid
import uuid import uuid
from sqlalchemy import Column, String, Text, Boolean, DateTime, ForeignKey, func
from sqlalchemy.dialects.postgresql import UUID, JSONB # Import Boolean, Column, DateTime, ForeignKey, String, ... from sqlalchemy
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, String, Text, func
# Import JSONB, UUID from sqlalchemy.dialects.postgresql
from sqlalchemy.dialects.postgresql import JSONB, UUID
# Import relationship from sqlalchemy.orm
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
# Import Base from app.database
from app.database import Base from app.database import Base
# Define class OsintItem
class OsintItem(Base): class OsintItem(Base):
"""Represents an OSINT data point (CVE, blog, PoC, advisory) associated """Represents an OSINT data point (CVE, blog, PoC, advisory) associated with a MITRE ATT&CK technique.
with a MITRE ATT&CK technique.
Used by the enrichment pipeline to surface relevant threat intelligence Used by the enrichment pipeline to surface relevant threat intelligence
for each technique, flagging those that need review. for each technique, flagging those that need review.
""" """
# Assign __tablename__ = "osint_items"
__tablename__ = "osint_items" __tablename__ = "osint_items"
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
# Assign technique_id = Column(
technique_id = Column( technique_id = Column(
UUID(as_uuid=True), UUID(as_uuid=True),
ForeignKey("techniques.id"), ForeignKey("techniques.id"),
# Keyword argument: nullable
nullable=False, nullable=False,
# Keyword argument: index
index=True, index=True,
) )
# Assign source_type = Column(String(50), nullable=False) # "cve", "blog", "poc", "advisory"
source_type = Column(String(50), nullable=False) # "cve", "blog", "poc", "advisory" source_type = Column(String(50), nullable=False) # "cve", "blog", "poc", "advisory"
# Assign source_url = Column(Text, nullable=False)
source_url = Column(Text, nullable=False) source_url = Column(Text, nullable=False)
# Assign title = Column(String(500), nullable=False)
title = Column(String(500), nullable=False) title = Column(String(500), nullable=False)
# Assign description = Column(Text, nullable=True)
description = Column(Text, nullable=True) description = Column(Text, nullable=True)
# Assign severity = Column(String(20), nullable=True) # CRITICAL, HIGH, MEDIUM, LOW, U...
severity = Column(String(20), nullable=True) # CRITICAL, HIGH, MEDIUM, LOW, UNKNOWN severity = Column(String(20), nullable=True) # CRITICAL, HIGH, MEDIUM, LOW, UNKNOWN
# Assign discovered_at = Column(DateTime(timezone=True), server_default=func.now(), nullable...
discovered_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) discovered_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
# Assign reviewed = Column(Boolean, default=False)
reviewed = Column(Boolean, default=False) reviewed = Column(Boolean, default=False)
# Assign metadata_ = Column("metadata", JSONB, default={})
metadata_ = Column("metadata", JSONB, default={}) metadata_ = Column("metadata", JSONB, default={})
# ── Relationships ───────────────────────────────────────────────── # ── Relationships ─────────────────────────────────────────────────
+19 -1
View File
@@ -1,25 +1,43 @@
"""ScoringConfig — single-row table for persisted scoring weights.""" """ScoringConfig — single-row table for persisted scoring weights."""
# Import uuid
import uuid import uuid
from sqlalchemy import Column, Float, DateTime, ForeignKey, func # Import Column, DateTime, Float, ForeignKey, func from sqlalchemy
from sqlalchemy import Column, DateTime, Float, ForeignKey, func
# Import UUID from sqlalchemy.dialects.postgresql
from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.dialects.postgresql import UUID
# Import Base from app.database
from app.database import Base from app.database import Base
# Define class ScoringConfig
class ScoringConfig(Base): class ScoringConfig(Base):
"""Single-row table persisting the active scoring weight configuration."""
# Assign __tablename__ = "scoring_config"
__tablename__ = "scoring_config" __tablename__ = "scoring_config"
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
# Assign weight_tests = Column(Float, nullable=False, default=40.0)
weight_tests = Column(Float, nullable=False, default=40.0) weight_tests = Column(Float, nullable=False, default=40.0)
# Assign weight_detection_rules = Column(Float, nullable=False, default=25.0)
weight_detection_rules = Column(Float, nullable=False, default=25.0) weight_detection_rules = Column(Float, nullable=False, default=25.0)
# Assign weight_d3fend = Column(Float, nullable=False, default=15.0)
weight_d3fend = Column(Float, nullable=False, default=15.0) weight_d3fend = Column(Float, nullable=False, default=15.0)
# Assign weight_recency = Column(Float, nullable=False, default=10.0)
weight_recency = Column(Float, nullable=False, default=10.0) weight_recency = Column(Float, nullable=False, default=10.0)
# Assign weight_severity = Column(Float, nullable=False, default=10.0)
weight_severity = Column(Float, nullable=False, default=10.0) weight_severity = Column(Float, nullable=False, default=10.0)
# Assign updated_by = Column(
updated_by = Column( updated_by = Column(
UUID(as_uuid=True), UUID(as_uuid=True),
ForeignKey("users.id", ondelete="SET NULL"), ForeignKey("users.id", ondelete="SET NULL"),
# Keyword argument: nullable
nullable=True, nullable=True,
) )
# Assign updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate...
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now()) updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
+32 -7
View File
@@ -1,38 +1,63 @@
import uuid """SQLAlchemy model for the techniques table."""
from datetime import datetime
from sqlalchemy import Column, String, Text, Boolean, DateTime, Enum # Import uuid
from sqlalchemy.dialects.postgresql import UUID, JSONB import uuid
# Import Boolean, Column, DateTime, Enum, String, Text from sqlalchemy
from sqlalchemy import Boolean, Column, DateTime, Enum, String, Text
# Import JSONB, UUID from sqlalchemy.dialects.postgresql
from sqlalchemy.dialects.postgresql import JSONB, UUID
# Import relationship from sqlalchemy.orm
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
# Import Base from app.database
from app.database import Base from app.database import Base
# Import TechniqueStatus from app.models.enums
from app.models.enums import TechniqueStatus from app.models.enums import TechniqueStatus
# Define class Technique
class Technique(Base): class Technique(Base):
""" """MITRE ATT&CK Technique model.
MITRE ATT&CK Technique model.
Represents an attack technique from the MITRE ATT&CK framework, Represents an attack technique from the MITRE ATT&CK framework,
including its coverage status and associated tests. including its coverage status and associated tests.
""" """
# Assign __tablename__ = "techniques"
__tablename__ = "techniques" __tablename__ = "techniques"
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
# Assign mitre_id = Column(String, unique=True, nullable=False) # e.g., "T1059.001"
mitre_id = Column(String, unique=True, nullable=False) # e.g., "T1059.001" mitre_id = Column(String, unique=True, nullable=False) # e.g., "T1059.001"
# Assign name = Column(String, nullable=False)
name = Column(String, nullable=False) name = Column(String, nullable=False)
# Assign description = Column(Text, nullable=True)
description = Column(Text, nullable=True) description = Column(Text, nullable=True)
# Assign tactic = Column(String, nullable=True)
tactic = Column(String, nullable=True) tactic = Column(String, nullable=True)
# Assign platforms = Column(JSONB, nullable=True, default=[])
platforms = Column(JSONB, nullable=True, default=[]) platforms = Column(JSONB, nullable=True, default=[])
# Assign mitre_version = Column(String, nullable=True)
mitre_version = Column(String, nullable=True) mitre_version = Column(String, nullable=True)
# Assign mitre_last_modified = Column(DateTime, nullable=True)
mitre_last_modified = Column(DateTime, nullable=True) mitre_last_modified = Column(DateTime, nullable=True)
# Assign is_subtechnique = Column(Boolean, default=False)
is_subtechnique = Column(Boolean, default=False) is_subtechnique = Column(Boolean, default=False)
# Assign parent_mitre_id = Column(String, nullable=True)
parent_mitre_id = Column(String, nullable=True) parent_mitre_id = Column(String, nullable=True)
# Assign status_global = Column(
status_global = Column( status_global = Column(
Enum(TechniqueStatus, name="techniquestatus"), Enum(TechniqueStatus, name="techniquestatus"),
# Keyword argument: default
default=TechniqueStatus.not_evaluated default=TechniqueStatus.not_evaluated
) )
# Assign review_required = Column(Boolean, default=False)
review_required = Column(Boolean, default=False) review_required = Column(Boolean, default=False)
# Assign last_review_date = Column(DateTime, nullable=True)
last_review_date = Column(DateTime, nullable=True) last_review_date = Column(DateTime, nullable=True)
# Relationships # Relationships
+65 -4
View File
@@ -1,79 +1,140 @@
"""SQLAlchemy model for the tests table."""
# Import uuid
import uuid import uuid
from sqlalchemy import Column, String, Text, Boolean, Integer, DateTime, ForeignKey, Enum, Index, func
# Import from sqlalchemy
from sqlalchemy import (
Boolean,
Column,
DateTime,
Enum,
ForeignKey,
Index,
Integer,
String,
Text,
func,
)
# Import UUID from sqlalchemy.dialects.postgresql
from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.dialects.postgresql import UUID
# Import relationship from sqlalchemy.orm
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
# Import Base from app.database
from app.database import Base from app.database import Base
from app.models.enums import TestState, TestResult
# Import TestResult, TestState from app.models.enums
from app.models.enums import TestResult, TestState
# Define class Test
class Test(Base): class Test(Base):
""" """Test model representing a security test for a MITRE ATT&CK technique.
Test model representing a security test for a MITRE ATT&CK technique.
Each test documents an attempt to validate coverage of a specific technique, Each test documents an attempt to validate coverage of a specific technique,
including the procedure, tools used, and outcome. V2 introduces dual including the procedure, tools used, and outcome. V2 introduces dual
validation: Red Lead and Blue Lead must each approve independently. validation: Red Lead and Blue Lead must each approve independently.
""" """
# Assign __tablename__ = "tests"
__tablename__ = "tests" __tablename__ = "tests"
# ── Core fields ───────────────────────────────────────────────── # ── Core fields ─────────────────────────────────────────────────
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
# Assign technique_id = Column(UUID(as_uuid=True), ForeignKey("techniques.id"), nullable=Fa...
technique_id = Column(UUID(as_uuid=True), ForeignKey("techniques.id"), nullable=False) technique_id = Column(UUID(as_uuid=True), ForeignKey("techniques.id"), nullable=False)
# Assign name = Column(String, nullable=False)
name = Column(String, nullable=False) name = Column(String, nullable=False)
# Assign description = Column(Text, nullable=True)
description = Column(Text, nullable=True) description = Column(Text, nullable=True)
# Assign platform = Column(String, nullable=True)
platform = Column(String, nullable=True) platform = Column(String, nullable=True)
# Assign procedure_text = Column(Text, nullable=True)
procedure_text = Column(Text, nullable=True) procedure_text = Column(Text, nullable=True)
# Assign tool_used = Column(String, nullable=True)
tool_used = Column(String, nullable=True) tool_used = Column(String, nullable=True)
# Assign execution_date = Column(DateTime, nullable=True)
execution_date = Column(DateTime, nullable=True) execution_date = Column(DateTime, nullable=True)
# Assign created_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
created_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True) created_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
# Assign result = Column(Enum(TestResult, name="testresult"), nullable=True)
result = Column(Enum(TestResult, name="testresult"), nullable=True) result = Column(Enum(TestResult, name="testresult"), nullable=True)
# Assign state = Column(Enum(TestState, name="teststate"), default=TestState.draft)
state = Column(Enum(TestState, name="teststate"), default=TestState.draft) state = Column(Enum(TestState, name="teststate"), default=TestState.draft)
# Assign created_at = Column(DateTime(timezone=True), server_default=func.now())
created_at = Column(DateTime(timezone=True), server_default=func.now()) created_at = Column(DateTime(timezone=True), server_default=func.now())
# ── Red Team fields ───────────────────────────────────────────── # ── Red Team fields ─────────────────────────────────────────────
red_summary = Column(Text, nullable=True) red_summary = Column(Text, nullable=True)
# Assign attack_success = Column(Boolean, nullable=True)
attack_success = Column(Boolean, nullable=True) attack_success = Column(Boolean, nullable=True)
# Assign red_validated_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
red_validated_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True) red_validated_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
# Assign red_validated_at = Column(DateTime, nullable=True)
red_validated_at = Column(DateTime, nullable=True) red_validated_at = Column(DateTime, nullable=True)
# Assign red_validation_status = Column(String, nullable=True) # pending / approved / rejected
red_validation_status = Column(String, nullable=True) # pending / approved / rejected red_validation_status = Column(String, nullable=True) # pending / approved / rejected
# Assign red_validation_notes = Column(Text, nullable=True)
red_validation_notes = Column(Text, nullable=True) red_validation_notes = Column(Text, nullable=True)
# ── Blue Team fields ──────────────────────────────────────────── # ── Blue Team fields ────────────────────────────────────────────
blue_summary = Column(Text, nullable=True) blue_summary = Column(Text, nullable=True)
# Assign detection_result = Column(Enum(TestResult, name="testresult"), nullable=True)
detection_result = Column(Enum(TestResult, name="testresult"), nullable=True) detection_result = Column(Enum(TestResult, name="testresult"), nullable=True)
# Assign blue_validated_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
blue_validated_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True) blue_validated_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
# Assign blue_validated_at = Column(DateTime, nullable=True)
blue_validated_at = Column(DateTime, nullable=True) blue_validated_at = Column(DateTime, nullable=True)
# Assign blue_validation_status = Column(String, nullable=True) # pending / approved / rejected
blue_validation_status = Column(String, nullable=True) # pending / approved / rejected blue_validation_status = Column(String, nullable=True) # pending / approved / rejected
# Assign blue_validation_notes = Column(Text, nullable=True)
blue_validation_notes = Column(Text, nullable=True) blue_validation_notes = Column(Text, nullable=True)
# ── Phase timing fields (for automatic Tempo worklogs) ────────── # ── Phase timing fields (for automatic Tempo worklogs) ──────────
red_started_at = Column(DateTime, nullable=True) red_started_at = Column(DateTime, nullable=True)
# Assign blue_started_at = Column(DateTime, nullable=True)
blue_started_at = Column(DateTime, nullable=True) blue_started_at = Column(DateTime, nullable=True)
# Assign paused_at = Column(DateTime, nullable=True)
paused_at = Column(DateTime, nullable=True) paused_at = Column(DateTime, nullable=True)
# Assign red_paused_seconds = Column(Integer, default=0)
red_paused_seconds = Column(Integer, default=0) red_paused_seconds = Column(Integer, default=0)
# Assign blue_paused_seconds = Column(Integer, default=0)
blue_paused_seconds = Column(Integer, default=0) blue_paused_seconds = Column(Integer, default=0)
# ── Remediation fields ─────────────────────────────────────────── # ── Remediation fields ───────────────────────────────────────────
remediation_steps = Column(Text, nullable=True) remediation_steps = Column(Text, nullable=True)
# Assign remediation_status = Column(String, nullable=True) # pending / in_progress / completed ...
remediation_status = Column(String, nullable=True) # pending / in_progress / completed / not_applicable remediation_status = Column(String, nullable=True) # pending / in_progress / completed / not_applicable
# Assign remediation_assignee = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
remediation_assignee = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True) remediation_assignee = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
# ── Re-test fields ──────────────────────────────────────────── # ── Re-test fields ────────────────────────────────────────────
retest_of = Column(UUID(as_uuid=True), ForeignKey("tests.id"), nullable=True) retest_of = Column(UUID(as_uuid=True), ForeignKey("tests.id"), nullable=True)
# Assign retest_count = Column(Integer, default=0)
retest_count = Column(Integer, default=0) retest_count = Column(Integer, default=0)
# Assign data_classification = Column(String(20), nullable=False, server_default="internal")
data_classification = Column(String(20), nullable=False, server_default="internal") data_classification = Column(String(20), nullable=False, server_default="internal")
# ── Relationships ─────────────────────────────────────────────── # ── Relationships ───────────────────────────────────────────────
technique = relationship("Technique", back_populates="tests") technique = relationship("Technique", back_populates="tests")
# Assign evidences = relationship("Evidence", back_populates="test")
evidences = relationship("Evidence", back_populates="test") evidences = relationship("Evidence", back_populates="test")
# Assign creator = relationship("User", foreign_keys=[created_by])
creator = relationship("User", foreign_keys=[created_by]) creator = relationship("User", foreign_keys=[created_by])
# Assign red_validator = relationship("User", foreign_keys=[red_validated_by])
red_validator = relationship("User", foreign_keys=[red_validated_by]) red_validator = relationship("User", foreign_keys=[red_validated_by])
# Assign blue_validator = relationship("User", foreign_keys=[blue_validated_by])
blue_validator = relationship("User", foreign_keys=[blue_validated_by]) blue_validator = relationship("User", foreign_keys=[blue_validated_by])
# Assign remediation_user = relationship("User", foreign_keys=[remediation_assignee])
remediation_user = relationship("User", foreign_keys=[remediation_assignee]) remediation_user = relationship("User", foreign_keys=[remediation_assignee])
# Assign original_test = relationship("Test", remote_side="Test.id", foreign_keys=[retest_of])
original_test = relationship("Test", remote_side="Test.id", foreign_keys=[retest_of]) original_test = relationship("Test", remote_side="Test.id", foreign_keys=[retest_of])
# Assign retests = relationship("Test", foreign_keys=[retest_of], back_populates="orig...
retests = relationship("Test", foreign_keys=[retest_of], back_populates="original_test") retests = relationship("Test", foreign_keys=[retest_of], back_populates="original_test")
# Assign __table_args__ = (
__table_args__ = ( __table_args__ = (
Index("ix_tests_technique_id", "technique_id"), Index("ix_tests_technique_id", "technique_id"),
Index("ix_tests_state", "state"), Index("ix_tests_state", "state"),
+32 -4
View File
@@ -4,51 +4,79 @@ When the Blue Team evaluates a test, they mark each associated detection
rule as triggered / not triggered / not applicable, along with notes. rule as triggered / not triggered / not applicable, along with notes.
""" """
# Import uuid
import uuid import uuid
from datetime import datetime
from sqlalchemy import Column, String, Text, Boolean, DateTime, ForeignKey, Index, UniqueConstraint # Import from sqlalchemy
from sqlalchemy import (
Boolean,
Column,
DateTime,
ForeignKey,
Index,
Text,
UniqueConstraint,
)
# Import UUID from sqlalchemy.dialects.postgresql
from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.dialects.postgresql import UUID
# Import relationship from sqlalchemy.orm
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
# Import Base from app.database
from app.database import Base from app.database import Base
# Define class TestDetectionResult
class TestDetectionResult(Base): class TestDetectionResult(Base):
""" """Per-test, per-rule evaluation result.
Per-test, per-rule evaluation result.
- ``triggered`` = True: rule detected the attack - ``triggered`` = True: rule detected the attack
- ``triggered`` = False: rule did NOT detect the attack - ``triggered`` = False: rule did NOT detect the attack
- ``triggered`` = None: not yet evaluated - ``triggered`` = None: not yet evaluated
""" """
# Assign __tablename__ = "test_detection_results"
__tablename__ = "test_detection_results" __tablename__ = "test_detection_results"
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
# Assign test_id = Column(
test_id = Column( test_id = Column(
UUID(as_uuid=True), UUID(as_uuid=True),
ForeignKey("tests.id", ondelete="CASCADE"), ForeignKey("tests.id", ondelete="CASCADE"),
# Keyword argument: nullable
nullable=False, nullable=False,
) )
# Assign detection_rule_id = Column(
detection_rule_id = Column( detection_rule_id = Column(
UUID(as_uuid=True), UUID(as_uuid=True),
ForeignKey("detection_rules.id", ondelete="CASCADE"), ForeignKey("detection_rules.id", ondelete="CASCADE"),
# Keyword argument: nullable
nullable=False, nullable=False,
) )
# Assign triggered = Column(Boolean, nullable=True) # None = not evaluated
triggered = Column(Boolean, nullable=True) # None = not evaluated triggered = Column(Boolean, nullable=True) # None = not evaluated
# Assign notes = Column(Text, nullable=True)
notes = Column(Text, nullable=True) notes = Column(Text, nullable=True)
# Assign evaluated_by = Column(
evaluated_by = Column( evaluated_by = Column(
UUID(as_uuid=True), UUID(as_uuid=True),
ForeignKey("users.id", ondelete="SET NULL"), ForeignKey("users.id", ondelete="SET NULL"),
# Keyword argument: nullable
nullable=True, nullable=True,
) )
# Assign evaluated_at = Column(DateTime, nullable=True)
evaluated_at = Column(DateTime, nullable=True) evaluated_at = Column(DateTime, nullable=True)
# Relationships # Relationships
test = relationship("Test") test = relationship("Test")
# Assign detection_rule = relationship("DetectionRule")
detection_rule = relationship("DetectionRule") detection_rule = relationship("DetectionRule")
# Assign evaluator = relationship("User", foreign_keys=[evaluated_by])
evaluator = relationship("User", foreign_keys=[evaluated_by]) evaluator = relationship("User", foreign_keys=[evaluated_by])
# Assign __table_args__ = (
__table_args__ = ( __table_args__ = (
Index('ix_tdr_test', 'test_id'), Index('ix_tdr_test', 'test_id'),
Index('ix_tdr_rule', 'detection_rule_id'), Index('ix_tdr_rule', 'detection_rule_id'),
+26 -3
View File
@@ -1,15 +1,21 @@
"""TestTemplate model — predefined test catalog entries.""" """TestTemplate model — predefined test catalog entries."""
# Import uuid
import uuid import uuid
from sqlalchemy import Column, String, Text, Boolean, DateTime, Index, func
# Import Boolean, Column, DateTime, Index, String, Text,... from sqlalchemy
from sqlalchemy import Boolean, Column, DateTime, Index, String, Text, func
# Import UUID from sqlalchemy.dialects.postgresql
from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.dialects.postgresql import UUID
# Import Base from app.database
from app.database import Base from app.database import Base
# Define class TestTemplate
class TestTemplate(Base): class TestTemplate(Base):
""" """Predefined test template mapped to a MITRE ATT&CK technique.
Predefined test template mapped to a MITRE ATT&CK technique.
Templates come from several sources: Templates come from several sources:
- **atomic_red_team**: Atomic Red Team by Red Canary - **atomic_red_team**: Atomic Red Team by Red Canary
@@ -18,24 +24,41 @@ class TestTemplate(Base):
Users can instantiate a real Test from a template. Users can instantiate a real Test from a template.
""" """
# Assign __tablename__ = "test_templates"
__tablename__ = "test_templates" __tablename__ = "test_templates"
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
# Assign mitre_technique_id = Column(String, nullable=False) # e.g. "T1059.001"
mitre_technique_id = Column(String, nullable=False) # e.g. "T1059.001" mitre_technique_id = Column(String, nullable=False) # e.g. "T1059.001"
# Assign name = Column(String, nullable=False)
name = Column(String, nullable=False) name = Column(String, nullable=False)
# Assign description = Column(Text, nullable=True)
description = Column(Text, nullable=True) description = Column(Text, nullable=True)
# Assign source = Column(String, nullable=False) # atomic_red_te...
source = Column(String, nullable=False) # atomic_red_team / mitre / custom source = Column(String, nullable=False) # atomic_red_team / mitre / custom
# Assign source_url = Column(String, nullable=True)
source_url = Column(String, nullable=True) source_url = Column(String, nullable=True)
# Assign attack_procedure = Column(Text, nullable=True) # Suggested attack procedure
attack_procedure = Column(Text, nullable=True) # Suggested attack procedure attack_procedure = Column(Text, nullable=True) # Suggested attack procedure
# Assign expected_detection = Column(Text, nullable=True) # What blue team should detect
expected_detection = Column(Text, nullable=True) # What blue team should detect expected_detection = Column(Text, nullable=True) # What blue team should detect
# Assign platform = Column(String, nullable=True) # windows / linux...
platform = Column(String, nullable=True) # windows / linux / macos platform = Column(String, nullable=True) # windows / linux / macos
# Assign tool_suggested = Column(String, nullable=True)
tool_suggested = Column(String, nullable=True) tool_suggested = Column(String, nullable=True)
# Assign severity = Column(String, nullable=True) # low / medium / ...
severity = Column(String, nullable=True) # low / medium / high / critical severity = Column(String, nullable=True) # low / medium / high / critical
# Assign atomic_test_id = Column(String, nullable=True) # ID in Atomic Red Team...
atomic_test_id = Column(String, nullable=True) # ID in Atomic Red Team repo atomic_test_id = Column(String, nullable=True) # ID in Atomic Red Team repo
# Assign suggested_remediation = Column(Text, nullable=True)
suggested_remediation = Column(Text, nullable=True) suggested_remediation = Column(Text, nullable=True)
# Assign is_active = Column(Boolean, default=True)
is_active = Column(Boolean, default=True) is_active = Column(Boolean, default=True)
# Assign created_at = Column(DateTime(timezone=True), server_default=func.now())
created_at = Column(DateTime(timezone=True), server_default=func.now()) created_at = Column(DateTime(timezone=True), server_default=func.now())
# Assign __table_args__ = (
__table_args__ = ( __table_args__ = (
Index('ix_test_templates_mitre_technique_id', 'mitre_technique_id'), Index('ix_test_templates_mitre_technique_id', 'mitre_technique_id'),
Index('ix_test_templates_source', 'source'), Index('ix_test_templates_source', 'source'),
@@ -4,47 +4,64 @@ Enables the Blue Team to see which detection rules should fire
for a given test template / attack procedure. for a given test template / attack procedure.
""" """
# Import uuid
import uuid import uuid
from datetime import datetime
from sqlalchemy import Column, Boolean, ForeignKey, Index, UniqueConstraint # Import Boolean, Column, ForeignKey, Index, UniqueConst... from sqlalchemy
from sqlalchemy import Boolean, Column, ForeignKey, Index, UniqueConstraint
# Import UUID from sqlalchemy.dialects.postgresql
from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.dialects.postgresql import UUID
# Import relationship from sqlalchemy.orm
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
# Import Base from app.database
from app.database import Base from app.database import Base
# Define class TestTemplateDetectionRule
class TestTemplateDetectionRule(Base): class TestTemplateDetectionRule(Base):
""" """Association between a test template and a detection rule.
Association between a test template and a detection rule.
Auto-generated by matching mitre_technique_id, or manually curated. Auto-generated by matching mitre_technique_id, or manually curated.
``is_primary`` marks rules with severity >= high as primary detections. ``is_primary`` marks rules with severity >= high as primary detections.
""" """
# Assign __tablename__ = "test_template_detection_rules"
__tablename__ = "test_template_detection_rules" __tablename__ = "test_template_detection_rules"
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
# Assign test_template_id = Column(
test_template_id = Column( test_template_id = Column(
UUID(as_uuid=True), UUID(as_uuid=True),
ForeignKey("test_templates.id", ondelete="CASCADE"), ForeignKey("test_templates.id", ondelete="CASCADE"),
# Keyword argument: nullable
nullable=True, nullable=True,
) )
# Assign detection_rule_id = Column(
detection_rule_id = Column( detection_rule_id = Column(
UUID(as_uuid=True), UUID(as_uuid=True),
ForeignKey("detection_rules.id", ondelete="CASCADE"), ForeignKey("detection_rules.id", ondelete="CASCADE"),
# Keyword argument: nullable
nullable=False, nullable=False,
) )
# Assign is_primary = Column(Boolean, default=False)
is_primary = Column(Boolean, default=False) is_primary = Column(Boolean, default=False)
# Relationships # Relationships
test_template = relationship("TestTemplate") test_template = relationship("TestTemplate")
# Assign detection_rule = relationship("DetectionRule")
detection_rule = relationship("DetectionRule") detection_rule = relationship("DetectionRule")
# Assign __table_args__ = (
__table_args__ = ( __table_args__ = (
Index('ix_ttdr_template', 'test_template_id'), Index('ix_ttdr_template', 'test_template_id'),
Index('ix_ttdr_rule', 'detection_rule_id'), Index('ix_ttdr_rule', 'detection_rule_id'),
UniqueConstraint( UniqueConstraint(
# Literal argument value
'test_template_id', 'detection_rule_id', 'test_template_id', 'detection_rule_id',
# Keyword argument: name
name='uq_template_detection_rule', name='uq_template_detection_rule',
), ),
) )
+55 -7
View File
@@ -4,87 +4,135 @@ Stores profiles of APT groups and their associated MITRE ATT&CK
techniques, imported from MITRE CTI (STIX 2.0). techniques, imported from MITRE CTI (STIX 2.0).
""" """
# Import uuid
import uuid import uuid
# Import from sqlalchemy
from sqlalchemy import ( from sqlalchemy import (
Column, String, Text, Boolean, DateTime, Boolean,
ForeignKey, Index, UniqueConstraint, func, Column,
DateTime,
ForeignKey,
Index,
String,
Text,
UniqueConstraint,
func,
) )
from sqlalchemy.dialects.postgresql import UUID, JSONB
# Import JSONB, UUID from sqlalchemy.dialects.postgresql
from sqlalchemy.dialects.postgresql import JSONB, UUID
# Import relationship from sqlalchemy.orm
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
# Import Base from app.database
from app.database import Base from app.database import Base
# Define class ThreatActor
class ThreatActor(Base): class ThreatActor(Base):
""" """Threat actor / APT group profile.
Threat actor / APT group profile.
Imported from MITRE CTI ``intrusion-set`` STIX objects. Imported from MITRE CTI ``intrusion-set`` STIX objects.
""" """
# Assign __tablename__ = "threat_actors"
__tablename__ = "threat_actors" __tablename__ = "threat_actors"
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
# Assign mitre_id = Column(String, unique=True, nullable=True) # e.g. "G00...
mitre_id = Column(String, unique=True, nullable=True) # e.g. "G0016" (APT29) mitre_id = Column(String, unique=True, nullable=True) # e.g. "G0016" (APT29)
# Assign name = Column(String, nullable=False)
name = Column(String, nullable=False) name = Column(String, nullable=False)
# Assign aliases = Column(JSONB, nullable=True, default=[]) # ["Cozy ...
aliases = Column(JSONB, nullable=True, default=[]) # ["Cozy Bear", "The Dukes", ...] aliases = Column(JSONB, nullable=True, default=[]) # ["Cozy Bear", "The Dukes", ...]
# Assign description = Column(Text, nullable=True)
description = Column(Text, nullable=True) description = Column(Text, nullable=True)
# Assign country = Column(String, nullable=True)
country = Column(String, nullable=True) country = Column(String, nullable=True)
# Assign target_sectors = Column(JSONB, nullable=True, default=[]) # ["government",...
target_sectors = Column(JSONB, nullable=True, default=[]) # ["government", "defense", ...] target_sectors = Column(JSONB, nullable=True, default=[]) # ["government", "defense", ...]
# Assign target_regions = Column(JSONB, nullable=True, default=[]) # ["north-americ...
target_regions = Column(JSONB, nullable=True, default=[]) # ["north-america", "europe", ...] target_regions = Column(JSONB, nullable=True, default=[]) # ["north-america", "europe", ...]
# Assign motivation = Column(String, nullable=True) # espionage ...
motivation = Column(String, nullable=True) # espionage / financial / destruction / ... motivation = Column(String, nullable=True) # espionage / financial / destruction / ...
# Assign sophistication = Column(String, nullable=True) # low / medium /...
sophistication = Column(String, nullable=True) # low / medium / high / advanced sophistication = Column(String, nullable=True) # low / medium / high / advanced
# Assign first_seen = Column(String, nullable=True)
first_seen = Column(String, nullable=True) first_seen = Column(String, nullable=True)
# Assign last_seen = Column(String, nullable=True)
last_seen = Column(String, nullable=True) last_seen = Column(String, nullable=True)
# Assign references = Column(JSONB, nullable=True, default=[]) # [{"url": "...
references = Column(JSONB, nullable=True, default=[]) # [{"url": "...", "description": "..."}] references = Column(JSONB, nullable=True, default=[]) # [{"url": "...", "description": "..."}]
# Assign mitre_url = Column(String, nullable=True)
mitre_url = Column(String, nullable=True) mitre_url = Column(String, nullable=True)
# Assign is_active = Column(Boolean, default=True)
is_active = Column(Boolean, default=True) is_active = Column(Boolean, default=True)
# Assign created_at = Column(DateTime(timezone=True), server_default=func.now())
created_at = Column(DateTime(timezone=True), server_default=func.now()) created_at = Column(DateTime(timezone=True), server_default=func.now())
# Relationships # Relationships
techniques = relationship( techniques = relationship(
# Literal argument value
"ThreatActorTechnique", "ThreatActorTechnique",
# Keyword argument: back_populates
back_populates="threat_actor", back_populates="threat_actor",
# Keyword argument: cascade
cascade="all, delete-orphan", cascade="all, delete-orphan",
) )
# Assign __table_args__ = (
__table_args__ = ( __table_args__ = (
Index('ix_threat_actors_country', 'country'), Index('ix_threat_actors_country', 'country'),
Index('ix_threat_actors_motivation', 'motivation'), Index('ix_threat_actors_motivation', 'motivation'),
) )
# Define class ThreatActorTechnique
class ThreatActorTechnique(Base): class ThreatActorTechnique(Base):
""" """Association between a threat actor and a MITRE ATT&CK technique.
Association between a threat actor and a MITRE ATT&CK technique.
Stores additional context about how the actor uses the technique Stores additional context about how the actor uses the technique
(from the STIX ``relationship`` ``uses`` objects). (from the STIX ``relationship`` ``uses`` objects).
""" """
# Assign __tablename__ = "threat_actor_techniques"
__tablename__ = "threat_actor_techniques" __tablename__ = "threat_actor_techniques"
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
# Assign threat_actor_id = Column(
threat_actor_id = Column( threat_actor_id = Column(
UUID(as_uuid=True), UUID(as_uuid=True),
ForeignKey("threat_actors.id", ondelete="CASCADE"), ForeignKey("threat_actors.id", ondelete="CASCADE"),
# Keyword argument: nullable
nullable=False, nullable=False,
) )
# Assign technique_id = Column(
technique_id = Column( technique_id = Column(
UUID(as_uuid=True), UUID(as_uuid=True),
ForeignKey("techniques.id", ondelete="CASCADE"), ForeignKey("techniques.id", ondelete="CASCADE"),
# Keyword argument: nullable
nullable=False, nullable=False,
) )
# Assign usage_description = Column(Text, nullable=True)
usage_description = Column(Text, nullable=True) usage_description = Column(Text, nullable=True)
# Assign first_seen_using = Column(String, nullable=True)
first_seen_using = Column(String, nullable=True) first_seen_using = Column(String, nullable=True)
# Relationships # Relationships
threat_actor = relationship("ThreatActor", back_populates="techniques") threat_actor = relationship("ThreatActor", back_populates="techniques")
# Assign technique = relationship("Technique")
technique = relationship("Technique") technique = relationship("Technique")
# Assign __table_args__ = (
__table_args__ = ( __table_args__ = (
Index('ix_threat_actor_techniques_actor', 'threat_actor_id'), Index('ix_threat_actor_techniques_actor', 'threat_actor_id'),
Index('ix_threat_actor_techniques_technique', 'technique_id'), Index('ix_threat_actor_techniques_technique', 'technique_id'),
UniqueConstraint( UniqueConstraint(
# Literal argument value
'threat_actor_id', 'technique_id', 'threat_actor_id', 'technique_id',
# Keyword argument: name
name='uq_actor_technique', name='uq_actor_technique',
), ),
) )
+22 -4
View File
@@ -1,14 +1,22 @@
"""SQLAlchemy model for the users table."""
# Import uuid
import uuid import uuid
from sqlalchemy import Column, String, Boolean, DateTime, func
# Import Boolean, Column, DateTime, String, func from sqlalchemy
from sqlalchemy import Boolean, Column, DateTime, String, func
# Import UUID from sqlalchemy.dialects.postgresql
from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.dialects.postgresql import UUID
# Import Base from app.database
from app.database import Base from app.database import Base
# Define class User
class User(Base): class User(Base):
""" """User model for authentication and authorization.
User model for authentication and authorization.
Possible roles: Possible roles:
- admin: Full system access - admin: Full system access
- red_tech: Red team technician - can create and edit tests - red_tech: Red team technician - can create and edit tests
@@ -17,14 +25,24 @@ class User(Base):
- blue_lead: Blue team lead - can validate tests - blue_lead: Blue team lead - can validate tests
- viewer: Read-only access (default) - viewer: Read-only access (default)
""" """
# Assign __tablename__ = "users"
__tablename__ = "users" __tablename__ = "users"
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
# Assign username = Column(String, unique=True, nullable=False)
username = Column(String, unique=True, nullable=False) username = Column(String, unique=True, nullable=False)
# Assign email = Column(String, nullable=True)
email = Column(String, nullable=True) email = Column(String, nullable=True)
# Assign hashed_password = Column(String, nullable=False)
hashed_password = Column(String, nullable=False) hashed_password = Column(String, nullable=False)
# Assign role = Column(String, nullable=False, default="viewer")
role = Column(String, nullable=False, default="viewer") role = Column(String, nullable=False, default="viewer")
# Assign is_active = Column(Boolean, default=True)
is_active = Column(Boolean, default=True) is_active = Column(Boolean, default=True)
# Assign must_change_password = Column(Boolean, default=True)
must_change_password = Column(Boolean, default=True) must_change_password = Column(Boolean, default=True)
# Assign created_at = Column(DateTime(timezone=True), server_default=func.now())
created_at = Column(DateTime(timezone=True), server_default=func.now()) created_at = Column(DateTime(timezone=True), server_default=func.now())
# Assign last_login = Column(DateTime, nullable=True)
last_login = Column(DateTime, nullable=True) last_login = Column(DateTime, nullable=True)
+28 -2
View File
@@ -1,13 +1,22 @@
"""Worklog model — immutable internal time-tracking records.""" """Worklog model — immutable internal time-tracking records."""
# Import uuid
import uuid import uuid
from sqlalchemy import Column, String, Integer, DateTime, ForeignKey, Text, Index, func
from sqlalchemy.dialects.postgresql import UUID, JSONB # Import Column, DateTime, ForeignKey, Index, Integer, S... from sqlalchemy
from sqlalchemy import Column, DateTime, ForeignKey, Index, Integer, String, Text, func
# Import JSONB, UUID from sqlalchemy.dialects.postgresql
from sqlalchemy.dialects.postgresql import JSONB, UUID
# Import relationship from sqlalchemy.orm
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
# Import Base from app.database
from app.database import Base from app.database import Base
# Define class Worklog
class Worklog(Base): class Worklog(Base):
"""Internal worklog entry with integrity hash for audit compliance. """Internal worklog entry with integrity hash for audit compliance.
@@ -16,25 +25,42 @@ class Worklog(Base):
the immutable fields so tampering can be detected. the immutable fields so tampering can be detected.
""" """
# Assign __tablename__ = "worklogs"
__tablename__ = "worklogs" __tablename__ = "worklogs"
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
# Assign entity_type = Column(String(50), nullable=False)
entity_type = Column(String(50), nullable=False) entity_type = Column(String(50), nullable=False)
# Assign entity_id = Column(UUID(as_uuid=True), nullable=False)
entity_id = Column(UUID(as_uuid=True), nullable=False) entity_id = Column(UUID(as_uuid=True), nullable=False)
# Assign user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=False)
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=False) user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=False)
# Assign activity_type = Column(String(100), nullable=False)
activity_type = Column(String(100), nullable=False) activity_type = Column(String(100), nullable=False)
# Assign started_at = Column(DateTime, nullable=False)
started_at = Column(DateTime, nullable=False) started_at = Column(DateTime, nullable=False)
# Assign ended_at = Column(DateTime)
ended_at = Column(DateTime) ended_at = Column(DateTime)
# Assign duration_seconds = Column(Integer, nullable=False)
duration_seconds = Column(Integer, nullable=False) duration_seconds = Column(Integer, nullable=False)
# Assign description = Column(Text)
description = Column(Text) description = Column(Text)
# Assign tempo_synced = Column(DateTime)
tempo_synced = Column(DateTime) tempo_synced = Column(DateTime)
# Assign tempo_worklog_id = Column(String(100))
tempo_worklog_id = Column(String(100)) tempo_worklog_id = Column(String(100))
# Assign integrity_hash = Column(String(64))
integrity_hash = Column(String(64)) integrity_hash = Column(String(64))
# Assign created_at = Column(DateTime(timezone=True), server_default=func.now())
created_at = Column(DateTime(timezone=True), server_default=func.now()) created_at = Column(DateTime(timezone=True), server_default=func.now())
# Assign extra_metadata = Column("metadata", JSONB, default={})
extra_metadata = Column("metadata", JSONB, default={}) extra_metadata = Column("metadata", JSONB, default={})
# Assign user = relationship("User", foreign_keys=[user_id])
user = relationship("User", foreign_keys=[user_id]) user = relationship("User", foreign_keys=[user_id])
# Assign __table_args__ = (
__table_args__ = ( __table_args__ = (
Index("ix_worklogs_entity_id", "entity_id"), Index("ix_worklogs_entity_id", "entity_id"),
Index("ix_worklogs_user_id", "user_id"), Index("ix_worklogs_user_id", "user_id"),
+1
View File
@@ -0,0 +1 @@
"""FastAPI router modules — one router per feature domain."""
+35 -4
View File
@@ -1,50 +1,81 @@
"""Advanced metrics endpoints — coverage by tactic, never-tested, avg validation time.""" """Advanced metrics endpoints — coverage by tactic, never-tested, avg validation time."""
# Import APIRouter, Depends from fastapi
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
# Import get_db from app.database
from app.database import get_db from app.database import get_db
# Import get_current_user from app.dependencies.auth
from app.dependencies.auth import get_current_user from app.dependencies.auth import get_current_user
# Import User from app.models.user
from app.models.user import User from app.models.user import User
# Import advanced_metrics_service from app.services
from app.services import advanced_metrics_service from app.services import advanced_metrics_service
# Assign router = APIRouter(prefix="/metrics/advanced", tags=["advanced-metrics"])
router = APIRouter(prefix="/metrics/advanced", tags=["advanced-metrics"]) router = APIRouter(prefix="/metrics/advanced", tags=["advanced-metrics"])
# Apply the @router.get decorator
@router.get("/coverage-by-tactic") @router.get("/coverage-by-tactic")
# Define function coverage_by_tactic
def coverage_by_tactic( def coverage_by_tactic(
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: user
user: User = Depends(get_current_user), user: User = Depends(get_current_user),
): ) -> list:
"""Coverage percentage broken down by MITRE ATT&CK tactic.""" """Coverage percentage broken down by MITRE ATT&CK tactic."""
# Return advanced_metrics_service.get_coverage_by_tactic(db)
return advanced_metrics_service.get_coverage_by_tactic(db) return advanced_metrics_service.get_coverage_by_tactic(db)
# Apply the @router.get decorator
@router.get("/never-tested") @router.get("/never-tested")
# Define function never_tested_techniques
def never_tested_techniques( def never_tested_techniques(
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: user
user: User = Depends(get_current_user), user: User = Depends(get_current_user),
): ) -> list:
"""Techniques that have never had a test created.""" """Techniques that have never had a test created."""
# Return advanced_metrics_service.get_never_tested_techniques(db)
return advanced_metrics_service.get_never_tested_techniques(db) return advanced_metrics_service.get_never_tested_techniques(db)
# Apply the @router.get decorator
@router.get("/avg-validation-time") @router.get("/avg-validation-time")
# Define function avg_validation_time
def avg_validation_time( def avg_validation_time(
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: user
user: User = Depends(get_current_user), user: User = Depends(get_current_user),
): ) -> dict:
"""Average time from test creation to validation, computed from audit logs. """Average time from test creation to validation, computed from audit logs.
Returns overall average and per-phase averages where data is available. Returns overall average and per-phase averages where data is available.
""" """
# Return advanced_metrics_service.get_avg_validation_time(db)
return advanced_metrics_service.get_avg_validation_time(db) return advanced_metrics_service.get_avg_validation_time(db)
# Apply the @router.get decorator
@router.get("/detection-rate-trend") @router.get("/detection-rate-trend")
# Define function detection_rate_trend
def detection_rate_trend( def detection_rate_trend(
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: user
user: User = Depends(get_current_user), user: User = Depends(get_current_user),
): ) -> list:
"""Monthly detection rate trend for the last 12 months.""" """Monthly detection rate trend for the last 12 months."""
# Return advanced_metrics_service.get_detection_rate_trend(db)
return advanced_metrics_service.get_detection_rate_trend(db) return advanced_metrics_service.get_detection_rate_trend(db)
+37 -4
View File
@@ -4,52 +4,85 @@ Returns complete datasets without pagination so BI tools can ingest
directly from URL. All endpoints require authentication. directly from URL. All endpoints require authentication.
""" """
# Import APIRouter, Depends, Query from fastapi
from fastapi import APIRouter, Depends, Query from fastapi import APIRouter, Depends, Query
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
# Import get_db from app.database
from app.database import get_db from app.database import get_db
# Import get_current_user, require_role from app.dependencies.auth
from app.dependencies.auth import get_current_user, require_role from app.dependencies.auth import get_current_user, require_role
# Import User from app.models.user
from app.models.user import User from app.models.user import User
# Import analytics_service from app.services
from app.services import analytics_service from app.services import analytics_service
# Assign router = APIRouter(prefix="/analytics", tags=["analytics"])
router = APIRouter(prefix="/analytics", tags=["analytics"]) router = APIRouter(prefix="/analytics", tags=["analytics"])
# Apply the @router.get decorator
@router.get("/coverage") @router.get("/coverage")
# Define function analytics_coverage
def analytics_coverage( def analytics_coverage(
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: user
user: User = Depends(get_current_user), user: User = Depends(get_current_user),
): ) -> list:
"""Coverage per technique — flat format for BI dashboards.""" """Coverage per technique — flat format for BI dashboards."""
# Return analytics_service.get_coverage_analytics(db)
return analytics_service.get_coverage_analytics(db) return analytics_service.get_coverage_analytics(db)
# Apply the @router.get decorator
@router.get("/tests") @router.get("/tests")
# Define function analytics_tests
def analytics_tests( def analytics_tests(
# Entry: date_from
date_from: str = Query(None, description="ISO date filter (>=)"), date_from: str = Query(None, description="ISO date filter (>=)"),
# Entry: date_to
date_to: str = Query(None, description="ISO date filter (<=)"), date_to: str = Query(None, description="ISO date filter (<=)"),
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: user
user: User = Depends(get_current_user), user: User = Depends(get_current_user),
): ) -> list:
"""All tests with timestamps — flat format for BI dashboards.""" """All tests with timestamps — flat format for BI dashboards."""
# Return analytics_service.get_tests_analytics(
return analytics_service.get_tests_analytics( return analytics_service.get_tests_analytics(
db, date_from=date_from, date_to=date_to db, date_from=date_from, date_to=date_to
) )
# Apply the @router.get decorator
@router.get("/trends") @router.get("/trends")
# Define function analytics_trends
def analytics_trends( def analytics_trends(
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: user
user: User = Depends(get_current_user), user: User = Depends(get_current_user),
): ) -> list:
"""Historical coverage snapshots for trend visualization.""" """Historical coverage snapshots for trend visualization."""
# Return analytics_service.get_trends_analytics(db)
return analytics_service.get_trends_analytics(db) return analytics_service.get_trends_analytics(db)
# Apply the @router.get decorator
@router.get("/operators") @router.get("/operators")
# Define function analytics_operators
def analytics_operators( def analytics_operators(
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: user
user: User = Depends(require_role("admin")), user: User = Depends(require_role("admin")),
): ) -> list:
"""Per-operator metrics — for workload management dashboards.""" """Per-operator metrics — for workload management dashboards."""
# Return analytics_service.get_operators_analytics(db)
return analytics_service.get_operators_analytics(db) return analytics_service.get_operators_analytics(db)
+53 -3
View File
@@ -1,77 +1,127 @@
"""Audit log viewer router (admin only).""" """Audit log viewer router (admin only)."""
# Import datetime from datetime
from datetime import datetime from datetime import datetime
# Import Optional from typing
from typing import Optional from typing import Optional
# Import APIRouter, Depends, Query from fastapi
from fastapi import APIRouter, Depends, Query from fastapi import APIRouter, Depends, Query
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
# Import get_db from app.database
from app.database import get_db from app.database import get_db
# Import require_role from app.dependencies.auth
from app.dependencies.auth import require_role from app.dependencies.auth import require_role
# Import User from app.models.user
from app.models.user import User from app.models.user import User
# Import AuditLogOut, AuditLogPage from app.schemas.audit
from app.schemas.audit import AuditLogOut, AuditLogPage from app.schemas.audit import AuditLogOut, AuditLogPage
# Import from app.services.audit_query_service
from app.services.audit_query_service import ( from app.services.audit_query_service import (
list_distinct_actions, list_distinct_actions,
list_distinct_entity_types, list_distinct_entity_types,
list_logs, list_logs,
) )
# Assign router = APIRouter(prefix="/audit-logs", tags=["audit"])
router = APIRouter(prefix="/audit-logs", tags=["audit"]) router = APIRouter(prefix="/audit-logs", tags=["audit"])
# Apply the @router.get decorator
@router.get("", response_model=AuditLogPage) @router.get("", response_model=AuditLogPage)
# Define function list_audit_logs
def list_audit_logs( def list_audit_logs(
# Entry: user_id
user_id: Optional[str] = Query(None, description="Filter by user ID"), user_id: Optional[str] = Query(None, description="Filter by user ID"),
# Entry: action
action: Optional[str] = Query(None, description="Filter by action type"), action: Optional[str] = Query(None, description="Filter by action type"),
# Entry: entity_type
entity_type: Optional[str] = Query(None, description="Filter by entity type"), entity_type: Optional[str] = Query(None, description="Filter by entity type"),
# Entry: start_date
start_date: Optional[datetime] = Query(None, description="Filter by start date"), start_date: Optional[datetime] = Query(None, description="Filter by start date"),
# Entry: end_date
end_date: Optional[datetime] = Query(None, description="Filter by end date"), end_date: Optional[datetime] = Query(None, description="Filter by end date"),
# Entry: offset
offset: int = Query(0, ge=0, description="Number of records to skip"), offset: int = Query(0, ge=0, description="Number of records to skip"),
# Entry: limit
limit: int = Query(50, ge=1, le=100, description="Max records to return"), limit: int = Query(50, ge=1, le=100, description="Max records to return"),
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(require_role("admin")), current_user: User = Depends(require_role("admin")),
): ) -> AuditLogPage:
"""Return paginated audit logs with optional filters. """Return paginated audit logs with optional filters.
**Requires admin role.** **Requires admin role.**
""" """
# Assign result = list_logs(
result = list_logs( result = list_logs(
db, db,
# Keyword argument: user_id
user_id=user_id, user_id=user_id,
# Keyword argument: action
action=action, action=action,
# Keyword argument: entity_type
entity_type=entity_type, entity_type=entity_type,
# Keyword argument: start_date
start_date=start_date, start_date=start_date,
# Keyword argument: end_date
end_date=end_date, end_date=end_date,
# Keyword argument: offset
offset=offset, offset=offset,
# Keyword argument: limit
limit=limit, limit=limit,
) )
# Return AuditLogPage(
return AuditLogPage( return AuditLogPage(
# Keyword argument: items
items=[AuditLogOut(**item) for item in result["items"]], items=[AuditLogOut(**item) for item in result["items"]],
# Keyword argument: total
total=result["total"], total=result["total"],
# Keyword argument: offset
offset=result["offset"], offset=result["offset"],
# Keyword argument: limit
limit=result["limit"], limit=result["limit"],
) )
# Apply the @router.get decorator
@router.get("/actions", response_model=list[str]) @router.get("/actions", response_model=list[str])
# Define function list_actions
def list_actions( def list_actions(
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(require_role("admin")), current_user: User = Depends(require_role("admin")),
): ) -> list[str]:
"""Return a list of distinct action types in the audit log. """Return a list of distinct action types in the audit log.
**Requires admin role.** **Requires admin role.**
""" """
# Return list_distinct_actions(db)
return list_distinct_actions(db) return list_distinct_actions(db)
# Apply the @router.get decorator
@router.get("/entity-types", response_model=list[str]) @router.get("/entity-types", response_model=list[str])
# Define function list_entity_types
def list_entity_types( def list_entity_types(
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(require_role("admin")), current_user: User = Depends(require_role("admin")),
): ) -> list[str]:
"""Return a list of distinct entity types in the audit log. """Return a list of distinct entity types in the audit log.
**Requires admin role.** **Requires admin role.**
""" """
# Return list_distinct_entity_types(db)
return list_distinct_entity_types(db) return list_distinct_entity_types(db)
+135 -11
View File
@@ -7,165 +7,289 @@ the token in the body for backwards compatibility and for clients that
cannot use cookies (e.g. Swagger UI). cannot use cookies (e.g. Swagger UI).
""" """
# Import os
import os import os
# Import APIRouter, Cookie, Depends, Request, Response from fastapi
from fastapi import APIRouter, Cookie, Depends, Request, Response from fastapi import APIRouter, Cookie, Depends, Request, Response
# Import OAuth2PasswordRequestForm from fastapi.security
from fastapi.security import OAuth2PasswordRequestForm from fastapi.security import OAuth2PasswordRequestForm
# Import JWTError, jwt from jose
from jose import JWTError, jwt
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from jose import jwt, JWTError # Import blacklist_token, create_access_token, verify_pa... from app.auth
from app.auth import blacklist_token, create_access_token, verify_password
from app.auth import create_access_token, blacklist_token, verify_password # Import settings from app.config
from app.config import settings from app.config import settings
# Import get_db from app.database
from app.database import get_db from app.database import get_db
# Import get_current_user from app.dependencies.auth
from app.dependencies.auth import get_current_user from app.dependencies.auth import get_current_user
# Import BusinessRuleViolation, PermissionViolation from app.domain.errors
from app.domain.errors import BusinessRuleViolation, PermissionViolation from app.domain.errors import BusinessRuleViolation, PermissionViolation
# Import UnitOfWork from app.domain.unit_of_work
from app.domain.unit_of_work import UnitOfWork from app.domain.unit_of_work import UnitOfWork
# Import limiter from app.limiter
from app.limiter import limiter from app.limiter import limiter
# Import resolve_client_ip from app.middleware.request_context
from app.middleware.request_context import resolve_client_ip from app.middleware.request_context import resolve_client_ip
# Import User from app.models.user
from app.models.user import User from app.models.user import User
from app.services.auth_service import (
_DUMMY_HASH, # Import TokenResponse, UserOut from app.schemas.auth
change_password as auth_change_password,
)
from app.services.audit_service import log_action
from app.schemas.auth import TokenResponse, UserOut from app.schemas.auth import TokenResponse, UserOut
# Import PasswordChange from app.schemas.user
from app.schemas.user import PasswordChange from app.schemas.user import PasswordChange
# Import log_action from app.services.audit_service
from app.services.audit_service import log_action
# Import from app.services.auth_service
from app.services.auth_service import (
_DUMMY_HASH,
)
# Import from app.services.auth_service
from app.services.auth_service import (
change_password as auth_change_password,
)
# Assign router = APIRouter(prefix="/auth", tags=["auth"])
router = APIRouter(prefix="/auth", tags=["auth"]) router = APIRouter(prefix="/auth", tags=["auth"])
# Assign _IS_HTTPS = os.environ.get("AEGIS_ENV", "").lower() == "production"
_IS_HTTPS = os.environ.get("AEGIS_ENV", "").lower() == "production" _IS_HTTPS = os.environ.get("AEGIS_ENV", "").lower() == "production"
# Assign _COOKIE_NAME = "aegis_token"
_COOKIE_NAME = "aegis_token" _COOKIE_NAME = "aegis_token"
# Apply the @router.post decorator
@router.post("/login", response_model=TokenResponse) @router.post("/login", response_model=TokenResponse)
# Apply the @limiter.limit decorator
@limiter.limit("5/minute") @limiter.limit("5/minute")
# Define function login
def login( def login(
# Entry: request
request: Request, request: Request,
# Entry: response
response: Response, response: Response,
# Entry: form_data
form_data: OAuth2PasswordRequestForm = Depends(), form_data: OAuth2PasswordRequestForm = Depends(),
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ) -> TokenResponse:
"""Authenticate a user and return a JWT access token. """Authenticate a user and return a JWT access token.
Rate-limited to **5 attempts per minute per IP**. Failed and successful Rate-limited to **5 attempts per minute per IP**. Failed and successful
logins are recorded in the audit log (SEC-009). logins are recorded in the audit log (SEC-009).
""" """
# Assign user = db.query(User).filter(User.username == form_data.username).first()
user = db.query(User).filter(User.username == form_data.username).first() user = db.query(User).filter(User.username == form_data.username).first()
# Assign target_hash = user.hashed_password if user else _DUMMY_HASH
target_hash = user.hashed_password if user else _DUMMY_HASH target_hash = user.hashed_password if user else _DUMMY_HASH
# Assign password_valid = verify_password(form_data.password, target_hash)
password_valid = verify_password(form_data.password, target_hash) password_valid = verify_password(form_data.password, target_hash)
# Assign ip = resolve_client_ip(request)
ip = resolve_client_ip(request) ip = resolve_client_ip(request)
# Check: user is None or not password_valid
if user is None or not password_valid: if user is None or not password_valid:
# Open context manager
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
# Call log_action()
log_action( log_action(
db, db,
user.id if user else None, user.id if user else None,
# Literal argument value
"LOGIN_FAILED", "LOGIN_FAILED",
# Literal argument value
"auth", "auth",
# Literal argument value
None, None,
# Keyword argument: details
details={ details={
# Literal argument value
"username": form_data.username, "username": form_data.username,
# Literal argument value
"ip": ip, "ip": ip,
# Literal argument value
"reason": "invalid_credentials", "reason": "invalid_credentials",
}, },
# Keyword argument: ip_address
ip_address=ip, ip_address=ip,
) )
# Call uow.commit()
uow.commit() uow.commit()
# Raise BusinessRuleViolation
raise BusinessRuleViolation("Incorrect username or password") raise BusinessRuleViolation("Incorrect username or password")
# Check: not user.is_active
if not user.is_active: if not user.is_active:
# Raise PermissionViolation
raise PermissionViolation("Account is disabled. Contact an administrator.") raise PermissionViolation("Account is disabled. Contact an administrator.")
# Assign access_token = create_access_token(data={"sub": user.username})
access_token = create_access_token(data={"sub": user.username}) access_token = create_access_token(data={"sub": user.username})
# Open context manager
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
# Call log_action()
log_action( log_action(
db, db,
user.id, user.id,
# Literal argument value
"LOGIN_SUCCESS", "LOGIN_SUCCESS",
# Literal argument value
"auth", "auth",
str(user.id), str(user.id),
# Keyword argument: details
details={"username": user.username, "ip": ip}, details={"username": user.username, "ip": ip},
# Keyword argument: ip_address
ip_address=ip, ip_address=ip,
) )
# Call uow.commit()
uow.commit() uow.commit()
# Call response.set_cookie()
response.set_cookie( response.set_cookie(
# Keyword argument: key
key=_COOKIE_NAME, key=_COOKIE_NAME,
# Keyword argument: value
value=access_token, value=access_token,
# Keyword argument: httponly
httponly=True, httponly=True,
# Keyword argument: secure
secure=_IS_HTTPS, secure=_IS_HTTPS,
# Keyword argument: samesite
samesite="strict", samesite="strict",
# Keyword argument: max_age
max_age=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60, max_age=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
# Keyword argument: path
path="/", path="/",
) )
# Return TokenResponse(access_token=access_token)
return TokenResponse(access_token=access_token) return TokenResponse(access_token=access_token)
# Apply the @router.post decorator
@router.post("/logout") @router.post("/logout")
# Define function logout
def logout( def logout(
# Entry: request
request: Request, request: Request,
# Entry: response
response: Response, response: Response,
# Entry: aegis_token
aegis_token: str | None = Cookie(None), aegis_token: str | None = Cookie(None),
): ) -> dict:
"""Clear the authentication cookie and revoke the current token.""" """Clear the authentication cookie and revoke the current token."""
# Assign bearer = (
bearer = ( bearer = (
request.headers.get("Authorization") request.headers.get("Authorization")
or request.headers.get("authorization") or request.headers.get("authorization")
or "" or ""
) )
# Assign bearer = bearer.removeprefix("Bearer ").removeprefix("bearer ").strip()
bearer = bearer.removeprefix("Bearer ").removeprefix("bearer ").strip() bearer = bearer.removeprefix("Bearer ").removeprefix("bearer ").strip()
# Assign seen = set()
seen: set[str] = set() seen: set[str] = set()
# Iterate over (aegis_token, bearer)
for raw in (aegis_token, bearer): for raw in (aegis_token, bearer):
# Check: not raw or raw in seen
if not raw or raw in seen: if not raw or raw in seen:
# Skip to the next loop iteration
continue continue
# Call seen.add()
seen.add(raw) seen.add(raw)
# Attempt the following; catch errors below
try: try:
# Assign payload = jwt.decode(
payload = jwt.decode( payload = jwt.decode(
raw, raw,
settings.SECRET_KEY, settings.SECRET_KEY,
# Keyword argument: algorithms
algorithms=[settings.ALGORITHM], algorithms=[settings.ALGORITHM],
) )
# Assign jti = payload.get("jti")
jti = payload.get("jti") jti = payload.get("jti")
# Assign exp = payload.get("exp", 0)
exp = payload.get("exp", 0) exp = payload.get("exp", 0)
# Check: jti
if jti: if jti:
# Call blacklist_token()
blacklist_token(jti, float(exp)) blacklist_token(jti, float(exp))
# Handle JWTError
except JWTError: except JWTError:
# Intentional no-op placeholder
pass pass
# Call response.delete_cookie()
response.delete_cookie( response.delete_cookie(
# Keyword argument: key
key=_COOKIE_NAME, key=_COOKIE_NAME,
# Keyword argument: httponly
httponly=True, httponly=True,
# Keyword argument: secure
secure=_IS_HTTPS, secure=_IS_HTTPS,
# Keyword argument: samesite
samesite="strict", samesite="strict",
# Keyword argument: path
path="/", path="/",
) )
# Return {"detail": "Logged out"}
return {"detail": "Logged out"} return {"detail": "Logged out"}
# Apply the @router.get decorator
@router.get("/me", response_model=UserOut) @router.get("/me", response_model=UserOut)
def read_current_user(current_user: User = Depends(get_current_user)): # Define function read_current_user
def read_current_user(current_user: User = Depends(get_current_user)) -> UserOut:
"""Return the profile of the currently authenticated user.""" """Return the profile of the currently authenticated user."""
# Return current_user
return current_user return current_user
# Apply the @router.post decorator
@router.post("/change-password") @router.post("/change-password")
# Define function change_password
def change_password( def change_password(
# Entry: body
body: PasswordChange, body: PasswordChange,
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> dict:
"""Change the current user's password.""" """Change the current user's password."""
# Call auth_change_password()
auth_change_password( auth_change_password(
db, db,
current_user, current_user,
# Keyword argument: current_password
current_password=body.current_password, current_password=body.current_password,
# Keyword argument: new_password
new_password=body.new_password, new_password=body.new_password,
) )
# Open context manager
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
# Call uow.commit()
uow.commit() uow.commit()
# Return {"detail": "Password changed successfully"}
return {"detail": "Password changed successfully"} return {"detail": "Password changed successfully"}
+416 -29
View File
@@ -1,73 +1,177 @@
"""Campaign endpoints — CRUD, test management, activation, and auto-generation. """Campaign endpoints — CRUD, test management, activation, and auto-generation.
Provides comprehensive campaign lifecycle management including Provides comprehensive campaign lifecycle management including test ordering,
test ordering, progress tracking, and threat actor integration. progress tracking, and threat actor integration.
""" """
# Import logging
import logging import logging
# Import uuid
import uuid import uuid
# Import Optional from typing
from typing import Optional from typing import Optional
# Import APIRouter, Depends, Query from fastapi
from fastapi import APIRouter, Depends, Query from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
# Import BaseModel, Field from pydantic
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session
# Import get_db from app.database
from app.database import get_db from app.database import get_db
# Import get_current_user, require_any_role from app.dependencies.auth
from app.dependencies.auth import get_current_user, require_any_role from app.dependencies.auth import get_current_user, require_any_role
# Import UnitOfWork from app.domain.unit_of_work
from app.domain.unit_of_work import UnitOfWork
# Import User from app.models.user
from app.models.user import User from app.models.user import User
from app.services.campaign_service import generate_campaign_from_threat_actor
# Import log_action from app.services.audit_service
from app.services.audit_service import log_action
# Import from app.services.campaign_crud_service
from app.services.campaign_crud_service import (
activate_campaign as crud_activate,
)
# Import from app.services.campaign_crud_service
from app.services.campaign_crud_service import ( from app.services.campaign_crud_service import (
add_test_to_campaign as crud_add_test, add_test_to_campaign as crud_add_test,
activate_campaign as crud_activate, )
# Import from app.services.campaign_crud_service
from app.services.campaign_crud_service import (
complete_campaign as crud_complete, complete_campaign as crud_complete,
)
# Import from app.services.campaign_crud_service
from app.services.campaign_crud_service import (
create_campaign as crud_create, create_campaign as crud_create,
)
# Import from app.services.campaign_crud_service
from app.services.campaign_crud_service import (
get_campaign_detail as crud_get_detail, get_campaign_detail as crud_get_detail,
)
# Import from app.services.campaign_crud_service
from app.services.campaign_crud_service import (
get_campaign_history as crud_get_history, get_campaign_history as crud_get_history,
)
# Import from app.services.campaign_crud_service
from app.services.campaign_crud_service import (
get_campaign_progress_data as crud_get_progress, get_campaign_progress_data as crud_get_progress,
)
# Import from app.services.campaign_crud_service
from app.services.campaign_crud_service import (
list_campaigns as crud_list, list_campaigns as crud_list,
)
# Import from app.services.campaign_crud_service
from app.services.campaign_crud_service import (
remove_test_from_campaign as crud_remove_test, remove_test_from_campaign as crud_remove_test,
)
# Import from app.services.campaign_crud_service
from app.services.campaign_crud_service import (
schedule_campaign as crud_schedule, schedule_campaign as crud_schedule,
)
# Import from app.services.campaign_crud_service
from app.services.campaign_crud_service import (
serialize_campaign, serialize_campaign,
)
# Import from app.services.campaign_crud_service
from app.services.campaign_crud_service import (
update_campaign as crud_update, update_campaign as crud_update,
) )
from app.domain.unit_of_work import UnitOfWork
from app.services.audit_service import log_action # Import generate_campaign_from_threat_actor from app.services.campaign_service
from app.services.campaign_service import generate_campaign_from_threat_actor
# Import notify_role from app.services.notification_service
from app.services.notification_service import notify_role from app.services.notification_service import notify_role
# Assign logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Assign router = APIRouter(prefix="/campaigns", tags=["campaigns"])
router = APIRouter(prefix="/campaigns", tags=["campaigns"]) router = APIRouter(prefix="/campaigns", tags=["campaigns"])
# ── Pydantic schemas ───────────────────────────────────────────────── # ── Pydantic schemas ─────────────────────────────────────────────────
class CampaignCreate(BaseModel): class CampaignCreate(BaseModel):
"""Payload for creating a new campaign."""
# name: str
name: str name: str
# Assign description = None
description: Optional[str] = None description: Optional[str] = None
# Assign type = "custom"
type: str = "custom" type: str = "custom"
# Assign threat_actor_id = None
threat_actor_id: Optional[str] = None threat_actor_id: Optional[str] = None
# Assign target_platform = None
target_platform: Optional[str] = None target_platform: Optional[str] = None
# Assign tags = Field(default_factory=list)
tags: Optional[list[str]] = Field(default_factory=list) tags: Optional[list[str]] = Field(default_factory=list)
# Assign scheduled_at = None
scheduled_at: Optional[str] = None scheduled_at: Optional[str] = None
# Define class CampaignUpdate
class CampaignUpdate(BaseModel): class CampaignUpdate(BaseModel):
"""Payload for updating an existing campaign's metadata."""
# Assign name = None
name: Optional[str] = None name: Optional[str] = None
# Assign description = None
description: Optional[str] = None description: Optional[str] = None
# Assign type = None
type: Optional[str] = None type: Optional[str] = None
# Assign target_platform = None
target_platform: Optional[str] = None target_platform: Optional[str] = None
# Assign tags = None
tags: Optional[list[str]] = None tags: Optional[list[str]] = None
# Assign scheduled_at = None
scheduled_at: Optional[str] = None scheduled_at: Optional[str] = None
# Define class AddTestPayload
class AddTestPayload(BaseModel): class AddTestPayload(BaseModel):
"""Payload for adding a test to a campaign."""
# test_id: str
test_id: str test_id: str
# Assign order_index = None
order_index: Optional[int] = None order_index: Optional[int] = None
# Assign depends_on = None
depends_on: Optional[str] = None depends_on: Optional[str] = None
# Assign phase = None
phase: Optional[str] = None phase: Optional[str] = None
# Define class SchedulePayload
class SchedulePayload(BaseModel): class SchedulePayload(BaseModel):
"""Payload for scheduling or rescheduling a campaign run."""
# is_recurring: bool
is_recurring: bool is_recurring: bool
# Assign recurrence_pattern = None # weekly, monthly, quarterly
recurrence_pattern: Optional[str] = None # weekly, monthly, quarterly recurrence_pattern: Optional[str] = None # weekly, monthly, quarterly
# Assign next_run_at = None
next_run_at: Optional[str] = None next_run_at: Optional[str] = None
@@ -76,24 +180,54 @@ class SchedulePayload(BaseModel):
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@router.get("") @router.get("")
# Define function list_campaigns
def list_campaigns( def list_campaigns(
# Entry: type
type: Optional[str] = Query(None), type: Optional[str] = Query(None),
# Entry: status
status: Optional[str] = Query(None), status: Optional[str] = Query(None),
# Entry: threat_actor_id
threat_actor_id: Optional[str] = Query(None), threat_actor_id: Optional[str] = Query(None),
# Entry: search
search: Optional[str] = Query(None), search: Optional[str] = Query(None),
# Entry: offset
offset: int = Query(0, ge=0), offset: int = Query(0, ge=0),
# Entry: limit
limit: int = Query(50, ge=1, le=200), limit: int = Query(50, ge=1, le=200),
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> list:
"""List campaigns with optional filters and pagination.""" """List campaigns with optional filters and pagination.
Args:
type (Optional[str]): Filter by campaign type (e.g. ``custom``, ``threat_actor``).
status (Optional[str]): Filter by campaign status (e.g. ``draft``, ``active``).
threat_actor_id (Optional[str]): Filter campaigns linked to a specific threat actor.
search (Optional[str]): Free-text search against campaign name.
offset (int): Number of records to skip for pagination.
limit (int): Maximum number of records to return.
db (Session): SQLAlchemy database session.
current_user (User): Authenticated user making the request.
Returns:
list: Serialised list of campaign summary dicts.
"""
# Return crud_list(
return crud_list( return crud_list(
db, db,
# Keyword argument: type
type=type, type=type,
# Keyword argument: status
status=status, status=status,
# Keyword argument: threat_actor_id
threat_actor_id=threat_actor_id, threat_actor_id=threat_actor_id,
# Keyword argument: search
search=search, search=search,
# Keyword argument: offset
offset=offset, offset=offset,
# Keyword argument: limit
limit=limit, limit=limit,
) )
@@ -103,34 +237,65 @@ def list_campaigns(
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@router.post("", status_code=201) @router.post("", status_code=201)
# Define function create_campaign
def create_campaign( def create_campaign(
# Entry: payload
payload: CampaignCreate, payload: CampaignCreate,
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(require_any_role("red_lead", "blue_lead")), current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
): ) -> dict:
"""Create a new campaign.""" """Create a new campaign.
Args:
payload (CampaignCreate): Fields for the new campaign (name, type, threat actor, etc.).
db (Session): SQLAlchemy database session.
current_user (User): Authenticated red_lead or blue_lead creating the campaign.
Returns:
dict: Serialised representation of the newly created campaign.
"""
# Open context manager
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
# Assign result = crud_create(
result = crud_create( result = crud_create(
db, db,
# Keyword argument: creator_id
creator_id=current_user.id, creator_id=current_user.id,
# Keyword argument: name
name=payload.name, name=payload.name,
# Keyword argument: description
description=payload.description, description=payload.description,
# Keyword argument: type
type=payload.type, type=payload.type,
# Keyword argument: threat_actor_id
threat_actor_id=payload.threat_actor_id, threat_actor_id=payload.threat_actor_id,
# Keyword argument: target_platform
target_platform=payload.target_platform, target_platform=payload.target_platform,
# Keyword argument: tags
tags=payload.tags, tags=payload.tags,
# Keyword argument: scheduled_at
scheduled_at=payload.scheduled_at, scheduled_at=payload.scheduled_at,
) )
# Call log_action()
log_action( log_action(
db, db,
# Keyword argument: user_id
user_id=current_user.id, user_id=current_user.id,
# Keyword argument: action
action="create_campaign", action="create_campaign",
# Keyword argument: entity_type
entity_type="campaign", entity_type="campaign",
# Keyword argument: entity_id
entity_id=result["id"], entity_id=result["id"],
# Keyword argument: details
details={"name": payload.name, "type": payload.type}, details={"name": payload.name, "type": payload.type},
) )
# Call uow.commit()
uow.commit() uow.commit()
# Return result
return result return result
@@ -139,12 +304,26 @@ def create_campaign(
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@router.get("/{campaign_id}") @router.get("/{campaign_id}")
# Define function get_campaign
def get_campaign( def get_campaign(
# Entry: campaign_id
campaign_id: str, campaign_id: str,
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> dict:
"""Get detailed campaign info including tests and progress.""" """Get detailed campaign info including tests and progress.
Args:
campaign_id (str): UUID string of the campaign to retrieve.
db (Session): SQLAlchemy database session.
current_user (User): Authenticated user making the request.
Returns:
dict: Campaign detail including associated tests and progress metrics.
"""
# Return crud_get_detail(db, campaign_id)
return crud_get_detail(db, campaign_id) return crud_get_detail(db, campaign_id)
@@ -153,32 +332,60 @@ def get_campaign(
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@router.patch("/{campaign_id}") @router.patch("/{campaign_id}")
# Define function update_campaign
def update_campaign( def update_campaign(
# Entry: campaign_id
campaign_id: str, campaign_id: str,
# Entry: payload
payload: CampaignUpdate, payload: CampaignUpdate,
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(require_any_role("red_lead", "blue_lead")), current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
): ) -> dict:
"""Update a campaign. Only allowed in draft or active state.""" """Update a campaign. Only allowed in draft or active state.
Args:
campaign_id (str): UUID string of the campaign to update.
payload (CampaignUpdate): Partial update payload; only set fields are applied.
db (Session): SQLAlchemy database session.
current_user (User): Authenticated red_lead or blue_lead performing the update.
Returns:
dict: Serialised representation of the updated campaign.
"""
# Assign update_data = payload.model_dump(exclude_unset=True)
update_data = payload.model_dump(exclude_unset=True) update_data = payload.model_dump(exclude_unset=True)
# Open context manager
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
# Assign result = crud_update(
result = crud_update( result = crud_update(
db, db,
campaign_id, campaign_id,
# Keyword argument: updater_id
updater_id=current_user.id, updater_id=current_user.id,
# Keyword argument: updater_role
updater_role=current_user.role, updater_role=current_user.role,
**update_data, **update_data,
) )
# Call log_action()
log_action( log_action(
db, db,
# Keyword argument: user_id
user_id=current_user.id, user_id=current_user.id,
# Keyword argument: action
action="update_campaign", action="update_campaign",
# Keyword argument: entity_type
entity_type="campaign", entity_type="campaign",
# Keyword argument: entity_id
entity_id=campaign_id, entity_id=campaign_id,
# Keyword argument: details
details={"updated_fields": list(update_data.keys())}, details={"updated_fields": list(update_data.keys())},
) )
# Call uow.commit()
uow.commit() uow.commit()
# Return result
return result return result
@@ -187,23 +394,46 @@ def update_campaign(
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@router.post("/{campaign_id}/tests") @router.post("/{campaign_id}/tests")
# Define function add_test_to_campaign
def add_test_to_campaign( def add_test_to_campaign(
# Entry: campaign_id
campaign_id: str, campaign_id: str,
# Entry: payload
payload: AddTestPayload, payload: AddTestPayload,
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(require_any_role("red_lead", "blue_lead")), current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
): ) -> dict:
"""Add a test to a campaign with optional ordering and dependency.""" """Add a test to a campaign with optional ordering and dependency.
Args:
campaign_id (str): UUID string of the target campaign.
payload (AddTestPayload): Test ID plus optional order index, dependency, and phase.
db (Session): SQLAlchemy database session.
current_user (User): Authenticated red_lead or blue_lead adding the test.
Returns:
dict: The created campaign-test association record.
"""
# Open context manager
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
# Assign result = crud_add_test(
result = crud_add_test( result = crud_add_test(
db, db,
campaign_id, campaign_id,
# Keyword argument: test_id
test_id=payload.test_id, test_id=payload.test_id,
# Keyword argument: order_index
order_index=payload.order_index, order_index=payload.order_index,
# Keyword argument: depends_on
depends_on=payload.depends_on, depends_on=payload.depends_on,
# Keyword argument: phase
phase=payload.phase, phase=payload.phase,
) )
# Call uow.commit()
uow.commit() uow.commit()
# Return result
return result return result
@@ -212,16 +442,35 @@ def add_test_to_campaign(
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@router.delete("/{campaign_id}/tests/{campaign_test_id}") @router.delete("/{campaign_id}/tests/{campaign_test_id}")
# Define function remove_test_from_campaign
def remove_test_from_campaign( def remove_test_from_campaign(
# Entry: campaign_id
campaign_id: str, campaign_id: str,
# Entry: campaign_test_id
campaign_test_id: str, campaign_test_id: str,
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(require_any_role("red_lead", "blue_lead")), current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
): ) -> dict:
"""Remove a test from a campaign.""" """Remove a test from a campaign.
Args:
campaign_id (str): UUID string of the campaign.
campaign_test_id (str): UUID string of the campaign-test association to remove.
db (Session): SQLAlchemy database session.
current_user (User): Authenticated red_lead or blue_lead removing the test.
Returns:
dict: Confirmation message with key ``detail``.
"""
# Open context manager
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
# Call crud_remove_test()
crud_remove_test(db, campaign_id, campaign_test_id) crud_remove_test(db, campaign_id, campaign_test_id)
# Call uow.commit()
uow.commit() uow.commit()
# Return {"detail": "Test removed from campaign"}
return {"detail": "Test removed from campaign"} return {"detail": "Test removed from campaign"}
@@ -230,34 +479,65 @@ def remove_test_from_campaign(
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@router.post("/{campaign_id}/activate") @router.post("/{campaign_id}/activate")
# Define function activate_campaign
def activate_campaign( def activate_campaign(
# Entry: campaign_id
campaign_id: str, campaign_id: str,
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(require_any_role("red_lead", "blue_lead")), current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
): ) -> dict:
"""Activate a campaign, moving it from draft to active.""" """Activate a campaign, moving it from draft to active.
Args:
campaign_id (str): UUID string of the campaign to activate.
db (Session): SQLAlchemy database session.
current_user (User): Authenticated red_lead or blue_lead activating the campaign.
Returns:
dict: Serialised representation of the activated campaign.
"""
# Open context manager
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
# Assign campaign = crud_activate(db, campaign_id)
campaign = crud_activate(db, campaign_id) campaign = crud_activate(db, campaign_id)
# Call notify_role()
notify_role( notify_role(
db, db,
# Keyword argument: role
role="red_tech", role="red_tech",
# Keyword argument: type
type="campaign_activated", type="campaign_activated",
# Keyword argument: title
title="Campaign activated", title="Campaign activated",
# Keyword argument: message
message=f'Campaign "{campaign.name}" has been activated.', message=f'Campaign "{campaign.name}" has been activated.',
# Keyword argument: entity_type
entity_type="campaign", entity_type="campaign",
# Keyword argument: entity_id
entity_id=campaign.id, entity_id=campaign.id,
) )
# Call log_action()
log_action( log_action(
db, db,
# Keyword argument: user_id
user_id=current_user.id, user_id=current_user.id,
# Keyword argument: action
action="activate_campaign", action="activate_campaign",
# Keyword argument: entity_type
entity_type="campaign", entity_type="campaign",
# Keyword argument: entity_id
entity_id=campaign.id, entity_id=campaign.id,
# Keyword argument: details
details={"name": campaign.name}, details={"name": campaign.name},
) )
# Call uow.commit()
uow.commit() uow.commit()
# Reload ORM object attributes from the database
db.refresh(campaign) db.refresh(campaign)
# Return serialize_campaign(db, campaign)
return serialize_campaign(db, campaign) return serialize_campaign(db, campaign)
@@ -266,25 +546,49 @@ def activate_campaign(
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@router.post("/{campaign_id}/complete") @router.post("/{campaign_id}/complete")
# Define function complete_campaign
def complete_campaign( def complete_campaign(
# Entry: campaign_id
campaign_id: str, campaign_id: str,
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(require_any_role("red_lead", "admin")), current_user: User = Depends(require_any_role("red_lead", "admin")),
): ) -> dict:
"""Mark a campaign as completed.""" """Mark a campaign as completed.
Args:
campaign_id (str): UUID string of the campaign to complete.
db (Session): SQLAlchemy database session.
current_user (User): Authenticated red_lead or admin completing the campaign.
Returns:
dict: Serialised representation of the completed campaign.
"""
# Open context manager
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
# Assign campaign = crud_complete(db, campaign_id)
campaign = crud_complete(db, campaign_id) campaign = crud_complete(db, campaign_id)
# Call log_action()
log_action( log_action(
db, db,
# Keyword argument: user_id
user_id=current_user.id, user_id=current_user.id,
# Keyword argument: action
action="complete_campaign", action="complete_campaign",
# Keyword argument: entity_type
entity_type="campaign", entity_type="campaign",
# Keyword argument: entity_id
entity_id=campaign.id, entity_id=campaign.id,
# Keyword argument: details
details={"name": campaign.name}, details={"name": campaign.name},
) )
# Call uow.commit()
uow.commit() uow.commit()
# Reload ORM object attributes from the database
db.refresh(campaign) db.refresh(campaign)
# Return serialize_campaign(db, campaign)
return serialize_campaign(db, campaign) return serialize_campaign(db, campaign)
@@ -293,12 +597,26 @@ def complete_campaign(
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@router.get("/{campaign_id}/progress") @router.get("/{campaign_id}/progress")
# Define function get_campaign_progress_endpoint
def get_campaign_progress_endpoint( def get_campaign_progress_endpoint(
# Entry: campaign_id
campaign_id: str, campaign_id: str,
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> dict:
"""Get progress statistics for a campaign.""" """Get progress statistics for a campaign.
Args:
campaign_id (str): UUID string of the campaign.
db (Session): SQLAlchemy database session.
current_user (User): Authenticated user making the request.
Returns:
dict: Progress breakdown including counts by test state and overall percentage.
"""
# Return crud_get_progress(db, campaign_id)
return crud_get_progress(db, campaign_id) return crud_get_progress(db, campaign_id)
@@ -307,33 +625,55 @@ def get_campaign_progress_endpoint(
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@router.post("/from-threat-actor/{actor_id}", status_code=201) @router.post("/from-threat-actor/{actor_id}", status_code=201)
# Define function generate_campaign_from_actor
def generate_campaign_from_actor( def generate_campaign_from_actor(
# Entry: actor_id
actor_id: str, actor_id: str,
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(require_any_role("red_lead", "blue_lead")), current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
): ) -> dict:
"""Auto-generate a campaign from a threat actor's uncovered techniques. """Auto-generate a campaign from a threat actor's uncovered techniques.
Creates tests from the best available templates and orders them Creates tests from the best available templates and orders them
by kill chain phase. by kill chain phase.
Args:
actor_id (str): UUID string of the threat actor to generate a campaign for.
db (Session): SQLAlchemy database session.
current_user (User): Authenticated red_lead or blue_lead requesting the generation.
Returns:
dict: Serialised representation of the newly generated campaign.
""" """
# Assign campaign = generate_campaign_from_threat_actor(
campaign = generate_campaign_from_threat_actor( campaign = generate_campaign_from_threat_actor(
db, db,
uuid.UUID(actor_id), uuid.UUID(actor_id),
current_user, current_user,
) )
# Open context manager
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
# Call log_action()
log_action( log_action(
db, db,
# Keyword argument: user_id
user_id=current_user.id, user_id=current_user.id,
# Keyword argument: action
action="generate_campaign", action="generate_campaign",
# Keyword argument: entity_type
entity_type="campaign", entity_type="campaign",
# Keyword argument: entity_id
entity_id=campaign.id, entity_id=campaign.id,
# Keyword argument: details
details={"actor_id": actor_id, "campaign_name": campaign.name}, details={"actor_id": actor_id, "campaign_name": campaign.name},
) )
# Call uow.commit()
uow.commit() uow.commit()
# Return serialize_campaign(db, campaign)
return serialize_campaign(db, campaign) return serialize_campaign(db, campaign)
@@ -342,41 +682,74 @@ def generate_campaign_from_actor(
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@router.patch("/{campaign_id}/schedule") @router.patch("/{campaign_id}/schedule")
# Define function schedule_campaign
def schedule_campaign( def schedule_campaign(
# Entry: campaign_id
campaign_id: str, campaign_id: str,
# Entry: payload
payload: SchedulePayload, payload: SchedulePayload,
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(require_any_role("red_lead", "blue_lead")), current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
): ) -> dict:
"""Configure or update the recurrence schedule for a campaign. """Configure or update the recurrence schedule for a campaign.
Only the campaign creator or admin can change scheduling. Only the campaign creator or admin can change scheduling.
Args:
campaign_id (str): UUID string of the campaign to schedule.
payload (SchedulePayload): Recurrence flag, pattern, and next run timestamp.
db (Session): SQLAlchemy database session.
current_user (User): Authenticated red_lead or blue_lead (must be owner or admin).
Returns:
dict: Serialised representation of the campaign with updated schedule fields.
""" """
# Open context manager
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
# Assign campaign = crud_schedule(
campaign = crud_schedule( campaign = crud_schedule(
db, db,
campaign_id, campaign_id,
# Keyword argument: owner_id
owner_id=current_user.id, owner_id=current_user.id,
# Keyword argument: owner_role
owner_role=current_user.role, owner_role=current_user.role,
# Keyword argument: is_recurring
is_recurring=payload.is_recurring, is_recurring=payload.is_recurring,
# Keyword argument: recurrence_pattern
recurrence_pattern=payload.recurrence_pattern, recurrence_pattern=payload.recurrence_pattern,
# Keyword argument: next_run_at
next_run_at=payload.next_run_at, next_run_at=payload.next_run_at,
) )
# Call log_action()
log_action( log_action(
db, db,
# Keyword argument: user_id
user_id=current_user.id, user_id=current_user.id,
# Keyword argument: action
action="schedule_campaign", action="schedule_campaign",
# Keyword argument: entity_type
entity_type="campaign", entity_type="campaign",
# Keyword argument: entity_id
entity_id=campaign.id, entity_id=campaign.id,
# Keyword argument: details
details={ details={
# Literal argument value
"is_recurring": campaign.is_recurring, "is_recurring": campaign.is_recurring,
# Literal argument value
"recurrence_pattern": campaign.recurrence_pattern, "recurrence_pattern": campaign.recurrence_pattern,
# Literal argument value
"next_run_at": campaign.next_run_at.isoformat() if campaign.next_run_at else None, "next_run_at": campaign.next_run_at.isoformat() if campaign.next_run_at else None,
}, },
) )
# Call uow.commit()
uow.commit() uow.commit()
# Reload ORM object attributes from the database
db.refresh(campaign) db.refresh(campaign)
# Return serialize_campaign(db, campaign)
return serialize_campaign(db, campaign) return serialize_campaign(db, campaign)
@@ -385,10 +758,24 @@ def schedule_campaign(
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@router.get("/{campaign_id}/history") @router.get("/{campaign_id}/history")
# Define function get_campaign_history
def get_campaign_history( def get_campaign_history(
# Entry: campaign_id
campaign_id: str, campaign_id: str,
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> list:
"""List all child campaigns (execution history) of a recurring campaign.""" """List all child campaigns (execution history) of a recurring campaign.
Args:
campaign_id (str): UUID string of the parent recurring campaign.
db (Session): SQLAlchemy database session.
current_user (User): Authenticated user making the request.
Returns:
list: Serialised list of child campaign dicts ordered by creation date.
"""
# Return crud_get_history(db, campaign_id)
return crud_get_history(db, campaign_id) return crud_get_history(db, campaign_id)
+136 -23
View File
@@ -1,29 +1,43 @@
"""Compliance endpoints — framework status, reports, and gap analysis. """Compliance endpoints — framework status, reports, and gap analysis.
Thin HTTP adapter: delegates all data logic to compliance_service. Thin HTTP adapter that delegates all data logic to compliance_service.
Provides compliance posture assessment by mapping MITRE ATT&CK technique Provides compliance posture assessment by mapping MITRE ATT&CK technique
coverage to compliance framework controls. coverage to compliance framework controls.
""" """
# Import APIRouter, Depends from fastapi
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
# Import StreamingResponse from fastapi.responses
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
# Import get_db from app.database
from app.database import get_db from app.database import get_db
# Import get_current_user, require_role from app.dependencies.auth
from app.dependencies.auth import get_current_user, require_role from app.dependencies.auth import get_current_user, require_role
# Import User from app.models.user
from app.models.user import User from app.models.user import User
from app.services.compliance_service import (
list_frameworks, # Import from app.services.compliance_import_service
get_framework_status,
build_framework_report_csv,
get_framework_gaps,
)
from app.services.compliance_import_service import ( from app.services.compliance_import_service import (
import_nist_800_53_mappings,
import_cis_controls_v8_mappings, import_cis_controls_v8_mappings,
import_nist_800_53_mappings,
) )
# Import from app.services.compliance_service
from app.services.compliance_service import (
build_framework_report_csv,
get_framework_gaps,
get_framework_status,
list_frameworks,
)
# Assign router = APIRouter(prefix="/compliance", tags=["compliance"])
router = APIRouter(prefix="/compliance", tags=["compliance"]) router = APIRouter(prefix="/compliance", tags=["compliance"])
@@ -31,11 +45,23 @@ router = APIRouter(prefix="/compliance", tags=["compliance"])
@router.get("/frameworks") @router.get("/frameworks")
# Define function list_frameworks_endpoint
def list_frameworks_endpoint( def list_frameworks_endpoint(
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> list:
"""List all available compliance frameworks.""" """List all available compliance frameworks.
Args:
db (Session): SQLAlchemy database session.
current_user (User): Authenticated user making the request.
Returns:
list: List of framework summary dicts containing id, name, and control counts.
"""
# Return list_frameworks(db)
return list_frameworks(db) return list_frameworks(db)
@@ -43,12 +69,26 @@ def list_frameworks_endpoint(
@router.get("/frameworks/{framework_id}/status") @router.get("/frameworks/{framework_id}/status")
# Define function framework_status
def framework_status( def framework_status(
# Entry: framework_id
framework_id: str, framework_id: str,
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> dict:
"""Get compliance status for each control in a framework.""" """Get compliance status for each control in a framework.
Args:
framework_id (str): Identifier of the compliance framework (e.g. ``nist-800-53``).
db (Session): SQLAlchemy database session.
current_user (User): Authenticated user making the request.
Returns:
dict: Mapping of control IDs to their coverage status and linked techniques.
"""
# Return get_framework_status(db, framework_id)
return get_framework_status(db, framework_id) return get_framework_status(db, framework_id)
@@ -56,12 +96,26 @@ def framework_status(
@router.get("/frameworks/{framework_id}/report") @router.get("/frameworks/{framework_id}/report")
# Define function framework_report
def framework_report( def framework_report(
# Entry: framework_id
framework_id: str, framework_id: str,
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> dict:
"""Get the full compliance report (same as status but marked as report).""" """Get the full compliance report (same as status but marked as report).
Args:
framework_id (str): Identifier of the compliance framework.
db (Session): SQLAlchemy database session.
current_user (User): Authenticated user making the request.
Returns:
dict: Full compliance report with per-control coverage details.
"""
# Return get_framework_status(db, framework_id)
return get_framework_status(db, framework_id) return get_framework_status(db, framework_id)
@@ -69,17 +123,35 @@ def framework_report(
@router.get("/frameworks/{framework_id}/report/csv") @router.get("/frameworks/{framework_id}/report/csv")
# Define function framework_report_csv
def framework_report_csv( def framework_report_csv(
# Entry: framework_id
framework_id: str, framework_id: str,
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> StreamingResponse:
"""Export compliance report as CSV.""" """Export compliance report as CSV.
Args:
framework_id (str): Identifier of the compliance framework to export.
db (Session): SQLAlchemy database session.
current_user (User): Authenticated user making the request.
Returns:
StreamingResponse: CSV file attachment with compliance coverage data.
"""
# csv_bytes, filename = build_framework_report_csv(db, framework_id)
csv_bytes, filename = build_framework_report_csv(db, framework_id) csv_bytes, filename = build_framework_report_csv(db, framework_id)
# Return StreamingResponse(
return StreamingResponse( return StreamingResponse(
iter([csv_bytes]), iter([csv_bytes]),
# Keyword argument: media_type
media_type="text/csv", media_type="text/csv",
# Keyword argument: headers
headers={ headers={
# Literal argument value
"Content-Disposition": f"attachment; filename={filename}", "Content-Disposition": f"attachment; filename={filename}",
}, },
) )
@@ -89,12 +161,26 @@ def framework_report_csv(
@router.get("/frameworks/{framework_id}/gaps") @router.get("/frameworks/{framework_id}/gaps")
# Define function framework_gaps
def framework_gaps( def framework_gaps(
# Entry: framework_id
framework_id: str, framework_id: str,
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> dict:
"""Get controls with techniques that are not adequately covered.""" """Get controls with techniques that are not adequately covered.
Args:
framework_id (str): Identifier of the compliance framework to analyse.
db (Session): SQLAlchemy database session.
current_user (User): Authenticated user making the request.
Returns:
dict: Controls flagged as gaps, with linked technique IDs and coverage ratios.
"""
# Return get_framework_gaps(db, framework_id)
return get_framework_gaps(db, framework_id) return get_framework_gaps(db, framework_id)
@@ -102,20 +188,47 @@ def framework_gaps(
@router.post("/import/nist-800-53") @router.post("/import/nist-800-53")
# Define function import_nist
def import_nist( def import_nist(
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(require_role("admin")), current_user: User = Depends(require_role("admin")),
): ) -> dict:
"""Import NIST 800-53 Rev 5 mappings (admin only).""" """Import NIST 800-53 Rev 5 mappings (admin only).
Args:
db (Session): SQLAlchemy database session.
current_user (User): Authenticated admin user.
Returns:
dict: Import result with counts of created and updated control mappings.
"""
# Assign result = import_nist_800_53_mappings(db)
result = import_nist_800_53_mappings(db) result = import_nist_800_53_mappings(db)
# Return result
return result return result
# Apply the @router.post decorator
@router.post("/import/cis-controls-v8") @router.post("/import/cis-controls-v8")
# Define function import_cis
def import_cis( def import_cis(
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(require_role("admin")), current_user: User = Depends(require_role("admin")),
): ) -> dict:
"""Import CIS Controls v8 mappings (admin only).""" """Import CIS Controls v8 mappings (admin only).
Args:
db (Session): SQLAlchemy database session.
current_user (User): Authenticated admin user.
Returns:
dict: Import result with counts of created and updated control mappings.
"""
# Assign result = import_cis_controls_v8_mappings(db)
result = import_cis_controls_v8_mappings(db) result = import_cis_controls_v8_mappings(db)
# Return result
return result return result
+53 -7
View File
@@ -1,26 +1,47 @@
"""D3FEND endpoints — defensive technique listings, mappings, and import trigger.""" """D3FEND endpoints — defensive technique listings, mappings, and import trigger."""
# Import logging
import logging import logging
# Import Optional from typing
from typing import Optional from typing import Optional
# Import APIRouter, Depends, Query from fastapi
from fastapi import APIRouter, Depends, Query from fastapi import APIRouter, Depends, Query
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
# Import get_db from app.database
from app.database import get_db from app.database import get_db
# Import get_current_user, require_role from app.dependencies.auth
from app.dependencies.auth import get_current_user, require_role from app.dependencies.auth import get_current_user, require_role
# Import User from app.models.user
from app.models.user import User from app.models.user import User
# Import from app.services.d3fend_import_service
from app.services.d3fend_import_service import ( from app.services.d3fend_import_service import (
import_d3fend_techniques,
import_d3fend_mappings, import_d3fend_mappings,
import_d3fend_techniques,
) )
# Import from app.services.d3fend_query_service
from app.services.d3fend_query_service import (
get_defenses_for_attack_technique,
list_d3fend_tactics,
)
# Import from app.services.d3fend_query_service
from app.services.d3fend_query_service import ( from app.services.d3fend_query_service import (
list_defensive_techniques as list_defensive_techniques_svc, list_defensive_techniques as list_defensive_techniques_svc,
list_d3fend_tactics,
get_defenses_for_attack_technique,
) )
# Assign logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Assign router = APIRouter(prefix="/d3fend", tags=["d3fend"])
router = APIRouter(prefix="/d3fend", tags=["d3fend"]) router = APIRouter(prefix="/d3fend", tags=["d3fend"])
@@ -29,15 +50,23 @@ router = APIRouter(prefix="/d3fend", tags=["d3fend"])
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@router.get("") @router.get("")
# Define function list_defensive_techniques
def list_defensive_techniques( def list_defensive_techniques(
# Entry: tactic
tactic: Optional[str] = Query(None), tactic: Optional[str] = Query(None),
# Entry: search
search: Optional[str] = Query(None), search: Optional[str] = Query(None),
# Entry: offset
offset: int = Query(0, ge=0), offset: int = Query(0, ge=0),
# Entry: limit
limit: int = Query(50, ge=1, le=200), limit: int = Query(50, ge=1, le=200),
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> list:
"""List all D3FEND defensive techniques with optional filters.""" """List all D3FEND defensive techniques with optional filters."""
# Return list_defensive_techniques_svc(
return list_defensive_techniques_svc( return list_defensive_techniques_svc(
db, tactic=tactic, search=search, offset=offset, limit=limit db, tactic=tactic, search=search, offset=offset, limit=limit
) )
@@ -48,11 +77,15 @@ def list_defensive_techniques(
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@router.get("/tactics") @router.get("/tactics")
# Define function list_d3fend_tactics_endpoint
def list_d3fend_tactics_endpoint( def list_d3fend_tactics_endpoint(
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> list:
"""Return a list of all D3FEND tactics with counts.""" """Return a list of all D3FEND tactics with counts."""
# Return list_d3fend_tactics(db)
return list_d3fend_tactics(db) return list_d3fend_tactics(db)
@@ -61,12 +94,17 @@ def list_d3fend_tactics_endpoint(
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@router.get("/for-technique/{mitre_id}") @router.get("/for-technique/{mitre_id}")
# Define function get_defenses_for_attack_technique_endpoint
def get_defenses_for_attack_technique_endpoint( def get_defenses_for_attack_technique_endpoint(
# Entry: mitre_id
mitre_id: str, mitre_id: str,
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> list:
"""Get all D3FEND defensive techniques mapped to a given ATT&CK technique.""" """Get all D3FEND defensive techniques mapped to a given ATT&CK technique."""
# Return get_defenses_for_attack_technique(db, mitre_id)
return get_defenses_for_attack_technique(db, mitre_id) return get_defenses_for_attack_technique(db, mitre_id)
@@ -75,15 +113,23 @@ def get_defenses_for_attack_technique_endpoint(
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@router.post("/import") @router.post("/import")
# Define function trigger_d3fend_import
def trigger_d3fend_import( def trigger_d3fend_import(
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(require_role("admin")), current_user: User = Depends(require_role("admin")),
): ) -> dict:
"""Import D3FEND techniques and ATT&CK mappings. Admin only.""" """Import D3FEND techniques and ATT&CK mappings. Admin only."""
# Assign tech_result = import_d3fend_techniques(db)
tech_result = import_d3fend_techniques(db) tech_result = import_d3fend_techniques(db)
# Assign mapping_result = import_d3fend_mappings(db)
mapping_result = import_d3fend_mappings(db) mapping_result = import_d3fend_mappings(db)
# Return {
return { return {
# Literal argument value
"techniques": tech_result, "techniques": tech_result,
# Literal argument value
"mappings": mapping_result, "mappings": mapping_result,
} }
+77 -9
View File
@@ -5,16 +5,34 @@ Provides a centralized panel for managing all external data sources
including sync triggers, enable/disable toggles, and statistics. including sync triggers, enable/disable toggles, and statistics.
""" """
from fastapi import APIRouter, Depends # Import Optional from typing
from pydantic import BaseModel
from sqlalchemy.orm import Session
from typing import Optional from typing import Optional
# Import APIRouter, Depends from fastapi
from fastapi import APIRouter, Depends
# Import BaseModel from pydantic
from pydantic import BaseModel
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session
# Import get_db from app.database
from app.database import get_db from app.database import get_db
# Import require_role from app.dependencies.auth
from app.dependencies.auth import require_role from app.dependencies.auth import require_role
# Import UnitOfWork from app.domain.unit_of_work
from app.domain.unit_of_work import UnitOfWork from app.domain.unit_of_work import UnitOfWork
# Import User from app.models.user
from app.models.user import User from app.models.user import User
# Import log_action from app.services.audit_service
from app.services.audit_service import log_action from app.services.audit_service import log_action
# Import from app.services.data_source_service
from app.services.data_source_service import ( from app.services.data_source_service import (
get_source_stats, get_source_stats,
list_sources, list_sources,
@@ -23,18 +41,21 @@ from app.services.data_source_service import (
update_source, update_source,
) )
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Pydantic schemas for request validation # Pydantic schemas for request validation
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class DataSourceUpdate(BaseModel): class DataSourceUpdate(BaseModel):
"""Payload for updating a data source — only allowed fields.""" """Payload for updating a data source — only allowed fields."""
# Assign is_enabled = None
is_enabled: Optional[bool] = None is_enabled: Optional[bool] = None
# Assign sync_frequency = None
sync_frequency: Optional[str] = None sync_frequency: Optional[str] = None
# Assign config = None
config: Optional[dict] = None config: Optional[dict] = None
# Assign router = APIRouter(prefix="/data-sources", tags=["data-sources"])
router = APIRouter(prefix="/data-sources", tags=["data-sources"]) router = APIRouter(prefix="/data-sources", tags=["data-sources"])
@@ -44,90 +65,137 @@ router = APIRouter(prefix="/data-sources", tags=["data-sources"])
@router.get("") @router.get("")
# Define function list_data_sources
def list_data_sources( def list_data_sources(
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(require_role("admin")), current_user: User = Depends(require_role("admin")),
): ) -> list:
"""List all registered data sources. """List all registered data sources.
**Requires** the ``admin`` role. **Requires** the ``admin`` role.
""" """
# Return list_sources(db)
return list_sources(db) return list_sources(db)
# Apply the @router.patch decorator
@router.patch("/{source_id}") @router.patch("/{source_id}")
# Define function update_data_source
def update_data_source( def update_data_source(
# Entry: source_id
source_id: str, source_id: str,
# Entry: body
body: DataSourceUpdate, body: DataSourceUpdate,
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(require_role("admin")), current_user: User = Depends(require_role("admin")),
): ) -> dict:
"""Update a data source (enable/disable, change config). """Update a data source (enable/disable, change config).
**Requires** the ``admin`` role. **Requires** the ``admin`` role.
""" """
# Assign update_data = body.model_dump(exclude_unset=True)
update_data = body.model_dump(exclude_unset=True) update_data = body.model_dump(exclude_unset=True)
# Open context manager
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
# Call update_source()
update_source(db, source_id, **update_data) update_source(db, source_id, **update_data)
# Call log_action()
log_action( log_action(
db, db,
# Keyword argument: user_id
user_id=current_user.id, user_id=current_user.id,
# Keyword argument: action
action="update_data_source", action="update_data_source",
# Keyword argument: entity_type
entity_type="data_source", entity_type="data_source",
# Keyword argument: entity_id
entity_id=source_id, entity_id=source_id,
# Keyword argument: details
details={"updates": update_data}, details={"updates": update_data},
) )
# Call uow.commit()
uow.commit() uow.commit()
# Return {"message": "Data source updated", "id": source_id}
return {"message": "Data source updated", "id": source_id} return {"message": "Data source updated", "id": source_id}
# Apply the @router.post decorator
@router.post("/{source_id}/sync") @router.post("/{source_id}/sync")
# Define function sync_data_source
def sync_data_source( def sync_data_source(
# Entry: source_id
source_id: str, source_id: str,
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(require_role("admin")), current_user: User = Depends(require_role("admin")),
): ) -> dict:
"""Trigger sync/import for a specific data source. """Trigger sync/import for a specific data source.
**Requires** the ``admin`` role. **Requires** the ``admin`` role.
""" """
# Return sync_source(db, source_id)
return sync_source(db, source_id) return sync_source(db, source_id)
# Apply the @router.post decorator
@router.post("/sync-all") @router.post("/sync-all")
# Define function sync_all_data_sources
def sync_all_data_sources( def sync_all_data_sources(
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(require_role("admin")), current_user: User = Depends(require_role("admin")),
): ) -> dict:
"""Trigger sync for all enabled data sources (sequentially). """Trigger sync for all enabled data sources (sequentially).
**Requires** the ``admin`` role. **Requires** the ``admin`` role.
""" """
# Assign results = sync_all_sources(db)
results = sync_all_sources(db) results = sync_all_sources(db)
# Open context manager
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
# Call log_action()
log_action( log_action(
db, db,
# Keyword argument: user_id
user_id=current_user.id, user_id=current_user.id,
# Keyword argument: action
action="sync_all_data_sources", action="sync_all_data_sources",
# Keyword argument: entity_type
entity_type="data_source", entity_type="data_source",
# Keyword argument: entity_id
entity_id=None, entity_id=None,
# Keyword argument: details
details={"results": results}, details={"results": results},
) )
# Call uow.commit()
uow.commit() uow.commit()
# Return {"message": "Sync all complete", "results": results}
return {"message": "Sync all complete", "results": results} return {"message": "Sync all complete", "results": results}
# Apply the @router.get decorator
@router.get("/{source_id}/stats") @router.get("/{source_id}/stats")
# Define function get_data_source_stats
def get_data_source_stats( def get_data_source_stats(
# Entry: source_id
source_id: str, source_id: str,
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(require_role("admin")), current_user: User = Depends(require_role("admin")),
): ) -> dict:
"""Get detailed statistics for a specific data source. """Get detailed statistics for a specific data source.
**Requires** the ``admin`` role. **Requires** the ``admin`` role.
""" """
# Return get_source_stats(db, source_id)
return get_source_stats(db, source_id) return get_source_stats(db, source_id)
+73 -14
View File
@@ -6,36 +6,55 @@ Provides endpoints for browsing detection rules, querying rules by technique,
and managing the template ↔ detection rule associations. and managing the template ↔ detection rule associations.
""" """
# Import uuid
import uuid import uuid
# Import Optional from typing
from typing import Optional from typing import Optional
# Import APIRouter, Depends, Query from fastapi
from fastapi import APIRouter, Depends, Query from fastapi import APIRouter, Depends, Query
# Import BaseModel from pydantic
from pydantic import BaseModel from pydantic import BaseModel
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
# Import get_db from app.database
from app.database import get_db from app.database import get_db
from app.dependencies.auth import get_current_user, require_role, require_any_role
from app.models.user import User
from app.services.detection_rule_service import (
list_rules,
get_rules_for_template,
auto_associate_rules,
get_rules_for_test,
evaluate_rule,
)
# Import get_current_user, require_any_role, require_role from app.dependencies.auth
from app.dependencies.auth import get_current_user, require_any_role, require_role
# Import User from app.models.user
from app.models.user import User
# Import from app.services.detection_rule_service
from app.services.detection_rule_service import (
auto_associate_rules,
evaluate_rule,
get_rules_for_template,
get_rules_for_test,
list_rules,
)
# ── Pydantic schemas for request validation ──────────────────────────── # ── Pydantic schemas for request validation ────────────────────────────
class DetectionRuleEvaluate(BaseModel): class DetectionRuleEvaluate(BaseModel):
"""Payload for evaluating a detection rule against a test.""" """Payload for evaluating a detection rule against a test."""
# test_id: uuid.UUID
test_id: uuid.UUID test_id: uuid.UUID
# detection_rule_id: uuid.UUID
detection_rule_id: uuid.UUID detection_rule_id: uuid.UUID
# Assign triggered = None
triggered: Optional[bool] = None triggered: Optional[bool] = None
# Assign notes = None
notes: Optional[str] = None notes: Optional[str] = None
# Assign router = APIRouter(prefix="/detection-rules", tags=["detection-rules"])
router = APIRouter(prefix="/detection-rules", tags=["detection-rules"]) router = APIRouter(prefix="/detection-rules", tags=["detection-rules"])
@@ -43,24 +62,40 @@ router = APIRouter(prefix="/detection-rules", tags=["detection-rules"])
@router.get("") @router.get("")
# Define function list_detection_rules
def list_detection_rules( def list_detection_rules(
# Entry: technique
technique: Optional[str] = Query(None, description="Filter by MITRE technique ID"), technique: Optional[str] = Query(None, description="Filter by MITRE technique ID"),
# Entry: source
source: Optional[str] = Query(None, description="Filter by source (sigma, elastic, splunk, custom)"), source: Optional[str] = Query(None, description="Filter by source (sigma, elastic, splunk, custom)"),
# Entry: severity
severity: Optional[str] = Query(None), severity: Optional[str] = Query(None),
# Entry: search
search: Optional[str] = Query(None), search: Optional[str] = Query(None),
# Entry: offset
offset: int = Query(0, ge=0), offset: int = Query(0, ge=0),
# Entry: limit
limit: int = Query(50, ge=1, le=200), limit: int = Query(50, ge=1, le=200),
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> list:
"""List detection rules with optional filters and pagination.""" """List detection rules with optional filters and pagination."""
# Return list_rules(
return list_rules( return list_rules(
db, db,
# Keyword argument: technique
technique=technique, technique=technique,
# Keyword argument: source
source=source, source=source,
# Keyword argument: severity
severity=severity, severity=severity,
# Keyword argument: search
search=search, search=search,
# Keyword argument: offset
offset=offset, offset=offset,
# Keyword argument: limit
limit=limit, limit=limit,
) )
@@ -69,12 +104,17 @@ def list_detection_rules(
@router.get("/for-template/{template_id}") @router.get("/for-template/{template_id}")
# Define function get_detection_rules_for_template
def get_detection_rules_for_template( def get_detection_rules_for_template(
# Entry: template_id
template_id: str, template_id: str,
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> list:
"""Get detection rules associated with a test template.""" """Get detection rules associated with a test template."""
# Return get_rules_for_template(db, template_id)
return get_rules_for_template(db, template_id) return get_rules_for_template(db, template_id)
@@ -82,16 +122,20 @@ def get_detection_rules_for_template(
@router.post("/auto-associate") @router.post("/auto-associate")
# Define function auto_associate_detection_rules
def auto_associate_detection_rules( def auto_associate_detection_rules(
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(require_role("admin")), current_user: User = Depends(require_role("admin")),
): ) -> dict:
"""Auto-associate test templates with detection rules by MITRE technique ID. """Auto-associate test templates with detection rules by MITRE technique ID.
For each active template, find all active detection rules for the same For each active template, find all active detection rules for the same
technique and create associations. Rules with severity >= high are marked technique and create associations. Rules with severity >= high are marked
as primary. as primary.
""" """
# Return auto_associate_rules(db)
return auto_associate_rules(db) return auto_associate_rules(db)
@@ -99,16 +143,21 @@ def auto_associate_detection_rules(
@router.get("/for-test/{test_id}") @router.get("/for-test/{test_id}")
# Define function get_detection_rules_for_test
def get_detection_rules_for_test( def get_detection_rules_for_test(
# Entry: test_id
test_id: str, test_id: str,
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> list:
"""Get detection rules relevant to a test, along with their evaluation results. """Get detection rules relevant to a test, along with their evaluation results.
Finds rules by matching the test's technique_id to detection rules, Finds rules by matching the test's technique_id to detection rules,
and returns any existing evaluation results. and returns any existing evaluation results.
""" """
# Return get_rules_for_test(db, test_id)
return get_rules_for_test(db, test_id) return get_rules_for_test(db, test_id)
@@ -116,17 +165,27 @@ def get_detection_rules_for_test(
@router.post("/evaluate") @router.post("/evaluate")
# Define function evaluate_detection_rule
def evaluate_detection_rule( def evaluate_detection_rule(
# Entry: payload
payload: DetectionRuleEvaluate, payload: DetectionRuleEvaluate,
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(require_any_role("blue_tech", "blue_lead")), current_user: User = Depends(require_any_role("blue_tech", "blue_lead")),
): ) -> dict:
"""Save or update the evaluation result for a detection rule on a test.""" """Save or update the evaluation result for a detection rule on a test."""
# Return evaluate_rule(
return evaluate_rule( return evaluate_rule(
db, db,
# Keyword argument: test_id
test_id=payload.test_id, test_id=payload.test_id,
# Keyword argument: detection_rule_id
detection_rule_id=payload.detection_rule_id, detection_rule_id=payload.detection_rule_id,
# Keyword argument: triggered
triggered=payload.triggered, triggered=payload.triggered,
# Keyword argument: notes
notes=payload.notes, notes=payload.notes,
# Keyword argument: evaluator_id
evaluator_id=current_user.id, evaluator_id=current_user.id,
) )
+124 -7
View File
@@ -19,34 +19,66 @@ Access Control
``validated``, or ``rejected``. ``validated``, or ``rejected``.
""" """
# Import hashlib
import hashlib import hashlib
# Import os
import os import os
# Import uuid
import uuid as _uuid import uuid as _uuid
# Import Optional from typing
from typing import Optional from typing import Optional
# Import APIRouter, Depends, File, Form, Query, Request,... from fastapi
from fastapi import APIRouter, Depends, File, Form, Query, Request, UploadFile, status from fastapi import APIRouter, Depends, File, Form, Query, Request, UploadFile, status
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
# Import get_db from app.database
from app.database import get_db from app.database import get_db
from app.domain.unit_of_work import UnitOfWork
# Import get_current_user from app.dependencies.auth
from app.dependencies.auth import get_current_user from app.dependencies.auth import get_current_user
# Import UnitOfWork from app.domain.unit_of_work
from app.domain.unit_of_work import UnitOfWork
# Import limiter from app.limiter
from app.limiter import limiter
# Import TeamSide from app.models.enums
from app.models.enums import TeamSide from app.models.enums import TeamSide
# Import Evidence from app.models.evidence
from app.models.evidence import Evidence from app.models.evidence import Evidence
# Import User from app.models.user
from app.models.user import User from app.models.user import User
# Import EvidenceOut from app.schemas.evidence
from app.schemas.evidence import EvidenceOut from app.schemas.evidence import EvidenceOut
# Import log_action from app.services.audit_service
from app.services.audit_service import log_action from app.services.audit_service import log_action
# Import from app.services.evidence_service
from app.services.evidence_service import ( from app.services.evidence_service import (
MAX_UPLOAD_SIZE,
get_evidence_or_raise, get_evidence_or_raise,
get_test_or_raise, get_test_or_raise,
list_evidence_for_test, list_evidence_for_test,
MAX_UPLOAD_SIZE,
validate_delete_permission, validate_delete_permission,
validate_file, validate_file,
validate_upload_permission, validate_upload_permission,
) )
from app.limiter import limiter
# Import get_presigned_url, upload_file from app.storage
from app.storage import get_presigned_url, upload_file from app.storage import get_presigned_url, upload_file
# Assign router = APIRouter(tags=["evidence"])
router = APIRouter(tags=["evidence"]) router = APIRouter(tags=["evidence"])
@@ -56,15 +88,25 @@ router = APIRouter(tags=["evidence"])
def _evidence_to_out(evidence: Evidence) -> EvidenceOut: def _evidence_to_out(evidence: Evidence) -> EvidenceOut:
"""Convert an ORM ``Evidence`` to the API schema, injecting a presigned URL.""" """Convert an ORM ``Evidence`` to the API schema, injecting a presigned URL."""
# Return EvidenceOut(
return EvidenceOut( return EvidenceOut(
# Keyword argument: id
id=evidence.id, id=evidence.id,
# Keyword argument: test_id
test_id=evidence.test_id, test_id=evidence.test_id,
# Keyword argument: file_name
file_name=evidence.file_name, file_name=evidence.file_name,
# Keyword argument: sha256_hash
sha256_hash=evidence.sha256_hash, sha256_hash=evidence.sha256_hash,
# Keyword argument: uploaded_by
uploaded_by=evidence.uploaded_by, uploaded_by=evidence.uploaded_by,
# Keyword argument: uploaded_at
uploaded_at=evidence.uploaded_at, uploaded_at=evidence.uploaded_at,
# Keyword argument: team
team=evidence.team, team=evidence.team,
# Keyword argument: notes
notes=evidence.notes, notes=evidence.notes,
# Keyword argument: download_url
download_url=get_presigned_url(evidence.file_path), download_url=get_presigned_url(evidence.file_path),
) )
@@ -75,30 +117,47 @@ def _evidence_to_out(evidence: Evidence) -> EvidenceOut:
@router.post( @router.post(
# Literal argument value
"/tests/{test_id}/evidence", "/tests/{test_id}/evidence",
# Keyword argument: response_model
response_model=EvidenceOut, response_model=EvidenceOut,
# Keyword argument: status_code
status_code=status.HTTP_201_CREATED, status_code=status.HTTP_201_CREATED,
) )
# Apply the @limiter.limit decorator
@limiter.limit("10/minute") @limiter.limit("10/minute")
# Define async function upload_evidence
async def upload_evidence( async def upload_evidence(
# Entry: request
request: Request, request: Request,
# Entry: test_id
test_id: _uuid.UUID, test_id: _uuid.UUID,
# Entry: file
file: UploadFile = File(...), file: UploadFile = File(...),
# Entry: team
team: TeamSide = Form(TeamSide.red), team: TeamSide = Form(TeamSide.red),
# Entry: notes
notes: Optional[str] = Form(None), notes: Optional[str] = Form(None),
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> EvidenceOut:
"""Upload a file as evidence for the given test. """Upload a file as evidence for the given test.
The ``team`` field (sent as form data) determines whether this is The ``team`` field (sent as form data) determines whether this is
Red Team (attack) or Blue Team (detection) evidence. Red Team (attack) or Blue Team (detection) evidence.
""" """
# Assign test = get_test_or_raise(db, test_id)
test = get_test_or_raise(db, test_id) test = get_test_or_raise(db, test_id)
# Call validate_upload_permission()
validate_upload_permission(test, team, current_user.role) validate_upload_permission(test, team, current_user.role)
# Assign file_name = file.filename or "unnamed"
file_name = file.filename or "unnamed" file_name = file.filename or "unnamed"
# Assign content = await file.read(MAX_UPLOAD_SIZE + 1)
content = await file.read(MAX_UPLOAD_SIZE + 1) content = await file.read(MAX_UPLOAD_SIZE + 1)
# Call validate_file()
validate_file(file_name, len(content)) validate_file(file_name, len(content))
# Hash # Hash
@@ -106,6 +165,7 @@ async def upload_evidence(
# 4. Object key (sanitise filename to prevent path traversal in storage) # 4. Object key (sanitise filename to prevent path traversal in storage)
safe_name = os.path.basename(file_name) safe_name = os.path.basename(file_name)
# Assign key = f"{test_id}/{_uuid.uuid4()}_{safe_name}"
key = f"{test_id}/{_uuid.uuid4()}_{safe_name}" key = f"{test_id}/{_uuid.uuid4()}_{safe_name}"
# 5. Upload to MinIO # 5. Upload to MinIO
@@ -113,33 +173,56 @@ async def upload_evidence(
# 6. Persist metadata and audit # 6. Persist metadata and audit
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
# Assign evidence = Evidence(
evidence = Evidence( evidence = Evidence(
# Keyword argument: test_id
test_id=test_id, test_id=test_id,
# Keyword argument: file_name
file_name=safe_name, file_name=safe_name,
# Keyword argument: file_path
file_path=key, file_path=key,
# Keyword argument: sha256_hash
sha256_hash=sha256, sha256_hash=sha256,
# Keyword argument: uploaded_by
uploaded_by=current_user.id, uploaded_by=current_user.id,
# Keyword argument: team
team=team, team=team,
# Keyword argument: notes
notes=notes, notes=notes,
) )
# Stage new record(s) for database insertion
db.add(evidence) db.add(evidence)
# Flush changes to DB without committing the transaction
db.flush() # Get evidence.id for audit db.flush() # Get evidence.id for audit
# Call log_action()
log_action( log_action(
db, db,
# Keyword argument: user_id
user_id=current_user.id, user_id=current_user.id,
# Keyword argument: action
action="upload_evidence", action="upload_evidence",
# Keyword argument: entity_type
entity_type="evidence", entity_type="evidence",
# Keyword argument: entity_id
entity_id=evidence.id, entity_id=evidence.id,
# Keyword argument: details
details={ details={
# Literal argument value
"file_name": safe_name, "file_name": safe_name,
# Literal argument value
"sha256": sha256, "sha256": sha256,
# Literal argument value
"test_id": str(test_id), "test_id": str(test_id),
# Literal argument value
"team": team.value, "team": team.value,
}, },
) )
# Call uow.commit()
uow.commit() uow.commit()
# Reload ORM object attributes from the database
db.refresh(evidence) db.refresh(evidence)
# Return _evidence_to_out(evidence)
return _evidence_to_out(evidence) return _evidence_to_out(evidence)
@@ -149,15 +232,23 @@ async def upload_evidence(
@router.get("/tests/{test_id}/evidence", response_model=list[EvidenceOut]) @router.get("/tests/{test_id}/evidence", response_model=list[EvidenceOut])
# Define function list_evidence
def list_evidence( def list_evidence(
# Entry: test_id
test_id: _uuid.UUID, test_id: _uuid.UUID,
# Entry: team
team: Optional[str] = Query(None, description="Filter by team: red or blue"), team: Optional[str] = Query(None, description="Filter by team: red or blue"),
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> list[EvidenceOut]:
"""List all evidences for a test, optionally filtered by team.""" """List all evidences for a test, optionally filtered by team."""
# Call get_test_or_raise()
get_test_or_raise(db, test_id) get_test_or_raise(db, test_id)
# Assign evidences = list_evidence_for_test(db, test_id, team=team)
evidences = list_evidence_for_test(db, test_id, team=team) evidences = list_evidence_for_test(db, test_id, team=team)
# Return [_evidence_to_out(e) for e in evidences]
return [_evidence_to_out(e) for e in evidences] return [_evidence_to_out(e) for e in evidences]
@@ -167,13 +258,19 @@ def list_evidence(
@router.get("/evidence/{evidence_id}", response_model=EvidenceOut) @router.get("/evidence/{evidence_id}", response_model=EvidenceOut)
# Define function get_evidence
def get_evidence( def get_evidence(
# Entry: evidence_id
evidence_id: _uuid.UUID, evidence_id: _uuid.UUID,
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> EvidenceOut:
"""Return evidence metadata together with a presigned download URL.""" """Return evidence metadata together with a presigned download URL."""
# Assign evidence = get_evidence_or_raise(db, evidence_id)
evidence = get_evidence_or_raise(db, evidence_id) evidence = get_evidence_or_raise(db, evidence_id)
# Return _evidence_to_out(evidence)
return _evidence_to_out(evidence) return _evidence_to_out(evidence)
@@ -183,11 +280,15 @@ def get_evidence(
@router.delete("/evidence/{evidence_id}", status_code=status.HTTP_200_OK) @router.delete("/evidence/{evidence_id}", status_code=status.HTTP_200_OK)
# Define function delete_evidence
def delete_evidence( def delete_evidence(
# Entry: evidence_id
evidence_id: _uuid.UUID, evidence_id: _uuid.UUID,
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> dict:
"""Delete an evidence record. """Delete an evidence record.
Only allowed in editable states: Only allowed in editable states:
@@ -195,24 +296,40 @@ def delete_evidence(
- Blue evidence: ``blue_evaluating`` - Blue evidence: ``blue_evaluating``
- No deletions in ``in_review``, ``validated``, ``rejected`` - No deletions in ``in_review``, ``validated``, ``rejected``
""" """
# Assign evidence = get_evidence_or_raise(db, evidence_id)
evidence = get_evidence_or_raise(db, evidence_id) evidence = get_evidence_or_raise(db, evidence_id)
# Assign test = get_test_or_raise(db, evidence.test_id)
test = get_test_or_raise(db, evidence.test_id) test = get_test_or_raise(db, evidence.test_id)
# Call validate_delete_permission()
validate_delete_permission(test, evidence, current_user.role, current_user.id) validate_delete_permission(test, evidence, current_user.role, current_user.id)
# Open context manager
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
# Call log_action()
log_action( log_action(
db, db,
# Keyword argument: user_id
user_id=current_user.id, user_id=current_user.id,
# Keyword argument: action
action="delete_evidence", action="delete_evidence",
# Keyword argument: entity_type
entity_type="evidence", entity_type="evidence",
# Keyword argument: entity_id
entity_id=evidence.id, entity_id=evidence.id,
# Keyword argument: details
details={ details={
# Literal argument value
"file_name": evidence.file_name, "file_name": evidence.file_name,
# Literal argument value
"test_id": str(evidence.test_id), "test_id": str(evidence.test_id),
# Literal argument value
"team": evidence.team.value if evidence.team else None, "team": evidence.team.value if evidence.team else None,
}, },
) )
# Mark record for deletion on next commit
db.delete(evidence) db.delete(evidence)
# Call uow.commit()
uow.commit() uow.commit()
# Return {"detail": "Evidence deleted"}
return {"detail": "Evidence deleted"} return {"detail": "Evidence deleted"}
+73 -5
View File
@@ -5,101 +5,169 @@ No business logic lives here — only request validation and response
formatting. formatting.
""" """
# Import io
import io import io
# Import json
import json import json
# Import Optional from typing
from typing import Optional from typing import Optional
# Import APIRouter, Depends, Query from fastapi
from fastapi import APIRouter, Depends, Query from fastapi import APIRouter, Depends, Query
# Import StreamingResponse from fastapi.responses
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
# Import get_db from app.database
from app.database import get_db from app.database import get_db
# Import get_current_user from app.dependencies.auth
from app.dependencies.auth import get_current_user from app.dependencies.auth import get_current_user
# Import User from app.models.user
from app.models.user import User from app.models.user import User
# Import heatmap_service from app.services
from app.services import heatmap_service from app.services import heatmap_service
# Assign router = APIRouter(prefix="/heatmap", tags=["heatmap"])
router = APIRouter(prefix="/heatmap", tags=["heatmap"]) router = APIRouter(prefix="/heatmap", tags=["heatmap"])
# Apply the @router.get decorator
@router.get("/coverage") @router.get("/coverage")
# Define function heatmap_coverage
def heatmap_coverage( def heatmap_coverage(
# Entry: platforms
platforms: Optional[str] = Query(None, description="Comma-separated platforms"), platforms: Optional[str] = Query(None, description="Comma-separated platforms"),
# Entry: tactics
tactics: Optional[str] = Query(None, description="Comma-separated tactics"), tactics: Optional[str] = Query(None, description="Comma-separated tactics"),
# Entry: min_score
min_score: int = Query(0, ge=0, le=100), min_score: int = Query(0, ge=0, le=100),
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> dict:
"""Coverage layer — score based on status_global of each technique.""" """Coverage layer — score based on status_global of each technique."""
# Return heatmap_service.build_coverage_layer(
return heatmap_service.build_coverage_layer( return heatmap_service.build_coverage_layer(
db, platforms=platforms, tactics=tactics, min_score=min_score, db, platforms=platforms, tactics=tactics, min_score=min_score,
) )
# Apply the @router.get decorator
@router.get("/threat-actor/{actor_id}") @router.get("/threat-actor/{actor_id}")
# Define function heatmap_threat_actor
def heatmap_threat_actor( def heatmap_threat_actor(
# Entry: actor_id
actor_id: str, actor_id: str,
# Entry: platforms
platforms: Optional[str] = Query(None), platforms: Optional[str] = Query(None),
# Entry: tactics
tactics: Optional[str] = Query(None), tactics: Optional[str] = Query(None),
# Entry: min_score
min_score: int = Query(0, ge=0, le=100), min_score: int = Query(0, ge=0, le=100),
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> dict:
"""Threat actor layer — techniques used by an actor with coverage color.""" """Threat actor layer — techniques used by an actor with coverage color."""
# Return heatmap_service.build_threat_actor_layer(
return heatmap_service.build_threat_actor_layer( return heatmap_service.build_threat_actor_layer(
db, actor_id, platforms=platforms, tactics=tactics, min_score=min_score, db, actor_id, platforms=platforms, tactics=tactics, min_score=min_score,
) )
# Apply the @router.get decorator
@router.get("/detection-rules") @router.get("/detection-rules")
# Define function heatmap_detection_rules
def heatmap_detection_rules( def heatmap_detection_rules(
# Entry: platforms
platforms: Optional[str] = Query(None), platforms: Optional[str] = Query(None),
# Entry: tactics
tactics: Optional[str] = Query(None), tactics: Optional[str] = Query(None),
# Entry: min_score
min_score: int = Query(0, ge=0, le=100), min_score: int = Query(0, ge=0, le=100),
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> dict:
"""Detection rules layer — score based on ratio of rules available vs total.""" """Detection rules layer — score based on ratio of rules available vs total."""
# Return heatmap_service.build_detection_rules_layer(
return heatmap_service.build_detection_rules_layer( return heatmap_service.build_detection_rules_layer(
db, platforms=platforms, tactics=tactics, min_score=min_score, db, platforms=platforms, tactics=tactics, min_score=min_score,
) )
# Apply the @router.get decorator
@router.get("/campaign/{campaign_id}") @router.get("/campaign/{campaign_id}")
# Define function heatmap_campaign
def heatmap_campaign( def heatmap_campaign(
# Entry: campaign_id
campaign_id: str, campaign_id: str,
# Entry: platforms
platforms: Optional[str] = Query(None), platforms: Optional[str] = Query(None),
# Entry: tactics
tactics: Optional[str] = Query(None), tactics: Optional[str] = Query(None),
# Entry: min_score
min_score: int = Query(0, ge=0, le=100), min_score: int = Query(0, ge=0, le=100),
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> dict:
"""Campaign layer — only techniques in the campaign, colored by test state.""" """Campaign layer — only techniques in the campaign, colored by test state."""
# Return heatmap_service.build_campaign_layer(
return heatmap_service.build_campaign_layer( return heatmap_service.build_campaign_layer(
db, campaign_id, platforms=platforms, tactics=tactics, min_score=min_score, db, campaign_id, platforms=platforms, tactics=tactics, min_score=min_score,
) )
# Apply the @router.get decorator
@router.get("/export-navigator") @router.get("/export-navigator")
# Define function export_navigator
def export_navigator( def export_navigator(
# Entry: layer
layer: str = Query(..., description="Layer type: coverage, threat-actor, detection-rules, campaign"), layer: str = Query(..., description="Layer type: coverage, threat-actor, detection-rules, campaign"),
# Entry: layer_id
layer_id: Optional[str] = Query(None, description="Actor ID or Campaign ID (if applicable)"), layer_id: Optional[str] = Query(None, description="Actor ID or Campaign ID (if applicable)"),
# Entry: platforms
platforms: Optional[str] = Query(None), platforms: Optional[str] = Query(None),
# Entry: tactics
tactics: Optional[str] = Query(None), tactics: Optional[str] = Query(None),
# Entry: min_score
min_score: int = Query(0, ge=0, le=100), min_score: int = Query(0, ge=0, le=100),
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> StreamingResponse:
"""Export a heatmap layer as a downloadable JSON file for ATT&CK Navigator.""" """Export a heatmap layer as a downloadable JSON file for ATT&CK Navigator."""
# Assign data = heatmap_service.build_navigator_export(
data = heatmap_service.build_navigator_export( data = heatmap_service.build_navigator_export(
db, layer, layer_id=layer_id, db, layer, layer_id=layer_id,
# Keyword argument: platforms
platforms=platforms, tactics=tactics, min_score=min_score, platforms=platforms, tactics=tactics, min_score=min_score,
) )
# Assign json_content = json.dumps(data, indent=2, default=str)
json_content = json.dumps(data, indent=2, default=str) json_content = json.dumps(data, indent=2, default=str)
# Assign buffer = io.BytesIO(json_content.encode("utf-8"))
buffer = io.BytesIO(json_content.encode("utf-8")) buffer = io.BytesIO(json_content.encode("utf-8"))
# Return StreamingResponse(
return StreamingResponse( return StreamingResponse(
buffer, buffer,
# Keyword argument: media_type
media_type="application/json", media_type="application/json",
# Keyword argument: headers
headers={"Content-Disposition": f"attachment; filename=aegis_{layer}_layer.json"}, headers={"Content-Disposition": f"attachment; filename=aegis_{layer}_layer.json"},
) )
+106 -7
View File
@@ -1,136 +1,235 @@
"""Jira integration router — link, search, sync, create issues.""" """Jira integration router — link, search, sync, create issues."""
# Import logging
import logging import logging
# Import Optional from typing
from typing import Optional from typing import Optional
# Import UUID from uuid
from uuid import UUID from uuid import UUID
# Import APIRouter, Depends, Query from fastapi
from fastapi import APIRouter, Depends, Query from fastapi import APIRouter, Depends, Query
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
# Import get_db from app.database
from app.database import get_db from app.database import get_db
# Import get_current_user, require_role from app.dependencies.auth
from app.dependencies.auth import get_current_user, require_role from app.dependencies.auth import get_current_user, require_role
# Import UnitOfWork from app.domain.unit_of_work
from app.domain.unit_of_work import UnitOfWork from app.domain.unit_of_work import UnitOfWork
# Import JiraLinkEntityType from app.models.jira_link
from app.models.jira_link import JiraLinkEntityType from app.models.jira_link import JiraLinkEntityType
# Import User from app.models.user
from app.models.user import User from app.models.user import User
# Import from app.schemas.jira_schema
from app.schemas.jira_schema import ( from app.schemas.jira_schema import (
JiraIssueResult, JiraIssueResult,
JiraLinkCreate, JiraLinkCreate,
JiraLinkOut, JiraLinkOut,
) )
from app.services import jira_service, audit_service
# Import audit_service, jira_service from app.services
from app.services import audit_service, jira_service
# Assign logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Assign router = APIRouter(prefix="/jira", tags=["jira"])
router = APIRouter(prefix="/jira", tags=["jira"]) router = APIRouter(prefix="/jira", tags=["jira"])
# Apply the @router.get decorator
@router.get("/search", response_model=list[JiraIssueResult]) @router.get("/search", response_model=list[JiraIssueResult])
# Define function search_issues
def search_issues( def search_issues(
# Entry: q
q: str = Query(..., min_length=2), q: str = Query(..., min_length=2),
# Entry: max_results
max_results: int = Query(10, le=50), max_results: int = Query(10, le=50),
# Entry: user
user: User = Depends(get_current_user), user: User = Depends(get_current_user),
): ) -> list[JiraIssueResult]:
"""Search Jira issues by JQL or free text.""" """Search Jira issues by JQL or free text."""
# Return jira_service.search_jira_issues(q, max_results)
return jira_service.search_jira_issues(q, max_results) return jira_service.search_jira_issues(q, max_results)
# Apply the @router.post decorator
@router.post("/links", response_model=JiraLinkOut, status_code=201) @router.post("/links", response_model=JiraLinkOut, status_code=201)
# Define function create_link
def create_link( def create_link(
# Entry: body
body: JiraLinkCreate, body: JiraLinkCreate,
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: user
user: User = Depends(get_current_user), user: User = Depends(get_current_user),
): ) -> JiraLinkOut:
"""Associate an Aegis entity with a Jira issue.""" """Associate an Aegis entity with a Jira issue."""
# Open context manager
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
# Assign link = jira_service.create_link(
link = jira_service.create_link( link = jira_service.create_link(
db, db,
# Keyword argument: entity_type
entity_type=body.entity_type, entity_type=body.entity_type,
# Keyword argument: entity_id
entity_id=body.entity_id, entity_id=body.entity_id,
# Keyword argument: jira_issue_key
jira_issue_key=body.jira_issue_key, jira_issue_key=body.jira_issue_key,
# Keyword argument: sync_direction
sync_direction=body.sync_direction, sync_direction=body.sync_direction,
# Keyword argument: created_by
created_by=user.id, created_by=user.id,
) )
# Call audit_service.log_action()
audit_service.log_action( audit_service.log_action(
db, db,
# Keyword argument: user_id
user_id=user.id, user_id=user.id,
# Keyword argument: action
action="JIRA_LINK_CREATED", action="JIRA_LINK_CREATED",
# Keyword argument: entity_type
entity_type="jira_link", entity_type="jira_link",
# Keyword argument: entity_id
entity_id=str(link.id), entity_id=str(link.id),
# Keyword argument: details
details={ details={
# Literal argument value
"linked_entity_type": body.entity_type.value, "linked_entity_type": body.entity_type.value,
# Literal argument value
"linked_entity_id": str(body.entity_id), "linked_entity_id": str(body.entity_id),
# Literal argument value
"jira_issue_key": body.jira_issue_key, "jira_issue_key": body.jira_issue_key,
}, },
) )
# Call uow.commit()
uow.commit() uow.commit()
# Reload ORM object attributes from the database
db.refresh(link) db.refresh(link)
# Return link
return link return link
# Apply the @router.get decorator
@router.get("/links", response_model=list[JiraLinkOut]) @router.get("/links", response_model=list[JiraLinkOut])
# Define function list_links
def list_links( def list_links(
# Entry: entity_type
entity_type: Optional[JiraLinkEntityType] = None, entity_type: Optional[JiraLinkEntityType] = None,
# Entry: entity_id
entity_id: Optional[UUID] = None, entity_id: Optional[UUID] = None,
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: user
user: User = Depends(get_current_user), user: User = Depends(get_current_user),
): ) -> list[JiraLinkOut]:
"""List Jira links, optionally filtered by entity.""" """List Jira links, optionally filtered by entity."""
# Return jira_service.list_links(
return jira_service.list_links( return jira_service.list_links(
db, db,
# Keyword argument: entity_type
entity_type=entity_type, entity_type=entity_type,
# Keyword argument: entity_id
entity_id=entity_id, entity_id=entity_id,
) )
# Apply the @router.post decorator
@router.post("/links/{link_id}/sync") @router.post("/links/{link_id}/sync")
# Define function sync_link
def sync_link( def sync_link(
# Entry: link_id
link_id: UUID, link_id: UUID,
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: user
user: User = Depends(require_role("admin")), user: User = Depends(require_role("admin")),
): ) -> dict:
"""Force bidirectional sync for a specific Jira link.""" """Force bidirectional sync for a specific Jira link."""
# Open context manager
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
# Assign link = jira_service.get_link_or_raise(db, link_id)
link = jira_service.get_link_or_raise(db, link_id) link = jira_service.get_link_or_raise(db, link_id)
# Call jira_service.sync_jira_to_aegis()
jira_service.sync_jira_to_aegis(db, link) jira_service.sync_jira_to_aegis(db, link)
# Call uow.commit()
uow.commit() uow.commit()
# Return {"message": "Sync completed", "jira_status": link.jira_status}
return {"message": "Sync completed", "jira_status": link.jira_status} return {"message": "Sync completed", "jira_status": link.jira_status}
# Apply the @router.delete decorator
@router.delete("/links/{link_id}", status_code=204) @router.delete("/links/{link_id}", status_code=204)
# Define function delete_link
def delete_link( def delete_link(
# Entry: link_id
link_id: UUID, link_id: UUID,
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: user
user: User = Depends(get_current_user), user: User = Depends(get_current_user),
): ) -> None:
"""Remove a Jira link.""" """Remove a Jira link."""
# Open context manager
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
# Assign link = jira_service.delete_link(db, link_id)
link = jira_service.delete_link(db, link_id) link = jira_service.delete_link(db, link_id)
# Call audit_service.log_action()
audit_service.log_action( audit_service.log_action(
db, db,
# Keyword argument: user_id
user_id=user.id, user_id=user.id,
# Keyword argument: action
action="jira_link_deleted", action="jira_link_deleted",
# Keyword argument: entity_type
entity_type="jira_link", entity_type="jira_link",
# Keyword argument: entity_id
entity_id=str(link_id), entity_id=str(link_id),
# Keyword argument: details
details={"jira_issue_key": link.jira_issue_key}, details={"jira_issue_key": link.jira_issue_key},
) )
# Call uow.commit()
uow.commit() uow.commit()
# Apply the @router.post decorator
@router.post("/create-issue") @router.post("/create-issue")
# Define function create_issue_from_entity
def create_issue_from_entity( def create_issue_from_entity(
# Entry: entity_type
entity_type: JiraLinkEntityType, entity_type: JiraLinkEntityType,
# Entry: entity_id
entity_id: UUID, entity_id: UUID,
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: user
user: User = Depends(get_current_user), user: User = Depends(get_current_user),
): ) -> dict:
"""Auto-create a Jira issue from an Aegis entity and link them.""" """Auto-create a Jira issue from an Aegis entity and link them."""
# Open context manager
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
# Assign result = jira_service.create_issue_and_link(
result = jira_service.create_issue_and_link( result = jira_service.create_issue_and_link(
db, db,
# Keyword argument: entity_type
entity_type=entity_type, entity_type=entity_type,
# Keyword argument: entity_id
entity_id=entity_id, entity_id=entity_id,
# Keyword argument: created_by
created_by=user.id, created_by=user.id,
) )
# Call uow.commit()
uow.commit() uow.commit()
# Return result
return result return result
+43 -6
View File
@@ -7,12 +7,22 @@ validation-rate endpoints for the Red/Blue workflow.
Thin HTTP adapter: delegates all data logic to metrics_query_service. Thin HTTP adapter: delegates all data logic to metrics_query_service.
""" """
# Import APIRouter, Depends from fastapi
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
# Import get_db from app.database
from app.database import get_db from app.database import get_db
# Import get_current_user from app.dependencies.auth
from app.dependencies.auth import get_current_user from app.dependencies.auth import get_current_user
# Import User from app.models.user
from app.models.user import User from app.models.user import User
# Import from app.schemas.metrics
from app.schemas.metrics import ( from app.schemas.metrics import (
CoverageSummary, CoverageSummary,
RecentTestItem, RecentTestItem,
@@ -21,6 +31,8 @@ from app.schemas.metrics import (
TestPipelineCounts, TestPipelineCounts,
ValidationRate, ValidationRate,
) )
# Import from app.services.metrics_query_service
from app.services.metrics_query_service import ( from app.services.metrics_query_service import (
get_coverage_by_tactic, get_coverage_by_tactic,
get_coverage_summary, get_coverage_summary,
@@ -30,6 +42,7 @@ from app.services.metrics_query_service import (
get_validation_rate, get_validation_rate,
) )
# Assign router = APIRouter(prefix="/metrics", tags=["metrics"])
router = APIRouter(prefix="/metrics", tags=["metrics"]) router = APIRouter(prefix="/metrics", tags=["metrics"])
@@ -39,11 +52,15 @@ router = APIRouter(prefix="/metrics", tags=["metrics"])
@router.get("/summary", response_model=CoverageSummary) @router.get("/summary", response_model=CoverageSummary)
# Define function coverage_summary
def coverage_summary( def coverage_summary(
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> CoverageSummary:
"""Return a global coverage summary across all techniques.""" """Return a global coverage summary across all techniques."""
# Return get_coverage_summary(db)
return get_coverage_summary(db) return get_coverage_summary(db)
@@ -53,11 +70,15 @@ def coverage_summary(
@router.get("/by-tactic", response_model=list[TacticCoverage]) @router.get("/by-tactic", response_model=list[TacticCoverage])
# Define function coverage_by_tactic
def coverage_by_tactic( def coverage_by_tactic(
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> list[TacticCoverage]:
"""Return coverage breakdown grouped by tactic.""" """Return coverage breakdown grouped by tactic."""
# Return get_coverage_by_tactic(db)
return get_coverage_by_tactic(db) return get_coverage_by_tactic(db)
@@ -67,11 +88,15 @@ def coverage_by_tactic(
@router.get("/test-pipeline", response_model=TestPipelineCounts) @router.get("/test-pipeline", response_model=TestPipelineCounts)
# Define function test_pipeline
def test_pipeline( def test_pipeline(
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> TestPipelineCounts:
"""Return how many tests are in each pipeline state.""" """Return how many tests are in each pipeline state."""
# Return get_test_pipeline_counts(db)
return get_test_pipeline_counts(db) return get_test_pipeline_counts(db)
@@ -81,11 +106,15 @@ def test_pipeline(
@router.get("/team-activity", response_model=list[TeamActivity]) @router.get("/team-activity", response_model=list[TeamActivity])
# Define function team_activity
def team_activity( def team_activity(
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> list[TeamActivity]:
"""Return activity summary for Red and Blue teams.""" """Return activity summary for Red and Blue teams."""
# Return get_team_activity(db)
return get_team_activity(db) return get_team_activity(db)
@@ -95,11 +124,15 @@ def team_activity(
@router.get("/validation-rate", response_model=list[ValidationRate]) @router.get("/validation-rate", response_model=list[ValidationRate])
# Define function validation_rate
def validation_rate( def validation_rate(
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> list[ValidationRate]:
"""Return approval and rejection rates for Red Lead and Blue Lead.""" """Return approval and rejection rates for Red Lead and Blue Lead."""
# Return get_validation_rate(db)
return get_validation_rate(db) return get_validation_rate(db)
@@ -109,9 +142,13 @@ def validation_rate(
@router.get("/recent-tests", response_model=list[RecentTestItem]) @router.get("/recent-tests", response_model=list[RecentTestItem])
# Define function recent_tests
def recent_tests( def recent_tests(
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> list[RecentTestItem]:
"""Return the 10 most recently created tests.""" """Return the 10 most recently created tests."""
# Return get_recent_tests(db, limit=10)
return get_recent_tests(db, limit=10) return get_recent_tests(db, limit=10)
+49 -7
View File
@@ -8,23 +8,39 @@ PATCH /notifications/{id}/read — mark one notification as read
POST /notifications/read-all — mark all as read POST /notifications/read-all — mark all as read
""" """
# Import uuid
import uuid import uuid
# Import APIRouter, Depends, Query from fastapi
from fastapi import APIRouter, Depends, Query from fastapi import APIRouter, Depends, Query
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
# Import get_db from app.database
from app.database import get_db from app.database import get_db
# Import get_current_user from app.dependencies.auth
from app.dependencies.auth import get_current_user from app.dependencies.auth import get_current_user
# Import UnitOfWork from app.domain.unit_of_work
from app.domain.unit_of_work import UnitOfWork from app.domain.unit_of_work import UnitOfWork
# Import User from app.models.user
from app.models.user import User from app.models.user import User
# Import NotificationOut, UnreadCountOut from app.schemas.notification
from app.schemas.notification import NotificationOut, UnreadCountOut from app.schemas.notification import NotificationOut, UnreadCountOut
# Import from app.services.notification_service
from app.services.notification_service import ( from app.services.notification_service import (
list_notifications,
mark_as_read,
mark_all_as_read,
get_unread_count, get_unread_count,
list_notifications,
mark_all_as_read,
mark_as_read,
) )
# Assign router = APIRouter(prefix="/notifications", tags=["notifications"])
router = APIRouter(prefix="/notifications", tags=["notifications"]) router = APIRouter(prefix="/notifications", tags=["notifications"])
@@ -34,13 +50,19 @@ router = APIRouter(prefix="/notifications", tags=["notifications"])
@router.get("", response_model=list[NotificationOut]) @router.get("", response_model=list[NotificationOut])
# Define function list_notifications_endpoint
def list_notifications_endpoint( def list_notifications_endpoint(
# Entry: offset
offset: int = Query(0, ge=0), offset: int = Query(0, ge=0),
# Entry: limit
limit: int = Query(20, ge=1, le=100), limit: int = Query(20, ge=1, le=100),
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> list[NotificationOut]:
"""Return paginated notifications for the current user, newest first.""" """Return paginated notifications for the current user, newest first."""
# Return list_notifications(db, current_user.id, offset=offset, limit=limit)
return list_notifications(db, current_user.id, offset=offset, limit=limit) return list_notifications(db, current_user.id, offset=offset, limit=limit)
@@ -50,12 +72,17 @@ def list_notifications_endpoint(
@router.get("/unread-count", response_model=UnreadCountOut) @router.get("/unread-count", response_model=UnreadCountOut)
# Define function unread_count
def unread_count( def unread_count(
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> UnreadCountOut:
"""Return the number of unread notifications for the current user.""" """Return the number of unread notifications for the current user."""
# Assign count = get_unread_count(db, current_user.id)
count = get_unread_count(db, current_user.id) count = get_unread_count(db, current_user.id)
# Return UnreadCountOut(unread_count=count)
return UnreadCountOut(unread_count=count) return UnreadCountOut(unread_count=count)
@@ -65,15 +92,23 @@ def unread_count(
@router.patch("/{notification_id}/read", response_model=NotificationOut) @router.patch("/{notification_id}/read", response_model=NotificationOut)
# Define function read_notification
def read_notification( def read_notification(
# Entry: notification_id
notification_id: uuid.UUID, notification_id: uuid.UUID,
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> NotificationOut:
"""Mark a single notification as read.""" """Mark a single notification as read."""
# Open context manager
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
# Assign notif = mark_as_read(db, notification_id, current_user.id)
notif = mark_as_read(db, notification_id, current_user.id) notif = mark_as_read(db, notification_id, current_user.id)
# Call uow.commit()
uow.commit() uow.commit()
# Return notif
return notif return notif
@@ -83,12 +118,19 @@ def read_notification(
@router.post("/read-all") @router.post("/read-all")
# Define function read_all_notifications
def read_all_notifications( def read_all_notifications(
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> dict:
"""Mark all notifications for the current user as read.""" """Mark all notifications for the current user as read."""
# Open context manager
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
# Assign count = mark_all_as_read(db, current_user.id)
count = mark_all_as_read(db, current_user.id) count = mark_all_as_read(db, current_user.id)
# Call uow.commit()
uow.commit() uow.commit()
# Return {"detail": f"Marked {count} notifications as read"}
return {"detail": f"Marked {count} notifications as read"} return {"detail": f"Marked {count} notifications as read"}
+29 -5
View File
@@ -4,18 +4,28 @@ Provides operational KPIs for security teams with trend analysis
and team-level breakdowns. and team-level breakdowns.
""" """
# Import APIRouter, Depends, Query from fastapi
from fastapi import APIRouter, Depends, Query from fastapi import APIRouter, Depends, Query
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
# Import get_db from app.database
from app.database import get_db from app.database import get_db
# Import get_current_user from app.dependencies.auth
from app.dependencies.auth import get_current_user from app.dependencies.auth import get_current_user
# Import User from app.models.user
from app.models.user import User from app.models.user import User
# Import from app.services.operational_metrics_service
from app.services.operational_metrics_service import ( from app.services.operational_metrics_service import (
get_all_operational_metrics,
get_operational_trend,
get_metrics_by_team, get_metrics_by_team,
get_operational_trend,
) )
# Assign router = APIRouter(prefix="/metrics/operational", tags=["operational-metrics"])
router = APIRouter(prefix="/metrics/operational", tags=["operational-metrics"]) router = APIRouter(prefix="/metrics/operational", tags=["operational-metrics"])
@@ -23,13 +33,18 @@ router = APIRouter(prefix="/metrics/operational", tags=["operational-metrics"])
@router.get("") @router.get("")
# Define function operational_metrics
def operational_metrics( def operational_metrics(
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> dict:
"""Get all operational metrics (MTTD, MTTR, etc.) — cached for 5 min.""" """Get all operational metrics (MTTD, MTTR, etc.) — cached for 5 min."""
# Import get_operational_metrics_cached from app.services.score_cache
from app.services.score_cache import get_operational_metrics_cached from app.services.score_cache import get_operational_metrics_cached
# Return get_operational_metrics_cached(db)
return get_operational_metrics_cached(db) return get_operational_metrics_cached(db)
@@ -37,12 +52,17 @@ def operational_metrics(
@router.get("/trend") @router.get("/trend")
# Define function operational_trend
def operational_trend( def operational_trend(
# Entry: period
period: str = Query("90d", pattern="^(30d|90d|1y)$"), period: str = Query("90d", pattern="^(30d|90d|1y)$"),
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> dict:
"""Get weekly trend data for operational metrics.""" """Get weekly trend data for operational metrics."""
# Return get_operational_trend(db, period)
return get_operational_trend(db, period) return get_operational_trend(db, period)
@@ -50,9 +70,13 @@ def operational_trend(
@router.get("/by-team") @router.get("/by-team")
# Define function metrics_by_team
def metrics_by_team( def metrics_by_team(
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> dict:
"""Get metrics broken down by Red Team vs Blue Team.""" """Get metrics broken down by Red Team vs Blue Team."""
# Return get_metrics_by_team(db)
return get_metrics_by_team(db) return get_metrics_by_team(db)
+162 -15
View File
@@ -1,26 +1,44 @@
"""OSINT enrichment endpoints — view, review, and trigger enrichment of """OSINT enrichment endpoints — view, review, and trigger enrichment of OSINT items linked to techniques."""
OSINT items (CVEs, advisories, etc.) linked to techniques.
"""
# Import UUID from uuid
from uuid import UUID from uuid import UUID
from fastapi import APIRouter, Depends, Query, HTTPException, status # Import APIRouter, Depends, HTTPException, Query, status from fastapi
from fastapi import APIRouter, Depends, HTTPException, Query, status
# Import BaseModel from pydantic
from pydantic import BaseModel from pydantic import BaseModel
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
# Import get_db from app.database
from app.database import get_db from app.database import get_db
# Import get_current_user, require_any_role from app.dependencies.auth
from app.dependencies.auth import get_current_user, require_any_role from app.dependencies.auth import get_current_user, require_any_role
# Import UnitOfWork from app.domain.unit_of_work
from app.domain.unit_of_work import UnitOfWork from app.domain.unit_of_work import UnitOfWork
# Import User from app.models.user
from app.models.user import User from app.models.user import User
# Import from app.services.osint_enrichment_service
from app.services.osint_enrichment_service import ( from app.services.osint_enrichment_service import (
enrich_technique_with_cves, enrich_technique_with_cves,
get_osint_items_for_technique, get_osint_items_for_technique,
get_osint_summary, get_osint_summary,
get_technique_or_raise, get_technique_or_raise,
list_osint_items as service_list_osint_items,
mark_osint_reviewed, mark_osint_reviewed,
) )
# Import from app.services.osint_enrichment_service
from app.services.osint_enrichment_service import (
list_osint_items as service_list_osint_items,
)
# Assign router = APIRouter(prefix="/osint", tags=["osint"])
router = APIRouter(prefix="/osint", tags=["osint"]) router = APIRouter(prefix="/osint", tags=["osint"])
@@ -28,18 +46,34 @@ router = APIRouter(prefix="/osint", tags=["osint"])
class OsintItemOut(BaseModel): class OsintItemOut(BaseModel):
"""Serialized OSINT item returned by the API."""
# id: str
id: str id: str
# technique_id: str
technique_id: str technique_id: str
# source_type: str
source_type: str source_type: str
# source_url: str
source_url: str source_url: str
# title: str
title: str title: str
# description: str | None
description: str | None description: str | None
# severity: str | None
severity: str | None severity: str | None
# discovered_at: str | None
discovered_at: str | None discovered_at: str | None
# reviewed: bool
reviewed: bool reviewed: bool
# Assign metadata_ = None
metadata_: dict | None = None metadata_: dict | None = None
# Define class Config
class Config: class Config:
"""ORM mode configuration for SQLAlchemy model mapping."""
# Assign from_attributes = True
from_attributes = True from_attributes = True
@@ -47,94 +81,207 @@ class OsintItemOut(BaseModel):
@router.get("/items") @router.get("/items")
# Define function list_osint_items
def list_osint_items( def list_osint_items(
# Entry: technique_id
technique_id: UUID | None = Query(None), technique_id: UUID | None = Query(None),
# Entry: source_type
source_type: str | None = Query(None), source_type: str | None = Query(None),
# Entry: reviewed
reviewed: bool | None = Query(None), reviewed: bool | None = Query(None),
# Entry: offset
offset: int = Query(0, ge=0), offset: int = Query(0, ge=0),
# Entry: limit
limit: int = Query(50, ge=1, le=200), limit: int = Query(50, ge=1, le=200),
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: user
user: User = Depends(get_current_user), user: User = Depends(get_current_user),
): ) -> list:
"""List OSINT items with optional filters.""" """List OSINT items with optional filters.
Args:
technique_id (UUID | None): Filter by the technique's UUID.
source_type (str | None): Filter by source type (e.g. ``nvd_cve``, ``advisory``).
reviewed (bool | None): Filter by review status; ``None`` returns all.
offset (int): Number of records to skip for pagination.
limit (int): Maximum number of records to return.
db (Session): SQLAlchemy database session.
user (User): Authenticated user making the request.
Returns:
list: Serialised list of OSINT item dicts matching the filters.
"""
# Return service_list_osint_items(
return service_list_osint_items( return service_list_osint_items(
db, db,
# Keyword argument: technique_id
technique_id=technique_id, technique_id=technique_id,
# Keyword argument: source_type
source_type=source_type, source_type=source_type,
# Keyword argument: reviewed
reviewed=reviewed, reviewed=reviewed,
# Keyword argument: offset
offset=offset, offset=offset,
# Keyword argument: limit
limit=limit, limit=limit,
) )
# Apply the @router.get decorator
@router.get("/summary") @router.get("/summary")
# Define function osint_summary
def osint_summary( def osint_summary(
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: user
user: User = Depends(get_current_user), user: User = Depends(get_current_user),
): ) -> dict:
"""Summary statistics for OSINT items.""" """Return summary statistics for OSINT items.
Args:
db (Session): SQLAlchemy database session.
user (User): Authenticated user making the request.
Returns:
dict: Counts of total, reviewed, and unreviewed items broken down by source type.
"""
# Return get_osint_summary(db)
return get_osint_summary(db) return get_osint_summary(db)
# Apply the @router.post decorator
@router.post("/items/{item_id}/review") @router.post("/items/{item_id}/review")
# Define function review_osint_item
def review_osint_item( def review_osint_item(
# Entry: item_id
item_id: UUID, item_id: UUID,
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: user
user: User = Depends(get_current_user), user: User = Depends(get_current_user),
): ) -> dict:
"""Mark an OSINT item as reviewed.""" """Mark an OSINT item as reviewed.
Args:
item_id (UUID): Primary key of the OSINT item to mark reviewed.
db (Session): SQLAlchemy database session.
user (User): Authenticated user performing the review.
Returns:
dict: Contains ``id`` (str) and ``reviewed`` (bool ``True``).
"""
# Open context manager
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
# Assign item = mark_osint_reviewed(db, str(item_id))
item = mark_osint_reviewed(db, str(item_id)) item = mark_osint_reviewed(db, str(item_id))
# Check: not item
if not item: if not item:
# Raise HTTPException
raise HTTPException( raise HTTPException(
# Keyword argument: status_code
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
# Keyword argument: detail
detail="OSINT item not found", detail="OSINT item not found",
) )
# Call uow.commit()
uow.commit() uow.commit()
# Return {"id": str(item.id), "reviewed": True}
return {"id": str(item.id), "reviewed": True} return {"id": str(item.id), "reviewed": True}
# Apply the @router.post decorator
@router.post("/enrich/{technique_id}") @router.post("/enrich/{technique_id}")
# Define function trigger_technique_enrichment
def trigger_technique_enrichment( def trigger_technique_enrichment(
# Entry: technique_id
technique_id: UUID, technique_id: UUID,
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: user
user: User = Depends(require_any_role("red_lead", "blue_lead")), user: User = Depends(require_any_role("red_lead", "blue_lead")),
): ) -> dict:
"""Manually trigger OSINT enrichment for a single technique.""" """Manually trigger OSINT enrichment for a single technique.
Args:
technique_id (UUID): Primary key of the technique to enrich.
db (Session): SQLAlchemy database session.
user (User): Authenticated red_lead or blue_lead requesting enrichment.
Returns:
dict: Contains ``technique_id`` (str), ``mitre_id`` (str), and ``new_items`` (int).
"""
# Assign technique = get_technique_or_raise(db, technique_id)
technique = get_technique_or_raise(db, technique_id) technique = get_technique_or_raise(db, technique_id)
# Assign count = enrich_technique_with_cves(db, technique)
count = enrich_technique_with_cves(db, technique) count = enrich_technique_with_cves(db, technique)
# Return {
return { return {
# Literal argument value
"technique_id": str(technique.id), "technique_id": str(technique.id),
# Literal argument value
"mitre_id": technique.mitre_id, "mitre_id": technique.mitre_id,
# Literal argument value
"new_items": count, "new_items": count,
} }
# Apply the @router.get decorator
@router.get("/technique/{technique_id}") @router.get("/technique/{technique_id}")
# Define function get_technique_osint
def get_technique_osint( def get_technique_osint(
# Entry: technique_id
technique_id: UUID, technique_id: UUID,
# Entry: source_type
source_type: str | None = Query(None), source_type: str | None = Query(None),
# Entry: reviewed
reviewed: bool | None = Query(None), reviewed: bool | None = Query(None),
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: user
user: User = Depends(get_current_user), user: User = Depends(get_current_user),
): ) -> list:
"""Get all OSINT items for a specific technique.""" """Get all OSINT items for a specific technique.
Args:
technique_id (UUID): Primary key of the technique.
source_type (str | None): Filter by source type (e.g. ``nvd_cve``).
reviewed (bool | None): Filter by review status; ``None`` returns all.
db (Session): SQLAlchemy database session.
user (User): Authenticated user making the request.
Returns:
list: Dicts with OSINT item fields including source URL, severity, and review status.
"""
# Assign items = get_osint_items_for_technique(
items = get_osint_items_for_technique( items = get_osint_items_for_technique(
db, db,
str(technique_id), str(technique_id),
# Keyword argument: source_type
source_type=source_type, source_type=source_type,
# Keyword argument: reviewed
reviewed=reviewed, reviewed=reviewed,
) )
# Return [
return [ return [
{ {
# Literal argument value
"id": str(item.id), "id": str(item.id),
# Literal argument value
"source_type": item.source_type, "source_type": item.source_type,
# Literal argument value
"source_url": item.source_url, "source_url": item.source_url,
# Literal argument value
"title": item.title, "title": item.title,
# Literal argument value
"description": item.description, "description": item.description,
# Literal argument value
"severity": item.severity, "severity": item.severity,
# Literal argument value
"discovered_at": item.discovered_at.isoformat() if item.discovered_at else None, "discovered_at": item.discovered_at.isoformat() if item.discovered_at else None,
# Literal argument value
"reviewed": item.reviewed, "reviewed": item.reviewed,
# Literal argument value
"metadata": item.metadata_, "metadata": item.metadata_,
} }
for item in items for item in items
+83 -6
View File
@@ -1,118 +1,195 @@
"""Professional report generation endpoints — PDF, DOCX, HTML output.""" """Professional report generation endpoints — PDF, DOCX, HTML output."""
# Import UUID from uuid
from uuid import UUID from uuid import UUID
# Import APIRouter, Depends, Query, Request from fastapi
from fastapi import APIRouter, Depends, Query, Request from fastapi import APIRouter, Depends, Query, Request
# Import FileResponse from fastapi.responses
from fastapi.responses import FileResponse from fastapi.responses import FileResponse
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
# Import get_db from app.database
from app.database import get_db from app.database import get_db
# Import get_current_user, require_any_role from app.dependencies.auth
from app.dependencies.auth import get_current_user, require_any_role from app.dependencies.auth import get_current_user, require_any_role
from app.models.user import User
# Import limiter from app.limiter
from app.limiter import limiter from app.limiter import limiter
# Import User from app.models.user
from app.models.user import User
# Import report_generation_service from app.services
from app.services import report_generation_service from app.services import report_generation_service
# Assign router = APIRouter(prefix="/reports/generate", tags=["professional-reports"])
router = APIRouter(prefix="/reports/generate", tags=["professional-reports"]) router = APIRouter(prefix="/reports/generate", tags=["professional-reports"])
# Assign _MEDIA_TYPES = {
_MEDIA_TYPES = { _MEDIA_TYPES = {
# Literal argument value
"pdf": "application/pdf", "pdf": "application/pdf",
# Literal argument value
"docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document", "docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
# Literal argument value
"html": "text/html", "html": "text/html",
} }
# Apply the @router.get decorator
@router.get("/purple-campaign/{campaign_id}") @router.get("/purple-campaign/{campaign_id}")
# Apply the @limiter.limit decorator
@limiter.limit("5/minute") @limiter.limit("5/minute")
# Define function generate_purple_report
def generate_purple_report( def generate_purple_report(
# Entry: request
request: Request, request: Request,
# Entry: campaign_id
campaign_id: UUID, campaign_id: UUID,
# Entry: format
format: str = Query("pdf", pattern="^(pdf|docx|html)$"), format: str = Query("pdf", pattern="^(pdf|docx|html)$"),
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: user
user: User = Depends(require_any_role("red_lead", "blue_lead", "viewer")), user: User = Depends(require_any_role("red_lead", "blue_lead", "viewer")),
): ) -> FileResponse:
"""Generate a Purple Team campaign assessment report.""" """Generate a Purple Team campaign assessment report."""
# Assign filepath = report_generation_service.generate_purple_campaign_report(
filepath = report_generation_service.generate_purple_campaign_report( filepath = report_generation_service.generate_purple_campaign_report(
db, str(campaign_id), output_format=format, db, str(campaign_id), output_format=format,
) )
# Return FileResponse(
return FileResponse( return FileResponse(
filepath, filepath,
# Keyword argument: media_type
media_type=_MEDIA_TYPES[format], media_type=_MEDIA_TYPES[format],
# Keyword argument: filename
filename=f"purple_report.{format}", filename=f"purple_report.{format}",
) )
# Apply the @router.get decorator
@router.get("/coverage-summary") @router.get("/coverage-summary")
# Apply the @limiter.limit decorator
@limiter.limit("5/minute") @limiter.limit("5/minute")
# Define function generate_coverage_report
def generate_coverage_report( def generate_coverage_report(
# Entry: request
request: Request, request: Request,
# Entry: format
format: str = Query("pdf", pattern="^(pdf|docx|html)$"), format: str = Query("pdf", pattern="^(pdf|docx|html)$"),
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: user
user: User = Depends(require_any_role("red_lead", "blue_lead", "viewer")), user: User = Depends(require_any_role("red_lead", "blue_lead", "viewer")),
): ) -> FileResponse:
"""Generate an organization-wide MITRE ATT&CK coverage report.""" """Generate an organization-wide MITRE ATT&CK coverage report."""
# Assign filepath = report_generation_service.generate_coverage_report(
filepath = report_generation_service.generate_coverage_report( filepath = report_generation_service.generate_coverage_report(
db, output_format=format, db, output_format=format,
) )
# Return FileResponse(
return FileResponse( return FileResponse(
filepath, filepath,
# Keyword argument: media_type
media_type=_MEDIA_TYPES[format], media_type=_MEDIA_TYPES[format],
# Keyword argument: filename
filename=f"coverage_report.{format}", filename=f"coverage_report.{format}",
) )
# Apply the @router.get decorator
@router.get("/executive-summary") @router.get("/executive-summary")
# Apply the @limiter.limit decorator
@limiter.limit("5/minute") @limiter.limit("5/minute")
# Define function generate_executive_report
def generate_executive_report( def generate_executive_report(
# Entry: request
request: Request, request: Request,
# Entry: format
format: str = Query("pdf", pattern="^(pdf|docx|html)$"), format: str = Query("pdf", pattern="^(pdf|docx|html)$"),
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: user
user: User = Depends(require_any_role("red_lead", "blue_lead", "viewer")), user: User = Depends(require_any_role("red_lead", "blue_lead", "viewer")),
): ) -> FileResponse:
"""Generate an executive security summary report.""" """Generate an executive security summary report."""
# Assign filepath = report_generation_service.generate_executive_summary(
filepath = report_generation_service.generate_executive_summary( filepath = report_generation_service.generate_executive_summary(
db, output_format=format, db, output_format=format,
) )
# Return FileResponse(
return FileResponse( return FileResponse(
filepath, filepath,
# Keyword argument: media_type
media_type=_MEDIA_TYPES[format], media_type=_MEDIA_TYPES[format],
# Keyword argument: filename
filename=f"executive_summary.{format}", filename=f"executive_summary.{format}",
) )
# Apply the @router.get decorator
@router.get("/quarterly-summary") @router.get("/quarterly-summary")
# Apply the @limiter.limit decorator
@limiter.limit("5/minute") @limiter.limit("5/minute")
# Define function generate_quarterly_report
def generate_quarterly_report( def generate_quarterly_report(
# Entry: request
request: Request, request: Request,
# Entry: format
format: str = Query("pdf", pattern="^(pdf|docx|html)$"), format: str = Query("pdf", pattern="^(pdf|docx|html)$"),
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: user
user: User = Depends(require_any_role("red_lead", "blue_lead", "viewer")), user: User = Depends(require_any_role("red_lead", "blue_lead", "viewer")),
): ) -> FileResponse:
"""Generate a quarterly security summary report.""" """Generate a quarterly security summary report."""
# Assign filepath = report_generation_service.generate_quarterly_summary(
filepath = report_generation_service.generate_quarterly_summary( filepath = report_generation_service.generate_quarterly_summary(
db, output_format=format, db, output_format=format,
) )
# Return FileResponse(
return FileResponse( return FileResponse(
filepath, filepath,
# Keyword argument: media_type
media_type=_MEDIA_TYPES[format], media_type=_MEDIA_TYPES[format],
# Keyword argument: filename
filename=f"quarterly_summary.{format}", filename=f"quarterly_summary.{format}",
) )
# Apply the @router.get decorator
@router.get("/technique/{technique_id}") @router.get("/technique/{technique_id}")
# Apply the @limiter.limit decorator
@limiter.limit("5/minute") @limiter.limit("5/minute")
# Define function generate_technique_report
def generate_technique_report( def generate_technique_report(
# Entry: request
request: Request, request: Request,
# Entry: technique_id
technique_id: UUID, technique_id: UUID,
# Entry: format
format: str = Query("pdf", pattern="^(pdf|docx|html)$"), format: str = Query("pdf", pattern="^(pdf|docx|html)$"),
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: user
user: User = Depends(get_current_user), user: User = Depends(get_current_user),
): ) -> FileResponse:
"""Generate a detailed report for one MITRE technique.""" """Generate a detailed report for one MITRE technique."""
# Assign filepath = report_generation_service.generate_technique_detail_report(
filepath = report_generation_service.generate_technique_detail_report( filepath = report_generation_service.generate_technique_detail_report(
db, str(technique_id), output_format=format, db, str(technique_id), output_format=format,
) )
# Return FileResponse(
return FileResponse( return FileResponse(
filepath, filepath,
# Keyword argument: media_type
media_type=_MEDIA_TYPES[format], media_type=_MEDIA_TYPES[format],
# Keyword argument: filename
filename=f"technique_{technique_id}.{format}", filename=f"technique_{technique_id}.{format}",
) )
+61 -4
View File
@@ -10,18 +10,37 @@ GET /reports/test-results — test results report (JSON)
GET /reports/remediation-status — remediation status report (JSON) GET /reports/remediation-status — remediation status report (JSON)
""" """
# Import csv
import csv import csv
# Import io
import io import io
# Import datetime from datetime
from datetime import datetime from datetime import datetime
# Import Optional from typing
from typing import Optional from typing import Optional
# Import APIRouter, Depends, Query from fastapi
from fastapi import APIRouter, Depends, Query from fastapi import APIRouter, Depends, Query
# Import StreamingResponse from fastapi.responses
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
# Import get_db from app.database
from app.database import get_db from app.database import get_db
# Import get_current_user from app.dependencies.auth
from app.dependencies.auth import get_current_user from app.dependencies.auth import get_current_user
# Import User from app.models.user
from app.models.user import User from app.models.user import User
# Import from app.services.coverage_report_service
from app.services.coverage_report_service import ( from app.services.coverage_report_service import (
build_coverage_csv_rows, build_coverage_csv_rows,
build_coverage_summary, build_coverage_summary,
@@ -29,61 +48,99 @@ from app.services.coverage_report_service import (
build_test_results_report, build_test_results_report,
) )
# Assign router = APIRouter(prefix="/reports", tags=["reports"])
router = APIRouter(prefix="/reports", tags=["reports"]) router = APIRouter(prefix="/reports", tags=["reports"])
# Apply the @router.get decorator
@router.get("/coverage-summary") @router.get("/coverage-summary")
# Define function coverage_summary
def coverage_summary( def coverage_summary(
# Entry: tactic
tactic: Optional[str] = Query(None, description="Filter by tactic"), tactic: Optional[str] = Query(None, description="Filter by tactic"),
# Entry: platform
platform: Optional[str] = Query(None, description="Filter by platform (in techniques)"), platform: Optional[str] = Query(None, description="Filter by platform (in techniques)"),
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> dict:
"""Full coverage report as JSON — technique-by-technique with test counts.""" """Full coverage report as JSON — technique-by-technique with test counts."""
# Return build_coverage_summary(db, tactic=tactic, platform=platform)
return build_coverage_summary(db, tactic=tactic, platform=platform) return build_coverage_summary(db, tactic=tactic, platform=platform)
# Apply the @router.get decorator
@router.get("/coverage-csv") @router.get("/coverage-csv")
# Define function coverage_csv
def coverage_csv( def coverage_csv(
# Entry: tactic
tactic: Optional[str] = Query(None), tactic: Optional[str] = Query(None),
# Entry: platform
platform: Optional[str] = Query(None), platform: Optional[str] = Query(None),
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> StreamingResponse:
"""Export coverage as a downloadable CSV.""" """Export coverage as a downloadable CSV."""
# Assign rows = build_coverage_csv_rows(db, tactic=tactic, platform=platform)
rows = build_coverage_csv_rows(db, tactic=tactic, platform=platform) rows = build_coverage_csv_rows(db, tactic=tactic, platform=platform)
# Assign output = io.StringIO()
output = io.StringIO() output = io.StringIO()
# Assign writer = csv.writer(output)
writer = csv.writer(output) writer = csv.writer(output)
# Iterate over rows
for row in rows: for row in rows:
# Call writer.writerow()
writer.writerow(row) writer.writerow(row)
# Call output.seek()
output.seek(0) output.seek(0)
# Assign filename = f"aegis_coverage_{datetime.utcnow().strftime('%Y%m%d')}.csv"
filename = f"aegis_coverage_{datetime.utcnow().strftime('%Y%m%d')}.csv" filename = f"aegis_coverage_{datetime.utcnow().strftime('%Y%m%d')}.csv"
# Return StreamingResponse(
return StreamingResponse( return StreamingResponse(
iter([output.getvalue()]), iter([output.getvalue()]),
# Keyword argument: media_type
media_type="text/csv", media_type="text/csv",
# Keyword argument: headers
headers={"Content-Disposition": f"attachment; filename={filename}"}, headers={"Content-Disposition": f"attachment; filename={filename}"},
) )
# Apply the @router.get decorator
@router.get("/test-results") @router.get("/test-results")
# Define function test_results
def test_results( def test_results(
# Entry: state
state: Optional[str] = Query(None), state: Optional[str] = Query(None),
# Entry: date_from
date_from: Optional[str] = Query(None, description="ISO date string YYYY-MM-DD"), date_from: Optional[str] = Query(None, description="ISO date string YYYY-MM-DD"),
# Entry: date_to
date_to: Optional[str] = Query(None, description="ISO date string YYYY-MM-DD"), date_to: Optional[str] = Query(None, description="ISO date string YYYY-MM-DD"),
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> dict:
"""Report of test results with optional filters.""" """Report of test results with optional filters."""
# Return build_test_results_report(db, state=state, date_from=date_from, dat...
return build_test_results_report(db, state=state, date_from=date_from, date_to=date_to) return build_test_results_report(db, state=state, date_from=date_from, date_to=date_to)
# Apply the @router.get decorator
@router.get("/remediation-status") @router.get("/remediation-status")
# Define function remediation_status
def remediation_status( def remediation_status(
# Entry: status
status: Optional[str] = Query(None, description="Filter by remediation status"), status: Optional[str] = Query(None, description="Filter by remediation status"),
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> dict:
"""Report of remediation status across all tests.""" """Report of remediation status across all tests."""
# Return build_remediation_status_report(db, status=status)
return build_remediation_status_report(db, status=status) return build_remediation_status_report(db, status=status)
+154 -20
View File
@@ -3,28 +3,45 @@
Provides granular scoring with breakdowns and configurable weights. Provides granular scoring with breakdowns and configurable weights.
""" """
# Import Optional from typing
from typing import Optional from typing import Optional
# Import APIRouter, Depends, Query from fastapi
from fastapi import APIRouter, Depends, Query from fastapi import APIRouter, Depends, Query
# Import BaseModel from pydantic
from pydantic import BaseModel from pydantic import BaseModel
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
# Import get_db from app.database
from app.database import get_db from app.database import get_db
# Import get_current_user, require_role from app.dependencies.auth
from app.dependencies.auth import get_current_user, require_role from app.dependencies.auth import get_current_user, require_role
# Import UnitOfWork from app.domain.unit_of_work
from app.domain.unit_of_work import UnitOfWork from app.domain.unit_of_work import UnitOfWork
# Import User from app.models.user
from app.models.user import User from app.models.user import User
from app.services.scoring_service import (
score_technique_by_mitre_id, # Import from app.services.scoring_config_service
score_actor_by_id,
calculate_tactic_score,
calculate_organization_score,
get_score_history,
)
from app.services.scoring_config_service import ( from app.services.scoring_config_service import (
get_weights_dict, get_weights_dict,
update_scoring_weights, update_scoring_weights,
) )
# Import from app.services.scoring_service
from app.services.scoring_service import (
calculate_tactic_score,
get_score_history,
score_actor_by_id,
score_technique_by_mitre_id,
)
# Assign router = APIRouter(prefix="/scores", tags=["scores"])
router = APIRouter(prefix="/scores", tags=["scores"]) router = APIRouter(prefix="/scores", tags=["scores"])
@@ -32,12 +49,26 @@ router = APIRouter(prefix="/scores", tags=["scores"])
@router.get("/technique/{mitre_id}") @router.get("/technique/{mitre_id}")
# Define function score_technique
def score_technique( def score_technique(
# Entry: mitre_id
mitre_id: str, mitre_id: str,
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> dict:
"""Get detailed score with breakdown for a specific technique.""" """Get detailed score with breakdown for a specific technique.
Args:
mitre_id (str): MITRE ATT&CK technique ID (e.g. ``T1059``).
db (Session): SQLAlchemy database session.
current_user (User): Authenticated user making the request.
Returns:
dict: Score value and component breakdown (tests, detection rules, recency, etc.).
"""
# Return score_technique_by_mitre_id(db, mitre_id)
return score_technique_by_mitre_id(db, mitre_id) return score_technique_by_mitre_id(db, mitre_id)
@@ -45,12 +76,26 @@ def score_technique(
@router.get("/tactic/{tactic}") @router.get("/tactic/{tactic}")
# Define function score_tactic
def score_tactic( def score_tactic(
# Entry: tactic
tactic: str, tactic: str,
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> dict:
"""Get average score for a tactic.""" """Get average score for a tactic.
Args:
tactic (str): MITRE ATT&CK tactic slug (e.g. ``initial-access``).
db (Session): SQLAlchemy database session.
current_user (User): Authenticated user making the request.
Returns:
dict: Average score and per-technique breakdown for the tactic.
"""
# Return calculate_tactic_score(tactic, db)
return calculate_tactic_score(tactic, db) return calculate_tactic_score(tactic, db)
@@ -58,12 +103,26 @@ def score_tactic(
@router.get("/threat-actor/{actor_id}") @router.get("/threat-actor/{actor_id}")
# Define function score_threat_actor
def score_threat_actor( def score_threat_actor(
# Entry: actor_id
actor_id: str, actor_id: str,
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> dict:
"""Get coverage score against a specific threat actor.""" """Get coverage score against a specific threat actor.
Args:
actor_id (str): UUID string of the threat actor to score against.
db (Session): SQLAlchemy database session.
current_user (User): Authenticated user making the request.
Returns:
dict: Coverage score and per-technique breakdown for the threat actor.
"""
# Return score_actor_by_id(db, actor_id)
return score_actor_by_id(db, actor_id) return score_actor_by_id(db, actor_id)
@@ -71,13 +130,26 @@ def score_threat_actor(
@router.get("/organization") @router.get("/organization")
# Define function score_organization
def score_organization( def score_organization(
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> dict:
"""Get the overall organization security score (cached for 5 min).""" """Get the overall organization security score (cached for 5 min).
Args:
db (Session): SQLAlchemy database session.
current_user (User): Authenticated user making the request.
Returns:
dict: Aggregate organization score with tactic-level breakdowns.
"""
# Import get_organization_score_cached from app.services.score_cache
from app.services.score_cache import get_organization_score_cached from app.services.score_cache import get_organization_score_cached
# Return get_organization_score_cached(db)
return get_organization_score_cached(db) return get_organization_score_cached(db)
@@ -85,12 +157,26 @@ def score_organization(
@router.get("/history") @router.get("/history")
# Define function score_history
def score_history( def score_history(
# Entry: period
period: str = Query("90d", pattern="^(30d|90d|1y)$"), period: str = Query("90d", pattern="^(30d|90d|1y)$"),
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> dict:
"""Get historical score data points (weekly).""" """Get historical score data points (weekly).
Args:
period (str): Time window for history — one of ``30d``, ``90d``, or ``1y``.
db (Session): SQLAlchemy database session.
current_user (User): Authenticated user making the request.
Returns:
dict: Weekly score data points for the requested period.
"""
# Return get_score_history(db, period)
return get_score_history(db, period) return get_score_history(db, period)
@@ -98,11 +184,23 @@ def score_history(
@router.get("/config") @router.get("/config")
# Define function get_scoring_config
def get_scoring_config( def get_scoring_config(
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(require_role("admin")), current_user: User = Depends(require_role("admin")),
): ) -> dict:
"""Get current scoring weights (admin only).""" """Get current scoring weights (admin only).
Args:
db (Session): SQLAlchemy database session.
current_user (User): Authenticated admin user.
Returns:
dict: Current weight values for each scoring component.
"""
# Return get_weights_dict(db)
return get_weights_dict(db) return get_weights_dict(db)
@@ -110,41 +208,77 @@ def get_scoring_config(
class ScoringConfigUpdate(BaseModel): class ScoringConfigUpdate(BaseModel):
"""Partial update payload for the scoring weight configuration."""
# Assign tests = None
tests: Optional[float] = None tests: Optional[float] = None
# Assign detection_rules = None
detection_rules: Optional[float] = None detection_rules: Optional[float] = None
# Assign d3fend = None
d3fend: Optional[float] = None d3fend: Optional[float] = None
# Assign recency = None
recency: Optional[float] = None recency: Optional[float] = None
# Assign severity = None
severity: Optional[float] = None severity: Optional[float] = None
# Assign freshness = None
freshness: Optional[float] = None freshness: Optional[float] = None
# Assign platform_diversity = None
platform_diversity: Optional[float] = None platform_diversity: Optional[float] = None
# Apply the @router.patch decorator
@router.patch("/config") @router.patch("/config")
# Define function update_scoring_config
def update_scoring_config( def update_scoring_config(
# Entry: payload
payload: ScoringConfigUpdate, payload: ScoringConfigUpdate,
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(require_role("admin")), current_user: User = Depends(require_role("admin")),
): ) -> dict:
"""Update scoring weights (admin only). """Update scoring weights (admin only).
Weights are persisted in the database and survive restarts. Weights are persisted in the database and survive restarts.
Validation enforces that all weights are non-negative and sum to 100. Validation enforces that all weights are non-negative and sum to 100.
Args:
payload (ScoringConfigUpdate): Partial weight update; only set fields are changed.
db (Session): SQLAlchemy database session.
current_user (User): Authenticated admin user.
Returns:
dict: Confirmation message plus the full updated weight configuration.
""" """
# Open context manager
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
# Assign result = update_scoring_weights(
result = update_scoring_weights( result = update_scoring_weights(
db, db,
# Keyword argument: tests
tests=payload.tests, tests=payload.tests,
# Keyword argument: detection_rules
detection_rules=payload.detection_rules, detection_rules=payload.detection_rules,
# Keyword argument: d3fend
d3fend=payload.d3fend, d3fend=payload.d3fend,
# Keyword argument: recency
recency=payload.recency, recency=payload.recency,
# Keyword argument: severity
severity=payload.severity, severity=payload.severity,
# Keyword argument: freshness
freshness=payload.freshness, freshness=payload.freshness,
# Keyword argument: platform_diversity
platform_diversity=payload.platform_diversity, platform_diversity=payload.platform_diversity,
# Keyword argument: updated_by
updated_by=current_user.id, updated_by=current_user.id,
) )
# Call uow.commit()
uow.commit() uow.commit()
# Import invalidate from app.services.score_cache
from app.services.score_cache import invalidate from app.services.score_cache import invalidate
# Call invalidate()
invalidate() invalidate()
# Return {"message": "Scoring config updated", **result}
return {"message": "Scoring config updated", **result} return {"message": "Scoring config updated", **result}
+104 -17
View File
@@ -4,40 +4,71 @@ Provides periodic and manual snapshots of the organisation's coverage
state, plus temporal comparison between any two snapshots. state, plus temporal comparison between any two snapshots.
""" """
# Import logging
import logging import logging
# Import uuid
import uuid import uuid
# Import Optional from typing
from typing import Optional from typing import Optional
# Import APIRouter, Depends, Query from fastapi
from fastapi import APIRouter, Depends, Query from fastapi import APIRouter, Depends, Query
# Import BaseModel from pydantic
from pydantic import BaseModel from pydantic import BaseModel
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
# Import get_db from app.database
from app.database import get_db from app.database import get_db
# Import get_current_user, require_any_role, require_role from app.dependencies.auth
from app.dependencies.auth import get_current_user, require_any_role, require_role from app.dependencies.auth import get_current_user, require_any_role, require_role
# Import BusinessRuleViolation from app.domain.errors
from app.domain.errors import BusinessRuleViolation from app.domain.errors import BusinessRuleViolation
# Import UnitOfWork from app.domain.unit_of_work
from app.domain.unit_of_work import UnitOfWork from app.domain.unit_of_work import UnitOfWork
# Import User from app.models.user
from app.models.user import User from app.models.user import User
from app.services.snapshot_service import (
create_snapshot, # Import log_action from app.services.audit_service
compare_snapshots,
cleanup_old_snapshots,
get_coverage_evolution,
serialize_snapshot_summary,
list_snapshots as list_snapshots_svc,
get_snapshot_or_raise,
get_snapshot_detail,
delete_snapshot,
)
from app.services.audit_service import log_action from app.services.audit_service import log_action
# Import from app.services.snapshot_service
from app.services.snapshot_service import (
compare_snapshots,
create_snapshot,
delete_snapshot,
get_coverage_evolution,
get_snapshot_detail,
get_snapshot_or_raise,
serialize_snapshot_summary,
)
# Import from app.services.snapshot_service
from app.services.snapshot_service import (
list_snapshots as list_snapshots_svc,
)
# Assign logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Assign router = APIRouter(prefix="/snapshots", tags=["snapshots"])
router = APIRouter(prefix="/snapshots", tags=["snapshots"]) router = APIRouter(prefix="/snapshots", tags=["snapshots"])
# ── Pydantic schemas ───────────────────────────────────────────────── # ── Pydantic schemas ─────────────────────────────────────────────────
class SnapshotCreate(BaseModel): class SnapshotCreate(BaseModel):
"""Payload for creating a new coverage snapshot."""
# Assign name = None
name: Optional[str] = None name: Optional[str] = None
@@ -46,13 +77,19 @@ class SnapshotCreate(BaseModel):
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@router.get("") @router.get("")
# Define function list_snapshots
def list_snapshots( def list_snapshots(
# Entry: offset
offset: int = Query(0, ge=0), offset: int = Query(0, ge=0),
# Entry: limit
limit: int = Query(50, ge=1, le=200), limit: int = Query(50, ge=1, le=200),
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> list:
"""List coverage snapshots ordered by creation date (newest first).""" """List coverage snapshots ordered by creation date (newest first)."""
# Return list_snapshots_svc(db, offset=offset, limit=limit)
return list_snapshots_svc(db, offset=offset, limit=limit) return list_snapshots_svc(db, offset=offset, limit=limit)
@@ -61,25 +98,39 @@ def list_snapshots(
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@router.post("", status_code=201) @router.post("", status_code=201)
# Define function create_snapshot_endpoint
def create_snapshot_endpoint( def create_snapshot_endpoint(
# Entry: payload
payload: SnapshotCreate, payload: SnapshotCreate,
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(require_any_role("red_lead", "blue_lead", "admin")), current_user: User = Depends(require_any_role("red_lead", "blue_lead", "admin")),
): ) -> dict:
"""Create a manual coverage snapshot with an optional name.""" """Create a manual coverage snapshot with an optional name."""
# Assign snapshot = create_snapshot(db, name=payload.name, user_id=current_user.id)
snapshot = create_snapshot(db, name=payload.name, user_id=current_user.id) snapshot = create_snapshot(db, name=payload.name, user_id=current_user.id)
# Open context manager
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
# Call log_action()
log_action( log_action(
db, db,
# Keyword argument: user_id
user_id=current_user.id, user_id=current_user.id,
# Keyword argument: action
action="create_snapshot", action="create_snapshot",
# Keyword argument: entity_type
entity_type="snapshot", entity_type="snapshot",
# Keyword argument: entity_id
entity_id=snapshot.id, entity_id=snapshot.id,
# Keyword argument: details
details={"name": snapshot.name, "score": snapshot.organization_score}, details={"name": snapshot.name, "score": snapshot.organization_score},
) )
# Call uow.commit()
uow.commit() uow.commit()
# Return serialize_snapshot_summary(snapshot)
return serialize_snapshot_summary(snapshot) return serialize_snapshot_summary(snapshot)
@@ -89,12 +140,17 @@ def create_snapshot_endpoint(
@router.get("/evolution") @router.get("/evolution")
# Define function coverage_evolution
def coverage_evolution( def coverage_evolution(
# Entry: months
months: int = Query(12, ge=1, le=36), months: int = Query(12, ge=1, le=36),
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> list:
"""Return coverage snapshots for trend charts (last *months* months).""" """Return coverage snapshots for trend charts (last *months* months)."""
# Return get_coverage_evolution(db, months=months)
return get_coverage_evolution(db, months=months) return get_coverage_evolution(db, months=months)
@@ -103,19 +159,30 @@ def coverage_evolution(
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@router.get("/compare") @router.get("/compare")
# Define function compare_snapshots_endpoint
def compare_snapshots_endpoint( def compare_snapshots_endpoint(
# Entry: a
a: str = Query(..., description="Snapshot A ID"), a: str = Query(..., description="Snapshot A ID"),
# Entry: b
b: str = Query(..., description="Snapshot B ID"), b: str = Query(..., description="Snapshot B ID"),
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> dict:
"""Compare two snapshots showing improved, worsened, and unchanged techniques.""" """Compare two snapshots showing improved, worsened, and unchanged techniques."""
# Attempt the following; catch errors below
try: try:
# Assign a_id = uuid.UUID(a)
a_id = uuid.UUID(a) a_id = uuid.UUID(a)
# Assign b_id = uuid.UUID(b)
b_id = uuid.UUID(b) b_id = uuid.UUID(b)
# Handle ValueError
except ValueError: except ValueError:
# Raise BusinessRuleViolation
raise BusinessRuleViolation("Invalid snapshot ID format") raise BusinessRuleViolation("Invalid snapshot ID format")
# Return compare_snapshots(db, a_id, b_id)
return compare_snapshots(db, a_id, b_id) return compare_snapshots(db, a_id, b_id)
@@ -124,12 +191,17 @@ def compare_snapshots_endpoint(
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@router.get("/{snapshot_id}") @router.get("/{snapshot_id}")
# Define function get_snapshot
def get_snapshot( def get_snapshot(
# Entry: snapshot_id
snapshot_id: str, snapshot_id: str,
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> dict:
"""Get detailed snapshot information including per-technique states.""" """Get detailed snapshot information including per-technique states."""
# Return get_snapshot_detail(db, snapshot_id)
return get_snapshot_detail(db, snapshot_id) return get_snapshot_detail(db, snapshot_id)
@@ -138,24 +210,39 @@ def get_snapshot(
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@router.delete("/{snapshot_id}") @router.delete("/{snapshot_id}")
# Define function delete_snapshot_endpoint
def delete_snapshot_endpoint( def delete_snapshot_endpoint(
# Entry: snapshot_id
snapshot_id: str, snapshot_id: str,
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(require_role("admin")), current_user: User = Depends(require_role("admin")),
): ) -> dict:
"""Delete a snapshot (admin only).""" """Delete a snapshot (admin only)."""
# Assign snapshot = get_snapshot_or_raise(db, snapshot_id)
snapshot = get_snapshot_or_raise(db, snapshot_id) snapshot = get_snapshot_or_raise(db, snapshot_id)
# Open context manager
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
# Call log_action()
log_action( log_action(
db, db,
# Keyword argument: user_id
user_id=current_user.id, user_id=current_user.id,
# Keyword argument: action
action="delete_snapshot", action="delete_snapshot",
# Keyword argument: entity_type
entity_type="snapshot", entity_type="snapshot",
# Keyword argument: entity_id
entity_id=snapshot.id, entity_id=snapshot.id,
# Keyword argument: details
details={"name": snapshot.name}, details={"name": snapshot.name},
) )
# Call delete_snapshot()
delete_snapshot(db, snapshot_id) delete_snapshot(db, snapshot_id)
# Call uow.commit()
uow.commit() uow.commit()
# Return {"detail": "Snapshot deleted"}
return {"detail": "Snapshot deleted"} return {"detail": "Snapshot deleted"}
+75 -8
View File
@@ -5,32 +5,59 @@ ATT&CK synchronisation, intel scanning, Atomic Red Team import, and
scheduler health introspection. scheduler health introspection.
""" """
# Import logging
import logging import logging
# Import APIRouter, Depends, Request from fastapi
from fastapi import APIRouter, Depends, Request from fastapi import APIRouter, Depends, Request
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
# Import get_db from app.database
from app.database import get_db from app.database import get_db
# Import require_role from app.dependencies.auth
from app.dependencies.auth import require_role from app.dependencies.auth import require_role
from app.models.user import User
from app.services.mitre_sync_service import sync_mitre # Import scheduler from app.jobs.mitre_sync_job
from app.services.intel_service import scan_intel
from app.services.atomic_import_service import import_atomic_red_team
from app.jobs.mitre_sync_job import scheduler from app.jobs.mitre_sync_job import scheduler
# Import limiter from app.limiter
from app.limiter import limiter from app.limiter import limiter
# Import User from app.models.user
from app.models.user import User
# Import import_atomic_red_team from app.services.atomic_import_service
from app.services.atomic_import_service import import_atomic_red_team
# Import scan_intel from app.services.intel_service
from app.services.intel_service import scan_intel
# Import sync_mitre from app.services.mitre_sync_service
from app.services.mitre_sync_service import sync_mitre
# Assign logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Assign router = APIRouter(prefix="/system", tags=["system"])
router = APIRouter(prefix="/system", tags=["system"]) router = APIRouter(prefix="/system", tags=["system"])
# Apply the @router.post decorator
@router.post("/sync-mitre") @router.post("/sync-mitre")
# Apply the @limiter.limit decorator
@limiter.limit("2/hour") @limiter.limit("2/hour")
# Define function trigger_mitre_sync
def trigger_mitre_sync( def trigger_mitre_sync(
# Entry: request
request: Request, request: Request,
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(require_role("admin")), current_user: User = Depends(require_role("admin")),
): ) -> dict:
"""Manually trigger a MITRE ATT&CK synchronisation. """Manually trigger a MITRE ATT&CK synchronisation.
**Requires** the ``admin`` role. **Requires** the ``admin`` role.
@@ -38,19 +65,28 @@ def trigger_mitre_sync(
Returns a JSON object with the sync summary including the count of Returns a JSON object with the sync summary including the count of
new and updated techniques. new and updated techniques.
""" """
# Assign summary = sync_mitre(db)
summary = sync_mitre(db) summary = sync_mitre(db)
# Return {
return { return {
# Literal argument value
"message": "MITRE sync completed", "message": "MITRE sync completed",
# Literal argument value
"new": summary["created"], "new": summary["created"],
# Literal argument value
"updated": summary["updated"], "updated": summary["updated"],
} }
# Apply the @router.post decorator
@router.post("/run-intel-scan") @router.post("/run-intel-scan")
# Define function trigger_intel_scan
def trigger_intel_scan( def trigger_intel_scan(
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(require_role("admin")), current_user: User = Depends(require_role("admin")),
): ) -> dict:
"""Manually trigger a threat-intelligence scan. """Manually trigger a threat-intelligence scan.
**Requires** the ``admin`` role. **Requires** the ``admin`` role.
@@ -58,20 +94,30 @@ def trigger_intel_scan(
Returns a JSON object with the scan summary including the count of Returns a JSON object with the scan summary including the count of
new intel items found. new intel items found.
""" """
# Assign summary = scan_intel(db)
summary = scan_intel(db) summary = scan_intel(db)
# Return {
return { return {
# Literal argument value
"message": "Intel scan completed", "message": "Intel scan completed",
# Literal argument value
"new_items": summary["new_items"], "new_items": summary["new_items"],
} }
# Apply the @router.post decorator
@router.post("/import-atomic-tests") @router.post("/import-atomic-tests")
# Apply the @limiter.limit decorator
@limiter.limit("2/hour") @limiter.limit("2/hour")
# Define function trigger_atomic_import
def trigger_atomic_import( def trigger_atomic_import(
# Entry: request
request: Request, request: Request,
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(require_role("admin")), current_user: User = Depends(require_role("admin")),
): ) -> dict:
"""Trigger an import of Atomic Red Team tests as TestTemplates. """Trigger an import of Atomic Red Team tests as TestTemplates.
**Requires** the ``admin`` role. **Requires** the ``admin`` role.
@@ -82,37 +128,58 @@ def trigger_atomic_import(
Returns a JSON object with import statistics. Returns a JSON object with import statistics.
""" """
# Attempt the following; catch errors below
try: try:
# Assign summary = import_atomic_red_team(db)
summary = import_atomic_red_team(db) summary = import_atomic_red_team(db)
# Handle Exception
except Exception as exc: except Exception as exc:
# Log error: "Atomic Red Team import failed: %s", exc, exc_info
logger.error("Atomic Red Team import failed: %s", exc, exc_info=True) logger.error("Atomic Red Team import failed: %s", exc, exc_info=True)
# Return {
return { return {
# Literal argument value
"message": "Import failed. Check server logs for details.", "message": "Import failed. Check server logs for details.",
} }
# Return {
return { return {
# Literal argument value
"message": "Import completed", "message": "Import completed",
# Literal argument value
"imported": summary["created"], "imported": summary["created"],
# Literal argument value
"skipped": summary["skipped_existing"], "skipped": summary["skipped_existing"],
# Literal argument value
"total_parsed": summary["total_tests_parsed"], "total_parsed": summary["total_tests_parsed"],
} }
# Apply the @router.get decorator
@router.get("/scheduler-status") @router.get("/scheduler-status")
# Define function scheduler_status
def scheduler_status( def scheduler_status(
# Entry: current_user
current_user: User = Depends(require_role("admin")), current_user: User = Depends(require_role("admin")),
): ) -> dict:
"""Return the current state of the background scheduler. """Return the current state of the background scheduler.
**Requires** the ``admin`` role. **Requires** the ``admin`` role.
""" """
# Assign jobs = scheduler.get_jobs()
jobs = scheduler.get_jobs() jobs = scheduler.get_jobs()
# Return {
return { return {
# Literal argument value
"running": scheduler.running, "running": scheduler.running,
# Literal argument value
"jobs": [ "jobs": [
{ {
# Literal argument value
"id": job.id, "id": job.id,
# Literal argument value
"name": job.name, "name": job.name,
# Literal argument value
"next_run_time": str(job.next_run_time) if job.next_run_time else None, "next_run_time": str(job.next_run_time) if job.next_run_time else None,
} }
for job in jobs for job in jobs
+116 -7
View File
@@ -5,29 +5,56 @@ for error signaling. The error_handler middleware maps domain
exceptions to HTTP responses automatically. exceptions to HTTP responses automatically.
""" """
# Import APIRouter, Depends, Query, status from fastapi
from fastapi import APIRouter, Depends, Query, status from fastapi import APIRouter, Depends, Query, status
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
# Import get_db from app.database
from app.database import get_db from app.database import get_db
from app.dependencies.auth import get_current_user, require_role, require_any_role
# Import get_current_user, require_any_role, require_role from app.dependencies.auth
from app.dependencies.auth import get_current_user, require_any_role, require_role
# Import get_technique_repository from app.dependencies.repositories
from app.dependencies.repositories import get_technique_repository from app.dependencies.repositories import get_technique_repository
# Import TechniqueEntity from app.domain.entities.technique
from app.domain.entities.technique import TechniqueEntity from app.domain.entities.technique import TechniqueEntity
from app.domain.errors import DuplicateEntityError, EntityNotFoundError
# Import TechniqueStatus from app.domain.enums
from app.domain.enums import TechniqueStatus from app.domain.enums import TechniqueStatus
# Import DuplicateEntityError, EntityNotFoundError from app.domain.errors
from app.domain.errors import DuplicateEntityError, EntityNotFoundError
# Import UnitOfWork from app.domain.unit_of_work
from app.domain.unit_of_work import UnitOfWork from app.domain.unit_of_work import UnitOfWork
# Import from app.infrastructure.persistence.repositories.sa_technique_repository
from app.infrastructure.persistence.repositories.sa_technique_repository import ( from app.infrastructure.persistence.repositories.sa_technique_repository import (
SATechniqueRepository, SATechniqueRepository,
) )
# Import User from app.models.user
from app.models.user import User from app.models.user import User
# Import from app.schemas.technique
from app.schemas.technique import ( from app.schemas.technique import (
TechniqueCreate, TechniqueCreate,
TechniqueOut, TechniqueOut,
TechniqueSummary, TechniqueSummary,
TechniqueUpdate, TechniqueUpdate,
) )
# Import log_action from app.services.audit_service
from app.services.audit_service import log_action from app.services.audit_service import log_action
# Import get_technique_detail from app.services.technique_query_service
from app.services.technique_query_service import get_technique_detail from app.services.technique_query_service import get_technique_detail
# Assign router = APIRouter(prefix="/techniques", tags=["techniques"])
router = APIRouter(prefix="/techniques", tags=["techniques"]) router = APIRouter(prefix="/techniques", tags=["techniques"])
@@ -37,19 +64,29 @@ router = APIRouter(prefix="/techniques", tags=["techniques"])
@router.get("", response_model=list[TechniqueSummary]) @router.get("", response_model=list[TechniqueSummary])
# Define function list_techniques
def list_techniques( def list_techniques(
# Entry: tactic
tactic: str | None = Query(None, description="Filter by tactic name"), tactic: str | None = Query(None, description="Filter by tactic name"),
# Entry: status_global
status_global: TechniqueStatus | None = Query( status_global: TechniqueStatus | None = Query(
None, alias="status", description="Filter by global status" None, alias="status", description="Filter by global status"
), ),
# Entry: review_required
review_required: bool | None = Query(None, description="Filter by review flag"), review_required: bool | None = Query(None, description="Filter by review flag"),
# Entry: repo
repo: SATechniqueRepository = Depends(get_technique_repository), repo: SATechniqueRepository = Depends(get_technique_repository),
# Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> list:
"""Return a lightweight list of techniques, optionally filtered.""" """Return a lightweight list of techniques, optionally filtered."""
# Return repo.list_all(
return repo.list_all( return repo.list_all(
# Keyword argument: tactic
tactic=tactic, tactic=tactic,
# Keyword argument: status
status=status_global, status=status_global,
# Keyword argument: review_required
review_required=review_required, review_required=review_required,
) )
@@ -60,12 +97,17 @@ def list_techniques(
@router.get("/{mitre_id}") @router.get("/{mitre_id}")
# Define function get_technique
def get_technique( def get_technique(
# Entry: mitre_id
mitre_id: str, mitre_id: str,
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> dict:
"""Return full details for a single technique, including its tests and D3FEND defenses.""" """Return full details for a single technique, including its tests and D3FEND defenses."""
# Return get_technique_detail(db, mitre_id)
return get_technique_detail(db, mitre_id) return get_technique_detail(db, mitre_id)
@@ -75,40 +117,66 @@ def get_technique(
@router.post( @router.post(
# Literal argument value
"", "",
# Keyword argument: response_model
response_model=TechniqueOut, response_model=TechniqueOut,
# Keyword argument: status_code
status_code=status.HTTP_201_CREATED, status_code=status.HTTP_201_CREATED,
) )
# Define function create_technique
def create_technique( def create_technique(
# Entry: payload
payload: TechniqueCreate, payload: TechniqueCreate,
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: repo
repo: SATechniqueRepository = Depends(get_technique_repository), repo: SATechniqueRepository = Depends(get_technique_repository),
# Entry: current_user
current_user: User = Depends(require_role("admin")), current_user: User = Depends(require_role("admin")),
): ) -> TechniqueOut:
"""Create a new technique manually.""" """Create a new technique manually."""
# Check: repo.exists_by_mitre_id(payload.mitre_id)
if repo.exists_by_mitre_id(payload.mitre_id): if repo.exists_by_mitre_id(payload.mitre_id):
# Raise DuplicateEntityError
raise DuplicateEntityError("Technique", "mitre_id", payload.mitre_id) raise DuplicateEntityError("Technique", "mitre_id", payload.mitre_id)
# Assign entity = TechniqueEntity.create(
entity = TechniqueEntity.create( entity = TechniqueEntity.create(
# Keyword argument: mitre_id
mitre_id=payload.mitre_id, mitre_id=payload.mitre_id,
# Keyword argument: name
name=payload.name, name=payload.name,
# Keyword argument: description
description=payload.description, description=payload.description,
# Keyword argument: tactic
tactic=payload.tactic, tactic=payload.tactic,
# Keyword argument: platforms
platforms=payload.platforms, platforms=payload.platforms,
) )
# Open context manager
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
# Assign saved = repo.save(entity)
saved = repo.save(entity) saved = repo.save(entity)
# Call log_action()
log_action( log_action(
db, db,
# Keyword argument: user_id
user_id=current_user.id, user_id=current_user.id,
# Keyword argument: action
action="create_technique", action="create_technique",
# Keyword argument: entity_type
entity_type="technique", entity_type="technique",
# Keyword argument: entity_id
entity_id=saved.id, entity_id=saved.id,
# Keyword argument: details
details={"mitre_id": saved.mitre_id, "name": saved.name}, details={"mitre_id": saved.mitre_id, "name": saved.name},
) )
# Call uow.commit()
uow.commit() uow.commit()
# Return saved
return saved return saved
@@ -118,34 +186,56 @@ def create_technique(
@router.patch("/{mitre_id}", response_model=TechniqueOut) @router.patch("/{mitre_id}", response_model=TechniqueOut)
# Define function update_technique
def update_technique( def update_technique(
# Entry: mitre_id
mitre_id: str, mitre_id: str,
# Entry: payload
payload: TechniqueUpdate, payload: TechniqueUpdate,
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: repo
repo: SATechniqueRepository = Depends(get_technique_repository), repo: SATechniqueRepository = Depends(get_technique_repository),
# Entry: current_user
current_user: User = Depends(require_role("admin")), current_user: User = Depends(require_role("admin")),
): ) -> TechniqueOut:
"""Update one or more fields of an existing technique.""" """Update one or more fields of an existing technique."""
# Assign entity = repo.find_by_mitre_id(mitre_id)
entity = repo.find_by_mitre_id(mitre_id) entity = repo.find_by_mitre_id(mitre_id)
# Check: entity is None
if entity is None: if entity is None:
# Raise EntityNotFoundError
raise EntityNotFoundError("Technique", mitre_id) raise EntityNotFoundError("Technique", mitre_id)
# Assign update_data = payload.model_dump(exclude_unset=True)
update_data = payload.model_dump(exclude_unset=True) update_data = payload.model_dump(exclude_unset=True)
# Iterate over update_data.items()
for field, value in update_data.items(): for field, value in update_data.items():
# Call setattr()
setattr(entity, field, value) setattr(entity, field, value)
# Open context manager
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
# Assign saved = repo.save(entity)
saved = repo.save(entity) saved = repo.save(entity)
# Call log_action()
log_action( log_action(
db, db,
# Keyword argument: user_id
user_id=current_user.id, user_id=current_user.id,
# Keyword argument: action
action="update_technique", action="update_technique",
# Keyword argument: entity_type
entity_type="technique", entity_type="technique",
# Keyword argument: entity_id
entity_id=saved.id, entity_id=saved.id,
# Keyword argument: details
details={"mitre_id": mitre_id, "updated_fields": list(update_data.keys())}, details={"mitre_id": mitre_id, "updated_fields": list(update_data.keys())},
) )
# Call uow.commit()
uow.commit() uow.commit()
# Return saved
return saved return saved
@@ -155,33 +245,52 @@ def update_technique(
@router.patch("/{mitre_id}/review", response_model=TechniqueOut) @router.patch("/{mitre_id}/review", response_model=TechniqueOut)
# Define function review_technique
def review_technique( def review_technique(
# Entry: mitre_id
mitre_id: str, mitre_id: str,
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: repo
repo: SATechniqueRepository = Depends(get_technique_repository), repo: SATechniqueRepository = Depends(get_technique_repository),
# Entry: current_user
current_user: User = Depends(require_any_role("red_lead", "blue_lead")), current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
): ) -> TechniqueOut:
"""Mark a technique as reviewed. """Mark a technique as reviewed.
Sets ``review_required`` to *False* and records the current timestamp Sets ``review_required`` to *False* and records the current timestamp
in ``last_review_date``. in ``last_review_date``.
""" """
# Assign entity = repo.find_by_mitre_id(mitre_id)
entity = repo.find_by_mitre_id(mitre_id) entity = repo.find_by_mitre_id(mitre_id)
# Check: entity is None
if entity is None: if entity is None:
# Raise EntityNotFoundError
raise EntityNotFoundError("Technique", mitre_id) raise EntityNotFoundError("Technique", mitre_id)
# Call entity.mark_reviewed()
entity.mark_reviewed() entity.mark_reviewed()
# Open context manager
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
# Assign saved = repo.save(entity)
saved = repo.save(entity) saved = repo.save(entity)
# Call log_action()
log_action( log_action(
db, db,
# Keyword argument: user_id
user_id=current_user.id, user_id=current_user.id,
# Keyword argument: action
action="review_technique", action="review_technique",
# Keyword argument: entity_type
entity_type="technique", entity_type="technique",
# Keyword argument: entity_id
entity_id=saved.id, entity_id=saved.id,
# Keyword argument: details
details={"mitre_id": mitre_id}, details={"mitre_id": mitre_id},
) )
# Call uow.commit()
uow.commit() uow.commit()
# Return saved
return saved return saved
+262 -20
View File
@@ -22,34 +22,70 @@ Filters (GET /test-templates)
- offset / limit: pagination (default limit=50) - offset / limit: pagination (default limit=50)
""" """
# Import uuid
import uuid import uuid
# Import Optional from typing
from typing import Optional from typing import Optional
# Import APIRouter, Depends, Query, status from fastapi
from fastapi import APIRouter, Depends, Query, status from fastapi import APIRouter, Depends, Query, status
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
# Import get_db from app.database
from app.database import get_db from app.database import get_db
# Import get_current_user, require_any_role from app.dependencies.auth
from app.dependencies.auth import get_current_user, require_any_role from app.dependencies.auth import get_current_user, require_any_role
# Import UnitOfWork from app.domain.unit_of_work
from app.domain.unit_of_work import UnitOfWork from app.domain.unit_of_work import UnitOfWork
# Import User from app.models.user
from app.models.user import User from app.models.user import User
# Import from app.schemas.test_template
from app.schemas.test_template import ( from app.schemas.test_template import (
TestTemplateCreate, TestTemplateCreate,
TestTemplateOut, TestTemplateOut,
TestTemplateSummary, TestTemplateSummary,
) )
# Import log_action from app.services.audit_service
from app.services.audit_service import log_action from app.services.audit_service import log_action
# Import from app.services.test_template_service
from app.services.test_template_service import ( from app.services.test_template_service import (
bulk_activate, bulk_activate,
create_template as create_template_svc,
get_template_or_raise, get_template_or_raise,
get_template_stats, get_template_stats,
get_templates_by_technique as templates_by_technique,
list_templates, list_templates,
soft_delete_template, soft_delete_template,
)
# Import from app.services.test_template_service
from app.services.test_template_service import (
create_template as create_template_svc,
)
# Import from app.services.test_template_service
from app.services.test_template_service import (
get_templates_by_technique as templates_by_technique,
)
# Import from app.services.test_template_service
from app.services.test_template_service import (
toggle_template_active as toggle_template_active_svc, toggle_template_active as toggle_template_active_svc,
)
# Import from app.services.test_template_service
from app.services.test_template_service import (
update_template as update_template_svc, update_template as update_template_svc,
) )
# Assign router = APIRouter(prefix="/test-templates", tags=["test-templates"])
router = APIRouter(prefix="/test-templates", tags=["test-templates"]) router = APIRouter(prefix="/test-templates", tags=["test-templates"])
@@ -59,28 +95,64 @@ router = APIRouter(prefix="/test-templates", tags=["test-templates"])
@router.get("", response_model=list[TestTemplateSummary]) @router.get("", response_model=list[TestTemplateSummary])
# Define function _list_templates_handler
def _list_templates_handler( def _list_templates_handler(
# Entry: source
source: Optional[str] = Query(None, description="Filter by source (atomic_red_team, mitre, custom)"), source: Optional[str] = Query(None, description="Filter by source (atomic_red_team, mitre, custom)"),
# Entry: platform
platform: Optional[str] = Query(None, description="Filter by platform (windows, linux, macos)"), platform: Optional[str] = Query(None, description="Filter by platform (windows, linux, macos)"),
# Entry: severity
severity: Optional[str] = Query(None, description="Filter by severity (low, medium, high, critical)"), severity: Optional[str] = Query(None, description="Filter by severity (low, medium, high, critical)"),
# Entry: mitre_technique_id
mitre_technique_id: Optional[str] = Query(None, description="Filter by MITRE technique ID"), mitre_technique_id: Optional[str] = Query(None, description="Filter by MITRE technique ID"),
# Entry: search
search: Optional[str] = Query(None, description="Search in name and description"), search: Optional[str] = Query(None, description="Search in name and description"),
# Entry: is_active
is_active: Optional[bool] = Query(None, description="Filter by active status (true/false). Omit to return all."), is_active: Optional[bool] = Query(None, description="Filter by active status (true/false). Omit to return all."),
# Entry: offset
offset: int = Query(0, ge=0), offset: int = Query(0, ge=0),
# Entry: limit
limit: int = Query(50, ge=1, le=200), limit: int = Query(50, ge=1, le=200),
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> list:
"""Return a paginated, filterable list of test templates.""" """Return a paginated, filterable list of test templates.
Args:
source (Optional[str]): Filter by source (``atomic_red_team``, ``mitre``, ``custom``).
platform (Optional[str]): Filter by platform (``windows``, ``linux``, ``macos``).
severity (Optional[str]): Filter by severity (``low``, ``medium``, ``high``, ``critical``).
mitre_technique_id (Optional[str]): Filter by MITRE technique ID string.
search (Optional[str]): Full-text search across name and description.
is_active (Optional[bool]): Filter by active status; omit to return all.
offset (int): Number of records to skip for pagination.
limit (int): Maximum number of records to return.
db (Session): SQLAlchemy database session.
current_user (User): Authenticated user making the request.
Returns:
list: Serialised list of :class:`TestTemplateSummary` objects.
"""
# Return list_templates(
return list_templates( return list_templates(
db, db,
# Keyword argument: source
source=source, source=source,
# Keyword argument: platform
platform=platform, platform=platform,
# Keyword argument: severity
severity=severity, severity=severity,
# Keyword argument: mitre_technique_id
mitre_technique_id=mitre_technique_id, mitre_technique_id=mitre_technique_id,
# Keyword argument: search
search=search, search=search,
# Keyword argument: is_active
is_active=is_active, is_active=is_active,
# Keyword argument: offset
offset=offset, offset=offset,
# Keyword argument: limit
limit=limit, limit=limit,
) )
@@ -91,11 +163,23 @@ def _list_templates_handler(
@router.get("/stats") @router.get("/stats")
# Define function template_stats
def template_stats( def template_stats(
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(require_any_role("red_lead", "blue_lead")), current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
): ) -> dict:
"""Return catalog statistics: active, by_source, by_platform.""" """Return catalog statistics: active, by_source, by_platform.
Args:
db (Session): SQLAlchemy database session.
current_user (User): Authenticated red_lead or blue_lead.
Returns:
dict: Counts of active templates broken down by source and platform.
"""
# Return get_template_stats(db)
return get_template_stats(db) return get_template_stats(db)
@@ -105,27 +189,53 @@ def template_stats(
@router.patch("/bulk-activate") @router.patch("/bulk-activate")
# Define function bulk_activate_templates
def bulk_activate_templates( def bulk_activate_templates(
# Entry: activate
activate: bool = Query(True, description="True to activate all, False to deactivate all"), activate: bool = Query(True, description="True to activate all, False to deactivate all"),
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(require_any_role("red_lead", "blue_lead")), current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
): ) -> dict:
"""Set all templates to active or inactive.""" """Set all templates to active or inactive.
Args:
activate (bool): ``True`` to activate all templates, ``False`` to deactivate all.
db (Session): SQLAlchemy database session.
current_user (User): Authenticated red_lead or blue_lead.
Returns:
dict: Confirmation message with ``affected`` count and the applied ``is_active`` flag.
"""
# Assign count = bulk_activate(db, activate=activate)
count = bulk_activate(db, activate=activate) count = bulk_activate(db, activate=activate)
# Open context manager
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
# Call log_action()
log_action( log_action(
db, db,
# Keyword argument: user_id
user_id=current_user.id, user_id=current_user.id,
# Keyword argument: action
action="bulk_activate_templates" if activate else "bulk_deactivate_templates", action="bulk_activate_templates" if activate else "bulk_deactivate_templates",
# Keyword argument: entity_type
entity_type="test_template", entity_type="test_template",
# Keyword argument: entity_id
entity_id=None, entity_id=None,
# Keyword argument: details
details={"affected": count, "is_active": activate}, details={"affected": count, "is_active": activate},
) )
# Call uow.commit()
uow.commit() uow.commit()
# Return {
return { return {
# Literal argument value
"detail": f"{'Activated' if activate else 'Deactivated'} {count} templates", "detail": f"{'Activated' if activate else 'Deactivated'} {count} templates",
# Literal argument value
"affected": count, "affected": count,
# Literal argument value
"is_active": activate, "is_active": activate,
} }
@@ -136,12 +246,26 @@ def bulk_activate_templates(
@router.get("/by-technique/{mitre_id}", response_model=list[TestTemplateSummary]) @router.get("/by-technique/{mitre_id}", response_model=list[TestTemplateSummary])
# Define function _templates_by_technique_handler
def _templates_by_technique_handler( def _templates_by_technique_handler(
# Entry: mitre_id
mitre_id: str, mitre_id: str,
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> list:
"""Return all active templates mapped to a specific MITRE technique.""" """Return all active templates mapped to a specific MITRE technique.
Args:
mitre_id (str): MITRE ATT&CK technique ID (e.g. ``T1059.001``).
db (Session): SQLAlchemy database session.
current_user (User): Authenticated user making the request.
Returns:
list: Serialised list of :class:`TestTemplateSummary` objects for the technique.
"""
# Return templates_by_technique(db, mitre_id)
return templates_by_technique(db, mitre_id) return templates_by_technique(db, mitre_id)
@@ -151,12 +275,26 @@ def _templates_by_technique_handler(
@router.get("/{template_id}", response_model=TestTemplateOut) @router.get("/{template_id}", response_model=TestTemplateOut)
# Define function get_template
def get_template( def get_template(
# Entry: template_id
template_id: uuid.UUID, template_id: uuid.UUID,
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> TestTemplateOut:
"""Return full details for a single test template.""" """Return full details for a single test template.
Args:
template_id (uuid.UUID): Primary key of the template to retrieve.
db (Session): SQLAlchemy database session.
current_user (User): Authenticated user making the request.
Returns:
TestTemplateOut: Full template detail including all fields.
"""
# Return get_template_or_raise(db, template_id)
return get_template_or_raise(db, template_id) return get_template_or_raise(db, template_id)
@@ -166,33 +304,63 @@ def get_template(
@router.post( @router.post(
# Literal argument value
"", "",
# Keyword argument: response_model
response_model=TestTemplateOut, response_model=TestTemplateOut,
# Keyword argument: status_code
status_code=status.HTTP_201_CREATED, status_code=status.HTTP_201_CREATED,
) )
# Define function create_template
def create_template( def create_template(
# Entry: payload
payload: TestTemplateCreate, payload: TestTemplateCreate,
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(require_any_role("red_lead", "blue_lead")), current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
): ) -> TestTemplateOut:
"""Create a custom test template.""" """Create a custom test template.
Args:
payload (TestTemplateCreate): All fields for the new template.
db (Session): SQLAlchemy database session.
current_user (User): Authenticated red_lead or blue_lead creating the template.
Returns:
TestTemplateOut: The newly created template with all fields populated.
"""
# Assign template = create_template_svc(db, **payload.model_dump())
template = create_template_svc(db, **payload.model_dump()) template = create_template_svc(db, **payload.model_dump())
# Open context manager
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
# Call log_action()
log_action( log_action(
db, db,
# Keyword argument: user_id
user_id=current_user.id, user_id=current_user.id,
# Keyword argument: action
action="create_test_template", action="create_test_template",
# Keyword argument: entity_type
entity_type="test_template", entity_type="test_template",
# Keyword argument: entity_id
entity_id=template.id, entity_id=template.id,
# Keyword argument: details
details={ details={
# Literal argument value
"name": template.name, "name": template.name,
# Literal argument value
"source": template.source, "source": template.source,
# Literal argument value
"mitre_technique_id": template.mitre_technique_id, "mitre_technique_id": template.mitre_technique_id,
}, },
) )
# Call uow.commit()
uow.commit() uow.commit()
# Reload ORM object attributes from the database
db.refresh(template) db.refresh(template)
# Return template
return template return template
@@ -202,26 +370,52 @@ def create_template(
@router.patch("/{template_id}", response_model=TestTemplateOut) @router.patch("/{template_id}", response_model=TestTemplateOut)
# Define function update_template
def update_template( def update_template(
# Entry: template_id
template_id: uuid.UUID, template_id: uuid.UUID,
# Entry: payload
payload: TestTemplateCreate, payload: TestTemplateCreate,
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(require_any_role("red_lead", "blue_lead")), current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
): ) -> TestTemplateOut:
"""Update fields of an existing test template.""" """Update fields of an existing test template.
Args:
template_id (uuid.UUID): Primary key of the template to update.
payload (TestTemplateCreate): Fields to update; only set fields are applied.
db (Session): SQLAlchemy database session.
current_user (User): Authenticated red_lead or blue_lead updating the template.
Returns:
TestTemplateOut: The updated template with refreshed field values.
"""
# Assign template = update_template_svc(db, template_id, **payload.model_dump(exclude_u...
template = update_template_svc(db, template_id, **payload.model_dump(exclude_unset=True)) template = update_template_svc(db, template_id, **payload.model_dump(exclude_unset=True))
# Open context manager
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
# Call log_action()
log_action( log_action(
db, db,
# Keyword argument: user_id
user_id=current_user.id, user_id=current_user.id,
# Keyword argument: action
action="update_test_template", action="update_test_template",
# Keyword argument: entity_type
entity_type="test_template", entity_type="test_template",
# Keyword argument: entity_id
entity_id=template.id, entity_id=template.id,
# Keyword argument: details
details={"updated_fields": list(payload.model_dump(exclude_unset=True).keys())}, details={"updated_fields": list(payload.model_dump(exclude_unset=True).keys())},
) )
# Call uow.commit()
uow.commit() uow.commit()
# Reload ORM object attributes from the database
db.refresh(template) db.refresh(template)
# Return template
return template return template
@@ -231,25 +425,49 @@ def update_template(
@router.patch("/{template_id}/toggle-active", response_model=TestTemplateOut) @router.patch("/{template_id}/toggle-active", response_model=TestTemplateOut)
# Define function toggle_template_active
def toggle_template_active( def toggle_template_active(
# Entry: template_id
template_id: uuid.UUID, template_id: uuid.UUID,
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(require_any_role("red_lead", "blue_lead")), current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
): ) -> TestTemplateOut:
"""Toggle a template between active and inactive (is_active = not is_active).""" """Toggle a template between active and inactive (is_active = not is_active).
Args:
template_id (uuid.UUID): Primary key of the template to toggle.
db (Session): SQLAlchemy database session.
current_user (User): Authenticated red_lead or blue_lead.
Returns:
TestTemplateOut: The template with the updated ``is_active`` flag.
"""
# Assign template = toggle_template_active_svc(db, template_id)
template = toggle_template_active_svc(db, template_id) template = toggle_template_active_svc(db, template_id)
# Open context manager
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
# Call log_action()
log_action( log_action(
db, db,
# Keyword argument: user_id
user_id=current_user.id, user_id=current_user.id,
# Keyword argument: action
action="toggle_test_template", action="toggle_test_template",
# Keyword argument: entity_type
entity_type="test_template", entity_type="test_template",
# Keyword argument: entity_id
entity_id=template.id, entity_id=template.id,
# Keyword argument: details
details={"name": template.name, "is_active": template.is_active}, details={"name": template.name, "is_active": template.is_active},
) )
# Call uow.commit()
uow.commit() uow.commit()
# Reload ORM object attributes from the database
db.refresh(template) db.refresh(template)
# Return template
return template return template
@@ -259,23 +477,47 @@ def toggle_template_active(
@router.delete("/{template_id}", status_code=status.HTTP_200_OK) @router.delete("/{template_id}", status_code=status.HTTP_200_OK)
# Define function delete_template
def delete_template( def delete_template(
# Entry: template_id
template_id: uuid.UUID, template_id: uuid.UUID,
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(require_any_role("red_lead", "blue_lead")), current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
): ) -> dict:
"""Soft-delete a test template by setting ``is_active=False``.""" """Soft-delete a test template by setting ``is_active=False``.
Args:
template_id (uuid.UUID): Primary key of the template to delete.
db (Session): SQLAlchemy database session.
current_user (User): Authenticated red_lead or blue_lead.
Returns:
dict: Confirmation message with key ``detail``.
"""
# Assign template = get_template_or_raise(db, template_id)
template = get_template_or_raise(db, template_id) template = get_template_or_raise(db, template_id)
# Call soft_delete_template()
soft_delete_template(db, template_id) soft_delete_template(db, template_id)
# Open context manager
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
# Call log_action()
log_action( log_action(
db, db,
# Keyword argument: user_id
user_id=current_user.id, user_id=current_user.id,
# Keyword argument: action
action="delete_test_template", action="delete_test_template",
# Keyword argument: entity_type
entity_type="test_template", entity_type="test_template",
# Keyword argument: entity_id
entity_id=template.id, entity_id=template.id,
# Keyword argument: details
details={"name": template.name}, details={"name": template.name},
) )
# Call uow.commit()
uow.commit() uow.commit()
# Return {"detail": "Test template deactivated"}
return {"detail": "Test template deactivated"} return {"detail": "Test template deactivated"}
File diff suppressed because it is too large Load Diff
+56 -4
View File
@@ -4,15 +4,28 @@ Provides listing, detail, coverage analysis, and gap analysis for
threat actor profiles imported from MITRE CTI. threat actor profiles imported from MITRE CTI.
""" """
# Import logging
import logging import logging
# Import Optional from typing
from typing import Optional from typing import Optional
# Import APIRouter, Depends, Query from fastapi
from fastapi import APIRouter, Depends, Query from fastapi import APIRouter, Depends, Query
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
# Import get_db from app.database
from app.database import get_db from app.database import get_db
# Import get_current_user from app.dependencies.auth
from app.dependencies.auth import get_current_user from app.dependencies.auth import get_current_user
# Import User from app.models.user
from app.models.user import User from app.models.user import User
# Import from app.services.threat_actor_service
from app.services.threat_actor_service import ( from app.services.threat_actor_service import (
get_actor_coverage, get_actor_coverage,
get_actor_detail, get_actor_detail,
@@ -20,58 +33,90 @@ from app.services.threat_actor_service import (
list_actors, list_actors,
) )
# Assign logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Assign router = APIRouter(prefix="/threat-actors", tags=["threat-actors"])
router = APIRouter(prefix="/threat-actors", tags=["threat-actors"]) router = APIRouter(prefix="/threat-actors", tags=["threat-actors"])
# Apply the @router.get decorator
@router.get("") @router.get("")
# Define function list_threat_actors
def list_threat_actors( def list_threat_actors(
# Entry: search
search: Optional[str] = Query(None), search: Optional[str] = Query(None),
# Entry: country
country: Optional[str] = Query(None), country: Optional[str] = Query(None),
# Entry: motivation
motivation: Optional[str] = Query(None), motivation: Optional[str] = Query(None),
# Entry: sophistication
sophistication: Optional[str] = Query(None), sophistication: Optional[str] = Query(None),
# Entry: target_sectors
target_sectors: Optional[str] = Query(None), target_sectors: Optional[str] = Query(None),
# Entry: offset
offset: int = Query(0, ge=0), offset: int = Query(0, ge=0),
# Entry: limit
limit: int = Query(50, ge=1, le=200), limit: int = Query(50, ge=1, le=200),
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> list:
"""List threat actors with optional filters and pagination. """List threat actors with optional filters and pagination.
**Requires** authentication (any role). **Requires** authentication (any role).
""" """
# Return list_actors(
return list_actors( return list_actors(
db, db,
# Keyword argument: search
search=search, search=search,
# Keyword argument: country
country=country, country=country,
# Keyword argument: motivation
motivation=motivation, motivation=motivation,
# Keyword argument: sophistication
sophistication=sophistication, sophistication=sophistication,
# Keyword argument: target_sectors
target_sectors=target_sectors, target_sectors=target_sectors,
# Keyword argument: offset
offset=offset, offset=offset,
# Keyword argument: limit
limit=limit, limit=limit,
) )
# Apply the @router.get decorator
@router.get("/{actor_id}") @router.get("/{actor_id}")
# Define function get_threat_actor
def get_threat_actor( def get_threat_actor(
# Entry: actor_id
actor_id: str, actor_id: str,
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> dict:
"""Get detailed info about a threat actor including techniques. """Get detailed info about a threat actor including techniques.
**Requires** authentication (any role). **Requires** authentication (any role).
""" """
# Return get_actor_detail(db, actor_id)
return get_actor_detail(db, actor_id) return get_actor_detail(db, actor_id)
# Apply the @router.get decorator
@router.get("/{actor_id}/coverage") @router.get("/{actor_id}/coverage")
# Define function get_threat_actor_coverage
def get_threat_actor_coverage( def get_threat_actor_coverage(
# Entry: actor_id
actor_id: str, actor_id: str,
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> dict:
"""Calculate coverage percentage against a specific threat actor. """Calculate coverage percentage against a specific threat actor.
**Requires** authentication (any role). **Requires** authentication (any role).
@@ -79,19 +124,26 @@ def get_threat_actor_coverage(
Returns the percentage of the actor's techniques that have been Returns the percentage of the actor's techniques that have been
validated or partially validated, along with a breakdown. validated or partially validated, along with a breakdown.
""" """
# Return get_actor_coverage(db, actor_id)
return get_actor_coverage(db, actor_id) return get_actor_coverage(db, actor_id)
# Apply the @router.get decorator
@router.get("/{actor_id}/gaps") @router.get("/{actor_id}/gaps")
# Define function get_threat_actor_gaps
def get_threat_actor_gaps( def get_threat_actor_gaps(
# Entry: actor_id
actor_id: str, actor_id: str,
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> list:
"""Identify techniques of this actor that are NOT fully validated. """Identify techniques of this actor that are NOT fully validated.
**Requires** authentication (any role). **Requires** authentication (any role).
Returns list of gap techniques with available templates. Returns list of gap techniques with available templates.
""" """
# Return get_actor_gaps(db, actor_id)
return get_actor_gaps(db, actor_id) return get_actor_gaps(db, actor_id)
+72 -9
View File
@@ -1,16 +1,33 @@
"""User management router (admin only).""" """User management router (admin only)."""
# Import uuid
import uuid import uuid
# Import APIRouter, Depends, status from fastapi
from fastapi import APIRouter, Depends, status from fastapi import APIRouter, Depends, status
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
# Import get_db from app.database
from app.database import get_db from app.database import get_db
# Import require_role from app.dependencies.auth
from app.dependencies.auth import require_role from app.dependencies.auth import require_role
# Import UnitOfWork from app.domain.unit_of_work
from app.domain.unit_of_work import UnitOfWork from app.domain.unit_of_work import UnitOfWork
# Import User from app.models.user
from app.models.user import User from app.models.user import User
from app.schemas.user import UserCreate, UserUpdate, UserOut
# Import UserCreate, UserOut, UserUpdate from app.schemas.user
from app.schemas.user import UserCreate, UserOut, UserUpdate
# Import log_action from app.services.audit_service
from app.services.audit_service import log_action from app.services.audit_service import log_action
# Import from app.services.user_service
from app.services.user_service import ( from app.services.user_service import (
create_user, create_user,
get_user_or_raise, get_user_or_raise,
@@ -18,6 +35,7 @@ from app.services.user_service import (
update_user, update_user,
) )
# Assign router = APIRouter(prefix="/users", tags=["users"])
router = APIRouter(prefix="/users", tags=["users"]) router = APIRouter(prefix="/users", tags=["users"])
@@ -27,11 +45,15 @@ router = APIRouter(prefix="/users", tags=["users"])
@router.get("", response_model=list[UserOut]) @router.get("", response_model=list[UserOut])
# Define function list_users_route
def list_users_route( def list_users_route(
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(require_role("admin")), current_user: User = Depends(require_role("admin")),
): ) -> list[UserOut]:
"""Return a list of all users. **Requires admin role.**""" """Return a list of all users. **Requires admin role.**."""
# Return list_users(db)
return list_users(db) return list_users(db)
@@ -41,31 +63,50 @@ def list_users_route(
@router.post("", response_model=UserOut, status_code=status.HTTP_201_CREATED) @router.post("", response_model=UserOut, status_code=status.HTTP_201_CREATED)
# Define function create_user_route
def create_user_route( def create_user_route(
# Entry: payload
payload: UserCreate, payload: UserCreate,
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(require_role("admin")), current_user: User = Depends(require_role("admin")),
): ) -> UserOut:
"""Create a new user. **Requires admin role.**""" """Create a new user. **Requires admin role.**."""
# Open context manager
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
# Assign user = create_user(
user = create_user( user = create_user(
db, db,
# Keyword argument: username
username=payload.username, username=payload.username,
# Keyword argument: email
email=payload.email, email=payload.email,
# Keyword argument: password
password=payload.password, password=payload.password,
# Keyword argument: role
role=payload.role, role=payload.role,
) )
# Call log_action()
log_action( log_action(
db, db,
# Keyword argument: user_id
user_id=current_user.id, user_id=current_user.id,
# Keyword argument: action
action="create_user", action="create_user",
# Keyword argument: entity_type
entity_type="user", entity_type="user",
# Keyword argument: entity_id
entity_id=user.id, entity_id=user.id,
# Keyword argument: details
details={"username": user.username, "role": user.role}, details={"username": user.username, "role": user.role},
) )
# Call uow.commit()
uow.commit() uow.commit()
# Reload ORM object attributes from the database
db.refresh(user) db.refresh(user)
# Return user
return user return user
@@ -75,12 +116,17 @@ def create_user_route(
@router.get("/{user_id}", response_model=UserOut) @router.get("/{user_id}", response_model=UserOut)
# Define function get_user
def get_user( def get_user(
# Entry: user_id
user_id: uuid.UUID, user_id: uuid.UUID,
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(require_role("admin")), current_user: User = Depends(require_role("admin")),
): ) -> UserOut:
"""Return a single user by ID. **Requires admin role.**""" """Return a single user by ID. **Requires admin role.**."""
# Return get_user_or_raise(db, user_id)
return get_user_or_raise(db, user_id) return get_user_or_raise(db, user_id)
@@ -90,25 +136,42 @@ def get_user(
@router.patch("/{user_id}", response_model=UserOut) @router.patch("/{user_id}", response_model=UserOut)
# Define function update_user_route
def update_user_route( def update_user_route(
# Entry: user_id
user_id: uuid.UUID, user_id: uuid.UUID,
# Entry: payload
payload: UserUpdate, payload: UserUpdate,
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(require_role("admin")), current_user: User = Depends(require_role("admin")),
): ) -> UserOut:
"""Update one or more fields of an existing user. **Requires admin role.**""" """Update one or more fields of an existing user. **Requires admin role.**."""
# Assign update_data = payload.model_dump(exclude_unset=True)
update_data = payload.model_dump(exclude_unset=True) update_data = payload.model_dump(exclude_unset=True)
# Open context manager
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
# Assign user = update_user(db, user_id, **update_data)
user = update_user(db, user_id, **update_data) user = update_user(db, user_id, **update_data)
# Call log_action()
log_action( log_action(
db, db,
# Keyword argument: user_id
user_id=current_user.id, user_id=current_user.id,
# Keyword argument: action
action="update_user", action="update_user",
# Keyword argument: entity_type
entity_type="user", entity_type="user",
# Keyword argument: entity_id
entity_id=user.id, entity_id=user.id,
# Keyword argument: details
details={"updated_fields": list(update_data.keys())}, details={"updated_fields": list(update_data.keys())},
) )
# Call uow.commit()
uow.commit() uow.commit()
# Reload ORM object attributes from the database
db.refresh(user) db.refresh(user)
# Return user
return user return user
+138 -9
View File
@@ -1,19 +1,39 @@
"""Worklog router — internal time-tracking records with integrity verification.""" """Worklog router — internal time-tracking records with integrity verification."""
# Import datetime from datetime
from datetime import datetime from datetime import datetime
# Import Optional from typing
from typing import Optional from typing import Optional
# Import UUID from uuid
from uuid import UUID from uuid import UUID
from fastapi import APIRouter, Depends, Query # Import APIRouter, Depends from fastapi
from fastapi import APIRouter, Depends
# Import BaseModel, Field from pydantic
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
# Import get_db from app.database
from app.database import get_db from app.database import get_db
# Import get_current_user, require_any_role from app.dependencies.auth
from app.dependencies.auth import get_current_user, require_any_role from app.dependencies.auth import get_current_user, require_any_role
# Import UnitOfWork from app.domain.unit_of_work
from app.domain.unit_of_work import UnitOfWork from app.domain.unit_of_work import UnitOfWork
# Import User from app.models.user
from app.models.user import User from app.models.user import User
# Import worklog_service from app.services
from app.services import worklog_service from app.services import worklog_service
# Assign router = APIRouter(prefix="/worklogs", tags=["worklogs"])
router = APIRouter(prefix="/worklogs", tags=["worklogs"]) router = APIRouter(prefix="/worklogs", tags=["worklogs"])
@@ -21,30 +41,58 @@ router = APIRouter(prefix="/worklogs", tags=["worklogs"])
class WorklogCreate(BaseModel): class WorklogCreate(BaseModel):
"""Payload for logging a work session against an entity."""
# Assign entity_type = Field(..., max_length=50)
entity_type: str = Field(..., max_length=50) entity_type: str = Field(..., max_length=50)
# entity_id: UUID
entity_id: UUID entity_id: UUID
# Assign activity_type = Field(..., max_length=100)
activity_type: str = Field(..., max_length=100) activity_type: str = Field(..., max_length=100)
# started_at: datetime
started_at: datetime started_at: datetime
# Assign ended_at = None
ended_at: Optional[datetime] = None ended_at: Optional[datetime] = None
# Assign duration_seconds = Field(..., gt=0)
duration_seconds: int = Field(..., gt=0) duration_seconds: int = Field(..., gt=0)
# Assign description = None
description: Optional[str] = None description: Optional[str] = None
# Define class WorklogOut
class WorklogOut(BaseModel): class WorklogOut(BaseModel):
"""Serialized worklog entry returned by the API."""
# id: UUID
id: UUID id: UUID
# entity_type: str
entity_type: str entity_type: str
# entity_id: UUID
entity_id: UUID entity_id: UUID
# user_id: UUID
user_id: UUID user_id: UUID
# activity_type: str
activity_type: str activity_type: str
# started_at: datetime
started_at: datetime started_at: datetime
# Assign ended_at = None
ended_at: Optional[datetime] = None ended_at: Optional[datetime] = None
# duration_seconds: int
duration_seconds: int duration_seconds: int
# Assign description = None
description: Optional[str] = None description: Optional[str] = None
# Assign tempo_synced = None
tempo_synced: Optional[datetime] = None tempo_synced: Optional[datetime] = None
# Assign integrity_hash = None
integrity_hash: Optional[str] = None integrity_hash: Optional[str] = None
# created_at: datetime
created_at: datetime created_at: datetime
# Define class Config
class Config: class Config:
"""ORM mode configuration for SQLAlchemy model mapping."""
# Assign from_attributes = True
from_attributes = True from_attributes = True
@@ -52,65 +100,146 @@ class WorklogOut(BaseModel):
@router.post("", response_model=WorklogOut, status_code=201) @router.post("", response_model=WorklogOut, status_code=201)
# Define function create
def create( def create(
# Entry: body
body: WorklogCreate, body: WorklogCreate,
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: user
user: User = Depends(require_any_role("red_tech", "blue_tech", "red_lead", "blue_lead")), user: User = Depends(require_any_role("red_tech", "blue_tech", "red_lead", "blue_lead")),
): ) -> WorklogOut:
"""Create a manually-logged worklog entry.""" """Create a manually-logged worklog entry.
Args:
body (WorklogCreate): Worklog fields including entity, activity type, and duration.
db (Session): SQLAlchemy database session.
user (User): Authenticated team member creating the worklog.
Returns:
WorklogOut: The newly created worklog with integrity hash and all fields.
"""
# Open context manager
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
# Assign wl = worklog_service.create_worklog(
wl = worklog_service.create_worklog( wl = worklog_service.create_worklog(
db, db,
# Keyword argument: entity_type
entity_type=body.entity_type, entity_type=body.entity_type,
# Keyword argument: entity_id
entity_id=body.entity_id, entity_id=body.entity_id,
# Keyword argument: user_id
user_id=user.id, user_id=user.id,
# Keyword argument: activity_type
activity_type=body.activity_type, activity_type=body.activity_type,
# Keyword argument: started_at
started_at=body.started_at, started_at=body.started_at,
# Keyword argument: ended_at
ended_at=body.ended_at, ended_at=body.ended_at,
# Keyword argument: duration_seconds
duration_seconds=body.duration_seconds, duration_seconds=body.duration_seconds,
# Keyword argument: description
description=body.description, description=body.description,
) )
# Call uow.commit()
uow.commit() uow.commit()
# Reload ORM object attributes from the database
db.refresh(wl) db.refresh(wl)
# Return wl
return wl return wl
# Apply the @router.get decorator
@router.get("", response_model=list[WorklogOut]) @router.get("", response_model=list[WorklogOut])
# Define function list_all
def list_all( def list_all(
# Entry: entity_type
entity_type: Optional[str] = None, entity_type: Optional[str] = None,
# Entry: entity_id
entity_id: Optional[UUID] = None, entity_id: Optional[UUID] = None,
# Entry: user_id
user_id: Optional[UUID] = None, user_id: Optional[UUID] = None,
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: _user
_user: User = Depends(get_current_user), _user: User = Depends(get_current_user),
): ) -> list[WorklogOut]:
"""List worklogs with optional filters.""" """List worklogs with optional filters.
Args:
entity_type (Optional[str]): Filter by entity type (e.g. ``test``, ``campaign``).
entity_id (Optional[UUID]): Filter by the UUID of the associated entity.
user_id (Optional[UUID]): Filter by the UUID of the worklog author.
db (Session): SQLAlchemy database session.
_user (User): Authenticated user making the request (unused, enforces auth).
Returns:
list[WorklogOut]: Serialised list of worklog entries matching the filters.
"""
# Return worklog_service.list_worklogs(
return worklog_service.list_worklogs( return worklog_service.list_worklogs(
db, db,
# Keyword argument: entity_type
entity_type=entity_type, entity_type=entity_type,
# Keyword argument: entity_id
entity_id=entity_id, entity_id=entity_id,
# Keyword argument: user_id
user_id=user_id, user_id=user_id,
) )
# Apply the @router.get decorator
@router.get("/{worklog_id}", response_model=WorklogOut) @router.get("/{worklog_id}", response_model=WorklogOut)
# Define function get_one
def get_one( def get_one(
# Entry: worklog_id
worklog_id: UUID, worklog_id: UUID,
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: _user
_user: User = Depends(get_current_user), _user: User = Depends(get_current_user),
): ) -> WorklogOut:
"""Get a single worklog by ID.""" """Get a single worklog by ID.
Args:
worklog_id (UUID): Primary key of the worklog to retrieve.
db (Session): SQLAlchemy database session.
_user (User): Authenticated user making the request (unused, enforces auth).
Returns:
WorklogOut: Full worklog detail including integrity hash.
"""
# Return worklog_service.get_worklog_or_raise(db, worklog_id)
return worklog_service.get_worklog_or_raise(db, worklog_id) return worklog_service.get_worklog_or_raise(db, worklog_id)
# Apply the @router.get decorator
@router.get("/{worklog_id}/verify") @router.get("/{worklog_id}/verify")
# Define function verify_integrity
def verify_integrity( def verify_integrity(
# Entry: worklog_id
worklog_id: UUID, worklog_id: UUID,
# Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: _user
_user: User = Depends(get_current_user), _user: User = Depends(get_current_user),
): ) -> dict:
"""Check whether a worklog's integrity hash is still valid.""" """Check whether a worklog's integrity hash is still valid.
Args:
worklog_id (UUID): Primary key of the worklog to verify.
db (Session): SQLAlchemy database session.
_user (User): Authenticated user making the request (unused, enforces auth).
Returns:
dict: Contains ``worklog_id`` (str) and ``integrity_valid`` (bool).
"""
# Assign wl = worklog_service.get_worklog_or_raise(db, worklog_id)
wl = worklog_service.get_worklog_or_raise(db, worklog_id) wl = worklog_service.get_worklog_or_raise(db, worklog_id)
# Return {
return { return {
# Literal argument value
"worklog_id": str(wl.id), "worklog_id": str(wl.id),
# Literal argument value
"integrity_valid": worklog_service.verify_worklog_integrity(wl), "integrity_valid": worklog_service.verify_worklog_integrity(wl),
} }
+30 -8
View File
@@ -1,7 +1,12 @@
"""Pydantic schemas — re-exported for convenient imports.""" """Pydantic schemas — re-exported for convenient imports."""
# Import LoginRequest, TokenResponse, UserOut from app.schemas.auth
from app.schemas.auth import LoginRequest, TokenResponse, UserOut from app.schemas.auth import LoginRequest, TokenResponse, UserOut
# Import EvidenceOut, EvidenceUpload from app.schemas.evidence
from app.schemas.evidence import EvidenceOut, EvidenceUpload
# Import from app.schemas.technique
from app.schemas.technique import ( from app.schemas.technique import (
TechniqueCreate, TechniqueCreate,
TechniqueOut, TechniqueOut,
@@ -9,51 +14,68 @@ from app.schemas.technique import (
TechniqueUpdate, TechniqueUpdate,
) )
# Import from app.schemas.test
from app.schemas.test import ( from app.schemas.test import (
TestBlueUpdate,
TestBlueValidate,
TestCreate, TestCreate,
TestOut, TestOut,
TestRedUpdate,
TestRedValidate,
TestUpdate, TestUpdate,
TestValidate, TestValidate,
TestRedUpdate,
TestBlueUpdate,
TestRedValidate,
TestBlueValidate,
) )
from app.schemas.evidence import EvidenceOut, EvidenceUpload # Import from app.schemas.test_template
from app.schemas.test_template import ( from app.schemas.test_template import (
TestTemplateOut,
TestTemplateCreate, TestTemplateCreate,
TestTemplateSummary,
TestTemplateInstantiate, TestTemplateInstantiate,
TestTemplateOut,
TestTemplateSummary,
) )
# Assign __all__ = [
__all__ = [ __all__ = [
# Auth # Auth
"LoginRequest", "LoginRequest",
# Literal argument value
"TokenResponse", "TokenResponse",
# Literal argument value
"UserOut", "UserOut",
# Technique # Technique
"TechniqueCreate", "TechniqueCreate",
# Literal argument value
"TechniqueOut", "TechniqueOut",
# Literal argument value
"TechniqueSummary", "TechniqueSummary",
# Literal argument value
"TechniqueUpdate", "TechniqueUpdate",
# Test # Test
"TestCreate", "TestCreate",
# Literal argument value
"TestOut", "TestOut",
# Literal argument value
"TestUpdate", "TestUpdate",
# Literal argument value
"TestValidate", "TestValidate",
# Literal argument value
"TestRedUpdate", "TestRedUpdate",
# Literal argument value
"TestBlueUpdate", "TestBlueUpdate",
# Literal argument value
"TestRedValidate", "TestRedValidate",
# Literal argument value
"TestBlueValidate", "TestBlueValidate",
# Evidence # Evidence
"EvidenceOut", "EvidenceOut",
# Literal argument value
"EvidenceUpload", "EvidenceUpload",
# Test Template # Test Template
"TestTemplateOut", "TestTemplateOut",
# Literal argument value
"TestTemplateCreate", "TestTemplateCreate",
# Literal argument value
"TestTemplateSummary", "TestTemplateSummary",
# Literal argument value
"TestTemplateInstantiate", "TestTemplateInstantiate",
] ]
+21
View File
@@ -1,31 +1,52 @@
"""Pydantic schemas for Audit Log endpoints.""" """Pydantic schemas for Audit Log endpoints."""
# Import uuid
import uuid import uuid
# Import datetime from datetime
from datetime import datetime from datetime import datetime
# Import Any from typing
from typing import Any from typing import Any
# Import BaseModel, ConfigDict from pydantic
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
# Define class AuditLogOut
class AuditLogOut(BaseModel): class AuditLogOut(BaseModel):
"""Complete representation of an audit log entry.""" """Complete representation of an audit log entry."""
# id: uuid.UUID
id: uuid.UUID id: uuid.UUID
# Assign user_id = None
user_id: uuid.UUID | None = None user_id: uuid.UUID | None = None
# Assign username = None # Populated from user relationship
username: str | None = None # Populated from user relationship username: str | None = None # Populated from user relationship
# action: str
action: str action: str
# Assign entity_type = None
entity_type: str | None = None entity_type: str | None = None
# Assign entity_id = None
entity_id: str | None = None entity_id: str | None = None
# timestamp: datetime
timestamp: datetime timestamp: datetime
# Assign details = None
details: dict[str, Any] | None = None details: dict[str, Any] | None = None
# Assign model_config = ConfigDict(from_attributes=True)
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)
# Define class AuditLogPage
class AuditLogPage(BaseModel): class AuditLogPage(BaseModel):
"""Paginated response for audit logs.""" """Paginated response for audit logs."""
# items: list[AuditLogOut]
items: list[AuditLogOut] items: list[AuditLogOut]
# total: int
total: int total: int
# offset: int
offset: int offset: int
# limit: int
limit: int limit: int
+24 -2
View File
@@ -1,34 +1,56 @@
"""Pydantic schemas for authentication endpoints.""" """Pydantic schemas for authentication endpoints."""
# Import uuid
import uuid import uuid
# Import BaseModel from pydantic
from pydantic import BaseModel from pydantic import BaseModel
# Define class LoginRequest
class LoginRequest(BaseModel): class LoginRequest(BaseModel):
"""Body for the login endpoint (unused directly — we rely on """Body for the login endpoint.
``OAuth2PasswordRequestForm``, but kept for documentation / testing)."""
Unused directly we rely on ``OAuth2PasswordRequestForm``, but kept for
documentation and testing purposes.
"""
# username: str
username: str username: str
# password: str
password: str password: str
# Define class TokenResponse
class TokenResponse(BaseModel): class TokenResponse(BaseModel):
"""Response returned after a successful login.""" """Response returned after a successful login."""
# access_token: str
access_token: str access_token: str
# Assign token_type = "bearer"
token_type: str = "bearer" token_type: str = "bearer"
# Define class UserOut
class UserOut(BaseModel): class UserOut(BaseModel):
"""Public representation of a user (no password hash).""" """Public representation of a user (no password hash)."""
# id: uuid.UUID
id: uuid.UUID id: uuid.UUID
# username: str
username: str username: str
# Assign email = None
email: str | None = None email: str | None = None
# role: str
role: str role: str
# is_active: bool
is_active: bool is_active: bool
# Assign must_change_password = True
must_change_password: bool = True must_change_password: bool = True
# Define class Config
class Config: class Config:
"""ORM mode configuration for SQLAlchemy model mapping."""
# Assign from_attributes = True
from_attributes = True from_attributes = True
+19
View File
@@ -1,34 +1,53 @@
"""Pydantic schemas for Evidence endpoints.""" """Pydantic schemas for Evidence endpoints."""
# Import uuid
import uuid import uuid
# Import datetime from datetime
from datetime import datetime from datetime import datetime
# Import BaseModel, ConfigDict from pydantic
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
# Import TeamSide from app.models.enums
from app.models.enums import TeamSide from app.models.enums import TeamSide
# Define class EvidenceOut
class EvidenceOut(BaseModel): class EvidenceOut(BaseModel):
"""Representation of an evidence record returned by the API. """Representation of an evidence record returned by the API.
``download_url`` is a presigned URL generated at response time. ``download_url`` is a presigned URL generated at response time.
""" """
# id: uuid.UUID
id: uuid.UUID id: uuid.UUID
# test_id: uuid.UUID
test_id: uuid.UUID test_id: uuid.UUID
# file_name: str
file_name: str file_name: str
# sha256_hash: str
sha256_hash: str sha256_hash: str
# Assign uploaded_by = None
uploaded_by: uuid.UUID | None = None uploaded_by: uuid.UUID | None = None
# Assign uploaded_at = None
uploaded_at: datetime | None = None uploaded_at: datetime | None = None
# Assign team = TeamSide.red
team: TeamSide = TeamSide.red team: TeamSide = TeamSide.red
# Assign notes = None
notes: str | None = None notes: str | None = None
# Assign download_url = None
download_url: str | None = None download_url: str | None = None
# Assign model_config = ConfigDict(from_attributes=True)
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)
# Define class EvidenceUpload
class EvidenceUpload(BaseModel): class EvidenceUpload(BaseModel):
"""Metadata sent alongside an evidence file upload.""" """Metadata sent alongside an evidence file upload."""
# team: TeamSide
team: TeamSide team: TeamSide
# Assign notes = None
notes: str | None = None notes: str | None = None
+45
View File
@@ -1,46 +1,91 @@
"""Pydantic schemas for Jira integration endpoints.""" """Pydantic schemas for Jira integration endpoints."""
# Import datetime from datetime
from datetime import datetime from datetime import datetime
# Import Optional from typing
from typing import Optional from typing import Optional
# Import UUID from uuid
from uuid import UUID from uuid import UUID
# Import BaseModel, Field from pydantic
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
# Import JiraLinkEntityType, JiraSyncDirection from app.models.jira_link
from app.models.jira_link import JiraLinkEntityType, JiraSyncDirection from app.models.jira_link import JiraLinkEntityType, JiraSyncDirection
# Define class JiraLinkCreate
class JiraLinkCreate(BaseModel): class JiraLinkCreate(BaseModel):
"""Payload for linking an Aegis entity to an existing Jira issue."""
# entity_type: JiraLinkEntityType
entity_type: JiraLinkEntityType entity_type: JiraLinkEntityType
# entity_id: UUID
entity_id: UUID entity_id: UUID
# Assign jira_issue_key = Field(..., pattern=r"^[A-Z][A-Z0-9]+-\d+$")
jira_issue_key: str = Field(..., pattern=r"^[A-Z][A-Z0-9]+-\d+$") jira_issue_key: str = Field(..., pattern=r"^[A-Z][A-Z0-9]+-\d+$")
# Assign sync_direction = JiraSyncDirection.bidirectional
sync_direction: JiraSyncDirection = JiraSyncDirection.bidirectional sync_direction: JiraSyncDirection = JiraSyncDirection.bidirectional
# Define class JiraLinkOut
class JiraLinkOut(BaseModel): class JiraLinkOut(BaseModel):
"""Full representation of a Jira link returned by the API."""
# id: UUID
id: UUID id: UUID
# entity_type: JiraLinkEntityType
entity_type: JiraLinkEntityType entity_type: JiraLinkEntityType
# entity_id: UUID
entity_id: UUID entity_id: UUID
# jira_issue_key: str
jira_issue_key: str jira_issue_key: str
# Assign jira_issue_id = None
jira_issue_id: Optional[str] = None jira_issue_id: Optional[str] = None
# Assign jira_project_key = None
jira_project_key: Optional[str] = None jira_project_key: Optional[str] = None
# Assign jira_status = None
jira_status: Optional[str] = None jira_status: Optional[str] = None
# Assign jira_priority = None
jira_priority: Optional[str] = None jira_priority: Optional[str] = None
# Assign jira_assignee = None
jira_assignee: Optional[str] = None jira_assignee: Optional[str] = None
# Assign jira_story_points = None
jira_story_points: Optional[str] = None jira_story_points: Optional[str] = None
# Assign last_synced_at = None
last_synced_at: Optional[datetime] = None last_synced_at: Optional[datetime] = None
# created_at: datetime
created_at: datetime created_at: datetime
# Define class Config
class Config: class Config:
"""ORM mode configuration for SQLAlchemy model mapping."""
# Assign from_attributes = True
from_attributes = True from_attributes = True
# Define class JiraIssueSearch
class JiraIssueSearch(BaseModel): class JiraIssueSearch(BaseModel):
"""Payload for searching Jira issues by free-text query."""
# query: str
query: str query: str
# Define class JiraIssueResult
class JiraIssueResult(BaseModel): class JiraIssueResult(BaseModel):
"""Lightweight Jira issue representation returned by search results."""
# issue_key: str
issue_key: str issue_key: str
# summary: str
summary: str summary: str
# status: str
status: str status: str
# Assign assignee = None
assignee: Optional[str] = None assignee: Optional[str] = None
# Assign priority = None
priority: Optional[str] = None priority: Optional[str] = None

Some files were not shown because too many files have changed in this diff Show More