Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 64cc438bcc | |||
| 8fea0c1ada | |||
| 1f19bd8432 | |||
| d2a46feba8 | |||
| 9ff0f04ba3 | |||
| 8f98bdd273 |
@@ -0,0 +1 @@
|
|||||||
|
"""Aegis — MITRE ATT&CK Coverage Platform application package."""
|
||||||
|
|||||||
+40
-4
@@ -1,23 +1,32 @@
|
|||||||
"""
|
"""Security utilities: password hashing and JWT token management.
|
||||||
Security utilities: password hashing and JWT token management.
|
|
||||||
|
|
||||||
This module provides pure functions for:
|
This module provides pure functions for:
|
||||||
- Hashing and verifying passwords using bcrypt via passlib.
|
- Hashing and verifying passwords using bcrypt via passlib.
|
||||||
- Creating JWT access tokens using python-jose.
|
- Creating JWT access tokens using PyJWT.
|
||||||
- Managing a Redis-backed token blacklist for revocation.
|
- Managing a Redis-backed token blacklist for revocation.
|
||||||
|
|
||||||
No endpoints are defined here.
|
No endpoints are defined here.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Import logging
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
# Import uuid
|
||||||
import uuid as _uuid
|
import uuid as _uuid
|
||||||
|
|
||||||
|
# Import datetime, timedelta, timezone from datetime
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
from jose import jwt
|
# Import jwt (PyJWT)
|
||||||
|
import jwt
|
||||||
|
|
||||||
|
# Import CryptContext from passlib.context
|
||||||
from passlib.context import CryptContext
|
from passlib.context import CryptContext
|
||||||
|
|
||||||
|
# Import settings from app.config
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
|
|
||||||
|
# Assign logger = logging.getLogger(__name__)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -27,13 +36,17 @@ logger = logging.getLogger(__name__)
|
|||||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||||
|
|
||||||
|
|
||||||
|
# Define function hash_password
|
||||||
def hash_password(password: str) -> str:
|
def hash_password(password: str) -> str:
|
||||||
"""Return a bcrypt hash of *password*."""
|
"""Return a bcrypt hash of *password*."""
|
||||||
|
# Return pwd_context.hash(password)
|
||||||
return pwd_context.hash(password)
|
return pwd_context.hash(password)
|
||||||
|
|
||||||
|
|
||||||
|
# Define function verify_password
|
||||||
def verify_password(plain: str, hashed: str) -> bool:
|
def verify_password(plain: str, hashed: str) -> bool:
|
||||||
"""Return ``True`` if *plain* matches the bcrypt *hashed* value."""
|
"""Return ``True`` if *plain* matches the bcrypt *hashed* value."""
|
||||||
|
# Return pwd_context.verify(plain, hashed)
|
||||||
return pwd_context.verify(plain, hashed)
|
return pwd_context.verify(plain, hashed)
|
||||||
|
|
||||||
|
|
||||||
@@ -48,14 +61,21 @@ def create_access_token(data: dict) -> str:
|
|||||||
- ``jti`` (JWT ID): unique identifier that enables token revocation.
|
- ``jti`` (JWT ID): unique identifier that enables token revocation.
|
||||||
- ``exp``: expiration timestamp based on ``ACCESS_TOKEN_EXPIRE_MINUTES``.
|
- ``exp``: expiration timestamp based on ``ACCESS_TOKEN_EXPIRE_MINUTES``.
|
||||||
"""
|
"""
|
||||||
|
# Assign to_encode = data.copy()
|
||||||
to_encode = data.copy()
|
to_encode = data.copy()
|
||||||
|
# Assign expire = datetime.now(timezone.utc) + timedelta(
|
||||||
expire = datetime.now(timezone.utc) + timedelta(
|
expire = datetime.now(timezone.utc) + timedelta(
|
||||||
|
# Keyword argument: minutes
|
||||||
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES,
|
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES,
|
||||||
)
|
)
|
||||||
|
# Call to_encode.update()
|
||||||
to_encode.update({
|
to_encode.update({
|
||||||
|
# Literal argument value
|
||||||
"exp": expire,
|
"exp": expire,
|
||||||
|
# Literal argument value
|
||||||
"jti": str(_uuid.uuid4()),
|
"jti": str(_uuid.uuid4()),
|
||||||
})
|
})
|
||||||
|
# Return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGOR...
|
||||||
return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
||||||
|
|
||||||
|
|
||||||
@@ -73,6 +93,7 @@ def create_access_token(data: dict) -> str:
|
|||||||
_BLACKLIST_PREFIX = "blacklist:"
|
_BLACKLIST_PREFIX = "blacklist:"
|
||||||
|
|
||||||
|
|
||||||
|
# Define function blacklist_token
|
||||||
def blacklist_token(jti: str, exp: float) -> None:
|
def blacklist_token(jti: str, exp: float) -> None:
|
||||||
"""Add *jti* to the Redis blacklist with a TTL derived from *exp*.
|
"""Add *jti* to the Redis blacklist with a TTL derived from *exp*.
|
||||||
|
|
||||||
@@ -80,23 +101,38 @@ def blacklist_token(jti: str, exp: float) -> None:
|
|||||||
to ``exp - now`` so the key vanishes when the token would have expired
|
to ``exp - now`` so the key vanishes when the token would have expired
|
||||||
naturally.
|
naturally.
|
||||||
"""
|
"""
|
||||||
|
# Import get_redis_blacklist from app.infrastructure.redis_client
|
||||||
from app.infrastructure.redis_client import get_redis_blacklist
|
from app.infrastructure.redis_client import get_redis_blacklist
|
||||||
|
|
||||||
|
# Assign ttl = max(int(exp - datetime.now(timezone.utc).timestamp()), 1)
|
||||||
ttl = max(int(exp - datetime.now(timezone.utc).timestamp()), 1)
|
ttl = max(int(exp - datetime.now(timezone.utc).timestamp()), 1)
|
||||||
|
# Attempt the following; catch errors below
|
||||||
try:
|
try:
|
||||||
|
# Assign r = get_redis_blacklist()
|
||||||
r = get_redis_blacklist()
|
r = get_redis_blacklist()
|
||||||
|
# Call r.setex()
|
||||||
r.setex(f"{_BLACKLIST_PREFIX}{jti}", ttl, "1")
|
r.setex(f"{_BLACKLIST_PREFIX}{jti}", ttl, "1")
|
||||||
|
# Handle Exception
|
||||||
except Exception:
|
except Exception:
|
||||||
|
# Log warning: "Failed to blacklist token %s in Redis", jti, exc_
|
||||||
logger.warning("Failed to blacklist token %s in Redis", jti, exc_info=True)
|
logger.warning("Failed to blacklist token %s in Redis", jti, exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
|
# Define function is_token_blacklisted
|
||||||
def is_token_blacklisted(jti: str) -> bool:
|
def is_token_blacklisted(jti: str) -> bool:
|
||||||
"""Return ``True`` if *jti* has been revoked (exists in Redis)."""
|
"""Return ``True`` if *jti* has been revoked (exists in Redis)."""
|
||||||
|
# Import get_redis_blacklist from app.infrastructure.redis_client
|
||||||
from app.infrastructure.redis_client import get_redis_blacklist
|
from app.infrastructure.redis_client import get_redis_blacklist
|
||||||
|
|
||||||
|
# Attempt the following; catch errors below
|
||||||
try:
|
try:
|
||||||
|
# Assign r = get_redis_blacklist()
|
||||||
r = get_redis_blacklist()
|
r = get_redis_blacklist()
|
||||||
|
# Return r.exists(f"{_BLACKLIST_PREFIX}{jti}") > 0
|
||||||
return r.exists(f"{_BLACKLIST_PREFIX}{jti}") > 0
|
return r.exists(f"{_BLACKLIST_PREFIX}{jti}") > 0
|
||||||
|
# Handle Exception
|
||||||
except Exception:
|
except Exception:
|
||||||
|
# Log warning: "Failed to check blacklist for %s in Redis", jti,
|
||||||
logger.warning("Failed to check blacklist for %s in Redis", jti, exc_info=True)
|
logger.warning("Failed to check blacklist for %s in Redis", jti, exc_info=True)
|
||||||
|
# Return False
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -1,7 +1,21 @@
|
|||||||
|
"""Application configuration for the Aegis MITRE ATT&CK Coverage Platform.
|
||||||
|
|
||||||
|
Loads settings from environment variables and ``.env`` files via
|
||||||
|
``pydantic-settings``. Validates critical secrets at import time and raises
|
||||||
|
``RuntimeError`` (production) or issues a ``UserWarning`` (development) when
|
||||||
|
unsafe defaults are detected.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Import os
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
# Import secrets
|
||||||
import secrets
|
import secrets
|
||||||
|
|
||||||
|
# Import warnings
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
|
# Import BaseSettings from pydantic_settings
|
||||||
from pydantic_settings import BaseSettings
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -10,7 +24,11 @@ from pydantic_settings import BaseSettings
|
|||||||
_is_production = os.environ.get("AEGIS_ENV", "").lower() == "production"
|
_is_production = os.environ.get("AEGIS_ENV", "").lower() == "production"
|
||||||
|
|
||||||
|
|
||||||
|
# Define class Settings
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
|
"""Application settings loaded from environment variables and .env file."""
|
||||||
|
|
||||||
|
# Assign DATABASE_URL = "postgresql://postgres:postgres@postgres:5432/attackdb"
|
||||||
DATABASE_URL: str = "postgresql://postgres:postgres@postgres:5432/attackdb"
|
DATABASE_URL: str = "postgresql://postgres:postgres@postgres:5432/attackdb"
|
||||||
|
|
||||||
# ── Security ──────────────────────────────────────────────────────
|
# ── Security ──────────────────────────────────────────────────────
|
||||||
@@ -19,6 +37,7 @@ class Settings(BaseSettings):
|
|||||||
# for local dev). In production it MUST be supplied via env/.env
|
# for local dev). In production it MUST be supplied via env/.env
|
||||||
# so tokens survive restarts.
|
# so tokens survive restarts.
|
||||||
SECRET_KEY: str = ""
|
SECRET_KEY: str = ""
|
||||||
|
# Assign ALGORITHM = "HS256"
|
||||||
ALGORITHM: str = "HS256"
|
ALGORITHM: str = "HS256"
|
||||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 480 # 8 hours — /auth/refresh extends active sessions
|
ACCESS_TOKEN_EXPIRE_MINUTES: int = 480 # 8 hours — /auth/refresh extends active sessions
|
||||||
|
|
||||||
@@ -26,6 +45,7 @@ class Settings(BaseSettings):
|
|||||||
REDIS_URL: str = "redis://redis:6379/0"
|
REDIS_URL: str = "redis://redis:6379/0"
|
||||||
# Logical DB indices on the same Redis instance (PATH in URL is overridden).
|
# Logical DB indices on the same Redis instance (PATH in URL is overridden).
|
||||||
REDIS_TOKEN_BLACKLIST_DB: int = 1
|
REDIS_TOKEN_BLACKLIST_DB: int = 1
|
||||||
|
# Assign REDIS_CACHE_DB = 2
|
||||||
REDIS_CACHE_DB: int = 2
|
REDIS_CACHE_DB: int = 2
|
||||||
|
|
||||||
# ── CORS ─────────────────────────────────────────────────────────
|
# ── CORS ─────────────────────────────────────────────────────────
|
||||||
@@ -41,8 +61,11 @@ class Settings(BaseSettings):
|
|||||||
# the browser can reach MinIO directly. Defaults to MINIO_ENDPOINT.
|
# the browser can reach MinIO directly. Defaults to MINIO_ENDPOINT.
|
||||||
MINIO_PUBLIC_ENDPOINT: str = ""
|
MINIO_PUBLIC_ENDPOINT: str = ""
|
||||||
MINIO_ACCESS_KEY: str = "minioadmin"
|
MINIO_ACCESS_KEY: str = "minioadmin"
|
||||||
|
# Assign MINIO_SECRET_KEY = "minioadmin"
|
||||||
MINIO_SECRET_KEY: str = "minioadmin"
|
MINIO_SECRET_KEY: str = "minioadmin"
|
||||||
|
# Assign MINIO_BUCKET = "evidence"
|
||||||
MINIO_BUCKET: str = "evidence"
|
MINIO_BUCKET: str = "evidence"
|
||||||
|
# Assign MINIO_SECURE = False # True → use HTTPS to connect to MinIO
|
||||||
MINIO_SECURE: bool = False # True → use HTTPS to connect to MinIO
|
MINIO_SECURE: bool = False # True → use HTTPS to connect to MinIO
|
||||||
|
|
||||||
# ── Re-testing ───────────────────────────────────────────────────
|
# ── Re-testing ───────────────────────────────────────────────────
|
||||||
@@ -50,10 +73,15 @@ class Settings(BaseSettings):
|
|||||||
|
|
||||||
# ── Jira Integration ────────────────────────────────────────────
|
# ── Jira Integration ────────────────────────────────────────────
|
||||||
JIRA_ENABLED: bool = False
|
JIRA_ENABLED: bool = False
|
||||||
|
# Assign JIRA_URL = ""
|
||||||
JIRA_URL: str = ""
|
JIRA_URL: str = ""
|
||||||
|
# Assign JIRA_USERNAME = ""
|
||||||
JIRA_USERNAME: str = ""
|
JIRA_USERNAME: str = ""
|
||||||
|
# Assign JIRA_API_TOKEN = ""
|
||||||
JIRA_API_TOKEN: str = ""
|
JIRA_API_TOKEN: str = ""
|
||||||
|
# Assign JIRA_IS_CLOUD = True
|
||||||
JIRA_IS_CLOUD: bool = True
|
JIRA_IS_CLOUD: bool = True
|
||||||
|
# Assign JIRA_DEFAULT_PROJECT = ""
|
||||||
JIRA_DEFAULT_PROJECT: str = ""
|
JIRA_DEFAULT_PROJECT: str = ""
|
||||||
JIRA_ISSUE_TYPE_TEST: str = "Task" # tests (campaign or standalone)
|
JIRA_ISSUE_TYPE_TEST: str = "Task" # tests (campaign or standalone)
|
||||||
JIRA_ISSUE_TYPE_CAMPAIGN: str = "Epic" # campaigns (under Initiative)
|
JIRA_ISSUE_TYPE_CAMPAIGN: str = "Epic" # campaigns (under Initiative)
|
||||||
@@ -63,8 +91,11 @@ class Settings(BaseSettings):
|
|||||||
|
|
||||||
# ── Tempo Integration ─────────────────────────────────────────────
|
# ── Tempo Integration ─────────────────────────────────────────────
|
||||||
TEMPO_ENABLED: bool = False
|
TEMPO_ENABLED: bool = False
|
||||||
|
# Assign TEMPO_API_TOKEN = ""
|
||||||
TEMPO_API_TOKEN: str = ""
|
TEMPO_API_TOKEN: str = ""
|
||||||
|
# Assign TEMPO_API_VERSION = 4
|
||||||
TEMPO_API_VERSION: int = 4
|
TEMPO_API_VERSION: int = 4
|
||||||
|
# Assign TEMPO_DEFAULT_WORK_TYPE = "Red Team"
|
||||||
TEMPO_DEFAULT_WORK_TYPE: str = "Red Team"
|
TEMPO_DEFAULT_WORK_TYPE: str = "Red Team"
|
||||||
# Tempo API base URL — use https://api.eu.tempo.io/4 for EU workspaces.
|
# Tempo API base URL — use https://api.eu.tempo.io/4 for EU workspaces.
|
||||||
# Can also be set via system_configs key "tempo.base_url" at runtime.
|
# Can also be set via system_configs key "tempo.base_url" at runtime.
|
||||||
@@ -72,12 +103,16 @@ class Settings(BaseSettings):
|
|||||||
|
|
||||||
# ── OSINT / Intelligence ────────────────────────────────────────
|
# ── OSINT / Intelligence ────────────────────────────────────────
|
||||||
NVD_API_KEY: str = "" # optional; increases NVD rate limit from 5/30s to 50/30s
|
NVD_API_KEY: str = "" # optional; increases NVD rate limit from 5/30s to 50/30s
|
||||||
|
# Assign STALE_THRESHOLD_DAYS = 365 # days before coverage is considered stale
|
||||||
STALE_THRESHOLD_DAYS: int = 365 # days before coverage is considered stale
|
STALE_THRESHOLD_DAYS: int = 365 # days before coverage is considered stale
|
||||||
|
|
||||||
# ── Reporting ─────────────────────────────────────────────────────
|
# ── Reporting ─────────────────────────────────────────────────────
|
||||||
REPORT_TEMPLATES_DIR: str = "app/templates/reports"
|
REPORT_TEMPLATES_DIR: str = "app/templates/reports"
|
||||||
|
# Assign REPORT_OUTPUT_DIR = "/tmp/aegis_reports"
|
||||||
REPORT_OUTPUT_DIR: str = "/tmp/aegis_reports"
|
REPORT_OUTPUT_DIR: str = "/tmp/aegis_reports"
|
||||||
|
# Assign COMPANY_NAME = "Organization"
|
||||||
COMPANY_NAME: str = "Organization"
|
COMPANY_NAME: str = "Organization"
|
||||||
|
# Assign COMPANY_LOGO_PATH = "app/templates/reports/assets/logo.png"
|
||||||
COMPANY_LOGO_PATH: str = "app/templates/reports/assets/logo.png"
|
COMPANY_LOGO_PATH: str = "app/templates/reports/assets/logo.png"
|
||||||
|
|
||||||
# ── Email / SMTP ──────────────────────────────────────────────────
|
# ── Email / SMTP ──────────────────────────────────────────────────
|
||||||
@@ -92,43 +127,68 @@ class Settings(BaseSettings):
|
|||||||
|
|
||||||
# ── Scoring weights (must sum to 100) ────────────────────────────
|
# ── Scoring weights (must sum to 100) ────────────────────────────
|
||||||
SCORING_WEIGHT_TESTS: int = 40
|
SCORING_WEIGHT_TESTS: int = 40
|
||||||
|
# Assign SCORING_WEIGHT_DETECTION_RULES = 25
|
||||||
SCORING_WEIGHT_DETECTION_RULES: int = 25
|
SCORING_WEIGHT_DETECTION_RULES: int = 25
|
||||||
|
# Assign SCORING_WEIGHT_D3FEND = 15
|
||||||
SCORING_WEIGHT_D3FEND: int = 15
|
SCORING_WEIGHT_D3FEND: int = 15
|
||||||
|
# Assign SCORING_WEIGHT_RECENCY = 10
|
||||||
SCORING_WEIGHT_RECENCY: int = 10
|
SCORING_WEIGHT_RECENCY: int = 10
|
||||||
|
# Assign SCORING_WEIGHT_SEVERITY = 10
|
||||||
SCORING_WEIGHT_SEVERITY: int = 10
|
SCORING_WEIGHT_SEVERITY: int = 10
|
||||||
# Legacy env names (mapped in scoring_config_service)
|
# Legacy env names (mapped in scoring_config_service)
|
||||||
SCORING_WEIGHT_FRESHNESS: int = 10
|
SCORING_WEIGHT_FRESHNESS: int = 10
|
||||||
|
# Assign SCORING_WEIGHT_PLATFORM_DIVERSITY = 10
|
||||||
SCORING_WEIGHT_PLATFORM_DIVERSITY: int = 10
|
SCORING_WEIGHT_PLATFORM_DIVERSITY: int = 10
|
||||||
|
|
||||||
|
# Define class Config
|
||||||
class Config:
|
class Config:
|
||||||
|
"""Pydantic BaseSettings configuration — load from .env file."""
|
||||||
|
|
||||||
|
# Assign env_file = ".env"
|
||||||
env_file = ".env"
|
env_file = ".env"
|
||||||
|
|
||||||
|
|
||||||
|
# Assign settings = Settings()
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Post-init validation for SECRET_KEY
|
# Post-init validation for SECRET_KEY
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
_UNSAFE_SECRETS = {
|
_UNSAFE_SECRETS = {
|
||||||
|
# Literal argument value
|
||||||
"",
|
"",
|
||||||
|
# Literal argument value
|
||||||
"change-me-in-production",
|
"change-me-in-production",
|
||||||
|
# Literal argument value
|
||||||
"change-me-in-production-use-a-long-random-string",
|
"change-me-in-production-use-a-long-random-string",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Check: settings.SECRET_KEY in _UNSAFE_SECRETS
|
||||||
if settings.SECRET_KEY in _UNSAFE_SECRETS:
|
if settings.SECRET_KEY in _UNSAFE_SECRETS:
|
||||||
|
# Check: _is_production
|
||||||
if _is_production:
|
if _is_production:
|
||||||
|
# Raise RuntimeError
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
|
# Literal argument value
|
||||||
"CRITICAL: SECRET_KEY is not configured. "
|
"CRITICAL: SECRET_KEY is not configured. "
|
||||||
|
# Literal argument value
|
||||||
"Set a strong random value (>= 32 chars) via the SECRET_KEY "
|
"Set a strong random value (>= 32 chars) via the SECRET_KEY "
|
||||||
|
# Literal argument value
|
||||||
"environment variable or in your .env file before running in "
|
"environment variable or in your .env file before running in "
|
||||||
|
# Literal argument value
|
||||||
"production. Example: openssl rand -hex 32"
|
"production. Example: openssl rand -hex 32"
|
||||||
)
|
)
|
||||||
# Development: auto-generate an ephemeral key and warn
|
# Development: auto-generate an ephemeral key and warn
|
||||||
settings.SECRET_KEY = secrets.token_hex(32)
|
settings.SECRET_KEY = secrets.token_hex(32)
|
||||||
|
# Call warnings.warn()
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
|
# Literal argument value
|
||||||
"SECRET_KEY was not set — using an auto-generated ephemeral key. "
|
"SECRET_KEY was not set — using an auto-generated ephemeral key. "
|
||||||
|
# Literal argument value
|
||||||
"JWT tokens will be invalidated on every restart. "
|
"JWT tokens will be invalidated on every restart. "
|
||||||
|
# Literal argument value
|
||||||
"Set SECRET_KEY in your environment for persistent sessions.",
|
"Set SECRET_KEY in your environment for persistent sessions.",
|
||||||
|
# Keyword argument: stacklevel
|
||||||
stacklevel=2,
|
stacklevel=2,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -136,12 +196,16 @@ if settings.SECRET_KEY in _UNSAFE_SECRETS:
|
|||||||
# SEC-002: Reject default credentials in production
|
# SEC-002: Reject default credentials in production
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
if _is_production:
|
if _is_production:
|
||||||
|
# Assign _DEFAULT_CREDS = {
|
||||||
_DEFAULT_CREDS = {
|
_DEFAULT_CREDS = {
|
||||||
("MINIO_ACCESS_KEY", settings.MINIO_ACCESS_KEY, "minioadmin"),
|
("MINIO_ACCESS_KEY", settings.MINIO_ACCESS_KEY, "minioadmin"),
|
||||||
("MINIO_SECRET_KEY", settings.MINIO_SECRET_KEY, "minioadmin"),
|
("MINIO_SECRET_KEY", settings.MINIO_SECRET_KEY, "minioadmin"),
|
||||||
}
|
}
|
||||||
|
# Iterate over _DEFAULT_CREDS
|
||||||
for name, current, default in _DEFAULT_CREDS:
|
for name, current, default in _DEFAULT_CREDS:
|
||||||
|
# Check: current == default
|
||||||
if current == default:
|
if current == default:
|
||||||
|
# Raise RuntimeError
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"CRITICAL: {name} is using the default value '{default}'. "
|
f"CRITICAL: {name} is using the default value '{default}'. "
|
||||||
f"Set a strong value via the {name} environment variable "
|
f"Set a strong value via the {name} environment variable "
|
||||||
|
|||||||
+106
-10
@@ -1,68 +1,164 @@
|
|||||||
from sqlalchemy import create_engine
|
"""Database engine and session management for the Aegis platform.
|
||||||
from sqlalchemy.orm import sessionmaker, declarative_base
|
|
||||||
|
|
||||||
|
The engine and session factory are created lazily so that tests can override
|
||||||
|
``DATABASE_URL`` via environment variables before any import triggers real
|
||||||
|
PostgreSQL engine creation (which requires psycopg2).
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Import Generator from collections.abc
|
||||||
|
from collections.abc import Generator
|
||||||
|
|
||||||
|
# Import create_engine from sqlalchemy
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
|
||||||
|
# Import Engine from sqlalchemy.engine
|
||||||
|
from sqlalchemy.engine import Engine
|
||||||
|
|
||||||
|
# Import Session, declarative_base, sessionmaker from sqlalchemy.orm
|
||||||
|
from sqlalchemy.orm import Session, declarative_base, sessionmaker
|
||||||
|
|
||||||
|
# Assign Base = declarative_base()
|
||||||
Base = declarative_base()
|
Base = declarative_base()
|
||||||
|
|
||||||
# Engine and session factory are created lazily so that tests can
|
# Engine and session factory are created lazily so that tests can
|
||||||
# override DATABASE_URL via environment *before* any import triggers
|
# override DATABASE_URL via environment *before* any import triggers
|
||||||
# the real PostgreSQL engine creation (which requires psycopg2).
|
# the real PostgreSQL engine creation (which requires psycopg2).
|
||||||
_engine = None
|
_engine = None
|
||||||
|
# Assign _SessionLocal = None
|
||||||
_SessionLocal = None
|
_SessionLocal = None
|
||||||
|
|
||||||
|
|
||||||
def _get_engine():
|
# Define function _get_engine
|
||||||
|
def _get_engine() -> Engine:
|
||||||
|
"""Return the shared SQLAlchemy engine, creating it on first call.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Engine: Configured SQLAlchemy engine for the application database.
|
||||||
|
"""
|
||||||
|
# Declare global variable
|
||||||
global _engine
|
global _engine
|
||||||
|
# Check: _engine is None
|
||||||
if _engine is None:
|
if _engine is None:
|
||||||
|
# Import settings from app.config
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
|
|
||||||
|
# Assign url = settings.DATABASE_URL
|
||||||
url = settings.DATABASE_URL
|
url = settings.DATABASE_URL
|
||||||
|
# Assign kwargs = {}
|
||||||
kwargs: dict = {}
|
kwargs: dict = {}
|
||||||
|
# Check: url.startswith("postgresql")
|
||||||
if url.startswith("postgresql"):
|
if url.startswith("postgresql"):
|
||||||
|
# Call kwargs.update()
|
||||||
kwargs.update(
|
kwargs.update(
|
||||||
|
# Keyword argument: pool_size
|
||||||
pool_size=20,
|
pool_size=20,
|
||||||
|
# Keyword argument: max_overflow
|
||||||
max_overflow=10,
|
max_overflow=10,
|
||||||
|
# Keyword argument: pool_recycle
|
||||||
pool_recycle=3600,
|
pool_recycle=3600,
|
||||||
|
# Keyword argument: pool_pre_ping
|
||||||
pool_pre_ping=True,
|
pool_pre_ping=True,
|
||||||
)
|
)
|
||||||
|
# Assign _engine = create_engine(url, **kwargs)
|
||||||
_engine = create_engine(url, **kwargs)
|
_engine = create_engine(url, **kwargs)
|
||||||
|
# Return _engine
|
||||||
return _engine
|
return _engine
|
||||||
|
|
||||||
|
|
||||||
def _get_session_factory():
|
# Define function _get_session_factory
|
||||||
|
def _get_session_factory() -> sessionmaker:
|
||||||
|
"""Return the shared sessionmaker, creating it on first call.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
sessionmaker: Configured sessionmaker bound to the application engine.
|
||||||
|
"""
|
||||||
|
# Declare global variable
|
||||||
global _SessionLocal
|
global _SessionLocal
|
||||||
|
# Check: _SessionLocal is None
|
||||||
if _SessionLocal is None:
|
if _SessionLocal is None:
|
||||||
|
# Assign _SessionLocal = sessionmaker(
|
||||||
_SessionLocal = sessionmaker(
|
_SessionLocal = sessionmaker(
|
||||||
|
# Keyword argument: autocommit
|
||||||
autocommit=False, autoflush=False, bind=_get_engine()
|
autocommit=False, autoflush=False, bind=_get_engine()
|
||||||
)
|
)
|
||||||
|
# Return _SessionLocal
|
||||||
return _SessionLocal
|
return _SessionLocal
|
||||||
|
|
||||||
|
|
||||||
|
# Define class _LazySessionLocal
|
||||||
class _LazySessionLocal:
|
class _LazySessionLocal:
|
||||||
"""Proxy so ``SessionLocal()`` keeps working as before but the real
|
"""Proxy so ``SessionLocal()`` keeps working as before but the real sessionmaker is only created on first call."""
|
||||||
sessionmaker is only created on first call."""
|
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
# Define function __call__
|
||||||
|
def __call__(self, *args: object, **kwargs: object) -> Session:
|
||||||
|
"""Create and return a new database session.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*args (object): Positional arguments forwarded to the sessionmaker.
|
||||||
|
**kwargs (object): Keyword arguments forwarded to the sessionmaker.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Session: A new SQLAlchemy database session.
|
||||||
|
"""
|
||||||
|
# Return _get_session_factory()(*args, **kwargs)
|
||||||
return _get_session_factory()(*args, **kwargs)
|
return _get_session_factory()(*args, **kwargs)
|
||||||
|
|
||||||
def __getattr__(self, name):
|
# Define function __getattr__
|
||||||
|
def __getattr__(self, name: str) -> object:
|
||||||
|
"""Delegate attribute access to the underlying sessionmaker.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): Attribute name to look up on the sessionmaker.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
object: The attribute value from the underlying sessionmaker.
|
||||||
|
"""
|
||||||
|
# Return getattr(_get_session_factory(), name)
|
||||||
return getattr(_get_session_factory(), name)
|
return getattr(_get_session_factory(), name)
|
||||||
|
|
||||||
|
|
||||||
|
# Assign SessionLocal = _LazySessionLocal()
|
||||||
SessionLocal = _LazySessionLocal()
|
SessionLocal = _LazySessionLocal()
|
||||||
|
|
||||||
|
|
||||||
|
# Define class _EngineProxy
|
||||||
class _EngineProxy:
|
class _EngineProxy:
|
||||||
"""Thin proxy so ``from app.database import engine`` still works."""
|
"""Thin proxy so ``from app.database import engine`` still works."""
|
||||||
def __getattr__(self, name):
|
|
||||||
|
# Define function __getattr__
|
||||||
|
def __getattr__(self, name: str) -> object:
|
||||||
|
"""Delegate attribute access to the lazily-created engine.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): Attribute name to look up on the real engine.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
object: The attribute value from the underlying SQLAlchemy engine.
|
||||||
|
"""
|
||||||
|
# Return getattr(_get_engine(), name)
|
||||||
return getattr(_get_engine(), name)
|
return getattr(_get_engine(), name)
|
||||||
|
|
||||||
|
|
||||||
|
# Assign engine = _EngineProxy() # type: ignore[assignment]
|
||||||
engine = _EngineProxy() # type: ignore[assignment]
|
engine = _EngineProxy() # type: ignore[assignment]
|
||||||
|
|
||||||
|
|
||||||
def get_db():
|
# Define function get_db
|
||||||
|
def get_db() -> Generator[Session, None, None]:
|
||||||
|
"""Yield a database session and close it when the request is done.
|
||||||
|
|
||||||
|
Intended for use as a FastAPI dependency.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Session: An active SQLAlchemy session for the current request.
|
||||||
|
"""
|
||||||
|
# Assign db = SessionLocal()
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
|
# Attempt the following; catch errors below
|
||||||
try:
|
try:
|
||||||
|
# Yield db
|
||||||
yield db
|
yield db
|
||||||
|
# Always execute this cleanup block
|
||||||
finally:
|
finally:
|
||||||
|
# Close the database session
|
||||||
db.close()
|
db.close()
|
||||||
|
|||||||
@@ -0,0 +1 @@
|
|||||||
|
"""FastAPI dependency injection helpers for auth, DB, and shared state."""
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
"""
|
"""Authentication and RBAC dependencies for FastAPI.
|
||||||
Authentication and RBAC dependencies for FastAPI.
|
|
||||||
|
|
||||||
Provides:
|
Provides:
|
||||||
- ``get_current_user``: decodes JWT from HttpOnly cookie (preferred) or
|
- ``get_current_user``: decodes JWT from HttpOnly cookie (preferred) or
|
||||||
@@ -9,16 +8,34 @@ Provides:
|
|||||||
(admins always pass).
|
(admins always pass).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Import Callable from collections.abc
|
||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
|
# Import Optional from typing
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
# Import Cookie, Depends, HTTPException, status from fastapi
|
||||||
from fastapi import Cookie, Depends, HTTPException, status
|
from fastapi import Cookie, Depends, HTTPException, status
|
||||||
|
|
||||||
|
# Import OAuth2PasswordBearer from fastapi.security
|
||||||
from fastapi.security import OAuth2PasswordBearer
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
from jose import JWTError, jwt
|
|
||||||
|
# Import jwt (PyJWT)
|
||||||
|
import jwt
|
||||||
|
|
||||||
|
# Import Session from sqlalchemy.orm
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
# Import auth as auth_lib from app
|
||||||
from app import auth as auth_lib
|
from app import auth as auth_lib
|
||||||
|
|
||||||
|
# Import settings from app.config
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
|
|
||||||
|
# Import get_db from app.database
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
|
|
||||||
|
# Import User from app.models.user
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.models.api_key import KEY_PREFIX
|
from app.models.api_key import KEY_PREFIX
|
||||||
|
|
||||||
@@ -37,8 +54,11 @@ _COOKIE_NAME = "aegis_token"
|
|||||||
|
|
||||||
|
|
||||||
async def get_current_user(
|
async def get_current_user(
|
||||||
|
# Entry: aegis_token
|
||||||
aegis_token: Optional[str] = Cookie(None),
|
aegis_token: Optional[str] = Cookie(None),
|
||||||
|
# Entry: bearer_token
|
||||||
bearer_token: Optional[str] = Depends(oauth2_scheme),
|
bearer_token: Optional[str] = Depends(oauth2_scheme),
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> User:
|
) -> User:
|
||||||
"""Decode the JWT, look up the user in *db*, and return it.
|
"""Decode the JWT, look up the user in *db*, and return it.
|
||||||
@@ -54,20 +74,30 @@ async def get_current_user(
|
|||||||
- the ``sub`` claim is missing, or
|
- the ``sub`` claim is missing, or
|
||||||
- no matching active user exists in the database.
|
- no matching active user exists in the database.
|
||||||
"""
|
"""
|
||||||
|
# Assign credentials_exception = HTTPException(
|
||||||
credentials_exception = HTTPException(
|
credentials_exception = HTTPException(
|
||||||
|
# Keyword argument: status_code
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
# Keyword argument: detail
|
||||||
detail="Could not validate credentials",
|
detail="Could not validate credentials",
|
||||||
|
# Keyword argument: headers
|
||||||
headers={"WWW-Authenticate": "Bearer"},
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
)
|
)
|
||||||
|
# Assign revoked_exception = HTTPException(
|
||||||
revoked_exception = HTTPException(
|
revoked_exception = HTTPException(
|
||||||
|
# Keyword argument: status_code
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
# Keyword argument: detail
|
||||||
detail="Token has been revoked",
|
detail="Token has been revoked",
|
||||||
|
# Keyword argument: headers
|
||||||
headers={"WWW-Authenticate": "Bearer"},
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Prefer cookie, fall back to header
|
# Prefer cookie, fall back to header
|
||||||
token = aegis_token or bearer_token
|
token = aegis_token or bearer_token
|
||||||
|
# Check: token is None
|
||||||
if token is None:
|
if token is None:
|
||||||
|
# Raise credentials_exception
|
||||||
raise credentials_exception
|
raise credentials_exception
|
||||||
|
|
||||||
# ── API Key path (Bearer token starts with "aegis_") ──────────────────
|
# ── API Key path (Bearer token starts with "aegis_") ──────────────────
|
||||||
@@ -80,25 +110,38 @@ async def get_current_user(
|
|||||||
|
|
||||||
# ── JWT path ──────────────────────────────────────────────────────────
|
# ── JWT path ──────────────────────────────────────────────────────────
|
||||||
try:
|
try:
|
||||||
|
# Assign payload = jwt.decode(
|
||||||
payload = jwt.decode(
|
payload = jwt.decode(
|
||||||
token,
|
token,
|
||||||
settings.SECRET_KEY,
|
settings.SECRET_KEY,
|
||||||
|
# Keyword argument: algorithms
|
||||||
algorithms=[settings.ALGORITHM],
|
algorithms=[settings.ALGORITHM],
|
||||||
)
|
)
|
||||||
|
# Assign username = payload.get("sub")
|
||||||
username: str | None = payload.get("sub")
|
username: str | None = payload.get("sub")
|
||||||
|
# Check: username is None
|
||||||
if username is None:
|
if username is None:
|
||||||
|
# Raise credentials_exception
|
||||||
raise credentials_exception
|
raise credentials_exception
|
||||||
# Check token blacklist (revoked tokens)
|
# Check token blacklist (revoked tokens)
|
||||||
jti: str | None = payload.get("jti")
|
jti: str | None = payload.get("jti")
|
||||||
|
# Check: jti and auth_lib.is_token_blacklisted(jti)
|
||||||
if jti and auth_lib.is_token_blacklisted(jti):
|
if jti and auth_lib.is_token_blacklisted(jti):
|
||||||
|
# Raise revoked_exception
|
||||||
raise revoked_exception
|
raise revoked_exception
|
||||||
except JWTError:
|
# Handle any JWT validation error (expired, invalid signature, malformed)
|
||||||
|
except jwt.exceptions.InvalidTokenError:
|
||||||
|
# Raise credentials_exception
|
||||||
raise credentials_exception
|
raise credentials_exception
|
||||||
|
|
||||||
|
# Assign user = db.query(User).filter(User.username == username).first()
|
||||||
user = db.query(User).filter(User.username == username).first()
|
user = db.query(User).filter(User.username == username).first()
|
||||||
|
# Check: user is None or not user.is_active
|
||||||
if user is None or not user.is_active:
|
if user is None or not user.is_active:
|
||||||
|
# Raise credentials_exception
|
||||||
raise credentials_exception
|
raise credentials_exception
|
||||||
|
|
||||||
|
# Return user
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
@@ -108,6 +151,7 @@ async def get_current_user(
|
|||||||
|
|
||||||
|
|
||||||
async def require_password_changed(
|
async def require_password_changed(
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
) -> User:
|
) -> User:
|
||||||
"""Block all requests when the user still needs to change their password.
|
"""Block all requests when the user still needs to change their password.
|
||||||
@@ -115,11 +159,16 @@ async def require_password_changed(
|
|||||||
Only ``/auth/change-password`` and ``/auth/me`` are exempt — those
|
Only ``/auth/change-password`` and ``/auth/me`` are exempt — those
|
||||||
endpoints do **not** depend on this function.
|
endpoints do **not** depend on this function.
|
||||||
"""
|
"""
|
||||||
|
# Check: getattr(current_user, "must_change_password", False)
|
||||||
if getattr(current_user, "must_change_password", False):
|
if getattr(current_user, "must_change_password", False):
|
||||||
|
# Raise HTTPException
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
# Keyword argument: status_code
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
# Keyword argument: detail
|
||||||
detail="PASSWORD_CHANGE_REQUIRED",
|
detail="PASSWORD_CHANGE_REQUIRED",
|
||||||
)
|
)
|
||||||
|
# Return current_user
|
||||||
return current_user
|
return current_user
|
||||||
|
|
||||||
|
|
||||||
@@ -147,22 +196,30 @@ def require_role(required_role: str):
|
|||||||
Otherwise it raises :class:`~fastapi.HTTPException` **403**.
|
Otherwise it raises :class:`~fastapi.HTTPException` **403**.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Define async function role_checker
|
||||||
async def role_checker(
|
async def role_checker(
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
) -> User:
|
) -> User:
|
||||||
|
# Check: current_user.role != required_role and current_user.role != "admin"
|
||||||
if current_user.role != required_role and current_user.role != "admin":
|
if current_user.role != required_role and current_user.role != "admin":
|
||||||
|
# Raise HTTPException
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
# Keyword argument: status_code
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
# Keyword argument: detail
|
||||||
detail="Not enough permissions",
|
detail="Not enough permissions",
|
||||||
)
|
)
|
||||||
scope = "admin" if required_role == "admin" else "write"
|
scope = "admin" if required_role == "admin" else "write"
|
||||||
_check_api_key_scope(current_user, scope)
|
_check_api_key_scope(current_user, scope)
|
||||||
return current_user
|
return current_user
|
||||||
|
|
||||||
|
# Return role_checker
|
||||||
return role_checker
|
return role_checker
|
||||||
|
|
||||||
|
|
||||||
def require_any_role(*roles: str):
|
# Define function require_any_role
|
||||||
|
def require_any_role(*roles: str) -> Callable[..., object]:
|
||||||
"""Return a FastAPI dependency that enforces **any** of the given *roles*.
|
"""Return a FastAPI dependency that enforces **any** of the given *roles*.
|
||||||
|
|
||||||
Admins always pass. Also enforces API key scopes: if the only accepted
|
Admins always pass. Also enforces API key scopes: if the only accepted
|
||||||
@@ -174,18 +231,25 @@ def require_any_role(*roles: str):
|
|||||||
@router.patch("/resource", dependencies=[Depends(require_any_role("red_lead", "blue_lead"))])
|
@router.patch("/resource", dependencies=[Depends(require_any_role("red_lead", "blue_lead"))])
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Define async function role_checker
|
||||||
async def role_checker(
|
async def role_checker(
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
) -> User:
|
) -> User:
|
||||||
|
# Check: current_user.role != "admin" and current_user.role not in roles
|
||||||
if current_user.role != "admin" and current_user.role not in roles:
|
if current_user.role != "admin" and current_user.role not in roles:
|
||||||
|
# Raise HTTPException
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
# Keyword argument: status_code
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
# Keyword argument: detail
|
||||||
detail="Not enough permissions",
|
detail="Not enough permissions",
|
||||||
)
|
)
|
||||||
scope = "admin" if set(roles) == {"admin"} else "write"
|
scope = "admin" if set(roles) == {"admin"} else "write"
|
||||||
_check_api_key_scope(current_user, scope)
|
_check_api_key_scope(current_user, scope)
|
||||||
return current_user
|
return current_user
|
||||||
|
|
||||||
|
# Return role_checker
|
||||||
return role_checker
|
return role_checker
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -4,27 +4,41 @@ Wiring lives ONLY in the presentation layer — use cases and services
|
|||||||
never know which concrete repository implementation they receive.
|
never know which concrete repository implementation they receive.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Import Depends from fastapi
|
||||||
from fastapi import Depends
|
from fastapi import Depends
|
||||||
|
|
||||||
|
# Import Session from sqlalchemy.orm
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
# Import get_db from app.database
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
|
|
||||||
|
# Import from app.infrastructure.persistence.repositories.sa_technique_repository
|
||||||
from app.infrastructure.persistence.repositories.sa_technique_repository import (
|
from app.infrastructure.persistence.repositories.sa_technique_repository import (
|
||||||
SATechniqueRepository,
|
SATechniqueRepository,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Import from app.infrastructure.persistence.repositories.sa_test_repository
|
||||||
from app.infrastructure.persistence.repositories.sa_test_repository import (
|
from app.infrastructure.persistence.repositories.sa_test_repository import (
|
||||||
SATestRepository,
|
SATestRepository,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Define function get_technique_repository
|
||||||
def get_technique_repository(
|
def get_technique_repository(
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> SATechniqueRepository:
|
) -> SATechniqueRepository:
|
||||||
"""Provide a TechniqueRepository backed by the current DB session."""
|
"""Provide a TechniqueRepository backed by the current DB session."""
|
||||||
|
# Return SATechniqueRepository(db)
|
||||||
return SATechniqueRepository(db)
|
return SATechniqueRepository(db)
|
||||||
|
|
||||||
|
|
||||||
|
# Define function get_test_repository
|
||||||
def get_test_repository(
|
def get_test_repository(
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> SATestRepository:
|
) -> SATestRepository:
|
||||||
"""Provide a TestRepository backed by the current DB session."""
|
"""Provide a TestRepository backed by the current DB session."""
|
||||||
|
# Return SATestRepository(db)
|
||||||
return SATestRepository(db)
|
return SATestRepository(db)
|
||||||
|
|||||||
@@ -0,0 +1 @@
|
|||||||
|
"""Domain layer — entities, value objects, errors, and repository ports."""
|
||||||
|
|||||||
@@ -1,18 +1,34 @@
|
|||||||
|
"""Domain entity classes representing core business objects."""
|
||||||
|
# Import CampaignEntity from app.domain.entities.campaign
|
||||||
from app.domain.entities.campaign import CampaignEntity
|
from app.domain.entities.campaign import CampaignEntity
|
||||||
|
|
||||||
|
# Import from app.domain.entities.compliance
|
||||||
from app.domain.entities.compliance import (
|
from app.domain.entities.compliance import (
|
||||||
ComplianceControlEntity,
|
ComplianceControlEntity,
|
||||||
ComplianceFrameworkEntity,
|
ComplianceFrameworkEntity,
|
||||||
ControlCoverageStatus,
|
ControlCoverageStatus,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Import TechniqueEntity from app.domain.entities.technique
|
||||||
from app.domain.entities.technique import TechniqueEntity
|
from app.domain.entities.technique import TechniqueEntity
|
||||||
|
|
||||||
|
# Import ThreatActorEntity, ThreatActorTechniqueRef from app.domain.entities.threat_actor
|
||||||
from app.domain.entities.threat_actor import ThreatActorEntity, ThreatActorTechniqueRef
|
from app.domain.entities.threat_actor import ThreatActorEntity, ThreatActorTechniqueRef
|
||||||
|
|
||||||
|
# Assign __all__ = [
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
# Literal argument value
|
||||||
"CampaignEntity",
|
"CampaignEntity",
|
||||||
|
# Literal argument value
|
||||||
"ComplianceControlEntity",
|
"ComplianceControlEntity",
|
||||||
|
# Literal argument value
|
||||||
"ComplianceFrameworkEntity",
|
"ComplianceFrameworkEntity",
|
||||||
|
# Literal argument value
|
||||||
"ControlCoverageStatus",
|
"ControlCoverageStatus",
|
||||||
|
# Literal argument value
|
||||||
"TechniqueEntity",
|
"TechniqueEntity",
|
||||||
|
# Literal argument value
|
||||||
"ThreatActorEntity",
|
"ThreatActorEntity",
|
||||||
|
# Literal argument value
|
||||||
"ThreatActorTechniqueRef",
|
"ThreatActorTechniqueRef",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -3,30 +3,59 @@
|
|||||||
Pure domain logic — no framework imports.
|
Pure domain logic — no framework imports.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Enable future language features for compatibility
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
# Import enum
|
||||||
import enum
|
import enum
|
||||||
import uuid
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
|
# Import uuid
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
# Import dataclass, field from dataclasses
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
# Import TYPE_CHECKING from typing
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
# Import BusinessRuleViolation, InvalidStateTransition from app.domain.errors
|
||||||
from app.domain.errors import BusinessRuleViolation, InvalidStateTransition
|
from app.domain.errors import BusinessRuleViolation, InvalidStateTransition
|
||||||
|
|
||||||
|
# Check: TYPE_CHECKING
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
# Import Campaign as CampaignORM from app.models.campaign
|
||||||
|
from app.models.campaign import Campaign as CampaignORM
|
||||||
|
|
||||||
|
|
||||||
|
# Define class CampaignStatus
|
||||||
class CampaignStatus(str, enum.Enum):
|
class CampaignStatus(str, enum.Enum):
|
||||||
|
"""Lifecycle states for a campaign."""
|
||||||
|
|
||||||
|
# Assign draft = "draft"
|
||||||
draft = "draft"
|
draft = "draft"
|
||||||
|
# Assign active = "active"
|
||||||
active = "active"
|
active = "active"
|
||||||
|
# Assign completed = "completed"
|
||||||
completed = "completed"
|
completed = "completed"
|
||||||
|
# Assign archived = "archived"
|
||||||
archived = "archived"
|
archived = "archived"
|
||||||
|
|
||||||
|
|
||||||
|
# Define class CampaignType
|
||||||
class CampaignType(str, enum.Enum):
|
class CampaignType(str, enum.Enum):
|
||||||
|
"""Classification of the campaign's testing methodology."""
|
||||||
|
|
||||||
|
# Assign custom = "custom"
|
||||||
custom = "custom"
|
custom = "custom"
|
||||||
|
# Assign apt_emulation = "apt_emulation"
|
||||||
apt_emulation = "apt_emulation"
|
apt_emulation = "apt_emulation"
|
||||||
|
# Assign kill_chain = "kill_chain"
|
||||||
kill_chain = "kill_chain"
|
kill_chain = "kill_chain"
|
||||||
|
# Assign compliance = "compliance"
|
||||||
compliance = "compliance"
|
compliance = "compliance"
|
||||||
|
|
||||||
|
|
||||||
|
# Assign VALID_TRANSITIONS = {
|
||||||
VALID_TRANSITIONS: dict[CampaignStatus, list[CampaignStatus]] = {
|
VALID_TRANSITIONS: dict[CampaignStatus, list[CampaignStatus]] = {
|
||||||
CampaignStatus.draft: [CampaignStatus.active],
|
CampaignStatus.draft: [CampaignStatus.active],
|
||||||
CampaignStatus.active: [CampaignStatus.completed],
|
CampaignStatus.active: [CampaignStatus.completed],
|
||||||
@@ -35,69 +64,156 @@ VALID_TRANSITIONS: dict[CampaignStatus, list[CampaignStatus]] = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @dataclass decorator
|
||||||
@dataclass
|
@dataclass
|
||||||
|
# Define class CampaignEntity
|
||||||
class CampaignEntity:
|
class CampaignEntity:
|
||||||
|
"""Pure domain representation of a security testing campaign.
|
||||||
|
|
||||||
|
Owns all lifecycle state-machine logic for campaign activation,
|
||||||
|
completion, and archival.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# name: str
|
||||||
name: str
|
name: str
|
||||||
|
# Assign type = CampaignType.custom
|
||||||
type: CampaignType = CampaignType.custom
|
type: CampaignType = CampaignType.custom
|
||||||
|
# Assign status = CampaignStatus.draft
|
||||||
status: CampaignStatus = CampaignStatus.draft
|
status: CampaignStatus = CampaignStatus.draft
|
||||||
|
# Assign id = None
|
||||||
id: uuid.UUID | None = None
|
id: uuid.UUID | None = None
|
||||||
|
# Assign description = None
|
||||||
description: str | None = None
|
description: str | None = None
|
||||||
|
# Assign threat_actor_id = None
|
||||||
threat_actor_id: uuid.UUID | None = None
|
threat_actor_id: uuid.UUID | None = None
|
||||||
|
# Assign created_by = None
|
||||||
created_by: uuid.UUID | None = None
|
created_by: uuid.UUID | None = None
|
||||||
|
# Assign target_platform = None
|
||||||
target_platform: str | None = None
|
target_platform: str | None = None
|
||||||
|
# Assign tags = field(default_factory=list)
|
||||||
tags: list[str] = field(default_factory=list)
|
tags: list[str] = field(default_factory=list)
|
||||||
|
# Assign test_count = 0
|
||||||
test_count: int = 0
|
test_count: int = 0
|
||||||
|
|
||||||
|
# Define function can_transition_to
|
||||||
def can_transition_to(self, target: CampaignStatus) -> bool:
|
def can_transition_to(self, target: CampaignStatus) -> bool:
|
||||||
|
"""Check whether transitioning from the current status to *target* is valid.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target (CampaignStatus): The desired next status.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the transition is allowed, False otherwise.
|
||||||
|
"""
|
||||||
|
# Return target in VALID_TRANSITIONS.get(self.status, [])
|
||||||
return target in VALID_TRANSITIONS.get(self.status, [])
|
return target in VALID_TRANSITIONS.get(self.status, [])
|
||||||
|
|
||||||
|
# Define function activate
|
||||||
def activate(self) -> None:
|
def activate(self) -> None:
|
||||||
|
"""Transition the campaign from ``draft`` to ``active``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
# Check: not self.can_transition_to(CampaignStatus.active)
|
||||||
if not self.can_transition_to(CampaignStatus.active):
|
if not self.can_transition_to(CampaignStatus.active):
|
||||||
|
# Raise InvalidStateTransition
|
||||||
raise InvalidStateTransition(
|
raise InvalidStateTransition(
|
||||||
self.status.value, CampaignStatus.active.value,
|
self.status.value, CampaignStatus.active.value,
|
||||||
[s.value for s in VALID_TRANSITIONS[self.status]],
|
[s.value for s in VALID_TRANSITIONS[self.status]],
|
||||||
)
|
)
|
||||||
|
# Check: self.test_count == 0
|
||||||
if self.test_count == 0:
|
if self.test_count == 0:
|
||||||
|
# Raise BusinessRuleViolation
|
||||||
raise BusinessRuleViolation(
|
raise BusinessRuleViolation(
|
||||||
|
# Literal argument value
|
||||||
"Campaign must have at least one test to activate"
|
"Campaign must have at least one test to activate"
|
||||||
)
|
)
|
||||||
|
# Assign self.status = CampaignStatus.active
|
||||||
self.status = CampaignStatus.active
|
self.status = CampaignStatus.active
|
||||||
|
|
||||||
|
# Define function complete
|
||||||
def complete(self) -> None:
|
def complete(self) -> None:
|
||||||
|
"""Transition the campaign from ``active`` to ``completed``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
# Check: not self.can_transition_to(CampaignStatus.completed)
|
||||||
if not self.can_transition_to(CampaignStatus.completed):
|
if not self.can_transition_to(CampaignStatus.completed):
|
||||||
|
# Raise InvalidStateTransition
|
||||||
raise InvalidStateTransition(
|
raise InvalidStateTransition(
|
||||||
self.status.value, CampaignStatus.completed.value,
|
self.status.value, CampaignStatus.completed.value,
|
||||||
[s.value for s in VALID_TRANSITIONS[self.status]],
|
[s.value for s in VALID_TRANSITIONS[self.status]],
|
||||||
)
|
)
|
||||||
|
# Assign self.status = CampaignStatus.completed
|
||||||
self.status = CampaignStatus.completed
|
self.status = CampaignStatus.completed
|
||||||
|
|
||||||
|
# Define function archive
|
||||||
def archive(self) -> None:
|
def archive(self) -> None:
|
||||||
|
"""Transition the campaign from ``completed`` to ``archived``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
# Check: not self.can_transition_to(CampaignStatus.archived)
|
||||||
if not self.can_transition_to(CampaignStatus.archived):
|
if not self.can_transition_to(CampaignStatus.archived):
|
||||||
|
# Raise InvalidStateTransition
|
||||||
raise InvalidStateTransition(
|
raise InvalidStateTransition(
|
||||||
self.status.value, CampaignStatus.archived.value,
|
self.status.value, CampaignStatus.archived.value,
|
||||||
[s.value for s in VALID_TRANSITIONS[self.status]],
|
[s.value for s in VALID_TRANSITIONS[self.status]],
|
||||||
)
|
)
|
||||||
|
# Assign self.status = CampaignStatus.archived
|
||||||
self.status = CampaignStatus.archived
|
self.status = CampaignStatus.archived
|
||||||
|
|
||||||
|
# Define function ensure_modifiable
|
||||||
def ensure_modifiable(self) -> None:
|
def ensure_modifiable(self) -> None:
|
||||||
|
"""Raise BusinessRuleViolation if the campaign is not in a modifiable state.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
# Check: self.status not in (CampaignStatus.draft, CampaignStatus.active)
|
||||||
if self.status not in (CampaignStatus.draft, CampaignStatus.active):
|
if self.status not in (CampaignStatus.draft, CampaignStatus.active):
|
||||||
|
# Raise BusinessRuleViolation
|
||||||
raise BusinessRuleViolation(
|
raise BusinessRuleViolation(
|
||||||
f"Cannot modify campaign in '{self.status.value}' state"
|
f"Cannot modify campaign in '{self.status.value}' state"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Apply the @classmethod decorator
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_orm(cls, orm: Any) -> CampaignEntity:
|
# Define function from_orm
|
||||||
"""Build a CampaignEntity from a SQLAlchemy Campaign model."""
|
def from_orm(cls, orm: CampaignORM) -> CampaignEntity:
|
||||||
|
"""Build a CampaignEntity from a SQLAlchemy Campaign model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
orm (CampaignORM): The SQLAlchemy Campaign ORM model instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CampaignEntity: A fully populated domain entity reflecting the ORM state.
|
||||||
|
"""
|
||||||
|
# Assign test_count = len(getattr(orm, "campaign_tests", None) or [])
|
||||||
test_count = len(getattr(orm, "campaign_tests", None) or [])
|
test_count = len(getattr(orm, "campaign_tests", None) or [])
|
||||||
|
# Return cls(
|
||||||
return cls(
|
return cls(
|
||||||
|
# Keyword argument: id
|
||||||
id=orm.id,
|
id=orm.id,
|
||||||
|
# Keyword argument: name
|
||||||
name=orm.name,
|
name=orm.name,
|
||||||
|
# Keyword argument: type
|
||||||
type=CampaignType(orm.type) if orm.type else CampaignType.custom,
|
type=CampaignType(orm.type) if orm.type else CampaignType.custom,
|
||||||
|
# Keyword argument: status
|
||||||
status=CampaignStatus(orm.status) if orm.status else CampaignStatus.draft,
|
status=CampaignStatus(orm.status) if orm.status else CampaignStatus.draft,
|
||||||
|
# Keyword argument: description
|
||||||
description=orm.description,
|
description=orm.description,
|
||||||
|
# Keyword argument: threat_actor_id
|
||||||
threat_actor_id=orm.threat_actor_id,
|
threat_actor_id=orm.threat_actor_id,
|
||||||
|
# Keyword argument: created_by
|
||||||
created_by=orm.created_by,
|
created_by=orm.created_by,
|
||||||
|
# Keyword argument: target_platform
|
||||||
target_platform=orm.target_platform,
|
target_platform=orm.target_platform,
|
||||||
|
# Keyword argument: tags
|
||||||
tags=orm.tags or [],
|
tags=orm.tags or [],
|
||||||
|
# Keyword argument: test_count
|
||||||
test_count=test_count,
|
test_count=test_count,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -3,68 +3,161 @@
|
|||||||
Pure domain logic — no framework imports.
|
Pure domain logic — no framework imports.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Enable future language features for compatibility
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
# Import enum
|
||||||
import enum
|
import enum
|
||||||
|
|
||||||
|
# Import uuid
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
# Import dataclass, field from dataclasses
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
|
||||||
|
# Define class ControlCoverageStatus
|
||||||
class ControlCoverageStatus(str, enum.Enum):
|
class ControlCoverageStatus(str, enum.Enum):
|
||||||
|
"""Computed coverage level for a single compliance control."""
|
||||||
|
|
||||||
|
# Assign covered = "covered"
|
||||||
covered = "covered"
|
covered = "covered"
|
||||||
|
# Assign partially_covered = "partially_covered"
|
||||||
partially_covered = "partially_covered"
|
partially_covered = "partially_covered"
|
||||||
|
# Assign not_covered = "not_covered"
|
||||||
not_covered = "not_covered"
|
not_covered = "not_covered"
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @dataclass decorator
|
||||||
@dataclass
|
@dataclass
|
||||||
|
# Define class ComplianceControlEntity
|
||||||
class ComplianceControlEntity:
|
class ComplianceControlEntity:
|
||||||
|
"""Pure domain representation of a single compliance framework control.
|
||||||
|
|
||||||
|
Derives its coverage status from the technique statuses associated
|
||||||
|
with it via the ``technique_statuses`` list.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# control_id: str
|
||||||
control_id: str
|
control_id: str
|
||||||
|
# title: str
|
||||||
title: str
|
title: str
|
||||||
|
# Assign id = None
|
||||||
id: uuid.UUID | None = None
|
id: uuid.UUID | None = None
|
||||||
|
# Assign description = None
|
||||||
description: str | None = None
|
description: str | None = None
|
||||||
|
# Assign category = None
|
||||||
category: str | None = None
|
category: str | None = None
|
||||||
|
# Assign technique_statuses = field(default_factory=list)
|
||||||
technique_statuses: list[str] = field(default_factory=list)
|
technique_statuses: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
# Apply the @property decorator
|
||||||
@property
|
@property
|
||||||
|
# Define function coverage_status
|
||||||
def coverage_status(self) -> ControlCoverageStatus:
|
def coverage_status(self) -> ControlCoverageStatus:
|
||||||
|
"""Compute the coverage status for this control based on linked technique statuses.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ControlCoverageStatus: ``covered`` when all techniques are covered,
|
||||||
|
``partially_covered`` when at least one is covered, and
|
||||||
|
``not_covered`` when none are covered or the control has no techniques.
|
||||||
|
"""
|
||||||
|
# Check: not self.technique_statuses
|
||||||
if not self.technique_statuses:
|
if not self.technique_statuses:
|
||||||
|
# Return ControlCoverageStatus.not_covered
|
||||||
return ControlCoverageStatus.not_covered
|
return ControlCoverageStatus.not_covered
|
||||||
|
# Assign covered_statuses = {"validated", "partial"}
|
||||||
covered_statuses = {"validated", "partial"}
|
covered_statuses = {"validated", "partial"}
|
||||||
|
# Assign covered = [s for s in self.technique_statuses if s in covered_statuses]
|
||||||
covered = [s for s in self.technique_statuses if s in covered_statuses]
|
covered = [s for s in self.technique_statuses if s in covered_statuses]
|
||||||
|
# Check: len(covered) == len(self.technique_statuses)
|
||||||
if len(covered) == len(self.technique_statuses):
|
if len(covered) == len(self.technique_statuses):
|
||||||
|
# Return ControlCoverageStatus.covered
|
||||||
return ControlCoverageStatus.covered
|
return ControlCoverageStatus.covered
|
||||||
|
# Alternative: len(covered) > 0
|
||||||
elif len(covered) > 0:
|
elif len(covered) > 0:
|
||||||
|
# Return ControlCoverageStatus.partially_covered
|
||||||
return ControlCoverageStatus.partially_covered
|
return ControlCoverageStatus.partially_covered
|
||||||
|
# Return ControlCoverageStatus.not_covered
|
||||||
return ControlCoverageStatus.not_covered
|
return ControlCoverageStatus.not_covered
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @dataclass decorator
|
||||||
@dataclass
|
@dataclass
|
||||||
|
# Define class ComplianceFrameworkEntity
|
||||||
class ComplianceFrameworkEntity:
|
class ComplianceFrameworkEntity:
|
||||||
|
"""Pure domain representation of a compliance framework (e.g. NIST 800-53, PCI-DSS).
|
||||||
|
|
||||||
|
Aggregates a collection of controls and provides aggregate coverage statistics.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# name: str
|
||||||
name: str
|
name: str
|
||||||
|
# Assign id = None
|
||||||
id: uuid.UUID | None = None
|
id: uuid.UUID | None = None
|
||||||
|
# Assign version = None
|
||||||
version: str | None = None
|
version: str | None = None
|
||||||
|
# Assign description = None
|
||||||
description: str | None = None
|
description: str | None = None
|
||||||
|
# Assign is_active = True
|
||||||
is_active: bool = True
|
is_active: bool = True
|
||||||
|
# Assign controls = field(default_factory=list)
|
||||||
controls: list[ComplianceControlEntity] = field(default_factory=list)
|
controls: list[ComplianceControlEntity] = field(default_factory=list)
|
||||||
|
|
||||||
|
# Apply the @property decorator
|
||||||
@property
|
@property
|
||||||
|
# Define function total_controls
|
||||||
def total_controls(self) -> int:
|
def total_controls(self) -> int:
|
||||||
|
"""Return the total number of controls in this framework.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Count of all controls regardless of coverage status.
|
||||||
|
"""
|
||||||
|
# Return len(self.controls)
|
||||||
return len(self.controls)
|
return len(self.controls)
|
||||||
|
|
||||||
|
# Apply the @property decorator
|
||||||
@property
|
@property
|
||||||
|
# Define function covered_controls
|
||||||
def covered_controls(self) -> int:
|
def covered_controls(self) -> int:
|
||||||
|
"""Return the number of fully covered controls in this framework.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Count of controls with ``ControlCoverageStatus.covered`` status.
|
||||||
|
"""
|
||||||
|
# Return sum(
|
||||||
return sum(
|
return sum(
|
||||||
|
# Literal argument value
|
||||||
1 for c in self.controls
|
1 for c in self.controls
|
||||||
if c.coverage_status == ControlCoverageStatus.covered
|
if c.coverage_status == ControlCoverageStatus.covered
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Apply the @property decorator
|
||||||
@property
|
@property
|
||||||
|
# Define function coverage_pct
|
||||||
def coverage_pct(self) -> float:
|
def coverage_pct(self) -> float:
|
||||||
|
"""Return the percentage of controls that are fully covered.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: A value from 0.0 to 100.0, rounded to one decimal place.
|
||||||
|
Returns 0.0 when the framework has no controls.
|
||||||
|
"""
|
||||||
|
# Check: self.total_controls == 0
|
||||||
if self.total_controls == 0:
|
if self.total_controls == 0:
|
||||||
|
# Return 0.0
|
||||||
return 0.0
|
return 0.0
|
||||||
|
# Return round(self.covered_controls / self.total_controls * 100, 1)
|
||||||
return round(self.covered_controls / self.total_controls * 100, 1)
|
return round(self.covered_controls / self.total_controls * 100, 1)
|
||||||
|
|
||||||
|
# Define function get_gap_controls
|
||||||
def get_gap_controls(self) -> list[ComplianceControlEntity]:
|
def get_gap_controls(self) -> list[ComplianceControlEntity]:
|
||||||
|
"""Return controls that are not fully covered.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[ComplianceControlEntity]: Controls with ``partially_covered`` or
|
||||||
|
``not_covered`` status.
|
||||||
|
"""
|
||||||
|
# Return [
|
||||||
return [
|
return [
|
||||||
c for c in self.controls
|
c for c in self.controls
|
||||||
if c.coverage_status != ControlCoverageStatus.covered
|
if c.coverage_status != ControlCoverageStatus.covered
|
||||||
|
|||||||
@@ -12,105 +12,211 @@ Usage::
|
|||||||
entity.apply_to(technique_orm_model)
|
entity.apply_to(technique_orm_model)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Enable future language features for compatibility
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
# Import uuid
|
||||||
import uuid
|
import uuid
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
|
# Import dataclass, field from dataclasses
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
# Import datetime from datetime
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
# Import TYPE_CHECKING from typing
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
# Import TechniqueStatus, TestResult, TestState from app.domain.enums
|
||||||
from app.domain.enums import TechniqueStatus, TestResult, TestState
|
from app.domain.enums import TechniqueStatus, TestResult, TestState
|
||||||
|
|
||||||
|
# Import MitreId from app.domain.value_objects.mitre_id
|
||||||
from app.domain.value_objects.mitre_id import MitreId
|
from app.domain.value_objects.mitre_id import MitreId
|
||||||
|
|
||||||
|
# Check: TYPE_CHECKING
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
# Import Technique as TechniqueORM from app.models.technique
|
||||||
|
from app.models.technique import Technique as TechniqueORM
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @dataclass decorator
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
|
# Define class _TestSnapshot
|
||||||
class _TestSnapshot:
|
class _TestSnapshot:
|
||||||
"""Minimal read-only view of a test for status calculation."""
|
"""Minimal read-only view of a test for status calculation."""
|
||||||
|
|
||||||
|
# state: TestState
|
||||||
state: TestState
|
state: TestState
|
||||||
|
# detection_result: str | None
|
||||||
detection_result: str | None
|
detection_result: str | None
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @dataclass decorator
|
||||||
@dataclass
|
@dataclass
|
||||||
|
# Define class TechniqueEntity
|
||||||
class TechniqueEntity:
|
class TechniqueEntity:
|
||||||
"""Pure domain representation of a MITRE ATT&CK technique."""
|
"""Pure domain representation of a MITRE ATT&CK technique."""
|
||||||
|
|
||||||
|
# id: uuid.UUID
|
||||||
id: uuid.UUID
|
id: uuid.UUID
|
||||||
|
# mitre_id: str
|
||||||
mitre_id: str
|
mitre_id: str
|
||||||
|
# name: str
|
||||||
name: str
|
name: str
|
||||||
|
# Assign tactic = None
|
||||||
tactic: str | None = None
|
tactic: str | None = None
|
||||||
|
# Assign description = None
|
||||||
description: str | None = None
|
description: str | None = None
|
||||||
|
# Assign platforms = field(default_factory=list)
|
||||||
platforms: list[str] = field(default_factory=list)
|
platforms: list[str] = field(default_factory=list)
|
||||||
|
# Assign is_subtechnique = False
|
||||||
is_subtechnique: bool = False
|
is_subtechnique: bool = False
|
||||||
|
# Assign parent_mitre_id = None
|
||||||
parent_mitre_id: str | None = None
|
parent_mitre_id: str | None = None
|
||||||
|
# Assign status_global = TechniqueStatus.not_evaluated
|
||||||
status_global: TechniqueStatus = TechniqueStatus.not_evaluated
|
status_global: TechniqueStatus = TechniqueStatus.not_evaluated
|
||||||
|
# Assign review_required = False
|
||||||
review_required: bool = False
|
review_required: bool = False
|
||||||
|
# Assign last_review_date = None
|
||||||
last_review_date: datetime | None = None
|
last_review_date: datetime | None = None
|
||||||
|
# Assign mitre_version = None
|
||||||
mitre_version: str | None = None
|
mitre_version: str | None = None
|
||||||
|
# Assign mitre_last_modified = None
|
||||||
mitre_last_modified: datetime | None = None
|
mitre_last_modified: datetime | None = None
|
||||||
|
|
||||||
# -- Factory -----------------------------------------------------------
|
# -- Factory -----------------------------------------------------------
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
# Define function create
|
||||||
def create(
|
def create(
|
||||||
cls,
|
cls,
|
||||||
*,
|
*,
|
||||||
|
# Entry: mitre_id
|
||||||
mitre_id: str,
|
mitre_id: str,
|
||||||
|
# Entry: name
|
||||||
name: str,
|
name: str,
|
||||||
|
# Entry: tactic
|
||||||
tactic: str | None = None,
|
tactic: str | None = None,
|
||||||
|
# Entry: description
|
||||||
description: str | None = None,
|
description: str | None = None,
|
||||||
|
# Entry: platforms
|
||||||
platforms: list[str] | None = None,
|
platforms: list[str] | None = None,
|
||||||
) -> TechniqueEntity:
|
) -> TechniqueEntity:
|
||||||
"""Create a new technique, validating the MITRE ID format."""
|
"""Create a new technique, validating the MITRE ID format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mitre_id (str): MITRE ATT&CK identifier (e.g. ``"T1059"`` or ``"T1059.001"``).
|
||||||
|
name (str): Human-readable name of the technique.
|
||||||
|
tactic (str | None): MITRE tactic category the technique belongs to.
|
||||||
|
description (str | None): Optional free-text description.
|
||||||
|
platforms (list[str] | None): List of platform strings the technique applies to.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TechniqueEntity: A new entity with a freshly generated UUID and
|
||||||
|
``status_global`` set to ``not_evaluated``.
|
||||||
|
"""
|
||||||
|
# Assign validated_id = MitreId(mitre_id)
|
||||||
validated_id = MitreId(mitre_id)
|
validated_id = MitreId(mitre_id)
|
||||||
|
# Return cls(
|
||||||
return cls(
|
return cls(
|
||||||
|
# Keyword argument: id
|
||||||
id=uuid.uuid4(),
|
id=uuid.uuid4(),
|
||||||
|
# Keyword argument: mitre_id
|
||||||
mitre_id=validated_id.value,
|
mitre_id=validated_id.value,
|
||||||
|
# Keyword argument: name
|
||||||
name=name,
|
name=name,
|
||||||
|
# Keyword argument: tactic
|
||||||
tactic=tactic,
|
tactic=tactic,
|
||||||
|
# Keyword argument: description
|
||||||
description=description,
|
description=description,
|
||||||
|
# Keyword argument: platforms
|
||||||
platforms=platforms or [],
|
platforms=platforms or [],
|
||||||
|
# Keyword argument: is_subtechnique
|
||||||
is_subtechnique=validated_id.is_subtechnique,
|
is_subtechnique=validated_id.is_subtechnique,
|
||||||
|
# Keyword argument: parent_mitre_id
|
||||||
parent_mitre_id=validated_id.parent_id,
|
parent_mitre_id=validated_id.parent_id,
|
||||||
|
# Keyword argument: status_global
|
||||||
status_global=TechniqueStatus.not_evaluated,
|
status_global=TechniqueStatus.not_evaluated,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Apply the @classmethod decorator
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_orm(cls, model: Any) -> TechniqueEntity:
|
# Define function from_orm
|
||||||
"""Build a TechniqueEntity from a SQLAlchemy Technique model."""
|
def from_orm(cls, model: TechniqueORM) -> TechniqueEntity:
|
||||||
|
"""Build a TechniqueEntity from a SQLAlchemy Technique model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (TechniqueORM): The ORM model instance to convert.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TechniqueEntity: A fully populated domain entity reflecting the ORM state.
|
||||||
|
"""
|
||||||
|
# Assign raw_status = model.status_global
|
||||||
raw_status = model.status_global
|
raw_status = model.status_global
|
||||||
|
# Check: raw_status is None
|
||||||
if raw_status is None:
|
if raw_status is None:
|
||||||
|
# Assign status = TechniqueStatus.not_evaluated
|
||||||
status = TechniqueStatus.not_evaluated
|
status = TechniqueStatus.not_evaluated
|
||||||
|
# Alternative: isinstance(raw_status, TechniqueStatus)
|
||||||
elif isinstance(raw_status, TechniqueStatus):
|
elif isinstance(raw_status, TechniqueStatus):
|
||||||
|
# Assign status = raw_status
|
||||||
status = raw_status
|
status = raw_status
|
||||||
|
# Fallback: handle remaining cases
|
||||||
else:
|
else:
|
||||||
|
# Assign status = TechniqueStatus(raw_status)
|
||||||
status = TechniqueStatus(raw_status)
|
status = TechniqueStatus(raw_status)
|
||||||
|
# Return cls(
|
||||||
return cls(
|
return cls(
|
||||||
|
# Keyword argument: id
|
||||||
id=model.id,
|
id=model.id,
|
||||||
|
# Keyword argument: mitre_id
|
||||||
mitre_id=model.mitre_id,
|
mitre_id=model.mitre_id,
|
||||||
|
# Keyword argument: name
|
||||||
name=model.name,
|
name=model.name,
|
||||||
|
# Keyword argument: tactic
|
||||||
tactic=model.tactic,
|
tactic=model.tactic,
|
||||||
|
# Keyword argument: description
|
||||||
description=model.description,
|
description=model.description,
|
||||||
|
# Keyword argument: platforms
|
||||||
platforms=model.platforms or [],
|
platforms=model.platforms or [],
|
||||||
|
# Keyword argument: is_subtechnique
|
||||||
is_subtechnique=model.is_subtechnique or False,
|
is_subtechnique=model.is_subtechnique or False,
|
||||||
|
# Keyword argument: parent_mitre_id
|
||||||
parent_mitre_id=model.parent_mitre_id,
|
parent_mitre_id=model.parent_mitre_id,
|
||||||
|
# Keyword argument: status_global
|
||||||
status_global=status,
|
status_global=status,
|
||||||
|
# Keyword argument: review_required
|
||||||
review_required=model.review_required or False,
|
review_required=model.review_required or False,
|
||||||
|
# Keyword argument: last_review_date
|
||||||
last_review_date=model.last_review_date,
|
last_review_date=model.last_review_date,
|
||||||
|
# Keyword argument: mitre_version
|
||||||
mitre_version=getattr(model, "mitre_version", None),
|
mitre_version=getattr(model, "mitre_version", None),
|
||||||
|
# Keyword argument: mitre_last_modified
|
||||||
mitre_last_modified=getattr(model, "mitre_last_modified", None),
|
mitre_last_modified=getattr(model, "mitre_last_modified", None),
|
||||||
)
|
)
|
||||||
|
|
||||||
def apply_to(self, model: Any) -> None:
|
# Define function apply_to
|
||||||
"""Copy mutable fields back onto the ORM model."""
|
def apply_to(self, model: TechniqueORM) -> None:
|
||||||
|
"""Copy mutable fields back onto the ORM model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (TechniqueORM): The ORM model to update in-place.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
# Assign model.status_global = self.status_global
|
||||||
model.status_global = self.status_global
|
model.status_global = self.status_global
|
||||||
|
# Assign model.review_required = self.review_required
|
||||||
model.review_required = self.review_required
|
model.review_required = self.review_required
|
||||||
|
# Assign model.last_review_date = self.last_review_date
|
||||||
model.last_review_date = self.last_review_date
|
model.last_review_date = self.last_review_date
|
||||||
|
|
||||||
# -- Business logic ----------------------------------------------------
|
# -- Business logic ----------------------------------------------------
|
||||||
|
|
||||||
def recalculate_status(
|
def recalculate_status(
|
||||||
self,
|
self,
|
||||||
|
# Entry: test_snapshots
|
||||||
test_snapshots: list[tuple[str, str | None]],
|
test_snapshots: list[tuple[str, str | None]],
|
||||||
) -> TechniqueStatus:
|
) -> TechniqueStatus:
|
||||||
"""Recompute ``status_global`` from a list of (state, detection_result) pairs.
|
"""Recompute ``status_global`` from a list of (state, detection_result) pairs.
|
||||||
@@ -131,23 +237,37 @@ class TechniqueEntity:
|
|||||||
With only 1 validated+detected test the technique is "partial" to
|
With only 1 validated+detected test the technique is "partial" to
|
||||||
signal that more testing is recommended.
|
signal that more testing is recommended.
|
||||||
|
|
||||||
Returns the new status (also set on the entity).
|
Args:
|
||||||
|
test_snapshots (list[tuple[str, str | None]]): Each element is a
|
||||||
|
``(state, detection_result)`` pair where *state* is a
|
||||||
|
:class:`TestState` value string and *detection_result* is a
|
||||||
|
:class:`TestResult` value string or ``None``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TechniqueStatus: The newly computed status, which is also stored on
|
||||||
|
the entity's ``status_global`` field.
|
||||||
"""
|
"""
|
||||||
_MIN_VALIDATED_FOR_FULL = 2 # require ≥ N validated tests for "validated"
|
_MIN_VALIDATED_FOR_FULL = 2 # require ≥ N validated tests for "validated"
|
||||||
|
|
||||||
tests = [
|
tests = [
|
||||||
_TestSnapshot(
|
_TestSnapshot(
|
||||||
|
# Keyword argument: state
|
||||||
state=s if isinstance(s, TestState) else TestState(s),
|
state=s if isinstance(s, TestState) else TestState(s),
|
||||||
|
# Keyword argument: detection_result
|
||||||
detection_result=dr,
|
detection_result=dr,
|
||||||
)
|
)
|
||||||
for s, dr in test_snapshots
|
for s, dr in test_snapshots
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Check: not tests
|
||||||
if not tests:
|
if not tests:
|
||||||
|
# Assign self.status_global = TechniqueStatus.not_evaluated
|
||||||
self.status_global = TechniqueStatus.not_evaluated
|
self.status_global = TechniqueStatus.not_evaluated
|
||||||
|
# Alternative: all(t.state == TestState.validated for t in tests)
|
||||||
elif all(t.state == TestState.validated for t in tests):
|
elif all(t.state == TestState.validated for t in tests):
|
||||||
validated_count = len(tests)
|
validated_count = len(tests)
|
||||||
results = [t.detection_result for t in tests if t.detection_result]
|
results = [t.detection_result for t in tests if t.detection_result]
|
||||||
|
# Check: results and all(r == TestResult.detected or r == "detected" for r i...
|
||||||
if results and all(r == TestResult.detected or r == "detected" for r in results):
|
if results and all(r == TestResult.detected or r == "detected" for r in results):
|
||||||
# Need at least _MIN_VALIDATED_FOR_FULL tests for "validated"
|
# Need at least _MIN_VALIDATED_FOR_FULL tests for "validated"
|
||||||
if validated_count >= _MIN_VALIDATED_FOR_FULL:
|
if validated_count >= _MIN_VALIDATED_FOR_FULL:
|
||||||
@@ -155,24 +275,46 @@ class TechniqueEntity:
|
|||||||
else:
|
else:
|
||||||
self.status_global = TechniqueStatus.partial
|
self.status_global = TechniqueStatus.partial
|
||||||
elif any(
|
elif any(
|
||||||
|
# Keyword argument: r
|
||||||
r == TestResult.partially_detected or r == "partially_detected"
|
r == TestResult.partially_detected or r == "partially_detected"
|
||||||
for r in results
|
for r in results
|
||||||
):
|
):
|
||||||
|
# Assign self.status_global = TechniqueStatus.partial
|
||||||
self.status_global = TechniqueStatus.partial
|
self.status_global = TechniqueStatus.partial
|
||||||
|
# Fallback: handle remaining cases
|
||||||
else:
|
else:
|
||||||
|
# Assign self.status_global = TechniqueStatus.not_covered
|
||||||
self.status_global = TechniqueStatus.not_covered
|
self.status_global = TechniqueStatus.not_covered
|
||||||
|
# Alternative: any(t.state == TestState.validated for t in tests)
|
||||||
elif any(t.state == TestState.validated for t in tests):
|
elif any(t.state == TestState.validated for t in tests):
|
||||||
|
# Assign self.status_global = TechniqueStatus.partial
|
||||||
self.status_global = TechniqueStatus.partial
|
self.status_global = TechniqueStatus.partial
|
||||||
|
# Fallback: handle remaining cases
|
||||||
else:
|
else:
|
||||||
|
# Assign self.status_global = TechniqueStatus.in_progress
|
||||||
self.status_global = TechniqueStatus.in_progress
|
self.status_global = TechniqueStatus.in_progress
|
||||||
|
|
||||||
|
# Return self.status_global
|
||||||
return self.status_global
|
return self.status_global
|
||||||
|
|
||||||
|
# Define function mark_reviewed
|
||||||
def mark_reviewed(self) -> None:
|
def mark_reviewed(self) -> None:
|
||||||
"""Mark the technique as reviewed, clearing the review flag."""
|
"""Mark the technique as reviewed, clearing the review flag.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
# Assign self.review_required = False
|
||||||
self.review_required = False
|
self.review_required = False
|
||||||
|
# Assign self.last_review_date = datetime.utcnow()
|
||||||
self.last_review_date = datetime.utcnow()
|
self.last_review_date = datetime.utcnow()
|
||||||
|
|
||||||
|
# Define function flag_for_review
|
||||||
def flag_for_review(self) -> None:
|
def flag_for_review(self) -> None:
|
||||||
"""Flag the technique as needing review."""
|
"""Flag the technique as needing review.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
# Assign self.review_required = True
|
||||||
self.review_required = True
|
self.review_required = True
|
||||||
|
|||||||
@@ -3,94 +3,204 @@
|
|||||||
Pure domain logic — no framework imports.
|
Pure domain logic — no framework imports.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Enable future language features for compatibility
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
# Import uuid
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
# Import dataclass, field from dataclasses
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any
|
|
||||||
|
# Import TYPE_CHECKING from typing
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
# Check: TYPE_CHECKING
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
# Import ThreatActor as ThreatActorORM from app.models.threat_actor
|
||||||
|
from app.models.threat_actor import ThreatActor as ThreatActorORM
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @dataclass decorator
|
||||||
@dataclass
|
@dataclass
|
||||||
|
# Define class ThreatActorTechniqueRef
|
||||||
class ThreatActorTechniqueRef:
|
class ThreatActorTechniqueRef:
|
||||||
"""Lightweight reference to a technique used by an actor."""
|
"""Lightweight reference to a technique used by an actor."""
|
||||||
|
|
||||||
|
# technique_id: uuid.UUID
|
||||||
technique_id: uuid.UUID
|
technique_id: uuid.UUID
|
||||||
|
# Assign mitre_id = None
|
||||||
mitre_id: str | None = None
|
mitre_id: str | None = None
|
||||||
|
# Assign name = None
|
||||||
name: str | None = None
|
name: str | None = None
|
||||||
|
# Assign status = None
|
||||||
status: str | None = None
|
status: str | None = None
|
||||||
|
# Assign usage_description = None
|
||||||
usage_description: str | None = None
|
usage_description: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @dataclass decorator
|
||||||
@dataclass
|
@dataclass
|
||||||
|
# Define class ThreatActorEntity
|
||||||
class ThreatActorEntity:
|
class ThreatActorEntity:
|
||||||
|
"""Pure domain representation of a MITRE ATT&CK threat actor (group).
|
||||||
|
|
||||||
|
Aggregates references to the techniques the actor is known to use and
|
||||||
|
provides coverage analysis properties.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# name: str
|
||||||
name: str
|
name: str
|
||||||
|
# Assign id = None
|
||||||
id: uuid.UUID | None = None
|
id: uuid.UUID | None = None
|
||||||
|
# Assign mitre_id = None
|
||||||
mitre_id: str | None = None
|
mitre_id: str | None = None
|
||||||
|
# Assign aliases = field(default_factory=list)
|
||||||
aliases: list[str] = field(default_factory=list)
|
aliases: list[str] = field(default_factory=list)
|
||||||
|
# Assign description = None
|
||||||
description: str | None = None
|
description: str | None = None
|
||||||
|
# Assign country = None
|
||||||
country: str | None = None
|
country: str | None = None
|
||||||
|
# Assign target_sectors = field(default_factory=list)
|
||||||
target_sectors: list[str] = field(default_factory=list)
|
target_sectors: list[str] = field(default_factory=list)
|
||||||
|
# Assign target_regions = field(default_factory=list)
|
||||||
target_regions: list[str] = field(default_factory=list)
|
target_regions: list[str] = field(default_factory=list)
|
||||||
|
# Assign motivation = None
|
||||||
motivation: str | None = None
|
motivation: str | None = None
|
||||||
|
# Assign sophistication = None
|
||||||
sophistication: str | None = None
|
sophistication: str | None = None
|
||||||
|
# Assign first_seen = None
|
||||||
first_seen: str | None = None
|
first_seen: str | None = None
|
||||||
|
# Assign last_seen = None
|
||||||
last_seen: str | None = None
|
last_seen: str | None = None
|
||||||
|
# Assign is_active = True
|
||||||
is_active: bool = True
|
is_active: bool = True
|
||||||
|
# Assign techniques = field(default_factory=list)
|
||||||
techniques: list[ThreatActorTechniqueRef] = field(default_factory=list)
|
techniques: list[ThreatActorTechniqueRef] = field(default_factory=list)
|
||||||
|
|
||||||
|
# Apply the @property decorator
|
||||||
@property
|
@property
|
||||||
|
# Define function technique_count
|
||||||
def technique_count(self) -> int:
|
def technique_count(self) -> int:
|
||||||
|
"""Return the total number of techniques associated with this actor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Count of technique references.
|
||||||
|
"""
|
||||||
|
# Return len(self.techniques)
|
||||||
return len(self.techniques)
|
return len(self.techniques)
|
||||||
|
|
||||||
|
# Apply the @property decorator
|
||||||
@property
|
@property
|
||||||
|
# Define function covered_techniques
|
||||||
def covered_techniques(self) -> list[ThreatActorTechniqueRef]:
|
def covered_techniques(self) -> list[ThreatActorTechniqueRef]:
|
||||||
|
"""Return technique references whose coverage status is ``validated`` or ``partial``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[ThreatActorTechniqueRef]: Subset of techniques considered covered.
|
||||||
|
"""
|
||||||
|
# Return [
|
||||||
return [
|
return [
|
||||||
t for t in self.techniques
|
t for t in self.techniques
|
||||||
if t.status in ("validated", "partial")
|
if t.status in ("validated", "partial")
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Apply the @property decorator
|
||||||
@property
|
@property
|
||||||
|
# Define function uncovered_techniques
|
||||||
def uncovered_techniques(self) -> list[ThreatActorTechniqueRef]:
|
def uncovered_techniques(self) -> list[ThreatActorTechniqueRef]:
|
||||||
|
"""Return technique references whose coverage status is neither ``validated`` nor ``partial``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[ThreatActorTechniqueRef]: Subset of techniques not yet covered.
|
||||||
|
"""
|
||||||
|
# Return [
|
||||||
return [
|
return [
|
||||||
t for t in self.techniques
|
t for t in self.techniques
|
||||||
if t.status not in ("validated", "partial")
|
if t.status not in ("validated", "partial")
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Apply the @property decorator
|
||||||
@property
|
@property
|
||||||
|
# Define function coverage_pct
|
||||||
def coverage_pct(self) -> float:
|
def coverage_pct(self) -> float:
|
||||||
|
"""Return the percentage of the actor's techniques that are covered.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: A value from 0.0 to 100.0, rounded to one decimal place.
|
||||||
|
Returns 0.0 when the actor has no associated techniques.
|
||||||
|
"""
|
||||||
|
# Check: not self.techniques
|
||||||
if not self.techniques:
|
if not self.techniques:
|
||||||
|
# Return 0.0
|
||||||
return 0.0
|
return 0.0
|
||||||
|
# Return round(len(self.covered_techniques) / len(self.techniques) * 100, 1)
|
||||||
return round(len(self.covered_techniques) / len(self.techniques) * 100, 1)
|
return round(len(self.covered_techniques) / len(self.techniques) * 100, 1)
|
||||||
|
|
||||||
|
# Apply the @classmethod decorator
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_orm(cls, orm: Any) -> ThreatActorEntity:
|
# Define function from_orm
|
||||||
|
def from_orm(cls, orm: ThreatActorORM) -> ThreatActorEntity:
|
||||||
|
"""Build a ThreatActorEntity from a SQLAlchemy ThreatActor model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
orm (ThreatActorORM): The ORM model instance to convert.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ThreatActorEntity: A fully populated domain entity including
|
||||||
|
technique references resolved from the ORM relationship.
|
||||||
|
"""
|
||||||
|
# Assign techs = []
|
||||||
techs: list[ThreatActorTechniqueRef] = []
|
techs: list[ThreatActorTechniqueRef] = []
|
||||||
|
# Iterate over getattr(orm, "techniques", None) or []
|
||||||
for tat in getattr(orm, "techniques", None) or []:
|
for tat in getattr(orm, "techniques", None) or []:
|
||||||
|
# Assign technique = getattr(tat, "technique", None)
|
||||||
technique = getattr(tat, "technique", None)
|
technique = getattr(tat, "technique", None)
|
||||||
|
# Call techs.append()
|
||||||
techs.append(ThreatActorTechniqueRef(
|
techs.append(ThreatActorTechniqueRef(
|
||||||
|
# Keyword argument: technique_id
|
||||||
technique_id=tat.technique_id,
|
technique_id=tat.technique_id,
|
||||||
|
# Keyword argument: mitre_id
|
||||||
mitre_id=getattr(technique, "mitre_id", None) if technique else None,
|
mitre_id=getattr(technique, "mitre_id", None) if technique else None,
|
||||||
|
# Keyword argument: name
|
||||||
name=getattr(technique, "name", None) if technique else None,
|
name=getattr(technique, "name", None) if technique else None,
|
||||||
|
# Keyword argument: status
|
||||||
status=(
|
status=(
|
||||||
technique.status_global.value
|
technique.status_global.value
|
||||||
if technique and hasattr(technique.status_global, "value")
|
if technique and hasattr(technique.status_global, "value")
|
||||||
else getattr(technique, "status_global", None) if technique else None
|
else getattr(technique, "status_global", None) if technique else None
|
||||||
),
|
),
|
||||||
|
# Keyword argument: usage_description
|
||||||
usage_description=tat.usage_description,
|
usage_description=tat.usage_description,
|
||||||
))
|
))
|
||||||
|
# Return cls(
|
||||||
return cls(
|
return cls(
|
||||||
|
# Keyword argument: id
|
||||||
id=orm.id,
|
id=orm.id,
|
||||||
|
# Keyword argument: name
|
||||||
name=orm.name,
|
name=orm.name,
|
||||||
|
# Keyword argument: mitre_id
|
||||||
mitre_id=orm.mitre_id,
|
mitre_id=orm.mitre_id,
|
||||||
|
# Keyword argument: aliases
|
||||||
aliases=orm.aliases or [],
|
aliases=orm.aliases or [],
|
||||||
|
# Keyword argument: description
|
||||||
description=orm.description,
|
description=orm.description,
|
||||||
|
# Keyword argument: country
|
||||||
country=orm.country,
|
country=orm.country,
|
||||||
|
# Keyword argument: target_sectors
|
||||||
target_sectors=orm.target_sectors or [],
|
target_sectors=orm.target_sectors or [],
|
||||||
|
# Keyword argument: target_regions
|
||||||
target_regions=orm.target_regions or [],
|
target_regions=orm.target_regions or [],
|
||||||
|
# Keyword argument: motivation
|
||||||
motivation=orm.motivation,
|
motivation=orm.motivation,
|
||||||
|
# Keyword argument: sophistication
|
||||||
sophistication=orm.sophistication,
|
sophistication=orm.sophistication,
|
||||||
|
# Keyword argument: first_seen
|
||||||
first_seen=orm.first_seen,
|
first_seen=orm.first_seen,
|
||||||
|
# Keyword argument: last_seen
|
||||||
last_seen=orm.last_seen,
|
last_seen=orm.last_seen,
|
||||||
|
# Keyword argument: is_active
|
||||||
is_active=orm.is_active if orm.is_active is not None else True,
|
is_active=orm.is_active if orm.is_active is not None else True,
|
||||||
|
# Keyword argument: techniques
|
||||||
techniques=techs,
|
techniques=techs,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -5,41 +5,78 @@ truth. ``models/enums.py`` re-exports them so that existing ORM code
|
|||||||
continues to work without changes.
|
continues to work without changes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Import enum
|
||||||
import enum
|
import enum
|
||||||
|
|
||||||
|
|
||||||
|
# Define class TechniqueStatus
|
||||||
class TechniqueStatus(str, enum.Enum):
|
class TechniqueStatus(str, enum.Enum):
|
||||||
|
"""Coverage and evaluation status for a MITRE ATT&CK technique."""
|
||||||
|
|
||||||
|
# Assign not_evaluated = "not_evaluated"
|
||||||
not_evaluated = "not_evaluated"
|
not_evaluated = "not_evaluated"
|
||||||
|
# Assign in_progress = "in_progress"
|
||||||
in_progress = "in_progress"
|
in_progress = "in_progress"
|
||||||
|
# Assign validated = "validated"
|
||||||
validated = "validated"
|
validated = "validated"
|
||||||
|
# Assign partial = "partial"
|
||||||
partial = "partial"
|
partial = "partial"
|
||||||
|
# Assign not_covered = "not_covered"
|
||||||
not_covered = "not_covered"
|
not_covered = "not_covered"
|
||||||
|
# Assign review_required = "review_required"
|
||||||
review_required = "review_required"
|
review_required = "review_required"
|
||||||
|
|
||||||
|
|
||||||
|
# Define class TestState
|
||||||
class TestState(str, enum.Enum):
|
class TestState(str, enum.Enum):
|
||||||
|
"""Lifecycle states in the security test state machine."""
|
||||||
|
|
||||||
|
# Assign draft = "draft"
|
||||||
draft = "draft"
|
draft = "draft"
|
||||||
|
# Assign red_executing = "red_executing"
|
||||||
red_executing = "red_executing"
|
red_executing = "red_executing"
|
||||||
|
# Assign blue_evaluating = "blue_evaluating"
|
||||||
blue_evaluating = "blue_evaluating"
|
blue_evaluating = "blue_evaluating"
|
||||||
|
# Assign in_review = "in_review"
|
||||||
in_review = "in_review"
|
in_review = "in_review"
|
||||||
|
# Assign validated = "validated"
|
||||||
validated = "validated"
|
validated = "validated"
|
||||||
|
# Assign rejected = "rejected"
|
||||||
rejected = "rejected"
|
rejected = "rejected"
|
||||||
disputed = "disputed" # one lead approved, the other rejected
|
disputed = "disputed" # one lead approved, the other rejected
|
||||||
|
|
||||||
|
|
||||||
|
# Define class TeamSide
|
||||||
class TeamSide(str, enum.Enum):
|
class TeamSide(str, enum.Enum):
|
||||||
|
"""Identifies which team (red or blue) an action belongs to."""
|
||||||
|
|
||||||
|
# Assign red = "red"
|
||||||
red = "red"
|
red = "red"
|
||||||
|
# Assign blue = "blue"
|
||||||
blue = "blue"
|
blue = "blue"
|
||||||
|
|
||||||
|
|
||||||
|
# Define class TestResult
|
||||||
class TestResult(str, enum.Enum):
|
class TestResult(str, enum.Enum):
|
||||||
|
"""Outcome of a red-team test from a detection perspective."""
|
||||||
|
|
||||||
|
# Assign detected = "detected"
|
||||||
detected = "detected"
|
detected = "detected"
|
||||||
|
# Assign not_detected = "not_detected"
|
||||||
not_detected = "not_detected"
|
not_detected = "not_detected"
|
||||||
|
# Assign partially_detected = "partially_detected"
|
||||||
partially_detected = "partially_detected"
|
partially_detected = "partially_detected"
|
||||||
|
|
||||||
|
|
||||||
|
# Define class DataClassification
|
||||||
class DataClassification(str, enum.Enum):
|
class DataClassification(str, enum.Enum):
|
||||||
|
"""Data sensitivity classification levels for compliance and retention policies."""
|
||||||
|
|
||||||
|
# Assign public = "public"
|
||||||
public = "public"
|
public = "public"
|
||||||
|
# Assign internal = "internal"
|
||||||
internal = "internal"
|
internal = "internal"
|
||||||
|
# Assign sensitive = "sensitive"
|
||||||
sensitive = "sensitive"
|
sensitive = "sensitive"
|
||||||
|
# Assign restricted = "restricted"
|
||||||
restricted = "restricted"
|
restricted = "restricted"
|
||||||
|
|||||||
@@ -9,15 +9,30 @@ Existing code that imports from ``app.domain.exceptions`` continues to
|
|||||||
work — that module re-exports everything defined here.
|
work — that module re-exports everything defined here.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Enable future language features for compatibility
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
|
||||||
|
# Define class DomainError
|
||||||
class DomainError(Exception):
|
class DomainError(Exception):
|
||||||
"""Base for all domain errors."""
|
"""Base for all domain errors."""
|
||||||
|
|
||||||
|
# Define function __init__
|
||||||
def __init__(self, message: str, *, code: str = "DOMAIN_ERROR") -> None:
|
def __init__(self, message: str, *, code: str = "DOMAIN_ERROR") -> None:
|
||||||
|
"""Initialise the domain error with a human-readable message and error code.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message (str): Human-readable description of the error.
|
||||||
|
code (str): Machine-readable error code used by the HTTP error handler.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
# Assign self.message = message
|
||||||
self.message = message
|
self.message = message
|
||||||
|
# Assign self.code = code
|
||||||
self.code = code
|
self.code = code
|
||||||
|
# Call super()
|
||||||
super().__init__(message)
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
@@ -27,18 +42,45 @@ class DomainError(Exception):
|
|||||||
class EntityNotFoundError(DomainError):
|
class EntityNotFoundError(DomainError):
|
||||||
"""A requested entity does not exist."""
|
"""A requested entity does not exist."""
|
||||||
|
|
||||||
|
# Define function __init__
|
||||||
def __init__(self, entity: str, identifier: str) -> None:
|
def __init__(self, entity: str, identifier: str) -> None:
|
||||||
|
"""Initialise an entity-not-found error.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
entity (str): Name of the entity type that was not found (e.g. "Technique").
|
||||||
|
identifier (str): The ID or key used in the failed lookup.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
# Call super()
|
||||||
super().__init__(f"{entity} not found: {identifier}", code="NOT_FOUND")
|
super().__init__(f"{entity} not found: {identifier}", code="NOT_FOUND")
|
||||||
|
# Assign self.entity = entity
|
||||||
self.entity = entity
|
self.entity = entity
|
||||||
|
# Assign self.identifier = identifier
|
||||||
self.identifier = identifier
|
self.identifier = identifier
|
||||||
|
|
||||||
|
|
||||||
|
# Define class DuplicateEntityError
|
||||||
class DuplicateEntityError(DomainError):
|
class DuplicateEntityError(DomainError):
|
||||||
"""Creating an entity that already exists."""
|
"""Creating an entity that already exists."""
|
||||||
|
|
||||||
|
# Define function __init__
|
||||||
def __init__(self, entity: str, field: str, value: str) -> None:
|
def __init__(self, entity: str, field: str, value: str) -> None:
|
||||||
|
"""Initialise a duplicate-entity error.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
entity (str): Name of the entity type that already exists (e.g. "Campaign").
|
||||||
|
field (str): Name of the field whose value conflicts (e.g. "name").
|
||||||
|
value (str): The conflicting value that is already in use.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
# Call super()
|
||||||
super().__init__(
|
super().__init__(
|
||||||
f"{entity} with {field}='{value}' already exists",
|
f"{entity} with {field}='{value}' already exists",
|
||||||
|
# Keyword argument: code
|
||||||
code="DUPLICATE",
|
code="DUPLICATE",
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -46,34 +88,67 @@ class DuplicateEntityError(DomainError):
|
|||||||
# ── State machine ────────────────────────────────────────────────────
|
# ── State machine ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
class InvalidStateTransition(DomainError):
|
class InvalidStateTransition(DomainError): # noqa: N818 — DDD term, renaming would break 96 call sites
|
||||||
"""A state-machine transition is not allowed."""
|
"""A state-machine transition is not allowed."""
|
||||||
|
|
||||||
|
# Define function __init__
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
# Entry: current_state
|
||||||
current_state: str,
|
current_state: str,
|
||||||
|
# Entry: target_state
|
||||||
target_state: str,
|
target_state: str,
|
||||||
|
# Entry: valid_transitions
|
||||||
valid_transitions: list[str] | None = None,
|
valid_transitions: list[str] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""Initialise an invalid state-transition error.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
current_state (str): The entity's present state (e.g. "draft").
|
||||||
|
target_state (str): The state that was illegally requested.
|
||||||
|
valid_transitions (list[str] | None): Allowed target states from the
|
||||||
|
current state; included in the error message when provided.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
# Assign msg = f"Cannot transition from '{current_state}' to '{target_state}'"
|
||||||
msg = f"Cannot transition from '{current_state}' to '{target_state}'"
|
msg = f"Cannot transition from '{current_state}' to '{target_state}'"
|
||||||
|
# Check: valid_transitions
|
||||||
if valid_transitions:
|
if valid_transitions:
|
||||||
|
# Assign msg = f". Valid transitions: {valid_transitions}"
|
||||||
msg += f". Valid transitions: {valid_transitions}"
|
msg += f". Valid transitions: {valid_transitions}"
|
||||||
|
# Call super()
|
||||||
super().__init__(msg, code="INVALID_TRANSITION")
|
super().__init__(msg, code="INVALID_TRANSITION")
|
||||||
|
# Assign self.current_state = current_state
|
||||||
self.current_state = current_state
|
self.current_state = current_state
|
||||||
|
# Assign self.target_state = target_state
|
||||||
self.target_state = target_state
|
self.target_state = target_state
|
||||||
|
# Assign self.valid_transitions = valid_transitions or []
|
||||||
self.valid_transitions = valid_transitions or []
|
self.valid_transitions = valid_transitions or []
|
||||||
|
|
||||||
|
|
||||||
# ── Business rules ────────────────────────────────────────────────────
|
# ── Business rules ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
class BusinessRuleViolation(DomainError):
|
class BusinessRuleViolation(DomainError): # noqa: N818 — DDD term, renaming would break 96 call sites
|
||||||
"""An operation violates a business invariant."""
|
"""An operation violates a business invariant."""
|
||||||
|
|
||||||
|
# Define function __init__
|
||||||
def __init__(self, message: str) -> None:
|
def __init__(self, message: str) -> None:
|
||||||
|
"""Initialise a business-rule violation error.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message (str): Human-readable description of the violated rule.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
# Call super()
|
||||||
super().__init__(message, code="BUSINESS_RULE_VIOLATION")
|
super().__init__(message, code="BUSINESS_RULE_VIOLATION")
|
||||||
|
|
||||||
|
|
||||||
|
# Define class InvalidOperationError
|
||||||
class InvalidOperationError(BusinessRuleViolation):
|
class InvalidOperationError(BusinessRuleViolation):
|
||||||
"""An operation is invalid in the current context.
|
"""An operation is invalid in the current context.
|
||||||
|
|
||||||
@@ -81,16 +156,37 @@ class InvalidOperationError(BusinessRuleViolation):
|
|||||||
:class:`BusinessRuleViolation` directly.
|
:class:`BusinessRuleViolation` directly.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Define function __init__
|
||||||
def __init__(self, message: str) -> None:
|
def __init__(self, message: str) -> None:
|
||||||
|
"""Initialise an invalid-operation error.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message (str): Human-readable description of why the operation is invalid.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
# Call super()
|
||||||
super().__init__(message)
|
super().__init__(message)
|
||||||
|
# Assign self.code = "INVALID_OPERATION"
|
||||||
self.code = "INVALID_OPERATION"
|
self.code = "INVALID_OPERATION"
|
||||||
|
|
||||||
|
|
||||||
# ── Authorization ────────────────────────────────────────────────────
|
# ── Authorization ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
class PermissionViolation(DomainError):
|
class PermissionViolation(DomainError): # noqa: N818 — DDD term, renaming would break 96 call sites
|
||||||
"""The user lacks permissions for an action."""
|
"""The user lacks permissions for an action."""
|
||||||
|
|
||||||
|
# Define function __init__
|
||||||
def __init__(self, message: str = "Insufficient permissions") -> None:
|
def __init__(self, message: str = "Insufficient permissions") -> None:
|
||||||
|
"""Initialise a permission-violation error.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message (str): Human-readable description of the access denial.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
# Call super()
|
||||||
super().__init__(message, code="FORBIDDEN")
|
super().__init__(message, code="FORBIDDEN")
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ old import paths so that existing code keeps working without changes::
|
|||||||
from app.domain.exceptions import InvalidTransitionError # still works
|
from app.domain.exceptions import InvalidTransitionError # still works
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Import # noqa: F401 from app.domain.errors
|
||||||
from app.domain.errors import ( # noqa: F401
|
from app.domain.errors import ( # noqa: F401
|
||||||
BusinessRuleViolation,
|
BusinessRuleViolation,
|
||||||
DomainError,
|
DomainError,
|
||||||
@@ -18,5 +19,7 @@ from app.domain.errors import ( # noqa: F401
|
|||||||
|
|
||||||
# Legacy aliases — old name → new name
|
# Legacy aliases — old name → new name
|
||||||
DomainException = DomainError
|
DomainException = DomainError
|
||||||
|
# Assign InvalidTransitionError = InvalidStateTransition
|
||||||
InvalidTransitionError = InvalidStateTransition
|
InvalidTransitionError = InvalidStateTransition
|
||||||
|
# Assign AuthorizationError = PermissionViolation
|
||||||
AuthorizationError = PermissionViolation
|
AuthorizationError = PermissionViolation
|
||||||
|
|||||||
@@ -0,0 +1 @@
|
|||||||
|
"""Abstract port interfaces that infrastructure adapters must implement."""
|
||||||
|
|||||||
@@ -12,14 +12,19 @@ This satisfies the Open/Closed Principle — the system is open for new
|
|||||||
import sources without modifying existing code.
|
import sources without modifying existing code.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Enable future language features for compatibility
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
# Import Any, Protocol, runtime_checkable from typing
|
||||||
from typing import Any, Protocol, runtime_checkable
|
from typing import Any, Protocol, runtime_checkable
|
||||||
|
|
||||||
|
# Import Session from sqlalchemy.orm
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @runtime_checkable decorator
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
|
# Define class ImportService
|
||||||
class ImportService(Protocol):
|
class ImportService(Protocol):
|
||||||
"""Contract for any data-import operation.
|
"""Contract for any data-import operation.
|
||||||
|
|
||||||
@@ -27,62 +32,134 @@ class ImportService(Protocol):
|
|||||||
downloads, parses, and upserts records from an external source.
|
downloads, parses, and upserts records from an external source.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __call__(self, db: Session) -> dict[str, Any]: ...
|
# Define function __call__
|
||||||
|
def __call__(self, db: Session) -> dict[str, Any]:
|
||||||
|
"""Execute the import operation against the given database session.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db (Session): Active SQLAlchemy session to use for all DB operations.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict[str, Any]: Summary statistics for the import run (e.g. created,
|
||||||
|
updated, skipped counts).
|
||||||
|
"""
|
||||||
|
# ...
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
# Define class ImportServiceEntry
|
||||||
class ImportServiceEntry:
|
class ImportServiceEntry:
|
||||||
"""Lazy-loading wrapper that resolves a module-level function on first call."""
|
"""Lazy-loading wrapper that resolves a module-level function on first call."""
|
||||||
|
|
||||||
|
# Assign __slots__ = ("_module_path", "_func_name", "_resolved")
|
||||||
__slots__ = ("_module_path", "_func_name", "_resolved")
|
__slots__ = ("_module_path", "_func_name", "_resolved")
|
||||||
|
|
||||||
|
# Define function __init__
|
||||||
def __init__(self, module_path: str, func_name: str) -> None:
|
def __init__(self, module_path: str, func_name: str) -> None:
|
||||||
|
"""Initialise the lazy entry with the module path and function name to resolve later.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module_path (str): Dotted Python module path, e.g.
|
||||||
|
``"app.services.atomic_import_service"``.
|
||||||
|
func_name (str): Name of the callable to import from *module_path*.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
# Assign self._module_path = module_path
|
||||||
self._module_path = module_path
|
self._module_path = module_path
|
||||||
|
# Assign self._func_name = func_name
|
||||||
self._func_name = func_name
|
self._func_name = func_name
|
||||||
|
# Assign self._resolved = None
|
||||||
self._resolved: ImportService | None = None
|
self._resolved: ImportService | None = None
|
||||||
|
|
||||||
|
# Define function __call__
|
||||||
def __call__(self, db: Session) -> dict[str, Any]:
|
def __call__(self, db: Session) -> dict[str, Any]:
|
||||||
|
"""Resolve the import function on first call and invoke it with *db*.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db (Session): SQLAlchemy session passed through to the underlying
|
||||||
|
import function.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict[str, Any]: Import statistics returned by the underlying function
|
||||||
|
(e.g. counts of created/updated/skipped records).
|
||||||
|
"""
|
||||||
|
# Check: self._resolved is None
|
||||||
if self._resolved is None:
|
if self._resolved is None:
|
||||||
|
# Import importlib
|
||||||
import importlib
|
import importlib
|
||||||
|
# Assign mod = importlib.import_module(self._module_path)
|
||||||
mod = importlib.import_module(self._module_path)
|
mod = importlib.import_module(self._module_path)
|
||||||
|
# Assign self._resolved = getattr(mod, self._func_name)
|
||||||
self._resolved = getattr(mod, self._func_name)
|
self._resolved = getattr(mod, self._func_name)
|
||||||
|
# Return self._resolved(db)
|
||||||
return self._resolved(db)
|
return self._resolved(db)
|
||||||
|
|
||||||
|
# Apply the @property decorator
|
||||||
@property
|
@property
|
||||||
|
# Define function source_info
|
||||||
def source_info(self) -> str:
|
def source_info(self) -> str:
|
||||||
|
"""Return a human-readable identifier for this import entry.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The fully qualified function reference as
|
||||||
|
``"<module_path>.<func_name>"``.
|
||||||
|
"""
|
||||||
|
# Return f"{self._module_path}.{self._func_name}"
|
||||||
return f"{self._module_path}.{self._func_name}"
|
return f"{self._module_path}.{self._func_name}"
|
||||||
|
|
||||||
|
|
||||||
|
# Assign IMPORT_REGISTRY = {
|
||||||
IMPORT_REGISTRY: dict[str, ImportServiceEntry] = {
|
IMPORT_REGISTRY: dict[str, ImportServiceEntry] = {
|
||||||
|
# Literal argument value
|
||||||
"atomic_red_team": ImportServiceEntry(
|
"atomic_red_team": ImportServiceEntry(
|
||||||
|
# Literal argument value
|
||||||
"app.services.atomic_import_service", "import_atomic_red_team",
|
"app.services.atomic_import_service", "import_atomic_red_team",
|
||||||
),
|
),
|
||||||
|
# Literal argument value
|
||||||
"sigma": ImportServiceEntry(
|
"sigma": ImportServiceEntry(
|
||||||
|
# Literal argument value
|
||||||
"app.services.sigma_import_service", "sync",
|
"app.services.sigma_import_service", "sync",
|
||||||
),
|
),
|
||||||
|
# Literal argument value
|
||||||
"lolbas": ImportServiceEntry(
|
"lolbas": ImportServiceEntry(
|
||||||
|
# Literal argument value
|
||||||
"app.services.lolbas_import_service", "sync",
|
"app.services.lolbas_import_service", "sync",
|
||||||
),
|
),
|
||||||
|
# Literal argument value
|
||||||
"gtfobins": ImportServiceEntry(
|
"gtfobins": ImportServiceEntry(
|
||||||
|
# Literal argument value
|
||||||
"app.services.lolbas_import_service", "sync_gtfobins",
|
"app.services.lolbas_import_service", "sync_gtfobins",
|
||||||
),
|
),
|
||||||
|
# Literal argument value
|
||||||
"caldera": ImportServiceEntry(
|
"caldera": ImportServiceEntry(
|
||||||
|
# Literal argument value
|
||||||
"app.services.caldera_import_service", "sync",
|
"app.services.caldera_import_service", "sync",
|
||||||
),
|
),
|
||||||
|
# Literal argument value
|
||||||
"elastic_rules": ImportServiceEntry(
|
"elastic_rules": ImportServiceEntry(
|
||||||
|
# Literal argument value
|
||||||
"app.services.elastic_import_service", "sync",
|
"app.services.elastic_import_service", "sync",
|
||||||
),
|
),
|
||||||
|
# Literal argument value
|
||||||
"mitre_cti": ImportServiceEntry(
|
"mitre_cti": ImportServiceEntry(
|
||||||
|
# Literal argument value
|
||||||
"app.services.threat_actor_import_service", "sync",
|
"app.services.threat_actor_import_service", "sync",
|
||||||
),
|
),
|
||||||
|
# Literal argument value
|
||||||
"d3fend": ImportServiceEntry(
|
"d3fend": ImportServiceEntry(
|
||||||
|
# Literal argument value
|
||||||
"app.services.d3fend_import_service", "sync",
|
"app.services.d3fend_import_service", "sync",
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Define function get_import_handler
|
||||||
def get_import_handler(source_name: str) -> ImportServiceEntry | None:
|
def get_import_handler(source_name: str) -> ImportServiceEntry | None:
|
||||||
"""Look up the import handler for *source_name*.
|
"""Look up the import handler for *source_name*.
|
||||||
|
|
||||||
Returns ``None`` when no handler is registered.
|
Returns ``None`` when no handler is registered.
|
||||||
"""
|
"""
|
||||||
|
# Return IMPORT_REGISTRY.get(source_name)
|
||||||
return IMPORT_REGISTRY.get(source_name)
|
return IMPORT_REGISTRY.get(source_name)
|
||||||
|
|||||||
@@ -1,4 +1,9 @@
|
|||||||
|
"""Abstract repository port interfaces for domain entity persistence."""
|
||||||
|
# Import TechniqueRepository from app.domain.ports.repositories.technique_repository
|
||||||
from app.domain.ports.repositories.technique_repository import TechniqueRepository
|
from app.domain.ports.repositories.technique_repository import TechniqueRepository
|
||||||
|
|
||||||
|
# Import TestRepository from app.domain.ports.repositories.test_repository
|
||||||
from app.domain.ports.repositories.test_repository import TestRepository
|
from app.domain.ports.repositories.test_repository import TestRepository
|
||||||
|
|
||||||
|
# Assign __all__ = ["TechniqueRepository", "TestRepository"]
|
||||||
__all__ = ["TechniqueRepository", "TestRepository"]
|
__all__ = ["TechniqueRepository", "TestRepository"]
|
||||||
|
|||||||
@@ -4,54 +4,157 @@ This is a domain contract — implementations live in infrastructure/.
|
|||||||
The domain layer NEVER imports the implementation.
|
The domain layer NEVER imports the implementation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Enable future language features for compatibility
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
# Import uuid
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
# Import NamedTuple, Protocol, runtime_checkable from typing
|
||||||
from typing import NamedTuple, Protocol, runtime_checkable
|
from typing import NamedTuple, Protocol, runtime_checkable
|
||||||
|
|
||||||
|
# Import TechniqueEntity from app.domain.entities.technique
|
||||||
from app.domain.entities.technique import TechniqueEntity
|
from app.domain.entities.technique import TechniqueEntity
|
||||||
|
|
||||||
|
# Import TechniqueStatus from app.domain.enums
|
||||||
from app.domain.enums import TechniqueStatus
|
from app.domain.enums import TechniqueStatus
|
||||||
|
|
||||||
|
|
||||||
|
# Define class TechniqueWithCounts
|
||||||
class TechniqueWithCounts(NamedTuple):
|
class TechniqueWithCounts(NamedTuple):
|
||||||
"""Pre-aggregated technique data for heatmap/scoring."""
|
"""Pre-aggregated technique data for heatmap/scoring."""
|
||||||
|
|
||||||
|
# entity: TechniqueEntity
|
||||||
entity: TechniqueEntity
|
entity: TechniqueEntity
|
||||||
|
# test_count: int
|
||||||
test_count: int
|
test_count: int
|
||||||
|
# validated_test_count: int
|
||||||
validated_test_count: int
|
validated_test_count: int
|
||||||
|
# detection_rule_count: int
|
||||||
detection_rule_count: int
|
detection_rule_count: int
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @runtime_checkable decorator
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
|
# Define class TechniqueRepository
|
||||||
class TechniqueRepository(Protocol):
|
class TechniqueRepository(Protocol):
|
||||||
"""Data access contract for techniques (one per aggregate root)."""
|
"""Data access contract for techniques (one per aggregate root)."""
|
||||||
|
|
||||||
# -- Single-entity access ----------------------------------------------
|
# -- Single-entity access ----------------------------------------------
|
||||||
|
|
||||||
def find_by_id(self, technique_id: uuid.UUID) -> TechniqueEntity | None: ...
|
def find_by_id(self, technique_id: uuid.UUID) -> TechniqueEntity | None:
|
||||||
|
"""Return the technique with the given primary key, or None if absent.
|
||||||
|
|
||||||
def find_by_mitre_id(self, mitre_id: str) -> TechniqueEntity | None: ...
|
Args:
|
||||||
|
technique_id (uuid.UUID): Primary key of the technique to look up.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TechniqueEntity | None: The matching entity, or None if not found.
|
||||||
|
"""
|
||||||
|
# ...
|
||||||
|
...
|
||||||
|
|
||||||
|
# Define function find_by_mitre_id
|
||||||
|
def find_by_mitre_id(self, mitre_id: str) -> TechniqueEntity | None:
|
||||||
|
"""Return the technique matching the given MITRE ATT&CK identifier, or None.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mitre_id (str): MITRE ATT&CK ID (e.g. ``"T1059"`` or ``"T1059.001"``).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TechniqueEntity | None: The matching entity, or None if not found.
|
||||||
|
"""
|
||||||
|
# ...
|
||||||
|
...
|
||||||
|
|
||||||
# -- List access -------------------------------------------------------
|
# -- List access -------------------------------------------------------
|
||||||
|
|
||||||
def list_all(
|
def list_all(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
|
# Entry: tactic
|
||||||
tactic: str | None = None,
|
tactic: str | None = None,
|
||||||
|
# Entry: status
|
||||||
status: TechniqueStatus | None = None,
|
status: TechniqueStatus | None = None,
|
||||||
|
# Entry: review_required
|
||||||
review_required: bool | None = None,
|
review_required: bool | None = None,
|
||||||
) -> list[TechniqueEntity]: ...
|
) -> list[TechniqueEntity]:
|
||||||
|
"""Return all techniques, optionally filtered by tactic, status, or review flag.
|
||||||
|
|
||||||
def list_by_ids(self, ids: list[uuid.UUID]) -> list[TechniqueEntity]: ...
|
Args:
|
||||||
|
tactic (str | None): When provided, restrict results to this tactic category.
|
||||||
|
status (TechniqueStatus | None): When provided, restrict results to this status.
|
||||||
|
review_required (bool | None): When provided, restrict results to techniques
|
||||||
|
whose ``review_required`` flag matches this value.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[TechniqueEntity]: Matching technique entities; may be empty.
|
||||||
|
"""
|
||||||
|
# ...
|
||||||
|
...
|
||||||
|
|
||||||
|
# Define function list_by_ids
|
||||||
|
def list_by_ids(self, ids: list[uuid.UUID]) -> list[TechniqueEntity]:
|
||||||
|
"""Return all techniques whose primary keys are in *ids*.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ids (list[uuid.UUID]): List of technique UUIDs to retrieve.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[TechniqueEntity]: Entities found for the supplied IDs; order
|
||||||
|
is not guaranteed and missing IDs are silently omitted.
|
||||||
|
"""
|
||||||
|
# ...
|
||||||
|
...
|
||||||
|
|
||||||
# -- Batch queries (scoring/heatmap performance) -----------------------
|
# -- Batch queries (scoring/heatmap performance) -----------------------
|
||||||
|
|
||||||
def count_by_status(self) -> dict[TechniqueStatus, int]: ...
|
def count_by_status(self) -> dict[TechniqueStatus, int]:
|
||||||
|
"""Return a count of techniques grouped by their global status.
|
||||||
|
|
||||||
def find_all_with_test_counts(self) -> list[TechniqueWithCounts]: ...
|
Returns:
|
||||||
|
dict[TechniqueStatus, int]: Mapping from each status value to the
|
||||||
|
number of techniques in that state.
|
||||||
|
"""
|
||||||
|
# ...
|
||||||
|
...
|
||||||
|
|
||||||
|
# Define function find_all_with_test_counts
|
||||||
|
def find_all_with_test_counts(self) -> list[TechniqueWithCounts]:
|
||||||
|
"""Return all techniques together with pre-aggregated test and rule counts.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[TechniqueWithCounts]: Each element bundles a TechniqueEntity
|
||||||
|
with its total, validated, and detection-rule counts for use
|
||||||
|
in heatmap and scoring calculations.
|
||||||
|
"""
|
||||||
|
# ...
|
||||||
|
...
|
||||||
|
|
||||||
# -- Mutations ---------------------------------------------------------
|
# -- Mutations ---------------------------------------------------------
|
||||||
|
|
||||||
def save(self, technique: TechniqueEntity) -> TechniqueEntity: ...
|
def save(self, technique: TechniqueEntity) -> TechniqueEntity:
|
||||||
|
"""Persist a technique entity and return the saved state.
|
||||||
|
|
||||||
def exists_by_mitre_id(self, mitre_id: str) -> bool: ...
|
Args:
|
||||||
|
technique (TechniqueEntity): The entity to create or update.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TechniqueEntity: The persisted entity, potentially with updated
|
||||||
|
fields (e.g. server-side timestamps).
|
||||||
|
"""
|
||||||
|
# ...
|
||||||
|
...
|
||||||
|
|
||||||
|
# Define function exists_by_mitre_id
|
||||||
|
def exists_by_mitre_id(self, mitre_id: str) -> bool:
|
||||||
|
"""Return True if a technique with the given MITRE ID exists in the repository.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mitre_id (str): MITRE ATT&CK ID to check (e.g. ``"T1059"``).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if a matching technique is found, False otherwise.
|
||||||
|
"""
|
||||||
|
# ...
|
||||||
|
...
|
||||||
|
|||||||
@@ -3,14 +3,20 @@
|
|||||||
This is a domain contract — implementations live in infrastructure/.
|
This is a domain contract — implementations live in infrastructure/.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Enable future language features for compatibility
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
# Import uuid
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Protocol, runtime_checkable
|
|
||||||
|
|
||||||
|
# Import Protocol from typing
|
||||||
|
from typing import Protocol
|
||||||
|
|
||||||
|
# Import TestState from app.domain.enums
|
||||||
from app.domain.enums import TestState
|
from app.domain.enums import TestState
|
||||||
|
|
||||||
|
|
||||||
|
# Define class TestRepository
|
||||||
class TestRepository(Protocol):
|
class TestRepository(Protocol):
|
||||||
"""Data access contract for tests."""
|
"""Data access contract for tests."""
|
||||||
|
|
||||||
@@ -22,31 +28,81 @@ class TestRepository(Protocol):
|
|||||||
Returns the ORM model directly (not a domain entity) because
|
Returns the ORM model directly (not a domain entity) because
|
||||||
the TestEntity is constructed at the service layer via
|
the TestEntity is constructed at the service layer via
|
||||||
``TestEntity.from_orm()``.
|
``TestEntity.from_orm()``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
test_id (uuid.UUID): Primary key of the test to look up.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
object | None: The ORM model instance, or None if not found.
|
||||||
"""
|
"""
|
||||||
|
# ...
|
||||||
...
|
...
|
||||||
|
|
||||||
# -- List access -------------------------------------------------------
|
# -- List access -------------------------------------------------------
|
||||||
|
|
||||||
def list_by_technique(self, technique_id: uuid.UUID) -> list[object]: ...
|
def list_by_technique(self, technique_id: uuid.UUID) -> list[object]:
|
||||||
|
"""Return all test ORM models associated with the given technique.
|
||||||
|
|
||||||
def list_by_state(self, state: TestState) -> list[object]: ...
|
Args:
|
||||||
|
technique_id (uuid.UUID): Primary key of the technique whose tests to retrieve.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[object]: ORM model instances for all tests linked to this technique.
|
||||||
|
"""
|
||||||
|
# ...
|
||||||
|
...
|
||||||
|
|
||||||
|
# Define function list_by_state
|
||||||
|
def list_by_state(self, state: TestState) -> list[object]:
|
||||||
|
"""Return all test ORM models in the given state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state (TestState): The state to filter tests by.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[object]: ORM model instances for all tests currently in *state*.
|
||||||
|
"""
|
||||||
|
# ...
|
||||||
|
...
|
||||||
|
|
||||||
|
# Define function count_by_technique_and_state
|
||||||
def count_by_technique_and_state(
|
def count_by_technique_and_state(
|
||||||
self,
|
self,
|
||||||
|
# Entry: technique_id
|
||||||
technique_id: uuid.UUID,
|
technique_id: uuid.UUID,
|
||||||
) -> dict[TestState, int]:
|
) -> dict[TestState, int]:
|
||||||
"""Return test counts grouped by state for a single technique."""
|
"""Return test counts grouped by state for a single technique.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
technique_id (uuid.UUID): Primary key of the technique whose test
|
||||||
|
counts to aggregate.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict[TestState, int]: Mapping from each test state to the number of
|
||||||
|
tests in that state for the given technique.
|
||||||
|
"""
|
||||||
|
# ...
|
||||||
...
|
...
|
||||||
|
|
||||||
# -- Batch queries -----------------------------------------------------
|
# -- Batch queries -----------------------------------------------------
|
||||||
|
|
||||||
def get_states_and_results_for_technique(
|
def get_states_and_results_for_technique(
|
||||||
self,
|
self,
|
||||||
|
# Entry: technique_id
|
||||||
technique_id: uuid.UUID,
|
technique_id: uuid.UUID,
|
||||||
) -> list[tuple[str, str | None]]:
|
) -> list[tuple[str, str | None]]:
|
||||||
"""Return (state, detection_result) pairs for all tests of a technique.
|
"""Return (state, detection_result) pairs for all tests of a technique.
|
||||||
|
|
||||||
Used by TechniqueEntity.recalculate_status() without loading full
|
Used by TechniqueEntity.recalculate_status() without loading full
|
||||||
test models.
|
test models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
technique_id (uuid.UUID): Primary key of the technique whose test
|
||||||
|
data to retrieve.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[tuple[str, str | None]]: Each tuple contains the test state
|
||||||
|
string and the detection result string (or None if not yet set).
|
||||||
"""
|
"""
|
||||||
|
# ...
|
||||||
...
|
...
|
||||||
|
|||||||
@@ -20,34 +20,58 @@ After mutations, the service layer copies ``entity.changes`` back onto
|
|||||||
the ORM model and persists via Unit of Work.
|
the ORM model and persists via Unit of Work.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Enable future language features for compatibility
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
# Import enum
|
||||||
import enum
|
import enum
|
||||||
import uuid
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
|
# Import uuid
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
# Import dataclass, field from dataclasses
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
# Import datetime from datetime
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
# Import TYPE_CHECKING, Any from typing
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
# Import from app.domain.errors
|
||||||
from app.domain.errors import (
|
from app.domain.errors import (
|
||||||
BusinessRuleViolation,
|
BusinessRuleViolation,
|
||||||
InvalidOperationError,
|
InvalidOperationError,
|
||||||
InvalidStateTransition,
|
InvalidStateTransition,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Check: TYPE_CHECKING
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
# Import Test as TestORM from app.models.test
|
||||||
|
from app.models.test import Test as TestORM
|
||||||
|
|
||||||
# ── Value objects ────────────────────────────────────────────────────
|
# ── Value objects ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
class TestState(str, enum.Enum):
|
class TestState(str, enum.Enum):
|
||||||
|
"""Ordered lifecycle states for a security test."""
|
||||||
|
|
||||||
|
# Assign draft = "draft"
|
||||||
draft = "draft"
|
draft = "draft"
|
||||||
|
# Assign red_executing = "red_executing"
|
||||||
red_executing = "red_executing"
|
red_executing = "red_executing"
|
||||||
|
# Assign blue_evaluating = "blue_evaluating"
|
||||||
blue_evaluating = "blue_evaluating"
|
blue_evaluating = "blue_evaluating"
|
||||||
|
# Assign in_review = "in_review"
|
||||||
in_review = "in_review"
|
in_review = "in_review"
|
||||||
|
# Assign validated = "validated"
|
||||||
validated = "validated"
|
validated = "validated"
|
||||||
|
# Assign rejected = "rejected"
|
||||||
rejected = "rejected"
|
rejected = "rejected"
|
||||||
disputed = "disputed" # one lead approved, the other rejected
|
disputed = "disputed" # one lead approved, the other rejected
|
||||||
|
|
||||||
|
|
||||||
|
# Assign VALID_TRANSITIONS = {
|
||||||
VALID_TRANSITIONS: dict[TestState, list[TestState]] = {
|
VALID_TRANSITIONS: dict[TestState, list[TestState]] = {
|
||||||
TestState.draft: [TestState.red_executing],
|
TestState.draft: [TestState.red_executing],
|
||||||
TestState.red_executing: [TestState.blue_evaluating],
|
TestState.red_executing: [TestState.blue_evaluating],
|
||||||
@@ -58,6 +82,7 @@ VALID_TRANSITIONS: dict[TestState, list[TestState]] = {
|
|||||||
TestState.validated: [],
|
TestState.validated: [],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Assign _PAUSABLE_STATES = frozenset({TestState.red_executing, TestState.blue_evaluating})
|
||||||
_PAUSABLE_STATES = frozenset({TestState.red_executing, TestState.blue_evaluating})
|
_PAUSABLE_STATES = frozenset({TestState.red_executing, TestState.blue_evaluating})
|
||||||
|
|
||||||
|
|
||||||
@@ -65,8 +90,13 @@ _PAUSABLE_STATES = frozenset({TestState.red_executing, TestState.blue_evaluating
|
|||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
|
# Define class DomainEvent
|
||||||
class DomainEvent:
|
class DomainEvent:
|
||||||
|
"""Immutable record of a domain-level event emitted by the test entity."""
|
||||||
|
|
||||||
|
# name: str
|
||||||
name: str
|
name: str
|
||||||
|
# Assign payload = field(default_factory=dict)
|
||||||
payload: dict[str, Any] = field(default_factory=dict)
|
payload: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
@@ -74,30 +104,44 @@ class DomainEvent:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
# Define class TestEntity
|
||||||
class TestEntity:
|
class TestEntity:
|
||||||
"""Pure domain representation of a security test."""
|
"""Pure domain representation of a security test."""
|
||||||
|
|
||||||
|
# id: uuid.UUID
|
||||||
id: uuid.UUID
|
id: uuid.UUID
|
||||||
|
# state: TestState
|
||||||
state: TestState
|
state: TestState
|
||||||
|
|
||||||
# Red validation
|
# Red validation
|
||||||
red_validation_status: str | None = None
|
red_validation_status: str | None = None
|
||||||
|
# Assign red_validated_by = None
|
||||||
red_validated_by: uuid.UUID | None = None
|
red_validated_by: uuid.UUID | None = None
|
||||||
|
# Assign red_validated_at = None
|
||||||
red_validated_at: datetime | None = None
|
red_validated_at: datetime | None = None
|
||||||
|
# Assign red_validation_notes = None
|
||||||
red_validation_notes: str | None = None
|
red_validation_notes: str | None = None
|
||||||
|
|
||||||
# Blue validation
|
# Blue validation
|
||||||
blue_validation_status: str | None = None
|
blue_validation_status: str | None = None
|
||||||
|
# Assign blue_validated_by = None
|
||||||
blue_validated_by: uuid.UUID | None = None
|
blue_validated_by: uuid.UUID | None = None
|
||||||
|
# Assign blue_validated_at = None
|
||||||
blue_validated_at: datetime | None = None
|
blue_validated_at: datetime | None = None
|
||||||
|
# Assign blue_validation_notes = None
|
||||||
blue_validation_notes: str | None = None
|
blue_validation_notes: str | None = None
|
||||||
|
|
||||||
# Phase timing
|
# Phase timing
|
||||||
execution_date: datetime | None = None
|
execution_date: datetime | None = None
|
||||||
|
# Assign red_started_at = None
|
||||||
red_started_at: datetime | None = None
|
red_started_at: datetime | None = None
|
||||||
|
# Assign blue_started_at = None
|
||||||
blue_started_at: datetime | None = None
|
blue_started_at: datetime | None = None
|
||||||
|
# Assign paused_at = None
|
||||||
paused_at: datetime | None = None
|
paused_at: datetime | None = None
|
||||||
|
# Assign red_paused_seconds = 0
|
||||||
red_paused_seconds: int = 0
|
red_paused_seconds: int = 0
|
||||||
|
# Assign blue_paused_seconds = 0
|
||||||
blue_paused_seconds: int = 0
|
blue_paused_seconds: int = 0
|
||||||
|
|
||||||
# Internal bookkeeping (not persisted as-is)
|
# Internal bookkeeping (not persisted as-is)
|
||||||
@@ -106,58 +150,134 @@ class TestEntity:
|
|||||||
# -- Factory --------------------------------------------------------
|
# -- Factory --------------------------------------------------------
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_orm(cls, model: Any) -> TestEntity:
|
# Define function from_orm
|
||||||
"""Build a TestEntity from a SQLAlchemy ``Test`` model instance."""
|
def from_orm(cls, model: TestORM) -> TestEntity:
|
||||||
|
"""Build a TestEntity from a SQLAlchemy ``Test`` model instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (TestORM): The ORM model whose fields will be copied into the entity.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TestEntity: A fully populated domain entity reflecting the ORM state.
|
||||||
|
"""
|
||||||
|
# Assign raw_state = model.state
|
||||||
raw_state = model.state
|
raw_state = model.state
|
||||||
|
# Assign state = raw_state if isinstance(raw_state, TestState) else TestState(raw_st...
|
||||||
state = raw_state if isinstance(raw_state, TestState) else TestState(raw_state)
|
state = raw_state if isinstance(raw_state, TestState) else TestState(raw_state)
|
||||||
|
# Return cls(
|
||||||
return cls(
|
return cls(
|
||||||
|
# Keyword argument: id
|
||||||
id=model.id,
|
id=model.id,
|
||||||
|
# Keyword argument: state
|
||||||
state=state,
|
state=state,
|
||||||
|
# Keyword argument: red_validation_status
|
||||||
red_validation_status=model.red_validation_status,
|
red_validation_status=model.red_validation_status,
|
||||||
|
# Keyword argument: red_validated_by
|
||||||
red_validated_by=model.red_validated_by,
|
red_validated_by=model.red_validated_by,
|
||||||
|
# Keyword argument: red_validated_at
|
||||||
red_validated_at=model.red_validated_at,
|
red_validated_at=model.red_validated_at,
|
||||||
|
# Keyword argument: red_validation_notes
|
||||||
red_validation_notes=model.red_validation_notes,
|
red_validation_notes=model.red_validation_notes,
|
||||||
|
# Keyword argument: blue_validation_status
|
||||||
blue_validation_status=model.blue_validation_status,
|
blue_validation_status=model.blue_validation_status,
|
||||||
|
# Keyword argument: blue_validated_by
|
||||||
blue_validated_by=model.blue_validated_by,
|
blue_validated_by=model.blue_validated_by,
|
||||||
|
# Keyword argument: blue_validated_at
|
||||||
blue_validated_at=model.blue_validated_at,
|
blue_validated_at=model.blue_validated_at,
|
||||||
|
# Keyword argument: blue_validation_notes
|
||||||
blue_validation_notes=model.blue_validation_notes,
|
blue_validation_notes=model.blue_validation_notes,
|
||||||
|
# Keyword argument: execution_date
|
||||||
execution_date=model.execution_date,
|
execution_date=model.execution_date,
|
||||||
|
# Keyword argument: red_started_at
|
||||||
red_started_at=model.red_started_at,
|
red_started_at=model.red_started_at,
|
||||||
|
# Keyword argument: blue_started_at
|
||||||
blue_started_at=model.blue_started_at,
|
blue_started_at=model.blue_started_at,
|
||||||
|
# Keyword argument: paused_at
|
||||||
paused_at=model.paused_at,
|
paused_at=model.paused_at,
|
||||||
|
# Keyword argument: red_paused_seconds
|
||||||
red_paused_seconds=model.red_paused_seconds or 0,
|
red_paused_seconds=model.red_paused_seconds or 0,
|
||||||
|
# Keyword argument: blue_paused_seconds
|
||||||
blue_paused_seconds=model.blue_paused_seconds or 0,
|
blue_paused_seconds=model.blue_paused_seconds or 0,
|
||||||
)
|
)
|
||||||
|
|
||||||
def apply_to(self, model: Any) -> None:
|
# Define function apply_to
|
||||||
"""Copy the entity's mutable fields back onto the ORM model."""
|
def apply_to(self, model: TestORM) -> None:
|
||||||
|
"""Copy the entity's mutable fields back onto the ORM model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (TestORM): The ORM model to update in-place.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
# Assign model.state = self.state
|
||||||
model.state = self.state
|
model.state = self.state
|
||||||
|
# Assign model.red_validation_status = self.red_validation_status
|
||||||
model.red_validation_status = self.red_validation_status
|
model.red_validation_status = self.red_validation_status
|
||||||
|
# Assign model.red_validated_by = self.red_validated_by
|
||||||
model.red_validated_by = self.red_validated_by
|
model.red_validated_by = self.red_validated_by
|
||||||
|
# Assign model.red_validated_at = self.red_validated_at
|
||||||
model.red_validated_at = self.red_validated_at
|
model.red_validated_at = self.red_validated_at
|
||||||
|
# Assign model.red_validation_notes = self.red_validation_notes
|
||||||
model.red_validation_notes = self.red_validation_notes
|
model.red_validation_notes = self.red_validation_notes
|
||||||
|
# Assign model.blue_validation_status = self.blue_validation_status
|
||||||
model.blue_validation_status = self.blue_validation_status
|
model.blue_validation_status = self.blue_validation_status
|
||||||
|
# Assign model.blue_validated_by = self.blue_validated_by
|
||||||
model.blue_validated_by = self.blue_validated_by
|
model.blue_validated_by = self.blue_validated_by
|
||||||
|
# Assign model.blue_validated_at = self.blue_validated_at
|
||||||
model.blue_validated_at = self.blue_validated_at
|
model.blue_validated_at = self.blue_validated_at
|
||||||
|
# Assign model.blue_validation_notes = self.blue_validation_notes
|
||||||
model.blue_validation_notes = self.blue_validation_notes
|
model.blue_validation_notes = self.blue_validation_notes
|
||||||
|
# Assign model.execution_date = self.execution_date
|
||||||
model.execution_date = self.execution_date
|
model.execution_date = self.execution_date
|
||||||
|
# Assign model.red_started_at = self.red_started_at
|
||||||
model.red_started_at = self.red_started_at
|
model.red_started_at = self.red_started_at
|
||||||
|
# Assign model.blue_started_at = self.blue_started_at
|
||||||
model.blue_started_at = self.blue_started_at
|
model.blue_started_at = self.blue_started_at
|
||||||
|
# Assign model.paused_at = self.paused_at
|
||||||
model.paused_at = self.paused_at
|
model.paused_at = self.paused_at
|
||||||
|
# Assign model.red_paused_seconds = self.red_paused_seconds
|
||||||
model.red_paused_seconds = self.red_paused_seconds
|
model.red_paused_seconds = self.red_paused_seconds
|
||||||
|
# Assign model.blue_paused_seconds = self.blue_paused_seconds
|
||||||
model.blue_paused_seconds = self.blue_paused_seconds
|
model.blue_paused_seconds = self.blue_paused_seconds
|
||||||
|
|
||||||
# -- Query helpers --------------------------------------------------
|
# -- Query helpers --------------------------------------------------
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
# Define function events
|
||||||
def events(self) -> list[DomainEvent]:
|
def events(self) -> list[DomainEvent]:
|
||||||
|
"""Return a snapshot of all domain events raised on this entity.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[DomainEvent]: Ordered list of events emitted since the entity
|
||||||
|
was constructed or last cleared.
|
||||||
|
"""
|
||||||
|
# Return list(self._events)
|
||||||
return list(self._events)
|
return list(self._events)
|
||||||
|
|
||||||
|
# Define function can_transition
|
||||||
def can_transition(self, target: TestState) -> bool:
|
def can_transition(self, target: TestState) -> bool:
|
||||||
|
"""Check whether a transition from the current state to *target* is valid.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target (TestState): The desired next state.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the transition is allowed, False otherwise.
|
||||||
|
"""
|
||||||
|
# Return target in VALID_TRANSITIONS.get(self.state, [])
|
||||||
return target in VALID_TRANSITIONS.get(self.state, [])
|
return target in VALID_TRANSITIONS.get(self.state, [])
|
||||||
|
|
||||||
|
# Apply the @property decorator
|
||||||
@property
|
@property
|
||||||
|
# Define function is_terminal
|
||||||
def is_terminal(self) -> bool:
|
def is_terminal(self) -> bool:
|
||||||
|
"""Return True if the test has reached its final (validated) state.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True when state is ``validated``, False for all other states.
|
||||||
|
"""
|
||||||
|
# Return self.state == TestState.validated
|
||||||
return self.state == TestState.validated
|
return self.state == TestState.validated
|
||||||
|
|
||||||
# -- Core transition ------------------------------------------------
|
# -- Core transition ------------------------------------------------
|
||||||
@@ -171,148 +291,305 @@ class TestEntity:
|
|||||||
Returns the *previous* state value as a plain string.
|
Returns the *previous* state value as a plain string.
|
||||||
|
|
||||||
Raises :class:`InvalidStateTransition` when the move is illegal.
|
Raises :class:`InvalidStateTransition` when the move is illegal.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target (TestState | str): The desired next state, as an enum member
|
||||||
|
or its string equivalent.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The previous state value before the transition.
|
||||||
"""
|
"""
|
||||||
|
# Assign value = target.value if hasattr(target, "value") else str(target)
|
||||||
value = target.value if hasattr(target, "value") else str(target)
|
value = target.value if hasattr(target, "value") else str(target)
|
||||||
|
# Assign resolved = target if isinstance(target, TestState) else TestState(value)
|
||||||
resolved = target if isinstance(target, TestState) else TestState(value)
|
resolved = target if isinstance(target, TestState) else TestState(value)
|
||||||
|
# Return self._transition(resolved)
|
||||||
return self._transition(resolved)
|
return self._transition(resolved)
|
||||||
|
|
||||||
|
# Define function _transition
|
||||||
def _transition(self, target: TestState) -> str:
|
def _transition(self, target: TestState) -> str:
|
||||||
"""Internal: validate and apply; return previous state value."""
|
"""Validate and apply a state transition, returning the previous state value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target (TestState): The desired next state enum member.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The previous state value before the transition was applied.
|
||||||
|
"""
|
||||||
|
# Check: not self.can_transition(target)
|
||||||
if not self.can_transition(target):
|
if not self.can_transition(target):
|
||||||
|
# Assign valid = [s.value for s in VALID_TRANSITIONS.get(self.state, [])]
|
||||||
valid = [s.value for s in VALID_TRANSITIONS.get(self.state, [])]
|
valid = [s.value for s in VALID_TRANSITIONS.get(self.state, [])]
|
||||||
|
# Raise InvalidStateTransition
|
||||||
raise InvalidStateTransition(
|
raise InvalidStateTransition(
|
||||||
|
# Keyword argument: current_state
|
||||||
current_state=self.state.value,
|
current_state=self.state.value,
|
||||||
|
# Keyword argument: target_state
|
||||||
target_state=target.value,
|
target_state=target.value,
|
||||||
|
# Keyword argument: valid_transitions
|
||||||
valid_transitions=valid,
|
valid_transitions=valid,
|
||||||
)
|
)
|
||||||
|
# Assign previous = self.state.value
|
||||||
previous = self.state.value
|
previous = self.state.value
|
||||||
|
# Assign self.state = target
|
||||||
self.state = target
|
self.state = target
|
||||||
|
# Call self._events.append()
|
||||||
self._events.append(DomainEvent(
|
self._events.append(DomainEvent(
|
||||||
|
# Literal argument value
|
||||||
"state_changed",
|
"state_changed",
|
||||||
{"previous": previous, "new": target.value},
|
{"previous": previous, "new": target.value},
|
||||||
))
|
))
|
||||||
|
# Return previous
|
||||||
return previous
|
return previous
|
||||||
|
|
||||||
# -- Lifecycle commands --------------------------------------------
|
# -- Lifecycle commands --------------------------------------------
|
||||||
|
|
||||||
def start_execution(self) -> None:
|
def start_execution(self) -> None:
|
||||||
"""``draft`` -> ``red_executing``."""
|
"""Transition the test from ``draft`` to ``red_executing``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
# Call self._transition()
|
||||||
self._transition(TestState.red_executing)
|
self._transition(TestState.red_executing)
|
||||||
|
# Assign now = datetime.utcnow()
|
||||||
now = datetime.utcnow()
|
now = datetime.utcnow()
|
||||||
|
# Assign self.execution_date = now
|
||||||
self.execution_date = now
|
self.execution_date = now
|
||||||
|
# Assign self.red_started_at = now
|
||||||
self.red_started_at = now
|
self.red_started_at = now
|
||||||
|
# Call self._events.append()
|
||||||
self._events.append(DomainEvent("execution_started"))
|
self._events.append(DomainEvent("execution_started"))
|
||||||
|
|
||||||
|
# Define function submit_red_evidence
|
||||||
def submit_red_evidence(self) -> int:
|
def submit_red_evidence(self) -> int:
|
||||||
"""``red_executing`` -> ``blue_evaluating``.
|
"""Transition the test from ``red_executing`` to ``blue_evaluating``.
|
||||||
|
|
||||||
Auto-resumes if paused. Returns paused seconds accumulated
|
Auto-resumes if paused. Returns paused seconds accumulated
|
||||||
during this phase (for worklog calculation).
|
during this phase (for worklog calculation).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Total seconds the red phase was paused.
|
||||||
"""
|
"""
|
||||||
|
# Assign paused_extra = self._auto_resume()
|
||||||
paused_extra = self._auto_resume()
|
paused_extra = self._auto_resume()
|
||||||
|
# Call self._transition()
|
||||||
self._transition(TestState.blue_evaluating)
|
self._transition(TestState.blue_evaluating)
|
||||||
|
# Assign total_paused = self.red_paused_seconds + paused_extra
|
||||||
total_paused = self.red_paused_seconds + paused_extra
|
total_paused = self.red_paused_seconds + paused_extra
|
||||||
|
# Assign self.blue_started_at = datetime.utcnow()
|
||||||
self.blue_started_at = datetime.utcnow()
|
self.blue_started_at = datetime.utcnow()
|
||||||
|
# Assign self.blue_paused_seconds = 0
|
||||||
self.blue_paused_seconds = 0
|
self.blue_paused_seconds = 0
|
||||||
|
# Call self._events.append()
|
||||||
self._events.append(DomainEvent(
|
self._events.append(DomainEvent(
|
||||||
|
# Literal argument value
|
||||||
"red_evidence_submitted",
|
"red_evidence_submitted",
|
||||||
{"red_paused_seconds": total_paused},
|
{"red_paused_seconds": total_paused},
|
||||||
))
|
))
|
||||||
|
# Return total_paused
|
||||||
return total_paused
|
return total_paused
|
||||||
|
|
||||||
|
# Define function submit_blue_evidence
|
||||||
def submit_blue_evidence(self) -> int:
|
def submit_blue_evidence(self) -> int:
|
||||||
"""``blue_evaluating`` -> ``in_review``.
|
"""Transition the test from ``blue_evaluating`` to ``in_review``.
|
||||||
|
|
||||||
Auto-resumes if paused. Returns paused seconds accumulated
|
Auto-resumes if paused. Returns paused seconds accumulated
|
||||||
during this phase (for worklog calculation).
|
during this phase (for worklog calculation).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Total seconds the blue phase was paused.
|
||||||
"""
|
"""
|
||||||
|
# Assign paused_extra = self._auto_resume()
|
||||||
paused_extra = self._auto_resume()
|
paused_extra = self._auto_resume()
|
||||||
|
# Call self._transition()
|
||||||
self._transition(TestState.in_review)
|
self._transition(TestState.in_review)
|
||||||
|
# Assign total_paused = self.blue_paused_seconds + paused_extra
|
||||||
total_paused = self.blue_paused_seconds + paused_extra
|
total_paused = self.blue_paused_seconds + paused_extra
|
||||||
|
# Call self._events.append()
|
||||||
self._events.append(DomainEvent(
|
self._events.append(DomainEvent(
|
||||||
|
# Literal argument value
|
||||||
"blue_evidence_submitted",
|
"blue_evidence_submitted",
|
||||||
{"blue_paused_seconds": total_paused},
|
{"blue_paused_seconds": total_paused},
|
||||||
))
|
))
|
||||||
|
# Return total_paused
|
||||||
return total_paused
|
return total_paused
|
||||||
|
|
||||||
|
# Define function pause_timer
|
||||||
def pause_timer(self) -> None:
|
def pause_timer(self) -> None:
|
||||||
"""Pause the active phase timer."""
|
"""Pause the active phase timer.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
# Check: self.state not in _PAUSABLE_STATES
|
||||||
if self.state not in _PAUSABLE_STATES:
|
if self.state not in _PAUSABLE_STATES:
|
||||||
|
# Raise BusinessRuleViolation
|
||||||
raise BusinessRuleViolation(
|
raise BusinessRuleViolation(
|
||||||
f"Cannot pause timer in '{self.state.value}' state"
|
f"Cannot pause timer in '{self.state.value}' state"
|
||||||
)
|
)
|
||||||
|
# Check: self.paused_at is not None
|
||||||
if self.paused_at is not None:
|
if self.paused_at is not None:
|
||||||
|
# Raise BusinessRuleViolation
|
||||||
raise BusinessRuleViolation("Timer is already paused")
|
raise BusinessRuleViolation("Timer is already paused")
|
||||||
|
# Assign self.paused_at = datetime.utcnow()
|
||||||
self.paused_at = datetime.utcnow()
|
self.paused_at = datetime.utcnow()
|
||||||
|
# Call self._events.append()
|
||||||
self._events.append(DomainEvent("timer_paused"))
|
self._events.append(DomainEvent("timer_paused"))
|
||||||
|
|
||||||
|
# Define function resume_timer
|
||||||
def resume_timer(self) -> int:
|
def resume_timer(self) -> int:
|
||||||
"""Resume a paused timer. Returns seconds that were paused."""
|
"""Resume a paused timer.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Number of seconds the timer was paused for.
|
||||||
|
"""
|
||||||
|
# Check: self.paused_at is None
|
||||||
if self.paused_at is None:
|
if self.paused_at is None:
|
||||||
|
# Raise BusinessRuleViolation
|
||||||
raise BusinessRuleViolation("Timer is not paused")
|
raise BusinessRuleViolation("Timer is not paused")
|
||||||
|
# Assign now = datetime.utcnow()
|
||||||
now = datetime.utcnow()
|
now = datetime.utcnow()
|
||||||
|
# Assign paused_seconds = max(int((now - self.paused_at).total_seconds()), 0)
|
||||||
paused_seconds = max(int((now - self.paused_at).total_seconds()), 0)
|
paused_seconds = max(int((now - self.paused_at).total_seconds()), 0)
|
||||||
|
# Check: self.state == TestState.red_executing
|
||||||
if self.state == TestState.red_executing:
|
if self.state == TestState.red_executing:
|
||||||
|
# Assign self.red_paused_seconds = paused_seconds
|
||||||
self.red_paused_seconds += paused_seconds
|
self.red_paused_seconds += paused_seconds
|
||||||
|
# Alternative: self.state == TestState.blue_evaluating
|
||||||
elif self.state == TestState.blue_evaluating:
|
elif self.state == TestState.blue_evaluating:
|
||||||
|
# Assign self.blue_paused_seconds = paused_seconds
|
||||||
self.blue_paused_seconds += paused_seconds
|
self.blue_paused_seconds += paused_seconds
|
||||||
|
# Assign self.paused_at = None
|
||||||
self.paused_at = None
|
self.paused_at = None
|
||||||
|
# Call self._events.append()
|
||||||
self._events.append(DomainEvent("timer_resumed", {"paused_seconds": paused_seconds}))
|
self._events.append(DomainEvent("timer_resumed", {"paused_seconds": paused_seconds}))
|
||||||
|
# Return paused_seconds
|
||||||
return paused_seconds
|
return paused_seconds
|
||||||
|
|
||||||
|
# Define function validate_red
|
||||||
def validate_red(self, status: str, *, by: uuid.UUID, notes: str | None = None) -> None:
|
def validate_red(self, status: str, *, by: uuid.UUID, notes: str | None = None) -> None:
|
||||||
"""Record Red Lead's validation decision."""
|
"""Record Red Lead's validation decision.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
status (str): Validation outcome; must be ``"approved"`` or ``"rejected"``.
|
||||||
|
by (uuid.UUID): UUID of the Red Lead recording the decision.
|
||||||
|
notes (str | None): Optional free-text notes about the decision.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
# Call self._assert_in_review()
|
||||||
self._assert_in_review("red")
|
self._assert_in_review("red")
|
||||||
|
# Call self._assert_valid_vote()
|
||||||
self._assert_valid_vote(status)
|
self._assert_valid_vote(status)
|
||||||
|
# Assign now = datetime.utcnow()
|
||||||
now = datetime.utcnow()
|
now = datetime.utcnow()
|
||||||
|
# Assign self.red_validation_status = status
|
||||||
self.red_validation_status = status
|
self.red_validation_status = status
|
||||||
|
# Assign self.red_validated_by = by
|
||||||
self.red_validated_by = by
|
self.red_validated_by = by
|
||||||
|
# Assign self.red_validated_at = now
|
||||||
self.red_validated_at = now
|
self.red_validated_at = now
|
||||||
|
# Assign self.red_validation_notes = notes
|
||||||
self.red_validation_notes = notes
|
self.red_validation_notes = notes
|
||||||
|
# Call self._events.append()
|
||||||
self._events.append(DomainEvent("red_validated", {"status": status}))
|
self._events.append(DomainEvent("red_validated", {"status": status}))
|
||||||
|
# Call self._check_dual_validation()
|
||||||
self._check_dual_validation()
|
self._check_dual_validation()
|
||||||
|
|
||||||
|
# Define function validate_blue
|
||||||
def validate_blue(self, status: str, *, by: uuid.UUID, notes: str | None = None) -> None:
|
def validate_blue(self, status: str, *, by: uuid.UUID, notes: str | None = None) -> None:
|
||||||
"""Record Blue Lead's validation decision."""
|
"""Record Blue Lead's validation decision.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
status (str): Validation outcome; must be ``"approved"`` or ``"rejected"``.
|
||||||
|
by (uuid.UUID): UUID of the Blue Lead recording the decision.
|
||||||
|
notes (str | None): Optional free-text notes about the decision.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
# Call self._assert_in_review()
|
||||||
self._assert_in_review("blue")
|
self._assert_in_review("blue")
|
||||||
|
# Call self._assert_valid_vote()
|
||||||
self._assert_valid_vote(status)
|
self._assert_valid_vote(status)
|
||||||
|
# Assign now = datetime.utcnow()
|
||||||
now = datetime.utcnow()
|
now = datetime.utcnow()
|
||||||
|
# Assign self.blue_validation_status = status
|
||||||
self.blue_validation_status = status
|
self.blue_validation_status = status
|
||||||
|
# Assign self.blue_validated_by = by
|
||||||
self.blue_validated_by = by
|
self.blue_validated_by = by
|
||||||
|
# Assign self.blue_validated_at = now
|
||||||
self.blue_validated_at = now
|
self.blue_validated_at = now
|
||||||
|
# Assign self.blue_validation_notes = notes
|
||||||
self.blue_validation_notes = notes
|
self.blue_validation_notes = notes
|
||||||
|
# Call self._events.append()
|
||||||
self._events.append(DomainEvent("blue_validated", {"status": status}))
|
self._events.append(DomainEvent("blue_validated", {"status": status}))
|
||||||
|
# Call self._check_dual_validation()
|
||||||
self._check_dual_validation()
|
self._check_dual_validation()
|
||||||
|
|
||||||
|
# Define function reopen
|
||||||
def reopen(self) -> None:
|
def reopen(self) -> None:
|
||||||
"""``rejected`` -> ``draft``, clearing all validation/timing fields."""
|
"""Transition the test from ``rejected`` back to ``draft``, clearing all validation and timing fields.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
# Call self._transition()
|
||||||
self._transition(TestState.draft)
|
self._transition(TestState.draft)
|
||||||
|
# Assign self.red_validation_status = None
|
||||||
self.red_validation_status = None
|
self.red_validation_status = None
|
||||||
|
# Assign self.red_validated_by = None
|
||||||
self.red_validated_by = None
|
self.red_validated_by = None
|
||||||
|
# Assign self.red_validated_at = None
|
||||||
self.red_validated_at = None
|
self.red_validated_at = None
|
||||||
|
# Assign self.red_validation_notes = None
|
||||||
self.red_validation_notes = None
|
self.red_validation_notes = None
|
||||||
|
# Assign self.blue_validation_status = None
|
||||||
self.blue_validation_status = None
|
self.blue_validation_status = None
|
||||||
|
# Assign self.blue_validated_by = None
|
||||||
self.blue_validated_by = None
|
self.blue_validated_by = None
|
||||||
|
# Assign self.blue_validated_at = None
|
||||||
self.blue_validated_at = None
|
self.blue_validated_at = None
|
||||||
|
# Assign self.blue_validation_notes = None
|
||||||
self.blue_validation_notes = None
|
self.blue_validation_notes = None
|
||||||
|
# Assign self.red_started_at = None
|
||||||
self.red_started_at = None
|
self.red_started_at = None
|
||||||
|
# Assign self.blue_started_at = None
|
||||||
self.blue_started_at = None
|
self.blue_started_at = None
|
||||||
|
# Assign self.paused_at = None
|
||||||
self.paused_at = None
|
self.paused_at = None
|
||||||
|
# Assign self.red_paused_seconds = 0
|
||||||
self.red_paused_seconds = 0
|
self.red_paused_seconds = 0
|
||||||
|
# Assign self.blue_paused_seconds = 0
|
||||||
self.blue_paused_seconds = 0
|
self.blue_paused_seconds = 0
|
||||||
|
# Call self._events.append()
|
||||||
self._events.append(DomainEvent("test_reopened"))
|
self._events.append(DomainEvent("test_reopened"))
|
||||||
|
|
||||||
# -- Private -------------------------------------------------------
|
# -- Private -------------------------------------------------------
|
||||||
|
|
||||||
def _auto_resume(self) -> int:
|
def _auto_resume(self) -> int:
|
||||||
"""If paused, accumulate pause time and clear. Returns extra seconds."""
|
"""Accumulate pause time and clear the paused timestamp if currently paused.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Extra seconds that were accumulated from the current pause, or 0
|
||||||
|
if the timer was not paused.
|
||||||
|
"""
|
||||||
|
# Check: self.paused_at is None
|
||||||
if self.paused_at is None:
|
if self.paused_at is None:
|
||||||
|
# Return 0
|
||||||
return 0
|
return 0
|
||||||
|
# Assign now = datetime.utcnow()
|
||||||
now = datetime.utcnow()
|
now = datetime.utcnow()
|
||||||
|
# Assign extra = max(int((now - self.paused_at).total_seconds()), 0)
|
||||||
extra = max(int((now - self.paused_at).total_seconds()), 0)
|
extra = max(int((now - self.paused_at).total_seconds()), 0)
|
||||||
|
# Assign self.paused_at = None
|
||||||
self.paused_at = None
|
self.paused_at = None
|
||||||
|
# Return extra
|
||||||
return extra
|
return extra
|
||||||
|
|
||||||
|
# Define function check_dual_validation
|
||||||
def check_dual_validation(self) -> None:
|
def check_dual_validation(self) -> None:
|
||||||
"""Evaluate both leads' votes and advance state if appropriate.
|
"""Evaluate both leads' votes and advance state if appropriate.
|
||||||
|
|
||||||
@@ -324,8 +601,10 @@ class TestEntity:
|
|||||||
|
|
||||||
Called automatically by :meth:`validate_red` and :meth:`validate_blue`.
|
Called automatically by :meth:`validate_red` and :meth:`validate_blue`.
|
||||||
"""
|
"""
|
||||||
|
# Call self._check_dual_validation()
|
||||||
self._check_dual_validation()
|
self._check_dual_validation()
|
||||||
|
|
||||||
|
# Define function _assert_in_review
|
||||||
def _assert_in_review(self, side: str) -> None:
|
def _assert_in_review(self, side: str) -> None:
|
||||||
if self.state not in (TestState.in_review, TestState.disputed):
|
if self.state not in (TestState.in_review, TestState.disputed):
|
||||||
raise InvalidOperationError(
|
raise InvalidOperationError(
|
||||||
@@ -333,19 +612,34 @@ class TestEntity:
|
|||||||
f"'{self.state.value}' state (must be in_review or disputed)"
|
f"'{self.state.value}' state (must be in_review or disputed)"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Apply the @staticmethod decorator
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
# Define function _assert_valid_vote
|
||||||
def _assert_valid_vote(status: str) -> None:
|
def _assert_valid_vote(status: str) -> None:
|
||||||
|
"""Raise InvalidOperationError if *status* is not a valid vote value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
status (str): The vote value to validate; must be ``"approved"`` or ``"rejected"``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
# Check: status not in ("approved", "rejected")
|
||||||
if status not in ("approved", "rejected"):
|
if status not in ("approved", "rejected"):
|
||||||
|
# Raise InvalidOperationError
|
||||||
raise InvalidOperationError(
|
raise InvalidOperationError(
|
||||||
|
# Literal argument value
|
||||||
"validation_status must be 'approved' or 'rejected'"
|
"validation_status must be 'approved' or 'rejected'"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Define function _check_dual_validation
|
||||||
def _check_dual_validation(self) -> None:
|
def _check_dual_validation(self) -> None:
|
||||||
"""Advance the test state once both leads have voted."""
|
"""Advance the test state once both leads have voted."""
|
||||||
r, b = self.red_validation_status, self.blue_validation_status
|
r, b = self.red_validation_status, self.blue_validation_status
|
||||||
|
|
||||||
if r == "approved" and b == "approved":
|
if r == "approved" and b == "approved":
|
||||||
self.state = TestState.validated
|
self.state = TestState.validated
|
||||||
|
# Call self._events.append()
|
||||||
self._events.append(DomainEvent("dual_validation_approved"))
|
self._events.append(DomainEvent("dual_validation_approved"))
|
||||||
|
|
||||||
elif r == "rejected" and b == "rejected":
|
elif r == "rejected" and b == "rejected":
|
||||||
|
|||||||
@@ -20,36 +20,84 @@ Services should **never** call ``db.commit()``; they use ``db.add()`` /
|
|||||||
osint_enrichment_service.enrich_technique_with_cves).
|
osint_enrichment_service.enrich_technique_with_cves).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Enable future language features for compatibility
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
# Import TracebackType from types
|
||||||
|
from types import TracebackType
|
||||||
|
|
||||||
|
# Import Session from sqlalchemy.orm
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
|
||||||
|
# Define class UnitOfWork
|
||||||
class UnitOfWork:
|
class UnitOfWork:
|
||||||
"""Lightweight transaction wrapper around an existing SQLAlchemy session."""
|
"""Lightweight transaction wrapper around an existing SQLAlchemy session."""
|
||||||
|
|
||||||
|
# Define function __init__
|
||||||
def __init__(self, session: Session) -> None:
|
def __init__(self, session: Session) -> None:
|
||||||
|
"""Wrap an existing SQLAlchemy session in a Unit of Work.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session (Session): The active SQLAlchemy session to manage.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
# Assign self._session = session
|
||||||
self._session = session
|
self._session = session
|
||||||
|
|
||||||
# -- context manager -----------------------------------------------------
|
# -- context manager -----------------------------------------------------
|
||||||
|
|
||||||
def __enter__(self) -> "UnitOfWork":
|
def __enter__(self) -> "UnitOfWork":
|
||||||
|
"""Enter the runtime context, returning this UnitOfWork instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UnitOfWork: The UnitOfWork itself, for use in ``with`` statements.
|
||||||
|
"""
|
||||||
|
# Return self
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
# Define function __exit__
|
||||||
|
def __exit__(
|
||||||
|
self,
|
||||||
|
# Entry: exc_type
|
||||||
|
exc_type: type[BaseException] | None,
|
||||||
|
# Entry: exc_val
|
||||||
|
exc_val: BaseException | None,
|
||||||
|
# Entry: exc_tb
|
||||||
|
exc_tb: TracebackType | None,
|
||||||
|
) -> None:
|
||||||
|
"""Exit the runtime context, rolling back if an exception propagated.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
exc_type (type[BaseException] | None): Exception class, if raised.
|
||||||
|
exc_val (BaseException | None): Exception instance, if raised.
|
||||||
|
exc_tb (TracebackType | None): Traceback object, if an exception was raised.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
# Check: exc_type is not None
|
||||||
if exc_type is not None:
|
if exc_type is not None:
|
||||||
|
# Call self.rollback()
|
||||||
self.rollback()
|
self.rollback()
|
||||||
|
|
||||||
# -- public API ----------------------------------------------------------
|
# -- public API ----------------------------------------------------------
|
||||||
|
|
||||||
def commit(self) -> None:
|
def commit(self) -> None:
|
||||||
"""Flush pending changes and commit the transaction."""
|
"""Flush pending changes and commit the transaction."""
|
||||||
|
# Call self._session.commit()
|
||||||
self._session.commit()
|
self._session.commit()
|
||||||
|
|
||||||
|
# Define function rollback
|
||||||
def rollback(self) -> None:
|
def rollback(self) -> None:
|
||||||
"""Roll back the current transaction."""
|
"""Roll back the current transaction."""
|
||||||
|
# Call self._session.rollback()
|
||||||
self._session.rollback()
|
self._session.rollback()
|
||||||
|
|
||||||
|
# Define function flush
|
||||||
def flush(self) -> None:
|
def flush(self) -> None:
|
||||||
"""Flush pending changes without committing (useful for getting IDs)."""
|
"""Flush pending changes without committing (useful for getting IDs)."""
|
||||||
|
# Call self._session.flush()
|
||||||
self._session.flush()
|
self._session.flush()
|
||||||
|
|||||||
@@ -1,4 +1,9 @@
|
|||||||
|
"""Immutable domain value objects."""
|
||||||
|
# Import MitreId from app.domain.value_objects.mitre_id
|
||||||
from app.domain.value_objects.mitre_id import MitreId
|
from app.domain.value_objects.mitre_id import MitreId
|
||||||
|
|
||||||
|
# Import ScoringWeights from app.domain.value_objects.scoring_weights
|
||||||
from app.domain.value_objects.scoring_weights import ScoringWeights
|
from app.domain.value_objects.scoring_weights import ScoringWeights
|
||||||
|
|
||||||
|
# Assign __all__ = ["MitreId", "ScoringWeights"]
|
||||||
__all__ = ["MitreId", "ScoringWeights"]
|
__all__ = ["MitreId", "ScoringWeights"]
|
||||||
|
|||||||
@@ -5,47 +5,111 @@ format: ``T`` followed by 4 digits, optionally a dot and 3 more digits
|
|||||||
for sub-techniques (e.g. ``T1059``, ``T1059.001``).
|
for sub-techniques (e.g. ``T1059``, ``T1059.001``).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Enable future language features for compatibility
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
# Import re
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
# Import dataclass from dataclasses
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
# Assign _MITRE_ID_RE = re.compile(r"^T\d{4}(\.\d{3})?$")
|
||||||
_MITRE_ID_RE = re.compile(r"^T\d{4}(\.\d{3})?$")
|
_MITRE_ID_RE = re.compile(r"^T\d{4}(\.\d{3})?$")
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @dataclass decorator
|
||||||
@dataclass(frozen=True, slots=True)
|
@dataclass(frozen=True, slots=True)
|
||||||
|
# Define class MitreId
|
||||||
class MitreId:
|
class MitreId:
|
||||||
"""Validated MITRE ATT&CK technique identifier."""
|
"""Validated MITRE ATT&CK technique identifier."""
|
||||||
|
|
||||||
|
# value: str
|
||||||
value: str
|
value: str
|
||||||
|
|
||||||
|
# Define function __post_init__
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
|
"""Validate that *value* matches the expected MITRE ATT&CK ID format.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
# Check: not _MITRE_ID_RE.match(self.value)
|
||||||
if not _MITRE_ID_RE.match(self.value):
|
if not _MITRE_ID_RE.match(self.value):
|
||||||
|
# Raise ValueError
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Invalid MITRE ATT&CK ID '{self.value}'. "
|
f"Invalid MITRE ATT&CK ID '{self.value}'. "
|
||||||
|
# Literal argument value
|
||||||
"Expected format: T1234 or T1234.001"
|
"Expected format: T1234 or T1234.001"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Apply the @property decorator
|
||||||
@property
|
@property
|
||||||
|
# Define function is_subtechnique
|
||||||
def is_subtechnique(self) -> bool:
|
def is_subtechnique(self) -> bool:
|
||||||
|
"""Return True if this identifier represents a sub-technique.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True when the ID contains a dot (e.g. ``T1059.001``).
|
||||||
|
"""
|
||||||
|
# Return "." in self.value
|
||||||
return "." in self.value
|
return "." in self.value
|
||||||
|
|
||||||
|
# Apply the @property decorator
|
||||||
@property
|
@property
|
||||||
|
# Define function parent_id
|
||||||
def parent_id(self) -> str | None:
|
def parent_id(self) -> str | None:
|
||||||
"""Return the parent technique ID (e.g. T1059 for T1059.001)."""
|
"""Return the parent technique ID (e.g. ``T1059`` for ``T1059.001``).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str | None: The parent ID string, or None if this is not a sub-technique.
|
||||||
|
"""
|
||||||
|
# Check: not self.is_subtechnique
|
||||||
if not self.is_subtechnique:
|
if not self.is_subtechnique:
|
||||||
|
# Return None
|
||||||
return None
|
return None
|
||||||
|
# Return self.value.split(".")[0]
|
||||||
return self.value.split(".")[0]
|
return self.value.split(".")[0]
|
||||||
|
|
||||||
|
# Define function __str__
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
|
"""Return the string representation of the MITRE ID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The raw identifier string (e.g. ``"T1059.001"``).
|
||||||
|
"""
|
||||||
|
# Return self.value
|
||||||
return self.value
|
return self.value
|
||||||
|
|
||||||
|
# Define function __eq__
|
||||||
def __eq__(self, other: object) -> bool:
|
def __eq__(self, other: object) -> bool:
|
||||||
|
"""Compare this MitreId to another MitreId or a plain string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
other (object): The value to compare against; may be a
|
||||||
|
:class:`MitreId` instance or a plain ``str``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the identifiers are equal, NotImplemented for
|
||||||
|
unsupported types.
|
||||||
|
"""
|
||||||
|
# Check: isinstance(other, MitreId)
|
||||||
if isinstance(other, MitreId):
|
if isinstance(other, MitreId):
|
||||||
|
# Return self.value == other.value
|
||||||
return self.value == other.value
|
return self.value == other.value
|
||||||
|
# Check: isinstance(other, str)
|
||||||
if isinstance(other, str):
|
if isinstance(other, str):
|
||||||
|
# Return self.value == other
|
||||||
return self.value == other
|
return self.value == other
|
||||||
|
# Return NotImplemented
|
||||||
return NotImplemented
|
return NotImplemented
|
||||||
|
|
||||||
|
# Define function __hash__
|
||||||
def __hash__(self) -> int:
|
def __hash__(self) -> int:
|
||||||
|
"""Return the hash of the identifier string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Hash value derived from the raw identifier string.
|
||||||
|
"""
|
||||||
|
# Return hash(self.value)
|
||||||
return hash(self.value)
|
return hash(self.value)
|
||||||
|
|||||||
@@ -3,22 +3,38 @@
|
|||||||
Enforces that all five weights are non-negative and sum to exactly 100.
|
Enforces that all five weights are non-negative and sum to exactly 100.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Enable future language features for compatibility
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
# Import dataclass from dataclasses
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @dataclass decorator
|
||||||
@dataclass(frozen=True, slots=True)
|
@dataclass(frozen=True, slots=True)
|
||||||
|
# Define class ScoringWeights
|
||||||
class ScoringWeights:
|
class ScoringWeights:
|
||||||
"""Five scoring dimension weights that must sum to 100."""
|
"""Five scoring dimension weights that must sum to 100."""
|
||||||
|
|
||||||
|
# tests: float
|
||||||
tests: float
|
tests: float
|
||||||
|
# detection_rules: float
|
||||||
detection_rules: float
|
detection_rules: float
|
||||||
|
# d3fend: float
|
||||||
d3fend: float
|
d3fend: float
|
||||||
|
# recency: float
|
||||||
recency: float
|
recency: float
|
||||||
|
# severity: float
|
||||||
severity: float
|
severity: float
|
||||||
|
|
||||||
|
# Define function __post_init__
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
|
"""Validate that all weights are non-negative and sum to exactly 100.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
# Assign fields = [
|
||||||
fields = [
|
fields = [
|
||||||
self.tests,
|
self.tests,
|
||||||
self.detection_rules,
|
self.detection_rules,
|
||||||
@@ -26,32 +42,66 @@ class ScoringWeights:
|
|||||||
self.recency,
|
self.recency,
|
||||||
self.severity,
|
self.severity,
|
||||||
]
|
]
|
||||||
|
# Iterate over fields
|
||||||
for f in fields:
|
for f in fields:
|
||||||
|
# Check: f < 0
|
||||||
if f < 0:
|
if f < 0:
|
||||||
|
# Raise ValueError
|
||||||
raise ValueError("Scoring weights must be non-negative")
|
raise ValueError("Scoring weights must be non-negative")
|
||||||
|
|
||||||
|
# Assign total = sum(fields)
|
||||||
total = sum(fields)
|
total = sum(fields)
|
||||||
|
# Check: abs(total - 100) > 0.01
|
||||||
if abs(total - 100) > 0.01:
|
if abs(total - 100) > 0.01:
|
||||||
|
# Raise ValueError
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Scoring weights must sum to 100, got {total}"
|
f"Scoring weights must sum to 100, got {total}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Apply the @classmethod decorator
|
||||||
@classmethod
|
@classmethod
|
||||||
|
# Define function default
|
||||||
def default(cls) -> ScoringWeights:
|
def default(cls) -> ScoringWeights:
|
||||||
"""Return the default weight distribution."""
|
"""Return the default weight distribution.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ScoringWeights: A weight set with tests=40, detection_rules=25,
|
||||||
|
d3fend=15, recency=10, severity=10.
|
||||||
|
"""
|
||||||
|
# Return cls(
|
||||||
return cls(
|
return cls(
|
||||||
|
# Keyword argument: tests
|
||||||
tests=40.0,
|
tests=40.0,
|
||||||
|
# Keyword argument: detection_rules
|
||||||
detection_rules=25.0,
|
detection_rules=25.0,
|
||||||
|
# Keyword argument: d3fend
|
||||||
d3fend=15.0,
|
d3fend=15.0,
|
||||||
|
# Keyword argument: recency
|
||||||
recency=10.0,
|
recency=10.0,
|
||||||
|
# Keyword argument: severity
|
||||||
severity=10.0,
|
severity=10.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Backward-compatible aliases for older API payloads
|
# Backward-compatible aliases for older API payloads
|
||||||
@property
|
@property
|
||||||
|
# Define function freshness
|
||||||
def freshness(self) -> float:
|
def freshness(self) -> float:
|
||||||
|
"""Return the recency weight (backward-compatible alias).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: The value of the ``recency`` weight.
|
||||||
|
"""
|
||||||
|
# Return self.recency
|
||||||
return self.recency
|
return self.recency
|
||||||
|
|
||||||
|
# Apply the @property decorator
|
||||||
@property
|
@property
|
||||||
|
# Define function platform_diversity
|
||||||
def platform_diversity(self) -> float:
|
def platform_diversity(self) -> float:
|
||||||
|
"""Return the severity weight (backward-compatible alias).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: The value of the ``severity`` weight.
|
||||||
|
"""
|
||||||
|
# Return self.severity
|
||||||
return self.severity
|
return self.severity
|
||||||
|
|||||||
@@ -0,0 +1 @@
|
|||||||
|
"""Infrastructure adapters — persistence, caching, and external services."""
|
||||||
|
|||||||
@@ -0,0 +1 @@
|
|||||||
|
"""SQLAlchemy-based persistence adapters for the domain repository ports."""
|
||||||
|
|||||||
@@ -0,0 +1 @@
|
|||||||
|
"""ORM-to-domain entity mapper functions."""
|
||||||
|
|||||||
@@ -1,20 +1,28 @@
|
|||||||
"""Technique ORM model <-> domain entity mapper."""
|
"""Technique ORM model <-> domain entity mapper."""
|
||||||
|
|
||||||
|
# Enable future language features for compatibility
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
# Import TechniqueEntity from app.domain.entities.technique
|
||||||
from app.domain.entities.technique import TechniqueEntity
|
from app.domain.entities.technique import TechniqueEntity
|
||||||
from app.domain.enums import TechniqueStatus
|
|
||||||
|
|
||||||
|
|
||||||
|
# Define class TechniqueMapper
|
||||||
class TechniqueMapper:
|
class TechniqueMapper:
|
||||||
"""Converts between SQLAlchemy Technique model and TechniqueEntity."""
|
"""Converts between SQLAlchemy Technique model and TechniqueEntity."""
|
||||||
|
|
||||||
|
# Apply the @staticmethod decorator
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
# Define function to_entity
|
||||||
def to_entity(model: object) -> TechniqueEntity:
|
def to_entity(model: object) -> TechniqueEntity:
|
||||||
"""Convert an ORM Technique model to a domain TechniqueEntity."""
|
"""Convert an ORM Technique model to a domain TechniqueEntity."""
|
||||||
|
# Return TechniqueEntity.from_orm(model)
|
||||||
return TechniqueEntity.from_orm(model)
|
return TechniqueEntity.from_orm(model)
|
||||||
|
|
||||||
|
# Apply the @staticmethod decorator
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
# Define function to_model_updates
|
||||||
def to_model_updates(entity: TechniqueEntity, model: object) -> None:
|
def to_model_updates(entity: TechniqueEntity, model: object) -> None:
|
||||||
"""Apply entity changes back onto an existing ORM model."""
|
"""Apply entity changes back onto an existing ORM model."""
|
||||||
|
# Call entity.apply_to()
|
||||||
entity.apply_to(model)
|
entity.apply_to(model)
|
||||||
|
|||||||
@@ -1,8 +1,13 @@
|
|||||||
|
"""Concrete SQLAlchemy repository implementations."""
|
||||||
|
# Import from app.infrastructure.persistence.repositories.sa_technique_repository
|
||||||
from app.infrastructure.persistence.repositories.sa_technique_repository import (
|
from app.infrastructure.persistence.repositories.sa_technique_repository import (
|
||||||
SATechniqueRepository,
|
SATechniqueRepository,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Import from app.infrastructure.persistence.repositories.sa_test_repository
|
||||||
from app.infrastructure.persistence.repositories.sa_test_repository import (
|
from app.infrastructure.persistence.repositories.sa_test_repository import (
|
||||||
SATestRepository,
|
SATestRepository,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Assign __all__ = ["SATechniqueRepository", "SATestRepository"]
|
||||||
__all__ = ["SATechniqueRepository", "SATestRepository"]
|
__all__ = ["SATechniqueRepository", "SATestRepository"]
|
||||||
|
|||||||
@@ -4,44 +4,95 @@ Receives a Session from the caller — does NOT create its own.
|
|||||||
Does NOT call commit() — the Unit of Work owns that.
|
Does NOT call commit() — the Unit of Work owns that.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Enable future language features for compatibility
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
# Import uuid
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
# Import func from sqlalchemy
|
||||||
from sqlalchemy import func
|
from sqlalchemy import func
|
||||||
|
|
||||||
|
# Import Session from sqlalchemy.orm
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
# Import TechniqueEntity from app.domain.entities.technique
|
||||||
from app.domain.entities.technique import TechniqueEntity
|
from app.domain.entities.technique import TechniqueEntity
|
||||||
|
|
||||||
|
# Import TechniqueStatus, TestState from app.domain.enums
|
||||||
from app.domain.enums import TechniqueStatus, TestState
|
from app.domain.enums import TechniqueStatus, TestState
|
||||||
|
|
||||||
|
# Import TechniqueWithCounts from app.domain.ports.repositories.technique_repository
|
||||||
from app.domain.ports.repositories.technique_repository import TechniqueWithCounts
|
from app.domain.ports.repositories.technique_repository import TechniqueWithCounts
|
||||||
|
|
||||||
|
# Import TechniqueMapper from app.infrastructure.persistence.mappers.technique_mapper
|
||||||
from app.infrastructure.persistence.mappers.technique_mapper import TechniqueMapper
|
from app.infrastructure.persistence.mappers.technique_mapper import TechniqueMapper
|
||||||
|
|
||||||
|
# Import DetectionRule from app.models.detection_rule
|
||||||
from app.models.detection_rule import DetectionRule
|
from app.models.detection_rule import DetectionRule
|
||||||
|
|
||||||
|
# Import Technique from app.models.technique
|
||||||
from app.models.technique import Technique
|
from app.models.technique import Technique
|
||||||
|
|
||||||
|
# Import Test from app.models.test
|
||||||
from app.models.test import Test
|
from app.models.test import Test
|
||||||
|
|
||||||
|
|
||||||
|
# Define class SATechniqueRepository
|
||||||
class SATechniqueRepository:
|
class SATechniqueRepository:
|
||||||
"""Concrete repository backed by SQLAlchemy."""
|
"""Concrete repository backed by SQLAlchemy."""
|
||||||
|
|
||||||
|
# Define function __init__
|
||||||
def __init__(self, session: Session) -> None:
|
def __init__(self, session: Session) -> None:
|
||||||
|
"""Initialise the repository with a caller-provided session.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session (Session): The SQLAlchemy session to use for all queries.
|
||||||
|
"""
|
||||||
|
# Assign self._session = session
|
||||||
self._session = session
|
self._session = session
|
||||||
|
|
||||||
# -- Single-entity access ----------------------------------------------
|
# -- Single-entity access ----------------------------------------------
|
||||||
|
|
||||||
def find_by_id(self, technique_id: uuid.UUID) -> TechniqueEntity | None:
|
def find_by_id(self, technique_id: uuid.UUID) -> TechniqueEntity | None:
|
||||||
|
"""Return a single technique by its primary key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
technique_id (uuid.UUID): The UUID primary key of the technique.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TechniqueEntity | None: The matching entity, or ``None`` if not found.
|
||||||
|
"""
|
||||||
|
# Assign model = (
|
||||||
model = (
|
model = (
|
||||||
self._session.query(Technique)
|
self._session.query(Technique)
|
||||||
|
# Chain .filter() call
|
||||||
.filter(Technique.id == technique_id)
|
.filter(Technique.id == technique_id)
|
||||||
|
# Chain .first() call
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
|
# Return TechniqueMapper.to_entity(model) if model else None
|
||||||
return TechniqueMapper.to_entity(model) if model else None
|
return TechniqueMapper.to_entity(model) if model else None
|
||||||
|
|
||||||
|
# Define function find_by_mitre_id
|
||||||
def find_by_mitre_id(self, mitre_id: str) -> TechniqueEntity | None:
|
def find_by_mitre_id(self, mitre_id: str) -> TechniqueEntity | None:
|
||||||
|
"""Return a single technique by its MITRE ATT&CK ID (e.g. ``T1059.001``).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mitre_id (str): The MITRE ATT&CK identifier string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TechniqueEntity | None: The matching entity, or ``None`` if not found.
|
||||||
|
"""
|
||||||
|
# Assign model = (
|
||||||
model = (
|
model = (
|
||||||
self._session.query(Technique)
|
self._session.query(Technique)
|
||||||
|
# Chain .filter() call
|
||||||
.filter(Technique.mitre_id == mitre_id)
|
.filter(Technique.mitre_id == mitre_id)
|
||||||
|
# Chain .first() call
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
|
# Return TechniqueMapper.to_entity(model) if model else None
|
||||||
return TechniqueMapper.to_entity(model) if model else None
|
return TechniqueMapper.to_entity(model) if model else None
|
||||||
|
|
||||||
# -- List access -------------------------------------------------------
|
# -- List access -------------------------------------------------------
|
||||||
@@ -49,57 +100,111 @@ class SATechniqueRepository:
|
|||||||
def list_all(
|
def list_all(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
|
# Entry: tactic
|
||||||
tactic: str | None = None,
|
tactic: str | None = None,
|
||||||
|
# Entry: status
|
||||||
status: TechniqueStatus | None = None,
|
status: TechniqueStatus | None = None,
|
||||||
|
# Entry: review_required
|
||||||
review_required: bool | None = None,
|
review_required: bool | None = None,
|
||||||
) -> list[TechniqueEntity]:
|
) -> list[TechniqueEntity]:
|
||||||
|
"""Return all techniques, optionally filtered by tactic, status, or review flag.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tactic (str | None): Filter to techniques belonging to this tactic name.
|
||||||
|
status (TechniqueStatus | None): Filter to techniques with this coverage status.
|
||||||
|
review_required (bool | None): Filter to techniques where ``review_required`` matches.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[TechniqueEntity]: Ordered list of matching technique entities.
|
||||||
|
"""
|
||||||
|
# Assign query = self._session.query(Technique)
|
||||||
query = self._session.query(Technique)
|
query = self._session.query(Technique)
|
||||||
|
# Check: tactic is not None
|
||||||
if tactic is not None:
|
if tactic is not None:
|
||||||
|
# Assign query = query.filter(Technique.tactic == tactic)
|
||||||
query = query.filter(Technique.tactic == tactic)
|
query = query.filter(Technique.tactic == tactic)
|
||||||
|
# Check: status is not None
|
||||||
if status is not None:
|
if status is not None:
|
||||||
|
# Assign query = query.filter(Technique.status_global == status)
|
||||||
query = query.filter(Technique.status_global == status)
|
query = query.filter(Technique.status_global == status)
|
||||||
|
# Check: review_required is not None
|
||||||
if review_required is not None:
|
if review_required is not None:
|
||||||
|
# Assign query = query.filter(Technique.review_required == review_required)
|
||||||
query = query.filter(Technique.review_required == review_required)
|
query = query.filter(Technique.review_required == review_required)
|
||||||
|
# Assign models = query.order_by(Technique.mitre_id).all()
|
||||||
models = query.order_by(Technique.mitre_id).all()
|
models = query.order_by(Technique.mitre_id).all()
|
||||||
|
# Return [TechniqueMapper.to_entity(m) for m in models]
|
||||||
return [TechniqueMapper.to_entity(m) for m in models]
|
return [TechniqueMapper.to_entity(m) for m in models]
|
||||||
|
|
||||||
|
# Define function list_by_ids
|
||||||
def list_by_ids(self, ids: list[uuid.UUID]) -> list[TechniqueEntity]:
|
def list_by_ids(self, ids: list[uuid.UUID]) -> list[TechniqueEntity]:
|
||||||
|
"""Return techniques matching the provided list of UUIDs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ids (list[uuid.UUID]): UUIDs of the techniques to retrieve.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[TechniqueEntity]: Technique entities corresponding to the given IDs.
|
||||||
|
"""
|
||||||
|
# Check: not ids
|
||||||
if not ids:
|
if not ids:
|
||||||
|
# Return []
|
||||||
return []
|
return []
|
||||||
|
# Assign models = (
|
||||||
models = (
|
models = (
|
||||||
self._session.query(Technique)
|
self._session.query(Technique)
|
||||||
|
# Chain .filter() call
|
||||||
.filter(Technique.id.in_(ids))
|
.filter(Technique.id.in_(ids))
|
||||||
|
# Chain .all() call
|
||||||
.all()
|
.all()
|
||||||
)
|
)
|
||||||
|
# Return [TechniqueMapper.to_entity(m) for m in models]
|
||||||
return [TechniqueMapper.to_entity(m) for m in models]
|
return [TechniqueMapper.to_entity(m) for m in models]
|
||||||
|
|
||||||
# -- Batch queries (for scoring/heatmap) -------------------------------
|
# -- Batch queries (for scoring/heatmap) -------------------------------
|
||||||
|
|
||||||
def count_by_status(self) -> dict[TechniqueStatus, int]:
|
def count_by_status(self) -> dict[TechniqueStatus, int]:
|
||||||
|
"""Return a count of techniques grouped by their coverage status.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict[TechniqueStatus, int]: Mapping of each status value to its technique count.
|
||||||
|
"""
|
||||||
|
# Assign rows = (
|
||||||
rows = (
|
rows = (
|
||||||
self._session.query(
|
self._session.query(
|
||||||
Technique.status_global,
|
Technique.status_global,
|
||||||
func.count(Technique.id),
|
func.count(Technique.id),
|
||||||
)
|
)
|
||||||
|
# Chain .group_by() call
|
||||||
.group_by(Technique.status_global)
|
.group_by(Technique.status_global)
|
||||||
|
# Chain .all() call
|
||||||
.all()
|
.all()
|
||||||
)
|
)
|
||||||
|
# Assign result = {s: 0 for s in TechniqueStatus}
|
||||||
result = {s: 0 for s in TechniqueStatus}
|
result = {s: 0 for s in TechniqueStatus}
|
||||||
|
# Iterate over rows
|
||||||
for status_val, count in rows:
|
for status_val, count in rows:
|
||||||
|
# Assign key = (
|
||||||
key = (
|
key = (
|
||||||
status_val
|
status_val
|
||||||
if isinstance(status_val, TechniqueStatus)
|
if isinstance(status_val, TechniqueStatus)
|
||||||
else TechniqueStatus(status_val)
|
else TechniqueStatus(status_val)
|
||||||
)
|
)
|
||||||
|
# Assign result[key] = count
|
||||||
result[key] = count
|
result[key] = count
|
||||||
|
# Return result
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
# Define function find_all_with_test_counts
|
||||||
def find_all_with_test_counts(self) -> list[TechniqueWithCounts]:
|
def find_all_with_test_counts(self) -> list[TechniqueWithCounts]:
|
||||||
"""Single query replacing the N+1 pattern.
|
"""Return all techniques with pre-aggregated test and detection rule counts.
|
||||||
|
|
||||||
Returns all techniques with pre-aggregated test and detection
|
Uses a single query with subqueries to avoid the N+1 pattern.
|
||||||
rule counts via subqueries.
|
|
||||||
|
Returns:
|
||||||
|
list[TechniqueWithCounts]: All techniques with their associated counts.
|
||||||
"""
|
"""
|
||||||
|
# Assign test_count_sq = (
|
||||||
test_count_sq = (
|
test_count_sq = (
|
||||||
self._session.query(
|
self._session.query(
|
||||||
Test.technique_id,
|
Test.technique_id,
|
||||||
@@ -108,18 +213,24 @@ class SATechniqueRepository:
|
|||||||
func.cast(Test.state == TestState.validated, self._int_type())
|
func.cast(Test.state == TestState.validated, self._int_type())
|
||||||
).label("validated_count"),
|
).label("validated_count"),
|
||||||
)
|
)
|
||||||
|
# Chain .group_by() call
|
||||||
.group_by(Test.technique_id)
|
.group_by(Test.technique_id)
|
||||||
|
# Chain .subquery() call
|
||||||
.subquery()
|
.subquery()
|
||||||
)
|
)
|
||||||
|
# Assign rule_count_sq = (
|
||||||
rule_count_sq = (
|
rule_count_sq = (
|
||||||
self._session.query(
|
self._session.query(
|
||||||
DetectionRule.mitre_technique_id,
|
DetectionRule.mitre_technique_id,
|
||||||
func.count(DetectionRule.id).label("rule_count"),
|
func.count(DetectionRule.id).label("rule_count"),
|
||||||
)
|
)
|
||||||
|
# Chain .group_by() call
|
||||||
.group_by(DetectionRule.mitre_technique_id)
|
.group_by(DetectionRule.mitre_technique_id)
|
||||||
|
# Chain .subquery() call
|
||||||
.subquery()
|
.subquery()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Assign rows = (
|
||||||
rows = (
|
rows = (
|
||||||
self._session.query(
|
self._session.query(
|
||||||
Technique,
|
Technique,
|
||||||
@@ -127,20 +238,29 @@ class SATechniqueRepository:
|
|||||||
func.coalesce(test_count_sq.c.validated_count, 0),
|
func.coalesce(test_count_sq.c.validated_count, 0),
|
||||||
func.coalesce(rule_count_sq.c.rule_count, 0),
|
func.coalesce(rule_count_sq.c.rule_count, 0),
|
||||||
)
|
)
|
||||||
|
# Chain .outerjoin() call
|
||||||
.outerjoin(test_count_sq, Technique.id == test_count_sq.c.technique_id)
|
.outerjoin(test_count_sq, Technique.id == test_count_sq.c.technique_id)
|
||||||
|
# Chain .outerjoin() call
|
||||||
.outerjoin(
|
.outerjoin(
|
||||||
rule_count_sq,
|
rule_count_sq,
|
||||||
Technique.mitre_id == rule_count_sq.c.mitre_technique_id,
|
Technique.mitre_id == rule_count_sq.c.mitre_technique_id,
|
||||||
)
|
)
|
||||||
|
# Chain .order_by() call
|
||||||
.order_by(Technique.mitre_id)
|
.order_by(Technique.mitre_id)
|
||||||
|
# Chain .all() call
|
||||||
.all()
|
.all()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Return [
|
||||||
return [
|
return [
|
||||||
TechniqueWithCounts(
|
TechniqueWithCounts(
|
||||||
|
# Keyword argument: entity
|
||||||
entity=TechniqueMapper.to_entity(tech),
|
entity=TechniqueMapper.to_entity(tech),
|
||||||
|
# Keyword argument: test_count
|
||||||
test_count=int(tc),
|
test_count=int(tc),
|
||||||
|
# Keyword argument: validated_test_count
|
||||||
validated_test_count=int(vtc),
|
validated_test_count=int(vtc),
|
||||||
|
# Keyword argument: detection_rule_count
|
||||||
detection_rule_count=int(rc),
|
detection_rule_count=int(rc),
|
||||||
)
|
)
|
||||||
for tech, tc, vtc, rc in rows
|
for tech, tc, vtc, rc in rows
|
||||||
@@ -149,55 +269,112 @@ class SATechniqueRepository:
|
|||||||
# -- Mutations ---------------------------------------------------------
|
# -- Mutations ---------------------------------------------------------
|
||||||
|
|
||||||
def save(self, technique: TechniqueEntity) -> TechniqueEntity:
|
def save(self, technique: TechniqueEntity) -> TechniqueEntity:
|
||||||
|
"""Persist a technique entity, inserting or updating as needed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
technique (TechniqueEntity): The domain entity to persist.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TechniqueEntity: The persisted entity reflecting the current DB state.
|
||||||
|
"""
|
||||||
|
# Assign existing = (
|
||||||
existing = (
|
existing = (
|
||||||
self._session.query(Technique)
|
self._session.query(Technique)
|
||||||
|
# Chain .filter() call
|
||||||
.filter(Technique.id == technique.id)
|
.filter(Technique.id == technique.id)
|
||||||
|
# Chain .first() call
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
|
# Check: existing
|
||||||
if existing:
|
if existing:
|
||||||
|
# Call technique.apply_to()
|
||||||
technique.apply_to(existing)
|
technique.apply_to(existing)
|
||||||
|
# Assign existing.mitre_id = technique.mitre_id
|
||||||
existing.mitre_id = technique.mitre_id
|
existing.mitre_id = technique.mitre_id
|
||||||
|
# Assign existing.name = technique.name
|
||||||
existing.name = technique.name
|
existing.name = technique.name
|
||||||
|
# Assign existing.tactic = technique.tactic
|
||||||
existing.tactic = technique.tactic
|
existing.tactic = technique.tactic
|
||||||
|
# Assign existing.description = technique.description
|
||||||
existing.description = technique.description
|
existing.description = technique.description
|
||||||
|
# Assign existing.platforms = technique.platforms
|
||||||
existing.platforms = technique.platforms
|
existing.platforms = technique.platforms
|
||||||
|
# Assign existing.is_subtechnique = technique.is_subtechnique
|
||||||
existing.is_subtechnique = technique.is_subtechnique
|
existing.is_subtechnique = technique.is_subtechnique
|
||||||
|
# Assign existing.parent_mitre_id = technique.parent_mitre_id
|
||||||
existing.parent_mitre_id = technique.parent_mitre_id
|
existing.parent_mitre_id = technique.parent_mitre_id
|
||||||
|
# Assign existing.mitre_version = technique.mitre_version
|
||||||
existing.mitre_version = technique.mitre_version
|
existing.mitre_version = technique.mitre_version
|
||||||
|
# Assign existing.mitre_last_modified = technique.mitre_last_modified
|
||||||
existing.mitre_last_modified = technique.mitre_last_modified
|
existing.mitre_last_modified = technique.mitre_last_modified
|
||||||
|
# Call self._session.flush()
|
||||||
self._session.flush()
|
self._session.flush()
|
||||||
|
# Return TechniqueMapper.to_entity(existing)
|
||||||
return TechniqueMapper.to_entity(existing)
|
return TechniqueMapper.to_entity(existing)
|
||||||
|
# Fallback: handle remaining cases
|
||||||
else:
|
else:
|
||||||
|
# Assign model = Technique(
|
||||||
model = Technique(
|
model = Technique(
|
||||||
|
# Keyword argument: id
|
||||||
id=technique.id,
|
id=technique.id,
|
||||||
|
# Keyword argument: mitre_id
|
||||||
mitre_id=technique.mitre_id,
|
mitre_id=technique.mitre_id,
|
||||||
|
# Keyword argument: name
|
||||||
name=technique.name,
|
name=technique.name,
|
||||||
|
# Keyword argument: tactic
|
||||||
tactic=technique.tactic,
|
tactic=technique.tactic,
|
||||||
|
# Keyword argument: description
|
||||||
description=technique.description,
|
description=technique.description,
|
||||||
|
# Keyword argument: platforms
|
||||||
platforms=technique.platforms,
|
platforms=technique.platforms,
|
||||||
|
# Keyword argument: is_subtechnique
|
||||||
is_subtechnique=technique.is_subtechnique,
|
is_subtechnique=technique.is_subtechnique,
|
||||||
|
# Keyword argument: parent_mitre_id
|
||||||
parent_mitre_id=technique.parent_mitre_id,
|
parent_mitre_id=technique.parent_mitre_id,
|
||||||
|
# Keyword argument: status_global
|
||||||
status_global=technique.status_global,
|
status_global=technique.status_global,
|
||||||
|
# Keyword argument: review_required
|
||||||
review_required=technique.review_required,
|
review_required=technique.review_required,
|
||||||
|
# Keyword argument: last_review_date
|
||||||
last_review_date=technique.last_review_date,
|
last_review_date=technique.last_review_date,
|
||||||
|
# Keyword argument: mitre_version
|
||||||
mitre_version=technique.mitre_version,
|
mitre_version=technique.mitre_version,
|
||||||
|
# Keyword argument: mitre_last_modified
|
||||||
mitre_last_modified=technique.mitre_last_modified,
|
mitre_last_modified=technique.mitre_last_modified,
|
||||||
)
|
)
|
||||||
|
# Call self._session.add()
|
||||||
self._session.add(model)
|
self._session.add(model)
|
||||||
|
# Call self._session.flush()
|
||||||
self._session.flush()
|
self._session.flush()
|
||||||
|
# Return TechniqueMapper.to_entity(model)
|
||||||
return TechniqueMapper.to_entity(model)
|
return TechniqueMapper.to_entity(model)
|
||||||
|
|
||||||
|
# Define function exists_by_mitre_id
|
||||||
def exists_by_mitre_id(self, mitre_id: str) -> bool:
|
def exists_by_mitre_id(self, mitre_id: str) -> bool:
|
||||||
|
"""Check whether a technique with the given MITRE ID already exists.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mitre_id (str): The MITRE ATT&CK identifier to look up.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: ``True`` if the technique exists, ``False`` otherwise.
|
||||||
|
"""
|
||||||
|
# Return (
|
||||||
return (
|
return (
|
||||||
self._session.query(Technique.id)
|
self._session.query(Technique.id)
|
||||||
|
# Chain .filter() call
|
||||||
.filter(Technique.mitre_id == mitre_id)
|
.filter(Technique.mitre_id == mitre_id)
|
||||||
|
# Chain .first() call
|
||||||
.first()
|
.first()
|
||||||
) is not None
|
) is not None
|
||||||
|
|
||||||
# -- Internal ----------------------------------------------------------
|
# -- Internal ----------------------------------------------------------
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _int_type():
|
# Define function _int_type
|
||||||
|
def _int_type() -> type:
|
||||||
"""Return an Integer type for CAST expressions (SQLite-compatible)."""
|
"""Return an Integer type for CAST expressions (SQLite-compatible)."""
|
||||||
|
# Import Integer from sqlalchemy
|
||||||
from sqlalchemy import Integer
|
from sqlalchemy import Integer
|
||||||
|
# Return Integer
|
||||||
return Integer
|
return Integer
|
||||||
|
|||||||
@@ -1,78 +1,163 @@
|
|||||||
"""SQLAlchemy implementation of TestRepository."""
|
"""SQLAlchemy implementation of TestRepository."""
|
||||||
|
|
||||||
|
# Enable future language features for compatibility
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
# Import uuid
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
# Import func from sqlalchemy
|
||||||
from sqlalchemy import func
|
from sqlalchemy import func
|
||||||
|
|
||||||
|
# Import Session from sqlalchemy.orm
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
# Import TestState from app.domain.enums
|
||||||
from app.domain.enums import TestState
|
from app.domain.enums import TestState
|
||||||
|
|
||||||
|
# Import Test from app.models.test
|
||||||
from app.models.test import Test
|
from app.models.test import Test
|
||||||
|
|
||||||
|
|
||||||
|
# Define class SATestRepository
|
||||||
class SATestRepository:
|
class SATestRepository:
|
||||||
"""Concrete test repository backed by SQLAlchemy."""
|
"""Concrete test repository backed by SQLAlchemy."""
|
||||||
|
|
||||||
|
# Define function __init__
|
||||||
def __init__(self, session: Session) -> None:
|
def __init__(self, session: Session) -> None:
|
||||||
|
"""Initialise the repository with a caller-provided session.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session (Session): The SQLAlchemy session to use for all queries.
|
||||||
|
"""
|
||||||
|
# Assign self._session = session
|
||||||
self._session = session
|
self._session = session
|
||||||
|
|
||||||
|
# Define function find_by_id
|
||||||
def find_by_id(self, test_id: uuid.UUID) -> Test | None:
|
def find_by_id(self, test_id: uuid.UUID) -> Test | None:
|
||||||
|
"""Return a single test by its primary key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
test_id (uuid.UUID): The UUID primary key of the test.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Test | None: The ORM model instance, or ``None`` if not found.
|
||||||
|
"""
|
||||||
|
# Return (
|
||||||
return (
|
return (
|
||||||
self._session.query(Test)
|
self._session.query(Test)
|
||||||
|
# Chain .filter() call
|
||||||
.filter(Test.id == test_id)
|
.filter(Test.id == test_id)
|
||||||
|
# Chain .first() call
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Define function list_by_technique
|
||||||
def list_by_technique(self, technique_id: uuid.UUID) -> list[Test]:
|
def list_by_technique(self, technique_id: uuid.UUID) -> list[Test]:
|
||||||
|
"""Return all tests for a given technique, ordered by creation date.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
technique_id (uuid.UUID): The UUID of the parent technique.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[Test]: ORM model instances ordered by ``created_at`` ascending.
|
||||||
|
"""
|
||||||
|
# Return (
|
||||||
return (
|
return (
|
||||||
self._session.query(Test)
|
self._session.query(Test)
|
||||||
|
# Chain .filter() call
|
||||||
.filter(Test.technique_id == technique_id)
|
.filter(Test.technique_id == technique_id)
|
||||||
|
# Chain .order_by() call
|
||||||
.order_by(Test.created_at)
|
.order_by(Test.created_at)
|
||||||
|
# Chain .all() call
|
||||||
.all()
|
.all()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Define function list_by_state
|
||||||
def list_by_state(self, state: TestState) -> list[Test]:
|
def list_by_state(self, state: TestState) -> list[Test]:
|
||||||
|
"""Return all tests that are currently in the given workflow state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state (TestState): The workflow state to filter on.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[Test]: All ORM model instances with the specified state.
|
||||||
|
"""
|
||||||
|
# Return (
|
||||||
return (
|
return (
|
||||||
self._session.query(Test)
|
self._session.query(Test)
|
||||||
|
# Chain .filter() call
|
||||||
.filter(Test.state == state)
|
.filter(Test.state == state)
|
||||||
|
# Chain .all() call
|
||||||
.all()
|
.all()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Define function count_by_technique_and_state
|
||||||
def count_by_technique_and_state(
|
def count_by_technique_and_state(
|
||||||
self,
|
self,
|
||||||
|
# Entry: technique_id
|
||||||
technique_id: uuid.UUID,
|
technique_id: uuid.UUID,
|
||||||
) -> dict[TestState, int]:
|
) -> dict[TestState, int]:
|
||||||
|
"""Return per-state test counts for a specific technique.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
technique_id (uuid.UUID): The UUID of the technique to aggregate for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict[TestState, int]: Mapping of each state to the number of tests in that state.
|
||||||
|
"""
|
||||||
|
# Assign rows = (
|
||||||
rows = (
|
rows = (
|
||||||
self._session.query(Test.state, func.count(Test.id))
|
self._session.query(Test.state, func.count(Test.id))
|
||||||
|
# Chain .filter() call
|
||||||
.filter(Test.technique_id == technique_id)
|
.filter(Test.technique_id == technique_id)
|
||||||
|
# Chain .group_by() call
|
||||||
.group_by(Test.state)
|
.group_by(Test.state)
|
||||||
|
# Chain .all() call
|
||||||
.all()
|
.all()
|
||||||
)
|
)
|
||||||
|
# Assign result = {}
|
||||||
result: dict[TestState, int] = {}
|
result: dict[TestState, int] = {}
|
||||||
|
# Iterate over rows
|
||||||
for state_val, count in rows:
|
for state_val, count in rows:
|
||||||
|
# Assign key = (
|
||||||
key = (
|
key = (
|
||||||
state_val
|
state_val
|
||||||
if isinstance(state_val, TestState)
|
if isinstance(state_val, TestState)
|
||||||
else TestState(state_val)
|
else TestState(state_val)
|
||||||
)
|
)
|
||||||
|
# Assign result[key] = count
|
||||||
result[key] = count
|
result[key] = count
|
||||||
|
# Return result
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
# Define function get_states_and_results_for_technique
|
||||||
def get_states_and_results_for_technique(
|
def get_states_and_results_for_technique(
|
||||||
self,
|
self,
|
||||||
|
# Entry: technique_id
|
||||||
technique_id: uuid.UUID,
|
technique_id: uuid.UUID,
|
||||||
) -> list[tuple[str, str | None]]:
|
) -> list[tuple[str, str | None]]:
|
||||||
"""Return lightweight (state, detection_result) pairs.
|
"""Return lightweight ``(state, detection_result)`` pairs for a technique.
|
||||||
|
|
||||||
Used by TechniqueEntity.recalculate_status() without loading
|
Used by ``TechniqueEntity.recalculate_status()`` to avoid loading full
|
||||||
full Test models.
|
``Test`` models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
technique_id (uuid.UUID): The UUID of the technique to query.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[tuple[str, str | None]]: Each tuple contains the state string
|
||||||
|
and the detection result string (or ``None``).
|
||||||
"""
|
"""
|
||||||
|
# Assign rows = (
|
||||||
rows = (
|
rows = (
|
||||||
self._session.query(Test.state, Test.detection_result)
|
self._session.query(Test.state, Test.detection_result)
|
||||||
|
# Chain .filter() call
|
||||||
.filter(Test.technique_id == technique_id)
|
.filter(Test.technique_id == technique_id)
|
||||||
|
# Chain .all() call
|
||||||
.all()
|
.all()
|
||||||
)
|
)
|
||||||
|
# Return [
|
||||||
return [
|
return [
|
||||||
(
|
(
|
||||||
r.state.value if hasattr(r.state, "value") else str(r.state),
|
r.state.value if hasattr(r.state, "value") else str(r.state),
|
||||||
|
|||||||
@@ -13,54 +13,79 @@ Usage::
|
|||||||
get_redis_blacklist().setex("blacklist:…", ttl, "1")
|
get_redis_blacklist().setex("blacklist:…", ttl, "1")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Enable future language features for compatibility
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
# Import logging
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
# Import urlparse, urlunparse from urllib.parse
|
||||||
from urllib.parse import urlparse, urlunparse
|
from urllib.parse import urlparse, urlunparse
|
||||||
|
|
||||||
|
# Import redis
|
||||||
import redis
|
import redis
|
||||||
|
|
||||||
|
# Import settings from app.config
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
|
|
||||||
|
# Assign logger = logging.getLogger(__name__)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Assign _clients = {}
|
||||||
_clients: dict[str, redis.Redis] = {}
|
_clients: dict[str, redis.Redis] = {}
|
||||||
|
|
||||||
|
|
||||||
|
# Define function _redis_url_with_db
|
||||||
def _redis_url_with_db(base_url: str, db_index: int) -> str:
|
def _redis_url_with_db(base_url: str, db_index: int) -> str:
|
||||||
"""Return *base_url* with its path replaced by ``/{db_index}``."""
|
"""Return *base_url* with its path replaced by ``/{db_index}``."""
|
||||||
|
# Assign parsed = urlparse(base_url)
|
||||||
parsed = urlparse(base_url)
|
parsed = urlparse(base_url)
|
||||||
|
# Assign path = f"/{db_index}"
|
||||||
path = f"/{db_index}"
|
path = f"/{db_index}"
|
||||||
|
# Return urlunparse(
|
||||||
return urlunparse(
|
return urlunparse(
|
||||||
(parsed.scheme, parsed.netloc, path, "", "", ""),
|
(parsed.scheme, parsed.netloc, path, "", "", ""),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Define function _get_client
|
||||||
def _get_client(url: str) -> redis.Redis:
|
def _get_client(url: str) -> redis.Redis:
|
||||||
|
# Check: url not in _clients
|
||||||
if url not in _clients:
|
if url not in _clients:
|
||||||
|
# Assign _clients[url] = redis.from_url(url, decode_responses=True)
|
||||||
_clients[url] = redis.from_url(url, decode_responses=True)
|
_clients[url] = redis.from_url(url, decode_responses=True)
|
||||||
|
# Log info: "Redis client connected to %s", url
|
||||||
logger.info("Redis client connected to %s", url)
|
logger.info("Redis client connected to %s", url)
|
||||||
|
# Return _clients[url]
|
||||||
return _clients[url]
|
return _clients[url]
|
||||||
|
|
||||||
|
|
||||||
|
# Define function get_redis
|
||||||
def get_redis() -> redis.Redis:
|
def get_redis() -> redis.Redis:
|
||||||
"""Default Redis connection (URL from ``settings.REDIS_URL``)."""
|
"""Default Redis connection (URL from ``settings.REDIS_URL``)."""
|
||||||
|
# Return _get_client(settings.REDIS_URL)
|
||||||
return _get_client(settings.REDIS_URL)
|
return _get_client(settings.REDIS_URL)
|
||||||
|
|
||||||
|
|
||||||
|
# Define function get_redis_blacklist
|
||||||
def get_redis_blacklist() -> redis.Redis:
|
def get_redis_blacklist() -> redis.Redis:
|
||||||
"""Redis DB used for JWT revocation (``jti`` keys with TTL)."""
|
"""Redis DB used for JWT revocation (``jti`` keys with TTL)."""
|
||||||
|
# Assign url = _redis_url_with_db(
|
||||||
url = _redis_url_with_db(
|
url = _redis_url_with_db(
|
||||||
settings.REDIS_URL,
|
settings.REDIS_URL,
|
||||||
settings.REDIS_TOKEN_BLACKLIST_DB,
|
settings.REDIS_TOKEN_BLACKLIST_DB,
|
||||||
)
|
)
|
||||||
|
# Return _get_client(url)
|
||||||
return _get_client(url)
|
return _get_client(url)
|
||||||
|
|
||||||
|
|
||||||
|
# Define function get_redis_cache
|
||||||
def get_redis_cache() -> redis.Redis:
|
def get_redis_cache() -> redis.Redis:
|
||||||
"""Redis DB reserved for shared cache (scores, queues, etc.)."""
|
"""Redis DB reserved for shared cache (scores, queues, etc.)."""
|
||||||
|
# Assign url = _redis_url_with_db(
|
||||||
url = _redis_url_with_db(
|
url = _redis_url_with_db(
|
||||||
settings.REDIS_URL,
|
settings.REDIS_URL,
|
||||||
settings.REDIS_CACHE_DB,
|
settings.REDIS_CACHE_DB,
|
||||||
)
|
)
|
||||||
|
# Return _get_client(url)
|
||||||
return _get_client(url)
|
return _get_client(url)
|
||||||
|
|||||||
@@ -0,0 +1 @@
|
|||||||
|
"""Background scheduler jobs (MITRE sync, Jira sync, data retention)."""
|
||||||
|
|||||||
@@ -1,37 +1,65 @@
|
|||||||
"""Scheduled job — syncs all Jira links hourly."""
|
"""Scheduled job — syncs all Jira links hourly."""
|
||||||
|
|
||||||
|
# Import logging
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
# Import settings from app.config
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
|
|
||||||
|
# Import SessionLocal from app.database
|
||||||
from app.database import SessionLocal
|
from app.database import SessionLocal
|
||||||
|
|
||||||
|
# Import JiraLink from app.models.jira_link
|
||||||
from app.models.jira_link import JiraLink
|
from app.models.jira_link import JiraLink
|
||||||
|
|
||||||
|
# Import jira_service from app.services
|
||||||
from app.services import jira_service
|
from app.services import jira_service
|
||||||
|
|
||||||
|
# Assign logger = logging.getLogger(__name__)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Define function sync_all_jira_links
|
||||||
def sync_all_jira_links() -> None:
|
def sync_all_jira_links() -> None:
|
||||||
"""Pull latest status from Jira for every stored link.
|
"""Pull latest status from Jira for every stored link.
|
||||||
|
|
||||||
Silently skips if ``JIRA_ENABLED`` is ``False``. Individual link
|
Silently skips if ``JIRA_ENABLED`` is ``False``. Individual link
|
||||||
failures are logged but do not abort the rest of the batch.
|
failures are logged but do not abort the rest of the batch.
|
||||||
"""
|
"""
|
||||||
|
# Check: not settings.JIRA_ENABLED
|
||||||
if not settings.JIRA_ENABLED:
|
if not settings.JIRA_ENABLED:
|
||||||
|
# Return control to caller
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Assign db = SessionLocal()
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
|
# Attempt the following; catch errors below
|
||||||
try:
|
try:
|
||||||
|
# Assign links = db.query(JiraLink).all()
|
||||||
links = db.query(JiraLink).all()
|
links = db.query(JiraLink).all()
|
||||||
|
# Assign synced = 0
|
||||||
synced = 0
|
synced = 0
|
||||||
|
# Iterate over links
|
||||||
for link in links:
|
for link in links:
|
||||||
|
# Attempt the following; catch errors below
|
||||||
try:
|
try:
|
||||||
|
# Call jira_service.sync_jira_to_aegis()
|
||||||
jira_service.sync_jira_to_aegis(db, link)
|
jira_service.sync_jira_to_aegis(db, link)
|
||||||
|
# Assign synced = 1
|
||||||
synced += 1
|
synced += 1
|
||||||
|
# Handle Exception
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
# Log warning: "Jira sync failed for link %s: %s", link.id, e
|
||||||
logger.warning("Jira sync failed for link %s: %s", link.id, e)
|
logger.warning("Jira sync failed for link %s: %s", link.id, e)
|
||||||
|
# Commit all pending changes to the database
|
||||||
db.commit()
|
db.commit()
|
||||||
|
# Log info: "Jira sync completed: %d/%d links updated", synced
|
||||||
logger.info("Jira sync completed: %d/%d links updated", synced, len(links))
|
logger.info("Jira sync completed: %d/%d links updated", synced, len(links))
|
||||||
|
# Handle Exception
|
||||||
except Exception:
|
except Exception:
|
||||||
|
# Log exception: "Jira sync batch job failed"
|
||||||
logger.exception("Jira sync batch job failed")
|
logger.exception("Jira sync batch job failed")
|
||||||
|
# Always execute this cleanup block
|
||||||
finally:
|
finally:
|
||||||
|
# Close the database session
|
||||||
db.close()
|
db.close()
|
||||||
|
|||||||
@@ -10,22 +10,44 @@ Each job manages its own database session (created on entry, closed in
|
|||||||
sessions.
|
sessions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Import logging
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
|
# Import BackgroundScheduler from apscheduler.schedulers.background
|
||||||
from apscheduler.schedulers.background import BackgroundScheduler
|
from apscheduler.schedulers.background import BackgroundScheduler
|
||||||
|
|
||||||
|
# Import SessionLocal from app.database
|
||||||
from app.database import SessionLocal
|
from app.database import SessionLocal
|
||||||
from app.services.mitre_sync_service import sync_mitre
|
|
||||||
from app.services.intel_service import scan_intel
|
# Import sync_all_jira_links from app.jobs.jira_sync_job
|
||||||
from app.services.notification_service import cleanup_old_notifications
|
|
||||||
from app.services.snapshot_service import create_snapshot, cleanup_old_snapshots
|
|
||||||
from app.services.campaign_scheduler_service import check_and_run_recurring_campaigns
|
|
||||||
from app.jobs.jira_sync_job import sync_all_jira_links
|
from app.jobs.jira_sync_job import sync_all_jira_links
|
||||||
from app.services.osint_enrichment_service import enrich_all_techniques
|
|
||||||
from app.services.stale_detection_service import detect_stale_coverage
|
# Import run_retention_job from app.jobs.retention_job
|
||||||
from app.jobs.retention_job import run_retention_job
|
from app.jobs.retention_job import run_retention_job
|
||||||
|
|
||||||
|
# Import check_and_run_recurring_campaigns from app.services.campaign_scheduler_service
|
||||||
|
from app.services.campaign_scheduler_service import check_and_run_recurring_campaigns
|
||||||
|
|
||||||
|
# Import scan_intel from app.services.intel_service
|
||||||
|
from app.services.intel_service import scan_intel
|
||||||
|
|
||||||
|
# Import sync_mitre from app.services.mitre_sync_service
|
||||||
|
from app.services.mitre_sync_service import sync_mitre
|
||||||
|
|
||||||
|
# Import cleanup_old_notifications from app.services.notification_service
|
||||||
|
from app.services.notification_service import cleanup_old_notifications
|
||||||
|
|
||||||
|
# Import enrich_all_techniques from app.services.osint_enrichment_service
|
||||||
|
from app.services.osint_enrichment_service import enrich_all_techniques
|
||||||
|
|
||||||
|
# Import cleanup_old_snapshots, create_snapshot from app.services.snapshot_service
|
||||||
|
from app.services.snapshot_service import cleanup_old_snapshots, create_snapshot
|
||||||
|
|
||||||
|
# Import detect_stale_coverage from app.services.stale_detection_service
|
||||||
|
from app.services.stale_detection_service import detect_stale_coverage
|
||||||
|
|
||||||
|
# Assign logger = logging.getLogger(__name__)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -44,60 +66,101 @@ def _run_mitre_sync() -> None:
|
|||||||
"""Execute a MITRE sync inside its own DB session."""
|
"""Execute a MITRE sync inside its own DB session."""
|
||||||
from app.services.webhook_service import dispatch_webhook
|
from app.services.webhook_service import dispatch_webhook
|
||||||
logger.info("Scheduled MITRE sync job starting...")
|
logger.info("Scheduled MITRE sync job starting...")
|
||||||
|
# Assign db = SessionLocal()
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
|
# Attempt the following; catch errors below
|
||||||
try:
|
try:
|
||||||
|
# Assign summary = sync_mitre(db)
|
||||||
summary = sync_mitre(db)
|
summary = sync_mitre(db)
|
||||||
|
# Log info: "Scheduled MITRE sync job finished — %s", summary
|
||||||
logger.info("Scheduled MITRE sync job finished — %s", summary)
|
logger.info("Scheduled MITRE sync job finished — %s", summary)
|
||||||
dispatch_webhook("mitre.synced", {"created": summary.get("created", 0), "updated": summary.get("updated", 0)})
|
dispatch_webhook("mitre.synced", {"created": summary.get("created", 0), "updated": summary.get("updated", 0)})
|
||||||
except Exception:
|
except Exception:
|
||||||
|
# Log exception: "Scheduled MITRE sync job failed"
|
||||||
logger.exception("Scheduled MITRE sync job failed")
|
logger.exception("Scheduled MITRE sync job failed")
|
||||||
|
# Always execute this cleanup block
|
||||||
finally:
|
finally:
|
||||||
|
# Close the database session
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
# Define function _run_notification_cleanup
|
||||||
def _run_notification_cleanup() -> None:
|
def _run_notification_cleanup() -> None:
|
||||||
"""Clean up old read notifications."""
|
"""Clean up old read notifications."""
|
||||||
|
# Log info: "Scheduled notification cleanup job starting..."
|
||||||
logger.info("Scheduled notification cleanup job starting...")
|
logger.info("Scheduled notification cleanup job starting...")
|
||||||
|
# Assign db = SessionLocal()
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
|
# Attempt the following; catch errors below
|
||||||
try:
|
try:
|
||||||
|
# Assign deleted = cleanup_old_notifications(db, days=90)
|
||||||
deleted = cleanup_old_notifications(db, days=90)
|
deleted = cleanup_old_notifications(db, days=90)
|
||||||
|
# Log info: "Notification cleanup finished — deleted %d old no
|
||||||
logger.info("Notification cleanup finished — deleted %d old notifications", deleted)
|
logger.info("Notification cleanup finished — deleted %d old notifications", deleted)
|
||||||
|
# Handle Exception
|
||||||
except Exception:
|
except Exception:
|
||||||
|
# Log exception: "Notification cleanup job failed"
|
||||||
logger.exception("Notification cleanup job failed")
|
logger.exception("Notification cleanup job failed")
|
||||||
|
# Always execute this cleanup block
|
||||||
finally:
|
finally:
|
||||||
|
# Close the database session
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
# Define function _run_weekly_snapshot
|
||||||
def _run_weekly_snapshot() -> None:
|
def _run_weekly_snapshot() -> None:
|
||||||
"""Create a weekly coverage snapshot and clean up old ones."""
|
"""Create a weekly coverage snapshot and clean up old ones."""
|
||||||
|
# Log info: "Scheduled weekly snapshot job starting..."
|
||||||
logger.info("Scheduled weekly snapshot job starting...")
|
logger.info("Scheduled weekly snapshot job starting...")
|
||||||
|
# Assign db = SessionLocal()
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
|
# Attempt the following; catch errors below
|
||||||
try:
|
try:
|
||||||
|
# Assign snapshot = create_snapshot(db, name="Auto-weekly")
|
||||||
snapshot = create_snapshot(db, name="Auto-weekly")
|
snapshot = create_snapshot(db, name="Auto-weekly")
|
||||||
|
# Log info:
|
||||||
logger.info(
|
logger.info(
|
||||||
|
# Literal argument value
|
||||||
"Weekly snapshot created — score %.1f, %d techniques",
|
"Weekly snapshot created — score %.1f, %d techniques",
|
||||||
snapshot.organization_score,
|
snapshot.organization_score,
|
||||||
snapshot.total_techniques,
|
snapshot.total_techniques,
|
||||||
)
|
)
|
||||||
|
# Assign deleted = cleanup_old_snapshots(db, keep_last=52)
|
||||||
deleted = cleanup_old_snapshots(db, keep_last=52)
|
deleted = cleanup_old_snapshots(db, keep_last=52)
|
||||||
|
# Check: deleted
|
||||||
if deleted:
|
if deleted:
|
||||||
|
# Log info: "Cleaned up %d old snapshots", deleted
|
||||||
logger.info("Cleaned up %d old snapshots", deleted)
|
logger.info("Cleaned up %d old snapshots", deleted)
|
||||||
|
# Handle Exception
|
||||||
except Exception:
|
except Exception:
|
||||||
|
# Log exception: "Weekly snapshot job failed"
|
||||||
logger.exception("Weekly snapshot job failed")
|
logger.exception("Weekly snapshot job failed")
|
||||||
|
# Always execute this cleanup block
|
||||||
finally:
|
finally:
|
||||||
|
# Close the database session
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
# Define function _run_recurring_campaigns
|
||||||
def _run_recurring_campaigns() -> None:
|
def _run_recurring_campaigns() -> None:
|
||||||
"""Check and run any due recurring campaigns."""
|
"""Check and run any due recurring campaigns."""
|
||||||
|
# Log info: "Scheduled recurring campaigns check starting..."
|
||||||
logger.info("Scheduled recurring campaigns check starting...")
|
logger.info("Scheduled recurring campaigns check starting...")
|
||||||
|
# Assign db = SessionLocal()
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
|
# Attempt the following; catch errors below
|
||||||
try:
|
try:
|
||||||
|
# Assign spawned = check_and_run_recurring_campaigns(db)
|
||||||
spawned = check_and_run_recurring_campaigns(db)
|
spawned = check_and_run_recurring_campaigns(db)
|
||||||
|
# Log info: "Recurring campaigns check finished — spawned %d c
|
||||||
logger.info("Recurring campaigns check finished — spawned %d campaigns", spawned)
|
logger.info("Recurring campaigns check finished — spawned %d campaigns", spawned)
|
||||||
|
# Handle Exception
|
||||||
except Exception:
|
except Exception:
|
||||||
|
# Log exception: "Recurring campaigns check failed"
|
||||||
logger.exception("Recurring campaigns check failed")
|
logger.exception("Recurring campaigns check failed")
|
||||||
|
# Always execute this cleanup block
|
||||||
finally:
|
finally:
|
||||||
|
# Close the database session
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
@@ -193,14 +256,23 @@ def _run_scheduled_campaign_activation() -> None:
|
|||||||
|
|
||||||
def _run_intel_scan() -> None:
|
def _run_intel_scan() -> None:
|
||||||
"""Execute an intel scan inside its own DB session."""
|
"""Execute an intel scan inside its own DB session."""
|
||||||
|
# Log info: "Scheduled intel scan job starting..."
|
||||||
logger.info("Scheduled intel scan job starting...")
|
logger.info("Scheduled intel scan job starting...")
|
||||||
|
# Assign db = SessionLocal()
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
|
# Attempt the following; catch errors below
|
||||||
try:
|
try:
|
||||||
|
# Assign summary = scan_intel(db)
|
||||||
summary = scan_intel(db)
|
summary = scan_intel(db)
|
||||||
|
# Log info: "Scheduled intel scan job finished — %s", summary
|
||||||
logger.info("Scheduled intel scan job finished — %s", summary)
|
logger.info("Scheduled intel scan job finished — %s", summary)
|
||||||
|
# Handle Exception
|
||||||
except Exception:
|
except Exception:
|
||||||
|
# Log exception: "Scheduled intel scan job failed"
|
||||||
logger.exception("Scheduled intel scan job failed")
|
logger.exception("Scheduled intel scan job failed")
|
||||||
|
# Always execute this cleanup block
|
||||||
finally:
|
finally:
|
||||||
|
# Close the database session
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
@@ -283,14 +355,23 @@ def _run_evaluation_round_check() -> None:
|
|||||||
|
|
||||||
def _run_osint_enrichment() -> None:
|
def _run_osint_enrichment() -> None:
|
||||||
"""Execute weekly OSINT enrichment inside its own DB session."""
|
"""Execute weekly OSINT enrichment inside its own DB session."""
|
||||||
|
# Log info: "Scheduled OSINT enrichment job starting..."
|
||||||
logger.info("Scheduled OSINT enrichment job starting...")
|
logger.info("Scheduled OSINT enrichment job starting...")
|
||||||
|
# Assign db = SessionLocal()
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
|
# Attempt the following; catch errors below
|
||||||
try:
|
try:
|
||||||
|
# Assign total = enrich_all_techniques(db)
|
||||||
total = enrich_all_techniques(db)
|
total = enrich_all_techniques(db)
|
||||||
|
# Log info: "OSINT enrichment finished — %d new items", total
|
||||||
logger.info("OSINT enrichment finished — %d new items", total)
|
logger.info("OSINT enrichment finished — %d new items", total)
|
||||||
|
# Handle Exception
|
||||||
except Exception:
|
except Exception:
|
||||||
|
# Log exception: "OSINT enrichment job failed"
|
||||||
logger.exception("OSINT enrichment job failed")
|
logger.exception("OSINT enrichment job failed")
|
||||||
|
# Always execute this cleanup block
|
||||||
finally:
|
finally:
|
||||||
|
# Close the database session
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
@@ -351,14 +432,23 @@ def _run_data_sources_sync() -> None:
|
|||||||
|
|
||||||
def _run_stale_detection() -> None:
|
def _run_stale_detection() -> None:
|
||||||
"""Execute daily stale coverage detection inside its own DB session."""
|
"""Execute daily stale coverage detection inside its own DB session."""
|
||||||
|
# Log info: "Scheduled stale coverage detection starting..."
|
||||||
logger.info("Scheduled stale coverage detection starting...")
|
logger.info("Scheduled stale coverage detection starting...")
|
||||||
|
# Assign db = SessionLocal()
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
|
# Attempt the following; catch errors below
|
||||||
try:
|
try:
|
||||||
|
# Assign count = detect_stale_coverage(db)
|
||||||
count = detect_stale_coverage(db)
|
count = detect_stale_coverage(db)
|
||||||
|
# Log info: "Stale detection finished — %d techniques flagged"
|
||||||
logger.info("Stale detection finished — %d techniques flagged", count)
|
logger.info("Stale detection finished — %d techniques flagged", count)
|
||||||
|
# Handle Exception
|
||||||
except Exception:
|
except Exception:
|
||||||
|
# Log exception: "Stale coverage detection job failed"
|
||||||
logger.exception("Stale coverage detection job failed")
|
logger.exception("Stale coverage detection job failed")
|
||||||
|
# Always execute this cleanup block
|
||||||
finally:
|
finally:
|
||||||
|
# Close the database session
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
@@ -424,40 +514,67 @@ def start_scheduler() -> None:
|
|||||||
|
|
||||||
Neither job fires immediately on startup.
|
Neither job fires immediately on startup.
|
||||||
"""
|
"""
|
||||||
|
# Call scheduler.add_job()
|
||||||
scheduler.add_job(
|
scheduler.add_job(
|
||||||
_run_mitre_sync,
|
_run_mitre_sync,
|
||||||
|
# Keyword argument: trigger
|
||||||
trigger="interval",
|
trigger="interval",
|
||||||
|
# Keyword argument: hours
|
||||||
hours=24,
|
hours=24,
|
||||||
|
# Keyword argument: id
|
||||||
id="mitre_sync",
|
id="mitre_sync",
|
||||||
|
# Keyword argument: name
|
||||||
name="MITRE ATT&CK sync (every 24h)",
|
name="MITRE ATT&CK sync (every 24h)",
|
||||||
|
# Keyword argument: replace_existing
|
||||||
replace_existing=True,
|
replace_existing=True,
|
||||||
)
|
)
|
||||||
|
# Call scheduler.add_job()
|
||||||
scheduler.add_job(
|
scheduler.add_job(
|
||||||
_run_intel_scan,
|
_run_intel_scan,
|
||||||
|
# Keyword argument: trigger
|
||||||
trigger="interval",
|
trigger="interval",
|
||||||
|
# Keyword argument: weeks
|
||||||
weeks=1,
|
weeks=1,
|
||||||
|
# Keyword argument: id
|
||||||
id="intel_scan",
|
id="intel_scan",
|
||||||
|
# Keyword argument: name
|
||||||
name="Intel scan (every 7d)",
|
name="Intel scan (every 7d)",
|
||||||
|
# Keyword argument: replace_existing
|
||||||
replace_existing=True,
|
replace_existing=True,
|
||||||
)
|
)
|
||||||
|
# Call scheduler.add_job()
|
||||||
scheduler.add_job(
|
scheduler.add_job(
|
||||||
_run_notification_cleanup,
|
_run_notification_cleanup,
|
||||||
|
# Keyword argument: trigger
|
||||||
trigger="interval",
|
trigger="interval",
|
||||||
|
# Keyword argument: hours
|
||||||
hours=24,
|
hours=24,
|
||||||
|
# Keyword argument: id
|
||||||
id="notification_cleanup",
|
id="notification_cleanup",
|
||||||
|
# Keyword argument: name
|
||||||
name="Notification cleanup (daily)",
|
name="Notification cleanup (daily)",
|
||||||
|
# Keyword argument: replace_existing
|
||||||
replace_existing=True,
|
replace_existing=True,
|
||||||
)
|
)
|
||||||
|
# Call scheduler.add_job()
|
||||||
scheduler.add_job(
|
scheduler.add_job(
|
||||||
_run_weekly_snapshot,
|
_run_weekly_snapshot,
|
||||||
|
# Keyword argument: trigger
|
||||||
trigger="cron",
|
trigger="cron",
|
||||||
|
# Keyword argument: day_of_week
|
||||||
day_of_week="sun",
|
day_of_week="sun",
|
||||||
|
# Keyword argument: hour
|
||||||
hour=0,
|
hour=0,
|
||||||
|
# Keyword argument: minute
|
||||||
minute=0,
|
minute=0,
|
||||||
|
# Keyword argument: id
|
||||||
id="weekly_snapshot",
|
id="weekly_snapshot",
|
||||||
|
# Keyword argument: name
|
||||||
name="Weekly coverage snapshot (Sundays 00:00)",
|
name="Weekly coverage snapshot (Sundays 00:00)",
|
||||||
|
# Keyword argument: replace_existing
|
||||||
replace_existing=True,
|
replace_existing=True,
|
||||||
)
|
)
|
||||||
|
# Call scheduler.add_job()
|
||||||
scheduler.add_job(
|
scheduler.add_job(
|
||||||
_run_scheduled_campaign_activation,
|
_run_scheduled_campaign_activation,
|
||||||
trigger="interval",
|
trigger="interval",
|
||||||
@@ -468,42 +585,71 @@ def start_scheduler() -> None:
|
|||||||
)
|
)
|
||||||
scheduler.add_job(
|
scheduler.add_job(
|
||||||
_run_recurring_campaigns,
|
_run_recurring_campaigns,
|
||||||
|
# Keyword argument: trigger
|
||||||
trigger="interval",
|
trigger="interval",
|
||||||
|
# Keyword argument: hours
|
||||||
hours=24,
|
hours=24,
|
||||||
|
# Keyword argument: id
|
||||||
id="recurring_campaigns",
|
id="recurring_campaigns",
|
||||||
|
# Keyword argument: name
|
||||||
name="Recurring campaigns check (daily)",
|
name="Recurring campaigns check (daily)",
|
||||||
|
# Keyword argument: replace_existing
|
||||||
replace_existing=True,
|
replace_existing=True,
|
||||||
)
|
)
|
||||||
|
# Call scheduler.add_job()
|
||||||
scheduler.add_job(
|
scheduler.add_job(
|
||||||
sync_all_jira_links,
|
sync_all_jira_links,
|
||||||
|
# Keyword argument: trigger
|
||||||
trigger="interval",
|
trigger="interval",
|
||||||
|
# Keyword argument: hours
|
||||||
hours=1,
|
hours=1,
|
||||||
|
# Keyword argument: id
|
||||||
id="jira_sync",
|
id="jira_sync",
|
||||||
|
# Keyword argument: name
|
||||||
name="Jira link sync (hourly)",
|
name="Jira link sync (hourly)",
|
||||||
|
# Keyword argument: replace_existing
|
||||||
replace_existing=True,
|
replace_existing=True,
|
||||||
)
|
)
|
||||||
|
# Call scheduler.add_job()
|
||||||
scheduler.add_job(
|
scheduler.add_job(
|
||||||
_run_osint_enrichment,
|
_run_osint_enrichment,
|
||||||
|
# Keyword argument: trigger
|
||||||
trigger="interval",
|
trigger="interval",
|
||||||
|
# Keyword argument: weeks
|
||||||
weeks=1,
|
weeks=1,
|
||||||
|
# Keyword argument: id
|
||||||
id="osint_enrichment",
|
id="osint_enrichment",
|
||||||
|
# Keyword argument: name
|
||||||
name="OSINT enrichment (weekly)",
|
name="OSINT enrichment (weekly)",
|
||||||
|
# Keyword argument: replace_existing
|
||||||
replace_existing=True,
|
replace_existing=True,
|
||||||
)
|
)
|
||||||
|
# Call scheduler.add_job()
|
||||||
scheduler.add_job(
|
scheduler.add_job(
|
||||||
_run_stale_detection,
|
_run_stale_detection,
|
||||||
|
# Keyword argument: trigger
|
||||||
trigger="interval",
|
trigger="interval",
|
||||||
|
# Keyword argument: hours
|
||||||
hours=24,
|
hours=24,
|
||||||
|
# Keyword argument: id
|
||||||
id="stale_detection",
|
id="stale_detection",
|
||||||
|
# Keyword argument: name
|
||||||
name="Stale coverage detection (daily)",
|
name="Stale coverage detection (daily)",
|
||||||
|
# Keyword argument: replace_existing
|
||||||
replace_existing=True,
|
replace_existing=True,
|
||||||
)
|
)
|
||||||
|
# Call scheduler.add_job()
|
||||||
scheduler.add_job(
|
scheduler.add_job(
|
||||||
run_retention_job,
|
run_retention_job,
|
||||||
|
# Keyword argument: trigger
|
||||||
trigger="interval",
|
trigger="interval",
|
||||||
|
# Keyword argument: hours
|
||||||
hours=24,
|
hours=24,
|
||||||
|
# Keyword argument: id
|
||||||
id="retention_policies",
|
id="retention_policies",
|
||||||
|
# Keyword argument: name
|
||||||
name="Data retention policies (daily)",
|
name="Data retention policies (daily)",
|
||||||
|
# Keyword argument: replace_existing
|
||||||
replace_existing=True,
|
replace_existing=True,
|
||||||
)
|
)
|
||||||
scheduler.add_job(
|
scheduler.add_job(
|
||||||
@@ -551,10 +697,15 @@ def start_scheduler() -> None:
|
|||||||
replace_existing=True,
|
replace_existing=True,
|
||||||
)
|
)
|
||||||
scheduler.start()
|
scheduler.start()
|
||||||
|
# Log info:
|
||||||
logger.info(
|
logger.info(
|
||||||
|
# Literal argument value
|
||||||
"Background scheduler started — mitre_sync (24h), intel_scan (7d), "
|
"Background scheduler started — mitre_sync (24h), intel_scan (7d), "
|
||||||
|
# Literal argument value
|
||||||
"notification_cleanup (24h), weekly_snapshot (Sundays 00:00), "
|
"notification_cleanup (24h), weekly_snapshot (Sundays 00:00), "
|
||||||
|
# Literal argument value
|
||||||
"recurring_campaigns (daily), jira_sync (1h), "
|
"recurring_campaigns (daily), jira_sync (1h), "
|
||||||
|
# Literal argument value
|
||||||
"osint_enrichment (weekly), stale_detection (daily), "
|
"osint_enrichment (weekly), stale_detection (daily), "
|
||||||
"retention_policies (daily), data_sources_sync (6h), "
|
"retention_policies (daily), data_sources_sync (6h), "
|
||||||
"alert_evaluation (1h), attck_evaluation_check (Mondays 06:00)"
|
"alert_evaluation (1h), attck_evaluation_check (Mondays 06:00)"
|
||||||
|
|||||||
@@ -1,53 +1,89 @@
|
|||||||
"""Data retention policies — scheduled cleanup of aged records."""
|
"""Data retention policies — scheduled cleanup of aged records."""
|
||||||
|
|
||||||
|
# Enable future language features for compatibility
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
# Import logging
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
# Import datetime, timedelta, timezone from datetime
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
|
# Import Session from sqlalchemy.orm
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
# Import SessionLocal from app.database
|
||||||
from app.database import SessionLocal
|
from app.database import SessionLocal
|
||||||
|
|
||||||
|
# Import AuditLog from app.models.audit
|
||||||
from app.models.audit import AuditLog
|
from app.models.audit import AuditLog
|
||||||
|
|
||||||
|
# Import cleanup_old_notifications from app.services.notification_service
|
||||||
from app.services.notification_service import cleanup_old_notifications
|
from app.services.notification_service import cleanup_old_notifications
|
||||||
|
|
||||||
|
# Assign logger = logging.getLogger(__name__)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Assign AUDIT_LOG_RETENTION_DAYS = 730
|
||||||
AUDIT_LOG_RETENTION_DAYS = 730
|
AUDIT_LOG_RETENTION_DAYS = 730
|
||||||
|
|
||||||
|
|
||||||
|
# Define function apply_retention_policies
|
||||||
def apply_retention_policies(db: Session) -> dict[str, int]:
|
def apply_retention_policies(db: Session) -> dict[str, int]:
|
||||||
"""Apply retention rules. Commits the session before returning."""
|
"""Apply retention rules. Commits the session before returning."""
|
||||||
|
# Assign cutoff = datetime.now(timezone.utc) - timedelta(days=AUDIT_LOG_RETENTION_DAYS)
|
||||||
cutoff = datetime.now(timezone.utc) - timedelta(days=AUDIT_LOG_RETENTION_DAYS)
|
cutoff = datetime.now(timezone.utc) - timedelta(days=AUDIT_LOG_RETENTION_DAYS)
|
||||||
|
# Assign deleted_audit = (
|
||||||
deleted_audit = (
|
deleted_audit = (
|
||||||
db.query(AuditLog)
|
db.query(AuditLog)
|
||||||
|
# Chain .filter() call
|
||||||
.filter(AuditLog.timestamp < cutoff)
|
.filter(AuditLog.timestamp < cutoff)
|
||||||
|
# Chain .delete() call
|
||||||
.delete(synchronize_session=False)
|
.delete(synchronize_session=False)
|
||||||
)
|
)
|
||||||
|
# Check: deleted_audit
|
||||||
if deleted_audit:
|
if deleted_audit:
|
||||||
|
# Log info:
|
||||||
logger.info(
|
logger.info(
|
||||||
|
# Literal argument value
|
||||||
"Retention: deleted %d audit logs older than %d days",
|
"Retention: deleted %d audit logs older than %d days",
|
||||||
deleted_audit,
|
deleted_audit,
|
||||||
AUDIT_LOG_RETENTION_DAYS,
|
AUDIT_LOG_RETENTION_DAYS,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Assign deleted_notifications = cleanup_old_notifications(db, days=90)
|
||||||
deleted_notifications = cleanup_old_notifications(db, days=90)
|
deleted_notifications = cleanup_old_notifications(db, days=90)
|
||||||
|
# Commit all pending changes to the database
|
||||||
db.commit()
|
db.commit()
|
||||||
|
# Return {
|
||||||
return {
|
return {
|
||||||
|
# Literal argument value
|
||||||
"audit_logs_deleted": deleted_audit,
|
"audit_logs_deleted": deleted_audit,
|
||||||
|
# Literal argument value
|
||||||
"notifications_deleted": deleted_notifications,
|
"notifications_deleted": deleted_notifications,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Define function run_retention_job
|
||||||
def run_retention_job() -> None:
|
def run_retention_job() -> None:
|
||||||
"""Entry point for the daily retention scheduler job."""
|
"""Entry point for the daily retention scheduler job."""
|
||||||
|
# Log info: "Scheduled retention job starting..."
|
||||||
logger.info("Scheduled retention job starting...")
|
logger.info("Scheduled retention job starting...")
|
||||||
|
# Assign db = SessionLocal()
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
|
# Attempt the following; catch errors below
|
||||||
try:
|
try:
|
||||||
|
# Assign summary = apply_retention_policies(db)
|
||||||
summary = apply_retention_policies(db)
|
summary = apply_retention_policies(db)
|
||||||
|
# Log info: "Retention job finished — %s", summary
|
||||||
logger.info("Retention job finished — %s", summary)
|
logger.info("Retention job finished — %s", summary)
|
||||||
|
# Handle Exception
|
||||||
except Exception:
|
except Exception:
|
||||||
|
# Log exception: "Retention job failed"
|
||||||
logger.exception("Retention job failed")
|
logger.exception("Retention job failed")
|
||||||
|
# Roll back all uncommitted changes
|
||||||
db.rollback()
|
db.rollback()
|
||||||
|
# Always execute this cleanup block
|
||||||
finally:
|
finally:
|
||||||
|
# Close the database session
|
||||||
db.close()
|
db.close()
|
||||||
|
|||||||
@@ -1,6 +1,10 @@
|
|||||||
"""Shared SlowAPI rate limiter for all routers."""
|
"""Shared SlowAPI rate limiter for all routers."""
|
||||||
|
|
||||||
|
# Import Limiter from slowapi
|
||||||
from slowapi import Limiter
|
from slowapi import Limiter
|
||||||
|
|
||||||
|
# Import get_remote_address from slowapi.util
|
||||||
from slowapi.util import get_remote_address
|
from slowapi.util import get_remote_address
|
||||||
|
|
||||||
|
# Assign limiter = Limiter(key_func=get_remote_address)
|
||||||
limiter = Limiter(key_func=get_remote_address)
|
limiter = Limiter(key_func=get_remote_address)
|
||||||
|
|||||||
@@ -8,60 +8,101 @@ In **development** (default), uses a human-readable text format for
|
|||||||
comfortable local work.
|
comfortable local work.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Enable future language features for compatibility
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
# Import json
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
# Import logging
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
# Import os
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
# Import sys
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
# Import datetime, timezone from datetime
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
|
||||||
|
# Define class _JSONFormatter
|
||||||
class _JSONFormatter(logging.Formatter):
|
class _JSONFormatter(logging.Formatter):
|
||||||
"""Emit each log record as a single-line JSON object."""
|
"""Emit each log record as a single-line JSON object."""
|
||||||
|
|
||||||
|
# Define function format
|
||||||
def format(self, record: logging.LogRecord) -> str:
|
def format(self, record: logging.LogRecord) -> str:
|
||||||
|
# Assign payload = {
|
||||||
payload: dict = {
|
payload: dict = {
|
||||||
|
# Literal argument value
|
||||||
"timestamp": datetime.fromtimestamp(record.created, tz=timezone.utc).isoformat(),
|
"timestamp": datetime.fromtimestamp(record.created, tz=timezone.utc).isoformat(),
|
||||||
|
# Literal argument value
|
||||||
"level": record.levelname,
|
"level": record.levelname,
|
||||||
|
# Literal argument value
|
||||||
"logger": record.name,
|
"logger": record.name,
|
||||||
|
# Literal argument value
|
||||||
"message": record.getMessage(),
|
"message": record.getMessage(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Check: record.exc_info and record.exc_info[1] is not None
|
||||||
if record.exc_info and record.exc_info[1] is not None:
|
if record.exc_info and record.exc_info[1] is not None:
|
||||||
|
# Assign payload["exception"] = self.formatException(record.exc_info)
|
||||||
payload["exception"] = self.formatException(record.exc_info)
|
payload["exception"] = self.formatException(record.exc_info)
|
||||||
|
|
||||||
|
# Assign extra = getattr(record, "_extra", None)
|
||||||
extra = getattr(record, "_extra", None)
|
extra = getattr(record, "_extra", None)
|
||||||
|
# Check: extra
|
||||||
if extra:
|
if extra:
|
||||||
|
# Call payload.update()
|
||||||
payload.update(extra)
|
payload.update(extra)
|
||||||
|
|
||||||
|
# Return json.dumps(payload, default=str)
|
||||||
return json.dumps(payload, default=str)
|
return json.dumps(payload, default=str)
|
||||||
|
|
||||||
|
|
||||||
|
# Assign _DEV_FORMAT = "%(asctime)s %(levelname)-8s %(name)s — %(message)s"
|
||||||
_DEV_FORMAT = "%(asctime)s %(levelname)-8s %(name)s — %(message)s"
|
_DEV_FORMAT = "%(asctime)s %(levelname)-8s %(name)s — %(message)s"
|
||||||
|
|
||||||
|
|
||||||
|
# Define function setup_logging
|
||||||
def setup_logging() -> None:
|
def setup_logging() -> None:
|
||||||
"""Configure the root logger based on the environment."""
|
"""Configure the root logger based on the environment."""
|
||||||
|
# Assign is_production = os.environ.get("AEGIS_ENV", "").lower() == "production"
|
||||||
is_production = os.environ.get("AEGIS_ENV", "").lower() == "production"
|
is_production = os.environ.get("AEGIS_ENV", "").lower() == "production"
|
||||||
|
# Assign level_name = os.environ.get("LOG_LEVEL", "INFO").upper()
|
||||||
level_name = os.environ.get("LOG_LEVEL", "INFO").upper()
|
level_name = os.environ.get("LOG_LEVEL", "INFO").upper()
|
||||||
|
# Assign level = getattr(logging, level_name, logging.INFO)
|
||||||
level = getattr(logging, level_name, logging.INFO)
|
level = getattr(logging, level_name, logging.INFO)
|
||||||
|
|
||||||
|
# Assign root = logging.getLogger()
|
||||||
root = logging.getLogger()
|
root = logging.getLogger()
|
||||||
|
# Call root.setLevel()
|
||||||
root.setLevel(level)
|
root.setLevel(level)
|
||||||
|
|
||||||
|
# Check: root.handlers
|
||||||
if root.handlers:
|
if root.handlers:
|
||||||
|
# Call root.handlers.clear()
|
||||||
root.handlers.clear()
|
root.handlers.clear()
|
||||||
|
|
||||||
|
# Assign handler = logging.StreamHandler(sys.stdout)
|
||||||
handler = logging.StreamHandler(sys.stdout)
|
handler = logging.StreamHandler(sys.stdout)
|
||||||
|
# Call handler.setLevel()
|
||||||
handler.setLevel(level)
|
handler.setLevel(level)
|
||||||
|
|
||||||
|
# Check: is_production
|
||||||
if is_production:
|
if is_production:
|
||||||
|
# Call handler.setFormatter()
|
||||||
handler.setFormatter(_JSONFormatter())
|
handler.setFormatter(_JSONFormatter())
|
||||||
|
# Fallback: handle remaining cases
|
||||||
else:
|
else:
|
||||||
|
# Call handler.setFormatter()
|
||||||
handler.setFormatter(logging.Formatter(_DEV_FORMAT))
|
handler.setFormatter(logging.Formatter(_DEV_FORMAT))
|
||||||
|
|
||||||
|
# Call root.addHandler()
|
||||||
root.addHandler(handler)
|
root.addHandler(handler)
|
||||||
|
|
||||||
|
# Call logging.getLogger()
|
||||||
logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
|
logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
|
||||||
|
# Call logging.getLogger()
|
||||||
logging.getLogger("sqlalchemy.engine").setLevel(logging.WARNING)
|
logging.getLogger("sqlalchemy.engine").setLevel(logging.WARNING)
|
||||||
|
|||||||
+264
-23
@@ -1,13 +1,41 @@
|
|||||||
|
"""FastAPI application factory and global middleware/exception configuration.
|
||||||
|
|
||||||
|
Builds the ``app`` instance, wires up CORS, rate limiting, domain-error
|
||||||
|
mapping, all API routers, and async lifespan hooks (MinIO bucket creation,
|
||||||
|
APScheduler startup/shutdown).
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Import logging
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
# Import os
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
# Import AsyncGenerator from collections.abc
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
|
||||||
|
# Import asynccontextmanager from contextlib
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
# Import FastAPI, Request, status from fastapi
|
||||||
from fastapi import FastAPI, Request, status
|
from fastapi import FastAPI, Request, status
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
|
||||||
from fastapi.responses import JSONResponse
|
# Import RequestValidationError from fastapi.exceptions
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
|
|
||||||
|
# Import CORSMiddleware from fastapi.middleware.cors
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
|
# Import JSONResponse from fastapi.responses
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
|
# Import _rate_limit_exceeded_handler from slowapi
|
||||||
from slowapi import _rate_limit_exceeded_handler
|
from slowapi import _rate_limit_exceeded_handler
|
||||||
|
|
||||||
|
# Import RateLimitExceeded from slowapi.errors
|
||||||
from slowapi.errors import RateLimitExceeded
|
from slowapi.errors import RateLimitExceeded
|
||||||
|
|
||||||
|
# Import SQLAlchemyError from sqlalchemy.exc
|
||||||
from sqlalchemy.exc import SQLAlchemyError
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
|
|
||||||
from app.routers import auth as auth_router
|
from app.routers import auth as auth_router
|
||||||
@@ -50,24 +78,130 @@ from app.routers import api_keys as api_keys_router
|
|||||||
from app.routers import sso as sso_router
|
from app.routers import sso as sso_router
|
||||||
from app.routers import operational_alerts as alerts_router
|
from app.routers import operational_alerts as alerts_router
|
||||||
from app.domain.errors import DomainError
|
from app.domain.errors import DomainError
|
||||||
from app.middleware.error_handler import domain_exception_handler
|
|
||||||
from app.middleware.request_context import RequestContextMiddleware
|
# Import scheduler, start_scheduler from app.jobs.mitre_sync_job
|
||||||
|
from app.jobs.mitre_sync_job import scheduler, start_scheduler
|
||||||
|
|
||||||
|
# Import limiter from app.limiter
|
||||||
from app.limiter import limiter
|
from app.limiter import limiter
|
||||||
|
|
||||||
|
# Import setup_logging from app.logging_config
|
||||||
|
from app.logging_config import setup_logging
|
||||||
|
|
||||||
|
# Import domain_exception_handler from app.middleware.error_handler
|
||||||
|
from app.middleware.error_handler import domain_exception_handler
|
||||||
|
|
||||||
|
# Import RequestContextMiddleware from app.middleware.request_context
|
||||||
|
from app.middleware.request_context import RequestContextMiddleware
|
||||||
|
|
||||||
|
# Import advanced_metrics as advanced_metrics_router from app.routers
|
||||||
|
from app.routers import advanced_metrics as advanced_metrics_router
|
||||||
|
|
||||||
|
# Import analytics as analytics_router from app.routers
|
||||||
|
from app.routers import analytics as analytics_router
|
||||||
|
|
||||||
|
# Import audit as audit_router from app.routers
|
||||||
|
from app.routers import audit as audit_router
|
||||||
|
|
||||||
|
# Import auth as auth_router from app.routers
|
||||||
|
from app.routers import auth as auth_router
|
||||||
|
|
||||||
|
# Import campaigns as campaigns_router from app.routers
|
||||||
|
from app.routers import campaigns as campaigns_router
|
||||||
|
|
||||||
|
# Import compliance as compliance_router from app.routers
|
||||||
|
from app.routers import compliance as compliance_router
|
||||||
|
|
||||||
|
# Import d3fend as d3fend_router from app.routers
|
||||||
|
from app.routers import d3fend as d3fend_router
|
||||||
|
|
||||||
|
# Import data_sources as data_sources_router from app.routers
|
||||||
|
from app.routers import data_sources as data_sources_router
|
||||||
|
|
||||||
|
# Import detection_rules as detection_rules_router from app.routers
|
||||||
|
from app.routers import detection_rules as detection_rules_router
|
||||||
|
|
||||||
|
# Import evidence as evidence_router from app.routers
|
||||||
|
from app.routers import evidence as evidence_router
|
||||||
|
|
||||||
|
# Import heatmap as heatmap_router from app.routers
|
||||||
|
from app.routers import heatmap as heatmap_router
|
||||||
|
|
||||||
|
# Import jira as jira_router from app.routers
|
||||||
|
from app.routers import jira as jira_router
|
||||||
|
|
||||||
|
# Import metrics as metrics_router from app.routers
|
||||||
|
from app.routers import metrics as metrics_router
|
||||||
|
|
||||||
|
# Import notifications as notifications_router from app.routers
|
||||||
|
from app.routers import notifications as notifications_router
|
||||||
|
|
||||||
|
# Import operational_metrics as operational_metrics_router from app.routers
|
||||||
|
from app.routers import operational_metrics as operational_metrics_router
|
||||||
|
|
||||||
|
# Import osint as osint_router from app.routers
|
||||||
|
from app.routers import osint as osint_router
|
||||||
|
|
||||||
|
# Import professional_reports as professional_reports_ro... from app.routers
|
||||||
|
from app.routers import professional_reports as professional_reports_router
|
||||||
|
|
||||||
|
# Import reports as reports_router from app.routers
|
||||||
|
from app.routers import reports as reports_router
|
||||||
|
|
||||||
|
# Import scores as scores_router from app.routers
|
||||||
|
from app.routers import scores as scores_router
|
||||||
|
|
||||||
|
# Import snapshots as snapshots_router from app.routers
|
||||||
|
from app.routers import snapshots as snapshots_router
|
||||||
|
|
||||||
|
# Import system as system_router from app.routers
|
||||||
|
from app.routers import system as system_router
|
||||||
|
|
||||||
|
# Import techniques as techniques_router from app.routers
|
||||||
|
from app.routers import techniques as techniques_router
|
||||||
|
|
||||||
|
# Import test_templates as test_templates_router from app.routers
|
||||||
|
from app.routers import test_templates as test_templates_router
|
||||||
|
|
||||||
|
# Import tests as tests_router from app.routers
|
||||||
|
from app.routers import tests as tests_router
|
||||||
|
|
||||||
|
# Import threat_actors as threat_actors_router from app.routers
|
||||||
|
from app.routers import threat_actors as threat_actors_router
|
||||||
|
|
||||||
|
# Import users as users_router from app.routers
|
||||||
|
from app.routers import users as users_router
|
||||||
|
|
||||||
|
# Import worklogs as worklogs_router from app.routers
|
||||||
|
from app.routers import worklogs as worklogs_router
|
||||||
|
|
||||||
|
# Import ensure_bucket_exists from app.storage
|
||||||
from app.storage import ensure_bucket_exists
|
from app.storage import ensure_bucket_exists
|
||||||
from app.jobs.mitre_sync_job import start_scheduler, scheduler
|
|
||||||
|
# Import settings as _settings from app.config
|
||||||
|
from app.config import settings as _settings
|
||||||
|
|
||||||
|
# Configure structured logging before any module initialises its own logger
|
||||||
|
setup_logging()
|
||||||
|
|
||||||
# ── Environment detection ─────────────────────────────────────────────────
|
# ── Environment detection ─────────────────────────────────────────────────
|
||||||
_IS_PRODUCTION = os.environ.get("AEGIS_ENV", "").lower() == "production"
|
_IS_PRODUCTION = os.environ.get("AEGIS_ENV", "").lower() == "production"
|
||||||
|
|
||||||
# ── Logging ───────────────────────────────────────────────────────────────
|
# Apply the @asynccontextmanager decorator
|
||||||
from app.logging_config import setup_logging
|
|
||||||
|
|
||||||
setup_logging()
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
# Define async function lifespan
|
||||||
"""Startup / shutdown logic."""
|
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||||
|
"""Manage application startup and shutdown lifecycle.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app (FastAPI): The FastAPI application instance.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
None: Control is yielded to the running application.
|
||||||
|
"""
|
||||||
|
# Call ensure_bucket_exists()
|
||||||
ensure_bucket_exists()
|
ensure_bucket_exists()
|
||||||
|
# Call start_scheduler()
|
||||||
start_scheduler()
|
start_scheduler()
|
||||||
# Seed decay policies
|
# Seed decay policies
|
||||||
from app.database import SessionLocal
|
from app.database import SessionLocal
|
||||||
@@ -95,17 +229,24 @@ async def lifespan(app: FastAPI):
|
|||||||
|
|
||||||
# ── In production, disable Swagger UI and ReDoc to hide API surface ──────
|
# ── In production, disable Swagger UI and ReDoc to hide API surface ──────
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
|
# Keyword argument: title
|
||||||
title="Attack Coverage Platform",
|
title="Attack Coverage Platform",
|
||||||
|
# Keyword argument: lifespan
|
||||||
lifespan=lifespan,
|
lifespan=lifespan,
|
||||||
|
# Keyword argument: docs_url
|
||||||
docs_url=None if _IS_PRODUCTION else "/docs",
|
docs_url=None if _IS_PRODUCTION else "/docs",
|
||||||
|
# Keyword argument: redoc_url
|
||||||
redoc_url=None if _IS_PRODUCTION else "/redoc",
|
redoc_url=None if _IS_PRODUCTION else "/redoc",
|
||||||
|
# Keyword argument: openapi_url
|
||||||
openapi_url=None if _IS_PRODUCTION else "/openapi.json",
|
openapi_url=None if _IS_PRODUCTION else "/openapi.json",
|
||||||
)
|
)
|
||||||
|
|
||||||
# ── Rate Limiter ──────────────────────────────────────────────────────────
|
# ── Rate Limiter ──────────────────────────────────────────────────────────
|
||||||
app.state.limiter = limiter
|
app.state.limiter = limiter
|
||||||
|
# Call app.add_exception_handler()
|
||||||
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
||||||
|
|
||||||
|
# Call app.add_middleware()
|
||||||
app.add_middleware(RequestContextMiddleware)
|
app.add_middleware(RequestContextMiddleware)
|
||||||
|
|
||||||
|
|
||||||
@@ -130,49 +271,77 @@ app.add_middleware(NoCacheAPIMiddleware)
|
|||||||
app.add_exception_handler(DomainError, domain_exception_handler)
|
app.add_exception_handler(DomainError, domain_exception_handler)
|
||||||
|
|
||||||
# ── CORS ──────────────────────────────────────────────────────────────────
|
# ── CORS ──────────────────────────────────────────────────────────────────
|
||||||
from app.config import settings as _settings
|
|
||||||
|
|
||||||
_cors_origins: list[str] = [
|
_cors_origins: list[str] = [
|
||||||
o.strip() for o in _settings.CORS_ORIGINS.split(",") if o.strip()
|
o.strip() for o in _settings.CORS_ORIGINS.split(",") if o.strip()
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Call app.add_middleware()
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
|
# Keyword argument: allow_origins
|
||||||
allow_origins=_cors_origins,
|
allow_origins=_cors_origins,
|
||||||
|
# Keyword argument: allow_credentials
|
||||||
allow_credentials=True,
|
allow_credentials=True,
|
||||||
|
# Keyword argument: allow_methods
|
||||||
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
|
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
|
||||||
|
# Keyword argument: allow_headers
|
||||||
allow_headers=["Authorization", "Content-Type"],
|
allow_headers=["Authorization", "Content-Type"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# ── Routers ──────────────────────────────────────────────────────────────
|
# ── Routers ──────────────────────────────────────────────────────────────
|
||||||
app.include_router(auth_router.router, prefix="/api/v1")
|
app.include_router(auth_router.router, prefix="/api/v1")
|
||||||
|
# Call app.include_router()
|
||||||
app.include_router(techniques_router.router, prefix="/api/v1")
|
app.include_router(techniques_router.router, prefix="/api/v1")
|
||||||
|
# Call app.include_router()
|
||||||
app.include_router(tests_router.router, prefix="/api/v1")
|
app.include_router(tests_router.router, prefix="/api/v1")
|
||||||
|
# Call app.include_router()
|
||||||
app.include_router(evidence_router.router, prefix="/api/v1")
|
app.include_router(evidence_router.router, prefix="/api/v1")
|
||||||
|
# Call app.include_router()
|
||||||
app.include_router(test_templates_router.router, prefix="/api/v1")
|
app.include_router(test_templates_router.router, prefix="/api/v1")
|
||||||
|
# Call app.include_router()
|
||||||
app.include_router(system_router.router, prefix="/api/v1")
|
app.include_router(system_router.router, prefix="/api/v1")
|
||||||
|
# Call app.include_router()
|
||||||
app.include_router(metrics_router.router, prefix="/api/v1")
|
app.include_router(metrics_router.router, prefix="/api/v1")
|
||||||
|
# Call app.include_router()
|
||||||
app.include_router(users_router.router, prefix="/api/v1")
|
app.include_router(users_router.router, prefix="/api/v1")
|
||||||
|
# Call app.include_router()
|
||||||
app.include_router(audit_router.router, prefix="/api/v1")
|
app.include_router(audit_router.router, prefix="/api/v1")
|
||||||
|
# Call app.include_router()
|
||||||
app.include_router(notifications_router.router, prefix="/api/v1")
|
app.include_router(notifications_router.router, prefix="/api/v1")
|
||||||
|
# Call app.include_router()
|
||||||
app.include_router(reports_router.router, prefix="/api/v1")
|
app.include_router(reports_router.router, prefix="/api/v1")
|
||||||
|
# Call app.include_router()
|
||||||
app.include_router(data_sources_router.router, prefix="/api/v1")
|
app.include_router(data_sources_router.router, prefix="/api/v1")
|
||||||
|
# Call app.include_router()
|
||||||
app.include_router(threat_actors_router.router, prefix="/api/v1")
|
app.include_router(threat_actors_router.router, prefix="/api/v1")
|
||||||
|
# Call app.include_router()
|
||||||
app.include_router(d3fend_router.router, prefix="/api/v1")
|
app.include_router(d3fend_router.router, prefix="/api/v1")
|
||||||
|
# Call app.include_router()
|
||||||
app.include_router(detection_rules_router.router, prefix="/api/v1")
|
app.include_router(detection_rules_router.router, prefix="/api/v1")
|
||||||
|
# Call app.include_router()
|
||||||
app.include_router(campaigns_router.router, prefix="/api/v1")
|
app.include_router(campaigns_router.router, prefix="/api/v1")
|
||||||
|
# Call app.include_router()
|
||||||
app.include_router(heatmap_router.router, prefix="/api/v1")
|
app.include_router(heatmap_router.router, prefix="/api/v1")
|
||||||
|
# Call app.include_router()
|
||||||
app.include_router(scores_router.router, prefix="/api/v1")
|
app.include_router(scores_router.router, prefix="/api/v1")
|
||||||
|
# Call app.include_router()
|
||||||
app.include_router(operational_metrics_router.router, prefix="/api/v1")
|
app.include_router(operational_metrics_router.router, prefix="/api/v1")
|
||||||
|
# Call app.include_router()
|
||||||
app.include_router(compliance_router.router, prefix="/api/v1")
|
app.include_router(compliance_router.router, prefix="/api/v1")
|
||||||
app.include_router(intel_router.router, prefix="/api/v1")
|
app.include_router(intel_router.router, prefix="/api/v1")
|
||||||
app.include_router(admin_config_router.router, prefix="/api/v1")
|
app.include_router(admin_config_router.router, prefix="/api/v1")
|
||||||
app.include_router(snapshots_router.router, prefix="/api/v1")
|
app.include_router(snapshots_router.router, prefix="/api/v1")
|
||||||
|
# Call app.include_router()
|
||||||
app.include_router(jira_router.router, prefix="/api/v1")
|
app.include_router(jira_router.router, prefix="/api/v1")
|
||||||
|
# Call app.include_router()
|
||||||
app.include_router(worklogs_router.router, prefix="/api/v1")
|
app.include_router(worklogs_router.router, prefix="/api/v1")
|
||||||
|
# Call app.include_router()
|
||||||
app.include_router(professional_reports_router.router, prefix="/api/v1")
|
app.include_router(professional_reports_router.router, prefix="/api/v1")
|
||||||
|
# Call app.include_router()
|
||||||
app.include_router(analytics_router.router, prefix="/api/v1")
|
app.include_router(analytics_router.router, prefix="/api/v1")
|
||||||
|
# Call app.include_router()
|
||||||
app.include_router(advanced_metrics_router.router, prefix="/api/v1")
|
app.include_router(advanced_metrics_router.router, prefix="/api/v1")
|
||||||
|
# Call app.include_router()
|
||||||
app.include_router(osint_router.router, prefix="/api/v1")
|
app.include_router(osint_router.router, prefix="/api/v1")
|
||||||
app.include_router(webhooks_router.router, prefix="/api/v1")
|
app.include_router(webhooks_router.router, prefix="/api/v1")
|
||||||
app.include_router(detection_lifecycle_router.router, prefix="/api/v1")
|
app.include_router(detection_lifecycle_router.router, prefix="/api/v1")
|
||||||
@@ -186,13 +355,19 @@ app.include_router(sso_router.router, prefix="/api/v1")
|
|||||||
app.include_router(alerts_router.router, prefix="/api/v1")
|
app.include_router(alerts_router.router, prefix="/api/v1")
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @app.get decorator
|
||||||
@app.get("/health", include_in_schema=False)
|
@app.get("/health", include_in_schema=False)
|
||||||
def health():
|
# Define function health
|
||||||
"""Minimal health check — returns only an HTTP 200 with no service metadata.
|
def health() -> dict[str, str]:
|
||||||
|
"""Return a minimal liveness probe response.
|
||||||
|
|
||||||
Access is restricted to internal networks at the Nginx level
|
Access is restricted to internal networks at the Nginx level
|
||||||
(see ``frontend/nginx.conf``).
|
(see ``frontend/nginx.conf``).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict[str, str]: A dict with ``{"status": "ok"}``.
|
||||||
"""
|
"""
|
||||||
|
# Return {"status": "ok"}
|
||||||
return {"status": "ok"}
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
@@ -200,51 +375,117 @@ def health():
|
|||||||
|
|
||||||
|
|
||||||
def _serialize_validation_errors(exc: RequestValidationError) -> list[dict]:
|
def _serialize_validation_errors(exc: RequestValidationError) -> list[dict]:
|
||||||
"""Return validation errors safe for JSON (no raw exception objects)."""
|
"""Return validation errors safe for JSON serialization.
|
||||||
|
|
||||||
|
Converts non-serializable values inside ``ctx`` dictionaries to strings
|
||||||
|
so the response body can be safely encoded.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
exc (RequestValidationError): The Pydantic validation exception.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[dict]: A list of sanitised error detail dictionaries.
|
||||||
|
"""
|
||||||
|
# Assign serialized = []
|
||||||
serialized: list[dict] = []
|
serialized: list[dict] = []
|
||||||
|
# Iterate over exc.errors()
|
||||||
for err in exc.errors():
|
for err in exc.errors():
|
||||||
|
# Assign item = dict(err)
|
||||||
item = dict(err)
|
item = dict(err)
|
||||||
|
# Assign ctx = item.get("ctx")
|
||||||
ctx = item.get("ctx")
|
ctx = item.get("ctx")
|
||||||
|
# Check: isinstance(ctx, dict)
|
||||||
if isinstance(ctx, dict):
|
if isinstance(ctx, dict):
|
||||||
|
# Assign item["ctx"] = {key: str(value) for key, value in ctx.items()}
|
||||||
item["ctx"] = {key: str(value) for key, value in ctx.items()}
|
item["ctx"] = {key: str(value) for key, value in ctx.items()}
|
||||||
|
# Call serialized.append()
|
||||||
serialized.append(item)
|
serialized.append(item)
|
||||||
|
# Return serialized
|
||||||
return serialized
|
return serialized
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @app.exception_handler decorator
|
||||||
@app.exception_handler(RequestValidationError)
|
@app.exception_handler(RequestValidationError)
|
||||||
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
# Define async function validation_exception_handler
|
||||||
"""Handle validation errors with consistent format."""
|
async def validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse:
|
||||||
|
"""Handle Pydantic validation errors and return a structured 422 response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request (Request): The incoming HTTP request.
|
||||||
|
exc (RequestValidationError): The caught validation exception.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JSONResponse: A 422 response with a ``VALIDATION_ERROR`` code and error details.
|
||||||
|
"""
|
||||||
|
# Return JSONResponse(
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
|
# Keyword argument: status_code
|
||||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||||
|
# Keyword argument: content
|
||||||
content={
|
content={
|
||||||
|
# Literal argument value
|
||||||
"detail": "Validation error",
|
"detail": "Validation error",
|
||||||
|
# Literal argument value
|
||||||
"code": "VALIDATION_ERROR",
|
"code": "VALIDATION_ERROR",
|
||||||
|
# Literal argument value
|
||||||
"errors": _serialize_validation_errors(exc),
|
"errors": _serialize_validation_errors(exc),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @app.exception_handler decorator
|
||||||
@app.exception_handler(SQLAlchemyError)
|
@app.exception_handler(SQLAlchemyError)
|
||||||
async def sqlalchemy_exception_handler(request: Request, exc: SQLAlchemyError):
|
# Define async function sqlalchemy_exception_handler
|
||||||
"""Handle database errors."""
|
async def sqlalchemy_exception_handler(request: Request, exc: SQLAlchemyError) -> JSONResponse:
|
||||||
|
"""Handle SQLAlchemy database errors and return a structured 500 response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request (Request): The incoming HTTP request.
|
||||||
|
exc (SQLAlchemyError): The caught SQLAlchemy exception.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JSONResponse: A 500 response with a ``DATABASE_ERROR`` code.
|
||||||
|
"""
|
||||||
|
# Log error: f"Database error: {exc}"
|
||||||
logging.error(f"Database error: {exc}")
|
logging.error(f"Database error: {exc}")
|
||||||
|
# Return JSONResponse(
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
|
# Keyword argument: status_code
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
# Keyword argument: content
|
||||||
content={
|
content={
|
||||||
|
# Literal argument value
|
||||||
"detail": "Database error occurred",
|
"detail": "Database error occurred",
|
||||||
|
# Literal argument value
|
||||||
"code": "DATABASE_ERROR",
|
"code": "DATABASE_ERROR",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @app.exception_handler decorator
|
||||||
@app.exception_handler(Exception)
|
@app.exception_handler(Exception)
|
||||||
async def general_exception_handler(request: Request, exc: Exception):
|
# Define async function general_exception_handler
|
||||||
"""Handle all unhandled exceptions."""
|
async def general_exception_handler(request: Request, exc: Exception) -> JSONResponse:
|
||||||
|
"""Handle all otherwise-unhandled exceptions and return a structured 500 response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request (Request): The incoming HTTP request.
|
||||||
|
exc (Exception): The unhandled exception.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JSONResponse: A 500 response with an ``INTERNAL_ERROR`` code.
|
||||||
|
"""
|
||||||
|
# Log error: f"Unhandled exception: {exc}"
|
||||||
logging.error(f"Unhandled exception: {exc}")
|
logging.error(f"Unhandled exception: {exc}")
|
||||||
|
# Return JSONResponse(
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
|
# Keyword argument: status_code
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
# Keyword argument: content
|
||||||
content={
|
content={
|
||||||
|
# Literal argument value
|
||||||
"detail": "An internal server error occurred",
|
"detail": "An internal server error occurred",
|
||||||
|
# Literal argument value
|
||||||
"code": "INTERNAL_ERROR",
|
"code": "INTERNAL_ERROR",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -0,0 +1 @@
|
|||||||
|
"""ASGI middleware components for request context, error handling, and rate limiting."""
|
||||||
|
|||||||
@@ -5,9 +5,13 @@ domain-layer errors into structured JSON responses, keeping
|
|||||||
the service layer free from FastAPI's ``HTTPException``.
|
the service layer free from FastAPI's ``HTTPException``.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Import Request from fastapi
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
|
|
||||||
|
# Import JSONResponse from fastapi.responses
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
|
# Import from app.domain.errors
|
||||||
from app.domain.errors import (
|
from app.domain.errors import (
|
||||||
BusinessRuleViolation,
|
BusinessRuleViolation,
|
||||||
DomainError,
|
DomainError,
|
||||||
@@ -18,28 +22,45 @@ from app.domain.errors import (
|
|||||||
PermissionViolation,
|
PermissionViolation,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Assign EXCEPTION_STATUS_MAP = {
|
||||||
EXCEPTION_STATUS_MAP: dict[type[DomainError], int] = {
|
EXCEPTION_STATUS_MAP: dict[type[DomainError], int] = {
|
||||||
|
# Entry: EntityNotFoundError
|
||||||
EntityNotFoundError: 404,
|
EntityNotFoundError: 404,
|
||||||
|
# Entry: DuplicateEntityError
|
||||||
DuplicateEntityError: 409,
|
DuplicateEntityError: 409,
|
||||||
|
# Entry: InvalidStateTransition
|
||||||
InvalidStateTransition: 400,
|
InvalidStateTransition: 400,
|
||||||
|
# Entry: InvalidOperationError
|
||||||
InvalidOperationError: 400,
|
InvalidOperationError: 400,
|
||||||
|
# Entry: BusinessRuleViolation
|
||||||
BusinessRuleViolation: 400,
|
BusinessRuleViolation: 400,
|
||||||
|
# Entry: PermissionViolation
|
||||||
PermissionViolation: 403,
|
PermissionViolation: 403,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Define async function domain_exception_handler
|
||||||
async def domain_exception_handler(
|
async def domain_exception_handler(
|
||||||
|
# Entry: request
|
||||||
request: Request,
|
request: Request,
|
||||||
|
# Entry: exc
|
||||||
exc: DomainError,
|
exc: DomainError,
|
||||||
) -> JSONResponse:
|
) -> JSONResponse:
|
||||||
"""Convert a :class:`DomainError` into a JSON error response."""
|
"""Convert a :class:`DomainError` into a JSON error response."""
|
||||||
|
# Assign status_code = EXCEPTION_STATUS_MAP.get(type(exc), 400)
|
||||||
status_code = EXCEPTION_STATUS_MAP.get(type(exc), 400)
|
status_code = EXCEPTION_STATUS_MAP.get(type(exc), 400)
|
||||||
|
|
||||||
|
# Assign content = {"detail": exc.message, "code": exc.code}
|
||||||
content: dict = {"detail": exc.message, "code": exc.code}
|
content: dict = {"detail": exc.message, "code": exc.code}
|
||||||
|
|
||||||
|
# Check: isinstance(exc, InvalidStateTransition)
|
||||||
if isinstance(exc, InvalidStateTransition):
|
if isinstance(exc, InvalidStateTransition):
|
||||||
|
# Assign content["current_state"] = exc.current_state
|
||||||
content["current_state"] = exc.current_state
|
content["current_state"] = exc.current_state
|
||||||
|
# Assign content["target_state"] = exc.target_state
|
||||||
content["target_state"] = exc.target_state
|
content["target_state"] = exc.target_state
|
||||||
|
# Assign content["valid_transitions"] = exc.valid_transitions
|
||||||
content["valid_transitions"] = exc.valid_transitions
|
content["valid_transitions"] = exc.valid_transitions
|
||||||
|
|
||||||
|
# Return JSONResponse(status_code=status_code, content=content)
|
||||||
return JSONResponse(status_code=status_code, content=content)
|
return JSONResponse(status_code=status_code, content=content)
|
||||||
|
|||||||
@@ -1,26 +1,74 @@
|
|||||||
"""Request context middleware — captures client IP and User-Agent per request."""
|
"""Request context middleware — captures client IP and User-Agent per request."""
|
||||||
|
|
||||||
|
# Import Awaitable, Callable from collections.abc
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
|
|
||||||
|
# Import ContextVar from contextvars
|
||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
|
|
||||||
|
# Import Request from fastapi
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
|
|
||||||
|
# Import BaseHTTPMiddleware from starlette.middleware.base
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
|
||||||
|
# Import Response from starlette.responses
|
||||||
|
from starlette.responses import Response
|
||||||
|
|
||||||
|
# Assign request_ip = ContextVar("request_ip", default="")
|
||||||
request_ip: ContextVar[str] = ContextVar("request_ip", default="")
|
request_ip: ContextVar[str] = ContextVar("request_ip", default="")
|
||||||
|
# Assign request_user_agent = ContextVar("request_user_agent", default="")
|
||||||
request_user_agent: ContextVar[str] = ContextVar("request_user_agent", default="")
|
request_user_agent: ContextVar[str] = ContextVar("request_user_agent", default="")
|
||||||
|
|
||||||
|
|
||||||
|
# Define function resolve_client_ip
|
||||||
def resolve_client_ip(request: Request) -> str:
|
def resolve_client_ip(request: Request) -> str:
|
||||||
"""Extract the client IP, honouring ``X-Forwarded-For`` when present."""
|
"""Extract the real client IP, honouring ``X-Forwarded-For`` when present.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request (Request): The incoming Starlette/FastAPI request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The resolved client IP address, or ``"unknown"`` when unavailable.
|
||||||
|
"""
|
||||||
|
# Assign forwarded = request.headers.get("X-Forwarded-For")
|
||||||
forwarded = request.headers.get("X-Forwarded-For")
|
forwarded = request.headers.get("X-Forwarded-For")
|
||||||
|
# Check: forwarded
|
||||||
if forwarded:
|
if forwarded:
|
||||||
|
# Return forwarded.split(",")[0].strip()
|
||||||
return forwarded.split(",")[0].strip()
|
return forwarded.split(",")[0].strip()
|
||||||
|
# Check: request.client
|
||||||
if request.client:
|
if request.client:
|
||||||
|
# Return request.client.host
|
||||||
return request.client.host
|
return request.client.host
|
||||||
|
# Return "unknown"
|
||||||
return "unknown"
|
return "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
# Define class RequestContextMiddleware
|
||||||
class RequestContextMiddleware(BaseHTTPMiddleware):
|
class RequestContextMiddleware(BaseHTTPMiddleware):
|
||||||
async def dispatch(self, request: Request, call_next):
|
"""Middleware that captures client IP and User-Agent into context variables."""
|
||||||
|
|
||||||
|
# Define async function dispatch
|
||||||
|
async def dispatch(
|
||||||
|
self,
|
||||||
|
# Entry: request
|
||||||
|
request: Request,
|
||||||
|
# Entry: call_next
|
||||||
|
call_next: Callable[[Request], Awaitable[Response]],
|
||||||
|
) -> Response:
|
||||||
|
"""Store client IP and User-Agent in context vars for the current request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request (Request): The incoming HTTP request.
|
||||||
|
call_next (Callable[[Request], Awaitable[Response]]): The next middleware or route handler.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response: The HTTP response produced by the downstream handler.
|
||||||
|
"""
|
||||||
|
# Call request_ip.set()
|
||||||
request_ip.set(resolve_client_ip(request))
|
request_ip.set(resolve_client_ip(request))
|
||||||
|
# Call request_user_agent.set()
|
||||||
request_user_agent.set(request.headers.get("User-Agent", ""))
|
request_user_agent.set(request.headers.get("User-Agent", ""))
|
||||||
|
# Return await call_next(request)
|
||||||
return await call_next(request)
|
return await call_next(request)
|
||||||
|
|||||||
@@ -1,10 +1,5 @@
|
|||||||
|
"""SQLAlchemy ORM model definitions for all database tables."""
|
||||||
# Import all models here so Alembic can detect them
|
# Import all models here so Alembic can detect them
|
||||||
from app.models.user import User
|
|
||||||
from app.models.technique import Technique
|
|
||||||
from app.models.test import Test
|
|
||||||
from app.models.test_template import TestTemplate
|
|
||||||
from app.models.evidence import Evidence
|
|
||||||
from app.models.intel import IntelItem
|
|
||||||
from app.models.audit import AuditLog
|
from app.models.audit import AuditLog
|
||||||
from app.models.notification import Notification
|
from app.models.notification import Notification
|
||||||
from app.models.data_source import DataSource
|
from app.models.data_source import DataSource
|
||||||
@@ -45,17 +40,96 @@ from app.models.api_key import ApiKey
|
|||||||
from app.models.sso_config import SsoConfig
|
from app.models.sso_config import SsoConfig
|
||||||
from app.models.operational_alert import AlertRule, AlertInstance
|
from app.models.operational_alert import AlertRule, AlertInstance
|
||||||
|
|
||||||
|
# Import Campaign, CampaignTest from app.models.campaign
|
||||||
|
from app.models.campaign import Campaign, CampaignTest
|
||||||
|
|
||||||
|
# Import from app.models.compliance
|
||||||
|
from app.models.compliance import (
|
||||||
|
ComplianceControl,
|
||||||
|
ComplianceControlMapping,
|
||||||
|
ComplianceFramework,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Import CoverageSnapshot, SnapshotTechniqueState from app.models.coverage_snapshot
|
||||||
|
from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState
|
||||||
|
|
||||||
|
# Import DataSource from app.models.data_source
|
||||||
|
from app.models.data_source import DataSource
|
||||||
|
|
||||||
|
# Import DefensiveTechnique, DefensiveTechniqueMapping from app.models.defensive_technique
|
||||||
|
from app.models.defensive_technique import DefensiveTechnique, DefensiveTechniqueMapping
|
||||||
|
|
||||||
|
# Import DetectionRule from app.models.detection_rule
|
||||||
|
from app.models.detection_rule import DetectionRule
|
||||||
|
|
||||||
|
# Import TeamSide, TechniqueStatus, TestResult, TestState from app.models.enums
|
||||||
|
from app.models.enums import TeamSide, TechniqueStatus, TestResult, TestState
|
||||||
|
|
||||||
|
# Import Evidence from app.models.evidence
|
||||||
|
from app.models.evidence import Evidence
|
||||||
|
|
||||||
|
# Import IntelItem from app.models.intel
|
||||||
|
from app.models.intel import IntelItem
|
||||||
|
|
||||||
|
# Import JiraLink, JiraLinkEntityType, JiraSyncDirection from app.models.jira_link
|
||||||
|
from app.models.jira_link import JiraLink, JiraLinkEntityType, JiraSyncDirection
|
||||||
|
|
||||||
|
# Import Notification from app.models.notification
|
||||||
|
from app.models.notification import Notification
|
||||||
|
|
||||||
|
# Import OsintItem from app.models.osint_item
|
||||||
|
from app.models.osint_item import OsintItem
|
||||||
|
|
||||||
|
# Import ScoringConfig from app.models.scoring_config
|
||||||
|
from app.models.scoring_config import ScoringConfig
|
||||||
|
|
||||||
|
# Import Technique from app.models.technique
|
||||||
|
from app.models.technique import Technique
|
||||||
|
|
||||||
|
# Import Test from app.models.test
|
||||||
|
from app.models.test import Test
|
||||||
|
|
||||||
|
# Import TestDetectionResult from app.models.test_detection_result
|
||||||
|
from app.models.test_detection_result import TestDetectionResult
|
||||||
|
|
||||||
|
# Import TestTemplate from app.models.test_template
|
||||||
|
from app.models.test_template import TestTemplate
|
||||||
|
|
||||||
|
# Import TestTemplateDetectionRule from app.models.test_template_detection_rule
|
||||||
|
from app.models.test_template_detection_rule import TestTemplateDetectionRule
|
||||||
|
|
||||||
|
# Import ThreatActor, ThreatActorTechnique from app.models.threat_actor
|
||||||
|
from app.models.threat_actor import ThreatActor, ThreatActorTechnique
|
||||||
|
|
||||||
|
# Import User from app.models.user
|
||||||
|
from app.models.user import User
|
||||||
|
|
||||||
|
# Import Worklog from app.models.worklog
|
||||||
|
from app.models.worklog import Worklog
|
||||||
|
|
||||||
|
# Assign __all__ = [
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
# Literal argument value
|
||||||
"User", "Technique", "Test", "TestTemplate", "Evidence",
|
"User", "Technique", "Test", "TestTemplate", "Evidence",
|
||||||
|
# Literal argument value
|
||||||
"IntelItem", "AuditLog", "Notification", "DataSource",
|
"IntelItem", "AuditLog", "Notification", "DataSource",
|
||||||
|
# Literal argument value
|
||||||
"DetectionRule", "ThreatActor", "ThreatActorTechnique",
|
"DetectionRule", "ThreatActor", "ThreatActorTechnique",
|
||||||
|
# Literal argument value
|
||||||
"DefensiveTechnique", "DefensiveTechniqueMapping",
|
"DefensiveTechnique", "DefensiveTechniqueMapping",
|
||||||
|
# Literal argument value
|
||||||
"TestTemplateDetectionRule", "TestDetectionResult",
|
"TestTemplateDetectionRule", "TestDetectionResult",
|
||||||
|
# Literal argument value
|
||||||
"Campaign", "CampaignTest",
|
"Campaign", "CampaignTest",
|
||||||
|
# Literal argument value
|
||||||
"ComplianceFramework", "ComplianceControl", "ComplianceControlMapping",
|
"ComplianceFramework", "ComplianceControl", "ComplianceControlMapping",
|
||||||
|
# Literal argument value
|
||||||
"CoverageSnapshot", "SnapshotTechniqueState",
|
"CoverageSnapshot", "SnapshotTechniqueState",
|
||||||
|
# Literal argument value
|
||||||
"JiraLink", "JiraLinkEntityType", "JiraSyncDirection",
|
"JiraLink", "JiraLinkEntityType", "JiraSyncDirection",
|
||||||
|
# Literal argument value
|
||||||
"Worklog", "OsintItem", "ScoringConfig",
|
"Worklog", "OsintItem", "ScoringConfig",
|
||||||
|
# Literal argument value
|
||||||
"TechniqueStatus", "TestState", "TestResult", "TeamSide",
|
"TechniqueStatus", "TestState", "TestResult", "TeamSide",
|
||||||
"WebhookConfig", "SystemConfig",
|
"WebhookConfig", "SystemConfig",
|
||||||
"DetectionAsset", "DetectionTechniqueMapping", "DetectionValidation",
|
"DetectionAsset", "DetectionTechniqueMapping", "DetectionValidation",
|
||||||
|
|||||||
@@ -1,35 +1,58 @@
|
|||||||
|
"""SQLAlchemy model for the audit log table."""
|
||||||
|
|
||||||
|
# Import uuid
|
||||||
import uuid
|
import uuid
|
||||||
from sqlalchemy import Column, String, DateTime, ForeignKey, Index, func
|
|
||||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
# Import Column, DateTime, ForeignKey, Index, String, func from sqlalchemy
|
||||||
|
from sqlalchemy import Column, DateTime, ForeignKey, Index, String, func
|
||||||
|
|
||||||
|
# Import JSONB, UUID from sqlalchemy.dialects.postgresql
|
||||||
|
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||||
|
|
||||||
|
# Import relationship from sqlalchemy.orm
|
||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
|
# Import Base from app.database
|
||||||
from app.database import Base
|
from app.database import Base
|
||||||
|
|
||||||
|
|
||||||
|
# Define class AuditLog
|
||||||
class AuditLog(Base):
|
class AuditLog(Base):
|
||||||
"""
|
"""Audit log model for tracking all system actions.
|
||||||
Audit log model for tracking all system actions.
|
|
||||||
|
|
||||||
Records user actions, entity changes, and system events
|
Records user actions, entity changes, and system events
|
||||||
for security auditing and compliance purposes.
|
for security auditing and compliance purposes.
|
||||||
"""
|
"""
|
||||||
|
# Assign __tablename__ = "audit_logs"
|
||||||
__tablename__ = "audit_logs"
|
__tablename__ = "audit_logs"
|
||||||
|
|
||||||
|
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
|
# Assign user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
|
||||||
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
|
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
|
||||||
|
# Assign action = Column(String, nullable=False)
|
||||||
action = Column(String, nullable=False)
|
action = Column(String, nullable=False)
|
||||||
|
# Assign entity_type = Column(String, nullable=True)
|
||||||
entity_type = Column(String, nullable=True)
|
entity_type = Column(String, nullable=True)
|
||||||
|
# Assign entity_id = Column(String, nullable=True)
|
||||||
entity_id = Column(String, nullable=True)
|
entity_id = Column(String, nullable=True)
|
||||||
|
# Assign timestamp = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
timestamp = Column(DateTime(timezone=True), server_default=func.now())
|
timestamp = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
# Assign details = Column(JSONB, nullable=True)
|
||||||
details = Column(JSONB, nullable=True)
|
details = Column(JSONB, nullable=True)
|
||||||
|
# Assign ip_address = Column(String(45), nullable=True)
|
||||||
ip_address = Column(String(45), nullable=True)
|
ip_address = Column(String(45), nullable=True)
|
||||||
|
# Assign user_agent = Column(String(500), nullable=True)
|
||||||
user_agent = Column(String(500), nullable=True)
|
user_agent = Column(String(500), nullable=True)
|
||||||
|
# Assign integrity_hash = Column(String(64), nullable=True)
|
||||||
integrity_hash = Column(String(64), nullable=True)
|
integrity_hash = Column(String(64), nullable=True)
|
||||||
|
# Assign session_id = Column(String(100), nullable=True)
|
||||||
session_id = Column(String(100), nullable=True)
|
session_id = Column(String(100), nullable=True)
|
||||||
|
|
||||||
# Relationships
|
# Relationships
|
||||||
user = relationship("User")
|
user = relationship("User")
|
||||||
|
|
||||||
|
# Assign __table_args__ = (
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
Index("ix_audit_logs_entity", "entity_type", "entity_id"),
|
Index("ix_audit_logs_entity", "entity_type", "entity_id"),
|
||||||
Index("ix_audit_logs_timestamp", "timestamp"),
|
Index("ix_audit_logs_timestamp", "timestamp"),
|
||||||
|
|||||||
@@ -4,20 +4,35 @@ Campaigns group multiple tests into a kill chain sequence,
|
|||||||
enabling simulation of complete attack chains and APT emulations.
|
enabling simulation of complete attack chains and APT emulations.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Import uuid
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
# Import from sqlalchemy
|
||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
Column, String, Text, Integer, Boolean, DateTime,
|
Boolean,
|
||||||
ForeignKey, Index, func,
|
Column,
|
||||||
|
DateTime,
|
||||||
|
ForeignKey,
|
||||||
|
Index,
|
||||||
|
Integer,
|
||||||
|
String,
|
||||||
|
Text,
|
||||||
|
func,
|
||||||
)
|
)
|
||||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
|
||||||
|
# Import JSONB, UUID from sqlalchemy.dialects.postgresql
|
||||||
|
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||||
|
|
||||||
|
# Import relationship from sqlalchemy.orm
|
||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
|
# Import Base from app.database
|
||||||
from app.database import Base
|
from app.database import Base
|
||||||
|
|
||||||
|
|
||||||
|
# Define class Campaign
|
||||||
class Campaign(Base):
|
class Campaign(Base):
|
||||||
"""
|
"""A campaign groups multiple tests into a sequenced attack chain.
|
||||||
A campaign groups multiple tests into a sequenced attack chain.
|
|
||||||
|
|
||||||
Types:
|
Types:
|
||||||
- custom: manually created campaign
|
- custom: manually created campaign
|
||||||
@@ -31,62 +46,97 @@ class Campaign(Base):
|
|||||||
- completed: all tests done
|
- completed: all tests done
|
||||||
- archived: historical record
|
- archived: historical record
|
||||||
"""
|
"""
|
||||||
|
# Assign __tablename__ = "campaigns"
|
||||||
__tablename__ = "campaigns"
|
__tablename__ = "campaigns"
|
||||||
|
|
||||||
|
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
|
# Assign name = Column(String, nullable=False)
|
||||||
name = Column(String, nullable=False)
|
name = Column(String, nullable=False)
|
||||||
|
# Assign description = Column(Text, nullable=True)
|
||||||
description = Column(Text, nullable=True)
|
description = Column(Text, nullable=True)
|
||||||
|
# Assign type = Column(String, nullable=False, default="custom") # custom, ap...
|
||||||
type = Column(String, nullable=False, default="custom") # custom, apt_emulation, kill_chain, compliance
|
type = Column(String, nullable=False, default="custom") # custom, apt_emulation, kill_chain, compliance
|
||||||
|
# Assign threat_actor_id = Column(
|
||||||
threat_actor_id = Column(
|
threat_actor_id = Column(
|
||||||
UUID(as_uuid=True),
|
UUID(as_uuid=True),
|
||||||
ForeignKey("threat_actors.id", ondelete="SET NULL"),
|
ForeignKey("threat_actors.id", ondelete="SET NULL"),
|
||||||
|
# Keyword argument: nullable
|
||||||
nullable=True,
|
nullable=True,
|
||||||
)
|
)
|
||||||
|
# Assign status = Column(String, nullable=False, default="draft") # draft, activ...
|
||||||
status = Column(String, nullable=False, default="draft") # draft, active, completed, archived
|
status = Column(String, nullable=False, default="draft") # draft, active, completed, archived
|
||||||
|
# Assign created_by = Column(
|
||||||
created_by = Column(
|
created_by = Column(
|
||||||
UUID(as_uuid=True),
|
UUID(as_uuid=True),
|
||||||
ForeignKey("users.id", ondelete="SET NULL"),
|
ForeignKey("users.id", ondelete="SET NULL"),
|
||||||
|
# Keyword argument: nullable
|
||||||
nullable=True,
|
nullable=True,
|
||||||
)
|
)
|
||||||
start_date = Column(DateTime, nullable=True) # campaign won't activate before this date
|
start_date = Column(DateTime, nullable=True) # campaign won't activate before this date
|
||||||
scheduled_at = Column(DateTime, nullable=True)
|
scheduled_at = Column(DateTime, nullable=True)
|
||||||
|
# Assign completed_at = Column(DateTime, nullable=True)
|
||||||
completed_at = Column(DateTime, nullable=True)
|
completed_at = Column(DateTime, nullable=True)
|
||||||
|
# Assign target_platform = Column(String, nullable=True)
|
||||||
target_platform = Column(String, nullable=True)
|
target_platform = Column(String, nullable=True)
|
||||||
|
# Assign tags = Column(JSONB, nullable=True, default=[])
|
||||||
tags = Column(JSONB, nullable=True, default=[])
|
tags = Column(JSONB, nullable=True, default=[])
|
||||||
|
# Assign created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
# Assign data_classification = Column(String(20), nullable=False, server_default="internal")
|
||||||
data_classification = Column(String(20), nullable=False, server_default="internal")
|
data_classification = Column(String(20), nullable=False, server_default="internal")
|
||||||
|
|
||||||
# Recurring scheduling fields
|
# Recurring scheduling fields
|
||||||
is_recurring = Column(Boolean, default=False)
|
is_recurring = Column(Boolean, default=False)
|
||||||
|
# Assign recurrence_pattern = Column(String, nullable=True) # weekly, monthly, quarterly
|
||||||
recurrence_pattern = Column(String, nullable=True) # weekly, monthly, quarterly
|
recurrence_pattern = Column(String, nullable=True) # weekly, monthly, quarterly
|
||||||
|
# Assign next_run_at = Column(DateTime, nullable=True)
|
||||||
next_run_at = Column(DateTime, nullable=True)
|
next_run_at = Column(DateTime, nullable=True)
|
||||||
|
# Assign last_run_at = Column(DateTime, nullable=True)
|
||||||
last_run_at = Column(DateTime, nullable=True)
|
last_run_at = Column(DateTime, nullable=True)
|
||||||
|
# Assign parent_campaign_id = Column(
|
||||||
parent_campaign_id = Column(
|
parent_campaign_id = Column(
|
||||||
UUID(as_uuid=True),
|
UUID(as_uuid=True),
|
||||||
ForeignKey("campaigns.id", ondelete="SET NULL"),
|
ForeignKey("campaigns.id", ondelete="SET NULL"),
|
||||||
|
# Keyword argument: nullable
|
||||||
nullable=True,
|
nullable=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Relationships
|
# Relationships
|
||||||
threat_actor = relationship("ThreatActor")
|
threat_actor = relationship("ThreatActor")
|
||||||
|
# Assign creator = relationship("User", foreign_keys=[created_by])
|
||||||
creator = relationship("User", foreign_keys=[created_by])
|
creator = relationship("User", foreign_keys=[created_by])
|
||||||
|
# Assign campaign_tests = relationship(
|
||||||
campaign_tests = relationship(
|
campaign_tests = relationship(
|
||||||
|
# Literal argument value
|
||||||
"CampaignTest",
|
"CampaignTest",
|
||||||
|
# Keyword argument: back_populates
|
||||||
back_populates="campaign",
|
back_populates="campaign",
|
||||||
|
# Keyword argument: cascade
|
||||||
cascade="all, delete-orphan",
|
cascade="all, delete-orphan",
|
||||||
|
# Keyword argument: order_by
|
||||||
order_by="CampaignTest.order_index",
|
order_by="CampaignTest.order_index",
|
||||||
)
|
)
|
||||||
|
# Assign parent_campaign = relationship(
|
||||||
parent_campaign = relationship(
|
parent_campaign = relationship(
|
||||||
|
# Literal argument value
|
||||||
"Campaign",
|
"Campaign",
|
||||||
|
# Keyword argument: remote_side
|
||||||
remote_side="Campaign.id",
|
remote_side="Campaign.id",
|
||||||
|
# Keyword argument: foreign_keys
|
||||||
foreign_keys=[parent_campaign_id],
|
foreign_keys=[parent_campaign_id],
|
||||||
)
|
)
|
||||||
|
# Assign child_campaigns = relationship(
|
||||||
child_campaigns = relationship(
|
child_campaigns = relationship(
|
||||||
|
# Literal argument value
|
||||||
"Campaign",
|
"Campaign",
|
||||||
|
# Keyword argument: foreign_keys
|
||||||
foreign_keys=[parent_campaign_id],
|
foreign_keys=[parent_campaign_id],
|
||||||
|
# Keyword argument: back_populates
|
||||||
back_populates="parent_campaign",
|
back_populates="parent_campaign",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Assign __table_args__ = (
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
Index('ix_campaigns_status', 'status'),
|
Index('ix_campaigns_status', 'status'),
|
||||||
Index('ix_campaigns_type', 'type'),
|
Index('ix_campaigns_type', 'type'),
|
||||||
@@ -98,56 +148,83 @@ class Campaign(Base):
|
|||||||
|
|
||||||
# Kill chain phases in order (for sorting and validation)
|
# Kill chain phases in order (for sorting and validation)
|
||||||
KILL_CHAIN_PHASES = [
|
KILL_CHAIN_PHASES = [
|
||||||
|
# Literal argument value
|
||||||
"reconnaissance",
|
"reconnaissance",
|
||||||
|
# Literal argument value
|
||||||
"resource_development",
|
"resource_development",
|
||||||
|
# Literal argument value
|
||||||
"initial_access",
|
"initial_access",
|
||||||
|
# Literal argument value
|
||||||
"execution",
|
"execution",
|
||||||
|
# Literal argument value
|
||||||
"persistence",
|
"persistence",
|
||||||
|
# Literal argument value
|
||||||
"privilege_escalation",
|
"privilege_escalation",
|
||||||
|
# Literal argument value
|
||||||
"defense_evasion",
|
"defense_evasion",
|
||||||
|
# Literal argument value
|
||||||
"credential_access",
|
"credential_access",
|
||||||
|
# Literal argument value
|
||||||
"discovery",
|
"discovery",
|
||||||
|
# Literal argument value
|
||||||
"lateral_movement",
|
"lateral_movement",
|
||||||
|
# Literal argument value
|
||||||
"collection",
|
"collection",
|
||||||
|
# Literal argument value
|
||||||
"command_and_control",
|
"command_and_control",
|
||||||
|
# Literal argument value
|
||||||
"exfiltration",
|
"exfiltration",
|
||||||
|
# Literal argument value
|
||||||
"impact",
|
"impact",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# Define class CampaignTest
|
||||||
class CampaignTest(Base):
|
class CampaignTest(Base):
|
||||||
"""
|
"""A test within a campaign, with ordering and dependency information.
|
||||||
A test within a campaign, with ordering and dependency information.
|
|
||||||
|
|
||||||
``depends_on`` creates a self-referential chain (A -> B -> C).
|
``depends_on`` creates a self-referential chain (A -> B -> C).
|
||||||
Circular dependencies are validated at the service layer.
|
Circular dependencies are validated at the service layer.
|
||||||
"""
|
"""
|
||||||
|
# Assign __tablename__ = "campaign_tests"
|
||||||
__tablename__ = "campaign_tests"
|
__tablename__ = "campaign_tests"
|
||||||
|
|
||||||
|
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
|
# Assign campaign_id = Column(
|
||||||
campaign_id = Column(
|
campaign_id = Column(
|
||||||
UUID(as_uuid=True),
|
UUID(as_uuid=True),
|
||||||
ForeignKey("campaigns.id", ondelete="CASCADE"),
|
ForeignKey("campaigns.id", ondelete="CASCADE"),
|
||||||
|
# Keyword argument: nullable
|
||||||
nullable=False,
|
nullable=False,
|
||||||
)
|
)
|
||||||
|
# Assign test_id = Column(
|
||||||
test_id = Column(
|
test_id = Column(
|
||||||
UUID(as_uuid=True),
|
UUID(as_uuid=True),
|
||||||
ForeignKey("tests.id", ondelete="CASCADE"),
|
ForeignKey("tests.id", ondelete="CASCADE"),
|
||||||
|
# Keyword argument: nullable
|
||||||
nullable=False,
|
nullable=False,
|
||||||
)
|
)
|
||||||
|
# Assign order_index = Column(Integer, nullable=False, default=0)
|
||||||
order_index = Column(Integer, nullable=False, default=0)
|
order_index = Column(Integer, nullable=False, default=0)
|
||||||
|
# Assign depends_on = Column(
|
||||||
depends_on = Column(
|
depends_on = Column(
|
||||||
UUID(as_uuid=True),
|
UUID(as_uuid=True),
|
||||||
ForeignKey("campaign_tests.id", ondelete="SET NULL"),
|
ForeignKey("campaign_tests.id", ondelete="SET NULL"),
|
||||||
|
# Keyword argument: nullable
|
||||||
nullable=True,
|
nullable=True,
|
||||||
)
|
)
|
||||||
|
# Assign phase = Column(String, nullable=True) # kill chain phase
|
||||||
phase = Column(String, nullable=True) # kill chain phase
|
phase = Column(String, nullable=True) # kill chain phase
|
||||||
|
|
||||||
# Relationships
|
# Relationships
|
||||||
campaign = relationship("Campaign", back_populates="campaign_tests")
|
campaign = relationship("Campaign", back_populates="campaign_tests")
|
||||||
|
# Assign test = relationship("Test")
|
||||||
test = relationship("Test")
|
test = relationship("Test")
|
||||||
|
# Assign dependency = relationship("CampaignTest", remote_side="CampaignTest.id")
|
||||||
dependency = relationship("CampaignTest", remote_side="CampaignTest.id")
|
dependency = relationship("CampaignTest", remote_side="CampaignTest.id")
|
||||||
|
|
||||||
|
# Assign __table_args__ = (
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
Index('ix_campaign_tests_campaign', 'campaign_id'),
|
Index('ix_campaign_tests_campaign', 'campaign_id'),
|
||||||
Index('ix_campaign_tests_test', 'test_id'),
|
Index('ix_campaign_tests_test', 'test_id'),
|
||||||
|
|||||||
@@ -4,92 +4,145 @@ Maps compliance frameworks (NIST 800-53, DORA, NIS2, ISO 27001) to
|
|||||||
MITRE ATT&CK techniques, enabling compliance gap analysis.
|
MITRE ATT&CK techniques, enabling compliance gap analysis.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Import uuid
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
# Import from sqlalchemy
|
||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
Column, String, Text, Boolean, DateTime,
|
Boolean,
|
||||||
ForeignKey, Index, UniqueConstraint, func,
|
Column,
|
||||||
|
DateTime,
|
||||||
|
ForeignKey,
|
||||||
|
Index,
|
||||||
|
String,
|
||||||
|
Text,
|
||||||
|
UniqueConstraint,
|
||||||
|
func,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Import UUID from sqlalchemy.dialects.postgresql
|
||||||
from sqlalchemy.dialects.postgresql import UUID
|
from sqlalchemy.dialects.postgresql import UUID
|
||||||
|
|
||||||
|
# Import relationship from sqlalchemy.orm
|
||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
|
# Import Base from app.database
|
||||||
from app.database import Base
|
from app.database import Base
|
||||||
|
|
||||||
|
|
||||||
|
# Define class ComplianceFramework
|
||||||
class ComplianceFramework(Base):
|
class ComplianceFramework(Base):
|
||||||
"""A compliance framework (e.g. NIST 800-53, ISO 27001)."""
|
"""A compliance framework (e.g. NIST 800-53, ISO 27001)."""
|
||||||
|
# Assign __tablename__ = "compliance_frameworks"
|
||||||
__tablename__ = "compliance_frameworks"
|
__tablename__ = "compliance_frameworks"
|
||||||
|
|
||||||
|
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
|
# Assign name = Column(String, unique=True, nullable=False)
|
||||||
name = Column(String, unique=True, nullable=False)
|
name = Column(String, unique=True, nullable=False)
|
||||||
|
# Assign version = Column(String, nullable=True)
|
||||||
version = Column(String, nullable=True)
|
version = Column(String, nullable=True)
|
||||||
|
# Assign description = Column(Text, nullable=True)
|
||||||
description = Column(Text, nullable=True)
|
description = Column(Text, nullable=True)
|
||||||
|
# Assign url = Column(String, nullable=True)
|
||||||
url = Column(String, nullable=True)
|
url = Column(String, nullable=True)
|
||||||
|
# Assign is_active = Column(Boolean, default=True)
|
||||||
is_active = Column(Boolean, default=True)
|
is_active = Column(Boolean, default=True)
|
||||||
|
# Assign created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
|
||||||
# Relationships
|
# Relationships
|
||||||
controls = relationship(
|
controls = relationship(
|
||||||
|
# Literal argument value
|
||||||
"ComplianceControl",
|
"ComplianceControl",
|
||||||
|
# Keyword argument: back_populates
|
||||||
back_populates="framework",
|
back_populates="framework",
|
||||||
|
# Keyword argument: cascade
|
||||||
cascade="all, delete-orphan",
|
cascade="all, delete-orphan",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Define class ComplianceControl
|
||||||
class ComplianceControl(Base):
|
class ComplianceControl(Base):
|
||||||
"""A control within a compliance framework (e.g. AC-2, PR.AC-1)."""
|
"""A control within a compliance framework (e.g. AC-2, PR.AC-1)."""
|
||||||
|
# Assign __tablename__ = "compliance_controls"
|
||||||
__tablename__ = "compliance_controls"
|
__tablename__ = "compliance_controls"
|
||||||
|
|
||||||
|
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
|
# Assign framework_id = Column(
|
||||||
framework_id = Column(
|
framework_id = Column(
|
||||||
UUID(as_uuid=True),
|
UUID(as_uuid=True),
|
||||||
ForeignKey("compliance_frameworks.id", ondelete="CASCADE"),
|
ForeignKey("compliance_frameworks.id", ondelete="CASCADE"),
|
||||||
|
# Keyword argument: nullable
|
||||||
nullable=False,
|
nullable=False,
|
||||||
)
|
)
|
||||||
|
# Assign control_id = Column(String, nullable=False) # e.g. "AC-2"
|
||||||
control_id = Column(String, nullable=False) # e.g. "AC-2"
|
control_id = Column(String, nullable=False) # e.g. "AC-2"
|
||||||
|
# Assign title = Column(String, nullable=False)
|
||||||
title = Column(String, nullable=False)
|
title = Column(String, nullable=False)
|
||||||
|
# Assign description = Column(Text, nullable=True)
|
||||||
description = Column(Text, nullable=True)
|
description = Column(Text, nullable=True)
|
||||||
|
# Assign category = Column(String, nullable=True)
|
||||||
category = Column(String, nullable=True)
|
category = Column(String, nullable=True)
|
||||||
|
|
||||||
# Relationships
|
# Relationships
|
||||||
framework = relationship("ComplianceFramework", back_populates="controls")
|
framework = relationship("ComplianceFramework", back_populates="controls")
|
||||||
|
# Assign technique_mappings = relationship(
|
||||||
technique_mappings = relationship(
|
technique_mappings = relationship(
|
||||||
|
# Literal argument value
|
||||||
"ComplianceControlMapping",
|
"ComplianceControlMapping",
|
||||||
|
# Keyword argument: back_populates
|
||||||
back_populates="compliance_control",
|
back_populates="compliance_control",
|
||||||
|
# Keyword argument: cascade
|
||||||
cascade="all, delete-orphan",
|
cascade="all, delete-orphan",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Assign __table_args__ = (
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
Index('ix_compliance_controls_framework', 'framework_id'),
|
Index('ix_compliance_controls_framework', 'framework_id'),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Define class ComplianceControlMapping
|
||||||
class ComplianceControlMapping(Base):
|
class ComplianceControlMapping(Base):
|
||||||
"""Maps a compliance control to a MITRE ATT&CK technique."""
|
"""Maps a compliance control to a MITRE ATT&CK technique."""
|
||||||
|
# Assign __tablename__ = "compliance_control_mappings"
|
||||||
__tablename__ = "compliance_control_mappings"
|
__tablename__ = "compliance_control_mappings"
|
||||||
|
|
||||||
|
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
|
# Assign compliance_control_id = Column(
|
||||||
compliance_control_id = Column(
|
compliance_control_id = Column(
|
||||||
UUID(as_uuid=True),
|
UUID(as_uuid=True),
|
||||||
ForeignKey("compliance_controls.id", ondelete="CASCADE"),
|
ForeignKey("compliance_controls.id", ondelete="CASCADE"),
|
||||||
|
# Keyword argument: nullable
|
||||||
nullable=False,
|
nullable=False,
|
||||||
)
|
)
|
||||||
|
# Assign technique_id = Column(
|
||||||
technique_id = Column(
|
technique_id = Column(
|
||||||
UUID(as_uuid=True),
|
UUID(as_uuid=True),
|
||||||
ForeignKey("techniques.id", ondelete="CASCADE"),
|
ForeignKey("techniques.id", ondelete="CASCADE"),
|
||||||
|
# Keyword argument: nullable
|
||||||
nullable=False,
|
nullable=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Relationships
|
# Relationships
|
||||||
compliance_control = relationship(
|
compliance_control = relationship(
|
||||||
|
# Literal argument value
|
||||||
"ComplianceControl", back_populates="technique_mappings"
|
"ComplianceControl", back_populates="technique_mappings"
|
||||||
)
|
)
|
||||||
|
# Assign technique = relationship("Technique")
|
||||||
technique = relationship("Technique")
|
technique = relationship("Technique")
|
||||||
|
|
||||||
|
# Assign __table_args__ = (
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
Index('ix_compliance_mappings_control', 'compliance_control_id'),
|
Index('ix_compliance_mappings_control', 'compliance_control_id'),
|
||||||
Index('ix_compliance_mappings_technique', 'technique_id'),
|
Index('ix_compliance_mappings_technique', 'technique_id'),
|
||||||
UniqueConstraint(
|
UniqueConstraint(
|
||||||
|
# Literal argument value
|
||||||
'compliance_control_id', 'technique_id',
|
'compliance_control_id', 'technique_id',
|
||||||
|
# Keyword argument: name
|
||||||
name='uq_control_technique',
|
name='uq_control_technique',
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -5,76 +5,125 @@ SnapshotTechniqueState stores per-technique state (normalized, one row
|
|||||||
per technique per snapshot) to avoid bloated JSONB fields.
|
per technique per snapshot) to avoid bloated JSONB fields.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Import uuid
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
# Import from sqlalchemy
|
||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
Column, String, Float, Integer, DateTime,
|
Column,
|
||||||
ForeignKey, Index, func,
|
DateTime,
|
||||||
|
Float,
|
||||||
|
ForeignKey,
|
||||||
|
Index,
|
||||||
|
Integer,
|
||||||
|
String,
|
||||||
|
func,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Import JSONB, UUID from sqlalchemy.dialects.postgresql
|
||||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||||
|
|
||||||
|
# Import relationship from sqlalchemy.orm
|
||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
|
# Import Base from app.database
|
||||||
from app.database import Base
|
from app.database import Base
|
||||||
|
|
||||||
|
|
||||||
|
# Define class CoverageSnapshot
|
||||||
class CoverageSnapshot(Base):
|
class CoverageSnapshot(Base):
|
||||||
"""A point-in-time snapshot of the organisation's overall coverage."""
|
"""A point-in-time snapshot of the organisation's overall coverage."""
|
||||||
|
|
||||||
|
# Assign __tablename__ = "coverage_snapshots"
|
||||||
__tablename__ = "coverage_snapshots"
|
__tablename__ = "coverage_snapshots"
|
||||||
|
|
||||||
|
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
|
# Assign name = Column(String, nullable=True) # e.g. "Pre-remediación Q1"
|
||||||
name = Column(String, nullable=True) # e.g. "Pre-remediación Q1"
|
name = Column(String, nullable=True) # e.g. "Pre-remediación Q1"
|
||||||
|
# Assign organization_score = Column(Float, nullable=False)
|
||||||
organization_score = Column(Float, nullable=False)
|
organization_score = Column(Float, nullable=False)
|
||||||
|
# Assign total_techniques = Column(Integer, nullable=False)
|
||||||
total_techniques = Column(Integer, nullable=False)
|
total_techniques = Column(Integer, nullable=False)
|
||||||
|
# Assign validated_count = Column(Integer, nullable=False)
|
||||||
validated_count = Column(Integer, nullable=False)
|
validated_count = Column(Integer, nullable=False)
|
||||||
|
# Assign partial_count = Column(Integer, nullable=False)
|
||||||
partial_count = Column(Integer, nullable=False)
|
partial_count = Column(Integer, nullable=False)
|
||||||
|
# Assign not_covered_count = Column(Integer, nullable=False)
|
||||||
not_covered_count = Column(Integer, nullable=False)
|
not_covered_count = Column(Integer, nullable=False)
|
||||||
|
# Assign in_progress_count = Column(Integer, nullable=False)
|
||||||
in_progress_count = Column(Integer, nullable=False)
|
in_progress_count = Column(Integer, nullable=False)
|
||||||
|
# Assign not_evaluated_count = Column(Integer, nullable=False)
|
||||||
not_evaluated_count = Column(Integer, nullable=False)
|
not_evaluated_count = Column(Integer, nullable=False)
|
||||||
|
# Assign coverage_percentage = Column(Float, nullable=False, default=0.0)
|
||||||
coverage_percentage = Column(Float, nullable=False, default=0.0)
|
coverage_percentage = Column(Float, nullable=False, default=0.0)
|
||||||
|
# Assign by_tactic = Column(JSONB, nullable=False, default=dict)
|
||||||
by_tactic = Column(JSONB, nullable=False, default=dict)
|
by_tactic = Column(JSONB, nullable=False, default=dict)
|
||||||
|
# Assign by_status = Column(JSONB, nullable=False, default=dict)
|
||||||
by_status = Column(JSONB, nullable=False, default=dict)
|
by_status = Column(JSONB, nullable=False, default=dict)
|
||||||
|
# Assign stale_count = Column(Integer, nullable=False, default=0)
|
||||||
stale_count = Column(Integer, nullable=False, default=0)
|
stale_count = Column(Integer, nullable=False, default=0)
|
||||||
|
# Assign never_tested_count = Column(Integer, nullable=False, default=0)
|
||||||
never_tested_count = Column(Integer, nullable=False, default=0)
|
never_tested_count = Column(Integer, nullable=False, default=0)
|
||||||
|
# Assign created_by = Column(
|
||||||
created_by = Column(
|
created_by = Column(
|
||||||
UUID(as_uuid=True),
|
UUID(as_uuid=True),
|
||||||
ForeignKey("users.id", ondelete="SET NULL"),
|
ForeignKey("users.id", ondelete="SET NULL"),
|
||||||
|
# Keyword argument: nullable
|
||||||
nullable=True,
|
nullable=True,
|
||||||
)
|
)
|
||||||
|
# Assign created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
|
||||||
# Relationships
|
# Relationships
|
||||||
creator = relationship("User", foreign_keys=[created_by])
|
creator = relationship("User", foreign_keys=[created_by])
|
||||||
|
# Assign technique_states = relationship(
|
||||||
technique_states = relationship(
|
technique_states = relationship(
|
||||||
|
# Literal argument value
|
||||||
"SnapshotTechniqueState",
|
"SnapshotTechniqueState",
|
||||||
|
# Keyword argument: back_populates
|
||||||
back_populates="snapshot",
|
back_populates="snapshot",
|
||||||
|
# Keyword argument: cascade
|
||||||
cascade="all, delete-orphan",
|
cascade="all, delete-orphan",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Define class SnapshotTechniqueState
|
||||||
class SnapshotTechniqueState(Base):
|
class SnapshotTechniqueState(Base):
|
||||||
"""Per-technique state within a snapshot (normalised storage)."""
|
"""Per-technique state within a snapshot (normalised storage)."""
|
||||||
|
|
||||||
|
# Assign __tablename__ = "snapshot_technique_states"
|
||||||
__tablename__ = "snapshot_technique_states"
|
__tablename__ = "snapshot_technique_states"
|
||||||
|
|
||||||
|
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
|
# Assign snapshot_id = Column(
|
||||||
snapshot_id = Column(
|
snapshot_id = Column(
|
||||||
UUID(as_uuid=True),
|
UUID(as_uuid=True),
|
||||||
ForeignKey("coverage_snapshots.id", ondelete="CASCADE"),
|
ForeignKey("coverage_snapshots.id", ondelete="CASCADE"),
|
||||||
|
# Keyword argument: nullable
|
||||||
nullable=False,
|
nullable=False,
|
||||||
)
|
)
|
||||||
|
# Assign technique_id = Column(
|
||||||
technique_id = Column(
|
technique_id = Column(
|
||||||
UUID(as_uuid=True),
|
UUID(as_uuid=True),
|
||||||
ForeignKey("techniques.id", ondelete="CASCADE"),
|
ForeignKey("techniques.id", ondelete="CASCADE"),
|
||||||
|
# Keyword argument: nullable
|
||||||
nullable=False,
|
nullable=False,
|
||||||
)
|
)
|
||||||
|
# Assign mitre_id = Column(String, nullable=False) # denormalised for fast queries
|
||||||
mitre_id = Column(String, nullable=False) # denormalised for fast queries
|
mitre_id = Column(String, nullable=False) # denormalised for fast queries
|
||||||
|
# Assign status = Column(String, nullable=False)
|
||||||
status = Column(String, nullable=False)
|
status = Column(String, nullable=False)
|
||||||
|
# Assign score = Column(Float, nullable=True)
|
||||||
score = Column(Float, nullable=True)
|
score = Column(Float, nullable=True)
|
||||||
|
|
||||||
# Relationships
|
# Relationships
|
||||||
snapshot = relationship("CoverageSnapshot", back_populates="technique_states")
|
snapshot = relationship("CoverageSnapshot", back_populates="technique_states")
|
||||||
|
# Assign technique = relationship("Technique")
|
||||||
technique = relationship("Technique")
|
technique = relationship("Technique")
|
||||||
|
|
||||||
|
# Assign __table_args__ = (
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
Index("ix_snapshot_technique_states_snapshot", "snapshot_id"),
|
Index("ix_snapshot_technique_states_snapshot", "snapshot_id"),
|
||||||
Index("ix_snapshot_technique_states_technique", "technique_id"),
|
Index("ix_snapshot_technique_states_technique", "technique_id"),
|
||||||
|
|||||||
@@ -1,36 +1,56 @@
|
|||||||
"""DataSource model — registry of external data sources for import."""
|
"""DataSource model — registry of external data sources for import."""
|
||||||
|
|
||||||
|
# Import uuid
|
||||||
import uuid
|
import uuid
|
||||||
from sqlalchemy import Column, String, Text, Boolean, DateTime, Index, func
|
|
||||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
|
||||||
|
|
||||||
|
# Import Boolean, Column, DateTime, Index, String, Text,... from sqlalchemy
|
||||||
|
from sqlalchemy import Boolean, Column, DateTime, Index, String, Text, func
|
||||||
|
|
||||||
|
# Import JSONB, UUID from sqlalchemy.dialects.postgresql
|
||||||
|
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||||
|
|
||||||
|
# Import Base from app.database
|
||||||
from app.database import Base
|
from app.database import Base
|
||||||
|
|
||||||
|
|
||||||
|
# Define class DataSource
|
||||||
class DataSource(Base):
|
class DataSource(Base):
|
||||||
"""
|
"""Unified registry of all external data sources.
|
||||||
Unified registry of all external data sources (attack procedures,
|
|
||||||
detection rules, threat intel, defensive techniques).
|
|
||||||
|
|
||||||
Each source can be independently enabled/disabled and tracks its own
|
Covers attack procedures, detection rules, threat intel, and defensive techniques.
|
||||||
synchronisation state.
|
Each source can be independently enabled/disabled and tracks its own synchronisation state.
|
||||||
"""
|
"""
|
||||||
|
# Assign __tablename__ = "data_sources"
|
||||||
__tablename__ = "data_sources"
|
__tablename__ = "data_sources"
|
||||||
|
|
||||||
|
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
|
# Assign name = Column(String, unique=True, nullable=False) # e.g. "atom...
|
||||||
name = Column(String, unique=True, nullable=False) # e.g. "atomic_red_team"
|
name = Column(String, unique=True, nullable=False) # e.g. "atomic_red_team"
|
||||||
|
# Assign display_name = Column(String, nullable=False) # e.g. "Atomic Red ...
|
||||||
display_name = Column(String, nullable=False) # e.g. "Atomic Red Team"
|
display_name = Column(String, nullable=False) # e.g. "Atomic Red Team"
|
||||||
type = Column(String, nullable=False) # attack_procedure / detection_rule / threat_intel / defensive_technique
|
# Values: attack_procedure / detection_rule / threat_intel / defensive_technique
|
||||||
|
type = Column(String, nullable=False)
|
||||||
|
# Assign url = Column(String, nullable=True) # URL base...
|
||||||
url = Column(String, nullable=True) # URL base of repo/API
|
url = Column(String, nullable=True) # URL base of repo/API
|
||||||
|
# Assign description = Column(Text, nullable=True)
|
||||||
description = Column(Text, nullable=True)
|
description = Column(Text, nullable=True)
|
||||||
|
# Assign is_enabled = Column(Boolean, default=True)
|
||||||
is_enabled = Column(Boolean, default=True)
|
is_enabled = Column(Boolean, default=True)
|
||||||
|
# Assign last_sync_at = Column(DateTime, nullable=True)
|
||||||
last_sync_at = Column(DateTime, nullable=True)
|
last_sync_at = Column(DateTime, nullable=True)
|
||||||
|
# Assign last_sync_status = Column(String, nullable=True) # success / error / in_...
|
||||||
last_sync_status = Column(String, nullable=True) # success / error / in_progress
|
last_sync_status = Column(String, nullable=True) # success / error / in_progress
|
||||||
|
# Assign last_sync_stats = Column(JSONB, nullable=True) # {"imported": X, "upd...
|
||||||
last_sync_stats = Column(JSONB, nullable=True) # {"imported": X, "updated": Y, ...}
|
last_sync_stats = Column(JSONB, nullable=True) # {"imported": X, "updated": Y, ...}
|
||||||
|
# Assign sync_frequency = Column(String, nullable=True) # daily / weekly / mo...
|
||||||
sync_frequency = Column(String, nullable=True) # daily / weekly / monthly / manual
|
sync_frequency = Column(String, nullable=True) # daily / weekly / monthly / manual
|
||||||
|
# Assign config = Column(JSONB, nullable=True) # source-spec...
|
||||||
config = Column(JSONB, nullable=True) # source-specific configuration
|
config = Column(JSONB, nullable=True) # source-specific configuration
|
||||||
|
# Assign created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
|
||||||
|
# Assign __table_args__ = (
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
Index('ix_data_sources_type', 'type'),
|
Index('ix_data_sources_type', 'type'),
|
||||||
Index('ix_data_sources_is_enabled', 'is_enabled'),
|
Index('ix_data_sources_is_enabled', 'is_enabled'),
|
||||||
|
|||||||
@@ -4,74 +4,108 @@ Stores MITRE D3FEND defensive techniques and their mappings to
|
|||||||
ATT&CK techniques, enabling recommended countermeasure lookups.
|
ATT&CK techniques, enabling recommended countermeasure lookups.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Import uuid
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
# Import from sqlalchemy
|
||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
Column, String, Text, DateTime,
|
Column,
|
||||||
ForeignKey, Index, UniqueConstraint, func,
|
DateTime,
|
||||||
|
ForeignKey,
|
||||||
|
Index,
|
||||||
|
String,
|
||||||
|
Text,
|
||||||
|
UniqueConstraint,
|
||||||
|
func,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Import UUID from sqlalchemy.dialects.postgresql
|
||||||
from sqlalchemy.dialects.postgresql import UUID
|
from sqlalchemy.dialects.postgresql import UUID
|
||||||
|
|
||||||
|
# Import relationship from sqlalchemy.orm
|
||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
|
# Import Base from app.database
|
||||||
from app.database import Base
|
from app.database import Base
|
||||||
|
|
||||||
|
|
||||||
|
# Define class DefensiveTechnique
|
||||||
class DefensiveTechnique(Base):
|
class DefensiveTechnique(Base):
|
||||||
"""
|
"""MITRE D3FEND defensive technique.
|
||||||
MITRE D3FEND defensive technique.
|
|
||||||
|
|
||||||
Represents a countermeasure from the D3FEND framework that can be
|
Represents a countermeasure from the D3FEND framework that can be
|
||||||
mapped to one or more ATT&CK techniques via DefensiveTechniqueMapping.
|
mapped to one or more ATT&CK techniques via DefensiveTechniqueMapping.
|
||||||
"""
|
"""
|
||||||
|
# Assign __tablename__ = "defensive_techniques"
|
||||||
__tablename__ = "defensive_techniques"
|
__tablename__ = "defensive_techniques"
|
||||||
|
|
||||||
|
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
|
# Assign d3fend_id = Column(String, unique=True, nullable=False) # e.g. "D3-AL"
|
||||||
d3fend_id = Column(String, unique=True, nullable=False) # e.g. "D3-AL"
|
d3fend_id = Column(String, unique=True, nullable=False) # e.g. "D3-AL"
|
||||||
|
# Assign name = Column(String, nullable=False)
|
||||||
name = Column(String, nullable=False)
|
name = Column(String, nullable=False)
|
||||||
|
# Assign description = Column(Text, nullable=True)
|
||||||
description = Column(Text, nullable=True)
|
description = Column(Text, nullable=True)
|
||||||
|
# Assign tactic = Column(String, nullable=True) # Detect, ...
|
||||||
tactic = Column(String, nullable=True) # Detect, Isolate, Deceive, Evict, etc.
|
tactic = Column(String, nullable=True) # Detect, Isolate, Deceive, Evict, etc.
|
||||||
|
# Assign d3fend_url = Column(String, nullable=True)
|
||||||
d3fend_url = Column(String, nullable=True)
|
d3fend_url = Column(String, nullable=True)
|
||||||
|
# Assign created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
|
||||||
# Relationships
|
# Relationships
|
||||||
attack_mappings = relationship(
|
attack_mappings = relationship(
|
||||||
|
# Literal argument value
|
||||||
"DefensiveTechniqueMapping",
|
"DefensiveTechniqueMapping",
|
||||||
|
# Keyword argument: back_populates
|
||||||
back_populates="defensive_technique",
|
back_populates="defensive_technique",
|
||||||
|
# Keyword argument: cascade
|
||||||
cascade="all, delete-orphan",
|
cascade="all, delete-orphan",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Assign __table_args__ = (
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
Index('ix_defensive_techniques_tactic', 'tactic'),
|
Index('ix_defensive_techniques_tactic', 'tactic'),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Define class DefensiveTechniqueMapping
|
||||||
class DefensiveTechniqueMapping(Base):
|
class DefensiveTechniqueMapping(Base):
|
||||||
"""
|
"""Association between a MITRE ATT&CK technique and a D3FEND defensive technique."""
|
||||||
Association between a MITRE ATT&CK technique and a D3FEND
|
# Assign __tablename__ = "defensive_technique_mappings"
|
||||||
defensive technique.
|
|
||||||
"""
|
|
||||||
__tablename__ = "defensive_technique_mappings"
|
__tablename__ = "defensive_technique_mappings"
|
||||||
|
|
||||||
|
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
|
# Assign attack_technique_id = Column(
|
||||||
attack_technique_id = Column(
|
attack_technique_id = Column(
|
||||||
UUID(as_uuid=True),
|
UUID(as_uuid=True),
|
||||||
ForeignKey("techniques.id", ondelete="CASCADE"),
|
ForeignKey("techniques.id", ondelete="CASCADE"),
|
||||||
|
# Keyword argument: nullable
|
||||||
nullable=False,
|
nullable=False,
|
||||||
)
|
)
|
||||||
|
# Assign defensive_technique_id = Column(
|
||||||
defensive_technique_id = Column(
|
defensive_technique_id = Column(
|
||||||
UUID(as_uuid=True),
|
UUID(as_uuid=True),
|
||||||
ForeignKey("defensive_techniques.id", ondelete="CASCADE"),
|
ForeignKey("defensive_techniques.id", ondelete="CASCADE"),
|
||||||
|
# Keyword argument: nullable
|
||||||
nullable=False,
|
nullable=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Relationships
|
# Relationships
|
||||||
attack_technique = relationship("Technique")
|
attack_technique = relationship("Technique")
|
||||||
|
# Assign defensive_technique = relationship("DefensiveTechnique", back_populates="attack_mappings")
|
||||||
defensive_technique = relationship("DefensiveTechnique", back_populates="attack_mappings")
|
defensive_technique = relationship("DefensiveTechnique", back_populates="attack_mappings")
|
||||||
|
|
||||||
|
# Assign __table_args__ = (
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
Index('ix_dtm_attack_technique', 'attack_technique_id'),
|
Index('ix_dtm_attack_technique', 'attack_technique_id'),
|
||||||
Index('ix_dtm_defensive_technique', 'defensive_technique_id'),
|
Index('ix_dtm_defensive_technique', 'defensive_technique_id'),
|
||||||
UniqueConstraint(
|
UniqueConstraint(
|
||||||
|
# Literal argument value
|
||||||
'attack_technique_id', 'defensive_technique_id',
|
'attack_technique_id', 'defensive_technique_id',
|
||||||
|
# Keyword argument: name
|
||||||
name='uq_attack_defensive_technique',
|
name='uq_attack_defensive_technique',
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,38 +1,61 @@
|
|||||||
"""DetectionRule model — detection rules from multiple sources."""
|
"""DetectionRule model — detection rules from multiple sources."""
|
||||||
|
|
||||||
|
# Import uuid
|
||||||
import uuid
|
import uuid
|
||||||
from sqlalchemy import Column, String, Text, Boolean, DateTime, Index, func
|
|
||||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
|
||||||
|
|
||||||
|
# Import Boolean, Column, DateTime, Index, String, Text,... from sqlalchemy
|
||||||
|
from sqlalchemy import Boolean, Column, DateTime, Index, String, Text, func
|
||||||
|
|
||||||
|
# Import JSONB, UUID from sqlalchemy.dialects.postgresql
|
||||||
|
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||||
|
|
||||||
|
# Import Base from app.database
|
||||||
from app.database import Base
|
from app.database import Base
|
||||||
|
|
||||||
|
|
||||||
|
# Define class DetectionRule
|
||||||
class DetectionRule(Base):
|
class DetectionRule(Base):
|
||||||
"""
|
"""Detection rule from an external source (Sigma, Elastic, Splunk, custom).
|
||||||
Detection rule from an external source (Sigma, Elastic, Splunk, custom).
|
|
||||||
|
|
||||||
Each rule is mapped to one MITRE ATT&CK technique via
|
Each rule is mapped to one MITRE ATT&CK technique via
|
||||||
``mitre_technique_id`` and stores the complete rule content in
|
``mitre_technique_id`` and stores the complete rule content in
|
||||||
``rule_content``.
|
``rule_content``.
|
||||||
"""
|
"""
|
||||||
|
# Assign __tablename__ = "detection_rules"
|
||||||
__tablename__ = "detection_rules"
|
__tablename__ = "detection_rules"
|
||||||
|
|
||||||
|
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
|
# Assign mitre_technique_id = Column(String, nullable=False) # e.g. "T1059.001"
|
||||||
mitre_technique_id = Column(String, nullable=False) # e.g. "T1059.001"
|
mitre_technique_id = Column(String, nullable=False) # e.g. "T1059.001"
|
||||||
|
# Assign title = Column(String, nullable=False)
|
||||||
title = Column(String, nullable=False)
|
title = Column(String, nullable=False)
|
||||||
|
# Assign description = Column(Text, nullable=True)
|
||||||
description = Column(Text, nullable=True)
|
description = Column(Text, nullable=True)
|
||||||
|
# Assign source = Column(String, nullable=False) # sigma / ela...
|
||||||
source = Column(String, nullable=False) # sigma / elastic / splunk / custom
|
source = Column(String, nullable=False) # sigma / elastic / splunk / custom
|
||||||
|
# Assign source_id = Column(String, nullable=True) # ID in the sour...
|
||||||
source_id = Column(String, nullable=True) # ID in the source repo (for dedup)
|
source_id = Column(String, nullable=True) # ID in the source repo (for dedup)
|
||||||
|
# Assign source_url = Column(String, nullable=True)
|
||||||
source_url = Column(String, nullable=True)
|
source_url = Column(String, nullable=True)
|
||||||
|
# Assign rule_content = Column(Text, nullable=False) # YAML / KQL / SPL ...
|
||||||
rule_content = Column(Text, nullable=False) # YAML / KQL / SPL content
|
rule_content = Column(Text, nullable=False) # YAML / KQL / SPL content
|
||||||
|
# Assign rule_format = Column(String, nullable=False) # sigma_yaml / kql...
|
||||||
rule_format = Column(String, nullable=False) # sigma_yaml / kql / spl / custom
|
rule_format = Column(String, nullable=False) # sigma_yaml / kql / spl / custom
|
||||||
|
# Assign severity = Column(String, nullable=True) # informational...
|
||||||
severity = Column(String, nullable=True) # informational / low / medium / high / critical
|
severity = Column(String, nullable=True) # informational / low / medium / high / critical
|
||||||
|
# Assign platforms = Column(JSONB, nullable=True, default=[])
|
||||||
platforms = Column(JSONB, nullable=True, default=[])
|
platforms = Column(JSONB, nullable=True, default=[])
|
||||||
|
# Assign log_sources = Column(JSONB, nullable=True) # e.g. {"product":...
|
||||||
log_sources = Column(JSONB, nullable=True) # e.g. {"product": "windows", "service": "sysmon"}
|
log_sources = Column(JSONB, nullable=True) # e.g. {"product": "windows", "service": "sysmon"}
|
||||||
|
# Assign false_positive_rate = Column(String, nullable=True) # low / medium / high
|
||||||
false_positive_rate = Column(String, nullable=True) # low / medium / high
|
false_positive_rate = Column(String, nullable=True) # low / medium / high
|
||||||
|
# Assign is_active = Column(Boolean, default=True)
|
||||||
is_active = Column(Boolean, default=True)
|
is_active = Column(Boolean, default=True)
|
||||||
|
# Assign created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
|
||||||
|
# Assign __table_args__ = (
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
Index('ix_detection_rules_mitre_technique_id', 'mitre_technique_id'),
|
Index('ix_detection_rules_mitre_technique_id', 'mitre_technique_id'),
|
||||||
Index('ix_detection_rules_source', 'source'),
|
Index('ix_detection_rules_source', 'source'),
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ re-exports every enum so that existing model and router code keeps
|
|||||||
working with ``from app.models.enums import ...``.
|
working with ``from app.models.enums import ...``.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Import # noqa: F401 from app.domain.enums
|
||||||
from app.domain.enums import ( # noqa: F401
|
from app.domain.enums import ( # noqa: F401
|
||||||
DataClassification,
|
DataClassification,
|
||||||
TeamSide,
|
TeamSide,
|
||||||
|
|||||||
@@ -1,35 +1,59 @@
|
|||||||
|
"""SQLAlchemy model for the evidence table."""
|
||||||
|
|
||||||
|
# Import uuid
|
||||||
import uuid
|
import uuid
|
||||||
from sqlalchemy import Column, String, Text, DateTime, ForeignKey, Enum, func
|
|
||||||
|
# Import Column, DateTime, Enum, ForeignKey, String, Tex... from sqlalchemy
|
||||||
|
from sqlalchemy import Column, DateTime, Enum, ForeignKey, String, Text, func
|
||||||
|
|
||||||
|
# Import UUID from sqlalchemy.dialects.postgresql
|
||||||
from sqlalchemy.dialects.postgresql import UUID
|
from sqlalchemy.dialects.postgresql import UUID
|
||||||
|
|
||||||
|
# Import relationship from sqlalchemy.orm
|
||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
|
# Import Base from app.database
|
||||||
from app.database import Base
|
from app.database import Base
|
||||||
|
|
||||||
|
# Import TeamSide from app.models.enums
|
||||||
from app.models.enums import TeamSide
|
from app.models.enums import TeamSide
|
||||||
|
|
||||||
|
|
||||||
|
# Define class Evidence
|
||||||
class Evidence(Base):
|
class Evidence(Base):
|
||||||
"""
|
"""Evidence model for storing file metadata associated with tests.
|
||||||
Evidence model for storing file metadata associated with tests.
|
|
||||||
|
|
||||||
Files are stored in MinIO, and this model tracks the file location,
|
Files are stored in MinIO, and this model tracks the file location,
|
||||||
integrity hash, and upload metadata.
|
integrity hash, and upload metadata.
|
||||||
|
|
||||||
The ``team`` field distinguishes whether this evidence was uploaded by
|
The ``team`` field distinguishes whether this evidence was uploaded by
|
||||||
Red Team (attack evidence) or Blue Team (detection evidence).
|
Red Team (attack evidence) or Blue Team (detection evidence).
|
||||||
"""
|
"""
|
||||||
|
# Assign __tablename__ = "evidences"
|
||||||
__tablename__ = "evidences"
|
__tablename__ = "evidences"
|
||||||
|
|
||||||
|
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
|
# Assign test_id = Column(UUID(as_uuid=True), ForeignKey("tests.id"), nullable=False)
|
||||||
test_id = Column(UUID(as_uuid=True), ForeignKey("tests.id"), nullable=False)
|
test_id = Column(UUID(as_uuid=True), ForeignKey("tests.id"), nullable=False)
|
||||||
|
# Assign file_name = Column(String, nullable=False)
|
||||||
file_name = Column(String, nullable=False)
|
file_name = Column(String, nullable=False)
|
||||||
|
# Assign file_path = Column(String, nullable=False) # Path in MinIO
|
||||||
file_path = Column(String, nullable=False) # Path in MinIO
|
file_path = Column(String, nullable=False) # Path in MinIO
|
||||||
|
# Assign sha256_hash = Column(String, nullable=False)
|
||||||
sha256_hash = Column(String, nullable=False)
|
sha256_hash = Column(String, nullable=False)
|
||||||
|
# Assign uploaded_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
|
||||||
uploaded_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
|
uploaded_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
|
||||||
|
# Assign uploaded_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
uploaded_at = Column(DateTime(timezone=True), server_default=func.now())
|
uploaded_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
# Assign team = Column(Enum(TeamSide, name="teamside"), nullable=False, default=Tea...
|
||||||
team = Column(Enum(TeamSide, name="teamside"), nullable=False, default=TeamSide.red)
|
team = Column(Enum(TeamSide, name="teamside"), nullable=False, default=TeamSide.red)
|
||||||
|
# Assign notes = Column(Text, nullable=True)
|
||||||
notes = Column(Text, nullable=True)
|
notes = Column(Text, nullable=True)
|
||||||
|
# Assign data_classification = Column(String(20), nullable=False, server_default="internal")
|
||||||
data_classification = Column(String(20), nullable=False, server_default="internal")
|
data_classification = Column(String(20), nullable=False, server_default="internal")
|
||||||
|
|
||||||
# Relationships
|
# Relationships
|
||||||
test = relationship("Test", back_populates="evidences")
|
test = relationship("Test", back_populates="evidences")
|
||||||
|
# Assign uploader = relationship("User", foreign_keys=[uploaded_by])
|
||||||
uploader = relationship("User", foreign_keys=[uploaded_by])
|
uploader = relationship("User", foreign_keys=[uploaded_by])
|
||||||
|
|||||||
@@ -1,26 +1,44 @@
|
|||||||
|
"""SQLAlchemy model for the intel_items table."""
|
||||||
|
|
||||||
|
# Import uuid
|
||||||
import uuid
|
import uuid
|
||||||
from sqlalchemy import Column, String, Boolean, DateTime, ForeignKey, func
|
|
||||||
|
# Import Boolean, Column, DateTime, ForeignKey, String, ... from sqlalchemy
|
||||||
|
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, String, func
|
||||||
|
|
||||||
|
# Import UUID from sqlalchemy.dialects.postgresql
|
||||||
from sqlalchemy.dialects.postgresql import UUID
|
from sqlalchemy.dialects.postgresql import UUID
|
||||||
|
|
||||||
|
# Import relationship from sqlalchemy.orm
|
||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
|
# Import Base from app.database
|
||||||
from app.database import Base
|
from app.database import Base
|
||||||
|
|
||||||
|
|
||||||
|
# Define class IntelItem
|
||||||
class IntelItem(Base):
|
class IntelItem(Base):
|
||||||
"""
|
"""Intelligence item model for tracking threat intelligence related to techniques.
|
||||||
Intelligence item model for tracking threat intelligence related to techniques.
|
|
||||||
|
|
||||||
Stores URLs and metadata from automated intel scans that may indicate
|
Stores URLs and metadata from automated intel scans that may indicate
|
||||||
new attack variations or detection bypasses for specific techniques.
|
new attack variations or detection bypasses for specific techniques.
|
||||||
"""
|
"""
|
||||||
|
# Assign __tablename__ = "intel_items"
|
||||||
__tablename__ = "intel_items"
|
__tablename__ = "intel_items"
|
||||||
|
|
||||||
|
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
|
# Assign technique_id = Column(UUID(as_uuid=True), ForeignKey("techniques.id"), nullable=True)
|
||||||
technique_id = Column(UUID(as_uuid=True), ForeignKey("techniques.id"), nullable=True)
|
technique_id = Column(UUID(as_uuid=True), ForeignKey("techniques.id"), nullable=True)
|
||||||
|
# Assign url = Column(String, nullable=False)
|
||||||
url = Column(String, nullable=False)
|
url = Column(String, nullable=False)
|
||||||
|
# Assign title = Column(String, nullable=True)
|
||||||
title = Column(String, nullable=True)
|
title = Column(String, nullable=True)
|
||||||
|
# Assign source = Column(String, nullable=True)
|
||||||
source = Column(String, nullable=True)
|
source = Column(String, nullable=True)
|
||||||
|
# Assign detected_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
detected_at = Column(DateTime(timezone=True), server_default=func.now())
|
detected_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
# Assign reviewed = Column(Boolean, default=False)
|
||||||
reviewed = Column(Boolean, default=False)
|
reviewed = Column(Boolean, default=False)
|
||||||
|
|
||||||
# Relationships
|
# Relationships
|
||||||
|
|||||||
@@ -1,53 +1,99 @@
|
|||||||
"""Jira integration models — link Aegis entities to Jira issues."""
|
"""Jira integration models — link Aegis entities to Jira issues."""
|
||||||
|
|
||||||
|
# Import enum
|
||||||
import enum
|
import enum
|
||||||
|
|
||||||
|
# Import uuid
|
||||||
import uuid
|
import uuid
|
||||||
from sqlalchemy import Column, String, DateTime, ForeignKey, Enum as SQLEnum, Index, func
|
|
||||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
# Import Column, DateTime, ForeignKey, Index, String, func from sqlalchemy
|
||||||
|
from sqlalchemy import Column, DateTime, ForeignKey, Index, String, func
|
||||||
|
|
||||||
|
# Import Enum as SQLEnum from sqlalchemy
|
||||||
|
from sqlalchemy import Enum as SQLEnum
|
||||||
|
|
||||||
|
# Import JSONB, UUID from sqlalchemy.dialects.postgresql
|
||||||
|
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||||
|
|
||||||
|
# Import relationship from sqlalchemy.orm
|
||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
|
# Import Base from app.database
|
||||||
from app.database import Base
|
from app.database import Base
|
||||||
|
|
||||||
|
|
||||||
|
# Define class JiraLinkEntityType
|
||||||
class JiraLinkEntityType(str, enum.Enum):
|
class JiraLinkEntityType(str, enum.Enum):
|
||||||
|
"""Aegis entity types that can be linked to a Jira issue."""
|
||||||
|
|
||||||
|
# Assign test = "test"
|
||||||
test = "test"
|
test = "test"
|
||||||
|
# Assign technique = "technique"
|
||||||
technique = "technique"
|
technique = "technique"
|
||||||
|
# Assign campaign = "campaign"
|
||||||
campaign = "campaign"
|
campaign = "campaign"
|
||||||
|
# Assign evidence = "evidence"
|
||||||
evidence = "evidence"
|
evidence = "evidence"
|
||||||
|
|
||||||
|
|
||||||
|
# Define class JiraSyncDirection
|
||||||
class JiraSyncDirection(str, enum.Enum):
|
class JiraSyncDirection(str, enum.Enum):
|
||||||
|
"""Direction of synchronisation between Aegis and Jira."""
|
||||||
|
|
||||||
|
# Assign aegis_to_jira = "aegis_to_jira"
|
||||||
aegis_to_jira = "aegis_to_jira"
|
aegis_to_jira = "aegis_to_jira"
|
||||||
|
# Assign jira_to_aegis = "jira_to_aegis"
|
||||||
jira_to_aegis = "jira_to_aegis"
|
jira_to_aegis = "jira_to_aegis"
|
||||||
|
# Assign bidirectional = "bidirectional"
|
||||||
bidirectional = "bidirectional"
|
bidirectional = "bidirectional"
|
||||||
|
|
||||||
|
|
||||||
|
# Define class JiraLink
|
||||||
class JiraLink(Base):
|
class JiraLink(Base):
|
||||||
"""Associates an Aegis entity with a Jira issue for bidirectional sync."""
|
"""Associates an Aegis entity with a Jira issue for bidirectional sync."""
|
||||||
|
|
||||||
|
# Assign __tablename__ = "jira_links"
|
||||||
__tablename__ = "jira_links"
|
__tablename__ = "jira_links"
|
||||||
|
|
||||||
|
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
|
# Assign entity_type = Column(SQLEnum(JiraLinkEntityType), nullable=False)
|
||||||
entity_type = Column(SQLEnum(JiraLinkEntityType), nullable=False)
|
entity_type = Column(SQLEnum(JiraLinkEntityType), nullable=False)
|
||||||
|
# Assign entity_id = Column(UUID(as_uuid=True), nullable=False)
|
||||||
entity_id = Column(UUID(as_uuid=True), nullable=False)
|
entity_id = Column(UUID(as_uuid=True), nullable=False)
|
||||||
|
# Assign jira_issue_key = Column(String(50), nullable=False)
|
||||||
jira_issue_key = Column(String(50), nullable=False)
|
jira_issue_key = Column(String(50), nullable=False)
|
||||||
|
# Assign jira_issue_id = Column(String(50))
|
||||||
jira_issue_id = Column(String(50))
|
jira_issue_id = Column(String(50))
|
||||||
|
# Assign jira_project_key = Column(String(20))
|
||||||
jira_project_key = Column(String(20))
|
jira_project_key = Column(String(20))
|
||||||
|
# Assign jira_status = Column(String(100))
|
||||||
jira_status = Column(String(100))
|
jira_status = Column(String(100))
|
||||||
|
# Assign jira_priority = Column(String(50))
|
||||||
jira_priority = Column(String(50))
|
jira_priority = Column(String(50))
|
||||||
|
# Assign jira_assignee = Column(String(255))
|
||||||
jira_assignee = Column(String(255))
|
jira_assignee = Column(String(255))
|
||||||
|
# Assign jira_story_points = Column(String(10))
|
||||||
jira_story_points = Column(String(10))
|
jira_story_points = Column(String(10))
|
||||||
|
# Assign sync_direction = Column(
|
||||||
sync_direction = Column(
|
sync_direction = Column(
|
||||||
SQLEnum(JiraSyncDirection), default=JiraSyncDirection.bidirectional
|
SQLEnum(JiraSyncDirection), default=JiraSyncDirection.bidirectional
|
||||||
)
|
)
|
||||||
|
# Assign last_synced_at = Column(DateTime)
|
||||||
last_synced_at = Column(DateTime)
|
last_synced_at = Column(DateTime)
|
||||||
|
# Assign sync_metadata = Column(JSONB, default={})
|
||||||
sync_metadata = Column(JSONB, default={})
|
sync_metadata = Column(JSONB, default={})
|
||||||
|
# Assign created_by = Column(UUID(as_uuid=True), ForeignKey("users.id"))
|
||||||
created_by = Column(UUID(as_uuid=True), ForeignKey("users.id"))
|
created_by = Column(UUID(as_uuid=True), ForeignKey("users.id"))
|
||||||
|
# Assign created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
# Assign updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate...
|
||||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
|
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
|
||||||
|
|
||||||
|
# Assign creator = relationship("User", foreign_keys=[created_by])
|
||||||
creator = relationship("User", foreign_keys=[created_by])
|
creator = relationship("User", foreign_keys=[created_by])
|
||||||
|
|
||||||
|
# Assign __table_args__ = (
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
Index("ix_jira_links_entity_id", "entity_id"),
|
Index("ix_jira_links_entity_id", "entity_id"),
|
||||||
Index("ix_jira_links_issue_key", "jira_issue_key"),
|
Index("ix_jira_links_issue_key", "jira_issue_key"),
|
||||||
|
|||||||
@@ -1,35 +1,54 @@
|
|||||||
"""Notification model — in-app notifications for user actions."""
|
"""Notification model — in-app notifications for user actions."""
|
||||||
|
|
||||||
|
# Import uuid
|
||||||
import uuid
|
import uuid
|
||||||
from sqlalchemy import Column, String, Text, Boolean, DateTime, ForeignKey, Index, func
|
|
||||||
|
# Import Boolean, Column, DateTime, ForeignKey, Index, S... from sqlalchemy
|
||||||
|
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Index, String, Text, func
|
||||||
|
|
||||||
|
# Import UUID from sqlalchemy.dialects.postgresql
|
||||||
from sqlalchemy.dialects.postgresql import UUID
|
from sqlalchemy.dialects.postgresql import UUID
|
||||||
|
|
||||||
|
# Import relationship from sqlalchemy.orm
|
||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
|
# Import Base from app.database
|
||||||
from app.database import Base
|
from app.database import Base
|
||||||
|
|
||||||
|
|
||||||
|
# Define class Notification
|
||||||
class Notification(Base):
|
class Notification(Base):
|
||||||
"""
|
"""In-app notification for alerting users when they need to act.
|
||||||
In-app notification for alerting users when they need to act.
|
|
||||||
|
|
||||||
Types include: test_assigned, validation_needed, test_rejected,
|
Types include: test_assigned, validation_needed, test_rejected,
|
||||||
test_validated, test_state_changed, etc.
|
test_validated, test_state_changed, etc.
|
||||||
"""
|
"""
|
||||||
|
# Assign __tablename__ = "notifications"
|
||||||
__tablename__ = "notifications"
|
__tablename__ = "notifications"
|
||||||
|
|
||||||
|
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
|
# Assign user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=False)
|
||||||
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=False)
|
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=False)
|
||||||
|
# Assign type = Column(String, nullable=False)
|
||||||
type = Column(String, nullable=False)
|
type = Column(String, nullable=False)
|
||||||
|
# Assign title = Column(String, nullable=False)
|
||||||
title = Column(String, nullable=False)
|
title = Column(String, nullable=False)
|
||||||
|
# Assign message = Column(Text, nullable=True)
|
||||||
message = Column(Text, nullable=True)
|
message = Column(Text, nullable=True)
|
||||||
|
# Assign entity_type = Column(String, nullable=True)
|
||||||
entity_type = Column(String, nullable=True)
|
entity_type = Column(String, nullable=True)
|
||||||
|
# Assign entity_id = Column(UUID(as_uuid=True), nullable=True)
|
||||||
entity_id = Column(UUID(as_uuid=True), nullable=True)
|
entity_id = Column(UUID(as_uuid=True), nullable=True)
|
||||||
|
# Assign read = Column(Boolean, default=False)
|
||||||
read = Column(Boolean, default=False)
|
read = Column(Boolean, default=False)
|
||||||
|
# Assign created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
|
||||||
# Relationships
|
# Relationships
|
||||||
user = relationship("User")
|
user = relationship("User")
|
||||||
|
|
||||||
|
# Assign __table_args__ = (
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
Index("ix_notifications_user_id", "user_id"),
|
Index("ix_notifications_user_id", "user_id"),
|
||||||
Index("ix_notifications_read", "read"),
|
Index("ix_notifications_read", "read"),
|
||||||
|
|||||||
@@ -1,37 +1,58 @@
|
|||||||
"""OSINT enrichment items — CVEs, blogs, PoCs, and advisories linked to techniques."""
|
"""OSINT enrichment items — CVEs, blogs, PoCs, and advisories linked to techniques."""
|
||||||
|
|
||||||
|
# Import uuid
|
||||||
import uuid
|
import uuid
|
||||||
from sqlalchemy import Column, String, Text, Boolean, DateTime, ForeignKey, func
|
|
||||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
# Import Boolean, Column, DateTime, ForeignKey, String, ... from sqlalchemy
|
||||||
|
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, String, Text, func
|
||||||
|
|
||||||
|
# Import JSONB, UUID from sqlalchemy.dialects.postgresql
|
||||||
|
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||||
|
|
||||||
|
# Import relationship from sqlalchemy.orm
|
||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
|
# Import Base from app.database
|
||||||
from app.database import Base
|
from app.database import Base
|
||||||
|
|
||||||
|
|
||||||
|
# Define class OsintItem
|
||||||
class OsintItem(Base):
|
class OsintItem(Base):
|
||||||
"""Represents an OSINT data point (CVE, blog, PoC, advisory) associated
|
"""Represents an OSINT data point (CVE, blog, PoC, advisory) associated with a MITRE ATT&CK technique.
|
||||||
with a MITRE ATT&CK technique.
|
|
||||||
|
|
||||||
Used by the enrichment pipeline to surface relevant threat intelligence
|
Used by the enrichment pipeline to surface relevant threat intelligence
|
||||||
for each technique, flagging those that need review.
|
for each technique, flagging those that need review.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Assign __tablename__ = "osint_items"
|
||||||
__tablename__ = "osint_items"
|
__tablename__ = "osint_items"
|
||||||
|
|
||||||
|
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
|
# Assign technique_id = Column(
|
||||||
technique_id = Column(
|
technique_id = Column(
|
||||||
UUID(as_uuid=True),
|
UUID(as_uuid=True),
|
||||||
ForeignKey("techniques.id"),
|
ForeignKey("techniques.id"),
|
||||||
|
# Keyword argument: nullable
|
||||||
nullable=False,
|
nullable=False,
|
||||||
|
# Keyword argument: index
|
||||||
index=True,
|
index=True,
|
||||||
)
|
)
|
||||||
|
# Assign source_type = Column(String(50), nullable=False) # "cve", "blog", "poc", "advisory"
|
||||||
source_type = Column(String(50), nullable=False) # "cve", "blog", "poc", "advisory"
|
source_type = Column(String(50), nullable=False) # "cve", "blog", "poc", "advisory"
|
||||||
|
# Assign source_url = Column(Text, nullable=False)
|
||||||
source_url = Column(Text, nullable=False)
|
source_url = Column(Text, nullable=False)
|
||||||
|
# Assign title = Column(String(500), nullable=False)
|
||||||
title = Column(String(500), nullable=False)
|
title = Column(String(500), nullable=False)
|
||||||
|
# Assign description = Column(Text, nullable=True)
|
||||||
description = Column(Text, nullable=True)
|
description = Column(Text, nullable=True)
|
||||||
|
# Assign severity = Column(String(20), nullable=True) # CRITICAL, HIGH, MEDIUM, LOW, U...
|
||||||
severity = Column(String(20), nullable=True) # CRITICAL, HIGH, MEDIUM, LOW, UNKNOWN
|
severity = Column(String(20), nullable=True) # CRITICAL, HIGH, MEDIUM, LOW, UNKNOWN
|
||||||
|
# Assign discovered_at = Column(DateTime(timezone=True), server_default=func.now(), nullable...
|
||||||
discovered_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
discovered_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||||
|
# Assign reviewed = Column(Boolean, default=False)
|
||||||
reviewed = Column(Boolean, default=False)
|
reviewed = Column(Boolean, default=False)
|
||||||
|
# Assign metadata_ = Column("metadata", JSONB, default={})
|
||||||
metadata_ = Column("metadata", JSONB, default={})
|
metadata_ = Column("metadata", JSONB, default={})
|
||||||
|
|
||||||
# ── Relationships ─────────────────────────────────────────────────
|
# ── Relationships ─────────────────────────────────────────────────
|
||||||
|
|||||||
@@ -1,25 +1,43 @@
|
|||||||
"""ScoringConfig — single-row table for persisted scoring weights."""
|
"""ScoringConfig — single-row table for persisted scoring weights."""
|
||||||
|
|
||||||
|
# Import uuid
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from sqlalchemy import Column, Float, DateTime, ForeignKey, func
|
# Import Column, DateTime, Float, ForeignKey, func from sqlalchemy
|
||||||
|
from sqlalchemy import Column, DateTime, Float, ForeignKey, func
|
||||||
|
|
||||||
|
# Import UUID from sqlalchemy.dialects.postgresql
|
||||||
from sqlalchemy.dialects.postgresql import UUID
|
from sqlalchemy.dialects.postgresql import UUID
|
||||||
|
|
||||||
|
# Import Base from app.database
|
||||||
from app.database import Base
|
from app.database import Base
|
||||||
|
|
||||||
|
|
||||||
|
# Define class ScoringConfig
|
||||||
class ScoringConfig(Base):
|
class ScoringConfig(Base):
|
||||||
|
"""Single-row table persisting the active scoring weight configuration."""
|
||||||
|
|
||||||
|
# Assign __tablename__ = "scoring_config"
|
||||||
__tablename__ = "scoring_config"
|
__tablename__ = "scoring_config"
|
||||||
|
|
||||||
|
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
|
# Assign weight_tests = Column(Float, nullable=False, default=40.0)
|
||||||
weight_tests = Column(Float, nullable=False, default=40.0)
|
weight_tests = Column(Float, nullable=False, default=40.0)
|
||||||
|
# Assign weight_detection_rules = Column(Float, nullable=False, default=25.0)
|
||||||
weight_detection_rules = Column(Float, nullable=False, default=25.0)
|
weight_detection_rules = Column(Float, nullable=False, default=25.0)
|
||||||
|
# Assign weight_d3fend = Column(Float, nullable=False, default=15.0)
|
||||||
weight_d3fend = Column(Float, nullable=False, default=15.0)
|
weight_d3fend = Column(Float, nullable=False, default=15.0)
|
||||||
|
# Assign weight_recency = Column(Float, nullable=False, default=10.0)
|
||||||
weight_recency = Column(Float, nullable=False, default=10.0)
|
weight_recency = Column(Float, nullable=False, default=10.0)
|
||||||
|
# Assign weight_severity = Column(Float, nullable=False, default=10.0)
|
||||||
weight_severity = Column(Float, nullable=False, default=10.0)
|
weight_severity = Column(Float, nullable=False, default=10.0)
|
||||||
|
# Assign updated_by = Column(
|
||||||
updated_by = Column(
|
updated_by = Column(
|
||||||
UUID(as_uuid=True),
|
UUID(as_uuid=True),
|
||||||
ForeignKey("users.id", ondelete="SET NULL"),
|
ForeignKey("users.id", ondelete="SET NULL"),
|
||||||
|
# Keyword argument: nullable
|
||||||
nullable=True,
|
nullable=True,
|
||||||
)
|
)
|
||||||
|
# Assign updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate...
|
||||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
|
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
|
||||||
|
|||||||
@@ -1,38 +1,63 @@
|
|||||||
import uuid
|
"""SQLAlchemy model for the techniques table."""
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from sqlalchemy import Column, String, Text, Boolean, DateTime, Enum
|
# Import uuid
|
||||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
import uuid
|
||||||
|
|
||||||
|
# Import Boolean, Column, DateTime, Enum, String, Text from sqlalchemy
|
||||||
|
from sqlalchemy import Boolean, Column, DateTime, Enum, String, Text
|
||||||
|
|
||||||
|
# Import JSONB, UUID from sqlalchemy.dialects.postgresql
|
||||||
|
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||||
|
|
||||||
|
# Import relationship from sqlalchemy.orm
|
||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
|
# Import Base from app.database
|
||||||
from app.database import Base
|
from app.database import Base
|
||||||
|
|
||||||
|
# Import TechniqueStatus from app.models.enums
|
||||||
from app.models.enums import TechniqueStatus
|
from app.models.enums import TechniqueStatus
|
||||||
|
|
||||||
|
|
||||||
|
# Define class Technique
|
||||||
class Technique(Base):
|
class Technique(Base):
|
||||||
"""
|
"""MITRE ATT&CK Technique model.
|
||||||
MITRE ATT&CK Technique model.
|
|
||||||
|
|
||||||
Represents an attack technique from the MITRE ATT&CK framework,
|
Represents an attack technique from the MITRE ATT&CK framework,
|
||||||
including its coverage status and associated tests.
|
including its coverage status and associated tests.
|
||||||
"""
|
"""
|
||||||
|
# Assign __tablename__ = "techniques"
|
||||||
__tablename__ = "techniques"
|
__tablename__ = "techniques"
|
||||||
|
|
||||||
|
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
|
# Assign mitre_id = Column(String, unique=True, nullable=False) # e.g., "T1059.001"
|
||||||
mitre_id = Column(String, unique=True, nullable=False) # e.g., "T1059.001"
|
mitre_id = Column(String, unique=True, nullable=False) # e.g., "T1059.001"
|
||||||
|
# Assign name = Column(String, nullable=False)
|
||||||
name = Column(String, nullable=False)
|
name = Column(String, nullable=False)
|
||||||
|
# Assign description = Column(Text, nullable=True)
|
||||||
description = Column(Text, nullable=True)
|
description = Column(Text, nullable=True)
|
||||||
|
# Assign tactic = Column(String, nullable=True)
|
||||||
tactic = Column(String, nullable=True)
|
tactic = Column(String, nullable=True)
|
||||||
|
# Assign platforms = Column(JSONB, nullable=True, default=[])
|
||||||
platforms = Column(JSONB, nullable=True, default=[])
|
platforms = Column(JSONB, nullable=True, default=[])
|
||||||
|
# Assign mitre_version = Column(String, nullable=True)
|
||||||
mitre_version = Column(String, nullable=True)
|
mitre_version = Column(String, nullable=True)
|
||||||
|
# Assign mitre_last_modified = Column(DateTime, nullable=True)
|
||||||
mitre_last_modified = Column(DateTime, nullable=True)
|
mitre_last_modified = Column(DateTime, nullable=True)
|
||||||
|
# Assign is_subtechnique = Column(Boolean, default=False)
|
||||||
is_subtechnique = Column(Boolean, default=False)
|
is_subtechnique = Column(Boolean, default=False)
|
||||||
|
# Assign parent_mitre_id = Column(String, nullable=True)
|
||||||
parent_mitre_id = Column(String, nullable=True)
|
parent_mitre_id = Column(String, nullable=True)
|
||||||
|
# Assign status_global = Column(
|
||||||
status_global = Column(
|
status_global = Column(
|
||||||
Enum(TechniqueStatus, name="techniquestatus"),
|
Enum(TechniqueStatus, name="techniquestatus"),
|
||||||
|
# Keyword argument: default
|
||||||
default=TechniqueStatus.not_evaluated
|
default=TechniqueStatus.not_evaluated
|
||||||
)
|
)
|
||||||
|
# Assign review_required = Column(Boolean, default=False)
|
||||||
review_required = Column(Boolean, default=False)
|
review_required = Column(Boolean, default=False)
|
||||||
|
# Assign last_review_date = Column(DateTime, nullable=True)
|
||||||
last_review_date = Column(DateTime, nullable=True)
|
last_review_date = Column(DateTime, nullable=True)
|
||||||
|
|
||||||
# Relationships
|
# Relationships
|
||||||
|
|||||||
@@ -1,80 +1,140 @@
|
|||||||
|
"""SQLAlchemy model for the tests table."""
|
||||||
|
|
||||||
|
# Import uuid
|
||||||
import uuid
|
import uuid
|
||||||
from sqlalchemy import Column, String, Text, Boolean, Integer, DateTime, ForeignKey, Enum, Index, func
|
|
||||||
|
# Import from sqlalchemy
|
||||||
|
from sqlalchemy import (
|
||||||
|
Boolean,
|
||||||
|
Column,
|
||||||
|
DateTime,
|
||||||
|
Enum,
|
||||||
|
ForeignKey,
|
||||||
|
Index,
|
||||||
|
Integer,
|
||||||
|
String,
|
||||||
|
Text,
|
||||||
|
func,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Import UUID from sqlalchemy.dialects.postgresql
|
||||||
from sqlalchemy.dialects.postgresql import UUID
|
from sqlalchemy.dialects.postgresql import UUID
|
||||||
|
|
||||||
|
# Import relationship from sqlalchemy.orm
|
||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
|
# Import Base from app.database
|
||||||
from app.database import Base
|
from app.database import Base
|
||||||
from app.models.enums import TestState, TestResult
|
|
||||||
|
# Import TestResult, TestState from app.models.enums
|
||||||
|
from app.models.enums import TestResult, TestState
|
||||||
|
|
||||||
|
|
||||||
|
# Define class Test
|
||||||
class Test(Base):
|
class Test(Base):
|
||||||
"""
|
"""Test model representing a security test for a MITRE ATT&CK technique.
|
||||||
Test model representing a security test for a MITRE ATT&CK technique.
|
|
||||||
|
|
||||||
Each test documents an attempt to validate coverage of a specific technique,
|
Each test documents an attempt to validate coverage of a specific technique,
|
||||||
including the procedure, tools used, and outcome. V2 introduces dual
|
including the procedure, tools used, and outcome. V2 introduces dual
|
||||||
validation: Red Lead and Blue Lead must each approve independently.
|
validation: Red Lead and Blue Lead must each approve independently.
|
||||||
"""
|
"""
|
||||||
|
# Assign __tablename__ = "tests"
|
||||||
__tablename__ = "tests"
|
__tablename__ = "tests"
|
||||||
|
|
||||||
# ── Core fields ─────────────────────────────────────────────────
|
# ── Core fields ─────────────────────────────────────────────────
|
||||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
|
# Assign technique_id = Column(UUID(as_uuid=True), ForeignKey("techniques.id"), nullable=Fa...
|
||||||
technique_id = Column(UUID(as_uuid=True), ForeignKey("techniques.id"), nullable=False)
|
technique_id = Column(UUID(as_uuid=True), ForeignKey("techniques.id"), nullable=False)
|
||||||
|
# Assign name = Column(String, nullable=False)
|
||||||
name = Column(String, nullable=False)
|
name = Column(String, nullable=False)
|
||||||
|
# Assign description = Column(Text, nullable=True)
|
||||||
description = Column(Text, nullable=True)
|
description = Column(Text, nullable=True)
|
||||||
|
# Assign platform = Column(String, nullable=True)
|
||||||
platform = Column(String, nullable=True)
|
platform = Column(String, nullable=True)
|
||||||
|
# Assign procedure_text = Column(Text, nullable=True)
|
||||||
procedure_text = Column(Text, nullable=True)
|
procedure_text = Column(Text, nullable=True)
|
||||||
|
# Assign tool_used = Column(String, nullable=True)
|
||||||
tool_used = Column(String, nullable=True)
|
tool_used = Column(String, nullable=True)
|
||||||
|
# Assign execution_date = Column(DateTime, nullable=True)
|
||||||
execution_date = Column(DateTime, nullable=True)
|
execution_date = Column(DateTime, nullable=True)
|
||||||
|
# Assign created_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
|
||||||
created_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
|
created_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
|
||||||
|
# Assign result = Column(Enum(TestResult, name="testresult"), nullable=True)
|
||||||
result = Column(Enum(TestResult, name="testresult"), nullable=True)
|
result = Column(Enum(TestResult, name="testresult"), nullable=True)
|
||||||
|
# Assign state = Column(Enum(TestState, name="teststate"), default=TestState.draft)
|
||||||
state = Column(Enum(TestState, name="teststate"), default=TestState.draft)
|
state = Column(Enum(TestState, name="teststate"), default=TestState.draft)
|
||||||
|
# Assign created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
|
||||||
# ── Red Team fields ─────────────────────────────────────────────
|
# ── Red Team fields ─────────────────────────────────────────────
|
||||||
red_summary = Column(Text, nullable=True)
|
red_summary = Column(Text, nullable=True)
|
||||||
|
# Assign attack_success = Column(Boolean, nullable=True)
|
||||||
attack_success = Column(Boolean, nullable=True)
|
attack_success = Column(Boolean, nullable=True)
|
||||||
|
# Assign red_validated_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
|
||||||
red_validated_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
|
red_validated_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
|
||||||
|
# Assign red_validated_at = Column(DateTime, nullable=True)
|
||||||
red_validated_at = Column(DateTime, nullable=True)
|
red_validated_at = Column(DateTime, nullable=True)
|
||||||
|
# Assign red_validation_status = Column(String, nullable=True) # pending / approved / rejected
|
||||||
red_validation_status = Column(String, nullable=True) # pending / approved / rejected
|
red_validation_status = Column(String, nullable=True) # pending / approved / rejected
|
||||||
|
# Assign red_validation_notes = Column(Text, nullable=True)
|
||||||
red_validation_notes = Column(Text, nullable=True)
|
red_validation_notes = Column(Text, nullable=True)
|
||||||
|
|
||||||
# ── Blue Team fields ────────────────────────────────────────────
|
# ── Blue Team fields ────────────────────────────────────────────
|
||||||
blue_summary = Column(Text, nullable=True)
|
blue_summary = Column(Text, nullable=True)
|
||||||
|
# Assign detection_result = Column(Enum(TestResult, name="testresult"), nullable=True)
|
||||||
detection_result = Column(Enum(TestResult, name="testresult"), nullable=True)
|
detection_result = Column(Enum(TestResult, name="testresult"), nullable=True)
|
||||||
|
# Assign blue_validated_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
|
||||||
blue_validated_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
|
blue_validated_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
|
||||||
|
# Assign blue_validated_at = Column(DateTime, nullable=True)
|
||||||
blue_validated_at = Column(DateTime, nullable=True)
|
blue_validated_at = Column(DateTime, nullable=True)
|
||||||
|
# Assign blue_validation_status = Column(String, nullable=True) # pending / approved / rejected
|
||||||
blue_validation_status = Column(String, nullable=True) # pending / approved / rejected
|
blue_validation_status = Column(String, nullable=True) # pending / approved / rejected
|
||||||
|
# Assign blue_validation_notes = Column(Text, nullable=True)
|
||||||
blue_validation_notes = Column(Text, nullable=True)
|
blue_validation_notes = Column(Text, nullable=True)
|
||||||
|
|
||||||
# ── Phase timing fields (for automatic Tempo worklogs) ──────────
|
# ── Phase timing fields (for automatic Tempo worklogs) ──────────
|
||||||
red_started_at = Column(DateTime, nullable=True)
|
red_started_at = Column(DateTime, nullable=True)
|
||||||
|
# Assign blue_started_at = Column(DateTime, nullable=True)
|
||||||
blue_started_at = Column(DateTime, nullable=True)
|
blue_started_at = Column(DateTime, nullable=True)
|
||||||
blue_work_started_at = Column(DateTime, nullable=True) # when blue tech picks up (Tempo start)
|
blue_work_started_at = Column(DateTime, nullable=True) # when blue tech picks up (Tempo start)
|
||||||
paused_at = Column(DateTime, nullable=True)
|
paused_at = Column(DateTime, nullable=True)
|
||||||
|
# Assign red_paused_seconds = Column(Integer, default=0)
|
||||||
red_paused_seconds = Column(Integer, default=0)
|
red_paused_seconds = Column(Integer, default=0)
|
||||||
|
# Assign blue_paused_seconds = Column(Integer, default=0)
|
||||||
blue_paused_seconds = Column(Integer, default=0)
|
blue_paused_seconds = Column(Integer, default=0)
|
||||||
|
|
||||||
# ── Remediation fields ───────────────────────────────────────────
|
# ── Remediation fields ───────────────────────────────────────────
|
||||||
remediation_steps = Column(Text, nullable=True)
|
remediation_steps = Column(Text, nullable=True)
|
||||||
|
# Assign remediation_status = Column(String, nullable=True) # pending / in_progress / completed ...
|
||||||
remediation_status = Column(String, nullable=True) # pending / in_progress / completed / not_applicable
|
remediation_status = Column(String, nullable=True) # pending / in_progress / completed / not_applicable
|
||||||
|
# Assign remediation_assignee = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
|
||||||
remediation_assignee = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
|
remediation_assignee = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
|
||||||
|
|
||||||
# ── Re-test fields ────────────────────────────────────────────
|
# ── Re-test fields ────────────────────────────────────────────
|
||||||
retest_of = Column(UUID(as_uuid=True), ForeignKey("tests.id"), nullable=True)
|
retest_of = Column(UUID(as_uuid=True), ForeignKey("tests.id"), nullable=True)
|
||||||
|
# Assign retest_count = Column(Integer, default=0)
|
||||||
retest_count = Column(Integer, default=0)
|
retest_count = Column(Integer, default=0)
|
||||||
|
# Assign data_classification = Column(String(20), nullable=False, server_default="internal")
|
||||||
data_classification = Column(String(20), nullable=False, server_default="internal")
|
data_classification = Column(String(20), nullable=False, server_default="internal")
|
||||||
|
|
||||||
# ── Relationships ───────────────────────────────────────────────
|
# ── Relationships ───────────────────────────────────────────────
|
||||||
technique = relationship("Technique", back_populates="tests")
|
technique = relationship("Technique", back_populates="tests")
|
||||||
|
# Assign evidences = relationship("Evidence", back_populates="test")
|
||||||
evidences = relationship("Evidence", back_populates="test")
|
evidences = relationship("Evidence", back_populates="test")
|
||||||
|
# Assign creator = relationship("User", foreign_keys=[created_by])
|
||||||
creator = relationship("User", foreign_keys=[created_by])
|
creator = relationship("User", foreign_keys=[created_by])
|
||||||
|
# Assign red_validator = relationship("User", foreign_keys=[red_validated_by])
|
||||||
red_validator = relationship("User", foreign_keys=[red_validated_by])
|
red_validator = relationship("User", foreign_keys=[red_validated_by])
|
||||||
|
# Assign blue_validator = relationship("User", foreign_keys=[blue_validated_by])
|
||||||
blue_validator = relationship("User", foreign_keys=[blue_validated_by])
|
blue_validator = relationship("User", foreign_keys=[blue_validated_by])
|
||||||
|
# Assign remediation_user = relationship("User", foreign_keys=[remediation_assignee])
|
||||||
remediation_user = relationship("User", foreign_keys=[remediation_assignee])
|
remediation_user = relationship("User", foreign_keys=[remediation_assignee])
|
||||||
|
# Assign original_test = relationship("Test", remote_side="Test.id", foreign_keys=[retest_of])
|
||||||
original_test = relationship("Test", remote_side="Test.id", foreign_keys=[retest_of])
|
original_test = relationship("Test", remote_side="Test.id", foreign_keys=[retest_of])
|
||||||
|
# Assign retests = relationship("Test", foreign_keys=[retest_of], back_populates="orig...
|
||||||
retests = relationship("Test", foreign_keys=[retest_of], back_populates="original_test")
|
retests = relationship("Test", foreign_keys=[retest_of], back_populates="original_test")
|
||||||
|
|
||||||
|
# Assign __table_args__ = (
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
Index("ix_tests_technique_id", "technique_id"),
|
Index("ix_tests_technique_id", "technique_id"),
|
||||||
Index("ix_tests_state", "state"),
|
Index("ix_tests_state", "state"),
|
||||||
|
|||||||
@@ -4,51 +4,79 @@ When the Blue Team evaluates a test, they mark each associated detection
|
|||||||
rule as triggered / not triggered / not applicable, along with notes.
|
rule as triggered / not triggered / not applicable, along with notes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Import uuid
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from sqlalchemy import Column, String, Text, Boolean, DateTime, ForeignKey, Index, UniqueConstraint
|
# Import from sqlalchemy
|
||||||
|
from sqlalchemy import (
|
||||||
|
Boolean,
|
||||||
|
Column,
|
||||||
|
DateTime,
|
||||||
|
ForeignKey,
|
||||||
|
Index,
|
||||||
|
Text,
|
||||||
|
UniqueConstraint,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Import UUID from sqlalchemy.dialects.postgresql
|
||||||
from sqlalchemy.dialects.postgresql import UUID
|
from sqlalchemy.dialects.postgresql import UUID
|
||||||
|
|
||||||
|
# Import relationship from sqlalchemy.orm
|
||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
|
# Import Base from app.database
|
||||||
from app.database import Base
|
from app.database import Base
|
||||||
|
|
||||||
|
|
||||||
|
# Define class TestDetectionResult
|
||||||
class TestDetectionResult(Base):
|
class TestDetectionResult(Base):
|
||||||
"""
|
"""Per-test, per-rule evaluation result.
|
||||||
Per-test, per-rule evaluation result.
|
|
||||||
|
|
||||||
- ``triggered`` = True: rule detected the attack
|
- ``triggered`` = True: rule detected the attack
|
||||||
- ``triggered`` = False: rule did NOT detect the attack
|
- ``triggered`` = False: rule did NOT detect the attack
|
||||||
- ``triggered`` = None: not yet evaluated
|
- ``triggered`` = None: not yet evaluated
|
||||||
"""
|
"""
|
||||||
|
# Assign __tablename__ = "test_detection_results"
|
||||||
__tablename__ = "test_detection_results"
|
__tablename__ = "test_detection_results"
|
||||||
|
|
||||||
|
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
|
# Assign test_id = Column(
|
||||||
test_id = Column(
|
test_id = Column(
|
||||||
UUID(as_uuid=True),
|
UUID(as_uuid=True),
|
||||||
ForeignKey("tests.id", ondelete="CASCADE"),
|
ForeignKey("tests.id", ondelete="CASCADE"),
|
||||||
|
# Keyword argument: nullable
|
||||||
nullable=False,
|
nullable=False,
|
||||||
)
|
)
|
||||||
|
# Assign detection_rule_id = Column(
|
||||||
detection_rule_id = Column(
|
detection_rule_id = Column(
|
||||||
UUID(as_uuid=True),
|
UUID(as_uuid=True),
|
||||||
ForeignKey("detection_rules.id", ondelete="CASCADE"),
|
ForeignKey("detection_rules.id", ondelete="CASCADE"),
|
||||||
|
# Keyword argument: nullable
|
||||||
nullable=False,
|
nullable=False,
|
||||||
)
|
)
|
||||||
|
# Assign triggered = Column(Boolean, nullable=True) # None = not evaluated
|
||||||
triggered = Column(Boolean, nullable=True) # None = not evaluated
|
triggered = Column(Boolean, nullable=True) # None = not evaluated
|
||||||
|
# Assign notes = Column(Text, nullable=True)
|
||||||
notes = Column(Text, nullable=True)
|
notes = Column(Text, nullable=True)
|
||||||
|
# Assign evaluated_by = Column(
|
||||||
evaluated_by = Column(
|
evaluated_by = Column(
|
||||||
UUID(as_uuid=True),
|
UUID(as_uuid=True),
|
||||||
ForeignKey("users.id", ondelete="SET NULL"),
|
ForeignKey("users.id", ondelete="SET NULL"),
|
||||||
|
# Keyword argument: nullable
|
||||||
nullable=True,
|
nullable=True,
|
||||||
)
|
)
|
||||||
|
# Assign evaluated_at = Column(DateTime, nullable=True)
|
||||||
evaluated_at = Column(DateTime, nullable=True)
|
evaluated_at = Column(DateTime, nullable=True)
|
||||||
|
|
||||||
# Relationships
|
# Relationships
|
||||||
test = relationship("Test")
|
test = relationship("Test")
|
||||||
|
# Assign detection_rule = relationship("DetectionRule")
|
||||||
detection_rule = relationship("DetectionRule")
|
detection_rule = relationship("DetectionRule")
|
||||||
|
# Assign evaluator = relationship("User", foreign_keys=[evaluated_by])
|
||||||
evaluator = relationship("User", foreign_keys=[evaluated_by])
|
evaluator = relationship("User", foreign_keys=[evaluated_by])
|
||||||
|
|
||||||
|
# Assign __table_args__ = (
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
Index('ix_tdr_test', 'test_id'),
|
Index('ix_tdr_test', 'test_id'),
|
||||||
Index('ix_tdr_rule', 'detection_rule_id'),
|
Index('ix_tdr_rule', 'detection_rule_id'),
|
||||||
|
|||||||
@@ -1,15 +1,21 @@
|
|||||||
"""TestTemplate model — predefined test catalog entries."""
|
"""TestTemplate model — predefined test catalog entries."""
|
||||||
|
|
||||||
|
# Import uuid
|
||||||
import uuid
|
import uuid
|
||||||
from sqlalchemy import Column, String, Text, Boolean, DateTime, Index, func
|
|
||||||
|
# Import Boolean, Column, DateTime, Index, String, Text,... from sqlalchemy
|
||||||
|
from sqlalchemy import Boolean, Column, DateTime, Index, String, Text, func
|
||||||
|
|
||||||
|
# Import UUID from sqlalchemy.dialects.postgresql
|
||||||
from sqlalchemy.dialects.postgresql import UUID
|
from sqlalchemy.dialects.postgresql import UUID
|
||||||
|
|
||||||
|
# Import Base from app.database
|
||||||
from app.database import Base
|
from app.database import Base
|
||||||
|
|
||||||
|
|
||||||
|
# Define class TestTemplate
|
||||||
class TestTemplate(Base):
|
class TestTemplate(Base):
|
||||||
"""
|
"""Predefined test template mapped to a MITRE ATT&CK technique.
|
||||||
Predefined test template mapped to a MITRE ATT&CK technique.
|
|
||||||
|
|
||||||
Templates come from several sources:
|
Templates come from several sources:
|
||||||
- **atomic_red_team**: Atomic Red Team by Red Canary
|
- **atomic_red_team**: Atomic Red Team by Red Canary
|
||||||
@@ -18,24 +24,41 @@ class TestTemplate(Base):
|
|||||||
|
|
||||||
Users can instantiate a real Test from a template.
|
Users can instantiate a real Test from a template.
|
||||||
"""
|
"""
|
||||||
|
# Assign __tablename__ = "test_templates"
|
||||||
__tablename__ = "test_templates"
|
__tablename__ = "test_templates"
|
||||||
|
|
||||||
|
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
|
# Assign mitre_technique_id = Column(String, nullable=False) # e.g. "T1059.001"
|
||||||
mitre_technique_id = Column(String, nullable=False) # e.g. "T1059.001"
|
mitre_technique_id = Column(String, nullable=False) # e.g. "T1059.001"
|
||||||
|
# Assign name = Column(String, nullable=False)
|
||||||
name = Column(String, nullable=False)
|
name = Column(String, nullable=False)
|
||||||
|
# Assign description = Column(Text, nullable=True)
|
||||||
description = Column(Text, nullable=True)
|
description = Column(Text, nullable=True)
|
||||||
|
# Assign source = Column(String, nullable=False) # atomic_red_te...
|
||||||
source = Column(String, nullable=False) # atomic_red_team / mitre / custom
|
source = Column(String, nullable=False) # atomic_red_team / mitre / custom
|
||||||
|
# Assign source_url = Column(String, nullable=True)
|
||||||
source_url = Column(String, nullable=True)
|
source_url = Column(String, nullable=True)
|
||||||
|
# Assign attack_procedure = Column(Text, nullable=True) # Suggested attack procedure
|
||||||
attack_procedure = Column(Text, nullable=True) # Suggested attack procedure
|
attack_procedure = Column(Text, nullable=True) # Suggested attack procedure
|
||||||
|
# Assign expected_detection = Column(Text, nullable=True) # What blue team should detect
|
||||||
expected_detection = Column(Text, nullable=True) # What blue team should detect
|
expected_detection = Column(Text, nullable=True) # What blue team should detect
|
||||||
|
# Assign platform = Column(String, nullable=True) # windows / linux...
|
||||||
platform = Column(String, nullable=True) # windows / linux / macos
|
platform = Column(String, nullable=True) # windows / linux / macos
|
||||||
|
# Assign tool_suggested = Column(String, nullable=True)
|
||||||
tool_suggested = Column(String, nullable=True)
|
tool_suggested = Column(String, nullable=True)
|
||||||
|
# Assign severity = Column(String, nullable=True) # low / medium / ...
|
||||||
severity = Column(String, nullable=True) # low / medium / high / critical
|
severity = Column(String, nullable=True) # low / medium / high / critical
|
||||||
|
# Assign atomic_test_id = Column(String, nullable=True) # ID in Atomic Red Team...
|
||||||
atomic_test_id = Column(String, nullable=True) # ID in Atomic Red Team repo
|
atomic_test_id = Column(String, nullable=True) # ID in Atomic Red Team repo
|
||||||
|
# Assign suggested_remediation = Column(Text, nullable=True)
|
||||||
suggested_remediation = Column(Text, nullable=True)
|
suggested_remediation = Column(Text, nullable=True)
|
||||||
|
# Assign is_active = Column(Boolean, default=True)
|
||||||
is_active = Column(Boolean, default=True)
|
is_active = Column(Boolean, default=True)
|
||||||
|
# Assign created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
|
||||||
|
# Assign __table_args__ = (
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
Index('ix_test_templates_mitre_technique_id', 'mitre_technique_id'),
|
Index('ix_test_templates_mitre_technique_id', 'mitre_technique_id'),
|
||||||
Index('ix_test_templates_source', 'source'),
|
Index('ix_test_templates_source', 'source'),
|
||||||
|
|||||||
@@ -4,47 +4,64 @@ Enables the Blue Team to see which detection rules should fire
|
|||||||
for a given test template / attack procedure.
|
for a given test template / attack procedure.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Import uuid
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from sqlalchemy import Column, Boolean, ForeignKey, Index, UniqueConstraint
|
# Import Boolean, Column, ForeignKey, Index, UniqueConst... from sqlalchemy
|
||||||
|
from sqlalchemy import Boolean, Column, ForeignKey, Index, UniqueConstraint
|
||||||
|
|
||||||
|
# Import UUID from sqlalchemy.dialects.postgresql
|
||||||
from sqlalchemy.dialects.postgresql import UUID
|
from sqlalchemy.dialects.postgresql import UUID
|
||||||
|
|
||||||
|
# Import relationship from sqlalchemy.orm
|
||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
|
# Import Base from app.database
|
||||||
from app.database import Base
|
from app.database import Base
|
||||||
|
|
||||||
|
|
||||||
|
# Define class TestTemplateDetectionRule
|
||||||
class TestTemplateDetectionRule(Base):
|
class TestTemplateDetectionRule(Base):
|
||||||
"""
|
"""Association between a test template and a detection rule.
|
||||||
Association between a test template and a detection rule.
|
|
||||||
|
|
||||||
Auto-generated by matching mitre_technique_id, or manually curated.
|
Auto-generated by matching mitre_technique_id, or manually curated.
|
||||||
``is_primary`` marks rules with severity >= high as primary detections.
|
``is_primary`` marks rules with severity >= high as primary detections.
|
||||||
"""
|
"""
|
||||||
|
# Assign __tablename__ = "test_template_detection_rules"
|
||||||
__tablename__ = "test_template_detection_rules"
|
__tablename__ = "test_template_detection_rules"
|
||||||
|
|
||||||
|
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
|
# Assign test_template_id = Column(
|
||||||
test_template_id = Column(
|
test_template_id = Column(
|
||||||
UUID(as_uuid=True),
|
UUID(as_uuid=True),
|
||||||
ForeignKey("test_templates.id", ondelete="CASCADE"),
|
ForeignKey("test_templates.id", ondelete="CASCADE"),
|
||||||
|
# Keyword argument: nullable
|
||||||
nullable=True,
|
nullable=True,
|
||||||
)
|
)
|
||||||
|
# Assign detection_rule_id = Column(
|
||||||
detection_rule_id = Column(
|
detection_rule_id = Column(
|
||||||
UUID(as_uuid=True),
|
UUID(as_uuid=True),
|
||||||
ForeignKey("detection_rules.id", ondelete="CASCADE"),
|
ForeignKey("detection_rules.id", ondelete="CASCADE"),
|
||||||
|
# Keyword argument: nullable
|
||||||
nullable=False,
|
nullable=False,
|
||||||
)
|
)
|
||||||
|
# Assign is_primary = Column(Boolean, default=False)
|
||||||
is_primary = Column(Boolean, default=False)
|
is_primary = Column(Boolean, default=False)
|
||||||
|
|
||||||
# Relationships
|
# Relationships
|
||||||
test_template = relationship("TestTemplate")
|
test_template = relationship("TestTemplate")
|
||||||
|
# Assign detection_rule = relationship("DetectionRule")
|
||||||
detection_rule = relationship("DetectionRule")
|
detection_rule = relationship("DetectionRule")
|
||||||
|
|
||||||
|
# Assign __table_args__ = (
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
Index('ix_ttdr_template', 'test_template_id'),
|
Index('ix_ttdr_template', 'test_template_id'),
|
||||||
Index('ix_ttdr_rule', 'detection_rule_id'),
|
Index('ix_ttdr_rule', 'detection_rule_id'),
|
||||||
UniqueConstraint(
|
UniqueConstraint(
|
||||||
|
# Literal argument value
|
||||||
'test_template_id', 'detection_rule_id',
|
'test_template_id', 'detection_rule_id',
|
||||||
|
# Keyword argument: name
|
||||||
name='uq_template_detection_rule',
|
name='uq_template_detection_rule',
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -4,87 +4,135 @@ Stores profiles of APT groups and their associated MITRE ATT&CK
|
|||||||
techniques, imported from MITRE CTI (STIX 2.0).
|
techniques, imported from MITRE CTI (STIX 2.0).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Import uuid
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
# Import from sqlalchemy
|
||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
Column, String, Text, Boolean, DateTime,
|
Boolean,
|
||||||
ForeignKey, Index, UniqueConstraint, func,
|
Column,
|
||||||
|
DateTime,
|
||||||
|
ForeignKey,
|
||||||
|
Index,
|
||||||
|
String,
|
||||||
|
Text,
|
||||||
|
UniqueConstraint,
|
||||||
|
func,
|
||||||
)
|
)
|
||||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
|
||||||
|
# Import JSONB, UUID from sqlalchemy.dialects.postgresql
|
||||||
|
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||||
|
|
||||||
|
# Import relationship from sqlalchemy.orm
|
||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
|
# Import Base from app.database
|
||||||
from app.database import Base
|
from app.database import Base
|
||||||
|
|
||||||
|
|
||||||
|
# Define class ThreatActor
|
||||||
class ThreatActor(Base):
|
class ThreatActor(Base):
|
||||||
"""
|
"""Threat actor / APT group profile.
|
||||||
Threat actor / APT group profile.
|
|
||||||
|
|
||||||
Imported from MITRE CTI ``intrusion-set`` STIX objects.
|
Imported from MITRE CTI ``intrusion-set`` STIX objects.
|
||||||
"""
|
"""
|
||||||
|
# Assign __tablename__ = "threat_actors"
|
||||||
__tablename__ = "threat_actors"
|
__tablename__ = "threat_actors"
|
||||||
|
|
||||||
|
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
|
# Assign mitre_id = Column(String, unique=True, nullable=True) # e.g. "G00...
|
||||||
mitre_id = Column(String, unique=True, nullable=True) # e.g. "G0016" (APT29)
|
mitre_id = Column(String, unique=True, nullable=True) # e.g. "G0016" (APT29)
|
||||||
|
# Assign name = Column(String, nullable=False)
|
||||||
name = Column(String, nullable=False)
|
name = Column(String, nullable=False)
|
||||||
|
# Assign aliases = Column(JSONB, nullable=True, default=[]) # ["Cozy ...
|
||||||
aliases = Column(JSONB, nullable=True, default=[]) # ["Cozy Bear", "The Dukes", ...]
|
aliases = Column(JSONB, nullable=True, default=[]) # ["Cozy Bear", "The Dukes", ...]
|
||||||
|
# Assign description = Column(Text, nullable=True)
|
||||||
description = Column(Text, nullable=True)
|
description = Column(Text, nullable=True)
|
||||||
|
# Assign country = Column(String, nullable=True)
|
||||||
country = Column(String, nullable=True)
|
country = Column(String, nullable=True)
|
||||||
|
# Assign target_sectors = Column(JSONB, nullable=True, default=[]) # ["government",...
|
||||||
target_sectors = Column(JSONB, nullable=True, default=[]) # ["government", "defense", ...]
|
target_sectors = Column(JSONB, nullable=True, default=[]) # ["government", "defense", ...]
|
||||||
|
# Assign target_regions = Column(JSONB, nullable=True, default=[]) # ["north-americ...
|
||||||
target_regions = Column(JSONB, nullable=True, default=[]) # ["north-america", "europe", ...]
|
target_regions = Column(JSONB, nullable=True, default=[]) # ["north-america", "europe", ...]
|
||||||
|
# Assign motivation = Column(String, nullable=True) # espionage ...
|
||||||
motivation = Column(String, nullable=True) # espionage / financial / destruction / ...
|
motivation = Column(String, nullable=True) # espionage / financial / destruction / ...
|
||||||
|
# Assign sophistication = Column(String, nullable=True) # low / medium /...
|
||||||
sophistication = Column(String, nullable=True) # low / medium / high / advanced
|
sophistication = Column(String, nullable=True) # low / medium / high / advanced
|
||||||
|
# Assign first_seen = Column(String, nullable=True)
|
||||||
first_seen = Column(String, nullable=True)
|
first_seen = Column(String, nullable=True)
|
||||||
|
# Assign last_seen = Column(String, nullable=True)
|
||||||
last_seen = Column(String, nullable=True)
|
last_seen = Column(String, nullable=True)
|
||||||
|
# Assign references = Column(JSONB, nullable=True, default=[]) # [{"url": "...
|
||||||
references = Column(JSONB, nullable=True, default=[]) # [{"url": "...", "description": "..."}]
|
references = Column(JSONB, nullable=True, default=[]) # [{"url": "...", "description": "..."}]
|
||||||
|
# Assign mitre_url = Column(String, nullable=True)
|
||||||
mitre_url = Column(String, nullable=True)
|
mitre_url = Column(String, nullable=True)
|
||||||
|
# Assign is_active = Column(Boolean, default=True)
|
||||||
is_active = Column(Boolean, default=True)
|
is_active = Column(Boolean, default=True)
|
||||||
|
# Assign created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
|
||||||
# Relationships
|
# Relationships
|
||||||
techniques = relationship(
|
techniques = relationship(
|
||||||
|
# Literal argument value
|
||||||
"ThreatActorTechnique",
|
"ThreatActorTechnique",
|
||||||
|
# Keyword argument: back_populates
|
||||||
back_populates="threat_actor",
|
back_populates="threat_actor",
|
||||||
|
# Keyword argument: cascade
|
||||||
cascade="all, delete-orphan",
|
cascade="all, delete-orphan",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Assign __table_args__ = (
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
Index('ix_threat_actors_country', 'country'),
|
Index('ix_threat_actors_country', 'country'),
|
||||||
Index('ix_threat_actors_motivation', 'motivation'),
|
Index('ix_threat_actors_motivation', 'motivation'),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Define class ThreatActorTechnique
|
||||||
class ThreatActorTechnique(Base):
|
class ThreatActorTechnique(Base):
|
||||||
"""
|
"""Association between a threat actor and a MITRE ATT&CK technique.
|
||||||
Association between a threat actor and a MITRE ATT&CK technique.
|
|
||||||
|
|
||||||
Stores additional context about how the actor uses the technique
|
Stores additional context about how the actor uses the technique
|
||||||
(from the STIX ``relationship`` ``uses`` objects).
|
(from the STIX ``relationship`` ``uses`` objects).
|
||||||
"""
|
"""
|
||||||
|
# Assign __tablename__ = "threat_actor_techniques"
|
||||||
__tablename__ = "threat_actor_techniques"
|
__tablename__ = "threat_actor_techniques"
|
||||||
|
|
||||||
|
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
|
# Assign threat_actor_id = Column(
|
||||||
threat_actor_id = Column(
|
threat_actor_id = Column(
|
||||||
UUID(as_uuid=True),
|
UUID(as_uuid=True),
|
||||||
ForeignKey("threat_actors.id", ondelete="CASCADE"),
|
ForeignKey("threat_actors.id", ondelete="CASCADE"),
|
||||||
|
# Keyword argument: nullable
|
||||||
nullable=False,
|
nullable=False,
|
||||||
)
|
)
|
||||||
|
# Assign technique_id = Column(
|
||||||
technique_id = Column(
|
technique_id = Column(
|
||||||
UUID(as_uuid=True),
|
UUID(as_uuid=True),
|
||||||
ForeignKey("techniques.id", ondelete="CASCADE"),
|
ForeignKey("techniques.id", ondelete="CASCADE"),
|
||||||
|
# Keyword argument: nullable
|
||||||
nullable=False,
|
nullable=False,
|
||||||
)
|
)
|
||||||
|
# Assign usage_description = Column(Text, nullable=True)
|
||||||
usage_description = Column(Text, nullable=True)
|
usage_description = Column(Text, nullable=True)
|
||||||
|
# Assign first_seen_using = Column(String, nullable=True)
|
||||||
first_seen_using = Column(String, nullable=True)
|
first_seen_using = Column(String, nullable=True)
|
||||||
|
|
||||||
# Relationships
|
# Relationships
|
||||||
threat_actor = relationship("ThreatActor", back_populates="techniques")
|
threat_actor = relationship("ThreatActor", back_populates="techniques")
|
||||||
|
# Assign technique = relationship("Technique")
|
||||||
technique = relationship("Technique")
|
technique = relationship("Technique")
|
||||||
|
|
||||||
|
# Assign __table_args__ = (
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
Index('ix_threat_actor_techniques_actor', 'threat_actor_id'),
|
Index('ix_threat_actor_techniques_actor', 'threat_actor_id'),
|
||||||
Index('ix_threat_actor_techniques_technique', 'technique_id'),
|
Index('ix_threat_actor_techniques_technique', 'technique_id'),
|
||||||
UniqueConstraint(
|
UniqueConstraint(
|
||||||
|
# Literal argument value
|
||||||
'threat_actor_id', 'technique_id',
|
'threat_actor_id', 'technique_id',
|
||||||
|
# Keyword argument: name
|
||||||
name='uq_actor_technique',
|
name='uq_actor_technique',
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,14 +1,18 @@
|
|||||||
|
"""SQLAlchemy model for the users table."""
|
||||||
|
|
||||||
|
# Import uuid
|
||||||
import uuid
|
import uuid
|
||||||
from sqlalchemy import Column, String, Boolean, DateTime, func
|
from sqlalchemy import Column, String, Boolean, DateTime, func
|
||||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||||
|
|
||||||
|
# Import Base from app.database
|
||||||
from app.database import Base
|
from app.database import Base
|
||||||
|
|
||||||
|
|
||||||
|
# Define class User
|
||||||
class User(Base):
|
class User(Base):
|
||||||
"""
|
"""User model for authentication and authorization.
|
||||||
User model for authentication and authorization.
|
|
||||||
|
|
||||||
Possible roles:
|
Possible roles:
|
||||||
- admin: Full system access
|
- admin: Full system access
|
||||||
- red_tech: Red team technician - can create and edit tests
|
- red_tech: Red team technician - can create and edit tests
|
||||||
@@ -17,16 +21,26 @@ class User(Base):
|
|||||||
- blue_lead: Blue team lead - can validate tests
|
- blue_lead: Blue team lead - can validate tests
|
||||||
- viewer: Read-only access (default)
|
- viewer: Read-only access (default)
|
||||||
"""
|
"""
|
||||||
|
# Assign __tablename__ = "users"
|
||||||
__tablename__ = "users"
|
__tablename__ = "users"
|
||||||
|
|
||||||
|
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
|
# Assign username = Column(String, unique=True, nullable=False)
|
||||||
username = Column(String, unique=True, nullable=False)
|
username = Column(String, unique=True, nullable=False)
|
||||||
|
# Assign email = Column(String, nullable=True)
|
||||||
email = Column(String, nullable=True)
|
email = Column(String, nullable=True)
|
||||||
|
# Assign hashed_password = Column(String, nullable=False)
|
||||||
hashed_password = Column(String, nullable=False)
|
hashed_password = Column(String, nullable=False)
|
||||||
|
# Assign role = Column(String, nullable=False, default="viewer")
|
||||||
role = Column(String, nullable=False, default="viewer")
|
role = Column(String, nullable=False, default="viewer")
|
||||||
|
# Assign is_active = Column(Boolean, default=True)
|
||||||
is_active = Column(Boolean, default=True)
|
is_active = Column(Boolean, default=True)
|
||||||
|
# Assign must_change_password = Column(Boolean, default=True)
|
||||||
must_change_password = Column(Boolean, default=True)
|
must_change_password = Column(Boolean, default=True)
|
||||||
|
# Assign created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
# Assign last_login = Column(DateTime, nullable=True)
|
||||||
last_login = Column(DateTime, nullable=True)
|
last_login = Column(DateTime, nullable=True)
|
||||||
notification_preferences = Column(JSONB, nullable=True, server_default='{"email_on_test_validated": true, "email_on_campaign_completed": true, "email_on_new_mitre_techniques": false, "in_app_all": true}')
|
notification_preferences = Column(JSONB, nullable=True, server_default='{"email_on_test_validated": true, "email_on_campaign_completed": true, "email_on_new_mitre_techniques": false, "in_app_all": true}')
|
||||||
jira_account_id = Column(String(100), nullable=True)
|
jira_account_id = Column(String(100), nullable=True)
|
||||||
|
|||||||
@@ -1,13 +1,22 @@
|
|||||||
"""Worklog model — immutable internal time-tracking records."""
|
"""Worklog model — immutable internal time-tracking records."""
|
||||||
|
|
||||||
|
# Import uuid
|
||||||
import uuid
|
import uuid
|
||||||
from sqlalchemy import Column, String, Integer, DateTime, ForeignKey, Text, Index, func
|
|
||||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
# Import Column, DateTime, ForeignKey, Index, Integer, S... from sqlalchemy
|
||||||
|
from sqlalchemy import Column, DateTime, ForeignKey, Index, Integer, String, Text, func
|
||||||
|
|
||||||
|
# Import JSONB, UUID from sqlalchemy.dialects.postgresql
|
||||||
|
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||||
|
|
||||||
|
# Import relationship from sqlalchemy.orm
|
||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
|
# Import Base from app.database
|
||||||
from app.database import Base
|
from app.database import Base
|
||||||
|
|
||||||
|
|
||||||
|
# Define class Worklog
|
||||||
class Worklog(Base):
|
class Worklog(Base):
|
||||||
"""Internal worklog entry with integrity hash for audit compliance.
|
"""Internal worklog entry with integrity hash for audit compliance.
|
||||||
|
|
||||||
@@ -16,25 +25,42 @@ class Worklog(Base):
|
|||||||
the immutable fields so tampering can be detected.
|
the immutable fields so tampering can be detected.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Assign __tablename__ = "worklogs"
|
||||||
__tablename__ = "worklogs"
|
__tablename__ = "worklogs"
|
||||||
|
|
||||||
|
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
|
# Assign entity_type = Column(String(50), nullable=False)
|
||||||
entity_type = Column(String(50), nullable=False)
|
entity_type = Column(String(50), nullable=False)
|
||||||
|
# Assign entity_id = Column(UUID(as_uuid=True), nullable=False)
|
||||||
entity_id = Column(UUID(as_uuid=True), nullable=False)
|
entity_id = Column(UUID(as_uuid=True), nullable=False)
|
||||||
|
# Assign user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=False)
|
||||||
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=False)
|
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=False)
|
||||||
|
# Assign activity_type = Column(String(100), nullable=False)
|
||||||
activity_type = Column(String(100), nullable=False)
|
activity_type = Column(String(100), nullable=False)
|
||||||
|
# Assign started_at = Column(DateTime, nullable=False)
|
||||||
started_at = Column(DateTime, nullable=False)
|
started_at = Column(DateTime, nullable=False)
|
||||||
|
# Assign ended_at = Column(DateTime)
|
||||||
ended_at = Column(DateTime)
|
ended_at = Column(DateTime)
|
||||||
|
# Assign duration_seconds = Column(Integer, nullable=False)
|
||||||
duration_seconds = Column(Integer, nullable=False)
|
duration_seconds = Column(Integer, nullable=False)
|
||||||
|
# Assign description = Column(Text)
|
||||||
description = Column(Text)
|
description = Column(Text)
|
||||||
|
# Assign tempo_synced = Column(DateTime)
|
||||||
tempo_synced = Column(DateTime)
|
tempo_synced = Column(DateTime)
|
||||||
|
# Assign tempo_worklog_id = Column(String(100))
|
||||||
tempo_worklog_id = Column(String(100))
|
tempo_worklog_id = Column(String(100))
|
||||||
|
# Assign integrity_hash = Column(String(64))
|
||||||
integrity_hash = Column(String(64))
|
integrity_hash = Column(String(64))
|
||||||
|
# Assign created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
# Assign extra_metadata = Column("metadata", JSONB, default={})
|
||||||
extra_metadata = Column("metadata", JSONB, default={})
|
extra_metadata = Column("metadata", JSONB, default={})
|
||||||
|
|
||||||
|
# Assign user = relationship("User", foreign_keys=[user_id])
|
||||||
user = relationship("User", foreign_keys=[user_id])
|
user = relationship("User", foreign_keys=[user_id])
|
||||||
|
|
||||||
|
# Assign __table_args__ = (
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
Index("ix_worklogs_entity_id", "entity_id"),
|
Index("ix_worklogs_entity_id", "entity_id"),
|
||||||
Index("ix_worklogs_user_id", "user_id"),
|
Index("ix_worklogs_user_id", "user_id"),
|
||||||
|
|||||||
@@ -0,0 +1 @@
|
|||||||
|
"""FastAPI router modules — one router per feature domain."""
|
||||||
|
|||||||
@@ -1,50 +1,81 @@
|
|||||||
"""Advanced metrics endpoints — coverage by tactic, never-tested, avg validation time."""
|
"""Advanced metrics endpoints — coverage by tactic, never-tested, avg validation time."""
|
||||||
|
|
||||||
|
# Import APIRouter, Depends from fastapi
|
||||||
from fastapi import APIRouter, Depends
|
from fastapi import APIRouter, Depends
|
||||||
|
|
||||||
|
# Import Session from sqlalchemy.orm
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
# Import get_db from app.database
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
|
|
||||||
|
# Import get_current_user from app.dependencies.auth
|
||||||
from app.dependencies.auth import get_current_user
|
from app.dependencies.auth import get_current_user
|
||||||
|
|
||||||
|
# Import User from app.models.user
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
|
|
||||||
|
# Import advanced_metrics_service from app.services
|
||||||
from app.services import advanced_metrics_service
|
from app.services import advanced_metrics_service
|
||||||
|
|
||||||
|
# Assign router = APIRouter(prefix="/metrics/advanced", tags=["advanced-metrics"])
|
||||||
router = APIRouter(prefix="/metrics/advanced", tags=["advanced-metrics"])
|
router = APIRouter(prefix="/metrics/advanced", tags=["advanced-metrics"])
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @router.get decorator
|
||||||
@router.get("/coverage-by-tactic")
|
@router.get("/coverage-by-tactic")
|
||||||
|
# Define function coverage_by_tactic
|
||||||
def coverage_by_tactic(
|
def coverage_by_tactic(
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: user
|
||||||
user: User = Depends(get_current_user),
|
user: User = Depends(get_current_user),
|
||||||
):
|
) -> list:
|
||||||
"""Coverage percentage broken down by MITRE ATT&CK tactic."""
|
"""Coverage percentage broken down by MITRE ATT&CK tactic."""
|
||||||
|
# Return advanced_metrics_service.get_coverage_by_tactic(db)
|
||||||
return advanced_metrics_service.get_coverage_by_tactic(db)
|
return advanced_metrics_service.get_coverage_by_tactic(db)
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @router.get decorator
|
||||||
@router.get("/never-tested")
|
@router.get("/never-tested")
|
||||||
|
# Define function never_tested_techniques
|
||||||
def never_tested_techniques(
|
def never_tested_techniques(
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: user
|
||||||
user: User = Depends(get_current_user),
|
user: User = Depends(get_current_user),
|
||||||
):
|
) -> list:
|
||||||
"""Techniques that have never had a test created."""
|
"""Techniques that have never had a test created."""
|
||||||
|
# Return advanced_metrics_service.get_never_tested_techniques(db)
|
||||||
return advanced_metrics_service.get_never_tested_techniques(db)
|
return advanced_metrics_service.get_never_tested_techniques(db)
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @router.get decorator
|
||||||
@router.get("/avg-validation-time")
|
@router.get("/avg-validation-time")
|
||||||
|
# Define function avg_validation_time
|
||||||
def avg_validation_time(
|
def avg_validation_time(
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: user
|
||||||
user: User = Depends(get_current_user),
|
user: User = Depends(get_current_user),
|
||||||
):
|
) -> dict:
|
||||||
"""Average time from test creation to validation, computed from audit logs.
|
"""Average time from test creation to validation, computed from audit logs.
|
||||||
|
|
||||||
Returns overall average and per-phase averages where data is available.
|
Returns overall average and per-phase averages where data is available.
|
||||||
"""
|
"""
|
||||||
|
# Return advanced_metrics_service.get_avg_validation_time(db)
|
||||||
return advanced_metrics_service.get_avg_validation_time(db)
|
return advanced_metrics_service.get_avg_validation_time(db)
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @router.get decorator
|
||||||
@router.get("/detection-rate-trend")
|
@router.get("/detection-rate-trend")
|
||||||
|
# Define function detection_rate_trend
|
||||||
def detection_rate_trend(
|
def detection_rate_trend(
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: user
|
||||||
user: User = Depends(get_current_user),
|
user: User = Depends(get_current_user),
|
||||||
):
|
) -> list:
|
||||||
"""Monthly detection rate trend for the last 12 months."""
|
"""Monthly detection rate trend for the last 12 months."""
|
||||||
|
# Return advanced_metrics_service.get_detection_rate_trend(db)
|
||||||
return advanced_metrics_service.get_detection_rate_trend(db)
|
return advanced_metrics_service.get_detection_rate_trend(db)
|
||||||
|
|||||||
@@ -4,52 +4,85 @@ Returns complete datasets without pagination so BI tools can ingest
|
|||||||
directly from URL. All endpoints require authentication.
|
directly from URL. All endpoints require authentication.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Import APIRouter, Depends, Query from fastapi
|
||||||
from fastapi import APIRouter, Depends, Query
|
from fastapi import APIRouter, Depends, Query
|
||||||
|
|
||||||
|
# Import Session from sqlalchemy.orm
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
# Import get_db from app.database
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
|
|
||||||
|
# Import get_current_user, require_role from app.dependencies.auth
|
||||||
from app.dependencies.auth import get_current_user, require_role
|
from app.dependencies.auth import get_current_user, require_role
|
||||||
|
|
||||||
|
# Import User from app.models.user
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
|
|
||||||
|
# Import analytics_service from app.services
|
||||||
from app.services import analytics_service
|
from app.services import analytics_service
|
||||||
|
|
||||||
|
# Assign router = APIRouter(prefix="/analytics", tags=["analytics"])
|
||||||
router = APIRouter(prefix="/analytics", tags=["analytics"])
|
router = APIRouter(prefix="/analytics", tags=["analytics"])
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @router.get decorator
|
||||||
@router.get("/coverage")
|
@router.get("/coverage")
|
||||||
|
# Define function analytics_coverage
|
||||||
def analytics_coverage(
|
def analytics_coverage(
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: user
|
||||||
user: User = Depends(get_current_user),
|
user: User = Depends(get_current_user),
|
||||||
):
|
) -> list:
|
||||||
"""Coverage per technique — flat format for BI dashboards."""
|
"""Coverage per technique — flat format for BI dashboards."""
|
||||||
|
# Return analytics_service.get_coverage_analytics(db)
|
||||||
return analytics_service.get_coverage_analytics(db)
|
return analytics_service.get_coverage_analytics(db)
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @router.get decorator
|
||||||
@router.get("/tests")
|
@router.get("/tests")
|
||||||
|
# Define function analytics_tests
|
||||||
def analytics_tests(
|
def analytics_tests(
|
||||||
|
# Entry: date_from
|
||||||
date_from: str = Query(None, description="ISO date filter (>=)"),
|
date_from: str = Query(None, description="ISO date filter (>=)"),
|
||||||
|
# Entry: date_to
|
||||||
date_to: str = Query(None, description="ISO date filter (<=)"),
|
date_to: str = Query(None, description="ISO date filter (<=)"),
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: user
|
||||||
user: User = Depends(get_current_user),
|
user: User = Depends(get_current_user),
|
||||||
):
|
) -> list:
|
||||||
"""All tests with timestamps — flat format for BI dashboards."""
|
"""All tests with timestamps — flat format for BI dashboards."""
|
||||||
|
# Return analytics_service.get_tests_analytics(
|
||||||
return analytics_service.get_tests_analytics(
|
return analytics_service.get_tests_analytics(
|
||||||
db, date_from=date_from, date_to=date_to
|
db, date_from=date_from, date_to=date_to
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @router.get decorator
|
||||||
@router.get("/trends")
|
@router.get("/trends")
|
||||||
|
# Define function analytics_trends
|
||||||
def analytics_trends(
|
def analytics_trends(
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: user
|
||||||
user: User = Depends(get_current_user),
|
user: User = Depends(get_current_user),
|
||||||
):
|
) -> list:
|
||||||
"""Historical coverage snapshots for trend visualization."""
|
"""Historical coverage snapshots for trend visualization."""
|
||||||
|
# Return analytics_service.get_trends_analytics(db)
|
||||||
return analytics_service.get_trends_analytics(db)
|
return analytics_service.get_trends_analytics(db)
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @router.get decorator
|
||||||
@router.get("/operators")
|
@router.get("/operators")
|
||||||
|
# Define function analytics_operators
|
||||||
def analytics_operators(
|
def analytics_operators(
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: user
|
||||||
user: User = Depends(require_role("admin")),
|
user: User = Depends(require_role("admin")),
|
||||||
):
|
) -> list:
|
||||||
"""Per-operator metrics — for workload management dashboards."""
|
"""Per-operator metrics — for workload management dashboards."""
|
||||||
|
# Return analytics_service.get_operators_analytics(db)
|
||||||
return analytics_service.get_operators_analytics(db)
|
return analytics_service.get_operators_analytics(db)
|
||||||
|
|||||||
@@ -1,77 +1,127 @@
|
|||||||
"""Audit log viewer router (admin only)."""
|
"""Audit log viewer router (admin only)."""
|
||||||
|
|
||||||
|
# Import datetime from datetime
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
# Import Optional from typing
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
# Import APIRouter, Depends, Query from fastapi
|
||||||
from fastapi import APIRouter, Depends, Query
|
from fastapi import APIRouter, Depends, Query
|
||||||
|
|
||||||
|
# Import Session from sqlalchemy.orm
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
# Import get_db from app.database
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
|
|
||||||
|
# Import require_role from app.dependencies.auth
|
||||||
from app.dependencies.auth import require_role
|
from app.dependencies.auth import require_role
|
||||||
|
|
||||||
|
# Import User from app.models.user
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
|
|
||||||
|
# Import AuditLogOut, AuditLogPage from app.schemas.audit
|
||||||
from app.schemas.audit import AuditLogOut, AuditLogPage
|
from app.schemas.audit import AuditLogOut, AuditLogPage
|
||||||
|
|
||||||
|
# Import from app.services.audit_query_service
|
||||||
from app.services.audit_query_service import (
|
from app.services.audit_query_service import (
|
||||||
list_distinct_actions,
|
list_distinct_actions,
|
||||||
list_distinct_entity_types,
|
list_distinct_entity_types,
|
||||||
list_logs,
|
list_logs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Assign router = APIRouter(prefix="/audit-logs", tags=["audit"])
|
||||||
router = APIRouter(prefix="/audit-logs", tags=["audit"])
|
router = APIRouter(prefix="/audit-logs", tags=["audit"])
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @router.get decorator
|
||||||
@router.get("", response_model=AuditLogPage)
|
@router.get("", response_model=AuditLogPage)
|
||||||
|
# Define function list_audit_logs
|
||||||
def list_audit_logs(
|
def list_audit_logs(
|
||||||
|
# Entry: user_id
|
||||||
user_id: Optional[str] = Query(None, description="Filter by user ID"),
|
user_id: Optional[str] = Query(None, description="Filter by user ID"),
|
||||||
|
# Entry: action
|
||||||
action: Optional[str] = Query(None, description="Filter by action type"),
|
action: Optional[str] = Query(None, description="Filter by action type"),
|
||||||
|
# Entry: entity_type
|
||||||
entity_type: Optional[str] = Query(None, description="Filter by entity type"),
|
entity_type: Optional[str] = Query(None, description="Filter by entity type"),
|
||||||
|
# Entry: start_date
|
||||||
start_date: Optional[datetime] = Query(None, description="Filter by start date"),
|
start_date: Optional[datetime] = Query(None, description="Filter by start date"),
|
||||||
|
# Entry: end_date
|
||||||
end_date: Optional[datetime] = Query(None, description="Filter by end date"),
|
end_date: Optional[datetime] = Query(None, description="Filter by end date"),
|
||||||
|
# Entry: offset
|
||||||
offset: int = Query(0, ge=0, description="Number of records to skip"),
|
offset: int = Query(0, ge=0, description="Number of records to skip"),
|
||||||
|
# Entry: limit
|
||||||
limit: int = Query(50, ge=1, le=100, description="Max records to return"),
|
limit: int = Query(50, ge=1, le=100, description="Max records to return"),
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(require_role("admin")),
|
current_user: User = Depends(require_role("admin")),
|
||||||
):
|
) -> AuditLogPage:
|
||||||
"""Return paginated audit logs with optional filters.
|
"""Return paginated audit logs with optional filters.
|
||||||
|
|
||||||
**Requires admin role.**
|
**Requires admin role.**
|
||||||
"""
|
"""
|
||||||
|
# Assign result = list_logs(
|
||||||
result = list_logs(
|
result = list_logs(
|
||||||
db,
|
db,
|
||||||
|
# Keyword argument: user_id
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
# Keyword argument: action
|
||||||
action=action,
|
action=action,
|
||||||
|
# Keyword argument: entity_type
|
||||||
entity_type=entity_type,
|
entity_type=entity_type,
|
||||||
|
# Keyword argument: start_date
|
||||||
start_date=start_date,
|
start_date=start_date,
|
||||||
|
# Keyword argument: end_date
|
||||||
end_date=end_date,
|
end_date=end_date,
|
||||||
|
# Keyword argument: offset
|
||||||
offset=offset,
|
offset=offset,
|
||||||
|
# Keyword argument: limit
|
||||||
limit=limit,
|
limit=limit,
|
||||||
)
|
)
|
||||||
|
# Return AuditLogPage(
|
||||||
return AuditLogPage(
|
return AuditLogPage(
|
||||||
|
# Keyword argument: items
|
||||||
items=[AuditLogOut(**item) for item in result["items"]],
|
items=[AuditLogOut(**item) for item in result["items"]],
|
||||||
|
# Keyword argument: total
|
||||||
total=result["total"],
|
total=result["total"],
|
||||||
|
# Keyword argument: offset
|
||||||
offset=result["offset"],
|
offset=result["offset"],
|
||||||
|
# Keyword argument: limit
|
||||||
limit=result["limit"],
|
limit=result["limit"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @router.get decorator
|
||||||
@router.get("/actions", response_model=list[str])
|
@router.get("/actions", response_model=list[str])
|
||||||
|
# Define function list_actions
|
||||||
def list_actions(
|
def list_actions(
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(require_role("admin")),
|
current_user: User = Depends(require_role("admin")),
|
||||||
):
|
) -> list[str]:
|
||||||
"""Return a list of distinct action types in the audit log.
|
"""Return a list of distinct action types in the audit log.
|
||||||
|
|
||||||
**Requires admin role.**
|
**Requires admin role.**
|
||||||
"""
|
"""
|
||||||
|
# Return list_distinct_actions(db)
|
||||||
return list_distinct_actions(db)
|
return list_distinct_actions(db)
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @router.get decorator
|
||||||
@router.get("/entity-types", response_model=list[str])
|
@router.get("/entity-types", response_model=list[str])
|
||||||
|
# Define function list_entity_types
|
||||||
def list_entity_types(
|
def list_entity_types(
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(require_role("admin")),
|
current_user: User = Depends(require_role("admin")),
|
||||||
):
|
) -> list[str]:
|
||||||
"""Return a list of distinct entity types in the audit log.
|
"""Return a list of distinct entity types in the audit log.
|
||||||
|
|
||||||
**Requires admin role.**
|
**Requires admin role.**
|
||||||
"""
|
"""
|
||||||
|
# Return list_distinct_entity_types(db)
|
||||||
return list_distinct_entity_types(db)
|
return list_distinct_entity_types(db)
|
||||||
|
|||||||
+133
-12
@@ -7,31 +7,68 @@ the token in the body for backwards compatibility and for clients that
|
|||||||
cannot use cookies (e.g. Swagger UI).
|
cannot use cookies (e.g. Swagger UI).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Import os
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
# Import APIRouter, Cookie, Depends, Request, Response from fastapi
|
||||||
from fastapi import APIRouter, Cookie, Depends, Request, Response
|
from fastapi import APIRouter, Cookie, Depends, Request, Response
|
||||||
|
|
||||||
|
# Import OAuth2PasswordRequestForm from fastapi.security
|
||||||
from fastapi.security import OAuth2PasswordRequestForm
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
|
|
||||||
|
# Import jwt (PyJWT)
|
||||||
|
import jwt
|
||||||
|
|
||||||
|
# Import Session from sqlalchemy.orm
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from jose import jwt, JWTError
|
# Import blacklist_token, create_access_token, verify_pa... from app.auth
|
||||||
|
from app.auth import blacklist_token, create_access_token, verify_password
|
||||||
|
|
||||||
from app.auth import create_access_token, blacklist_token, verify_password
|
# Import settings from app.config
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
|
|
||||||
|
# Import get_db from app.database
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
|
|
||||||
|
# Import get_current_user from app.dependencies.auth
|
||||||
from app.dependencies.auth import get_current_user
|
from app.dependencies.auth import get_current_user
|
||||||
|
|
||||||
|
# Import BusinessRuleViolation, PermissionViolation from app.domain.errors
|
||||||
from app.domain.errors import BusinessRuleViolation, PermissionViolation
|
from app.domain.errors import BusinessRuleViolation, PermissionViolation
|
||||||
|
|
||||||
|
# Import UnitOfWork from app.domain.unit_of_work
|
||||||
from app.domain.unit_of_work import UnitOfWork
|
from app.domain.unit_of_work import UnitOfWork
|
||||||
|
|
||||||
|
# Import limiter from app.limiter
|
||||||
from app.limiter import limiter
|
from app.limiter import limiter
|
||||||
|
|
||||||
|
# Import resolve_client_ip from app.middleware.request_context
|
||||||
from app.middleware.request_context import resolve_client_ip
|
from app.middleware.request_context import resolve_client_ip
|
||||||
|
|
||||||
|
# Import User from app.models.user
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.services.auth_service import (
|
|
||||||
_DUMMY_HASH,
|
# Import TokenResponse, UserOut from app.schemas.auth
|
||||||
change_password as auth_change_password,
|
|
||||||
)
|
|
||||||
from app.services.audit_service import log_action
|
|
||||||
from app.schemas.auth import TokenResponse, UserOut
|
from app.schemas.auth import TokenResponse, UserOut
|
||||||
|
|
||||||
|
# Import PasswordChange from app.schemas.user
|
||||||
from app.schemas.user import PasswordChange
|
from app.schemas.user import PasswordChange
|
||||||
|
|
||||||
|
# Import log_action from app.services.audit_service
|
||||||
|
from app.services.audit_service import log_action
|
||||||
|
|
||||||
|
# Import from app.services.auth_service
|
||||||
|
from app.services.auth_service import (
|
||||||
|
_DUMMY_HASH,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Import from app.services.auth_service
|
||||||
|
from app.services.auth_service import (
|
||||||
|
change_password as auth_change_password,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assign router = APIRouter(prefix="/auth", tags=["auth"])
|
||||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||||
|
|
||||||
# SECURE_COOKIES desacopla la seguridad de la cookie del entorno de ejecucion.
|
# SECURE_COOKIES desacopla la seguridad de la cookie del entorno de ejecucion.
|
||||||
@@ -47,111 +84,182 @@ else: # "auto" — activo solo si AEGIS_ENV=production
|
|||||||
_COOKIE_NAME = "aegis_token"
|
_COOKIE_NAME = "aegis_token"
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @router.post decorator
|
||||||
@router.post("/login", response_model=TokenResponse)
|
@router.post("/login", response_model=TokenResponse)
|
||||||
|
# Apply the @limiter.limit decorator
|
||||||
@limiter.limit("5/minute")
|
@limiter.limit("5/minute")
|
||||||
|
# Define function login
|
||||||
def login(
|
def login(
|
||||||
|
# Entry: request
|
||||||
request: Request,
|
request: Request,
|
||||||
|
# Entry: response
|
||||||
response: Response,
|
response: Response,
|
||||||
|
# Entry: form_data
|
||||||
form_data: OAuth2PasswordRequestForm = Depends(),
|
form_data: OAuth2PasswordRequestForm = Depends(),
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
):
|
) -> TokenResponse:
|
||||||
"""Authenticate a user and return a JWT access token.
|
"""Authenticate a user and return a JWT access token.
|
||||||
|
|
||||||
Rate-limited to **5 attempts per minute per IP**. Failed and successful
|
Rate-limited to **5 attempts per minute per IP**. Failed and successful
|
||||||
logins are recorded in the audit log (SEC-009).
|
logins are recorded in the audit log (SEC-009).
|
||||||
"""
|
"""
|
||||||
|
# Assign user = db.query(User).filter(User.username == form_data.username).first()
|
||||||
user = db.query(User).filter(User.username == form_data.username).first()
|
user = db.query(User).filter(User.username == form_data.username).first()
|
||||||
|
# Assign target_hash = user.hashed_password if user else _DUMMY_HASH
|
||||||
target_hash = user.hashed_password if user else _DUMMY_HASH
|
target_hash = user.hashed_password if user else _DUMMY_HASH
|
||||||
|
# Assign password_valid = verify_password(form_data.password, target_hash)
|
||||||
password_valid = verify_password(form_data.password, target_hash)
|
password_valid = verify_password(form_data.password, target_hash)
|
||||||
|
# Assign ip = resolve_client_ip(request)
|
||||||
ip = resolve_client_ip(request)
|
ip = resolve_client_ip(request)
|
||||||
|
|
||||||
|
# Check: user is None or not password_valid
|
||||||
if user is None or not password_valid:
|
if user is None or not password_valid:
|
||||||
|
# Open context manager
|
||||||
with UnitOfWork(db) as uow:
|
with UnitOfWork(db) as uow:
|
||||||
|
# Call log_action()
|
||||||
log_action(
|
log_action(
|
||||||
db,
|
db,
|
||||||
user.id if user else None,
|
user.id if user else None,
|
||||||
|
# Literal argument value
|
||||||
"LOGIN_FAILED",
|
"LOGIN_FAILED",
|
||||||
|
# Literal argument value
|
||||||
"auth",
|
"auth",
|
||||||
|
# Literal argument value
|
||||||
None,
|
None,
|
||||||
|
# Keyword argument: details
|
||||||
details={
|
details={
|
||||||
|
# Literal argument value
|
||||||
"username": form_data.username,
|
"username": form_data.username,
|
||||||
|
# Literal argument value
|
||||||
"ip": ip,
|
"ip": ip,
|
||||||
|
# Literal argument value
|
||||||
"reason": "invalid_credentials",
|
"reason": "invalid_credentials",
|
||||||
},
|
},
|
||||||
|
# Keyword argument: ip_address
|
||||||
ip_address=ip,
|
ip_address=ip,
|
||||||
)
|
)
|
||||||
|
# Call uow.commit()
|
||||||
uow.commit()
|
uow.commit()
|
||||||
|
# Raise BusinessRuleViolation
|
||||||
raise BusinessRuleViolation("Incorrect username or password")
|
raise BusinessRuleViolation("Incorrect username or password")
|
||||||
|
|
||||||
|
# Check: not user.is_active
|
||||||
if not user.is_active:
|
if not user.is_active:
|
||||||
|
# Raise PermissionViolation
|
||||||
raise PermissionViolation("Account is disabled. Contact an administrator.")
|
raise PermissionViolation("Account is disabled. Contact an administrator.")
|
||||||
|
|
||||||
|
# Assign access_token = create_access_token(data={"sub": user.username})
|
||||||
access_token = create_access_token(data={"sub": user.username})
|
access_token = create_access_token(data={"sub": user.username})
|
||||||
|
|
||||||
|
# Open context manager
|
||||||
with UnitOfWork(db) as uow:
|
with UnitOfWork(db) as uow:
|
||||||
|
# Call log_action()
|
||||||
log_action(
|
log_action(
|
||||||
db,
|
db,
|
||||||
user.id,
|
user.id,
|
||||||
|
# Literal argument value
|
||||||
"LOGIN_SUCCESS",
|
"LOGIN_SUCCESS",
|
||||||
|
# Literal argument value
|
||||||
"auth",
|
"auth",
|
||||||
str(user.id),
|
str(user.id),
|
||||||
|
# Keyword argument: details
|
||||||
details={"username": user.username, "ip": ip},
|
details={"username": user.username, "ip": ip},
|
||||||
|
# Keyword argument: ip_address
|
||||||
ip_address=ip,
|
ip_address=ip,
|
||||||
)
|
)
|
||||||
|
# Call uow.commit()
|
||||||
uow.commit()
|
uow.commit()
|
||||||
|
|
||||||
|
# Call response.set_cookie()
|
||||||
response.set_cookie(
|
response.set_cookie(
|
||||||
|
# Keyword argument: key
|
||||||
key=_COOKIE_NAME,
|
key=_COOKIE_NAME,
|
||||||
|
# Keyword argument: value
|
||||||
value=access_token,
|
value=access_token,
|
||||||
|
# Keyword argument: httponly
|
||||||
httponly=True,
|
httponly=True,
|
||||||
|
# Keyword argument: secure
|
||||||
secure=_IS_HTTPS,
|
secure=_IS_HTTPS,
|
||||||
|
# Keyword argument: samesite
|
||||||
samesite="strict",
|
samesite="strict",
|
||||||
|
# Keyword argument: max_age
|
||||||
max_age=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
max_age=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
||||||
|
# Keyword argument: path
|
||||||
path="/",
|
path="/",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Return TokenResponse(access_token=access_token)
|
||||||
return TokenResponse(access_token=access_token)
|
return TokenResponse(access_token=access_token)
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @router.post decorator
|
||||||
@router.post("/logout")
|
@router.post("/logout")
|
||||||
|
# Define function logout
|
||||||
def logout(
|
def logout(
|
||||||
|
# Entry: request
|
||||||
request: Request,
|
request: Request,
|
||||||
|
# Entry: response
|
||||||
response: Response,
|
response: Response,
|
||||||
|
# Entry: aegis_token
|
||||||
aegis_token: str | None = Cookie(None),
|
aegis_token: str | None = Cookie(None),
|
||||||
):
|
) -> dict:
|
||||||
"""Clear the authentication cookie and revoke the current token."""
|
"""Clear the authentication cookie and revoke the current token."""
|
||||||
|
# Assign bearer = (
|
||||||
bearer = (
|
bearer = (
|
||||||
request.headers.get("Authorization")
|
request.headers.get("Authorization")
|
||||||
or request.headers.get("authorization")
|
or request.headers.get("authorization")
|
||||||
or ""
|
or ""
|
||||||
)
|
)
|
||||||
|
# Assign bearer = bearer.removeprefix("Bearer ").removeprefix("bearer ").strip()
|
||||||
bearer = bearer.removeprefix("Bearer ").removeprefix("bearer ").strip()
|
bearer = bearer.removeprefix("Bearer ").removeprefix("bearer ").strip()
|
||||||
|
|
||||||
|
# Assign seen = set()
|
||||||
seen: set[str] = set()
|
seen: set[str] = set()
|
||||||
|
# Iterate over (aegis_token, bearer)
|
||||||
for raw in (aegis_token, bearer):
|
for raw in (aegis_token, bearer):
|
||||||
|
# Check: not raw or raw in seen
|
||||||
if not raw or raw in seen:
|
if not raw or raw in seen:
|
||||||
|
# Skip to the next loop iteration
|
||||||
continue
|
continue
|
||||||
|
# Call seen.add()
|
||||||
seen.add(raw)
|
seen.add(raw)
|
||||||
|
# Attempt the following; catch errors below
|
||||||
try:
|
try:
|
||||||
|
# Assign payload = jwt.decode(
|
||||||
payload = jwt.decode(
|
payload = jwt.decode(
|
||||||
raw,
|
raw,
|
||||||
settings.SECRET_KEY,
|
settings.SECRET_KEY,
|
||||||
|
# Keyword argument: algorithms
|
||||||
algorithms=[settings.ALGORITHM],
|
algorithms=[settings.ALGORITHM],
|
||||||
)
|
)
|
||||||
|
# Assign jti = payload.get("jti")
|
||||||
jti = payload.get("jti")
|
jti = payload.get("jti")
|
||||||
|
# Assign exp = payload.get("exp", 0)
|
||||||
exp = payload.get("exp", 0)
|
exp = payload.get("exp", 0)
|
||||||
|
# Check: jti
|
||||||
if jti:
|
if jti:
|
||||||
|
# Call blacklist_token()
|
||||||
blacklist_token(jti, float(exp))
|
blacklist_token(jti, float(exp))
|
||||||
except JWTError:
|
# Handle any JWT validation error during logout (token may be expired or malformed)
|
||||||
|
except jwt.exceptions.InvalidTokenError:
|
||||||
|
# Intentional no-op placeholder
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# Call response.delete_cookie()
|
||||||
response.delete_cookie(
|
response.delete_cookie(
|
||||||
|
# Keyword argument: key
|
||||||
key=_COOKIE_NAME,
|
key=_COOKIE_NAME,
|
||||||
|
# Keyword argument: httponly
|
||||||
httponly=True,
|
httponly=True,
|
||||||
|
# Keyword argument: secure
|
||||||
secure=_IS_HTTPS,
|
secure=_IS_HTTPS,
|
||||||
|
# Keyword argument: samesite
|
||||||
samesite="strict",
|
samesite="strict",
|
||||||
|
# Keyword argument: path
|
||||||
path="/",
|
path="/",
|
||||||
)
|
)
|
||||||
|
# Return {"detail": "Logged out"}
|
||||||
return {"detail": "Logged out"}
|
return {"detail": "Logged out"}
|
||||||
|
|
||||||
|
|
||||||
@@ -207,25 +315,38 @@ def refresh_token(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/me", response_model=UserOut)
|
@router.get("/me", response_model=UserOut)
|
||||||
def read_current_user(current_user: User = Depends(get_current_user)):
|
# Define function read_current_user
|
||||||
|
def read_current_user(current_user: User = Depends(get_current_user)) -> UserOut:
|
||||||
"""Return the profile of the currently authenticated user."""
|
"""Return the profile of the currently authenticated user."""
|
||||||
|
# Return current_user
|
||||||
return current_user
|
return current_user
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @router.post decorator
|
||||||
@router.post("/change-password")
|
@router.post("/change-password")
|
||||||
|
# Define function change_password
|
||||||
def change_password(
|
def change_password(
|
||||||
|
# Entry: body
|
||||||
body: PasswordChange,
|
body: PasswordChange,
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
) -> dict:
|
||||||
"""Change the current user's password."""
|
"""Change the current user's password."""
|
||||||
|
# Call auth_change_password()
|
||||||
auth_change_password(
|
auth_change_password(
|
||||||
db,
|
db,
|
||||||
current_user,
|
current_user,
|
||||||
|
# Keyword argument: current_password
|
||||||
current_password=body.current_password,
|
current_password=body.current_password,
|
||||||
|
# Keyword argument: new_password
|
||||||
new_password=body.new_password,
|
new_password=body.new_password,
|
||||||
)
|
)
|
||||||
|
# Open context manager
|
||||||
with UnitOfWork(db) as uow:
|
with UnitOfWork(db) as uow:
|
||||||
|
# Call uow.commit()
|
||||||
uow.commit()
|
uow.commit()
|
||||||
|
|
||||||
|
# Return {"detail": "Password changed successfully"}
|
||||||
return {"detail": "Password changed successfully"}
|
return {"detail": "Password changed successfully"}
|
||||||
|
|||||||
@@ -1,80 +1,169 @@
|
|||||||
"""Campaign endpoints — CRUD, test management, activation, and auto-generation.
|
"""Campaign endpoints — CRUD, test management, activation, and auto-generation.
|
||||||
|
|
||||||
Provides comprehensive campaign lifecycle management including
|
Provides comprehensive campaign lifecycle management including test ordering,
|
||||||
test ordering, progress tracking, and threat actor integration.
|
progress tracking, and threat actor integration.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Import logging
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
# Import uuid
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
# Import APIRouter, Depends, Query from fastapi
|
||||||
from fastapi import APIRouter, Depends, Query
|
from fastapi import APIRouter, Depends, Query
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
# Import BaseModel, Field from pydantic
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
# Import Session from sqlalchemy.orm
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
# Import get_db from app.database
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
|
|
||||||
|
# Import get_current_user, require_any_role from app.dependencies.auth
|
||||||
from app.dependencies.auth import get_current_user, require_any_role
|
from app.dependencies.auth import get_current_user, require_any_role
|
||||||
|
|
||||||
|
# Import UnitOfWork from app.domain.unit_of_work
|
||||||
|
from app.domain.unit_of_work import UnitOfWork
|
||||||
|
|
||||||
|
# Import User from app.models.user
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.models.campaign import Campaign, CampaignTest
|
from app.models.campaign import Campaign, CampaignTest
|
||||||
from app.models.test import Test
|
from app.models.test import Test
|
||||||
from app.services.campaign_service import generate_campaign_from_threat_actor
|
from app.services.campaign_service import generate_campaign_from_threat_actor
|
||||||
from app.services.campaign_crud_service import (
|
from app.services.campaign_crud_service import (
|
||||||
add_test_to_campaign as crud_add_test,
|
add_test_to_campaign as crud_add_test,
|
||||||
activate_campaign as crud_activate,
|
)
|
||||||
|
|
||||||
|
# Import from app.services.campaign_crud_service
|
||||||
|
from app.services.campaign_crud_service import (
|
||||||
complete_campaign as crud_complete,
|
complete_campaign as crud_complete,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Import from app.services.campaign_crud_service
|
||||||
|
from app.services.campaign_crud_service import (
|
||||||
create_campaign as crud_create,
|
create_campaign as crud_create,
|
||||||
delete_campaign as crud_delete,
|
delete_campaign as crud_delete,
|
||||||
get_campaign_detail as crud_get_detail,
|
get_campaign_detail as crud_get_detail,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Import from app.services.campaign_crud_service
|
||||||
|
from app.services.campaign_crud_service import (
|
||||||
get_campaign_history as crud_get_history,
|
get_campaign_history as crud_get_history,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Import from app.services.campaign_crud_service
|
||||||
|
from app.services.campaign_crud_service import (
|
||||||
get_campaign_progress_data as crud_get_progress,
|
get_campaign_progress_data as crud_get_progress,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Import from app.services.campaign_crud_service
|
||||||
|
from app.services.campaign_crud_service import (
|
||||||
list_campaigns as crud_list,
|
list_campaigns as crud_list,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Import from app.services.campaign_crud_service
|
||||||
|
from app.services.campaign_crud_service import (
|
||||||
remove_test_from_campaign as crud_remove_test,
|
remove_test_from_campaign as crud_remove_test,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Import from app.services.campaign_crud_service
|
||||||
|
from app.services.campaign_crud_service import (
|
||||||
schedule_campaign as crud_schedule,
|
schedule_campaign as crud_schedule,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Import from app.services.campaign_crud_service
|
||||||
|
from app.services.campaign_crud_service import (
|
||||||
serialize_campaign,
|
serialize_campaign,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Import from app.services.campaign_crud_service
|
||||||
|
from app.services.campaign_crud_service import (
|
||||||
update_campaign as crud_update,
|
update_campaign as crud_update,
|
||||||
)
|
)
|
||||||
from app.domain.unit_of_work import UnitOfWork
|
|
||||||
from app.services.audit_service import log_action
|
# Import generate_campaign_from_threat_actor from app.services.campaign_service
|
||||||
|
from app.services.campaign_service import generate_campaign_from_threat_actor
|
||||||
|
|
||||||
|
# Import notify_role from app.services.notification_service
|
||||||
from app.services.notification_service import notify_role
|
from app.services.notification_service import notify_role
|
||||||
from app.services.webhook_service import dispatch_webhook
|
from app.services.webhook_service import dispatch_webhook
|
||||||
|
|
||||||
|
# Assign logger = logging.getLogger(__name__)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Assign router = APIRouter(prefix="/campaigns", tags=["campaigns"])
|
||||||
router = APIRouter(prefix="/campaigns", tags=["campaigns"])
|
router = APIRouter(prefix="/campaigns", tags=["campaigns"])
|
||||||
|
|
||||||
|
|
||||||
# ── Pydantic schemas ─────────────────────────────────────────────────
|
# ── Pydantic schemas ─────────────────────────────────────────────────
|
||||||
|
|
||||||
class CampaignCreate(BaseModel):
|
class CampaignCreate(BaseModel):
|
||||||
|
"""Payload for creating a new campaign."""
|
||||||
|
|
||||||
|
# name: str
|
||||||
name: str
|
name: str
|
||||||
|
# Assign description = None
|
||||||
description: Optional[str] = None
|
description: Optional[str] = None
|
||||||
|
# Assign type = "custom"
|
||||||
type: str = "custom"
|
type: str = "custom"
|
||||||
|
# Assign threat_actor_id = None
|
||||||
threat_actor_id: Optional[str] = None
|
threat_actor_id: Optional[str] = None
|
||||||
|
# Assign target_platform = None
|
||||||
target_platform: Optional[str] = None
|
target_platform: Optional[str] = None
|
||||||
|
# Assign tags = Field(default_factory=list)
|
||||||
tags: Optional[list[str]] = Field(default_factory=list)
|
tags: Optional[list[str]] = Field(default_factory=list)
|
||||||
|
# Assign scheduled_at = None
|
||||||
scheduled_at: Optional[str] = None
|
scheduled_at: Optional[str] = None
|
||||||
start_date: Optional[str] = None # ISO date — campaign won't activate before this
|
start_date: Optional[str] = None # ISO date — campaign won't activate before this
|
||||||
|
|
||||||
|
|
||||||
|
# Define class CampaignUpdate
|
||||||
class CampaignUpdate(BaseModel):
|
class CampaignUpdate(BaseModel):
|
||||||
|
"""Payload for updating an existing campaign's metadata."""
|
||||||
|
|
||||||
|
# Assign name = None
|
||||||
name: Optional[str] = None
|
name: Optional[str] = None
|
||||||
|
# Assign description = None
|
||||||
description: Optional[str] = None
|
description: Optional[str] = None
|
||||||
|
# Assign type = None
|
||||||
type: Optional[str] = None
|
type: Optional[str] = None
|
||||||
|
# Assign target_platform = None
|
||||||
target_platform: Optional[str] = None
|
target_platform: Optional[str] = None
|
||||||
|
# Assign tags = None
|
||||||
tags: Optional[list[str]] = None
|
tags: Optional[list[str]] = None
|
||||||
|
# Assign scheduled_at = None
|
||||||
scheduled_at: Optional[str] = None
|
scheduled_at: Optional[str] = None
|
||||||
start_date: Optional[str] = None # ISO date — can be updated while still in draft
|
start_date: Optional[str] = None # ISO date — can be updated while still in draft
|
||||||
|
|
||||||
|
|
||||||
|
# Define class AddTestPayload
|
||||||
class AddTestPayload(BaseModel):
|
class AddTestPayload(BaseModel):
|
||||||
|
"""Payload for adding a test to a campaign."""
|
||||||
|
|
||||||
|
# test_id: str
|
||||||
test_id: str
|
test_id: str
|
||||||
|
# Assign order_index = None
|
||||||
order_index: Optional[int] = None
|
order_index: Optional[int] = None
|
||||||
|
# Assign depends_on = None
|
||||||
depends_on: Optional[str] = None
|
depends_on: Optional[str] = None
|
||||||
|
# Assign phase = None
|
||||||
phase: Optional[str] = None
|
phase: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
# Define class SchedulePayload
|
||||||
class SchedulePayload(BaseModel):
|
class SchedulePayload(BaseModel):
|
||||||
|
"""Payload for scheduling or rescheduling a campaign run."""
|
||||||
|
|
||||||
|
# is_recurring: bool
|
||||||
is_recurring: bool
|
is_recurring: bool
|
||||||
|
# Assign recurrence_pattern = None # weekly, monthly, quarterly
|
||||||
recurrence_pattern: Optional[str] = None # weekly, monthly, quarterly
|
recurrence_pattern: Optional[str] = None # weekly, monthly, quarterly
|
||||||
|
# Assign next_run_at = None
|
||||||
next_run_at: Optional[str] = None
|
next_run_at: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
@@ -83,24 +172,54 @@ class SchedulePayload(BaseModel):
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
@router.get("")
|
@router.get("")
|
||||||
|
# Define function list_campaigns
|
||||||
def list_campaigns(
|
def list_campaigns(
|
||||||
|
# Entry: type
|
||||||
type: Optional[str] = Query(None),
|
type: Optional[str] = Query(None),
|
||||||
|
# Entry: status
|
||||||
status: Optional[str] = Query(None),
|
status: Optional[str] = Query(None),
|
||||||
|
# Entry: threat_actor_id
|
||||||
threat_actor_id: Optional[str] = Query(None),
|
threat_actor_id: Optional[str] = Query(None),
|
||||||
|
# Entry: search
|
||||||
search: Optional[str] = Query(None),
|
search: Optional[str] = Query(None),
|
||||||
|
# Entry: offset
|
||||||
offset: int = Query(0, ge=0),
|
offset: int = Query(0, ge=0),
|
||||||
|
# Entry: limit
|
||||||
limit: int = Query(50, ge=1, le=200),
|
limit: int = Query(50, ge=1, le=200),
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
) -> list:
|
||||||
"""List campaigns with optional filters and pagination."""
|
"""List campaigns with optional filters and pagination.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
type (Optional[str]): Filter by campaign type (e.g. ``custom``, ``threat_actor``).
|
||||||
|
status (Optional[str]): Filter by campaign status (e.g. ``draft``, ``active``).
|
||||||
|
threat_actor_id (Optional[str]): Filter campaigns linked to a specific threat actor.
|
||||||
|
search (Optional[str]): Free-text search against campaign name.
|
||||||
|
offset (int): Number of records to skip for pagination.
|
||||||
|
limit (int): Maximum number of records to return.
|
||||||
|
db (Session): SQLAlchemy database session.
|
||||||
|
current_user (User): Authenticated user making the request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: Serialised list of campaign summary dicts.
|
||||||
|
"""
|
||||||
|
# Return crud_list(
|
||||||
return crud_list(
|
return crud_list(
|
||||||
db,
|
db,
|
||||||
|
# Keyword argument: type
|
||||||
type=type,
|
type=type,
|
||||||
|
# Keyword argument: status
|
||||||
status=status,
|
status=status,
|
||||||
|
# Keyword argument: threat_actor_id
|
||||||
threat_actor_id=threat_actor_id,
|
threat_actor_id=threat_actor_id,
|
||||||
|
# Keyword argument: search
|
||||||
search=search,
|
search=search,
|
||||||
|
# Keyword argument: offset
|
||||||
offset=offset,
|
offset=offset,
|
||||||
|
# Keyword argument: limit
|
||||||
limit=limit,
|
limit=limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -110,36 +229,64 @@ def list_campaigns(
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
@router.post("", status_code=201)
|
@router.post("", status_code=201)
|
||||||
|
# Define function create_campaign
|
||||||
def create_campaign(
|
def create_campaign(
|
||||||
|
# Entry: payload
|
||||||
payload: CampaignCreate,
|
payload: CampaignCreate,
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||||
):
|
) -> dict:
|
||||||
"""Create a new campaign."""
|
"""Create a new campaign.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
payload (CampaignCreate): Fields for the new campaign (name, type, threat actor, etc.).
|
||||||
|
db (Session): SQLAlchemy database session.
|
||||||
|
current_user (User): Authenticated red_lead or blue_lead creating the campaign.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Serialised representation of the newly created campaign.
|
||||||
|
"""
|
||||||
|
# Open context manager
|
||||||
with UnitOfWork(db) as uow:
|
with UnitOfWork(db) as uow:
|
||||||
|
# Assign result = crud_create(
|
||||||
result = crud_create(
|
result = crud_create(
|
||||||
db,
|
db,
|
||||||
|
# Keyword argument: creator_id
|
||||||
creator_id=current_user.id,
|
creator_id=current_user.id,
|
||||||
|
# Keyword argument: name
|
||||||
name=payload.name,
|
name=payload.name,
|
||||||
|
# Keyword argument: description
|
||||||
description=payload.description,
|
description=payload.description,
|
||||||
|
# Keyword argument: type
|
||||||
type=payload.type,
|
type=payload.type,
|
||||||
|
# Keyword argument: threat_actor_id
|
||||||
threat_actor_id=payload.threat_actor_id,
|
threat_actor_id=payload.threat_actor_id,
|
||||||
|
# Keyword argument: target_platform
|
||||||
target_platform=payload.target_platform,
|
target_platform=payload.target_platform,
|
||||||
|
# Keyword argument: tags
|
||||||
tags=payload.tags,
|
tags=payload.tags,
|
||||||
|
# Keyword argument: scheduled_at
|
||||||
scheduled_at=payload.scheduled_at,
|
scheduled_at=payload.scheduled_at,
|
||||||
start_date=payload.start_date,
|
start_date=payload.start_date,
|
||||||
)
|
)
|
||||||
campaign_id = result["id"]
|
campaign_id = result["id"]
|
||||||
log_action(
|
log_action(
|
||||||
db,
|
db,
|
||||||
|
# Keyword argument: user_id
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
|
# Keyword argument: action
|
||||||
action="create_campaign",
|
action="create_campaign",
|
||||||
|
# Keyword argument: entity_type
|
||||||
entity_type="campaign",
|
entity_type="campaign",
|
||||||
entity_id=campaign_id,
|
entity_id=campaign_id,
|
||||||
details={"name": payload.name, "type": payload.type},
|
details={"name": payload.name, "type": payload.type},
|
||||||
)
|
)
|
||||||
|
# Call uow.commit()
|
||||||
uow.commit()
|
uow.commit()
|
||||||
|
|
||||||
|
# Return result
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@@ -148,12 +295,26 @@ def create_campaign(
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
@router.get("/{campaign_id}")
|
@router.get("/{campaign_id}")
|
||||||
|
# Define function get_campaign
|
||||||
def get_campaign(
|
def get_campaign(
|
||||||
|
# Entry: campaign_id
|
||||||
campaign_id: str,
|
campaign_id: str,
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
) -> dict:
|
||||||
"""Get detailed campaign info including tests and progress."""
|
"""Get detailed campaign info including tests and progress.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
campaign_id (str): UUID string of the campaign to retrieve.
|
||||||
|
db (Session): SQLAlchemy database session.
|
||||||
|
current_user (User): Authenticated user making the request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Campaign detail including associated tests and progress metrics.
|
||||||
|
"""
|
||||||
|
# Return crud_get_detail(db, campaign_id)
|
||||||
return crud_get_detail(db, campaign_id)
|
return crud_get_detail(db, campaign_id)
|
||||||
|
|
||||||
|
|
||||||
@@ -162,32 +323,60 @@ def get_campaign(
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
@router.patch("/{campaign_id}")
|
@router.patch("/{campaign_id}")
|
||||||
|
# Define function update_campaign
|
||||||
def update_campaign(
|
def update_campaign(
|
||||||
|
# Entry: campaign_id
|
||||||
campaign_id: str,
|
campaign_id: str,
|
||||||
|
# Entry: payload
|
||||||
payload: CampaignUpdate,
|
payload: CampaignUpdate,
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||||
):
|
) -> dict:
|
||||||
"""Update a campaign. Only allowed in draft or active state."""
|
"""Update a campaign. Only allowed in draft or active state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
campaign_id (str): UUID string of the campaign to update.
|
||||||
|
payload (CampaignUpdate): Partial update payload; only set fields are applied.
|
||||||
|
db (Session): SQLAlchemy database session.
|
||||||
|
current_user (User): Authenticated red_lead or blue_lead performing the update.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Serialised representation of the updated campaign.
|
||||||
|
"""
|
||||||
|
# Assign update_data = payload.model_dump(exclude_unset=True)
|
||||||
update_data = payload.model_dump(exclude_unset=True)
|
update_data = payload.model_dump(exclude_unset=True)
|
||||||
|
# Open context manager
|
||||||
with UnitOfWork(db) as uow:
|
with UnitOfWork(db) as uow:
|
||||||
|
# Assign result = crud_update(
|
||||||
result = crud_update(
|
result = crud_update(
|
||||||
db,
|
db,
|
||||||
campaign_id,
|
campaign_id,
|
||||||
|
# Keyword argument: updater_id
|
||||||
updater_id=current_user.id,
|
updater_id=current_user.id,
|
||||||
|
# Keyword argument: updater_role
|
||||||
updater_role=current_user.role,
|
updater_role=current_user.role,
|
||||||
**update_data,
|
**update_data,
|
||||||
)
|
)
|
||||||
|
# Call log_action()
|
||||||
log_action(
|
log_action(
|
||||||
db,
|
db,
|
||||||
|
# Keyword argument: user_id
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
|
# Keyword argument: action
|
||||||
action="update_campaign",
|
action="update_campaign",
|
||||||
|
# Keyword argument: entity_type
|
||||||
entity_type="campaign",
|
entity_type="campaign",
|
||||||
|
# Keyword argument: entity_id
|
||||||
entity_id=campaign_id,
|
entity_id=campaign_id,
|
||||||
|
# Keyword argument: details
|
||||||
details={"updated_fields": list(update_data.keys())},
|
details={"updated_fields": list(update_data.keys())},
|
||||||
)
|
)
|
||||||
|
# Call uow.commit()
|
||||||
uow.commit()
|
uow.commit()
|
||||||
|
|
||||||
|
# Return result
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@@ -227,22 +416,44 @@ def delete_campaign(
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
@router.post("/{campaign_id}/tests")
|
@router.post("/{campaign_id}/tests")
|
||||||
|
# Define function add_test_to_campaign
|
||||||
def add_test_to_campaign(
|
def add_test_to_campaign(
|
||||||
|
# Entry: campaign_id
|
||||||
campaign_id: str,
|
campaign_id: str,
|
||||||
|
# Entry: payload
|
||||||
payload: AddTestPayload,
|
payload: AddTestPayload,
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||||
):
|
) -> dict:
|
||||||
"""Add a test to a campaign with optional ordering and dependency."""
|
"""Add a test to a campaign with optional ordering and dependency.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
campaign_id (str): UUID string of the target campaign.
|
||||||
|
payload (AddTestPayload): Test ID plus optional order index, dependency, and phase.
|
||||||
|
db (Session): SQLAlchemy database session.
|
||||||
|
current_user (User): Authenticated red_lead or blue_lead adding the test.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: The created campaign-test association record.
|
||||||
|
"""
|
||||||
|
# Open context manager
|
||||||
with UnitOfWork(db) as uow:
|
with UnitOfWork(db) as uow:
|
||||||
|
# Assign result = crud_add_test(
|
||||||
result = crud_add_test(
|
result = crud_add_test(
|
||||||
db,
|
db,
|
||||||
campaign_id,
|
campaign_id,
|
||||||
|
# Keyword argument: test_id
|
||||||
test_id=payload.test_id,
|
test_id=payload.test_id,
|
||||||
|
# Keyword argument: order_index
|
||||||
order_index=payload.order_index,
|
order_index=payload.order_index,
|
||||||
|
# Keyword argument: depends_on
|
||||||
depends_on=payload.depends_on,
|
depends_on=payload.depends_on,
|
||||||
|
# Keyword argument: phase
|
||||||
phase=payload.phase,
|
phase=payload.phase,
|
||||||
)
|
)
|
||||||
|
# Call uow.commit()
|
||||||
uow.commit()
|
uow.commit()
|
||||||
|
|
||||||
return result
|
return result
|
||||||
@@ -253,16 +464,35 @@ def add_test_to_campaign(
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
@router.delete("/{campaign_id}/tests/{campaign_test_id}")
|
@router.delete("/{campaign_id}/tests/{campaign_test_id}")
|
||||||
|
# Define function remove_test_from_campaign
|
||||||
def remove_test_from_campaign(
|
def remove_test_from_campaign(
|
||||||
|
# Entry: campaign_id
|
||||||
campaign_id: str,
|
campaign_id: str,
|
||||||
|
# Entry: campaign_test_id
|
||||||
campaign_test_id: str,
|
campaign_test_id: str,
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||||
):
|
) -> dict:
|
||||||
"""Remove a test from a campaign."""
|
"""Remove a test from a campaign.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
campaign_id (str): UUID string of the campaign.
|
||||||
|
campaign_test_id (str): UUID string of the campaign-test association to remove.
|
||||||
|
db (Session): SQLAlchemy database session.
|
||||||
|
current_user (User): Authenticated red_lead or blue_lead removing the test.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Confirmation message with key ``detail``.
|
||||||
|
"""
|
||||||
|
# Open context manager
|
||||||
with UnitOfWork(db) as uow:
|
with UnitOfWork(db) as uow:
|
||||||
|
# Call crud_remove_test()
|
||||||
crud_remove_test(db, campaign_id, campaign_test_id)
|
crud_remove_test(db, campaign_id, campaign_test_id)
|
||||||
|
# Call uow.commit()
|
||||||
uow.commit()
|
uow.commit()
|
||||||
|
# Return {"detail": "Test removed from campaign"}
|
||||||
return {"detail": "Test removed from campaign"}
|
return {"detail": "Test removed from campaign"}
|
||||||
|
|
||||||
|
|
||||||
@@ -271,10 +501,13 @@ def remove_test_from_campaign(
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
@router.post("/{campaign_id}/activate")
|
@router.post("/{campaign_id}/activate")
|
||||||
|
# Define function activate_campaign
|
||||||
def activate_campaign(
|
def activate_campaign(
|
||||||
|
# Entry: campaign_id
|
||||||
campaign_id: str,
|
campaign_id: str,
|
||||||
force: bool = Query(False, description="Activate even if start_date is in the future"),
|
force: bool = Query(False, description="Activate even if start_date is in the future"),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||||
):
|
):
|
||||||
"""Activate a campaign, moving it from draft to active.
|
"""Activate a campaign, moving it from draft to active.
|
||||||
@@ -303,25 +536,41 @@ def activate_campaign(
|
|||||||
)
|
)
|
||||||
|
|
||||||
with UnitOfWork(db) as uow:
|
with UnitOfWork(db) as uow:
|
||||||
|
# Assign campaign = crud_activate(db, campaign_id)
|
||||||
campaign = crud_activate(db, campaign_id)
|
campaign = crud_activate(db, campaign_id)
|
||||||
|
# Call notify_role()
|
||||||
notify_role(
|
notify_role(
|
||||||
db,
|
db,
|
||||||
|
# Keyword argument: role
|
||||||
role="red_tech",
|
role="red_tech",
|
||||||
|
# Keyword argument: type
|
||||||
type="campaign_activated",
|
type="campaign_activated",
|
||||||
|
# Keyword argument: title
|
||||||
title="Campaign activated",
|
title="Campaign activated",
|
||||||
|
# Keyword argument: message
|
||||||
message=f'Campaign "{campaign.name}" has been activated.',
|
message=f'Campaign "{campaign.name}" has been activated.',
|
||||||
|
# Keyword argument: entity_type
|
||||||
entity_type="campaign",
|
entity_type="campaign",
|
||||||
|
# Keyword argument: entity_id
|
||||||
entity_id=campaign.id,
|
entity_id=campaign.id,
|
||||||
)
|
)
|
||||||
|
# Call log_action()
|
||||||
log_action(
|
log_action(
|
||||||
db,
|
db,
|
||||||
|
# Keyword argument: user_id
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
|
# Keyword argument: action
|
||||||
action="activate_campaign",
|
action="activate_campaign",
|
||||||
|
# Keyword argument: entity_type
|
||||||
entity_type="campaign",
|
entity_type="campaign",
|
||||||
|
# Keyword argument: entity_id
|
||||||
entity_id=campaign.id,
|
entity_id=campaign.id,
|
||||||
|
# Keyword argument: details
|
||||||
details={"name": campaign.name},
|
details={"name": campaign.name},
|
||||||
)
|
)
|
||||||
|
# Call uow.commit()
|
||||||
uow.commit()
|
uow.commit()
|
||||||
|
# Reload ORM object attributes from the database
|
||||||
db.refresh(campaign)
|
db.refresh(campaign)
|
||||||
|
|
||||||
# Create Jira tickets for campaign and tests at activation time (non-fatal).
|
# Create Jira tickets for campaign and tests at activation time (non-fatal).
|
||||||
@@ -359,26 +608,50 @@ def activate_campaign(
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
@router.post("/{campaign_id}/complete")
|
@router.post("/{campaign_id}/complete")
|
||||||
|
# Define function complete_campaign
|
||||||
def complete_campaign(
|
def complete_campaign(
|
||||||
|
# Entry: campaign_id
|
||||||
campaign_id: str,
|
campaign_id: str,
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(require_any_role("red_lead", "admin")),
|
current_user: User = Depends(require_any_role("red_lead", "admin")),
|
||||||
):
|
) -> dict:
|
||||||
"""Mark a campaign as completed."""
|
"""Mark a campaign as completed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
campaign_id (str): UUID string of the campaign to complete.
|
||||||
|
db (Session): SQLAlchemy database session.
|
||||||
|
current_user (User): Authenticated red_lead or admin completing the campaign.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Serialised representation of the completed campaign.
|
||||||
|
"""
|
||||||
|
# Open context manager
|
||||||
with UnitOfWork(db) as uow:
|
with UnitOfWork(db) as uow:
|
||||||
|
# Assign campaign = crud_complete(db, campaign_id)
|
||||||
campaign = crud_complete(db, campaign_id)
|
campaign = crud_complete(db, campaign_id)
|
||||||
|
# Call log_action()
|
||||||
log_action(
|
log_action(
|
||||||
db,
|
db,
|
||||||
|
# Keyword argument: user_id
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
|
# Keyword argument: action
|
||||||
action="complete_campaign",
|
action="complete_campaign",
|
||||||
|
# Keyword argument: entity_type
|
||||||
entity_type="campaign",
|
entity_type="campaign",
|
||||||
|
# Keyword argument: entity_id
|
||||||
entity_id=campaign.id,
|
entity_id=campaign.id,
|
||||||
|
# Keyword argument: details
|
||||||
details={"name": campaign.name},
|
details={"name": campaign.name},
|
||||||
)
|
)
|
||||||
|
# Call uow.commit()
|
||||||
uow.commit()
|
uow.commit()
|
||||||
|
# Reload ORM object attributes from the database
|
||||||
db.refresh(campaign)
|
db.refresh(campaign)
|
||||||
dispatch_webhook("campaign.completed", {"campaign_id": str(campaign.id), "name": campaign.name})
|
dispatch_webhook("campaign.completed", {"campaign_id": str(campaign.id), "name": campaign.name})
|
||||||
|
|
||||||
|
# Return serialize_campaign(db, campaign)
|
||||||
return serialize_campaign(db, campaign)
|
return serialize_campaign(db, campaign)
|
||||||
|
|
||||||
|
|
||||||
@@ -387,12 +660,26 @@ def complete_campaign(
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
@router.get("/{campaign_id}/progress")
|
@router.get("/{campaign_id}/progress")
|
||||||
|
# Define function get_campaign_progress_endpoint
|
||||||
def get_campaign_progress_endpoint(
|
def get_campaign_progress_endpoint(
|
||||||
|
# Entry: campaign_id
|
||||||
campaign_id: str,
|
campaign_id: str,
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
) -> dict:
|
||||||
"""Get progress statistics for a campaign."""
|
"""Get progress statistics for a campaign.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
campaign_id (str): UUID string of the campaign.
|
||||||
|
db (Session): SQLAlchemy database session.
|
||||||
|
current_user (User): Authenticated user making the request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Progress breakdown including counts by test state and overall percentage.
|
||||||
|
"""
|
||||||
|
# Return crud_get_progress(db, campaign_id)
|
||||||
return crud_get_progress(db, campaign_id)
|
return crud_get_progress(db, campaign_id)
|
||||||
|
|
||||||
|
|
||||||
@@ -405,16 +692,27 @@ class GenerateFromActorPayload(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/from-threat-actor/{actor_id}", status_code=201)
|
@router.post("/from-threat-actor/{actor_id}", status_code=201)
|
||||||
|
# Define function generate_campaign_from_actor
|
||||||
def generate_campaign_from_actor(
|
def generate_campaign_from_actor(
|
||||||
|
# Entry: actor_id
|
||||||
actor_id: str,
|
actor_id: str,
|
||||||
payload: GenerateFromActorPayload = GenerateFromActorPayload(),
|
payload: GenerateFromActorPayload = GenerateFromActorPayload(),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||||
):
|
) -> dict:
|
||||||
"""Auto-generate a campaign from a threat actor's uncovered techniques.
|
"""Auto-generate a campaign from a threat actor's uncovered techniques.
|
||||||
|
|
||||||
Creates tests from the best available templates and orders them
|
Creates tests from the best available templates and orders them
|
||||||
by kill chain phase.
|
by kill chain phase.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
actor_id (str): UUID string of the threat actor to generate a campaign for.
|
||||||
|
db (Session): SQLAlchemy database session.
|
||||||
|
current_user (User): Authenticated red_lead or blue_lead requesting the generation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Serialised representation of the newly generated campaign.
|
||||||
"""
|
"""
|
||||||
start_date_parsed = (
|
start_date_parsed = (
|
||||||
datetime.fromisoformat(payload.start_date) if payload.start_date else None
|
datetime.fromisoformat(payload.start_date) if payload.start_date else None
|
||||||
@@ -426,17 +724,26 @@ def generate_campaign_from_actor(
|
|||||||
start_date=start_date_parsed,
|
start_date=start_date_parsed,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Open context manager
|
||||||
with UnitOfWork(db) as uow:
|
with UnitOfWork(db) as uow:
|
||||||
|
# Call log_action()
|
||||||
log_action(
|
log_action(
|
||||||
db,
|
db,
|
||||||
|
# Keyword argument: user_id
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
|
# Keyword argument: action
|
||||||
action="generate_campaign",
|
action="generate_campaign",
|
||||||
|
# Keyword argument: entity_type
|
||||||
entity_type="campaign",
|
entity_type="campaign",
|
||||||
|
# Keyword argument: entity_id
|
||||||
entity_id=campaign.id,
|
entity_id=campaign.id,
|
||||||
|
# Keyword argument: details
|
||||||
details={"actor_id": actor_id, "campaign_name": campaign.name},
|
details={"actor_id": actor_id, "campaign_name": campaign.name},
|
||||||
)
|
)
|
||||||
|
# Call uow.commit()
|
||||||
uow.commit()
|
uow.commit()
|
||||||
|
|
||||||
|
# Return serialize_campaign(db, campaign)
|
||||||
return serialize_campaign(db, campaign)
|
return serialize_campaign(db, campaign)
|
||||||
|
|
||||||
|
|
||||||
@@ -445,41 +752,74 @@ def generate_campaign_from_actor(
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
@router.patch("/{campaign_id}/schedule")
|
@router.patch("/{campaign_id}/schedule")
|
||||||
|
# Define function schedule_campaign
|
||||||
def schedule_campaign(
|
def schedule_campaign(
|
||||||
|
# Entry: campaign_id
|
||||||
campaign_id: str,
|
campaign_id: str,
|
||||||
|
# Entry: payload
|
||||||
payload: SchedulePayload,
|
payload: SchedulePayload,
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||||
):
|
) -> dict:
|
||||||
"""Configure or update the recurrence schedule for a campaign.
|
"""Configure or update the recurrence schedule for a campaign.
|
||||||
|
|
||||||
Only the campaign creator or admin can change scheduling.
|
Only the campaign creator or admin can change scheduling.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
campaign_id (str): UUID string of the campaign to schedule.
|
||||||
|
payload (SchedulePayload): Recurrence flag, pattern, and next run timestamp.
|
||||||
|
db (Session): SQLAlchemy database session.
|
||||||
|
current_user (User): Authenticated red_lead or blue_lead (must be owner or admin).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Serialised representation of the campaign with updated schedule fields.
|
||||||
"""
|
"""
|
||||||
|
# Open context manager
|
||||||
with UnitOfWork(db) as uow:
|
with UnitOfWork(db) as uow:
|
||||||
|
# Assign campaign = crud_schedule(
|
||||||
campaign = crud_schedule(
|
campaign = crud_schedule(
|
||||||
db,
|
db,
|
||||||
campaign_id,
|
campaign_id,
|
||||||
|
# Keyword argument: owner_id
|
||||||
owner_id=current_user.id,
|
owner_id=current_user.id,
|
||||||
|
# Keyword argument: owner_role
|
||||||
owner_role=current_user.role,
|
owner_role=current_user.role,
|
||||||
|
# Keyword argument: is_recurring
|
||||||
is_recurring=payload.is_recurring,
|
is_recurring=payload.is_recurring,
|
||||||
|
# Keyword argument: recurrence_pattern
|
||||||
recurrence_pattern=payload.recurrence_pattern,
|
recurrence_pattern=payload.recurrence_pattern,
|
||||||
|
# Keyword argument: next_run_at
|
||||||
next_run_at=payload.next_run_at,
|
next_run_at=payload.next_run_at,
|
||||||
)
|
)
|
||||||
|
# Call log_action()
|
||||||
log_action(
|
log_action(
|
||||||
db,
|
db,
|
||||||
|
# Keyword argument: user_id
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
|
# Keyword argument: action
|
||||||
action="schedule_campaign",
|
action="schedule_campaign",
|
||||||
|
# Keyword argument: entity_type
|
||||||
entity_type="campaign",
|
entity_type="campaign",
|
||||||
|
# Keyword argument: entity_id
|
||||||
entity_id=campaign.id,
|
entity_id=campaign.id,
|
||||||
|
# Keyword argument: details
|
||||||
details={
|
details={
|
||||||
|
# Literal argument value
|
||||||
"is_recurring": campaign.is_recurring,
|
"is_recurring": campaign.is_recurring,
|
||||||
|
# Literal argument value
|
||||||
"recurrence_pattern": campaign.recurrence_pattern,
|
"recurrence_pattern": campaign.recurrence_pattern,
|
||||||
|
# Literal argument value
|
||||||
"next_run_at": campaign.next_run_at.isoformat() if campaign.next_run_at else None,
|
"next_run_at": campaign.next_run_at.isoformat() if campaign.next_run_at else None,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
# Call uow.commit()
|
||||||
uow.commit()
|
uow.commit()
|
||||||
|
# Reload ORM object attributes from the database
|
||||||
db.refresh(campaign)
|
db.refresh(campaign)
|
||||||
|
|
||||||
|
# Return serialize_campaign(db, campaign)
|
||||||
return serialize_campaign(db, campaign)
|
return serialize_campaign(db, campaign)
|
||||||
|
|
||||||
|
|
||||||
@@ -488,12 +828,26 @@ def schedule_campaign(
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
@router.get("/{campaign_id}/history")
|
@router.get("/{campaign_id}/history")
|
||||||
|
# Define function get_campaign_history
|
||||||
def get_campaign_history(
|
def get_campaign_history(
|
||||||
|
# Entry: campaign_id
|
||||||
campaign_id: str,
|
campaign_id: str,
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
) -> list:
|
||||||
"""List all child campaigns (execution history) of a recurring campaign."""
|
"""List all child campaigns (execution history) of a recurring campaign.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
campaign_id (str): UUID string of the parent recurring campaign.
|
||||||
|
db (Session): SQLAlchemy database session.
|
||||||
|
current_user (User): Authenticated user making the request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: Serialised list of child campaign dicts ordered by creation date.
|
||||||
|
"""
|
||||||
|
# Return crud_get_history(db, campaign_id)
|
||||||
return crud_get_history(db, campaign_id)
|
return crud_get_history(db, campaign_id)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,32 +1,45 @@
|
|||||||
"""Compliance endpoints — framework status, reports, and gap analysis.
|
"""Compliance endpoints — framework status, reports, and gap analysis.
|
||||||
|
|
||||||
Thin HTTP adapter: delegates all data logic to compliance_service.
|
Thin HTTP adapter that delegates all data logic to compliance_service.
|
||||||
|
|
||||||
Provides compliance posture assessment by mapping MITRE ATT&CK technique
|
Provides compliance posture assessment by mapping MITRE ATT&CK technique
|
||||||
coverage to compliance framework controls.
|
coverage to compliance framework controls.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Import APIRouter, Depends from fastapi
|
||||||
from fastapi import APIRouter, Depends
|
from fastapi import APIRouter, Depends
|
||||||
|
|
||||||
|
# Import StreamingResponse from fastapi.responses
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
|
# Import Session from sqlalchemy.orm
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
# Import get_db from app.database
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
|
|
||||||
|
# Import get_current_user, require_role from app.dependencies.auth
|
||||||
from app.dependencies.auth import get_current_user, require_role
|
from app.dependencies.auth import get_current_user, require_role
|
||||||
|
|
||||||
|
# Import User from app.models.user
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.services.compliance_service import (
|
|
||||||
list_frameworks,
|
# Import from app.services.compliance_import_service
|
||||||
get_framework_status,
|
|
||||||
build_framework_report_csv,
|
|
||||||
get_framework_gaps,
|
|
||||||
)
|
|
||||||
from app.services.compliance_import_service import (
|
from app.services.compliance_import_service import (
|
||||||
import_nist_800_53_mappings,
|
|
||||||
import_cis_controls_v8_mappings,
|
import_cis_controls_v8_mappings,
|
||||||
import_dora_mappings,
|
import_dora_mappings,
|
||||||
import_iso_27001_mappings,
|
import_iso_27001_mappings,
|
||||||
import_iso_42001_mappings,
|
import_iso_42001_mappings,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Import from app.services.compliance_service
|
||||||
|
from app.services.compliance_service import (
|
||||||
|
build_framework_report_csv,
|
||||||
|
get_framework_gaps,
|
||||||
|
get_framework_status,
|
||||||
|
list_frameworks,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assign router = APIRouter(prefix="/compliance", tags=["compliance"])
|
||||||
router = APIRouter(prefix="/compliance", tags=["compliance"])
|
router = APIRouter(prefix="/compliance", tags=["compliance"])
|
||||||
|
|
||||||
|
|
||||||
@@ -34,11 +47,23 @@ router = APIRouter(prefix="/compliance", tags=["compliance"])
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/frameworks")
|
@router.get("/frameworks")
|
||||||
|
# Define function list_frameworks_endpoint
|
||||||
def list_frameworks_endpoint(
|
def list_frameworks_endpoint(
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
) -> list:
|
||||||
"""List all available compliance frameworks."""
|
"""List all available compliance frameworks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db (Session): SQLAlchemy database session.
|
||||||
|
current_user (User): Authenticated user making the request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: List of framework summary dicts containing id, name, and control counts.
|
||||||
|
"""
|
||||||
|
# Return list_frameworks(db)
|
||||||
return list_frameworks(db)
|
return list_frameworks(db)
|
||||||
|
|
||||||
|
|
||||||
@@ -46,12 +71,26 @@ def list_frameworks_endpoint(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/frameworks/{framework_id}/status")
|
@router.get("/frameworks/{framework_id}/status")
|
||||||
|
# Define function framework_status
|
||||||
def framework_status(
|
def framework_status(
|
||||||
|
# Entry: framework_id
|
||||||
framework_id: str,
|
framework_id: str,
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
) -> dict:
|
||||||
"""Get compliance status for each control in a framework."""
|
"""Get compliance status for each control in a framework.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
framework_id (str): Identifier of the compliance framework (e.g. ``nist-800-53``).
|
||||||
|
db (Session): SQLAlchemy database session.
|
||||||
|
current_user (User): Authenticated user making the request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Mapping of control IDs to their coverage status and linked techniques.
|
||||||
|
"""
|
||||||
|
# Return get_framework_status(db, framework_id)
|
||||||
return get_framework_status(db, framework_id)
|
return get_framework_status(db, framework_id)
|
||||||
|
|
||||||
|
|
||||||
@@ -59,12 +98,26 @@ def framework_status(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/frameworks/{framework_id}/report")
|
@router.get("/frameworks/{framework_id}/report")
|
||||||
|
# Define function framework_report
|
||||||
def framework_report(
|
def framework_report(
|
||||||
|
# Entry: framework_id
|
||||||
framework_id: str,
|
framework_id: str,
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
) -> dict:
|
||||||
"""Get the full compliance report (same as status but marked as report)."""
|
"""Get the full compliance report (same as status but marked as report).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
framework_id (str): Identifier of the compliance framework.
|
||||||
|
db (Session): SQLAlchemy database session.
|
||||||
|
current_user (User): Authenticated user making the request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Full compliance report with per-control coverage details.
|
||||||
|
"""
|
||||||
|
# Return get_framework_status(db, framework_id)
|
||||||
return get_framework_status(db, framework_id)
|
return get_framework_status(db, framework_id)
|
||||||
|
|
||||||
|
|
||||||
@@ -72,17 +125,35 @@ def framework_report(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/frameworks/{framework_id}/report/csv")
|
@router.get("/frameworks/{framework_id}/report/csv")
|
||||||
|
# Define function framework_report_csv
|
||||||
def framework_report_csv(
|
def framework_report_csv(
|
||||||
|
# Entry: framework_id
|
||||||
framework_id: str,
|
framework_id: str,
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
) -> StreamingResponse:
|
||||||
"""Export compliance report as CSV."""
|
"""Export compliance report as CSV.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
framework_id (str): Identifier of the compliance framework to export.
|
||||||
|
db (Session): SQLAlchemy database session.
|
||||||
|
current_user (User): Authenticated user making the request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
StreamingResponse: CSV file attachment with compliance coverage data.
|
||||||
|
"""
|
||||||
|
# csv_bytes, filename = build_framework_report_csv(db, framework_id)
|
||||||
csv_bytes, filename = build_framework_report_csv(db, framework_id)
|
csv_bytes, filename = build_framework_report_csv(db, framework_id)
|
||||||
|
# Return StreamingResponse(
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
iter([csv_bytes]),
|
iter([csv_bytes]),
|
||||||
|
# Keyword argument: media_type
|
||||||
media_type="text/csv",
|
media_type="text/csv",
|
||||||
|
# Keyword argument: headers
|
||||||
headers={
|
headers={
|
||||||
|
# Literal argument value
|
||||||
"Content-Disposition": f"attachment; filename={filename}",
|
"Content-Disposition": f"attachment; filename={filename}",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@@ -92,12 +163,26 @@ def framework_report_csv(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/frameworks/{framework_id}/gaps")
|
@router.get("/frameworks/{framework_id}/gaps")
|
||||||
|
# Define function framework_gaps
|
||||||
def framework_gaps(
|
def framework_gaps(
|
||||||
|
# Entry: framework_id
|
||||||
framework_id: str,
|
framework_id: str,
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
) -> dict:
|
||||||
"""Get controls with techniques that are not adequately covered."""
|
"""Get controls with techniques that are not adequately covered.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
framework_id (str): Identifier of the compliance framework to analyse.
|
||||||
|
db (Session): SQLAlchemy database session.
|
||||||
|
current_user (User): Authenticated user making the request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Controls flagged as gaps, with linked technique IDs and coverage ratios.
|
||||||
|
"""
|
||||||
|
# Return get_framework_gaps(db, framework_id)
|
||||||
return get_framework_gaps(db, framework_id)
|
return get_framework_gaps(db, framework_id)
|
||||||
|
|
||||||
|
|
||||||
@@ -105,22 +190,49 @@ def framework_gaps(
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/import/nist-800-53")
|
@router.post("/import/nist-800-53")
|
||||||
|
# Define function import_nist
|
||||||
def import_nist(
|
def import_nist(
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(require_role("admin")),
|
current_user: User = Depends(require_role("admin")),
|
||||||
):
|
) -> dict:
|
||||||
"""Import NIST 800-53 Rev 5 mappings (admin only)."""
|
"""Import NIST 800-53 Rev 5 mappings (admin only).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db (Session): SQLAlchemy database session.
|
||||||
|
current_user (User): Authenticated admin user.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Import result with counts of created and updated control mappings.
|
||||||
|
"""
|
||||||
|
# Assign result = import_nist_800_53_mappings(db)
|
||||||
result = import_nist_800_53_mappings(db)
|
result = import_nist_800_53_mappings(db)
|
||||||
|
# Return result
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @router.post decorator
|
||||||
@router.post("/import/cis-controls-v8")
|
@router.post("/import/cis-controls-v8")
|
||||||
|
# Define function import_cis
|
||||||
def import_cis(
|
def import_cis(
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(require_role("admin")),
|
current_user: User = Depends(require_role("admin")),
|
||||||
):
|
) -> dict:
|
||||||
"""Import CIS Controls v8 mappings (admin only)."""
|
"""Import CIS Controls v8 mappings (admin only).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db (Session): SQLAlchemy database session.
|
||||||
|
current_user (User): Authenticated admin user.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Import result with counts of created and updated control mappings.
|
||||||
|
"""
|
||||||
|
# Assign result = import_cis_controls_v8_mappings(db)
|
||||||
result = import_cis_controls_v8_mappings(db)
|
result = import_cis_controls_v8_mappings(db)
|
||||||
|
# Return result
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,26 +1,47 @@
|
|||||||
"""D3FEND endpoints — defensive technique listings, mappings, and import trigger."""
|
"""D3FEND endpoints — defensive technique listings, mappings, and import trigger."""
|
||||||
|
|
||||||
|
# Import logging
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
# Import Optional from typing
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
# Import APIRouter, Depends, Query from fastapi
|
||||||
from fastapi import APIRouter, Depends, Query
|
from fastapi import APIRouter, Depends, Query
|
||||||
|
|
||||||
|
# Import Session from sqlalchemy.orm
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
# Import get_db from app.database
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
|
|
||||||
|
# Import get_current_user, require_role from app.dependencies.auth
|
||||||
from app.dependencies.auth import get_current_user, require_role
|
from app.dependencies.auth import get_current_user, require_role
|
||||||
|
|
||||||
|
# Import User from app.models.user
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
|
|
||||||
|
# Import from app.services.d3fend_import_service
|
||||||
from app.services.d3fend_import_service import (
|
from app.services.d3fend_import_service import (
|
||||||
import_d3fend_techniques,
|
|
||||||
import_d3fend_mappings,
|
import_d3fend_mappings,
|
||||||
|
import_d3fend_techniques,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Import from app.services.d3fend_query_service
|
||||||
|
from app.services.d3fend_query_service import (
|
||||||
|
get_defenses_for_attack_technique,
|
||||||
|
list_d3fend_tactics,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Import from app.services.d3fend_query_service
|
||||||
from app.services.d3fend_query_service import (
|
from app.services.d3fend_query_service import (
|
||||||
list_defensive_techniques as list_defensive_techniques_svc,
|
list_defensive_techniques as list_defensive_techniques_svc,
|
||||||
list_d3fend_tactics,
|
|
||||||
get_defenses_for_attack_technique,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Assign logger = logging.getLogger(__name__)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Assign router = APIRouter(prefix="/d3fend", tags=["d3fend"])
|
||||||
router = APIRouter(prefix="/d3fend", tags=["d3fend"])
|
router = APIRouter(prefix="/d3fend", tags=["d3fend"])
|
||||||
|
|
||||||
|
|
||||||
@@ -29,15 +50,23 @@ router = APIRouter(prefix="/d3fend", tags=["d3fend"])
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
@router.get("")
|
@router.get("")
|
||||||
|
# Define function list_defensive_techniques
|
||||||
def list_defensive_techniques(
|
def list_defensive_techniques(
|
||||||
|
# Entry: tactic
|
||||||
tactic: Optional[str] = Query(None),
|
tactic: Optional[str] = Query(None),
|
||||||
|
# Entry: search
|
||||||
search: Optional[str] = Query(None),
|
search: Optional[str] = Query(None),
|
||||||
|
# Entry: offset
|
||||||
offset: int = Query(0, ge=0),
|
offset: int = Query(0, ge=0),
|
||||||
|
# Entry: limit
|
||||||
limit: int = Query(50, ge=1, le=200),
|
limit: int = Query(50, ge=1, le=200),
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
) -> list:
|
||||||
"""List all D3FEND defensive techniques with optional filters."""
|
"""List all D3FEND defensive techniques with optional filters."""
|
||||||
|
# Return list_defensive_techniques_svc(
|
||||||
return list_defensive_techniques_svc(
|
return list_defensive_techniques_svc(
|
||||||
db, tactic=tactic, search=search, offset=offset, limit=limit
|
db, tactic=tactic, search=search, offset=offset, limit=limit
|
||||||
)
|
)
|
||||||
@@ -48,11 +77,15 @@ def list_defensive_techniques(
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
@router.get("/tactics")
|
@router.get("/tactics")
|
||||||
|
# Define function list_d3fend_tactics_endpoint
|
||||||
def list_d3fend_tactics_endpoint(
|
def list_d3fend_tactics_endpoint(
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
) -> list:
|
||||||
"""Return a list of all D3FEND tactics with counts."""
|
"""Return a list of all D3FEND tactics with counts."""
|
||||||
|
# Return list_d3fend_tactics(db)
|
||||||
return list_d3fend_tactics(db)
|
return list_d3fend_tactics(db)
|
||||||
|
|
||||||
|
|
||||||
@@ -61,12 +94,17 @@ def list_d3fend_tactics_endpoint(
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
@router.get("/for-technique/{mitre_id}")
|
@router.get("/for-technique/{mitre_id}")
|
||||||
|
# Define function get_defenses_for_attack_technique_endpoint
|
||||||
def get_defenses_for_attack_technique_endpoint(
|
def get_defenses_for_attack_technique_endpoint(
|
||||||
|
# Entry: mitre_id
|
||||||
mitre_id: str,
|
mitre_id: str,
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
) -> list:
|
||||||
"""Get all D3FEND defensive techniques mapped to a given ATT&CK technique."""
|
"""Get all D3FEND defensive techniques mapped to a given ATT&CK technique."""
|
||||||
|
# Return get_defenses_for_attack_technique(db, mitre_id)
|
||||||
return get_defenses_for_attack_technique(db, mitre_id)
|
return get_defenses_for_attack_technique(db, mitre_id)
|
||||||
|
|
||||||
|
|
||||||
@@ -75,15 +113,23 @@ def get_defenses_for_attack_technique_endpoint(
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
@router.post("/import")
|
@router.post("/import")
|
||||||
|
# Define function trigger_d3fend_import
|
||||||
def trigger_d3fend_import(
|
def trigger_d3fend_import(
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(require_role("admin")),
|
current_user: User = Depends(require_role("admin")),
|
||||||
):
|
) -> dict:
|
||||||
"""Import D3FEND techniques and ATT&CK mappings. Admin only."""
|
"""Import D3FEND techniques and ATT&CK mappings. Admin only."""
|
||||||
|
# Assign tech_result = import_d3fend_techniques(db)
|
||||||
tech_result = import_d3fend_techniques(db)
|
tech_result = import_d3fend_techniques(db)
|
||||||
|
# Assign mapping_result = import_d3fend_mappings(db)
|
||||||
mapping_result = import_d3fend_mappings(db)
|
mapping_result = import_d3fend_mappings(db)
|
||||||
|
|
||||||
|
# Return {
|
||||||
return {
|
return {
|
||||||
|
# Literal argument value
|
||||||
"techniques": tech_result,
|
"techniques": tech_result,
|
||||||
|
# Literal argument value
|
||||||
"mappings": mapping_result,
|
"mappings": mapping_result,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,16 +5,34 @@ Provides a centralized panel for managing all external data sources
|
|||||||
including sync triggers, enable/disable toggles, and statistics.
|
including sync triggers, enable/disable toggles, and statistics.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends
|
# Import Optional from typing
|
||||||
from pydantic import BaseModel
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
# Import APIRouter, Depends from fastapi
|
||||||
|
from fastapi import APIRouter, Depends
|
||||||
|
|
||||||
|
# Import BaseModel from pydantic
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
# Import Session from sqlalchemy.orm
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
# Import get_db from app.database
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
|
|
||||||
|
# Import require_role from app.dependencies.auth
|
||||||
from app.dependencies.auth import require_role
|
from app.dependencies.auth import require_role
|
||||||
|
|
||||||
|
# Import UnitOfWork from app.domain.unit_of_work
|
||||||
from app.domain.unit_of_work import UnitOfWork
|
from app.domain.unit_of_work import UnitOfWork
|
||||||
|
|
||||||
|
# Import User from app.models.user
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
|
|
||||||
|
# Import log_action from app.services.audit_service
|
||||||
from app.services.audit_service import log_action
|
from app.services.audit_service import log_action
|
||||||
|
|
||||||
|
# Import from app.services.data_source_service
|
||||||
from app.services.data_source_service import (
|
from app.services.data_source_service import (
|
||||||
get_source_stats,
|
get_source_stats,
|
||||||
list_sources,
|
list_sources,
|
||||||
@@ -23,18 +41,21 @@ from app.services.data_source_service import (
|
|||||||
update_source,
|
update_source,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Pydantic schemas for request validation
|
# Pydantic schemas for request validation
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
class DataSourceUpdate(BaseModel):
|
class DataSourceUpdate(BaseModel):
|
||||||
"""Payload for updating a data source — only allowed fields."""
|
"""Payload for updating a data source — only allowed fields."""
|
||||||
|
# Assign is_enabled = None
|
||||||
is_enabled: Optional[bool] = None
|
is_enabled: Optional[bool] = None
|
||||||
|
# Assign sync_frequency = None
|
||||||
sync_frequency: Optional[str] = None
|
sync_frequency: Optional[str] = None
|
||||||
|
# Assign config = None
|
||||||
config: Optional[dict] = None
|
config: Optional[dict] = None
|
||||||
|
|
||||||
|
|
||||||
|
# Assign router = APIRouter(prefix="/data-sources", tags=["data-sources"])
|
||||||
router = APIRouter(prefix="/data-sources", tags=["data-sources"])
|
router = APIRouter(prefix="/data-sources", tags=["data-sources"])
|
||||||
|
|
||||||
|
|
||||||
@@ -44,90 +65,137 @@ router = APIRouter(prefix="/data-sources", tags=["data-sources"])
|
|||||||
|
|
||||||
|
|
||||||
@router.get("")
|
@router.get("")
|
||||||
|
# Define function list_data_sources
|
||||||
def list_data_sources(
|
def list_data_sources(
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(require_role("admin")),
|
current_user: User = Depends(require_role("admin")),
|
||||||
):
|
) -> list:
|
||||||
"""List all registered data sources.
|
"""List all registered data sources.
|
||||||
|
|
||||||
**Requires** the ``admin`` role.
|
**Requires** the ``admin`` role.
|
||||||
"""
|
"""
|
||||||
|
# Return list_sources(db)
|
||||||
return list_sources(db)
|
return list_sources(db)
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @router.patch decorator
|
||||||
@router.patch("/{source_id}")
|
@router.patch("/{source_id}")
|
||||||
|
# Define function update_data_source
|
||||||
def update_data_source(
|
def update_data_source(
|
||||||
|
# Entry: source_id
|
||||||
source_id: str,
|
source_id: str,
|
||||||
|
# Entry: body
|
||||||
body: DataSourceUpdate,
|
body: DataSourceUpdate,
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(require_role("admin")),
|
current_user: User = Depends(require_role("admin")),
|
||||||
):
|
) -> dict:
|
||||||
"""Update a data source (enable/disable, change config).
|
"""Update a data source (enable/disable, change config).
|
||||||
|
|
||||||
**Requires** the ``admin`` role.
|
**Requires** the ``admin`` role.
|
||||||
"""
|
"""
|
||||||
|
# Assign update_data = body.model_dump(exclude_unset=True)
|
||||||
update_data = body.model_dump(exclude_unset=True)
|
update_data = body.model_dump(exclude_unset=True)
|
||||||
|
# Open context manager
|
||||||
with UnitOfWork(db) as uow:
|
with UnitOfWork(db) as uow:
|
||||||
|
# Call update_source()
|
||||||
update_source(db, source_id, **update_data)
|
update_source(db, source_id, **update_data)
|
||||||
|
# Call log_action()
|
||||||
log_action(
|
log_action(
|
||||||
db,
|
db,
|
||||||
|
# Keyword argument: user_id
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
|
# Keyword argument: action
|
||||||
action="update_data_source",
|
action="update_data_source",
|
||||||
|
# Keyword argument: entity_type
|
||||||
entity_type="data_source",
|
entity_type="data_source",
|
||||||
|
# Keyword argument: entity_id
|
||||||
entity_id=source_id,
|
entity_id=source_id,
|
||||||
|
# Keyword argument: details
|
||||||
details={"updates": update_data},
|
details={"updates": update_data},
|
||||||
)
|
)
|
||||||
|
# Call uow.commit()
|
||||||
uow.commit()
|
uow.commit()
|
||||||
|
|
||||||
|
# Return {"message": "Data source updated", "id": source_id}
|
||||||
return {"message": "Data source updated", "id": source_id}
|
return {"message": "Data source updated", "id": source_id}
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @router.post decorator
|
||||||
@router.post("/{source_id}/sync")
|
@router.post("/{source_id}/sync")
|
||||||
|
# Define function sync_data_source
|
||||||
def sync_data_source(
|
def sync_data_source(
|
||||||
|
# Entry: source_id
|
||||||
source_id: str,
|
source_id: str,
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(require_role("admin")),
|
current_user: User = Depends(require_role("admin")),
|
||||||
):
|
) -> dict:
|
||||||
"""Trigger sync/import for a specific data source.
|
"""Trigger sync/import for a specific data source.
|
||||||
|
|
||||||
**Requires** the ``admin`` role.
|
**Requires** the ``admin`` role.
|
||||||
"""
|
"""
|
||||||
|
# Return sync_source(db, source_id)
|
||||||
return sync_source(db, source_id)
|
return sync_source(db, source_id)
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @router.post decorator
|
||||||
@router.post("/sync-all")
|
@router.post("/sync-all")
|
||||||
|
# Define function sync_all_data_sources
|
||||||
def sync_all_data_sources(
|
def sync_all_data_sources(
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(require_role("admin")),
|
current_user: User = Depends(require_role("admin")),
|
||||||
):
|
) -> dict:
|
||||||
"""Trigger sync for all enabled data sources (sequentially).
|
"""Trigger sync for all enabled data sources (sequentially).
|
||||||
|
|
||||||
**Requires** the ``admin`` role.
|
**Requires** the ``admin`` role.
|
||||||
"""
|
"""
|
||||||
|
# Assign results = sync_all_sources(db)
|
||||||
results = sync_all_sources(db)
|
results = sync_all_sources(db)
|
||||||
|
|
||||||
|
# Open context manager
|
||||||
with UnitOfWork(db) as uow:
|
with UnitOfWork(db) as uow:
|
||||||
|
# Call log_action()
|
||||||
log_action(
|
log_action(
|
||||||
db,
|
db,
|
||||||
|
# Keyword argument: user_id
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
|
# Keyword argument: action
|
||||||
action="sync_all_data_sources",
|
action="sync_all_data_sources",
|
||||||
|
# Keyword argument: entity_type
|
||||||
entity_type="data_source",
|
entity_type="data_source",
|
||||||
|
# Keyword argument: entity_id
|
||||||
entity_id=None,
|
entity_id=None,
|
||||||
|
# Keyword argument: details
|
||||||
details={"results": results},
|
details={"results": results},
|
||||||
)
|
)
|
||||||
|
# Call uow.commit()
|
||||||
uow.commit()
|
uow.commit()
|
||||||
|
|
||||||
|
# Return {"message": "Sync all complete", "results": results}
|
||||||
return {"message": "Sync all complete", "results": results}
|
return {"message": "Sync all complete", "results": results}
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @router.get decorator
|
||||||
@router.get("/{source_id}/stats")
|
@router.get("/{source_id}/stats")
|
||||||
|
# Define function get_data_source_stats
|
||||||
def get_data_source_stats(
|
def get_data_source_stats(
|
||||||
|
# Entry: source_id
|
||||||
source_id: str,
|
source_id: str,
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(require_role("admin")),
|
current_user: User = Depends(require_role("admin")),
|
||||||
):
|
) -> dict:
|
||||||
"""Get detailed statistics for a specific data source.
|
"""Get detailed statistics for a specific data source.
|
||||||
|
|
||||||
**Requires** the ``admin`` role.
|
**Requires** the ``admin`` role.
|
||||||
"""
|
"""
|
||||||
|
# Return get_source_stats(db, source_id)
|
||||||
return get_source_stats(db, source_id)
|
return get_source_stats(db, source_id)
|
||||||
|
|||||||
@@ -6,36 +6,55 @@ Provides endpoints for browsing detection rules, querying rules by technique,
|
|||||||
and managing the template ↔ detection rule associations.
|
and managing the template ↔ detection rule associations.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Import uuid
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
# Import Optional from typing
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
# Import APIRouter, Depends, Query from fastapi
|
||||||
from fastapi import APIRouter, Depends, Query
|
from fastapi import APIRouter, Depends, Query
|
||||||
|
|
||||||
|
# Import BaseModel from pydantic
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
# Import Session from sqlalchemy.orm
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
# Import get_db from app.database
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
from app.dependencies.auth import get_current_user, require_role, require_any_role
|
|
||||||
from app.models.user import User
|
|
||||||
from app.services.detection_rule_service import (
|
|
||||||
list_rules,
|
|
||||||
get_rules_for_template,
|
|
||||||
auto_associate_rules,
|
|
||||||
get_rules_for_test,
|
|
||||||
evaluate_rule,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
# Import get_current_user, require_any_role, require_role from app.dependencies.auth
|
||||||
|
from app.dependencies.auth import get_current_user, require_any_role, require_role
|
||||||
|
|
||||||
|
# Import User from app.models.user
|
||||||
|
from app.models.user import User
|
||||||
|
|
||||||
|
# Import from app.services.detection_rule_service
|
||||||
|
from app.services.detection_rule_service import (
|
||||||
|
auto_associate_rules,
|
||||||
|
evaluate_rule,
|
||||||
|
get_rules_for_template,
|
||||||
|
get_rules_for_test,
|
||||||
|
list_rules,
|
||||||
|
)
|
||||||
|
|
||||||
# ── Pydantic schemas for request validation ────────────────────────────
|
# ── Pydantic schemas for request validation ────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
class DetectionRuleEvaluate(BaseModel):
|
class DetectionRuleEvaluate(BaseModel):
|
||||||
"""Payload for evaluating a detection rule against a test."""
|
"""Payload for evaluating a detection rule against a test."""
|
||||||
|
# test_id: uuid.UUID
|
||||||
test_id: uuid.UUID
|
test_id: uuid.UUID
|
||||||
|
# detection_rule_id: uuid.UUID
|
||||||
detection_rule_id: uuid.UUID
|
detection_rule_id: uuid.UUID
|
||||||
|
# Assign triggered = None
|
||||||
triggered: Optional[bool] = None
|
triggered: Optional[bool] = None
|
||||||
|
# Assign notes = None
|
||||||
notes: Optional[str] = None
|
notes: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
# Assign router = APIRouter(prefix="/detection-rules", tags=["detection-rules"])
|
||||||
router = APIRouter(prefix="/detection-rules", tags=["detection-rules"])
|
router = APIRouter(prefix="/detection-rules", tags=["detection-rules"])
|
||||||
|
|
||||||
|
|
||||||
@@ -43,24 +62,40 @@ router = APIRouter(prefix="/detection-rules", tags=["detection-rules"])
|
|||||||
|
|
||||||
|
|
||||||
@router.get("")
|
@router.get("")
|
||||||
|
# Define function list_detection_rules
|
||||||
def list_detection_rules(
|
def list_detection_rules(
|
||||||
|
# Entry: technique
|
||||||
technique: Optional[str] = Query(None, description="Filter by MITRE technique ID"),
|
technique: Optional[str] = Query(None, description="Filter by MITRE technique ID"),
|
||||||
|
# Entry: source
|
||||||
source: Optional[str] = Query(None, description="Filter by source (sigma, elastic, splunk, custom)"),
|
source: Optional[str] = Query(None, description="Filter by source (sigma, elastic, splunk, custom)"),
|
||||||
|
# Entry: severity
|
||||||
severity: Optional[str] = Query(None),
|
severity: Optional[str] = Query(None),
|
||||||
|
# Entry: search
|
||||||
search: Optional[str] = Query(None),
|
search: Optional[str] = Query(None),
|
||||||
|
# Entry: offset
|
||||||
offset: int = Query(0, ge=0),
|
offset: int = Query(0, ge=0),
|
||||||
|
# Entry: limit
|
||||||
limit: int = Query(50, ge=1, le=200),
|
limit: int = Query(50, ge=1, le=200),
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
) -> list:
|
||||||
"""List detection rules with optional filters and pagination."""
|
"""List detection rules with optional filters and pagination."""
|
||||||
|
# Return list_rules(
|
||||||
return list_rules(
|
return list_rules(
|
||||||
db,
|
db,
|
||||||
|
# Keyword argument: technique
|
||||||
technique=technique,
|
technique=technique,
|
||||||
|
# Keyword argument: source
|
||||||
source=source,
|
source=source,
|
||||||
|
# Keyword argument: severity
|
||||||
severity=severity,
|
severity=severity,
|
||||||
|
# Keyword argument: search
|
||||||
search=search,
|
search=search,
|
||||||
|
# Keyword argument: offset
|
||||||
offset=offset,
|
offset=offset,
|
||||||
|
# Keyword argument: limit
|
||||||
limit=limit,
|
limit=limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -69,12 +104,17 @@ def list_detection_rules(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/for-template/{template_id}")
|
@router.get("/for-template/{template_id}")
|
||||||
|
# Define function get_detection_rules_for_template
|
||||||
def get_detection_rules_for_template(
|
def get_detection_rules_for_template(
|
||||||
|
# Entry: template_id
|
||||||
template_id: str,
|
template_id: str,
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
) -> list:
|
||||||
"""Get detection rules associated with a test template."""
|
"""Get detection rules associated with a test template."""
|
||||||
|
# Return get_rules_for_template(db, template_id)
|
||||||
return get_rules_for_template(db, template_id)
|
return get_rules_for_template(db, template_id)
|
||||||
|
|
||||||
|
|
||||||
@@ -82,16 +122,20 @@ def get_detection_rules_for_template(
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/auto-associate")
|
@router.post("/auto-associate")
|
||||||
|
# Define function auto_associate_detection_rules
|
||||||
def auto_associate_detection_rules(
|
def auto_associate_detection_rules(
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(require_role("admin")),
|
current_user: User = Depends(require_role("admin")),
|
||||||
):
|
) -> dict:
|
||||||
"""Auto-associate test templates with detection rules by MITRE technique ID.
|
"""Auto-associate test templates with detection rules by MITRE technique ID.
|
||||||
|
|
||||||
For each active template, find all active detection rules for the same
|
For each active template, find all active detection rules for the same
|
||||||
technique and create associations. Rules with severity >= high are marked
|
technique and create associations. Rules with severity >= high are marked
|
||||||
as primary.
|
as primary.
|
||||||
"""
|
"""
|
||||||
|
# Return auto_associate_rules(db)
|
||||||
return auto_associate_rules(db)
|
return auto_associate_rules(db)
|
||||||
|
|
||||||
|
|
||||||
@@ -99,16 +143,21 @@ def auto_associate_detection_rules(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/for-test/{test_id}")
|
@router.get("/for-test/{test_id}")
|
||||||
|
# Define function get_detection_rules_for_test
|
||||||
def get_detection_rules_for_test(
|
def get_detection_rules_for_test(
|
||||||
|
# Entry: test_id
|
||||||
test_id: str,
|
test_id: str,
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
) -> list:
|
||||||
"""Get detection rules relevant to a test, along with their evaluation results.
|
"""Get detection rules relevant to a test, along with their evaluation results.
|
||||||
|
|
||||||
Finds rules by matching the test's technique_id to detection rules,
|
Finds rules by matching the test's technique_id to detection rules,
|
||||||
and returns any existing evaluation results.
|
and returns any existing evaluation results.
|
||||||
"""
|
"""
|
||||||
|
# Return get_rules_for_test(db, test_id)
|
||||||
return get_rules_for_test(db, test_id)
|
return get_rules_for_test(db, test_id)
|
||||||
|
|
||||||
|
|
||||||
@@ -116,17 +165,27 @@ def get_detection_rules_for_test(
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/evaluate")
|
@router.post("/evaluate")
|
||||||
|
# Define function evaluate_detection_rule
|
||||||
def evaluate_detection_rule(
|
def evaluate_detection_rule(
|
||||||
|
# Entry: payload
|
||||||
payload: DetectionRuleEvaluate,
|
payload: DetectionRuleEvaluate,
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(require_any_role("blue_tech", "blue_lead")),
|
current_user: User = Depends(require_any_role("blue_tech", "blue_lead")),
|
||||||
):
|
) -> dict:
|
||||||
"""Save or update the evaluation result for a detection rule on a test."""
|
"""Save or update the evaluation result for a detection rule on a test."""
|
||||||
|
# Return evaluate_rule(
|
||||||
return evaluate_rule(
|
return evaluate_rule(
|
||||||
db,
|
db,
|
||||||
|
# Keyword argument: test_id
|
||||||
test_id=payload.test_id,
|
test_id=payload.test_id,
|
||||||
|
# Keyword argument: detection_rule_id
|
||||||
detection_rule_id=payload.detection_rule_id,
|
detection_rule_id=payload.detection_rule_id,
|
||||||
|
# Keyword argument: triggered
|
||||||
triggered=payload.triggered,
|
triggered=payload.triggered,
|
||||||
|
# Keyword argument: notes
|
||||||
notes=payload.notes,
|
notes=payload.notes,
|
||||||
|
# Keyword argument: evaluator_id
|
||||||
evaluator_id=current_user.id,
|
evaluator_id=current_user.id,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -20,30 +20,54 @@ Access Control
|
|||||||
``validated``, or ``rejected``.
|
``validated``, or ``rejected``.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Import hashlib
|
||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
# Import uuid
|
||||||
import uuid as _uuid
|
import uuid as _uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
# Import APIRouter, Depends, File, Form, Query, Request,... from fastapi
|
||||||
from fastapi import APIRouter, Depends, File, Form, Query, Request, UploadFile, status
|
from fastapi import APIRouter, Depends, File, Form, Query, Request, UploadFile, status
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
# Import get_db from app.database
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
from app.domain.unit_of_work import UnitOfWork
|
|
||||||
|
# Import get_current_user from app.dependencies.auth
|
||||||
from app.dependencies.auth import get_current_user
|
from app.dependencies.auth import get_current_user
|
||||||
|
|
||||||
|
# Import UnitOfWork from app.domain.unit_of_work
|
||||||
|
from app.domain.unit_of_work import UnitOfWork
|
||||||
|
|
||||||
|
# Import limiter from app.limiter
|
||||||
|
from app.limiter import limiter
|
||||||
|
|
||||||
|
# Import TeamSide from app.models.enums
|
||||||
from app.models.enums import TeamSide
|
from app.models.enums import TeamSide
|
||||||
|
|
||||||
|
# Import Evidence from app.models.evidence
|
||||||
from app.models.evidence import Evidence
|
from app.models.evidence import Evidence
|
||||||
|
|
||||||
|
# Import User from app.models.user
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
|
|
||||||
|
# Import EvidenceOut from app.schemas.evidence
|
||||||
from app.schemas.evidence import EvidenceOut
|
from app.schemas.evidence import EvidenceOut
|
||||||
|
|
||||||
|
# Import log_action from app.services.audit_service
|
||||||
from app.services.audit_service import log_action
|
from app.services.audit_service import log_action
|
||||||
|
|
||||||
|
# Import from app.services.evidence_service
|
||||||
from app.services.evidence_service import (
|
from app.services.evidence_service import (
|
||||||
|
MAX_UPLOAD_SIZE,
|
||||||
get_evidence_or_raise,
|
get_evidence_or_raise,
|
||||||
get_test_or_raise,
|
get_test_or_raise,
|
||||||
list_evidence_for_test,
|
list_evidence_for_test,
|
||||||
MAX_UPLOAD_SIZE,
|
|
||||||
validate_delete_permission,
|
validate_delete_permission,
|
||||||
validate_file,
|
validate_file,
|
||||||
validate_upload_permission,
|
validate_upload_permission,
|
||||||
@@ -53,6 +77,7 @@ from app.storage import download_file, upload_file
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Assign router = APIRouter(tags=["evidence"])
|
||||||
router = APIRouter(tags=["evidence"])
|
router = APIRouter(tags=["evidence"])
|
||||||
|
|
||||||
|
|
||||||
@@ -67,13 +92,21 @@ def _evidence_to_out(evidence: Evidence) -> EvidenceOut:
|
|||||||
never needs direct access to MinIO.
|
never needs direct access to MinIO.
|
||||||
"""
|
"""
|
||||||
return EvidenceOut(
|
return EvidenceOut(
|
||||||
|
# Keyword argument: id
|
||||||
id=evidence.id,
|
id=evidence.id,
|
||||||
|
# Keyword argument: test_id
|
||||||
test_id=evidence.test_id,
|
test_id=evidence.test_id,
|
||||||
|
# Keyword argument: file_name
|
||||||
file_name=evidence.file_name,
|
file_name=evidence.file_name,
|
||||||
|
# Keyword argument: sha256_hash
|
||||||
sha256_hash=evidence.sha256_hash,
|
sha256_hash=evidence.sha256_hash,
|
||||||
|
# Keyword argument: uploaded_by
|
||||||
uploaded_by=evidence.uploaded_by,
|
uploaded_by=evidence.uploaded_by,
|
||||||
|
# Keyword argument: uploaded_at
|
||||||
uploaded_at=evidence.uploaded_at,
|
uploaded_at=evidence.uploaded_at,
|
||||||
|
# Keyword argument: team
|
||||||
team=evidence.team,
|
team=evidence.team,
|
||||||
|
# Keyword argument: notes
|
||||||
notes=evidence.notes,
|
notes=evidence.notes,
|
||||||
download_url=f"/api/v1/evidence/{evidence.id}/file",
|
download_url=f"/api/v1/evidence/{evidence.id}/file",
|
||||||
)
|
)
|
||||||
@@ -85,30 +118,47 @@ def _evidence_to_out(evidence: Evidence) -> EvidenceOut:
|
|||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
|
# Literal argument value
|
||||||
"/tests/{test_id}/evidence",
|
"/tests/{test_id}/evidence",
|
||||||
|
# Keyword argument: response_model
|
||||||
response_model=EvidenceOut,
|
response_model=EvidenceOut,
|
||||||
|
# Keyword argument: status_code
|
||||||
status_code=status.HTTP_201_CREATED,
|
status_code=status.HTTP_201_CREATED,
|
||||||
)
|
)
|
||||||
|
# Apply the @limiter.limit decorator
|
||||||
@limiter.limit("10/minute")
|
@limiter.limit("10/minute")
|
||||||
|
# Define async function upload_evidence
|
||||||
async def upload_evidence(
|
async def upload_evidence(
|
||||||
|
# Entry: request
|
||||||
request: Request,
|
request: Request,
|
||||||
|
# Entry: test_id
|
||||||
test_id: _uuid.UUID,
|
test_id: _uuid.UUID,
|
||||||
|
# Entry: file
|
||||||
file: UploadFile = File(...),
|
file: UploadFile = File(...),
|
||||||
|
# Entry: team
|
||||||
team: TeamSide = Form(TeamSide.red),
|
team: TeamSide = Form(TeamSide.red),
|
||||||
|
# Entry: notes
|
||||||
notes: Optional[str] = Form(None),
|
notes: Optional[str] = Form(None),
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
) -> EvidenceOut:
|
||||||
"""Upload a file as evidence for the given test.
|
"""Upload a file as evidence for the given test.
|
||||||
|
|
||||||
The ``team`` field (sent as form data) determines whether this is
|
The ``team`` field (sent as form data) determines whether this is
|
||||||
Red Team (attack) or Blue Team (detection) evidence.
|
Red Team (attack) or Blue Team (detection) evidence.
|
||||||
"""
|
"""
|
||||||
|
# Assign test = get_test_or_raise(db, test_id)
|
||||||
test = get_test_or_raise(db, test_id)
|
test = get_test_or_raise(db, test_id)
|
||||||
|
# Call validate_upload_permission()
|
||||||
validate_upload_permission(test, team, current_user.role)
|
validate_upload_permission(test, team, current_user.role)
|
||||||
|
|
||||||
|
# Assign file_name = file.filename or "unnamed"
|
||||||
file_name = file.filename or "unnamed"
|
file_name = file.filename or "unnamed"
|
||||||
|
# Assign content = await file.read(MAX_UPLOAD_SIZE + 1)
|
||||||
content = await file.read(MAX_UPLOAD_SIZE + 1)
|
content = await file.read(MAX_UPLOAD_SIZE + 1)
|
||||||
|
# Call validate_file()
|
||||||
validate_file(file_name, len(content))
|
validate_file(file_name, len(content))
|
||||||
|
|
||||||
# Hash
|
# Hash
|
||||||
@@ -116,6 +166,7 @@ async def upload_evidence(
|
|||||||
|
|
||||||
# 4. Object key (sanitise filename to prevent path traversal in storage)
|
# 4. Object key (sanitise filename to prevent path traversal in storage)
|
||||||
safe_name = os.path.basename(file_name)
|
safe_name = os.path.basename(file_name)
|
||||||
|
# Assign key = f"{test_id}/{_uuid.uuid4()}_{safe_name}"
|
||||||
key = f"{test_id}/{_uuid.uuid4()}_{safe_name}"
|
key = f"{test_id}/{_uuid.uuid4()}_{safe_name}"
|
||||||
|
|
||||||
# 5. Upload to MinIO
|
# 5. Upload to MinIO
|
||||||
@@ -123,32 +174,53 @@ async def upload_evidence(
|
|||||||
|
|
||||||
# 6. Persist metadata and audit
|
# 6. Persist metadata and audit
|
||||||
with UnitOfWork(db) as uow:
|
with UnitOfWork(db) as uow:
|
||||||
|
# Assign evidence = Evidence(
|
||||||
evidence = Evidence(
|
evidence = Evidence(
|
||||||
|
# Keyword argument: test_id
|
||||||
test_id=test_id,
|
test_id=test_id,
|
||||||
|
# Keyword argument: file_name
|
||||||
file_name=safe_name,
|
file_name=safe_name,
|
||||||
|
# Keyword argument: file_path
|
||||||
file_path=key,
|
file_path=key,
|
||||||
|
# Keyword argument: sha256_hash
|
||||||
sha256_hash=sha256,
|
sha256_hash=sha256,
|
||||||
|
# Keyword argument: uploaded_by
|
||||||
uploaded_by=current_user.id,
|
uploaded_by=current_user.id,
|
||||||
uploaded_at=datetime.utcnow(), # set explicitly — DB column has no server default
|
uploaded_at=datetime.utcnow(), # set explicitly — DB column has no server default
|
||||||
team=team,
|
team=team,
|
||||||
|
# Keyword argument: notes
|
||||||
notes=notes,
|
notes=notes,
|
||||||
)
|
)
|
||||||
|
# Stage new record(s) for database insertion
|
||||||
db.add(evidence)
|
db.add(evidence)
|
||||||
|
# Flush changes to DB without committing the transaction
|
||||||
db.flush() # Get evidence.id for audit
|
db.flush() # Get evidence.id for audit
|
||||||
|
# Call log_action()
|
||||||
log_action(
|
log_action(
|
||||||
db,
|
db,
|
||||||
|
# Keyword argument: user_id
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
|
# Keyword argument: action
|
||||||
action="upload_evidence",
|
action="upload_evidence",
|
||||||
|
# Keyword argument: entity_type
|
||||||
entity_type="evidence",
|
entity_type="evidence",
|
||||||
|
# Keyword argument: entity_id
|
||||||
entity_id=evidence.id,
|
entity_id=evidence.id,
|
||||||
|
# Keyword argument: details
|
||||||
details={
|
details={
|
||||||
|
# Literal argument value
|
||||||
"file_name": safe_name,
|
"file_name": safe_name,
|
||||||
|
# Literal argument value
|
||||||
"sha256": sha256,
|
"sha256": sha256,
|
||||||
|
# Literal argument value
|
||||||
"test_id": str(test_id),
|
"test_id": str(test_id),
|
||||||
|
# Literal argument value
|
||||||
"team": team.value,
|
"team": team.value,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
# Call uow.commit()
|
||||||
uow.commit()
|
uow.commit()
|
||||||
|
# Reload ORM object attributes from the database
|
||||||
db.refresh(evidence)
|
db.refresh(evidence)
|
||||||
|
|
||||||
# 7. Attach to Jira ticket if one exists (non-fatal)
|
# 7. Attach to Jira ticket if one exists (non-fatal)
|
||||||
@@ -194,15 +266,23 @@ def _attach_evidence_to_jira(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/tests/{test_id}/evidence", response_model=list[EvidenceOut])
|
@router.get("/tests/{test_id}/evidence", response_model=list[EvidenceOut])
|
||||||
|
# Define function list_evidence
|
||||||
def list_evidence(
|
def list_evidence(
|
||||||
|
# Entry: test_id
|
||||||
test_id: _uuid.UUID,
|
test_id: _uuid.UUID,
|
||||||
|
# Entry: team
|
||||||
team: Optional[str] = Query(None, description="Filter by team: red or blue"),
|
team: Optional[str] = Query(None, description="Filter by team: red or blue"),
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
) -> list[EvidenceOut]:
|
||||||
"""List all evidences for a test, optionally filtered by team."""
|
"""List all evidences for a test, optionally filtered by team."""
|
||||||
|
# Call get_test_or_raise()
|
||||||
get_test_or_raise(db, test_id)
|
get_test_or_raise(db, test_id)
|
||||||
|
# Assign evidences = list_evidence_for_test(db, test_id, team=team)
|
||||||
evidences = list_evidence_for_test(db, test_id, team=team)
|
evidences = list_evidence_for_test(db, test_id, team=team)
|
||||||
|
# Return [_evidence_to_out(e) for e in evidences]
|
||||||
return [_evidence_to_out(e) for e in evidences]
|
return [_evidence_to_out(e) for e in evidences]
|
||||||
|
|
||||||
|
|
||||||
@@ -212,13 +292,18 @@ def list_evidence(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/evidence/{evidence_id}", response_model=EvidenceOut)
|
@router.get("/evidence/{evidence_id}", response_model=EvidenceOut)
|
||||||
|
# Define function get_evidence
|
||||||
def get_evidence(
|
def get_evidence(
|
||||||
|
# Entry: evidence_id
|
||||||
evidence_id: _uuid.UUID,
|
evidence_id: _uuid.UUID,
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Return evidence metadata. ``download_url`` is a backend proxy URL."""
|
"""Return evidence metadata. ``download_url`` is a backend proxy URL."""
|
||||||
evidence = get_evidence_or_raise(db, evidence_id)
|
evidence = get_evidence_or_raise(db, evidence_id)
|
||||||
|
# Return _evidence_to_out(evidence)
|
||||||
return _evidence_to_out(evidence)
|
return _evidence_to_out(evidence)
|
||||||
|
|
||||||
|
|
||||||
@@ -265,11 +350,15 @@ def download_evidence_file(
|
|||||||
|
|
||||||
|
|
||||||
@router.delete("/evidence/{evidence_id}", status_code=status.HTTP_200_OK)
|
@router.delete("/evidence/{evidence_id}", status_code=status.HTTP_200_OK)
|
||||||
|
# Define function delete_evidence
|
||||||
def delete_evidence(
|
def delete_evidence(
|
||||||
|
# Entry: evidence_id
|
||||||
evidence_id: _uuid.UUID,
|
evidence_id: _uuid.UUID,
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
) -> dict:
|
||||||
"""Delete an evidence record.
|
"""Delete an evidence record.
|
||||||
|
|
||||||
Only allowed in editable states:
|
Only allowed in editable states:
|
||||||
@@ -277,24 +366,40 @@ def delete_evidence(
|
|||||||
- Blue evidence: ``blue_evaluating``
|
- Blue evidence: ``blue_evaluating``
|
||||||
- No deletions in ``in_review``, ``validated``, ``rejected``
|
- No deletions in ``in_review``, ``validated``, ``rejected``
|
||||||
"""
|
"""
|
||||||
|
# Assign evidence = get_evidence_or_raise(db, evidence_id)
|
||||||
evidence = get_evidence_or_raise(db, evidence_id)
|
evidence = get_evidence_or_raise(db, evidence_id)
|
||||||
|
# Assign test = get_test_or_raise(db, evidence.test_id)
|
||||||
test = get_test_or_raise(db, evidence.test_id)
|
test = get_test_or_raise(db, evidence.test_id)
|
||||||
|
# Call validate_delete_permission()
|
||||||
validate_delete_permission(test, evidence, current_user.role, current_user.id)
|
validate_delete_permission(test, evidence, current_user.role, current_user.id)
|
||||||
|
|
||||||
|
# Open context manager
|
||||||
with UnitOfWork(db) as uow:
|
with UnitOfWork(db) as uow:
|
||||||
|
# Call log_action()
|
||||||
log_action(
|
log_action(
|
||||||
db,
|
db,
|
||||||
|
# Keyword argument: user_id
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
|
# Keyword argument: action
|
||||||
action="delete_evidence",
|
action="delete_evidence",
|
||||||
|
# Keyword argument: entity_type
|
||||||
entity_type="evidence",
|
entity_type="evidence",
|
||||||
|
# Keyword argument: entity_id
|
||||||
entity_id=evidence.id,
|
entity_id=evidence.id,
|
||||||
|
# Keyword argument: details
|
||||||
details={
|
details={
|
||||||
|
# Literal argument value
|
||||||
"file_name": evidence.file_name,
|
"file_name": evidence.file_name,
|
||||||
|
# Literal argument value
|
||||||
"test_id": str(evidence.test_id),
|
"test_id": str(evidence.test_id),
|
||||||
|
# Literal argument value
|
||||||
"team": evidence.team.value if evidence.team else None,
|
"team": evidence.team.value if evidence.team else None,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
# Mark record for deletion on next commit
|
||||||
db.delete(evidence)
|
db.delete(evidence)
|
||||||
|
# Call uow.commit()
|
||||||
uow.commit()
|
uow.commit()
|
||||||
|
|
||||||
|
# Return {"detail": "Evidence deleted"}
|
||||||
return {"detail": "Evidence deleted"}
|
return {"detail": "Evidence deleted"}
|
||||||
|
|||||||
@@ -5,101 +5,169 @@ No business logic lives here — only request validation and response
|
|||||||
formatting.
|
formatting.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Import io
|
||||||
import io
|
import io
|
||||||
|
|
||||||
|
# Import json
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
# Import Optional from typing
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
# Import APIRouter, Depends, Query from fastapi
|
||||||
from fastapi import APIRouter, Depends, Query
|
from fastapi import APIRouter, Depends, Query
|
||||||
|
|
||||||
|
# Import StreamingResponse from fastapi.responses
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
|
# Import Session from sqlalchemy.orm
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
# Import get_db from app.database
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
|
|
||||||
|
# Import get_current_user from app.dependencies.auth
|
||||||
from app.dependencies.auth import get_current_user
|
from app.dependencies.auth import get_current_user
|
||||||
|
|
||||||
|
# Import User from app.models.user
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
|
|
||||||
|
# Import heatmap_service from app.services
|
||||||
from app.services import heatmap_service
|
from app.services import heatmap_service
|
||||||
|
|
||||||
|
# Assign router = APIRouter(prefix="/heatmap", tags=["heatmap"])
|
||||||
router = APIRouter(prefix="/heatmap", tags=["heatmap"])
|
router = APIRouter(prefix="/heatmap", tags=["heatmap"])
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @router.get decorator
|
||||||
@router.get("/coverage")
|
@router.get("/coverage")
|
||||||
|
# Define function heatmap_coverage
|
||||||
def heatmap_coverage(
|
def heatmap_coverage(
|
||||||
|
# Entry: platforms
|
||||||
platforms: Optional[str] = Query(None, description="Comma-separated platforms"),
|
platforms: Optional[str] = Query(None, description="Comma-separated platforms"),
|
||||||
|
# Entry: tactics
|
||||||
tactics: Optional[str] = Query(None, description="Comma-separated tactics"),
|
tactics: Optional[str] = Query(None, description="Comma-separated tactics"),
|
||||||
|
# Entry: min_score
|
||||||
min_score: int = Query(0, ge=0, le=100),
|
min_score: int = Query(0, ge=0, le=100),
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
) -> dict:
|
||||||
"""Coverage layer — score based on status_global of each technique."""
|
"""Coverage layer — score based on status_global of each technique."""
|
||||||
|
# Return heatmap_service.build_coverage_layer(
|
||||||
return heatmap_service.build_coverage_layer(
|
return heatmap_service.build_coverage_layer(
|
||||||
db, platforms=platforms, tactics=tactics, min_score=min_score,
|
db, platforms=platforms, tactics=tactics, min_score=min_score,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @router.get decorator
|
||||||
@router.get("/threat-actor/{actor_id}")
|
@router.get("/threat-actor/{actor_id}")
|
||||||
|
# Define function heatmap_threat_actor
|
||||||
def heatmap_threat_actor(
|
def heatmap_threat_actor(
|
||||||
|
# Entry: actor_id
|
||||||
actor_id: str,
|
actor_id: str,
|
||||||
|
# Entry: platforms
|
||||||
platforms: Optional[str] = Query(None),
|
platforms: Optional[str] = Query(None),
|
||||||
|
# Entry: tactics
|
||||||
tactics: Optional[str] = Query(None),
|
tactics: Optional[str] = Query(None),
|
||||||
|
# Entry: min_score
|
||||||
min_score: int = Query(0, ge=0, le=100),
|
min_score: int = Query(0, ge=0, le=100),
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
) -> dict:
|
||||||
"""Threat actor layer — techniques used by an actor with coverage color."""
|
"""Threat actor layer — techniques used by an actor with coverage color."""
|
||||||
|
# Return heatmap_service.build_threat_actor_layer(
|
||||||
return heatmap_service.build_threat_actor_layer(
|
return heatmap_service.build_threat_actor_layer(
|
||||||
db, actor_id, platforms=platforms, tactics=tactics, min_score=min_score,
|
db, actor_id, platforms=platforms, tactics=tactics, min_score=min_score,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @router.get decorator
|
||||||
@router.get("/detection-rules")
|
@router.get("/detection-rules")
|
||||||
|
# Define function heatmap_detection_rules
|
||||||
def heatmap_detection_rules(
|
def heatmap_detection_rules(
|
||||||
|
# Entry: platforms
|
||||||
platforms: Optional[str] = Query(None),
|
platforms: Optional[str] = Query(None),
|
||||||
|
# Entry: tactics
|
||||||
tactics: Optional[str] = Query(None),
|
tactics: Optional[str] = Query(None),
|
||||||
|
# Entry: min_score
|
||||||
min_score: int = Query(0, ge=0, le=100),
|
min_score: int = Query(0, ge=0, le=100),
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
) -> dict:
|
||||||
"""Detection rules layer — score based on ratio of rules available vs total."""
|
"""Detection rules layer — score based on ratio of rules available vs total."""
|
||||||
|
# Return heatmap_service.build_detection_rules_layer(
|
||||||
return heatmap_service.build_detection_rules_layer(
|
return heatmap_service.build_detection_rules_layer(
|
||||||
db, platforms=platforms, tactics=tactics, min_score=min_score,
|
db, platforms=platforms, tactics=tactics, min_score=min_score,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @router.get decorator
|
||||||
@router.get("/campaign/{campaign_id}")
|
@router.get("/campaign/{campaign_id}")
|
||||||
|
# Define function heatmap_campaign
|
||||||
def heatmap_campaign(
|
def heatmap_campaign(
|
||||||
|
# Entry: campaign_id
|
||||||
campaign_id: str,
|
campaign_id: str,
|
||||||
|
# Entry: platforms
|
||||||
platforms: Optional[str] = Query(None),
|
platforms: Optional[str] = Query(None),
|
||||||
|
# Entry: tactics
|
||||||
tactics: Optional[str] = Query(None),
|
tactics: Optional[str] = Query(None),
|
||||||
|
# Entry: min_score
|
||||||
min_score: int = Query(0, ge=0, le=100),
|
min_score: int = Query(0, ge=0, le=100),
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
) -> dict:
|
||||||
"""Campaign layer — only techniques in the campaign, colored by test state."""
|
"""Campaign layer — only techniques in the campaign, colored by test state."""
|
||||||
|
# Return heatmap_service.build_campaign_layer(
|
||||||
return heatmap_service.build_campaign_layer(
|
return heatmap_service.build_campaign_layer(
|
||||||
db, campaign_id, platforms=platforms, tactics=tactics, min_score=min_score,
|
db, campaign_id, platforms=platforms, tactics=tactics, min_score=min_score,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @router.get decorator
|
||||||
@router.get("/export-navigator")
|
@router.get("/export-navigator")
|
||||||
|
# Define function export_navigator
|
||||||
def export_navigator(
|
def export_navigator(
|
||||||
|
# Entry: layer
|
||||||
layer: str = Query(..., description="Layer type: coverage, threat-actor, detection-rules, campaign"),
|
layer: str = Query(..., description="Layer type: coverage, threat-actor, detection-rules, campaign"),
|
||||||
|
# Entry: layer_id
|
||||||
layer_id: Optional[str] = Query(None, description="Actor ID or Campaign ID (if applicable)"),
|
layer_id: Optional[str] = Query(None, description="Actor ID or Campaign ID (if applicable)"),
|
||||||
|
# Entry: platforms
|
||||||
platforms: Optional[str] = Query(None),
|
platforms: Optional[str] = Query(None),
|
||||||
|
# Entry: tactics
|
||||||
tactics: Optional[str] = Query(None),
|
tactics: Optional[str] = Query(None),
|
||||||
|
# Entry: min_score
|
||||||
min_score: int = Query(0, ge=0, le=100),
|
min_score: int = Query(0, ge=0, le=100),
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
) -> StreamingResponse:
|
||||||
"""Export a heatmap layer as a downloadable JSON file for ATT&CK Navigator."""
|
"""Export a heatmap layer as a downloadable JSON file for ATT&CK Navigator."""
|
||||||
|
# Assign data = heatmap_service.build_navigator_export(
|
||||||
data = heatmap_service.build_navigator_export(
|
data = heatmap_service.build_navigator_export(
|
||||||
db, layer, layer_id=layer_id,
|
db, layer, layer_id=layer_id,
|
||||||
|
# Keyword argument: platforms
|
||||||
platforms=platforms, tactics=tactics, min_score=min_score,
|
platforms=platforms, tactics=tactics, min_score=min_score,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Assign json_content = json.dumps(data, indent=2, default=str)
|
||||||
json_content = json.dumps(data, indent=2, default=str)
|
json_content = json.dumps(data, indent=2, default=str)
|
||||||
|
# Assign buffer = io.BytesIO(json_content.encode("utf-8"))
|
||||||
buffer = io.BytesIO(json_content.encode("utf-8"))
|
buffer = io.BytesIO(json_content.encode("utf-8"))
|
||||||
|
|
||||||
|
# Return StreamingResponse(
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
buffer,
|
buffer,
|
||||||
|
# Keyword argument: media_type
|
||||||
media_type="application/json",
|
media_type="application/json",
|
||||||
|
# Keyword argument: headers
|
||||||
headers={"Content-Disposition": f"attachment; filename=aegis_{layer}_layer.json"},
|
headers={"Content-Disposition": f"attachment; filename=aegis_{layer}_layer.json"},
|
||||||
)
|
)
|
||||||
|
|||||||
+103
-6
@@ -1,138 +1,235 @@
|
|||||||
"""Jira integration router — link, search, sync, create issues."""
|
"""Jira integration router — link, search, sync, create issues."""
|
||||||
|
|
||||||
|
# Import logging
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
# Import Optional from typing
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
# Import UUID from uuid
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
|
# Import APIRouter, Depends, Query from fastapi
|
||||||
from fastapi import APIRouter, Depends, Query
|
from fastapi import APIRouter, Depends, Query
|
||||||
|
|
||||||
|
# Import Session from sqlalchemy.orm
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
# Import get_db from app.database
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
|
|
||||||
|
# Import get_current_user, require_role from app.dependencies.auth
|
||||||
from app.dependencies.auth import get_current_user, require_role
|
from app.dependencies.auth import get_current_user, require_role
|
||||||
|
|
||||||
|
# Import UnitOfWork from app.domain.unit_of_work
|
||||||
from app.domain.unit_of_work import UnitOfWork
|
from app.domain.unit_of_work import UnitOfWork
|
||||||
|
|
||||||
|
# Import JiraLinkEntityType from app.models.jira_link
|
||||||
from app.models.jira_link import JiraLinkEntityType
|
from app.models.jira_link import JiraLinkEntityType
|
||||||
|
|
||||||
|
# Import User from app.models.user
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
|
|
||||||
|
# Import from app.schemas.jira_schema
|
||||||
from app.schemas.jira_schema import (
|
from app.schemas.jira_schema import (
|
||||||
JiraIssueResult,
|
JiraIssueResult,
|
||||||
JiraLinkCreate,
|
JiraLinkCreate,
|
||||||
JiraLinkOut,
|
JiraLinkOut,
|
||||||
)
|
)
|
||||||
from app.services import jira_service, audit_service
|
|
||||||
|
|
||||||
|
# Import audit_service, jira_service from app.services
|
||||||
|
from app.services import audit_service, jira_service
|
||||||
|
|
||||||
|
# Assign logger = logging.getLogger(__name__)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Assign router = APIRouter(prefix="/jira", tags=["jira"])
|
||||||
router = APIRouter(prefix="/jira", tags=["jira"])
|
router = APIRouter(prefix="/jira", tags=["jira"])
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @router.get decorator
|
||||||
@router.get("/search", response_model=list[JiraIssueResult])
|
@router.get("/search", response_model=list[JiraIssueResult])
|
||||||
|
# Define function search_issues
|
||||||
def search_issues(
|
def search_issues(
|
||||||
|
# Entry: q
|
||||||
q: str = Query(..., min_length=2),
|
q: str = Query(..., min_length=2),
|
||||||
|
# Entry: max_results
|
||||||
max_results: int = Query(10, le=50),
|
max_results: int = Query(10, le=50),
|
||||||
|
# Entry: user
|
||||||
user: User = Depends(get_current_user),
|
user: User = Depends(get_current_user),
|
||||||
):
|
) -> list[JiraIssueResult]:
|
||||||
"""Search Jira issues by JQL or free text."""
|
"""Search Jira issues by JQL or free text."""
|
||||||
|
# Return jira_service.search_jira_issues(q, max_results)
|
||||||
return jira_service.search_jira_issues(q, max_results)
|
return jira_service.search_jira_issues(q, max_results)
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @router.post decorator
|
||||||
@router.post("/links", response_model=JiraLinkOut, status_code=201)
|
@router.post("/links", response_model=JiraLinkOut, status_code=201)
|
||||||
|
# Define function create_link
|
||||||
def create_link(
|
def create_link(
|
||||||
|
# Entry: body
|
||||||
body: JiraLinkCreate,
|
body: JiraLinkCreate,
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: user
|
||||||
user: User = Depends(get_current_user),
|
user: User = Depends(get_current_user),
|
||||||
):
|
) -> JiraLinkOut:
|
||||||
"""Associate an Aegis entity with a Jira issue."""
|
"""Associate an Aegis entity with a Jira issue."""
|
||||||
|
# Open context manager
|
||||||
with UnitOfWork(db) as uow:
|
with UnitOfWork(db) as uow:
|
||||||
|
# Assign link = jira_service.create_link(
|
||||||
link = jira_service.create_link(
|
link = jira_service.create_link(
|
||||||
db,
|
db,
|
||||||
|
# Keyword argument: entity_type
|
||||||
entity_type=body.entity_type,
|
entity_type=body.entity_type,
|
||||||
|
# Keyword argument: entity_id
|
||||||
entity_id=body.entity_id,
|
entity_id=body.entity_id,
|
||||||
|
# Keyword argument: jira_issue_key
|
||||||
jira_issue_key=body.jira_issue_key,
|
jira_issue_key=body.jira_issue_key,
|
||||||
|
# Keyword argument: sync_direction
|
||||||
sync_direction=body.sync_direction,
|
sync_direction=body.sync_direction,
|
||||||
|
# Keyword argument: created_by
|
||||||
created_by=user.id,
|
created_by=user.id,
|
||||||
)
|
)
|
||||||
|
# Call audit_service.log_action()
|
||||||
audit_service.log_action(
|
audit_service.log_action(
|
||||||
db,
|
db,
|
||||||
|
# Keyword argument: user_id
|
||||||
user_id=user.id,
|
user_id=user.id,
|
||||||
|
# Keyword argument: action
|
||||||
action="JIRA_LINK_CREATED",
|
action="JIRA_LINK_CREATED",
|
||||||
|
# Keyword argument: entity_type
|
||||||
entity_type="jira_link",
|
entity_type="jira_link",
|
||||||
|
# Keyword argument: entity_id
|
||||||
entity_id=str(link.id),
|
entity_id=str(link.id),
|
||||||
|
# Keyword argument: details
|
||||||
details={
|
details={
|
||||||
|
# Literal argument value
|
||||||
"linked_entity_type": body.entity_type.value,
|
"linked_entity_type": body.entity_type.value,
|
||||||
|
# Literal argument value
|
||||||
"linked_entity_id": str(body.entity_id),
|
"linked_entity_id": str(body.entity_id),
|
||||||
|
# Literal argument value
|
||||||
"jira_issue_key": body.jira_issue_key,
|
"jira_issue_key": body.jira_issue_key,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
# Call uow.commit()
|
||||||
uow.commit()
|
uow.commit()
|
||||||
|
# Reload ORM object attributes from the database
|
||||||
db.refresh(link)
|
db.refresh(link)
|
||||||
|
|
||||||
|
# Return link
|
||||||
return link
|
return link
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @router.get decorator
|
||||||
@router.get("/links", response_model=list[JiraLinkOut])
|
@router.get("/links", response_model=list[JiraLinkOut])
|
||||||
|
# Define function list_links
|
||||||
def list_links(
|
def list_links(
|
||||||
|
# Entry: entity_type
|
||||||
entity_type: Optional[JiraLinkEntityType] = None,
|
entity_type: Optional[JiraLinkEntityType] = None,
|
||||||
|
# Entry: entity_id
|
||||||
entity_id: Optional[UUID] = None,
|
entity_id: Optional[UUID] = None,
|
||||||
entity_ids: Optional[list[UUID]] = Query(default=None, description="Filter by multiple entity IDs"),
|
entity_ids: Optional[list[UUID]] = Query(default=None, description="Filter by multiple entity IDs"),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: user
|
||||||
user: User = Depends(get_current_user),
|
user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""List Jira links, optionally filtered by entity or a list of entity IDs."""
|
"""List Jira links, optionally filtered by entity or a list of entity IDs."""
|
||||||
return jira_service.list_links(
|
return jira_service.list_links(
|
||||||
db,
|
db,
|
||||||
|
# Keyword argument: entity_type
|
||||||
entity_type=entity_type,
|
entity_type=entity_type,
|
||||||
|
# Keyword argument: entity_id
|
||||||
entity_id=entity_id,
|
entity_id=entity_id,
|
||||||
entity_ids=entity_ids,
|
entity_ids=entity_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @router.post decorator
|
||||||
@router.post("/links/{link_id}/sync")
|
@router.post("/links/{link_id}/sync")
|
||||||
|
# Define function sync_link
|
||||||
def sync_link(
|
def sync_link(
|
||||||
|
# Entry: link_id
|
||||||
link_id: UUID,
|
link_id: UUID,
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: user
|
||||||
user: User = Depends(require_role("admin")),
|
user: User = Depends(require_role("admin")),
|
||||||
):
|
) -> dict:
|
||||||
"""Force bidirectional sync for a specific Jira link."""
|
"""Force bidirectional sync for a specific Jira link."""
|
||||||
|
# Open context manager
|
||||||
with UnitOfWork(db) as uow:
|
with UnitOfWork(db) as uow:
|
||||||
|
# Assign link = jira_service.get_link_or_raise(db, link_id)
|
||||||
link = jira_service.get_link_or_raise(db, link_id)
|
link = jira_service.get_link_or_raise(db, link_id)
|
||||||
|
# Call jira_service.sync_jira_to_aegis()
|
||||||
jira_service.sync_jira_to_aegis(db, link)
|
jira_service.sync_jira_to_aegis(db, link)
|
||||||
|
# Call uow.commit()
|
||||||
uow.commit()
|
uow.commit()
|
||||||
|
# Return {"message": "Sync completed", "jira_status": link.jira_status}
|
||||||
return {"message": "Sync completed", "jira_status": link.jira_status}
|
return {"message": "Sync completed", "jira_status": link.jira_status}
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @router.delete decorator
|
||||||
@router.delete("/links/{link_id}", status_code=204)
|
@router.delete("/links/{link_id}", status_code=204)
|
||||||
|
# Define function delete_link
|
||||||
def delete_link(
|
def delete_link(
|
||||||
|
# Entry: link_id
|
||||||
link_id: UUID,
|
link_id: UUID,
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: user
|
||||||
user: User = Depends(get_current_user),
|
user: User = Depends(get_current_user),
|
||||||
):
|
) -> None:
|
||||||
"""Remove a Jira link."""
|
"""Remove a Jira link."""
|
||||||
|
# Open context manager
|
||||||
with UnitOfWork(db) as uow:
|
with UnitOfWork(db) as uow:
|
||||||
|
# Assign link = jira_service.delete_link(db, link_id)
|
||||||
link = jira_service.delete_link(db, link_id)
|
link = jira_service.delete_link(db, link_id)
|
||||||
|
# Call audit_service.log_action()
|
||||||
audit_service.log_action(
|
audit_service.log_action(
|
||||||
db,
|
db,
|
||||||
|
# Keyword argument: user_id
|
||||||
user_id=user.id,
|
user_id=user.id,
|
||||||
|
# Keyword argument: action
|
||||||
action="jira_link_deleted",
|
action="jira_link_deleted",
|
||||||
|
# Keyword argument: entity_type
|
||||||
entity_type="jira_link",
|
entity_type="jira_link",
|
||||||
|
# Keyword argument: entity_id
|
||||||
entity_id=str(link_id),
|
entity_id=str(link_id),
|
||||||
|
# Keyword argument: details
|
||||||
details={"jira_issue_key": link.jira_issue_key},
|
details={"jira_issue_key": link.jira_issue_key},
|
||||||
)
|
)
|
||||||
|
# Call uow.commit()
|
||||||
uow.commit()
|
uow.commit()
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @router.post decorator
|
||||||
@router.post("/create-issue")
|
@router.post("/create-issue")
|
||||||
|
# Define function create_issue_from_entity
|
||||||
def create_issue_from_entity(
|
def create_issue_from_entity(
|
||||||
|
# Entry: entity_type
|
||||||
entity_type: JiraLinkEntityType,
|
entity_type: JiraLinkEntityType,
|
||||||
|
# Entry: entity_id
|
||||||
entity_id: UUID,
|
entity_id: UUID,
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: user
|
||||||
user: User = Depends(get_current_user),
|
user: User = Depends(get_current_user),
|
||||||
):
|
) -> dict:
|
||||||
"""Auto-create a Jira issue from an Aegis entity and link them."""
|
"""Auto-create a Jira issue from an Aegis entity and link them."""
|
||||||
|
# Open context manager
|
||||||
with UnitOfWork(db) as uow:
|
with UnitOfWork(db) as uow:
|
||||||
|
# Assign result = jira_service.create_issue_and_link(
|
||||||
result = jira_service.create_issue_and_link(
|
result = jira_service.create_issue_and_link(
|
||||||
db,
|
db,
|
||||||
|
# Keyword argument: entity_type
|
||||||
entity_type=entity_type,
|
entity_type=entity_type,
|
||||||
|
# Keyword argument: entity_id
|
||||||
entity_id=entity_id,
|
entity_id=entity_id,
|
||||||
|
# Keyword argument: created_by
|
||||||
created_by=user.id,
|
created_by=user.id,
|
||||||
)
|
)
|
||||||
|
# Call uow.commit()
|
||||||
uow.commit()
|
uow.commit()
|
||||||
|
# Return result
|
||||||
return result
|
return result
|
||||||
|
|||||||
@@ -7,12 +7,22 @@ validation-rate endpoints for the Red/Blue workflow.
|
|||||||
Thin HTTP adapter: delegates all data logic to metrics_query_service.
|
Thin HTTP adapter: delegates all data logic to metrics_query_service.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Import APIRouter, Depends from fastapi
|
||||||
from fastapi import APIRouter, Depends
|
from fastapi import APIRouter, Depends
|
||||||
|
|
||||||
|
# Import Session from sqlalchemy.orm
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
# Import get_db from app.database
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
|
|
||||||
|
# Import get_current_user from app.dependencies.auth
|
||||||
from app.dependencies.auth import get_current_user
|
from app.dependencies.auth import get_current_user
|
||||||
|
|
||||||
|
# Import User from app.models.user
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
|
|
||||||
|
# Import from app.schemas.metrics
|
||||||
from app.schemas.metrics import (
|
from app.schemas.metrics import (
|
||||||
CoverageSummary,
|
CoverageSummary,
|
||||||
RecentTestItem,
|
RecentTestItem,
|
||||||
@@ -21,6 +31,8 @@ from app.schemas.metrics import (
|
|||||||
TestPipelineCounts,
|
TestPipelineCounts,
|
||||||
ValidationRate,
|
ValidationRate,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Import from app.services.metrics_query_service
|
||||||
from app.services.metrics_query_service import (
|
from app.services.metrics_query_service import (
|
||||||
get_coverage_by_tactic,
|
get_coverage_by_tactic,
|
||||||
get_coverage_summary,
|
get_coverage_summary,
|
||||||
@@ -30,6 +42,7 @@ from app.services.metrics_query_service import (
|
|||||||
get_validation_rate,
|
get_validation_rate,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Assign router = APIRouter(prefix="/metrics", tags=["metrics"])
|
||||||
router = APIRouter(prefix="/metrics", tags=["metrics"])
|
router = APIRouter(prefix="/metrics", tags=["metrics"])
|
||||||
|
|
||||||
|
|
||||||
@@ -39,11 +52,15 @@ router = APIRouter(prefix="/metrics", tags=["metrics"])
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/summary", response_model=CoverageSummary)
|
@router.get("/summary", response_model=CoverageSummary)
|
||||||
|
# Define function coverage_summary
|
||||||
def coverage_summary(
|
def coverage_summary(
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
) -> CoverageSummary:
|
||||||
"""Return a global coverage summary across all techniques."""
|
"""Return a global coverage summary across all techniques."""
|
||||||
|
# Return get_coverage_summary(db)
|
||||||
return get_coverage_summary(db)
|
return get_coverage_summary(db)
|
||||||
|
|
||||||
|
|
||||||
@@ -53,11 +70,15 @@ def coverage_summary(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/by-tactic", response_model=list[TacticCoverage])
|
@router.get("/by-tactic", response_model=list[TacticCoverage])
|
||||||
|
# Define function coverage_by_tactic
|
||||||
def coverage_by_tactic(
|
def coverage_by_tactic(
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
) -> list[TacticCoverage]:
|
||||||
"""Return coverage breakdown grouped by tactic."""
|
"""Return coverage breakdown grouped by tactic."""
|
||||||
|
# Return get_coverage_by_tactic(db)
|
||||||
return get_coverage_by_tactic(db)
|
return get_coverage_by_tactic(db)
|
||||||
|
|
||||||
|
|
||||||
@@ -67,11 +88,15 @@ def coverage_by_tactic(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/test-pipeline", response_model=TestPipelineCounts)
|
@router.get("/test-pipeline", response_model=TestPipelineCounts)
|
||||||
|
# Define function test_pipeline
|
||||||
def test_pipeline(
|
def test_pipeline(
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
) -> TestPipelineCounts:
|
||||||
"""Return how many tests are in each pipeline state."""
|
"""Return how many tests are in each pipeline state."""
|
||||||
|
# Return get_test_pipeline_counts(db)
|
||||||
return get_test_pipeline_counts(db)
|
return get_test_pipeline_counts(db)
|
||||||
|
|
||||||
|
|
||||||
@@ -81,11 +106,15 @@ def test_pipeline(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/team-activity", response_model=list[TeamActivity])
|
@router.get("/team-activity", response_model=list[TeamActivity])
|
||||||
|
# Define function team_activity
|
||||||
def team_activity(
|
def team_activity(
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
) -> list[TeamActivity]:
|
||||||
"""Return activity summary for Red and Blue teams."""
|
"""Return activity summary for Red and Blue teams."""
|
||||||
|
# Return get_team_activity(db)
|
||||||
return get_team_activity(db)
|
return get_team_activity(db)
|
||||||
|
|
||||||
|
|
||||||
@@ -95,11 +124,15 @@ def team_activity(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/validation-rate", response_model=list[ValidationRate])
|
@router.get("/validation-rate", response_model=list[ValidationRate])
|
||||||
|
# Define function validation_rate
|
||||||
def validation_rate(
|
def validation_rate(
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
) -> list[ValidationRate]:
|
||||||
"""Return approval and rejection rates for Red Lead and Blue Lead."""
|
"""Return approval and rejection rates for Red Lead and Blue Lead."""
|
||||||
|
# Return get_validation_rate(db)
|
||||||
return get_validation_rate(db)
|
return get_validation_rate(db)
|
||||||
|
|
||||||
|
|
||||||
@@ -109,9 +142,13 @@ def validation_rate(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/recent-tests", response_model=list[RecentTestItem])
|
@router.get("/recent-tests", response_model=list[RecentTestItem])
|
||||||
|
# Define function recent_tests
|
||||||
def recent_tests(
|
def recent_tests(
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
) -> list[RecentTestItem]:
|
||||||
"""Return the 10 most recently created tests."""
|
"""Return the 10 most recently created tests."""
|
||||||
|
# Return get_recent_tests(db, limit=10)
|
||||||
return get_recent_tests(db, limit=10)
|
return get_recent_tests(db, limit=10)
|
||||||
|
|||||||
@@ -8,23 +8,39 @@ PATCH /notifications/{id}/read — mark one notification as read
|
|||||||
POST /notifications/read-all — mark all as read
|
POST /notifications/read-all — mark all as read
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Import uuid
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
# Import APIRouter, Depends, Query from fastapi
|
||||||
from fastapi import APIRouter, Depends, Query
|
from fastapi import APIRouter, Depends, Query
|
||||||
|
|
||||||
|
# Import Session from sqlalchemy.orm
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
# Import get_db from app.database
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
|
|
||||||
|
# Import get_current_user from app.dependencies.auth
|
||||||
from app.dependencies.auth import get_current_user
|
from app.dependencies.auth import get_current_user
|
||||||
|
|
||||||
|
# Import UnitOfWork from app.domain.unit_of_work
|
||||||
from app.domain.unit_of_work import UnitOfWork
|
from app.domain.unit_of_work import UnitOfWork
|
||||||
|
|
||||||
|
# Import User from app.models.user
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
|
|
||||||
|
# Import NotificationOut, UnreadCountOut from app.schemas.notification
|
||||||
from app.schemas.notification import NotificationOut, UnreadCountOut
|
from app.schemas.notification import NotificationOut, UnreadCountOut
|
||||||
|
|
||||||
|
# Import from app.services.notification_service
|
||||||
from app.services.notification_service import (
|
from app.services.notification_service import (
|
||||||
list_notifications,
|
|
||||||
mark_as_read,
|
|
||||||
mark_all_as_read,
|
|
||||||
get_unread_count,
|
get_unread_count,
|
||||||
|
list_notifications,
|
||||||
|
mark_all_as_read,
|
||||||
|
mark_as_read,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Assign router = APIRouter(prefix="/notifications", tags=["notifications"])
|
||||||
router = APIRouter(prefix="/notifications", tags=["notifications"])
|
router = APIRouter(prefix="/notifications", tags=["notifications"])
|
||||||
|
|
||||||
|
|
||||||
@@ -34,13 +50,19 @@ router = APIRouter(prefix="/notifications", tags=["notifications"])
|
|||||||
|
|
||||||
|
|
||||||
@router.get("", response_model=list[NotificationOut])
|
@router.get("", response_model=list[NotificationOut])
|
||||||
|
# Define function list_notifications_endpoint
|
||||||
def list_notifications_endpoint(
|
def list_notifications_endpoint(
|
||||||
|
# Entry: offset
|
||||||
offset: int = Query(0, ge=0),
|
offset: int = Query(0, ge=0),
|
||||||
|
# Entry: limit
|
||||||
limit: int = Query(20, ge=1, le=100),
|
limit: int = Query(20, ge=1, le=100),
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
) -> list[NotificationOut]:
|
||||||
"""Return paginated notifications for the current user, newest first."""
|
"""Return paginated notifications for the current user, newest first."""
|
||||||
|
# Return list_notifications(db, current_user.id, offset=offset, limit=limit)
|
||||||
return list_notifications(db, current_user.id, offset=offset, limit=limit)
|
return list_notifications(db, current_user.id, offset=offset, limit=limit)
|
||||||
|
|
||||||
|
|
||||||
@@ -50,12 +72,17 @@ def list_notifications_endpoint(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/unread-count", response_model=UnreadCountOut)
|
@router.get("/unread-count", response_model=UnreadCountOut)
|
||||||
|
# Define function unread_count
|
||||||
def unread_count(
|
def unread_count(
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
) -> UnreadCountOut:
|
||||||
"""Return the number of unread notifications for the current user."""
|
"""Return the number of unread notifications for the current user."""
|
||||||
|
# Assign count = get_unread_count(db, current_user.id)
|
||||||
count = get_unread_count(db, current_user.id)
|
count = get_unread_count(db, current_user.id)
|
||||||
|
# Return UnreadCountOut(unread_count=count)
|
||||||
return UnreadCountOut(unread_count=count)
|
return UnreadCountOut(unread_count=count)
|
||||||
|
|
||||||
|
|
||||||
@@ -65,15 +92,23 @@ def unread_count(
|
|||||||
|
|
||||||
|
|
||||||
@router.patch("/{notification_id}/read", response_model=NotificationOut)
|
@router.patch("/{notification_id}/read", response_model=NotificationOut)
|
||||||
|
# Define function read_notification
|
||||||
def read_notification(
|
def read_notification(
|
||||||
|
# Entry: notification_id
|
||||||
notification_id: uuid.UUID,
|
notification_id: uuid.UUID,
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
) -> NotificationOut:
|
||||||
"""Mark a single notification as read."""
|
"""Mark a single notification as read."""
|
||||||
|
# Open context manager
|
||||||
with UnitOfWork(db) as uow:
|
with UnitOfWork(db) as uow:
|
||||||
|
# Assign notif = mark_as_read(db, notification_id, current_user.id)
|
||||||
notif = mark_as_read(db, notification_id, current_user.id)
|
notif = mark_as_read(db, notification_id, current_user.id)
|
||||||
|
# Call uow.commit()
|
||||||
uow.commit()
|
uow.commit()
|
||||||
|
# Return notif
|
||||||
return notif
|
return notif
|
||||||
|
|
||||||
|
|
||||||
@@ -83,12 +118,19 @@ def read_notification(
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/read-all")
|
@router.post("/read-all")
|
||||||
|
# Define function read_all_notifications
|
||||||
def read_all_notifications(
|
def read_all_notifications(
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
) -> dict:
|
||||||
"""Mark all notifications for the current user as read."""
|
"""Mark all notifications for the current user as read."""
|
||||||
|
# Open context manager
|
||||||
with UnitOfWork(db) as uow:
|
with UnitOfWork(db) as uow:
|
||||||
|
# Assign count = mark_all_as_read(db, current_user.id)
|
||||||
count = mark_all_as_read(db, current_user.id)
|
count = mark_all_as_read(db, current_user.id)
|
||||||
|
# Call uow.commit()
|
||||||
uow.commit()
|
uow.commit()
|
||||||
|
# Return {"detail": f"Marked {count} notifications as read"}
|
||||||
return {"detail": f"Marked {count} notifications as read"}
|
return {"detail": f"Marked {count} notifications as read"}
|
||||||
|
|||||||
@@ -4,18 +4,28 @@ Provides operational KPIs for security teams with trend analysis
|
|||||||
and team-level breakdowns.
|
and team-level breakdowns.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Import APIRouter, Depends, Query from fastapi
|
||||||
from fastapi import APIRouter, Depends, Query
|
from fastapi import APIRouter, Depends, Query
|
||||||
|
|
||||||
|
# Import Session from sqlalchemy.orm
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
# Import get_db from app.database
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
|
|
||||||
|
# Import get_current_user from app.dependencies.auth
|
||||||
from app.dependencies.auth import get_current_user
|
from app.dependencies.auth import get_current_user
|
||||||
|
|
||||||
|
# Import User from app.models.user
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
|
|
||||||
|
# Import from app.services.operational_metrics_service
|
||||||
from app.services.operational_metrics_service import (
|
from app.services.operational_metrics_service import (
|
||||||
get_all_operational_metrics,
|
|
||||||
get_operational_trend,
|
|
||||||
get_metrics_by_team,
|
get_metrics_by_team,
|
||||||
|
get_operational_trend,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Assign router = APIRouter(prefix="/metrics/operational", tags=["operational-metrics"])
|
||||||
router = APIRouter(prefix="/metrics/operational", tags=["operational-metrics"])
|
router = APIRouter(prefix="/metrics/operational", tags=["operational-metrics"])
|
||||||
|
|
||||||
|
|
||||||
@@ -23,13 +33,18 @@ router = APIRouter(prefix="/metrics/operational", tags=["operational-metrics"])
|
|||||||
|
|
||||||
|
|
||||||
@router.get("")
|
@router.get("")
|
||||||
|
# Define function operational_metrics
|
||||||
def operational_metrics(
|
def operational_metrics(
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
) -> dict:
|
||||||
"""Get all operational metrics (MTTD, MTTR, etc.) — cached for 5 min."""
|
"""Get all operational metrics (MTTD, MTTR, etc.) — cached for 5 min."""
|
||||||
|
# Import get_operational_metrics_cached from app.services.score_cache
|
||||||
from app.services.score_cache import get_operational_metrics_cached
|
from app.services.score_cache import get_operational_metrics_cached
|
||||||
|
|
||||||
|
# Return get_operational_metrics_cached(db)
|
||||||
return get_operational_metrics_cached(db)
|
return get_operational_metrics_cached(db)
|
||||||
|
|
||||||
|
|
||||||
@@ -37,12 +52,17 @@ def operational_metrics(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/trend")
|
@router.get("/trend")
|
||||||
|
# Define function operational_trend
|
||||||
def operational_trend(
|
def operational_trend(
|
||||||
|
# Entry: period
|
||||||
period: str = Query("90d", pattern="^(30d|90d|1y)$"),
|
period: str = Query("90d", pattern="^(30d|90d|1y)$"),
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
) -> dict:
|
||||||
"""Get weekly trend data for operational metrics."""
|
"""Get weekly trend data for operational metrics."""
|
||||||
|
# Return get_operational_trend(db, period)
|
||||||
return get_operational_trend(db, period)
|
return get_operational_trend(db, period)
|
||||||
|
|
||||||
|
|
||||||
@@ -50,9 +70,13 @@ def operational_trend(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/by-team")
|
@router.get("/by-team")
|
||||||
|
# Define function metrics_by_team
|
||||||
def metrics_by_team(
|
def metrics_by_team(
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
) -> dict:
|
||||||
"""Get metrics broken down by Red Team vs Blue Team."""
|
"""Get metrics broken down by Red Team vs Blue Team."""
|
||||||
|
# Return get_metrics_by_team(db)
|
||||||
return get_metrics_by_team(db)
|
return get_metrics_by_team(db)
|
||||||
|
|||||||
+162
-15
@@ -1,26 +1,44 @@
|
|||||||
"""OSINT enrichment endpoints — view, review, and trigger enrichment of
|
"""OSINT enrichment endpoints — view, review, and trigger enrichment of OSINT items linked to techniques."""
|
||||||
OSINT items (CVEs, advisories, etc.) linked to techniques.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
# Import UUID from uuid
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Query, HTTPException, status
|
# Import APIRouter, Depends, HTTPException, Query, status from fastapi
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||||
|
|
||||||
|
# Import BaseModel from pydantic
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
# Import Session from sqlalchemy.orm
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
# Import get_db from app.database
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
|
|
||||||
|
# Import get_current_user, require_any_role from app.dependencies.auth
|
||||||
from app.dependencies.auth import get_current_user, require_any_role
|
from app.dependencies.auth import get_current_user, require_any_role
|
||||||
|
|
||||||
|
# Import UnitOfWork from app.domain.unit_of_work
|
||||||
from app.domain.unit_of_work import UnitOfWork
|
from app.domain.unit_of_work import UnitOfWork
|
||||||
|
|
||||||
|
# Import User from app.models.user
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
|
|
||||||
|
# Import from app.services.osint_enrichment_service
|
||||||
from app.services.osint_enrichment_service import (
|
from app.services.osint_enrichment_service import (
|
||||||
enrich_technique_with_cves,
|
enrich_technique_with_cves,
|
||||||
get_osint_items_for_technique,
|
get_osint_items_for_technique,
|
||||||
get_osint_summary,
|
get_osint_summary,
|
||||||
get_technique_or_raise,
|
get_technique_or_raise,
|
||||||
list_osint_items as service_list_osint_items,
|
|
||||||
mark_osint_reviewed,
|
mark_osint_reviewed,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Import from app.services.osint_enrichment_service
|
||||||
|
from app.services.osint_enrichment_service import (
|
||||||
|
list_osint_items as service_list_osint_items,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assign router = APIRouter(prefix="/osint", tags=["osint"])
|
||||||
router = APIRouter(prefix="/osint", tags=["osint"])
|
router = APIRouter(prefix="/osint", tags=["osint"])
|
||||||
|
|
||||||
|
|
||||||
@@ -28,18 +46,34 @@ router = APIRouter(prefix="/osint", tags=["osint"])
|
|||||||
|
|
||||||
|
|
||||||
class OsintItemOut(BaseModel):
|
class OsintItemOut(BaseModel):
|
||||||
|
"""Serialized OSINT item returned by the API."""
|
||||||
|
|
||||||
|
# id: str
|
||||||
id: str
|
id: str
|
||||||
|
# technique_id: str
|
||||||
technique_id: str
|
technique_id: str
|
||||||
|
# source_type: str
|
||||||
source_type: str
|
source_type: str
|
||||||
|
# source_url: str
|
||||||
source_url: str
|
source_url: str
|
||||||
|
# title: str
|
||||||
title: str
|
title: str
|
||||||
|
# description: str | None
|
||||||
description: str | None
|
description: str | None
|
||||||
|
# severity: str | None
|
||||||
severity: str | None
|
severity: str | None
|
||||||
|
# discovered_at: str | None
|
||||||
discovered_at: str | None
|
discovered_at: str | None
|
||||||
|
# reviewed: bool
|
||||||
reviewed: bool
|
reviewed: bool
|
||||||
|
# Assign metadata_ = None
|
||||||
metadata_: dict | None = None
|
metadata_: dict | None = None
|
||||||
|
|
||||||
|
# Define class Config
|
||||||
class Config:
|
class Config:
|
||||||
|
"""ORM mode configuration for SQLAlchemy model mapping."""
|
||||||
|
|
||||||
|
# Assign from_attributes = True
|
||||||
from_attributes = True
|
from_attributes = True
|
||||||
|
|
||||||
|
|
||||||
@@ -47,94 +81,207 @@ class OsintItemOut(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/items")
|
@router.get("/items")
|
||||||
|
# Define function list_osint_items
|
||||||
def list_osint_items(
|
def list_osint_items(
|
||||||
|
# Entry: technique_id
|
||||||
technique_id: UUID | None = Query(None),
|
technique_id: UUID | None = Query(None),
|
||||||
|
# Entry: source_type
|
||||||
source_type: str | None = Query(None),
|
source_type: str | None = Query(None),
|
||||||
|
# Entry: reviewed
|
||||||
reviewed: bool | None = Query(None),
|
reviewed: bool | None = Query(None),
|
||||||
|
# Entry: offset
|
||||||
offset: int = Query(0, ge=0),
|
offset: int = Query(0, ge=0),
|
||||||
|
# Entry: limit
|
||||||
limit: int = Query(50, ge=1, le=200),
|
limit: int = Query(50, ge=1, le=200),
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: user
|
||||||
user: User = Depends(get_current_user),
|
user: User = Depends(get_current_user),
|
||||||
):
|
) -> list:
|
||||||
"""List OSINT items with optional filters."""
|
"""List OSINT items with optional filters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
technique_id (UUID | None): Filter by the technique's UUID.
|
||||||
|
source_type (str | None): Filter by source type (e.g. ``nvd_cve``, ``advisory``).
|
||||||
|
reviewed (bool | None): Filter by review status; ``None`` returns all.
|
||||||
|
offset (int): Number of records to skip for pagination.
|
||||||
|
limit (int): Maximum number of records to return.
|
||||||
|
db (Session): SQLAlchemy database session.
|
||||||
|
user (User): Authenticated user making the request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: Serialised list of OSINT item dicts matching the filters.
|
||||||
|
"""
|
||||||
|
# Return service_list_osint_items(
|
||||||
return service_list_osint_items(
|
return service_list_osint_items(
|
||||||
db,
|
db,
|
||||||
|
# Keyword argument: technique_id
|
||||||
technique_id=technique_id,
|
technique_id=technique_id,
|
||||||
|
# Keyword argument: source_type
|
||||||
source_type=source_type,
|
source_type=source_type,
|
||||||
|
# Keyword argument: reviewed
|
||||||
reviewed=reviewed,
|
reviewed=reviewed,
|
||||||
|
# Keyword argument: offset
|
||||||
offset=offset,
|
offset=offset,
|
||||||
|
# Keyword argument: limit
|
||||||
limit=limit,
|
limit=limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @router.get decorator
|
||||||
@router.get("/summary")
|
@router.get("/summary")
|
||||||
|
# Define function osint_summary
|
||||||
def osint_summary(
|
def osint_summary(
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: user
|
||||||
user: User = Depends(get_current_user),
|
user: User = Depends(get_current_user),
|
||||||
):
|
) -> dict:
|
||||||
"""Summary statistics for OSINT items."""
|
"""Return summary statistics for OSINT items.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db (Session): SQLAlchemy database session.
|
||||||
|
user (User): Authenticated user making the request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Counts of total, reviewed, and unreviewed items broken down by source type.
|
||||||
|
"""
|
||||||
|
# Return get_osint_summary(db)
|
||||||
return get_osint_summary(db)
|
return get_osint_summary(db)
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @router.post decorator
|
||||||
@router.post("/items/{item_id}/review")
|
@router.post("/items/{item_id}/review")
|
||||||
|
# Define function review_osint_item
|
||||||
def review_osint_item(
|
def review_osint_item(
|
||||||
|
# Entry: item_id
|
||||||
item_id: UUID,
|
item_id: UUID,
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: user
|
||||||
user: User = Depends(get_current_user),
|
user: User = Depends(get_current_user),
|
||||||
):
|
) -> dict:
|
||||||
"""Mark an OSINT item as reviewed."""
|
"""Mark an OSINT item as reviewed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
item_id (UUID): Primary key of the OSINT item to mark reviewed.
|
||||||
|
db (Session): SQLAlchemy database session.
|
||||||
|
user (User): Authenticated user performing the review.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Contains ``id`` (str) and ``reviewed`` (bool ``True``).
|
||||||
|
"""
|
||||||
|
# Open context manager
|
||||||
with UnitOfWork(db) as uow:
|
with UnitOfWork(db) as uow:
|
||||||
|
# Assign item = mark_osint_reviewed(db, str(item_id))
|
||||||
item = mark_osint_reviewed(db, str(item_id))
|
item = mark_osint_reviewed(db, str(item_id))
|
||||||
|
# Check: not item
|
||||||
if not item:
|
if not item:
|
||||||
|
# Raise HTTPException
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
# Keyword argument: status_code
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
# Keyword argument: detail
|
||||||
detail="OSINT item not found",
|
detail="OSINT item not found",
|
||||||
)
|
)
|
||||||
|
# Call uow.commit()
|
||||||
uow.commit()
|
uow.commit()
|
||||||
|
# Return {"id": str(item.id), "reviewed": True}
|
||||||
return {"id": str(item.id), "reviewed": True}
|
return {"id": str(item.id), "reviewed": True}
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @router.post decorator
|
||||||
@router.post("/enrich/{technique_id}")
|
@router.post("/enrich/{technique_id}")
|
||||||
|
# Define function trigger_technique_enrichment
|
||||||
def trigger_technique_enrichment(
|
def trigger_technique_enrichment(
|
||||||
|
# Entry: technique_id
|
||||||
technique_id: UUID,
|
technique_id: UUID,
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: user
|
||||||
user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||||
):
|
) -> dict:
|
||||||
"""Manually trigger OSINT enrichment for a single technique."""
|
"""Manually trigger OSINT enrichment for a single technique.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
technique_id (UUID): Primary key of the technique to enrich.
|
||||||
|
db (Session): SQLAlchemy database session.
|
||||||
|
user (User): Authenticated red_lead or blue_lead requesting enrichment.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Contains ``technique_id`` (str), ``mitre_id`` (str), and ``new_items`` (int).
|
||||||
|
"""
|
||||||
|
# Assign technique = get_technique_or_raise(db, technique_id)
|
||||||
technique = get_technique_or_raise(db, technique_id)
|
technique = get_technique_or_raise(db, technique_id)
|
||||||
|
# Assign count = enrich_technique_with_cves(db, technique)
|
||||||
count = enrich_technique_with_cves(db, technique)
|
count = enrich_technique_with_cves(db, technique)
|
||||||
|
# Return {
|
||||||
return {
|
return {
|
||||||
|
# Literal argument value
|
||||||
"technique_id": str(technique.id),
|
"technique_id": str(technique.id),
|
||||||
|
# Literal argument value
|
||||||
"mitre_id": technique.mitre_id,
|
"mitre_id": technique.mitre_id,
|
||||||
|
# Literal argument value
|
||||||
"new_items": count,
|
"new_items": count,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @router.get decorator
|
||||||
@router.get("/technique/{technique_id}")
|
@router.get("/technique/{technique_id}")
|
||||||
|
# Define function get_technique_osint
|
||||||
def get_technique_osint(
|
def get_technique_osint(
|
||||||
|
# Entry: technique_id
|
||||||
technique_id: UUID,
|
technique_id: UUID,
|
||||||
|
# Entry: source_type
|
||||||
source_type: str | None = Query(None),
|
source_type: str | None = Query(None),
|
||||||
|
# Entry: reviewed
|
||||||
reviewed: bool | None = Query(None),
|
reviewed: bool | None = Query(None),
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: user
|
||||||
user: User = Depends(get_current_user),
|
user: User = Depends(get_current_user),
|
||||||
):
|
) -> list:
|
||||||
"""Get all OSINT items for a specific technique."""
|
"""Get all OSINT items for a specific technique.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
technique_id (UUID): Primary key of the technique.
|
||||||
|
source_type (str | None): Filter by source type (e.g. ``nvd_cve``).
|
||||||
|
reviewed (bool | None): Filter by review status; ``None`` returns all.
|
||||||
|
db (Session): SQLAlchemy database session.
|
||||||
|
user (User): Authenticated user making the request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: Dicts with OSINT item fields including source URL, severity, and review status.
|
||||||
|
"""
|
||||||
|
# Assign items = get_osint_items_for_technique(
|
||||||
items = get_osint_items_for_technique(
|
items = get_osint_items_for_technique(
|
||||||
db,
|
db,
|
||||||
str(technique_id),
|
str(technique_id),
|
||||||
|
# Keyword argument: source_type
|
||||||
source_type=source_type,
|
source_type=source_type,
|
||||||
|
# Keyword argument: reviewed
|
||||||
reviewed=reviewed,
|
reviewed=reviewed,
|
||||||
)
|
)
|
||||||
|
# Return [
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
|
# Literal argument value
|
||||||
"id": str(item.id),
|
"id": str(item.id),
|
||||||
|
# Literal argument value
|
||||||
"source_type": item.source_type,
|
"source_type": item.source_type,
|
||||||
|
# Literal argument value
|
||||||
"source_url": item.source_url,
|
"source_url": item.source_url,
|
||||||
|
# Literal argument value
|
||||||
"title": item.title,
|
"title": item.title,
|
||||||
|
# Literal argument value
|
||||||
"description": item.description,
|
"description": item.description,
|
||||||
|
# Literal argument value
|
||||||
"severity": item.severity,
|
"severity": item.severity,
|
||||||
|
# Literal argument value
|
||||||
"discovered_at": item.discovered_at.isoformat() if item.discovered_at else None,
|
"discovered_at": item.discovered_at.isoformat() if item.discovered_at else None,
|
||||||
|
# Literal argument value
|
||||||
"reviewed": item.reviewed,
|
"reviewed": item.reviewed,
|
||||||
|
# Literal argument value
|
||||||
"metadata": item.metadata_,
|
"metadata": item.metadata_,
|
||||||
}
|
}
|
||||||
for item in items
|
for item in items
|
||||||
|
|||||||
@@ -1,118 +1,195 @@
|
|||||||
"""Professional report generation endpoints — PDF, DOCX, HTML output."""
|
"""Professional report generation endpoints — PDF, DOCX, HTML output."""
|
||||||
|
|
||||||
|
# Import UUID from uuid
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
|
# Import APIRouter, Depends, Query, Request from fastapi
|
||||||
from fastapi import APIRouter, Depends, Query, Request
|
from fastapi import APIRouter, Depends, Query, Request
|
||||||
|
|
||||||
|
# Import FileResponse from fastapi.responses
|
||||||
from fastapi.responses import FileResponse
|
from fastapi.responses import FileResponse
|
||||||
|
|
||||||
|
# Import Session from sqlalchemy.orm
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
# Import get_db from app.database
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
|
|
||||||
|
# Import get_current_user, require_any_role from app.dependencies.auth
|
||||||
from app.dependencies.auth import get_current_user, require_any_role
|
from app.dependencies.auth import get_current_user, require_any_role
|
||||||
from app.models.user import User
|
|
||||||
|
# Import limiter from app.limiter
|
||||||
from app.limiter import limiter
|
from app.limiter import limiter
|
||||||
|
|
||||||
|
# Import User from app.models.user
|
||||||
|
from app.models.user import User
|
||||||
|
|
||||||
|
# Import report_generation_service from app.services
|
||||||
from app.services import report_generation_service
|
from app.services import report_generation_service
|
||||||
|
|
||||||
|
# Assign router = APIRouter(prefix="/reports/generate", tags=["professional-reports"])
|
||||||
router = APIRouter(prefix="/reports/generate", tags=["professional-reports"])
|
router = APIRouter(prefix="/reports/generate", tags=["professional-reports"])
|
||||||
|
|
||||||
|
# Assign _MEDIA_TYPES = {
|
||||||
_MEDIA_TYPES = {
|
_MEDIA_TYPES = {
|
||||||
|
# Literal argument value
|
||||||
"pdf": "application/pdf",
|
"pdf": "application/pdf",
|
||||||
|
# Literal argument value
|
||||||
"docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
"docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||||
|
# Literal argument value
|
||||||
"html": "text/html",
|
"html": "text/html",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @router.get decorator
|
||||||
@router.get("/purple-campaign/{campaign_id}")
|
@router.get("/purple-campaign/{campaign_id}")
|
||||||
|
# Apply the @limiter.limit decorator
|
||||||
@limiter.limit("5/minute")
|
@limiter.limit("5/minute")
|
||||||
|
# Define function generate_purple_report
|
||||||
def generate_purple_report(
|
def generate_purple_report(
|
||||||
|
# Entry: request
|
||||||
request: Request,
|
request: Request,
|
||||||
|
# Entry: campaign_id
|
||||||
campaign_id: UUID,
|
campaign_id: UUID,
|
||||||
|
# Entry: format
|
||||||
format: str = Query("pdf", pattern="^(pdf|docx|html)$"),
|
format: str = Query("pdf", pattern="^(pdf|docx|html)$"),
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: user
|
||||||
user: User = Depends(require_any_role("red_lead", "blue_lead", "viewer")),
|
user: User = Depends(require_any_role("red_lead", "blue_lead", "viewer")),
|
||||||
):
|
) -> FileResponse:
|
||||||
"""Generate a Purple Team campaign assessment report."""
|
"""Generate a Purple Team campaign assessment report."""
|
||||||
|
# Assign filepath = report_generation_service.generate_purple_campaign_report(
|
||||||
filepath = report_generation_service.generate_purple_campaign_report(
|
filepath = report_generation_service.generate_purple_campaign_report(
|
||||||
db, str(campaign_id), output_format=format,
|
db, str(campaign_id), output_format=format,
|
||||||
)
|
)
|
||||||
|
# Return FileResponse(
|
||||||
return FileResponse(
|
return FileResponse(
|
||||||
filepath,
|
filepath,
|
||||||
|
# Keyword argument: media_type
|
||||||
media_type=_MEDIA_TYPES[format],
|
media_type=_MEDIA_TYPES[format],
|
||||||
|
# Keyword argument: filename
|
||||||
filename=f"purple_report.{format}",
|
filename=f"purple_report.{format}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @router.get decorator
|
||||||
@router.get("/coverage-summary")
|
@router.get("/coverage-summary")
|
||||||
|
# Apply the @limiter.limit decorator
|
||||||
@limiter.limit("5/minute")
|
@limiter.limit("5/minute")
|
||||||
|
# Define function generate_coverage_report
|
||||||
def generate_coverage_report(
|
def generate_coverage_report(
|
||||||
|
# Entry: request
|
||||||
request: Request,
|
request: Request,
|
||||||
|
# Entry: format
|
||||||
format: str = Query("pdf", pattern="^(pdf|docx|html)$"),
|
format: str = Query("pdf", pattern="^(pdf|docx|html)$"),
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: user
|
||||||
user: User = Depends(require_any_role("red_lead", "blue_lead", "viewer")),
|
user: User = Depends(require_any_role("red_lead", "blue_lead", "viewer")),
|
||||||
):
|
) -> FileResponse:
|
||||||
"""Generate an organization-wide MITRE ATT&CK coverage report."""
|
"""Generate an organization-wide MITRE ATT&CK coverage report."""
|
||||||
|
# Assign filepath = report_generation_service.generate_coverage_report(
|
||||||
filepath = report_generation_service.generate_coverage_report(
|
filepath = report_generation_service.generate_coverage_report(
|
||||||
db, output_format=format,
|
db, output_format=format,
|
||||||
)
|
)
|
||||||
|
# Return FileResponse(
|
||||||
return FileResponse(
|
return FileResponse(
|
||||||
filepath,
|
filepath,
|
||||||
|
# Keyword argument: media_type
|
||||||
media_type=_MEDIA_TYPES[format],
|
media_type=_MEDIA_TYPES[format],
|
||||||
|
# Keyword argument: filename
|
||||||
filename=f"coverage_report.{format}",
|
filename=f"coverage_report.{format}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @router.get decorator
|
||||||
@router.get("/executive-summary")
|
@router.get("/executive-summary")
|
||||||
|
# Apply the @limiter.limit decorator
|
||||||
@limiter.limit("5/minute")
|
@limiter.limit("5/minute")
|
||||||
|
# Define function generate_executive_report
|
||||||
def generate_executive_report(
|
def generate_executive_report(
|
||||||
|
# Entry: request
|
||||||
request: Request,
|
request: Request,
|
||||||
|
# Entry: format
|
||||||
format: str = Query("pdf", pattern="^(pdf|docx|html)$"),
|
format: str = Query("pdf", pattern="^(pdf|docx|html)$"),
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: user
|
||||||
user: User = Depends(require_any_role("red_lead", "blue_lead", "viewer")),
|
user: User = Depends(require_any_role("red_lead", "blue_lead", "viewer")),
|
||||||
):
|
) -> FileResponse:
|
||||||
"""Generate an executive security summary report."""
|
"""Generate an executive security summary report."""
|
||||||
|
# Assign filepath = report_generation_service.generate_executive_summary(
|
||||||
filepath = report_generation_service.generate_executive_summary(
|
filepath = report_generation_service.generate_executive_summary(
|
||||||
db, output_format=format,
|
db, output_format=format,
|
||||||
)
|
)
|
||||||
|
# Return FileResponse(
|
||||||
return FileResponse(
|
return FileResponse(
|
||||||
filepath,
|
filepath,
|
||||||
|
# Keyword argument: media_type
|
||||||
media_type=_MEDIA_TYPES[format],
|
media_type=_MEDIA_TYPES[format],
|
||||||
|
# Keyword argument: filename
|
||||||
filename=f"executive_summary.{format}",
|
filename=f"executive_summary.{format}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @router.get decorator
|
||||||
@router.get("/quarterly-summary")
|
@router.get("/quarterly-summary")
|
||||||
|
# Apply the @limiter.limit decorator
|
||||||
@limiter.limit("5/minute")
|
@limiter.limit("5/minute")
|
||||||
|
# Define function generate_quarterly_report
|
||||||
def generate_quarterly_report(
|
def generate_quarterly_report(
|
||||||
|
# Entry: request
|
||||||
request: Request,
|
request: Request,
|
||||||
|
# Entry: format
|
||||||
format: str = Query("pdf", pattern="^(pdf|docx|html)$"),
|
format: str = Query("pdf", pattern="^(pdf|docx|html)$"),
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: user
|
||||||
user: User = Depends(require_any_role("red_lead", "blue_lead", "viewer")),
|
user: User = Depends(require_any_role("red_lead", "blue_lead", "viewer")),
|
||||||
):
|
) -> FileResponse:
|
||||||
"""Generate a quarterly security summary report."""
|
"""Generate a quarterly security summary report."""
|
||||||
|
# Assign filepath = report_generation_service.generate_quarterly_summary(
|
||||||
filepath = report_generation_service.generate_quarterly_summary(
|
filepath = report_generation_service.generate_quarterly_summary(
|
||||||
db, output_format=format,
|
db, output_format=format,
|
||||||
)
|
)
|
||||||
|
# Return FileResponse(
|
||||||
return FileResponse(
|
return FileResponse(
|
||||||
filepath,
|
filepath,
|
||||||
|
# Keyword argument: media_type
|
||||||
media_type=_MEDIA_TYPES[format],
|
media_type=_MEDIA_TYPES[format],
|
||||||
|
# Keyword argument: filename
|
||||||
filename=f"quarterly_summary.{format}",
|
filename=f"quarterly_summary.{format}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @router.get decorator
|
||||||
@router.get("/technique/{technique_id}")
|
@router.get("/technique/{technique_id}")
|
||||||
|
# Apply the @limiter.limit decorator
|
||||||
@limiter.limit("5/minute")
|
@limiter.limit("5/minute")
|
||||||
|
# Define function generate_technique_report
|
||||||
def generate_technique_report(
|
def generate_technique_report(
|
||||||
|
# Entry: request
|
||||||
request: Request,
|
request: Request,
|
||||||
|
# Entry: technique_id
|
||||||
technique_id: UUID,
|
technique_id: UUID,
|
||||||
|
# Entry: format
|
||||||
format: str = Query("pdf", pattern="^(pdf|docx|html)$"),
|
format: str = Query("pdf", pattern="^(pdf|docx|html)$"),
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: user
|
||||||
user: User = Depends(get_current_user),
|
user: User = Depends(get_current_user),
|
||||||
):
|
) -> FileResponse:
|
||||||
"""Generate a detailed report for one MITRE technique."""
|
"""Generate a detailed report for one MITRE technique."""
|
||||||
|
# Assign filepath = report_generation_service.generate_technique_detail_report(
|
||||||
filepath = report_generation_service.generate_technique_detail_report(
|
filepath = report_generation_service.generate_technique_detail_report(
|
||||||
db, str(technique_id), output_format=format,
|
db, str(technique_id), output_format=format,
|
||||||
)
|
)
|
||||||
|
# Return FileResponse(
|
||||||
return FileResponse(
|
return FileResponse(
|
||||||
filepath,
|
filepath,
|
||||||
|
# Keyword argument: media_type
|
||||||
media_type=_MEDIA_TYPES[format],
|
media_type=_MEDIA_TYPES[format],
|
||||||
|
# Keyword argument: filename
|
||||||
filename=f"technique_{technique_id}.{format}",
|
filename=f"technique_{technique_id}.{format}",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -10,18 +10,37 @@ GET /reports/test-results — test results report (JSON)
|
|||||||
GET /reports/remediation-status — remediation status report (JSON)
|
GET /reports/remediation-status — remediation status report (JSON)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Import csv
|
||||||
import csv
|
import csv
|
||||||
|
|
||||||
|
# Import io
|
||||||
import io
|
import io
|
||||||
|
|
||||||
|
# Import datetime from datetime
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
# Import Optional from typing
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
# Import APIRouter, Depends, Query from fastapi
|
||||||
from fastapi import APIRouter, Depends, Query
|
from fastapi import APIRouter, Depends, Query
|
||||||
|
|
||||||
|
# Import StreamingResponse from fastapi.responses
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
|
# Import Session from sqlalchemy.orm
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
# Import get_db from app.database
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
|
|
||||||
|
# Import get_current_user from app.dependencies.auth
|
||||||
from app.dependencies.auth import get_current_user
|
from app.dependencies.auth import get_current_user
|
||||||
|
|
||||||
|
# Import User from app.models.user
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
|
|
||||||
|
# Import from app.services.coverage_report_service
|
||||||
from app.services.coverage_report_service import (
|
from app.services.coverage_report_service import (
|
||||||
build_coverage_csv_rows,
|
build_coverage_csv_rows,
|
||||||
build_coverage_summary,
|
build_coverage_summary,
|
||||||
@@ -29,61 +48,99 @@ from app.services.coverage_report_service import (
|
|||||||
build_test_results_report,
|
build_test_results_report,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Assign router = APIRouter(prefix="/reports", tags=["reports"])
|
||||||
router = APIRouter(prefix="/reports", tags=["reports"])
|
router = APIRouter(prefix="/reports", tags=["reports"])
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @router.get decorator
|
||||||
@router.get("/coverage-summary")
|
@router.get("/coverage-summary")
|
||||||
|
# Define function coverage_summary
|
||||||
def coverage_summary(
|
def coverage_summary(
|
||||||
|
# Entry: tactic
|
||||||
tactic: Optional[str] = Query(None, description="Filter by tactic"),
|
tactic: Optional[str] = Query(None, description="Filter by tactic"),
|
||||||
|
# Entry: platform
|
||||||
platform: Optional[str] = Query(None, description="Filter by platform (in techniques)"),
|
platform: Optional[str] = Query(None, description="Filter by platform (in techniques)"),
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
) -> dict:
|
||||||
"""Full coverage report as JSON — technique-by-technique with test counts."""
|
"""Full coverage report as JSON — technique-by-technique with test counts."""
|
||||||
|
# Return build_coverage_summary(db, tactic=tactic, platform=platform)
|
||||||
return build_coverage_summary(db, tactic=tactic, platform=platform)
|
return build_coverage_summary(db, tactic=tactic, platform=platform)
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @router.get decorator
|
||||||
@router.get("/coverage-csv")
|
@router.get("/coverage-csv")
|
||||||
|
# Define function coverage_csv
|
||||||
def coverage_csv(
|
def coverage_csv(
|
||||||
|
# Entry: tactic
|
||||||
tactic: Optional[str] = Query(None),
|
tactic: Optional[str] = Query(None),
|
||||||
|
# Entry: platform
|
||||||
platform: Optional[str] = Query(None),
|
platform: Optional[str] = Query(None),
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
) -> StreamingResponse:
|
||||||
"""Export coverage as a downloadable CSV."""
|
"""Export coverage as a downloadable CSV."""
|
||||||
|
# Assign rows = build_coverage_csv_rows(db, tactic=tactic, platform=platform)
|
||||||
rows = build_coverage_csv_rows(db, tactic=tactic, platform=platform)
|
rows = build_coverage_csv_rows(db, tactic=tactic, platform=platform)
|
||||||
|
|
||||||
|
# Assign output = io.StringIO()
|
||||||
output = io.StringIO()
|
output = io.StringIO()
|
||||||
|
# Assign writer = csv.writer(output)
|
||||||
writer = csv.writer(output)
|
writer = csv.writer(output)
|
||||||
|
# Iterate over rows
|
||||||
for row in rows:
|
for row in rows:
|
||||||
|
# Call writer.writerow()
|
||||||
writer.writerow(row)
|
writer.writerow(row)
|
||||||
|
|
||||||
|
# Call output.seek()
|
||||||
output.seek(0)
|
output.seek(0)
|
||||||
|
# Assign filename = f"aegis_coverage_{datetime.utcnow().strftime('%Y%m%d')}.csv"
|
||||||
filename = f"aegis_coverage_{datetime.utcnow().strftime('%Y%m%d')}.csv"
|
filename = f"aegis_coverage_{datetime.utcnow().strftime('%Y%m%d')}.csv"
|
||||||
|
# Return StreamingResponse(
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
iter([output.getvalue()]),
|
iter([output.getvalue()]),
|
||||||
|
# Keyword argument: media_type
|
||||||
media_type="text/csv",
|
media_type="text/csv",
|
||||||
|
# Keyword argument: headers
|
||||||
headers={"Content-Disposition": f"attachment; filename={filename}"},
|
headers={"Content-Disposition": f"attachment; filename={filename}"},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @router.get decorator
|
||||||
@router.get("/test-results")
|
@router.get("/test-results")
|
||||||
|
# Define function test_results
|
||||||
def test_results(
|
def test_results(
|
||||||
|
# Entry: state
|
||||||
state: Optional[str] = Query(None),
|
state: Optional[str] = Query(None),
|
||||||
|
# Entry: date_from
|
||||||
date_from: Optional[str] = Query(None, description="ISO date string YYYY-MM-DD"),
|
date_from: Optional[str] = Query(None, description="ISO date string YYYY-MM-DD"),
|
||||||
|
# Entry: date_to
|
||||||
date_to: Optional[str] = Query(None, description="ISO date string YYYY-MM-DD"),
|
date_to: Optional[str] = Query(None, description="ISO date string YYYY-MM-DD"),
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
) -> dict:
|
||||||
"""Report of test results with optional filters."""
|
"""Report of test results with optional filters."""
|
||||||
|
# Return build_test_results_report(db, state=state, date_from=date_from, dat...
|
||||||
return build_test_results_report(db, state=state, date_from=date_from, date_to=date_to)
|
return build_test_results_report(db, state=state, date_from=date_from, date_to=date_to)
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @router.get decorator
|
||||||
@router.get("/remediation-status")
|
@router.get("/remediation-status")
|
||||||
|
# Define function remediation_status
|
||||||
def remediation_status(
|
def remediation_status(
|
||||||
|
# Entry: status
|
||||||
status: Optional[str] = Query(None, description="Filter by remediation status"),
|
status: Optional[str] = Query(None, description="Filter by remediation status"),
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
) -> dict:
|
||||||
"""Report of remediation status across all tests."""
|
"""Report of remediation status across all tests."""
|
||||||
|
# Return build_remediation_status_report(db, status=status)
|
||||||
return build_remediation_status_report(db, status=status)
|
return build_remediation_status_report(db, status=status)
|
||||||
|
|||||||
+154
-20
@@ -3,28 +3,45 @@
|
|||||||
Provides granular scoring with breakdowns and configurable weights.
|
Provides granular scoring with breakdowns and configurable weights.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Import Optional from typing
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
# Import APIRouter, Depends, Query from fastapi
|
||||||
from fastapi import APIRouter, Depends, Query
|
from fastapi import APIRouter, Depends, Query
|
||||||
|
|
||||||
|
# Import BaseModel from pydantic
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
# Import Session from sqlalchemy.orm
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
# Import get_db from app.database
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
|
|
||||||
|
# Import get_current_user, require_role from app.dependencies.auth
|
||||||
from app.dependencies.auth import get_current_user, require_role
|
from app.dependencies.auth import get_current_user, require_role
|
||||||
|
|
||||||
|
# Import UnitOfWork from app.domain.unit_of_work
|
||||||
from app.domain.unit_of_work import UnitOfWork
|
from app.domain.unit_of_work import UnitOfWork
|
||||||
|
|
||||||
|
# Import User from app.models.user
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.services.scoring_service import (
|
|
||||||
score_technique_by_mitre_id,
|
# Import from app.services.scoring_config_service
|
||||||
score_actor_by_id,
|
|
||||||
calculate_tactic_score,
|
|
||||||
calculate_organization_score,
|
|
||||||
get_score_history,
|
|
||||||
)
|
|
||||||
from app.services.scoring_config_service import (
|
from app.services.scoring_config_service import (
|
||||||
get_weights_dict,
|
get_weights_dict,
|
||||||
update_scoring_weights,
|
update_scoring_weights,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Import from app.services.scoring_service
|
||||||
|
from app.services.scoring_service import (
|
||||||
|
calculate_tactic_score,
|
||||||
|
get_score_history,
|
||||||
|
score_actor_by_id,
|
||||||
|
score_technique_by_mitre_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assign router = APIRouter(prefix="/scores", tags=["scores"])
|
||||||
router = APIRouter(prefix="/scores", tags=["scores"])
|
router = APIRouter(prefix="/scores", tags=["scores"])
|
||||||
|
|
||||||
|
|
||||||
@@ -32,12 +49,26 @@ router = APIRouter(prefix="/scores", tags=["scores"])
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/technique/{mitre_id}")
|
@router.get("/technique/{mitre_id}")
|
||||||
|
# Define function score_technique
|
||||||
def score_technique(
|
def score_technique(
|
||||||
|
# Entry: mitre_id
|
||||||
mitre_id: str,
|
mitre_id: str,
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
) -> dict:
|
||||||
"""Get detailed score with breakdown for a specific technique."""
|
"""Get detailed score with breakdown for a specific technique.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mitre_id (str): MITRE ATT&CK technique ID (e.g. ``T1059``).
|
||||||
|
db (Session): SQLAlchemy database session.
|
||||||
|
current_user (User): Authenticated user making the request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Score value and component breakdown (tests, detection rules, recency, etc.).
|
||||||
|
"""
|
||||||
|
# Return score_technique_by_mitre_id(db, mitre_id)
|
||||||
return score_technique_by_mitre_id(db, mitre_id)
|
return score_technique_by_mitre_id(db, mitre_id)
|
||||||
|
|
||||||
|
|
||||||
@@ -45,12 +76,26 @@ def score_technique(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/tactic/{tactic}")
|
@router.get("/tactic/{tactic}")
|
||||||
|
# Define function score_tactic
|
||||||
def score_tactic(
|
def score_tactic(
|
||||||
|
# Entry: tactic
|
||||||
tactic: str,
|
tactic: str,
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
) -> dict:
|
||||||
"""Get average score for a tactic."""
|
"""Get average score for a tactic.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tactic (str): MITRE ATT&CK tactic slug (e.g. ``initial-access``).
|
||||||
|
db (Session): SQLAlchemy database session.
|
||||||
|
current_user (User): Authenticated user making the request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Average score and per-technique breakdown for the tactic.
|
||||||
|
"""
|
||||||
|
# Return calculate_tactic_score(tactic, db)
|
||||||
return calculate_tactic_score(tactic, db)
|
return calculate_tactic_score(tactic, db)
|
||||||
|
|
||||||
|
|
||||||
@@ -58,12 +103,26 @@ def score_tactic(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/threat-actor/{actor_id}")
|
@router.get("/threat-actor/{actor_id}")
|
||||||
|
# Define function score_threat_actor
|
||||||
def score_threat_actor(
|
def score_threat_actor(
|
||||||
|
# Entry: actor_id
|
||||||
actor_id: str,
|
actor_id: str,
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
) -> dict:
|
||||||
"""Get coverage score against a specific threat actor."""
|
"""Get coverage score against a specific threat actor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
actor_id (str): UUID string of the threat actor to score against.
|
||||||
|
db (Session): SQLAlchemy database session.
|
||||||
|
current_user (User): Authenticated user making the request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Coverage score and per-technique breakdown for the threat actor.
|
||||||
|
"""
|
||||||
|
# Return score_actor_by_id(db, actor_id)
|
||||||
return score_actor_by_id(db, actor_id)
|
return score_actor_by_id(db, actor_id)
|
||||||
|
|
||||||
|
|
||||||
@@ -71,13 +130,26 @@ def score_threat_actor(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/organization")
|
@router.get("/organization")
|
||||||
|
# Define function score_organization
|
||||||
def score_organization(
|
def score_organization(
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
) -> dict:
|
||||||
"""Get the overall organization security score (cached for 5 min)."""
|
"""Get the overall organization security score (cached for 5 min).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db (Session): SQLAlchemy database session.
|
||||||
|
current_user (User): Authenticated user making the request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Aggregate organization score with tactic-level breakdowns.
|
||||||
|
"""
|
||||||
|
# Import get_organization_score_cached from app.services.score_cache
|
||||||
from app.services.score_cache import get_organization_score_cached
|
from app.services.score_cache import get_organization_score_cached
|
||||||
|
|
||||||
|
# Return get_organization_score_cached(db)
|
||||||
return get_organization_score_cached(db)
|
return get_organization_score_cached(db)
|
||||||
|
|
||||||
|
|
||||||
@@ -85,12 +157,26 @@ def score_organization(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/history")
|
@router.get("/history")
|
||||||
|
# Define function score_history
|
||||||
def score_history(
|
def score_history(
|
||||||
|
# Entry: period
|
||||||
period: str = Query("90d", pattern="^(30d|90d|1y)$"),
|
period: str = Query("90d", pattern="^(30d|90d|1y)$"),
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
) -> dict:
|
||||||
"""Get historical score data points (weekly)."""
|
"""Get historical score data points (weekly).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
period (str): Time window for history — one of ``30d``, ``90d``, or ``1y``.
|
||||||
|
db (Session): SQLAlchemy database session.
|
||||||
|
current_user (User): Authenticated user making the request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Weekly score data points for the requested period.
|
||||||
|
"""
|
||||||
|
# Return get_score_history(db, period)
|
||||||
return get_score_history(db, period)
|
return get_score_history(db, period)
|
||||||
|
|
||||||
|
|
||||||
@@ -98,11 +184,23 @@ def score_history(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/config")
|
@router.get("/config")
|
||||||
|
# Define function get_scoring_config
|
||||||
def get_scoring_config(
|
def get_scoring_config(
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(require_role("admin")),
|
current_user: User = Depends(require_role("admin")),
|
||||||
):
|
) -> dict:
|
||||||
"""Get current scoring weights (admin only)."""
|
"""Get current scoring weights (admin only).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db (Session): SQLAlchemy database session.
|
||||||
|
current_user (User): Authenticated admin user.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Current weight values for each scoring component.
|
||||||
|
"""
|
||||||
|
# Return get_weights_dict(db)
|
||||||
return get_weights_dict(db)
|
return get_weights_dict(db)
|
||||||
|
|
||||||
|
|
||||||
@@ -110,41 +208,77 @@ def get_scoring_config(
|
|||||||
|
|
||||||
|
|
||||||
class ScoringConfigUpdate(BaseModel):
|
class ScoringConfigUpdate(BaseModel):
|
||||||
|
"""Partial update payload for the scoring weight configuration."""
|
||||||
|
|
||||||
|
# Assign tests = None
|
||||||
tests: Optional[float] = None
|
tests: Optional[float] = None
|
||||||
|
# Assign detection_rules = None
|
||||||
detection_rules: Optional[float] = None
|
detection_rules: Optional[float] = None
|
||||||
|
# Assign d3fend = None
|
||||||
d3fend: Optional[float] = None
|
d3fend: Optional[float] = None
|
||||||
|
# Assign recency = None
|
||||||
recency: Optional[float] = None
|
recency: Optional[float] = None
|
||||||
|
# Assign severity = None
|
||||||
severity: Optional[float] = None
|
severity: Optional[float] = None
|
||||||
|
# Assign freshness = None
|
||||||
freshness: Optional[float] = None
|
freshness: Optional[float] = None
|
||||||
|
# Assign platform_diversity = None
|
||||||
platform_diversity: Optional[float] = None
|
platform_diversity: Optional[float] = None
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @router.patch decorator
|
||||||
@router.patch("/config")
|
@router.patch("/config")
|
||||||
|
# Define function update_scoring_config
|
||||||
def update_scoring_config(
|
def update_scoring_config(
|
||||||
|
# Entry: payload
|
||||||
payload: ScoringConfigUpdate,
|
payload: ScoringConfigUpdate,
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(require_role("admin")),
|
current_user: User = Depends(require_role("admin")),
|
||||||
):
|
) -> dict:
|
||||||
"""Update scoring weights (admin only).
|
"""Update scoring weights (admin only).
|
||||||
|
|
||||||
Weights are persisted in the database and survive restarts.
|
Weights are persisted in the database and survive restarts.
|
||||||
Validation enforces that all weights are non-negative and sum to 100.
|
Validation enforces that all weights are non-negative and sum to 100.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
payload (ScoringConfigUpdate): Partial weight update; only set fields are changed.
|
||||||
|
db (Session): SQLAlchemy database session.
|
||||||
|
current_user (User): Authenticated admin user.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Confirmation message plus the full updated weight configuration.
|
||||||
"""
|
"""
|
||||||
|
# Open context manager
|
||||||
with UnitOfWork(db) as uow:
|
with UnitOfWork(db) as uow:
|
||||||
|
# Assign result = update_scoring_weights(
|
||||||
result = update_scoring_weights(
|
result = update_scoring_weights(
|
||||||
db,
|
db,
|
||||||
|
# Keyword argument: tests
|
||||||
tests=payload.tests,
|
tests=payload.tests,
|
||||||
|
# Keyword argument: detection_rules
|
||||||
detection_rules=payload.detection_rules,
|
detection_rules=payload.detection_rules,
|
||||||
|
# Keyword argument: d3fend
|
||||||
d3fend=payload.d3fend,
|
d3fend=payload.d3fend,
|
||||||
|
# Keyword argument: recency
|
||||||
recency=payload.recency,
|
recency=payload.recency,
|
||||||
|
# Keyword argument: severity
|
||||||
severity=payload.severity,
|
severity=payload.severity,
|
||||||
|
# Keyword argument: freshness
|
||||||
freshness=payload.freshness,
|
freshness=payload.freshness,
|
||||||
|
# Keyword argument: platform_diversity
|
||||||
platform_diversity=payload.platform_diversity,
|
platform_diversity=payload.platform_diversity,
|
||||||
|
# Keyword argument: updated_by
|
||||||
updated_by=current_user.id,
|
updated_by=current_user.id,
|
||||||
)
|
)
|
||||||
|
# Call uow.commit()
|
||||||
uow.commit()
|
uow.commit()
|
||||||
|
|
||||||
|
# Import invalidate from app.services.score_cache
|
||||||
from app.services.score_cache import invalidate
|
from app.services.score_cache import invalidate
|
||||||
|
# Call invalidate()
|
||||||
invalidate()
|
invalidate()
|
||||||
|
|
||||||
|
# Return {"message": "Scoring config updated", **result}
|
||||||
return {"message": "Scoring config updated", **result}
|
return {"message": "Scoring config updated", **result}
|
||||||
|
|||||||
@@ -4,40 +4,71 @@ Provides periodic and manual snapshots of the organisation's coverage
|
|||||||
state, plus temporal comparison between any two snapshots.
|
state, plus temporal comparison between any two snapshots.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Import logging
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
# Import uuid
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
# Import Optional from typing
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
# Import APIRouter, Depends, Query from fastapi
|
||||||
from fastapi import APIRouter, Depends, Query
|
from fastapi import APIRouter, Depends, Query
|
||||||
|
|
||||||
|
# Import BaseModel from pydantic
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
# Import Session from sqlalchemy.orm
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
# Import get_db from app.database
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
|
|
||||||
|
# Import get_current_user, require_any_role, require_role from app.dependencies.auth
|
||||||
from app.dependencies.auth import get_current_user, require_any_role, require_role
|
from app.dependencies.auth import get_current_user, require_any_role, require_role
|
||||||
|
|
||||||
|
# Import BusinessRuleViolation from app.domain.errors
|
||||||
from app.domain.errors import BusinessRuleViolation
|
from app.domain.errors import BusinessRuleViolation
|
||||||
|
|
||||||
|
# Import UnitOfWork from app.domain.unit_of_work
|
||||||
from app.domain.unit_of_work import UnitOfWork
|
from app.domain.unit_of_work import UnitOfWork
|
||||||
|
|
||||||
|
# Import User from app.models.user
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.services.snapshot_service import (
|
|
||||||
create_snapshot,
|
# Import log_action from app.services.audit_service
|
||||||
compare_snapshots,
|
|
||||||
cleanup_old_snapshots,
|
|
||||||
get_coverage_evolution,
|
|
||||||
serialize_snapshot_summary,
|
|
||||||
list_snapshots as list_snapshots_svc,
|
|
||||||
get_snapshot_or_raise,
|
|
||||||
get_snapshot_detail,
|
|
||||||
delete_snapshot,
|
|
||||||
)
|
|
||||||
from app.services.audit_service import log_action
|
from app.services.audit_service import log_action
|
||||||
|
|
||||||
|
# Import from app.services.snapshot_service
|
||||||
|
from app.services.snapshot_service import (
|
||||||
|
compare_snapshots,
|
||||||
|
create_snapshot,
|
||||||
|
delete_snapshot,
|
||||||
|
get_coverage_evolution,
|
||||||
|
get_snapshot_detail,
|
||||||
|
get_snapshot_or_raise,
|
||||||
|
serialize_snapshot_summary,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Import from app.services.snapshot_service
|
||||||
|
from app.services.snapshot_service import (
|
||||||
|
list_snapshots as list_snapshots_svc,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assign logger = logging.getLogger(__name__)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Assign router = APIRouter(prefix="/snapshots", tags=["snapshots"])
|
||||||
router = APIRouter(prefix="/snapshots", tags=["snapshots"])
|
router = APIRouter(prefix="/snapshots", tags=["snapshots"])
|
||||||
|
|
||||||
|
|
||||||
# ── Pydantic schemas ─────────────────────────────────────────────────
|
# ── Pydantic schemas ─────────────────────────────────────────────────
|
||||||
|
|
||||||
class SnapshotCreate(BaseModel):
|
class SnapshotCreate(BaseModel):
|
||||||
|
"""Payload for creating a new coverage snapshot."""
|
||||||
|
|
||||||
|
# Assign name = None
|
||||||
name: Optional[str] = None
|
name: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
@@ -46,13 +77,19 @@ class SnapshotCreate(BaseModel):
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
@router.get("")
|
@router.get("")
|
||||||
|
# Define function list_snapshots
|
||||||
def list_snapshots(
|
def list_snapshots(
|
||||||
|
# Entry: offset
|
||||||
offset: int = Query(0, ge=0),
|
offset: int = Query(0, ge=0),
|
||||||
|
# Entry: limit
|
||||||
limit: int = Query(50, ge=1, le=200),
|
limit: int = Query(50, ge=1, le=200),
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
) -> list:
|
||||||
"""List coverage snapshots ordered by creation date (newest first)."""
|
"""List coverage snapshots ordered by creation date (newest first)."""
|
||||||
|
# Return list_snapshots_svc(db, offset=offset, limit=limit)
|
||||||
return list_snapshots_svc(db, offset=offset, limit=limit)
|
return list_snapshots_svc(db, offset=offset, limit=limit)
|
||||||
|
|
||||||
|
|
||||||
@@ -61,25 +98,39 @@ def list_snapshots(
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
@router.post("", status_code=201)
|
@router.post("", status_code=201)
|
||||||
|
# Define function create_snapshot_endpoint
|
||||||
def create_snapshot_endpoint(
|
def create_snapshot_endpoint(
|
||||||
|
# Entry: payload
|
||||||
payload: SnapshotCreate,
|
payload: SnapshotCreate,
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead", "admin")),
|
current_user: User = Depends(require_any_role("red_lead", "blue_lead", "admin")),
|
||||||
):
|
) -> dict:
|
||||||
"""Create a manual coverage snapshot with an optional name."""
|
"""Create a manual coverage snapshot with an optional name."""
|
||||||
|
# Assign snapshot = create_snapshot(db, name=payload.name, user_id=current_user.id)
|
||||||
snapshot = create_snapshot(db, name=payload.name, user_id=current_user.id)
|
snapshot = create_snapshot(db, name=payload.name, user_id=current_user.id)
|
||||||
|
|
||||||
|
# Open context manager
|
||||||
with UnitOfWork(db) as uow:
|
with UnitOfWork(db) as uow:
|
||||||
|
# Call log_action()
|
||||||
log_action(
|
log_action(
|
||||||
db,
|
db,
|
||||||
|
# Keyword argument: user_id
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
|
# Keyword argument: action
|
||||||
action="create_snapshot",
|
action="create_snapshot",
|
||||||
|
# Keyword argument: entity_type
|
||||||
entity_type="snapshot",
|
entity_type="snapshot",
|
||||||
|
# Keyword argument: entity_id
|
||||||
entity_id=snapshot.id,
|
entity_id=snapshot.id,
|
||||||
|
# Keyword argument: details
|
||||||
details={"name": snapshot.name, "score": snapshot.organization_score},
|
details={"name": snapshot.name, "score": snapshot.organization_score},
|
||||||
)
|
)
|
||||||
|
# Call uow.commit()
|
||||||
uow.commit()
|
uow.commit()
|
||||||
|
|
||||||
|
# Return serialize_snapshot_summary(snapshot)
|
||||||
return serialize_snapshot_summary(snapshot)
|
return serialize_snapshot_summary(snapshot)
|
||||||
|
|
||||||
|
|
||||||
@@ -89,12 +140,17 @@ def create_snapshot_endpoint(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/evolution")
|
@router.get("/evolution")
|
||||||
|
# Define function coverage_evolution
|
||||||
def coverage_evolution(
|
def coverage_evolution(
|
||||||
|
# Entry: months
|
||||||
months: int = Query(12, ge=1, le=36),
|
months: int = Query(12, ge=1, le=36),
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
) -> list:
|
||||||
"""Return coverage snapshots for trend charts (last *months* months)."""
|
"""Return coverage snapshots for trend charts (last *months* months)."""
|
||||||
|
# Return get_coverage_evolution(db, months=months)
|
||||||
return get_coverage_evolution(db, months=months)
|
return get_coverage_evolution(db, months=months)
|
||||||
|
|
||||||
|
|
||||||
@@ -103,19 +159,30 @@ def coverage_evolution(
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
@router.get("/compare")
|
@router.get("/compare")
|
||||||
|
# Define function compare_snapshots_endpoint
|
||||||
def compare_snapshots_endpoint(
|
def compare_snapshots_endpoint(
|
||||||
|
# Entry: a
|
||||||
a: str = Query(..., description="Snapshot A ID"),
|
a: str = Query(..., description="Snapshot A ID"),
|
||||||
|
# Entry: b
|
||||||
b: str = Query(..., description="Snapshot B ID"),
|
b: str = Query(..., description="Snapshot B ID"),
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
) -> dict:
|
||||||
"""Compare two snapshots showing improved, worsened, and unchanged techniques."""
|
"""Compare two snapshots showing improved, worsened, and unchanged techniques."""
|
||||||
|
# Attempt the following; catch errors below
|
||||||
try:
|
try:
|
||||||
|
# Assign a_id = uuid.UUID(a)
|
||||||
a_id = uuid.UUID(a)
|
a_id = uuid.UUID(a)
|
||||||
|
# Assign b_id = uuid.UUID(b)
|
||||||
b_id = uuid.UUID(b)
|
b_id = uuid.UUID(b)
|
||||||
|
# Handle ValueError
|
||||||
except ValueError:
|
except ValueError:
|
||||||
|
# Raise BusinessRuleViolation
|
||||||
raise BusinessRuleViolation("Invalid snapshot ID format")
|
raise BusinessRuleViolation("Invalid snapshot ID format")
|
||||||
|
|
||||||
|
# Return compare_snapshots(db, a_id, b_id)
|
||||||
return compare_snapshots(db, a_id, b_id)
|
return compare_snapshots(db, a_id, b_id)
|
||||||
|
|
||||||
|
|
||||||
@@ -124,12 +191,17 @@ def compare_snapshots_endpoint(
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
@router.get("/{snapshot_id}")
|
@router.get("/{snapshot_id}")
|
||||||
|
# Define function get_snapshot
|
||||||
def get_snapshot(
|
def get_snapshot(
|
||||||
|
# Entry: snapshot_id
|
||||||
snapshot_id: str,
|
snapshot_id: str,
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
) -> dict:
|
||||||
"""Get detailed snapshot information including per-technique states."""
|
"""Get detailed snapshot information including per-technique states."""
|
||||||
|
# Return get_snapshot_detail(db, snapshot_id)
|
||||||
return get_snapshot_detail(db, snapshot_id)
|
return get_snapshot_detail(db, snapshot_id)
|
||||||
|
|
||||||
|
|
||||||
@@ -138,24 +210,39 @@ def get_snapshot(
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
@router.delete("/{snapshot_id}")
|
@router.delete("/{snapshot_id}")
|
||||||
|
# Define function delete_snapshot_endpoint
|
||||||
def delete_snapshot_endpoint(
|
def delete_snapshot_endpoint(
|
||||||
|
# Entry: snapshot_id
|
||||||
snapshot_id: str,
|
snapshot_id: str,
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(require_role("admin")),
|
current_user: User = Depends(require_role("admin")),
|
||||||
):
|
) -> dict:
|
||||||
"""Delete a snapshot (admin only)."""
|
"""Delete a snapshot (admin only)."""
|
||||||
|
# Assign snapshot = get_snapshot_or_raise(db, snapshot_id)
|
||||||
snapshot = get_snapshot_or_raise(db, snapshot_id)
|
snapshot = get_snapshot_or_raise(db, snapshot_id)
|
||||||
|
|
||||||
|
# Open context manager
|
||||||
with UnitOfWork(db) as uow:
|
with UnitOfWork(db) as uow:
|
||||||
|
# Call log_action()
|
||||||
log_action(
|
log_action(
|
||||||
db,
|
db,
|
||||||
|
# Keyword argument: user_id
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
|
# Keyword argument: action
|
||||||
action="delete_snapshot",
|
action="delete_snapshot",
|
||||||
|
# Keyword argument: entity_type
|
||||||
entity_type="snapshot",
|
entity_type="snapshot",
|
||||||
|
# Keyword argument: entity_id
|
||||||
entity_id=snapshot.id,
|
entity_id=snapshot.id,
|
||||||
|
# Keyword argument: details
|
||||||
details={"name": snapshot.name},
|
details={"name": snapshot.name},
|
||||||
)
|
)
|
||||||
|
# Call delete_snapshot()
|
||||||
delete_snapshot(db, snapshot_id)
|
delete_snapshot(db, snapshot_id)
|
||||||
|
# Call uow.commit()
|
||||||
uow.commit()
|
uow.commit()
|
||||||
|
|
||||||
|
# Return {"detail": "Snapshot deleted"}
|
||||||
return {"detail": "Snapshot deleted"}
|
return {"detail": "Snapshot deleted"}
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ Also exposes email configuration CRUD (admin only) that writes to the
|
|||||||
system_configs table so settings survive container restarts.
|
system_configs table so settings survive container restarts.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Import logging
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@@ -22,10 +23,26 @@ from app.services.mitre_sync_service import sync_mitre
|
|||||||
from app.services.intel_service import scan_intel
|
from app.services.intel_service import scan_intel
|
||||||
from app.services.atomic_import_service import import_atomic_red_team
|
from app.services.atomic_import_service import import_atomic_red_team
|
||||||
from app.jobs.mitre_sync_job import scheduler
|
from app.jobs.mitre_sync_job import scheduler
|
||||||
|
|
||||||
|
# Import limiter from app.limiter
|
||||||
from app.limiter import limiter
|
from app.limiter import limiter
|
||||||
|
|
||||||
|
# Import User from app.models.user
|
||||||
|
from app.models.user import User
|
||||||
|
|
||||||
|
# Import import_atomic_red_team from app.services.atomic_import_service
|
||||||
|
from app.services.atomic_import_service import import_atomic_red_team
|
||||||
|
|
||||||
|
# Import scan_intel from app.services.intel_service
|
||||||
|
from app.services.intel_service import scan_intel
|
||||||
|
|
||||||
|
# Import sync_mitre from app.services.mitre_sync_service
|
||||||
|
from app.services.mitre_sync_service import sync_mitre
|
||||||
|
|
||||||
|
# Assign logger = logging.getLogger(__name__)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Assign router = APIRouter(prefix="/system", tags=["system"])
|
||||||
router = APIRouter(prefix="/system", tags=["system"])
|
router = APIRouter(prefix="/system", tags=["system"])
|
||||||
|
|
||||||
|
|
||||||
@@ -105,8 +122,11 @@ def _bg_mitre_sync() -> None:
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/sync-mitre")
|
@router.post("/sync-mitre")
|
||||||
|
# Apply the @limiter.limit decorator
|
||||||
@limiter.limit("2/hour")
|
@limiter.limit("2/hour")
|
||||||
|
# Define function trigger_mitre_sync
|
||||||
def trigger_mitre_sync(
|
def trigger_mitre_sync(
|
||||||
|
# Entry: request
|
||||||
request: Request,
|
request: Request,
|
||||||
background_tasks: BackgroundTasks,
|
background_tasks: BackgroundTasks,
|
||||||
current_user: User = Depends(require_role("admin")),
|
current_user: User = Depends(require_role("admin")),
|
||||||
@@ -127,11 +147,15 @@ def trigger_mitre_sync(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @router.post decorator
|
||||||
@router.post("/run-intel-scan")
|
@router.post("/run-intel-scan")
|
||||||
|
# Define function trigger_intel_scan
|
||||||
def trigger_intel_scan(
|
def trigger_intel_scan(
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(require_role("admin")),
|
current_user: User = Depends(require_role("admin")),
|
||||||
):
|
) -> dict:
|
||||||
"""Manually trigger a threat-intelligence scan.
|
"""Manually trigger a threat-intelligence scan.
|
||||||
|
|
||||||
**Requires** the ``admin`` role.
|
**Requires** the ``admin`` role.
|
||||||
@@ -139,20 +163,30 @@ def trigger_intel_scan(
|
|||||||
Returns a JSON object with the scan summary including the count of
|
Returns a JSON object with the scan summary including the count of
|
||||||
new intel items found.
|
new intel items found.
|
||||||
"""
|
"""
|
||||||
|
# Assign summary = scan_intel(db)
|
||||||
summary = scan_intel(db)
|
summary = scan_intel(db)
|
||||||
|
# Return {
|
||||||
return {
|
return {
|
||||||
|
# Literal argument value
|
||||||
"message": "Intel scan completed",
|
"message": "Intel scan completed",
|
||||||
|
# Literal argument value
|
||||||
"new_items": summary["new_items"],
|
"new_items": summary["new_items"],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @router.post decorator
|
||||||
@router.post("/import-atomic-tests")
|
@router.post("/import-atomic-tests")
|
||||||
|
# Apply the @limiter.limit decorator
|
||||||
@limiter.limit("2/hour")
|
@limiter.limit("2/hour")
|
||||||
|
# Define function trigger_atomic_import
|
||||||
def trigger_atomic_import(
|
def trigger_atomic_import(
|
||||||
|
# Entry: request
|
||||||
request: Request,
|
request: Request,
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(require_role("admin")),
|
current_user: User = Depends(require_role("admin")),
|
||||||
):
|
) -> dict:
|
||||||
"""Trigger an import of Atomic Red Team tests as TestTemplates.
|
"""Trigger an import of Atomic Red Team tests as TestTemplates.
|
||||||
|
|
||||||
**Requires** the ``admin`` role.
|
**Requires** the ``admin`` role.
|
||||||
@@ -163,37 +197,58 @@ def trigger_atomic_import(
|
|||||||
|
|
||||||
Returns a JSON object with import statistics.
|
Returns a JSON object with import statistics.
|
||||||
"""
|
"""
|
||||||
|
# Attempt the following; catch errors below
|
||||||
try:
|
try:
|
||||||
|
# Assign summary = import_atomic_red_team(db)
|
||||||
summary = import_atomic_red_team(db)
|
summary = import_atomic_red_team(db)
|
||||||
|
# Handle Exception
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
|
# Log error: "Atomic Red Team import failed: %s", exc, exc_info
|
||||||
logger.error("Atomic Red Team import failed: %s", exc, exc_info=True)
|
logger.error("Atomic Red Team import failed: %s", exc, exc_info=True)
|
||||||
|
# Return {
|
||||||
return {
|
return {
|
||||||
|
# Literal argument value
|
||||||
"message": "Import failed. Check server logs for details.",
|
"message": "Import failed. Check server logs for details.",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Return {
|
||||||
return {
|
return {
|
||||||
|
# Literal argument value
|
||||||
"message": "Import completed",
|
"message": "Import completed",
|
||||||
|
# Literal argument value
|
||||||
"imported": summary["created"],
|
"imported": summary["created"],
|
||||||
|
# Literal argument value
|
||||||
"skipped": summary["skipped_existing"],
|
"skipped": summary["skipped_existing"],
|
||||||
|
# Literal argument value
|
||||||
"total_parsed": summary["total_tests_parsed"],
|
"total_parsed": summary["total_tests_parsed"],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @router.get decorator
|
||||||
@router.get("/scheduler-status")
|
@router.get("/scheduler-status")
|
||||||
|
# Define function scheduler_status
|
||||||
def scheduler_status(
|
def scheduler_status(
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(require_role("admin")),
|
current_user: User = Depends(require_role("admin")),
|
||||||
):
|
) -> dict:
|
||||||
"""Return the current state of the background scheduler.
|
"""Return the current state of the background scheduler.
|
||||||
|
|
||||||
**Requires** the ``admin`` role.
|
**Requires** the ``admin`` role.
|
||||||
"""
|
"""
|
||||||
|
# Assign jobs = scheduler.get_jobs()
|
||||||
jobs = scheduler.get_jobs()
|
jobs = scheduler.get_jobs()
|
||||||
|
# Return {
|
||||||
return {
|
return {
|
||||||
|
# Literal argument value
|
||||||
"running": scheduler.running,
|
"running": scheduler.running,
|
||||||
|
# Literal argument value
|
||||||
"jobs": [
|
"jobs": [
|
||||||
{
|
{
|
||||||
|
# Literal argument value
|
||||||
"id": job.id,
|
"id": job.id,
|
||||||
|
# Literal argument value
|
||||||
"name": job.name,
|
"name": job.name,
|
||||||
|
# Literal argument value
|
||||||
"next_run_time": str(job.next_run_time) if job.next_run_time else None,
|
"next_run_time": str(job.next_run_time) if job.next_run_time else None,
|
||||||
}
|
}
|
||||||
for job in jobs
|
for job in jobs
|
||||||
|
|||||||
@@ -5,29 +5,56 @@ for error signaling. The error_handler middleware maps domain
|
|||||||
exceptions to HTTP responses automatically.
|
exceptions to HTTP responses automatically.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Import APIRouter, Depends, Query, status from fastapi
|
||||||
from fastapi import APIRouter, Depends, Query, status
|
from fastapi import APIRouter, Depends, Query, status
|
||||||
|
|
||||||
|
# Import Session from sqlalchemy.orm
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
# Import get_db from app.database
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
from app.dependencies.auth import get_current_user, require_role, require_any_role
|
|
||||||
|
# Import get_current_user, require_any_role, require_role from app.dependencies.auth
|
||||||
|
from app.dependencies.auth import get_current_user, require_any_role, require_role
|
||||||
|
|
||||||
|
# Import get_technique_repository from app.dependencies.repositories
|
||||||
from app.dependencies.repositories import get_technique_repository
|
from app.dependencies.repositories import get_technique_repository
|
||||||
|
|
||||||
|
# Import TechniqueEntity from app.domain.entities.technique
|
||||||
from app.domain.entities.technique import TechniqueEntity
|
from app.domain.entities.technique import TechniqueEntity
|
||||||
from app.domain.errors import DuplicateEntityError, EntityNotFoundError
|
|
||||||
|
# Import TechniqueStatus from app.domain.enums
|
||||||
from app.domain.enums import TechniqueStatus
|
from app.domain.enums import TechniqueStatus
|
||||||
|
|
||||||
|
# Import DuplicateEntityError, EntityNotFoundError from app.domain.errors
|
||||||
|
from app.domain.errors import DuplicateEntityError, EntityNotFoundError
|
||||||
|
|
||||||
|
# Import UnitOfWork from app.domain.unit_of_work
|
||||||
from app.domain.unit_of_work import UnitOfWork
|
from app.domain.unit_of_work import UnitOfWork
|
||||||
|
|
||||||
|
# Import from app.infrastructure.persistence.repositories.sa_technique_repository
|
||||||
from app.infrastructure.persistence.repositories.sa_technique_repository import (
|
from app.infrastructure.persistence.repositories.sa_technique_repository import (
|
||||||
SATechniqueRepository,
|
SATechniqueRepository,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Import User from app.models.user
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
|
|
||||||
|
# Import from app.schemas.technique
|
||||||
from app.schemas.technique import (
|
from app.schemas.technique import (
|
||||||
TechniqueCreate,
|
TechniqueCreate,
|
||||||
TechniqueOut,
|
TechniqueOut,
|
||||||
TechniqueSummary,
|
TechniqueSummary,
|
||||||
TechniqueUpdate,
|
TechniqueUpdate,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Import log_action from app.services.audit_service
|
||||||
from app.services.audit_service import log_action
|
from app.services.audit_service import log_action
|
||||||
|
|
||||||
|
# Import get_technique_detail from app.services.technique_query_service
|
||||||
from app.services.technique_query_service import get_technique_detail
|
from app.services.technique_query_service import get_technique_detail
|
||||||
|
|
||||||
|
# Assign router = APIRouter(prefix="/techniques", tags=["techniques"])
|
||||||
router = APIRouter(prefix="/techniques", tags=["techniques"])
|
router = APIRouter(prefix="/techniques", tags=["techniques"])
|
||||||
|
|
||||||
|
|
||||||
@@ -37,19 +64,29 @@ router = APIRouter(prefix="/techniques", tags=["techniques"])
|
|||||||
|
|
||||||
|
|
||||||
@router.get("", response_model=list[TechniqueSummary])
|
@router.get("", response_model=list[TechniqueSummary])
|
||||||
|
# Define function list_techniques
|
||||||
def list_techniques(
|
def list_techniques(
|
||||||
|
# Entry: tactic
|
||||||
tactic: str | None = Query(None, description="Filter by tactic name"),
|
tactic: str | None = Query(None, description="Filter by tactic name"),
|
||||||
|
# Entry: status_global
|
||||||
status_global: TechniqueStatus | None = Query(
|
status_global: TechniqueStatus | None = Query(
|
||||||
None, alias="status", description="Filter by global status"
|
None, alias="status", description="Filter by global status"
|
||||||
),
|
),
|
||||||
|
# Entry: review_required
|
||||||
review_required: bool | None = Query(None, description="Filter by review flag"),
|
review_required: bool | None = Query(None, description="Filter by review flag"),
|
||||||
|
# Entry: repo
|
||||||
repo: SATechniqueRepository = Depends(get_technique_repository),
|
repo: SATechniqueRepository = Depends(get_technique_repository),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
) -> list:
|
||||||
"""Return a lightweight list of techniques, optionally filtered."""
|
"""Return a lightweight list of techniques, optionally filtered."""
|
||||||
|
# Return repo.list_all(
|
||||||
return repo.list_all(
|
return repo.list_all(
|
||||||
|
# Keyword argument: tactic
|
||||||
tactic=tactic,
|
tactic=tactic,
|
||||||
|
# Keyword argument: status
|
||||||
status=status_global,
|
status=status_global,
|
||||||
|
# Keyword argument: review_required
|
||||||
review_required=review_required,
|
review_required=review_required,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -60,12 +97,17 @@ def list_techniques(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/{mitre_id}")
|
@router.get("/{mitre_id}")
|
||||||
|
# Define function get_technique
|
||||||
def get_technique(
|
def get_technique(
|
||||||
|
# Entry: mitre_id
|
||||||
mitre_id: str,
|
mitre_id: str,
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
) -> dict:
|
||||||
"""Return full details for a single technique, including its tests and D3FEND defenses."""
|
"""Return full details for a single technique, including its tests and D3FEND defenses."""
|
||||||
|
# Return get_technique_detail(db, mitre_id)
|
||||||
return get_technique_detail(db, mitre_id)
|
return get_technique_detail(db, mitre_id)
|
||||||
|
|
||||||
|
|
||||||
@@ -75,40 +117,66 @@ def get_technique(
|
|||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
|
# Literal argument value
|
||||||
"",
|
"",
|
||||||
|
# Keyword argument: response_model
|
||||||
response_model=TechniqueOut,
|
response_model=TechniqueOut,
|
||||||
|
# Keyword argument: status_code
|
||||||
status_code=status.HTTP_201_CREATED,
|
status_code=status.HTTP_201_CREATED,
|
||||||
)
|
)
|
||||||
|
# Define function create_technique
|
||||||
def create_technique(
|
def create_technique(
|
||||||
|
# Entry: payload
|
||||||
payload: TechniqueCreate,
|
payload: TechniqueCreate,
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: repo
|
||||||
repo: SATechniqueRepository = Depends(get_technique_repository),
|
repo: SATechniqueRepository = Depends(get_technique_repository),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(require_role("admin")),
|
current_user: User = Depends(require_role("admin")),
|
||||||
):
|
) -> TechniqueOut:
|
||||||
"""Create a new technique manually."""
|
"""Create a new technique manually."""
|
||||||
|
# Check: repo.exists_by_mitre_id(payload.mitre_id)
|
||||||
if repo.exists_by_mitre_id(payload.mitre_id):
|
if repo.exists_by_mitre_id(payload.mitre_id):
|
||||||
|
# Raise DuplicateEntityError
|
||||||
raise DuplicateEntityError("Technique", "mitre_id", payload.mitre_id)
|
raise DuplicateEntityError("Technique", "mitre_id", payload.mitre_id)
|
||||||
|
|
||||||
|
# Assign entity = TechniqueEntity.create(
|
||||||
entity = TechniqueEntity.create(
|
entity = TechniqueEntity.create(
|
||||||
|
# Keyword argument: mitre_id
|
||||||
mitre_id=payload.mitre_id,
|
mitre_id=payload.mitre_id,
|
||||||
|
# Keyword argument: name
|
||||||
name=payload.name,
|
name=payload.name,
|
||||||
|
# Keyword argument: description
|
||||||
description=payload.description,
|
description=payload.description,
|
||||||
|
# Keyword argument: tactic
|
||||||
tactic=payload.tactic,
|
tactic=payload.tactic,
|
||||||
|
# Keyword argument: platforms
|
||||||
platforms=payload.platforms,
|
platforms=payload.platforms,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Open context manager
|
||||||
with UnitOfWork(db) as uow:
|
with UnitOfWork(db) as uow:
|
||||||
|
# Assign saved = repo.save(entity)
|
||||||
saved = repo.save(entity)
|
saved = repo.save(entity)
|
||||||
|
# Call log_action()
|
||||||
log_action(
|
log_action(
|
||||||
db,
|
db,
|
||||||
|
# Keyword argument: user_id
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
|
# Keyword argument: action
|
||||||
action="create_technique",
|
action="create_technique",
|
||||||
|
# Keyword argument: entity_type
|
||||||
entity_type="technique",
|
entity_type="technique",
|
||||||
|
# Keyword argument: entity_id
|
||||||
entity_id=saved.id,
|
entity_id=saved.id,
|
||||||
|
# Keyword argument: details
|
||||||
details={"mitre_id": saved.mitre_id, "name": saved.name},
|
details={"mitre_id": saved.mitre_id, "name": saved.name},
|
||||||
)
|
)
|
||||||
|
# Call uow.commit()
|
||||||
uow.commit()
|
uow.commit()
|
||||||
|
|
||||||
|
# Return saved
|
||||||
return saved
|
return saved
|
||||||
|
|
||||||
|
|
||||||
@@ -118,34 +186,56 @@ def create_technique(
|
|||||||
|
|
||||||
|
|
||||||
@router.patch("/{mitre_id}", response_model=TechniqueOut)
|
@router.patch("/{mitre_id}", response_model=TechniqueOut)
|
||||||
|
# Define function update_technique
|
||||||
def update_technique(
|
def update_technique(
|
||||||
|
# Entry: mitre_id
|
||||||
mitre_id: str,
|
mitre_id: str,
|
||||||
|
# Entry: payload
|
||||||
payload: TechniqueUpdate,
|
payload: TechniqueUpdate,
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: repo
|
||||||
repo: SATechniqueRepository = Depends(get_technique_repository),
|
repo: SATechniqueRepository = Depends(get_technique_repository),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(require_role("admin")),
|
current_user: User = Depends(require_role("admin")),
|
||||||
):
|
) -> TechniqueOut:
|
||||||
"""Update one or more fields of an existing technique."""
|
"""Update one or more fields of an existing technique."""
|
||||||
|
# Assign entity = repo.find_by_mitre_id(mitre_id)
|
||||||
entity = repo.find_by_mitre_id(mitre_id)
|
entity = repo.find_by_mitre_id(mitre_id)
|
||||||
|
# Check: entity is None
|
||||||
if entity is None:
|
if entity is None:
|
||||||
|
# Raise EntityNotFoundError
|
||||||
raise EntityNotFoundError("Technique", mitre_id)
|
raise EntityNotFoundError("Technique", mitre_id)
|
||||||
|
|
||||||
|
# Assign update_data = payload.model_dump(exclude_unset=True)
|
||||||
update_data = payload.model_dump(exclude_unset=True)
|
update_data = payload.model_dump(exclude_unset=True)
|
||||||
|
# Iterate over update_data.items()
|
||||||
for field, value in update_data.items():
|
for field, value in update_data.items():
|
||||||
|
# Call setattr()
|
||||||
setattr(entity, field, value)
|
setattr(entity, field, value)
|
||||||
|
|
||||||
|
# Open context manager
|
||||||
with UnitOfWork(db) as uow:
|
with UnitOfWork(db) as uow:
|
||||||
|
# Assign saved = repo.save(entity)
|
||||||
saved = repo.save(entity)
|
saved = repo.save(entity)
|
||||||
|
# Call log_action()
|
||||||
log_action(
|
log_action(
|
||||||
db,
|
db,
|
||||||
|
# Keyword argument: user_id
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
|
# Keyword argument: action
|
||||||
action="update_technique",
|
action="update_technique",
|
||||||
|
# Keyword argument: entity_type
|
||||||
entity_type="technique",
|
entity_type="technique",
|
||||||
|
# Keyword argument: entity_id
|
||||||
entity_id=saved.id,
|
entity_id=saved.id,
|
||||||
|
# Keyword argument: details
|
||||||
details={"mitre_id": mitre_id, "updated_fields": list(update_data.keys())},
|
details={"mitre_id": mitre_id, "updated_fields": list(update_data.keys())},
|
||||||
)
|
)
|
||||||
|
# Call uow.commit()
|
||||||
uow.commit()
|
uow.commit()
|
||||||
|
|
||||||
|
# Return saved
|
||||||
return saved
|
return saved
|
||||||
|
|
||||||
|
|
||||||
@@ -155,33 +245,52 @@ def update_technique(
|
|||||||
|
|
||||||
|
|
||||||
@router.patch("/{mitre_id}/review", response_model=TechniqueOut)
|
@router.patch("/{mitre_id}/review", response_model=TechniqueOut)
|
||||||
|
# Define function review_technique
|
||||||
def review_technique(
|
def review_technique(
|
||||||
|
# Entry: mitre_id
|
||||||
mitre_id: str,
|
mitre_id: str,
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: repo
|
||||||
repo: SATechniqueRepository = Depends(get_technique_repository),
|
repo: SATechniqueRepository = Depends(get_technique_repository),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||||
):
|
) -> TechniqueOut:
|
||||||
"""Mark a technique as reviewed.
|
"""Mark a technique as reviewed.
|
||||||
|
|
||||||
Sets ``review_required`` to *False* and records the current timestamp
|
Sets ``review_required`` to *False* and records the current timestamp
|
||||||
in ``last_review_date``.
|
in ``last_review_date``.
|
||||||
"""
|
"""
|
||||||
|
# Assign entity = repo.find_by_mitre_id(mitre_id)
|
||||||
entity = repo.find_by_mitre_id(mitre_id)
|
entity = repo.find_by_mitre_id(mitre_id)
|
||||||
|
# Check: entity is None
|
||||||
if entity is None:
|
if entity is None:
|
||||||
|
# Raise EntityNotFoundError
|
||||||
raise EntityNotFoundError("Technique", mitre_id)
|
raise EntityNotFoundError("Technique", mitre_id)
|
||||||
|
|
||||||
|
# Call entity.mark_reviewed()
|
||||||
entity.mark_reviewed()
|
entity.mark_reviewed()
|
||||||
|
|
||||||
|
# Open context manager
|
||||||
with UnitOfWork(db) as uow:
|
with UnitOfWork(db) as uow:
|
||||||
|
# Assign saved = repo.save(entity)
|
||||||
saved = repo.save(entity)
|
saved = repo.save(entity)
|
||||||
|
# Call log_action()
|
||||||
log_action(
|
log_action(
|
||||||
db,
|
db,
|
||||||
|
# Keyword argument: user_id
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
|
# Keyword argument: action
|
||||||
action="review_technique",
|
action="review_technique",
|
||||||
|
# Keyword argument: entity_type
|
||||||
entity_type="technique",
|
entity_type="technique",
|
||||||
|
# Keyword argument: entity_id
|
||||||
entity_id=saved.id,
|
entity_id=saved.id,
|
||||||
|
# Keyword argument: details
|
||||||
details={"mitre_id": mitre_id},
|
details={"mitre_id": mitre_id},
|
||||||
)
|
)
|
||||||
|
# Call uow.commit()
|
||||||
uow.commit()
|
uow.commit()
|
||||||
|
|
||||||
|
# Return saved
|
||||||
return saved
|
return saved
|
||||||
|
|||||||
@@ -22,35 +22,69 @@ Filters (GET /test-templates)
|
|||||||
- offset / limit: pagination (default limit=50)
|
- offset / limit: pagination (default limit=50)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Import uuid
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
# Import Optional from typing
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
# Import APIRouter, Depends, Query, status from fastapi
|
||||||
from fastapi import APIRouter, Depends, Query, status
|
from fastapi import APIRouter, Depends, Query, status
|
||||||
|
|
||||||
|
# Import Session from sqlalchemy.orm
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
# Import get_db from app.database
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
|
|
||||||
|
# Import get_current_user, require_any_role from app.dependencies.auth
|
||||||
from app.dependencies.auth import get_current_user, require_any_role
|
from app.dependencies.auth import get_current_user, require_any_role
|
||||||
|
|
||||||
|
# Import UnitOfWork from app.domain.unit_of_work
|
||||||
from app.domain.unit_of_work import UnitOfWork
|
from app.domain.unit_of_work import UnitOfWork
|
||||||
from app.models.technique import Technique
|
from app.models.technique import Technique
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
|
|
||||||
|
# Import from app.schemas.test_template
|
||||||
from app.schemas.test_template import (
|
from app.schemas.test_template import (
|
||||||
TestTemplateCreate,
|
TestTemplateCreate,
|
||||||
TestTemplateOut,
|
TestTemplateOut,
|
||||||
TestTemplateSummary,
|
TestTemplateSummary,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Import log_action from app.services.audit_service
|
||||||
from app.services.audit_service import log_action
|
from app.services.audit_service import log_action
|
||||||
|
|
||||||
|
# Import from app.services.test_template_service
|
||||||
from app.services.test_template_service import (
|
from app.services.test_template_service import (
|
||||||
bulk_activate,
|
bulk_activate,
|
||||||
create_template as create_template_svc,
|
|
||||||
get_template_or_raise,
|
get_template_or_raise,
|
||||||
get_template_stats,
|
get_template_stats,
|
||||||
get_templates_by_technique as templates_by_technique,
|
|
||||||
list_templates,
|
list_templates,
|
||||||
soft_delete_template,
|
soft_delete_template,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Import from app.services.test_template_service
|
||||||
|
from app.services.test_template_service import (
|
||||||
|
create_template as create_template_svc,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Import from app.services.test_template_service
|
||||||
|
from app.services.test_template_service import (
|
||||||
|
get_templates_by_technique as templates_by_technique,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Import from app.services.test_template_service
|
||||||
|
from app.services.test_template_service import (
|
||||||
toggle_template_active as toggle_template_active_svc,
|
toggle_template_active as toggle_template_active_svc,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Import from app.services.test_template_service
|
||||||
|
from app.services.test_template_service import (
|
||||||
update_template as update_template_svc,
|
update_template as update_template_svc,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Assign router = APIRouter(prefix="/test-templates", tags=["test-templates"])
|
||||||
router = APIRouter(prefix="/test-templates", tags=["test-templates"])
|
router = APIRouter(prefix="/test-templates", tags=["test-templates"])
|
||||||
|
|
||||||
|
|
||||||
@@ -60,28 +94,64 @@ router = APIRouter(prefix="/test-templates", tags=["test-templates"])
|
|||||||
|
|
||||||
|
|
||||||
@router.get("", response_model=list[TestTemplateSummary])
|
@router.get("", response_model=list[TestTemplateSummary])
|
||||||
|
# Define function _list_templates_handler
|
||||||
def _list_templates_handler(
|
def _list_templates_handler(
|
||||||
|
# Entry: source
|
||||||
source: Optional[str] = Query(None, description="Filter by source (atomic_red_team, mitre, custom)"),
|
source: Optional[str] = Query(None, description="Filter by source (atomic_red_team, mitre, custom)"),
|
||||||
|
# Entry: platform
|
||||||
platform: Optional[str] = Query(None, description="Filter by platform (windows, linux, macos)"),
|
platform: Optional[str] = Query(None, description="Filter by platform (windows, linux, macos)"),
|
||||||
|
# Entry: severity
|
||||||
severity: Optional[str] = Query(None, description="Filter by severity (low, medium, high, critical)"),
|
severity: Optional[str] = Query(None, description="Filter by severity (low, medium, high, critical)"),
|
||||||
|
# Entry: mitre_technique_id
|
||||||
mitre_technique_id: Optional[str] = Query(None, description="Filter by MITRE technique ID"),
|
mitre_technique_id: Optional[str] = Query(None, description="Filter by MITRE technique ID"),
|
||||||
|
# Entry: search
|
||||||
search: Optional[str] = Query(None, description="Search in name and description"),
|
search: Optional[str] = Query(None, description="Search in name and description"),
|
||||||
|
# Entry: is_active
|
||||||
is_active: Optional[bool] = Query(None, description="Filter by active status (true/false). Omit to return all."),
|
is_active: Optional[bool] = Query(None, description="Filter by active status (true/false). Omit to return all."),
|
||||||
|
# Entry: offset
|
||||||
offset: int = Query(0, ge=0),
|
offset: int = Query(0, ge=0),
|
||||||
|
# Entry: limit
|
||||||
limit: int = Query(50, ge=1, le=200),
|
limit: int = Query(50, ge=1, le=200),
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
) -> list:
|
||||||
"""Return a paginated, filterable list of test templates."""
|
"""Return a paginated, filterable list of test templates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source (Optional[str]): Filter by source (``atomic_red_team``, ``mitre``, ``custom``).
|
||||||
|
platform (Optional[str]): Filter by platform (``windows``, ``linux``, ``macos``).
|
||||||
|
severity (Optional[str]): Filter by severity (``low``, ``medium``, ``high``, ``critical``).
|
||||||
|
mitre_technique_id (Optional[str]): Filter by MITRE technique ID string.
|
||||||
|
search (Optional[str]): Full-text search across name and description.
|
||||||
|
is_active (Optional[bool]): Filter by active status; omit to return all.
|
||||||
|
offset (int): Number of records to skip for pagination.
|
||||||
|
limit (int): Maximum number of records to return.
|
||||||
|
db (Session): SQLAlchemy database session.
|
||||||
|
current_user (User): Authenticated user making the request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: Serialised list of :class:`TestTemplateSummary` objects.
|
||||||
|
"""
|
||||||
|
# Return list_templates(
|
||||||
return list_templates(
|
return list_templates(
|
||||||
db,
|
db,
|
||||||
|
# Keyword argument: source
|
||||||
source=source,
|
source=source,
|
||||||
|
# Keyword argument: platform
|
||||||
platform=platform,
|
platform=platform,
|
||||||
|
# Keyword argument: severity
|
||||||
severity=severity,
|
severity=severity,
|
||||||
|
# Keyword argument: mitre_technique_id
|
||||||
mitre_technique_id=mitre_technique_id,
|
mitre_technique_id=mitre_technique_id,
|
||||||
|
# Keyword argument: search
|
||||||
search=search,
|
search=search,
|
||||||
|
# Keyword argument: is_active
|
||||||
is_active=is_active,
|
is_active=is_active,
|
||||||
|
# Keyword argument: offset
|
||||||
offset=offset,
|
offset=offset,
|
||||||
|
# Keyword argument: limit
|
||||||
limit=limit,
|
limit=limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -92,11 +162,23 @@ def _list_templates_handler(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/stats")
|
@router.get("/stats")
|
||||||
|
# Define function template_stats
|
||||||
def template_stats(
|
def template_stats(
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||||
):
|
) -> dict:
|
||||||
"""Return catalog statistics: active, by_source, by_platform."""
|
"""Return catalog statistics: active, by_source, by_platform.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db (Session): SQLAlchemy database session.
|
||||||
|
current_user (User): Authenticated red_lead or blue_lead.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Counts of active templates broken down by source and platform.
|
||||||
|
"""
|
||||||
|
# Return get_template_stats(db)
|
||||||
return get_template_stats(db)
|
return get_template_stats(db)
|
||||||
|
|
||||||
|
|
||||||
@@ -106,27 +188,53 @@ def template_stats(
|
|||||||
|
|
||||||
|
|
||||||
@router.patch("/bulk-activate")
|
@router.patch("/bulk-activate")
|
||||||
|
# Define function bulk_activate_templates
|
||||||
def bulk_activate_templates(
|
def bulk_activate_templates(
|
||||||
|
# Entry: activate
|
||||||
activate: bool = Query(True, description="True to activate all, False to deactivate all"),
|
activate: bool = Query(True, description="True to activate all, False to deactivate all"),
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||||
):
|
) -> dict:
|
||||||
"""Set all templates to active or inactive."""
|
"""Set all templates to active or inactive.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
activate (bool): ``True`` to activate all templates, ``False`` to deactivate all.
|
||||||
|
db (Session): SQLAlchemy database session.
|
||||||
|
current_user (User): Authenticated red_lead or blue_lead.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Confirmation message with ``affected`` count and the applied ``is_active`` flag.
|
||||||
|
"""
|
||||||
|
# Assign count = bulk_activate(db, activate=activate)
|
||||||
count = bulk_activate(db, activate=activate)
|
count = bulk_activate(db, activate=activate)
|
||||||
|
# Open context manager
|
||||||
with UnitOfWork(db) as uow:
|
with UnitOfWork(db) as uow:
|
||||||
|
# Call log_action()
|
||||||
log_action(
|
log_action(
|
||||||
db,
|
db,
|
||||||
|
# Keyword argument: user_id
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
|
# Keyword argument: action
|
||||||
action="bulk_activate_templates" if activate else "bulk_deactivate_templates",
|
action="bulk_activate_templates" if activate else "bulk_deactivate_templates",
|
||||||
|
# Keyword argument: entity_type
|
||||||
entity_type="test_template",
|
entity_type="test_template",
|
||||||
|
# Keyword argument: entity_id
|
||||||
entity_id=None,
|
entity_id=None,
|
||||||
|
# Keyword argument: details
|
||||||
details={"affected": count, "is_active": activate},
|
details={"affected": count, "is_active": activate},
|
||||||
)
|
)
|
||||||
|
# Call uow.commit()
|
||||||
uow.commit()
|
uow.commit()
|
||||||
|
|
||||||
|
# Return {
|
||||||
return {
|
return {
|
||||||
|
# Literal argument value
|
||||||
"detail": f"{'Activated' if activate else 'Deactivated'} {count} templates",
|
"detail": f"{'Activated' if activate else 'Deactivated'} {count} templates",
|
||||||
|
# Literal argument value
|
||||||
"affected": count,
|
"affected": count,
|
||||||
|
# Literal argument value
|
||||||
"is_active": activate,
|
"is_active": activate,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -137,12 +245,26 @@ def bulk_activate_templates(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/by-technique/{mitre_id}", response_model=list[TestTemplateSummary])
|
@router.get("/by-technique/{mitre_id}", response_model=list[TestTemplateSummary])
|
||||||
|
# Define function _templates_by_technique_handler
|
||||||
def _templates_by_technique_handler(
|
def _templates_by_technique_handler(
|
||||||
|
# Entry: mitre_id
|
||||||
mitre_id: str,
|
mitre_id: str,
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
) -> list:
|
||||||
"""Return all active templates mapped to a specific MITRE technique."""
|
"""Return all active templates mapped to a specific MITRE technique.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mitre_id (str): MITRE ATT&CK technique ID (e.g. ``T1059.001``).
|
||||||
|
db (Session): SQLAlchemy database session.
|
||||||
|
current_user (User): Authenticated user making the request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: Serialised list of :class:`TestTemplateSummary` objects for the technique.
|
||||||
|
"""
|
||||||
|
# Return templates_by_technique(db, mitre_id)
|
||||||
return templates_by_technique(db, mitre_id)
|
return templates_by_technique(db, mitre_id)
|
||||||
|
|
||||||
|
|
||||||
@@ -152,12 +274,26 @@ def _templates_by_technique_handler(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/{template_id}", response_model=TestTemplateOut)
|
@router.get("/{template_id}", response_model=TestTemplateOut)
|
||||||
|
# Define function get_template
|
||||||
def get_template(
|
def get_template(
|
||||||
|
# Entry: template_id
|
||||||
template_id: uuid.UUID,
|
template_id: uuid.UUID,
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
) -> TestTemplateOut:
|
||||||
"""Return full details for a single test template."""
|
"""Return full details for a single test template.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
template_id (uuid.UUID): Primary key of the template to retrieve.
|
||||||
|
db (Session): SQLAlchemy database session.
|
||||||
|
current_user (User): Authenticated user making the request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TestTemplateOut: Full template detail including all fields.
|
||||||
|
"""
|
||||||
|
# Return get_template_or_raise(db, template_id)
|
||||||
return get_template_or_raise(db, template_id)
|
return get_template_or_raise(db, template_id)
|
||||||
|
|
||||||
|
|
||||||
@@ -167,17 +303,35 @@ def get_template(
|
|||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
|
# Literal argument value
|
||||||
"",
|
"",
|
||||||
|
# Keyword argument: response_model
|
||||||
response_model=TestTemplateOut,
|
response_model=TestTemplateOut,
|
||||||
|
# Keyword argument: status_code
|
||||||
status_code=status.HTTP_201_CREATED,
|
status_code=status.HTTP_201_CREATED,
|
||||||
)
|
)
|
||||||
|
# Define function create_template
|
||||||
def create_template(
|
def create_template(
|
||||||
|
# Entry: payload
|
||||||
payload: TestTemplateCreate,
|
payload: TestTemplateCreate,
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||||
):
|
) -> TestTemplateOut:
|
||||||
"""Create a custom test template."""
|
"""Create a custom test template.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
payload (TestTemplateCreate): All fields for the new template.
|
||||||
|
db (Session): SQLAlchemy database session.
|
||||||
|
current_user (User): Authenticated red_lead or blue_lead creating the template.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TestTemplateOut: The newly created template with all fields populated.
|
||||||
|
"""
|
||||||
|
# Assign template = create_template_svc(db, **payload.model_dump())
|
||||||
template = create_template_svc(db, **payload.model_dump())
|
template = create_template_svc(db, **payload.model_dump())
|
||||||
|
# Open context manager
|
||||||
with UnitOfWork(db) as uow:
|
with UnitOfWork(db) as uow:
|
||||||
# Flag the associated technique for review — new template available
|
# Flag the associated technique for review — new template available
|
||||||
if template.mitre_technique_id:
|
if template.mitre_technique_id:
|
||||||
@@ -190,19 +344,30 @@ def create_template(
|
|||||||
technique.review_required = True
|
technique.review_required = True
|
||||||
log_action(
|
log_action(
|
||||||
db,
|
db,
|
||||||
|
# Keyword argument: user_id
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
|
# Keyword argument: action
|
||||||
action="create_test_template",
|
action="create_test_template",
|
||||||
|
# Keyword argument: entity_type
|
||||||
entity_type="test_template",
|
entity_type="test_template",
|
||||||
|
# Keyword argument: entity_id
|
||||||
entity_id=template.id,
|
entity_id=template.id,
|
||||||
|
# Keyword argument: details
|
||||||
details={
|
details={
|
||||||
|
# Literal argument value
|
||||||
"name": template.name,
|
"name": template.name,
|
||||||
|
# Literal argument value
|
||||||
"source": template.source,
|
"source": template.source,
|
||||||
|
# Literal argument value
|
||||||
"mitre_technique_id": template.mitre_technique_id,
|
"mitre_technique_id": template.mitre_technique_id,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
# Call uow.commit()
|
||||||
uow.commit()
|
uow.commit()
|
||||||
|
# Reload ORM object attributes from the database
|
||||||
db.refresh(template)
|
db.refresh(template)
|
||||||
|
|
||||||
|
# Return template
|
||||||
return template
|
return template
|
||||||
|
|
||||||
|
|
||||||
@@ -212,26 +377,52 @@ def create_template(
|
|||||||
|
|
||||||
|
|
||||||
@router.patch("/{template_id}", response_model=TestTemplateOut)
|
@router.patch("/{template_id}", response_model=TestTemplateOut)
|
||||||
|
# Define function update_template
|
||||||
def update_template(
|
def update_template(
|
||||||
|
# Entry: template_id
|
||||||
template_id: uuid.UUID,
|
template_id: uuid.UUID,
|
||||||
|
# Entry: payload
|
||||||
payload: TestTemplateCreate,
|
payload: TestTemplateCreate,
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||||
):
|
) -> TestTemplateOut:
|
||||||
"""Update fields of an existing test template."""
|
"""Update fields of an existing test template.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
template_id (uuid.UUID): Primary key of the template to update.
|
||||||
|
payload (TestTemplateCreate): Fields to update; only set fields are applied.
|
||||||
|
db (Session): SQLAlchemy database session.
|
||||||
|
current_user (User): Authenticated red_lead or blue_lead updating the template.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TestTemplateOut: The updated template with refreshed field values.
|
||||||
|
"""
|
||||||
|
# Assign template = update_template_svc(db, template_id, **payload.model_dump(exclude_u...
|
||||||
template = update_template_svc(db, template_id, **payload.model_dump(exclude_unset=True))
|
template = update_template_svc(db, template_id, **payload.model_dump(exclude_unset=True))
|
||||||
|
# Open context manager
|
||||||
with UnitOfWork(db) as uow:
|
with UnitOfWork(db) as uow:
|
||||||
|
# Call log_action()
|
||||||
log_action(
|
log_action(
|
||||||
db,
|
db,
|
||||||
|
# Keyword argument: user_id
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
|
# Keyword argument: action
|
||||||
action="update_test_template",
|
action="update_test_template",
|
||||||
|
# Keyword argument: entity_type
|
||||||
entity_type="test_template",
|
entity_type="test_template",
|
||||||
|
# Keyword argument: entity_id
|
||||||
entity_id=template.id,
|
entity_id=template.id,
|
||||||
|
# Keyword argument: details
|
||||||
details={"updated_fields": list(payload.model_dump(exclude_unset=True).keys())},
|
details={"updated_fields": list(payload.model_dump(exclude_unset=True).keys())},
|
||||||
)
|
)
|
||||||
|
# Call uow.commit()
|
||||||
uow.commit()
|
uow.commit()
|
||||||
|
# Reload ORM object attributes from the database
|
||||||
db.refresh(template)
|
db.refresh(template)
|
||||||
|
|
||||||
|
# Return template
|
||||||
return template
|
return template
|
||||||
|
|
||||||
|
|
||||||
@@ -241,25 +432,49 @@ def update_template(
|
|||||||
|
|
||||||
|
|
||||||
@router.patch("/{template_id}/toggle-active", response_model=TestTemplateOut)
|
@router.patch("/{template_id}/toggle-active", response_model=TestTemplateOut)
|
||||||
|
# Define function toggle_template_active
|
||||||
def toggle_template_active(
|
def toggle_template_active(
|
||||||
|
# Entry: template_id
|
||||||
template_id: uuid.UUID,
|
template_id: uuid.UUID,
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||||
):
|
) -> TestTemplateOut:
|
||||||
"""Toggle a template between active and inactive (is_active = not is_active)."""
|
"""Toggle a template between active and inactive (is_active = not is_active).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
template_id (uuid.UUID): Primary key of the template to toggle.
|
||||||
|
db (Session): SQLAlchemy database session.
|
||||||
|
current_user (User): Authenticated red_lead or blue_lead.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TestTemplateOut: The template with the updated ``is_active`` flag.
|
||||||
|
"""
|
||||||
|
# Assign template = toggle_template_active_svc(db, template_id)
|
||||||
template = toggle_template_active_svc(db, template_id)
|
template = toggle_template_active_svc(db, template_id)
|
||||||
|
# Open context manager
|
||||||
with UnitOfWork(db) as uow:
|
with UnitOfWork(db) as uow:
|
||||||
|
# Call log_action()
|
||||||
log_action(
|
log_action(
|
||||||
db,
|
db,
|
||||||
|
# Keyword argument: user_id
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
|
# Keyword argument: action
|
||||||
action="toggle_test_template",
|
action="toggle_test_template",
|
||||||
|
# Keyword argument: entity_type
|
||||||
entity_type="test_template",
|
entity_type="test_template",
|
||||||
|
# Keyword argument: entity_id
|
||||||
entity_id=template.id,
|
entity_id=template.id,
|
||||||
|
# Keyword argument: details
|
||||||
details={"name": template.name, "is_active": template.is_active},
|
details={"name": template.name, "is_active": template.is_active},
|
||||||
)
|
)
|
||||||
|
# Call uow.commit()
|
||||||
uow.commit()
|
uow.commit()
|
||||||
|
# Reload ORM object attributes from the database
|
||||||
db.refresh(template)
|
db.refresh(template)
|
||||||
|
|
||||||
|
# Return template
|
||||||
return template
|
return template
|
||||||
|
|
||||||
|
|
||||||
@@ -269,23 +484,47 @@ def toggle_template_active(
|
|||||||
|
|
||||||
|
|
||||||
@router.delete("/{template_id}", status_code=status.HTTP_200_OK)
|
@router.delete("/{template_id}", status_code=status.HTTP_200_OK)
|
||||||
|
# Define function delete_template
|
||||||
def delete_template(
|
def delete_template(
|
||||||
|
# Entry: template_id
|
||||||
template_id: uuid.UUID,
|
template_id: uuid.UUID,
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||||
):
|
) -> dict:
|
||||||
"""Soft-delete a test template by setting ``is_active=False``."""
|
"""Soft-delete a test template by setting ``is_active=False``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
template_id (uuid.UUID): Primary key of the template to delete.
|
||||||
|
db (Session): SQLAlchemy database session.
|
||||||
|
current_user (User): Authenticated red_lead or blue_lead.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Confirmation message with key ``detail``.
|
||||||
|
"""
|
||||||
|
# Assign template = get_template_or_raise(db, template_id)
|
||||||
template = get_template_or_raise(db, template_id)
|
template = get_template_or_raise(db, template_id)
|
||||||
|
# Call soft_delete_template()
|
||||||
soft_delete_template(db, template_id)
|
soft_delete_template(db, template_id)
|
||||||
|
# Open context manager
|
||||||
with UnitOfWork(db) as uow:
|
with UnitOfWork(db) as uow:
|
||||||
|
# Call log_action()
|
||||||
log_action(
|
log_action(
|
||||||
db,
|
db,
|
||||||
|
# Keyword argument: user_id
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
|
# Keyword argument: action
|
||||||
action="delete_test_template",
|
action="delete_test_template",
|
||||||
|
# Keyword argument: entity_type
|
||||||
entity_type="test_template",
|
entity_type="test_template",
|
||||||
|
# Keyword argument: entity_id
|
||||||
entity_id=template.id,
|
entity_id=template.id,
|
||||||
|
# Keyword argument: details
|
||||||
details={"name": template.name},
|
details={"name": template.name},
|
||||||
)
|
)
|
||||||
|
# Call uow.commit()
|
||||||
uow.commit()
|
uow.commit()
|
||||||
|
|
||||||
|
# Return {"detail": "Test template deactivated"}
|
||||||
return {"detail": "Test template deactivated"}
|
return {"detail": "Test template deactivated"}
|
||||||
|
|||||||
+590
-40
File diff suppressed because it is too large
Load Diff
@@ -4,15 +4,28 @@ Provides listing, detail, coverage analysis, and gap analysis for
|
|||||||
threat actor profiles imported from MITRE CTI.
|
threat actor profiles imported from MITRE CTI.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Import logging
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
# Import Optional from typing
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
# Import APIRouter, Depends, Query from fastapi
|
||||||
from fastapi import APIRouter, Depends, Query
|
from fastapi import APIRouter, Depends, Query
|
||||||
|
|
||||||
|
# Import Session from sqlalchemy.orm
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
# Import get_db from app.database
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
|
|
||||||
|
# Import get_current_user from app.dependencies.auth
|
||||||
from app.dependencies.auth import get_current_user
|
from app.dependencies.auth import get_current_user
|
||||||
|
|
||||||
|
# Import User from app.models.user
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
|
|
||||||
|
# Import from app.services.threat_actor_service
|
||||||
from app.services.threat_actor_service import (
|
from app.services.threat_actor_service import (
|
||||||
get_actor_coverage,
|
get_actor_coverage,
|
||||||
get_actor_detail,
|
get_actor_detail,
|
||||||
@@ -20,58 +33,90 @@ from app.services.threat_actor_service import (
|
|||||||
list_actors,
|
list_actors,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Assign logger = logging.getLogger(__name__)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Assign router = APIRouter(prefix="/threat-actors", tags=["threat-actors"])
|
||||||
router = APIRouter(prefix="/threat-actors", tags=["threat-actors"])
|
router = APIRouter(prefix="/threat-actors", tags=["threat-actors"])
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @router.get decorator
|
||||||
@router.get("")
|
@router.get("")
|
||||||
|
# Define function list_threat_actors
|
||||||
def list_threat_actors(
|
def list_threat_actors(
|
||||||
|
# Entry: search
|
||||||
search: Optional[str] = Query(None),
|
search: Optional[str] = Query(None),
|
||||||
|
# Entry: country
|
||||||
country: Optional[str] = Query(None),
|
country: Optional[str] = Query(None),
|
||||||
|
# Entry: motivation
|
||||||
motivation: Optional[str] = Query(None),
|
motivation: Optional[str] = Query(None),
|
||||||
|
# Entry: sophistication
|
||||||
sophistication: Optional[str] = Query(None),
|
sophistication: Optional[str] = Query(None),
|
||||||
|
# Entry: target_sectors
|
||||||
target_sectors: Optional[str] = Query(None),
|
target_sectors: Optional[str] = Query(None),
|
||||||
|
# Entry: offset
|
||||||
offset: int = Query(0, ge=0),
|
offset: int = Query(0, ge=0),
|
||||||
|
# Entry: limit
|
||||||
limit: int = Query(50, ge=1, le=200),
|
limit: int = Query(50, ge=1, le=200),
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
) -> list:
|
||||||
"""List threat actors with optional filters and pagination.
|
"""List threat actors with optional filters and pagination.
|
||||||
|
|
||||||
**Requires** authentication (any role).
|
**Requires** authentication (any role).
|
||||||
"""
|
"""
|
||||||
|
# Return list_actors(
|
||||||
return list_actors(
|
return list_actors(
|
||||||
db,
|
db,
|
||||||
|
# Keyword argument: search
|
||||||
search=search,
|
search=search,
|
||||||
|
# Keyword argument: country
|
||||||
country=country,
|
country=country,
|
||||||
|
# Keyword argument: motivation
|
||||||
motivation=motivation,
|
motivation=motivation,
|
||||||
|
# Keyword argument: sophistication
|
||||||
sophistication=sophistication,
|
sophistication=sophistication,
|
||||||
|
# Keyword argument: target_sectors
|
||||||
target_sectors=target_sectors,
|
target_sectors=target_sectors,
|
||||||
|
# Keyword argument: offset
|
||||||
offset=offset,
|
offset=offset,
|
||||||
|
# Keyword argument: limit
|
||||||
limit=limit,
|
limit=limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @router.get decorator
|
||||||
@router.get("/{actor_id}")
|
@router.get("/{actor_id}")
|
||||||
|
# Define function get_threat_actor
|
||||||
def get_threat_actor(
|
def get_threat_actor(
|
||||||
|
# Entry: actor_id
|
||||||
actor_id: str,
|
actor_id: str,
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
) -> dict:
|
||||||
"""Get detailed info about a threat actor including techniques.
|
"""Get detailed info about a threat actor including techniques.
|
||||||
|
|
||||||
**Requires** authentication (any role).
|
**Requires** authentication (any role).
|
||||||
"""
|
"""
|
||||||
|
# Return get_actor_detail(db, actor_id)
|
||||||
return get_actor_detail(db, actor_id)
|
return get_actor_detail(db, actor_id)
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @router.get decorator
|
||||||
@router.get("/{actor_id}/coverage")
|
@router.get("/{actor_id}/coverage")
|
||||||
|
# Define function get_threat_actor_coverage
|
||||||
def get_threat_actor_coverage(
|
def get_threat_actor_coverage(
|
||||||
|
# Entry: actor_id
|
||||||
actor_id: str,
|
actor_id: str,
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
) -> dict:
|
||||||
"""Calculate coverage percentage against a specific threat actor.
|
"""Calculate coverage percentage against a specific threat actor.
|
||||||
|
|
||||||
**Requires** authentication (any role).
|
**Requires** authentication (any role).
|
||||||
@@ -79,19 +124,26 @@ def get_threat_actor_coverage(
|
|||||||
Returns the percentage of the actor's techniques that have been
|
Returns the percentage of the actor's techniques that have been
|
||||||
validated or partially validated, along with a breakdown.
|
validated or partially validated, along with a breakdown.
|
||||||
"""
|
"""
|
||||||
|
# Return get_actor_coverage(db, actor_id)
|
||||||
return get_actor_coverage(db, actor_id)
|
return get_actor_coverage(db, actor_id)
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @router.get decorator
|
||||||
@router.get("/{actor_id}/gaps")
|
@router.get("/{actor_id}/gaps")
|
||||||
|
# Define function get_threat_actor_gaps
|
||||||
def get_threat_actor_gaps(
|
def get_threat_actor_gaps(
|
||||||
|
# Entry: actor_id
|
||||||
actor_id: str,
|
actor_id: str,
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
) -> list:
|
||||||
"""Identify techniques of this actor that are NOT fully validated.
|
"""Identify techniques of this actor that are NOT fully validated.
|
||||||
|
|
||||||
**Requires** authentication (any role).
|
**Requires** authentication (any role).
|
||||||
|
|
||||||
Returns list of gap techniques with available templates.
|
Returns list of gap techniques with available templates.
|
||||||
"""
|
"""
|
||||||
|
# Return get_actor_gaps(db, actor_id)
|
||||||
return get_actor_gaps(db, actor_id)
|
return get_actor_gaps(db, actor_id)
|
||||||
|
|||||||
@@ -1,17 +1,30 @@
|
|||||||
"""User management router (admin only)."""
|
"""User management router (admin only)."""
|
||||||
|
|
||||||
|
# Import uuid
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
# Import APIRouter, Depends, status from fastapi
|
||||||
from fastapi import APIRouter, Depends, status
|
from fastapi import APIRouter, Depends, status
|
||||||
|
|
||||||
|
# Import Session from sqlalchemy.orm
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
# Import get_db from app.database
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
|
|
||||||
|
# Import require_role from app.dependencies.auth
|
||||||
from app.dependencies.auth import require_role
|
from app.dependencies.auth import require_role
|
||||||
|
|
||||||
|
# Import UnitOfWork from app.domain.unit_of_work
|
||||||
from app.domain.unit_of_work import UnitOfWork
|
from app.domain.unit_of_work import UnitOfWork
|
||||||
|
|
||||||
|
# Import User from app.models.user
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.dependencies.auth import get_current_user
|
from app.dependencies.auth import get_current_user
|
||||||
from app.schemas.user import UserCreate, UserUpdate, UserOut, UserPreferencesUpdate
|
from app.schemas.user import UserCreate, UserUpdate, UserOut, UserPreferencesUpdate
|
||||||
from app.services.audit_service import log_action
|
from app.services.audit_service import log_action
|
||||||
|
|
||||||
|
# Import from app.services.user_service
|
||||||
from app.services.user_service import (
|
from app.services.user_service import (
|
||||||
create_user,
|
create_user,
|
||||||
get_user_or_raise,
|
get_user_or_raise,
|
||||||
@@ -19,6 +32,7 @@ from app.services.user_service import (
|
|||||||
update_user,
|
update_user,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Assign router = APIRouter(prefix="/users", tags=["users"])
|
||||||
router = APIRouter(prefix="/users", tags=["users"])
|
router = APIRouter(prefix="/users", tags=["users"])
|
||||||
|
|
||||||
|
|
||||||
@@ -69,11 +83,15 @@ def get_me(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("", response_model=list[UserOut])
|
@router.get("", response_model=list[UserOut])
|
||||||
|
# Define function list_users_route
|
||||||
def list_users_route(
|
def list_users_route(
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(require_role("admin")),
|
current_user: User = Depends(require_role("admin")),
|
||||||
):
|
) -> list[UserOut]:
|
||||||
"""Return a list of all users. **Requires admin role.**"""
|
"""Return a list of all users. **Requires admin role.**."""
|
||||||
|
# Return list_users(db)
|
||||||
return list_users(db)
|
return list_users(db)
|
||||||
|
|
||||||
|
|
||||||
@@ -83,31 +101,50 @@ def list_users_route(
|
|||||||
|
|
||||||
|
|
||||||
@router.post("", response_model=UserOut, status_code=status.HTTP_201_CREATED)
|
@router.post("", response_model=UserOut, status_code=status.HTTP_201_CREATED)
|
||||||
|
# Define function create_user_route
|
||||||
def create_user_route(
|
def create_user_route(
|
||||||
|
# Entry: payload
|
||||||
payload: UserCreate,
|
payload: UserCreate,
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(require_role("admin")),
|
current_user: User = Depends(require_role("admin")),
|
||||||
):
|
) -> UserOut:
|
||||||
"""Create a new user. **Requires admin role.**"""
|
"""Create a new user. **Requires admin role.**."""
|
||||||
|
# Open context manager
|
||||||
with UnitOfWork(db) as uow:
|
with UnitOfWork(db) as uow:
|
||||||
|
# Assign user = create_user(
|
||||||
user = create_user(
|
user = create_user(
|
||||||
db,
|
db,
|
||||||
|
# Keyword argument: username
|
||||||
username=payload.username,
|
username=payload.username,
|
||||||
|
# Keyword argument: email
|
||||||
email=payload.email,
|
email=payload.email,
|
||||||
|
# Keyword argument: password
|
||||||
password=payload.password,
|
password=payload.password,
|
||||||
|
# Keyword argument: role
|
||||||
role=payload.role,
|
role=payload.role,
|
||||||
)
|
)
|
||||||
|
# Call log_action()
|
||||||
log_action(
|
log_action(
|
||||||
db,
|
db,
|
||||||
|
# Keyword argument: user_id
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
|
# Keyword argument: action
|
||||||
action="create_user",
|
action="create_user",
|
||||||
|
# Keyword argument: entity_type
|
||||||
entity_type="user",
|
entity_type="user",
|
||||||
|
# Keyword argument: entity_id
|
||||||
entity_id=user.id,
|
entity_id=user.id,
|
||||||
|
# Keyword argument: details
|
||||||
details={"username": user.username, "role": user.role},
|
details={"username": user.username, "role": user.role},
|
||||||
)
|
)
|
||||||
|
# Call uow.commit()
|
||||||
uow.commit()
|
uow.commit()
|
||||||
|
# Reload ORM object attributes from the database
|
||||||
db.refresh(user)
|
db.refresh(user)
|
||||||
|
|
||||||
|
# Return user
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
@@ -117,12 +154,17 @@ def create_user_route(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/{user_id}", response_model=UserOut)
|
@router.get("/{user_id}", response_model=UserOut)
|
||||||
|
# Define function get_user
|
||||||
def get_user(
|
def get_user(
|
||||||
|
# Entry: user_id
|
||||||
user_id: uuid.UUID,
|
user_id: uuid.UUID,
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(require_role("admin")),
|
current_user: User = Depends(require_role("admin")),
|
||||||
):
|
) -> UserOut:
|
||||||
"""Return a single user by ID. **Requires admin role.**"""
|
"""Return a single user by ID. **Requires admin role.**."""
|
||||||
|
# Return get_user_or_raise(db, user_id)
|
||||||
return get_user_or_raise(db, user_id)
|
return get_user_or_raise(db, user_id)
|
||||||
|
|
||||||
|
|
||||||
@@ -132,25 +174,42 @@ def get_user(
|
|||||||
|
|
||||||
|
|
||||||
@router.patch("/{user_id}", response_model=UserOut)
|
@router.patch("/{user_id}", response_model=UserOut)
|
||||||
|
# Define function update_user_route
|
||||||
def update_user_route(
|
def update_user_route(
|
||||||
|
# Entry: user_id
|
||||||
user_id: uuid.UUID,
|
user_id: uuid.UUID,
|
||||||
|
# Entry: payload
|
||||||
payload: UserUpdate,
|
payload: UserUpdate,
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: current_user
|
||||||
current_user: User = Depends(require_role("admin")),
|
current_user: User = Depends(require_role("admin")),
|
||||||
):
|
) -> UserOut:
|
||||||
"""Update one or more fields of an existing user. **Requires admin role.**"""
|
"""Update one or more fields of an existing user. **Requires admin role.**."""
|
||||||
|
# Assign update_data = payload.model_dump(exclude_unset=True)
|
||||||
update_data = payload.model_dump(exclude_unset=True)
|
update_data = payload.model_dump(exclude_unset=True)
|
||||||
|
# Open context manager
|
||||||
with UnitOfWork(db) as uow:
|
with UnitOfWork(db) as uow:
|
||||||
|
# Assign user = update_user(db, user_id, **update_data)
|
||||||
user = update_user(db, user_id, **update_data)
|
user = update_user(db, user_id, **update_data)
|
||||||
|
# Call log_action()
|
||||||
log_action(
|
log_action(
|
||||||
db,
|
db,
|
||||||
|
# Keyword argument: user_id
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
|
# Keyword argument: action
|
||||||
action="update_user",
|
action="update_user",
|
||||||
|
# Keyword argument: entity_type
|
||||||
entity_type="user",
|
entity_type="user",
|
||||||
|
# Keyword argument: entity_id
|
||||||
entity_id=user.id,
|
entity_id=user.id,
|
||||||
|
# Keyword argument: details
|
||||||
details={"updated_fields": list(update_data.keys())},
|
details={"updated_fields": list(update_data.keys())},
|
||||||
)
|
)
|
||||||
|
# Call uow.commit()
|
||||||
uow.commit()
|
uow.commit()
|
||||||
|
# Reload ORM object attributes from the database
|
||||||
db.refresh(user)
|
db.refresh(user)
|
||||||
|
|
||||||
|
# Return user
|
||||||
return user
|
return user
|
||||||
|
|||||||
@@ -1,19 +1,39 @@
|
|||||||
"""Worklog router — internal time-tracking records with integrity verification."""
|
"""Worklog router — internal time-tracking records with integrity verification."""
|
||||||
|
|
||||||
|
# Import datetime from datetime
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
# Import Optional from typing
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
# Import UUID from uuid
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Query
|
# Import APIRouter, Depends from fastapi
|
||||||
|
from fastapi import APIRouter, Depends
|
||||||
|
|
||||||
|
# Import BaseModel, Field from pydantic
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
# Import Session from sqlalchemy.orm
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
# Import get_db from app.database
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
|
|
||||||
|
# Import get_current_user, require_any_role from app.dependencies.auth
|
||||||
from app.dependencies.auth import get_current_user, require_any_role
|
from app.dependencies.auth import get_current_user, require_any_role
|
||||||
|
|
||||||
|
# Import UnitOfWork from app.domain.unit_of_work
|
||||||
from app.domain.unit_of_work import UnitOfWork
|
from app.domain.unit_of_work import UnitOfWork
|
||||||
|
|
||||||
|
# Import User from app.models.user
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
|
|
||||||
|
# Import worklog_service from app.services
|
||||||
from app.services import worklog_service
|
from app.services import worklog_service
|
||||||
|
|
||||||
|
# Assign router = APIRouter(prefix="/worklogs", tags=["worklogs"])
|
||||||
router = APIRouter(prefix="/worklogs", tags=["worklogs"])
|
router = APIRouter(prefix="/worklogs", tags=["worklogs"])
|
||||||
|
|
||||||
|
|
||||||
@@ -21,30 +41,58 @@ router = APIRouter(prefix="/worklogs", tags=["worklogs"])
|
|||||||
|
|
||||||
|
|
||||||
class WorklogCreate(BaseModel):
|
class WorklogCreate(BaseModel):
|
||||||
|
"""Payload for logging a work session against an entity."""
|
||||||
|
|
||||||
|
# Assign entity_type = Field(..., max_length=50)
|
||||||
entity_type: str = Field(..., max_length=50)
|
entity_type: str = Field(..., max_length=50)
|
||||||
|
# entity_id: UUID
|
||||||
entity_id: UUID
|
entity_id: UUID
|
||||||
|
# Assign activity_type = Field(..., max_length=100)
|
||||||
activity_type: str = Field(..., max_length=100)
|
activity_type: str = Field(..., max_length=100)
|
||||||
|
# started_at: datetime
|
||||||
started_at: datetime
|
started_at: datetime
|
||||||
|
# Assign ended_at = None
|
||||||
ended_at: Optional[datetime] = None
|
ended_at: Optional[datetime] = None
|
||||||
|
# Assign duration_seconds = Field(..., gt=0)
|
||||||
duration_seconds: int = Field(..., gt=0)
|
duration_seconds: int = Field(..., gt=0)
|
||||||
|
# Assign description = None
|
||||||
description: Optional[str] = None
|
description: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
# Define class WorklogOut
|
||||||
class WorklogOut(BaseModel):
|
class WorklogOut(BaseModel):
|
||||||
|
"""Serialized worklog entry returned by the API."""
|
||||||
|
|
||||||
|
# id: UUID
|
||||||
id: UUID
|
id: UUID
|
||||||
|
# entity_type: str
|
||||||
entity_type: str
|
entity_type: str
|
||||||
|
# entity_id: UUID
|
||||||
entity_id: UUID
|
entity_id: UUID
|
||||||
|
# user_id: UUID
|
||||||
user_id: UUID
|
user_id: UUID
|
||||||
|
# activity_type: str
|
||||||
activity_type: str
|
activity_type: str
|
||||||
|
# started_at: datetime
|
||||||
started_at: datetime
|
started_at: datetime
|
||||||
|
# Assign ended_at = None
|
||||||
ended_at: Optional[datetime] = None
|
ended_at: Optional[datetime] = None
|
||||||
|
# duration_seconds: int
|
||||||
duration_seconds: int
|
duration_seconds: int
|
||||||
|
# Assign description = None
|
||||||
description: Optional[str] = None
|
description: Optional[str] = None
|
||||||
|
# Assign tempo_synced = None
|
||||||
tempo_synced: Optional[datetime] = None
|
tempo_synced: Optional[datetime] = None
|
||||||
|
# Assign integrity_hash = None
|
||||||
integrity_hash: Optional[str] = None
|
integrity_hash: Optional[str] = None
|
||||||
|
# created_at: datetime
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
|
|
||||||
|
# Define class Config
|
||||||
class Config:
|
class Config:
|
||||||
|
"""ORM mode configuration for SQLAlchemy model mapping."""
|
||||||
|
|
||||||
|
# Assign from_attributes = True
|
||||||
from_attributes = True
|
from_attributes = True
|
||||||
|
|
||||||
|
|
||||||
@@ -52,65 +100,146 @@ class WorklogOut(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
@router.post("", response_model=WorklogOut, status_code=201)
|
@router.post("", response_model=WorklogOut, status_code=201)
|
||||||
|
# Define function create
|
||||||
def create(
|
def create(
|
||||||
|
# Entry: body
|
||||||
body: WorklogCreate,
|
body: WorklogCreate,
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: user
|
||||||
user: User = Depends(require_any_role("red_tech", "blue_tech", "red_lead", "blue_lead")),
|
user: User = Depends(require_any_role("red_tech", "blue_tech", "red_lead", "blue_lead")),
|
||||||
):
|
) -> WorklogOut:
|
||||||
"""Create a manually-logged worklog entry."""
|
"""Create a manually-logged worklog entry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
body (WorklogCreate): Worklog fields including entity, activity type, and duration.
|
||||||
|
db (Session): SQLAlchemy database session.
|
||||||
|
user (User): Authenticated team member creating the worklog.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
WorklogOut: The newly created worklog with integrity hash and all fields.
|
||||||
|
"""
|
||||||
|
# Open context manager
|
||||||
with UnitOfWork(db) as uow:
|
with UnitOfWork(db) as uow:
|
||||||
|
# Assign wl = worklog_service.create_worklog(
|
||||||
wl = worklog_service.create_worklog(
|
wl = worklog_service.create_worklog(
|
||||||
db,
|
db,
|
||||||
|
# Keyword argument: entity_type
|
||||||
entity_type=body.entity_type,
|
entity_type=body.entity_type,
|
||||||
|
# Keyword argument: entity_id
|
||||||
entity_id=body.entity_id,
|
entity_id=body.entity_id,
|
||||||
|
# Keyword argument: user_id
|
||||||
user_id=user.id,
|
user_id=user.id,
|
||||||
|
# Keyword argument: activity_type
|
||||||
activity_type=body.activity_type,
|
activity_type=body.activity_type,
|
||||||
|
# Keyword argument: started_at
|
||||||
started_at=body.started_at,
|
started_at=body.started_at,
|
||||||
|
# Keyword argument: ended_at
|
||||||
ended_at=body.ended_at,
|
ended_at=body.ended_at,
|
||||||
|
# Keyword argument: duration_seconds
|
||||||
duration_seconds=body.duration_seconds,
|
duration_seconds=body.duration_seconds,
|
||||||
|
# Keyword argument: description
|
||||||
description=body.description,
|
description=body.description,
|
||||||
)
|
)
|
||||||
|
# Call uow.commit()
|
||||||
uow.commit()
|
uow.commit()
|
||||||
|
# Reload ORM object attributes from the database
|
||||||
db.refresh(wl)
|
db.refresh(wl)
|
||||||
|
# Return wl
|
||||||
return wl
|
return wl
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @router.get decorator
|
||||||
@router.get("", response_model=list[WorklogOut])
|
@router.get("", response_model=list[WorklogOut])
|
||||||
|
# Define function list_all
|
||||||
def list_all(
|
def list_all(
|
||||||
|
# Entry: entity_type
|
||||||
entity_type: Optional[str] = None,
|
entity_type: Optional[str] = None,
|
||||||
|
# Entry: entity_id
|
||||||
entity_id: Optional[UUID] = None,
|
entity_id: Optional[UUID] = None,
|
||||||
|
# Entry: user_id
|
||||||
user_id: Optional[UUID] = None,
|
user_id: Optional[UUID] = None,
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: _user
|
||||||
_user: User = Depends(get_current_user),
|
_user: User = Depends(get_current_user),
|
||||||
):
|
) -> list[WorklogOut]:
|
||||||
"""List worklogs with optional filters."""
|
"""List worklogs with optional filters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
entity_type (Optional[str]): Filter by entity type (e.g. ``test``, ``campaign``).
|
||||||
|
entity_id (Optional[UUID]): Filter by the UUID of the associated entity.
|
||||||
|
user_id (Optional[UUID]): Filter by the UUID of the worklog author.
|
||||||
|
db (Session): SQLAlchemy database session.
|
||||||
|
_user (User): Authenticated user making the request (unused, enforces auth).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[WorklogOut]: Serialised list of worklog entries matching the filters.
|
||||||
|
"""
|
||||||
|
# Return worklog_service.list_worklogs(
|
||||||
return worklog_service.list_worklogs(
|
return worklog_service.list_worklogs(
|
||||||
db,
|
db,
|
||||||
|
# Keyword argument: entity_type
|
||||||
entity_type=entity_type,
|
entity_type=entity_type,
|
||||||
|
# Keyword argument: entity_id
|
||||||
entity_id=entity_id,
|
entity_id=entity_id,
|
||||||
|
# Keyword argument: user_id
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @router.get decorator
|
||||||
@router.get("/{worklog_id}", response_model=WorklogOut)
|
@router.get("/{worklog_id}", response_model=WorklogOut)
|
||||||
|
# Define function get_one
|
||||||
def get_one(
|
def get_one(
|
||||||
|
# Entry: worklog_id
|
||||||
worklog_id: UUID,
|
worklog_id: UUID,
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: _user
|
||||||
_user: User = Depends(get_current_user),
|
_user: User = Depends(get_current_user),
|
||||||
):
|
) -> WorklogOut:
|
||||||
"""Get a single worklog by ID."""
|
"""Get a single worklog by ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
worklog_id (UUID): Primary key of the worklog to retrieve.
|
||||||
|
db (Session): SQLAlchemy database session.
|
||||||
|
_user (User): Authenticated user making the request (unused, enforces auth).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
WorklogOut: Full worklog detail including integrity hash.
|
||||||
|
"""
|
||||||
|
# Return worklog_service.get_worklog_or_raise(db, worklog_id)
|
||||||
return worklog_service.get_worklog_or_raise(db, worklog_id)
|
return worklog_service.get_worklog_or_raise(db, worklog_id)
|
||||||
|
|
||||||
|
|
||||||
|
# Apply the @router.get decorator
|
||||||
@router.get("/{worklog_id}/verify")
|
@router.get("/{worklog_id}/verify")
|
||||||
|
# Define function verify_integrity
|
||||||
def verify_integrity(
|
def verify_integrity(
|
||||||
|
# Entry: worklog_id
|
||||||
worklog_id: UUID,
|
worklog_id: UUID,
|
||||||
|
# Entry: db
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
# Entry: _user
|
||||||
_user: User = Depends(get_current_user),
|
_user: User = Depends(get_current_user),
|
||||||
):
|
) -> dict:
|
||||||
"""Check whether a worklog's integrity hash is still valid."""
|
"""Check whether a worklog's integrity hash is still valid.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
worklog_id (UUID): Primary key of the worklog to verify.
|
||||||
|
db (Session): SQLAlchemy database session.
|
||||||
|
_user (User): Authenticated user making the request (unused, enforces auth).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Contains ``worklog_id`` (str) and ``integrity_valid`` (bool).
|
||||||
|
"""
|
||||||
|
# Assign wl = worklog_service.get_worklog_or_raise(db, worklog_id)
|
||||||
wl = worklog_service.get_worklog_or_raise(db, worklog_id)
|
wl = worklog_service.get_worklog_or_raise(db, worklog_id)
|
||||||
|
# Return {
|
||||||
return {
|
return {
|
||||||
|
# Literal argument value
|
||||||
"worklog_id": str(wl.id),
|
"worklog_id": str(wl.id),
|
||||||
|
# Literal argument value
|
||||||
"integrity_valid": worklog_service.verify_worklog_integrity(wl),
|
"integrity_valid": worklog_service.verify_worklog_integrity(wl),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,12 @@
|
|||||||
"""Pydantic schemas — re-exported for convenient imports."""
|
"""Pydantic schemas — re-exported for convenient imports."""
|
||||||
|
|
||||||
|
# Import LoginRequest, TokenResponse, UserOut from app.schemas.auth
|
||||||
from app.schemas.auth import LoginRequest, TokenResponse, UserOut
|
from app.schemas.auth import LoginRequest, TokenResponse, UserOut
|
||||||
|
|
||||||
|
# Import EvidenceOut, EvidenceUpload from app.schemas.evidence
|
||||||
|
from app.schemas.evidence import EvidenceOut, EvidenceUpload
|
||||||
|
|
||||||
|
# Import from app.schemas.technique
|
||||||
from app.schemas.technique import (
|
from app.schemas.technique import (
|
||||||
TechniqueCreate,
|
TechniqueCreate,
|
||||||
TechniqueOut,
|
TechniqueOut,
|
||||||
@@ -9,51 +14,68 @@ from app.schemas.technique import (
|
|||||||
TechniqueUpdate,
|
TechniqueUpdate,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Import from app.schemas.test
|
||||||
from app.schemas.test import (
|
from app.schemas.test import (
|
||||||
|
TestBlueUpdate,
|
||||||
|
TestBlueValidate,
|
||||||
TestCreate,
|
TestCreate,
|
||||||
TestOut,
|
TestOut,
|
||||||
|
TestRedUpdate,
|
||||||
|
TestRedValidate,
|
||||||
TestUpdate,
|
TestUpdate,
|
||||||
TestValidate,
|
TestValidate,
|
||||||
TestRedUpdate,
|
|
||||||
TestBlueUpdate,
|
|
||||||
TestRedValidate,
|
|
||||||
TestBlueValidate,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from app.schemas.evidence import EvidenceOut, EvidenceUpload
|
# Import from app.schemas.test_template
|
||||||
|
|
||||||
from app.schemas.test_template import (
|
from app.schemas.test_template import (
|
||||||
TestTemplateOut,
|
|
||||||
TestTemplateCreate,
|
TestTemplateCreate,
|
||||||
TestTemplateSummary,
|
|
||||||
TestTemplateInstantiate,
|
TestTemplateInstantiate,
|
||||||
|
TestTemplateOut,
|
||||||
|
TestTemplateSummary,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Assign __all__ = [
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# Auth
|
# Auth
|
||||||
"LoginRequest",
|
"LoginRequest",
|
||||||
|
# Literal argument value
|
||||||
"TokenResponse",
|
"TokenResponse",
|
||||||
|
# Literal argument value
|
||||||
"UserOut",
|
"UserOut",
|
||||||
# Technique
|
# Technique
|
||||||
"TechniqueCreate",
|
"TechniqueCreate",
|
||||||
|
# Literal argument value
|
||||||
"TechniqueOut",
|
"TechniqueOut",
|
||||||
|
# Literal argument value
|
||||||
"TechniqueSummary",
|
"TechniqueSummary",
|
||||||
|
# Literal argument value
|
||||||
"TechniqueUpdate",
|
"TechniqueUpdate",
|
||||||
# Test
|
# Test
|
||||||
"TestCreate",
|
"TestCreate",
|
||||||
|
# Literal argument value
|
||||||
"TestOut",
|
"TestOut",
|
||||||
|
# Literal argument value
|
||||||
"TestUpdate",
|
"TestUpdate",
|
||||||
|
# Literal argument value
|
||||||
"TestValidate",
|
"TestValidate",
|
||||||
|
# Literal argument value
|
||||||
"TestRedUpdate",
|
"TestRedUpdate",
|
||||||
|
# Literal argument value
|
||||||
"TestBlueUpdate",
|
"TestBlueUpdate",
|
||||||
|
# Literal argument value
|
||||||
"TestRedValidate",
|
"TestRedValidate",
|
||||||
|
# Literal argument value
|
||||||
"TestBlueValidate",
|
"TestBlueValidate",
|
||||||
# Evidence
|
# Evidence
|
||||||
"EvidenceOut",
|
"EvidenceOut",
|
||||||
|
# Literal argument value
|
||||||
"EvidenceUpload",
|
"EvidenceUpload",
|
||||||
# Test Template
|
# Test Template
|
||||||
"TestTemplateOut",
|
"TestTemplateOut",
|
||||||
|
# Literal argument value
|
||||||
"TestTemplateCreate",
|
"TestTemplateCreate",
|
||||||
|
# Literal argument value
|
||||||
"TestTemplateSummary",
|
"TestTemplateSummary",
|
||||||
|
# Literal argument value
|
||||||
"TestTemplateInstantiate",
|
"TestTemplateInstantiate",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,31 +1,48 @@
|
|||||||
"""Pydantic schemas for Audit Log endpoints."""
|
"""Pydantic schemas for Audit Log endpoints."""
|
||||||
|
|
||||||
|
# Import uuid
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
# Import datetime from datetime
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
# Import BaseModel, ConfigDict from pydantic
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
|
|
||||||
|
# Define class AuditLogOut
|
||||||
class AuditLogOut(BaseModel):
|
class AuditLogOut(BaseModel):
|
||||||
"""Complete representation of an audit log entry."""
|
"""Complete representation of an audit log entry."""
|
||||||
|
|
||||||
|
# id: uuid.UUID
|
||||||
id: uuid.UUID
|
id: uuid.UUID
|
||||||
|
# Assign user_id = None
|
||||||
user_id: uuid.UUID | None = None
|
user_id: uuid.UUID | None = None
|
||||||
|
# Assign username = None # Populated from user relationship
|
||||||
username: str | None = None # Populated from user relationship
|
username: str | None = None # Populated from user relationship
|
||||||
|
# action: str
|
||||||
action: str
|
action: str
|
||||||
|
# Assign entity_type = None
|
||||||
entity_type: str | None = None
|
entity_type: str | None = None
|
||||||
|
# Assign entity_id = None
|
||||||
entity_id: str | None = None
|
entity_id: str | None = None
|
||||||
timestamp: Optional[datetime] = None
|
timestamp: Optional[datetime] = None
|
||||||
details: dict[str, Any] | None = None
|
details: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
# Assign model_config = ConfigDict(from_attributes=True)
|
||||||
model_config = ConfigDict(from_attributes=True)
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
|
||||||
|
# Define class AuditLogPage
|
||||||
class AuditLogPage(BaseModel):
|
class AuditLogPage(BaseModel):
|
||||||
"""Paginated response for audit logs."""
|
"""Paginated response for audit logs."""
|
||||||
|
|
||||||
|
# items: list[AuditLogOut]
|
||||||
items: list[AuditLogOut]
|
items: list[AuditLogOut]
|
||||||
|
# total: int
|
||||||
total: int
|
total: int
|
||||||
|
# offset: int
|
||||||
offset: int
|
offset: int
|
||||||
|
# limit: int
|
||||||
limit: int
|
limit: int
|
||||||
|
|||||||
@@ -1,34 +1,56 @@
|
|||||||
"""Pydantic schemas for authentication endpoints."""
|
"""Pydantic schemas for authentication endpoints."""
|
||||||
|
|
||||||
|
# Import uuid
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
# Import BaseModel from pydantic
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
# Define class LoginRequest
|
||||||
class LoginRequest(BaseModel):
|
class LoginRequest(BaseModel):
|
||||||
"""Body for the login endpoint (unused directly — we rely on
|
"""Body for the login endpoint.
|
||||||
``OAuth2PasswordRequestForm``, but kept for documentation / testing)."""
|
|
||||||
|
|
||||||
|
Unused directly — we rely on ``OAuth2PasswordRequestForm``, but kept for
|
||||||
|
documentation and testing purposes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# username: str
|
||||||
username: str
|
username: str
|
||||||
|
# password: str
|
||||||
password: str
|
password: str
|
||||||
|
|
||||||
|
|
||||||
|
# Define class TokenResponse
|
||||||
class TokenResponse(BaseModel):
|
class TokenResponse(BaseModel):
|
||||||
"""Response returned after a successful login."""
|
"""Response returned after a successful login."""
|
||||||
|
|
||||||
|
# access_token: str
|
||||||
access_token: str
|
access_token: str
|
||||||
|
# Assign token_type = "bearer"
|
||||||
token_type: str = "bearer"
|
token_type: str = "bearer"
|
||||||
|
|
||||||
|
|
||||||
|
# Define class UserOut
|
||||||
class UserOut(BaseModel):
|
class UserOut(BaseModel):
|
||||||
"""Public representation of a user (no password hash)."""
|
"""Public representation of a user (no password hash)."""
|
||||||
|
|
||||||
|
# id: uuid.UUID
|
||||||
id: uuid.UUID
|
id: uuid.UUID
|
||||||
|
# username: str
|
||||||
username: str
|
username: str
|
||||||
|
# Assign email = None
|
||||||
email: str | None = None
|
email: str | None = None
|
||||||
|
# role: str
|
||||||
role: str
|
role: str
|
||||||
|
# is_active: bool
|
||||||
is_active: bool
|
is_active: bool
|
||||||
|
# Assign must_change_password = True
|
||||||
must_change_password: bool = True
|
must_change_password: bool = True
|
||||||
|
|
||||||
|
# Define class Config
|
||||||
class Config:
|
class Config:
|
||||||
|
"""ORM mode configuration for SQLAlchemy model mapping."""
|
||||||
|
|
||||||
|
# Assign from_attributes = True
|
||||||
from_attributes = True
|
from_attributes = True
|
||||||
|
|||||||
@@ -1,34 +1,53 @@
|
|||||||
"""Pydantic schemas for Evidence endpoints."""
|
"""Pydantic schemas for Evidence endpoints."""
|
||||||
|
|
||||||
|
# Import uuid
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
# Import datetime from datetime
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
# Import BaseModel, ConfigDict from pydantic
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
|
# Import TeamSide from app.models.enums
|
||||||
from app.models.enums import TeamSide
|
from app.models.enums import TeamSide
|
||||||
|
|
||||||
|
|
||||||
|
# Define class EvidenceOut
|
||||||
class EvidenceOut(BaseModel):
|
class EvidenceOut(BaseModel):
|
||||||
"""Representation of an evidence record returned by the API.
|
"""Representation of an evidence record returned by the API.
|
||||||
|
|
||||||
``download_url`` is a presigned URL generated at response time.
|
``download_url`` is a presigned URL generated at response time.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# id: uuid.UUID
|
||||||
id: uuid.UUID
|
id: uuid.UUID
|
||||||
|
# test_id: uuid.UUID
|
||||||
test_id: uuid.UUID
|
test_id: uuid.UUID
|
||||||
|
# file_name: str
|
||||||
file_name: str
|
file_name: str
|
||||||
|
# sha256_hash: str
|
||||||
sha256_hash: str
|
sha256_hash: str
|
||||||
|
# Assign uploaded_by = None
|
||||||
uploaded_by: uuid.UUID | None = None
|
uploaded_by: uuid.UUID | None = None
|
||||||
|
# Assign uploaded_at = None
|
||||||
uploaded_at: datetime | None = None
|
uploaded_at: datetime | None = None
|
||||||
|
# Assign team = TeamSide.red
|
||||||
team: TeamSide = TeamSide.red
|
team: TeamSide = TeamSide.red
|
||||||
|
# Assign notes = None
|
||||||
notes: str | None = None
|
notes: str | None = None
|
||||||
|
# Assign download_url = None
|
||||||
download_url: str | None = None
|
download_url: str | None = None
|
||||||
|
|
||||||
|
# Assign model_config = ConfigDict(from_attributes=True)
|
||||||
model_config = ConfigDict(from_attributes=True)
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
|
||||||
|
# Define class EvidenceUpload
|
||||||
class EvidenceUpload(BaseModel):
|
class EvidenceUpload(BaseModel):
|
||||||
"""Metadata sent alongside an evidence file upload."""
|
"""Metadata sent alongside an evidence file upload."""
|
||||||
|
|
||||||
|
# team: TeamSide
|
||||||
team: TeamSide
|
team: TeamSide
|
||||||
|
# Assign notes = None
|
||||||
notes: str | None = None
|
notes: str | None = None
|
||||||
|
|||||||
@@ -1,46 +1,91 @@
|
|||||||
"""Pydantic schemas for Jira integration endpoints."""
|
"""Pydantic schemas for Jira integration endpoints."""
|
||||||
|
|
||||||
|
# Import datetime from datetime
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
# Import Optional from typing
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
# Import UUID from uuid
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
|
# Import BaseModel, Field from pydantic
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
# Import JiraLinkEntityType, JiraSyncDirection from app.models.jira_link
|
||||||
from app.models.jira_link import JiraLinkEntityType, JiraSyncDirection
|
from app.models.jira_link import JiraLinkEntityType, JiraSyncDirection
|
||||||
|
|
||||||
|
|
||||||
|
# Define class JiraLinkCreate
|
||||||
class JiraLinkCreate(BaseModel):
|
class JiraLinkCreate(BaseModel):
|
||||||
|
"""Payload for linking an Aegis entity to an existing Jira issue."""
|
||||||
|
|
||||||
|
# entity_type: JiraLinkEntityType
|
||||||
entity_type: JiraLinkEntityType
|
entity_type: JiraLinkEntityType
|
||||||
|
# entity_id: UUID
|
||||||
entity_id: UUID
|
entity_id: UUID
|
||||||
|
# Assign jira_issue_key = Field(..., pattern=r"^[A-Z][A-Z0-9]+-\d+$")
|
||||||
jira_issue_key: str = Field(..., pattern=r"^[A-Z][A-Z0-9]+-\d+$")
|
jira_issue_key: str = Field(..., pattern=r"^[A-Z][A-Z0-9]+-\d+$")
|
||||||
|
# Assign sync_direction = JiraSyncDirection.bidirectional
|
||||||
sync_direction: JiraSyncDirection = JiraSyncDirection.bidirectional
|
sync_direction: JiraSyncDirection = JiraSyncDirection.bidirectional
|
||||||
|
|
||||||
|
|
||||||
|
# Define class JiraLinkOut
|
||||||
class JiraLinkOut(BaseModel):
|
class JiraLinkOut(BaseModel):
|
||||||
|
"""Full representation of a Jira link returned by the API."""
|
||||||
|
|
||||||
|
# id: UUID
|
||||||
id: UUID
|
id: UUID
|
||||||
|
# entity_type: JiraLinkEntityType
|
||||||
entity_type: JiraLinkEntityType
|
entity_type: JiraLinkEntityType
|
||||||
|
# entity_id: UUID
|
||||||
entity_id: UUID
|
entity_id: UUID
|
||||||
|
# jira_issue_key: str
|
||||||
jira_issue_key: str
|
jira_issue_key: str
|
||||||
|
# Assign jira_issue_id = None
|
||||||
jira_issue_id: Optional[str] = None
|
jira_issue_id: Optional[str] = None
|
||||||
|
# Assign jira_project_key = None
|
||||||
jira_project_key: Optional[str] = None
|
jira_project_key: Optional[str] = None
|
||||||
|
# Assign jira_status = None
|
||||||
jira_status: Optional[str] = None
|
jira_status: Optional[str] = None
|
||||||
|
# Assign jira_priority = None
|
||||||
jira_priority: Optional[str] = None
|
jira_priority: Optional[str] = None
|
||||||
|
# Assign jira_assignee = None
|
||||||
jira_assignee: Optional[str] = None
|
jira_assignee: Optional[str] = None
|
||||||
|
# Assign jira_story_points = None
|
||||||
jira_story_points: Optional[str] = None
|
jira_story_points: Optional[str] = None
|
||||||
|
# Assign last_synced_at = None
|
||||||
last_synced_at: Optional[datetime] = None
|
last_synced_at: Optional[datetime] = None
|
||||||
|
# created_at: datetime
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
|
|
||||||
|
# Define class Config
|
||||||
class Config:
|
class Config:
|
||||||
|
"""ORM mode configuration for SQLAlchemy model mapping."""
|
||||||
|
|
||||||
|
# Assign from_attributes = True
|
||||||
from_attributes = True
|
from_attributes = True
|
||||||
|
|
||||||
|
|
||||||
|
# Define class JiraIssueSearch
|
||||||
class JiraIssueSearch(BaseModel):
|
class JiraIssueSearch(BaseModel):
|
||||||
|
"""Payload for searching Jira issues by free-text query."""
|
||||||
|
|
||||||
|
# query: str
|
||||||
query: str
|
query: str
|
||||||
|
|
||||||
|
|
||||||
|
# Define class JiraIssueResult
|
||||||
class JiraIssueResult(BaseModel):
|
class JiraIssueResult(BaseModel):
|
||||||
|
"""Lightweight Jira issue representation returned by search results."""
|
||||||
|
|
||||||
|
# issue_key: str
|
||||||
issue_key: str
|
issue_key: str
|
||||||
|
# summary: str
|
||||||
summary: str
|
summary: str
|
||||||
|
# status: str
|
||||||
status: str
|
status: str
|
||||||
|
# Assign assignee = None
|
||||||
assignee: Optional[str] = None
|
assignee: Optional[str] = None
|
||||||
|
# Assign priority = None
|
||||||
priority: Optional[str] = None
|
priority: Optional[str] = None
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user