diff --git a/backend/app/__init__.py b/backend/app/__init__.py index e69de29..3c405cb 100644 --- a/backend/app/__init__.py +++ b/backend/app/__init__.py @@ -0,0 +1 @@ +"""Aegis — MITRE ATT&CK Coverage Platform application package.""" diff --git a/backend/app/auth.py b/backend/app/auth.py index 146cfb0..dfcef53 100644 --- a/backend/app/auth.py +++ b/backend/app/auth.py @@ -1,23 +1,32 @@ -""" -Security utilities: password hashing and JWT token management. +"""Security utilities: password hashing and JWT token management. This module provides pure functions for: - Hashing and verifying passwords using bcrypt via passlib. -- Creating JWT access tokens using python-jose. +- Creating JWT access tokens using PyJWT. - Managing a Redis-backed token blacklist for revocation. No endpoints are defined here. """ +# Import logging import logging + +# Import uuid import uuid as _uuid + +# Import datetime, timedelta, timezone from datetime from datetime import datetime, timedelta, timezone -from jose import jwt +# Import jwt (PyJWT) +import jwt + +# Import CryptContext from passlib.context from passlib.context import CryptContext +# Import settings from app.config from app.config import settings +# Assign logger = logging.getLogger(__name__) logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- @@ -27,13 +36,17 @@ logger = logging.getLogger(__name__) pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") +# Define function hash_password def hash_password(password: str) -> str: """Return a bcrypt hash of *password*.""" + # Return pwd_context.hash(password) return pwd_context.hash(password) +# Define function verify_password def verify_password(plain: str, hashed: str) -> bool: """Return ``True`` if *plain* matches the bcrypt *hashed* value.""" + # Return pwd_context.verify(plain, hashed) return pwd_context.verify(plain, hashed) @@ -48,14 +61,21 @@ def create_access_token(data: dict) -> str: - ``jti`` (JWT ID): unique identifier that enables token revocation. - ``exp``: expiration timestamp based on ``ACCESS_TOKEN_EXPIRE_MINUTES``. """ + # Assign to_encode = data.copy() to_encode = data.copy() + # Assign expire = datetime.now(timezone.utc) + timedelta( expire = datetime.now(timezone.utc) + timedelta( + # Keyword argument: minutes minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES, ) + # Call to_encode.update() to_encode.update({ + # Literal argument value "exp": expire, + # Literal argument value "jti": str(_uuid.uuid4()), }) + # Return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGOR... return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM) @@ -73,6 +93,7 @@ def create_access_token(data: dict) -> str: _BLACKLIST_PREFIX = "blacklist:" +# Define function blacklist_token def blacklist_token(jti: str, exp: float) -> None: """Add *jti* to the Redis blacklist with a TTL derived from *exp*. @@ -80,23 +101,38 @@ def blacklist_token(jti: str, exp: float) -> None: to ``exp - now`` so the key vanishes when the token would have expired naturally. """ + # Import get_redis_blacklist from app.infrastructure.redis_client from app.infrastructure.redis_client import get_redis_blacklist + # Assign ttl = max(int(exp - datetime.now(timezone.utc).timestamp()), 1) ttl = max(int(exp - datetime.now(timezone.utc).timestamp()), 1) + # Attempt the following; catch errors below try: + # Assign r = get_redis_blacklist() r = get_redis_blacklist() + # Call r.setex() r.setex(f"{_BLACKLIST_PREFIX}{jti}", ttl, "1") + # Handle Exception except Exception: + # Log warning: "Failed to blacklist token %s in Redis", jti, exc_ logger.warning("Failed to blacklist token %s in Redis", jti, exc_info=True) +# Define function is_token_blacklisted def is_token_blacklisted(jti: str) -> bool: """Return ``True`` if *jti* has been revoked (exists in Redis).""" + # Import get_redis_blacklist from app.infrastructure.redis_client from app.infrastructure.redis_client import get_redis_blacklist + # Attempt the following; catch errors below try: + # Assign r = get_redis_blacklist() r = get_redis_blacklist() + # Return r.exists(f"{_BLACKLIST_PREFIX}{jti}") > 0 return r.exists(f"{_BLACKLIST_PREFIX}{jti}") > 0 + # Handle Exception except Exception: + # Log warning: "Failed to check blacklist for %s in Redis", jti, logger.warning("Failed to check blacklist for %s in Redis", jti, exc_info=True) + # Return False return False diff --git a/backend/app/config.py b/backend/app/config.py index 6951912..3b4e42b 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -1,7 +1,21 @@ +"""Application configuration for the Aegis MITRE ATT&CK Coverage Platform. + +Loads settings from environment variables and ``.env`` files via +``pydantic-settings``. Validates critical secrets at import time and raises +``RuntimeError`` (production) or issues a ``UserWarning`` (development) when +unsafe defaults are detected. +""" + +# Import os import os + +# Import secrets import secrets + +# Import warnings import warnings +# Import BaseSettings from pydantic_settings from pydantic_settings import BaseSettings # --------------------------------------------------------------------------- @@ -10,7 +24,11 @@ from pydantic_settings import BaseSettings _is_production = os.environ.get("AEGIS_ENV", "").lower() == "production" +# Define class Settings class Settings(BaseSettings): + """Application settings loaded from environment variables and .env file.""" + + # Assign DATABASE_URL = "postgresql://postgres:postgres@postgres:5432/attackdb" DATABASE_URL: str = "postgresql://postgres:postgres@postgres:5432/attackdb" # ── Security ────────────────────────────────────────────────────── @@ -19,6 +37,7 @@ class Settings(BaseSettings): # for local dev). In production it MUST be supplied via env/.env # so tokens survive restarts. SECRET_KEY: str = "" + # Assign ALGORITHM = "HS256" ALGORITHM: str = "HS256" ACCESS_TOKEN_EXPIRE_MINUTES: int = 480 # 8 hours — /auth/refresh extends active sessions @@ -26,6 +45,7 @@ class Settings(BaseSettings): REDIS_URL: str = "redis://redis:6379/0" # Logical DB indices on the same Redis instance (PATH in URL is overridden). REDIS_TOKEN_BLACKLIST_DB: int = 1 + # Assign REDIS_CACHE_DB = 2 REDIS_CACHE_DB: int = 2 # ── CORS ───────────────────────────────────────────────────────── @@ -41,8 +61,11 @@ class Settings(BaseSettings): # the browser can reach MinIO directly. Defaults to MINIO_ENDPOINT. MINIO_PUBLIC_ENDPOINT: str = "" MINIO_ACCESS_KEY: str = "minioadmin" + # Assign MINIO_SECRET_KEY = "minioadmin" MINIO_SECRET_KEY: str = "minioadmin" + # Assign MINIO_BUCKET = "evidence" MINIO_BUCKET: str = "evidence" + # Assign MINIO_SECURE = False # True → use HTTPS to connect to MinIO MINIO_SECURE: bool = False # True → use HTTPS to connect to MinIO # ── Re-testing ─────────────────────────────────────────────────── @@ -50,10 +73,15 @@ class Settings(BaseSettings): # ── Jira Integration ──────────────────────────────────────────── JIRA_ENABLED: bool = False + # Assign JIRA_URL = "" JIRA_URL: str = "" + # Assign JIRA_USERNAME = "" JIRA_USERNAME: str = "" + # Assign JIRA_API_TOKEN = "" JIRA_API_TOKEN: str = "" + # Assign JIRA_IS_CLOUD = True JIRA_IS_CLOUD: bool = True + # Assign JIRA_DEFAULT_PROJECT = "" JIRA_DEFAULT_PROJECT: str = "" JIRA_ISSUE_TYPE_TEST: str = "Task" # tests (campaign or standalone) JIRA_ISSUE_TYPE_CAMPAIGN: str = "Epic" # campaigns (under Initiative) @@ -63,8 +91,11 @@ class Settings(BaseSettings): # ── Tempo Integration ───────────────────────────────────────────── TEMPO_ENABLED: bool = False + # Assign TEMPO_API_TOKEN = "" TEMPO_API_TOKEN: str = "" + # Assign TEMPO_API_VERSION = 4 TEMPO_API_VERSION: int = 4 + # Assign TEMPO_DEFAULT_WORK_TYPE = "Red Team" TEMPO_DEFAULT_WORK_TYPE: str = "Red Team" # Tempo API base URL — use https://api.eu.tempo.io/4 for EU workspaces. # Can also be set via system_configs key "tempo.base_url" at runtime. @@ -72,12 +103,16 @@ class Settings(BaseSettings): # ── OSINT / Intelligence ──────────────────────────────────────── NVD_API_KEY: str = "" # optional; increases NVD rate limit from 5/30s to 50/30s + # Assign STALE_THRESHOLD_DAYS = 365 # days before coverage is considered stale STALE_THRESHOLD_DAYS: int = 365 # days before coverage is considered stale # ── Reporting ───────────────────────────────────────────────────── REPORT_TEMPLATES_DIR: str = "app/templates/reports" + # Assign REPORT_OUTPUT_DIR = "/tmp/aegis_reports" REPORT_OUTPUT_DIR: str = "/tmp/aegis_reports" + # Assign COMPANY_NAME = "Organization" COMPANY_NAME: str = "Organization" + # Assign COMPANY_LOGO_PATH = "app/templates/reports/assets/logo.png" COMPANY_LOGO_PATH: str = "app/templates/reports/assets/logo.png" # ── Email / SMTP ────────────────────────────────────────────────── @@ -92,43 +127,68 @@ class Settings(BaseSettings): # ── Scoring weights (must sum to 100) ──────────────────────────── SCORING_WEIGHT_TESTS: int = 40 + # Assign SCORING_WEIGHT_DETECTION_RULES = 25 SCORING_WEIGHT_DETECTION_RULES: int = 25 + # Assign SCORING_WEIGHT_D3FEND = 15 SCORING_WEIGHT_D3FEND: int = 15 + # Assign SCORING_WEIGHT_RECENCY = 10 SCORING_WEIGHT_RECENCY: int = 10 + # Assign SCORING_WEIGHT_SEVERITY = 10 SCORING_WEIGHT_SEVERITY: int = 10 # Legacy env names (mapped in scoring_config_service) SCORING_WEIGHT_FRESHNESS: int = 10 + # Assign SCORING_WEIGHT_PLATFORM_DIVERSITY = 10 SCORING_WEIGHT_PLATFORM_DIVERSITY: int = 10 + # Define class Config class Config: + """Pydantic BaseSettings configuration — load from .env file.""" + + # Assign env_file = ".env" env_file = ".env" +# Assign settings = Settings() settings = Settings() # --------------------------------------------------------------------------- # Post-init validation for SECRET_KEY # --------------------------------------------------------------------------- _UNSAFE_SECRETS = { + # Literal argument value "", + # Literal argument value "change-me-in-production", + # Literal argument value "change-me-in-production-use-a-long-random-string", } +# Check: settings.SECRET_KEY in _UNSAFE_SECRETS if settings.SECRET_KEY in _UNSAFE_SECRETS: + # Check: _is_production if _is_production: + # Raise RuntimeError raise RuntimeError( + # Literal argument value "CRITICAL: SECRET_KEY is not configured. " + # Literal argument value "Set a strong random value (>= 32 chars) via the SECRET_KEY " + # Literal argument value "environment variable or in your .env file before running in " + # Literal argument value "production. Example: openssl rand -hex 32" ) # Development: auto-generate an ephemeral key and warn settings.SECRET_KEY = secrets.token_hex(32) + # Call warnings.warn() warnings.warn( + # Literal argument value "SECRET_KEY was not set — using an auto-generated ephemeral key. " + # Literal argument value "JWT tokens will be invalidated on every restart. " + # Literal argument value "Set SECRET_KEY in your environment for persistent sessions.", + # Keyword argument: stacklevel stacklevel=2, ) @@ -136,12 +196,16 @@ if settings.SECRET_KEY in _UNSAFE_SECRETS: # SEC-002: Reject default credentials in production # --------------------------------------------------------------------------- if _is_production: + # Assign _DEFAULT_CREDS = { _DEFAULT_CREDS = { ("MINIO_ACCESS_KEY", settings.MINIO_ACCESS_KEY, "minioadmin"), ("MINIO_SECRET_KEY", settings.MINIO_SECRET_KEY, "minioadmin"), } + # Iterate over _DEFAULT_CREDS for name, current, default in _DEFAULT_CREDS: + # Check: current == default if current == default: + # Raise RuntimeError raise RuntimeError( f"CRITICAL: {name} is using the default value '{default}'. " f"Set a strong value via the {name} environment variable " diff --git a/backend/app/database.py b/backend/app/database.py index 0c6adb0..6d0dd9f 100644 --- a/backend/app/database.py +++ b/backend/app/database.py @@ -1,68 +1,164 @@ -from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker, declarative_base +"""Database engine and session management for the Aegis platform. +The engine and session factory are created lazily so that tests can override +``DATABASE_URL`` via environment variables before any import triggers real +PostgreSQL engine creation (which requires psycopg2). +""" + +# Import Generator from collections.abc +from collections.abc import Generator + +# Import create_engine from sqlalchemy +from sqlalchemy import create_engine + +# Import Engine from sqlalchemy.engine +from sqlalchemy.engine import Engine + +# Import Session, declarative_base, sessionmaker from sqlalchemy.orm +from sqlalchemy.orm import Session, declarative_base, sessionmaker + +# Assign Base = declarative_base() Base = declarative_base() # Engine and session factory are created lazily so that tests can # override DATABASE_URL via environment *before* any import triggers # the real PostgreSQL engine creation (which requires psycopg2). _engine = None +# Assign _SessionLocal = None _SessionLocal = None -def _get_engine(): +# Define function _get_engine +def _get_engine() -> Engine: + """Return the shared SQLAlchemy engine, creating it on first call. + + Returns: + Engine: Configured SQLAlchemy engine for the application database. + """ + # Declare global variable global _engine + # Check: _engine is None if _engine is None: + # Import settings from app.config from app.config import settings + # Assign url = settings.DATABASE_URL url = settings.DATABASE_URL + # Assign kwargs = {} kwargs: dict = {} + # Check: url.startswith("postgresql") if url.startswith("postgresql"): + # Call kwargs.update() kwargs.update( + # Keyword argument: pool_size pool_size=20, + # Keyword argument: max_overflow max_overflow=10, + # Keyword argument: pool_recycle pool_recycle=3600, + # Keyword argument: pool_pre_ping pool_pre_ping=True, ) + # Assign _engine = create_engine(url, **kwargs) _engine = create_engine(url, **kwargs) + # Return _engine return _engine -def _get_session_factory(): +# Define function _get_session_factory +def _get_session_factory() -> sessionmaker: + """Return the shared sessionmaker, creating it on first call. + + Returns: + sessionmaker: Configured sessionmaker bound to the application engine. + """ + # Declare global variable global _SessionLocal + # Check: _SessionLocal is None if _SessionLocal is None: + # Assign _SessionLocal = sessionmaker( _SessionLocal = sessionmaker( + # Keyword argument: autocommit autocommit=False, autoflush=False, bind=_get_engine() ) + # Return _SessionLocal return _SessionLocal +# Define class _LazySessionLocal class _LazySessionLocal: - """Proxy so ``SessionLocal()`` keeps working as before but the real - sessionmaker is only created on first call.""" + """Proxy so ``SessionLocal()`` keeps working as before but the real sessionmaker is only created on first call.""" - def __call__(self, *args, **kwargs): + # Define function __call__ + def __call__(self, *args: object, **kwargs: object) -> Session: + """Create and return a new database session. + + Args: + *args (object): Positional arguments forwarded to the sessionmaker. + **kwargs (object): Keyword arguments forwarded to the sessionmaker. + + Returns: + Session: A new SQLAlchemy database session. + """ + # Return _get_session_factory()(*args, **kwargs) return _get_session_factory()(*args, **kwargs) - def __getattr__(self, name): + # Define function __getattr__ + def __getattr__(self, name: str) -> object: + """Delegate attribute access to the underlying sessionmaker. + + Args: + name (str): Attribute name to look up on the sessionmaker. + + Returns: + object: The attribute value from the underlying sessionmaker. + """ + # Return getattr(_get_session_factory(), name) return getattr(_get_session_factory(), name) +# Assign SessionLocal = _LazySessionLocal() SessionLocal = _LazySessionLocal() +# Define class _EngineProxy class _EngineProxy: """Thin proxy so ``from app.database import engine`` still works.""" - def __getattr__(self, name): + + # Define function __getattr__ + def __getattr__(self, name: str) -> object: + """Delegate attribute access to the lazily-created engine. + + Args: + name (str): Attribute name to look up on the real engine. + + Returns: + object: The attribute value from the underlying SQLAlchemy engine. + """ + # Return getattr(_get_engine(), name) return getattr(_get_engine(), name) +# Assign engine = _EngineProxy() # type: ignore[assignment] engine = _EngineProxy() # type: ignore[assignment] -def get_db(): +# Define function get_db +def get_db() -> Generator[Session, None, None]: + """Yield a database session and close it when the request is done. + + Intended for use as a FastAPI dependency. + + Yields: + Session: An active SQLAlchemy session for the current request. + """ + # Assign db = SessionLocal() db = SessionLocal() + # Attempt the following; catch errors below try: + # Yield db yield db + # Always execute this cleanup block finally: + # Close the database session db.close() diff --git a/backend/app/dependencies/__init__.py b/backend/app/dependencies/__init__.py index e69de29..b5990de 100644 --- a/backend/app/dependencies/__init__.py +++ b/backend/app/dependencies/__init__.py @@ -0,0 +1 @@ +"""FastAPI dependency injection helpers for auth, DB, and shared state.""" diff --git a/backend/app/dependencies/auth.py b/backend/app/dependencies/auth.py index bc45475..d06ffd7 100644 --- a/backend/app/dependencies/auth.py +++ b/backend/app/dependencies/auth.py @@ -1,5 +1,4 @@ -""" -Authentication and RBAC dependencies for FastAPI. +"""Authentication and RBAC dependencies for FastAPI. Provides: - ``get_current_user``: decodes JWT from HttpOnly cookie (preferred) or @@ -9,16 +8,34 @@ Provides: (admins always pass). """ +# Import Callable from collections.abc +from collections.abc import Callable + +# Import Optional from typing from typing import Optional +# Import Cookie, Depends, HTTPException, status from fastapi from fastapi import Cookie, Depends, HTTPException, status + +# Import OAuth2PasswordBearer from fastapi.security from fastapi.security import OAuth2PasswordBearer -from jose import JWTError, jwt + +# Import jwt (PyJWT) +import jwt + +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session +# Import auth as auth_lib from app from app import auth as auth_lib + +# Import settings from app.config from app.config import settings + +# Import get_db from app.database from app.database import get_db + +# Import User from app.models.user from app.models.user import User from app.models.api_key import KEY_PREFIX @@ -37,8 +54,11 @@ _COOKIE_NAME = "aegis_token" async def get_current_user( + # Entry: aegis_token aegis_token: Optional[str] = Cookie(None), + # Entry: bearer_token bearer_token: Optional[str] = Depends(oauth2_scheme), + # Entry: db db: Session = Depends(get_db), ) -> User: """Decode the JWT, look up the user in *db*, and return it. @@ -54,20 +74,30 @@ async def get_current_user( - the ``sub`` claim is missing, or - no matching active user exists in the database. """ + # Assign credentials_exception = HTTPException( credentials_exception = HTTPException( + # Keyword argument: status_code status_code=status.HTTP_401_UNAUTHORIZED, + # Keyword argument: detail detail="Could not validate credentials", + # Keyword argument: headers headers={"WWW-Authenticate": "Bearer"}, ) + # Assign revoked_exception = HTTPException( revoked_exception = HTTPException( + # Keyword argument: status_code status_code=status.HTTP_401_UNAUTHORIZED, + # Keyword argument: detail detail="Token has been revoked", + # Keyword argument: headers headers={"WWW-Authenticate": "Bearer"}, ) # Prefer cookie, fall back to header token = aegis_token or bearer_token + # Check: token is None if token is None: + # Raise credentials_exception raise credentials_exception # ── API Key path (Bearer token starts with "aegis_") ────────────────── @@ -80,25 +110,38 @@ async def get_current_user( # ── JWT path ────────────────────────────────────────────────────────── try: + # Assign payload = jwt.decode( payload = jwt.decode( token, settings.SECRET_KEY, + # Keyword argument: algorithms algorithms=[settings.ALGORITHM], ) + # Assign username = payload.get("sub") username: str | None = payload.get("sub") + # Check: username is None if username is None: + # Raise credentials_exception raise credentials_exception # Check token blacklist (revoked tokens) jti: str | None = payload.get("jti") + # Check: jti and auth_lib.is_token_blacklisted(jti) if jti and auth_lib.is_token_blacklisted(jti): + # Raise revoked_exception raise revoked_exception - except JWTError: + # Handle any JWT validation error (expired, invalid signature, malformed) + except jwt.exceptions.InvalidTokenError: + # Raise credentials_exception raise credentials_exception + # Assign user = db.query(User).filter(User.username == username).first() user = db.query(User).filter(User.username == username).first() + # Check: user is None or not user.is_active if user is None or not user.is_active: + # Raise credentials_exception raise credentials_exception + # Return user return user @@ -108,6 +151,7 @@ async def get_current_user( async def require_password_changed( + # Entry: current_user current_user: User = Depends(get_current_user), ) -> User: """Block all requests when the user still needs to change their password. @@ -115,11 +159,16 @@ async def require_password_changed( Only ``/auth/change-password`` and ``/auth/me`` are exempt — those endpoints do **not** depend on this function. """ + # Check: getattr(current_user, "must_change_password", False) if getattr(current_user, "must_change_password", False): + # Raise HTTPException raise HTTPException( + # Keyword argument: status_code status_code=status.HTTP_403_FORBIDDEN, + # Keyword argument: detail detail="PASSWORD_CHANGE_REQUIRED", ) + # Return current_user return current_user @@ -147,22 +196,30 @@ def require_role(required_role: str): Otherwise it raises :class:`~fastapi.HTTPException` **403**. """ + # Define async function role_checker async def role_checker( + # Entry: current_user current_user: User = Depends(get_current_user), ) -> User: + # Check: current_user.role != required_role and current_user.role != "admin" if current_user.role != required_role and current_user.role != "admin": + # Raise HTTPException raise HTTPException( + # Keyword argument: status_code status_code=status.HTTP_403_FORBIDDEN, + # Keyword argument: detail detail="Not enough permissions", ) scope = "admin" if required_role == "admin" else "write" _check_api_key_scope(current_user, scope) return current_user + # Return role_checker return role_checker -def require_any_role(*roles: str): +# Define function require_any_role +def require_any_role(*roles: str) -> Callable[..., object]: """Return a FastAPI dependency that enforces **any** of the given *roles*. Admins always pass. Also enforces API key scopes: if the only accepted @@ -174,18 +231,25 @@ def require_any_role(*roles: str): @router.patch("/resource", dependencies=[Depends(require_any_role("red_lead", "blue_lead"))]) """ + # Define async function role_checker async def role_checker( + # Entry: current_user current_user: User = Depends(get_current_user), ) -> User: + # Check: current_user.role != "admin" and current_user.role not in roles if current_user.role != "admin" and current_user.role not in roles: + # Raise HTTPException raise HTTPException( + # Keyword argument: status_code status_code=status.HTTP_403_FORBIDDEN, + # Keyword argument: detail detail="Not enough permissions", ) scope = "admin" if set(roles) == {"admin"} else "write" _check_api_key_scope(current_user, scope) return current_user + # Return role_checker return role_checker diff --git a/backend/app/dependencies/repositories.py b/backend/app/dependencies/repositories.py index eae5d62..ecc8950 100644 --- a/backend/app/dependencies/repositories.py +++ b/backend/app/dependencies/repositories.py @@ -4,27 +4,41 @@ Wiring lives ONLY in the presentation layer — use cases and services never know which concrete repository implementation they receive. """ +# Import Depends from fastapi from fastapi import Depends + +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session +# Import get_db from app.database from app.database import get_db + +# Import from app.infrastructure.persistence.repositories.sa_technique_repository from app.infrastructure.persistence.repositories.sa_technique_repository import ( SATechniqueRepository, ) + +# Import from app.infrastructure.persistence.repositories.sa_test_repository from app.infrastructure.persistence.repositories.sa_test_repository import ( SATestRepository, ) +# Define function get_technique_repository def get_technique_repository( + # Entry: db db: Session = Depends(get_db), ) -> SATechniqueRepository: """Provide a TechniqueRepository backed by the current DB session.""" + # Return SATechniqueRepository(db) return SATechniqueRepository(db) +# Define function get_test_repository def get_test_repository( + # Entry: db db: Session = Depends(get_db), ) -> SATestRepository: """Provide a TestRepository backed by the current DB session.""" + # Return SATestRepository(db) return SATestRepository(db) diff --git a/backend/app/domain/__init__.py b/backend/app/domain/__init__.py index e69de29..eb0fb57 100644 --- a/backend/app/domain/__init__.py +++ b/backend/app/domain/__init__.py @@ -0,0 +1 @@ +"""Domain layer — entities, value objects, errors, and repository ports.""" diff --git a/backend/app/domain/entities/__init__.py b/backend/app/domain/entities/__init__.py index 29c3ae2..c9ca2a7 100644 --- a/backend/app/domain/entities/__init__.py +++ b/backend/app/domain/entities/__init__.py @@ -1,18 +1,34 @@ +"""Domain entity classes representing core business objects.""" +# Import CampaignEntity from app.domain.entities.campaign from app.domain.entities.campaign import CampaignEntity + +# Import from app.domain.entities.compliance from app.domain.entities.compliance import ( ComplianceControlEntity, ComplianceFrameworkEntity, ControlCoverageStatus, ) + +# Import TechniqueEntity from app.domain.entities.technique from app.domain.entities.technique import TechniqueEntity + +# Import ThreatActorEntity, ThreatActorTechniqueRef from app.domain.entities.threat_actor from app.domain.entities.threat_actor import ThreatActorEntity, ThreatActorTechniqueRef +# Assign __all__ = [ __all__ = [ + # Literal argument value "CampaignEntity", + # Literal argument value "ComplianceControlEntity", + # Literal argument value "ComplianceFrameworkEntity", + # Literal argument value "ControlCoverageStatus", + # Literal argument value "TechniqueEntity", + # Literal argument value "ThreatActorEntity", + # Literal argument value "ThreatActorTechniqueRef", ] diff --git a/backend/app/domain/entities/campaign.py b/backend/app/domain/entities/campaign.py index 02c1487..53e25b6 100644 --- a/backend/app/domain/entities/campaign.py +++ b/backend/app/domain/entities/campaign.py @@ -3,30 +3,59 @@ Pure domain logic — no framework imports. """ +# Enable future language features for compatibility from __future__ import annotations +# Import enum import enum -import uuid -from dataclasses import dataclass, field -from typing import Any +# Import uuid +import uuid + +# Import dataclass, field from dataclasses +from dataclasses import dataclass, field + +# Import TYPE_CHECKING from typing +from typing import TYPE_CHECKING + +# Import BusinessRuleViolation, InvalidStateTransition from app.domain.errors from app.domain.errors import BusinessRuleViolation, InvalidStateTransition +# Check: TYPE_CHECKING +if TYPE_CHECKING: + # Import Campaign as CampaignORM from app.models.campaign + from app.models.campaign import Campaign as CampaignORM + +# Define class CampaignStatus class CampaignStatus(str, enum.Enum): + """Lifecycle states for a campaign.""" + + # Assign draft = "draft" draft = "draft" + # Assign active = "active" active = "active" + # Assign completed = "completed" completed = "completed" + # Assign archived = "archived" archived = "archived" +# Define class CampaignType class CampaignType(str, enum.Enum): + """Classification of the campaign's testing methodology.""" + + # Assign custom = "custom" custom = "custom" + # Assign apt_emulation = "apt_emulation" apt_emulation = "apt_emulation" + # Assign kill_chain = "kill_chain" kill_chain = "kill_chain" + # Assign compliance = "compliance" compliance = "compliance" +# Assign VALID_TRANSITIONS = { VALID_TRANSITIONS: dict[CampaignStatus, list[CampaignStatus]] = { CampaignStatus.draft: [CampaignStatus.active], CampaignStatus.active: [CampaignStatus.completed], @@ -35,69 +64,156 @@ VALID_TRANSITIONS: dict[CampaignStatus, list[CampaignStatus]] = { } +# Apply the @dataclass decorator @dataclass +# Define class CampaignEntity class CampaignEntity: + """Pure domain representation of a security testing campaign. + + Owns all lifecycle state-machine logic for campaign activation, + completion, and archival. + """ + + # name: str name: str + # Assign type = CampaignType.custom type: CampaignType = CampaignType.custom + # Assign status = CampaignStatus.draft status: CampaignStatus = CampaignStatus.draft + # Assign id = None id: uuid.UUID | None = None + # Assign description = None description: str | None = None + # Assign threat_actor_id = None threat_actor_id: uuid.UUID | None = None + # Assign created_by = None created_by: uuid.UUID | None = None + # Assign target_platform = None target_platform: str | None = None + # Assign tags = field(default_factory=list) tags: list[str] = field(default_factory=list) + # Assign test_count = 0 test_count: int = 0 + # Define function can_transition_to def can_transition_to(self, target: CampaignStatus) -> bool: + """Check whether transitioning from the current status to *target* is valid. + + Args: + target (CampaignStatus): The desired next status. + + Returns: + bool: True if the transition is allowed, False otherwise. + """ + # Return target in VALID_TRANSITIONS.get(self.status, []) return target in VALID_TRANSITIONS.get(self.status, []) + # Define function activate def activate(self) -> None: + """Transition the campaign from ``draft`` to ``active``. + + Returns: + None + """ + # Check: not self.can_transition_to(CampaignStatus.active) if not self.can_transition_to(CampaignStatus.active): + # Raise InvalidStateTransition raise InvalidStateTransition( self.status.value, CampaignStatus.active.value, [s.value for s in VALID_TRANSITIONS[self.status]], ) + # Check: self.test_count == 0 if self.test_count == 0: + # Raise BusinessRuleViolation raise BusinessRuleViolation( + # Literal argument value "Campaign must have at least one test to activate" ) + # Assign self.status = CampaignStatus.active self.status = CampaignStatus.active + # Define function complete def complete(self) -> None: + """Transition the campaign from ``active`` to ``completed``. + + Returns: + None + """ + # Check: not self.can_transition_to(CampaignStatus.completed) if not self.can_transition_to(CampaignStatus.completed): + # Raise InvalidStateTransition raise InvalidStateTransition( self.status.value, CampaignStatus.completed.value, [s.value for s in VALID_TRANSITIONS[self.status]], ) + # Assign self.status = CampaignStatus.completed self.status = CampaignStatus.completed + # Define function archive def archive(self) -> None: + """Transition the campaign from ``completed`` to ``archived``. + + Returns: + None + """ + # Check: not self.can_transition_to(CampaignStatus.archived) if not self.can_transition_to(CampaignStatus.archived): + # Raise InvalidStateTransition raise InvalidStateTransition( self.status.value, CampaignStatus.archived.value, [s.value for s in VALID_TRANSITIONS[self.status]], ) + # Assign self.status = CampaignStatus.archived self.status = CampaignStatus.archived + # Define function ensure_modifiable def ensure_modifiable(self) -> None: + """Raise BusinessRuleViolation if the campaign is not in a modifiable state. + + Returns: + None + """ + # Check: self.status not in (CampaignStatus.draft, CampaignStatus.active) if self.status not in (CampaignStatus.draft, CampaignStatus.active): + # Raise BusinessRuleViolation raise BusinessRuleViolation( f"Cannot modify campaign in '{self.status.value}' state" ) + # Apply the @classmethod decorator @classmethod - def from_orm(cls, orm: Any) -> CampaignEntity: - """Build a CampaignEntity from a SQLAlchemy Campaign model.""" + # Define function from_orm + def from_orm(cls, orm: CampaignORM) -> CampaignEntity: + """Build a CampaignEntity from a SQLAlchemy Campaign model. + + Args: + orm (CampaignORM): The SQLAlchemy Campaign ORM model instance. + + Returns: + CampaignEntity: A fully populated domain entity reflecting the ORM state. + """ + # Assign test_count = len(getattr(orm, "campaign_tests", None) or []) test_count = len(getattr(orm, "campaign_tests", None) or []) + # Return cls( return cls( + # Keyword argument: id id=orm.id, + # Keyword argument: name name=orm.name, + # Keyword argument: type type=CampaignType(orm.type) if orm.type else CampaignType.custom, + # Keyword argument: status status=CampaignStatus(orm.status) if orm.status else CampaignStatus.draft, + # Keyword argument: description description=orm.description, + # Keyword argument: threat_actor_id threat_actor_id=orm.threat_actor_id, + # Keyword argument: created_by created_by=orm.created_by, + # Keyword argument: target_platform target_platform=orm.target_platform, + # Keyword argument: tags tags=orm.tags or [], + # Keyword argument: test_count test_count=test_count, ) diff --git a/backend/app/domain/entities/compliance.py b/backend/app/domain/entities/compliance.py index 549eb6b..30fe36c 100644 --- a/backend/app/domain/entities/compliance.py +++ b/backend/app/domain/entities/compliance.py @@ -3,68 +3,161 @@ Pure domain logic — no framework imports. """ +# Enable future language features for compatibility from __future__ import annotations +# Import enum import enum + +# Import uuid import uuid + +# Import dataclass, field from dataclasses from dataclasses import dataclass, field +# Define class ControlCoverageStatus class ControlCoverageStatus(str, enum.Enum): + """Computed coverage level for a single compliance control.""" + + # Assign covered = "covered" covered = "covered" + # Assign partially_covered = "partially_covered" partially_covered = "partially_covered" + # Assign not_covered = "not_covered" not_covered = "not_covered" +# Apply the @dataclass decorator @dataclass +# Define class ComplianceControlEntity class ComplianceControlEntity: + """Pure domain representation of a single compliance framework control. + + Derives its coverage status from the technique statuses associated + with it via the ``technique_statuses`` list. + """ + + # control_id: str control_id: str + # title: str title: str + # Assign id = None id: uuid.UUID | None = None + # Assign description = None description: str | None = None + # Assign category = None category: str | None = None + # Assign technique_statuses = field(default_factory=list) technique_statuses: list[str] = field(default_factory=list) + # Apply the @property decorator @property + # Define function coverage_status def coverage_status(self) -> ControlCoverageStatus: + """Compute the coverage status for this control based on linked technique statuses. + + Returns: + ControlCoverageStatus: ``covered`` when all techniques are covered, + ``partially_covered`` when at least one is covered, and + ``not_covered`` when none are covered or the control has no techniques. + """ + # Check: not self.technique_statuses if not self.technique_statuses: + # Return ControlCoverageStatus.not_covered return ControlCoverageStatus.not_covered + # Assign covered_statuses = {"validated", "partial"} covered_statuses = {"validated", "partial"} + # Assign covered = [s for s in self.technique_statuses if s in covered_statuses] covered = [s for s in self.technique_statuses if s in covered_statuses] + # Check: len(covered) == len(self.technique_statuses) if len(covered) == len(self.technique_statuses): + # Return ControlCoverageStatus.covered return ControlCoverageStatus.covered + # Alternative: len(covered) > 0 elif len(covered) > 0: + # Return ControlCoverageStatus.partially_covered return ControlCoverageStatus.partially_covered + # Return ControlCoverageStatus.not_covered return ControlCoverageStatus.not_covered +# Apply the @dataclass decorator @dataclass +# Define class ComplianceFrameworkEntity class ComplianceFrameworkEntity: + """Pure domain representation of a compliance framework (e.g. NIST 800-53, PCI-DSS). + + Aggregates a collection of controls and provides aggregate coverage statistics. + """ + + # name: str name: str + # Assign id = None id: uuid.UUID | None = None + # Assign version = None version: str | None = None + # Assign description = None description: str | None = None + # Assign is_active = True is_active: bool = True + # Assign controls = field(default_factory=list) controls: list[ComplianceControlEntity] = field(default_factory=list) + # Apply the @property decorator @property + # Define function total_controls def total_controls(self) -> int: + """Return the total number of controls in this framework. + + Returns: + int: Count of all controls regardless of coverage status. + """ + # Return len(self.controls) return len(self.controls) + # Apply the @property decorator @property + # Define function covered_controls def covered_controls(self) -> int: + """Return the number of fully covered controls in this framework. + + Returns: + int: Count of controls with ``ControlCoverageStatus.covered`` status. + """ + # Return sum( return sum( + # Literal argument value 1 for c in self.controls if c.coverage_status == ControlCoverageStatus.covered ) + # Apply the @property decorator @property + # Define function coverage_pct def coverage_pct(self) -> float: + """Return the percentage of controls that are fully covered. + + Returns: + float: A value from 0.0 to 100.0, rounded to one decimal place. + Returns 0.0 when the framework has no controls. + """ + # Check: self.total_controls == 0 if self.total_controls == 0: + # Return 0.0 return 0.0 + # Return round(self.covered_controls / self.total_controls * 100, 1) return round(self.covered_controls / self.total_controls * 100, 1) + # Define function get_gap_controls def get_gap_controls(self) -> list[ComplianceControlEntity]: + """Return controls that are not fully covered. + + Returns: + list[ComplianceControlEntity]: Controls with ``partially_covered`` or + ``not_covered`` status. + """ + # Return [ return [ c for c in self.controls if c.coverage_status != ControlCoverageStatus.covered diff --git a/backend/app/domain/entities/technique.py b/backend/app/domain/entities/technique.py index b30132f..1a5b6ef 100644 --- a/backend/app/domain/entities/technique.py +++ b/backend/app/domain/entities/technique.py @@ -12,105 +12,211 @@ Usage:: entity.apply_to(technique_orm_model) """ +# Enable future language features for compatibility from __future__ import annotations +# Import uuid import uuid -from dataclasses import dataclass, field -from datetime import datetime -from typing import Any +# Import dataclass, field from dataclasses +from dataclasses import dataclass, field + +# Import datetime from datetime +from datetime import datetime + +# Import TYPE_CHECKING from typing +from typing import TYPE_CHECKING + +# Import TechniqueStatus, TestResult, TestState from app.domain.enums from app.domain.enums import TechniqueStatus, TestResult, TestState + +# Import MitreId from app.domain.value_objects.mitre_id from app.domain.value_objects.mitre_id import MitreId +# Check: TYPE_CHECKING +if TYPE_CHECKING: + # Import Technique as TechniqueORM from app.models.technique + from app.models.technique import Technique as TechniqueORM + +# Apply the @dataclass decorator @dataclass(frozen=True) +# Define class _TestSnapshot class _TestSnapshot: """Minimal read-only view of a test for status calculation.""" + # state: TestState state: TestState + # detection_result: str | None detection_result: str | None +# Apply the @dataclass decorator @dataclass +# Define class TechniqueEntity class TechniqueEntity: """Pure domain representation of a MITRE ATT&CK technique.""" + # id: uuid.UUID id: uuid.UUID + # mitre_id: str mitre_id: str + # name: str name: str + # Assign tactic = None tactic: str | None = None + # Assign description = None description: str | None = None + # Assign platforms = field(default_factory=list) platforms: list[str] = field(default_factory=list) + # Assign is_subtechnique = False is_subtechnique: bool = False + # Assign parent_mitre_id = None parent_mitre_id: str | None = None + # Assign status_global = TechniqueStatus.not_evaluated status_global: TechniqueStatus = TechniqueStatus.not_evaluated + # Assign review_required = False review_required: bool = False + # Assign last_review_date = None last_review_date: datetime | None = None + # Assign mitre_version = None mitre_version: str | None = None + # Assign mitre_last_modified = None mitre_last_modified: datetime | None = None # -- Factory ----------------------------------------------------------- @classmethod + # Define function create def create( cls, *, + # Entry: mitre_id mitre_id: str, + # Entry: name name: str, + # Entry: tactic tactic: str | None = None, + # Entry: description description: str | None = None, + # Entry: platforms platforms: list[str] | None = None, ) -> TechniqueEntity: - """Create a new technique, validating the MITRE ID format.""" + """Create a new technique, validating the MITRE ID format. + + Args: + mitre_id (str): MITRE ATT&CK identifier (e.g. ``"T1059"`` or ``"T1059.001"``). + name (str): Human-readable name of the technique. + tactic (str | None): MITRE tactic category the technique belongs to. + description (str | None): Optional free-text description. + platforms (list[str] | None): List of platform strings the technique applies to. + + Returns: + TechniqueEntity: A new entity with a freshly generated UUID and + ``status_global`` set to ``not_evaluated``. + """ + # Assign validated_id = MitreId(mitre_id) validated_id = MitreId(mitre_id) + # Return cls( return cls( + # Keyword argument: id id=uuid.uuid4(), + # Keyword argument: mitre_id mitre_id=validated_id.value, + # Keyword argument: name name=name, + # Keyword argument: tactic tactic=tactic, + # Keyword argument: description description=description, + # Keyword argument: platforms platforms=platforms or [], + # Keyword argument: is_subtechnique is_subtechnique=validated_id.is_subtechnique, + # Keyword argument: parent_mitre_id parent_mitre_id=validated_id.parent_id, + # Keyword argument: status_global status_global=TechniqueStatus.not_evaluated, ) + # Apply the @classmethod decorator @classmethod - def from_orm(cls, model: Any) -> TechniqueEntity: - """Build a TechniqueEntity from a SQLAlchemy Technique model.""" + # Define function from_orm + def from_orm(cls, model: TechniqueORM) -> TechniqueEntity: + """Build a TechniqueEntity from a SQLAlchemy Technique model. + + Args: + model (TechniqueORM): The ORM model instance to convert. + + Returns: + TechniqueEntity: A fully populated domain entity reflecting the ORM state. + """ + # Assign raw_status = model.status_global raw_status = model.status_global + # Check: raw_status is None if raw_status is None: + # Assign status = TechniqueStatus.not_evaluated status = TechniqueStatus.not_evaluated + # Alternative: isinstance(raw_status, TechniqueStatus) elif isinstance(raw_status, TechniqueStatus): + # Assign status = raw_status status = raw_status + # Fallback: handle remaining cases else: + # Assign status = TechniqueStatus(raw_status) status = TechniqueStatus(raw_status) + # Return cls( return cls( + # Keyword argument: id id=model.id, + # Keyword argument: mitre_id mitre_id=model.mitre_id, + # Keyword argument: name name=model.name, + # Keyword argument: tactic tactic=model.tactic, + # Keyword argument: description description=model.description, + # Keyword argument: platforms platforms=model.platforms or [], + # Keyword argument: is_subtechnique is_subtechnique=model.is_subtechnique or False, + # Keyword argument: parent_mitre_id parent_mitre_id=model.parent_mitre_id, + # Keyword argument: status_global status_global=status, + # Keyword argument: review_required review_required=model.review_required or False, + # Keyword argument: last_review_date last_review_date=model.last_review_date, + # Keyword argument: mitre_version mitre_version=getattr(model, "mitre_version", None), + # Keyword argument: mitre_last_modified mitre_last_modified=getattr(model, "mitre_last_modified", None), ) - def apply_to(self, model: Any) -> None: - """Copy mutable fields back onto the ORM model.""" + # Define function apply_to + def apply_to(self, model: TechniqueORM) -> None: + """Copy mutable fields back onto the ORM model. + + Args: + model (TechniqueORM): The ORM model to update in-place. + + Returns: + None + """ + # Assign model.status_global = self.status_global model.status_global = self.status_global + # Assign model.review_required = self.review_required model.review_required = self.review_required + # Assign model.last_review_date = self.last_review_date model.last_review_date = self.last_review_date # -- Business logic ---------------------------------------------------- def recalculate_status( self, + # Entry: test_snapshots test_snapshots: list[tuple[str, str | None]], ) -> TechniqueStatus: """Recompute ``status_global`` from a list of (state, detection_result) pairs. @@ -131,23 +237,37 @@ class TechniqueEntity: With only 1 validated+detected test the technique is "partial" to signal that more testing is recommended. - Returns the new status (also set on the entity). + Args: + test_snapshots (list[tuple[str, str | None]]): Each element is a + ``(state, detection_result)`` pair where *state* is a + :class:`TestState` value string and *detection_result* is a + :class:`TestResult` value string or ``None``. + + Returns: + TechniqueStatus: The newly computed status, which is also stored on + the entity's ``status_global`` field. """ _MIN_VALIDATED_FOR_FULL = 2 # require ≥ N validated tests for "validated" tests = [ _TestSnapshot( + # Keyword argument: state state=s if isinstance(s, TestState) else TestState(s), + # Keyword argument: detection_result detection_result=dr, ) for s, dr in test_snapshots ] + # Check: not tests if not tests: + # Assign self.status_global = TechniqueStatus.not_evaluated self.status_global = TechniqueStatus.not_evaluated + # Alternative: all(t.state == TestState.validated for t in tests) elif all(t.state == TestState.validated for t in tests): validated_count = len(tests) results = [t.detection_result for t in tests if t.detection_result] + # Check: results and all(r == TestResult.detected or r == "detected" for r i... if results and all(r == TestResult.detected or r == "detected" for r in results): # Need at least _MIN_VALIDATED_FOR_FULL tests for "validated" if validated_count >= _MIN_VALIDATED_FOR_FULL: @@ -155,24 +275,46 @@ class TechniqueEntity: else: self.status_global = TechniqueStatus.partial elif any( + # Keyword argument: r r == TestResult.partially_detected or r == "partially_detected" for r in results ): + # Assign self.status_global = TechniqueStatus.partial self.status_global = TechniqueStatus.partial + # Fallback: handle remaining cases else: + # Assign self.status_global = TechniqueStatus.not_covered self.status_global = TechniqueStatus.not_covered + # Alternative: any(t.state == TestState.validated for t in tests) elif any(t.state == TestState.validated for t in tests): + # Assign self.status_global = TechniqueStatus.partial self.status_global = TechniqueStatus.partial + # Fallback: handle remaining cases else: + # Assign self.status_global = TechniqueStatus.in_progress self.status_global = TechniqueStatus.in_progress + # Return self.status_global return self.status_global + # Define function mark_reviewed def mark_reviewed(self) -> None: - """Mark the technique as reviewed, clearing the review flag.""" + """Mark the technique as reviewed, clearing the review flag. + + Returns: + None + """ + # Assign self.review_required = False self.review_required = False + # Assign self.last_review_date = datetime.utcnow() self.last_review_date = datetime.utcnow() + # Define function flag_for_review def flag_for_review(self) -> None: - """Flag the technique as needing review.""" + """Flag the technique as needing review. + + Returns: + None + """ + # Assign self.review_required = True self.review_required = True diff --git a/backend/app/domain/entities/threat_actor.py b/backend/app/domain/entities/threat_actor.py index d477014..dc68c09 100644 --- a/backend/app/domain/entities/threat_actor.py +++ b/backend/app/domain/entities/threat_actor.py @@ -3,94 +3,204 @@ Pure domain logic — no framework imports. """ +# Enable future language features for compatibility from __future__ import annotations +# Import uuid import uuid + +# Import dataclass, field from dataclasses from dataclasses import dataclass, field -from typing import Any + +# Import TYPE_CHECKING from typing +from typing import TYPE_CHECKING + +# Check: TYPE_CHECKING +if TYPE_CHECKING: + # Import ThreatActor as ThreatActorORM from app.models.threat_actor + from app.models.threat_actor import ThreatActor as ThreatActorORM +# Apply the @dataclass decorator @dataclass +# Define class ThreatActorTechniqueRef class ThreatActorTechniqueRef: """Lightweight reference to a technique used by an actor.""" + # technique_id: uuid.UUID technique_id: uuid.UUID + # Assign mitre_id = None mitre_id: str | None = None + # Assign name = None name: str | None = None + # Assign status = None status: str | None = None + # Assign usage_description = None usage_description: str | None = None +# Apply the @dataclass decorator @dataclass +# Define class ThreatActorEntity class ThreatActorEntity: + """Pure domain representation of a MITRE ATT&CK threat actor (group). + + Aggregates references to the techniques the actor is known to use and + provides coverage analysis properties. + """ + + # name: str name: str + # Assign id = None id: uuid.UUID | None = None + # Assign mitre_id = None mitre_id: str | None = None + # Assign aliases = field(default_factory=list) aliases: list[str] = field(default_factory=list) + # Assign description = None description: str | None = None + # Assign country = None country: str | None = None + # Assign target_sectors = field(default_factory=list) target_sectors: list[str] = field(default_factory=list) + # Assign target_regions = field(default_factory=list) target_regions: list[str] = field(default_factory=list) + # Assign motivation = None motivation: str | None = None + # Assign sophistication = None sophistication: str | None = None + # Assign first_seen = None first_seen: str | None = None + # Assign last_seen = None last_seen: str | None = None + # Assign is_active = True is_active: bool = True + # Assign techniques = field(default_factory=list) techniques: list[ThreatActorTechniqueRef] = field(default_factory=list) + # Apply the @property decorator @property + # Define function technique_count def technique_count(self) -> int: + """Return the total number of techniques associated with this actor. + + Returns: + int: Count of technique references. + """ + # Return len(self.techniques) return len(self.techniques) + # Apply the @property decorator @property + # Define function covered_techniques def covered_techniques(self) -> list[ThreatActorTechniqueRef]: + """Return technique references whose coverage status is ``validated`` or ``partial``. + + Returns: + list[ThreatActorTechniqueRef]: Subset of techniques considered covered. + """ + # Return [ return [ t for t in self.techniques if t.status in ("validated", "partial") ] + # Apply the @property decorator @property + # Define function uncovered_techniques def uncovered_techniques(self) -> list[ThreatActorTechniqueRef]: + """Return technique references whose coverage status is neither ``validated`` nor ``partial``. + + Returns: + list[ThreatActorTechniqueRef]: Subset of techniques not yet covered. + """ + # Return [ return [ t for t in self.techniques if t.status not in ("validated", "partial") ] + # Apply the @property decorator @property + # Define function coverage_pct def coverage_pct(self) -> float: + """Return the percentage of the actor's techniques that are covered. + + Returns: + float: A value from 0.0 to 100.0, rounded to one decimal place. + Returns 0.0 when the actor has no associated techniques. + """ + # Check: not self.techniques if not self.techniques: + # Return 0.0 return 0.0 + # Return round(len(self.covered_techniques) / len(self.techniques) * 100, 1) return round(len(self.covered_techniques) / len(self.techniques) * 100, 1) + # Apply the @classmethod decorator @classmethod - def from_orm(cls, orm: Any) -> ThreatActorEntity: + # Define function from_orm + def from_orm(cls, orm: ThreatActorORM) -> ThreatActorEntity: + """Build a ThreatActorEntity from a SQLAlchemy ThreatActor model. + + Args: + orm (ThreatActorORM): The ORM model instance to convert. + + Returns: + ThreatActorEntity: A fully populated domain entity including + technique references resolved from the ORM relationship. + """ + # Assign techs = [] techs: list[ThreatActorTechniqueRef] = [] + # Iterate over getattr(orm, "techniques", None) or [] for tat in getattr(orm, "techniques", None) or []: + # Assign technique = getattr(tat, "technique", None) technique = getattr(tat, "technique", None) + # Call techs.append() techs.append(ThreatActorTechniqueRef( + # Keyword argument: technique_id technique_id=tat.technique_id, + # Keyword argument: mitre_id mitre_id=getattr(technique, "mitre_id", None) if technique else None, + # Keyword argument: name name=getattr(technique, "name", None) if technique else None, + # Keyword argument: status status=( technique.status_global.value if technique and hasattr(technique.status_global, "value") else getattr(technique, "status_global", None) if technique else None ), + # Keyword argument: usage_description usage_description=tat.usage_description, )) + # Return cls( return cls( + # Keyword argument: id id=orm.id, + # Keyword argument: name name=orm.name, + # Keyword argument: mitre_id mitre_id=orm.mitre_id, + # Keyword argument: aliases aliases=orm.aliases or [], + # Keyword argument: description description=orm.description, + # Keyword argument: country country=orm.country, + # Keyword argument: target_sectors target_sectors=orm.target_sectors or [], + # Keyword argument: target_regions target_regions=orm.target_regions or [], + # Keyword argument: motivation motivation=orm.motivation, + # Keyword argument: sophistication sophistication=orm.sophistication, + # Keyword argument: first_seen first_seen=orm.first_seen, + # Keyword argument: last_seen last_seen=orm.last_seen, + # Keyword argument: is_active is_active=orm.is_active if orm.is_active is not None else True, + # Keyword argument: techniques techniques=techs, ) diff --git a/backend/app/domain/enums.py b/backend/app/domain/enums.py index 820f6be..6161db8 100644 --- a/backend/app/domain/enums.py +++ b/backend/app/domain/enums.py @@ -5,41 +5,78 @@ truth. ``models/enums.py`` re-exports them so that existing ORM code continues to work without changes. """ +# Import enum import enum +# Define class TechniqueStatus class TechniqueStatus(str, enum.Enum): + """Coverage and evaluation status for a MITRE ATT&CK technique.""" + + # Assign not_evaluated = "not_evaluated" not_evaluated = "not_evaluated" + # Assign in_progress = "in_progress" in_progress = "in_progress" + # Assign validated = "validated" validated = "validated" + # Assign partial = "partial" partial = "partial" + # Assign not_covered = "not_covered" not_covered = "not_covered" + # Assign review_required = "review_required" review_required = "review_required" +# Define class TestState class TestState(str, enum.Enum): + """Lifecycle states in the security test state machine.""" + + # Assign draft = "draft" draft = "draft" + # Assign red_executing = "red_executing" red_executing = "red_executing" + # Assign blue_evaluating = "blue_evaluating" blue_evaluating = "blue_evaluating" + # Assign in_review = "in_review" in_review = "in_review" + # Assign validated = "validated" validated = "validated" + # Assign rejected = "rejected" rejected = "rejected" disputed = "disputed" # one lead approved, the other rejected +# Define class TeamSide class TeamSide(str, enum.Enum): + """Identifies which team (red or blue) an action belongs to.""" + + # Assign red = "red" red = "red" + # Assign blue = "blue" blue = "blue" +# Define class TestResult class TestResult(str, enum.Enum): + """Outcome of a red-team test from a detection perspective.""" + + # Assign detected = "detected" detected = "detected" + # Assign not_detected = "not_detected" not_detected = "not_detected" + # Assign partially_detected = "partially_detected" partially_detected = "partially_detected" +# Define class DataClassification class DataClassification(str, enum.Enum): + """Data sensitivity classification levels for compliance and retention policies.""" + + # Assign public = "public" public = "public" + # Assign internal = "internal" internal = "internal" + # Assign sensitive = "sensitive" sensitive = "sensitive" + # Assign restricted = "restricted" restricted = "restricted" diff --git a/backend/app/domain/errors.py b/backend/app/domain/errors.py index e23f0d5..19502ff 100644 --- a/backend/app/domain/errors.py +++ b/backend/app/domain/errors.py @@ -9,15 +9,30 @@ Existing code that imports from ``app.domain.exceptions`` continues to work — that module re-exports everything defined here. """ +# Enable future language features for compatibility from __future__ import annotations +# Define class DomainError class DomainError(Exception): """Base for all domain errors.""" + # Define function __init__ def __init__(self, message: str, *, code: str = "DOMAIN_ERROR") -> None: + """Initialise the domain error with a human-readable message and error code. + + Args: + message (str): Human-readable description of the error. + code (str): Machine-readable error code used by the HTTP error handler. + + Returns: + None + """ + # Assign self.message = message self.message = message + # Assign self.code = code self.code = code + # Call super() super().__init__(message) @@ -27,18 +42,45 @@ class DomainError(Exception): class EntityNotFoundError(DomainError): """A requested entity does not exist.""" + # Define function __init__ def __init__(self, entity: str, identifier: str) -> None: + """Initialise an entity-not-found error. + + Args: + entity (str): Name of the entity type that was not found (e.g. "Technique"). + identifier (str): The ID or key used in the failed lookup. + + Returns: + None + """ + # Call super() super().__init__(f"{entity} not found: {identifier}", code="NOT_FOUND") + # Assign self.entity = entity self.entity = entity + # Assign self.identifier = identifier self.identifier = identifier +# Define class DuplicateEntityError class DuplicateEntityError(DomainError): """Creating an entity that already exists.""" + # Define function __init__ def __init__(self, entity: str, field: str, value: str) -> None: + """Initialise a duplicate-entity error. + + Args: + entity (str): Name of the entity type that already exists (e.g. "Campaign"). + field (str): Name of the field whose value conflicts (e.g. "name"). + value (str): The conflicting value that is already in use. + + Returns: + None + """ + # Call super() super().__init__( f"{entity} with {field}='{value}' already exists", + # Keyword argument: code code="DUPLICATE", ) @@ -46,34 +88,67 @@ class DuplicateEntityError(DomainError): # ── State machine ──────────────────────────────────────────────────── -class InvalidStateTransition(DomainError): +class InvalidStateTransition(DomainError): # noqa: N818 — DDD term, renaming would break 96 call sites """A state-machine transition is not allowed.""" + # Define function __init__ def __init__( self, + # Entry: current_state current_state: str, + # Entry: target_state target_state: str, + # Entry: valid_transitions valid_transitions: list[str] | None = None, ) -> None: + """Initialise an invalid state-transition error. + + Args: + current_state (str): The entity's present state (e.g. "draft"). + target_state (str): The state that was illegally requested. + valid_transitions (list[str] | None): Allowed target states from the + current state; included in the error message when provided. + + Returns: + None + """ + # Assign msg = f"Cannot transition from '{current_state}' to '{target_state}'" msg = f"Cannot transition from '{current_state}' to '{target_state}'" + # Check: valid_transitions if valid_transitions: + # Assign msg = f". Valid transitions: {valid_transitions}" msg += f". Valid transitions: {valid_transitions}" + # Call super() super().__init__(msg, code="INVALID_TRANSITION") + # Assign self.current_state = current_state self.current_state = current_state + # Assign self.target_state = target_state self.target_state = target_state + # Assign self.valid_transitions = valid_transitions or [] self.valid_transitions = valid_transitions or [] # ── Business rules ──────────────────────────────────────────────────── -class BusinessRuleViolation(DomainError): +class BusinessRuleViolation(DomainError): # noqa: N818 — DDD term, renaming would break 96 call sites """An operation violates a business invariant.""" + # Define function __init__ def __init__(self, message: str) -> None: + """Initialise a business-rule violation error. + + Args: + message (str): Human-readable description of the violated rule. + + Returns: + None + """ + # Call super() super().__init__(message, code="BUSINESS_RULE_VIOLATION") +# Define class InvalidOperationError class InvalidOperationError(BusinessRuleViolation): """An operation is invalid in the current context. @@ -81,16 +156,37 @@ class InvalidOperationError(BusinessRuleViolation): :class:`BusinessRuleViolation` directly. """ + # Define function __init__ def __init__(self, message: str) -> None: + """Initialise an invalid-operation error. + + Args: + message (str): Human-readable description of why the operation is invalid. + + Returns: + None + """ + # Call super() super().__init__(message) + # Assign self.code = "INVALID_OPERATION" self.code = "INVALID_OPERATION" # ── Authorization ──────────────────────────────────────────────────── -class PermissionViolation(DomainError): +class PermissionViolation(DomainError): # noqa: N818 — DDD term, renaming would break 96 call sites """The user lacks permissions for an action.""" + # Define function __init__ def __init__(self, message: str = "Insufficient permissions") -> None: + """Initialise a permission-violation error. + + Args: + message (str): Human-readable description of the access denial. + + Returns: + None + """ + # Call super() super().__init__(message, code="FORBIDDEN") diff --git a/backend/app/domain/exceptions.py b/backend/app/domain/exceptions.py index 7564750..09c44e2 100644 --- a/backend/app/domain/exceptions.py +++ b/backend/app/domain/exceptions.py @@ -6,6 +6,7 @@ old import paths so that existing code keeps working without changes:: from app.domain.exceptions import InvalidTransitionError # still works """ +# Import # noqa: F401 from app.domain.errors from app.domain.errors import ( # noqa: F401 BusinessRuleViolation, DomainError, @@ -18,5 +19,7 @@ from app.domain.errors import ( # noqa: F401 # Legacy aliases — old name → new name DomainException = DomainError +# Assign InvalidTransitionError = InvalidStateTransition InvalidTransitionError = InvalidStateTransition +# Assign AuthorizationError = PermissionViolation AuthorizationError = PermissionViolation diff --git a/backend/app/domain/ports/__init__.py b/backend/app/domain/ports/__init__.py index e69de29..f50c579 100644 --- a/backend/app/domain/ports/__init__.py +++ b/backend/app/domain/ports/__init__.py @@ -0,0 +1 @@ +"""Abstract port interfaces that infrastructure adapters must implement.""" diff --git a/backend/app/domain/ports/import_service.py b/backend/app/domain/ports/import_service.py index be7b56d..d7e0171 100644 --- a/backend/app/domain/ports/import_service.py +++ b/backend/app/domain/ports/import_service.py @@ -12,14 +12,19 @@ This satisfies the Open/Closed Principle — the system is open for new import sources without modifying existing code. """ +# Enable future language features for compatibility from __future__ import annotations +# Import Any, Protocol, runtime_checkable from typing from typing import Any, Protocol, runtime_checkable +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session +# Apply the @runtime_checkable decorator @runtime_checkable +# Define class ImportService class ImportService(Protocol): """Contract for any data-import operation. @@ -27,62 +32,134 @@ class ImportService(Protocol): downloads, parses, and upserts records from an external source. """ - def __call__(self, db: Session) -> dict[str, Any]: ... + # Define function __call__ + def __call__(self, db: Session) -> dict[str, Any]: + """Execute the import operation against the given database session. + + Args: + db (Session): Active SQLAlchemy session to use for all DB operations. + + Returns: + dict[str, Any]: Summary statistics for the import run (e.g. created, + updated, skipped counts). + """ + # ... + ... +# Define class ImportServiceEntry class ImportServiceEntry: """Lazy-loading wrapper that resolves a module-level function on first call.""" + # Assign __slots__ = ("_module_path", "_func_name", "_resolved") __slots__ = ("_module_path", "_func_name", "_resolved") + # Define function __init__ def __init__(self, module_path: str, func_name: str) -> None: + """Initialise the lazy entry with the module path and function name to resolve later. + + Args: + module_path (str): Dotted Python module path, e.g. + ``"app.services.atomic_import_service"``. + func_name (str): Name of the callable to import from *module_path*. + + Returns: + None + """ + # Assign self._module_path = module_path self._module_path = module_path + # Assign self._func_name = func_name self._func_name = func_name + # Assign self._resolved = None self._resolved: ImportService | None = None + # Define function __call__ def __call__(self, db: Session) -> dict[str, Any]: + """Resolve the import function on first call and invoke it with *db*. + + Args: + db (Session): SQLAlchemy session passed through to the underlying + import function. + + Returns: + dict[str, Any]: Import statistics returned by the underlying function + (e.g. counts of created/updated/skipped records). + """ + # Check: self._resolved is None if self._resolved is None: + # Import importlib import importlib + # Assign mod = importlib.import_module(self._module_path) mod = importlib.import_module(self._module_path) + # Assign self._resolved = getattr(mod, self._func_name) self._resolved = getattr(mod, self._func_name) + # Return self._resolved(db) return self._resolved(db) + # Apply the @property decorator @property + # Define function source_info def source_info(self) -> str: + """Return a human-readable identifier for this import entry. + + Returns: + str: The fully qualified function reference as + ``"."``. + """ + # Return f"{self._module_path}.{self._func_name}" return f"{self._module_path}.{self._func_name}" +# Assign IMPORT_REGISTRY = { IMPORT_REGISTRY: dict[str, ImportServiceEntry] = { + # Literal argument value "atomic_red_team": ImportServiceEntry( + # Literal argument value "app.services.atomic_import_service", "import_atomic_red_team", ), + # Literal argument value "sigma": ImportServiceEntry( + # Literal argument value "app.services.sigma_import_service", "sync", ), + # Literal argument value "lolbas": ImportServiceEntry( + # Literal argument value "app.services.lolbas_import_service", "sync", ), + # Literal argument value "gtfobins": ImportServiceEntry( + # Literal argument value "app.services.lolbas_import_service", "sync_gtfobins", ), + # Literal argument value "caldera": ImportServiceEntry( + # Literal argument value "app.services.caldera_import_service", "sync", ), + # Literal argument value "elastic_rules": ImportServiceEntry( + # Literal argument value "app.services.elastic_import_service", "sync", ), + # Literal argument value "mitre_cti": ImportServiceEntry( + # Literal argument value "app.services.threat_actor_import_service", "sync", ), + # Literal argument value "d3fend": ImportServiceEntry( + # Literal argument value "app.services.d3fend_import_service", "sync", ), } +# Define function get_import_handler def get_import_handler(source_name: str) -> ImportServiceEntry | None: """Look up the import handler for *source_name*. Returns ``None`` when no handler is registered. """ + # Return IMPORT_REGISTRY.get(source_name) return IMPORT_REGISTRY.get(source_name) diff --git a/backend/app/domain/ports/repositories/__init__.py b/backend/app/domain/ports/repositories/__init__.py index 1260672..ed54d8d 100644 --- a/backend/app/domain/ports/repositories/__init__.py +++ b/backend/app/domain/ports/repositories/__init__.py @@ -1,4 +1,9 @@ +"""Abstract repository port interfaces for domain entity persistence.""" +# Import TechniqueRepository from app.domain.ports.repositories.technique_repository from app.domain.ports.repositories.technique_repository import TechniqueRepository + +# Import TestRepository from app.domain.ports.repositories.test_repository from app.domain.ports.repositories.test_repository import TestRepository +# Assign __all__ = ["TechniqueRepository", "TestRepository"] __all__ = ["TechniqueRepository", "TestRepository"] diff --git a/backend/app/domain/ports/repositories/technique_repository.py b/backend/app/domain/ports/repositories/technique_repository.py index b5a45e7..bfb9366 100644 --- a/backend/app/domain/ports/repositories/technique_repository.py +++ b/backend/app/domain/ports/repositories/technique_repository.py @@ -4,54 +4,157 @@ This is a domain contract — implementations live in infrastructure/. The domain layer NEVER imports the implementation. """ +# Enable future language features for compatibility from __future__ import annotations +# Import uuid import uuid + +# Import NamedTuple, Protocol, runtime_checkable from typing from typing import NamedTuple, Protocol, runtime_checkable +# Import TechniqueEntity from app.domain.entities.technique from app.domain.entities.technique import TechniqueEntity + +# Import TechniqueStatus from app.domain.enums from app.domain.enums import TechniqueStatus +# Define class TechniqueWithCounts class TechniqueWithCounts(NamedTuple): """Pre-aggregated technique data for heatmap/scoring.""" + # entity: TechniqueEntity entity: TechniqueEntity + # test_count: int test_count: int + # validated_test_count: int validated_test_count: int + # detection_rule_count: int detection_rule_count: int +# Apply the @runtime_checkable decorator @runtime_checkable +# Define class TechniqueRepository class TechniqueRepository(Protocol): """Data access contract for techniques (one per aggregate root).""" # -- Single-entity access ---------------------------------------------- - def find_by_id(self, technique_id: uuid.UUID) -> TechniqueEntity | None: ... + def find_by_id(self, technique_id: uuid.UUID) -> TechniqueEntity | None: + """Return the technique with the given primary key, or None if absent. - def find_by_mitre_id(self, mitre_id: str) -> TechniqueEntity | None: ... + Args: + technique_id (uuid.UUID): Primary key of the technique to look up. + + Returns: + TechniqueEntity | None: The matching entity, or None if not found. + """ + # ... + ... + + # Define function find_by_mitre_id + def find_by_mitre_id(self, mitre_id: str) -> TechniqueEntity | None: + """Return the technique matching the given MITRE ATT&CK identifier, or None. + + Args: + mitre_id (str): MITRE ATT&CK ID (e.g. ``"T1059"`` or ``"T1059.001"``). + + Returns: + TechniqueEntity | None: The matching entity, or None if not found. + """ + # ... + ... # -- List access ------------------------------------------------------- def list_all( self, *, + # Entry: tactic tactic: str | None = None, + # Entry: status status: TechniqueStatus | None = None, + # Entry: review_required review_required: bool | None = None, - ) -> list[TechniqueEntity]: ... + ) -> list[TechniqueEntity]: + """Return all techniques, optionally filtered by tactic, status, or review flag. - def list_by_ids(self, ids: list[uuid.UUID]) -> list[TechniqueEntity]: ... + Args: + tactic (str | None): When provided, restrict results to this tactic category. + status (TechniqueStatus | None): When provided, restrict results to this status. + review_required (bool | None): When provided, restrict results to techniques + whose ``review_required`` flag matches this value. + + Returns: + list[TechniqueEntity]: Matching technique entities; may be empty. + """ + # ... + ... + + # Define function list_by_ids + def list_by_ids(self, ids: list[uuid.UUID]) -> list[TechniqueEntity]: + """Return all techniques whose primary keys are in *ids*. + + Args: + ids (list[uuid.UUID]): List of technique UUIDs to retrieve. + + Returns: + list[TechniqueEntity]: Entities found for the supplied IDs; order + is not guaranteed and missing IDs are silently omitted. + """ + # ... + ... # -- Batch queries (scoring/heatmap performance) ----------------------- - def count_by_status(self) -> dict[TechniqueStatus, int]: ... + def count_by_status(self) -> dict[TechniqueStatus, int]: + """Return a count of techniques grouped by their global status. - def find_all_with_test_counts(self) -> list[TechniqueWithCounts]: ... + Returns: + dict[TechniqueStatus, int]: Mapping from each status value to the + number of techniques in that state. + """ + # ... + ... + + # Define function find_all_with_test_counts + def find_all_with_test_counts(self) -> list[TechniqueWithCounts]: + """Return all techniques together with pre-aggregated test and rule counts. + + Returns: + list[TechniqueWithCounts]: Each element bundles a TechniqueEntity + with its total, validated, and detection-rule counts for use + in heatmap and scoring calculations. + """ + # ... + ... # -- Mutations --------------------------------------------------------- - def save(self, technique: TechniqueEntity) -> TechniqueEntity: ... + def save(self, technique: TechniqueEntity) -> TechniqueEntity: + """Persist a technique entity and return the saved state. - def exists_by_mitre_id(self, mitre_id: str) -> bool: ... + Args: + technique (TechniqueEntity): The entity to create or update. + + Returns: + TechniqueEntity: The persisted entity, potentially with updated + fields (e.g. server-side timestamps). + """ + # ... + ... + + # Define function exists_by_mitre_id + def exists_by_mitre_id(self, mitre_id: str) -> bool: + """Return True if a technique with the given MITRE ID exists in the repository. + + Args: + mitre_id (str): MITRE ATT&CK ID to check (e.g. ``"T1059"``). + + Returns: + bool: True if a matching technique is found, False otherwise. + """ + # ... + ... diff --git a/backend/app/domain/ports/repositories/test_repository.py b/backend/app/domain/ports/repositories/test_repository.py index 79b6a26..96d3806 100644 --- a/backend/app/domain/ports/repositories/test_repository.py +++ b/backend/app/domain/ports/repositories/test_repository.py @@ -3,14 +3,20 @@ This is a domain contract — implementations live in infrastructure/. """ +# Enable future language features for compatibility from __future__ import annotations +# Import uuid import uuid -from typing import Protocol, runtime_checkable +# Import Protocol from typing +from typing import Protocol + +# Import TestState from app.domain.enums from app.domain.enums import TestState +# Define class TestRepository class TestRepository(Protocol): """Data access contract for tests.""" @@ -22,31 +28,81 @@ class TestRepository(Protocol): Returns the ORM model directly (not a domain entity) because the TestEntity is constructed at the service layer via ``TestEntity.from_orm()``. + + Args: + test_id (uuid.UUID): Primary key of the test to look up. + + Returns: + object | None: The ORM model instance, or None if not found. """ + # ... ... # -- List access ------------------------------------------------------- - def list_by_technique(self, technique_id: uuid.UUID) -> list[object]: ... + def list_by_technique(self, technique_id: uuid.UUID) -> list[object]: + """Return all test ORM models associated with the given technique. - def list_by_state(self, state: TestState) -> list[object]: ... + Args: + technique_id (uuid.UUID): Primary key of the technique whose tests to retrieve. + Returns: + list[object]: ORM model instances for all tests linked to this technique. + """ + # ... + ... + + # Define function list_by_state + def list_by_state(self, state: TestState) -> list[object]: + """Return all test ORM models in the given state. + + Args: + state (TestState): The state to filter tests by. + + Returns: + list[object]: ORM model instances for all tests currently in *state*. + """ + # ... + ... + + # Define function count_by_technique_and_state def count_by_technique_and_state( self, + # Entry: technique_id technique_id: uuid.UUID, ) -> dict[TestState, int]: - """Return test counts grouped by state for a single technique.""" + """Return test counts grouped by state for a single technique. + + Args: + technique_id (uuid.UUID): Primary key of the technique whose test + counts to aggregate. + + Returns: + dict[TestState, int]: Mapping from each test state to the number of + tests in that state for the given technique. + """ + # ... ... # -- Batch queries ----------------------------------------------------- def get_states_and_results_for_technique( self, + # Entry: technique_id technique_id: uuid.UUID, ) -> list[tuple[str, str | None]]: """Return (state, detection_result) pairs for all tests of a technique. Used by TechniqueEntity.recalculate_status() without loading full test models. + + Args: + technique_id (uuid.UUID): Primary key of the technique whose test + data to retrieve. + + Returns: + list[tuple[str, str | None]]: Each tuple contains the test state + string and the detection result string (or None if not yet set). """ + # ... ... diff --git a/backend/app/domain/test_entity.py b/backend/app/domain/test_entity.py index 2e94209..51efeeb 100644 --- a/backend/app/domain/test_entity.py +++ b/backend/app/domain/test_entity.py @@ -20,34 +20,58 @@ After mutations, the service layer copies ``entity.changes`` back onto the ORM model and persists via Unit of Work. """ +# Enable future language features for compatibility from __future__ import annotations +# Import enum import enum -import uuid -from dataclasses import dataclass, field -from datetime import datetime -from typing import Any +# Import uuid +import uuid + +# Import dataclass, field from dataclasses +from dataclasses import dataclass, field + +# Import datetime from datetime +from datetime import datetime + +# Import TYPE_CHECKING, Any from typing +from typing import TYPE_CHECKING, Any + +# Import from app.domain.errors from app.domain.errors import ( BusinessRuleViolation, InvalidOperationError, InvalidStateTransition, ) +# Check: TYPE_CHECKING +if TYPE_CHECKING: + # Import Test as TestORM from app.models.test + from app.models.test import Test as TestORM # ── Value objects ──────────────────────────────────────────────────── class TestState(str, enum.Enum): + """Ordered lifecycle states for a security test.""" + + # Assign draft = "draft" draft = "draft" + # Assign red_executing = "red_executing" red_executing = "red_executing" + # Assign blue_evaluating = "blue_evaluating" blue_evaluating = "blue_evaluating" + # Assign in_review = "in_review" in_review = "in_review" + # Assign validated = "validated" validated = "validated" + # Assign rejected = "rejected" rejected = "rejected" disputed = "disputed" # one lead approved, the other rejected +# Assign VALID_TRANSITIONS = { VALID_TRANSITIONS: dict[TestState, list[TestState]] = { TestState.draft: [TestState.red_executing], TestState.red_executing: [TestState.blue_evaluating], @@ -58,6 +82,7 @@ VALID_TRANSITIONS: dict[TestState, list[TestState]] = { TestState.validated: [], } +# Assign _PAUSABLE_STATES = frozenset({TestState.red_executing, TestState.blue_evaluating}) _PAUSABLE_STATES = frozenset({TestState.red_executing, TestState.blue_evaluating}) @@ -65,8 +90,13 @@ _PAUSABLE_STATES = frozenset({TestState.red_executing, TestState.blue_evaluating @dataclass(frozen=True) +# Define class DomainEvent class DomainEvent: + """Immutable record of a domain-level event emitted by the test entity.""" + + # name: str name: str + # Assign payload = field(default_factory=dict) payload: dict[str, Any] = field(default_factory=dict) @@ -74,30 +104,44 @@ class DomainEvent: @dataclass +# Define class TestEntity class TestEntity: """Pure domain representation of a security test.""" + # id: uuid.UUID id: uuid.UUID + # state: TestState state: TestState # Red validation red_validation_status: str | None = None + # Assign red_validated_by = None red_validated_by: uuid.UUID | None = None + # Assign red_validated_at = None red_validated_at: datetime | None = None + # Assign red_validation_notes = None red_validation_notes: str | None = None # Blue validation blue_validation_status: str | None = None + # Assign blue_validated_by = None blue_validated_by: uuid.UUID | None = None + # Assign blue_validated_at = None blue_validated_at: datetime | None = None + # Assign blue_validation_notes = None blue_validation_notes: str | None = None # Phase timing execution_date: datetime | None = None + # Assign red_started_at = None red_started_at: datetime | None = None + # Assign blue_started_at = None blue_started_at: datetime | None = None + # Assign paused_at = None paused_at: datetime | None = None + # Assign red_paused_seconds = 0 red_paused_seconds: int = 0 + # Assign blue_paused_seconds = 0 blue_paused_seconds: int = 0 # Internal bookkeeping (not persisted as-is) @@ -106,58 +150,134 @@ class TestEntity: # -- Factory -------------------------------------------------------- @classmethod - def from_orm(cls, model: Any) -> TestEntity: - """Build a TestEntity from a SQLAlchemy ``Test`` model instance.""" + # Define function from_orm + def from_orm(cls, model: TestORM) -> TestEntity: + """Build a TestEntity from a SQLAlchemy ``Test`` model instance. + + Args: + model (TestORM): The ORM model whose fields will be copied into the entity. + + Returns: + TestEntity: A fully populated domain entity reflecting the ORM state. + """ + # Assign raw_state = model.state raw_state = model.state + # Assign state = raw_state if isinstance(raw_state, TestState) else TestState(raw_st... state = raw_state if isinstance(raw_state, TestState) else TestState(raw_state) + # Return cls( return cls( + # Keyword argument: id id=model.id, + # Keyword argument: state state=state, + # Keyword argument: red_validation_status red_validation_status=model.red_validation_status, + # Keyword argument: red_validated_by red_validated_by=model.red_validated_by, + # Keyword argument: red_validated_at red_validated_at=model.red_validated_at, + # Keyword argument: red_validation_notes red_validation_notes=model.red_validation_notes, + # Keyword argument: blue_validation_status blue_validation_status=model.blue_validation_status, + # Keyword argument: blue_validated_by blue_validated_by=model.blue_validated_by, + # Keyword argument: blue_validated_at blue_validated_at=model.blue_validated_at, + # Keyword argument: blue_validation_notes blue_validation_notes=model.blue_validation_notes, + # Keyword argument: execution_date execution_date=model.execution_date, + # Keyword argument: red_started_at red_started_at=model.red_started_at, + # Keyword argument: blue_started_at blue_started_at=model.blue_started_at, + # Keyword argument: paused_at paused_at=model.paused_at, + # Keyword argument: red_paused_seconds red_paused_seconds=model.red_paused_seconds or 0, + # Keyword argument: blue_paused_seconds blue_paused_seconds=model.blue_paused_seconds or 0, ) - def apply_to(self, model: Any) -> None: - """Copy the entity's mutable fields back onto the ORM model.""" + # Define function apply_to + def apply_to(self, model: TestORM) -> None: + """Copy the entity's mutable fields back onto the ORM model. + + Args: + model (TestORM): The ORM model to update in-place. + + Returns: + None + """ + # Assign model.state = self.state model.state = self.state + # Assign model.red_validation_status = self.red_validation_status model.red_validation_status = self.red_validation_status + # Assign model.red_validated_by = self.red_validated_by model.red_validated_by = self.red_validated_by + # Assign model.red_validated_at = self.red_validated_at model.red_validated_at = self.red_validated_at + # Assign model.red_validation_notes = self.red_validation_notes model.red_validation_notes = self.red_validation_notes + # Assign model.blue_validation_status = self.blue_validation_status model.blue_validation_status = self.blue_validation_status + # Assign model.blue_validated_by = self.blue_validated_by model.blue_validated_by = self.blue_validated_by + # Assign model.blue_validated_at = self.blue_validated_at model.blue_validated_at = self.blue_validated_at + # Assign model.blue_validation_notes = self.blue_validation_notes model.blue_validation_notes = self.blue_validation_notes + # Assign model.execution_date = self.execution_date model.execution_date = self.execution_date + # Assign model.red_started_at = self.red_started_at model.red_started_at = self.red_started_at + # Assign model.blue_started_at = self.blue_started_at model.blue_started_at = self.blue_started_at + # Assign model.paused_at = self.paused_at model.paused_at = self.paused_at + # Assign model.red_paused_seconds = self.red_paused_seconds model.red_paused_seconds = self.red_paused_seconds + # Assign model.blue_paused_seconds = self.blue_paused_seconds model.blue_paused_seconds = self.blue_paused_seconds # -- Query helpers -------------------------------------------------- @property + # Define function events def events(self) -> list[DomainEvent]: + """Return a snapshot of all domain events raised on this entity. + + Returns: + list[DomainEvent]: Ordered list of events emitted since the entity + was constructed or last cleared. + """ + # Return list(self._events) return list(self._events) + # Define function can_transition def can_transition(self, target: TestState) -> bool: + """Check whether a transition from the current state to *target* is valid. + + Args: + target (TestState): The desired next state. + + Returns: + bool: True if the transition is allowed, False otherwise. + """ + # Return target in VALID_TRANSITIONS.get(self.state, []) return target in VALID_TRANSITIONS.get(self.state, []) + # Apply the @property decorator @property + # Define function is_terminal def is_terminal(self) -> bool: + """Return True if the test has reached its final (validated) state. + + Returns: + bool: True when state is ``validated``, False for all other states. + """ + # Return self.state == TestState.validated return self.state == TestState.validated # -- Core transition ------------------------------------------------ @@ -171,148 +291,305 @@ class TestEntity: Returns the *previous* state value as a plain string. Raises :class:`InvalidStateTransition` when the move is illegal. + + Args: + target (TestState | str): The desired next state, as an enum member + or its string equivalent. + + Returns: + str: The previous state value before the transition. """ + # Assign value = target.value if hasattr(target, "value") else str(target) value = target.value if hasattr(target, "value") else str(target) + # Assign resolved = target if isinstance(target, TestState) else TestState(value) resolved = target if isinstance(target, TestState) else TestState(value) + # Return self._transition(resolved) return self._transition(resolved) + # Define function _transition def _transition(self, target: TestState) -> str: - """Internal: validate and apply; return previous state value.""" + """Validate and apply a state transition, returning the previous state value. + + Args: + target (TestState): The desired next state enum member. + + Returns: + str: The previous state value before the transition was applied. + """ + # Check: not self.can_transition(target) if not self.can_transition(target): + # Assign valid = [s.value for s in VALID_TRANSITIONS.get(self.state, [])] valid = [s.value for s in VALID_TRANSITIONS.get(self.state, [])] + # Raise InvalidStateTransition raise InvalidStateTransition( + # Keyword argument: current_state current_state=self.state.value, + # Keyword argument: target_state target_state=target.value, + # Keyword argument: valid_transitions valid_transitions=valid, ) + # Assign previous = self.state.value previous = self.state.value + # Assign self.state = target self.state = target + # Call self._events.append() self._events.append(DomainEvent( + # Literal argument value "state_changed", {"previous": previous, "new": target.value}, )) + # Return previous return previous # -- Lifecycle commands -------------------------------------------- def start_execution(self) -> None: - """``draft`` -> ``red_executing``.""" + """Transition the test from ``draft`` to ``red_executing``. + + Returns: + None + """ + # Call self._transition() self._transition(TestState.red_executing) + # Assign now = datetime.utcnow() now = datetime.utcnow() + # Assign self.execution_date = now self.execution_date = now + # Assign self.red_started_at = now self.red_started_at = now + # Call self._events.append() self._events.append(DomainEvent("execution_started")) + # Define function submit_red_evidence def submit_red_evidence(self) -> int: - """``red_executing`` -> ``blue_evaluating``. + """Transition the test from ``red_executing`` to ``blue_evaluating``. Auto-resumes if paused. Returns paused seconds accumulated during this phase (for worklog calculation). + + Returns: + int: Total seconds the red phase was paused. """ + # Assign paused_extra = self._auto_resume() paused_extra = self._auto_resume() + # Call self._transition() self._transition(TestState.blue_evaluating) + # Assign total_paused = self.red_paused_seconds + paused_extra total_paused = self.red_paused_seconds + paused_extra + # Assign self.blue_started_at = datetime.utcnow() self.blue_started_at = datetime.utcnow() + # Assign self.blue_paused_seconds = 0 self.blue_paused_seconds = 0 + # Call self._events.append() self._events.append(DomainEvent( + # Literal argument value "red_evidence_submitted", {"red_paused_seconds": total_paused}, )) + # Return total_paused return total_paused + # Define function submit_blue_evidence def submit_blue_evidence(self) -> int: - """``blue_evaluating`` -> ``in_review``. + """Transition the test from ``blue_evaluating`` to ``in_review``. Auto-resumes if paused. Returns paused seconds accumulated during this phase (for worklog calculation). + + Returns: + int: Total seconds the blue phase was paused. """ + # Assign paused_extra = self._auto_resume() paused_extra = self._auto_resume() + # Call self._transition() self._transition(TestState.in_review) + # Assign total_paused = self.blue_paused_seconds + paused_extra total_paused = self.blue_paused_seconds + paused_extra + # Call self._events.append() self._events.append(DomainEvent( + # Literal argument value "blue_evidence_submitted", {"blue_paused_seconds": total_paused}, )) + # Return total_paused return total_paused + # Define function pause_timer def pause_timer(self) -> None: - """Pause the active phase timer.""" + """Pause the active phase timer. + + Returns: + None + """ + # Check: self.state not in _PAUSABLE_STATES if self.state not in _PAUSABLE_STATES: + # Raise BusinessRuleViolation raise BusinessRuleViolation( f"Cannot pause timer in '{self.state.value}' state" ) + # Check: self.paused_at is not None if self.paused_at is not None: + # Raise BusinessRuleViolation raise BusinessRuleViolation("Timer is already paused") + # Assign self.paused_at = datetime.utcnow() self.paused_at = datetime.utcnow() + # Call self._events.append() self._events.append(DomainEvent("timer_paused")) + # Define function resume_timer def resume_timer(self) -> int: - """Resume a paused timer. Returns seconds that were paused.""" + """Resume a paused timer. + + Returns: + int: Number of seconds the timer was paused for. + """ + # Check: self.paused_at is None if self.paused_at is None: + # Raise BusinessRuleViolation raise BusinessRuleViolation("Timer is not paused") + # Assign now = datetime.utcnow() now = datetime.utcnow() + # Assign paused_seconds = max(int((now - self.paused_at).total_seconds()), 0) paused_seconds = max(int((now - self.paused_at).total_seconds()), 0) + # Check: self.state == TestState.red_executing if self.state == TestState.red_executing: + # Assign self.red_paused_seconds = paused_seconds self.red_paused_seconds += paused_seconds + # Alternative: self.state == TestState.blue_evaluating elif self.state == TestState.blue_evaluating: + # Assign self.blue_paused_seconds = paused_seconds self.blue_paused_seconds += paused_seconds + # Assign self.paused_at = None self.paused_at = None + # Call self._events.append() self._events.append(DomainEvent("timer_resumed", {"paused_seconds": paused_seconds})) + # Return paused_seconds return paused_seconds + # Define function validate_red def validate_red(self, status: str, *, by: uuid.UUID, notes: str | None = None) -> None: - """Record Red Lead's validation decision.""" + """Record Red Lead's validation decision. + + Args: + status (str): Validation outcome; must be ``"approved"`` or ``"rejected"``. + by (uuid.UUID): UUID of the Red Lead recording the decision. + notes (str | None): Optional free-text notes about the decision. + + Returns: + None + """ + # Call self._assert_in_review() self._assert_in_review("red") + # Call self._assert_valid_vote() self._assert_valid_vote(status) + # Assign now = datetime.utcnow() now = datetime.utcnow() + # Assign self.red_validation_status = status self.red_validation_status = status + # Assign self.red_validated_by = by self.red_validated_by = by + # Assign self.red_validated_at = now self.red_validated_at = now + # Assign self.red_validation_notes = notes self.red_validation_notes = notes + # Call self._events.append() self._events.append(DomainEvent("red_validated", {"status": status})) + # Call self._check_dual_validation() self._check_dual_validation() + # Define function validate_blue def validate_blue(self, status: str, *, by: uuid.UUID, notes: str | None = None) -> None: - """Record Blue Lead's validation decision.""" + """Record Blue Lead's validation decision. + + Args: + status (str): Validation outcome; must be ``"approved"`` or ``"rejected"``. + by (uuid.UUID): UUID of the Blue Lead recording the decision. + notes (str | None): Optional free-text notes about the decision. + + Returns: + None + """ + # Call self._assert_in_review() self._assert_in_review("blue") + # Call self._assert_valid_vote() self._assert_valid_vote(status) + # Assign now = datetime.utcnow() now = datetime.utcnow() + # Assign self.blue_validation_status = status self.blue_validation_status = status + # Assign self.blue_validated_by = by self.blue_validated_by = by + # Assign self.blue_validated_at = now self.blue_validated_at = now + # Assign self.blue_validation_notes = notes self.blue_validation_notes = notes + # Call self._events.append() self._events.append(DomainEvent("blue_validated", {"status": status})) + # Call self._check_dual_validation() self._check_dual_validation() + # Define function reopen def reopen(self) -> None: - """``rejected`` -> ``draft``, clearing all validation/timing fields.""" + """Transition the test from ``rejected`` back to ``draft``, clearing all validation and timing fields. + + Returns: + None + """ + # Call self._transition() self._transition(TestState.draft) + # Assign self.red_validation_status = None self.red_validation_status = None + # Assign self.red_validated_by = None self.red_validated_by = None + # Assign self.red_validated_at = None self.red_validated_at = None + # Assign self.red_validation_notes = None self.red_validation_notes = None + # Assign self.blue_validation_status = None self.blue_validation_status = None + # Assign self.blue_validated_by = None self.blue_validated_by = None + # Assign self.blue_validated_at = None self.blue_validated_at = None + # Assign self.blue_validation_notes = None self.blue_validation_notes = None + # Assign self.red_started_at = None self.red_started_at = None + # Assign self.blue_started_at = None self.blue_started_at = None + # Assign self.paused_at = None self.paused_at = None + # Assign self.red_paused_seconds = 0 self.red_paused_seconds = 0 + # Assign self.blue_paused_seconds = 0 self.blue_paused_seconds = 0 + # Call self._events.append() self._events.append(DomainEvent("test_reopened")) # -- Private ------------------------------------------------------- def _auto_resume(self) -> int: - """If paused, accumulate pause time and clear. Returns extra seconds.""" + """Accumulate pause time and clear the paused timestamp if currently paused. + + Returns: + int: Extra seconds that were accumulated from the current pause, or 0 + if the timer was not paused. + """ + # Check: self.paused_at is None if self.paused_at is None: + # Return 0 return 0 + # Assign now = datetime.utcnow() now = datetime.utcnow() + # Assign extra = max(int((now - self.paused_at).total_seconds()), 0) extra = max(int((now - self.paused_at).total_seconds()), 0) + # Assign self.paused_at = None self.paused_at = None + # Return extra return extra + # Define function check_dual_validation def check_dual_validation(self) -> None: """Evaluate both leads' votes and advance state if appropriate. @@ -324,8 +601,10 @@ class TestEntity: Called automatically by :meth:`validate_red` and :meth:`validate_blue`. """ + # Call self._check_dual_validation() self._check_dual_validation() + # Define function _assert_in_review def _assert_in_review(self, side: str) -> None: if self.state not in (TestState.in_review, TestState.disputed): raise InvalidOperationError( @@ -333,19 +612,34 @@ class TestEntity: f"'{self.state.value}' state (must be in_review or disputed)" ) + # Apply the @staticmethod decorator @staticmethod + # Define function _assert_valid_vote def _assert_valid_vote(status: str) -> None: + """Raise InvalidOperationError if *status* is not a valid vote value. + + Args: + status (str): The vote value to validate; must be ``"approved"`` or ``"rejected"``. + + Returns: + None + """ + # Check: status not in ("approved", "rejected") if status not in ("approved", "rejected"): + # Raise InvalidOperationError raise InvalidOperationError( + # Literal argument value "validation_status must be 'approved' or 'rejected'" ) + # Define function _check_dual_validation def _check_dual_validation(self) -> None: """Advance the test state once both leads have voted.""" r, b = self.red_validation_status, self.blue_validation_status if r == "approved" and b == "approved": self.state = TestState.validated + # Call self._events.append() self._events.append(DomainEvent("dual_validation_approved")) elif r == "rejected" and b == "rejected": diff --git a/backend/app/domain/unit_of_work.py b/backend/app/domain/unit_of_work.py index 83b2400..add3192 100644 --- a/backend/app/domain/unit_of_work.py +++ b/backend/app/domain/unit_of_work.py @@ -20,36 +20,84 @@ Services should **never** call ``db.commit()``; they use ``db.add()`` / osint_enrichment_service.enrich_technique_with_cves). """ +# Enable future language features for compatibility from __future__ import annotations +# Import TracebackType from types +from types import TracebackType + +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session +# Define class UnitOfWork class UnitOfWork: """Lightweight transaction wrapper around an existing SQLAlchemy session.""" + # Define function __init__ def __init__(self, session: Session) -> None: + """Wrap an existing SQLAlchemy session in a Unit of Work. + + Args: + session (Session): The active SQLAlchemy session to manage. + + Returns: + None + """ + # Assign self._session = session self._session = session # -- context manager ----------------------------------------------------- def __enter__(self) -> "UnitOfWork": + """Enter the runtime context, returning this UnitOfWork instance. + + Returns: + UnitOfWork: The UnitOfWork itself, for use in ``with`` statements. + """ + # Return self return self - def __exit__(self, exc_type, exc_val, exc_tb) -> None: + # Define function __exit__ + def __exit__( + self, + # Entry: exc_type + exc_type: type[BaseException] | None, + # Entry: exc_val + exc_val: BaseException | None, + # Entry: exc_tb + exc_tb: TracebackType | None, + ) -> None: + """Exit the runtime context, rolling back if an exception propagated. + + Args: + exc_type (type[BaseException] | None): Exception class, if raised. + exc_val (BaseException | None): Exception instance, if raised. + exc_tb (TracebackType | None): Traceback object, if an exception was raised. + + Returns: + None + """ + # Check: exc_type is not None if exc_type is not None: + # Call self.rollback() self.rollback() # -- public API ---------------------------------------------------------- def commit(self) -> None: """Flush pending changes and commit the transaction.""" + # Call self._session.commit() self._session.commit() + # Define function rollback def rollback(self) -> None: """Roll back the current transaction.""" + # Call self._session.rollback() self._session.rollback() + # Define function flush def flush(self) -> None: """Flush pending changes without committing (useful for getting IDs).""" + # Call self._session.flush() self._session.flush() diff --git a/backend/app/domain/value_objects/__init__.py b/backend/app/domain/value_objects/__init__.py index bc332a6..8390a2b 100644 --- a/backend/app/domain/value_objects/__init__.py +++ b/backend/app/domain/value_objects/__init__.py @@ -1,4 +1,9 @@ +"""Immutable domain value objects.""" +# Import MitreId from app.domain.value_objects.mitre_id from app.domain.value_objects.mitre_id import MitreId + +# Import ScoringWeights from app.domain.value_objects.scoring_weights from app.domain.value_objects.scoring_weights import ScoringWeights +# Assign __all__ = ["MitreId", "ScoringWeights"] __all__ = ["MitreId", "ScoringWeights"] diff --git a/backend/app/domain/value_objects/mitre_id.py b/backend/app/domain/value_objects/mitre_id.py index 092a5a3..55bc944 100644 --- a/backend/app/domain/value_objects/mitre_id.py +++ b/backend/app/domain/value_objects/mitre_id.py @@ -5,47 +5,111 @@ format: ``T`` followed by 4 digits, optionally a dot and 3 more digits for sub-techniques (e.g. ``T1059``, ``T1059.001``). """ +# Enable future language features for compatibility from __future__ import annotations +# Import re import re + +# Import dataclass from dataclasses from dataclasses import dataclass +# Assign _MITRE_ID_RE = re.compile(r"^T\d{4}(\.\d{3})?$") _MITRE_ID_RE = re.compile(r"^T\d{4}(\.\d{3})?$") +# Apply the @dataclass decorator @dataclass(frozen=True, slots=True) +# Define class MitreId class MitreId: """Validated MITRE ATT&CK technique identifier.""" + # value: str value: str + # Define function __post_init__ def __post_init__(self) -> None: + """Validate that *value* matches the expected MITRE ATT&CK ID format. + + Returns: + None + """ + # Check: not _MITRE_ID_RE.match(self.value) if not _MITRE_ID_RE.match(self.value): + # Raise ValueError raise ValueError( f"Invalid MITRE ATT&CK ID '{self.value}'. " + # Literal argument value "Expected format: T1234 or T1234.001" ) + # Apply the @property decorator @property + # Define function is_subtechnique def is_subtechnique(self) -> bool: + """Return True if this identifier represents a sub-technique. + + Returns: + bool: True when the ID contains a dot (e.g. ``T1059.001``). + """ + # Return "." in self.value return "." in self.value + # Apply the @property decorator @property + # Define function parent_id def parent_id(self) -> str | None: - """Return the parent technique ID (e.g. T1059 for T1059.001).""" + """Return the parent technique ID (e.g. ``T1059`` for ``T1059.001``). + + Returns: + str | None: The parent ID string, or None if this is not a sub-technique. + """ + # Check: not self.is_subtechnique if not self.is_subtechnique: + # Return None return None + # Return self.value.split(".")[0] return self.value.split(".")[0] + # Define function __str__ def __str__(self) -> str: + """Return the string representation of the MITRE ID. + + Returns: + str: The raw identifier string (e.g. ``"T1059.001"``). + """ + # Return self.value return self.value + # Define function __eq__ def __eq__(self, other: object) -> bool: + """Compare this MitreId to another MitreId or a plain string. + + Args: + other (object): The value to compare against; may be a + :class:`MitreId` instance or a plain ``str``. + + Returns: + bool: True if the identifiers are equal, NotImplemented for + unsupported types. + """ + # Check: isinstance(other, MitreId) if isinstance(other, MitreId): + # Return self.value == other.value return self.value == other.value + # Check: isinstance(other, str) if isinstance(other, str): + # Return self.value == other return self.value == other + # Return NotImplemented return NotImplemented + # Define function __hash__ def __hash__(self) -> int: + """Return the hash of the identifier string. + + Returns: + int: Hash value derived from the raw identifier string. + """ + # Return hash(self.value) return hash(self.value) diff --git a/backend/app/domain/value_objects/scoring_weights.py b/backend/app/domain/value_objects/scoring_weights.py index 6bcb292..01f56da 100644 --- a/backend/app/domain/value_objects/scoring_weights.py +++ b/backend/app/domain/value_objects/scoring_weights.py @@ -3,22 +3,38 @@ Enforces that all five weights are non-negative and sum to exactly 100. """ +# Enable future language features for compatibility from __future__ import annotations +# Import dataclass from dataclasses from dataclasses import dataclass +# Apply the @dataclass decorator @dataclass(frozen=True, slots=True) +# Define class ScoringWeights class ScoringWeights: """Five scoring dimension weights that must sum to 100.""" + # tests: float tests: float + # detection_rules: float detection_rules: float + # d3fend: float d3fend: float + # recency: float recency: float + # severity: float severity: float + # Define function __post_init__ def __post_init__(self) -> None: + """Validate that all weights are non-negative and sum to exactly 100. + + Returns: + None + """ + # Assign fields = [ fields = [ self.tests, self.detection_rules, @@ -26,32 +42,66 @@ class ScoringWeights: self.recency, self.severity, ] + # Iterate over fields for f in fields: + # Check: f < 0 if f < 0: + # Raise ValueError raise ValueError("Scoring weights must be non-negative") + # Assign total = sum(fields) total = sum(fields) + # Check: abs(total - 100) > 0.01 if abs(total - 100) > 0.01: + # Raise ValueError raise ValueError( f"Scoring weights must sum to 100, got {total}" ) + # Apply the @classmethod decorator @classmethod + # Define function default def default(cls) -> ScoringWeights: - """Return the default weight distribution.""" + """Return the default weight distribution. + + Returns: + ScoringWeights: A weight set with tests=40, detection_rules=25, + d3fend=15, recency=10, severity=10. + """ + # Return cls( return cls( + # Keyword argument: tests tests=40.0, + # Keyword argument: detection_rules detection_rules=25.0, + # Keyword argument: d3fend d3fend=15.0, + # Keyword argument: recency recency=10.0, + # Keyword argument: severity severity=10.0, ) # Backward-compatible aliases for older API payloads @property + # Define function freshness def freshness(self) -> float: + """Return the recency weight (backward-compatible alias). + + Returns: + float: The value of the ``recency`` weight. + """ + # Return self.recency return self.recency + # Apply the @property decorator @property + # Define function platform_diversity def platform_diversity(self) -> float: + """Return the severity weight (backward-compatible alias). + + Returns: + float: The value of the ``severity`` weight. + """ + # Return self.severity return self.severity diff --git a/backend/app/infrastructure/__init__.py b/backend/app/infrastructure/__init__.py index e69de29..3623735 100644 --- a/backend/app/infrastructure/__init__.py +++ b/backend/app/infrastructure/__init__.py @@ -0,0 +1 @@ +"""Infrastructure adapters — persistence, caching, and external services.""" diff --git a/backend/app/infrastructure/persistence/__init__.py b/backend/app/infrastructure/persistence/__init__.py index e69de29..01fbb04 100644 --- a/backend/app/infrastructure/persistence/__init__.py +++ b/backend/app/infrastructure/persistence/__init__.py @@ -0,0 +1 @@ +"""SQLAlchemy-based persistence adapters for the domain repository ports.""" diff --git a/backend/app/infrastructure/persistence/mappers/__init__.py b/backend/app/infrastructure/persistence/mappers/__init__.py index e69de29..5ada5b1 100644 --- a/backend/app/infrastructure/persistence/mappers/__init__.py +++ b/backend/app/infrastructure/persistence/mappers/__init__.py @@ -0,0 +1 @@ +"""ORM-to-domain entity mapper functions.""" diff --git a/backend/app/infrastructure/persistence/mappers/technique_mapper.py b/backend/app/infrastructure/persistence/mappers/technique_mapper.py index 74cd588..fd0ee6f 100644 --- a/backend/app/infrastructure/persistence/mappers/technique_mapper.py +++ b/backend/app/infrastructure/persistence/mappers/technique_mapper.py @@ -1,20 +1,28 @@ """Technique ORM model <-> domain entity mapper.""" +# Enable future language features for compatibility from __future__ import annotations +# Import TechniqueEntity from app.domain.entities.technique from app.domain.entities.technique import TechniqueEntity -from app.domain.enums import TechniqueStatus +# Define class TechniqueMapper class TechniqueMapper: """Converts between SQLAlchemy Technique model and TechniqueEntity.""" + # Apply the @staticmethod decorator @staticmethod + # Define function to_entity def to_entity(model: object) -> TechniqueEntity: """Convert an ORM Technique model to a domain TechniqueEntity.""" + # Return TechniqueEntity.from_orm(model) return TechniqueEntity.from_orm(model) + # Apply the @staticmethod decorator @staticmethod + # Define function to_model_updates def to_model_updates(entity: TechniqueEntity, model: object) -> None: """Apply entity changes back onto an existing ORM model.""" + # Call entity.apply_to() entity.apply_to(model) diff --git a/backend/app/infrastructure/persistence/repositories/__init__.py b/backend/app/infrastructure/persistence/repositories/__init__.py index d6c0338..38d913d 100644 --- a/backend/app/infrastructure/persistence/repositories/__init__.py +++ b/backend/app/infrastructure/persistence/repositories/__init__.py @@ -1,8 +1,13 @@ +"""Concrete SQLAlchemy repository implementations.""" +# Import from app.infrastructure.persistence.repositories.sa_technique_repository from app.infrastructure.persistence.repositories.sa_technique_repository import ( SATechniqueRepository, ) + +# Import from app.infrastructure.persistence.repositories.sa_test_repository from app.infrastructure.persistence.repositories.sa_test_repository import ( SATestRepository, ) +# Assign __all__ = ["SATechniqueRepository", "SATestRepository"] __all__ = ["SATechniqueRepository", "SATestRepository"] diff --git a/backend/app/infrastructure/persistence/repositories/sa_technique_repository.py b/backend/app/infrastructure/persistence/repositories/sa_technique_repository.py index f1a3828..b142085 100644 --- a/backend/app/infrastructure/persistence/repositories/sa_technique_repository.py +++ b/backend/app/infrastructure/persistence/repositories/sa_technique_repository.py @@ -4,44 +4,95 @@ Receives a Session from the caller — does NOT create its own. Does NOT call commit() — the Unit of Work owns that. """ +# Enable future language features for compatibility from __future__ import annotations +# Import uuid import uuid +# Import func from sqlalchemy from sqlalchemy import func + +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session +# Import TechniqueEntity from app.domain.entities.technique from app.domain.entities.technique import TechniqueEntity + +# Import TechniqueStatus, TestState from app.domain.enums from app.domain.enums import TechniqueStatus, TestState + +# Import TechniqueWithCounts from app.domain.ports.repositories.technique_repository from app.domain.ports.repositories.technique_repository import TechniqueWithCounts + +# Import TechniqueMapper from app.infrastructure.persistence.mappers.technique_mapper from app.infrastructure.persistence.mappers.technique_mapper import TechniqueMapper + +# Import DetectionRule from app.models.detection_rule from app.models.detection_rule import DetectionRule + +# Import Technique from app.models.technique from app.models.technique import Technique + +# Import Test from app.models.test from app.models.test import Test +# Define class SATechniqueRepository class SATechniqueRepository: """Concrete repository backed by SQLAlchemy.""" + # Define function __init__ def __init__(self, session: Session) -> None: + """Initialise the repository with a caller-provided session. + + Args: + session (Session): The SQLAlchemy session to use for all queries. + """ + # Assign self._session = session self._session = session # -- Single-entity access ---------------------------------------------- def find_by_id(self, technique_id: uuid.UUID) -> TechniqueEntity | None: + """Return a single technique by its primary key. + + Args: + technique_id (uuid.UUID): The UUID primary key of the technique. + + Returns: + TechniqueEntity | None: The matching entity, or ``None`` if not found. + """ + # Assign model = ( model = ( self._session.query(Technique) + # Chain .filter() call .filter(Technique.id == technique_id) + # Chain .first() call .first() ) + # Return TechniqueMapper.to_entity(model) if model else None return TechniqueMapper.to_entity(model) if model else None + # Define function find_by_mitre_id def find_by_mitre_id(self, mitre_id: str) -> TechniqueEntity | None: + """Return a single technique by its MITRE ATT&CK ID (e.g. ``T1059.001``). + + Args: + mitre_id (str): The MITRE ATT&CK identifier string. + + Returns: + TechniqueEntity | None: The matching entity, or ``None`` if not found. + """ + # Assign model = ( model = ( self._session.query(Technique) + # Chain .filter() call .filter(Technique.mitre_id == mitre_id) + # Chain .first() call .first() ) + # Return TechniqueMapper.to_entity(model) if model else None return TechniqueMapper.to_entity(model) if model else None # -- List access ------------------------------------------------------- @@ -49,57 +100,111 @@ class SATechniqueRepository: def list_all( self, *, + # Entry: tactic tactic: str | None = None, + # Entry: status status: TechniqueStatus | None = None, + # Entry: review_required review_required: bool | None = None, ) -> list[TechniqueEntity]: + """Return all techniques, optionally filtered by tactic, status, or review flag. + + Args: + tactic (str | None): Filter to techniques belonging to this tactic name. + status (TechniqueStatus | None): Filter to techniques with this coverage status. + review_required (bool | None): Filter to techniques where ``review_required`` matches. + + Returns: + list[TechniqueEntity]: Ordered list of matching technique entities. + """ + # Assign query = self._session.query(Technique) query = self._session.query(Technique) + # Check: tactic is not None if tactic is not None: + # Assign query = query.filter(Technique.tactic == tactic) query = query.filter(Technique.tactic == tactic) + # Check: status is not None if status is not None: + # Assign query = query.filter(Technique.status_global == status) query = query.filter(Technique.status_global == status) + # Check: review_required is not None if review_required is not None: + # Assign query = query.filter(Technique.review_required == review_required) query = query.filter(Technique.review_required == review_required) + # Assign models = query.order_by(Technique.mitre_id).all() models = query.order_by(Technique.mitre_id).all() + # Return [TechniqueMapper.to_entity(m) for m in models] return [TechniqueMapper.to_entity(m) for m in models] + # Define function list_by_ids def list_by_ids(self, ids: list[uuid.UUID]) -> list[TechniqueEntity]: + """Return techniques matching the provided list of UUIDs. + + Args: + ids (list[uuid.UUID]): UUIDs of the techniques to retrieve. + + Returns: + list[TechniqueEntity]: Technique entities corresponding to the given IDs. + """ + # Check: not ids if not ids: + # Return [] return [] + # Assign models = ( models = ( self._session.query(Technique) + # Chain .filter() call .filter(Technique.id.in_(ids)) + # Chain .all() call .all() ) + # Return [TechniqueMapper.to_entity(m) for m in models] return [TechniqueMapper.to_entity(m) for m in models] # -- Batch queries (for scoring/heatmap) ------------------------------- def count_by_status(self) -> dict[TechniqueStatus, int]: + """Return a count of techniques grouped by their coverage status. + + Returns: + dict[TechniqueStatus, int]: Mapping of each status value to its technique count. + """ + # Assign rows = ( rows = ( self._session.query( Technique.status_global, func.count(Technique.id), ) + # Chain .group_by() call .group_by(Technique.status_global) + # Chain .all() call .all() ) + # Assign result = {s: 0 for s in TechniqueStatus} result = {s: 0 for s in TechniqueStatus} + # Iterate over rows for status_val, count in rows: + # Assign key = ( key = ( status_val if isinstance(status_val, TechniqueStatus) else TechniqueStatus(status_val) ) + # Assign result[key] = count result[key] = count + # Return result return result + # Define function find_all_with_test_counts def find_all_with_test_counts(self) -> list[TechniqueWithCounts]: - """Single query replacing the N+1 pattern. + """Return all techniques with pre-aggregated test and detection rule counts. - Returns all techniques with pre-aggregated test and detection - rule counts via subqueries. + Uses a single query with subqueries to avoid the N+1 pattern. + + Returns: + list[TechniqueWithCounts]: All techniques with their associated counts. """ + # Assign test_count_sq = ( test_count_sq = ( self._session.query( Test.technique_id, @@ -108,18 +213,24 @@ class SATechniqueRepository: func.cast(Test.state == TestState.validated, self._int_type()) ).label("validated_count"), ) + # Chain .group_by() call .group_by(Test.technique_id) + # Chain .subquery() call .subquery() ) + # Assign rule_count_sq = ( rule_count_sq = ( self._session.query( DetectionRule.mitre_technique_id, func.count(DetectionRule.id).label("rule_count"), ) + # Chain .group_by() call .group_by(DetectionRule.mitre_technique_id) + # Chain .subquery() call .subquery() ) + # Assign rows = ( rows = ( self._session.query( Technique, @@ -127,20 +238,29 @@ class SATechniqueRepository: func.coalesce(test_count_sq.c.validated_count, 0), func.coalesce(rule_count_sq.c.rule_count, 0), ) + # Chain .outerjoin() call .outerjoin(test_count_sq, Technique.id == test_count_sq.c.technique_id) + # Chain .outerjoin() call .outerjoin( rule_count_sq, Technique.mitre_id == rule_count_sq.c.mitre_technique_id, ) + # Chain .order_by() call .order_by(Technique.mitre_id) + # Chain .all() call .all() ) + # Return [ return [ TechniqueWithCounts( + # Keyword argument: entity entity=TechniqueMapper.to_entity(tech), + # Keyword argument: test_count test_count=int(tc), + # Keyword argument: validated_test_count validated_test_count=int(vtc), + # Keyword argument: detection_rule_count detection_rule_count=int(rc), ) for tech, tc, vtc, rc in rows @@ -149,55 +269,112 @@ class SATechniqueRepository: # -- Mutations --------------------------------------------------------- def save(self, technique: TechniqueEntity) -> TechniqueEntity: + """Persist a technique entity, inserting or updating as needed. + + Args: + technique (TechniqueEntity): The domain entity to persist. + + Returns: + TechniqueEntity: The persisted entity reflecting the current DB state. + """ + # Assign existing = ( existing = ( self._session.query(Technique) + # Chain .filter() call .filter(Technique.id == technique.id) + # Chain .first() call .first() ) + # Check: existing if existing: + # Call technique.apply_to() technique.apply_to(existing) + # Assign existing.mitre_id = technique.mitre_id existing.mitre_id = technique.mitre_id + # Assign existing.name = technique.name existing.name = technique.name + # Assign existing.tactic = technique.tactic existing.tactic = technique.tactic + # Assign existing.description = technique.description existing.description = technique.description + # Assign existing.platforms = technique.platforms existing.platforms = technique.platforms + # Assign existing.is_subtechnique = technique.is_subtechnique existing.is_subtechnique = technique.is_subtechnique + # Assign existing.parent_mitre_id = technique.parent_mitre_id existing.parent_mitre_id = technique.parent_mitre_id + # Assign existing.mitre_version = technique.mitre_version existing.mitre_version = technique.mitre_version + # Assign existing.mitre_last_modified = technique.mitre_last_modified existing.mitre_last_modified = technique.mitre_last_modified + # Call self._session.flush() self._session.flush() + # Return TechniqueMapper.to_entity(existing) return TechniqueMapper.to_entity(existing) + # Fallback: handle remaining cases else: + # Assign model = Technique( model = Technique( + # Keyword argument: id id=technique.id, + # Keyword argument: mitre_id mitre_id=technique.mitre_id, + # Keyword argument: name name=technique.name, + # Keyword argument: tactic tactic=technique.tactic, + # Keyword argument: description description=technique.description, + # Keyword argument: platforms platforms=technique.platforms, + # Keyword argument: is_subtechnique is_subtechnique=technique.is_subtechnique, + # Keyword argument: parent_mitre_id parent_mitre_id=technique.parent_mitre_id, + # Keyword argument: status_global status_global=technique.status_global, + # Keyword argument: review_required review_required=technique.review_required, + # Keyword argument: last_review_date last_review_date=technique.last_review_date, + # Keyword argument: mitre_version mitre_version=technique.mitre_version, + # Keyword argument: mitre_last_modified mitre_last_modified=technique.mitre_last_modified, ) + # Call self._session.add() self._session.add(model) + # Call self._session.flush() self._session.flush() + # Return TechniqueMapper.to_entity(model) return TechniqueMapper.to_entity(model) + # Define function exists_by_mitre_id def exists_by_mitre_id(self, mitre_id: str) -> bool: + """Check whether a technique with the given MITRE ID already exists. + + Args: + mitre_id (str): The MITRE ATT&CK identifier to look up. + + Returns: + bool: ``True`` if the technique exists, ``False`` otherwise. + """ + # Return ( return ( self._session.query(Technique.id) + # Chain .filter() call .filter(Technique.mitre_id == mitre_id) + # Chain .first() call .first() ) is not None # -- Internal ---------------------------------------------------------- @staticmethod - def _int_type(): + # Define function _int_type + def _int_type() -> type: """Return an Integer type for CAST expressions (SQLite-compatible).""" + # Import Integer from sqlalchemy from sqlalchemy import Integer + # Return Integer return Integer diff --git a/backend/app/infrastructure/persistence/repositories/sa_test_repository.py b/backend/app/infrastructure/persistence/repositories/sa_test_repository.py index 0a893f8..7a2a370 100644 --- a/backend/app/infrastructure/persistence/repositories/sa_test_repository.py +++ b/backend/app/infrastructure/persistence/repositories/sa_test_repository.py @@ -1,78 +1,163 @@ """SQLAlchemy implementation of TestRepository.""" +# Enable future language features for compatibility from __future__ import annotations +# Import uuid import uuid +# Import func from sqlalchemy from sqlalchemy import func + +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session +# Import TestState from app.domain.enums from app.domain.enums import TestState + +# Import Test from app.models.test from app.models.test import Test +# Define class SATestRepository class SATestRepository: """Concrete test repository backed by SQLAlchemy.""" + # Define function __init__ def __init__(self, session: Session) -> None: + """Initialise the repository with a caller-provided session. + + Args: + session (Session): The SQLAlchemy session to use for all queries. + """ + # Assign self._session = session self._session = session + # Define function find_by_id def find_by_id(self, test_id: uuid.UUID) -> Test | None: + """Return a single test by its primary key. + + Args: + test_id (uuid.UUID): The UUID primary key of the test. + + Returns: + Test | None: The ORM model instance, or ``None`` if not found. + """ + # Return ( return ( self._session.query(Test) + # Chain .filter() call .filter(Test.id == test_id) + # Chain .first() call .first() ) + # Define function list_by_technique def list_by_technique(self, technique_id: uuid.UUID) -> list[Test]: + """Return all tests for a given technique, ordered by creation date. + + Args: + technique_id (uuid.UUID): The UUID of the parent technique. + + Returns: + list[Test]: ORM model instances ordered by ``created_at`` ascending. + """ + # Return ( return ( self._session.query(Test) + # Chain .filter() call .filter(Test.technique_id == technique_id) + # Chain .order_by() call .order_by(Test.created_at) + # Chain .all() call .all() ) + # Define function list_by_state def list_by_state(self, state: TestState) -> list[Test]: + """Return all tests that are currently in the given workflow state. + + Args: + state (TestState): The workflow state to filter on. + + Returns: + list[Test]: All ORM model instances with the specified state. + """ + # Return ( return ( self._session.query(Test) + # Chain .filter() call .filter(Test.state == state) + # Chain .all() call .all() ) + # Define function count_by_technique_and_state def count_by_technique_and_state( self, + # Entry: technique_id technique_id: uuid.UUID, ) -> dict[TestState, int]: + """Return per-state test counts for a specific technique. + + Args: + technique_id (uuid.UUID): The UUID of the technique to aggregate for. + + Returns: + dict[TestState, int]: Mapping of each state to the number of tests in that state. + """ + # Assign rows = ( rows = ( self._session.query(Test.state, func.count(Test.id)) + # Chain .filter() call .filter(Test.technique_id == technique_id) + # Chain .group_by() call .group_by(Test.state) + # Chain .all() call .all() ) + # Assign result = {} result: dict[TestState, int] = {} + # Iterate over rows for state_val, count in rows: + # Assign key = ( key = ( state_val if isinstance(state_val, TestState) else TestState(state_val) ) + # Assign result[key] = count result[key] = count + # Return result return result + # Define function get_states_and_results_for_technique def get_states_and_results_for_technique( self, + # Entry: technique_id technique_id: uuid.UUID, ) -> list[tuple[str, str | None]]: - """Return lightweight (state, detection_result) pairs. + """Return lightweight ``(state, detection_result)`` pairs for a technique. - Used by TechniqueEntity.recalculate_status() without loading - full Test models. + Used by ``TechniqueEntity.recalculate_status()`` to avoid loading full + ``Test`` models. + + Args: + technique_id (uuid.UUID): The UUID of the technique to query. + + Returns: + list[tuple[str, str | None]]: Each tuple contains the state string + and the detection result string (or ``None``). """ + # Assign rows = ( rows = ( self._session.query(Test.state, Test.detection_result) + # Chain .filter() call .filter(Test.technique_id == technique_id) + # Chain .all() call .all() ) + # Return [ return [ ( r.state.value if hasattr(r.state, "value") else str(r.state), diff --git a/backend/app/infrastructure/redis_client.py b/backend/app/infrastructure/redis_client.py index 1c8f148..e112a3c 100644 --- a/backend/app/infrastructure/redis_client.py +++ b/backend/app/infrastructure/redis_client.py @@ -13,54 +13,79 @@ Usage:: get_redis_blacklist().setex("blacklist:…", ttl, "1") """ +# Enable future language features for compatibility from __future__ import annotations +# Import logging import logging + +# Import urlparse, urlunparse from urllib.parse from urllib.parse import urlparse, urlunparse +# Import redis import redis +# Import settings from app.config from app.config import settings +# Assign logger = logging.getLogger(__name__) logger = logging.getLogger(__name__) +# Assign _clients = {} _clients: dict[str, redis.Redis] = {} +# Define function _redis_url_with_db def _redis_url_with_db(base_url: str, db_index: int) -> str: """Return *base_url* with its path replaced by ``/{db_index}``.""" + # Assign parsed = urlparse(base_url) parsed = urlparse(base_url) + # Assign path = f"/{db_index}" path = f"/{db_index}" + # Return urlunparse( return urlunparse( (parsed.scheme, parsed.netloc, path, "", "", ""), ) +# Define function _get_client def _get_client(url: str) -> redis.Redis: + # Check: url not in _clients if url not in _clients: + # Assign _clients[url] = redis.from_url(url, decode_responses=True) _clients[url] = redis.from_url(url, decode_responses=True) + # Log info: "Redis client connected to %s", url logger.info("Redis client connected to %s", url) + # Return _clients[url] return _clients[url] +# Define function get_redis def get_redis() -> redis.Redis: """Default Redis connection (URL from ``settings.REDIS_URL``).""" + # Return _get_client(settings.REDIS_URL) return _get_client(settings.REDIS_URL) +# Define function get_redis_blacklist def get_redis_blacklist() -> redis.Redis: """Redis DB used for JWT revocation (``jti`` keys with TTL).""" + # Assign url = _redis_url_with_db( url = _redis_url_with_db( settings.REDIS_URL, settings.REDIS_TOKEN_BLACKLIST_DB, ) + # Return _get_client(url) return _get_client(url) +# Define function get_redis_cache def get_redis_cache() -> redis.Redis: """Redis DB reserved for shared cache (scores, queues, etc.).""" + # Assign url = _redis_url_with_db( url = _redis_url_with_db( settings.REDIS_URL, settings.REDIS_CACHE_DB, ) + # Return _get_client(url) return _get_client(url) diff --git a/backend/app/jobs/__init__.py b/backend/app/jobs/__init__.py index e69de29..811cda3 100644 --- a/backend/app/jobs/__init__.py +++ b/backend/app/jobs/__init__.py @@ -0,0 +1 @@ +"""Background scheduler jobs (MITRE sync, Jira sync, data retention).""" diff --git a/backend/app/jobs/jira_sync_job.py b/backend/app/jobs/jira_sync_job.py index 8bc3cc5..1c8740b 100644 --- a/backend/app/jobs/jira_sync_job.py +++ b/backend/app/jobs/jira_sync_job.py @@ -1,37 +1,65 @@ """Scheduled job — syncs all Jira links hourly.""" +# Import logging import logging +# Import settings from app.config from app.config import settings + +# Import SessionLocal from app.database from app.database import SessionLocal + +# Import JiraLink from app.models.jira_link from app.models.jira_link import JiraLink + +# Import jira_service from app.services from app.services import jira_service +# Assign logger = logging.getLogger(__name__) logger = logging.getLogger(__name__) +# Define function sync_all_jira_links def sync_all_jira_links() -> None: """Pull latest status from Jira for every stored link. Silently skips if ``JIRA_ENABLED`` is ``False``. Individual link failures are logged but do not abort the rest of the batch. """ + # Check: not settings.JIRA_ENABLED if not settings.JIRA_ENABLED: + # Return control to caller return + # Assign db = SessionLocal() db = SessionLocal() + # Attempt the following; catch errors below try: + # Assign links = db.query(JiraLink).all() links = db.query(JiraLink).all() + # Assign synced = 0 synced = 0 + # Iterate over links for link in links: + # Attempt the following; catch errors below try: + # Call jira_service.sync_jira_to_aegis() jira_service.sync_jira_to_aegis(db, link) + # Assign synced = 1 synced += 1 + # Handle Exception except Exception as e: + # Log warning: "Jira sync failed for link %s: %s", link.id, e logger.warning("Jira sync failed for link %s: %s", link.id, e) + # Commit all pending changes to the database db.commit() + # Log info: "Jira sync completed: %d/%d links updated", synced logger.info("Jira sync completed: %d/%d links updated", synced, len(links)) + # Handle Exception except Exception: + # Log exception: "Jira sync batch job failed" logger.exception("Jira sync batch job failed") + # Always execute this cleanup block finally: + # Close the database session db.close() diff --git a/backend/app/jobs/mitre_sync_job.py b/backend/app/jobs/mitre_sync_job.py index 2c69e1f..f591441 100644 --- a/backend/app/jobs/mitre_sync_job.py +++ b/backend/app/jobs/mitre_sync_job.py @@ -10,22 +10,44 @@ Each job manages its own database session (created on entry, closed in sessions. """ +# Import logging import logging from datetime import datetime, timedelta, timezone +# Import BackgroundScheduler from apscheduler.schedulers.background from apscheduler.schedulers.background import BackgroundScheduler +# Import SessionLocal from app.database from app.database import SessionLocal -from app.services.mitre_sync_service import sync_mitre -from app.services.intel_service import scan_intel -from app.services.notification_service import cleanup_old_notifications -from app.services.snapshot_service import create_snapshot, cleanup_old_snapshots -from app.services.campaign_scheduler_service import check_and_run_recurring_campaigns + +# Import sync_all_jira_links from app.jobs.jira_sync_job from app.jobs.jira_sync_job import sync_all_jira_links -from app.services.osint_enrichment_service import enrich_all_techniques -from app.services.stale_detection_service import detect_stale_coverage + +# Import run_retention_job from app.jobs.retention_job from app.jobs.retention_job import run_retention_job +# Import check_and_run_recurring_campaigns from app.services.campaign_scheduler_service +from app.services.campaign_scheduler_service import check_and_run_recurring_campaigns + +# Import scan_intel from app.services.intel_service +from app.services.intel_service import scan_intel + +# Import sync_mitre from app.services.mitre_sync_service +from app.services.mitre_sync_service import sync_mitre + +# Import cleanup_old_notifications from app.services.notification_service +from app.services.notification_service import cleanup_old_notifications + +# Import enrich_all_techniques from app.services.osint_enrichment_service +from app.services.osint_enrichment_service import enrich_all_techniques + +# Import cleanup_old_snapshots, create_snapshot from app.services.snapshot_service +from app.services.snapshot_service import cleanup_old_snapshots, create_snapshot + +# Import detect_stale_coverage from app.services.stale_detection_service +from app.services.stale_detection_service import detect_stale_coverage + +# Assign logger = logging.getLogger(__name__) logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- @@ -44,60 +66,101 @@ def _run_mitre_sync() -> None: """Execute a MITRE sync inside its own DB session.""" from app.services.webhook_service import dispatch_webhook logger.info("Scheduled MITRE sync job starting...") + # Assign db = SessionLocal() db = SessionLocal() + # Attempt the following; catch errors below try: + # Assign summary = sync_mitre(db) summary = sync_mitre(db) + # Log info: "Scheduled MITRE sync job finished — %s", summary logger.info("Scheduled MITRE sync job finished — %s", summary) dispatch_webhook("mitre.synced", {"created": summary.get("created", 0), "updated": summary.get("updated", 0)}) except Exception: + # Log exception: "Scheduled MITRE sync job failed" logger.exception("Scheduled MITRE sync job failed") + # Always execute this cleanup block finally: + # Close the database session db.close() +# Define function _run_notification_cleanup def _run_notification_cleanup() -> None: """Clean up old read notifications.""" + # Log info: "Scheduled notification cleanup job starting..." logger.info("Scheduled notification cleanup job starting...") + # Assign db = SessionLocal() db = SessionLocal() + # Attempt the following; catch errors below try: + # Assign deleted = cleanup_old_notifications(db, days=90) deleted = cleanup_old_notifications(db, days=90) + # Log info: "Notification cleanup finished — deleted %d old no logger.info("Notification cleanup finished — deleted %d old notifications", deleted) + # Handle Exception except Exception: + # Log exception: "Notification cleanup job failed" logger.exception("Notification cleanup job failed") + # Always execute this cleanup block finally: + # Close the database session db.close() +# Define function _run_weekly_snapshot def _run_weekly_snapshot() -> None: """Create a weekly coverage snapshot and clean up old ones.""" + # Log info: "Scheduled weekly snapshot job starting..." logger.info("Scheduled weekly snapshot job starting...") + # Assign db = SessionLocal() db = SessionLocal() + # Attempt the following; catch errors below try: + # Assign snapshot = create_snapshot(db, name="Auto-weekly") snapshot = create_snapshot(db, name="Auto-weekly") + # Log info: logger.info( + # Literal argument value "Weekly snapshot created — score %.1f, %d techniques", snapshot.organization_score, snapshot.total_techniques, ) + # Assign deleted = cleanup_old_snapshots(db, keep_last=52) deleted = cleanup_old_snapshots(db, keep_last=52) + # Check: deleted if deleted: + # Log info: "Cleaned up %d old snapshots", deleted logger.info("Cleaned up %d old snapshots", deleted) + # Handle Exception except Exception: + # Log exception: "Weekly snapshot job failed" logger.exception("Weekly snapshot job failed") + # Always execute this cleanup block finally: + # Close the database session db.close() +# Define function _run_recurring_campaigns def _run_recurring_campaigns() -> None: """Check and run any due recurring campaigns.""" + # Log info: "Scheduled recurring campaigns check starting..." logger.info("Scheduled recurring campaigns check starting...") + # Assign db = SessionLocal() db = SessionLocal() + # Attempt the following; catch errors below try: + # Assign spawned = check_and_run_recurring_campaigns(db) spawned = check_and_run_recurring_campaigns(db) + # Log info: "Recurring campaigns check finished — spawned %d c logger.info("Recurring campaigns check finished — spawned %d campaigns", spawned) + # Handle Exception except Exception: + # Log exception: "Recurring campaigns check failed" logger.exception("Recurring campaigns check failed") + # Always execute this cleanup block finally: + # Close the database session db.close() @@ -193,14 +256,23 @@ def _run_scheduled_campaign_activation() -> None: def _run_intel_scan() -> None: """Execute an intel scan inside its own DB session.""" + # Log info: "Scheduled intel scan job starting..." logger.info("Scheduled intel scan job starting...") + # Assign db = SessionLocal() db = SessionLocal() + # Attempt the following; catch errors below try: + # Assign summary = scan_intel(db) summary = scan_intel(db) + # Log info: "Scheduled intel scan job finished — %s", summary logger.info("Scheduled intel scan job finished — %s", summary) + # Handle Exception except Exception: + # Log exception: "Scheduled intel scan job failed" logger.exception("Scheduled intel scan job failed") + # Always execute this cleanup block finally: + # Close the database session db.close() @@ -283,14 +355,23 @@ def _run_evaluation_round_check() -> None: def _run_osint_enrichment() -> None: """Execute weekly OSINT enrichment inside its own DB session.""" + # Log info: "Scheduled OSINT enrichment job starting..." logger.info("Scheduled OSINT enrichment job starting...") + # Assign db = SessionLocal() db = SessionLocal() + # Attempt the following; catch errors below try: + # Assign total = enrich_all_techniques(db) total = enrich_all_techniques(db) + # Log info: "OSINT enrichment finished — %d new items", total logger.info("OSINT enrichment finished — %d new items", total) + # Handle Exception except Exception: + # Log exception: "OSINT enrichment job failed" logger.exception("OSINT enrichment job failed") + # Always execute this cleanup block finally: + # Close the database session db.close() @@ -351,14 +432,23 @@ def _run_data_sources_sync() -> None: def _run_stale_detection() -> None: """Execute daily stale coverage detection inside its own DB session.""" + # Log info: "Scheduled stale coverage detection starting..." logger.info("Scheduled stale coverage detection starting...") + # Assign db = SessionLocal() db = SessionLocal() + # Attempt the following; catch errors below try: + # Assign count = detect_stale_coverage(db) count = detect_stale_coverage(db) + # Log info: "Stale detection finished — %d techniques flagged" logger.info("Stale detection finished — %d techniques flagged", count) + # Handle Exception except Exception: + # Log exception: "Stale coverage detection job failed" logger.exception("Stale coverage detection job failed") + # Always execute this cleanup block finally: + # Close the database session db.close() @@ -424,40 +514,67 @@ def start_scheduler() -> None: Neither job fires immediately on startup. """ + # Call scheduler.add_job() scheduler.add_job( _run_mitre_sync, + # Keyword argument: trigger trigger="interval", + # Keyword argument: hours hours=24, + # Keyword argument: id id="mitre_sync", + # Keyword argument: name name="MITRE ATT&CK sync (every 24h)", + # Keyword argument: replace_existing replace_existing=True, ) + # Call scheduler.add_job() scheduler.add_job( _run_intel_scan, + # Keyword argument: trigger trigger="interval", + # Keyword argument: weeks weeks=1, + # Keyword argument: id id="intel_scan", + # Keyword argument: name name="Intel scan (every 7d)", + # Keyword argument: replace_existing replace_existing=True, ) + # Call scheduler.add_job() scheduler.add_job( _run_notification_cleanup, + # Keyword argument: trigger trigger="interval", + # Keyword argument: hours hours=24, + # Keyword argument: id id="notification_cleanup", + # Keyword argument: name name="Notification cleanup (daily)", + # Keyword argument: replace_existing replace_existing=True, ) + # Call scheduler.add_job() scheduler.add_job( _run_weekly_snapshot, + # Keyword argument: trigger trigger="cron", + # Keyword argument: day_of_week day_of_week="sun", + # Keyword argument: hour hour=0, + # Keyword argument: minute minute=0, + # Keyword argument: id id="weekly_snapshot", + # Keyword argument: name name="Weekly coverage snapshot (Sundays 00:00)", + # Keyword argument: replace_existing replace_existing=True, ) + # Call scheduler.add_job() scheduler.add_job( _run_scheduled_campaign_activation, trigger="interval", @@ -468,42 +585,71 @@ def start_scheduler() -> None: ) scheduler.add_job( _run_recurring_campaigns, + # Keyword argument: trigger trigger="interval", + # Keyword argument: hours hours=24, + # Keyword argument: id id="recurring_campaigns", + # Keyword argument: name name="Recurring campaigns check (daily)", + # Keyword argument: replace_existing replace_existing=True, ) + # Call scheduler.add_job() scheduler.add_job( sync_all_jira_links, + # Keyword argument: trigger trigger="interval", + # Keyword argument: hours hours=1, + # Keyword argument: id id="jira_sync", + # Keyword argument: name name="Jira link sync (hourly)", + # Keyword argument: replace_existing replace_existing=True, ) + # Call scheduler.add_job() scheduler.add_job( _run_osint_enrichment, + # Keyword argument: trigger trigger="interval", + # Keyword argument: weeks weeks=1, + # Keyword argument: id id="osint_enrichment", + # Keyword argument: name name="OSINT enrichment (weekly)", + # Keyword argument: replace_existing replace_existing=True, ) + # Call scheduler.add_job() scheduler.add_job( _run_stale_detection, + # Keyword argument: trigger trigger="interval", + # Keyword argument: hours hours=24, + # Keyword argument: id id="stale_detection", + # Keyword argument: name name="Stale coverage detection (daily)", + # Keyword argument: replace_existing replace_existing=True, ) + # Call scheduler.add_job() scheduler.add_job( run_retention_job, + # Keyword argument: trigger trigger="interval", + # Keyword argument: hours hours=24, + # Keyword argument: id id="retention_policies", + # Keyword argument: name name="Data retention policies (daily)", + # Keyword argument: replace_existing replace_existing=True, ) scheduler.add_job( @@ -551,10 +697,15 @@ def start_scheduler() -> None: replace_existing=True, ) scheduler.start() + # Log info: logger.info( + # Literal argument value "Background scheduler started — mitre_sync (24h), intel_scan (7d), " + # Literal argument value "notification_cleanup (24h), weekly_snapshot (Sundays 00:00), " + # Literal argument value "recurring_campaigns (daily), jira_sync (1h), " + # Literal argument value "osint_enrichment (weekly), stale_detection (daily), " "retention_policies (daily), data_sources_sync (6h), " "alert_evaluation (1h), attck_evaluation_check (Mondays 06:00)" diff --git a/backend/app/jobs/retention_job.py b/backend/app/jobs/retention_job.py index 1cd056e..5e79cbd 100644 --- a/backend/app/jobs/retention_job.py +++ b/backend/app/jobs/retention_job.py @@ -1,53 +1,89 @@ """Data retention policies — scheduled cleanup of aged records.""" +# Enable future language features for compatibility from __future__ import annotations +# Import logging import logging + +# Import datetime, timedelta, timezone from datetime from datetime import datetime, timedelta, timezone +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session +# Import SessionLocal from app.database from app.database import SessionLocal + +# Import AuditLog from app.models.audit from app.models.audit import AuditLog + +# Import cleanup_old_notifications from app.services.notification_service from app.services.notification_service import cleanup_old_notifications +# Assign logger = logging.getLogger(__name__) logger = logging.getLogger(__name__) +# Assign AUDIT_LOG_RETENTION_DAYS = 730 AUDIT_LOG_RETENTION_DAYS = 730 +# Define function apply_retention_policies def apply_retention_policies(db: Session) -> dict[str, int]: """Apply retention rules. Commits the session before returning.""" + # Assign cutoff = datetime.now(timezone.utc) - timedelta(days=AUDIT_LOG_RETENTION_DAYS) cutoff = datetime.now(timezone.utc) - timedelta(days=AUDIT_LOG_RETENTION_DAYS) + # Assign deleted_audit = ( deleted_audit = ( db.query(AuditLog) + # Chain .filter() call .filter(AuditLog.timestamp < cutoff) + # Chain .delete() call .delete(synchronize_session=False) ) + # Check: deleted_audit if deleted_audit: + # Log info: logger.info( + # Literal argument value "Retention: deleted %d audit logs older than %d days", deleted_audit, AUDIT_LOG_RETENTION_DAYS, ) + # Assign deleted_notifications = cleanup_old_notifications(db, days=90) deleted_notifications = cleanup_old_notifications(db, days=90) + # Commit all pending changes to the database db.commit() + # Return { return { + # Literal argument value "audit_logs_deleted": deleted_audit, + # Literal argument value "notifications_deleted": deleted_notifications, } +# Define function run_retention_job def run_retention_job() -> None: """Entry point for the daily retention scheduler job.""" + # Log info: "Scheduled retention job starting..." logger.info("Scheduled retention job starting...") + # Assign db = SessionLocal() db = SessionLocal() + # Attempt the following; catch errors below try: + # Assign summary = apply_retention_policies(db) summary = apply_retention_policies(db) + # Log info: "Retention job finished — %s", summary logger.info("Retention job finished — %s", summary) + # Handle Exception except Exception: + # Log exception: "Retention job failed" logger.exception("Retention job failed") + # Roll back all uncommitted changes db.rollback() + # Always execute this cleanup block finally: + # Close the database session db.close() diff --git a/backend/app/limiter.py b/backend/app/limiter.py index a2a2a99..3ac2acb 100644 --- a/backend/app/limiter.py +++ b/backend/app/limiter.py @@ -1,6 +1,10 @@ """Shared SlowAPI rate limiter for all routers.""" +# Import Limiter from slowapi from slowapi import Limiter + +# Import get_remote_address from slowapi.util from slowapi.util import get_remote_address +# Assign limiter = Limiter(key_func=get_remote_address) limiter = Limiter(key_func=get_remote_address) diff --git a/backend/app/logging_config.py b/backend/app/logging_config.py index 84dbb96..9c513b7 100644 --- a/backend/app/logging_config.py +++ b/backend/app/logging_config.py @@ -8,60 +8,101 @@ In **development** (default), uses a human-readable text format for comfortable local work. """ +# Enable future language features for compatibility from __future__ import annotations +# Import json import json + +# Import logging import logging + +# Import os import os + +# Import sys import sys + +# Import datetime, timezone from datetime from datetime import datetime, timezone +# Define class _JSONFormatter class _JSONFormatter(logging.Formatter): """Emit each log record as a single-line JSON object.""" + # Define function format def format(self, record: logging.LogRecord) -> str: + # Assign payload = { payload: dict = { + # Literal argument value "timestamp": datetime.fromtimestamp(record.created, tz=timezone.utc).isoformat(), + # Literal argument value "level": record.levelname, + # Literal argument value "logger": record.name, + # Literal argument value "message": record.getMessage(), } + # Check: record.exc_info and record.exc_info[1] is not None if record.exc_info and record.exc_info[1] is not None: + # Assign payload["exception"] = self.formatException(record.exc_info) payload["exception"] = self.formatException(record.exc_info) + # Assign extra = getattr(record, "_extra", None) extra = getattr(record, "_extra", None) + # Check: extra if extra: + # Call payload.update() payload.update(extra) + # Return json.dumps(payload, default=str) return json.dumps(payload, default=str) +# Assign _DEV_FORMAT = "%(asctime)s %(levelname)-8s %(name)s — %(message)s" _DEV_FORMAT = "%(asctime)s %(levelname)-8s %(name)s — %(message)s" +# Define function setup_logging def setup_logging() -> None: """Configure the root logger based on the environment.""" + # Assign is_production = os.environ.get("AEGIS_ENV", "").lower() == "production" is_production = os.environ.get("AEGIS_ENV", "").lower() == "production" + # Assign level_name = os.environ.get("LOG_LEVEL", "INFO").upper() level_name = os.environ.get("LOG_LEVEL", "INFO").upper() + # Assign level = getattr(logging, level_name, logging.INFO) level = getattr(logging, level_name, logging.INFO) + # Assign root = logging.getLogger() root = logging.getLogger() + # Call root.setLevel() root.setLevel(level) + # Check: root.handlers if root.handlers: + # Call root.handlers.clear() root.handlers.clear() + # Assign handler = logging.StreamHandler(sys.stdout) handler = logging.StreamHandler(sys.stdout) + # Call handler.setLevel() handler.setLevel(level) + # Check: is_production if is_production: + # Call handler.setFormatter() handler.setFormatter(_JSONFormatter()) + # Fallback: handle remaining cases else: + # Call handler.setFormatter() handler.setFormatter(logging.Formatter(_DEV_FORMAT)) + # Call root.addHandler() root.addHandler(handler) + # Call logging.getLogger() logging.getLogger("uvicorn.access").setLevel(logging.WARNING) + # Call logging.getLogger() logging.getLogger("sqlalchemy.engine").setLevel(logging.WARNING) diff --git a/backend/app/main.py b/backend/app/main.py index 1b42f5e..6d02982 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -1,13 +1,41 @@ +"""FastAPI application factory and global middleware/exception configuration. + +Builds the ``app`` instance, wires up CORS, rate limiting, domain-error +mapping, all API routers, and async lifespan hooks (MinIO bucket creation, +APScheduler startup/shutdown). +""" + +# Import logging import logging + +# Import os import os + +# Import AsyncGenerator from collections.abc +from collections.abc import AsyncGenerator + +# Import asynccontextmanager from contextlib from contextlib import asynccontextmanager +# Import FastAPI, Request, status from fastapi from fastapi import FastAPI, Request, status -from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import JSONResponse + +# Import RequestValidationError from fastapi.exceptions from fastapi.exceptions import RequestValidationError + +# Import CORSMiddleware from fastapi.middleware.cors +from fastapi.middleware.cors import CORSMiddleware + +# Import JSONResponse from fastapi.responses +from fastapi.responses import JSONResponse + +# Import _rate_limit_exceeded_handler from slowapi from slowapi import _rate_limit_exceeded_handler + +# Import RateLimitExceeded from slowapi.errors from slowapi.errors import RateLimitExceeded + +# Import SQLAlchemyError from sqlalchemy.exc from sqlalchemy.exc import SQLAlchemyError from app.routers import auth as auth_router @@ -50,24 +78,127 @@ from app.routers import api_keys as api_keys_router from app.routers import sso as sso_router from app.routers import operational_alerts as alerts_router from app.domain.errors import DomainError -from app.middleware.error_handler import domain_exception_handler -from app.middleware.request_context import RequestContextMiddleware + +# Import scheduler, start_scheduler from app.jobs.mitre_sync_job +from app.jobs.mitre_sync_job import scheduler, start_scheduler + +# Import limiter from app.limiter from app.limiter import limiter + +# Import setup_logging from app.logging_config +from app.logging_config import setup_logging + +# Import domain_exception_handler from app.middleware.error_handler +from app.middleware.error_handler import domain_exception_handler + +# Import RequestContextMiddleware from app.middleware.request_context +from app.middleware.request_context import RequestContextMiddleware + +# Import advanced_metrics as advanced_metrics_router from app.routers +from app.routers import advanced_metrics as advanced_metrics_router + +# Import analytics as analytics_router from app.routers +from app.routers import analytics as analytics_router + +# Import audit as audit_router from app.routers +from app.routers import audit as audit_router + +# Import auth as auth_router from app.routers +from app.routers import auth as auth_router + +# Import campaigns as campaigns_router from app.routers +from app.routers import campaigns as campaigns_router + +# Import compliance as compliance_router from app.routers +from app.routers import compliance as compliance_router + +# Import d3fend as d3fend_router from app.routers +from app.routers import d3fend as d3fend_router + +# Import data_sources as data_sources_router from app.routers +from app.routers import data_sources as data_sources_router + +# Import detection_rules as detection_rules_router from app.routers +from app.routers import detection_rules as detection_rules_router + +# Import evidence as evidence_router from app.routers +from app.routers import evidence as evidence_router + +# Import heatmap as heatmap_router from app.routers +from app.routers import heatmap as heatmap_router + +# Import jira as jira_router from app.routers +from app.routers import jira as jira_router + +# Import metrics as metrics_router from app.routers +from app.routers import metrics as metrics_router + +# Import notifications as notifications_router from app.routers +from app.routers import notifications as notifications_router + +# Import operational_metrics as operational_metrics_router from app.routers +from app.routers import operational_metrics as operational_metrics_router + +# Import osint as osint_router from app.routers +from app.routers import osint as osint_router + +# Import professional_reports as professional_reports_ro... from app.routers +from app.routers import professional_reports as professional_reports_router + +# Import reports as reports_router from app.routers +from app.routers import reports as reports_router + +# Import scores as scores_router from app.routers +from app.routers import scores as scores_router + +# Import snapshots as snapshots_router from app.routers +from app.routers import snapshots as snapshots_router + +# Import system as system_router from app.routers +from app.routers import system as system_router + +# Import techniques as techniques_router from app.routers +from app.routers import techniques as techniques_router + +# Import test_templates as test_templates_router from app.routers +from app.routers import test_templates as test_templates_router + +# Import tests as tests_router from app.routers +from app.routers import tests as tests_router + +# Import threat_actors as threat_actors_router from app.routers +from app.routers import threat_actors as threat_actors_router + +# Import users as users_router from app.routers +from app.routers import users as users_router + +# Import worklogs as worklogs_router from app.routers +from app.routers import worklogs as worklogs_router + +# Import ensure_bucket_exists from app.storage from app.storage import ensure_bucket_exists -from app.jobs.mitre_sync_job import start_scheduler, scheduler + +# Configure structured logging before any module initialises its own logger +setup_logging() # ── Environment detection ───────────────────────────────────────────────── _IS_PRODUCTION = os.environ.get("AEGIS_ENV", "").lower() == "production" -# ── Logging ─────────────────────────────────────────────────────────────── -from app.logging_config import setup_logging - -setup_logging() - +# Apply the @asynccontextmanager decorator @asynccontextmanager -async def lifespan(app: FastAPI): - """Startup / shutdown logic.""" +# Define async function lifespan +async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: + """Manage application startup and shutdown lifecycle. + + Args: + app (FastAPI): The FastAPI application instance. + + Yields: + None: Control is yielded to the running application. + """ + # Call ensure_bucket_exists() ensure_bucket_exists() + # Call start_scheduler() start_scheduler() # Seed decay policies from app.database import SessionLocal @@ -95,17 +226,24 @@ async def lifespan(app: FastAPI): # ── In production, disable Swagger UI and ReDoc to hide API surface ────── app = FastAPI( + # Keyword argument: title title="Attack Coverage Platform", + # Keyword argument: lifespan lifespan=lifespan, + # Keyword argument: docs_url docs_url=None if _IS_PRODUCTION else "/docs", + # Keyword argument: redoc_url redoc_url=None if _IS_PRODUCTION else "/redoc", + # Keyword argument: openapi_url openapi_url=None if _IS_PRODUCTION else "/openapi.json", ) # ── Rate Limiter ────────────────────────────────────────────────────────── app.state.limiter = limiter +# Call app.add_exception_handler() app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) +# Call app.add_middleware() app.add_middleware(RequestContextMiddleware) @@ -130,49 +268,77 @@ app.add_middleware(NoCacheAPIMiddleware) app.add_exception_handler(DomainError, domain_exception_handler) # ── CORS ────────────────────────────────────────────────────────────────── -from app.config import settings as _settings - _cors_origins: list[str] = [ o.strip() for o in _settings.CORS_ORIGINS.split(",") if o.strip() ] +# Call app.add_middleware() app.add_middleware( CORSMiddleware, + # Keyword argument: allow_origins allow_origins=_cors_origins, + # Keyword argument: allow_credentials allow_credentials=True, + # Keyword argument: allow_methods allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"], + # Keyword argument: allow_headers allow_headers=["Authorization", "Content-Type"], ) # ── Routers ────────────────────────────────────────────────────────────── app.include_router(auth_router.router, prefix="/api/v1") +# Call app.include_router() app.include_router(techniques_router.router, prefix="/api/v1") +# Call app.include_router() app.include_router(tests_router.router, prefix="/api/v1") +# Call app.include_router() app.include_router(evidence_router.router, prefix="/api/v1") +# Call app.include_router() app.include_router(test_templates_router.router, prefix="/api/v1") +# Call app.include_router() app.include_router(system_router.router, prefix="/api/v1") +# Call app.include_router() app.include_router(metrics_router.router, prefix="/api/v1") +# Call app.include_router() app.include_router(users_router.router, prefix="/api/v1") +# Call app.include_router() app.include_router(audit_router.router, prefix="/api/v1") +# Call app.include_router() app.include_router(notifications_router.router, prefix="/api/v1") +# Call app.include_router() app.include_router(reports_router.router, prefix="/api/v1") +# Call app.include_router() app.include_router(data_sources_router.router, prefix="/api/v1") +# Call app.include_router() app.include_router(threat_actors_router.router, prefix="/api/v1") +# Call app.include_router() app.include_router(d3fend_router.router, prefix="/api/v1") +# Call app.include_router() app.include_router(detection_rules_router.router, prefix="/api/v1") +# Call app.include_router() app.include_router(campaigns_router.router, prefix="/api/v1") +# Call app.include_router() app.include_router(heatmap_router.router, prefix="/api/v1") +# Call app.include_router() app.include_router(scores_router.router, prefix="/api/v1") +# Call app.include_router() app.include_router(operational_metrics_router.router, prefix="/api/v1") +# Call app.include_router() app.include_router(compliance_router.router, prefix="/api/v1") app.include_router(intel_router.router, prefix="/api/v1") app.include_router(admin_config_router.router, prefix="/api/v1") app.include_router(snapshots_router.router, prefix="/api/v1") +# Call app.include_router() app.include_router(jira_router.router, prefix="/api/v1") +# Call app.include_router() app.include_router(worklogs_router.router, prefix="/api/v1") +# Call app.include_router() app.include_router(professional_reports_router.router, prefix="/api/v1") +# Call app.include_router() app.include_router(analytics_router.router, prefix="/api/v1") +# Call app.include_router() app.include_router(advanced_metrics_router.router, prefix="/api/v1") +# Call app.include_router() app.include_router(osint_router.router, prefix="/api/v1") app.include_router(webhooks_router.router, prefix="/api/v1") app.include_router(detection_lifecycle_router.router, prefix="/api/v1") @@ -186,13 +352,19 @@ app.include_router(sso_router.router, prefix="/api/v1") app.include_router(alerts_router.router, prefix="/api/v1") +# Apply the @app.get decorator @app.get("/health", include_in_schema=False) -def health(): - """Minimal health check — returns only an HTTP 200 with no service metadata. +# Define function health +def health() -> dict[str, str]: + """Return a minimal liveness probe response. Access is restricted to internal networks at the Nginx level (see ``frontend/nginx.conf``). + + Returns: + dict[str, str]: A dict with ``{"status": "ok"}``. """ + # Return {"status": "ok"} return {"status": "ok"} @@ -200,51 +372,117 @@ def health(): def _serialize_validation_errors(exc: RequestValidationError) -> list[dict]: - """Return validation errors safe for JSON (no raw exception objects).""" + """Return validation errors safe for JSON serialization. + + Converts non-serializable values inside ``ctx`` dictionaries to strings + so the response body can be safely encoded. + + Args: + exc (RequestValidationError): The Pydantic validation exception. + + Returns: + list[dict]: A list of sanitised error detail dictionaries. + """ + # Assign serialized = [] serialized: list[dict] = [] + # Iterate over exc.errors() for err in exc.errors(): + # Assign item = dict(err) item = dict(err) + # Assign ctx = item.get("ctx") ctx = item.get("ctx") + # Check: isinstance(ctx, dict) if isinstance(ctx, dict): + # Assign item["ctx"] = {key: str(value) for key, value in ctx.items()} item["ctx"] = {key: str(value) for key, value in ctx.items()} + # Call serialized.append() serialized.append(item) + # Return serialized return serialized +# Apply the @app.exception_handler decorator @app.exception_handler(RequestValidationError) -async def validation_exception_handler(request: Request, exc: RequestValidationError): - """Handle validation errors with consistent format.""" +# Define async function validation_exception_handler +async def validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse: + """Handle Pydantic validation errors and return a structured 422 response. + + Args: + request (Request): The incoming HTTP request. + exc (RequestValidationError): The caught validation exception. + + Returns: + JSONResponse: A 422 response with a ``VALIDATION_ERROR`` code and error details. + """ + # Return JSONResponse( return JSONResponse( + # Keyword argument: status_code status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + # Keyword argument: content content={ + # Literal argument value "detail": "Validation error", + # Literal argument value "code": "VALIDATION_ERROR", + # Literal argument value "errors": _serialize_validation_errors(exc), }, ) +# Apply the @app.exception_handler decorator @app.exception_handler(SQLAlchemyError) -async def sqlalchemy_exception_handler(request: Request, exc: SQLAlchemyError): - """Handle database errors.""" +# Define async function sqlalchemy_exception_handler +async def sqlalchemy_exception_handler(request: Request, exc: SQLAlchemyError) -> JSONResponse: + """Handle SQLAlchemy database errors and return a structured 500 response. + + Args: + request (Request): The incoming HTTP request. + exc (SQLAlchemyError): The caught SQLAlchemy exception. + + Returns: + JSONResponse: A 500 response with a ``DATABASE_ERROR`` code. + """ + # Log error: f"Database error: {exc}" logging.error(f"Database error: {exc}") + # Return JSONResponse( return JSONResponse( + # Keyword argument: status_code status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + # Keyword argument: content content={ + # Literal argument value "detail": "Database error occurred", + # Literal argument value "code": "DATABASE_ERROR", }, ) +# Apply the @app.exception_handler decorator @app.exception_handler(Exception) -async def general_exception_handler(request: Request, exc: Exception): - """Handle all unhandled exceptions.""" +# Define async function general_exception_handler +async def general_exception_handler(request: Request, exc: Exception) -> JSONResponse: + """Handle all otherwise-unhandled exceptions and return a structured 500 response. + + Args: + request (Request): The incoming HTTP request. + exc (Exception): The unhandled exception. + + Returns: + JSONResponse: A 500 response with an ``INTERNAL_ERROR`` code. + """ + # Log error: f"Unhandled exception: {exc}" logging.error(f"Unhandled exception: {exc}") + # Return JSONResponse( return JSONResponse( + # Keyword argument: status_code status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + # Keyword argument: content content={ + # Literal argument value "detail": "An internal server error occurred", + # Literal argument value "code": "INTERNAL_ERROR", }, ) diff --git a/backend/app/middleware/__init__.py b/backend/app/middleware/__init__.py index e69de29..966c57b 100644 --- a/backend/app/middleware/__init__.py +++ b/backend/app/middleware/__init__.py @@ -0,0 +1 @@ +"""ASGI middleware components for request context, error handling, and rate limiting.""" diff --git a/backend/app/middleware/error_handler.py b/backend/app/middleware/error_handler.py index b815da9..28ca79d 100644 --- a/backend/app/middleware/error_handler.py +++ b/backend/app/middleware/error_handler.py @@ -5,9 +5,13 @@ domain-layer errors into structured JSON responses, keeping the service layer free from FastAPI's ``HTTPException``. """ +# Import Request from fastapi from fastapi import Request + +# Import JSONResponse from fastapi.responses from fastapi.responses import JSONResponse +# Import from app.domain.errors from app.domain.errors import ( BusinessRuleViolation, DomainError, @@ -18,28 +22,45 @@ from app.domain.errors import ( PermissionViolation, ) +# Assign EXCEPTION_STATUS_MAP = { EXCEPTION_STATUS_MAP: dict[type[DomainError], int] = { + # Entry: EntityNotFoundError EntityNotFoundError: 404, + # Entry: DuplicateEntityError DuplicateEntityError: 409, + # Entry: InvalidStateTransition InvalidStateTransition: 400, + # Entry: InvalidOperationError InvalidOperationError: 400, + # Entry: BusinessRuleViolation BusinessRuleViolation: 400, + # Entry: PermissionViolation PermissionViolation: 403, } +# Define async function domain_exception_handler async def domain_exception_handler( + # Entry: request request: Request, + # Entry: exc exc: DomainError, ) -> JSONResponse: """Convert a :class:`DomainError` into a JSON error response.""" + # Assign status_code = EXCEPTION_STATUS_MAP.get(type(exc), 400) status_code = EXCEPTION_STATUS_MAP.get(type(exc), 400) + # Assign content = {"detail": exc.message, "code": exc.code} content: dict = {"detail": exc.message, "code": exc.code} + # Check: isinstance(exc, InvalidStateTransition) if isinstance(exc, InvalidStateTransition): + # Assign content["current_state"] = exc.current_state content["current_state"] = exc.current_state + # Assign content["target_state"] = exc.target_state content["target_state"] = exc.target_state + # Assign content["valid_transitions"] = exc.valid_transitions content["valid_transitions"] = exc.valid_transitions + # Return JSONResponse(status_code=status_code, content=content) return JSONResponse(status_code=status_code, content=content) diff --git a/backend/app/middleware/request_context.py b/backend/app/middleware/request_context.py index f49ef57..30c01a4 100644 --- a/backend/app/middleware/request_context.py +++ b/backend/app/middleware/request_context.py @@ -1,26 +1,74 @@ """Request context middleware — captures client IP and User-Agent per request.""" +# Import Awaitable, Callable from collections.abc +from collections.abc import Awaitable, Callable + +# Import ContextVar from contextvars from contextvars import ContextVar +# Import Request from fastapi from fastapi import Request + +# Import BaseHTTPMiddleware from starlette.middleware.base from starlette.middleware.base import BaseHTTPMiddleware +# Import Response from starlette.responses +from starlette.responses import Response + +# Assign request_ip = ContextVar("request_ip", default="") request_ip: ContextVar[str] = ContextVar("request_ip", default="") +# Assign request_user_agent = ContextVar("request_user_agent", default="") request_user_agent: ContextVar[str] = ContextVar("request_user_agent", default="") +# Define function resolve_client_ip def resolve_client_ip(request: Request) -> str: - """Extract the client IP, honouring ``X-Forwarded-For`` when present.""" + """Extract the real client IP, honouring ``X-Forwarded-For`` when present. + + Args: + request (Request): The incoming Starlette/FastAPI request. + + Returns: + str: The resolved client IP address, or ``"unknown"`` when unavailable. + """ + # Assign forwarded = request.headers.get("X-Forwarded-For") forwarded = request.headers.get("X-Forwarded-For") + # Check: forwarded if forwarded: + # Return forwarded.split(",")[0].strip() return forwarded.split(",")[0].strip() + # Check: request.client if request.client: + # Return request.client.host return request.client.host + # Return "unknown" return "unknown" +# Define class RequestContextMiddleware class RequestContextMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request: Request, call_next): + """Middleware that captures client IP and User-Agent into context variables.""" + + # Define async function dispatch + async def dispatch( + self, + # Entry: request + request: Request, + # Entry: call_next + call_next: Callable[[Request], Awaitable[Response]], + ) -> Response: + """Store client IP and User-Agent in context vars for the current request. + + Args: + request (Request): The incoming HTTP request. + call_next (Callable[[Request], Awaitable[Response]]): The next middleware or route handler. + + Returns: + Response: The HTTP response produced by the downstream handler. + """ + # Call request_ip.set() request_ip.set(resolve_client_ip(request)) + # Call request_user_agent.set() request_user_agent.set(request.headers.get("User-Agent", "")) + # Return await call_next(request) return await call_next(request) diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index e094e4c..91a0743 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -1,10 +1,5 @@ +"""SQLAlchemy ORM model definitions for all database tables.""" # Import all models here so Alembic can detect them -from app.models.user import User -from app.models.technique import Technique -from app.models.test import Test -from app.models.test_template import TestTemplate -from app.models.evidence import Evidence -from app.models.intel import IntelItem from app.models.audit import AuditLog from app.models.notification import Notification from app.models.data_source import DataSource @@ -45,17 +40,96 @@ from app.models.api_key import ApiKey from app.models.sso_config import SsoConfig from app.models.operational_alert import AlertRule, AlertInstance +# Import Campaign, CampaignTest from app.models.campaign +from app.models.campaign import Campaign, CampaignTest + +# Import from app.models.compliance +from app.models.compliance import ( + ComplianceControl, + ComplianceControlMapping, + ComplianceFramework, +) + +# Import CoverageSnapshot, SnapshotTechniqueState from app.models.coverage_snapshot +from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState + +# Import DataSource from app.models.data_source +from app.models.data_source import DataSource + +# Import DefensiveTechnique, DefensiveTechniqueMapping from app.models.defensive_technique +from app.models.defensive_technique import DefensiveTechnique, DefensiveTechniqueMapping + +# Import DetectionRule from app.models.detection_rule +from app.models.detection_rule import DetectionRule + +# Import TeamSide, TechniqueStatus, TestResult, TestState from app.models.enums +from app.models.enums import TeamSide, TechniqueStatus, TestResult, TestState + +# Import Evidence from app.models.evidence +from app.models.evidence import Evidence + +# Import IntelItem from app.models.intel +from app.models.intel import IntelItem + +# Import JiraLink, JiraLinkEntityType, JiraSyncDirection from app.models.jira_link +from app.models.jira_link import JiraLink, JiraLinkEntityType, JiraSyncDirection + +# Import Notification from app.models.notification +from app.models.notification import Notification + +# Import OsintItem from app.models.osint_item +from app.models.osint_item import OsintItem + +# Import ScoringConfig from app.models.scoring_config +from app.models.scoring_config import ScoringConfig + +# Import Technique from app.models.technique +from app.models.technique import Technique + +# Import Test from app.models.test +from app.models.test import Test + +# Import TestDetectionResult from app.models.test_detection_result +from app.models.test_detection_result import TestDetectionResult + +# Import TestTemplate from app.models.test_template +from app.models.test_template import TestTemplate + +# Import TestTemplateDetectionRule from app.models.test_template_detection_rule +from app.models.test_template_detection_rule import TestTemplateDetectionRule + +# Import ThreatActor, ThreatActorTechnique from app.models.threat_actor +from app.models.threat_actor import ThreatActor, ThreatActorTechnique + +# Import User from app.models.user +from app.models.user import User + +# Import Worklog from app.models.worklog +from app.models.worklog import Worklog + +# Assign __all__ = [ __all__ = [ + # Literal argument value "User", "Technique", "Test", "TestTemplate", "Evidence", + # Literal argument value "IntelItem", "AuditLog", "Notification", "DataSource", + # Literal argument value "DetectionRule", "ThreatActor", "ThreatActorTechnique", + # Literal argument value "DefensiveTechnique", "DefensiveTechniqueMapping", + # Literal argument value "TestTemplateDetectionRule", "TestDetectionResult", + # Literal argument value "Campaign", "CampaignTest", + # Literal argument value "ComplianceFramework", "ComplianceControl", "ComplianceControlMapping", + # Literal argument value "CoverageSnapshot", "SnapshotTechniqueState", + # Literal argument value "JiraLink", "JiraLinkEntityType", "JiraSyncDirection", + # Literal argument value "Worklog", "OsintItem", "ScoringConfig", + # Literal argument value "TechniqueStatus", "TestState", "TestResult", "TeamSide", "WebhookConfig", "SystemConfig", "DetectionAsset", "DetectionTechniqueMapping", "DetectionValidation", diff --git a/backend/app/models/audit.py b/backend/app/models/audit.py index dda16b5..5571829 100644 --- a/backend/app/models/audit.py +++ b/backend/app/models/audit.py @@ -1,35 +1,58 @@ +"""SQLAlchemy model for the audit log table.""" + +# Import uuid import uuid -from sqlalchemy import Column, String, DateTime, ForeignKey, Index, func -from sqlalchemy.dialects.postgresql import UUID, JSONB + +# Import Column, DateTime, ForeignKey, Index, String, func from sqlalchemy +from sqlalchemy import Column, DateTime, ForeignKey, Index, String, func + +# Import JSONB, UUID from sqlalchemy.dialects.postgresql +from sqlalchemy.dialects.postgresql import JSONB, UUID + +# Import relationship from sqlalchemy.orm from sqlalchemy.orm import relationship +# Import Base from app.database from app.database import Base +# Define class AuditLog class AuditLog(Base): - """ - Audit log model for tracking all system actions. - + """Audit log model for tracking all system actions. + Records user actions, entity changes, and system events for security auditing and compliance purposes. """ + # Assign __tablename__ = "audit_logs" __tablename__ = "audit_logs" + # Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + # Assign user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True) user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True) + # Assign action = Column(String, nullable=False) action = Column(String, nullable=False) + # Assign entity_type = Column(String, nullable=True) entity_type = Column(String, nullable=True) + # Assign entity_id = Column(String, nullable=True) entity_id = Column(String, nullable=True) + # Assign timestamp = Column(DateTime(timezone=True), server_default=func.now()) timestamp = Column(DateTime(timezone=True), server_default=func.now()) + # Assign details = Column(JSONB, nullable=True) details = Column(JSONB, nullable=True) + # Assign ip_address = Column(String(45), nullable=True) ip_address = Column(String(45), nullable=True) + # Assign user_agent = Column(String(500), nullable=True) user_agent = Column(String(500), nullable=True) + # Assign integrity_hash = Column(String(64), nullable=True) integrity_hash = Column(String(64), nullable=True) + # Assign session_id = Column(String(100), nullable=True) session_id = Column(String(100), nullable=True) # Relationships user = relationship("User") + # Assign __table_args__ = ( __table_args__ = ( Index("ix_audit_logs_entity", "entity_type", "entity_id"), Index("ix_audit_logs_timestamp", "timestamp"), diff --git a/backend/app/models/campaign.py b/backend/app/models/campaign.py index 957f8f2..0d4f8b2 100644 --- a/backend/app/models/campaign.py +++ b/backend/app/models/campaign.py @@ -4,20 +4,35 @@ Campaigns group multiple tests into a kill chain sequence, enabling simulation of complete attack chains and APT emulations. """ +# Import uuid import uuid + +# Import from sqlalchemy from sqlalchemy import ( - Column, String, Text, Integer, Boolean, DateTime, - ForeignKey, Index, func, + Boolean, + Column, + DateTime, + ForeignKey, + Index, + Integer, + String, + Text, + func, ) -from sqlalchemy.dialects.postgresql import UUID, JSONB + +# Import JSONB, UUID from sqlalchemy.dialects.postgresql +from sqlalchemy.dialects.postgresql import JSONB, UUID + +# Import relationship from sqlalchemy.orm from sqlalchemy.orm import relationship +# Import Base from app.database from app.database import Base +# Define class Campaign class Campaign(Base): - """ - A campaign groups multiple tests into a sequenced attack chain. + """A campaign groups multiple tests into a sequenced attack chain. Types: - custom: manually created campaign @@ -31,62 +46,97 @@ class Campaign(Base): - completed: all tests done - archived: historical record """ + # Assign __tablename__ = "campaigns" __tablename__ = "campaigns" + # Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + # Assign name = Column(String, nullable=False) name = Column(String, nullable=False) + # Assign description = Column(Text, nullable=True) description = Column(Text, nullable=True) + # Assign type = Column(String, nullable=False, default="custom") # custom, ap... type = Column(String, nullable=False, default="custom") # custom, apt_emulation, kill_chain, compliance + # Assign threat_actor_id = Column( threat_actor_id = Column( UUID(as_uuid=True), ForeignKey("threat_actors.id", ondelete="SET NULL"), + # Keyword argument: nullable nullable=True, ) + # Assign status = Column(String, nullable=False, default="draft") # draft, activ... status = Column(String, nullable=False, default="draft") # draft, active, completed, archived + # Assign created_by = Column( created_by = Column( UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), + # Keyword argument: nullable nullable=True, ) start_date = Column(DateTime, nullable=True) # campaign won't activate before this date scheduled_at = Column(DateTime, nullable=True) + # Assign completed_at = Column(DateTime, nullable=True) completed_at = Column(DateTime, nullable=True) + # Assign target_platform = Column(String, nullable=True) target_platform = Column(String, nullable=True) + # Assign tags = Column(JSONB, nullable=True, default=[]) tags = Column(JSONB, nullable=True, default=[]) + # Assign created_at = Column(DateTime(timezone=True), server_default=func.now()) created_at = Column(DateTime(timezone=True), server_default=func.now()) + # Assign data_classification = Column(String(20), nullable=False, server_default="internal") data_classification = Column(String(20), nullable=False, server_default="internal") # Recurring scheduling fields is_recurring = Column(Boolean, default=False) + # Assign recurrence_pattern = Column(String, nullable=True) # weekly, monthly, quarterly recurrence_pattern = Column(String, nullable=True) # weekly, monthly, quarterly + # Assign next_run_at = Column(DateTime, nullable=True) next_run_at = Column(DateTime, nullable=True) + # Assign last_run_at = Column(DateTime, nullable=True) last_run_at = Column(DateTime, nullable=True) + # Assign parent_campaign_id = Column( parent_campaign_id = Column( UUID(as_uuid=True), ForeignKey("campaigns.id", ondelete="SET NULL"), + # Keyword argument: nullable nullable=True, ) # Relationships threat_actor = relationship("ThreatActor") + # Assign creator = relationship("User", foreign_keys=[created_by]) creator = relationship("User", foreign_keys=[created_by]) + # Assign campaign_tests = relationship( campaign_tests = relationship( + # Literal argument value "CampaignTest", + # Keyword argument: back_populates back_populates="campaign", + # Keyword argument: cascade cascade="all, delete-orphan", + # Keyword argument: order_by order_by="CampaignTest.order_index", ) + # Assign parent_campaign = relationship( parent_campaign = relationship( + # Literal argument value "Campaign", + # Keyword argument: remote_side remote_side="Campaign.id", + # Keyword argument: foreign_keys foreign_keys=[parent_campaign_id], ) + # Assign child_campaigns = relationship( child_campaigns = relationship( + # Literal argument value "Campaign", + # Keyword argument: foreign_keys foreign_keys=[parent_campaign_id], + # Keyword argument: back_populates back_populates="parent_campaign", ) + # Assign __table_args__ = ( __table_args__ = ( Index('ix_campaigns_status', 'status'), Index('ix_campaigns_type', 'type'), @@ -98,56 +148,83 @@ class Campaign(Base): # Kill chain phases in order (for sorting and validation) KILL_CHAIN_PHASES = [ + # Literal argument value "reconnaissance", + # Literal argument value "resource_development", + # Literal argument value "initial_access", + # Literal argument value "execution", + # Literal argument value "persistence", + # Literal argument value "privilege_escalation", + # Literal argument value "defense_evasion", + # Literal argument value "credential_access", + # Literal argument value "discovery", + # Literal argument value "lateral_movement", + # Literal argument value "collection", + # Literal argument value "command_and_control", + # Literal argument value "exfiltration", + # Literal argument value "impact", ] +# Define class CampaignTest class CampaignTest(Base): - """ - A test within a campaign, with ordering and dependency information. + """A test within a campaign, with ordering and dependency information. ``depends_on`` creates a self-referential chain (A -> B -> C). Circular dependencies are validated at the service layer. """ + # Assign __tablename__ = "campaign_tests" __tablename__ = "campaign_tests" + # Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + # Assign campaign_id = Column( campaign_id = Column( UUID(as_uuid=True), ForeignKey("campaigns.id", ondelete="CASCADE"), + # Keyword argument: nullable nullable=False, ) + # Assign test_id = Column( test_id = Column( UUID(as_uuid=True), ForeignKey("tests.id", ondelete="CASCADE"), + # Keyword argument: nullable nullable=False, ) + # Assign order_index = Column(Integer, nullable=False, default=0) order_index = Column(Integer, nullable=False, default=0) + # Assign depends_on = Column( depends_on = Column( UUID(as_uuid=True), ForeignKey("campaign_tests.id", ondelete="SET NULL"), + # Keyword argument: nullable nullable=True, ) + # Assign phase = Column(String, nullable=True) # kill chain phase phase = Column(String, nullable=True) # kill chain phase # Relationships campaign = relationship("Campaign", back_populates="campaign_tests") + # Assign test = relationship("Test") test = relationship("Test") + # Assign dependency = relationship("CampaignTest", remote_side="CampaignTest.id") dependency = relationship("CampaignTest", remote_side="CampaignTest.id") + # Assign __table_args__ = ( __table_args__ = ( Index('ix_campaign_tests_campaign', 'campaign_id'), Index('ix_campaign_tests_test', 'test_id'), diff --git a/backend/app/models/compliance.py b/backend/app/models/compliance.py index 4ad2a8e..bffd282 100644 --- a/backend/app/models/compliance.py +++ b/backend/app/models/compliance.py @@ -4,92 +4,145 @@ Maps compliance frameworks (NIST 800-53, DORA, NIS2, ISO 27001) to MITRE ATT&CK techniques, enabling compliance gap analysis. """ +# Import uuid import uuid + +# Import from sqlalchemy from sqlalchemy import ( - Column, String, Text, Boolean, DateTime, - ForeignKey, Index, UniqueConstraint, func, + Boolean, + Column, + DateTime, + ForeignKey, + Index, + String, + Text, + UniqueConstraint, + func, ) + +# Import UUID from sqlalchemy.dialects.postgresql from sqlalchemy.dialects.postgresql import UUID + +# Import relationship from sqlalchemy.orm from sqlalchemy.orm import relationship +# Import Base from app.database from app.database import Base +# Define class ComplianceFramework class ComplianceFramework(Base): """A compliance framework (e.g. NIST 800-53, ISO 27001).""" + # Assign __tablename__ = "compliance_frameworks" __tablename__ = "compliance_frameworks" + # Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + # Assign name = Column(String, unique=True, nullable=False) name = Column(String, unique=True, nullable=False) + # Assign version = Column(String, nullable=True) version = Column(String, nullable=True) + # Assign description = Column(Text, nullable=True) description = Column(Text, nullable=True) + # Assign url = Column(String, nullable=True) url = Column(String, nullable=True) + # Assign is_active = Column(Boolean, default=True) is_active = Column(Boolean, default=True) + # Assign created_at = Column(DateTime(timezone=True), server_default=func.now()) created_at = Column(DateTime(timezone=True), server_default=func.now()) # Relationships controls = relationship( + # Literal argument value "ComplianceControl", + # Keyword argument: back_populates back_populates="framework", + # Keyword argument: cascade cascade="all, delete-orphan", ) +# Define class ComplianceControl class ComplianceControl(Base): """A control within a compliance framework (e.g. AC-2, PR.AC-1).""" + # Assign __tablename__ = "compliance_controls" __tablename__ = "compliance_controls" + # Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + # Assign framework_id = Column( framework_id = Column( UUID(as_uuid=True), ForeignKey("compliance_frameworks.id", ondelete="CASCADE"), + # Keyword argument: nullable nullable=False, ) + # Assign control_id = Column(String, nullable=False) # e.g. "AC-2" control_id = Column(String, nullable=False) # e.g. "AC-2" + # Assign title = Column(String, nullable=False) title = Column(String, nullable=False) + # Assign description = Column(Text, nullable=True) description = Column(Text, nullable=True) + # Assign category = Column(String, nullable=True) category = Column(String, nullable=True) # Relationships framework = relationship("ComplianceFramework", back_populates="controls") + # Assign technique_mappings = relationship( technique_mappings = relationship( + # Literal argument value "ComplianceControlMapping", + # Keyword argument: back_populates back_populates="compliance_control", + # Keyword argument: cascade cascade="all, delete-orphan", ) + # Assign __table_args__ = ( __table_args__ = ( Index('ix_compliance_controls_framework', 'framework_id'), ) +# Define class ComplianceControlMapping class ComplianceControlMapping(Base): """Maps a compliance control to a MITRE ATT&CK technique.""" + # Assign __tablename__ = "compliance_control_mappings" __tablename__ = "compliance_control_mappings" + # Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + # Assign compliance_control_id = Column( compliance_control_id = Column( UUID(as_uuid=True), ForeignKey("compliance_controls.id", ondelete="CASCADE"), + # Keyword argument: nullable nullable=False, ) + # Assign technique_id = Column( technique_id = Column( UUID(as_uuid=True), ForeignKey("techniques.id", ondelete="CASCADE"), + # Keyword argument: nullable nullable=False, ) # Relationships compliance_control = relationship( + # Literal argument value "ComplianceControl", back_populates="technique_mappings" ) + # Assign technique = relationship("Technique") technique = relationship("Technique") + # Assign __table_args__ = ( __table_args__ = ( Index('ix_compliance_mappings_control', 'compliance_control_id'), Index('ix_compliance_mappings_technique', 'technique_id'), UniqueConstraint( + # Literal argument value 'compliance_control_id', 'technique_id', + # Keyword argument: name name='uq_control_technique', ), ) diff --git a/backend/app/models/coverage_snapshot.py b/backend/app/models/coverage_snapshot.py index 42fed60..9397643 100644 --- a/backend/app/models/coverage_snapshot.py +++ b/backend/app/models/coverage_snapshot.py @@ -5,76 +5,125 @@ SnapshotTechniqueState stores per-technique state (normalized, one row per technique per snapshot) to avoid bloated JSONB fields. """ +# Import uuid import uuid + +# Import from sqlalchemy from sqlalchemy import ( - Column, String, Float, Integer, DateTime, - ForeignKey, Index, func, + Column, + DateTime, + Float, + ForeignKey, + Index, + Integer, + String, + func, ) + +# Import JSONB, UUID from sqlalchemy.dialects.postgresql from sqlalchemy.dialects.postgresql import JSONB, UUID + +# Import relationship from sqlalchemy.orm from sqlalchemy.orm import relationship +# Import Base from app.database from app.database import Base +# Define class CoverageSnapshot class CoverageSnapshot(Base): """A point-in-time snapshot of the organisation's overall coverage.""" + # Assign __tablename__ = "coverage_snapshots" __tablename__ = "coverage_snapshots" + # Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + # Assign name = Column(String, nullable=True) # e.g. "Pre-remediación Q1" name = Column(String, nullable=True) # e.g. "Pre-remediación Q1" + # Assign organization_score = Column(Float, nullable=False) organization_score = Column(Float, nullable=False) + # Assign total_techniques = Column(Integer, nullable=False) total_techniques = Column(Integer, nullable=False) + # Assign validated_count = Column(Integer, nullable=False) validated_count = Column(Integer, nullable=False) + # Assign partial_count = Column(Integer, nullable=False) partial_count = Column(Integer, nullable=False) + # Assign not_covered_count = Column(Integer, nullable=False) not_covered_count = Column(Integer, nullable=False) + # Assign in_progress_count = Column(Integer, nullable=False) in_progress_count = Column(Integer, nullable=False) + # Assign not_evaluated_count = Column(Integer, nullable=False) not_evaluated_count = Column(Integer, nullable=False) + # Assign coverage_percentage = Column(Float, nullable=False, default=0.0) coverage_percentage = Column(Float, nullable=False, default=0.0) + # Assign by_tactic = Column(JSONB, nullable=False, default=dict) by_tactic = Column(JSONB, nullable=False, default=dict) + # Assign by_status = Column(JSONB, nullable=False, default=dict) by_status = Column(JSONB, nullable=False, default=dict) + # Assign stale_count = Column(Integer, nullable=False, default=0) stale_count = Column(Integer, nullable=False, default=0) + # Assign never_tested_count = Column(Integer, nullable=False, default=0) never_tested_count = Column(Integer, nullable=False, default=0) + # Assign created_by = Column( created_by = Column( UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), + # Keyword argument: nullable nullable=True, ) + # Assign created_at = Column(DateTime(timezone=True), server_default=func.now()) created_at = Column(DateTime(timezone=True), server_default=func.now()) # Relationships creator = relationship("User", foreign_keys=[created_by]) + # Assign technique_states = relationship( technique_states = relationship( + # Literal argument value "SnapshotTechniqueState", + # Keyword argument: back_populates back_populates="snapshot", + # Keyword argument: cascade cascade="all, delete-orphan", ) +# Define class SnapshotTechniqueState class SnapshotTechniqueState(Base): """Per-technique state within a snapshot (normalised storage).""" + # Assign __tablename__ = "snapshot_technique_states" __tablename__ = "snapshot_technique_states" + # Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + # Assign snapshot_id = Column( snapshot_id = Column( UUID(as_uuid=True), ForeignKey("coverage_snapshots.id", ondelete="CASCADE"), + # Keyword argument: nullable nullable=False, ) + # Assign technique_id = Column( technique_id = Column( UUID(as_uuid=True), ForeignKey("techniques.id", ondelete="CASCADE"), + # Keyword argument: nullable nullable=False, ) + # Assign mitre_id = Column(String, nullable=False) # denormalised for fast queries mitre_id = Column(String, nullable=False) # denormalised for fast queries + # Assign status = Column(String, nullable=False) status = Column(String, nullable=False) + # Assign score = Column(Float, nullable=True) score = Column(Float, nullable=True) # Relationships snapshot = relationship("CoverageSnapshot", back_populates="technique_states") + # Assign technique = relationship("Technique") technique = relationship("Technique") + # Assign __table_args__ = ( __table_args__ = ( Index("ix_snapshot_technique_states_snapshot", "snapshot_id"), Index("ix_snapshot_technique_states_technique", "technique_id"), diff --git a/backend/app/models/data_source.py b/backend/app/models/data_source.py index 60e0da3..609a2cb 100644 --- a/backend/app/models/data_source.py +++ b/backend/app/models/data_source.py @@ -1,36 +1,56 @@ """DataSource model — registry of external data sources for import.""" +# Import uuid import uuid -from sqlalchemy import Column, String, Text, Boolean, DateTime, Index, func -from sqlalchemy.dialects.postgresql import UUID, JSONB +# Import Boolean, Column, DateTime, Index, String, Text,... from sqlalchemy +from sqlalchemy import Boolean, Column, DateTime, Index, String, Text, func + +# Import JSONB, UUID from sqlalchemy.dialects.postgresql +from sqlalchemy.dialects.postgresql import JSONB, UUID + +# Import Base from app.database from app.database import Base +# Define class DataSource class DataSource(Base): - """ - Unified registry of all external data sources (attack procedures, - detection rules, threat intel, defensive techniques). + """Unified registry of all external data sources. - Each source can be independently enabled/disabled and tracks its own - synchronisation state. + Covers attack procedures, detection rules, threat intel, and defensive techniques. + Each source can be independently enabled/disabled and tracks its own synchronisation state. """ + # Assign __tablename__ = "data_sources" __tablename__ = "data_sources" + # Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + # Assign name = Column(String, unique=True, nullable=False) # e.g. "atom... name = Column(String, unique=True, nullable=False) # e.g. "atomic_red_team" + # Assign display_name = Column(String, nullable=False) # e.g. "Atomic Red ... display_name = Column(String, nullable=False) # e.g. "Atomic Red Team" - type = Column(String, nullable=False) # attack_procedure / detection_rule / threat_intel / defensive_technique + # Values: attack_procedure / detection_rule / threat_intel / defensive_technique + type = Column(String, nullable=False) + # Assign url = Column(String, nullable=True) # URL base... url = Column(String, nullable=True) # URL base of repo/API + # Assign description = Column(Text, nullable=True) description = Column(Text, nullable=True) + # Assign is_enabled = Column(Boolean, default=True) is_enabled = Column(Boolean, default=True) + # Assign last_sync_at = Column(DateTime, nullable=True) last_sync_at = Column(DateTime, nullable=True) + # Assign last_sync_status = Column(String, nullable=True) # success / error / in_... last_sync_status = Column(String, nullable=True) # success / error / in_progress + # Assign last_sync_stats = Column(JSONB, nullable=True) # {"imported": X, "upd... last_sync_stats = Column(JSONB, nullable=True) # {"imported": X, "updated": Y, ...} + # Assign sync_frequency = Column(String, nullable=True) # daily / weekly / mo... sync_frequency = Column(String, nullable=True) # daily / weekly / monthly / manual + # Assign config = Column(JSONB, nullable=True) # source-spec... config = Column(JSONB, nullable=True) # source-specific configuration + # Assign created_at = Column(DateTime(timezone=True), server_default=func.now()) created_at = Column(DateTime(timezone=True), server_default=func.now()) + # Assign __table_args__ = ( __table_args__ = ( Index('ix_data_sources_type', 'type'), Index('ix_data_sources_is_enabled', 'is_enabled'), diff --git a/backend/app/models/defensive_technique.py b/backend/app/models/defensive_technique.py index afff7b7..1bdc2fd 100644 --- a/backend/app/models/defensive_technique.py +++ b/backend/app/models/defensive_technique.py @@ -4,74 +4,108 @@ Stores MITRE D3FEND defensive techniques and their mappings to ATT&CK techniques, enabling recommended countermeasure lookups. """ +# Import uuid import uuid + +# Import from sqlalchemy from sqlalchemy import ( - Column, String, Text, DateTime, - ForeignKey, Index, UniqueConstraint, func, + Column, + DateTime, + ForeignKey, + Index, + String, + Text, + UniqueConstraint, + func, ) + +# Import UUID from sqlalchemy.dialects.postgresql from sqlalchemy.dialects.postgresql import UUID + +# Import relationship from sqlalchemy.orm from sqlalchemy.orm import relationship +# Import Base from app.database from app.database import Base +# Define class DefensiveTechnique class DefensiveTechnique(Base): - """ - MITRE D3FEND defensive technique. + """MITRE D3FEND defensive technique. Represents a countermeasure from the D3FEND framework that can be mapped to one or more ATT&CK techniques via DefensiveTechniqueMapping. """ + # Assign __tablename__ = "defensive_techniques" __tablename__ = "defensive_techniques" + # Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + # Assign d3fend_id = Column(String, unique=True, nullable=False) # e.g. "D3-AL" d3fend_id = Column(String, unique=True, nullable=False) # e.g. "D3-AL" + # Assign name = Column(String, nullable=False) name = Column(String, nullable=False) + # Assign description = Column(Text, nullable=True) description = Column(Text, nullable=True) + # Assign tactic = Column(String, nullable=True) # Detect, ... tactic = Column(String, nullable=True) # Detect, Isolate, Deceive, Evict, etc. + # Assign d3fend_url = Column(String, nullable=True) d3fend_url = Column(String, nullable=True) + # Assign created_at = Column(DateTime(timezone=True), server_default=func.now()) created_at = Column(DateTime(timezone=True), server_default=func.now()) # Relationships attack_mappings = relationship( + # Literal argument value "DefensiveTechniqueMapping", + # Keyword argument: back_populates back_populates="defensive_technique", + # Keyword argument: cascade cascade="all, delete-orphan", ) + # Assign __table_args__ = ( __table_args__ = ( Index('ix_defensive_techniques_tactic', 'tactic'), ) +# Define class DefensiveTechniqueMapping class DefensiveTechniqueMapping(Base): - """ - Association between a MITRE ATT&CK technique and a D3FEND - defensive technique. - """ + """Association between a MITRE ATT&CK technique and a D3FEND defensive technique.""" + # Assign __tablename__ = "defensive_technique_mappings" __tablename__ = "defensive_technique_mappings" + # Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + # Assign attack_technique_id = Column( attack_technique_id = Column( UUID(as_uuid=True), ForeignKey("techniques.id", ondelete="CASCADE"), + # Keyword argument: nullable nullable=False, ) + # Assign defensive_technique_id = Column( defensive_technique_id = Column( UUID(as_uuid=True), ForeignKey("defensive_techniques.id", ondelete="CASCADE"), + # Keyword argument: nullable nullable=False, ) # Relationships attack_technique = relationship("Technique") + # Assign defensive_technique = relationship("DefensiveTechnique", back_populates="attack_mappings") defensive_technique = relationship("DefensiveTechnique", back_populates="attack_mappings") + # Assign __table_args__ = ( __table_args__ = ( Index('ix_dtm_attack_technique', 'attack_technique_id'), Index('ix_dtm_defensive_technique', 'defensive_technique_id'), UniqueConstraint( + # Literal argument value 'attack_technique_id', 'defensive_technique_id', + # Keyword argument: name name='uq_attack_defensive_technique', ), ) diff --git a/backend/app/models/detection_rule.py b/backend/app/models/detection_rule.py index c411415..d59a595 100644 --- a/backend/app/models/detection_rule.py +++ b/backend/app/models/detection_rule.py @@ -1,38 +1,61 @@ """DetectionRule model — detection rules from multiple sources.""" +# Import uuid import uuid -from sqlalchemy import Column, String, Text, Boolean, DateTime, Index, func -from sqlalchemy.dialects.postgresql import UUID, JSONB +# Import Boolean, Column, DateTime, Index, String, Text,... from sqlalchemy +from sqlalchemy import Boolean, Column, DateTime, Index, String, Text, func + +# Import JSONB, UUID from sqlalchemy.dialects.postgresql +from sqlalchemy.dialects.postgresql import JSONB, UUID + +# Import Base from app.database from app.database import Base +# Define class DetectionRule class DetectionRule(Base): - """ - Detection rule from an external source (Sigma, Elastic, Splunk, custom). + """Detection rule from an external source (Sigma, Elastic, Splunk, custom). Each rule is mapped to one MITRE ATT&CK technique via ``mitre_technique_id`` and stores the complete rule content in ``rule_content``. """ + # Assign __tablename__ = "detection_rules" __tablename__ = "detection_rules" + # Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + # Assign mitre_technique_id = Column(String, nullable=False) # e.g. "T1059.001" mitre_technique_id = Column(String, nullable=False) # e.g. "T1059.001" + # Assign title = Column(String, nullable=False) title = Column(String, nullable=False) + # Assign description = Column(Text, nullable=True) description = Column(Text, nullable=True) + # Assign source = Column(String, nullable=False) # sigma / ela... source = Column(String, nullable=False) # sigma / elastic / splunk / custom + # Assign source_id = Column(String, nullable=True) # ID in the sour... source_id = Column(String, nullable=True) # ID in the source repo (for dedup) + # Assign source_url = Column(String, nullable=True) source_url = Column(String, nullable=True) + # Assign rule_content = Column(Text, nullable=False) # YAML / KQL / SPL ... rule_content = Column(Text, nullable=False) # YAML / KQL / SPL content + # Assign rule_format = Column(String, nullable=False) # sigma_yaml / kql... rule_format = Column(String, nullable=False) # sigma_yaml / kql / spl / custom + # Assign severity = Column(String, nullable=True) # informational... severity = Column(String, nullable=True) # informational / low / medium / high / critical + # Assign platforms = Column(JSONB, nullable=True, default=[]) platforms = Column(JSONB, nullable=True, default=[]) + # Assign log_sources = Column(JSONB, nullable=True) # e.g. {"product":... log_sources = Column(JSONB, nullable=True) # e.g. {"product": "windows", "service": "sysmon"} + # Assign false_positive_rate = Column(String, nullable=True) # low / medium / high false_positive_rate = Column(String, nullable=True) # low / medium / high + # Assign is_active = Column(Boolean, default=True) is_active = Column(Boolean, default=True) + # Assign created_at = Column(DateTime(timezone=True), server_default=func.now()) created_at = Column(DateTime(timezone=True), server_default=func.now()) + # Assign __table_args__ = ( __table_args__ = ( Index('ix_detection_rules_mitre_technique_id', 'mitre_technique_id'), Index('ix_detection_rules_source', 'source'), diff --git a/backend/app/models/enums.py b/backend/app/models/enums.py index b941d3e..909110d 100644 --- a/backend/app/models/enums.py +++ b/backend/app/models/enums.py @@ -5,6 +5,7 @@ re-exports every enum so that existing model and router code keeps working with ``from app.models.enums import ...``. """ +# Import # noqa: F401 from app.domain.enums from app.domain.enums import ( # noqa: F401 DataClassification, TeamSide, diff --git a/backend/app/models/evidence.py b/backend/app/models/evidence.py index 0f87db2..df92165 100644 --- a/backend/app/models/evidence.py +++ b/backend/app/models/evidence.py @@ -1,35 +1,59 @@ +"""SQLAlchemy model for the evidence table.""" + +# Import uuid import uuid -from sqlalchemy import Column, String, Text, DateTime, ForeignKey, Enum, func + +# Import Column, DateTime, Enum, ForeignKey, String, Tex... from sqlalchemy +from sqlalchemy import Column, DateTime, Enum, ForeignKey, String, Text, func + +# Import UUID from sqlalchemy.dialects.postgresql from sqlalchemy.dialects.postgresql import UUID + +# Import relationship from sqlalchemy.orm from sqlalchemy.orm import relationship +# Import Base from app.database from app.database import Base + +# Import TeamSide from app.models.enums from app.models.enums import TeamSide +# Define class Evidence class Evidence(Base): - """ - Evidence model for storing file metadata associated with tests. - + """Evidence model for storing file metadata associated with tests. + Files are stored in MinIO, and this model tracks the file location, integrity hash, and upload metadata. - + The ``team`` field distinguishes whether this evidence was uploaded by Red Team (attack evidence) or Blue Team (detection evidence). """ + # Assign __tablename__ = "evidences" __tablename__ = "evidences" + # Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + # Assign test_id = Column(UUID(as_uuid=True), ForeignKey("tests.id"), nullable=False) test_id = Column(UUID(as_uuid=True), ForeignKey("tests.id"), nullable=False) + # Assign file_name = Column(String, nullable=False) file_name = Column(String, nullable=False) + # Assign file_path = Column(String, nullable=False) # Path in MinIO file_path = Column(String, nullable=False) # Path in MinIO + # Assign sha256_hash = Column(String, nullable=False) sha256_hash = Column(String, nullable=False) + # Assign uploaded_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True) uploaded_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True) + # Assign uploaded_at = Column(DateTime(timezone=True), server_default=func.now()) uploaded_at = Column(DateTime(timezone=True), server_default=func.now()) + # Assign team = Column(Enum(TeamSide, name="teamside"), nullable=False, default=Tea... team = Column(Enum(TeamSide, name="teamside"), nullable=False, default=TeamSide.red) + # Assign notes = Column(Text, nullable=True) notes = Column(Text, nullable=True) + # Assign data_classification = Column(String(20), nullable=False, server_default="internal") data_classification = Column(String(20), nullable=False, server_default="internal") # Relationships test = relationship("Test", back_populates="evidences") + # Assign uploader = relationship("User", foreign_keys=[uploaded_by]) uploader = relationship("User", foreign_keys=[uploaded_by]) diff --git a/backend/app/models/intel.py b/backend/app/models/intel.py index 69056a2..64ccada 100644 --- a/backend/app/models/intel.py +++ b/backend/app/models/intel.py @@ -1,26 +1,44 @@ +"""SQLAlchemy model for the intel_items table.""" + +# Import uuid import uuid -from sqlalchemy import Column, String, Boolean, DateTime, ForeignKey, func + +# Import Boolean, Column, DateTime, ForeignKey, String, ... from sqlalchemy +from sqlalchemy import Boolean, Column, DateTime, ForeignKey, String, func + +# Import UUID from sqlalchemy.dialects.postgresql from sqlalchemy.dialects.postgresql import UUID + +# Import relationship from sqlalchemy.orm from sqlalchemy.orm import relationship +# Import Base from app.database from app.database import Base +# Define class IntelItem class IntelItem(Base): - """ - Intelligence item model for tracking threat intelligence related to techniques. - + """Intelligence item model for tracking threat intelligence related to techniques. + Stores URLs and metadata from automated intel scans that may indicate new attack variations or detection bypasses for specific techniques. """ + # Assign __tablename__ = "intel_items" __tablename__ = "intel_items" + # Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + # Assign technique_id = Column(UUID(as_uuid=True), ForeignKey("techniques.id"), nullable=True) technique_id = Column(UUID(as_uuid=True), ForeignKey("techniques.id"), nullable=True) + # Assign url = Column(String, nullable=False) url = Column(String, nullable=False) + # Assign title = Column(String, nullable=True) title = Column(String, nullable=True) + # Assign source = Column(String, nullable=True) source = Column(String, nullable=True) + # Assign detected_at = Column(DateTime(timezone=True), server_default=func.now()) detected_at = Column(DateTime(timezone=True), server_default=func.now()) + # Assign reviewed = Column(Boolean, default=False) reviewed = Column(Boolean, default=False) # Relationships diff --git a/backend/app/models/jira_link.py b/backend/app/models/jira_link.py index 4f43728..8a6efd8 100644 --- a/backend/app/models/jira_link.py +++ b/backend/app/models/jira_link.py @@ -1,53 +1,99 @@ """Jira integration models — link Aegis entities to Jira issues.""" +# Import enum import enum + +# Import uuid import uuid -from sqlalchemy import Column, String, DateTime, ForeignKey, Enum as SQLEnum, Index, func -from sqlalchemy.dialects.postgresql import UUID, JSONB + +# Import Column, DateTime, ForeignKey, Index, String, func from sqlalchemy +from sqlalchemy import Column, DateTime, ForeignKey, Index, String, func + +# Import Enum as SQLEnum from sqlalchemy +from sqlalchemy import Enum as SQLEnum + +# Import JSONB, UUID from sqlalchemy.dialects.postgresql +from sqlalchemy.dialects.postgresql import JSONB, UUID + +# Import relationship from sqlalchemy.orm from sqlalchemy.orm import relationship +# Import Base from app.database from app.database import Base +# Define class JiraLinkEntityType class JiraLinkEntityType(str, enum.Enum): + """Aegis entity types that can be linked to a Jira issue.""" + + # Assign test = "test" test = "test" + # Assign technique = "technique" technique = "technique" + # Assign campaign = "campaign" campaign = "campaign" + # Assign evidence = "evidence" evidence = "evidence" +# Define class JiraSyncDirection class JiraSyncDirection(str, enum.Enum): + """Direction of synchronisation between Aegis and Jira.""" + + # Assign aegis_to_jira = "aegis_to_jira" aegis_to_jira = "aegis_to_jira" + # Assign jira_to_aegis = "jira_to_aegis" jira_to_aegis = "jira_to_aegis" + # Assign bidirectional = "bidirectional" bidirectional = "bidirectional" +# Define class JiraLink class JiraLink(Base): """Associates an Aegis entity with a Jira issue for bidirectional sync.""" + # Assign __tablename__ = "jira_links" __tablename__ = "jira_links" + # Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + # Assign entity_type = Column(SQLEnum(JiraLinkEntityType), nullable=False) entity_type = Column(SQLEnum(JiraLinkEntityType), nullable=False) + # Assign entity_id = Column(UUID(as_uuid=True), nullable=False) entity_id = Column(UUID(as_uuid=True), nullable=False) + # Assign jira_issue_key = Column(String(50), nullable=False) jira_issue_key = Column(String(50), nullable=False) + # Assign jira_issue_id = Column(String(50)) jira_issue_id = Column(String(50)) + # Assign jira_project_key = Column(String(20)) jira_project_key = Column(String(20)) + # Assign jira_status = Column(String(100)) jira_status = Column(String(100)) + # Assign jira_priority = Column(String(50)) jira_priority = Column(String(50)) + # Assign jira_assignee = Column(String(255)) jira_assignee = Column(String(255)) + # Assign jira_story_points = Column(String(10)) jira_story_points = Column(String(10)) + # Assign sync_direction = Column( sync_direction = Column( SQLEnum(JiraSyncDirection), default=JiraSyncDirection.bidirectional ) + # Assign last_synced_at = Column(DateTime) last_synced_at = Column(DateTime) + # Assign sync_metadata = Column(JSONB, default={}) sync_metadata = Column(JSONB, default={}) + # Assign created_by = Column(UUID(as_uuid=True), ForeignKey("users.id")) created_by = Column(UUID(as_uuid=True), ForeignKey("users.id")) + # Assign created_at = Column(DateTime(timezone=True), server_default=func.now()) created_at = Column(DateTime(timezone=True), server_default=func.now()) + # Assign updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate... updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now()) + # Assign creator = relationship("User", foreign_keys=[created_by]) creator = relationship("User", foreign_keys=[created_by]) + # Assign __table_args__ = ( __table_args__ = ( Index("ix_jira_links_entity_id", "entity_id"), Index("ix_jira_links_issue_key", "jira_issue_key"), diff --git a/backend/app/models/notification.py b/backend/app/models/notification.py index 17e30a3..4d0224b 100644 --- a/backend/app/models/notification.py +++ b/backend/app/models/notification.py @@ -1,35 +1,54 @@ """Notification model — in-app notifications for user actions.""" +# Import uuid import uuid -from sqlalchemy import Column, String, Text, Boolean, DateTime, ForeignKey, Index, func + +# Import Boolean, Column, DateTime, ForeignKey, Index, S... from sqlalchemy +from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Index, String, Text, func + +# Import UUID from sqlalchemy.dialects.postgresql from sqlalchemy.dialects.postgresql import UUID + +# Import relationship from sqlalchemy.orm from sqlalchemy.orm import relationship +# Import Base from app.database from app.database import Base +# Define class Notification class Notification(Base): - """ - In-app notification for alerting users when they need to act. + """In-app notification for alerting users when they need to act. Types include: test_assigned, validation_needed, test_rejected, test_validated, test_state_changed, etc. """ + # Assign __tablename__ = "notifications" __tablename__ = "notifications" + # Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + # Assign user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=False) user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=False) + # Assign type = Column(String, nullable=False) type = Column(String, nullable=False) + # Assign title = Column(String, nullable=False) title = Column(String, nullable=False) + # Assign message = Column(Text, nullable=True) message = Column(Text, nullable=True) + # Assign entity_type = Column(String, nullable=True) entity_type = Column(String, nullable=True) + # Assign entity_id = Column(UUID(as_uuid=True), nullable=True) entity_id = Column(UUID(as_uuid=True), nullable=True) + # Assign read = Column(Boolean, default=False) read = Column(Boolean, default=False) + # Assign created_at = Column(DateTime(timezone=True), server_default=func.now()) created_at = Column(DateTime(timezone=True), server_default=func.now()) # Relationships user = relationship("User") + # Assign __table_args__ = ( __table_args__ = ( Index("ix_notifications_user_id", "user_id"), Index("ix_notifications_read", "read"), diff --git a/backend/app/models/osint_item.py b/backend/app/models/osint_item.py index b8cea0a..f181264 100644 --- a/backend/app/models/osint_item.py +++ b/backend/app/models/osint_item.py @@ -1,37 +1,58 @@ """OSINT enrichment items — CVEs, blogs, PoCs, and advisories linked to techniques.""" +# Import uuid import uuid -from sqlalchemy import Column, String, Text, Boolean, DateTime, ForeignKey, func -from sqlalchemy.dialects.postgresql import UUID, JSONB + +# Import Boolean, Column, DateTime, ForeignKey, String, ... from sqlalchemy +from sqlalchemy import Boolean, Column, DateTime, ForeignKey, String, Text, func + +# Import JSONB, UUID from sqlalchemy.dialects.postgresql +from sqlalchemy.dialects.postgresql import JSONB, UUID + +# Import relationship from sqlalchemy.orm from sqlalchemy.orm import relationship +# Import Base from app.database from app.database import Base +# Define class OsintItem class OsintItem(Base): - """Represents an OSINT data point (CVE, blog, PoC, advisory) associated - with a MITRE ATT&CK technique. + """Represents an OSINT data point (CVE, blog, PoC, advisory) associated with a MITRE ATT&CK technique. Used by the enrichment pipeline to surface relevant threat intelligence for each technique, flagging those that need review. """ + # Assign __tablename__ = "osint_items" __tablename__ = "osint_items" + # Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + # Assign technique_id = Column( technique_id = Column( UUID(as_uuid=True), ForeignKey("techniques.id"), + # Keyword argument: nullable nullable=False, + # Keyword argument: index index=True, ) + # Assign source_type = Column(String(50), nullable=False) # "cve", "blog", "poc", "advisory" source_type = Column(String(50), nullable=False) # "cve", "blog", "poc", "advisory" + # Assign source_url = Column(Text, nullable=False) source_url = Column(Text, nullable=False) + # Assign title = Column(String(500), nullable=False) title = Column(String(500), nullable=False) + # Assign description = Column(Text, nullable=True) description = Column(Text, nullable=True) + # Assign severity = Column(String(20), nullable=True) # CRITICAL, HIGH, MEDIUM, LOW, U... severity = Column(String(20), nullable=True) # CRITICAL, HIGH, MEDIUM, LOW, UNKNOWN + # Assign discovered_at = Column(DateTime(timezone=True), server_default=func.now(), nullable... discovered_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) + # Assign reviewed = Column(Boolean, default=False) reviewed = Column(Boolean, default=False) + # Assign metadata_ = Column("metadata", JSONB, default={}) metadata_ = Column("metadata", JSONB, default={}) # ── Relationships ───────────────────────────────────────────────── diff --git a/backend/app/models/scoring_config.py b/backend/app/models/scoring_config.py index ff8133d..74c3505 100644 --- a/backend/app/models/scoring_config.py +++ b/backend/app/models/scoring_config.py @@ -1,25 +1,43 @@ """ScoringConfig — single-row table for persisted scoring weights.""" +# Import uuid import uuid -from sqlalchemy import Column, Float, DateTime, ForeignKey, func +# Import Column, DateTime, Float, ForeignKey, func from sqlalchemy +from sqlalchemy import Column, DateTime, Float, ForeignKey, func + +# Import UUID from sqlalchemy.dialects.postgresql from sqlalchemy.dialects.postgresql import UUID +# Import Base from app.database from app.database import Base +# Define class ScoringConfig class ScoringConfig(Base): + """Single-row table persisting the active scoring weight configuration.""" + + # Assign __tablename__ = "scoring_config" __tablename__ = "scoring_config" + # Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + # Assign weight_tests = Column(Float, nullable=False, default=40.0) weight_tests = Column(Float, nullable=False, default=40.0) + # Assign weight_detection_rules = Column(Float, nullable=False, default=25.0) weight_detection_rules = Column(Float, nullable=False, default=25.0) + # Assign weight_d3fend = Column(Float, nullable=False, default=15.0) weight_d3fend = Column(Float, nullable=False, default=15.0) + # Assign weight_recency = Column(Float, nullable=False, default=10.0) weight_recency = Column(Float, nullable=False, default=10.0) + # Assign weight_severity = Column(Float, nullable=False, default=10.0) weight_severity = Column(Float, nullable=False, default=10.0) + # Assign updated_by = Column( updated_by = Column( UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), + # Keyword argument: nullable nullable=True, ) + # Assign updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate... updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now()) diff --git a/backend/app/models/technique.py b/backend/app/models/technique.py index f2b9cb2..0e405b7 100644 --- a/backend/app/models/technique.py +++ b/backend/app/models/technique.py @@ -1,38 +1,63 @@ -import uuid -from datetime import datetime +"""SQLAlchemy model for the techniques table.""" -from sqlalchemy import Column, String, Text, Boolean, DateTime, Enum -from sqlalchemy.dialects.postgresql import UUID, JSONB +# Import uuid +import uuid + +# Import Boolean, Column, DateTime, Enum, String, Text from sqlalchemy +from sqlalchemy import Boolean, Column, DateTime, Enum, String, Text + +# Import JSONB, UUID from sqlalchemy.dialects.postgresql +from sqlalchemy.dialects.postgresql import JSONB, UUID + +# Import relationship from sqlalchemy.orm from sqlalchemy.orm import relationship +# Import Base from app.database from app.database import Base + +# Import TechniqueStatus from app.models.enums from app.models.enums import TechniqueStatus +# Define class Technique class Technique(Base): - """ - MITRE ATT&CK Technique model. - + """MITRE ATT&CK Technique model. + Represents an attack technique from the MITRE ATT&CK framework, including its coverage status and associated tests. """ + # Assign __tablename__ = "techniques" __tablename__ = "techniques" + # Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + # Assign mitre_id = Column(String, unique=True, nullable=False) # e.g., "T1059.001" mitre_id = Column(String, unique=True, nullable=False) # e.g., "T1059.001" + # Assign name = Column(String, nullable=False) name = Column(String, nullable=False) + # Assign description = Column(Text, nullable=True) description = Column(Text, nullable=True) + # Assign tactic = Column(String, nullable=True) tactic = Column(String, nullable=True) + # Assign platforms = Column(JSONB, nullable=True, default=[]) platforms = Column(JSONB, nullable=True, default=[]) + # Assign mitre_version = Column(String, nullable=True) mitre_version = Column(String, nullable=True) + # Assign mitre_last_modified = Column(DateTime, nullable=True) mitre_last_modified = Column(DateTime, nullable=True) + # Assign is_subtechnique = Column(Boolean, default=False) is_subtechnique = Column(Boolean, default=False) + # Assign parent_mitre_id = Column(String, nullable=True) parent_mitre_id = Column(String, nullable=True) + # Assign status_global = Column( status_global = Column( Enum(TechniqueStatus, name="techniquestatus"), + # Keyword argument: default default=TechniqueStatus.not_evaluated ) + # Assign review_required = Column(Boolean, default=False) review_required = Column(Boolean, default=False) + # Assign last_review_date = Column(DateTime, nullable=True) last_review_date = Column(DateTime, nullable=True) # Relationships diff --git a/backend/app/models/test.py b/backend/app/models/test.py index e200504..a1ecabe 100644 --- a/backend/app/models/test.py +++ b/backend/app/models/test.py @@ -1,80 +1,140 @@ +"""SQLAlchemy model for the tests table.""" + +# Import uuid import uuid -from sqlalchemy import Column, String, Text, Boolean, Integer, DateTime, ForeignKey, Enum, Index, func + +# Import from sqlalchemy +from sqlalchemy import ( + Boolean, + Column, + DateTime, + Enum, + ForeignKey, + Index, + Integer, + String, + Text, + func, +) + +# Import UUID from sqlalchemy.dialects.postgresql from sqlalchemy.dialects.postgresql import UUID + +# Import relationship from sqlalchemy.orm from sqlalchemy.orm import relationship +# Import Base from app.database from app.database import Base -from app.models.enums import TestState, TestResult + +# Import TestResult, TestState from app.models.enums +from app.models.enums import TestResult, TestState +# Define class Test class Test(Base): - """ - Test model representing a security test for a MITRE ATT&CK technique. + """Test model representing a security test for a MITRE ATT&CK technique. Each test documents an attempt to validate coverage of a specific technique, including the procedure, tools used, and outcome. V2 introduces dual validation: Red Lead and Blue Lead must each approve independently. """ + # Assign __tablename__ = "tests" __tablename__ = "tests" # ── Core fields ───────────────────────────────────────────────── id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + # Assign technique_id = Column(UUID(as_uuid=True), ForeignKey("techniques.id"), nullable=Fa... technique_id = Column(UUID(as_uuid=True), ForeignKey("techniques.id"), nullable=False) + # Assign name = Column(String, nullable=False) name = Column(String, nullable=False) + # Assign description = Column(Text, nullable=True) description = Column(Text, nullable=True) + # Assign platform = Column(String, nullable=True) platform = Column(String, nullable=True) + # Assign procedure_text = Column(Text, nullable=True) procedure_text = Column(Text, nullable=True) + # Assign tool_used = Column(String, nullable=True) tool_used = Column(String, nullable=True) + # Assign execution_date = Column(DateTime, nullable=True) execution_date = Column(DateTime, nullable=True) + # Assign created_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True) created_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True) + # Assign result = Column(Enum(TestResult, name="testresult"), nullable=True) result = Column(Enum(TestResult, name="testresult"), nullable=True) + # Assign state = Column(Enum(TestState, name="teststate"), default=TestState.draft) state = Column(Enum(TestState, name="teststate"), default=TestState.draft) + # Assign created_at = Column(DateTime(timezone=True), server_default=func.now()) created_at = Column(DateTime(timezone=True), server_default=func.now()) # ── Red Team fields ───────────────────────────────────────────── red_summary = Column(Text, nullable=True) + # Assign attack_success = Column(Boolean, nullable=True) attack_success = Column(Boolean, nullable=True) + # Assign red_validated_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True) red_validated_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True) + # Assign red_validated_at = Column(DateTime, nullable=True) red_validated_at = Column(DateTime, nullable=True) + # Assign red_validation_status = Column(String, nullable=True) # pending / approved / rejected red_validation_status = Column(String, nullable=True) # pending / approved / rejected + # Assign red_validation_notes = Column(Text, nullable=True) red_validation_notes = Column(Text, nullable=True) # ── Blue Team fields ──────────────────────────────────────────── blue_summary = Column(Text, nullable=True) + # Assign detection_result = Column(Enum(TestResult, name="testresult"), nullable=True) detection_result = Column(Enum(TestResult, name="testresult"), nullable=True) + # Assign blue_validated_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True) blue_validated_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True) + # Assign blue_validated_at = Column(DateTime, nullable=True) blue_validated_at = Column(DateTime, nullable=True) + # Assign blue_validation_status = Column(String, nullable=True) # pending / approved / rejected blue_validation_status = Column(String, nullable=True) # pending / approved / rejected + # Assign blue_validation_notes = Column(Text, nullable=True) blue_validation_notes = Column(Text, nullable=True) # ── Phase timing fields (for automatic Tempo worklogs) ────────── red_started_at = Column(DateTime, nullable=True) + # Assign blue_started_at = Column(DateTime, nullable=True) blue_started_at = Column(DateTime, nullable=True) blue_work_started_at = Column(DateTime, nullable=True) # when blue tech picks up (Tempo start) paused_at = Column(DateTime, nullable=True) + # Assign red_paused_seconds = Column(Integer, default=0) red_paused_seconds = Column(Integer, default=0) + # Assign blue_paused_seconds = Column(Integer, default=0) blue_paused_seconds = Column(Integer, default=0) # ── Remediation fields ─────────────────────────────────────────── remediation_steps = Column(Text, nullable=True) + # Assign remediation_status = Column(String, nullable=True) # pending / in_progress / completed ... remediation_status = Column(String, nullable=True) # pending / in_progress / completed / not_applicable + # Assign remediation_assignee = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True) remediation_assignee = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True) # ── Re-test fields ──────────────────────────────────────────── retest_of = Column(UUID(as_uuid=True), ForeignKey("tests.id"), nullable=True) + # Assign retest_count = Column(Integer, default=0) retest_count = Column(Integer, default=0) + # Assign data_classification = Column(String(20), nullable=False, server_default="internal") data_classification = Column(String(20), nullable=False, server_default="internal") # ── Relationships ─────────────────────────────────────────────── technique = relationship("Technique", back_populates="tests") + # Assign evidences = relationship("Evidence", back_populates="test") evidences = relationship("Evidence", back_populates="test") + # Assign creator = relationship("User", foreign_keys=[created_by]) creator = relationship("User", foreign_keys=[created_by]) + # Assign red_validator = relationship("User", foreign_keys=[red_validated_by]) red_validator = relationship("User", foreign_keys=[red_validated_by]) + # Assign blue_validator = relationship("User", foreign_keys=[blue_validated_by]) blue_validator = relationship("User", foreign_keys=[blue_validated_by]) + # Assign remediation_user = relationship("User", foreign_keys=[remediation_assignee]) remediation_user = relationship("User", foreign_keys=[remediation_assignee]) + # Assign original_test = relationship("Test", remote_side="Test.id", foreign_keys=[retest_of]) original_test = relationship("Test", remote_side="Test.id", foreign_keys=[retest_of]) + # Assign retests = relationship("Test", foreign_keys=[retest_of], back_populates="orig... retests = relationship("Test", foreign_keys=[retest_of], back_populates="original_test") + # Assign __table_args__ = ( __table_args__ = ( Index("ix_tests_technique_id", "technique_id"), Index("ix_tests_state", "state"), diff --git a/backend/app/models/test_detection_result.py b/backend/app/models/test_detection_result.py index 2897bbf..dff15b8 100644 --- a/backend/app/models/test_detection_result.py +++ b/backend/app/models/test_detection_result.py @@ -4,51 +4,79 @@ When the Blue Team evaluates a test, they mark each associated detection rule as triggered / not triggered / not applicable, along with notes. """ +# Import uuid import uuid -from datetime import datetime -from sqlalchemy import Column, String, Text, Boolean, DateTime, ForeignKey, Index, UniqueConstraint +# Import from sqlalchemy +from sqlalchemy import ( + Boolean, + Column, + DateTime, + ForeignKey, + Index, + Text, + UniqueConstraint, +) + +# Import UUID from sqlalchemy.dialects.postgresql from sqlalchemy.dialects.postgresql import UUID + +# Import relationship from sqlalchemy.orm from sqlalchemy.orm import relationship +# Import Base from app.database from app.database import Base +# Define class TestDetectionResult class TestDetectionResult(Base): - """ - Per-test, per-rule evaluation result. + """Per-test, per-rule evaluation result. - ``triggered`` = True: rule detected the attack - ``triggered`` = False: rule did NOT detect the attack - ``triggered`` = None: not yet evaluated """ + # Assign __tablename__ = "test_detection_results" __tablename__ = "test_detection_results" + # Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + # Assign test_id = Column( test_id = Column( UUID(as_uuid=True), ForeignKey("tests.id", ondelete="CASCADE"), + # Keyword argument: nullable nullable=False, ) + # Assign detection_rule_id = Column( detection_rule_id = Column( UUID(as_uuid=True), ForeignKey("detection_rules.id", ondelete="CASCADE"), + # Keyword argument: nullable nullable=False, ) + # Assign triggered = Column(Boolean, nullable=True) # None = not evaluated triggered = Column(Boolean, nullable=True) # None = not evaluated + # Assign notes = Column(Text, nullable=True) notes = Column(Text, nullable=True) + # Assign evaluated_by = Column( evaluated_by = Column( UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), + # Keyword argument: nullable nullable=True, ) + # Assign evaluated_at = Column(DateTime, nullable=True) evaluated_at = Column(DateTime, nullable=True) # Relationships test = relationship("Test") + # Assign detection_rule = relationship("DetectionRule") detection_rule = relationship("DetectionRule") + # Assign evaluator = relationship("User", foreign_keys=[evaluated_by]) evaluator = relationship("User", foreign_keys=[evaluated_by]) + # Assign __table_args__ = ( __table_args__ = ( Index('ix_tdr_test', 'test_id'), Index('ix_tdr_rule', 'detection_rule_id'), diff --git a/backend/app/models/test_template.py b/backend/app/models/test_template.py index 262034b..af87194 100644 --- a/backend/app/models/test_template.py +++ b/backend/app/models/test_template.py @@ -1,15 +1,21 @@ """TestTemplate model — predefined test catalog entries.""" +# Import uuid import uuid -from sqlalchemy import Column, String, Text, Boolean, DateTime, Index, func + +# Import Boolean, Column, DateTime, Index, String, Text,... from sqlalchemy +from sqlalchemy import Boolean, Column, DateTime, Index, String, Text, func + +# Import UUID from sqlalchemy.dialects.postgresql from sqlalchemy.dialects.postgresql import UUID +# Import Base from app.database from app.database import Base +# Define class TestTemplate class TestTemplate(Base): - """ - Predefined test template mapped to a MITRE ATT&CK technique. + """Predefined test template mapped to a MITRE ATT&CK technique. Templates come from several sources: - **atomic_red_team**: Atomic Red Team by Red Canary @@ -18,24 +24,41 @@ class TestTemplate(Base): Users can instantiate a real Test from a template. """ + # Assign __tablename__ = "test_templates" __tablename__ = "test_templates" + # Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + # Assign mitre_technique_id = Column(String, nullable=False) # e.g. "T1059.001" mitre_technique_id = Column(String, nullable=False) # e.g. "T1059.001" + # Assign name = Column(String, nullable=False) name = Column(String, nullable=False) + # Assign description = Column(Text, nullable=True) description = Column(Text, nullable=True) + # Assign source = Column(String, nullable=False) # atomic_red_te... source = Column(String, nullable=False) # atomic_red_team / mitre / custom + # Assign source_url = Column(String, nullable=True) source_url = Column(String, nullable=True) + # Assign attack_procedure = Column(Text, nullable=True) # Suggested attack procedure attack_procedure = Column(Text, nullable=True) # Suggested attack procedure + # Assign expected_detection = Column(Text, nullable=True) # What blue team should detect expected_detection = Column(Text, nullable=True) # What blue team should detect + # Assign platform = Column(String, nullable=True) # windows / linux... platform = Column(String, nullable=True) # windows / linux / macos + # Assign tool_suggested = Column(String, nullable=True) tool_suggested = Column(String, nullable=True) + # Assign severity = Column(String, nullable=True) # low / medium / ... severity = Column(String, nullable=True) # low / medium / high / critical + # Assign atomic_test_id = Column(String, nullable=True) # ID in Atomic Red Team... atomic_test_id = Column(String, nullable=True) # ID in Atomic Red Team repo + # Assign suggested_remediation = Column(Text, nullable=True) suggested_remediation = Column(Text, nullable=True) + # Assign is_active = Column(Boolean, default=True) is_active = Column(Boolean, default=True) + # Assign created_at = Column(DateTime(timezone=True), server_default=func.now()) created_at = Column(DateTime(timezone=True), server_default=func.now()) + # Assign __table_args__ = ( __table_args__ = ( Index('ix_test_templates_mitre_technique_id', 'mitre_technique_id'), Index('ix_test_templates_source', 'source'), diff --git a/backend/app/models/test_template_detection_rule.py b/backend/app/models/test_template_detection_rule.py index 8380821..69f0b23 100644 --- a/backend/app/models/test_template_detection_rule.py +++ b/backend/app/models/test_template_detection_rule.py @@ -4,47 +4,64 @@ Enables the Blue Team to see which detection rules should fire for a given test template / attack procedure. """ +# Import uuid import uuid -from datetime import datetime -from sqlalchemy import Column, Boolean, ForeignKey, Index, UniqueConstraint +# Import Boolean, Column, ForeignKey, Index, UniqueConst... from sqlalchemy +from sqlalchemy import Boolean, Column, ForeignKey, Index, UniqueConstraint + +# Import UUID from sqlalchemy.dialects.postgresql from sqlalchemy.dialects.postgresql import UUID + +# Import relationship from sqlalchemy.orm from sqlalchemy.orm import relationship +# Import Base from app.database from app.database import Base +# Define class TestTemplateDetectionRule class TestTemplateDetectionRule(Base): - """ - Association between a test template and a detection rule. + """Association between a test template and a detection rule. Auto-generated by matching mitre_technique_id, or manually curated. ``is_primary`` marks rules with severity >= high as primary detections. """ + # Assign __tablename__ = "test_template_detection_rules" __tablename__ = "test_template_detection_rules" + # Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + # Assign test_template_id = Column( test_template_id = Column( UUID(as_uuid=True), ForeignKey("test_templates.id", ondelete="CASCADE"), + # Keyword argument: nullable nullable=True, ) + # Assign detection_rule_id = Column( detection_rule_id = Column( UUID(as_uuid=True), ForeignKey("detection_rules.id", ondelete="CASCADE"), + # Keyword argument: nullable nullable=False, ) + # Assign is_primary = Column(Boolean, default=False) is_primary = Column(Boolean, default=False) # Relationships test_template = relationship("TestTemplate") + # Assign detection_rule = relationship("DetectionRule") detection_rule = relationship("DetectionRule") + # Assign __table_args__ = ( __table_args__ = ( Index('ix_ttdr_template', 'test_template_id'), Index('ix_ttdr_rule', 'detection_rule_id'), UniqueConstraint( + # Literal argument value 'test_template_id', 'detection_rule_id', + # Keyword argument: name name='uq_template_detection_rule', ), ) diff --git a/backend/app/models/threat_actor.py b/backend/app/models/threat_actor.py index 5e1b6cd..1d8caff 100644 --- a/backend/app/models/threat_actor.py +++ b/backend/app/models/threat_actor.py @@ -4,87 +4,135 @@ Stores profiles of APT groups and their associated MITRE ATT&CK techniques, imported from MITRE CTI (STIX 2.0). """ +# Import uuid import uuid + +# Import from sqlalchemy from sqlalchemy import ( - Column, String, Text, Boolean, DateTime, - ForeignKey, Index, UniqueConstraint, func, + Boolean, + Column, + DateTime, + ForeignKey, + Index, + String, + Text, + UniqueConstraint, + func, ) -from sqlalchemy.dialects.postgresql import UUID, JSONB + +# Import JSONB, UUID from sqlalchemy.dialects.postgresql +from sqlalchemy.dialects.postgresql import JSONB, UUID + +# Import relationship from sqlalchemy.orm from sqlalchemy.orm import relationship +# Import Base from app.database from app.database import Base +# Define class ThreatActor class ThreatActor(Base): - """ - Threat actor / APT group profile. + """Threat actor / APT group profile. Imported from MITRE CTI ``intrusion-set`` STIX objects. """ + # Assign __tablename__ = "threat_actors" __tablename__ = "threat_actors" + # Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + # Assign mitre_id = Column(String, unique=True, nullable=True) # e.g. "G00... mitre_id = Column(String, unique=True, nullable=True) # e.g. "G0016" (APT29) + # Assign name = Column(String, nullable=False) name = Column(String, nullable=False) + # Assign aliases = Column(JSONB, nullable=True, default=[]) # ["Cozy ... aliases = Column(JSONB, nullable=True, default=[]) # ["Cozy Bear", "The Dukes", ...] + # Assign description = Column(Text, nullable=True) description = Column(Text, nullable=True) + # Assign country = Column(String, nullable=True) country = Column(String, nullable=True) + # Assign target_sectors = Column(JSONB, nullable=True, default=[]) # ["government",... target_sectors = Column(JSONB, nullable=True, default=[]) # ["government", "defense", ...] + # Assign target_regions = Column(JSONB, nullable=True, default=[]) # ["north-americ... target_regions = Column(JSONB, nullable=True, default=[]) # ["north-america", "europe", ...] + # Assign motivation = Column(String, nullable=True) # espionage ... motivation = Column(String, nullable=True) # espionage / financial / destruction / ... + # Assign sophistication = Column(String, nullable=True) # low / medium /... sophistication = Column(String, nullable=True) # low / medium / high / advanced + # Assign first_seen = Column(String, nullable=True) first_seen = Column(String, nullable=True) + # Assign last_seen = Column(String, nullable=True) last_seen = Column(String, nullable=True) + # Assign references = Column(JSONB, nullable=True, default=[]) # [{"url": "... references = Column(JSONB, nullable=True, default=[]) # [{"url": "...", "description": "..."}] + # Assign mitre_url = Column(String, nullable=True) mitre_url = Column(String, nullable=True) + # Assign is_active = Column(Boolean, default=True) is_active = Column(Boolean, default=True) + # Assign created_at = Column(DateTime(timezone=True), server_default=func.now()) created_at = Column(DateTime(timezone=True), server_default=func.now()) # Relationships techniques = relationship( + # Literal argument value "ThreatActorTechnique", + # Keyword argument: back_populates back_populates="threat_actor", + # Keyword argument: cascade cascade="all, delete-orphan", ) + # Assign __table_args__ = ( __table_args__ = ( Index('ix_threat_actors_country', 'country'), Index('ix_threat_actors_motivation', 'motivation'), ) +# Define class ThreatActorTechnique class ThreatActorTechnique(Base): - """ - Association between a threat actor and a MITRE ATT&CK technique. + """Association between a threat actor and a MITRE ATT&CK technique. Stores additional context about how the actor uses the technique (from the STIX ``relationship`` ``uses`` objects). """ + # Assign __tablename__ = "threat_actor_techniques" __tablename__ = "threat_actor_techniques" + # Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + # Assign threat_actor_id = Column( threat_actor_id = Column( UUID(as_uuid=True), ForeignKey("threat_actors.id", ondelete="CASCADE"), + # Keyword argument: nullable nullable=False, ) + # Assign technique_id = Column( technique_id = Column( UUID(as_uuid=True), ForeignKey("techniques.id", ondelete="CASCADE"), + # Keyword argument: nullable nullable=False, ) + # Assign usage_description = Column(Text, nullable=True) usage_description = Column(Text, nullable=True) + # Assign first_seen_using = Column(String, nullable=True) first_seen_using = Column(String, nullable=True) # Relationships threat_actor = relationship("ThreatActor", back_populates="techniques") + # Assign technique = relationship("Technique") technique = relationship("Technique") + # Assign __table_args__ = ( __table_args__ = ( Index('ix_threat_actor_techniques_actor', 'threat_actor_id'), Index('ix_threat_actor_techniques_technique', 'technique_id'), UniqueConstraint( + # Literal argument value 'threat_actor_id', 'technique_id', + # Keyword argument: name name='uq_actor_technique', ), ) diff --git a/backend/app/models/user.py b/backend/app/models/user.py index 6436edc..1cc2367 100644 --- a/backend/app/models/user.py +++ b/backend/app/models/user.py @@ -1,14 +1,18 @@ +"""SQLAlchemy model for the users table.""" + +# Import uuid import uuid from sqlalchemy import Column, String, Boolean, DateTime, func from sqlalchemy.dialects.postgresql import UUID, JSONB +# Import Base from app.database from app.database import Base +# Define class User class User(Base): - """ - User model for authentication and authorization. - + """User model for authentication and authorization. + Possible roles: - admin: Full system access - red_tech: Red team technician - can create and edit tests @@ -17,16 +21,26 @@ class User(Base): - blue_lead: Blue team lead - can validate tests - viewer: Read-only access (default) """ + # Assign __tablename__ = "users" __tablename__ = "users" + # Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + # Assign username = Column(String, unique=True, nullable=False) username = Column(String, unique=True, nullable=False) + # Assign email = Column(String, nullable=True) email = Column(String, nullable=True) + # Assign hashed_password = Column(String, nullable=False) hashed_password = Column(String, nullable=False) + # Assign role = Column(String, nullable=False, default="viewer") role = Column(String, nullable=False, default="viewer") + # Assign is_active = Column(Boolean, default=True) is_active = Column(Boolean, default=True) + # Assign must_change_password = Column(Boolean, default=True) must_change_password = Column(Boolean, default=True) + # Assign created_at = Column(DateTime(timezone=True), server_default=func.now()) created_at = Column(DateTime(timezone=True), server_default=func.now()) + # Assign last_login = Column(DateTime, nullable=True) last_login = Column(DateTime, nullable=True) notification_preferences = Column(JSONB, nullable=True, server_default='{"email_on_test_validated": true, "email_on_campaign_completed": true, "email_on_new_mitre_techniques": false, "in_app_all": true}') jira_account_id = Column(String(100), nullable=True) diff --git a/backend/app/models/worklog.py b/backend/app/models/worklog.py index 439b1d3..c55518c 100644 --- a/backend/app/models/worklog.py +++ b/backend/app/models/worklog.py @@ -1,13 +1,22 @@ """Worklog model — immutable internal time-tracking records.""" +# Import uuid import uuid -from sqlalchemy import Column, String, Integer, DateTime, ForeignKey, Text, Index, func -from sqlalchemy.dialects.postgresql import UUID, JSONB + +# Import Column, DateTime, ForeignKey, Index, Integer, S... from sqlalchemy +from sqlalchemy import Column, DateTime, ForeignKey, Index, Integer, String, Text, func + +# Import JSONB, UUID from sqlalchemy.dialects.postgresql +from sqlalchemy.dialects.postgresql import JSONB, UUID + +# Import relationship from sqlalchemy.orm from sqlalchemy.orm import relationship +# Import Base from app.database from app.database import Base +# Define class Worklog class Worklog(Base): """Internal worklog entry with integrity hash for audit compliance. @@ -16,25 +25,42 @@ class Worklog(Base): the immutable fields so tampering can be detected. """ + # Assign __tablename__ = "worklogs" __tablename__ = "worklogs" + # Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + # Assign entity_type = Column(String(50), nullable=False) entity_type = Column(String(50), nullable=False) + # Assign entity_id = Column(UUID(as_uuid=True), nullable=False) entity_id = Column(UUID(as_uuid=True), nullable=False) + # Assign user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=False) user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=False) + # Assign activity_type = Column(String(100), nullable=False) activity_type = Column(String(100), nullable=False) + # Assign started_at = Column(DateTime, nullable=False) started_at = Column(DateTime, nullable=False) + # Assign ended_at = Column(DateTime) ended_at = Column(DateTime) + # Assign duration_seconds = Column(Integer, nullable=False) duration_seconds = Column(Integer, nullable=False) + # Assign description = Column(Text) description = Column(Text) + # Assign tempo_synced = Column(DateTime) tempo_synced = Column(DateTime) + # Assign tempo_worklog_id = Column(String(100)) tempo_worklog_id = Column(String(100)) + # Assign integrity_hash = Column(String(64)) integrity_hash = Column(String(64)) + # Assign created_at = Column(DateTime(timezone=True), server_default=func.now()) created_at = Column(DateTime(timezone=True), server_default=func.now()) + # Assign extra_metadata = Column("metadata", JSONB, default={}) extra_metadata = Column("metadata", JSONB, default={}) + # Assign user = relationship("User", foreign_keys=[user_id]) user = relationship("User", foreign_keys=[user_id]) + # Assign __table_args__ = ( __table_args__ = ( Index("ix_worklogs_entity_id", "entity_id"), Index("ix_worklogs_user_id", "user_id"), diff --git a/backend/app/routers/__init__.py b/backend/app/routers/__init__.py index e69de29..2d55925 100644 --- a/backend/app/routers/__init__.py +++ b/backend/app/routers/__init__.py @@ -0,0 +1 @@ +"""FastAPI router modules — one router per feature domain.""" diff --git a/backend/app/routers/advanced_metrics.py b/backend/app/routers/advanced_metrics.py index 7308d23..1ad4b5a 100644 --- a/backend/app/routers/advanced_metrics.py +++ b/backend/app/routers/advanced_metrics.py @@ -1,50 +1,81 @@ """Advanced metrics endpoints — coverage by tactic, never-tested, avg validation time.""" +# Import APIRouter, Depends from fastapi from fastapi import APIRouter, Depends + +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session +# Import get_db from app.database from app.database import get_db + +# Import get_current_user from app.dependencies.auth from app.dependencies.auth import get_current_user + +# Import User from app.models.user from app.models.user import User + +# Import advanced_metrics_service from app.services from app.services import advanced_metrics_service +# Assign router = APIRouter(prefix="/metrics/advanced", tags=["advanced-metrics"]) router = APIRouter(prefix="/metrics/advanced", tags=["advanced-metrics"]) +# Apply the @router.get decorator @router.get("/coverage-by-tactic") +# Define function coverage_by_tactic def coverage_by_tactic( + # Entry: db db: Session = Depends(get_db), + # Entry: user user: User = Depends(get_current_user), -): +) -> list: """Coverage percentage broken down by MITRE ATT&CK tactic.""" + # Return advanced_metrics_service.get_coverage_by_tactic(db) return advanced_metrics_service.get_coverage_by_tactic(db) +# Apply the @router.get decorator @router.get("/never-tested") +# Define function never_tested_techniques def never_tested_techniques( + # Entry: db db: Session = Depends(get_db), + # Entry: user user: User = Depends(get_current_user), -): +) -> list: """Techniques that have never had a test created.""" + # Return advanced_metrics_service.get_never_tested_techniques(db) return advanced_metrics_service.get_never_tested_techniques(db) +# Apply the @router.get decorator @router.get("/avg-validation-time") +# Define function avg_validation_time def avg_validation_time( + # Entry: db db: Session = Depends(get_db), + # Entry: user user: User = Depends(get_current_user), -): +) -> dict: """Average time from test creation to validation, computed from audit logs. Returns overall average and per-phase averages where data is available. """ + # Return advanced_metrics_service.get_avg_validation_time(db) return advanced_metrics_service.get_avg_validation_time(db) +# Apply the @router.get decorator @router.get("/detection-rate-trend") +# Define function detection_rate_trend def detection_rate_trend( + # Entry: db db: Session = Depends(get_db), + # Entry: user user: User = Depends(get_current_user), -): +) -> list: """Monthly detection rate trend for the last 12 months.""" + # Return advanced_metrics_service.get_detection_rate_trend(db) return advanced_metrics_service.get_detection_rate_trend(db) diff --git a/backend/app/routers/analytics.py b/backend/app/routers/analytics.py index 4997888..0c84b08 100644 --- a/backend/app/routers/analytics.py +++ b/backend/app/routers/analytics.py @@ -4,52 +4,85 @@ Returns complete datasets without pagination so BI tools can ingest directly from URL. All endpoints require authentication. """ +# Import APIRouter, Depends, Query from fastapi from fastapi import APIRouter, Depends, Query + +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session +# Import get_db from app.database from app.database import get_db + +# Import get_current_user, require_role from app.dependencies.auth from app.dependencies.auth import get_current_user, require_role + +# Import User from app.models.user from app.models.user import User + +# Import analytics_service from app.services from app.services import analytics_service +# Assign router = APIRouter(prefix="/analytics", tags=["analytics"]) router = APIRouter(prefix="/analytics", tags=["analytics"]) +# Apply the @router.get decorator @router.get("/coverage") +# Define function analytics_coverage def analytics_coverage( + # Entry: db db: Session = Depends(get_db), + # Entry: user user: User = Depends(get_current_user), -): +) -> list: """Coverage per technique — flat format for BI dashboards.""" + # Return analytics_service.get_coverage_analytics(db) return analytics_service.get_coverage_analytics(db) +# Apply the @router.get decorator @router.get("/tests") +# Define function analytics_tests def analytics_tests( + # Entry: date_from date_from: str = Query(None, description="ISO date filter (>=)"), + # Entry: date_to date_to: str = Query(None, description="ISO date filter (<=)"), + # Entry: db db: Session = Depends(get_db), + # Entry: user user: User = Depends(get_current_user), -): +) -> list: """All tests with timestamps — flat format for BI dashboards.""" + # Return analytics_service.get_tests_analytics( return analytics_service.get_tests_analytics( db, date_from=date_from, date_to=date_to ) +# Apply the @router.get decorator @router.get("/trends") +# Define function analytics_trends def analytics_trends( + # Entry: db db: Session = Depends(get_db), + # Entry: user user: User = Depends(get_current_user), -): +) -> list: """Historical coverage snapshots for trend visualization.""" + # Return analytics_service.get_trends_analytics(db) return analytics_service.get_trends_analytics(db) +# Apply the @router.get decorator @router.get("/operators") +# Define function analytics_operators def analytics_operators( + # Entry: db db: Session = Depends(get_db), + # Entry: user user: User = Depends(require_role("admin")), -): +) -> list: """Per-operator metrics — for workload management dashboards.""" + # Return analytics_service.get_operators_analytics(db) return analytics_service.get_operators_analytics(db) diff --git a/backend/app/routers/audit.py b/backend/app/routers/audit.py index 0dd257b..d47af70 100644 --- a/backend/app/routers/audit.py +++ b/backend/app/routers/audit.py @@ -1,77 +1,127 @@ """Audit log viewer router (admin only).""" +# Import datetime from datetime from datetime import datetime + +# Import Optional from typing from typing import Optional +# Import APIRouter, Depends, Query from fastapi from fastapi import APIRouter, Depends, Query + +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session +# Import get_db from app.database from app.database import get_db + +# Import require_role from app.dependencies.auth from app.dependencies.auth import require_role + +# Import User from app.models.user from app.models.user import User + +# Import AuditLogOut, AuditLogPage from app.schemas.audit from app.schemas.audit import AuditLogOut, AuditLogPage + +# Import from app.services.audit_query_service from app.services.audit_query_service import ( list_distinct_actions, list_distinct_entity_types, list_logs, ) +# Assign router = APIRouter(prefix="/audit-logs", tags=["audit"]) router = APIRouter(prefix="/audit-logs", tags=["audit"]) +# Apply the @router.get decorator @router.get("", response_model=AuditLogPage) +# Define function list_audit_logs def list_audit_logs( + # Entry: user_id user_id: Optional[str] = Query(None, description="Filter by user ID"), + # Entry: action action: Optional[str] = Query(None, description="Filter by action type"), + # Entry: entity_type entity_type: Optional[str] = Query(None, description="Filter by entity type"), + # Entry: start_date start_date: Optional[datetime] = Query(None, description="Filter by start date"), + # Entry: end_date end_date: Optional[datetime] = Query(None, description="Filter by end date"), + # Entry: offset offset: int = Query(0, ge=0, description="Number of records to skip"), + # Entry: limit limit: int = Query(50, ge=1, le=100, description="Max records to return"), + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(require_role("admin")), -): +) -> AuditLogPage: """Return paginated audit logs with optional filters. **Requires admin role.** """ + # Assign result = list_logs( result = list_logs( db, + # Keyword argument: user_id user_id=user_id, + # Keyword argument: action action=action, + # Keyword argument: entity_type entity_type=entity_type, + # Keyword argument: start_date start_date=start_date, + # Keyword argument: end_date end_date=end_date, + # Keyword argument: offset offset=offset, + # Keyword argument: limit limit=limit, ) + # Return AuditLogPage( return AuditLogPage( + # Keyword argument: items items=[AuditLogOut(**item) for item in result["items"]], + # Keyword argument: total total=result["total"], + # Keyword argument: offset offset=result["offset"], + # Keyword argument: limit limit=result["limit"], ) +# Apply the @router.get decorator @router.get("/actions", response_model=list[str]) +# Define function list_actions def list_actions( + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(require_role("admin")), -): +) -> list[str]: """Return a list of distinct action types in the audit log. **Requires admin role.** """ + # Return list_distinct_actions(db) return list_distinct_actions(db) +# Apply the @router.get decorator @router.get("/entity-types", response_model=list[str]) +# Define function list_entity_types def list_entity_types( + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(require_role("admin")), -): +) -> list[str]: """Return a list of distinct entity types in the audit log. **Requires admin role.** """ + # Return list_distinct_entity_types(db) return list_distinct_entity_types(db) diff --git a/backend/app/routers/auth.py b/backend/app/routers/auth.py index f4d1e90..ab4fcce 100644 --- a/backend/app/routers/auth.py +++ b/backend/app/routers/auth.py @@ -7,31 +7,68 @@ the token in the body for backwards compatibility and for clients that cannot use cookies (e.g. Swagger UI). """ +# Import os import os +# Import APIRouter, Cookie, Depends, Request, Response from fastapi from fastapi import APIRouter, Cookie, Depends, Request, Response + +# Import OAuth2PasswordRequestForm from fastapi.security from fastapi.security import OAuth2PasswordRequestForm + +# Import jwt (PyJWT) +import jwt + +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session -from jose import jwt, JWTError +# Import blacklist_token, create_access_token, verify_pa... from app.auth +from app.auth import blacklist_token, create_access_token, verify_password -from app.auth import create_access_token, blacklist_token, verify_password +# Import settings from app.config from app.config import settings + +# Import get_db from app.database from app.database import get_db + +# Import get_current_user from app.dependencies.auth from app.dependencies.auth import get_current_user + +# Import BusinessRuleViolation, PermissionViolation from app.domain.errors from app.domain.errors import BusinessRuleViolation, PermissionViolation + +# Import UnitOfWork from app.domain.unit_of_work from app.domain.unit_of_work import UnitOfWork + +# Import limiter from app.limiter from app.limiter import limiter + +# Import resolve_client_ip from app.middleware.request_context from app.middleware.request_context import resolve_client_ip + +# Import User from app.models.user from app.models.user import User -from app.services.auth_service import ( - _DUMMY_HASH, - change_password as auth_change_password, -) -from app.services.audit_service import log_action + +# Import TokenResponse, UserOut from app.schemas.auth from app.schemas.auth import TokenResponse, UserOut + +# Import PasswordChange from app.schemas.user from app.schemas.user import PasswordChange +# Import log_action from app.services.audit_service +from app.services.audit_service import log_action + +# Import from app.services.auth_service +from app.services.auth_service import ( + _DUMMY_HASH, +) + +# Import from app.services.auth_service +from app.services.auth_service import ( + change_password as auth_change_password, +) + +# Assign router = APIRouter(prefix="/auth", tags=["auth"]) router = APIRouter(prefix="/auth", tags=["auth"]) # SECURE_COOKIES desacopla la seguridad de la cookie del entorno de ejecucion. @@ -47,111 +84,182 @@ else: # "auto" — activo solo si AEGIS_ENV=production _COOKIE_NAME = "aegis_token" +# Apply the @router.post decorator @router.post("/login", response_model=TokenResponse) +# Apply the @limiter.limit decorator @limiter.limit("5/minute") +# Define function login def login( + # Entry: request request: Request, + # Entry: response response: Response, + # Entry: form_data form_data: OAuth2PasswordRequestForm = Depends(), + # Entry: db db: Session = Depends(get_db), -): +) -> TokenResponse: """Authenticate a user and return a JWT access token. Rate-limited to **5 attempts per minute per IP**. Failed and successful logins are recorded in the audit log (SEC-009). """ + # Assign user = db.query(User).filter(User.username == form_data.username).first() user = db.query(User).filter(User.username == form_data.username).first() + # Assign target_hash = user.hashed_password if user else _DUMMY_HASH target_hash = user.hashed_password if user else _DUMMY_HASH + # Assign password_valid = verify_password(form_data.password, target_hash) password_valid = verify_password(form_data.password, target_hash) + # Assign ip = resolve_client_ip(request) ip = resolve_client_ip(request) + # Check: user is None or not password_valid if user is None or not password_valid: + # Open context manager with UnitOfWork(db) as uow: + # Call log_action() log_action( db, user.id if user else None, + # Literal argument value "LOGIN_FAILED", + # Literal argument value "auth", + # Literal argument value None, + # Keyword argument: details details={ + # Literal argument value "username": form_data.username, + # Literal argument value "ip": ip, + # Literal argument value "reason": "invalid_credentials", }, + # Keyword argument: ip_address ip_address=ip, ) + # Call uow.commit() uow.commit() + # Raise BusinessRuleViolation raise BusinessRuleViolation("Incorrect username or password") + # Check: not user.is_active if not user.is_active: + # Raise PermissionViolation raise PermissionViolation("Account is disabled. Contact an administrator.") + # Assign access_token = create_access_token(data={"sub": user.username}) access_token = create_access_token(data={"sub": user.username}) + # Open context manager with UnitOfWork(db) as uow: + # Call log_action() log_action( db, user.id, + # Literal argument value "LOGIN_SUCCESS", + # Literal argument value "auth", str(user.id), + # Keyword argument: details details={"username": user.username, "ip": ip}, + # Keyword argument: ip_address ip_address=ip, ) + # Call uow.commit() uow.commit() + # Call response.set_cookie() response.set_cookie( + # Keyword argument: key key=_COOKIE_NAME, + # Keyword argument: value value=access_token, + # Keyword argument: httponly httponly=True, + # Keyword argument: secure secure=_IS_HTTPS, + # Keyword argument: samesite samesite="strict", + # Keyword argument: max_age max_age=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60, + # Keyword argument: path path="/", ) + # Return TokenResponse(access_token=access_token) return TokenResponse(access_token=access_token) +# Apply the @router.post decorator @router.post("/logout") +# Define function logout def logout( + # Entry: request request: Request, + # Entry: response response: Response, + # Entry: aegis_token aegis_token: str | None = Cookie(None), -): +) -> dict: """Clear the authentication cookie and revoke the current token.""" + # Assign bearer = ( bearer = ( request.headers.get("Authorization") or request.headers.get("authorization") or "" ) + # Assign bearer = bearer.removeprefix("Bearer ").removeprefix("bearer ").strip() bearer = bearer.removeprefix("Bearer ").removeprefix("bearer ").strip() + # Assign seen = set() seen: set[str] = set() + # Iterate over (aegis_token, bearer) for raw in (aegis_token, bearer): + # Check: not raw or raw in seen if not raw or raw in seen: + # Skip to the next loop iteration continue + # Call seen.add() seen.add(raw) + # Attempt the following; catch errors below try: + # Assign payload = jwt.decode( payload = jwt.decode( raw, settings.SECRET_KEY, + # Keyword argument: algorithms algorithms=[settings.ALGORITHM], ) + # Assign jti = payload.get("jti") jti = payload.get("jti") + # Assign exp = payload.get("exp", 0) exp = payload.get("exp", 0) + # Check: jti if jti: + # Call blacklist_token() blacklist_token(jti, float(exp)) - except JWTError: + # Handle any JWT validation error during logout (token may be expired or malformed) + except jwt.exceptions.InvalidTokenError: + # Intentional no-op placeholder pass + # Call response.delete_cookie() response.delete_cookie( + # Keyword argument: key key=_COOKIE_NAME, + # Keyword argument: httponly httponly=True, + # Keyword argument: secure secure=_IS_HTTPS, + # Keyword argument: samesite samesite="strict", + # Keyword argument: path path="/", ) + # Return {"detail": "Logged out"} return {"detail": "Logged out"} @@ -207,25 +315,38 @@ def refresh_token( @router.get("/me", response_model=UserOut) -def read_current_user(current_user: User = Depends(get_current_user)): +# Define function read_current_user +def read_current_user(current_user: User = Depends(get_current_user)) -> UserOut: """Return the profile of the currently authenticated user.""" + # Return current_user return current_user +# Apply the @router.post decorator @router.post("/change-password") +# Define function change_password def change_password( + # Entry: body body: PasswordChange, + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(get_current_user), -): +) -> dict: """Change the current user's password.""" + # Call auth_change_password() auth_change_password( db, current_user, + # Keyword argument: current_password current_password=body.current_password, + # Keyword argument: new_password new_password=body.new_password, ) + # Open context manager with UnitOfWork(db) as uow: + # Call uow.commit() uow.commit() + # Return {"detail": "Password changed successfully"} return {"detail": "Password changed successfully"} diff --git a/backend/app/routers/campaigns.py b/backend/app/routers/campaigns.py index 5f26754..cdf0667 100644 --- a/backend/app/routers/campaigns.py +++ b/backend/app/routers/campaigns.py @@ -1,80 +1,169 @@ """Campaign endpoints — CRUD, test management, activation, and auto-generation. -Provides comprehensive campaign lifecycle management including -test ordering, progress tracking, and threat actor integration. +Provides comprehensive campaign lifecycle management including test ordering, +progress tracking, and threat actor integration. """ +# Import logging import logging + +# Import uuid import uuid from datetime import datetime from typing import Optional +# Import APIRouter, Depends, Query from fastapi from fastapi import APIRouter, Depends, Query -from sqlalchemy.orm import Session + +# Import BaseModel, Field from pydantic from pydantic import BaseModel, Field +# Import Session from sqlalchemy.orm +from sqlalchemy.orm import Session + +# Import get_db from app.database from app.database import get_db + +# Import get_current_user, require_any_role from app.dependencies.auth from app.dependencies.auth import get_current_user, require_any_role + +# Import UnitOfWork from app.domain.unit_of_work +from app.domain.unit_of_work import UnitOfWork + +# Import User from app.models.user from app.models.user import User from app.models.campaign import Campaign, CampaignTest from app.models.test import Test from app.services.campaign_service import generate_campaign_from_threat_actor from app.services.campaign_crud_service import ( add_test_to_campaign as crud_add_test, - activate_campaign as crud_activate, +) + +# Import from app.services.campaign_crud_service +from app.services.campaign_crud_service import ( complete_campaign as crud_complete, +) + +# Import from app.services.campaign_crud_service +from app.services.campaign_crud_service import ( create_campaign as crud_create, delete_campaign as crud_delete, get_campaign_detail as crud_get_detail, +) + +# Import from app.services.campaign_crud_service +from app.services.campaign_crud_service import ( get_campaign_history as crud_get_history, +) + +# Import from app.services.campaign_crud_service +from app.services.campaign_crud_service import ( get_campaign_progress_data as crud_get_progress, +) + +# Import from app.services.campaign_crud_service +from app.services.campaign_crud_service import ( list_campaigns as crud_list, +) + +# Import from app.services.campaign_crud_service +from app.services.campaign_crud_service import ( remove_test_from_campaign as crud_remove_test, +) + +# Import from app.services.campaign_crud_service +from app.services.campaign_crud_service import ( schedule_campaign as crud_schedule, +) + +# Import from app.services.campaign_crud_service +from app.services.campaign_crud_service import ( serialize_campaign, +) + +# Import from app.services.campaign_crud_service +from app.services.campaign_crud_service import ( update_campaign as crud_update, ) -from app.domain.unit_of_work import UnitOfWork -from app.services.audit_service import log_action + +# Import generate_campaign_from_threat_actor from app.services.campaign_service +from app.services.campaign_service import generate_campaign_from_threat_actor + +# Import notify_role from app.services.notification_service from app.services.notification_service import notify_role from app.services.webhook_service import dispatch_webhook +# Assign logger = logging.getLogger(__name__) logger = logging.getLogger(__name__) +# Assign router = APIRouter(prefix="/campaigns", tags=["campaigns"]) router = APIRouter(prefix="/campaigns", tags=["campaigns"]) # ── Pydantic schemas ───────────────────────────────────────────────── class CampaignCreate(BaseModel): + """Payload for creating a new campaign.""" + + # name: str name: str + # Assign description = None description: Optional[str] = None + # Assign type = "custom" type: str = "custom" + # Assign threat_actor_id = None threat_actor_id: Optional[str] = None + # Assign target_platform = None target_platform: Optional[str] = None + # Assign tags = Field(default_factory=list) tags: Optional[list[str]] = Field(default_factory=list) + # Assign scheduled_at = None scheduled_at: Optional[str] = None start_date: Optional[str] = None # ISO date — campaign won't activate before this + +# Define class CampaignUpdate class CampaignUpdate(BaseModel): + """Payload for updating an existing campaign's metadata.""" + + # Assign name = None name: Optional[str] = None + # Assign description = None description: Optional[str] = None + # Assign type = None type: Optional[str] = None + # Assign target_platform = None target_platform: Optional[str] = None + # Assign tags = None tags: Optional[list[str]] = None + # Assign scheduled_at = None scheduled_at: Optional[str] = None start_date: Optional[str] = None # ISO date — can be updated while still in draft + +# Define class AddTestPayload class AddTestPayload(BaseModel): + """Payload for adding a test to a campaign.""" + + # test_id: str test_id: str + # Assign order_index = None order_index: Optional[int] = None + # Assign depends_on = None depends_on: Optional[str] = None + # Assign phase = None phase: Optional[str] = None +# Define class SchedulePayload class SchedulePayload(BaseModel): + """Payload for scheduling or rescheduling a campaign run.""" + + # is_recurring: bool is_recurring: bool + # Assign recurrence_pattern = None # weekly, monthly, quarterly recurrence_pattern: Optional[str] = None # weekly, monthly, quarterly + # Assign next_run_at = None next_run_at: Optional[str] = None @@ -83,24 +172,54 @@ class SchedulePayload(BaseModel): # --------------------------------------------------------------------------- @router.get("") +# Define function list_campaigns def list_campaigns( + # Entry: type type: Optional[str] = Query(None), + # Entry: status status: Optional[str] = Query(None), + # Entry: threat_actor_id threat_actor_id: Optional[str] = Query(None), + # Entry: search search: Optional[str] = Query(None), + # Entry: offset offset: int = Query(0, ge=0), + # Entry: limit limit: int = Query(50, ge=1, le=200), + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(get_current_user), -): - """List campaigns with optional filters and pagination.""" +) -> list: + """List campaigns with optional filters and pagination. + + Args: + type (Optional[str]): Filter by campaign type (e.g. ``custom``, ``threat_actor``). + status (Optional[str]): Filter by campaign status (e.g. ``draft``, ``active``). + threat_actor_id (Optional[str]): Filter campaigns linked to a specific threat actor. + search (Optional[str]): Free-text search against campaign name. + offset (int): Number of records to skip for pagination. + limit (int): Maximum number of records to return. + db (Session): SQLAlchemy database session. + current_user (User): Authenticated user making the request. + + Returns: + list: Serialised list of campaign summary dicts. + """ + # Return crud_list( return crud_list( db, + # Keyword argument: type type=type, + # Keyword argument: status status=status, + # Keyword argument: threat_actor_id threat_actor_id=threat_actor_id, + # Keyword argument: search search=search, + # Keyword argument: offset offset=offset, + # Keyword argument: limit limit=limit, ) @@ -110,36 +229,64 @@ def list_campaigns( # --------------------------------------------------------------------------- @router.post("", status_code=201) +# Define function create_campaign def create_campaign( + # Entry: payload payload: CampaignCreate, + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(require_any_role("red_lead", "blue_lead")), -): - """Create a new campaign.""" +) -> dict: + """Create a new campaign. + + Args: + payload (CampaignCreate): Fields for the new campaign (name, type, threat actor, etc.). + db (Session): SQLAlchemy database session. + current_user (User): Authenticated red_lead or blue_lead creating the campaign. + + Returns: + dict: Serialised representation of the newly created campaign. + """ + # Open context manager with UnitOfWork(db) as uow: + # Assign result = crud_create( result = crud_create( db, + # Keyword argument: creator_id creator_id=current_user.id, + # Keyword argument: name name=payload.name, + # Keyword argument: description description=payload.description, + # Keyword argument: type type=payload.type, + # Keyword argument: threat_actor_id threat_actor_id=payload.threat_actor_id, + # Keyword argument: target_platform target_platform=payload.target_platform, + # Keyword argument: tags tags=payload.tags, + # Keyword argument: scheduled_at scheduled_at=payload.scheduled_at, start_date=payload.start_date, ) campaign_id = result["id"] log_action( db, + # Keyword argument: user_id user_id=current_user.id, + # Keyword argument: action action="create_campaign", + # Keyword argument: entity_type entity_type="campaign", entity_id=campaign_id, details={"name": payload.name, "type": payload.type}, ) + # Call uow.commit() uow.commit() + # Return result return result @@ -148,12 +295,26 @@ def create_campaign( # --------------------------------------------------------------------------- @router.get("/{campaign_id}") +# Define function get_campaign def get_campaign( + # Entry: campaign_id campaign_id: str, + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(get_current_user), -): - """Get detailed campaign info including tests and progress.""" +) -> dict: + """Get detailed campaign info including tests and progress. + + Args: + campaign_id (str): UUID string of the campaign to retrieve. + db (Session): SQLAlchemy database session. + current_user (User): Authenticated user making the request. + + Returns: + dict: Campaign detail including associated tests and progress metrics. + """ + # Return crud_get_detail(db, campaign_id) return crud_get_detail(db, campaign_id) @@ -162,32 +323,60 @@ def get_campaign( # --------------------------------------------------------------------------- @router.patch("/{campaign_id}") +# Define function update_campaign def update_campaign( + # Entry: campaign_id campaign_id: str, + # Entry: payload payload: CampaignUpdate, + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(require_any_role("red_lead", "blue_lead")), -): - """Update a campaign. Only allowed in draft or active state.""" +) -> dict: + """Update a campaign. Only allowed in draft or active state. + + Args: + campaign_id (str): UUID string of the campaign to update. + payload (CampaignUpdate): Partial update payload; only set fields are applied. + db (Session): SQLAlchemy database session. + current_user (User): Authenticated red_lead or blue_lead performing the update. + + Returns: + dict: Serialised representation of the updated campaign. + """ + # Assign update_data = payload.model_dump(exclude_unset=True) update_data = payload.model_dump(exclude_unset=True) + # Open context manager with UnitOfWork(db) as uow: + # Assign result = crud_update( result = crud_update( db, campaign_id, + # Keyword argument: updater_id updater_id=current_user.id, + # Keyword argument: updater_role updater_role=current_user.role, **update_data, ) + # Call log_action() log_action( db, + # Keyword argument: user_id user_id=current_user.id, + # Keyword argument: action action="update_campaign", + # Keyword argument: entity_type entity_type="campaign", + # Keyword argument: entity_id entity_id=campaign_id, + # Keyword argument: details details={"updated_fields": list(update_data.keys())}, ) + # Call uow.commit() uow.commit() + # Return result return result @@ -227,22 +416,44 @@ def delete_campaign( # --------------------------------------------------------------------------- @router.post("/{campaign_id}/tests") +# Define function add_test_to_campaign def add_test_to_campaign( + # Entry: campaign_id campaign_id: str, + # Entry: payload payload: AddTestPayload, + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(require_any_role("red_lead", "blue_lead")), -): - """Add a test to a campaign with optional ordering and dependency.""" +) -> dict: + """Add a test to a campaign with optional ordering and dependency. + + Args: + campaign_id (str): UUID string of the target campaign. + payload (AddTestPayload): Test ID plus optional order index, dependency, and phase. + db (Session): SQLAlchemy database session. + current_user (User): Authenticated red_lead or blue_lead adding the test. + + Returns: + dict: The created campaign-test association record. + """ + # Open context manager with UnitOfWork(db) as uow: + # Assign result = crud_add_test( result = crud_add_test( db, campaign_id, + # Keyword argument: test_id test_id=payload.test_id, + # Keyword argument: order_index order_index=payload.order_index, + # Keyword argument: depends_on depends_on=payload.depends_on, + # Keyword argument: phase phase=payload.phase, ) + # Call uow.commit() uow.commit() return result @@ -253,16 +464,35 @@ def add_test_to_campaign( # --------------------------------------------------------------------------- @router.delete("/{campaign_id}/tests/{campaign_test_id}") +# Define function remove_test_from_campaign def remove_test_from_campaign( + # Entry: campaign_id campaign_id: str, + # Entry: campaign_test_id campaign_test_id: str, + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(require_any_role("red_lead", "blue_lead")), -): - """Remove a test from a campaign.""" +) -> dict: + """Remove a test from a campaign. + + Args: + campaign_id (str): UUID string of the campaign. + campaign_test_id (str): UUID string of the campaign-test association to remove. + db (Session): SQLAlchemy database session. + current_user (User): Authenticated red_lead or blue_lead removing the test. + + Returns: + dict: Confirmation message with key ``detail``. + """ + # Open context manager with UnitOfWork(db) as uow: + # Call crud_remove_test() crud_remove_test(db, campaign_id, campaign_test_id) + # Call uow.commit() uow.commit() + # Return {"detail": "Test removed from campaign"} return {"detail": "Test removed from campaign"} @@ -271,10 +501,13 @@ def remove_test_from_campaign( # --------------------------------------------------------------------------- @router.post("/{campaign_id}/activate") +# Define function activate_campaign def activate_campaign( + # Entry: campaign_id campaign_id: str, force: bool = Query(False, description="Activate even if start_date is in the future"), db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(require_any_role("red_lead", "blue_lead")), ): """Activate a campaign, moving it from draft to active. @@ -303,25 +536,41 @@ def activate_campaign( ) with UnitOfWork(db) as uow: + # Assign campaign = crud_activate(db, campaign_id) campaign = crud_activate(db, campaign_id) + # Call notify_role() notify_role( db, + # Keyword argument: role role="red_tech", + # Keyword argument: type type="campaign_activated", + # Keyword argument: title title="Campaign activated", + # Keyword argument: message message=f'Campaign "{campaign.name}" has been activated.', + # Keyword argument: entity_type entity_type="campaign", + # Keyword argument: entity_id entity_id=campaign.id, ) + # Call log_action() log_action( db, + # Keyword argument: user_id user_id=current_user.id, + # Keyword argument: action action="activate_campaign", + # Keyword argument: entity_type entity_type="campaign", + # Keyword argument: entity_id entity_id=campaign.id, + # Keyword argument: details details={"name": campaign.name}, ) + # Call uow.commit() uow.commit() + # Reload ORM object attributes from the database db.refresh(campaign) # Create Jira tickets for campaign and tests at activation time (non-fatal). @@ -359,26 +608,50 @@ def activate_campaign( # --------------------------------------------------------------------------- @router.post("/{campaign_id}/complete") +# Define function complete_campaign def complete_campaign( + # Entry: campaign_id campaign_id: str, + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(require_any_role("red_lead", "admin")), -): - """Mark a campaign as completed.""" +) -> dict: + """Mark a campaign as completed. + + Args: + campaign_id (str): UUID string of the campaign to complete. + db (Session): SQLAlchemy database session. + current_user (User): Authenticated red_lead or admin completing the campaign. + + Returns: + dict: Serialised representation of the completed campaign. + """ + # Open context manager with UnitOfWork(db) as uow: + # Assign campaign = crud_complete(db, campaign_id) campaign = crud_complete(db, campaign_id) + # Call log_action() log_action( db, + # Keyword argument: user_id user_id=current_user.id, + # Keyword argument: action action="complete_campaign", + # Keyword argument: entity_type entity_type="campaign", + # Keyword argument: entity_id entity_id=campaign.id, + # Keyword argument: details details={"name": campaign.name}, ) + # Call uow.commit() uow.commit() + # Reload ORM object attributes from the database db.refresh(campaign) dispatch_webhook("campaign.completed", {"campaign_id": str(campaign.id), "name": campaign.name}) + # Return serialize_campaign(db, campaign) return serialize_campaign(db, campaign) @@ -387,12 +660,26 @@ def complete_campaign( # --------------------------------------------------------------------------- @router.get("/{campaign_id}/progress") +# Define function get_campaign_progress_endpoint def get_campaign_progress_endpoint( + # Entry: campaign_id campaign_id: str, + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(get_current_user), -): - """Get progress statistics for a campaign.""" +) -> dict: + """Get progress statistics for a campaign. + + Args: + campaign_id (str): UUID string of the campaign. + db (Session): SQLAlchemy database session. + current_user (User): Authenticated user making the request. + + Returns: + dict: Progress breakdown including counts by test state and overall percentage. + """ + # Return crud_get_progress(db, campaign_id) return crud_get_progress(db, campaign_id) @@ -405,16 +692,27 @@ class GenerateFromActorPayload(BaseModel): @router.post("/from-threat-actor/{actor_id}", status_code=201) +# Define function generate_campaign_from_actor def generate_campaign_from_actor( + # Entry: actor_id actor_id: str, payload: GenerateFromActorPayload = GenerateFromActorPayload(), db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(require_any_role("red_lead", "blue_lead")), -): +) -> dict: """Auto-generate a campaign from a threat actor's uncovered techniques. Creates tests from the best available templates and orders them by kill chain phase. + + Args: + actor_id (str): UUID string of the threat actor to generate a campaign for. + db (Session): SQLAlchemy database session. + current_user (User): Authenticated red_lead or blue_lead requesting the generation. + + Returns: + dict: Serialised representation of the newly generated campaign. """ start_date_parsed = ( datetime.fromisoformat(payload.start_date) if payload.start_date else None @@ -426,17 +724,26 @@ def generate_campaign_from_actor( start_date=start_date_parsed, ) + # Open context manager with UnitOfWork(db) as uow: + # Call log_action() log_action( db, + # Keyword argument: user_id user_id=current_user.id, + # Keyword argument: action action="generate_campaign", + # Keyword argument: entity_type entity_type="campaign", + # Keyword argument: entity_id entity_id=campaign.id, + # Keyword argument: details details={"actor_id": actor_id, "campaign_name": campaign.name}, ) + # Call uow.commit() uow.commit() + # Return serialize_campaign(db, campaign) return serialize_campaign(db, campaign) @@ -445,41 +752,74 @@ def generate_campaign_from_actor( # --------------------------------------------------------------------------- @router.patch("/{campaign_id}/schedule") +# Define function schedule_campaign def schedule_campaign( + # Entry: campaign_id campaign_id: str, + # Entry: payload payload: SchedulePayload, + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(require_any_role("red_lead", "blue_lead")), -): +) -> dict: """Configure or update the recurrence schedule for a campaign. Only the campaign creator or admin can change scheduling. + + Args: + campaign_id (str): UUID string of the campaign to schedule. + payload (SchedulePayload): Recurrence flag, pattern, and next run timestamp. + db (Session): SQLAlchemy database session. + current_user (User): Authenticated red_lead or blue_lead (must be owner or admin). + + Returns: + dict: Serialised representation of the campaign with updated schedule fields. """ + # Open context manager with UnitOfWork(db) as uow: + # Assign campaign = crud_schedule( campaign = crud_schedule( db, campaign_id, + # Keyword argument: owner_id owner_id=current_user.id, + # Keyword argument: owner_role owner_role=current_user.role, + # Keyword argument: is_recurring is_recurring=payload.is_recurring, + # Keyword argument: recurrence_pattern recurrence_pattern=payload.recurrence_pattern, + # Keyword argument: next_run_at next_run_at=payload.next_run_at, ) + # Call log_action() log_action( db, + # Keyword argument: user_id user_id=current_user.id, + # Keyword argument: action action="schedule_campaign", + # Keyword argument: entity_type entity_type="campaign", + # Keyword argument: entity_id entity_id=campaign.id, + # Keyword argument: details details={ + # Literal argument value "is_recurring": campaign.is_recurring, + # Literal argument value "recurrence_pattern": campaign.recurrence_pattern, + # Literal argument value "next_run_at": campaign.next_run_at.isoformat() if campaign.next_run_at else None, }, ) + # Call uow.commit() uow.commit() + # Reload ORM object attributes from the database db.refresh(campaign) + # Return serialize_campaign(db, campaign) return serialize_campaign(db, campaign) @@ -488,12 +828,26 @@ def schedule_campaign( # --------------------------------------------------------------------------- @router.get("/{campaign_id}/history") +# Define function get_campaign_history def get_campaign_history( + # Entry: campaign_id campaign_id: str, + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(get_current_user), -): - """List all child campaigns (execution history) of a recurring campaign.""" +) -> list: + """List all child campaigns (execution history) of a recurring campaign. + + Args: + campaign_id (str): UUID string of the parent recurring campaign. + db (Session): SQLAlchemy database session. + current_user (User): Authenticated user making the request. + + Returns: + list: Serialised list of child campaign dicts ordered by creation date. + """ + # Return crud_get_history(db, campaign_id) return crud_get_history(db, campaign_id) diff --git a/backend/app/routers/compliance.py b/backend/app/routers/compliance.py index 9183f29..42b1d02 100644 --- a/backend/app/routers/compliance.py +++ b/backend/app/routers/compliance.py @@ -1,32 +1,45 @@ """Compliance endpoints — framework status, reports, and gap analysis. -Thin HTTP adapter: delegates all data logic to compliance_service. - +Thin HTTP adapter that delegates all data logic to compliance_service. Provides compliance posture assessment by mapping MITRE ATT&CK technique coverage to compliance framework controls. """ +# Import APIRouter, Depends from fastapi from fastapi import APIRouter, Depends + +# Import StreamingResponse from fastapi.responses from fastapi.responses import StreamingResponse + +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session +# Import get_db from app.database from app.database import get_db + +# Import get_current_user, require_role from app.dependencies.auth from app.dependencies.auth import get_current_user, require_role + +# Import User from app.models.user from app.models.user import User -from app.services.compliance_service import ( - list_frameworks, - get_framework_status, - build_framework_report_csv, - get_framework_gaps, -) + +# Import from app.services.compliance_import_service from app.services.compliance_import_service import ( - import_nist_800_53_mappings, import_cis_controls_v8_mappings, import_dora_mappings, import_iso_27001_mappings, import_iso_42001_mappings, ) +# Import from app.services.compliance_service +from app.services.compliance_service import ( + build_framework_report_csv, + get_framework_gaps, + get_framework_status, + list_frameworks, +) + +# Assign router = APIRouter(prefix="/compliance", tags=["compliance"]) router = APIRouter(prefix="/compliance", tags=["compliance"]) @@ -34,11 +47,23 @@ router = APIRouter(prefix="/compliance", tags=["compliance"]) @router.get("/frameworks") +# Define function list_frameworks_endpoint def list_frameworks_endpoint( + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(get_current_user), -): - """List all available compliance frameworks.""" +) -> list: + """List all available compliance frameworks. + + Args: + db (Session): SQLAlchemy database session. + current_user (User): Authenticated user making the request. + + Returns: + list: List of framework summary dicts containing id, name, and control counts. + """ + # Return list_frameworks(db) return list_frameworks(db) @@ -46,12 +71,26 @@ def list_frameworks_endpoint( @router.get("/frameworks/{framework_id}/status") +# Define function framework_status def framework_status( + # Entry: framework_id framework_id: str, + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(get_current_user), -): - """Get compliance status for each control in a framework.""" +) -> dict: + """Get compliance status for each control in a framework. + + Args: + framework_id (str): Identifier of the compliance framework (e.g. ``nist-800-53``). + db (Session): SQLAlchemy database session. + current_user (User): Authenticated user making the request. + + Returns: + dict: Mapping of control IDs to their coverage status and linked techniques. + """ + # Return get_framework_status(db, framework_id) return get_framework_status(db, framework_id) @@ -59,12 +98,26 @@ def framework_status( @router.get("/frameworks/{framework_id}/report") +# Define function framework_report def framework_report( + # Entry: framework_id framework_id: str, + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(get_current_user), -): - """Get the full compliance report (same as status but marked as report).""" +) -> dict: + """Get the full compliance report (same as status but marked as report). + + Args: + framework_id (str): Identifier of the compliance framework. + db (Session): SQLAlchemy database session. + current_user (User): Authenticated user making the request. + + Returns: + dict: Full compliance report with per-control coverage details. + """ + # Return get_framework_status(db, framework_id) return get_framework_status(db, framework_id) @@ -72,17 +125,35 @@ def framework_report( @router.get("/frameworks/{framework_id}/report/csv") +# Define function framework_report_csv def framework_report_csv( + # Entry: framework_id framework_id: str, + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(get_current_user), -): - """Export compliance report as CSV.""" +) -> StreamingResponse: + """Export compliance report as CSV. + + Args: + framework_id (str): Identifier of the compliance framework to export. + db (Session): SQLAlchemy database session. + current_user (User): Authenticated user making the request. + + Returns: + StreamingResponse: CSV file attachment with compliance coverage data. + """ + # csv_bytes, filename = build_framework_report_csv(db, framework_id) csv_bytes, filename = build_framework_report_csv(db, framework_id) + # Return StreamingResponse( return StreamingResponse( iter([csv_bytes]), + # Keyword argument: media_type media_type="text/csv", + # Keyword argument: headers headers={ + # Literal argument value "Content-Disposition": f"attachment; filename={filename}", }, ) @@ -92,12 +163,26 @@ def framework_report_csv( @router.get("/frameworks/{framework_id}/gaps") +# Define function framework_gaps def framework_gaps( + # Entry: framework_id framework_id: str, + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(get_current_user), -): - """Get controls with techniques that are not adequately covered.""" +) -> dict: + """Get controls with techniques that are not adequately covered. + + Args: + framework_id (str): Identifier of the compliance framework to analyse. + db (Session): SQLAlchemy database session. + current_user (User): Authenticated user making the request. + + Returns: + dict: Controls flagged as gaps, with linked technique IDs and coverage ratios. + """ + # Return get_framework_gaps(db, framework_id) return get_framework_gaps(db, framework_id) @@ -105,22 +190,49 @@ def framework_gaps( @router.post("/import/nist-800-53") +# Define function import_nist def import_nist( + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(require_role("admin")), -): - """Import NIST 800-53 Rev 5 mappings (admin only).""" +) -> dict: + """Import NIST 800-53 Rev 5 mappings (admin only). + + Args: + db (Session): SQLAlchemy database session. + current_user (User): Authenticated admin user. + + Returns: + dict: Import result with counts of created and updated control mappings. + """ + # Assign result = import_nist_800_53_mappings(db) result = import_nist_800_53_mappings(db) + # Return result return result +# Apply the @router.post decorator @router.post("/import/cis-controls-v8") +# Define function import_cis def import_cis( + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(require_role("admin")), -): - """Import CIS Controls v8 mappings (admin only).""" +) -> dict: + """Import CIS Controls v8 mappings (admin only). + + Args: + db (Session): SQLAlchemy database session. + current_user (User): Authenticated admin user. + + Returns: + dict: Import result with counts of created and updated control mappings. + """ + # Assign result = import_cis_controls_v8_mappings(db) result = import_cis_controls_v8_mappings(db) + # Return result return result diff --git a/backend/app/routers/d3fend.py b/backend/app/routers/d3fend.py index 4afa50a..521fb30 100644 --- a/backend/app/routers/d3fend.py +++ b/backend/app/routers/d3fend.py @@ -1,26 +1,47 @@ """D3FEND endpoints — defensive technique listings, mappings, and import trigger.""" +# Import logging import logging + +# Import Optional from typing from typing import Optional +# Import APIRouter, Depends, Query from fastapi from fastapi import APIRouter, Depends, Query + +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session +# Import get_db from app.database from app.database import get_db + +# Import get_current_user, require_role from app.dependencies.auth from app.dependencies.auth import get_current_user, require_role + +# Import User from app.models.user from app.models.user import User + +# Import from app.services.d3fend_import_service from app.services.d3fend_import_service import ( - import_d3fend_techniques, import_d3fend_mappings, + import_d3fend_techniques, ) + +# Import from app.services.d3fend_query_service +from app.services.d3fend_query_service import ( + get_defenses_for_attack_technique, + list_d3fend_tactics, +) + +# Import from app.services.d3fend_query_service from app.services.d3fend_query_service import ( list_defensive_techniques as list_defensive_techniques_svc, - list_d3fend_tactics, - get_defenses_for_attack_technique, ) +# Assign logger = logging.getLogger(__name__) logger = logging.getLogger(__name__) +# Assign router = APIRouter(prefix="/d3fend", tags=["d3fend"]) router = APIRouter(prefix="/d3fend", tags=["d3fend"]) @@ -29,15 +50,23 @@ router = APIRouter(prefix="/d3fend", tags=["d3fend"]) # --------------------------------------------------------------------------- @router.get("") +# Define function list_defensive_techniques def list_defensive_techniques( + # Entry: tactic tactic: Optional[str] = Query(None), + # Entry: search search: Optional[str] = Query(None), + # Entry: offset offset: int = Query(0, ge=0), + # Entry: limit limit: int = Query(50, ge=1, le=200), + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(get_current_user), -): +) -> list: """List all D3FEND defensive techniques with optional filters.""" + # Return list_defensive_techniques_svc( return list_defensive_techniques_svc( db, tactic=tactic, search=search, offset=offset, limit=limit ) @@ -48,11 +77,15 @@ def list_defensive_techniques( # --------------------------------------------------------------------------- @router.get("/tactics") +# Define function list_d3fend_tactics_endpoint def list_d3fend_tactics_endpoint( + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(get_current_user), -): +) -> list: """Return a list of all D3FEND tactics with counts.""" + # Return list_d3fend_tactics(db) return list_d3fend_tactics(db) @@ -61,12 +94,17 @@ def list_d3fend_tactics_endpoint( # --------------------------------------------------------------------------- @router.get("/for-technique/{mitre_id}") +# Define function get_defenses_for_attack_technique_endpoint def get_defenses_for_attack_technique_endpoint( + # Entry: mitre_id mitre_id: str, + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(get_current_user), -): +) -> list: """Get all D3FEND defensive techniques mapped to a given ATT&CK technique.""" + # Return get_defenses_for_attack_technique(db, mitre_id) return get_defenses_for_attack_technique(db, mitre_id) @@ -75,15 +113,23 @@ def get_defenses_for_attack_technique_endpoint( # --------------------------------------------------------------------------- @router.post("/import") +# Define function trigger_d3fend_import def trigger_d3fend_import( + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(require_role("admin")), -): +) -> dict: """Import D3FEND techniques and ATT&CK mappings. Admin only.""" + # Assign tech_result = import_d3fend_techniques(db) tech_result = import_d3fend_techniques(db) + # Assign mapping_result = import_d3fend_mappings(db) mapping_result = import_d3fend_mappings(db) + # Return { return { + # Literal argument value "techniques": tech_result, + # Literal argument value "mappings": mapping_result, } diff --git a/backend/app/routers/data_sources.py b/backend/app/routers/data_sources.py index 95af5cf..db7d405 100644 --- a/backend/app/routers/data_sources.py +++ b/backend/app/routers/data_sources.py @@ -5,16 +5,34 @@ Provides a centralized panel for managing all external data sources including sync triggers, enable/disable toggles, and statistics. """ -from fastapi import APIRouter, Depends -from pydantic import BaseModel -from sqlalchemy.orm import Session +# Import Optional from typing from typing import Optional +# Import APIRouter, Depends from fastapi +from fastapi import APIRouter, Depends + +# Import BaseModel from pydantic +from pydantic import BaseModel + +# Import Session from sqlalchemy.orm +from sqlalchemy.orm import Session + +# Import get_db from app.database from app.database import get_db + +# Import require_role from app.dependencies.auth from app.dependencies.auth import require_role + +# Import UnitOfWork from app.domain.unit_of_work from app.domain.unit_of_work import UnitOfWork + +# Import User from app.models.user from app.models.user import User + +# Import log_action from app.services.audit_service from app.services.audit_service import log_action + +# Import from app.services.data_source_service from app.services.data_source_service import ( get_source_stats, list_sources, @@ -23,18 +41,21 @@ from app.services.data_source_service import ( update_source, ) - # --------------------------------------------------------------------------- # Pydantic schemas for request validation # --------------------------------------------------------------------------- class DataSourceUpdate(BaseModel): """Payload for updating a data source — only allowed fields.""" + # Assign is_enabled = None is_enabled: Optional[bool] = None + # Assign sync_frequency = None sync_frequency: Optional[str] = None + # Assign config = None config: Optional[dict] = None +# Assign router = APIRouter(prefix="/data-sources", tags=["data-sources"]) router = APIRouter(prefix="/data-sources", tags=["data-sources"]) @@ -44,90 +65,137 @@ router = APIRouter(prefix="/data-sources", tags=["data-sources"]) @router.get("") +# Define function list_data_sources def list_data_sources( + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(require_role("admin")), -): +) -> list: """List all registered data sources. **Requires** the ``admin`` role. """ + # Return list_sources(db) return list_sources(db) +# Apply the @router.patch decorator @router.patch("/{source_id}") +# Define function update_data_source def update_data_source( + # Entry: source_id source_id: str, + # Entry: body body: DataSourceUpdate, + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(require_role("admin")), -): +) -> dict: """Update a data source (enable/disable, change config). **Requires** the ``admin`` role. """ + # Assign update_data = body.model_dump(exclude_unset=True) update_data = body.model_dump(exclude_unset=True) + # Open context manager with UnitOfWork(db) as uow: + # Call update_source() update_source(db, source_id, **update_data) + # Call log_action() log_action( db, + # Keyword argument: user_id user_id=current_user.id, + # Keyword argument: action action="update_data_source", + # Keyword argument: entity_type entity_type="data_source", + # Keyword argument: entity_id entity_id=source_id, + # Keyword argument: details details={"updates": update_data}, ) + # Call uow.commit() uow.commit() + # Return {"message": "Data source updated", "id": source_id} return {"message": "Data source updated", "id": source_id} +# Apply the @router.post decorator @router.post("/{source_id}/sync") +# Define function sync_data_source def sync_data_source( + # Entry: source_id source_id: str, + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(require_role("admin")), -): +) -> dict: """Trigger sync/import for a specific data source. **Requires** the ``admin`` role. """ + # Return sync_source(db, source_id) return sync_source(db, source_id) +# Apply the @router.post decorator @router.post("/sync-all") +# Define function sync_all_data_sources def sync_all_data_sources( + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(require_role("admin")), -): +) -> dict: """Trigger sync for all enabled data sources (sequentially). **Requires** the ``admin`` role. """ + # Assign results = sync_all_sources(db) results = sync_all_sources(db) + # Open context manager with UnitOfWork(db) as uow: + # Call log_action() log_action( db, + # Keyword argument: user_id user_id=current_user.id, + # Keyword argument: action action="sync_all_data_sources", + # Keyword argument: entity_type entity_type="data_source", + # Keyword argument: entity_id entity_id=None, + # Keyword argument: details details={"results": results}, ) + # Call uow.commit() uow.commit() + # Return {"message": "Sync all complete", "results": results} return {"message": "Sync all complete", "results": results} +# Apply the @router.get decorator @router.get("/{source_id}/stats") +# Define function get_data_source_stats def get_data_source_stats( + # Entry: source_id source_id: str, + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(require_role("admin")), -): +) -> dict: """Get detailed statistics for a specific data source. **Requires** the ``admin`` role. """ + # Return get_source_stats(db, source_id) return get_source_stats(db, source_id) diff --git a/backend/app/routers/detection_rules.py b/backend/app/routers/detection_rules.py index c847db1..9578e57 100644 --- a/backend/app/routers/detection_rules.py +++ b/backend/app/routers/detection_rules.py @@ -6,36 +6,55 @@ Provides endpoints for browsing detection rules, querying rules by technique, and managing the template ↔ detection rule associations. """ +# Import uuid import uuid + +# Import Optional from typing from typing import Optional +# Import APIRouter, Depends, Query from fastapi from fastapi import APIRouter, Depends, Query + +# Import BaseModel from pydantic from pydantic import BaseModel + +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session +# Import get_db from app.database from app.database import get_db -from app.dependencies.auth import get_current_user, require_role, require_any_role -from app.models.user import User -from app.services.detection_rule_service import ( - list_rules, - get_rules_for_template, - auto_associate_rules, - get_rules_for_test, - evaluate_rule, -) +# Import get_current_user, require_any_role, require_role from app.dependencies.auth +from app.dependencies.auth import get_current_user, require_any_role, require_role + +# Import User from app.models.user +from app.models.user import User + +# Import from app.services.detection_rule_service +from app.services.detection_rule_service import ( + auto_associate_rules, + evaluate_rule, + get_rules_for_template, + get_rules_for_test, + list_rules, +) # ── Pydantic schemas for request validation ──────────────────────────── class DetectionRuleEvaluate(BaseModel): """Payload for evaluating a detection rule against a test.""" + # test_id: uuid.UUID test_id: uuid.UUID + # detection_rule_id: uuid.UUID detection_rule_id: uuid.UUID + # Assign triggered = None triggered: Optional[bool] = None + # Assign notes = None notes: Optional[str] = None +# Assign router = APIRouter(prefix="/detection-rules", tags=["detection-rules"]) router = APIRouter(prefix="/detection-rules", tags=["detection-rules"]) @@ -43,24 +62,40 @@ router = APIRouter(prefix="/detection-rules", tags=["detection-rules"]) @router.get("") +# Define function list_detection_rules def list_detection_rules( + # Entry: technique technique: Optional[str] = Query(None, description="Filter by MITRE technique ID"), + # Entry: source source: Optional[str] = Query(None, description="Filter by source (sigma, elastic, splunk, custom)"), + # Entry: severity severity: Optional[str] = Query(None), + # Entry: search search: Optional[str] = Query(None), + # Entry: offset offset: int = Query(0, ge=0), + # Entry: limit limit: int = Query(50, ge=1, le=200), + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(get_current_user), -): +) -> list: """List detection rules with optional filters and pagination.""" + # Return list_rules( return list_rules( db, + # Keyword argument: technique technique=technique, + # Keyword argument: source source=source, + # Keyword argument: severity severity=severity, + # Keyword argument: search search=search, + # Keyword argument: offset offset=offset, + # Keyword argument: limit limit=limit, ) @@ -69,12 +104,17 @@ def list_detection_rules( @router.get("/for-template/{template_id}") +# Define function get_detection_rules_for_template def get_detection_rules_for_template( + # Entry: template_id template_id: str, + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(get_current_user), -): +) -> list: """Get detection rules associated with a test template.""" + # Return get_rules_for_template(db, template_id) return get_rules_for_template(db, template_id) @@ -82,16 +122,20 @@ def get_detection_rules_for_template( @router.post("/auto-associate") +# Define function auto_associate_detection_rules def auto_associate_detection_rules( + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(require_role("admin")), -): +) -> dict: """Auto-associate test templates with detection rules by MITRE technique ID. For each active template, find all active detection rules for the same technique and create associations. Rules with severity >= high are marked as primary. """ + # Return auto_associate_rules(db) return auto_associate_rules(db) @@ -99,16 +143,21 @@ def auto_associate_detection_rules( @router.get("/for-test/{test_id}") +# Define function get_detection_rules_for_test def get_detection_rules_for_test( + # Entry: test_id test_id: str, + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(get_current_user), -): +) -> list: """Get detection rules relevant to a test, along with their evaluation results. Finds rules by matching the test's technique_id to detection rules, and returns any existing evaluation results. """ + # Return get_rules_for_test(db, test_id) return get_rules_for_test(db, test_id) @@ -116,17 +165,27 @@ def get_detection_rules_for_test( @router.post("/evaluate") +# Define function evaluate_detection_rule def evaluate_detection_rule( + # Entry: payload payload: DetectionRuleEvaluate, + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(require_any_role("blue_tech", "blue_lead")), -): +) -> dict: """Save or update the evaluation result for a detection rule on a test.""" + # Return evaluate_rule( return evaluate_rule( db, + # Keyword argument: test_id test_id=payload.test_id, + # Keyword argument: detection_rule_id detection_rule_id=payload.detection_rule_id, + # Keyword argument: triggered triggered=payload.triggered, + # Keyword argument: notes notes=payload.notes, + # Keyword argument: evaluator_id evaluator_id=current_user.id, ) diff --git a/backend/app/routers/evidence.py b/backend/app/routers/evidence.py index 57ac4c1..8e68fe7 100644 --- a/backend/app/routers/evidence.py +++ b/backend/app/routers/evidence.py @@ -20,30 +20,54 @@ Access Control ``validated``, or ``rejected``. """ +# Import hashlib import hashlib import logging import os + +# Import uuid import uuid as _uuid from datetime import datetime from typing import Optional +# Import APIRouter, Depends, File, Form, Query, Request,... from fastapi from fastapi import APIRouter, Depends, File, Form, Query, Request, UploadFile, status from fastapi.responses import StreamingResponse from sqlalchemy.orm import Session +# Import get_db from app.database from app.database import get_db -from app.domain.unit_of_work import UnitOfWork + +# Import get_current_user from app.dependencies.auth from app.dependencies.auth import get_current_user + +# Import UnitOfWork from app.domain.unit_of_work +from app.domain.unit_of_work import UnitOfWork + +# Import limiter from app.limiter +from app.limiter import limiter + +# Import TeamSide from app.models.enums from app.models.enums import TeamSide + +# Import Evidence from app.models.evidence from app.models.evidence import Evidence + +# Import User from app.models.user from app.models.user import User + +# Import EvidenceOut from app.schemas.evidence from app.schemas.evidence import EvidenceOut + +# Import log_action from app.services.audit_service from app.services.audit_service import log_action + +# Import from app.services.evidence_service from app.services.evidence_service import ( + MAX_UPLOAD_SIZE, get_evidence_or_raise, get_test_or_raise, list_evidence_for_test, - MAX_UPLOAD_SIZE, validate_delete_permission, validate_file, validate_upload_permission, @@ -53,6 +77,7 @@ from app.storage import download_file, upload_file logger = logging.getLogger(__name__) +# Assign router = APIRouter(tags=["evidence"]) router = APIRouter(tags=["evidence"]) @@ -67,13 +92,21 @@ def _evidence_to_out(evidence: Evidence) -> EvidenceOut: never needs direct access to MinIO. """ return EvidenceOut( + # Keyword argument: id id=evidence.id, + # Keyword argument: test_id test_id=evidence.test_id, + # Keyword argument: file_name file_name=evidence.file_name, + # Keyword argument: sha256_hash sha256_hash=evidence.sha256_hash, + # Keyword argument: uploaded_by uploaded_by=evidence.uploaded_by, + # Keyword argument: uploaded_at uploaded_at=evidence.uploaded_at, + # Keyword argument: team team=evidence.team, + # Keyword argument: notes notes=evidence.notes, download_url=f"/api/v1/evidence/{evidence.id}/file", ) @@ -85,30 +118,47 @@ def _evidence_to_out(evidence: Evidence) -> EvidenceOut: @router.post( + # Literal argument value "/tests/{test_id}/evidence", + # Keyword argument: response_model response_model=EvidenceOut, + # Keyword argument: status_code status_code=status.HTTP_201_CREATED, ) +# Apply the @limiter.limit decorator @limiter.limit("10/minute") +# Define async function upload_evidence async def upload_evidence( + # Entry: request request: Request, + # Entry: test_id test_id: _uuid.UUID, + # Entry: file file: UploadFile = File(...), + # Entry: team team: TeamSide = Form(TeamSide.red), + # Entry: notes notes: Optional[str] = Form(None), + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(get_current_user), -): +) -> EvidenceOut: """Upload a file as evidence for the given test. The ``team`` field (sent as form data) determines whether this is Red Team (attack) or Blue Team (detection) evidence. """ + # Assign test = get_test_or_raise(db, test_id) test = get_test_or_raise(db, test_id) + # Call validate_upload_permission() validate_upload_permission(test, team, current_user.role) + # Assign file_name = file.filename or "unnamed" file_name = file.filename or "unnamed" + # Assign content = await file.read(MAX_UPLOAD_SIZE + 1) content = await file.read(MAX_UPLOAD_SIZE + 1) + # Call validate_file() validate_file(file_name, len(content)) # Hash @@ -116,6 +166,7 @@ async def upload_evidence( # 4. Object key (sanitise filename to prevent path traversal in storage) safe_name = os.path.basename(file_name) + # Assign key = f"{test_id}/{_uuid.uuid4()}_{safe_name}" key = f"{test_id}/{_uuid.uuid4()}_{safe_name}" # 5. Upload to MinIO @@ -123,32 +174,53 @@ async def upload_evidence( # 6. Persist metadata and audit with UnitOfWork(db) as uow: + # Assign evidence = Evidence( evidence = Evidence( + # Keyword argument: test_id test_id=test_id, + # Keyword argument: file_name file_name=safe_name, + # Keyword argument: file_path file_path=key, + # Keyword argument: sha256_hash sha256_hash=sha256, + # Keyword argument: uploaded_by uploaded_by=current_user.id, uploaded_at=datetime.utcnow(), # set explicitly — DB column has no server default team=team, + # Keyword argument: notes notes=notes, ) + # Stage new record(s) for database insertion db.add(evidence) + # Flush changes to DB without committing the transaction db.flush() # Get evidence.id for audit + # Call log_action() log_action( db, + # Keyword argument: user_id user_id=current_user.id, + # Keyword argument: action action="upload_evidence", + # Keyword argument: entity_type entity_type="evidence", + # Keyword argument: entity_id entity_id=evidence.id, + # Keyword argument: details details={ + # Literal argument value "file_name": safe_name, + # Literal argument value "sha256": sha256, + # Literal argument value "test_id": str(test_id), + # Literal argument value "team": team.value, }, ) + # Call uow.commit() uow.commit() + # Reload ORM object attributes from the database db.refresh(evidence) # 7. Attach to Jira ticket if one exists (non-fatal) @@ -194,15 +266,23 @@ def _attach_evidence_to_jira( @router.get("/tests/{test_id}/evidence", response_model=list[EvidenceOut]) +# Define function list_evidence def list_evidence( + # Entry: test_id test_id: _uuid.UUID, + # Entry: team team: Optional[str] = Query(None, description="Filter by team: red or blue"), + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(get_current_user), -): +) -> list[EvidenceOut]: """List all evidences for a test, optionally filtered by team.""" + # Call get_test_or_raise() get_test_or_raise(db, test_id) + # Assign evidences = list_evidence_for_test(db, test_id, team=team) evidences = list_evidence_for_test(db, test_id, team=team) + # Return [_evidence_to_out(e) for e in evidences] return [_evidence_to_out(e) for e in evidences] @@ -212,13 +292,18 @@ def list_evidence( @router.get("/evidence/{evidence_id}", response_model=EvidenceOut) +# Define function get_evidence def get_evidence( + # Entry: evidence_id evidence_id: _uuid.UUID, + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(get_current_user), ): """Return evidence metadata. ``download_url`` is a backend proxy URL.""" evidence = get_evidence_or_raise(db, evidence_id) + # Return _evidence_to_out(evidence) return _evidence_to_out(evidence) @@ -265,11 +350,15 @@ def download_evidence_file( @router.delete("/evidence/{evidence_id}", status_code=status.HTTP_200_OK) +# Define function delete_evidence def delete_evidence( + # Entry: evidence_id evidence_id: _uuid.UUID, + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(get_current_user), -): +) -> dict: """Delete an evidence record. Only allowed in editable states: @@ -277,24 +366,40 @@ def delete_evidence( - Blue evidence: ``blue_evaluating`` - No deletions in ``in_review``, ``validated``, ``rejected`` """ + # Assign evidence = get_evidence_or_raise(db, evidence_id) evidence = get_evidence_or_raise(db, evidence_id) + # Assign test = get_test_or_raise(db, evidence.test_id) test = get_test_or_raise(db, evidence.test_id) + # Call validate_delete_permission() validate_delete_permission(test, evidence, current_user.role, current_user.id) + # Open context manager with UnitOfWork(db) as uow: + # Call log_action() log_action( db, + # Keyword argument: user_id user_id=current_user.id, + # Keyword argument: action action="delete_evidence", + # Keyword argument: entity_type entity_type="evidence", + # Keyword argument: entity_id entity_id=evidence.id, + # Keyword argument: details details={ + # Literal argument value "file_name": evidence.file_name, + # Literal argument value "test_id": str(evidence.test_id), + # Literal argument value "team": evidence.team.value if evidence.team else None, }, ) + # Mark record for deletion on next commit db.delete(evidence) + # Call uow.commit() uow.commit() + # Return {"detail": "Evidence deleted"} return {"detail": "Evidence deleted"} diff --git a/backend/app/routers/heatmap.py b/backend/app/routers/heatmap.py index 18ec7f4..b4830bd 100644 --- a/backend/app/routers/heatmap.py +++ b/backend/app/routers/heatmap.py @@ -5,101 +5,169 @@ No business logic lives here — only request validation and response formatting. """ +# Import io import io + +# Import json import json + +# Import Optional from typing from typing import Optional +# Import APIRouter, Depends, Query from fastapi from fastapi import APIRouter, Depends, Query + +# Import StreamingResponse from fastapi.responses from fastapi.responses import StreamingResponse + +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session +# Import get_db from app.database from app.database import get_db + +# Import get_current_user from app.dependencies.auth from app.dependencies.auth import get_current_user + +# Import User from app.models.user from app.models.user import User + +# Import heatmap_service from app.services from app.services import heatmap_service +# Assign router = APIRouter(prefix="/heatmap", tags=["heatmap"]) router = APIRouter(prefix="/heatmap", tags=["heatmap"]) +# Apply the @router.get decorator @router.get("/coverage") +# Define function heatmap_coverage def heatmap_coverage( + # Entry: platforms platforms: Optional[str] = Query(None, description="Comma-separated platforms"), + # Entry: tactics tactics: Optional[str] = Query(None, description="Comma-separated tactics"), + # Entry: min_score min_score: int = Query(0, ge=0, le=100), + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(get_current_user), -): +) -> dict: """Coverage layer — score based on status_global of each technique.""" + # Return heatmap_service.build_coverage_layer( return heatmap_service.build_coverage_layer( db, platforms=platforms, tactics=tactics, min_score=min_score, ) +# Apply the @router.get decorator @router.get("/threat-actor/{actor_id}") +# Define function heatmap_threat_actor def heatmap_threat_actor( + # Entry: actor_id actor_id: str, + # Entry: platforms platforms: Optional[str] = Query(None), + # Entry: tactics tactics: Optional[str] = Query(None), + # Entry: min_score min_score: int = Query(0, ge=0, le=100), + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(get_current_user), -): +) -> dict: """Threat actor layer — techniques used by an actor with coverage color.""" + # Return heatmap_service.build_threat_actor_layer( return heatmap_service.build_threat_actor_layer( db, actor_id, platforms=platforms, tactics=tactics, min_score=min_score, ) +# Apply the @router.get decorator @router.get("/detection-rules") +# Define function heatmap_detection_rules def heatmap_detection_rules( + # Entry: platforms platforms: Optional[str] = Query(None), + # Entry: tactics tactics: Optional[str] = Query(None), + # Entry: min_score min_score: int = Query(0, ge=0, le=100), + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(get_current_user), -): +) -> dict: """Detection rules layer — score based on ratio of rules available vs total.""" + # Return heatmap_service.build_detection_rules_layer( return heatmap_service.build_detection_rules_layer( db, platforms=platforms, tactics=tactics, min_score=min_score, ) +# Apply the @router.get decorator @router.get("/campaign/{campaign_id}") +# Define function heatmap_campaign def heatmap_campaign( + # Entry: campaign_id campaign_id: str, + # Entry: platforms platforms: Optional[str] = Query(None), + # Entry: tactics tactics: Optional[str] = Query(None), + # Entry: min_score min_score: int = Query(0, ge=0, le=100), + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(get_current_user), -): +) -> dict: """Campaign layer — only techniques in the campaign, colored by test state.""" + # Return heatmap_service.build_campaign_layer( return heatmap_service.build_campaign_layer( db, campaign_id, platforms=platforms, tactics=tactics, min_score=min_score, ) +# Apply the @router.get decorator @router.get("/export-navigator") +# Define function export_navigator def export_navigator( + # Entry: layer layer: str = Query(..., description="Layer type: coverage, threat-actor, detection-rules, campaign"), + # Entry: layer_id layer_id: Optional[str] = Query(None, description="Actor ID or Campaign ID (if applicable)"), + # Entry: platforms platforms: Optional[str] = Query(None), + # Entry: tactics tactics: Optional[str] = Query(None), + # Entry: min_score min_score: int = Query(0, ge=0, le=100), + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(get_current_user), -): +) -> StreamingResponse: """Export a heatmap layer as a downloadable JSON file for ATT&CK Navigator.""" + # Assign data = heatmap_service.build_navigator_export( data = heatmap_service.build_navigator_export( db, layer, layer_id=layer_id, + # Keyword argument: platforms platforms=platforms, tactics=tactics, min_score=min_score, ) + # Assign json_content = json.dumps(data, indent=2, default=str) json_content = json.dumps(data, indent=2, default=str) + # Assign buffer = io.BytesIO(json_content.encode("utf-8")) buffer = io.BytesIO(json_content.encode("utf-8")) + # Return StreamingResponse( return StreamingResponse( buffer, + # Keyword argument: media_type media_type="application/json", + # Keyword argument: headers headers={"Content-Disposition": f"attachment; filename=aegis_{layer}_layer.json"}, ) diff --git a/backend/app/routers/jira.py b/backend/app/routers/jira.py index b369b01..c55caea 100644 --- a/backend/app/routers/jira.py +++ b/backend/app/routers/jira.py @@ -1,138 +1,235 @@ """Jira integration router — link, search, sync, create issues.""" +# Import logging import logging + +# Import Optional from typing from typing import Optional + +# Import UUID from uuid from uuid import UUID +# Import APIRouter, Depends, Query from fastapi from fastapi import APIRouter, Depends, Query + +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session +# Import get_db from app.database from app.database import get_db + +# Import get_current_user, require_role from app.dependencies.auth from app.dependencies.auth import get_current_user, require_role + +# Import UnitOfWork from app.domain.unit_of_work from app.domain.unit_of_work import UnitOfWork + +# Import JiraLinkEntityType from app.models.jira_link from app.models.jira_link import JiraLinkEntityType + +# Import User from app.models.user from app.models.user import User + +# Import from app.schemas.jira_schema from app.schemas.jira_schema import ( JiraIssueResult, JiraLinkCreate, JiraLinkOut, ) -from app.services import jira_service, audit_service +# Import audit_service, jira_service from app.services +from app.services import audit_service, jira_service + +# Assign logger = logging.getLogger(__name__) logger = logging.getLogger(__name__) +# Assign router = APIRouter(prefix="/jira", tags=["jira"]) router = APIRouter(prefix="/jira", tags=["jira"]) +# Apply the @router.get decorator @router.get("/search", response_model=list[JiraIssueResult]) +# Define function search_issues def search_issues( + # Entry: q q: str = Query(..., min_length=2), + # Entry: max_results max_results: int = Query(10, le=50), + # Entry: user user: User = Depends(get_current_user), -): +) -> list[JiraIssueResult]: """Search Jira issues by JQL or free text.""" + # Return jira_service.search_jira_issues(q, max_results) return jira_service.search_jira_issues(q, max_results) +# Apply the @router.post decorator @router.post("/links", response_model=JiraLinkOut, status_code=201) +# Define function create_link def create_link( + # Entry: body body: JiraLinkCreate, + # Entry: db db: Session = Depends(get_db), + # Entry: user user: User = Depends(get_current_user), -): +) -> JiraLinkOut: """Associate an Aegis entity with a Jira issue.""" + # Open context manager with UnitOfWork(db) as uow: + # Assign link = jira_service.create_link( link = jira_service.create_link( db, + # Keyword argument: entity_type entity_type=body.entity_type, + # Keyword argument: entity_id entity_id=body.entity_id, + # Keyword argument: jira_issue_key jira_issue_key=body.jira_issue_key, + # Keyword argument: sync_direction sync_direction=body.sync_direction, + # Keyword argument: created_by created_by=user.id, ) + # Call audit_service.log_action() audit_service.log_action( db, + # Keyword argument: user_id user_id=user.id, + # Keyword argument: action action="JIRA_LINK_CREATED", + # Keyword argument: entity_type entity_type="jira_link", + # Keyword argument: entity_id entity_id=str(link.id), + # Keyword argument: details details={ + # Literal argument value "linked_entity_type": body.entity_type.value, + # Literal argument value "linked_entity_id": str(body.entity_id), + # Literal argument value "jira_issue_key": body.jira_issue_key, }, ) + # Call uow.commit() uow.commit() + # Reload ORM object attributes from the database db.refresh(link) + # Return link return link +# Apply the @router.get decorator @router.get("/links", response_model=list[JiraLinkOut]) +# Define function list_links def list_links( + # Entry: entity_type entity_type: Optional[JiraLinkEntityType] = None, + # Entry: entity_id entity_id: Optional[UUID] = None, entity_ids: Optional[list[UUID]] = Query(default=None, description="Filter by multiple entity IDs"), db: Session = Depends(get_db), + # Entry: user user: User = Depends(get_current_user), ): """List Jira links, optionally filtered by entity or a list of entity IDs.""" return jira_service.list_links( db, + # Keyword argument: entity_type entity_type=entity_type, + # Keyword argument: entity_id entity_id=entity_id, entity_ids=entity_ids, ) +# Apply the @router.post decorator @router.post("/links/{link_id}/sync") +# Define function sync_link def sync_link( + # Entry: link_id link_id: UUID, + # Entry: db db: Session = Depends(get_db), + # Entry: user user: User = Depends(require_role("admin")), -): +) -> dict: """Force bidirectional sync for a specific Jira link.""" + # Open context manager with UnitOfWork(db) as uow: + # Assign link = jira_service.get_link_or_raise(db, link_id) link = jira_service.get_link_or_raise(db, link_id) + # Call jira_service.sync_jira_to_aegis() jira_service.sync_jira_to_aegis(db, link) + # Call uow.commit() uow.commit() + # Return {"message": "Sync completed", "jira_status": link.jira_status} return {"message": "Sync completed", "jira_status": link.jira_status} +# Apply the @router.delete decorator @router.delete("/links/{link_id}", status_code=204) +# Define function delete_link def delete_link( + # Entry: link_id link_id: UUID, + # Entry: db db: Session = Depends(get_db), + # Entry: user user: User = Depends(get_current_user), -): +) -> None: """Remove a Jira link.""" + # Open context manager with UnitOfWork(db) as uow: + # Assign link = jira_service.delete_link(db, link_id) link = jira_service.delete_link(db, link_id) + # Call audit_service.log_action() audit_service.log_action( db, + # Keyword argument: user_id user_id=user.id, + # Keyword argument: action action="jira_link_deleted", + # Keyword argument: entity_type entity_type="jira_link", + # Keyword argument: entity_id entity_id=str(link_id), + # Keyword argument: details details={"jira_issue_key": link.jira_issue_key}, ) + # Call uow.commit() uow.commit() +# Apply the @router.post decorator @router.post("/create-issue") +# Define function create_issue_from_entity def create_issue_from_entity( + # Entry: entity_type entity_type: JiraLinkEntityType, + # Entry: entity_id entity_id: UUID, + # Entry: db db: Session = Depends(get_db), + # Entry: user user: User = Depends(get_current_user), -): +) -> dict: """Auto-create a Jira issue from an Aegis entity and link them.""" + # Open context manager with UnitOfWork(db) as uow: + # Assign result = jira_service.create_issue_and_link( result = jira_service.create_issue_and_link( db, + # Keyword argument: entity_type entity_type=entity_type, + # Keyword argument: entity_id entity_id=entity_id, + # Keyword argument: created_by created_by=user.id, ) + # Call uow.commit() uow.commit() + # Return result return result diff --git a/backend/app/routers/metrics.py b/backend/app/routers/metrics.py index a8adcc3..e1f530f 100644 --- a/backend/app/routers/metrics.py +++ b/backend/app/routers/metrics.py @@ -7,12 +7,22 @@ validation-rate endpoints for the Red/Blue workflow. Thin HTTP adapter: delegates all data logic to metrics_query_service. """ +# Import APIRouter, Depends from fastapi from fastapi import APIRouter, Depends + +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session +# Import get_db from app.database from app.database import get_db + +# Import get_current_user from app.dependencies.auth from app.dependencies.auth import get_current_user + +# Import User from app.models.user from app.models.user import User + +# Import from app.schemas.metrics from app.schemas.metrics import ( CoverageSummary, RecentTestItem, @@ -21,6 +31,8 @@ from app.schemas.metrics import ( TestPipelineCounts, ValidationRate, ) + +# Import from app.services.metrics_query_service from app.services.metrics_query_service import ( get_coverage_by_tactic, get_coverage_summary, @@ -30,6 +42,7 @@ from app.services.metrics_query_service import ( get_validation_rate, ) +# Assign router = APIRouter(prefix="/metrics", tags=["metrics"]) router = APIRouter(prefix="/metrics", tags=["metrics"]) @@ -39,11 +52,15 @@ router = APIRouter(prefix="/metrics", tags=["metrics"]) @router.get("/summary", response_model=CoverageSummary) +# Define function coverage_summary def coverage_summary( + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(get_current_user), -): +) -> CoverageSummary: """Return a global coverage summary across all techniques.""" + # Return get_coverage_summary(db) return get_coverage_summary(db) @@ -53,11 +70,15 @@ def coverage_summary( @router.get("/by-tactic", response_model=list[TacticCoverage]) +# Define function coverage_by_tactic def coverage_by_tactic( + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(get_current_user), -): +) -> list[TacticCoverage]: """Return coverage breakdown grouped by tactic.""" + # Return get_coverage_by_tactic(db) return get_coverage_by_tactic(db) @@ -67,11 +88,15 @@ def coverage_by_tactic( @router.get("/test-pipeline", response_model=TestPipelineCounts) +# Define function test_pipeline def test_pipeline( + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(get_current_user), -): +) -> TestPipelineCounts: """Return how many tests are in each pipeline state.""" + # Return get_test_pipeline_counts(db) return get_test_pipeline_counts(db) @@ -81,11 +106,15 @@ def test_pipeline( @router.get("/team-activity", response_model=list[TeamActivity]) +# Define function team_activity def team_activity( + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(get_current_user), -): +) -> list[TeamActivity]: """Return activity summary for Red and Blue teams.""" + # Return get_team_activity(db) return get_team_activity(db) @@ -95,11 +124,15 @@ def team_activity( @router.get("/validation-rate", response_model=list[ValidationRate]) +# Define function validation_rate def validation_rate( + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(get_current_user), -): +) -> list[ValidationRate]: """Return approval and rejection rates for Red Lead and Blue Lead.""" + # Return get_validation_rate(db) return get_validation_rate(db) @@ -109,9 +142,13 @@ def validation_rate( @router.get("/recent-tests", response_model=list[RecentTestItem]) +# Define function recent_tests def recent_tests( + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(get_current_user), -): +) -> list[RecentTestItem]: """Return the 10 most recently created tests.""" + # Return get_recent_tests(db, limit=10) return get_recent_tests(db, limit=10) diff --git a/backend/app/routers/notifications.py b/backend/app/routers/notifications.py index da42c58..759b918 100644 --- a/backend/app/routers/notifications.py +++ b/backend/app/routers/notifications.py @@ -8,23 +8,39 @@ PATCH /notifications/{id}/read — mark one notification as read POST /notifications/read-all — mark all as read """ +# Import uuid import uuid +# Import APIRouter, Depends, Query from fastapi from fastapi import APIRouter, Depends, Query + +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session +# Import get_db from app.database from app.database import get_db + +# Import get_current_user from app.dependencies.auth from app.dependencies.auth import get_current_user + +# Import UnitOfWork from app.domain.unit_of_work from app.domain.unit_of_work import UnitOfWork + +# Import User from app.models.user from app.models.user import User + +# Import NotificationOut, UnreadCountOut from app.schemas.notification from app.schemas.notification import NotificationOut, UnreadCountOut + +# Import from app.services.notification_service from app.services.notification_service import ( - list_notifications, - mark_as_read, - mark_all_as_read, get_unread_count, + list_notifications, + mark_all_as_read, + mark_as_read, ) +# Assign router = APIRouter(prefix="/notifications", tags=["notifications"]) router = APIRouter(prefix="/notifications", tags=["notifications"]) @@ -34,13 +50,19 @@ router = APIRouter(prefix="/notifications", tags=["notifications"]) @router.get("", response_model=list[NotificationOut]) +# Define function list_notifications_endpoint def list_notifications_endpoint( + # Entry: offset offset: int = Query(0, ge=0), + # Entry: limit limit: int = Query(20, ge=1, le=100), + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(get_current_user), -): +) -> list[NotificationOut]: """Return paginated notifications for the current user, newest first.""" + # Return list_notifications(db, current_user.id, offset=offset, limit=limit) return list_notifications(db, current_user.id, offset=offset, limit=limit) @@ -50,12 +72,17 @@ def list_notifications_endpoint( @router.get("/unread-count", response_model=UnreadCountOut) +# Define function unread_count def unread_count( + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(get_current_user), -): +) -> UnreadCountOut: """Return the number of unread notifications for the current user.""" + # Assign count = get_unread_count(db, current_user.id) count = get_unread_count(db, current_user.id) + # Return UnreadCountOut(unread_count=count) return UnreadCountOut(unread_count=count) @@ -65,15 +92,23 @@ def unread_count( @router.patch("/{notification_id}/read", response_model=NotificationOut) +# Define function read_notification def read_notification( + # Entry: notification_id notification_id: uuid.UUID, + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(get_current_user), -): +) -> NotificationOut: """Mark a single notification as read.""" + # Open context manager with UnitOfWork(db) as uow: + # Assign notif = mark_as_read(db, notification_id, current_user.id) notif = mark_as_read(db, notification_id, current_user.id) + # Call uow.commit() uow.commit() + # Return notif return notif @@ -83,12 +118,19 @@ def read_notification( @router.post("/read-all") +# Define function read_all_notifications def read_all_notifications( + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(get_current_user), -): +) -> dict: """Mark all notifications for the current user as read.""" + # Open context manager with UnitOfWork(db) as uow: + # Assign count = mark_all_as_read(db, current_user.id) count = mark_all_as_read(db, current_user.id) + # Call uow.commit() uow.commit() + # Return {"detail": f"Marked {count} notifications as read"} return {"detail": f"Marked {count} notifications as read"} diff --git a/backend/app/routers/operational_metrics.py b/backend/app/routers/operational_metrics.py index 344a1ed..6f14cdd 100644 --- a/backend/app/routers/operational_metrics.py +++ b/backend/app/routers/operational_metrics.py @@ -4,18 +4,28 @@ Provides operational KPIs for security teams with trend analysis and team-level breakdowns. """ +# Import APIRouter, Depends, Query from fastapi from fastapi import APIRouter, Depends, Query + +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session +# Import get_db from app.database from app.database import get_db + +# Import get_current_user from app.dependencies.auth from app.dependencies.auth import get_current_user + +# Import User from app.models.user from app.models.user import User + +# Import from app.services.operational_metrics_service from app.services.operational_metrics_service import ( - get_all_operational_metrics, - get_operational_trend, get_metrics_by_team, + get_operational_trend, ) +# Assign router = APIRouter(prefix="/metrics/operational", tags=["operational-metrics"]) router = APIRouter(prefix="/metrics/operational", tags=["operational-metrics"]) @@ -23,13 +33,18 @@ router = APIRouter(prefix="/metrics/operational", tags=["operational-metrics"]) @router.get("") +# Define function operational_metrics def operational_metrics( + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(get_current_user), -): +) -> dict: """Get all operational metrics (MTTD, MTTR, etc.) — cached for 5 min.""" + # Import get_operational_metrics_cached from app.services.score_cache from app.services.score_cache import get_operational_metrics_cached + # Return get_operational_metrics_cached(db) return get_operational_metrics_cached(db) @@ -37,12 +52,17 @@ def operational_metrics( @router.get("/trend") +# Define function operational_trend def operational_trend( + # Entry: period period: str = Query("90d", pattern="^(30d|90d|1y)$"), + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(get_current_user), -): +) -> dict: """Get weekly trend data for operational metrics.""" + # Return get_operational_trend(db, period) return get_operational_trend(db, period) @@ -50,9 +70,13 @@ def operational_trend( @router.get("/by-team") +# Define function metrics_by_team def metrics_by_team( + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(get_current_user), -): +) -> dict: """Get metrics broken down by Red Team vs Blue Team.""" + # Return get_metrics_by_team(db) return get_metrics_by_team(db) diff --git a/backend/app/routers/osint.py b/backend/app/routers/osint.py index 97e0fff..1e48b37 100644 --- a/backend/app/routers/osint.py +++ b/backend/app/routers/osint.py @@ -1,26 +1,44 @@ -"""OSINT enrichment endpoints — view, review, and trigger enrichment of -OSINT items (CVEs, advisories, etc.) linked to techniques. -""" +"""OSINT enrichment endpoints — view, review, and trigger enrichment of OSINT items linked to techniques.""" +# Import UUID from uuid from uuid import UUID -from fastapi import APIRouter, Depends, Query, HTTPException, status +# Import APIRouter, Depends, HTTPException, Query, status from fastapi +from fastapi import APIRouter, Depends, HTTPException, Query, status + +# Import BaseModel from pydantic from pydantic import BaseModel + +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session +# Import get_db from app.database from app.database import get_db + +# Import get_current_user, require_any_role from app.dependencies.auth from app.dependencies.auth import get_current_user, require_any_role + +# Import UnitOfWork from app.domain.unit_of_work from app.domain.unit_of_work import UnitOfWork + +# Import User from app.models.user from app.models.user import User + +# Import from app.services.osint_enrichment_service from app.services.osint_enrichment_service import ( enrich_technique_with_cves, get_osint_items_for_technique, get_osint_summary, get_technique_or_raise, - list_osint_items as service_list_osint_items, mark_osint_reviewed, ) +# Import from app.services.osint_enrichment_service +from app.services.osint_enrichment_service import ( + list_osint_items as service_list_osint_items, +) + +# Assign router = APIRouter(prefix="/osint", tags=["osint"]) router = APIRouter(prefix="/osint", tags=["osint"]) @@ -28,18 +46,34 @@ router = APIRouter(prefix="/osint", tags=["osint"]) class OsintItemOut(BaseModel): + """Serialized OSINT item returned by the API.""" + + # id: str id: str + # technique_id: str technique_id: str + # source_type: str source_type: str + # source_url: str source_url: str + # title: str title: str + # description: str | None description: str | None + # severity: str | None severity: str | None + # discovered_at: str | None discovered_at: str | None + # reviewed: bool reviewed: bool + # Assign metadata_ = None metadata_: dict | None = None + # Define class Config class Config: + """ORM mode configuration for SQLAlchemy model mapping.""" + + # Assign from_attributes = True from_attributes = True @@ -47,94 +81,207 @@ class OsintItemOut(BaseModel): @router.get("/items") +# Define function list_osint_items def list_osint_items( + # Entry: technique_id technique_id: UUID | None = Query(None), + # Entry: source_type source_type: str | None = Query(None), + # Entry: reviewed reviewed: bool | None = Query(None), + # Entry: offset offset: int = Query(0, ge=0), + # Entry: limit limit: int = Query(50, ge=1, le=200), + # Entry: db db: Session = Depends(get_db), + # Entry: user user: User = Depends(get_current_user), -): - """List OSINT items with optional filters.""" +) -> list: + """List OSINT items with optional filters. + + Args: + technique_id (UUID | None): Filter by the technique's UUID. + source_type (str | None): Filter by source type (e.g. ``nvd_cve``, ``advisory``). + reviewed (bool | None): Filter by review status; ``None`` returns all. + offset (int): Number of records to skip for pagination. + limit (int): Maximum number of records to return. + db (Session): SQLAlchemy database session. + user (User): Authenticated user making the request. + + Returns: + list: Serialised list of OSINT item dicts matching the filters. + """ + # Return service_list_osint_items( return service_list_osint_items( db, + # Keyword argument: technique_id technique_id=technique_id, + # Keyword argument: source_type source_type=source_type, + # Keyword argument: reviewed reviewed=reviewed, + # Keyword argument: offset offset=offset, + # Keyword argument: limit limit=limit, ) +# Apply the @router.get decorator @router.get("/summary") +# Define function osint_summary def osint_summary( + # Entry: db db: Session = Depends(get_db), + # Entry: user user: User = Depends(get_current_user), -): - """Summary statistics for OSINT items.""" +) -> dict: + """Return summary statistics for OSINT items. + + Args: + db (Session): SQLAlchemy database session. + user (User): Authenticated user making the request. + + Returns: + dict: Counts of total, reviewed, and unreviewed items broken down by source type. + """ + # Return get_osint_summary(db) return get_osint_summary(db) +# Apply the @router.post decorator @router.post("/items/{item_id}/review") +# Define function review_osint_item def review_osint_item( + # Entry: item_id item_id: UUID, + # Entry: db db: Session = Depends(get_db), + # Entry: user user: User = Depends(get_current_user), -): - """Mark an OSINT item as reviewed.""" +) -> dict: + """Mark an OSINT item as reviewed. + + Args: + item_id (UUID): Primary key of the OSINT item to mark reviewed. + db (Session): SQLAlchemy database session. + user (User): Authenticated user performing the review. + + Returns: + dict: Contains ``id`` (str) and ``reviewed`` (bool ``True``). + """ + # Open context manager with UnitOfWork(db) as uow: + # Assign item = mark_osint_reviewed(db, str(item_id)) item = mark_osint_reviewed(db, str(item_id)) + # Check: not item if not item: + # Raise HTTPException raise HTTPException( + # Keyword argument: status_code status_code=status.HTTP_404_NOT_FOUND, + # Keyword argument: detail detail="OSINT item not found", ) + # Call uow.commit() uow.commit() + # Return {"id": str(item.id), "reviewed": True} return {"id": str(item.id), "reviewed": True} +# Apply the @router.post decorator @router.post("/enrich/{technique_id}") +# Define function trigger_technique_enrichment def trigger_technique_enrichment( + # Entry: technique_id technique_id: UUID, + # Entry: db db: Session = Depends(get_db), + # Entry: user user: User = Depends(require_any_role("red_lead", "blue_lead")), -): - """Manually trigger OSINT enrichment for a single technique.""" +) -> dict: + """Manually trigger OSINT enrichment for a single technique. + + Args: + technique_id (UUID): Primary key of the technique to enrich. + db (Session): SQLAlchemy database session. + user (User): Authenticated red_lead or blue_lead requesting enrichment. + + Returns: + dict: Contains ``technique_id`` (str), ``mitre_id`` (str), and ``new_items`` (int). + """ + # Assign technique = get_technique_or_raise(db, technique_id) technique = get_technique_or_raise(db, technique_id) + # Assign count = enrich_technique_with_cves(db, technique) count = enrich_technique_with_cves(db, technique) + # Return { return { + # Literal argument value "technique_id": str(technique.id), + # Literal argument value "mitre_id": technique.mitre_id, + # Literal argument value "new_items": count, } +# Apply the @router.get decorator @router.get("/technique/{technique_id}") +# Define function get_technique_osint def get_technique_osint( + # Entry: technique_id technique_id: UUID, + # Entry: source_type source_type: str | None = Query(None), + # Entry: reviewed reviewed: bool | None = Query(None), + # Entry: db db: Session = Depends(get_db), + # Entry: user user: User = Depends(get_current_user), -): - """Get all OSINT items for a specific technique.""" +) -> list: + """Get all OSINT items for a specific technique. + + Args: + technique_id (UUID): Primary key of the technique. + source_type (str | None): Filter by source type (e.g. ``nvd_cve``). + reviewed (bool | None): Filter by review status; ``None`` returns all. + db (Session): SQLAlchemy database session. + user (User): Authenticated user making the request. + + Returns: + list: Dicts with OSINT item fields including source URL, severity, and review status. + """ + # Assign items = get_osint_items_for_technique( items = get_osint_items_for_technique( db, str(technique_id), + # Keyword argument: source_type source_type=source_type, + # Keyword argument: reviewed reviewed=reviewed, ) + # Return [ return [ { + # Literal argument value "id": str(item.id), + # Literal argument value "source_type": item.source_type, + # Literal argument value "source_url": item.source_url, + # Literal argument value "title": item.title, + # Literal argument value "description": item.description, + # Literal argument value "severity": item.severity, + # Literal argument value "discovered_at": item.discovered_at.isoformat() if item.discovered_at else None, + # Literal argument value "reviewed": item.reviewed, + # Literal argument value "metadata": item.metadata_, } for item in items diff --git a/backend/app/routers/professional_reports.py b/backend/app/routers/professional_reports.py index 053d80d..b91ef24 100644 --- a/backend/app/routers/professional_reports.py +++ b/backend/app/routers/professional_reports.py @@ -1,118 +1,195 @@ """Professional report generation endpoints — PDF, DOCX, HTML output.""" +# Import UUID from uuid from uuid import UUID +# Import APIRouter, Depends, Query, Request from fastapi from fastapi import APIRouter, Depends, Query, Request + +# Import FileResponse from fastapi.responses from fastapi.responses import FileResponse + +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session +# Import get_db from app.database from app.database import get_db + +# Import get_current_user, require_any_role from app.dependencies.auth from app.dependencies.auth import get_current_user, require_any_role -from app.models.user import User + +# Import limiter from app.limiter from app.limiter import limiter + +# Import User from app.models.user +from app.models.user import User + +# Import report_generation_service from app.services from app.services import report_generation_service +# Assign router = APIRouter(prefix="/reports/generate", tags=["professional-reports"]) router = APIRouter(prefix="/reports/generate", tags=["professional-reports"]) +# Assign _MEDIA_TYPES = { _MEDIA_TYPES = { + # Literal argument value "pdf": "application/pdf", + # Literal argument value "docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + # Literal argument value "html": "text/html", } +# Apply the @router.get decorator @router.get("/purple-campaign/{campaign_id}") +# Apply the @limiter.limit decorator @limiter.limit("5/minute") +# Define function generate_purple_report def generate_purple_report( + # Entry: request request: Request, + # Entry: campaign_id campaign_id: UUID, + # Entry: format format: str = Query("pdf", pattern="^(pdf|docx|html)$"), + # Entry: db db: Session = Depends(get_db), + # Entry: user user: User = Depends(require_any_role("red_lead", "blue_lead", "viewer")), -): +) -> FileResponse: """Generate a Purple Team campaign assessment report.""" + # Assign filepath = report_generation_service.generate_purple_campaign_report( filepath = report_generation_service.generate_purple_campaign_report( db, str(campaign_id), output_format=format, ) + # Return FileResponse( return FileResponse( filepath, + # Keyword argument: media_type media_type=_MEDIA_TYPES[format], + # Keyword argument: filename filename=f"purple_report.{format}", ) +# Apply the @router.get decorator @router.get("/coverage-summary") +# Apply the @limiter.limit decorator @limiter.limit("5/minute") +# Define function generate_coverage_report def generate_coverage_report( + # Entry: request request: Request, + # Entry: format format: str = Query("pdf", pattern="^(pdf|docx|html)$"), + # Entry: db db: Session = Depends(get_db), + # Entry: user user: User = Depends(require_any_role("red_lead", "blue_lead", "viewer")), -): +) -> FileResponse: """Generate an organization-wide MITRE ATT&CK coverage report.""" + # Assign filepath = report_generation_service.generate_coverage_report( filepath = report_generation_service.generate_coverage_report( db, output_format=format, ) + # Return FileResponse( return FileResponse( filepath, + # Keyword argument: media_type media_type=_MEDIA_TYPES[format], + # Keyword argument: filename filename=f"coverage_report.{format}", ) +# Apply the @router.get decorator @router.get("/executive-summary") +# Apply the @limiter.limit decorator @limiter.limit("5/minute") +# Define function generate_executive_report def generate_executive_report( + # Entry: request request: Request, + # Entry: format format: str = Query("pdf", pattern="^(pdf|docx|html)$"), + # Entry: db db: Session = Depends(get_db), + # Entry: user user: User = Depends(require_any_role("red_lead", "blue_lead", "viewer")), -): +) -> FileResponse: """Generate an executive security summary report.""" + # Assign filepath = report_generation_service.generate_executive_summary( filepath = report_generation_service.generate_executive_summary( db, output_format=format, ) + # Return FileResponse( return FileResponse( filepath, + # Keyword argument: media_type media_type=_MEDIA_TYPES[format], + # Keyword argument: filename filename=f"executive_summary.{format}", ) +# Apply the @router.get decorator @router.get("/quarterly-summary") +# Apply the @limiter.limit decorator @limiter.limit("5/minute") +# Define function generate_quarterly_report def generate_quarterly_report( + # Entry: request request: Request, + # Entry: format format: str = Query("pdf", pattern="^(pdf|docx|html)$"), + # Entry: db db: Session = Depends(get_db), + # Entry: user user: User = Depends(require_any_role("red_lead", "blue_lead", "viewer")), -): +) -> FileResponse: """Generate a quarterly security summary report.""" + # Assign filepath = report_generation_service.generate_quarterly_summary( filepath = report_generation_service.generate_quarterly_summary( db, output_format=format, ) + # Return FileResponse( return FileResponse( filepath, + # Keyword argument: media_type media_type=_MEDIA_TYPES[format], + # Keyword argument: filename filename=f"quarterly_summary.{format}", ) +# Apply the @router.get decorator @router.get("/technique/{technique_id}") +# Apply the @limiter.limit decorator @limiter.limit("5/minute") +# Define function generate_technique_report def generate_technique_report( + # Entry: request request: Request, + # Entry: technique_id technique_id: UUID, + # Entry: format format: str = Query("pdf", pattern="^(pdf|docx|html)$"), + # Entry: db db: Session = Depends(get_db), + # Entry: user user: User = Depends(get_current_user), -): +) -> FileResponse: """Generate a detailed report for one MITRE technique.""" + # Assign filepath = report_generation_service.generate_technique_detail_report( filepath = report_generation_service.generate_technique_detail_report( db, str(technique_id), output_format=format, ) + # Return FileResponse( return FileResponse( filepath, + # Keyword argument: media_type media_type=_MEDIA_TYPES[format], + # Keyword argument: filename filename=f"technique_{technique_id}.{format}", ) diff --git a/backend/app/routers/reports.py b/backend/app/routers/reports.py index d065116..e64c892 100644 --- a/backend/app/routers/reports.py +++ b/backend/app/routers/reports.py @@ -10,18 +10,37 @@ GET /reports/test-results — test results report (JSON) GET /reports/remediation-status — remediation status report (JSON) """ +# Import csv import csv + +# Import io import io + +# Import datetime from datetime from datetime import datetime + +# Import Optional from typing from typing import Optional +# Import APIRouter, Depends, Query from fastapi from fastapi import APIRouter, Depends, Query + +# Import StreamingResponse from fastapi.responses from fastapi.responses import StreamingResponse + +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session +# Import get_db from app.database from app.database import get_db + +# Import get_current_user from app.dependencies.auth from app.dependencies.auth import get_current_user + +# Import User from app.models.user from app.models.user import User + +# Import from app.services.coverage_report_service from app.services.coverage_report_service import ( build_coverage_csv_rows, build_coverage_summary, @@ -29,61 +48,99 @@ from app.services.coverage_report_service import ( build_test_results_report, ) +# Assign router = APIRouter(prefix="/reports", tags=["reports"]) router = APIRouter(prefix="/reports", tags=["reports"]) +# Apply the @router.get decorator @router.get("/coverage-summary") +# Define function coverage_summary def coverage_summary( + # Entry: tactic tactic: Optional[str] = Query(None, description="Filter by tactic"), + # Entry: platform platform: Optional[str] = Query(None, description="Filter by platform (in techniques)"), + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(get_current_user), -): +) -> dict: """Full coverage report as JSON — technique-by-technique with test counts.""" + # Return build_coverage_summary(db, tactic=tactic, platform=platform) return build_coverage_summary(db, tactic=tactic, platform=platform) +# Apply the @router.get decorator @router.get("/coverage-csv") +# Define function coverage_csv def coverage_csv( + # Entry: tactic tactic: Optional[str] = Query(None), + # Entry: platform platform: Optional[str] = Query(None), + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(get_current_user), -): +) -> StreamingResponse: """Export coverage as a downloadable CSV.""" + # Assign rows = build_coverage_csv_rows(db, tactic=tactic, platform=platform) rows = build_coverage_csv_rows(db, tactic=tactic, platform=platform) + # Assign output = io.StringIO() output = io.StringIO() + # Assign writer = csv.writer(output) writer = csv.writer(output) + # Iterate over rows for row in rows: + # Call writer.writerow() writer.writerow(row) + # Call output.seek() output.seek(0) + # Assign filename = f"aegis_coverage_{datetime.utcnow().strftime('%Y%m%d')}.csv" filename = f"aegis_coverage_{datetime.utcnow().strftime('%Y%m%d')}.csv" + # Return StreamingResponse( return StreamingResponse( iter([output.getvalue()]), + # Keyword argument: media_type media_type="text/csv", + # Keyword argument: headers headers={"Content-Disposition": f"attachment; filename={filename}"}, ) +# Apply the @router.get decorator @router.get("/test-results") +# Define function test_results def test_results( + # Entry: state state: Optional[str] = Query(None), + # Entry: date_from date_from: Optional[str] = Query(None, description="ISO date string YYYY-MM-DD"), + # Entry: date_to date_to: Optional[str] = Query(None, description="ISO date string YYYY-MM-DD"), + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(get_current_user), -): +) -> dict: """Report of test results with optional filters.""" + # Return build_test_results_report(db, state=state, date_from=date_from, dat... return build_test_results_report(db, state=state, date_from=date_from, date_to=date_to) +# Apply the @router.get decorator @router.get("/remediation-status") +# Define function remediation_status def remediation_status( + # Entry: status status: Optional[str] = Query(None, description="Filter by remediation status"), + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(get_current_user), -): +) -> dict: """Report of remediation status across all tests.""" + # Return build_remediation_status_report(db, status=status) return build_remediation_status_report(db, status=status) diff --git a/backend/app/routers/scores.py b/backend/app/routers/scores.py index 8f8201a..1fd91ad 100644 --- a/backend/app/routers/scores.py +++ b/backend/app/routers/scores.py @@ -3,28 +3,45 @@ Provides granular scoring with breakdowns and configurable weights. """ +# Import Optional from typing from typing import Optional +# Import APIRouter, Depends, Query from fastapi from fastapi import APIRouter, Depends, Query + +# Import BaseModel from pydantic from pydantic import BaseModel + +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session +# Import get_db from app.database from app.database import get_db + +# Import get_current_user, require_role from app.dependencies.auth from app.dependencies.auth import get_current_user, require_role + +# Import UnitOfWork from app.domain.unit_of_work from app.domain.unit_of_work import UnitOfWork + +# Import User from app.models.user from app.models.user import User -from app.services.scoring_service import ( - score_technique_by_mitre_id, - score_actor_by_id, - calculate_tactic_score, - calculate_organization_score, - get_score_history, -) + +# Import from app.services.scoring_config_service from app.services.scoring_config_service import ( get_weights_dict, update_scoring_weights, ) +# Import from app.services.scoring_service +from app.services.scoring_service import ( + calculate_tactic_score, + get_score_history, + score_actor_by_id, + score_technique_by_mitre_id, +) + +# Assign router = APIRouter(prefix="/scores", tags=["scores"]) router = APIRouter(prefix="/scores", tags=["scores"]) @@ -32,12 +49,26 @@ router = APIRouter(prefix="/scores", tags=["scores"]) @router.get("/technique/{mitre_id}") +# Define function score_technique def score_technique( + # Entry: mitre_id mitre_id: str, + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(get_current_user), -): - """Get detailed score with breakdown for a specific technique.""" +) -> dict: + """Get detailed score with breakdown for a specific technique. + + Args: + mitre_id (str): MITRE ATT&CK technique ID (e.g. ``T1059``). + db (Session): SQLAlchemy database session. + current_user (User): Authenticated user making the request. + + Returns: + dict: Score value and component breakdown (tests, detection rules, recency, etc.). + """ + # Return score_technique_by_mitre_id(db, mitre_id) return score_technique_by_mitre_id(db, mitre_id) @@ -45,12 +76,26 @@ def score_technique( @router.get("/tactic/{tactic}") +# Define function score_tactic def score_tactic( + # Entry: tactic tactic: str, + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(get_current_user), -): - """Get average score for a tactic.""" +) -> dict: + """Get average score for a tactic. + + Args: + tactic (str): MITRE ATT&CK tactic slug (e.g. ``initial-access``). + db (Session): SQLAlchemy database session. + current_user (User): Authenticated user making the request. + + Returns: + dict: Average score and per-technique breakdown for the tactic. + """ + # Return calculate_tactic_score(tactic, db) return calculate_tactic_score(tactic, db) @@ -58,12 +103,26 @@ def score_tactic( @router.get("/threat-actor/{actor_id}") +# Define function score_threat_actor def score_threat_actor( + # Entry: actor_id actor_id: str, + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(get_current_user), -): - """Get coverage score against a specific threat actor.""" +) -> dict: + """Get coverage score against a specific threat actor. + + Args: + actor_id (str): UUID string of the threat actor to score against. + db (Session): SQLAlchemy database session. + current_user (User): Authenticated user making the request. + + Returns: + dict: Coverage score and per-technique breakdown for the threat actor. + """ + # Return score_actor_by_id(db, actor_id) return score_actor_by_id(db, actor_id) @@ -71,13 +130,26 @@ def score_threat_actor( @router.get("/organization") +# Define function score_organization def score_organization( + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(get_current_user), -): - """Get the overall organization security score (cached for 5 min).""" +) -> dict: + """Get the overall organization security score (cached for 5 min). + + Args: + db (Session): SQLAlchemy database session. + current_user (User): Authenticated user making the request. + + Returns: + dict: Aggregate organization score with tactic-level breakdowns. + """ + # Import get_organization_score_cached from app.services.score_cache from app.services.score_cache import get_organization_score_cached + # Return get_organization_score_cached(db) return get_organization_score_cached(db) @@ -85,12 +157,26 @@ def score_organization( @router.get("/history") +# Define function score_history def score_history( + # Entry: period period: str = Query("90d", pattern="^(30d|90d|1y)$"), + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(get_current_user), -): - """Get historical score data points (weekly).""" +) -> dict: + """Get historical score data points (weekly). + + Args: + period (str): Time window for history — one of ``30d``, ``90d``, or ``1y``. + db (Session): SQLAlchemy database session. + current_user (User): Authenticated user making the request. + + Returns: + dict: Weekly score data points for the requested period. + """ + # Return get_score_history(db, period) return get_score_history(db, period) @@ -98,11 +184,23 @@ def score_history( @router.get("/config") +# Define function get_scoring_config def get_scoring_config( + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(require_role("admin")), -): - """Get current scoring weights (admin only).""" +) -> dict: + """Get current scoring weights (admin only). + + Args: + db (Session): SQLAlchemy database session. + current_user (User): Authenticated admin user. + + Returns: + dict: Current weight values for each scoring component. + """ + # Return get_weights_dict(db) return get_weights_dict(db) @@ -110,41 +208,77 @@ def get_scoring_config( class ScoringConfigUpdate(BaseModel): + """Partial update payload for the scoring weight configuration.""" + + # Assign tests = None tests: Optional[float] = None + # Assign detection_rules = None detection_rules: Optional[float] = None + # Assign d3fend = None d3fend: Optional[float] = None + # Assign recency = None recency: Optional[float] = None + # Assign severity = None severity: Optional[float] = None + # Assign freshness = None freshness: Optional[float] = None + # Assign platform_diversity = None platform_diversity: Optional[float] = None +# Apply the @router.patch decorator @router.patch("/config") +# Define function update_scoring_config def update_scoring_config( + # Entry: payload payload: ScoringConfigUpdate, + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(require_role("admin")), -): +) -> dict: """Update scoring weights (admin only). Weights are persisted in the database and survive restarts. Validation enforces that all weights are non-negative and sum to 100. + + Args: + payload (ScoringConfigUpdate): Partial weight update; only set fields are changed. + db (Session): SQLAlchemy database session. + current_user (User): Authenticated admin user. + + Returns: + dict: Confirmation message plus the full updated weight configuration. """ + # Open context manager with UnitOfWork(db) as uow: + # Assign result = update_scoring_weights( result = update_scoring_weights( db, + # Keyword argument: tests tests=payload.tests, + # Keyword argument: detection_rules detection_rules=payload.detection_rules, + # Keyword argument: d3fend d3fend=payload.d3fend, + # Keyword argument: recency recency=payload.recency, + # Keyword argument: severity severity=payload.severity, + # Keyword argument: freshness freshness=payload.freshness, + # Keyword argument: platform_diversity platform_diversity=payload.platform_diversity, + # Keyword argument: updated_by updated_by=current_user.id, ) + # Call uow.commit() uow.commit() + # Import invalidate from app.services.score_cache from app.services.score_cache import invalidate + # Call invalidate() invalidate() + # Return {"message": "Scoring config updated", **result} return {"message": "Scoring config updated", **result} diff --git a/backend/app/routers/snapshots.py b/backend/app/routers/snapshots.py index ad0a4a9..d4987b7 100644 --- a/backend/app/routers/snapshots.py +++ b/backend/app/routers/snapshots.py @@ -4,40 +4,71 @@ Provides periodic and manual snapshots of the organisation's coverage state, plus temporal comparison between any two snapshots. """ +# Import logging import logging + +# Import uuid import uuid + +# Import Optional from typing from typing import Optional +# Import APIRouter, Depends, Query from fastapi from fastapi import APIRouter, Depends, Query + +# Import BaseModel from pydantic from pydantic import BaseModel + +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session +# Import get_db from app.database from app.database import get_db + +# Import get_current_user, require_any_role, require_role from app.dependencies.auth from app.dependencies.auth import get_current_user, require_any_role, require_role + +# Import BusinessRuleViolation from app.domain.errors from app.domain.errors import BusinessRuleViolation + +# Import UnitOfWork from app.domain.unit_of_work from app.domain.unit_of_work import UnitOfWork + +# Import User from app.models.user from app.models.user import User -from app.services.snapshot_service import ( - create_snapshot, - compare_snapshots, - cleanup_old_snapshots, - get_coverage_evolution, - serialize_snapshot_summary, - list_snapshots as list_snapshots_svc, - get_snapshot_or_raise, - get_snapshot_detail, - delete_snapshot, -) + +# Import log_action from app.services.audit_service from app.services.audit_service import log_action +# Import from app.services.snapshot_service +from app.services.snapshot_service import ( + compare_snapshots, + create_snapshot, + delete_snapshot, + get_coverage_evolution, + get_snapshot_detail, + get_snapshot_or_raise, + serialize_snapshot_summary, +) + +# Import from app.services.snapshot_service +from app.services.snapshot_service import ( + list_snapshots as list_snapshots_svc, +) + +# Assign logger = logging.getLogger(__name__) logger = logging.getLogger(__name__) +# Assign router = APIRouter(prefix="/snapshots", tags=["snapshots"]) router = APIRouter(prefix="/snapshots", tags=["snapshots"]) # ── Pydantic schemas ───────────────────────────────────────────────── class SnapshotCreate(BaseModel): + """Payload for creating a new coverage snapshot.""" + + # Assign name = None name: Optional[str] = None @@ -46,13 +77,19 @@ class SnapshotCreate(BaseModel): # --------------------------------------------------------------------------- @router.get("") +# Define function list_snapshots def list_snapshots( + # Entry: offset offset: int = Query(0, ge=0), + # Entry: limit limit: int = Query(50, ge=1, le=200), + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(get_current_user), -): +) -> list: """List coverage snapshots ordered by creation date (newest first).""" + # Return list_snapshots_svc(db, offset=offset, limit=limit) return list_snapshots_svc(db, offset=offset, limit=limit) @@ -61,25 +98,39 @@ def list_snapshots( # --------------------------------------------------------------------------- @router.post("", status_code=201) +# Define function create_snapshot_endpoint def create_snapshot_endpoint( + # Entry: payload payload: SnapshotCreate, + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(require_any_role("red_lead", "blue_lead", "admin")), -): +) -> dict: """Create a manual coverage snapshot with an optional name.""" + # Assign snapshot = create_snapshot(db, name=payload.name, user_id=current_user.id) snapshot = create_snapshot(db, name=payload.name, user_id=current_user.id) + # Open context manager with UnitOfWork(db) as uow: + # Call log_action() log_action( db, + # Keyword argument: user_id user_id=current_user.id, + # Keyword argument: action action="create_snapshot", + # Keyword argument: entity_type entity_type="snapshot", + # Keyword argument: entity_id entity_id=snapshot.id, + # Keyword argument: details details={"name": snapshot.name, "score": snapshot.organization_score}, ) + # Call uow.commit() uow.commit() + # Return serialize_snapshot_summary(snapshot) return serialize_snapshot_summary(snapshot) @@ -89,12 +140,17 @@ def create_snapshot_endpoint( @router.get("/evolution") +# Define function coverage_evolution def coverage_evolution( + # Entry: months months: int = Query(12, ge=1, le=36), + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(get_current_user), -): +) -> list: """Return coverage snapshots for trend charts (last *months* months).""" + # Return get_coverage_evolution(db, months=months) return get_coverage_evolution(db, months=months) @@ -103,19 +159,30 @@ def coverage_evolution( # --------------------------------------------------------------------------- @router.get("/compare") +# Define function compare_snapshots_endpoint def compare_snapshots_endpoint( + # Entry: a a: str = Query(..., description="Snapshot A ID"), + # Entry: b b: str = Query(..., description="Snapshot B ID"), + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(get_current_user), -): +) -> dict: """Compare two snapshots showing improved, worsened, and unchanged techniques.""" + # Attempt the following; catch errors below try: + # Assign a_id = uuid.UUID(a) a_id = uuid.UUID(a) + # Assign b_id = uuid.UUID(b) b_id = uuid.UUID(b) + # Handle ValueError except ValueError: + # Raise BusinessRuleViolation raise BusinessRuleViolation("Invalid snapshot ID format") + # Return compare_snapshots(db, a_id, b_id) return compare_snapshots(db, a_id, b_id) @@ -124,12 +191,17 @@ def compare_snapshots_endpoint( # --------------------------------------------------------------------------- @router.get("/{snapshot_id}") +# Define function get_snapshot def get_snapshot( + # Entry: snapshot_id snapshot_id: str, + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(get_current_user), -): +) -> dict: """Get detailed snapshot information including per-technique states.""" + # Return get_snapshot_detail(db, snapshot_id) return get_snapshot_detail(db, snapshot_id) @@ -138,24 +210,39 @@ def get_snapshot( # --------------------------------------------------------------------------- @router.delete("/{snapshot_id}") +# Define function delete_snapshot_endpoint def delete_snapshot_endpoint( + # Entry: snapshot_id snapshot_id: str, + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(require_role("admin")), -): +) -> dict: """Delete a snapshot (admin only).""" + # Assign snapshot = get_snapshot_or_raise(db, snapshot_id) snapshot = get_snapshot_or_raise(db, snapshot_id) + # Open context manager with UnitOfWork(db) as uow: + # Call log_action() log_action( db, + # Keyword argument: user_id user_id=current_user.id, + # Keyword argument: action action="delete_snapshot", + # Keyword argument: entity_type entity_type="snapshot", + # Keyword argument: entity_id entity_id=snapshot.id, + # Keyword argument: details details={"name": snapshot.name}, ) + # Call delete_snapshot() delete_snapshot(db, snapshot_id) + # Call uow.commit() uow.commit() + # Return {"detail": "Snapshot deleted"} return {"detail": "Snapshot deleted"} diff --git a/backend/app/routers/system.py b/backend/app/routers/system.py index 87cac15..2869285 100644 --- a/backend/app/routers/system.py +++ b/backend/app/routers/system.py @@ -8,6 +8,7 @@ Also exposes email configuration CRUD (admin only) that writes to the system_configs table so settings survive container restarts. """ +# Import logging import logging from typing import Optional @@ -22,10 +23,26 @@ from app.services.mitre_sync_service import sync_mitre from app.services.intel_service import scan_intel from app.services.atomic_import_service import import_atomic_red_team from app.jobs.mitre_sync_job import scheduler + +# Import limiter from app.limiter from app.limiter import limiter +# Import User from app.models.user +from app.models.user import User + +# Import import_atomic_red_team from app.services.atomic_import_service +from app.services.atomic_import_service import import_atomic_red_team + +# Import scan_intel from app.services.intel_service +from app.services.intel_service import scan_intel + +# Import sync_mitre from app.services.mitre_sync_service +from app.services.mitre_sync_service import sync_mitre + +# Assign logger = logging.getLogger(__name__) logger = logging.getLogger(__name__) +# Assign router = APIRouter(prefix="/system", tags=["system"]) router = APIRouter(prefix="/system", tags=["system"]) @@ -105,8 +122,11 @@ def _bg_mitre_sync() -> None: @router.post("/sync-mitre") +# Apply the @limiter.limit decorator @limiter.limit("2/hour") +# Define function trigger_mitre_sync def trigger_mitre_sync( + # Entry: request request: Request, background_tasks: BackgroundTasks, current_user: User = Depends(require_role("admin")), @@ -127,11 +147,15 @@ def trigger_mitre_sync( } +# Apply the @router.post decorator @router.post("/run-intel-scan") +# Define function trigger_intel_scan def trigger_intel_scan( + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(require_role("admin")), -): +) -> dict: """Manually trigger a threat-intelligence scan. **Requires** the ``admin`` role. @@ -139,20 +163,30 @@ def trigger_intel_scan( Returns a JSON object with the scan summary including the count of new intel items found. """ + # Assign summary = scan_intel(db) summary = scan_intel(db) + # Return { return { + # Literal argument value "message": "Intel scan completed", + # Literal argument value "new_items": summary["new_items"], } +# Apply the @router.post decorator @router.post("/import-atomic-tests") +# Apply the @limiter.limit decorator @limiter.limit("2/hour") +# Define function trigger_atomic_import def trigger_atomic_import( + # Entry: request request: Request, + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(require_role("admin")), -): +) -> dict: """Trigger an import of Atomic Red Team tests as TestTemplates. **Requires** the ``admin`` role. @@ -163,37 +197,58 @@ def trigger_atomic_import( Returns a JSON object with import statistics. """ + # Attempt the following; catch errors below try: + # Assign summary = import_atomic_red_team(db) summary = import_atomic_red_team(db) + # Handle Exception except Exception as exc: + # Log error: "Atomic Red Team import failed: %s", exc, exc_info logger.error("Atomic Red Team import failed: %s", exc, exc_info=True) + # Return { return { + # Literal argument value "message": "Import failed. Check server logs for details.", } + # Return { return { + # Literal argument value "message": "Import completed", + # Literal argument value "imported": summary["created"], + # Literal argument value "skipped": summary["skipped_existing"], + # Literal argument value "total_parsed": summary["total_tests_parsed"], } +# Apply the @router.get decorator @router.get("/scheduler-status") +# Define function scheduler_status def scheduler_status( + # Entry: current_user current_user: User = Depends(require_role("admin")), -): +) -> dict: """Return the current state of the background scheduler. **Requires** the ``admin`` role. """ + # Assign jobs = scheduler.get_jobs() jobs = scheduler.get_jobs() + # Return { return { + # Literal argument value "running": scheduler.running, + # Literal argument value "jobs": [ { + # Literal argument value "id": job.id, + # Literal argument value "name": job.name, + # Literal argument value "next_run_time": str(job.next_run_time) if job.next_run_time else None, } for job in jobs diff --git a/backend/app/routers/techniques.py b/backend/app/routers/techniques.py index 9737ada..805eee7 100644 --- a/backend/app/routers/techniques.py +++ b/backend/app/routers/techniques.py @@ -5,29 +5,56 @@ for error signaling. The error_handler middleware maps domain exceptions to HTTP responses automatically. """ +# Import APIRouter, Depends, Query, status from fastapi from fastapi import APIRouter, Depends, Query, status + +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session +# Import get_db from app.database from app.database import get_db -from app.dependencies.auth import get_current_user, require_role, require_any_role + +# Import get_current_user, require_any_role, require_role from app.dependencies.auth +from app.dependencies.auth import get_current_user, require_any_role, require_role + +# Import get_technique_repository from app.dependencies.repositories from app.dependencies.repositories import get_technique_repository + +# Import TechniqueEntity from app.domain.entities.technique from app.domain.entities.technique import TechniqueEntity -from app.domain.errors import DuplicateEntityError, EntityNotFoundError + +# Import TechniqueStatus from app.domain.enums from app.domain.enums import TechniqueStatus + +# Import DuplicateEntityError, EntityNotFoundError from app.domain.errors +from app.domain.errors import DuplicateEntityError, EntityNotFoundError + +# Import UnitOfWork from app.domain.unit_of_work from app.domain.unit_of_work import UnitOfWork + +# Import from app.infrastructure.persistence.repositories.sa_technique_repository from app.infrastructure.persistence.repositories.sa_technique_repository import ( SATechniqueRepository, ) + +# Import User from app.models.user from app.models.user import User + +# Import from app.schemas.technique from app.schemas.technique import ( TechniqueCreate, TechniqueOut, TechniqueSummary, TechniqueUpdate, ) + +# Import log_action from app.services.audit_service from app.services.audit_service import log_action + +# Import get_technique_detail from app.services.technique_query_service from app.services.technique_query_service import get_technique_detail +# Assign router = APIRouter(prefix="/techniques", tags=["techniques"]) router = APIRouter(prefix="/techniques", tags=["techniques"]) @@ -37,19 +64,29 @@ router = APIRouter(prefix="/techniques", tags=["techniques"]) @router.get("", response_model=list[TechniqueSummary]) +# Define function list_techniques def list_techniques( + # Entry: tactic tactic: str | None = Query(None, description="Filter by tactic name"), + # Entry: status_global status_global: TechniqueStatus | None = Query( None, alias="status", description="Filter by global status" ), + # Entry: review_required review_required: bool | None = Query(None, description="Filter by review flag"), + # Entry: repo repo: SATechniqueRepository = Depends(get_technique_repository), + # Entry: current_user current_user: User = Depends(get_current_user), -): +) -> list: """Return a lightweight list of techniques, optionally filtered.""" + # Return repo.list_all( return repo.list_all( + # Keyword argument: tactic tactic=tactic, + # Keyword argument: status status=status_global, + # Keyword argument: review_required review_required=review_required, ) @@ -60,12 +97,17 @@ def list_techniques( @router.get("/{mitre_id}") +# Define function get_technique def get_technique( + # Entry: mitre_id mitre_id: str, + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(get_current_user), -): +) -> dict: """Return full details for a single technique, including its tests and D3FEND defenses.""" + # Return get_technique_detail(db, mitre_id) return get_technique_detail(db, mitre_id) @@ -75,40 +117,66 @@ def get_technique( @router.post( + # Literal argument value "", + # Keyword argument: response_model response_model=TechniqueOut, + # Keyword argument: status_code status_code=status.HTTP_201_CREATED, ) +# Define function create_technique def create_technique( + # Entry: payload payload: TechniqueCreate, + # Entry: db db: Session = Depends(get_db), + # Entry: repo repo: SATechniqueRepository = Depends(get_technique_repository), + # Entry: current_user current_user: User = Depends(require_role("admin")), -): +) -> TechniqueOut: """Create a new technique manually.""" + # Check: repo.exists_by_mitre_id(payload.mitre_id) if repo.exists_by_mitre_id(payload.mitre_id): + # Raise DuplicateEntityError raise DuplicateEntityError("Technique", "mitre_id", payload.mitre_id) + # Assign entity = TechniqueEntity.create( entity = TechniqueEntity.create( + # Keyword argument: mitre_id mitre_id=payload.mitre_id, + # Keyword argument: name name=payload.name, + # Keyword argument: description description=payload.description, + # Keyword argument: tactic tactic=payload.tactic, + # Keyword argument: platforms platforms=payload.platforms, ) + # Open context manager with UnitOfWork(db) as uow: + # Assign saved = repo.save(entity) saved = repo.save(entity) + # Call log_action() log_action( db, + # Keyword argument: user_id user_id=current_user.id, + # Keyword argument: action action="create_technique", + # Keyword argument: entity_type entity_type="technique", + # Keyword argument: entity_id entity_id=saved.id, + # Keyword argument: details details={"mitre_id": saved.mitre_id, "name": saved.name}, ) + # Call uow.commit() uow.commit() + # Return saved return saved @@ -118,34 +186,56 @@ def create_technique( @router.patch("/{mitre_id}", response_model=TechniqueOut) +# Define function update_technique def update_technique( + # Entry: mitre_id mitre_id: str, + # Entry: payload payload: TechniqueUpdate, + # Entry: db db: Session = Depends(get_db), + # Entry: repo repo: SATechniqueRepository = Depends(get_technique_repository), + # Entry: current_user current_user: User = Depends(require_role("admin")), -): +) -> TechniqueOut: """Update one or more fields of an existing technique.""" + # Assign entity = repo.find_by_mitre_id(mitre_id) entity = repo.find_by_mitre_id(mitre_id) + # Check: entity is None if entity is None: + # Raise EntityNotFoundError raise EntityNotFoundError("Technique", mitre_id) + # Assign update_data = payload.model_dump(exclude_unset=True) update_data = payload.model_dump(exclude_unset=True) + # Iterate over update_data.items() for field, value in update_data.items(): + # Call setattr() setattr(entity, field, value) + # Open context manager with UnitOfWork(db) as uow: + # Assign saved = repo.save(entity) saved = repo.save(entity) + # Call log_action() log_action( db, + # Keyword argument: user_id user_id=current_user.id, + # Keyword argument: action action="update_technique", + # Keyword argument: entity_type entity_type="technique", + # Keyword argument: entity_id entity_id=saved.id, + # Keyword argument: details details={"mitre_id": mitre_id, "updated_fields": list(update_data.keys())}, ) + # Call uow.commit() uow.commit() + # Return saved return saved @@ -155,33 +245,52 @@ def update_technique( @router.patch("/{mitre_id}/review", response_model=TechniqueOut) +# Define function review_technique def review_technique( + # Entry: mitre_id mitre_id: str, + # Entry: db db: Session = Depends(get_db), + # Entry: repo repo: SATechniqueRepository = Depends(get_technique_repository), + # Entry: current_user current_user: User = Depends(require_any_role("red_lead", "blue_lead")), -): +) -> TechniqueOut: """Mark a technique as reviewed. Sets ``review_required`` to *False* and records the current timestamp in ``last_review_date``. """ + # Assign entity = repo.find_by_mitre_id(mitre_id) entity = repo.find_by_mitre_id(mitre_id) + # Check: entity is None if entity is None: + # Raise EntityNotFoundError raise EntityNotFoundError("Technique", mitre_id) + # Call entity.mark_reviewed() entity.mark_reviewed() + # Open context manager with UnitOfWork(db) as uow: + # Assign saved = repo.save(entity) saved = repo.save(entity) + # Call log_action() log_action( db, + # Keyword argument: user_id user_id=current_user.id, + # Keyword argument: action action="review_technique", + # Keyword argument: entity_type entity_type="technique", + # Keyword argument: entity_id entity_id=saved.id, + # Keyword argument: details details={"mitre_id": mitre_id}, ) + # Call uow.commit() uow.commit() + # Return saved return saved diff --git a/backend/app/routers/test_templates.py b/backend/app/routers/test_templates.py index e569570..5f32dc3 100644 --- a/backend/app/routers/test_templates.py +++ b/backend/app/routers/test_templates.py @@ -22,35 +22,69 @@ Filters (GET /test-templates) - offset / limit: pagination (default limit=50) """ +# Import uuid import uuid + +# Import Optional from typing from typing import Optional +# Import APIRouter, Depends, Query, status from fastapi from fastapi import APIRouter, Depends, Query, status + +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session +# Import get_db from app.database from app.database import get_db + +# Import get_current_user, require_any_role from app.dependencies.auth from app.dependencies.auth import get_current_user, require_any_role + +# Import UnitOfWork from app.domain.unit_of_work from app.domain.unit_of_work import UnitOfWork from app.models.technique import Technique from app.models.user import User + +# Import from app.schemas.test_template from app.schemas.test_template import ( TestTemplateCreate, TestTemplateOut, TestTemplateSummary, ) + +# Import log_action from app.services.audit_service from app.services.audit_service import log_action + +# Import from app.services.test_template_service from app.services.test_template_service import ( bulk_activate, - create_template as create_template_svc, get_template_or_raise, get_template_stats, - get_templates_by_technique as templates_by_technique, list_templates, soft_delete_template, +) + +# Import from app.services.test_template_service +from app.services.test_template_service import ( + create_template as create_template_svc, +) + +# Import from app.services.test_template_service +from app.services.test_template_service import ( + get_templates_by_technique as templates_by_technique, +) + +# Import from app.services.test_template_service +from app.services.test_template_service import ( toggle_template_active as toggle_template_active_svc, +) + +# Import from app.services.test_template_service +from app.services.test_template_service import ( update_template as update_template_svc, ) +# Assign router = APIRouter(prefix="/test-templates", tags=["test-templates"]) router = APIRouter(prefix="/test-templates", tags=["test-templates"]) @@ -60,28 +94,64 @@ router = APIRouter(prefix="/test-templates", tags=["test-templates"]) @router.get("", response_model=list[TestTemplateSummary]) +# Define function _list_templates_handler def _list_templates_handler( + # Entry: source source: Optional[str] = Query(None, description="Filter by source (atomic_red_team, mitre, custom)"), + # Entry: platform platform: Optional[str] = Query(None, description="Filter by platform (windows, linux, macos)"), + # Entry: severity severity: Optional[str] = Query(None, description="Filter by severity (low, medium, high, critical)"), + # Entry: mitre_technique_id mitre_technique_id: Optional[str] = Query(None, description="Filter by MITRE technique ID"), + # Entry: search search: Optional[str] = Query(None, description="Search in name and description"), + # Entry: is_active is_active: Optional[bool] = Query(None, description="Filter by active status (true/false). Omit to return all."), + # Entry: offset offset: int = Query(0, ge=0), + # Entry: limit limit: int = Query(50, ge=1, le=200), + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(get_current_user), -): - """Return a paginated, filterable list of test templates.""" +) -> list: + """Return a paginated, filterable list of test templates. + + Args: + source (Optional[str]): Filter by source (``atomic_red_team``, ``mitre``, ``custom``). + platform (Optional[str]): Filter by platform (``windows``, ``linux``, ``macos``). + severity (Optional[str]): Filter by severity (``low``, ``medium``, ``high``, ``critical``). + mitre_technique_id (Optional[str]): Filter by MITRE technique ID string. + search (Optional[str]): Full-text search across name and description. + is_active (Optional[bool]): Filter by active status; omit to return all. + offset (int): Number of records to skip for pagination. + limit (int): Maximum number of records to return. + db (Session): SQLAlchemy database session. + current_user (User): Authenticated user making the request. + + Returns: + list: Serialised list of :class:`TestTemplateSummary` objects. + """ + # Return list_templates( return list_templates( db, + # Keyword argument: source source=source, + # Keyword argument: platform platform=platform, + # Keyword argument: severity severity=severity, + # Keyword argument: mitre_technique_id mitre_technique_id=mitre_technique_id, + # Keyword argument: search search=search, + # Keyword argument: is_active is_active=is_active, + # Keyword argument: offset offset=offset, + # Keyword argument: limit limit=limit, ) @@ -92,11 +162,23 @@ def _list_templates_handler( @router.get("/stats") +# Define function template_stats def template_stats( + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(require_any_role("red_lead", "blue_lead")), -): - """Return catalog statistics: active, by_source, by_platform.""" +) -> dict: + """Return catalog statistics: active, by_source, by_platform. + + Args: + db (Session): SQLAlchemy database session. + current_user (User): Authenticated red_lead or blue_lead. + + Returns: + dict: Counts of active templates broken down by source and platform. + """ + # Return get_template_stats(db) return get_template_stats(db) @@ -106,27 +188,53 @@ def template_stats( @router.patch("/bulk-activate") +# Define function bulk_activate_templates def bulk_activate_templates( + # Entry: activate activate: bool = Query(True, description="True to activate all, False to deactivate all"), + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(require_any_role("red_lead", "blue_lead")), -): - """Set all templates to active or inactive.""" +) -> dict: + """Set all templates to active or inactive. + + Args: + activate (bool): ``True`` to activate all templates, ``False`` to deactivate all. + db (Session): SQLAlchemy database session. + current_user (User): Authenticated red_lead or blue_lead. + + Returns: + dict: Confirmation message with ``affected`` count and the applied ``is_active`` flag. + """ + # Assign count = bulk_activate(db, activate=activate) count = bulk_activate(db, activate=activate) + # Open context manager with UnitOfWork(db) as uow: + # Call log_action() log_action( db, + # Keyword argument: user_id user_id=current_user.id, + # Keyword argument: action action="bulk_activate_templates" if activate else "bulk_deactivate_templates", + # Keyword argument: entity_type entity_type="test_template", + # Keyword argument: entity_id entity_id=None, + # Keyword argument: details details={"affected": count, "is_active": activate}, ) + # Call uow.commit() uow.commit() + # Return { return { + # Literal argument value "detail": f"{'Activated' if activate else 'Deactivated'} {count} templates", + # Literal argument value "affected": count, + # Literal argument value "is_active": activate, } @@ -137,12 +245,26 @@ def bulk_activate_templates( @router.get("/by-technique/{mitre_id}", response_model=list[TestTemplateSummary]) +# Define function _templates_by_technique_handler def _templates_by_technique_handler( + # Entry: mitre_id mitre_id: str, + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(get_current_user), -): - """Return all active templates mapped to a specific MITRE technique.""" +) -> list: + """Return all active templates mapped to a specific MITRE technique. + + Args: + mitre_id (str): MITRE ATT&CK technique ID (e.g. ``T1059.001``). + db (Session): SQLAlchemy database session. + current_user (User): Authenticated user making the request. + + Returns: + list: Serialised list of :class:`TestTemplateSummary` objects for the technique. + """ + # Return templates_by_technique(db, mitre_id) return templates_by_technique(db, mitre_id) @@ -152,12 +274,26 @@ def _templates_by_technique_handler( @router.get("/{template_id}", response_model=TestTemplateOut) +# Define function get_template def get_template( + # Entry: template_id template_id: uuid.UUID, + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(get_current_user), -): - """Return full details for a single test template.""" +) -> TestTemplateOut: + """Return full details for a single test template. + + Args: + template_id (uuid.UUID): Primary key of the template to retrieve. + db (Session): SQLAlchemy database session. + current_user (User): Authenticated user making the request. + + Returns: + TestTemplateOut: Full template detail including all fields. + """ + # Return get_template_or_raise(db, template_id) return get_template_or_raise(db, template_id) @@ -167,17 +303,35 @@ def get_template( @router.post( + # Literal argument value "", + # Keyword argument: response_model response_model=TestTemplateOut, + # Keyword argument: status_code status_code=status.HTTP_201_CREATED, ) +# Define function create_template def create_template( + # Entry: payload payload: TestTemplateCreate, + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(require_any_role("red_lead", "blue_lead")), -): - """Create a custom test template.""" +) -> TestTemplateOut: + """Create a custom test template. + + Args: + payload (TestTemplateCreate): All fields for the new template. + db (Session): SQLAlchemy database session. + current_user (User): Authenticated red_lead or blue_lead creating the template. + + Returns: + TestTemplateOut: The newly created template with all fields populated. + """ + # Assign template = create_template_svc(db, **payload.model_dump()) template = create_template_svc(db, **payload.model_dump()) + # Open context manager with UnitOfWork(db) as uow: # Flag the associated technique for review — new template available if template.mitre_technique_id: @@ -190,19 +344,30 @@ def create_template( technique.review_required = True log_action( db, + # Keyword argument: user_id user_id=current_user.id, + # Keyword argument: action action="create_test_template", + # Keyword argument: entity_type entity_type="test_template", + # Keyword argument: entity_id entity_id=template.id, + # Keyword argument: details details={ + # Literal argument value "name": template.name, + # Literal argument value "source": template.source, + # Literal argument value "mitre_technique_id": template.mitre_technique_id, }, ) + # Call uow.commit() uow.commit() + # Reload ORM object attributes from the database db.refresh(template) + # Return template return template @@ -212,26 +377,52 @@ def create_template( @router.patch("/{template_id}", response_model=TestTemplateOut) +# Define function update_template def update_template( + # Entry: template_id template_id: uuid.UUID, + # Entry: payload payload: TestTemplateCreate, + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(require_any_role("red_lead", "blue_lead")), -): - """Update fields of an existing test template.""" +) -> TestTemplateOut: + """Update fields of an existing test template. + + Args: + template_id (uuid.UUID): Primary key of the template to update. + payload (TestTemplateCreate): Fields to update; only set fields are applied. + db (Session): SQLAlchemy database session. + current_user (User): Authenticated red_lead or blue_lead updating the template. + + Returns: + TestTemplateOut: The updated template with refreshed field values. + """ + # Assign template = update_template_svc(db, template_id, **payload.model_dump(exclude_u... template = update_template_svc(db, template_id, **payload.model_dump(exclude_unset=True)) + # Open context manager with UnitOfWork(db) as uow: + # Call log_action() log_action( db, + # Keyword argument: user_id user_id=current_user.id, + # Keyword argument: action action="update_test_template", + # Keyword argument: entity_type entity_type="test_template", + # Keyword argument: entity_id entity_id=template.id, + # Keyword argument: details details={"updated_fields": list(payload.model_dump(exclude_unset=True).keys())}, ) + # Call uow.commit() uow.commit() + # Reload ORM object attributes from the database db.refresh(template) + # Return template return template @@ -241,25 +432,49 @@ def update_template( @router.patch("/{template_id}/toggle-active", response_model=TestTemplateOut) +# Define function toggle_template_active def toggle_template_active( + # Entry: template_id template_id: uuid.UUID, + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(require_any_role("red_lead", "blue_lead")), -): - """Toggle a template between active and inactive (is_active = not is_active).""" +) -> TestTemplateOut: + """Toggle a template between active and inactive (is_active = not is_active). + + Args: + template_id (uuid.UUID): Primary key of the template to toggle. + db (Session): SQLAlchemy database session. + current_user (User): Authenticated red_lead or blue_lead. + + Returns: + TestTemplateOut: The template with the updated ``is_active`` flag. + """ + # Assign template = toggle_template_active_svc(db, template_id) template = toggle_template_active_svc(db, template_id) + # Open context manager with UnitOfWork(db) as uow: + # Call log_action() log_action( db, + # Keyword argument: user_id user_id=current_user.id, + # Keyword argument: action action="toggle_test_template", + # Keyword argument: entity_type entity_type="test_template", + # Keyword argument: entity_id entity_id=template.id, + # Keyword argument: details details={"name": template.name, "is_active": template.is_active}, ) + # Call uow.commit() uow.commit() + # Reload ORM object attributes from the database db.refresh(template) + # Return template return template @@ -269,23 +484,47 @@ def toggle_template_active( @router.delete("/{template_id}", status_code=status.HTTP_200_OK) +# Define function delete_template def delete_template( + # Entry: template_id template_id: uuid.UUID, + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(require_any_role("red_lead", "blue_lead")), -): - """Soft-delete a test template by setting ``is_active=False``.""" +) -> dict: + """Soft-delete a test template by setting ``is_active=False``. + + Args: + template_id (uuid.UUID): Primary key of the template to delete. + db (Session): SQLAlchemy database session. + current_user (User): Authenticated red_lead or blue_lead. + + Returns: + dict: Confirmation message with key ``detail``. + """ + # Assign template = get_template_or_raise(db, template_id) template = get_template_or_raise(db, template_id) + # Call soft_delete_template() soft_delete_template(db, template_id) + # Open context manager with UnitOfWork(db) as uow: + # Call log_action() log_action( db, + # Keyword argument: user_id user_id=current_user.id, + # Keyword argument: action action="delete_test_template", + # Keyword argument: entity_type entity_type="test_template", + # Keyword argument: entity_id entity_id=template.id, + # Keyword argument: details details={"name": template.name}, ) + # Call uow.commit() uow.commit() + # Return {"detail": "Test template deactivated"} return {"detail": "Test template deactivated"} diff --git a/backend/app/routers/tests.py b/backend/app/routers/tests.py index 25cb973..69bc557 100644 --- a/backend/app/routers/tests.py +++ b/backend/app/routers/tests.py @@ -26,13 +26,21 @@ import uuid from datetime import datetime from typing import Any, Optional +# Import APIRouter, Depends, HTTPException, Query, Reque... from fastapi from fastapi import APIRouter, Depends, HTTPException, Query, Request, status from pydantic import BaseModel from sqlalchemy.orm import Session +# Import get_db from app.database from app.database import get_db + +# Import get_current_user, require_any_role, require_role from app.dependencies.auth from app.dependencies.auth import get_current_user, require_any_role, require_role -from app.domain.enums import DataClassification + +# Import UnitOfWork from app.domain.unit_of_work +from app.domain.unit_of_work import UnitOfWork + +# Import limiter from app.limiter from app.limiter import limiter from app.models.enums import TestState, TestResult, TeamSide from app.models.evidence import Evidence @@ -40,34 +48,79 @@ from app.storage import upload_file from app.models.technique import Technique from app.models.test import Test from app.models.user import User + +# Import from app.schemas.test from app.schemas.test import ( + TestBlueUpdate, + TestBlueValidate, + TestClassificationUpdate, TestCreate, TestOut, - TestUpdate, TestRedUpdate, - TestBlueUpdate, TestRedValidate, - TestBlueValidate, TestRemediationUpdate, - TestClassificationUpdate, + TestUpdate, ) + +# Import TestTemplateInstantiate from app.schemas.test_template from app.schemas.test_template import TestTemplateInstantiate -from app.domain.unit_of_work import UnitOfWork + +# Import log_action from app.services.audit_service from app.services.audit_service import log_action + +# Import recalculate_technique_status from app.services.status_service from app.services.status_service import recalculate_technique_status from app.services.webhook_service import dispatch_webhook from app.services.test_crud_service import ( create_test as crud_create_test, +) + +# Import from app.services.test_crud_service +from app.services.test_crud_service import ( create_test_from_template as crud_create_from_template, +) + +# Import from app.services.test_crud_service +from app.services.test_crud_service import ( get_test_detail as crud_get_test_detail, +) + +# Import from app.services.test_crud_service +from app.services.test_crud_service import ( get_test_or_raise as crud_get_test_or_raise, +) + +# Import from app.services.test_crud_service +from app.services.test_crud_service import ( get_test_timeline as crud_get_test_timeline, +) + +# Import from app.services.test_crud_service +from app.services.test_crud_service import ( get_test_with_technique as crud_get_test_with_technique, +) + +# Import from app.services.test_crud_service +from app.services.test_crud_service import ( list_tests as crud_list_tests, +) + +# Import from app.services.test_crud_service +from app.services.test_crud_service import ( update_test as crud_update_test, +) + +# Import from app.services.test_crud_service +from app.services.test_crud_service import ( update_test_blue as crud_update_test_blue, +) + +# Import from app.services.test_crud_service +from app.services.test_crud_service import ( update_test_red as crud_update_test_red, ) + +# Import from app.services.test_workflow_service from app.services.test_workflow_service import ( start_execution as wf_start_execution, submit_red_evidence as wf_submit_red, @@ -78,10 +131,54 @@ from app.services.test_workflow_service import ( reopen_test as wf_reopen, handle_remediation_completed as wf_handle_remediation, get_retest_chain as wf_get_retest_chain, +) + +# Import from app.services.test_workflow_service +from app.services.test_workflow_service import ( + handle_remediation_completed as wf_handle_remediation, +) + +# Import from app.services.test_workflow_service +from app.services.test_workflow_service import ( pause_timer as wf_pause_timer, +) + +# Import from app.services.test_workflow_service +from app.services.test_workflow_service import ( + reopen_test as wf_reopen, +) + +# Import from app.services.test_workflow_service +from app.services.test_workflow_service import ( resume_timer as wf_resume_timer, ) +# Import from app.services.test_workflow_service +from app.services.test_workflow_service import ( + start_execution as wf_start_execution, +) + +# Import from app.services.test_workflow_service +from app.services.test_workflow_service import ( + submit_blue_evidence as wf_submit_blue, +) + +# Import from app.services.test_workflow_service +from app.services.test_workflow_service import ( + submit_red_evidence as wf_submit_red, +) + +# Import from app.services.test_workflow_service +from app.services.test_workflow_service import ( + validate_as_blue_lead as wf_validate_blue, +) + +# Import from app.services.test_workflow_service +from app.services.test_workflow_service import ( + validate_as_red_lead as wf_validate_red, +) + +# Assign router = APIRouter(prefix="/tests", tags=["tests"]) router = APIRouter(prefix="/tests", tags=["tests"]) @@ -91,11 +188,17 @@ router = APIRouter(prefix="/tests", tags=["tests"]) @router.get("", response_model=list[TestOut]) +# Define function list_tests def list_tests( + # Entry: state state: Optional[str] = Query(None, description="Filter by test state"), + # Entry: technique_id technique_id: Optional[uuid.UUID] = Query(None, description="Filter by technique"), + # Entry: platform platform: Optional[str] = Query(None, description="Filter by platform"), + # Entry: created_by created_by: Optional[uuid.UUID] = Query(None, description="Filter by creator"), + # Entry: pending_validation_side pending_validation_side: Optional[str] = Query( None, description="Filter in_review tests pending validation on 'red' or 'blue' side" ), @@ -103,20 +206,46 @@ def list_tests( False, description="Only return tests not linked to any campaign" ), offset: int = Query(0, ge=0), + # Entry: limit limit: int = Query(50, ge=1, le=200), + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(get_current_user), -): - """Return a paginated list of tests, optionally filtered by state, technique, platform or creator.""" +) -> list: + """Return a paginated list of tests, optionally filtered by state, technique, platform or creator. + + Args: + state (Optional[str]): Filter by test state (e.g. ``draft``, ``validated``). + technique_id (Optional[uuid.UUID]): Filter tests belonging to a specific technique. + platform (Optional[str]): Filter by target platform (e.g. ``windows``, ``linux``). + created_by (Optional[uuid.UUID]): Filter by the UUID of the creator. + pending_validation_side (Optional[str]): Filter ``in_review`` tests pending validation + on ``'red'`` or ``'blue'`` side. + offset (int): Number of records to skip for pagination. + limit (int): Maximum number of records to return. + db (Session): SQLAlchemy database session. + current_user (User): Authenticated user making the request. + + Returns: + list: Serialised list of :class:`TestOut` objects matching the filters. + """ + # Return crud_list_tests( return crud_list_tests( db, + # Keyword argument: state state=state, + # Keyword argument: technique_id technique_id=technique_id, + # Keyword argument: platform platform=platform, + # Keyword argument: created_by created_by=created_by, + # Keyword argument: pending_validation_side pending_validation_side=pending_validation_side, not_in_any_campaign=not_in_any_campaign, offset=offset, + # Keyword argument: limit limit=limit, ) @@ -127,37 +256,67 @@ def list_tests( @router.post( + # Literal argument value "", + # Keyword argument: response_model response_model=TestOut, + # Keyword argument: status_code status_code=status.HTTP_201_CREATED, ) +# Apply the @limiter.limit decorator @limiter.limit("30/minute") +# Define function create_test def create_test( + # Entry: request request: Request, + # Entry: payload payload: TestCreate, + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(require_any_role("red_lead", "blue_lead")), -): +) -> TestOut: """Create a new test linked to an existing technique. ``created_by`` is set automatically and ``state`` defaults to *draft*. + + Args: + request (Request): FastAPI request object (used by the rate limiter). + payload (TestCreate): Fields for the new test, including ``technique_id``. + db (Session): SQLAlchemy database session. + current_user (User): Authenticated red_lead or blue_lead creating the test. + + Returns: + TestOut: The newly created test with all fields populated. """ + # Open context manager with UnitOfWork(db) as uow: + # Assign test = crud_create_test( test = crud_create_test( db, + # Keyword argument: technique_id technique_id=payload.technique_id, + # Keyword argument: creator_id creator_id=current_user.id, **payload.model_dump(exclude={"technique_id"}), ) + # Call log_action() log_action( db, + # Keyword argument: user_id user_id=current_user.id, + # Keyword argument: action action="create_test", + # Keyword argument: entity_type entity_type="test", + # Keyword argument: entity_id entity_id=test.id, + # Keyword argument: details details={"name": test.name, "technique_id": str(test.technique_id)}, ) + # Call uow.commit() uow.commit() + # Reload ORM object attributes from the database db.refresh(test) # Auto-create Jira ticket (non-fatal — any failure is logged, not raised) @@ -177,26 +336,49 @@ def create_test( @router.post( + # Literal argument value "/from-template", + # Keyword argument: response_model response_model=TestOut, + # Keyword argument: status_code status_code=status.HTTP_201_CREATED, ) +# Apply the @limiter.limit decorator @limiter.limit("30/minute") +# Define function create_test_from_template def create_test_from_template( + # Entry: request request: Request, + # Entry: payload payload: TestTemplateInstantiate, + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(require_any_role("red_lead", "blue_lead")), -): +) -> TestOut: """Instantiate a real Test from an existing TestTemplate. The template's fields are copied into the new test as starting data. + + Args: + request (Request): FastAPI request object (used by the rate limiter). + payload (TestTemplateInstantiate): Contains ``template_id`` and target ``technique_id``. + db (Session): SQLAlchemy database session. + current_user (User): Authenticated red_lead or blue_lead creating the test. + + Returns: + TestOut: The newly created test populated from the template. """ + # Open context manager with UnitOfWork(db) as uow: + # Assign test = crud_create_from_template( test = crud_create_from_template( db, + # Keyword argument: template_id template_id=payload.template_id, + # Keyword argument: technique_id_or_mitre technique_id_or_mitre=payload.technique_id, + # Keyword argument: creator_id creator_id=current_user.id, name_override=payload.name, description_override=payload.description, @@ -204,19 +386,30 @@ def create_test_from_template( procedure_text_override=payload.procedure_text, tool_used_override=payload.tool_used, ) + # Call log_action() log_action( db, + # Keyword argument: user_id user_id=current_user.id, + # Keyword argument: action action="create_test_from_template", + # Keyword argument: entity_type entity_type="test", + # Keyword argument: entity_id entity_id=test.id, + # Keyword argument: details details={ + # Literal argument value "name": test.name, + # Literal argument value "template_id": str(payload.template_id), + # Literal argument value "technique_id": str(test.technique_id), }, ) + # Call uow.commit() uow.commit() + # Reload ORM object attributes from the database db.refresh(test) # Auto-create Jira ticket (non-fatal) @@ -236,12 +429,26 @@ def create_test_from_template( @router.get("/{test_id}", response_model=TestOut) +# Define function get_test def get_test( + # Entry: test_id test_id: uuid.UUID, + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(get_current_user), -): - """Return full details for a single test, including its evidences.""" +) -> TestOut: + """Return full details for a single test, including its evidences. + + Args: + test_id (uuid.UUID): Primary key of the test to retrieve. + db (Session): SQLAlchemy database session. + current_user (User): Authenticated user making the request. + + Returns: + TestOut: Full test detail including split red/blue evidence lists. + """ + # Return crud_get_test_detail(db, test_id) return crud_get_test_detail(db, test_id) @@ -251,37 +458,65 @@ def get_test( @router.patch("/{test_id}", response_model=TestOut) +# Define function update_test def update_test( + # Entry: test_id test_id: uuid.UUID, + # Entry: payload payload: TestUpdate, + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(require_any_role("red_lead", "blue_lead")), -): +) -> TestOut: """Update one or more fields of an existing test. Only leads or admins can update general test fields. The test must be in ``draft`` or ``rejected`` state. + + Args: + test_id (uuid.UUID): Primary key of the test to update. + payload (TestUpdate): Partial update payload; only set fields are applied. + db (Session): SQLAlchemy database session. + current_user (User): Authenticated red_lead or blue_lead performing the update. + + Returns: + TestOut: The updated test with refreshed field values. """ + # Assign update_data = payload.model_dump(exclude_unset=True) update_data = payload.model_dump(exclude_unset=True) + # Open context manager with UnitOfWork(db) as uow: + # Assign test = crud_update_test( test = crud_update_test( db, test_id, + # Keyword argument: updater_id updater_id=current_user.id, + # Keyword argument: updater_role updater_role=current_user.role, **update_data, ) + # Call log_action() log_action( db, + # Keyword argument: user_id user_id=current_user.id, + # Keyword argument: action action="update_test", + # Keyword argument: entity_type entity_type="test", + # Keyword argument: entity_id entity_id=test.id, + # Keyword argument: details details={"updated_fields": list(update_data.keys())}, ) + # Call uow.commit() uow.commit() + # Reload ORM object attributes from the database db.refresh(test) + # Return test return test @@ -291,27 +526,55 @@ def update_test( @router.patch("/{test_id}/classification", response_model=TestOut) +# Define function update_test_classification def update_test_classification( + # Entry: test_id test_id: uuid.UUID, + # Entry: payload payload: TestClassificationUpdate, + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(require_role("admin")), -): - """Update the data classification label for a test (admin only).""" +) -> TestOut: + """Update the data classification label for a test (admin only). + + Args: + test_id (uuid.UUID): Primary key of the test to classify. + payload (TestClassificationUpdate): Contains the new ``data_classification`` value. + db (Session): SQLAlchemy database session. + current_user (User): Authenticated admin user. + + Returns: + TestOut: The test with the updated ``data_classification`` field. + """ + # Open context manager with UnitOfWork(db) as uow: + # Assign test = crud_get_test_or_raise(db, test_id) test = crud_get_test_or_raise(db, test_id) + # Assign test.data_classification = payload.data_classification.value test.data_classification = payload.data_classification.value + # Flush changes to DB without committing the transaction db.flush() + # Call log_action() log_action( db, + # Keyword argument: user_id user_id=current_user.id, + # Keyword argument: action action="update_test_classification", + # Keyword argument: entity_type entity_type="test", + # Keyword argument: entity_id entity_id=test.id, + # Keyword argument: details details={"data_classification": payload.data_classification.value}, ) + # Call uow.commit() uow.commit() + # Reload ORM object attributes from the database db.refresh(test) + # Return test return test @@ -321,27 +584,54 @@ def update_test_classification( @router.patch("/{test_id}/red", response_model=TestOut) +# Define function update_test_red def update_test_red( + # Entry: test_id test_id: uuid.UUID, + # Entry: payload payload: TestRedUpdate, + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(require_any_role("red_tech", "red_lead")), -): - """Red Team updates their fields (allowed in ``draft`` and ``red_executing``).""" +) -> TestOut: + """Red Team updates their fields (allowed in ``draft`` and ``red_executing``). + + Args: + test_id (uuid.UUID): Primary key of the test to update. + payload (TestRedUpdate): Red-team-specific fields to update. + db (Session): SQLAlchemy database session. + current_user (User): Authenticated red_tech or red_lead. + + Returns: + TestOut: The updated test with refreshed red-team field values. + """ + # Assign update_data = payload.model_dump(exclude_unset=True) update_data = payload.model_dump(exclude_unset=True) + # Open context manager with UnitOfWork(db) as uow: + # Assign test = crud_update_test_red(db, test_id, **update_data) test = crud_update_test_red(db, test_id, **update_data) + # Call log_action() log_action( db, + # Keyword argument: user_id user_id=current_user.id, + # Keyword argument: action action="update_test_red", + # Keyword argument: entity_type entity_type="test", + # Keyword argument: entity_id entity_id=test.id, + # Keyword argument: details details={"updated_fields": list(update_data.keys())}, ) + # Call uow.commit() uow.commit() + # Reload ORM object attributes from the database db.refresh(test) + # Return test return test @@ -351,27 +641,54 @@ def update_test_red( @router.patch("/{test_id}/blue", response_model=TestOut) +# Define function update_test_blue def update_test_blue( + # Entry: test_id test_id: uuid.UUID, + # Entry: payload payload: TestBlueUpdate, + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(require_any_role("blue_tech", "blue_lead")), -): - """Blue Team updates their fields (allowed only in ``blue_evaluating``).""" +) -> TestOut: + """Blue Team updates their fields (allowed only in ``blue_evaluating``). + + Args: + test_id (uuid.UUID): Primary key of the test to update. + payload (TestBlueUpdate): Blue-team-specific fields to update. + db (Session): SQLAlchemy database session. + current_user (User): Authenticated blue_tech or blue_lead. + + Returns: + TestOut: The updated test with refreshed blue-team field values. + """ + # Assign update_data = payload.model_dump(exclude_unset=True) update_data = payload.model_dump(exclude_unset=True) + # Open context manager with UnitOfWork(db) as uow: + # Assign test = crud_update_test_blue(db, test_id, **update_data) test = crud_update_test_blue(db, test_id, **update_data) + # Call log_action() log_action( db, + # Keyword argument: user_id user_id=current_user.id, + # Keyword argument: action action="update_test_blue", + # Keyword argument: entity_type entity_type="test", + # Keyword argument: entity_id entity_id=test.id, + # Keyword argument: details details={"updated_fields": list(update_data.keys())}, ) + # Call uow.commit() uow.commit() + # Reload ORM object attributes from the database db.refresh(test) + # Return test return test @@ -381,17 +698,36 @@ def update_test_blue( @router.post("/{test_id}/start-execution", response_model=TestOut) +# Define function start_execution def start_execution( + # Entry: test_id test_id: uuid.UUID, + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(require_any_role("red_tech", "red_lead")), -): - """Move a test from ``draft`` to ``red_executing``.""" +) -> TestOut: + """Move a test from ``draft`` to ``red_executing``. + + Args: + test_id (uuid.UUID): Primary key of the test to start. + db (Session): SQLAlchemy database session. + current_user (User): Authenticated red_tech or red_lead initiating execution. + + Returns: + TestOut: The updated test in ``red_executing`` state. + """ + # Assign test = crud_get_test_or_raise(db, test_id) test = crud_get_test_or_raise(db, test_id) + # Open context manager with UnitOfWork(db) as uow: + # Assign test = wf_start_execution(db, test, current_user) test = wf_start_execution(db, test, current_user) + # Call uow.commit() uow.commit() + # Reload ORM object attributes from the database db.refresh(test) + # Return test return test @@ -401,17 +737,36 @@ def start_execution( @router.post("/{test_id}/submit-red", response_model=TestOut) +# Define function submit_red def submit_red( + # Entry: test_id test_id: uuid.UUID, + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(require_any_role("red_tech", "red_lead")), -): - """Red Team finalises — move from ``red_executing`` to ``blue_evaluating``.""" +) -> TestOut: + """Red Team finalises — move from ``red_executing`` to ``blue_evaluating``. + + Args: + test_id (uuid.UUID): Primary key of the test to submit. + db (Session): SQLAlchemy database session. + current_user (User): Authenticated red_tech or red_lead submitting red evidence. + + Returns: + TestOut: The updated test in ``blue_evaluating`` state. + """ + # Assign test = crud_get_test_or_raise(db, test_id) test = crud_get_test_or_raise(db, test_id) + # Open context manager with UnitOfWork(db) as uow: + # Assign test = wf_submit_red(db, test, current_user) test = wf_submit_red(db, test, current_user) + # Call uow.commit() uow.commit() + # Reload ORM object attributes from the database db.refresh(test) + # Return test return test @@ -421,17 +776,36 @@ def submit_red( @router.post("/{test_id}/submit-blue", response_model=TestOut) +# Define function submit_blue def submit_blue( + # Entry: test_id test_id: uuid.UUID, + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(require_any_role("blue_tech", "blue_lead")), -): - """Blue Team finalises — move from ``blue_evaluating`` to ``in_review``.""" +) -> TestOut: + """Blue Team finalises — move from ``blue_evaluating`` to ``in_review``. + + Args: + test_id (uuid.UUID): Primary key of the test to submit. + db (Session): SQLAlchemy database session. + current_user (User): Authenticated blue_tech or blue_lead submitting blue evidence. + + Returns: + TestOut: The updated test in ``in_review`` state. + """ + # Assign test = crud_get_test_or_raise(db, test_id) test = crud_get_test_or_raise(db, test_id) + # Open context manager with UnitOfWork(db) as uow: + # Assign test = wf_submit_blue(db, test, current_user) test = wf_submit_blue(db, test, current_user) + # Call uow.commit() uow.commit() + # Reload ORM object attributes from the database db.refresh(test) + # Return test return test @@ -461,17 +835,36 @@ def start_blue_work( @router.post("/{test_id}/pause-timer", response_model=TestOut) +# Define function pause_timer def pause_timer( + # Entry: test_id test_id: uuid.UUID, + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(require_any_role("red_tech", "blue_tech", "red_lead", "blue_lead")), -): - """Pause the running timer for the current phase (red_executing or blue_evaluating).""" +) -> TestOut: + """Pause the running timer for the current phase (red_executing or blue_evaluating). + + Args: + test_id (uuid.UUID): Primary key of the test whose timer should be paused. + db (Session): SQLAlchemy database session. + current_user (User): Authenticated team member in the active phase. + + Returns: + TestOut: The updated test with the phase timer paused. + """ + # Assign test = crud_get_test_or_raise(db, test_id) test = crud_get_test_or_raise(db, test_id) + # Open context manager with UnitOfWork(db) as uow: + # Assign test = wf_pause_timer(db, test, current_user) test = wf_pause_timer(db, test, current_user) + # Call uow.commit() uow.commit() + # Reload ORM object attributes from the database db.refresh(test) + # Return test return test @@ -481,17 +874,36 @@ def pause_timer( @router.post("/{test_id}/resume-timer", response_model=TestOut) +# Define function resume_timer def resume_timer( + # Entry: test_id test_id: uuid.UUID, + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(require_any_role("red_tech", "blue_tech", "red_lead", "blue_lead")), -): - """Resume the paused timer for the current phase.""" +) -> TestOut: + """Resume the paused timer for the current phase. + + Args: + test_id (uuid.UUID): Primary key of the test whose timer should be resumed. + db (Session): SQLAlchemy database session. + current_user (User): Authenticated team member in the active phase. + + Returns: + TestOut: The updated test with the phase timer running again. + """ + # Assign test = crud_get_test_or_raise(db, test_id) test = crud_get_test_or_raise(db, test_id) + # Open context manager with UnitOfWork(db) as uow: + # Assign test = wf_resume_timer(db, test, current_user) test = wf_resume_timer(db, test, current_user) + # Call uow.commit() uow.commit() + # Reload ORM object attributes from the database db.refresh(test) + # Return test return test @@ -501,26 +913,49 @@ def resume_timer( @router.post("/{test_id}/validate-red", response_model=TestOut) +# Define function validate_red def validate_red( + # Entry: test_id test_id: uuid.UUID, + # Entry: payload payload: TestRedValidate, + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(require_any_role("red_lead")), -): - """Red Lead approves or rejects the red side of a test.""" +) -> TestOut: + """Red Lead approves or rejects the red side of a test. + + Args: + test_id (uuid.UUID): Primary key of the test to validate. + payload (TestRedValidate): Validation status and optional notes from the Red Lead. + db (Session): SQLAlchemy database session. + current_user (User): Authenticated red_lead performing the validation. + + Returns: + TestOut: The updated test reflecting the red validation decision. + """ + # Assign test = crud_get_test_with_technique(db, test_id) test = crud_get_test_with_technique(db, test_id) + # Open context manager with UnitOfWork(db) as uow: + # Assign test = wf_validate_red( test = wf_validate_red( db, test, current_user, + # Keyword argument: validation_status validation_status=payload.red_validation_status, + # Keyword argument: notes notes=payload.red_validation_notes, ) + # Check: test.state in (TestState.validated, TestState.rejected) if test.state in (TestState.validated, TestState.rejected): + # Call recalculate_technique_status() recalculate_technique_status(db, test.technique) # Flag technique for review — coverage changed if test.technique: test.technique.review_required = True uow.commit() + # Reload ORM object attributes from the database db.refresh(test) if test.state == TestState.validated: dispatch_webhook("test.validated", {"test_id": str(test.id), "technique_id": str(test.technique_id), "result": test.result.value if test.result else None}) @@ -535,26 +970,49 @@ def validate_red( @router.post("/{test_id}/validate-blue", response_model=TestOut) +# Define function validate_blue def validate_blue( + # Entry: test_id test_id: uuid.UUID, + # Entry: payload payload: TestBlueValidate, + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(require_any_role("blue_lead")), -): - """Blue Lead approves or rejects the blue side of a test.""" +) -> TestOut: + """Blue Lead approves or rejects the blue side of a test. + + Args: + test_id (uuid.UUID): Primary key of the test to validate. + payload (TestBlueValidate): Validation status and optional notes from the Blue Lead. + db (Session): SQLAlchemy database session. + current_user (User): Authenticated blue_lead performing the validation. + + Returns: + TestOut: The updated test reflecting the blue validation decision. + """ + # Assign test = crud_get_test_with_technique(db, test_id) test = crud_get_test_with_technique(db, test_id) + # Open context manager with UnitOfWork(db) as uow: + # Assign test = wf_validate_blue( test = wf_validate_blue( db, test, current_user, + # Keyword argument: validation_status validation_status=payload.blue_validation_status, + # Keyword argument: notes notes=payload.blue_validation_notes, ) + # Check: test.state in (TestState.validated, TestState.rejected) if test.state in (TestState.validated, TestState.rejected): + # Call recalculate_technique_status() recalculate_technique_status(db, test.technique) # Flag technique for review — coverage changed if test.technique: test.technique.review_required = True uow.commit() + # Reload ORM object attributes from the database db.refresh(test) if test.state == TestState.validated: dispatch_webhook("test.validated", {"test_id": str(test.id), "technique_id": str(test.technique_id), "result": test.result.value if test.result else None}) @@ -569,17 +1027,36 @@ def validate_blue( @router.post("/{test_id}/reopen", response_model=TestOut) +# Define function reopen def reopen( + # Entry: test_id test_id: uuid.UUID, + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(require_any_role("red_lead", "blue_lead")), -): - """Reopen a rejected test, moving it back to ``draft``.""" +) -> TestOut: + """Reopen a rejected test, moving it back to ``draft``. + + Args: + test_id (uuid.UUID): Primary key of the rejected test to reopen. + db (Session): SQLAlchemy database session. + current_user (User): Authenticated red_lead or blue_lead reopening the test. + + Returns: + TestOut: The updated test in ``draft`` state. + """ + # Assign test = crud_get_test_or_raise(db, test_id) test = crud_get_test_or_raise(db, test_id) + # Open context manager with UnitOfWork(db) as uow: + # Assign test = wf_reopen(db, test, current_user) test = wf_reopen(db, test, current_user) + # Call uow.commit() uow.commit() + # Reload ORM object attributes from the database db.refresh(test) + # Return test return test @@ -589,42 +1066,74 @@ def reopen( @router.patch("/{test_id}/remediation", response_model=TestOut) +# Define function update_remediation def update_remediation( + # Entry: test_id test_id: uuid.UUID, + # Entry: payload payload: TestRemediationUpdate, + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(require_any_role("red_lead", "blue_lead")), -): +) -> TestOut: """Update remediation fields on a test. When ``remediation_status`` transitions to ``'completed'``, an automatic re-test is created (subject to ``MAX_RETEST_COUNT``). + + Args: + test_id (uuid.UUID): Primary key of the test to update. + payload (TestRemediationUpdate): Remediation fields to update (status, notes, etc.). + db (Session): SQLAlchemy database session. + current_user (User): Authenticated red_lead or blue_lead updating remediation. + + Returns: + TestOut: The updated test with refreshed remediation fields. """ + # Assign test = crud_get_test_or_raise(db, test_id) test = crud_get_test_or_raise(db, test_id) + # Assign old_remediation_status = test.remediation_status old_remediation_status = test.remediation_status + # Assign update_data = payload.model_dump(exclude_unset=True) update_data = payload.model_dump(exclude_unset=True) + # Iterate over update_data.items() for field, value in update_data.items(): + # Call setattr() setattr(test, field, value) + # Open context manager with UnitOfWork(db) as uow: + # Call log_action() log_action( db, + # Keyword argument: user_id user_id=current_user.id, + # Keyword argument: action action="update_remediation", + # Keyword argument: entity_type entity_type="test", + # Keyword argument: entity_id entity_id=test.id, + # Keyword argument: details details={"updated_fields": list(update_data.keys())}, ) + # Assign new_status = update_data.get("remediation_status") new_status = update_data.get("remediation_status") + # Check: new_status == "completed" and old_remediation_status != "completed" if new_status == "completed" and old_remediation_status != "completed": + # Call wf_handle_remediation() wf_handle_remediation(db, test, current_user) + # Call uow.commit() uow.commit() + # Reload ORM object attributes from the database db.refresh(test) + # Return test return test @@ -634,12 +1143,26 @@ def update_remediation( @router.get("/{test_id}/timeline") +# Define function get_test_timeline def get_test_timeline( + # Entry: test_id test_id: uuid.UUID, + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(get_current_user), -): - """Return the chronological audit-log history for a test.""" +) -> list: + """Return the chronological audit-log history for a test. + + Args: + test_id (uuid.UUID): Primary key of the test whose timeline is requested. + db (Session): SQLAlchemy database session. + current_user (User): Authenticated user making the request. + + Returns: + list: Chronological list of audit-log entries for the test. + """ + # Return crud_get_test_timeline(db, test_id) return crud_get_test_timeline(db, test_id) @@ -649,26 +1172,53 @@ def get_test_timeline( @router.get("/{test_id}/retest-chain") +# Define function get_retest_chain def get_retest_chain( + # Entry: test_id test_id: uuid.UUID, + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(get_current_user), -): - """Return the full chain of retests (original + all retests) for a test.""" +) -> list: + """Return the full chain of retests (original + all retests) for a test. + + Args: + test_id (uuid.UUID): Primary key of any test in the retest chain. + db (Session): SQLAlchemy database session. + current_user (User): Authenticated user making the request. + + Returns: + list: Ordered list of dicts describing each test in the chain, + including state, result, remediation status, and retest metadata. + """ + # Assign chain = wf_get_retest_chain(db, test_id) chain = wf_get_retest_chain(db, test_id) + # Check: not chain if not chain: + # Raise HTTPException raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Test not found") + # Return [ return [ { + # Literal argument value "id": str(t.id), + # Literal argument value "name": t.name, + # Literal argument value "state": t.state.value if t.state else None, + # Literal argument value "retest_of": str(t.retest_of) if t.retest_of else None, + # Literal argument value "retest_count": t.retest_count, + # Literal argument value "result": t.result.value if t.result else None, + # Literal argument value "detection_result": t.detection_result.value if t.detection_result else None, + # Literal argument value "remediation_status": t.remediation_status, + # Literal argument value "created_at": t.created_at.isoformat() if t.created_at else None, } for t in chain diff --git a/backend/app/routers/threat_actors.py b/backend/app/routers/threat_actors.py index 733112a..b26e91c 100644 --- a/backend/app/routers/threat_actors.py +++ b/backend/app/routers/threat_actors.py @@ -4,15 +4,28 @@ Provides listing, detail, coverage analysis, and gap analysis for threat actor profiles imported from MITRE CTI. """ +# Import logging import logging + +# Import Optional from typing from typing import Optional +# Import APIRouter, Depends, Query from fastapi from fastapi import APIRouter, Depends, Query + +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session +# Import get_db from app.database from app.database import get_db + +# Import get_current_user from app.dependencies.auth from app.dependencies.auth import get_current_user + +# Import User from app.models.user from app.models.user import User + +# Import from app.services.threat_actor_service from app.services.threat_actor_service import ( get_actor_coverage, get_actor_detail, @@ -20,58 +33,90 @@ from app.services.threat_actor_service import ( list_actors, ) +# Assign logger = logging.getLogger(__name__) logger = logging.getLogger(__name__) +# Assign router = APIRouter(prefix="/threat-actors", tags=["threat-actors"]) router = APIRouter(prefix="/threat-actors", tags=["threat-actors"]) +# Apply the @router.get decorator @router.get("") +# Define function list_threat_actors def list_threat_actors( + # Entry: search search: Optional[str] = Query(None), + # Entry: country country: Optional[str] = Query(None), + # Entry: motivation motivation: Optional[str] = Query(None), + # Entry: sophistication sophistication: Optional[str] = Query(None), + # Entry: target_sectors target_sectors: Optional[str] = Query(None), + # Entry: offset offset: int = Query(0, ge=0), + # Entry: limit limit: int = Query(50, ge=1, le=200), + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(get_current_user), -): +) -> list: """List threat actors with optional filters and pagination. **Requires** authentication (any role). """ + # Return list_actors( return list_actors( db, + # Keyword argument: search search=search, + # Keyword argument: country country=country, + # Keyword argument: motivation motivation=motivation, + # Keyword argument: sophistication sophistication=sophistication, + # Keyword argument: target_sectors target_sectors=target_sectors, + # Keyword argument: offset offset=offset, + # Keyword argument: limit limit=limit, ) +# Apply the @router.get decorator @router.get("/{actor_id}") +# Define function get_threat_actor def get_threat_actor( + # Entry: actor_id actor_id: str, + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(get_current_user), -): +) -> dict: """Get detailed info about a threat actor including techniques. **Requires** authentication (any role). """ + # Return get_actor_detail(db, actor_id) return get_actor_detail(db, actor_id) +# Apply the @router.get decorator @router.get("/{actor_id}/coverage") +# Define function get_threat_actor_coverage def get_threat_actor_coverage( + # Entry: actor_id actor_id: str, + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(get_current_user), -): +) -> dict: """Calculate coverage percentage against a specific threat actor. **Requires** authentication (any role). @@ -79,19 +124,26 @@ def get_threat_actor_coverage( Returns the percentage of the actor's techniques that have been validated or partially validated, along with a breakdown. """ + # Return get_actor_coverage(db, actor_id) return get_actor_coverage(db, actor_id) +# Apply the @router.get decorator @router.get("/{actor_id}/gaps") +# Define function get_threat_actor_gaps def get_threat_actor_gaps( + # Entry: actor_id actor_id: str, + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(get_current_user), -): +) -> list: """Identify techniques of this actor that are NOT fully validated. **Requires** authentication (any role). Returns list of gap techniques with available templates. """ + # Return get_actor_gaps(db, actor_id) return get_actor_gaps(db, actor_id) diff --git a/backend/app/routers/users.py b/backend/app/routers/users.py index cfcb941..ea441a5 100644 --- a/backend/app/routers/users.py +++ b/backend/app/routers/users.py @@ -1,17 +1,30 @@ """User management router (admin only).""" +# Import uuid import uuid +# Import APIRouter, Depends, status from fastapi from fastapi import APIRouter, Depends, status + +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session +# Import get_db from app.database from app.database import get_db + +# Import require_role from app.dependencies.auth from app.dependencies.auth import require_role + +# Import UnitOfWork from app.domain.unit_of_work from app.domain.unit_of_work import UnitOfWork + +# Import User from app.models.user from app.models.user import User from app.dependencies.auth import get_current_user from app.schemas.user import UserCreate, UserUpdate, UserOut, UserPreferencesUpdate from app.services.audit_service import log_action + +# Import from app.services.user_service from app.services.user_service import ( create_user, get_user_or_raise, @@ -19,6 +32,7 @@ from app.services.user_service import ( update_user, ) +# Assign router = APIRouter(prefix="/users", tags=["users"]) router = APIRouter(prefix="/users", tags=["users"]) @@ -69,11 +83,15 @@ def get_me( @router.get("", response_model=list[UserOut]) +# Define function list_users_route def list_users_route( + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(require_role("admin")), -): - """Return a list of all users. **Requires admin role.**""" +) -> list[UserOut]: + """Return a list of all users. **Requires admin role.**.""" + # Return list_users(db) return list_users(db) @@ -83,31 +101,50 @@ def list_users_route( @router.post("", response_model=UserOut, status_code=status.HTTP_201_CREATED) +# Define function create_user_route def create_user_route( + # Entry: payload payload: UserCreate, + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(require_role("admin")), -): - """Create a new user. **Requires admin role.**""" +) -> UserOut: + """Create a new user. **Requires admin role.**.""" + # Open context manager with UnitOfWork(db) as uow: + # Assign user = create_user( user = create_user( db, + # Keyword argument: username username=payload.username, + # Keyword argument: email email=payload.email, + # Keyword argument: password password=payload.password, + # Keyword argument: role role=payload.role, ) + # Call log_action() log_action( db, + # Keyword argument: user_id user_id=current_user.id, + # Keyword argument: action action="create_user", + # Keyword argument: entity_type entity_type="user", + # Keyword argument: entity_id entity_id=user.id, + # Keyword argument: details details={"username": user.username, "role": user.role}, ) + # Call uow.commit() uow.commit() + # Reload ORM object attributes from the database db.refresh(user) + # Return user return user @@ -117,12 +154,17 @@ def create_user_route( @router.get("/{user_id}", response_model=UserOut) +# Define function get_user def get_user( + # Entry: user_id user_id: uuid.UUID, + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(require_role("admin")), -): - """Return a single user by ID. **Requires admin role.**""" +) -> UserOut: + """Return a single user by ID. **Requires admin role.**.""" + # Return get_user_or_raise(db, user_id) return get_user_or_raise(db, user_id) @@ -132,25 +174,42 @@ def get_user( @router.patch("/{user_id}", response_model=UserOut) +# Define function update_user_route def update_user_route( + # Entry: user_id user_id: uuid.UUID, + # Entry: payload payload: UserUpdate, + # Entry: db db: Session = Depends(get_db), + # Entry: current_user current_user: User = Depends(require_role("admin")), -): - """Update one or more fields of an existing user. **Requires admin role.**""" +) -> UserOut: + """Update one or more fields of an existing user. **Requires admin role.**.""" + # Assign update_data = payload.model_dump(exclude_unset=True) update_data = payload.model_dump(exclude_unset=True) + # Open context manager with UnitOfWork(db) as uow: + # Assign user = update_user(db, user_id, **update_data) user = update_user(db, user_id, **update_data) + # Call log_action() log_action( db, + # Keyword argument: user_id user_id=current_user.id, + # Keyword argument: action action="update_user", + # Keyword argument: entity_type entity_type="user", + # Keyword argument: entity_id entity_id=user.id, + # Keyword argument: details details={"updated_fields": list(update_data.keys())}, ) + # Call uow.commit() uow.commit() + # Reload ORM object attributes from the database db.refresh(user) + # Return user return user diff --git a/backend/app/routers/worklogs.py b/backend/app/routers/worklogs.py index b7a65d8..31cb8b5 100644 --- a/backend/app/routers/worklogs.py +++ b/backend/app/routers/worklogs.py @@ -1,19 +1,39 @@ """Worklog router — internal time-tracking records with integrity verification.""" +# Import datetime from datetime from datetime import datetime + +# Import Optional from typing from typing import Optional + +# Import UUID from uuid from uuid import UUID -from fastapi import APIRouter, Depends, Query +# Import APIRouter, Depends from fastapi +from fastapi import APIRouter, Depends + +# Import BaseModel, Field from pydantic from pydantic import BaseModel, Field + +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session +# Import get_db from app.database from app.database import get_db + +# Import get_current_user, require_any_role from app.dependencies.auth from app.dependencies.auth import get_current_user, require_any_role + +# Import UnitOfWork from app.domain.unit_of_work from app.domain.unit_of_work import UnitOfWork + +# Import User from app.models.user from app.models.user import User + +# Import worklog_service from app.services from app.services import worklog_service +# Assign router = APIRouter(prefix="/worklogs", tags=["worklogs"]) router = APIRouter(prefix="/worklogs", tags=["worklogs"]) @@ -21,30 +41,58 @@ router = APIRouter(prefix="/worklogs", tags=["worklogs"]) class WorklogCreate(BaseModel): + """Payload for logging a work session against an entity.""" + + # Assign entity_type = Field(..., max_length=50) entity_type: str = Field(..., max_length=50) + # entity_id: UUID entity_id: UUID + # Assign activity_type = Field(..., max_length=100) activity_type: str = Field(..., max_length=100) + # started_at: datetime started_at: datetime + # Assign ended_at = None ended_at: Optional[datetime] = None + # Assign duration_seconds = Field(..., gt=0) duration_seconds: int = Field(..., gt=0) + # Assign description = None description: Optional[str] = None +# Define class WorklogOut class WorklogOut(BaseModel): + """Serialized worklog entry returned by the API.""" + + # id: UUID id: UUID + # entity_type: str entity_type: str + # entity_id: UUID entity_id: UUID + # user_id: UUID user_id: UUID + # activity_type: str activity_type: str + # started_at: datetime started_at: datetime + # Assign ended_at = None ended_at: Optional[datetime] = None + # duration_seconds: int duration_seconds: int + # Assign description = None description: Optional[str] = None + # Assign tempo_synced = None tempo_synced: Optional[datetime] = None + # Assign integrity_hash = None integrity_hash: Optional[str] = None + # created_at: datetime created_at: datetime + # Define class Config class Config: + """ORM mode configuration for SQLAlchemy model mapping.""" + + # Assign from_attributes = True from_attributes = True @@ -52,65 +100,146 @@ class WorklogOut(BaseModel): @router.post("", response_model=WorklogOut, status_code=201) +# Define function create def create( + # Entry: body body: WorklogCreate, + # Entry: db db: Session = Depends(get_db), + # Entry: user user: User = Depends(require_any_role("red_tech", "blue_tech", "red_lead", "blue_lead")), -): - """Create a manually-logged worklog entry.""" +) -> WorklogOut: + """Create a manually-logged worklog entry. + + Args: + body (WorklogCreate): Worklog fields including entity, activity type, and duration. + db (Session): SQLAlchemy database session. + user (User): Authenticated team member creating the worklog. + + Returns: + WorklogOut: The newly created worklog with integrity hash and all fields. + """ + # Open context manager with UnitOfWork(db) as uow: + # Assign wl = worklog_service.create_worklog( wl = worklog_service.create_worklog( db, + # Keyword argument: entity_type entity_type=body.entity_type, + # Keyword argument: entity_id entity_id=body.entity_id, + # Keyword argument: user_id user_id=user.id, + # Keyword argument: activity_type activity_type=body.activity_type, + # Keyword argument: started_at started_at=body.started_at, + # Keyword argument: ended_at ended_at=body.ended_at, + # Keyword argument: duration_seconds duration_seconds=body.duration_seconds, + # Keyword argument: description description=body.description, ) + # Call uow.commit() uow.commit() + # Reload ORM object attributes from the database db.refresh(wl) + # Return wl return wl +# Apply the @router.get decorator @router.get("", response_model=list[WorklogOut]) +# Define function list_all def list_all( + # Entry: entity_type entity_type: Optional[str] = None, + # Entry: entity_id entity_id: Optional[UUID] = None, + # Entry: user_id user_id: Optional[UUID] = None, + # Entry: db db: Session = Depends(get_db), + # Entry: _user _user: User = Depends(get_current_user), -): - """List worklogs with optional filters.""" +) -> list[WorklogOut]: + """List worklogs with optional filters. + + Args: + entity_type (Optional[str]): Filter by entity type (e.g. ``test``, ``campaign``). + entity_id (Optional[UUID]): Filter by the UUID of the associated entity. + user_id (Optional[UUID]): Filter by the UUID of the worklog author. + db (Session): SQLAlchemy database session. + _user (User): Authenticated user making the request (unused, enforces auth). + + Returns: + list[WorklogOut]: Serialised list of worklog entries matching the filters. + """ + # Return worklog_service.list_worklogs( return worklog_service.list_worklogs( db, + # Keyword argument: entity_type entity_type=entity_type, + # Keyword argument: entity_id entity_id=entity_id, + # Keyword argument: user_id user_id=user_id, ) +# Apply the @router.get decorator @router.get("/{worklog_id}", response_model=WorklogOut) +# Define function get_one def get_one( + # Entry: worklog_id worklog_id: UUID, + # Entry: db db: Session = Depends(get_db), + # Entry: _user _user: User = Depends(get_current_user), -): - """Get a single worklog by ID.""" +) -> WorklogOut: + """Get a single worklog by ID. + + Args: + worklog_id (UUID): Primary key of the worklog to retrieve. + db (Session): SQLAlchemy database session. + _user (User): Authenticated user making the request (unused, enforces auth). + + Returns: + WorklogOut: Full worklog detail including integrity hash. + """ + # Return worklog_service.get_worklog_or_raise(db, worklog_id) return worklog_service.get_worklog_or_raise(db, worklog_id) +# Apply the @router.get decorator @router.get("/{worklog_id}/verify") +# Define function verify_integrity def verify_integrity( + # Entry: worklog_id worklog_id: UUID, + # Entry: db db: Session = Depends(get_db), + # Entry: _user _user: User = Depends(get_current_user), -): - """Check whether a worklog's integrity hash is still valid.""" +) -> dict: + """Check whether a worklog's integrity hash is still valid. + + Args: + worklog_id (UUID): Primary key of the worklog to verify. + db (Session): SQLAlchemy database session. + _user (User): Authenticated user making the request (unused, enforces auth). + + Returns: + dict: Contains ``worklog_id`` (str) and ``integrity_valid`` (bool). + """ + # Assign wl = worklog_service.get_worklog_or_raise(db, worklog_id) wl = worklog_service.get_worklog_or_raise(db, worklog_id) + # Return { return { + # Literal argument value "worklog_id": str(wl.id), + # Literal argument value "integrity_valid": worklog_service.verify_worklog_integrity(wl), } diff --git a/backend/app/schemas/__init__.py b/backend/app/schemas/__init__.py index 62b98f1..f8b2f62 100644 --- a/backend/app/schemas/__init__.py +++ b/backend/app/schemas/__init__.py @@ -1,7 +1,12 @@ """Pydantic schemas — re-exported for convenient imports.""" +# Import LoginRequest, TokenResponse, UserOut from app.schemas.auth from app.schemas.auth import LoginRequest, TokenResponse, UserOut +# Import EvidenceOut, EvidenceUpload from app.schemas.evidence +from app.schemas.evidence import EvidenceOut, EvidenceUpload + +# Import from app.schemas.technique from app.schemas.technique import ( TechniqueCreate, TechniqueOut, @@ -9,51 +14,68 @@ from app.schemas.technique import ( TechniqueUpdate, ) +# Import from app.schemas.test from app.schemas.test import ( + TestBlueUpdate, + TestBlueValidate, TestCreate, TestOut, + TestRedUpdate, + TestRedValidate, TestUpdate, TestValidate, - TestRedUpdate, - TestBlueUpdate, - TestRedValidate, - TestBlueValidate, ) -from app.schemas.evidence import EvidenceOut, EvidenceUpload - +# Import from app.schemas.test_template from app.schemas.test_template import ( - TestTemplateOut, TestTemplateCreate, - TestTemplateSummary, TestTemplateInstantiate, + TestTemplateOut, + TestTemplateSummary, ) +# Assign __all__ = [ __all__ = [ # Auth "LoginRequest", + # Literal argument value "TokenResponse", + # Literal argument value "UserOut", # Technique "TechniqueCreate", + # Literal argument value "TechniqueOut", + # Literal argument value "TechniqueSummary", + # Literal argument value "TechniqueUpdate", # Test "TestCreate", + # Literal argument value "TestOut", + # Literal argument value "TestUpdate", + # Literal argument value "TestValidate", + # Literal argument value "TestRedUpdate", + # Literal argument value "TestBlueUpdate", + # Literal argument value "TestRedValidate", + # Literal argument value "TestBlueValidate", # Evidence "EvidenceOut", + # Literal argument value "EvidenceUpload", # Test Template "TestTemplateOut", + # Literal argument value "TestTemplateCreate", + # Literal argument value "TestTemplateSummary", + # Literal argument value "TestTemplateInstantiate", ] diff --git a/backend/app/schemas/audit.py b/backend/app/schemas/audit.py index e0acb7a..faaa4ed 100644 --- a/backend/app/schemas/audit.py +++ b/backend/app/schemas/audit.py @@ -1,31 +1,48 @@ """Pydantic schemas for Audit Log endpoints.""" +# Import uuid import uuid + +# Import datetime from datetime from datetime import datetime from typing import Any, Optional +# Import BaseModel, ConfigDict from pydantic from pydantic import BaseModel, ConfigDict +# Define class AuditLogOut class AuditLogOut(BaseModel): """Complete representation of an audit log entry.""" + # id: uuid.UUID id: uuid.UUID + # Assign user_id = None user_id: uuid.UUID | None = None + # Assign username = None # Populated from user relationship username: str | None = None # Populated from user relationship + # action: str action: str + # Assign entity_type = None entity_type: str | None = None + # Assign entity_id = None entity_id: str | None = None timestamp: Optional[datetime] = None details: dict[str, Any] | None = None + # Assign model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True) +# Define class AuditLogPage class AuditLogPage(BaseModel): """Paginated response for audit logs.""" + # items: list[AuditLogOut] items: list[AuditLogOut] + # total: int total: int + # offset: int offset: int + # limit: int limit: int diff --git a/backend/app/schemas/auth.py b/backend/app/schemas/auth.py index 14224cb..4c4bc08 100644 --- a/backend/app/schemas/auth.py +++ b/backend/app/schemas/auth.py @@ -1,34 +1,56 @@ """Pydantic schemas for authentication endpoints.""" +# Import uuid import uuid +# Import BaseModel from pydantic from pydantic import BaseModel +# Define class LoginRequest class LoginRequest(BaseModel): - """Body for the login endpoint (unused directly — we rely on - ``OAuth2PasswordRequestForm``, but kept for documentation / testing).""" + """Body for the login endpoint. + Unused directly — we rely on ``OAuth2PasswordRequestForm``, but kept for + documentation and testing purposes. + """ + + # username: str username: str + # password: str password: str +# Define class TokenResponse class TokenResponse(BaseModel): """Response returned after a successful login.""" + # access_token: str access_token: str + # Assign token_type = "bearer" token_type: str = "bearer" +# Define class UserOut class UserOut(BaseModel): """Public representation of a user (no password hash).""" + # id: uuid.UUID id: uuid.UUID + # username: str username: str + # Assign email = None email: str | None = None + # role: str role: str + # is_active: bool is_active: bool + # Assign must_change_password = True must_change_password: bool = True + # Define class Config class Config: + """ORM mode configuration for SQLAlchemy model mapping.""" + + # Assign from_attributes = True from_attributes = True diff --git a/backend/app/schemas/evidence.py b/backend/app/schemas/evidence.py index a26dd88..942ec3c 100644 --- a/backend/app/schemas/evidence.py +++ b/backend/app/schemas/evidence.py @@ -1,34 +1,53 @@ """Pydantic schemas for Evidence endpoints.""" +# Import uuid import uuid + +# Import datetime from datetime from datetime import datetime +# Import BaseModel, ConfigDict from pydantic from pydantic import BaseModel, ConfigDict +# Import TeamSide from app.models.enums from app.models.enums import TeamSide +# Define class EvidenceOut class EvidenceOut(BaseModel): """Representation of an evidence record returned by the API. ``download_url`` is a presigned URL generated at response time. """ + # id: uuid.UUID id: uuid.UUID + # test_id: uuid.UUID test_id: uuid.UUID + # file_name: str file_name: str + # sha256_hash: str sha256_hash: str + # Assign uploaded_by = None uploaded_by: uuid.UUID | None = None + # Assign uploaded_at = None uploaded_at: datetime | None = None + # Assign team = TeamSide.red team: TeamSide = TeamSide.red + # Assign notes = None notes: str | None = None + # Assign download_url = None download_url: str | None = None + # Assign model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True) +# Define class EvidenceUpload class EvidenceUpload(BaseModel): """Metadata sent alongside an evidence file upload.""" + # team: TeamSide team: TeamSide + # Assign notes = None notes: str | None = None diff --git a/backend/app/schemas/jira_schema.py b/backend/app/schemas/jira_schema.py index 489544c..d040f3d 100644 --- a/backend/app/schemas/jira_schema.py +++ b/backend/app/schemas/jira_schema.py @@ -1,46 +1,91 @@ """Pydantic schemas for Jira integration endpoints.""" +# Import datetime from datetime from datetime import datetime + +# Import Optional from typing from typing import Optional + +# Import UUID from uuid from uuid import UUID +# Import BaseModel, Field from pydantic from pydantic import BaseModel, Field +# Import JiraLinkEntityType, JiraSyncDirection from app.models.jira_link from app.models.jira_link import JiraLinkEntityType, JiraSyncDirection +# Define class JiraLinkCreate class JiraLinkCreate(BaseModel): + """Payload for linking an Aegis entity to an existing Jira issue.""" + + # entity_type: JiraLinkEntityType entity_type: JiraLinkEntityType + # entity_id: UUID entity_id: UUID + # Assign jira_issue_key = Field(..., pattern=r"^[A-Z][A-Z0-9]+-\d+$") jira_issue_key: str = Field(..., pattern=r"^[A-Z][A-Z0-9]+-\d+$") + # Assign sync_direction = JiraSyncDirection.bidirectional sync_direction: JiraSyncDirection = JiraSyncDirection.bidirectional +# Define class JiraLinkOut class JiraLinkOut(BaseModel): + """Full representation of a Jira link returned by the API.""" + + # id: UUID id: UUID + # entity_type: JiraLinkEntityType entity_type: JiraLinkEntityType + # entity_id: UUID entity_id: UUID + # jira_issue_key: str jira_issue_key: str + # Assign jira_issue_id = None jira_issue_id: Optional[str] = None + # Assign jira_project_key = None jira_project_key: Optional[str] = None + # Assign jira_status = None jira_status: Optional[str] = None + # Assign jira_priority = None jira_priority: Optional[str] = None + # Assign jira_assignee = None jira_assignee: Optional[str] = None + # Assign jira_story_points = None jira_story_points: Optional[str] = None + # Assign last_synced_at = None last_synced_at: Optional[datetime] = None + # created_at: datetime created_at: datetime + # Define class Config class Config: + """ORM mode configuration for SQLAlchemy model mapping.""" + + # Assign from_attributes = True from_attributes = True +# Define class JiraIssueSearch class JiraIssueSearch(BaseModel): + """Payload for searching Jira issues by free-text query.""" + + # query: str query: str +# Define class JiraIssueResult class JiraIssueResult(BaseModel): + """Lightweight Jira issue representation returned by search results.""" + + # issue_key: str issue_key: str + # summary: str summary: str + # status: str status: str + # Assign assignee = None assignee: Optional[str] = None + # Assign priority = None priority: Optional[str] = None diff --git a/backend/app/schemas/metrics.py b/backend/app/schemas/metrics.py index 1d06c36..9bfe689 100644 --- a/backend/app/schemas/metrics.py +++ b/backend/app/schemas/metrics.py @@ -1,31 +1,49 @@ """Pydantic schemas for coverage-metrics endpoints.""" +# Import datetime from datetime from datetime import datetime +# Import BaseModel, ConfigDict from pydantic from pydantic import BaseModel, ConfigDict +# Define class CoverageSummary class CoverageSummary(BaseModel): """Global coverage summary across all MITRE ATT&CK techniques.""" + # total_techniques: int total_techniques: int + # validated: int validated: int + # partial: int partial: int + # not_covered: int not_covered: int + # in_progress: int in_progress: int + # not_evaluated: int not_evaluated: int + # coverage_percentage: float # (validated + partial) / total * 100 coverage_percentage: float # (validated + partial) / total * 100 +# Define class TacticCoverage class TacticCoverage(BaseModel): """Coverage breakdown for a single tactic.""" + # tactic: str tactic: str + # total: int total: int + # validated: int validated: int + # partial: int partial: int + # not_covered: int not_covered: int + # not_evaluated: int not_evaluated: int + # in_progress: int in_progress: int @@ -35,12 +53,19 @@ class TacticCoverage(BaseModel): class TestPipelineCounts(BaseModel): """Counters per state in the test pipeline.""" + # Assign draft = 0 draft: int = 0 + # Assign red_executing = 0 red_executing: int = 0 + # Assign blue_evaluating = 0 blue_evaluating: int = 0 + # Assign in_review = 0 in_review: int = 0 + # Assign validated = 0 validated: int = 0 + # Assign rejected = 0 rejected: int = 0 + # Assign total = 0 total: int = 0 @@ -50,9 +75,13 @@ class TestPipelineCounts(BaseModel): class TeamActivity(BaseModel): """Activity summary for a team (Red or Blue).""" + # team: str team: str + # Assign tests_completed = 0 tests_completed: int = 0 + # Assign tests_pending = 0 tests_pending: int = 0 + # Assign avg_completion_hours = None avg_completion_hours: float | None = None @@ -62,10 +91,15 @@ class TeamActivity(BaseModel): class ValidationRate(BaseModel): """Approval / rejection rate for a manager role.""" + # role: str # "red_lead" or "blue_lead" role: str # "red_lead" or "blue_lead" + # Assign total_reviewed = 0 total_reviewed: int = 0 + # Assign approved = 0 approved: int = 0 + # Assign rejected = 0 rejected: int = 0 + # Assign approval_rate = 0.0 # percentage approval_rate: float = 0.0 # percentage @@ -75,11 +109,18 @@ class ValidationRate(BaseModel): class RecentTestItem(BaseModel): """Lightweight test entry for the recent-tests widget.""" + # id: str id: str + # name: str name: str + # state: str state: str + # Assign technique_mitre_id = None technique_mitre_id: str | None = None + # Assign technique_name = None technique_name: str | None = None + # Assign created_at = None created_at: datetime | None = None + # Assign model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True) diff --git a/backend/app/schemas/notification.py b/backend/app/schemas/notification.py index 436cb00..1f50d57 100644 --- a/backend/app/schemas/notification.py +++ b/backend/app/schemas/notification.py @@ -1,28 +1,45 @@ """Pydantic schemas for Notification endpoints.""" +# Import uuid import uuid + +# Import datetime from datetime from datetime import datetime +# Import BaseModel, ConfigDict from pydantic from pydantic import BaseModel, ConfigDict +# Define class NotificationOut class NotificationOut(BaseModel): """Notification returned by the API.""" + # id: uuid.UUID id: uuid.UUID + # user_id: uuid.UUID user_id: uuid.UUID + # type: str type: str + # title: str title: str + # Assign message = None message: str | None = None + # Assign entity_type = None entity_type: str | None = None + # Assign entity_id = None entity_id: uuid.UUID | None = None + # Assign read = False read: bool = False + # Assign created_at = None created_at: datetime | None = None + # Assign model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True) +# Define class UnreadCountOut class UnreadCountOut(BaseModel): """Simple counter response.""" + # unread_count: int unread_count: int diff --git a/backend/app/schemas/technique.py b/backend/app/schemas/technique.py index 9cb8042..4544e15 100644 --- a/backend/app/schemas/technique.py +++ b/backend/app/schemas/technique.py @@ -1,22 +1,31 @@ """Pydantic schemas for Technique endpoints.""" +# Import uuid import uuid + +# Import datetime from datetime from datetime import datetime +# Import BaseModel, ConfigDict from pydantic from pydantic import BaseModel, ConfigDict +# Import TechniqueStatus from app.models.enums from app.models.enums import TechniqueStatus - # ── Create ────────────────────────────────────────────────────────── class TechniqueCreate(BaseModel): """Payload for creating a new technique.""" + # mitre_id: str mitre_id: str + # name: str name: str + # Assign description = None description: str | None = None + # Assign tactic = None tactic: str | None = None + # Assign platforms = None platforms: list[str] | None = None @@ -24,12 +33,19 @@ class TechniqueCreate(BaseModel): class TechniqueUpdate(BaseModel): """Payload for partially updating an existing technique. - Every field is optional so callers send only what changed.""" + Every field is optional so callers send only what changed. + """ + + # Assign name = None name: str | None = None + # Assign description = None description: str | None = None + # Assign tactic = None tactic: str | None = None + # Assign platforms = None platforms: list[str] | None = None + # Assign status_global = None status_global: TechniqueStatus | None = None @@ -38,20 +54,34 @@ class TechniqueUpdate(BaseModel): class TechniqueOut(BaseModel): """Complete representation returned by the API.""" + # id: uuid.UUID id: uuid.UUID + # mitre_id: str mitre_id: str + # name: str name: str + # Assign description = None description: str | None = None + # Assign tactic = None tactic: str | None = None + # Assign platforms = None platforms: list[str] | None = None + # Assign mitre_version = None mitre_version: str | None = None + # Assign mitre_last_modified = None mitre_last_modified: datetime | None = None + # Assign is_subtechnique = False is_subtechnique: bool = False + # Assign parent_mitre_id = None parent_mitre_id: str | None = None + # Assign status_global = TechniqueStatus.not_evaluated status_global: TechniqueStatus = TechniqueStatus.not_evaluated + # Assign review_required = False review_required: bool = False + # Assign last_review_date = None last_review_date: datetime | None = None + # Assign model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True) @@ -60,10 +90,16 @@ class TechniqueOut(BaseModel): class TechniqueSummary(BaseModel): """Lightweight representation used in list endpoints.""" + # id: uuid.UUID id: uuid.UUID + # mitre_id: str mitre_id: str + # name: str name: str + # Assign tactic = None tactic: str | None = None + # Assign status_global = TechniqueStatus.not_evaluated status_global: TechniqueStatus = TechniqueStatus.not_evaluated + # Assign model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True) diff --git a/backend/app/schemas/test.py b/backend/app/schemas/test.py index ef68534..b3a0ad4 100644 --- a/backend/app/schemas/test.py +++ b/backend/app/schemas/test.py @@ -1,14 +1,20 @@ """Pydantic schemas for Test endpoints.""" +# Import uuid import uuid + +# Import datetime from datetime from datetime import datetime from pydantic import BaseModel, ConfigDict, model_validator +# Import DataClassification from app.domain.enums from app.domain.enums import DataClassification from app.models.enums import TestResult, TestState from app.schemas.evidence import EvidenceOut +# Import TestResult, TestState from app.models.enums +from app.models.enums import TestResult, TestState # ── Create ────────────────────────────────────────────────────────── @@ -16,11 +22,17 @@ from app.schemas.evidence import EvidenceOut class TestCreate(BaseModel): """Payload for creating a new test.""" + # technique_id: uuid.UUID technique_id: uuid.UUID + # name: str name: str + # Assign description = None description: str | None = None + # Assign platform = None platform: str | None = None + # Assign procedure_text = None procedure_text: str | None = None + # Assign tool_used = None tool_used: str | None = None @@ -30,18 +42,28 @@ class TestCreate(BaseModel): class TestClassificationUpdate(BaseModel): """Admin-only payload for changing data classification.""" + # data_classification: DataClassification data_classification: DataClassification +# Define class TestUpdate class TestUpdate(BaseModel): """Payload for partially updating an existing test. - Every field is optional so callers send only what changed.""" + Every field is optional so callers send only what changed. + """ + + # Assign name = None name: str | None = None + # Assign description = None description: str | None = None + # Assign platform = None platform: str | None = None + # Assign procedure_text = None procedure_text: str | None = None + # Assign tool_used = None tool_used: str | None = None + # Assign result = None result: TestResult | None = None @@ -51,11 +73,17 @@ class TestUpdate(BaseModel): class TestRedUpdate(BaseModel): """Fields that Red Team fills in during the red_executing phase.""" + # Assign name = None name: str | None = None + # Assign description = None description: str | None = None + # Assign procedure_text = None procedure_text: str | None = None + # Assign tool_used = None tool_used: str | None = None + # Assign attack_success = None attack_success: bool | None = None + # Assign red_summary = None red_summary: str | None = None @@ -65,7 +93,9 @@ class TestRedUpdate(BaseModel): class TestBlueUpdate(BaseModel): """Fields that Blue Team fills in during the blue_evaluating phase.""" + # Assign detection_result = None detection_result: TestResult | None = None + # Assign blue_summary = None blue_summary: str | None = None @@ -75,7 +105,9 @@ class TestBlueUpdate(BaseModel): class TestRedValidate(BaseModel): """Payload sent by Red Lead to approve/reject the red side.""" + # red_validation_status: str # "approved" or "rejected" red_validation_status: str # "approved" or "rejected" + # Assign red_validation_notes = None red_validation_notes: str | None = None @@ -85,7 +117,9 @@ class TestRedValidate(BaseModel): class TestBlueValidate(BaseModel): """Payload sent by Blue Lead to approve/reject the blue side.""" + # blue_validation_status: str # "approved" or "rejected" blue_validation_status: str # "approved" or "rejected" + # Assign blue_validation_notes = None blue_validation_notes: str | None = None @@ -95,8 +129,11 @@ class TestBlueValidate(BaseModel): class TestRemediationUpdate(BaseModel): """Payload for updating remediation fields.""" + # Assign remediation_steps = None remediation_steps: str | None = None + # Assign remediation_status = None # pending / in_progress / completed / not_applicable remediation_status: str | None = None # pending / in_progress / completed / not_applicable + # Assign remediation_assignee = None remediation_assignee: uuid.UUID | None = None @@ -106,7 +143,9 @@ class TestRemediationUpdate(BaseModel): class TestValidate(BaseModel): """Payload sent by a reviewer to validate / reject a test.""" + # result: TestResult result: TestResult + # Assign comments = None comments: str | None = None @@ -116,55 +155,85 @@ class TestValidate(BaseModel): class TestOut(BaseModel): """Complete representation returned by the API.""" + # id: uuid.UUID id: uuid.UUID + # technique_id: uuid.UUID technique_id: uuid.UUID + # name: str name: str + # Assign description = None description: str | None = None + # Assign platform = None platform: str | None = None + # Assign procedure_text = None procedure_text: str | None = None + # Assign tool_used = None tool_used: str | None = None + # Assign execution_date = None execution_date: datetime | None = None + # Assign created_by = None created_by: uuid.UUID | None = None + # Assign result = None result: TestResult | None = None + # Assign state = TestState.draft state: TestState = TestState.draft + # Assign created_at = None created_at: datetime | None = None # Red Team fields red_summary: str | None = None + # Assign attack_success = None attack_success: bool | None = None + # Assign red_validated_by = None red_validated_by: uuid.UUID | None = None + # Assign red_validated_at = None red_validated_at: datetime | None = None + # Assign red_validation_status = None red_validation_status: str | None = None + # Assign red_validation_notes = None red_validation_notes: str | None = None # Blue Team fields blue_summary: str | None = None + # Assign detection_result = None detection_result: TestResult | None = None + # Assign blue_validated_by = None blue_validated_by: uuid.UUID | None = None + # Assign blue_validated_at = None blue_validated_at: datetime | None = None + # Assign blue_validation_status = None blue_validation_status: str | None = None + # Assign blue_validation_notes = None blue_validation_notes: str | None = None # Phase timing fields (for Tempo worklogs) red_started_at: datetime | None = None + # Assign blue_started_at = None blue_started_at: datetime | None = None blue_work_started_at: datetime | None = None paused_at: datetime | None = None + # Assign red_paused_seconds = 0 red_paused_seconds: int = 0 + # Assign blue_paused_seconds = 0 blue_paused_seconds: int = 0 # Remediation fields remediation_steps: str | None = None + # Assign remediation_status = None remediation_status: str | None = None + # Assign remediation_assignee = None remediation_assignee: uuid.UUID | None = None # Re-test fields retest_of: uuid.UUID | None = None + # Assign retest_count = 0 retest_count: int = 0 + # Assign data_classification = "internal" data_classification: str = "internal" # Technique info (populated when joined) technique_mitre_id: str | None = None + # Assign technique_name = None technique_name: str | None = None # Evidences split by team (populated from the ORM relationship) diff --git a/backend/app/schemas/test_template.py b/backend/app/schemas/test_template.py index 18f7753..7d6aed6 100644 --- a/backend/app/schemas/test_template.py +++ b/backend/app/schemas/test_template.py @@ -1,33 +1,52 @@ """Pydantic schemas for TestTemplate endpoints.""" +# Import uuid import uuid + +# Import datetime from datetime from datetime import datetime +# Import BaseModel, ConfigDict from pydantic from pydantic import BaseModel, ConfigDict - # ── Full output ───────────────────────────────────────────────────── class TestTemplateOut(BaseModel): """Complete representation of a test template.""" + # id: uuid.UUID id: uuid.UUID + # mitre_technique_id: str mitre_technique_id: str + # name: str name: str + # Assign description = None description: str | None = None + # source: str source: str + # Assign source_url = None source_url: str | None = None + # Assign attack_procedure = None attack_procedure: str | None = None + # Assign expected_detection = None expected_detection: str | None = None + # Assign platform = None platform: str | None = None + # Assign tool_suggested = None tool_suggested: str | None = None + # Assign severity = None severity: str | None = None + # Assign atomic_test_id = None atomic_test_id: str | None = None + # Assign suggested_remediation = None suggested_remediation: str | None = None + # Assign is_active = True is_active: bool = True + # Assign created_at = None created_at: datetime | None = None + # Assign model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True) @@ -37,17 +56,29 @@ class TestTemplateOut(BaseModel): class TestTemplateCreate(BaseModel): """Payload for creating a custom test template.""" + # mitre_technique_id: str mitre_technique_id: str + # name: str name: str + # Assign description = None description: str | None = None + # Assign source = "custom" source: str = "custom" + # Assign source_url = None source_url: str | None = None + # Assign attack_procedure = None attack_procedure: str | None = None + # Assign expected_detection = None expected_detection: str | None = None + # Assign platform = None platform: str | None = None + # Assign tool_suggested = None tool_suggested: str | None = None + # Assign severity = None severity: str | None = None + # Assign atomic_test_id = None atomic_test_id: str | None = None + # Assign suggested_remediation = None suggested_remediation: str | None = None @@ -57,14 +88,22 @@ class TestTemplateCreate(BaseModel): class TestTemplateSummary(BaseModel): """Lightweight representation for listing templates.""" + # id: uuid.UUID id: uuid.UUID + # mitre_technique_id: str mitre_technique_id: str + # name: str name: str + # source: str source: str + # Assign platform = None platform: str | None = None + # Assign severity = None severity: str | None = None + # Assign is_active = True is_active: bool = True + # Assign model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True) @@ -77,7 +116,9 @@ class TestTemplateInstantiate(BaseModel): Optional override fields take precedence over the template values when provided. """ + # template_id: uuid.UUID template_id: uuid.UUID + # technique_id: str # accepts both UUID and MITRE ID (e.g. "T1059.001") technique_id: str # accepts both UUID and MITRE ID (e.g. "T1059.001") # User-editable overrides (if omitted the template value is used) diff --git a/backend/app/schemas/user.py b/backend/app/schemas/user.py index 870befa..9ffa9f9 100644 --- a/backend/app/schemas/user.py +++ b/backend/app/schemas/user.py @@ -1,7 +1,12 @@ """Pydantic schemas for User management endpoints.""" +# Import re import re + +# Import uuid import uuid + +# Import datetime from datetime from datetime import datetime from pydantic import BaseModel, ConfigDict, EmailStr, Field, field_validator, model_validator @@ -10,21 +15,39 @@ from pydantic import BaseModel, ConfigDict, EmailStr, Field, field_validator, mo # ── Username policy ───────────────────────────────────────────────── _USERNAME_RE = re.compile(r"^[a-zA-Z0-9_-]{3,50}$") +# Assign _RESERVED_USERNAMES = frozenset({ _RESERVED_USERNAMES = frozenset({ + # Literal argument value "admin", "root", "system", "api", "null", "undefined", + # Literal argument value "administrator", "superuser", "aegis", }) +# Define function _validate_username def _validate_username(username: str) -> str: - """Validate username format and reject reserved names.""" + """Validate username format and reject reserved names. + + Args: + username (str): The username string to validate. + + Returns: + str: The validated username, unchanged. + """ + # Check: not _USERNAME_RE.match(username) if not _USERNAME_RE.match(username): + # Raise ValueError raise ValueError( + # Literal argument value "Username must be 3-50 characters, containing only " + # Literal argument value "letters, digits, underscores, and hyphens" ) + # Check: username.lower() in _RESERVED_USERNAMES if username.lower() in _RESERVED_USERNAMES: + # Raise ValueError raise ValueError(f"Username '{username}' is reserved") + # Return username return username @@ -32,6 +55,7 @@ def _validate_username(username: str) -> str: _MIN_PASSWORD_LENGTH = 12 +# Assign _PASSWORD_RULES = [ _PASSWORD_RULES: list[tuple[str, str]] = [ (r"[A-Z]", "at least one uppercase letter"), (r"[a-z]", "at least one lowercase letter"), @@ -40,30 +64,48 @@ _PASSWORD_RULES: list[tuple[str, str]] = [ ] +# Define function _validate_password_strength def _validate_password_strength(password: str) -> str: """Check that *password* satisfies the complexity policy. Rules: + - Minimum 12 characters - At least one uppercase letter - At least one lowercase letter - At least one digit - At least one special character + + Args: + password (str): The plaintext password to validate. + + Returns: + str: The validated password, unchanged. """ + # Assign errors = [] errors: list[str] = [] + # Check: len(password) < _MIN_PASSWORD_LENGTH if len(password) < _MIN_PASSWORD_LENGTH: + # Call errors.append() errors.append(f"must be at least {_MIN_PASSWORD_LENGTH} characters long") + # Iterate over _PASSWORD_RULES for pattern, description in _PASSWORD_RULES: + # Check: not re.search(pattern, password) if not re.search(pattern, password): + # Call errors.append() errors.append(description) + # Check: errors if errors: + # Raise ValueError raise ValueError( + # Literal argument value "Password does not meet complexity requirements: " + "; ".join(errors) ) + # Return password return password @@ -72,19 +114,47 @@ def _validate_password_strength(password: str) -> str: class UserCreate(BaseModel): """Payload for creating a new user.""" + # username: str username: str + # Assign email = None email: str | None = None + # password: str password: str + # Assign role = "viewer" role: str = "viewer" + # Apply the @field_validator decorator @field_validator("username") + # Apply the @classmethod decorator @classmethod + # Define function username_format def username_format(cls, v: str) -> str: + """Validate the username field against the platform policy. + + Args: + v (str): Raw username value from the request body. + + Returns: + str: The validated username. + """ + # Return _validate_username(v) return _validate_username(v) + # Apply the @field_validator decorator @field_validator("password") + # Apply the @classmethod decorator @classmethod + # Define function password_strength def password_strength(cls, v: str) -> str: + """Validate the password field against the complexity policy. + + Args: + v (str): Raw password value from the request body. + + Returns: + str: The validated password. + """ + # Return _validate_password_strength(v) return _validate_password_strength(v) @@ -92,18 +162,38 @@ class UserCreate(BaseModel): class UserUpdate(BaseModel): """Payload for partially updating an existing user. - Every field is optional so callers send only what changed.""" + Every field is optional so callers send only what changed. + """ + + # Assign email = None email: str | None = None + # Assign role = None role: str | None = None + # Assign is_active = None is_active: bool | None = None + # Assign password = None password: str | None = None + # Apply the @field_validator decorator @field_validator("password") + # Apply the @classmethod decorator @classmethod + # Define function password_strength def password_strength(cls, v: str | None) -> str | None: + """Validate the password field when provided. + + Args: + v (str | None): Raw password value, or ``None`` when unchanged. + + Returns: + str | None: The validated password, or ``None``. + """ + # Check: v is not None if v is not None: + # Return _validate_password_strength(v) return _validate_password_strength(v) + # Return v return v @@ -112,12 +202,26 @@ class UserUpdate(BaseModel): class PasswordChange(BaseModel): """Payload for changing the current user's password.""" + # current_password: str current_password: str + # new_password: str new_password: str + # Apply the @field_validator decorator @field_validator("new_password") + # Apply the @classmethod decorator @classmethod + # Define function new_password_strength def new_password_strength(cls, v: str) -> str: + """Validate the new password against the complexity policy. + + Args: + v (str): Raw new-password value from the request body. + + Returns: + str: The validated new password. + """ + # Return _validate_password_strength(v) return _validate_password_strength(v) @@ -140,13 +244,21 @@ class UserPreferencesUpdate(BaseModel): class UserOut(BaseModel): """Complete representation returned by the API.""" + # id: uuid.UUID id: uuid.UUID + # username: str username: str + # Assign email = None email: str | None = None + # role: str role: str + # is_active: bool is_active: bool + # Assign must_change_password = True must_change_password: bool = True + # Assign created_at = None created_at: datetime | None = None + # Assign last_login = None last_login: datetime | None = None notification_preferences: dict | None = None jira_account_id: str | None = None @@ -158,6 +270,7 @@ class UserOut(BaseModel): jira_token_set: bool = False tempo_token_set: bool = False + # Assign model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True) @model_validator(mode="after") diff --git a/backend/app/seed.py b/backend/app/seed.py index 7ca16e9..7c5b34d 100644 --- a/backend/app/seed.py +++ b/backend/app/seed.py @@ -1,5 +1,4 @@ -""" -Seed script — creates the initial admin user if it does not already exist. +"""Seed script — creates the initial admin user if it does not already exist. On first run the admin credentials are generated securely: - Username is read from ``ADMIN_USERNAME`` env var (default: ``admin``). @@ -11,23 +10,36 @@ Usage: python -m app.seed """ +# Import os import os + +# Import secrets import secrets + +# Import string import string +# Import hash_password from app.auth from app.auth import hash_password + +# Import SessionLocal from app.database from app.database import SessionLocal + +# Import User from app.models.user from app.models.user import User # Characters for auto-generated passwords (alphanumeric + safe symbols) _PW_ALPHABET = string.ascii_letters + string.digits + "!@#$%&*-_+" +# Define function _generate_password def _generate_password(length: int = 16) -> str: """Return a cryptographically random password of *length* characters.""" + # Return "".join(secrets.choice(_PW_ALPHABET) for _ in range(length)) return "".join(secrets.choice(_PW_ALPHABET) for _ in range(length)) +# Define function seed_admin def seed_admin() -> None: """Create the initial admin user when it is missing. @@ -35,49 +47,85 @@ def seed_admin() -> None: If ``ADMIN_PASSWORD`` is empty or unset a secure random password is generated and displayed in the logs. """ + # Assign db = SessionLocal() db = SessionLocal() + # Attempt the following; catch errors below try: + # Assign admin_username = os.environ.get("ADMIN_USERNAME", "admin").strip() or "admin" admin_username = os.environ.get("ADMIN_USERNAME", "admin").strip() or "admin" + # Assign existing = db.query(User).filter(User.username == admin_username).first() existing = db.query(User).filter(User.username == admin_username).first() + # Check: existing if existing: + # Call print() print(f"Admin user '{admin_username}' already exists — skipping.") + # Return control to caller return + # Assign admin_password = os.environ.get("ADMIN_PASSWORD", "").strip() admin_password = os.environ.get("ADMIN_PASSWORD", "").strip() + # Assign password_was_generated = False password_was_generated = False + # Check: not admin_password if not admin_password: + # Assign admin_password = _generate_password() admin_password = _generate_password() + # Assign password_was_generated = True password_was_generated = True + # Assign admin = User( admin = User( + # Keyword argument: username username=admin_username, + # Keyword argument: hashed_password hashed_password=hash_password(admin_password), + # Keyword argument: role role="admin", ) + # Stage new record(s) for database insertion db.add(admin) + # Commit all pending changes to the database db.commit() # ── Display credentials in startup logs ────────────────────── print() + # Call print() print("=" * 60) + # Call print() print(" AEGIS — Initial Admin User Created") + # Call print() print("=" * 60) + # Call print() print(f" Username : {admin_username}") + # Check: password_was_generated if password_was_generated: + # Call print() print(f" Password : {admin_password}") + # Call print() print() + # Call print() print(" ** This password was auto-generated because") + # Call print() print(" ADMIN_PASSWORD was not set in the environment. **") + # Call print() print(" ** Save it now — it will NOT be shown again. **") + # Fallback: handle remaining cases else: + # Call print() print(" Password : (set via ADMIN_PASSWORD env var)") + # Call print() print("=" * 60) + # Call print() print() + # Always execute this cleanup block finally: + # Close the database session db.close() +# Check: __name__ == "__main__" if __name__ == "__main__": + # Call seed_admin() seed_admin() diff --git a/backend/app/seed_data_sources.py b/backend/app/seed_data_sources.py index 0c4b723..a229653 100644 --- a/backend/app/seed_data_sources.py +++ b/backend/app/seed_data_sources.py @@ -1,15 +1,19 @@ -""" -Seed script — registers all known data sources in the data_sources table. +"""Seed script — registers all known data sources in the data_sources table. Usage: python -m app.seed_data_sources """ +# Import logging import logging +# Import SessionLocal from app.database from app.database import SessionLocal + +# Import DataSource from app.models.data_source from app.models.data_source import DataSource +# Assign logger = logging.getLogger(__name__) logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- @@ -18,164 +22,287 @@ logger = logging.getLogger(__name__) INITIAL_SOURCES = [ { + # Literal argument value "name": "atomic_red_team", + # Literal argument value "display_name": "Atomic Red Team", + # Literal argument value "type": "attack_procedure", + # Literal argument value "url": "https://github.com/redcanaryco/atomic-red-team", + # Literal argument value "description": "Open-source library of atomic tests mapped to MITRE ATT&CK. " + # Literal argument value "Each test is a small, self-contained procedure for validating " + # Literal argument value "detection of a specific technique.", + # Literal argument value "sync_frequency": "weekly", + # Literal argument value "config": { + # Literal argument value "zip_url": "https://github.com/redcanaryco/atomic-red-team/archive/refs/heads/master.zip", + # Literal argument value "root_prefix": "atomic-red-team-master", + # Literal argument value "atomics_dir": "atomics", }, }, { + # Literal argument value "name": "sigma", + # Literal argument value "display_name": "SigmaHQ Rules", + # Literal argument value "type": "detection_rule", + # Literal argument value "url": "https://github.com/SigmaHQ/sigma", + # Literal argument value "description": "Generic SIEM detection rules in YAML format. " + # Literal argument value "3 000+ rules with MITRE ATT&CK mappings.", + # Literal argument value "sync_frequency": "weekly", + # Literal argument value "config": { + # Literal argument value "zip_url": "https://github.com/SigmaHQ/sigma/archive/refs/heads/master.zip", + # Literal argument value "root_prefix": "sigma-master", + # Literal argument value "rules_dir": "rules", }, }, { + # Literal argument value "name": "lolbas", + # Literal argument value "display_name": "LOLBAS (Windows)", + # Literal argument value "type": "attack_procedure", + # Literal argument value "url": "https://github.com/LOLBAS-Project/LOLBAS", + # Literal argument value "description": "Living Off The Land Binaries, Scripts, and Libraries — " + # Literal argument value "legitimate Windows binaries that can be abused for attacks.", + # Literal argument value "sync_frequency": "monthly", + # Literal argument value "config": { + # Literal argument value "zip_url": "https://github.com/LOLBAS-Project/LOLBAS/archive/refs/heads/master.zip", + # Literal argument value "root_prefix": "LOLBAS-master", + # Literal argument value "yaml_dirs": ["yml/OSBinaries", "yml/OSLibraries", "yml/OSScripts"], }, }, { + # Literal argument value "name": "gtfobins", + # Literal argument value "display_name": "GTFOBins (Linux)", + # Literal argument value "type": "attack_procedure", + # Literal argument value "url": "https://gtfobins.github.io/", + # Literal argument value "description": "Unix/Linux binaries that can be exploited for file transfer, " + # Literal argument value "shell escape, privilege escalation, and more.", + # Literal argument value "sync_frequency": "monthly", + # Literal argument value "config": { + # Literal argument value "zip_url": "https://github.com/GTFOBins/GTFOBins.github.io/archive/refs/heads/master.zip", + # Literal argument value "root_prefix": "GTFOBins.github.io-master", + # Literal argument value "gtfobins_dir": "_gtfobins", }, }, { + # Literal argument value "name": "caldera", + # Literal argument value "display_name": "MITRE CALDERA", + # Literal argument value "type": "attack_procedure", + # Literal argument value "url": "https://github.com/mitre/stockpile", + # Literal argument value "description": "Automated adversary emulation platform by MITRE. " + # Literal argument value "400+ abilities (executable actions) mapped to ATT&CK " + # Literal argument value "(via the Stockpile plugin).", + # Literal argument value "sync_frequency": "monthly", + # Literal argument value "config": { + # Literal argument value "zip_url": "https://github.com/mitre/stockpile/archive/refs/heads/master.zip", + # Literal argument value "root_prefix": "stockpile-master", + # Literal argument value "abilities_dir": "data/abilities", }, }, { + # Literal argument value "name": "elastic_rules", + # Literal argument value "display_name": "Elastic Detection Rules", + # Literal argument value "type": "detection_rule", + # Literal argument value "url": "https://github.com/elastic/detection-rules", + # Literal argument value "description": "Open-source detection rules for Elastic SIEM. " + # Literal argument value "1 000+ rules in KQL with MITRE ATT&CK mappings.", + # Literal argument value "sync_frequency": "weekly", + # Literal argument value "config": { + # Literal argument value "zip_url": "https://github.com/elastic/detection-rules/archive/refs/heads/main.zip", + # Literal argument value "root_prefix": "detection-rules-main", + # Literal argument value "rules_dir": "rules", }, }, { + # Literal argument value "name": "d3fend", + # Literal argument value "display_name": "MITRE D3FEND", + # Literal argument value "type": "defensive_technique", + # Literal argument value "url": "https://d3fend.mitre.org/", + # Literal argument value "description": "MITRE framework of defensive countermeasures. " + # Literal argument value "200+ defensive techniques mapped to ATT&CK.", + # Literal argument value "sync_frequency": "monthly", + # Literal argument value "config": {}, }, { + # Literal argument value "name": "mitre_cti", + # Literal argument value "display_name": "MITRE CTI (Groups & Software)", + # Literal argument value "type": "threat_intel", + # Literal argument value "url": "https://github.com/mitre/cti", + # Literal argument value "description": "MITRE ATT&CK STIX 2.0 data — threat actor groups, " + # Literal argument value "software, and campaigns with TTP mappings.", + # Literal argument value "sync_frequency": "monthly", + # Literal argument value "config": { + # Literal argument value "zip_url": "https://github.com/mitre/cti/archive/refs/heads/master.zip", + # Literal argument value "root_prefix": "cti-master", + # Literal argument value "enterprise_dir": "enterprise-attack", }, }, ] +# Define function seed_data_sources def seed_data_sources() -> dict: """Register all known data sources. Existing entries are skipped.""" + # Assign db = SessionLocal() db = SessionLocal() + # Attempt the following; catch errors below try: + # Assign created = 0 created = 0 + # Assign skipped = 0 skipped = 0 + # Assign existing_names = { existing_names = { row[0] for row in db.query(DataSource.name).all() } + # Iterate over INITIAL_SOURCES for src in INITIAL_SOURCES: + # Check: src["name"] in existing_names if src["name"] in existing_names: + # Assign skipped = 1 skipped += 1 + # Skip to the next loop iteration continue + # Assign ds = DataSource( ds = DataSource( + # Keyword argument: name name=src["name"], + # Keyword argument: display_name display_name=src["display_name"], + # Keyword argument: type type=src["type"], + # Keyword argument: url url=src.get("url"), + # Keyword argument: description description=src.get("description"), + # Keyword argument: sync_frequency sync_frequency=src.get("sync_frequency", "manual"), + # Keyword argument: config config=src.get("config"), + # Keyword argument: is_enabled is_enabled=True, ) + # Stage new record(s) for database insertion db.add(ds) + # Assign created = 1 created += 1 + # Commit all pending changes to the database db.commit() + # Assign summary = {"created": created, "skipped": skipped} summary = {"created": created, "skipped": skipped} + # Log info: "Data sources seed: %s", summary logger.info("Data sources seed: %s", summary) + # Return summary return summary + # Handle Exception except Exception: + # Roll back all uncommitted changes db.rollback() + # raise raise + # Always execute this cleanup block finally: + # Close the database session db.close() +# Check: __name__ == "__main__" if __name__ == "__main__": + # Call logging.basicConfig() logging.basicConfig( + # Keyword argument: level level=logging.INFO, + # Keyword argument: format format="%(asctime)s %(levelname)-8s %(name)s — %(message)s", ) + # Assign result = seed_data_sources() result = seed_data_sources() + # Call print() print(f"\nData sources seed complete: {result}") diff --git a/backend/app/seed_demo.py b/backend/app/seed_demo.py index 4e18009..a1d8b4a 100644 --- a/backend/app/seed_demo.py +++ b/backend/app/seed_demo.py @@ -1,5 +1,4 @@ -""" -Seed script — generates a realistic volume of demo data for V3 validation. +"""Seed script — generates a realistic volume of demo data for V3 validation. Usage: python -m app.seed_demo @@ -11,22 +10,52 @@ Running twice is safe — the script detects existing demo data (by username prefix ``demo_``) and deletes it before re-creating, ensuring idempotency. """ +# Import logging import logging + +# Import random import random + +# Import uuid import uuid + +# Import datetime, timedelta from datetime from datetime import datetime, timedelta -from app.auth import hash_password -from app.database import SessionLocal -from app.models.user import User -from app.models.technique import Technique -from app.models.test import Test -from app.models.test_template import TestTemplate -from app.models.evidence import Evidence -from app.models.audit import AuditLog -from app.models.notification import Notification -from app.models.enums import TechniqueStatus, TestState, TestResult, TeamSide +# Import Session from sqlalchemy.orm +from sqlalchemy.orm import Session +# Import hash_password from app.auth +from app.auth import hash_password + +# Import SessionLocal from app.database +from app.database import SessionLocal + +# Import AuditLog from app.models.audit +from app.models.audit import AuditLog + +# Import TeamSide, TechniqueStatus, TestResult, TestState from app.models.enums +from app.models.enums import TeamSide, TechniqueStatus, TestResult, TestState + +# Import Evidence from app.models.evidence +from app.models.evidence import Evidence + +# Import Notification from app.models.notification +from app.models.notification import Notification + +# Import Technique from app.models.technique +from app.models.technique import Technique + +# Import Test from app.models.test +from app.models.test import Test + +# Import TestTemplate from app.models.test_template +from app.models.test_template import TestTemplate + +# Import User from app.models.user +from app.models.user import User + +# Assign logger = logging.getLogger(__name__) logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- @@ -35,8 +64,10 @@ logger = logging.getLogger(__name__) DEMO_PREFIX = "demo_" +# Assign ROLES = ["red_tech", "blue_tech", "red_lead", "blue_lead", "admin"] ROLES = ["red_tech", "blue_tech", "red_lead", "blue_lead", "admin"] +# Assign TECHNIQUE_STATUSES = [ TECHNIQUE_STATUSES = [ TechniqueStatus.validated, TechniqueStatus.partial, @@ -45,6 +76,7 @@ TECHNIQUE_STATUSES = [ TechniqueStatus.not_evaluated, ] +# Assign TEST_STATES = [ TEST_STATES = [ TestState.draft, TestState.red_executing, @@ -54,45 +86,75 @@ TEST_STATES = [ TestState.rejected, ] +# Assign TEST_RESULTS = [ TEST_RESULTS = [ TestResult.detected, TestResult.not_detected, TestResult.partially_detected, ] +# Assign NOTIFICATION_TYPES = [ NOTIFICATION_TYPES = [ + # Literal argument value "test_assigned", + # Literal argument value "validation_needed", + # Literal argument value "test_rejected", + # Literal argument value "test_validated", + # Literal argument value "test_state_changed", ] +# Assign AUDIT_ACTIONS = [ AUDIT_ACTIONS = [ + # Literal argument value "create_test", + # Literal argument value "update_test", + # Literal argument value "validate_technique", + # Literal argument value "upload_evidence", + # Literal argument value "create_user", + # Literal argument value "import_atomic_red_team", + # Literal argument value "sync_mitre", + # Literal argument value "login", + # Literal argument value "reject_test", + # Literal argument value "approve_test", ] +# Assign PLATFORMS = ["windows", "linux", "macos"] PLATFORMS = ["windows", "linux", "macos"] +# Assign TEMPLATE_NAMES = [ TEMPLATE_NAMES = [ + # Literal argument value "Manual Credential Dumping Test", + # Literal argument value "Custom Phishing Payload Delivery", + # Literal argument value "Lateral Movement via RDP", + # Literal argument value "Persistence via Registry Run Keys", + # Literal argument value "Data Exfiltration over DNS", + # Literal argument value "Process Injection via DLL", + # Literal argument value "Privilege Escalation with Token Impersonation", + # Literal argument value "Custom C2 Beacon Communication Test", + # Literal argument value "Kerberoasting Attack Procedure", + # Literal argument value "Living Off The Land Binaries Test", ] @@ -102,12 +164,14 @@ TEMPLATE_NAMES = [ # --------------------------------------------------------------------------- -def _cleanup_demo_data(db) -> None: +def _cleanup_demo_data(db: Session) -> None: """Remove all previously seeded demo data.""" # Delete in order to respect FK constraints demo_users = db.query(User).filter(User.username.like(f"{DEMO_PREFIX}%")).all() + # Assign demo_user_ids = [u.id for u in demo_users] demo_user_ids = [u.id for u in demo_users] + # Check: demo_user_ids if demo_user_ids: # Notifications for demo users db.query(Notification).filter( @@ -123,13 +187,17 @@ def _cleanup_demo_data(db) -> None: demo_tests = db.query(Test).filter( Test.created_by.in_(demo_user_ids) ).all() + # Assign demo_test_ids = [t.id for t in demo_tests] demo_test_ids = [t.id for t in demo_tests] + # Check: demo_test_ids if demo_test_ids: + # Begin database query db.query(Evidence).filter( Evidence.test_id.in_(demo_test_ids) ).delete(synchronize_session=False) + # Begin database query db.query(Test).filter( Test.id.in_(demo_test_ids) ).delete(synchronize_session=False) @@ -141,11 +209,14 @@ def _cleanup_demo_data(db) -> None: # Delete demo users if demo_user_ids: + # Begin database query db.query(User).filter( User.id.in_(demo_user_ids) ).delete(synchronize_session=False) + # Commit all pending changes to the database db.commit() + # Log info: "Cleaned up existing demo data." logger.info("Cleaned up existing demo data.") @@ -154,215 +225,384 @@ def _cleanup_demo_data(db) -> None: # --------------------------------------------------------------------------- -def _seed_users(db) -> list[User]: +def _seed_users(db: Session) -> list[User]: """Create 5 users per role (25 total).""" + # Assign users = [] users = [] + # Iterate over ROLES for role in ROLES: + # Iterate over range(1, 6) for i in range(1, 6): + # Assign user = User( user = User( + # Keyword argument: username username=f"{DEMO_PREFIX}{role}_{i}", + # Keyword argument: email email=f"{DEMO_PREFIX}{role}_{i}@aegis-demo.local", + # Keyword argument: hashed_password hashed_password=hash_password("demo123"), + # Keyword argument: role role=role, + # Keyword argument: is_active is_active=True, ) + # Stage new record(s) for database insertion db.add(user) + # Call users.append() users.append(user) + # Flush changes to DB without committing the transaction db.flush() + # Log info: "Created %d demo users.", len(users logger.info("Created %d demo users.", len(users)) + # Return users return users -def _seed_technique_statuses(db, count: int = 50) -> list[Technique]: +# Define function _seed_technique_statuses +def _seed_technique_statuses(db: Session, count: int = 50) -> list[Technique]: """Set varied statuses on up to *count* techniques.""" + # Assign techniques = db.query(Technique).limit(count).all() techniques = db.query(Technique).limit(count).all() + # Check: not techniques if not techniques: + # Log warning: "No techniques found — run MITRE sync first!" logger.warning("No techniques found — run MITRE sync first!") + # Return [] return [] + # Iterate over techniques for tech in techniques: + # Assign tech.status_global = random.choice(TECHNIQUE_STATUSES) tech.status_global = random.choice(TECHNIQUE_STATUSES) + # Check: tech.status_global == TechniqueStatus.validated if tech.status_global == TechniqueStatus.validated: + # Assign tech.last_review_date = datetime.utcnow() - timedelta( tech.last_review_date = datetime.utcnow() - timedelta( + # Keyword argument: days days=random.randint(1, 30) ) + # Flush changes to DB without committing the transaction db.flush() + # Log info: "Updated status on %d techniques.", len(techniques logger.info("Updated status on %d techniques.", len(techniques)) + # Return techniques return techniques -def _seed_tests(db, users: list[User], techniques: list[Technique], count: int = 100) -> list[Test]: +# Define function _seed_tests +def _seed_tests(db: Session, users: list[User], techniques: list[Technique], count: int = 100) -> list[Test]: """Create *count* tests in various pipeline states.""" + # Check: not techniques if not techniques: + # Log warning: "No techniques available — skipping test seeding." logger.warning("No techniques available — skipping test seeding.") + # Return [] return [] + # Assign red_techs = [u for u in users if u.role == "red_tech"] red_techs = [u for u in users if u.role == "red_tech"] + # Assign blue_techs = [u for u in users if u.role == "blue_tech"] blue_techs = [u for u in users if u.role == "blue_tech"] + # Assign red_leads = [u for u in users if u.role == "red_lead"] red_leads = [u for u in users if u.role == "red_lead"] + # Assign blue_leads = [u for u in users if u.role == "blue_lead"] blue_leads = [u for u in users if u.role == "blue_lead"] + # Assign tests = [] tests = [] + # Iterate over range(count) for i in range(count): + # Assign technique = random.choice(techniques) technique = random.choice(techniques) + # Assign state = random.choice(TEST_STATES) state = random.choice(TEST_STATES) + # Assign creator = random.choice(red_techs + blue_techs) creator = random.choice(red_techs + blue_techs) + # Assign test = Test( test = Test( + # Keyword argument: technique_id technique_id=technique.id, + # Keyword argument: name name=f"Demo Test {i + 1} — {technique.name[:40]}", + # Keyword argument: description description=f"Automated demo test #{i + 1} for {technique.mitre_id}.", + # Keyword argument: platform platform=random.choice(PLATFORMS), - procedure_text=f"Step 1: Prepare environment.\nStep 2: Execute {technique.mitre_id} procedure.\nStep 3: Observe results.", + # Keyword argument: procedure_text + procedure_text=( + f"Step 1: Prepare environment.\n" + f"Step 2: Execute {technique.mitre_id} procedure.\n" + f"Step 3: Observe results." + ), + # Keyword argument: tool_used tool_used=random.choice(["powershell", "bash", "cmd", "python", "caldera", "metasploit"]), + # Keyword argument: execution_date execution_date=datetime.utcnow() - timedelta(days=random.randint(0, 60)), + # Keyword argument: created_by created_by=creator.id, + # Keyword argument: result result=random.choice(TEST_RESULTS) if state not in (TestState.draft, TestState.red_executing) else None, + # Keyword argument: state state=state, + # Keyword argument: created_at created_at=datetime.utcnow() - timedelta(days=random.randint(0, 90)), ) # Populate team fields based on state if state in (TestState.blue_evaluating, TestState.in_review, TestState.validated, TestState.rejected): + # Assign test.red_summary = f"Attack executed successfully using {test.tool_used}." test.red_summary = f"Attack executed successfully using {test.tool_used}." + # Assign test.attack_success = random.choice([True, True, True, False]) test.attack_success = random.choice([True, True, True, False]) + # Check: state in (TestState.in_review, TestState.validated, TestState.rejec... if state in (TestState.in_review, TestState.validated, TestState.rejected): + # Assign test.blue_summary = "Detection observed in SIEM. Alert fired." test.blue_summary = "Detection observed in SIEM. Alert fired." + # Assign test.detection_result = random.choice(TEST_RESULTS) test.detection_result = random.choice(TEST_RESULTS) + # Check: state == TestState.validated if state == TestState.validated: + # Assign rv = random.choice(red_leads) rv = random.choice(red_leads) + # Assign bv = random.choice(blue_leads) bv = random.choice(blue_leads) + # Assign test.red_validated_by = rv.id test.red_validated_by = rv.id + # Assign test.red_validated_at = datetime.utcnow() - timedelta(days=random.randint(0, 10)) test.red_validated_at = datetime.utcnow() - timedelta(days=random.randint(0, 10)) + # Assign test.red_validation_status = "approved" test.red_validation_status = "approved" + # Assign test.blue_validated_by = bv.id test.blue_validated_by = bv.id + # Assign test.blue_validated_at = datetime.utcnow() - timedelta(days=random.randint(0, 10)) test.blue_validated_at = datetime.utcnow() - timedelta(days=random.randint(0, 10)) + # Assign test.blue_validation_status = "approved" test.blue_validation_status = "approved" + # Check: state == TestState.rejected if state == TestState.rejected: + # Assign rejector = random.choice(red_leads + blue_leads) rejector = random.choice(red_leads + blue_leads) + # Check: rejector.role == "red_lead" if rejector.role == "red_lead": + # Assign test.red_validated_by = rejector.id test.red_validated_by = rejector.id + # Assign test.red_validated_at = datetime.utcnow() - timedelta(days=random.randint(0, 5)) test.red_validated_at = datetime.utcnow() - timedelta(days=random.randint(0, 5)) + # Assign test.red_validation_status = "rejected" test.red_validation_status = "rejected" + # Assign test.red_validation_notes = "Insufficient evidence of attack success." test.red_validation_notes = "Insufficient evidence of attack success." + # Fallback: handle remaining cases else: + # Assign test.blue_validated_by = rejector.id test.blue_validated_by = rejector.id + # Assign test.blue_validated_at = datetime.utcnow() - timedelta(days=random.randint(0, 5)) test.blue_validated_at = datetime.utcnow() - timedelta(days=random.randint(0, 5)) + # Assign test.blue_validation_status = "rejected" test.blue_validation_status = "rejected" + # Assign test.blue_validation_notes = "Detection evidence not conclusive." test.blue_validation_notes = "Detection evidence not conclusive." + # Stage new record(s) for database insertion db.add(test) + # Call tests.append() tests.append(test) + # Flush changes to DB without committing the transaction db.flush() + # Log info: "Created %d demo tests.", len(tests logger.info("Created %d demo tests.", len(tests)) + # Return tests return tests -def _seed_evidences(db, tests: list[Test], users: list[User], count: int = 50) -> list[Evidence]: +# Define function _seed_evidences +def _seed_evidences(db: Session, tests: list[Test], users: list[User], count: int = 50) -> list[Evidence]: """Create *count* dummy evidence records.""" + # Check: not tests if not tests: + # Return [] return [] # Pick tests that are past draft state eligible = [t for t in tests if t.state != TestState.draft] + # Check: not eligible if not eligible: + # Assign eligible = tests eligible = tests + # Assign evidences = [] evidences = [] + # Assign red_blue = [u for u in users if u.role in ("red_tech", "blue_tech")] red_blue = [u for u in users if u.role in ("red_tech", "blue_tech")] + # Iterate over range(count) for i in range(count): + # Assign test = random.choice(eligible) test = random.choice(eligible) + # Assign uploader = random.choice(red_blue) uploader = random.choice(red_blue) + # Assign team = TeamSide.red if uploader.role == "red_tech" else TeamSide.blue team = TeamSide.red if uploader.role == "red_tech" else TeamSide.blue + # Assign ext = random.choice(["png", "log", "pcap", "csv", "txt", "json"]) ext = random.choice(["png", "log", "pcap", "csv", "txt", "json"]) + # Assign fname = f"evidence_{i + 1}.{ext}" fname = f"evidence_{i + 1}.{ext}" + # Assign evidence = Evidence( evidence = Evidence( + # Keyword argument: test_id test_id=test.id, + # Keyword argument: file_name file_name=fname, + # Keyword argument: file_path file_path=f"{test.id}/{uuid.uuid4()}_{fname}", + # Keyword argument: sha256_hash sha256_hash=uuid.uuid4().hex + uuid.uuid4().hex, # dummy hash + # Keyword argument: uploaded_by uploaded_by=uploader.id, + # Keyword argument: uploaded_at uploaded_at=datetime.utcnow() - timedelta(days=random.randint(0, 30)), + # Keyword argument: team team=team, + # Keyword argument: notes notes=f"Auto-generated demo evidence #{i + 1}.", ) + # Stage new record(s) for database insertion db.add(evidence) + # Call evidences.append() evidences.append(evidence) + # Flush changes to DB without committing the transaction db.flush() + # Log info: "Created %d demo evidences.", len(evidences logger.info("Created %d demo evidences.", len(evidences)) + # Return evidences return evidences -def _seed_audit_logs(db, users: list[User], count: int = 20) -> None: +# Define function _seed_audit_logs +def _seed_audit_logs(db: Session, users: list[User], count: int = 20) -> None: """Create *count* varied audit log entries.""" + # Iterate over range(count) for i in range(count): + # Assign user = random.choice(users) user = random.choice(users) + # Assign log = AuditLog( log = AuditLog( + # Keyword argument: user_id user_id=user.id, + # Keyword argument: action action=random.choice(AUDIT_ACTIONS), + # Keyword argument: entity_type entity_type=random.choice(["test", "technique", "user", "test_template"]), + # Keyword argument: entity_id entity_id=str(uuid.uuid4()), + # Keyword argument: timestamp timestamp=datetime.utcnow() - timedelta(days=random.randint(0, 60)), + # Keyword argument: details details={"demo": True, "index": i}, ) + # Stage new record(s) for database insertion db.add(log) + # Flush changes to DB without committing the transaction db.flush() + # Log info: "Created %d demo audit logs.", count logger.info("Created %d demo audit logs.", count) -def _seed_notifications(db, users: list[User], count: int = 30) -> None: +# Define function _seed_notifications +def _seed_notifications(db: Session, users: list[User], count: int = 30) -> None: """Create *count* notifications spread across demo users.""" + # Iterate over range(count) for i in range(count): + # Assign user = random.choice(users) user = random.choice(users) + # Assign ntype = random.choice(NOTIFICATION_TYPES) ntype = random.choice(NOTIFICATION_TYPES) + # Assign notif = Notification( notif = Notification( + # Keyword argument: user_id user_id=user.id, + # Keyword argument: type type=ntype, + # Keyword argument: title title=f"Demo notification: {ntype.replace('_', ' ').title()} #{i + 1}", + # Keyword argument: message message=f"This is an auto-generated demo notification ({ntype}).", + # Keyword argument: entity_type entity_type="test", + # Keyword argument: entity_id entity_id=uuid.uuid4(), + # Keyword argument: read read=random.choice([True, False]), + # Keyword argument: created_at created_at=datetime.utcnow() - timedelta(days=random.randint(0, 30)), ) + # Stage new record(s) for database insertion db.add(notif) + # Flush changes to DB without committing the transaction db.flush() + # Log info: "Created %d demo notifications.", count logger.info("Created %d demo notifications.", count) -def _seed_templates(db, techniques: list[Technique], count: int = 10) -> None: +# Define function _seed_templates +def _seed_templates(db: Session, techniques: list[Technique], count: int = 10) -> None: """Create *count* manual demo templates.""" + # Check: not techniques if not techniques: + # Return control to caller return + # Iterate over enumerate(TEMPLATE_NAMES[ for i, name in enumerate(TEMPLATE_NAMES[:count]): + # Assign technique = techniques[i % len(techniques)] technique = techniques[i % len(techniques)] + # Assign template = TestTemplate( template = TestTemplate( + # Keyword argument: mitre_technique_id mitre_technique_id=technique.mitre_id, + # Keyword argument: name name=name, + # Keyword argument: description description=f"Demo template: {name}. Targets {technique.mitre_id} ({technique.name}).", + # Keyword argument: source source="demo", + # Keyword argument: source_url source_url=None, - attack_procedure=f"1. Set up environment for {technique.mitre_id}.\n2. Execute the procedure.\n3. Record observations.", + # Keyword argument: attack_procedure + attack_procedure=( + f"1. Set up environment for {technique.mitre_id}.\n" + # Literal argument value + "2. Execute the procedure.\n" + # Literal argument value + "3. Record observations." + ), + # Keyword argument: expected_detection expected_detection=f"SIEM should alert on {technique.mitre_id} indicators.", + # Keyword argument: platform platform=random.choice(PLATFORMS), + # Keyword argument: tool_suggested tool_suggested=random.choice(["powershell", "cmd", "bash", "python"]), + # Keyword argument: severity severity=random.choice(["low", "medium", "high", "critical"]), + # Keyword argument: is_active is_active=True, ) + # Stage new record(s) for database insertion db.add(template) + # Flush changes to DB without committing the transaction db.flush() + # Log info: "Created %d demo templates.", count logger.info("Created %d demo templates.", count) @@ -373,8 +613,11 @@ def _seed_templates(db, techniques: list[Technique], count: int = 10) -> None: def seed_demo() -> dict: """Generate all demo data. Returns a summary dict.""" + # Assign db = SessionLocal() db = SessionLocal() + # Attempt the following; catch errors below try: + # Log info: "=== Starting V3 demo seed ===" logger.info("=== Starting V3 demo seed ===") # Step 0: cleanup previous run @@ -401,31 +644,53 @@ def seed_demo() -> dict: # Step 7: templates _seed_templates(db, techniques, count=10) + # Commit all pending changes to the database db.commit() + # Assign summary = { summary = { + # Literal argument value "users": len(users), + # Literal argument value "techniques_updated": len(techniques), + # Literal argument value "tests": len(tests), + # Literal argument value "evidences": len(evidences), + # Literal argument value "audit_logs": 20, + # Literal argument value "notifications": 30, + # Literal argument value "templates": 10, } + # Log info: "=== Demo seed complete: %s ===", summary logger.info("=== Demo seed complete: %s ===", summary) + # Return summary return summary + # Handle Exception except Exception: + # Roll back all uncommitted changes db.rollback() + # raise raise + # Always execute this cleanup block finally: + # Close the database session db.close() +# Check: __name__ == "__main__" if __name__ == "__main__": + # Call logging.basicConfig() logging.basicConfig( + # Keyword argument: level level=logging.INFO, + # Keyword argument: format format="%(asctime)s %(levelname)-8s %(name)s — %(message)s", ) + # Assign result = seed_demo() result = seed_demo() + # Call print() print(f"\nSeed complete: {result}") diff --git a/backend/app/services/__init__.py b/backend/app/services/__init__.py index e69de29..7509b6f 100644 --- a/backend/app/services/__init__.py +++ b/backend/app/services/__init__.py @@ -0,0 +1 @@ +"""Service layer — business logic orchestrating domain entities and persistence.""" diff --git a/backend/app/services/advanced_metrics_service.py b/backend/app/services/advanced_metrics_service.py index 31bb2f4..f4e4f18 100644 --- a/backend/app/services/advanced_metrics_service.py +++ b/backend/app/services/advanced_metrics_service.py @@ -1,19 +1,31 @@ """Advanced metrics service — coverage by tactic, never-tested, avg validation time, detection trend.""" +# Enable future language features for compatibility from __future__ import annotations +# Import datetime, timedelta from datetime from datetime import datetime, timedelta +# Import case, func from sqlalchemy from sqlalchemy import case, func + +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session -from app.models.technique import Technique -from app.models.test import Test +# Import TestResult from app.models.enums from app.models.enums import TestResult +# Import Technique from app.models.technique +from app.models.technique import Technique +# Import Test from app.models.test +from app.models.test import Test + + +# Define function get_coverage_by_tactic def get_coverage_by_tactic(db: Session) -> list[dict]: """Coverage percentage broken down by MITRE ATT&CK tactic.""" + # Assign results = ( results = ( db.query( Technique.tactic, @@ -31,134 +43,211 @@ def get_coverage_by_tactic(db: Session) -> list[dict]: case((Technique.status_global == "in_progress", 1), else_=0) ).label("in_progress"), ) + # Chain .group_by() call .group_by(Technique.tactic) + # Chain .order_by() call .order_by(Technique.tactic) + # Chain .all() call .all() ) + # Return [ return [ { + # Literal argument value "tactic": r[0] or "Unknown", + # Literal argument value "total": r[1], + # Literal argument value "validated": int(r[2]), + # Literal argument value "partial": int(r[3]), + # Literal argument value "not_covered": int(r[4]), + # Literal argument value "in_progress": int(r[5]), + # Literal argument value "coverage_pct": round((int(r[2]) / r[1]) * 100, 1) if r[1] > 0 else 0, } for r in results ] +# Define function get_never_tested_techniques def get_never_tested_techniques(db: Session) -> list[dict]: """Techniques that have never had a test created.""" + # Assign tested_ids = [ tested_ids = [ row[0] for row in db.query(Test.technique_id) + # Chain .filter() call .filter(Test.technique_id.isnot(None)) + # Chain .distinct() call .distinct() + # Chain .all() call .all() ] + # Assign query = db.query(Technique) query = db.query(Technique) + # Check: tested_ids if tested_ids: + # Assign query = query.filter(~Technique.id.in_(tested_ids)) query = query.filter(~Technique.id.in_(tested_ids)) + # Assign techniques = query.order_by(Technique.mitre_id).all() techniques = query.order_by(Technique.mitre_id).all() + # Return [ return [ { + # Literal argument value "mitre_id": t.mitre_id, + # Literal argument value "name": t.name, + # Literal argument value "tactic": t.tactic, + # Literal argument value "is_subtechnique": t.is_subtechnique, } for t in techniques ] +# Define function get_avg_validation_time def get_avg_validation_time(db: Session) -> dict: """Average time from test creation to validation, computed from validated tests. Returns overall average and per-phase averages where data is available. """ + # Assign validated_tests = ( validated_tests = ( db.query(Test) + # Chain .filter() call .filter(Test.state == "validated") + # Chain .all() call .all() ) + # Check: not validated_tests if not validated_tests: + # Return { return { + # Literal argument value "total_validated": 0, + # Literal argument value "avg_total_hours": 0, + # Literal argument value "avg_red_phase_hours": 0, + # Literal argument value "avg_blue_phase_hours": 0, } + # Assign total_durations = [] total_durations = [] + # Assign red_durations = [] red_durations = [] + # Assign blue_durations = [] blue_durations = [] + # Iterate over validated_tests for test in validated_tests: + # Check: test.created_at and test.red_validated_at if test.created_at and test.red_validated_at: + # Assign total_seconds = (test.red_validated_at - test.created_at).total_seconds() total_seconds = (test.red_validated_at - test.created_at).total_seconds() + # Call total_durations.append() total_durations.append(total_seconds) + # Check: test.red_started_at and test.blue_started_at if test.red_started_at and test.blue_started_at: + # Assign red_sec = (test.blue_started_at - test.red_started_at).total_seconds() red_sec = (test.blue_started_at - test.red_started_at).total_seconds() + # Assign red_paused = test.red_paused_seconds or 0 red_paused = test.red_paused_seconds or 0 + # Call red_durations.append() red_durations.append(max(red_sec - red_paused, 0)) + # Check: test.blue_started_at and test.blue_validated_at if test.blue_started_at and test.blue_validated_at: + # Assign blue_sec = (test.blue_validated_at - test.blue_started_at).total_seconds() blue_sec = (test.blue_validated_at - test.blue_started_at).total_seconds() + # Assign blue_paused = test.blue_paused_seconds or 0 blue_paused = test.blue_paused_seconds or 0 + # Call blue_durations.append() blue_durations.append(max(blue_sec - blue_paused, 0)) + # Define function avg_hours def avg_hours(durations: list[float]) -> float: + # Check: not durations if not durations: + # Return 0 return 0 + # Return round(sum(durations) / len(durations) / 3600, 2) return round(sum(durations) / len(durations) / 3600, 2) + # Return { return { + # Literal argument value "total_validated": len(validated_tests), + # Literal argument value "avg_total_hours": avg_hours(total_durations), + # Literal argument value "avg_red_phase_hours": avg_hours(red_durations), + # Literal argument value "avg_blue_phase_hours": avg_hours(blue_durations), } +# Define function get_detection_rate_trend def get_detection_rate_trend(db: Session) -> list[dict]: """Monthly detection rate trend for the last 12 months.""" + # Assign now = datetime.utcnow() now = datetime.utcnow() + # Assign months = [] months = [] + # Iterate over range(11, -1, -1) for i in range(11, -1, -1): + # Assign month_start = datetime(now.year, now.month, 1) - timedelta(days=i * 30) month_start = datetime(now.year, now.month, 1) - timedelta(days=i * 30) + # Assign month_end = month_start + timedelta(days=30) month_end = month_start + timedelta(days=30) + # Assign validated = ( validated = ( db.query(func.count(Test.id)) + # Chain .filter() call .filter( Test.state == "validated", Test.created_at >= month_start, Test.created_at < month_end, ) + # Chain .scalar() call .scalar() or 0 ) + # Assign detected = ( detected = ( db.query(func.count(Test.id)) + # Chain .filter() call .filter( Test.state == "validated", Test.detection_result == TestResult.detected, Test.created_at >= month_start, Test.created_at < month_end, ) + # Chain .scalar() call .scalar() or 0 ) + # Call months.append() months.append({ + # Literal argument value "month": month_start.strftime("%Y-%m"), + # Literal argument value "validated": validated, + # Literal argument value "detected": detected, + # Literal argument value "detection_rate": round((detected / validated) * 100, 1) if validated > 0 else 0, }) + # Return months return months diff --git a/backend/app/services/analytics_service.py b/backend/app/services/analytics_service.py index 7b03721..3ed3219 100644 --- a/backend/app/services/analytics_service.py +++ b/backend/app/services/analytics_service.py @@ -1,28 +1,50 @@ """Analytics service — flat JSON optimized for PowerBI / BI tools.""" +# Enable future language features for compatibility from __future__ import annotations +# Import func from sqlalchemy from sqlalchemy import func + +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session +# Import CoverageSnapshot from app.models.coverage_snapshot from app.models.coverage_snapshot import CoverageSnapshot + +# Import Technique from app.models.technique from app.models.technique import Technique + +# Import Test from app.models.test from app.models.test import Test + +# Import User from app.models.user from app.models.user import User +# Define function get_coverage_analytics def get_coverage_analytics(db: Session) -> list[dict]: """Coverage per technique — flat format for BI dashboards.""" + # Assign techniques = db.query(Technique).all() techniques = db.query(Technique).all() + # Return [ return [ { + # Literal argument value "mitre_id": t.mitre_id, + # Literal argument value "name": t.name, + # Literal argument value "tactic": t.tactic, + # Literal argument value "status": t.status_global.value if t.status_global else "not_evaluated", + # Literal argument value "is_subtechnique": t.is_subtechnique, + # Literal argument value "test_count": len(t.tests) if t.tests else 0, + # Literal argument value "review_required": t.review_required, + # Literal argument value "last_review_date": ( t.last_review_date.isoformat() if t.last_review_date else None ), @@ -31,76 +53,117 @@ def get_coverage_analytics(db: Session) -> list[dict]: ] +# Define function get_tests_analytics def get_tests_analytics( + # Entry: db db: Session, *, + # Entry: date_from date_from: str | None = None, + # Entry: date_to date_to: str | None = None, ) -> list[dict]: """All tests with timestamps — flat format for BI dashboards.""" + # Assign query = db.query(Test) query = db.query(Test) + # Check: date_from if date_from: + # Assign query = query.filter(Test.created_at >= date_from) query = query.filter(Test.created_at >= date_from) + # Check: date_to if date_to: + # Assign query = query.filter(Test.created_at <= date_to) query = query.filter(Test.created_at <= date_to) + # Assign tests = query.all() tests = query.all() + # Return [ return [ { + # Literal argument value "id": str(t.id), + # Literal argument value "technique_id": str(t.technique_id), + # Literal argument value "name": t.name, + # Literal argument value "state": t.state.value if t.state else None, + # Literal argument value "result": t.result.value if t.result else None, + # Literal argument value "detection_result": ( t.detection_result.value if t.detection_result else None ), + # Literal argument value "created_at": t.created_at.isoformat() if t.created_at else None, + # Literal argument value "execution_date": ( t.execution_date.isoformat() if t.execution_date else None ), + # Literal argument value "platform": t.platform, + # Literal argument value "tool_used": t.tool_used, + # Literal argument value "attack_success": t.attack_success, + # Literal argument value "remediation_status": t.remediation_status, } for t in tests ] +# Define function get_trends_analytics def get_trends_analytics(db: Session) -> list[dict]: """Historical coverage snapshots for trend visualization.""" + # Assign snapshots = ( snapshots = ( db.query(CoverageSnapshot) + # Chain .order_by() call .order_by(CoverageSnapshot.created_at) + # Chain .all() call .all() ) + # Return [ return [ { + # Literal argument value "date": s.created_at.isoformat() if s.created_at else None, + # Literal argument value "name": s.name, + # Literal argument value "total_techniques": s.total_techniques, + # Literal argument value "validated_count": s.validated_count, + # Literal argument value "partial_count": s.partial_count, + # Literal argument value "not_covered_count": s.not_covered_count, + # Literal argument value "organization_score": s.organization_score, } for s in snapshots ] +# Define function get_operators_analytics def get_operators_analytics(db: Session) -> list[dict]: """Per-operator metrics — for workload management dashboards.""" + # Assign results = ( results = ( db.query( User.username, User.role, func.count(Test.id).label("test_count"), ) + # Chain .outerjoin() call .outerjoin(Test, Test.created_by == User.id) + # Chain .group_by() call .group_by(User.id, User.username, User.role) + # Chain .all() call .all() ) + # Return [ return [ {"username": r[0], "role": r[1], "test_count": r[2]} for r in results diff --git a/backend/app/services/atomic_import_service.py b/backend/app/services/atomic_import_service.py index 870980d..85796fd 100644 --- a/backend/app/services/atomic_import_service.py +++ b/backend/app/services/atomic_import_service.py @@ -22,22 +22,39 @@ Running the import twice does **not** create duplicates. Existing templates are identified by their ``atomic_test_id`` and simply skipped. """ +# Import io import io + +# Import logging import logging -import os + +# Import shutil import shutil + +# Import tempfile import tempfile + +# Import zipfile import zipfile + +# Import Path from pathlib from pathlib import Path +# Import requests import requests as _requests + +# Import yaml import yaml + +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session +# Import TestTemplate from app.models.test_template from app.models.test_template import TestTemplate from app.models.technique import Technique from app.services.audit_service import log_action +# Assign logger = logging.getLogger(__name__) logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- @@ -45,7 +62,9 @@ logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- ATOMIC_RT_ZIP_URL = ( + # Literal argument value "https://github.com/redcanaryco/atomic-red-team" + # Literal argument value "/archive/refs/heads/master.zip" ) @@ -55,6 +74,11 @@ _DOWNLOAD_TIMEOUT = 300 # Top-level directory name inside the ZIP _ZIP_ROOT_PREFIX = "atomic-red-team-master" +# Safety limits for ZIP extraction — prevent zip-bomb DoS +_MAX_UNCOMPRESSED_SIZE = 500 * 1024 * 1024 # 500 MB +# Assign _MAX_ENTRIES = 50_000 +_MAX_ENTRIES = 50_000 + # --------------------------------------------------------------------------- # Internal helpers @@ -63,14 +87,21 @@ _ZIP_ROOT_PREFIX = "atomic-red-team-master" def _download_zip(url: str = ATOMIC_RT_ZIP_URL) -> bytes: """Download the Atomic Red Team ZIP and return its raw bytes.""" + # Log info: "Downloading Atomic Red Team ZIP from %s …", url logger.info("Downloading Atomic Red Team ZIP from %s …", url) + # Assign resp = _requests.get(url, timeout=_DOWNLOAD_TIMEOUT, stream=True) resp = _requests.get(url, timeout=_DOWNLOAD_TIMEOUT, stream=True) + # Call resp.raise_for_status() resp.raise_for_status() + # Assign content = resp.content content = resp.content + # Log info: "Downloaded %.1f MB", len(content) / (1024 * 1024 logger.info("Downloaded %.1f MB", len(content) / (1024 * 1024)) + # Return content return content +# Define function _safe_extract_zip def _safe_extract_zip(zip_bytes: bytes, dest: str) -> None: """Extract *zip_bytes* into *dest* with Zip Slip and Zip Bomb protection. @@ -78,51 +109,66 @@ def _safe_extract_zip(zip_bytes: bytes, dest: str) -> None: directory (path traversal / Zip Slip) or if the archive exceeds the safety limits. """ - # Maximum uncompressed size: 500 MB — prevents zip-bomb DoS - _MAX_UNCOMPRESSED_SIZE = 500 * 1024 * 1024 - # Maximum number of entries - _MAX_ENTRIES = 50_000 - + # Assign dest_path = Path(dest).resolve() dest_path = Path(dest).resolve() + # Open context manager with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf: + # Assign entries = zf.infolist() entries = zf.infolist() + # Check: len(entries) > _MAX_ENTRIES if len(entries) > _MAX_ENTRIES: + # Raise ValueError raise ValueError( f"ZIP archive contains {len(entries)} entries " f"(limit: {_MAX_ENTRIES}) — possible zip bomb" ) + # Assign total_size = sum(info.file_size for info in entries) total_size = sum(info.file_size for info in entries) + # Check: total_size > _MAX_UNCOMPRESSED_SIZE if total_size > _MAX_UNCOMPRESSED_SIZE: + # Raise ValueError raise ValueError( f"ZIP uncompressed size {total_size / (1024 * 1024):.0f} MB " f"exceeds limit of {_MAX_UNCOMPRESSED_SIZE / (1024 * 1024):.0f} MB" ) + # Iterate over entries for member in entries: + # Assign target = (dest_path / member.filename).resolve() target = (dest_path / member.filename).resolve() + # Check: not target.is_relative_to(dest_path) if not target.is_relative_to(dest_path): + # Raise ValueError raise ValueError( f"Zip Slip detected — member '{member.filename}' " f"resolves outside target directory" ) + # Call zf.extractall() zf.extractall(dest) +# Define function _extract_zip def _extract_zip(zip_bytes: bytes, dest: str) -> Path: """Extract *zip_bytes* into *dest* and return the path to the atomics/ dir.""" + # Call _safe_extract_zip() _safe_extract_zip(zip_bytes, dest) + # Assign atomics_dir = Path(dest) / _ZIP_ROOT_PREFIX / "atomics" atomics_dir = Path(dest) / _ZIP_ROOT_PREFIX / "atomics" + # Check: not atomics_dir.is_dir() if not atomics_dir.is_dir(): + # Raise FileNotFoundError raise FileNotFoundError( f"Expected atomics directory not found at {atomics_dir}" ) + # Return atomics_dir return atomics_dir +# Define function _parse_yaml_files def _parse_yaml_files(atomics_dir: Path) -> list[dict]: """Walk the atomics directory and parse all technique YAML files. @@ -132,51 +178,84 @@ def _parse_yaml_files(atomics_dir: Path) -> list[dict]: technique_id, index, name, description, platforms, executor_type, command, source_url """ + # Assign results = [] results: list[dict] = [] + # Assign yaml_files = sorted(atomics_dir.glob("T*/T*.yaml")) yaml_files = sorted(atomics_dir.glob("T*/T*.yaml")) + # Log info: "Found %d YAML files to parse", len(yaml_files logger.info("Found %d YAML files to parse", len(yaml_files)) + # Iterate over yaml_files for yaml_path in yaml_files: + # Assign technique_id = yaml_path.stem # e.g. "T1059.001" technique_id = yaml_path.stem # e.g. "T1059.001" + # Attempt the following; catch errors below try: + # Open context manager with open(yaml_path, "r", encoding="utf-8") as fh: + # Assign data = yaml.safe_load(fh) data = yaml.safe_load(fh) + # Handle Exception except Exception as exc: + # Log warning: "Failed to parse %s: %s", yaml_path, exc logger.warning("Failed to parse %s: %s", yaml_path, exc) + # Skip to the next loop iteration continue + # Check: not data or "atomic_tests" not in data if not data or "atomic_tests" not in data: + # Skip to the next loop iteration continue + # Iterate over enumerate(data["atomic_tests"]) for idx, test in enumerate(data["atomic_tests"]): + # Assign name = test.get("name", "").strip() name = test.get("name", "").strip() + # Assign description = test.get("description", "").strip() description = test.get("description", "").strip() + # Assign platforms = test.get("supported_platforms", []) platforms = test.get("supported_platforms", []) + # Assign executor = test.get("executor", {}) executor = test.get("executor", {}) + # Assign executor_type = executor.get("name", "") if isinstance(executor, dict) else "" executor_type = executor.get("name", "") if isinstance(executor, dict) else "" + # Assign command = executor.get("command", "") if isinstance(executor, dict) else "" command = executor.get("command", "") if isinstance(executor, dict) else "" # Build an atomic_test_id in the format "T1059.001-0" atomic_test_id = f"{technique_id}-{idx}" + # Assign source_url = ( source_url = ( f"https://github.com/redcanaryco/atomic-red-team/blob/master" f"/atomics/{technique_id}/{technique_id}.yaml" ) + # Call results.append() results.append({ + # Literal argument value "technique_id": technique_id, + # Literal argument value "index": idx, + # Literal argument value "atomic_test_id": atomic_test_id, + # Literal argument value "name": name, + # Literal argument value "description": description, + # Literal argument value "platforms": ", ".join(platforms) if isinstance(platforms, list) else str(platforms), + # Literal argument value "executor_type": executor_type, + # Literal argument value "command": command[:4000] if command else None, # cap at 4k chars + # Literal argument value "source_url": source_url, }) + # Log info: "Parsed %d atomic tests total", len(results logger.info("Parsed %d atomic tests total", len(results)) + # Return results return results @@ -193,52 +272,80 @@ def import_atomic_red_team(db: Session) -> dict: db : Session Active SQLAlchemy database session. - Returns + Returns: ------- dict Summary with keys ``created``, ``skipped_existing``, ``yaml_files_parsed``, ``total_tests_parsed``. """ + # Assign tmp_dir = tempfile.mkdtemp(prefix="aegis_atomic_") tmp_dir = tempfile.mkdtemp(prefix="aegis_atomic_") + # Attempt the following; catch errors below try: + # Assign zip_bytes = _download_zip() zip_bytes = _download_zip() + # Assign atomics_dir = _extract_zip(zip_bytes, tmp_dir) atomics_dir = _extract_zip(zip_bytes, tmp_dir) + # Assign parsed_tests = _parse_yaml_files(atomics_dir) parsed_tests = _parse_yaml_files(atomics_dir) + # Always execute this cleanup block finally: # Always clean up shutil.rmtree(tmp_dir, ignore_errors=True) + # Log info: "Cleaned up temp directory %s", tmp_dir logger.info("Cleaned up temp directory %s", tmp_dir) # Pre-load existing atomic_test_ids for dedup existing_ids: set[str] = { row[0] for row in db.query(TestTemplate.atomic_test_id) + # Chain .filter() call .filter(TestTemplate.atomic_test_id.isnot(None)) + # Chain .all() call .all() } + # Assign created = 0 created = 0 + # Assign skipped = 0 skipped = 0 new_technique_ids: set[str] = set() + # Iterate over parsed_tests for item in parsed_tests: + # Check: item["atomic_test_id"] in existing_ids if item["atomic_test_id"] in existing_ids: + # Assign skipped = 1 skipped += 1 + # Skip to the next loop iteration continue + # Assign template = TestTemplate( template = TestTemplate( + # Keyword argument: mitre_technique_id mitre_technique_id=item["technique_id"], + # Keyword argument: name name=item["name"][:500] if item["name"] else f"Atomic Test {item['atomic_test_id']}", + # Keyword argument: description description=item["description"][:2000] if item["description"] else None, + # Keyword argument: source source="atomic_red_team", + # Keyword argument: source_url source_url=item["source_url"], + # Keyword argument: attack_procedure attack_procedure=item["command"], + # Keyword argument: platform platform=item["platforms"], + # Keyword argument: tool_suggested tool_suggested=item["executor_type"] if item["executor_type"] else None, + # Keyword argument: atomic_test_id atomic_test_id=item["atomic_test_id"], + # Keyword argument: is_active is_active=True, ) + # Stage new record(s) for database insertion db.add(template) + # Call existing_ids.add() existing_ids.add(item["atomic_test_id"]) new_technique_ids.add(item["technique_id"]) created += 1 @@ -253,15 +360,23 @@ def import_atomic_red_team(db: Session) -> dict: # Count distinct YAML files by technique_id yaml_files_count = len({t["technique_id"] for t in parsed_tests}) + # Assign summary = { summary = { + # Literal argument value "created": created, + # Literal argument value "skipped_existing": skipped, + # Literal argument value "yaml_files_parsed": yaml_files_count, + # Literal argument value "total_tests_parsed": len(parsed_tests), } + # Log info: logger.info( + # Literal argument value "Atomic Red Team import complete — created=%d, skipped=%d, " + # Literal argument value "yaml_files=%d, total_tests=%d", created, skipped, yaml_files_count, len(parsed_tests), ) @@ -269,12 +384,19 @@ def import_atomic_red_team(db: Session) -> dict: # Audit log (system action) log_action( db, + # Keyword argument: user_id user_id=None, + # Keyword argument: action action="import_atomic_red_team", + # Keyword argument: entity_type entity_type="test_template", + # Keyword argument: entity_id entity_id=None, + # Keyword argument: details details=summary, ) + # Commit all pending changes to the database db.commit() + # Return summary return summary diff --git a/backend/app/services/audit_query_service.py b/backend/app/services/audit_query_service.py index 6b3faaa..dfd6d83 100644 --- a/backend/app/services/audit_query_service.py +++ b/backend/app/services/audit_query_service.py @@ -4,24 +4,37 @@ Provides paginated logs and distinct action/entity-type lists. No FastAPI imports. """ +# Enable future language features for compatibility from __future__ import annotations +# Import datetime from datetime from datetime import datetime +# Import Session, joinedload from sqlalchemy.orm from sqlalchemy.orm import Session, joinedload +# Import AuditLog from app.models.audit from app.models.audit import AuditLog +# Define function list_logs def list_logs( + # Entry: db db: Session, *, + # Entry: user_id user_id: str | None = None, + # Entry: action action: str | None = None, + # Entry: entity_type entity_type: str | None = None, + # Entry: start_date start_date: datetime | None = None, + # Entry: end_date end_date: datetime | None = None, + # Entry: offset offset: int = 0, + # Entry: limit limit: int = 50, ) -> dict: """Return paginated audit logs with optional filters. @@ -30,64 +43,104 @@ def list_logs( Each item is a dict with: id, user_id, username, action, entity_type, entity_id, timestamp, details. """ + # Assign query = db.query(AuditLog).options(joinedload(AuditLog.user)) query = db.query(AuditLog).options(joinedload(AuditLog.user)) + # Check: user_id if user_id: + # Assign query = query.filter(AuditLog.user_id == user_id) query = query.filter(AuditLog.user_id == user_id) + # Check: action if action: + # Assign query = query.filter(AuditLog.action == action) query = query.filter(AuditLog.action == action) + # Check: entity_type if entity_type: + # Assign query = query.filter(AuditLog.entity_type == entity_type) query = query.filter(AuditLog.entity_type == entity_type) + # Check: start_date if start_date: + # Assign query = query.filter(AuditLog.timestamp >= start_date) query = query.filter(AuditLog.timestamp >= start_date) + # Check: end_date if end_date: + # Assign query = query.filter(AuditLog.timestamp <= end_date) query = query.filter(AuditLog.timestamp <= end_date) + # Assign total = query.count() total = query.count() + # Assign logs = ( logs = ( query + # Chain .order_by() call .order_by(AuditLog.timestamp.desc()) + # Chain .offset() call .offset(offset) + # Chain .limit() call .limit(limit) + # Chain .all() call .all() ) + # Assign items = [ items = [ { + # Literal argument value "id": log.id, + # Literal argument value "user_id": log.user_id, + # Literal argument value "username": log.user.username if log.user else None, + # Literal argument value "action": log.action, + # Literal argument value "entity_type": log.entity_type, + # Literal argument value "entity_id": log.entity_id, + # Literal argument value "timestamp": log.timestamp, + # Literal argument value "details": log.details, } for log in logs ] + # Return {"items": items, "total": total, "offset": offset, "limit": limit} return {"items": items, "total": total, "offset": offset, "limit": limit} +# Define function list_distinct_actions def list_distinct_actions(db: Session) -> list[str]: """Return a list of distinct action types in the audit log.""" + # Assign actions = ( actions = ( db.query(AuditLog.action) + # Chain .distinct() call .distinct() + # Chain .order_by() call .order_by(AuditLog.action) + # Chain .all() call .all() ) + # Return [a[0] for a in actions] return [a[0] for a in actions] +# Define function list_distinct_entity_types def list_distinct_entity_types(db: Session) -> list[str]: """Return a list of distinct entity types in the audit log.""" + # Assign types = ( types = ( db.query(AuditLog.entity_type) + # Chain .filter() call .filter(AuditLog.entity_type.isnot(None)) + # Chain .distinct() call .distinct() + # Chain .order_by() call .order_by(AuditLog.entity_type) + # Chain .all() call .all() ) + # Return [t[0] for t in types] return [t[0] for t in types] diff --git a/backend/app/services/audit_service.py b/backend/app/services/audit_service.py index 22cc8d0..0ae93c3 100644 --- a/backend/app/services/audit_service.py +++ b/backend/app/services/audit_service.py @@ -1,66 +1,116 @@ """Audit logging with request context and integrity hashing.""" +# Enable future language features for compatibility from __future__ import annotations +# Import hashlib import hashlib + +# Import datetime, timezone from datetime from datetime import datetime, timezone +# Import UUID from uuid +from uuid import UUID + +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session +# Import request_ip, request_user_agent from app.middleware.request_context from app.middleware.request_context import request_ip, request_user_agent + +# Import AuditLog from app.models.audit from app.models.audit import AuditLog +# Define function _integrity_payload def _integrity_payload(entry: AuditLog) -> str: + # Assign ts = entry.timestamp ts = entry.timestamp + # Check: ts is None if ts is None: + # Assign ts = datetime.now(timezone.utc) ts = datetime.now(timezone.utc) + # Assign user_part = str(entry.user_id) if entry.user_id else "" user_part = str(entry.user_id) if entry.user_id else "" + # Assign entity_type = entry.entity_type or "" entity_type = entry.entity_type or "" + # Assign entity_id = entry.entity_id or "" entity_id = entry.entity_id or "" + # Return f"{user_part}:{entry.action}:{entity_type}:{entity_id}:{ts.isoforma... return f"{user_part}:{entry.action}:{entity_type}:{entity_id}:{ts.isoformat()}" +# Define function compute_integrity_hash def compute_integrity_hash(entry: AuditLog) -> str: """Return the SHA-256 hex digest for an audit log entry.""" + # Return hashlib.sha256(_integrity_payload(entry).encode()).hexdigest() return hashlib.sha256(_integrity_payload(entry).encode()).hexdigest() +# Define function verify_audit_integrity def verify_audit_integrity(entry: AuditLog) -> bool: """Return whether the stored hash matches the entry's current fields.""" + # Check: not entry.integrity_hash if not entry.integrity_hash: + # Return False return False + # Return entry.integrity_hash == compute_integrity_hash(entry) return entry.integrity_hash == compute_integrity_hash(entry) +# Define function log_action def log_action( + # Entry: db db: Session, - user_id, + # Entry: user_id + user_id: UUID | None, + # Entry: action action: str, + # Entry: entity_type entity_type: str | None = None, + # Entry: entity_id entity_id: str | None = None, + # Entry: details details: dict | None = None, *, + # Entry: ip_address ip_address: str | None = None, + # Entry: user_agent user_agent: str | None = None, + # Entry: session_id session_id: str | None = None, ) -> AuditLog: """Record an audit event. Does not commit — the caller owns the transaction.""" + # Assign ip = ip_address if ip_address is not None else request_ip.get("") ip = ip_address if ip_address is not None else request_ip.get("") + # Assign ua = user_agent if user_agent is not None else request_user_agent.get("") ua = user_agent if user_agent is not None else request_user_agent.get("") + # Assign entry = AuditLog( entry = AuditLog( + # Keyword argument: user_id user_id=user_id, + # Keyword argument: action action=action, + # Keyword argument: entity_type entity_type=entity_type, + # Keyword argument: entity_id entity_id=str(entity_id) if entity_id else None, + # Keyword argument: details details=details, + # Keyword argument: ip_address ip_address=ip or None, + # Keyword argument: user_agent user_agent=ua or None, + # Keyword argument: session_id session_id=session_id, timestamp=datetime.now(timezone.utc), ) + # Stage new record(s) for database insertion db.add(entry) + # Flush changes to DB without committing the transaction db.flush() + # Assign entry.integrity_hash = compute_integrity_hash(entry) entry.integrity_hash = compute_integrity_hash(entry) + # Return entry return entry diff --git a/backend/app/services/auth_service.py b/backend/app/services/auth_service.py index fcb9798..a34f6d5 100644 --- a/backend/app/services/auth_service.py +++ b/backend/app/services/auth_service.py @@ -1,15 +1,24 @@ """Authentication service — credential validation and password management.""" +# Enable future language features for compatibility from __future__ import annotations +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session +# Import hash_password, verify_password from app.auth from app.auth import hash_password, verify_password + +# Import BusinessRuleViolation, PermissionViolation from app.domain.errors from app.domain.errors import BusinessRuleViolation, PermissionViolation + +# Import User from app.models.user from app.models.user import User +# Assign _DUMMY_HASH = "$2b$12$LJ3m4ys3Lg3dMO/NpNmOaeVwFpWJMxlB2FLmEAo9fZr.S8H1vC4Wy" _DUMMY_HASH = "$2b$12$LJ3m4ys3Lg3dMO/NpNmOaeVwFpWJMxlB2FLmEAo9fZr.S8H1vC4Wy" +# Define function authenticate_user def authenticate_user(db: Session, *, username: str, password: str) -> User: """Validate credentials and return the User. @@ -17,33 +26,49 @@ def authenticate_user(db: Session, *, username: str, password: str) -> User: Raises PermissionViolation for disabled account. Uses constant-time comparison to prevent timing attacks. """ + # Assign user = db.query(User).filter(User.username == username).first() user = db.query(User).filter(User.username == username).first() + # Assign hashed = user.hashed_password if user else _DUMMY_HASH hashed = user.hashed_password if user else _DUMMY_HASH + # Assign password_valid = verify_password(password, hashed) password_valid = verify_password(password, hashed) + # Check: user is None or not password_valid if user is None or not password_valid: + # Raise BusinessRuleViolation raise BusinessRuleViolation("Incorrect username or password") + # Check: not user.is_active if not user.is_active: + # Raise PermissionViolation raise PermissionViolation("Account is disabled. Contact an administrator.") + # Return user return user +# Define function change_password def change_password( + # Entry: db db: Session, + # Entry: user user: User, *, + # Entry: current_password current_password: str, + # Entry: new_password new_password: str, ) -> None: """Change a user's password. Does NOT commit. Raises BusinessRuleViolation if current password is wrong. """ + # Check: not verify_password(current_password, user.hashed_password) if not verify_password(current_password, user.hashed_password): + # Raise BusinessRuleViolation raise BusinessRuleViolation("Current password is incorrect") if verify_password(new_password, user.hashed_password): raise BusinessRuleViolation( "New password must be different from the current password" ) user.hashed_password = hash_password(new_password) + # Assign user.must_change_password = False user.must_change_password = False diff --git a/backend/app/services/caldera_import_service.py b/backend/app/services/caldera_import_service.py index cdea97f..43dccda 100644 --- a/backend/app/services/caldera_import_service.py +++ b/backend/app/services/caldera_import_service.py @@ -21,23 +21,42 @@ templates are identified by ``source = "caldera"`` + ``atomic_test_id`` (the CALDERA ability ``id``). """ +# Import io import io + +# Import logging import logging + +# Import shutil import shutil + +# Import tempfile import tempfile + +# Import zipfile import zipfile + +# Import datetime from datetime from datetime import datetime + +# Import Path from pathlib from pathlib import Path +# Import requests import requests as _requests + +# Import yaml import yaml + +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session -from app.models.test_template import TestTemplate +# Import DataSource from app.models.data_source from app.models.data_source import DataSource from app.models.technique import Technique from app.services.audit_service import log_action +# Assign logger = logging.getLogger(__name__) logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- @@ -45,11 +64,15 @@ logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- CALDERA_ZIP_URL = ( + # Literal argument value "https://github.com/mitre/stockpile" + # Literal argument value "/archive/refs/heads/master.zip" ) +# Assign _DOWNLOAD_TIMEOUT = 300 _DOWNLOAD_TIMEOUT = 300 +# Assign _ZIP_ROOT_PREFIX = "stockpile-master" _ZIP_ROOT_PREFIX = "stockpile-master" @@ -60,26 +83,40 @@ _ZIP_ROOT_PREFIX = "stockpile-master" def _download_zip(url: str = CALDERA_ZIP_URL) -> bytes: """Download the CALDERA ZIP and return raw bytes.""" + # Log info: "Downloading CALDERA ZIP from %s …", url logger.info("Downloading CALDERA ZIP from %s …", url) + # Assign resp = _requests.get(url, timeout=_DOWNLOAD_TIMEOUT, stream=True) resp = _requests.get(url, timeout=_DOWNLOAD_TIMEOUT, stream=True) + # Call resp.raise_for_status() resp.raise_for_status() + # Assign content = resp.content content = resp.content + # Log info: "Downloaded %.1f MB", len(content) / (1024 * 1024 logger.info("Downloaded %.1f MB", len(content) / (1024 * 1024)) + # Return content return content +# Define function _extract_zip def _extract_zip(zip_bytes: bytes, dest: str) -> Path: """Extract *zip_bytes* into *dest* and return abilities dir.""" + # Open context manager with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf: + # Call zf.extractall() zf.extractall(dest) + # Assign abilities_dir = Path(dest) / _ZIP_ROOT_PREFIX / "data" / "abilities" abilities_dir = Path(dest) / _ZIP_ROOT_PREFIX / "data" / "abilities" + # Check: not abilities_dir.is_dir() if not abilities_dir.is_dir(): + # Raise FileNotFoundError raise FileNotFoundError( f"Expected abilities directory not found at {abilities_dir}" ) + # Return abilities_dir return abilities_dir +# Define function _extract_commands def _extract_commands(platforms_dict: dict) -> str: """Extract executor commands from CALDERA platforms dict. @@ -95,116 +132,192 @@ def _extract_commands(platforms_dict: dict) -> str: Returns a formatted string with all commands. """ + # Assign lines = [] lines = [] + # Check: not isinstance(platforms_dict, dict) if not isinstance(platforms_dict, dict): + # Return "" return "" + # Iterate over platforms_dict.items() for os_name, executors in platforms_dict.items(): + # Check: not isinstance(executors, dict) if not isinstance(executors, dict): + # Skip to the next loop iteration continue + # Iterate over executors.items() for executor_name, executor_data in executors.items(): + # Check: isinstance(executor_data, dict) if isinstance(executor_data, dict): + # Assign cmd = executor_data.get("command", "") cmd = executor_data.get("command", "") + # Check: cmd if cmd: + # Call lines.append() lines.append(f"[{os_name}/{executor_name}]\n{cmd}") + # Alternative: isinstance(executor_data, str) elif isinstance(executor_data, str): + # Call lines.append() lines.append(f"[{os_name}/{executor_name}]\n{executor_data}") + # Return "\n\n".join(lines) return "\n\n".join(lines) +# Define function _extract_platforms def _extract_platforms(platforms_dict: dict) -> str: """Extract platform names from CALDERA platforms dict.""" + # Check: not isinstance(platforms_dict, dict) if not isinstance(platforms_dict, dict): + # Return "" return "" + # Assign platform_names = [] platform_names = [] + # Iterate over platforms_dict for os_name in platforms_dict: + # Assign normalized = str(os_name).lower().strip() normalized = str(os_name).lower().strip() + # Check: normalized in ("windows", "linux", "darwin", "macos") if normalized in ("windows", "linux", "darwin", "macos"): + # Check: normalized == "darwin" if normalized == "darwin": + # Assign normalized = "macos" normalized = "macos" + # Check: normalized not in platform_names if normalized not in platform_names: + # Call platform_names.append() platform_names.append(normalized) + # Return ", ".join(platform_names) return ", ".join(platform_names) +# Define function _parse_abilities def _parse_abilities(abilities_dir: Path) -> list[dict]: """Walk abilities directories and parse all YAML files. Returns a flat list of dicts, each representing one ability. """ + # Assign results = [] results: list[dict] = [] + # Assign yaml_files = sorted(abilities_dir.rglob("*.yml")) yaml_files = sorted(abilities_dir.rglob("*.yml")) + # Log info: "Found %d ability YAML files", len(yaml_files logger.info("Found %d ability YAML files", len(yaml_files)) + # Iterate over yaml_files for yaml_path in yaml_files: + # Attempt the following; catch errors below try: + # Open context manager with open(yaml_path, "r", encoding="utf-8") as fh: + # Assign data_list = list(yaml.safe_load_all(fh)) data_list = list(yaml.safe_load_all(fh)) + # Handle Exception except Exception as exc: + # Log debug: "Failed to parse %s: %s", yaml_path, exc logger.debug("Failed to parse %s: %s", yaml_path, exc) + # Skip to the next loop iteration continue # Stockpile YAML files may contain YAML lists of abilities # (e.g. [- id: ..., - id: ...]) or single-document dicts. # Flatten everything into individual ability dicts. abilities: list[dict] = [] + # Iterate over data_list for data in data_list: + # Check: isinstance(data, dict) if isinstance(data, dict): + # Call abilities.append() abilities.append(data) + # Alternative: isinstance(data, list) elif isinstance(data, list): + # Call abilities.extend() abilities.extend(d for d in data if isinstance(d, dict)) + # Iterate over abilities for data in abilities: + # Assign ability_id = data.get("id", "") ability_id = data.get("id", "") + # Check: not ability_id if not ability_id: + # Skip to the next loop iteration continue + # Assign name = data.get("name", "").strip() name = data.get("name", "").strip() + # Assign description = data.get("description", "").strip() description = data.get("description", "").strip() + # Assign tactic = data.get("tactic", "").strip() tactic = data.get("tactic", "").strip() # Extract technique info technique = data.get("technique", {}) + # Check: isinstance(technique, dict) if isinstance(technique, dict): + # Assign attack_id = technique.get("attack_id", "") attack_id = technique.get("attack_id", "") + # Fallback: handle remaining cases else: + # Assign attack_id = "" attack_id = "" + # Check: not attack_id if not attack_id: + # Skip to the next loop iteration continue # Normalise technique ID attack_id = str(attack_id).strip().upper() + # Check: not attack_id.startswith("T") if not attack_id.startswith("T"): + # Skip to the next loop iteration continue # Extract platforms and commands platforms_dict = data.get("platforms", {}) + # Assign commands = _extract_commands(platforms_dict) commands = _extract_commands(platforms_dict) + # Assign platform_str = _extract_platforms(platforms_dict) platform_str = _extract_platforms(platforms_dict) # Determine executor type executors = set() + # Check: isinstance(platforms_dict, dict) if isinstance(platforms_dict, dict): + # Iterate over platforms_dict.values() for os_executors in platforms_dict.values(): + # Check: isinstance(os_executors, dict) if isinstance(os_executors, dict): + # Call executors.update() executors.update(os_executors.keys()) + # Assign executor_str = ", ".join(sorted(executors)) if executors else None executor_str = ", ".join(sorted(executors)) if executors else None + # Call results.append() results.append({ + # Literal argument value "mitre_technique_id": attack_id, + # Literal argument value "name": f"CALDERA: {name}"[:500] if name else f"CALDERA ability {ability_id}"[:500], + # Literal argument value "description": f"{description}\n\nTactic: {tactic}".strip()[:2000] if description else None, + # Literal argument value "source": "caldera", + # Literal argument value "platform": platform_str, + # Literal argument value "tool_suggested": executor_str, + # Literal argument value "attack_procedure": commands[:4000] if commands else None, + # Literal argument value "atomic_test_id": f"caldera:{ability_id}", + # Literal argument value "source_url": f"https://github.com/mitre/stockpile/tree/master/data/abilities/{tactic}", }) + # Log info: "Parsed %d CALDERA abilities total", len(results logger.info("Parsed %d CALDERA abilities total", len(results)) + # Return results return results @@ -218,46 +331,76 @@ def sync(db: Session) -> dict: Returns a summary dict with ``created``, ``skipped_existing``, ``total_parsed``. """ + # Assign tmp_dir = tempfile.mkdtemp(prefix="aegis_caldera_") tmp_dir = tempfile.mkdtemp(prefix="aegis_caldera_") + # Attempt the following; catch errors below try: + # Assign zip_bytes = _download_zip() zip_bytes = _download_zip() + # Assign abilities_dir = _extract_zip(zip_bytes, tmp_dir) abilities_dir = _extract_zip(zip_bytes, tmp_dir) + # Assign parsed = _parse_abilities(abilities_dir) parsed = _parse_abilities(abilities_dir) + # Always execute this cleanup block finally: + # Call shutil.rmtree() shutil.rmtree(tmp_dir, ignore_errors=True) + # Log info: "Cleaned up temp directory %s", tmp_dir logger.info("Cleaned up temp directory %s", tmp_dir) # Pre-load existing for dedup existing_ids: set[str] = { row[0] for row in db.query(TestTemplate.atomic_test_id) + # Chain .filter() call .filter(TestTemplate.source == "caldera") + # Chain .filter() call .filter(TestTemplate.atomic_test_id.isnot(None)) + # Chain .all() call .all() } + # Assign created = 0 created = 0 + # Assign skipped = 0 skipped = 0 new_technique_ids: set[str] = set() + # Iterate over parsed for item in parsed: + # Check: item["atomic_test_id"] in existing_ids if item["atomic_test_id"] in existing_ids: + # Assign skipped = 1 skipped += 1 + # Skip to the next loop iteration continue + # Assign template = TestTemplate( template = TestTemplate( + # Keyword argument: mitre_technique_id mitre_technique_id=item["mitre_technique_id"], + # Keyword argument: name name=item["name"], + # Keyword argument: description description=item["description"], + # Keyword argument: source source=item["source"], + # Keyword argument: source_url source_url=item["source_url"], + # Keyword argument: attack_procedure attack_procedure=item["attack_procedure"], + # Keyword argument: platform platform=item["platform"], + # Keyword argument: tool_suggested tool_suggested=item["tool_suggested"], + # Keyword argument: atomic_test_id atomic_test_id=item["atomic_test_id"], + # Keyword argument: is_active is_active=True, ) + # Stage new record(s) for database insertion db.add(template) + # Call existing_ids.add() existing_ids.add(item["atomic_test_id"]) new_technique_ids.add(item["mitre_technique_id"]) created += 1 @@ -269,22 +412,36 @@ def sync(db: Session) -> dict: db.commit() + # Assign summary = { summary = { + # Literal argument value "created": created, + # Literal argument value "skipped_existing": skipped, + # Literal argument value "total_parsed": len(parsed), } # Update DataSource record ds = db.query(DataSource).filter(DataSource.name == "caldera").first() + # Check: ds if ds: + # Assign ds.last_sync_at = datetime.utcnow() ds.last_sync_at = datetime.utcnow() + # Assign ds.last_sync_status = "success" ds.last_sync_status = "success" + # Assign ds.last_sync_stats = summary ds.last_sync_stats = summary + # Commit all pending changes to the database db.commit() + # Log info: "CALDERA import complete — %s", summary logger.info("CALDERA import complete — %s", summary) + # Call log_action() log_action(db, user_id=None, action="import_caldera", + # Keyword argument: entity_type entity_type="test_template", entity_id=None, details=summary) + # Commit all pending changes to the database db.commit() + # Return summary return summary diff --git a/backend/app/services/campaign_crud_service.py b/backend/app/services/campaign_crud_service.py index f350aba..8ec30d5 100644 --- a/backend/app/services/campaign_crud_service.py +++ b/backend/app/services/campaign_crud_service.py @@ -4,112 +4,191 @@ Framework-agnostic; uses domain exceptions from app.domain.errors. The router is responsible for HTTP concerns, auth, audit logging, and commit. """ +# Import uuid import uuid + +# Import datetime from datetime from datetime import datetime + +# Import Optional from typing from typing import Optional +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session +# Import from app.domain.errors from app.domain.errors import ( BusinessRuleViolation, EntityNotFoundError, PermissionViolation, ) + +# Import Campaign, CampaignTest from app.models.campaign from app.models.campaign import Campaign, CampaignTest -from app.models.test import Test + +# Import Technique from app.models.technique from app.models.technique import Technique -from app.utils import escape_like -from app.services.campaign_service import ( - get_campaign_progress, - validate_no_circular_dependency, - TACTIC_TO_PHASE, -) + +# Import Test from app.models.test +from app.models.test import Test + +# Import calculate_next_run from app.services.campaign_scheduler_service from app.services.campaign_scheduler_service import calculate_next_run from app.services.status_service import recalculate_technique_status +# Import from app.services.campaign_service +from app.services.campaign_service import ( + TACTIC_TO_PHASE, + get_campaign_progress, + validate_no_circular_dependency, +) + +# Import escape_like from app.utils +from app.utils import escape_like # ── Serialization helpers ──────────────────────────────────────────────── def serialize_campaign(db: Session, campaign: Campaign) -> dict: """Serialize a campaign with its tests and progress.""" + # Assign progress = get_campaign_progress(db, campaign.id) progress = get_campaign_progress(db, campaign.id) + # Assign campaign_tests = ( campaign_tests = ( db.query(CampaignTest) + # Chain .filter() call .filter(CampaignTest.campaign_id == campaign.id) + # Chain .order_by() call .order_by(CampaignTest.order_index) + # Chain .all() call .all() ) + # Assign tests = [] tests = [] + # Iterate over campaign_tests for ct in campaign_tests: + # Assign test = ct.test test = ct.test + # Assign technique = db.query(Technique).filter(Technique.id == test.technique_id).first... technique = db.query(Technique).filter(Technique.id == test.technique_id).first() if test else None + # Call tests.append() tests.append({ + # Literal argument value "id": str(ct.id), + # Literal argument value "test_id": str(ct.test_id), + # Literal argument value "order_index": ct.order_index, + # Literal argument value "depends_on": str(ct.depends_on) if ct.depends_on else None, + # Literal argument value "phase": ct.phase, + # Literal argument value "test_name": test.name if test else None, + # Literal argument value "test_state": test.state.value if test and test.state else None, + # Literal argument value "test_result": test.result.value if test and test.result else None, + # Literal argument value "technique_mitre_id": technique.mitre_id if technique else None, + # Literal argument value "technique_name": technique.name if technique else None, + # Literal argument value "platform": test.platform if test else None, }) + # Assign actor = campaign.threat_actor actor = campaign.threat_actor + # Return { return { + # Literal argument value "id": str(campaign.id), + # Literal argument value "name": campaign.name, + # Literal argument value "description": campaign.description, + # Literal argument value "type": campaign.type, + # Literal argument value "status": campaign.status, + # Literal argument value "threat_actor_id": str(campaign.threat_actor_id) if campaign.threat_actor_id else None, + # Literal argument value "threat_actor_name": actor.name if actor else None, + # Literal argument value "created_by": str(campaign.created_by) if campaign.created_by else None, "start_date": campaign.start_date.isoformat() if campaign.start_date else None, "scheduled_at": campaign.scheduled_at.isoformat() if campaign.scheduled_at else None, + # Literal argument value "completed_at": campaign.completed_at.isoformat() if campaign.completed_at else None, + # Literal argument value "target_platform": campaign.target_platform, + # Literal argument value "tags": campaign.tags or [], + # Literal argument value "created_at": campaign.created_at.isoformat() if campaign.created_at else None, + # Literal argument value "is_recurring": campaign.is_recurring or False, + # Literal argument value "recurrence_pattern": campaign.recurrence_pattern, + # Literal argument value "next_run_at": campaign.next_run_at.isoformat() if campaign.next_run_at else None, + # Literal argument value "last_run_at": campaign.last_run_at.isoformat() if campaign.last_run_at else None, + # Literal argument value "parent_campaign_id": str(campaign.parent_campaign_id) if campaign.parent_campaign_id else None, + # Literal argument value "tests": tests, + # Literal argument value "progress": progress, } +# Define function serialize_campaign_summary def serialize_campaign_summary(db: Session, campaign: Campaign) -> dict: """Lightweight campaign serialization for list views.""" + # Assign progress = get_campaign_progress(db, campaign.id) progress = get_campaign_progress(db, campaign.id) + # Assign actor = campaign.threat_actor actor = campaign.threat_actor + # Return { return { + # Literal argument value "id": str(campaign.id), + # Literal argument value "name": campaign.name, + # Literal argument value "description": campaign.description, + # Literal argument value "type": campaign.type, + # Literal argument value "status": campaign.status, + # Literal argument value "threat_actor_id": str(campaign.threat_actor_id) if campaign.threat_actor_id else None, + # Literal argument value "threat_actor_name": actor.name if actor else None, "start_date": campaign.start_date.isoformat() if campaign.start_date else None, "target_platform": campaign.target_platform, + # Literal argument value "tags": campaign.tags or [], + # Literal argument value "created_at": campaign.created_at.isoformat() if campaign.created_at else None, + # Literal argument value "is_recurring": campaign.is_recurring or False, + # Literal argument value "recurrence_pattern": campaign.recurrence_pattern, + # Literal argument value "next_run_at": campaign.next_run_at.isoformat() if campaign.next_run_at else None, + # Literal argument value "last_run_at": campaign.last_run_at.isoformat() if campaign.last_run_at else None, + # Literal argument value "test_count": progress["total"], + # Literal argument value "completion_pct": progress["completion_pct"], } @@ -118,122 +197,198 @@ def serialize_campaign_summary(db: Session, campaign: Campaign) -> dict: def list_campaigns( + # Entry: db db: Session, *, + # Entry: type type: Optional[str] = None, + # Entry: status status: Optional[str] = None, + # Entry: threat_actor_id threat_actor_id: Optional[str] = None, + # Entry: search search: Optional[str] = None, + # Entry: offset offset: int = 0, + # Entry: limit limit: int = 50, ) -> dict: """Return a paginated list of campaigns with optional filters.""" + # Assign query = db.query(Campaign) query = db.query(Campaign) + # Check: type if type: + # Assign query = query.filter(Campaign.type == type) query = query.filter(Campaign.type == type) + # Check: status if status: + # Assign query = query.filter(Campaign.status == status) query = query.filter(Campaign.status == status) + # Check: threat_actor_id if threat_actor_id: + # Assign query = query.filter(Campaign.threat_actor_id == threat_actor_id) query = query.filter(Campaign.threat_actor_id == threat_actor_id) + # Check: search if search: + # Assign pattern = f"%{escape_like(search)}%" pattern = f"%{escape_like(search)}%" + # Assign query = query.filter(Campaign.name.ilike(pattern) | Campaign.description.il... query = query.filter(Campaign.name.ilike(pattern) | Campaign.description.ilike(pattern)) + # Assign total = query.count() total = query.count() + # Assign campaigns = query.order_by(Campaign.created_at.desc()).offset(offset).limit(lim... campaigns = query.order_by(Campaign.created_at.desc()).offset(offset).limit(limit).all() + # Return { return { + # Literal argument value "total": total, + # Literal argument value "offset": offset, + # Literal argument value "limit": limit, + # Literal argument value "items": [serialize_campaign_summary(db, c) for c in campaigns], } +# Define function create_campaign def create_campaign( + # Entry: db db: Session, *, + # Entry: creator_id creator_id: uuid.UUID, + # Entry: name name: str, + # Entry: description description: Optional[str] = None, + # Entry: type type: str = "custom", + # Entry: threat_actor_id threat_actor_id: Optional[str] = None, + # Entry: target_platform target_platform: Optional[str] = None, + # Entry: tags tags: Optional[list[str]] = None, + # Entry: scheduled_at scheduled_at: Optional[str] = None, start_date: Optional[str] = None, ) -> dict: """Create a new campaign. Does not commit; caller commits.""" + # Assign campaign = Campaign( campaign = Campaign( + # Keyword argument: name name=name, + # Keyword argument: description description=description, + # Keyword argument: type type=type, + # Keyword argument: threat_actor_id threat_actor_id=uuid.UUID(threat_actor_id) if threat_actor_id else None, + # Keyword argument: target_platform target_platform=target_platform, + # Keyword argument: tags tags=tags or [], + # Keyword argument: created_by created_by=creator_id, + # Keyword argument: scheduled_at scheduled_at=datetime.fromisoformat(scheduled_at) if scheduled_at else None, start_date=datetime.fromisoformat(start_date) if start_date else None, ) + # Stage new record(s) for database insertion db.add(campaign) + # Flush changes to DB without committing the transaction db.flush() + # Return serialize_campaign(db, campaign) return serialize_campaign(db, campaign) +# Define function get_campaign_detail def get_campaign_detail(db: Session, campaign_id: str) -> dict: """Get detailed campaign info including tests and progress. Raises EntityNotFoundError if campaign not found. """ + # Assign campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() + # Check: not campaign if not campaign: + # Raise EntityNotFoundError raise EntityNotFoundError("Campaign", campaign_id) + # Return serialize_campaign(db, campaign) return serialize_campaign(db, campaign) +# Define function update_campaign def update_campaign( + # Entry: db db: Session, + # Entry: campaign_id campaign_id: str, *, + # Entry: updater_id updater_id: uuid.UUID, + # Entry: updater_role updater_role: str, - **fields, + **fields: object, ) -> dict: """Update a campaign. Only allowed in draft or active state. Raises EntityNotFoundError, BusinessRuleViolation, or PermissionViolation. Does not commit; caller commits. """ + # Assign campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() + # Check: not campaign if not campaign: + # Raise EntityNotFoundError raise EntityNotFoundError("Campaign", campaign_id) + # Check: campaign.status not in ("draft", "active") if campaign.status not in ("draft", "active"): + # Raise BusinessRuleViolation raise BusinessRuleViolation("Can only update draft or active campaigns") + # Check: str(campaign.created_by) != str(updater_id) and updater_role != "ad... if str(campaign.created_by) != str(updater_id) and updater_role != "admin": + # Raise PermissionViolation raise PermissionViolation("Only the creator or admin can update this campaign") + # Check: "scheduled_at" in fields and fields["scheduled_at"] if "scheduled_at" in fields and fields["scheduled_at"]: + # Assign fields["scheduled_at"] = datetime.fromisoformat(fields["scheduled_at"]) fields["scheduled_at"] = datetime.fromisoformat(fields["scheduled_at"]) if "start_date" in fields and fields["start_date"]: fields["start_date"] = datetime.fromisoformat(fields["start_date"]) + # Iterate over fields.items() for field, value in fields.items(): + # Call setattr() setattr(campaign, field, value) + # Flush changes to DB without committing the transaction db.flush() + # Return serialize_campaign(db, campaign) return serialize_campaign(db, campaign) +# Define function add_test_to_campaign def add_test_to_campaign( + # Entry: db db: Session, + # Entry: campaign_id campaign_id: str, *, + # Entry: test_id test_id: str, + # Entry: order_index order_index: Optional[int] = None, + # Entry: depends_on depends_on: Optional[str] = None, + # Entry: phase phase: Optional[str] = None, ) -> dict: """Add a test to a campaign with optional ordering and dependency. @@ -242,60 +397,101 @@ def add_test_to_campaign( Raises BusinessRuleViolation for invalid state or circular dependency. Does not commit; caller commits. """ + # Assign campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() + # Check: not campaign if not campaign: + # Raise EntityNotFoundError raise EntityNotFoundError("Campaign", campaign_id) + # Check: campaign.status not in ("draft", "active") if campaign.status not in ("draft", "active"): + # Raise BusinessRuleViolation raise BusinessRuleViolation("Can only add tests to draft or active campaigns") + # Assign test = db.query(Test).filter(Test.id == test_id).first() test = db.query(Test).filter(Test.id == test_id).first() + # Check: not test if not test: + # Raise EntityNotFoundError raise EntityNotFoundError("Test", test_id) + # Check: order_index is not None if order_index is not None: + # Assign final_order_index = order_index final_order_index = order_index + # Fallback: handle remaining cases else: + # Assign max_order = ( max_order = ( db.query(CampaignTest.order_index) + # Chain .filter() call .filter(CampaignTest.campaign_id == campaign_id) + # Chain .order_by() call .order_by(CampaignTest.order_index.desc()) + # Chain .first() call .first() ) + # Assign final_order_index = (max_order[0] + 1) if max_order else 0 final_order_index = (max_order[0] + 1) if max_order else 0 + # Assign depends_on_uuid = uuid.UUID(depends_on) if depends_on else None depends_on_uuid = uuid.UUID(depends_on) if depends_on else None + # Assign ct_id = uuid.uuid4() ct_id = uuid.uuid4() + # Check: depends_on_uuid if depends_on_uuid: + # Call validate_no_circular_dependency() validate_no_circular_dependency(db, uuid.UUID(campaign_id), ct_id, depends_on_uuid) + # Check: not phase and test.technique_id if not phase and test.technique_id: + # Assign technique = db.query(Technique).filter(Technique.id == test.technique_id).first() technique = db.query(Technique).filter(Technique.id == test.technique_id).first() + # Check: technique and technique.tactic if technique and technique.tactic: + # Assign phase = TACTIC_TO_PHASE.get(technique.tactic, None) phase = TACTIC_TO_PHASE.get(technique.tactic, None) + # Assign campaign_test = CampaignTest( campaign_test = CampaignTest( + # Keyword argument: id id=ct_id, + # Keyword argument: campaign_id campaign_id=campaign_id, + # Keyword argument: test_id test_id=test_id, + # Keyword argument: order_index order_index=final_order_index, + # Keyword argument: depends_on depends_on=depends_on_uuid, + # Keyword argument: phase phase=phase, ) + # Stage new record(s) for database insertion db.add(campaign_test) + # Flush changes to DB without committing the transaction db.flush() + # Return { return { + # Literal argument value "id": str(campaign_test.id), + # Literal argument value "campaign_id": str(campaign_test.campaign_id), + # Literal argument value "test_id": str(campaign_test.test_id), + # Literal argument value "order_index": campaign_test.order_index, + # Literal argument value "depends_on": str(campaign_test.depends_on) if campaign_test.depends_on else None, + # Literal argument value "phase": campaign_test.phase, } +# Define function remove_test_from_campaign def remove_test_from_campaign(db: Session, campaign_id: str, campaign_test_id: str) -> None: """Remove a test from a campaign. @@ -303,27 +499,41 @@ def remove_test_from_campaign(db: Session, campaign_id: str, campaign_test_id: s Raises BusinessRuleViolation for invalid state. Does not commit; caller commits. """ + # Assign campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() + # Check: not campaign if not campaign: + # Raise EntityNotFoundError raise EntityNotFoundError("Campaign", campaign_id) + # Check: campaign.status not in ("draft", "active") if campaign.status not in ("draft", "active"): + # Raise BusinessRuleViolation raise BusinessRuleViolation("Can only modify draft or active campaigns") + # Assign ct = ( ct = ( db.query(CampaignTest) + # Chain .filter() call .filter( CampaignTest.id == campaign_test_id, CampaignTest.campaign_id == campaign_id, ) + # Chain .first() call .first() ) + # Check: not ct if not ct: + # Raise EntityNotFoundError raise EntityNotFoundError("CampaignTest", campaign_test_id) + # Assign dep_id = uuid.UUID(campaign_test_id) dep_id = uuid.UUID(campaign_test_id) + # Assign dependents = db.query(CampaignTest).filter(CampaignTest.depends_on == dep_id).all() dependents = db.query(CampaignTest).filter(CampaignTest.depends_on == dep_id).all() + # Iterate over dependents for dep in dependents: + # Assign dep.depends_on = None dep.depends_on = None # Keep a reference to the underlying test before deleting the join record @@ -334,6 +544,7 @@ def remove_test_from_campaign(db: Session, campaign_id: str, campaign_test_id: s technique_id = test_obj.technique_id db.delete(ct) + # Flush changes to DB without committing the transaction db.flush() # Also delete the actual test record (it was created for this campaign) @@ -349,72 +560,110 @@ def remove_test_from_campaign(db: Session, campaign_id: str, campaign_test_id: s db.flush() +# Define function activate_campaign def activate_campaign(db: Session, campaign_id: str) -> Campaign: """Activate a campaign, moving it from draft to active. Raises EntityNotFoundError, BusinessRuleViolation. Does not commit; caller commits. """ + # Assign campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() + # Check: not campaign if not campaign: + # Raise EntityNotFoundError raise EntityNotFoundError("Campaign", campaign_id) + # Check: campaign.status != "draft" if campaign.status != "draft": + # Raise BusinessRuleViolation raise BusinessRuleViolation("Only draft campaigns can be activated") + # Assign test_count = db.query(CampaignTest).filter(CampaignTest.campaign_id == campaign_... test_count = db.query(CampaignTest).filter(CampaignTest.campaign_id == campaign_id).count() + # Check: test_count == 0 if test_count == 0: + # Raise BusinessRuleViolation raise BusinessRuleViolation("Campaign must have at least one test to activate") + # Assign campaign.status = "active" campaign.status = "active" + # Flush changes to DB without committing the transaction db.flush() + # Return campaign return campaign +# Define function complete_campaign def complete_campaign(db: Session, campaign_id: str) -> Campaign: """Mark a campaign as completed. Raises EntityNotFoundError, BusinessRuleViolation. Does not commit; caller commits. """ + # Assign campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() + # Check: not campaign if not campaign: + # Raise EntityNotFoundError raise EntityNotFoundError("Campaign", campaign_id) + # Check: campaign.status != "active" if campaign.status != "active": + # Raise BusinessRuleViolation raise BusinessRuleViolation("Only active campaigns can be completed") + # Assign campaign.status = "completed" campaign.status = "completed" + # Assign campaign.completed_at = datetime.utcnow() campaign.completed_at = datetime.utcnow() + # Flush changes to DB without committing the transaction db.flush() + # Return campaign return campaign +# Define function get_campaign_progress_data def get_campaign_progress_data(db: Session, campaign_id: str) -> dict: """Get progress statistics for a campaign. Raises EntityNotFoundError if campaign not found. """ + # Assign campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() + # Check: not campaign if not campaign: + # Raise EntityNotFoundError raise EntityNotFoundError("Campaign", campaign_id) + # Assign progress = get_campaign_progress(db, uuid.UUID(campaign_id)) progress = get_campaign_progress(db, uuid.UUID(campaign_id)) + # Return { return { + # Literal argument value "campaign_id": str(campaign.id), + # Literal argument value "campaign_name": campaign.name, **progress, } +# Define function schedule_campaign def schedule_campaign( + # Entry: db db: Session, + # Entry: campaign_id campaign_id: str, *, + # Entry: owner_id owner_id: uuid.UUID, + # Entry: owner_role owner_role: str, + # Entry: is_recurring is_recurring: bool, + # Entry: recurrence_pattern recurrence_pattern: Optional[str] = None, + # Entry: next_run_at next_run_at: Optional[str] = None, ) -> Campaign: """Configure or update the recurrence schedule for a campaign. @@ -422,32 +671,52 @@ def schedule_campaign( Raises EntityNotFoundError, PermissionViolation, BusinessRuleViolation. Does not commit; caller commits. """ + # Assign campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() + # Check: not campaign if not campaign: + # Raise EntityNotFoundError raise EntityNotFoundError("Campaign", campaign_id) + # Check: str(campaign.created_by) != str(owner_id) and owner_role != "admin" if str(campaign.created_by) != str(owner_id) and owner_role != "admin": + # Raise PermissionViolation raise PermissionViolation("Only the creator or admin can configure scheduling") + # Assign campaign.is_recurring = is_recurring campaign.is_recurring = is_recurring + # Check: is_recurring if is_recurring: + # Check: recurrence_pattern not in ("weekly", "monthly", "quarterly") if recurrence_pattern not in ("weekly", "monthly", "quarterly"): + # Raise BusinessRuleViolation raise BusinessRuleViolation( + # Literal argument value "recurrence_pattern must be 'weekly', 'monthly', or 'quarterly'" ) + # Assign campaign.recurrence_pattern = recurrence_pattern campaign.recurrence_pattern = recurrence_pattern + # Check: next_run_at if next_run_at: + # Assign campaign.next_run_at = datetime.fromisoformat( campaign.next_run_at = datetime.fromisoformat( next_run_at.replace("Z", "+00:00").replace("+00:00", "") ) + # Alternative: not campaign.next_run_at elif not campaign.next_run_at: + # Assign campaign.next_run_at = calculate_next_run(datetime.utcnow(), recurrence_pattern) campaign.next_run_at = calculate_next_run(datetime.utcnow(), recurrence_pattern) + # Fallback: handle remaining cases else: + # Assign campaign.recurrence_pattern = None campaign.recurrence_pattern = None + # Assign campaign.next_run_at = None campaign.next_run_at = None + # Flush changes to DB without committing the transaction db.flush() + # Return campaign return campaign @@ -522,29 +791,48 @@ def get_campaign_history(db: Session, campaign_id: str) -> dict: Raises EntityNotFoundError if campaign not found. """ + # Assign campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() + # Check: not campaign if not campaign: + # Raise EntityNotFoundError raise EntityNotFoundError("Campaign", campaign_id) + # Assign campaign_uuid = uuid.UUID(campaign_id) campaign_uuid = uuid.UUID(campaign_id) + # Assign children = ( children = ( db.query(Campaign) + # Chain .filter() call .filter(Campaign.parent_campaign_id == campaign_uuid) + # Chain .order_by() call .order_by(Campaign.created_at.desc()) + # Chain .all() call .all() ) + # Return { return { + # Literal argument value "campaign_id": str(campaign.id), + # Literal argument value "campaign_name": campaign.name, + # Literal argument value "items": [ { + # Literal argument value "id": str(child.id), + # Literal argument value "name": child.name, + # Literal argument value "status": child.status, + # Literal argument value "test_count": db.query(CampaignTest).filter(CampaignTest.campaign_id == child.id).count(), + # Literal argument value "completion_pct": get_campaign_progress(db, child.id)["completion_pct"], + # Literal argument value "created_at": child.created_at.isoformat() if child.created_at else None, + # Literal argument value "completed_at": child.completed_at.isoformat() if child.completed_at else None, } for child in children diff --git a/backend/app/services/campaign_scheduler_service.py b/backend/app/services/campaign_scheduler_service.py index 17aa134..ed21179 100644 --- a/backend/app/services/campaign_scheduler_service.py +++ b/backend/app/services/campaign_scheduler_service.py @@ -4,19 +4,34 @@ Handles checking which recurring campaigns are due, cloning them with fresh tests, and computing the next run date. """ +# Import logging import logging -import uuid + +# Import datetime, timedelta from datetime from datetime import datetime, timedelta +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session +# Import Campaign, CampaignTest from app.models.campaign from app.models.campaign import Campaign, CampaignTest -from app.models.test import Test + +# Import TestState from app.models.enums from app.models.enums import TestState -from app.services.notification_service import create_notification -from app.services.audit_service import log_action + +# Import Test from app.models.test +from app.models.test import Test + +# Import User from app.models.user from app.models.user import User +# Import log_action from app.services.audit_service +from app.services.audit_service import log_action + +# Import create_notification from app.services.notification_service +from app.services.notification_service import create_notification + +# Assign logger = logging.getLogger(__name__) logger = logging.getLogger(__name__) @@ -33,11 +48,16 @@ def calculate_next_run(current_date: datetime, pattern: str) -> datetime: - ``monthly`` : +30 days - ``quarterly``: +90 days """ + # Assign offsets = { offsets = { + # Literal argument value "weekly": timedelta(days=7), + # Literal argument value "monthly": timedelta(days=30), + # Literal argument value "quarterly": timedelta(days=90), } + # Return current_date + offsets.get(pattern, timedelta(days=30)) return current_date + offsets.get(pattern, timedelta(days=30)) @@ -54,59 +74,99 @@ def _clone_campaign(db: Session, original: Campaign) -> Campaign: with the same base data (in ``draft`` state) and link it. 3. Activate the new campaign. """ + # Assign now = datetime.utcnow() now = datetime.utcnow() + # Assign run_label = now.strftime("%Y-%m-%d") run_label = now.strftime("%Y-%m-%d") + # Assign child = Campaign( child = Campaign( + # Keyword argument: name name=f"{original.name} (Run {run_label})", + # Keyword argument: description description=original.description, + # Keyword argument: type type=original.type, + # Keyword argument: threat_actor_id threat_actor_id=original.threat_actor_id, + # Keyword argument: status status="active", + # Keyword argument: created_by created_by=original.created_by, + # Keyword argument: target_platform target_platform=original.target_platform, + # Keyword argument: tags tags=original.tags or [], + # Keyword argument: parent_campaign_id parent_campaign_id=original.id, ) + # Stage new record(s) for database insertion db.add(child) + # Flush changes to DB without committing the transaction db.flush() # get child.id # Clone each campaign_test with a fresh Test original_cts = ( db.query(CampaignTest) + # Chain .filter() call .filter(CampaignTest.campaign_id == original.id) + # Chain .order_by() call .order_by(CampaignTest.order_index) + # Chain .all() call .all() ) + # Iterate over original_cts for ct in original_cts: + # Assign src_test = ct.test src_test = ct.test + # Check: not src_test if not src_test: + # Skip to the next loop iteration continue + # Assign new_test = Test( new_test = Test( + # Keyword argument: technique_id technique_id=src_test.technique_id, + # Keyword argument: name name=src_test.name, + # Keyword argument: description description=src_test.description, + # Keyword argument: platform platform=src_test.platform, + # Keyword argument: procedure_text procedure_text=src_test.procedure_text, + # Keyword argument: tool_used tool_used=src_test.tool_used, + # Keyword argument: created_by created_by=original.created_by, + # Keyword argument: state state=TestState.draft, ) + # Stage new record(s) for database insertion db.add(new_test) + # Flush changes to DB without committing the transaction db.flush() # get new_test.id + # Assign new_ct = CampaignTest( new_ct = CampaignTest( + # Keyword argument: campaign_id campaign_id=child.id, + # Keyword argument: test_id test_id=new_test.id, + # Keyword argument: order_index order_index=ct.order_index, + # Keyword argument: phase phase=ct.phase, # depends_on is not copied — would need ID remapping ) + # Stage new record(s) for database insertion db.add(new_ct) + # Flush changes to DB without committing the transaction db.flush() + # Return child return child @@ -120,75 +180,119 @@ def check_and_run_recurring_campaigns(db: Session) -> int: Returns the number of campaigns spawned. """ + # Assign now = datetime.utcnow() now = datetime.utcnow() + # Assign due_campaigns = ( due_campaigns = ( db.query(Campaign) + # Chain .filter() call .filter( Campaign.is_recurring == True, # noqa: E712 Campaign.next_run_at <= now, ) + # Chain .all() call .all() ) + # Assign spawned = 0 spawned = 0 + # Iterate over due_campaigns for campaign in due_campaigns: + # Attempt the following; catch errors below try: + # Assign child = _clone_campaign(db, campaign) child = _clone_campaign(db, campaign) # Update the original's scheduling fields campaign.last_run_at = now + # Assign campaign.next_run_at = calculate_next_run(now, campaign.recurrence_pattern or "monthly") campaign.next_run_at = calculate_next_run(now, campaign.recurrence_pattern or "monthly") + # Commit all pending changes to the database db.commit() + # Reload ORM object attributes from the database db.refresh(child) # Audit log_action( db, + # Keyword argument: user_id user_id=campaign.created_by, + # Keyword argument: action action="recurring_campaign_run", + # Keyword argument: entity_type entity_type="campaign", + # Keyword argument: entity_id entity_id=child.id, + # Keyword argument: details details={ + # Literal argument value "parent_campaign_id": str(campaign.id), + # Literal argument value "child_campaign_name": child.name, + # Literal argument value "pattern": campaign.recurrence_pattern, }, ) + # Commit all pending changes to the database db.commit() # Notify if campaign.created_by: + # Call create_notification() create_notification( db, + # Keyword argument: user_id user_id=campaign.created_by, + # Keyword argument: type type="recurring_campaign_run", + # Keyword argument: title title="Recurring campaign executed", - message=f'Campaign "{child.name}" was automatically created from recurring template "{campaign.name}".', + # Keyword argument: message + message=( + f'Campaign "{child.name}" was automatically created ' + f'from recurring template "{campaign.name}".' + ), + # Keyword argument: entity_type entity_type="campaign", + # Keyword argument: entity_id entity_id=child.id, ) # Notify red_tech users red_techs = db.query(User).filter(User.role == "red_tech", User.is_active == True).all() # noqa: E712 + # Iterate over red_techs for user in red_techs: + # Call create_notification() create_notification( db, + # Keyword argument: user_id user_id=user.id, + # Keyword argument: type type="campaign_activated", + # Keyword argument: title title="New recurring campaign active", + # Keyword argument: message message=f'Campaign "{child.name}" is now active and ready for execution.', + # Keyword argument: entity_type entity_type="campaign", + # Keyword argument: entity_id entity_id=child.id, ) + # Assign spawned = 1 spawned += 1 + # Log info: "Spawned child campaign '%s' from parent '%s'", ch logger.info("Spawned child campaign '%s' from parent '%s'", child.name, campaign.name) + # Handle Exception except Exception: + # Roll back all uncommitted changes db.rollback() + # Log exception: "Failed to run recurring campaign '%s'", campaign. logger.exception("Failed to run recurring campaign '%s'", campaign.name) + # Return spawned return spawned diff --git a/backend/app/services/campaign_service.py b/backend/app/services/campaign_service.py index cb2b0dc..956d382 100644 --- a/backend/app/services/campaign_service.py +++ b/backend/app/services/campaign_service.py @@ -4,108 +4,183 @@ Handles circular dependency validation, campaign generation from threat actors, and progress calculation. """ +# Import logging import logging import uuid from datetime import datetime from typing import Optional +# Import uuid +import uuid + +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session +# Import EntityNotFoundError, InvalidOperationError from app.domain.exceptions from app.domain.exceptions import EntityNotFoundError, InvalidOperationError -from app.models.campaign import Campaign, CampaignTest, KILL_CHAIN_PHASES -from app.models.test import Test -from app.models.test_template import TestTemplate -from app.models.technique import Technique -from app.models.threat_actor import ThreatActor, ThreatActorTechnique + +# Import Campaign, CampaignTest from app.models.campaign +from app.models.campaign import Campaign, CampaignTest + +# Import TechniqueStatus, TestState from app.models.enums from app.models.enums import TechniqueStatus, TestState -from app.services.notification_service import create_notification + +# Import Technique from app.models.technique +from app.models.technique import Technique + +# Import Test from app.models.test +from app.models.test import Test + +# Import TestTemplate from app.models.test_template +from app.models.test_template import TestTemplate + +# Import ThreatActor, ThreatActorTechnique from app.models.threat_actor +from app.models.threat_actor import ThreatActor, ThreatActorTechnique + +# Import User from app.models.user from app.models.user import User +# Assign logger = logging.getLogger(__name__) logger = logging.getLogger(__name__) # Mapping from ATT&CK tactics to kill chain phases TACTIC_TO_PHASE: dict[str, str] = { + # Literal argument value "reconnaissance": "reconnaissance", + # Literal argument value "resource-development": "resource_development", + # Literal argument value "initial-access": "initial_access", + # Literal argument value "execution": "execution", + # Literal argument value "persistence": "persistence", + # Literal argument value "privilege-escalation": "privilege_escalation", + # Literal argument value "defense-evasion": "defense_evasion", + # Literal argument value "credential-access": "credential_access", + # Literal argument value "discovery": "discovery", + # Literal argument value "lateral-movement": "lateral_movement", + # Literal argument value "collection": "collection", + # Literal argument value "command-and-control": "command_and_control", + # Literal argument value "exfiltration": "exfiltration", + # Literal argument value "impact": "impact", } +# Define function validate_no_circular_dependency def validate_no_circular_dependency( + # Entry: db db: Session, + # Entry: campaign_id campaign_id: uuid.UUID, + # Entry: test_id test_id: uuid.UUID, + # Entry: depends_on_id depends_on_id: uuid.UUID | None, ) -> None: """Walk the depends_on chain and verify no cycle is formed. Raises :class:`InvalidOperationError` if a circular dependency is detected. """ + # Check: depends_on_id is None if depends_on_id is None: + # Return control to caller return + # Assign visited = set() visited: set[uuid.UUID] = set() + # Assign current = depends_on_id current = depends_on_id + # Loop while current is not None while current is not None: + # Check: current in visited or current == test_id if current in visited or current == test_id: + # Raise InvalidOperationError raise InvalidOperationError( + # Literal argument value "Circular dependency detected in campaign test chain" ) + # Call visited.add() visited.add(current) + # Assign parent = db.query(CampaignTest).filter_by(id=current).first() parent = db.query(CampaignTest).filter_by(id=current).first() + # Assign current = parent.depends_on if parent else None current = parent.depends_on if parent else None +# Define function get_campaign_progress def get_campaign_progress(db: Session, campaign_id: uuid.UUID) -> dict: """Calculate progress statistics for a campaign. Returns counts of tests by state, plus total and completion percentage. """ + # Assign campaign_tests = ( campaign_tests = ( db.query(CampaignTest) + # Chain .filter() call .filter(CampaignTest.campaign_id == campaign_id) + # Chain .all() call .all() ) + # Check: not campaign_tests if not campaign_tests: + # Return { return { + # Literal argument value "total": 0, + # Literal argument value "by_state": {}, + # Literal argument value "completion_pct": 0.0, } + # Assign by_state = {} by_state: dict[str, int] = {} + # Iterate over campaign_tests for ct in campaign_tests: + # Assign test = ct.test test = ct.test + # Assign state = test.state.value if test and test.state else "unknown" state = test.state.value if test and test.state else "unknown" + # Assign by_state[state] = by_state.get(state, 0) + 1 by_state[state] = by_state.get(state, 0) + 1 + # Assign total = len(campaign_tests) total = len(campaign_tests) + # Assign completed = by_state.get("validated", 0) completed = by_state.get("validated", 0) + # Assign completion_pct = round(completed / total * 100, 1) if total > 0 else 0.0 completion_pct = round(completed / total * 100, 1) if total > 0 else 0.0 + # Return { return { + # Literal argument value "total": total, + # Literal argument value "by_state": by_state, + # Literal argument value "completion_pct": completion_pct, } +# Define function generate_campaign_from_threat_actor def generate_campaign_from_threat_actor( + # Entry: db db: Session, + # Entry: actor_id actor_id: uuid.UUID, + # Entry: user user: User, *, start_date: Optional[datetime] = None, @@ -119,75 +194,111 @@ def generate_campaign_from_threat_actor( 4. Create a campaign with tests ordered by kill chain phase 5. Return the campaign """ + # Assign actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first() actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first() + # Check: not actor if not actor: + # Raise EntityNotFoundError raise EntityNotFoundError("ThreatActor", str(actor_id)) # Get unvalidated techniques for this actor gap_techniques = ( db.query(Technique, ThreatActorTechnique) + # Chain .join() call .join(ThreatActorTechnique, ThreatActorTechnique.technique_id == Technique.id) + # Chain .filter() call .filter(ThreatActorTechnique.threat_actor_id == actor_id) + # Chain .filter() call .filter(Technique.status_global != TechniqueStatus.validated) + # Chain .order_by() call .order_by(Technique.tactic, Technique.mitre_id) + # Chain .all() call .all() ) + # Check: not gap_techniques if not gap_techniques: + # Raise InvalidOperationError raise InvalidOperationError( f"No uncovered techniques found for {actor.name}" ) # Create the campaign campaign = Campaign( + # Keyword argument: name name=f"APT Emulation: {actor.name}", + # Keyword argument: description description=f"Auto-generated campaign to test coverage against {actor.name} " f"({actor.mitre_id or 'unknown'}). " f"Covers {len(gap_techniques)} uncovered technique(s).", + # Keyword argument: type type="apt_emulation", + # Keyword argument: threat_actor_id threat_actor_id=actor_id, + # Keyword argument: status status="draft", + # Keyword argument: created_by created_by=user.id, + # Keyword argument: tags tags=[actor.name, "auto-generated"], start_date=start_date, ) + # Stage new record(s) for database insertion db.add(campaign) + # Flush changes to DB without committing the transaction db.flush() # Get campaign.id + # Assign order_index = 0 order_index = 0 + # Iterate over gap_techniques for tech, _at in gap_techniques: # Find best template for this technique template = ( db.query(TestTemplate) + # Chain .filter() call .filter( TestTemplate.mitre_technique_id == tech.mitre_id, TestTemplate.is_active == True, # noqa: E712 ) + # Chain .order_by() call .order_by( # Prioritize by severity: critical > high > medium > low TestTemplate.severity.desc(), TestTemplate.name, ) + # Chain .first() call .first() ) + # Check: not template if not template: + # continue # Skip techniques without templates continue # Skip techniques without templates # Create a test from the template test = Test( + # Keyword argument: technique_id technique_id=tech.id, + # Keyword argument: name name=f"[Campaign] {template.name}", + # Keyword argument: description description=template.description, + # Keyword argument: platform platform=template.platform, + # Keyword argument: procedure_text procedure_text=template.attack_procedure, + # Keyword argument: tool_used tool_used=template.tool_suggested, + # Keyword argument: created_by created_by=user.id, + # Keyword argument: state state=TestState.draft, created_at=datetime.utcnow(), ) + # Stage new record(s) for database insertion db.add(test) + # Flush changes to DB without committing the transaction db.flush() # Get test.id # Determine kill chain phase from the technique's tactic @@ -195,22 +306,33 @@ def generate_campaign_from_threat_actor( # Add to campaign campaign_test = CampaignTest( + # Keyword argument: campaign_id campaign_id=campaign.id, + # Keyword argument: test_id test_id=test.id, + # Keyword argument: order_index order_index=order_index, + # Keyword argument: phase phase=phase, ) + # Stage new record(s) for database insertion db.add(campaign_test) + # Assign order_index = 1 order_index += 1 + # Commit all pending changes to the database db.commit() + # Reload ORM object attributes from the database db.refresh(campaign) + # Log info: logger.info( + # Literal argument value "Generated campaign '%s' with %d tests for actor %s", campaign.name, order_index, actor.name, ) + # Return campaign return campaign diff --git a/backend/app/services/compliance_import_service.py b/backend/app/services/compliance_import_service.py index 1d92ddd..0e151b9 100644 --- a/backend/app/services/compliance_import_service.py +++ b/backend/app/services/compliance_import_service.py @@ -5,31 +5,431 @@ Defense's attack_to_nist_mapping repository to create ComplianceFramework, ComplianceControl, and ComplianceControlMapping records. """ +# Import logging import logging -import json -import re -from typing import Optional +# Import re +import re + +# Import requests import requests + +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session +# Import from app.models.compliance from app.models.compliance import ( - ComplianceFramework, ComplianceControl, ComplianceControlMapping, + ComplianceFramework, ) + +# Import Technique from app.models.technique from app.models.technique import Technique +# Assign logger = logging.getLogger(__name__) logger = logging.getLogger(__name__) +# ── Module-level control definitions (avoids N806 / uppercase-in-function) ──── + +_NIST_SAMPLE_CONTROLS = [ + { + # Literal argument value + "control_id": "AC-2", + # Literal argument value + "title": "Account Management", + # Literal argument value + "category": "Access Control", + # Literal argument value + "techniques": ["T1078", "T1136", "T1098", "T1087", "T1069"], + }, + { + # Literal argument value + "control_id": "AC-3", + # Literal argument value + "title": "Access Enforcement", + # Literal argument value + "category": "Access Control", + # Literal argument value + "techniques": ["T1078", "T1548", "T1134"], + }, + { + # Literal argument value + "control_id": "AC-4", + # Literal argument value + "title": "Information Flow Enforcement", + # Literal argument value + "category": "Access Control", + # Literal argument value + "techniques": ["T1048", "T1041", "T1572"], + }, + { + # Literal argument value + "control_id": "AC-6", + # Literal argument value + "title": "Least Privilege", + # Literal argument value + "category": "Access Control", + # Literal argument value + "techniques": ["T1078", "T1548", "T1134"], + }, + { + # Literal argument value + "control_id": "AU-2", + # Literal argument value + "title": "Event Logging", + # Literal argument value + "category": "Audit and Accountability", + # Literal argument value + "techniques": ["T1562", "T1070"], + }, + { + # Literal argument value + "control_id": "AU-6", + # Literal argument value + "title": "Audit Record Review", + # Literal argument value + "category": "Audit and Accountability", + # Literal argument value + "techniques": ["T1562", "T1070", "T1027"], + }, + { + # Literal argument value + "control_id": "CA-7", + # Literal argument value + "title": "Continuous Monitoring", + # Literal argument value + "category": "Assessment, Authorization, and Monitoring", + # Literal argument value + "techniques": ["T1059", "T1053"], + }, + { + # Literal argument value + "control_id": "CM-2", + # Literal argument value + "title": "Baseline Configuration", + # Literal argument value + "category": "Configuration Management", + # Literal argument value + "techniques": ["T1574", "T1546"], + }, + { + # Literal argument value + "control_id": "CM-6", + # Literal argument value + "title": "Configuration Settings", + # Literal argument value + "category": "Configuration Management", + # Literal argument value + "techniques": ["T1574", "T1546", "T1112"], + }, + { + # Literal argument value + "control_id": "CM-7", + # Literal argument value + "title": "Least Functionality", + # Literal argument value + "category": "Configuration Management", + # Literal argument value + "techniques": ["T1059", "T1218"], + }, + { + # Literal argument value + "control_id": "IA-2", + # Literal argument value + "title": "Identification and Authentication", + # Literal argument value + "category": "Identification and Authentication", + # Literal argument value + "techniques": ["T1078", "T1110"], + }, + { + # Literal argument value + "control_id": "IA-5", + # Literal argument value + "title": "Authenticator Management", + # Literal argument value + "category": "Identification and Authentication", + # Literal argument value + "techniques": ["T1078", "T1110", "T1003"], + }, + { + # Literal argument value + "control_id": "IR-4", + # Literal argument value + "title": "Incident Handling", + # Literal argument value + "category": "Incident Response", + # Literal argument value + "techniques": ["T1059", "T1547"], + }, + { + # Literal argument value + "control_id": "RA-5", + # Literal argument value + "title": "Vulnerability Monitoring and Scanning", + # Literal argument value + "category": "Risk Assessment", + # Literal argument value + "techniques": ["T1190", "T1203"], + }, + { + # Literal argument value + "control_id": "SC-7", + # Literal argument value + "title": "Boundary Protection", + # Literal argument value + "category": "System and Communications Protection", + # Literal argument value + "techniques": ["T1048", "T1041", "T1071"], + }, + { + # Literal argument value + "control_id": "SC-28", + # Literal argument value + "title": "Protection of Information at Rest", + # Literal argument value + "category": "System and Communications Protection", + # Literal argument value + "techniques": ["T1005", "T1114"], + }, + { + # Literal argument value + "control_id": "SI-3", + # Literal argument value + "title": "Malicious Code Protection", + # Literal argument value + "category": "System and Information Integrity", + # Literal argument value + "techniques": ["T1059", "T1204", "T1566"], + }, + { + # Literal argument value + "control_id": "SI-4", + # Literal argument value + "title": "System Monitoring", + # Literal argument value + "category": "System and Information Integrity", + # Literal argument value + "techniques": ["T1059", "T1053", "T1547"], + }, + { + # Literal argument value + "control_id": "SI-7", + # Literal argument value + "title": "Software, Firmware, and Information Integrity", + # Literal argument value + "category": "System and Information Integrity", + # Literal argument value + "techniques": ["T1195", "T1553"], + }, + { + # Literal argument value + "control_id": "PM-16", + # Literal argument value + "title": "Threat Awareness Program", + # Literal argument value + "category": "Program Management", + # Literal argument value + "techniques": ["T1566", "T1204"], + }, +] + +# Assign _CIS_CONTROLS = [ +_CIS_CONTROLS = [ + { + # Literal argument value + "control_id": "CIS-1", + # Literal argument value + "title": "Inventory and Control of Enterprise Assets", + # Literal argument value + "category": "IG1 — Basic", + # Literal argument value + "techniques": ["T1595", "T1590", "T1018", "T1082"], + }, + { + # Literal argument value + "control_id": "CIS-2", + # Literal argument value + "title": "Inventory and Control of Software Assets", + # Literal argument value + "category": "IG1 — Basic", + # Literal argument value + "techniques": ["T1518", "T1072", "T1195"], + }, + { + # Literal argument value + "control_id": "CIS-3", + # Literal argument value + "title": "Data Protection", + # Literal argument value + "category": "IG1 — Basic", + # Literal argument value + "techniques": ["T1005", "T1114", "T1560", "T1048", "T1041"], + }, + { + # Literal argument value + "control_id": "CIS-4", + # Literal argument value + "title": "Secure Configuration of Enterprise Assets and Software", + # Literal argument value + "category": "IG1 — Basic", + # Literal argument value + "techniques": ["T1574", "T1546", "T1112", "T1543"], + }, + { + # Literal argument value + "control_id": "CIS-5", + # Literal argument value + "title": "Account Management", + # Literal argument value + "category": "IG1 — Basic", + # Literal argument value + "techniques": ["T1078", "T1136", "T1098", "T1087"], + }, + { + # Literal argument value + "control_id": "CIS-6", + # Literal argument value + "title": "Access Control Management", + # Literal argument value + "category": "IG1 — Basic", + # Literal argument value + "techniques": ["T1078", "T1548", "T1134", "T1021"], + }, + { + # Literal argument value + "control_id": "CIS-7", + # Literal argument value + "title": "Continuous Vulnerability Management", + # Literal argument value + "category": "IG2 — Foundational", + # Literal argument value + "techniques": ["T1190", "T1203", "T1068", "T1210"], + }, + { + # Literal argument value + "control_id": "CIS-8", + # Literal argument value + "title": "Audit Log Management", + # Literal argument value + "category": "IG2 — Foundational", + # Literal argument value + "techniques": ["T1562", "T1070", "T1059"], + }, + { + # Literal argument value + "control_id": "CIS-9", + # Literal argument value + "title": "Email and Web Browser Protections", + # Literal argument value + "category": "IG2 — Foundational", + # Literal argument value + "techniques": ["T1566", "T1204", "T1189", "T1598"], + }, + { + # Literal argument value + "control_id": "CIS-10", + # Literal argument value + "title": "Malware Defenses", + # Literal argument value + "category": "IG2 — Foundational", + # Literal argument value + "techniques": ["T1059", "T1204", "T1027", "T1140", "T1497"], + }, + { + # Literal argument value + "control_id": "CIS-11", + # Literal argument value + "title": "Data Recovery", + # Literal argument value + "category": "IG1 — Basic", + # Literal argument value + "techniques": ["T1486", "T1490", "T1561"], + }, + { + # Literal argument value + "control_id": "CIS-12", + # Literal argument value + "title": "Network Infrastructure Management", + # Literal argument value + "category": "IG2 — Foundational", + # Literal argument value + "techniques": ["T1557", "T1071", "T1572", "T1571"], + }, + { + # Literal argument value + "control_id": "CIS-13", + # Literal argument value + "title": "Network Monitoring and Defense", + # Literal argument value + "category": "IG2 — Foundational", + # Literal argument value + "techniques": ["T1071", "T1048", "T1041", "T1105", "T1572"], + }, + { + # Literal argument value + "control_id": "CIS-14", + # Literal argument value + "title": "Security Awareness and Skills Training", + # Literal argument value + "category": "IG1 — Basic", + # Literal argument value + "techniques": ["T1566", "T1204", "T1598"], + }, + { + # Literal argument value + "control_id": "CIS-15", + # Literal argument value + "title": "Service Provider Management", + # Literal argument value + "category": "IG2 — Foundational", + # Literal argument value + "techniques": ["T1199", "T1195"], + }, + { + # Literal argument value + "control_id": "CIS-16", + # Literal argument value + "title": "Application Software Security", + # Literal argument value + "category": "IG2 — Foundational", + # Literal argument value + "techniques": ["T1190", "T1059", "T1203"], + }, + { + # Literal argument value + "control_id": "CIS-17", + # Literal argument value + "title": "Incident Response Management", + # Literal argument value + "category": "IG2 — Foundational", + # Literal argument value + "techniques": ["T1059", "T1547", "T1053"], + }, + { + # Literal argument value + "control_id": "CIS-18", + # Literal argument value + "title": "Penetration Testing", + # Literal argument value + "category": "IG3 — Organizational", + # Literal argument value + "techniques": ["T1595", "T1046", "T1190", "T1059"], + }, +] + # URL for the NIST 800-53 Rev 5 to ATT&CK mapping # This is the JSON STIX bundle that contains the relationships NIST_MAPPING_URL = ( + # Literal argument value "https://raw.githubusercontent.com/center-for-threat-informed-defense/" + # Literal argument value "attack_to_nist_mapping/main/data/attack-to-nist-rev5.json" ) +# Define function import_nist_800_53_mappings def import_nist_800_53_mappings(db: Session) -> dict: """Import NIST 800-53 Rev 5 mappings from MITRE CTI repository. @@ -45,30 +445,56 @@ def import_nist_800_53_mappings(db: Session) -> dict: # ── 1. Create or get framework ──────────────────────────────── framework = ( db.query(ComplianceFramework) + # Chain .filter() call .filter(ComplianceFramework.name == "NIST 800-53 Rev 5") + # Chain .first() call .first() ) + # Check: not framework if not framework: + # Assign framework = ComplianceFramework( framework = ComplianceFramework( + # Keyword argument: name name="NIST 800-53 Rev 5", + # Keyword argument: version version="5", - description="National Institute of Standards and Technology Special Publication 800-53 Revision 5 — Security and Privacy Controls for Information Systems and Organizations", + # Keyword argument: description + description=( + # Literal argument value + "National Institute of Standards and Technology " + # Literal argument value + "Special Publication 800-53 Revision 5 — " + # Literal argument value + "Security and Privacy Controls for Information Systems and Organizations" + ), + # Keyword argument: url url="https://csrc.nist.gov/publications/detail/sp/800-53/rev-5/final", + # Keyword argument: is_active is_active=True, ) + # Stage new record(s) for database insertion db.add(framework) + # Flush changes to DB without committing the transaction db.flush() + # Log info: "Created NIST 800-53 Rev 5 framework" logger.info("Created NIST 800-53 Rev 5 framework") + # Fallback: handle remaining cases else: + # Log info: "NIST 800-53 Rev 5 framework already exists" logger.info("NIST 800-53 Rev 5 framework already exists") # ── 2. Download STIX bundle ─────────────────────────────────── try: + # Assign response = requests.get(NIST_MAPPING_URL, timeout=30) response = requests.get(NIST_MAPPING_URL, timeout=30) + # Call response.raise_for_status() response.raise_for_status() + # Assign stix_bundle = response.json() stix_bundle = response.json() + # Handle requests.RequestException except requests.RequestException as e: + # Log warning: f"Failed to download STIX bundle: {e}" logger.warning(f"Failed to download STIX bundle: {e}") # Fallback: create a sample set of well-known NIST controls return _import_sample_nist_mappings(db, framework) @@ -79,87 +505,139 @@ def import_nist_800_53_mappings(db: Session) -> dict: # Build lookup maps # STIX IDs -> control info control_map = {} # stix_id -> {control_id, title, category} + # Assign technique_map = {} # stix_id -> mitre_technique_id technique_map = {} # stix_id -> mitre_technique_id + # Assign relationships = [] # (source_ref, target_ref) for "mitigates" relationships relationships = [] # (source_ref, target_ref) for "mitigates" relationships + # Iterate over objects for obj in objects: + # Assign obj_type = obj.get("type", "") obj_type = obj.get("type", "") + # Check: obj_type == "course-of-action" if obj_type == "course-of-action": # This is a NIST control name = obj.get("name", "") + # Assign desc = obj.get("description", "") desc = obj.get("description", "") + # Assign stix_id = obj.get("id", "") stix_id = obj.get("id", "") # Extract control ID from name (e.g., "AC-2 Account Management") match = re.match(r"^([A-Z]{2}-\d+(?:\.\d+)?)\s*(.*)", name) + # Check: match if match: + # Assign control_id = match.group(1) control_id = match.group(1) + # Assign title = match.group(2) or name title = match.group(2) or name + # Fallback: handle remaining cases else: + # Assign control_id = name control_id = name + # Assign title = name title = name # Extract category from control family category_match = re.match(r"^([A-Z]{2})", control_id) + # Assign category = _get_nist_category(category_match.group(1)) if category_match else ... category = _get_nist_category(category_match.group(1)) if category_match else None + # Assign control_map[stix_id] = { control_map[stix_id] = { + # Literal argument value "control_id": control_id, + # Literal argument value "title": title, + # Literal argument value "description": desc[:500] if desc else None, + # Literal argument value "category": category, } + # Alternative: obj_type == "attack-pattern" elif obj_type == "attack-pattern": # This is an ATT&CK technique stix_id = obj.get("id", "") + # Assign ext_refs = obj.get("external_references", []) ext_refs = obj.get("external_references", []) + # Iterate over ext_refs for ref in ext_refs: + # Check: ref.get("source_name") == "mitre-attack" if ref.get("source_name") == "mitre-attack": + # Assign technique_map[stix_id] = ref.get("external_id", "") technique_map[stix_id] = ref.get("external_id", "") + # Exit the loop early break + # Alternative: obj_type == "relationship" elif obj_type == "relationship": + # Assign rel_type = obj.get("relationship_type", "") rel_type = obj.get("relationship_type", "") + # Check: rel_type == "mitigates" if rel_type == "mitigates": + # Assign source_ref = obj.get("source_ref", "") source_ref = obj.get("source_ref", "") + # Assign target_ref = obj.get("target_ref", "") target_ref = obj.get("target_ref", "") + # Call relationships.append() relationships.append((source_ref, target_ref)) # ── 4. Create controls ──────────────────────────────────────── controls_created = 0 + # Assign controls_existing = 0 controls_existing = 0 + # Assign control_db_map = {} # control_id -> ComplianceControl control_db_map = {} # control_id -> ComplianceControl # Load existing controls for this framework existing_controls = { c.control_id: c for c in db.query(ComplianceControl) + # Chain .filter() call .filter(ComplianceControl.framework_id == framework.id) + # Chain .all() call .all() } + # Iterate over control_map.items() for stix_id, info in control_map.items(): + # Assign cid = info["control_id"] cid = info["control_id"] + # Check: cid in existing_controls if cid in existing_controls: + # Assign control_db_map[stix_id] = existing_controls[cid] control_db_map[stix_id] = existing_controls[cid] + # Assign controls_existing = 1 controls_existing += 1 + # Fallback: handle remaining cases else: + # Assign ctrl = ComplianceControl( ctrl = ComplianceControl( + # Keyword argument: framework_id framework_id=framework.id, + # Keyword argument: control_id control_id=cid, + # Keyword argument: title title=info["title"], + # Keyword argument: description description=info["description"], + # Keyword argument: category category=info["category"], ) + # Stage new record(s) for database insertion db.add(ctrl) + # Flush changes to DB without committing the transaction db.flush() + # Assign control_db_map[stix_id] = ctrl control_db_map[stix_id] = ctrl + # Assign controls_created = 1 controls_created += 1 # ── 5. Create mappings ──────────────────────────────────────── mappings_created = 0 + # Assign mappings_skipped = 0 mappings_skipped = 0 # Build technique DB lookup (mitre_id -> Technique) @@ -167,168 +645,205 @@ def import_nist_800_53_mappings(db: Session) -> dict: # Load existing mappings existing_mappings = set() + # Iterate over db.query(ComplianceControlMapping).all() for m in db.query(ComplianceControlMapping).all(): + # Call existing_mappings.add() existing_mappings.add((str(m.compliance_control_id), str(m.technique_id))) + # Iterate over relationships for source_ref, target_ref in relationships: + # Assign control = control_db_map.get(source_ref) control = control_db_map.get(source_ref) + # Assign mitre_id = technique_map.get(target_ref) mitre_id = technique_map.get(target_ref) + # Check: not control or not mitre_id if not control or not mitre_id: + # Assign mappings_skipped = 1 mappings_skipped += 1 + # Skip to the next loop iteration continue + # Assign technique = all_techniques.get(mitre_id) technique = all_techniques.get(mitre_id) + # Check: not technique if not technique: + # Assign mappings_skipped = 1 mappings_skipped += 1 + # Skip to the next loop iteration continue + # Assign key = (str(control.id), str(technique.id)) key = (str(control.id), str(technique.id)) + # Check: key in existing_mappings if key in existing_mappings: + # Assign mappings_skipped = 1 mappings_skipped += 1 + # Skip to the next loop iteration continue + # Assign mapping = ComplianceControlMapping( mapping = ComplianceControlMapping( + # Keyword argument: compliance_control_id compliance_control_id=control.id, + # Keyword argument: technique_id technique_id=technique.id, ) + # Stage new record(s) for database insertion db.add(mapping) + # Call existing_mappings.add() existing_mappings.add(key) + # Assign mappings_created = 1 mappings_created += 1 + # Commit all pending changes to the database db.commit() + # Assign summary = { summary = { + # Literal argument value "framework": framework.name, + # Literal argument value "controls_created": controls_created, + # Literal argument value "controls_existing": controls_existing, + # Literal argument value "mappings_created": mappings_created, + # Literal argument value "mappings_skipped": mappings_skipped, + # Literal argument value "total_controls": controls_created + controls_existing, + # Literal argument value "total_relationships_found": len(relationships), } + # Log info: f"NIST 800-53 import complete: {summary}" logger.info(f"NIST 800-53 import complete: {summary}") + # Return summary return summary +# Define function _import_sample_nist_mappings def _import_sample_nist_mappings(db: Session, framework: ComplianceFramework) -> dict: """Import a curated sample of NIST 800-53 controls when the download fails. This ensures the feature works even without network access. """ - SAMPLE_CONTROLS = [ - {"control_id": "AC-2", "title": "Account Management", "category": "Access Control", - "techniques": ["T1078", "T1136", "T1098", "T1087", "T1069"]}, - {"control_id": "AC-3", "title": "Access Enforcement", "category": "Access Control", - "techniques": ["T1078", "T1548", "T1134"]}, - {"control_id": "AC-4", "title": "Information Flow Enforcement", "category": "Access Control", - "techniques": ["T1048", "T1041", "T1572"]}, - {"control_id": "AC-6", "title": "Least Privilege", "category": "Access Control", - "techniques": ["T1078", "T1548", "T1134"]}, - {"control_id": "AU-2", "title": "Event Logging", "category": "Audit and Accountability", - "techniques": ["T1562", "T1070"]}, - {"control_id": "AU-6", "title": "Audit Record Review", "category": "Audit and Accountability", - "techniques": ["T1562", "T1070", "T1027"]}, - {"control_id": "CA-7", "title": "Continuous Monitoring", "category": "Assessment, Authorization, and Monitoring", - "techniques": ["T1059", "T1053"]}, - {"control_id": "CM-2", "title": "Baseline Configuration", "category": "Configuration Management", - "techniques": ["T1574", "T1546"]}, - {"control_id": "CM-6", "title": "Configuration Settings", "category": "Configuration Management", - "techniques": ["T1574", "T1546", "T1112"]}, - {"control_id": "CM-7", "title": "Least Functionality", "category": "Configuration Management", - "techniques": ["T1059", "T1218"]}, - {"control_id": "IA-2", "title": "Identification and Authentication", "category": "Identification and Authentication", - "techniques": ["T1078", "T1110"]}, - {"control_id": "IA-5", "title": "Authenticator Management", "category": "Identification and Authentication", - "techniques": ["T1078", "T1110", "T1003"]}, - {"control_id": "IR-4", "title": "Incident Handling", "category": "Incident Response", - "techniques": ["T1059", "T1547"]}, - {"control_id": "RA-5", "title": "Vulnerability Monitoring and Scanning", "category": "Risk Assessment", - "techniques": ["T1190", "T1203"]}, - {"control_id": "SC-7", "title": "Boundary Protection", "category": "System and Communications Protection", - "techniques": ["T1048", "T1041", "T1071"]}, - {"control_id": "SC-28", "title": "Protection of Information at Rest", "category": "System and Communications Protection", - "techniques": ["T1005", "T1114"]}, - {"control_id": "SI-3", "title": "Malicious Code Protection", "category": "System and Information Integrity", - "techniques": ["T1059", "T1204", "T1566"]}, - {"control_id": "SI-4", "title": "System Monitoring", "category": "System and Information Integrity", - "techniques": ["T1059", "T1053", "T1547"]}, - {"control_id": "SI-7", "title": "Software, Firmware, and Information Integrity", "category": "System and Information Integrity", - "techniques": ["T1195", "T1553"]}, - {"control_id": "PM-16", "title": "Threat Awareness Program", "category": "Program Management", - "techniques": ["T1566", "T1204"]}, - ] - # Build technique lookup all_techniques = {t.mitre_id: t for t in db.query(Technique).all()} + # Assign existing_controls = { existing_controls = { c.control_id: c for c in db.query(ComplianceControl) + # Chain .filter() call .filter(ComplianceControl.framework_id == framework.id) + # Chain .all() call .all() } + # Assign existing_mappings = set() existing_mappings = set() + # Iterate over db.query(ComplianceControlMapping).all() for m in db.query(ComplianceControlMapping).all(): + # Call existing_mappings.add() existing_mappings.add((str(m.compliance_control_id), str(m.technique_id))) + # Assign controls_created = 0 controls_created = 0 + # Assign mappings_created = 0 mappings_created = 0 - for sample in SAMPLE_CONTROLS: + # Iterate over _NIST_SAMPLE_CONTROLS + for sample in _NIST_SAMPLE_CONTROLS: # Create or get control if sample["control_id"] in existing_controls: + # Assign control = existing_controls[sample["control_id"]] control = existing_controls[sample["control_id"]] + # Fallback: handle remaining cases else: + # Assign control = ComplianceControl( control = ComplianceControl( + # Keyword argument: framework_id framework_id=framework.id, + # Keyword argument: control_id control_id=sample["control_id"], + # Keyword argument: title title=sample["title"], + # Keyword argument: category category=sample["category"], ) + # Stage new record(s) for database insertion db.add(control) + # Flush changes to DB without committing the transaction db.flush() + # Assign existing_controls[sample["control_id"]] = control existing_controls[sample["control_id"]] = control + # Assign controls_created = 1 controls_created += 1 # Create mappings for mitre_id in sample["techniques"]: + # Assign technique = all_techniques.get(mitre_id) technique = all_techniques.get(mitre_id) + # Check: not technique if not technique: # Try with subtechnique prefix for key, tech in all_techniques.items(): + # Check: key.startswith(mitre_id) if key.startswith(mitre_id): + # Assign technique = tech technique = tech + # Exit the loop early break + # Check: not technique if not technique: + # Skip to the next loop iteration continue + # Assign key = (str(control.id), str(technique.id)) key = (str(control.id), str(technique.id)) + # Check: key in existing_mappings if key in existing_mappings: + # Skip to the next loop iteration continue + # Assign mapping = ComplianceControlMapping( mapping = ComplianceControlMapping( + # Keyword argument: compliance_control_id compliance_control_id=control.id, + # Keyword argument: technique_id technique_id=technique.id, ) + # Stage new record(s) for database insertion db.add(mapping) + # Call existing_mappings.add() existing_mappings.add(key) + # Assign mappings_created = 1 mappings_created += 1 + # Commit all pending changes to the database db.commit() + # Return { return { + # Literal argument value "framework": framework.name, + # Literal argument value "controls_created": controls_created, + # Literal argument value "controls_existing": len(existing_controls) - controls_created, + # Literal argument value "mappings_created": mappings_created, + # Literal argument value "mappings_skipped": 0, + # Literal argument value "total_controls": len(existing_controls), + # Literal argument value "source": "sample_data", } +# Define function import_cis_controls_v8_mappings def import_cis_controls_v8_mappings(db: Session) -> dict: """Import CIS Controls v8 with ATT&CK technique mappings. @@ -340,23 +855,43 @@ def import_cis_controls_v8_mappings(db: Session) -> dict: # ── 1. Create or get framework ──────────────────────────────── framework = ( db.query(ComplianceFramework) + # Chain .filter() call .filter(ComplianceFramework.name == "CIS Controls v8") + # Chain .first() call .first() ) + # Check: not framework if not framework: + # Assign framework = ComplianceFramework( framework = ComplianceFramework( + # Keyword argument: name name="CIS Controls v8", + # Keyword argument: version version="8", - description="Center for Internet Security Critical Security Controls Version 8 — " - "a prioritized set of 18 security safeguards organized by Implementation Groups (IG1, IG2, IG3).", + # Keyword argument: description + description=( + # Literal argument value + "Center for Internet Security Critical Security Controls Version 8 — " + # Literal argument value + "a prioritized set of 18 security safeguards " + # Literal argument value + "organized by Implementation Groups (IG1, IG2, IG3)." + ), + # Keyword argument: url url="https://www.cisecurity.org/controls/v8", + # Keyword argument: is_active is_active=True, ) + # Stage new record(s) for database insertion db.add(framework) + # Flush changes to DB without committing the transaction db.flush() + # Log info: "Created CIS Controls v8 framework" logger.info("Created CIS Controls v8 framework") + # Fallback: handle remaining cases else: + # Log info: "CIS Controls v8 framework already exists" logger.info("CIS Controls v8 framework already exists") # ── 2. Control definitions with ATT&CK mappings ─────────────── @@ -626,65 +1161,111 @@ def import_cis_controls_v8_mappings(db: Session) -> dict: # Build technique lookup all_techniques = {t.mitre_id: t for t in db.query(Technique).all()} + # Assign existing_controls = { existing_controls = { c.control_id: c for c in db.query(ComplianceControl) + # Chain .filter() call .filter(ComplianceControl.framework_id == framework.id) + # Chain .all() call .all() } + # Assign existing_mappings = set() existing_mappings = set() + # for m in ( for m in ( db.query(ComplianceControlMapping) + # Chain .join() call .join(ComplianceControl) + # Chain .filter() call .filter(ComplianceControl.framework_id == framework.id) + # Chain .all() call .all() ): + # Call existing_mappings.add() existing_mappings.add((str(m.compliance_control_id), str(m.technique_id))) + # Assign controls_created = 0 controls_created = 0 + # Assign mappings_created = 0 mappings_created = 0 - for item in CIS_CONTROLS: + # Iterate over _CIS_CONTROLS + for item in _CIS_CONTROLS: + # Check: item["control_id"] in existing_controls if item["control_id"] in existing_controls: + # Assign control = existing_controls[item["control_id"]] control = existing_controls[item["control_id"]] + # Fallback: handle remaining cases else: + # Assign control = ComplianceControl( control = ComplianceControl( + # Keyword argument: framework_id framework_id=framework.id, + # Keyword argument: control_id control_id=item["control_id"], + # Keyword argument: title title=item["title"], + # Keyword argument: category category=item["category"], ) + # Stage new record(s) for database insertion db.add(control) + # Flush changes to DB without committing the transaction db.flush() + # Assign existing_controls[item["control_id"]] = control existing_controls[item["control_id"]] = control + # Assign controls_created = 1 controls_created += 1 + # Iterate over item["techniques"] for mitre_id in item["techniques"]: + # Assign technique = all_techniques.get(mitre_id) technique = all_techniques.get(mitre_id) + # Check: not technique if not technique: + # Skip to the next loop iteration continue + # Assign key = (str(control.id), str(technique.id)) key = (str(control.id), str(technique.id)) + # Check: key in existing_mappings if key in existing_mappings: + # Skip to the next loop iteration continue + # Assign mapping = ComplianceControlMapping( mapping = ComplianceControlMapping( + # Keyword argument: compliance_control_id compliance_control_id=control.id, + # Keyword argument: technique_id technique_id=technique.id, ) + # Stage new record(s) for database insertion db.add(mapping) + # Call existing_mappings.add() existing_mappings.add(key) + # Assign mappings_created = 1 mappings_created += 1 + # Commit all pending changes to the database db.commit() + # Assign summary = { summary = { + # Literal argument value "framework": framework.name, + # Literal argument value "controls_created": controls_created, + # Literal argument value "controls_existing": len(existing_controls) - controls_created, + # Literal argument value "mappings_created": mappings_created, + # Literal argument value "total_controls": len(existing_controls), } + # Log info: f"CIS Controls v8 import complete: {summary}" logger.info(f"CIS Controls v8 import complete: {summary}") + # Return summary return summary @@ -2205,26 +2786,48 @@ def _import_curated_framework( def _get_nist_category(family_code: str) -> str: """Map NIST 800-53 family code to category name.""" + # Assign categories = { categories = { + # Literal argument value "AC": "Access Control", + # Literal argument value "AT": "Awareness and Training", + # Literal argument value "AU": "Audit and Accountability", + # Literal argument value "CA": "Assessment, Authorization, and Monitoring", + # Literal argument value "CM": "Configuration Management", + # Literal argument value "CP": "Contingency Planning", + # Literal argument value "IA": "Identification and Authentication", + # Literal argument value "IR": "Incident Response", + # Literal argument value "MA": "Maintenance", + # Literal argument value "MP": "Media Protection", + # Literal argument value "PE": "Physical and Environmental Protection", + # Literal argument value "PL": "Planning", + # Literal argument value "PM": "Program Management", + # Literal argument value "PS": "Personnel Security", + # Literal argument value "PT": "Personally Identifiable Information Processing and Transparency", + # Literal argument value "RA": "Risk Assessment", + # Literal argument value "SA": "System and Services Acquisition", + # Literal argument value "SC": "System and Communications Protection", + # Literal argument value "SI": "System and Information Integrity", + # Literal argument value "SR": "Supply Chain Risk Management", } + # Return categories.get(family_code, "Unknown") return categories.get(family_code, "Unknown") diff --git a/backend/app/services/compliance_service.py b/backend/app/services/compliance_service.py index 15742ec..2340c22 100644 --- a/backend/app/services/compliance_service.py +++ b/backend/app/services/compliance_service.py @@ -6,111 +6,184 @@ that the router remains a thin HTTP adapter. This module is framework-agnostic: no FastAPI imports. """ +# Enable future language features for compatibility from __future__ import annotations +# Import csv import csv + +# Import io import io + +# Import Any from typing from typing import Any +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session +# Import EntityNotFoundError from app.domain.errors from app.domain.errors import EntityNotFoundError + +# Import from app.models.compliance from app.models.compliance import ( - ComplianceFramework, ComplianceControl, ComplianceControlMapping, + ComplianceFramework, ) -from app.models.technique import Technique -from app.models.test_template import TestTemplate -from app.models.threat_actor import ThreatActorTechnique -from app.services.scoring_service import calculate_technique_score +# Import Technique from app.models.technique +from app.models.technique import Technique + +# Import TestTemplate from app.models.test_template +from app.models.test_template import TestTemplate + +# Import ThreatActorTechnique from app.models.threat_actor +from app.models.threat_actor import ThreatActorTechnique + +# Import calculate_technique_score from app.services.scoring_service +from app.services.scoring_service import calculate_technique_score # ── Helpers ─────────────────────────────────────────────────────────── def _classify_control(technique_scores: list[float]) -> str: """Classify a control status based on its technique scores.""" + # Check: not technique_scores if not technique_scores: + # Return "not_evaluated" return "not_evaluated" + # Assign all_above_70 = all(s >= 70 for s in technique_scores) all_above_70 = all(s >= 70 for s in technique_scores) + # Assign any_above_30 = any(s >= 30 for s in technique_scores) any_above_30 = any(s >= 30 for s in technique_scores) + # Assign all_below_30 = all(s < 30 for s in technique_scores) all_below_30 = all(s < 30 for s in technique_scores) + # Assign all_zero = all(s == 0 for s in technique_scores) all_zero = all(s == 0 for s in technique_scores) + # Check: all_zero if all_zero: + # Return "not_evaluated" return "not_evaluated" + # Check: all_above_70 if all_above_70: + # Return "covered" return "covered" + # Check: all_below_30 if all_below_30: + # Return "not_covered" return "not_covered" + # Check: any_above_30 if any_above_30: + # Return "partially_covered" return "partially_covered" + # Return "not_covered" return "not_covered" +# Define function _get_control_status def _get_control_status(control: ComplianceControl, db: Session) -> dict[str, Any]: """Compute the status and score for a single control.""" + # Assign mappings = ( mappings = ( db.query(ComplianceControlMapping) + # Chain .filter() call .filter(ComplianceControlMapping.compliance_control_id == control.id) + # Chain .all() call .all() ) + # Check: not mappings if not mappings: + # Return { return { + # Literal argument value "control_id": control.control_id, + # Literal argument value "title": control.title, "description": control.description, "category": control.category, + # Literal argument value "status": "not_evaluated", + # Literal argument value "score": 0, + # Literal argument value "techniques_count": 0, + # Literal argument value "techniques_covered": 0, + # Literal argument value "techniques": [], } + # Assign technique_ids = [m.technique_id for m in mappings] technique_ids = [m.technique_id for m in mappings] + # Assign techniques = ( techniques = ( db.query(Technique) + # Chain .filter() call .filter(Technique.id.in_(technique_ids)) + # Chain .all() call .all() ) + # Assign tech_details = [] tech_details = [] + # Assign scores = [] scores = [] + # Assign covered_count = 0 covered_count = 0 + # Iterate over techniques for tech in techniques: + # Assign result = calculate_technique_score(tech, db) result = calculate_technique_score(tech, db) + # Assign score = result["total_score"] score = result["total_score"] + # Call scores.append() scores.append(score) + # Check: score >= 50 if score >= 50: + # Assign covered_count = 1 covered_count += 1 + # Call tech_details.append() tech_details.append({ + # Literal argument value "mitre_id": tech.mitre_id, + # Literal argument value "name": tech.name, + # Literal argument value "score": score, + # Literal argument value "status": tech.status_global.value if tech.status_global else "not_evaluated", }) # Sort techniques by score ascending (worst first for priority) tech_details.sort(key=lambda t: t["score"]) + # Assign avg_score = round(sum(scores) / len(scores), 1) if scores else 0 avg_score = round(sum(scores) / len(scores), 1) if scores else 0 + # Assign status = _classify_control(scores) status = _classify_control(scores) + # Return { return { + # Literal argument value "control_id": control.control_id, + # Literal argument value "title": control.title, "description": control.description, "category": control.category, + # Literal argument value "status": status, + # Literal argument value "score": avg_score, + # Literal argument value "techniques_count": len(techniques), + # Literal argument value "techniques_covered": covered_count, + # Literal argument value "techniques": tech_details, } @@ -120,95 +193,150 @@ def _get_control_status(control: ComplianceControl, db: Session) -> dict[str, An def list_frameworks(db: Session) -> list[dict[str, Any]]: """List all available compliance frameworks with control counts.""" + # Assign frameworks = ( frameworks = ( db.query(ComplianceFramework) + # Chain .filter() call .filter(ComplianceFramework.is_active == True) + # Chain .all() call .all() ) + # Assign result = [] result = [] + # Iterate over frameworks for fw in frameworks: + # Assign control_count = ( control_count = ( db.query(ComplianceControl) + # Chain .filter() call .filter(ComplianceControl.framework_id == fw.id) + # Chain .count() call .count() ) + # Call result.append() result.append({ + # Literal argument value "id": str(fw.id), + # Literal argument value "name": fw.name, + # Literal argument value "version": fw.version, + # Literal argument value "description": fw.description, + # Literal argument value "url": fw.url, + # Literal argument value "is_active": fw.is_active, + # Literal argument value "controls_count": control_count, }) + # Return result return result +# Define function get_framework def get_framework(db: Session, framework_id: str) -> ComplianceFramework | None: """Get a framework by ID, or None if not found.""" + # Return ( return ( db.query(ComplianceFramework) + # Chain .filter() call .filter(ComplianceFramework.id == framework_id) + # Chain .first() call .first() ) +# Define function get_framework_status def get_framework_status(db: Session, framework_id: str) -> dict[str, Any]: """Get compliance status for each control in a framework. Raises EntityNotFoundError if the framework does not exist. """ + # Assign framework = get_framework(db, framework_id) framework = get_framework(db, framework_id) + # Check: not framework if not framework: + # Raise EntityNotFoundError raise EntityNotFoundError("Framework", framework_id) + # Assign controls = ( controls = ( db.query(ComplianceControl) + # Chain .filter() call .filter(ComplianceControl.framework_id == framework.id) + # Chain .order_by() call .order_by(ComplianceControl.control_id) + # Chain .all() call .all() ) + # Assign control_statuses = [] control_statuses = [] + # Assign summary = { summary = { + # Literal argument value "total_controls": len(controls), + # Literal argument value "covered": 0, + # Literal argument value "partially_covered": 0, + # Literal argument value "not_covered": 0, + # Literal argument value "not_evaluated": 0, } + # Iterate over controls for control in controls: + # Assign status_data = _get_control_status(control, db) status_data = _get_control_status(control, db) + # Call control_statuses.append() control_statuses.append(status_data) + # Assign status = status_data["status"] status = status_data["status"] + # Check: status in summary if status in summary: + # Assign summary[status] = 1 summary[status] += 1 # Compliance percentage: (covered + partially_covered*0.5) / total * 100 total = summary["total_controls"] + # Check: total > 0 if total > 0: + # Assign compliance_pct = round( compliance_pct = round( (summary["covered"] + summary["partially_covered"] * 0.5) / total * 100, + # Literal argument value 1, ) + # Fallback: handle remaining cases else: + # Assign compliance_pct = 0 compliance_pct = 0 + # Assign summary["compliance_percentage"] = compliance_pct summary["compliance_percentage"] = compliance_pct + # Return { return { + # Literal argument value "framework": {"id": str(framework.id), "name": framework.name}, + # Literal argument value "summary": summary, + # Literal argument value "controls": control_statuses, } +# Define function build_framework_report_csv def build_framework_report_csv( + # Entry: db db: Session, + # Entry: framework_id framework_id: str, ) -> tuple[bytes, str]: """Build the compliance report CSV content and filename. @@ -217,33 +345,55 @@ def build_framework_report_csv( Raises EntityNotFoundError if the framework does not exist. """ + # Assign framework = get_framework(db, framework_id) framework = get_framework(db, framework_id) + # Check: not framework if not framework: + # Raise EntityNotFoundError raise EntityNotFoundError("Framework", framework_id) + # Assign controls = ( controls = ( db.query(ComplianceControl) + # Chain .filter() call .filter(ComplianceControl.framework_id == framework.id) + # Chain .order_by() call .order_by(ComplianceControl.control_id) + # Chain .all() call .all() ) + # Assign output = io.StringIO() output = io.StringIO() + # Assign writer = csv.writer(output) writer = csv.writer(output) + # Call writer.writerow() writer.writerow([ + # Literal argument value "control_id", + # Literal argument value "title", + # Literal argument value "category", + # Literal argument value "status", + # Literal argument value "score", + # Literal argument value "techniques_total", + # Literal argument value "techniques_covered", + # Literal argument value "technique_ids", ]) + # Iterate over controls for control in controls: + # Assign status_data = _get_control_status(control, db) status_data = _get_control_status(control, db) + # Assign technique_ids = ",".join(t["mitre_id"] for t in status_data["techniques"]) technique_ids = ",".join(t["mitre_id"] for t in status_data["techniques"]) + # Call writer.writerow() writer.writerow([ status_data["control_id"], status_data["title"], @@ -255,75 +405,116 @@ def build_framework_report_csv( technique_ids, ]) + # Call output.seek() output.seek(0) + # Assign filename = f"compliance_{framework.name.replace(' ', '_')}.csv" filename = f"compliance_{framework.name.replace(' ', '_')}.csv" + # Return output.getvalue().encode("utf-8"), filename return output.getvalue().encode("utf-8"), filename +# Define function get_framework_gaps def get_framework_gaps(db: Session, framework_id: str) -> dict[str, Any]: """Get controls with techniques that are not adequately covered. Raises EntityNotFoundError if the framework does not exist. """ + # Assign framework = get_framework(db, framework_id) framework = get_framework(db, framework_id) + # Check: not framework if not framework: + # Raise EntityNotFoundError raise EntityNotFoundError("Framework", framework_id) + # Assign controls = ( controls = ( db.query(ComplianceControl) + # Chain .filter() call .filter(ComplianceControl.framework_id == framework.id) + # Chain .order_by() call .order_by(ComplianceControl.control_id) + # Chain .all() call .all() ) + # Assign gaps = [] gaps = [] + # Iterate over controls for control in controls: + # Assign status_data = _get_control_status(control, db) status_data = _get_control_status(control, db) + # Check: status_data["status"] in ("not_covered", "partially_covered") if status_data["status"] in ("not_covered", "partially_covered"): # Find uncovered techniques uncovered_techniques = [] + # Iterate over status_data["techniques"] for tech_info in status_data["techniques"]: + # Check: tech_info["score"] < 70 if tech_info["score"] < 70: # Count available templates template_count = ( db.query(TestTemplate) + # Chain .filter() call .filter(TestTemplate.mitre_technique_id == tech_info["mitre_id"]) + # Chain .count() call .count() ) # Count threat actors using this technique technique = ( db.query(Technique) + # Chain .filter() call .filter(Technique.mitre_id == tech_info["mitre_id"]) + # Chain .first() call .first() ) + # Assign actor_count = 0 actor_count = 0 + # Check: technique if technique: + # Assign actor_count = ( actor_count = ( db.query(ThreatActorTechnique) + # Chain .filter() call .filter(ThreatActorTechnique.technique_id == technique.id) + # Chain .count() call .count() ) + # Call uncovered_techniques.append() uncovered_techniques.append({ **tech_info, + # Literal argument value "templates_available": template_count, + # Literal argument value "threat_actors_using": actor_count, }) + # Check: uncovered_techniques if uncovered_techniques: + # Call gaps.append() gaps.append({ + # Literal argument value "control_id": status_data["control_id"], + # Literal argument value "title": status_data["title"], + # Literal argument value "category": status_data["category"], + # Literal argument value "status": status_data["status"], + # Literal argument value "score": status_data["score"], + # Literal argument value "uncovered_techniques": uncovered_techniques, }) + # Return { return { + # Literal argument value "framework": {"id": str(framework.id), "name": framework.name}, + # Literal argument value "total_gaps": len(gaps), + # Literal argument value "gaps": gaps, } diff --git a/backend/app/services/coverage_report_service.py b/backend/app/services/coverage_report_service.py index c33f1df..64cb6b8 100644 --- a/backend/app/services/coverage_report_service.py +++ b/backend/app/services/coverage_report_service.py @@ -7,120 +7,202 @@ technique/test-count pattern by using a single grouped query. This module is framework-agnostic: no FastAPI imports. """ +# Enable future language features for compatibility from __future__ import annotations +# Import datetime from datetime from datetime import datetime +# Import func from sqlalchemy from sqlalchemy import func + +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session +# Import Technique from app.models.technique from app.models.technique import Technique + +# Import Test from app.models.test from app.models.test import Test + +# Import escape_like from app.utils from app.utils import escape_like +# Define function _technique_test_counts def _technique_test_counts( + # Entry: db db: Session, + # Entry: technique_ids technique_ids: list, ) -> dict: """Return ``{technique_id: {state_str: count}}`` in a single query.""" + # Check: not technique_ids if not technique_ids: + # Return {} return {} + # Assign rows = ( rows = ( db.query(Test.technique_id, Test.state, func.count(Test.id)) + # Chain .filter() call .filter(Test.technique_id.in_(technique_ids)) + # Chain .group_by() call .group_by(Test.technique_id, Test.state) + # Chain .all() call .all() ) + # Assign result = {} result: dict = {} + # Iterate over rows for tid, state, count in rows: + # Call result.setdefault() result.setdefault(tid, {})[str(state)] = count + # Return result return result +# Define function build_coverage_summary def build_coverage_summary( + # Entry: db db: Session, *, + # Entry: tactic tactic: str | None = None, + # Entry: platform platform: str | None = None, ) -> dict: """Build the full coverage summary report as a dict.""" + # Assign query = db.query(Technique) query = db.query(Technique) + # Check: tactic if tactic: + # Assign query = query.filter(Technique.tactic.ilike(f"%{escape_like(tactic)}%")) query = query.filter(Technique.tactic.ilike(f"%{escape_like(tactic)}%")) + # Assign techniques = query.order_by(Technique.mitre_id).all() techniques = query.order_by(Technique.mitre_id).all() + # Assign counts_map = _technique_test_counts(db, [t.id for t in techniques]) counts_map = _technique_test_counts(db, [t.id for t in techniques]) + # Assign rows = [] rows = [] + # Iterate over techniques for t in techniques: + # Check: platform and platform.lower() not in [p.lower() for p in (t.platfor... if platform and platform.lower() not in [p.lower() for p in (t.platforms or [])]: + # Skip to the next loop iteration continue + # Assign counts = counts_map.get(t.id, {}) counts = counts_map.get(t.id, {}) + # Call rows.append() rows.append({ + # Literal argument value "mitre_id": t.mitre_id, + # Literal argument value "name": t.name, + # Literal argument value "tactic": t.tactic, + # Literal argument value "platforms": t.platforms, + # Literal argument value "status_global": t.status_global, + # Literal argument value "total_tests": sum(counts.values()), + # Literal argument value "tests_by_state": counts, }) + # Assign total = len(rows) total = len(rows) + # Assign validated = sum(1 for r in rows if r["status_global"] == "validated") validated = sum(1 for r in rows if r["status_global"] == "validated") + # Assign partial = sum(1 for r in rows if r["status_global"] == "partial") partial = sum(1 for r in rows if r["status_global"] == "partial") + # Assign not_covered = sum(1 for r in rows if r["status_global"] == "not_covered") not_covered = sum(1 for r in rows if r["status_global"] == "not_covered") + # Assign in_progress = sum(1 for r in rows if r["status_global"] == "in_progress") in_progress = sum(1 for r in rows if r["status_global"] == "in_progress") + # Assign not_evaluated = sum(1 for r in rows if r["status_global"] == "not_evaluated") not_evaluated = sum(1 for r in rows if r["status_global"] == "not_evaluated") + # Return { return { + # Literal argument value "generated_at": datetime.utcnow().isoformat(), + # Literal argument value "summary": { + # Literal argument value "total_techniques": total, + # Literal argument value "validated": validated, + # Literal argument value "partial": partial, + # Literal argument value "not_covered": not_covered, + # Literal argument value "in_progress": in_progress, + # Literal argument value "not_evaluated": not_evaluated, + # Literal argument value "coverage_percentage": round((validated / total * 100) if total > 0 else 0, 1), }, + # Literal argument value "techniques": rows, } +# Define function build_coverage_csv_rows def build_coverage_csv_rows( + # Entry: db db: Session, *, + # Entry: tactic tactic: str | None = None, + # Entry: platform platform: str | None = None, ) -> list[list]: """Build rows for a CSV coverage export (header + data).""" + # Assign query = db.query(Technique) query = db.query(Technique) + # Check: tactic if tactic: + # Assign query = query.filter(Technique.tactic.ilike(f"%{escape_like(tactic)}%")) query = query.filter(Technique.tactic.ilike(f"%{escape_like(tactic)}%")) + # Assign techniques = query.order_by(Technique.mitre_id).all() techniques = query.order_by(Technique.mitre_id).all() + # Assign counts_map = _technique_test_counts(db, [t.id for t in techniques]) counts_map = _technique_test_counts(db, [t.id for t in techniques]) + # Assign header = [ header = [ + # Literal argument value "MITRE ID", "Name", "Tactic", "Platforms", "Status", + # Literal argument value "Total Tests", "Validated", "In Progress", "Not Covered", ] + # Assign rows = [header] rows = [header] + # Assign in_progress_states = {"draft", "red_executing", "blue_evaluating", "in_review"} in_progress_states = {"draft", "red_executing", "blue_evaluating", "in_review"} + # Iterate over techniques for t in techniques: + # Check: platform and platform.lower() not in [p.lower() for p in (t.platfor... if platform and platform.lower() not in [p.lower() for p in (t.platforms or [])]: + # Skip to the next loop iteration continue + # Assign counts = counts_map.get(t.id, {}) counts = counts_map.get(t.id, {}) + # Call rows.append() rows.append([ t.mitre_id, t.name, t.tactic, + # Literal argument value ", ".join(t.platforms or []), t.status_global, sum(counts.values()), @@ -129,65 +211,111 @@ def build_coverage_csv_rows( counts.get("rejected", 0), ]) + # Return rows return rows +# Define function build_test_results_report def build_test_results_report( + # Entry: db db: Session, *, + # Entry: state state: str | None = None, + # Entry: date_from date_from: str | None = None, + # Entry: date_to date_to: str | None = None, ) -> dict: """Build a test results report with optional filters.""" + # Assign query = db.query(Test) query = db.query(Test) + # Check: state if state: + # Assign query = query.filter(Test.state == state) query = query.filter(Test.state == state) + # Check: date_from if date_from: + # Attempt the following; catch errors below try: + # Assign query = query.filter(Test.created_at >= datetime.fromisoformat(date_from)) query = query.filter(Test.created_at >= datetime.fromisoformat(date_from)) + # Handle ValueError except ValueError: + # Intentional no-op placeholder pass + # Check: date_to if date_to: + # Attempt the following; catch errors below try: + # Assign query = query.filter(Test.created_at <= datetime.fromisoformat(date_to)) query = query.filter(Test.created_at <= datetime.fromisoformat(date_to)) + # Handle ValueError except ValueError: + # Intentional no-op placeholder pass + # Assign tests = query.order_by(Test.created_at.desc()).all() tests = query.order_by(Test.created_at.desc()).all() + # Assign by_state = {} by_state: dict[str, int] = {} + # Assign by_result = {} by_result: dict[str, int] = {} + # Iterate over tests for t in tests: + # Assign s = t.state.value if hasattr(t.state, "value") else str(t.state) s = t.state.value if hasattr(t.state, "value") else str(t.state) + # Assign by_state[s] = by_state.get(s, 0) + 1 by_state[s] = by_state.get(s, 0) + 1 + # Check: t.detection_result if t.detection_result: + # Assign r = t.detection_result.value if hasattr(t.detection_result, "value") el... r = t.detection_result.value if hasattr(t.detection_result, "value") else str(t.detection_result) + # Assign by_result[r] = by_result.get(r, 0) + 1 by_result[r] = by_result.get(r, 0) + 1 + # Return { return { + # Literal argument value "generated_at": datetime.utcnow().isoformat(), + # Literal argument value "filters": {"state": state, "date_from": date_from, "date_to": date_to}, + # Literal argument value "summary": { + # Literal argument value "total_tests": len(tests), + # Literal argument value "by_state": by_state, + # Literal argument value "by_detection_result": by_result, }, + # Literal argument value "tests": [ { + # Literal argument value "id": str(t.id), + # Literal argument value "name": t.name, + # Literal argument value "technique_id": str(t.technique_id), + # Literal argument value "state": t.state.value if hasattr(t.state, "value") else str(t.state), + # Literal argument value "platform": t.platform, + # Literal argument value "attack_success": t.attack_success, + # Literal argument value "detection_result": ( t.detection_result.value if t.detection_result and hasattr(t.detection_result, "value") else str(t.detection_result) if t.detection_result else None ), + # Literal argument value "red_validation_status": t.red_validation_status, + # Literal argument value "blue_validation_status": t.blue_validation_status, + # Literal argument value "created_at": t.created_at.isoformat() if t.created_at else None, } for t in tests @@ -195,38 +323,62 @@ def build_test_results_report( } +# Define function build_remediation_status_report def build_remediation_status_report( + # Entry: db db: Session, *, + # Entry: status status: str | None = None, ) -> dict: """Build a remediation status report.""" + # Assign query = db.query(Test).filter(Test.remediation_steps.isnot(None)) query = db.query(Test).filter(Test.remediation_steps.isnot(None)) + # Check: status if status: + # Assign query = query.filter(Test.remediation_status == status) query = query.filter(Test.remediation_status == status) + # Assign tests = query.order_by(Test.created_at.desc()).all() tests = query.order_by(Test.created_at.desc()).all() + # Assign by_status = {} by_status: dict[str, int] = {} + # Iterate over tests for t in tests: + # Assign s = t.remediation_status or "unset" s = t.remediation_status or "unset" + # Assign by_status[s] = by_status.get(s, 0) + 1 by_status[s] = by_status.get(s, 0) + 1 + # Return { return { + # Literal argument value "generated_at": datetime.utcnow().isoformat(), + # Literal argument value "summary": { + # Literal argument value "total_with_remediation": len(tests), + # Literal argument value "by_status": by_status, }, + # Literal argument value "tests": [ { + # Literal argument value "id": str(t.id), + # Literal argument value "name": t.name, + # Literal argument value "technique_id": str(t.technique_id), + # Literal argument value "state": t.state.value if hasattr(t.state, "value") else str(t.state), + # Literal argument value "remediation_status": t.remediation_status, + # Literal argument value "remediation_steps": t.remediation_steps, + # Literal argument value "remediation_assignee": str(t.remediation_assignee) if t.remediation_assignee else None, } for t in tests diff --git a/backend/app/services/d3fend_import_service.py b/backend/app/services/d3fend_import_service.py index d0f2f91..ed831c0 100644 --- a/backend/app/services/d3fend_import_service.py +++ b/backend/app/services/d3fend_import_service.py @@ -1,148 +1,270 @@ -"""D3FEND import service — fetches MITRE D3FEND data and creates -DefensiveTechnique records plus ATT&CK → D3FEND mappings. +"""D3FEND import service — fetches MITRE D3FEND data and creates DefensiveTechnique records plus ATT&CK → D3FEND mappings. Uses the D3FEND public API: - https://d3fend.mitre.org/api/technique/api-all.json (all defensive techniques) - https://d3fend.mitre.org/api/offensive-technique/{attack_id}.json (mappings per ATT&CK technique) """ +# Import logging import logging -import uuid + +# Import Any from typing from typing import Any +# Import UUID from uuid +from uuid import UUID + +# Import httpx import httpx + +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session -from app.models.technique import Technique +# Import DefensiveTechnique, DefensiveTechniqueMapping from app.models.defensive_technique from app.models.defensive_technique import DefensiveTechnique, DefensiveTechniqueMapping +# Import Technique from app.models.technique +from app.models.technique import Technique + +# Assign logger = logging.getLogger(__name__) logger = logging.getLogger(__name__) +# Assign D3FEND_TACTIC_URL = "https://d3fend.mitre.org/api/tactic/d3f:{tactic}.json" D3FEND_TACTIC_URL = "https://d3fend.mitre.org/api/tactic/d3f:{tactic}.json" +# Assign D3FEND_MAPPING_URL = "https://d3fend.mitre.org/api/offensive-technique/{attack_id}.json" D3FEND_MAPPING_URL = "https://d3fend.mitre.org/api/offensive-technique/{attack_id}.json" +# Assign D3FEND_BASE_URL = "https://d3fend.mitre.org/technique/d3f:{iri}" D3FEND_BASE_URL = "https://d3fend.mitre.org/technique/d3f:{iri}" +# Assign D3FEND_TACTICS = ["Detect", "Harden", "Isolate", "Deceive", "Evict", "Model"] D3FEND_TACTICS = ["Detect", "Harden", "Isolate", "Deceive", "Evict", "Model"] # ── Import all D3FEND techniques ───────────────────────────────────── -def _to_str(v: Any) -> str: - """Coerce an RDF value (str, dict with @value, or list) to a plain string.""" +def _to_str(v: Any) -> str: # noqa: ANN401 + """Coerce an RDF value (str, dict with @value, or list) to a plain string. + + Args: + v (Any): RDF node value — may be a plain string, a dict containing + a ``@value`` key, or a list of such values. + + Returns: + str: Plain string representation; ``"; "``-joined for list inputs. + """ + # Check: isinstance(v, dict) if isinstance(v, dict): + # Return v.get("@value", str(v)) return v.get("@value", str(v)) + # Check: isinstance(v, list) if isinstance(v, list): + # Return "; ".join(_to_str(x) for x in v) return "; ".join(_to_str(x) for x in v) + # Return str(v) if v else "" return str(v) if v else "" +# Define function _fetch_techniques_from_tactic_apis def _fetch_techniques_from_tactic_apis() -> list[dict[str, Any]]: """Fetch all defensive techniques via D3FEND tactic APIs. Uses ``/api/tactic/d3f:{tactic}.json`` which is reliable and returns full metadata including the ontology IRI for each technique. + + Returns: + list[dict[str, Any]]: Deduplicated list of technique dicts, each + containing ``d3fend_id``, ``iri``, ``name``, ``description``, + and ``tactic``. """ + # Assign all_techniques = [] all_techniques: list[dict[str, Any]] = [] + # Assign seen = set() seen: set[str] = set() + # Open context manager with httpx.Client(timeout=60.0) as client: + # Iterate over D3FEND_TACTICS for tactic in D3FEND_TACTICS: + # Assign url = D3FEND_TACTIC_URL.format(tactic=tactic) url = D3FEND_TACTIC_URL.format(tactic=tactic) + # Attempt the following; catch errors below try: + # Assign resp = client.get(url) resp = client.get(url) + # Call resp.raise_for_status() resp.raise_for_status() + # Assign data = resp.json() data = resp.json() + # Handle Exception except Exception as e: + # Log warning: "Failed to fetch D3FEND tactic %s: %s", tactic, e logger.warning("Failed to fetch D3FEND tactic %s: %s", tactic, e) + # Skip to the next loop iteration continue + # Assign graph = data.get("techniques", {}).get("@graph", []) graph = data.get("techniques", {}).get("@graph", []) + # Iterate over graph for node in graph: + # Assign nid = node.get("@id", "") nid = node.get("@id", "") + # Assign d3id = _to_str(node.get("d3f:d3fend-id", "")) d3id = _to_str(node.get("d3f:d3fend-id", "")) + # Assign label = _to_str(node.get("rdfs:label", "")) label = _to_str(node.get("rdfs:label", "")) + # Assign defn = _to_str(node.get("d3f:definition", "")) defn = _to_str(node.get("d3f:definition", "")) + # Check: not defn if not defn: + # Assign defn = _to_str(node.get("rdfs:comment", "")) defn = _to_str(node.get("rdfs:comment", "")) + # Assign iri = nid.replace("d3f:", "") if nid.startswith("d3f:") else nid iri = nid.replace("d3f:", "") if nid.startswith("d3f:") else nid + # Check: d3id and label and d3id not in seen if d3id and label and d3id not in seen: + # Call seen.add() seen.add(d3id) + # Call all_techniques.append() all_techniques.append({ + # Literal argument value "d3fend_id": d3id, + # Literal argument value "iri": iri, + # Literal argument value "name": label, + # Literal argument value "description": defn[:500] if defn else None, + # Literal argument value "tactic": tactic, }) + # Log info: "D3FEND tactic %s: %d techniques", tactic, len(gra logger.info("D3FEND tactic %s: %d techniques", tactic, len(graph)) + # Return all_techniques return all_techniques +# Define function _upsert_techniques def _upsert_techniques(db: Session, techniques: list[dict[str, Any]]) -> dict[str, int]: - """Upsert a list of technique dicts into the DefensiveTechnique table.""" + """Upsert a list of technique dicts into the DefensiveTechnique table. + + Args: + db (Session): Active SQLAlchemy database session. + techniques (list[dict[str, Any]]): List of technique data dicts, each + containing ``d3fend_id``, ``name``, and optionally ``description``, + ``tactic``, and ``iri``. + + Returns: + dict[str, int]: Contains ``created``, ``updated``, and ``total`` + counts after the upsert. + """ + # Assign created = 0 created = 0 + # Assign updated = 0 updated = 0 + # Iterate over techniques for tech_data in techniques: + # Assign existing = ( existing = ( db.query(DefensiveTechnique) + # Chain .filter() call .filter(DefensiveTechnique.d3fend_id == tech_data["d3fend_id"]) + # Chain .first() call .first() ) + # Assign iri = tech_data.get("iri") or tech_data["name"].replace(" ", "") iri = tech_data.get("iri") or tech_data["name"].replace(" ", "") + # Assign d3fend_url = D3FEND_BASE_URL.format(iri=iri) d3fend_url = D3FEND_BASE_URL.format(iri=iri) + # Check: existing if existing: + # Assign existing.name = tech_data["name"] existing.name = tech_data["name"] + # Assign existing.description = tech_data.get("description") existing.description = tech_data.get("description") + # Assign existing.tactic = tech_data.get("tactic") existing.tactic = tech_data.get("tactic") + # Assign existing.d3fend_url = d3fend_url existing.d3fend_url = d3fend_url + # Assign updated = 1 updated += 1 + # Fallback: handle remaining cases else: + # Assign new_tech = DefensiveTechnique( new_tech = DefensiveTechnique( + # Keyword argument: d3fend_id d3fend_id=tech_data["d3fend_id"], + # Keyword argument: name name=tech_data["name"], + # Keyword argument: description description=tech_data.get("description"), + # Keyword argument: tactic tactic=tech_data.get("tactic"), + # Keyword argument: d3fend_url d3fend_url=d3fend_url, ) + # Stage new record(s) for database insertion db.add(new_tech) + # Assign created = 1 created += 1 + # Commit all pending changes to the database db.commit() + # Assign total = db.query(DefensiveTechnique).count() total = db.query(DefensiveTechnique).count() + # Return {"created": created, "updated": updated, "total": total} return {"created": created, "updated": updated, "total": total} +# Define function import_d3fend_techniques def import_d3fend_techniques(db: Session) -> dict[str, int]: """Fetch all D3FEND defensive techniques and upsert into DB. Uses the tactic-level APIs which are reliable and provide full metadata including ontology IRIs for correct URL generation. - Returns a dict with counts: {created, updated, total}. + Args: + db (Session): Active SQLAlchemy database session. + + Returns: + dict[str, int]: Contains ``created``, ``updated``, and ``total`` + counts; falls back to curated list when the API returns fewer + than 50 techniques. """ + # Log info: "Fetching D3FEND techniques from tactic APIs" logger.info("Fetching D3FEND techniques from tactic APIs") + # Attempt the following; catch errors below try: + # Assign techniques = _fetch_techniques_from_tactic_apis() techniques = _fetch_techniques_from_tactic_apis() + # Handle Exception except Exception as e: + # Log error: "Failed to fetch D3FEND techniques from tactic API logger.error("Failed to fetch D3FEND techniques from tactic APIs: %s", e) + # Assign techniques = [] techniques = [] + # Check: len(techniques) >= 50 if len(techniques) >= 50: + # Log info: "Fetched %d D3FEND techniques from tactic APIs", l logger.info("Fetched %d D3FEND techniques from tactic APIs", len(techniques)) + # Assign result = _upsert_techniques(db, techniques) result = _upsert_techniques(db, techniques) + # Log info: "D3FEND import done: %d created, %d updated, %d to logger.info("D3FEND import done: %d created, %d updated, %d total", result["created"], result["updated"], result["total"]) + # Return result return result # Fallback: use a curated list of well-known D3FEND techniques logger.warning("Tactic APIs returned too few techniques (%d), using fallback", len(techniques)) + # Return _import_d3fend_fallback(db) return _import_d3fend_fallback(db) @@ -228,9 +350,20 @@ _FALLBACK_TECHNIQUES: list[dict[str, str | None]] = [ ] +# Define function _import_d3fend_fallback def _import_d3fend_fallback(db: Session) -> dict[str, int]: - """Import curated D3FEND techniques when the tactic APIs are unreachable.""" + """Import curated D3FEND techniques when the tactic APIs are unreachable. + + Args: + db (Session): Active SQLAlchemy database session. + + Returns: + dict[str, int]: Contains ``created``, ``updated``, and ``total`` + counts from upserting the fallback technique list. + """ + # Log info: "Using fallback D3FEND technique list (%d entries logger.info("Using fallback D3FEND technique list (%d entries)", len(_FALLBACK_TECHNIQUES)) + # Return _upsert_techniques(db, _FALLBACK_TECHNIQUES) # type: ignore[arg-type] return _upsert_techniques(db, _FALLBACK_TECHNIQUES) # type: ignore[arg-type] @@ -239,217 +372,399 @@ def _import_d3fend_fallback(db: Session) -> dict[str, int]: # Curated ATT&CK → D3FEND mapping for common techniques _ATTACK_TO_D3FEND: dict[str, list[str]] = { + # Literal argument value "T1059": ["D3-PSA", "D3-SCA", "D3-PA", "D3-EAW", "D3-EDL", "D3-PLA"], + # Literal argument value "T1059.001": ["D3-PSA", "D3-SCA", "D3-PA", "D3-EAW", "D3-EDL"], + # Literal argument value "T1059.003": ["D3-PSA", "D3-SCA", "D3-PA", "D3-EAW"], + # Literal argument value "T1059.005": ["D3-PSA", "D3-SCA", "D3-EAW"], + # Literal argument value "T1059.007": ["D3-PSA", "D3-SCA", "D3-EAW"], + # Literal argument value "T1055": ["D3-PA", "D3-PSA", "D3-HBPI", "D3-PMAD", "D3-PLA"], + # Literal argument value "T1055.001": ["D3-PA", "D3-PMAD", "D3-HBPI"], + # Literal argument value "T1055.002": ["D3-PA", "D3-PMAD", "D3-HBPI"], + # Literal argument value "T1003": ["D3-CH", "D3-CR", "D3-MFA", "D3-PMAD"], + # Literal argument value "T1003.001": ["D3-CH", "D3-CR", "D3-PMAD"], + # Literal argument value "T1078": ["D3-MFA", "D3-UBA", "D3-UGLPA", "D3-CH"], + # Literal argument value "T1078.001": ["D3-MFA", "D3-UBA", "D3-CH"], + # Literal argument value "T1566": ["D3-EAL", "D3-FA", "D3-FH", "D3-UA", "D3-EHR"], + # Literal argument value "T1566.001": ["D3-EAL", "D3-FA", "D3-FH", "D3-EHR"], + # Literal argument value "T1566.002": ["D3-UA", "D3-EAL", "D3-EHR"], + # Literal argument value "T1071": ["D3-AL", "D3-NTA", "D3-PM", "D3-CT"], + # Literal argument value "T1071.001": ["D3-AL", "D3-NTA", "D3-PM"], + # Literal argument value "T1053": ["D3-PSA", "D3-PA", "D3-SCHE", "D3-SSA"], + # Literal argument value "T1053.005": ["D3-PSA", "D3-SCHE", "D3-SSA"], + # Literal argument value "T1543": ["D3-SMRA", "D3-SSA", "D3-SBAN"], + # Literal argument value "T1543.003": ["D3-SMRA", "D3-SSA", "D3-SBAN"], + # Literal argument value "T1547": ["D3-SICA", "D3-SSA", "D3-RRID"], + # Literal argument value "T1547.001": ["D3-SICA", "D3-SSA", "D3-RRID"], + # Literal argument value "T1021": ["D3-RTSD", "D3-RPA", "D3-NTA", "D3-MFA"], + # Literal argument value "T1021.001": ["D3-RTSD", "D3-NTA", "D3-MFA"], + # Literal argument value "T1021.002": ["D3-RTSD", "D3-NTA", "D3-NI"], + # Literal argument value "T1560": ["D3-FA", "D3-FCA", "D3-ORA"], + # Literal argument value "T1560.001": ["D3-FA", "D3-FCA"], + # Literal argument value "T1048": ["D3-ORA", "D3-NTA", "D3-OTF"], + # Literal argument value "T1048.003": ["D3-ORA", "D3-NTA", "D3-OTF"], + # Literal argument value "T1105": ["D3-IRA", "D3-NTA", "D3-FA", "D3-FH"], + # Literal argument value "T1036": ["D3-FCA", "D3-FH", "D3-FA", "D3-SWI"], + # Literal argument value "T1036.005": ["D3-FCA", "D3-FH", "D3-FA"], + # Literal argument value "T1140": ["D3-FA", "D3-DA", "D3-SCA"], + # Literal argument value "T1070": ["D3-SSA", "D3-LOGA", "D3-SYSM"], + # Literal argument value "T1070.004": ["D3-SSA", "D3-FAPA"], + # Literal argument value "T1562": ["D3-SSA", "D3-SYSM", "D3-SMRA"], + # Literal argument value "T1562.001": ["D3-SSA", "D3-SYSM", "D3-SMRA"], + # Literal argument value "T1027": ["D3-DA", "D3-FA", "D3-RE"], + # Literal argument value "T1027.002": ["D3-DA", "D3-FA"], + # Literal argument value "T1110": ["D3-MFA", "D3-UBA", "D3-CH"], + # Literal argument value "T1110.001": ["D3-MFA", "D3-UBA", "D3-CH"], + # Literal argument value "T1082": ["D3-PSA", "D3-PA", "D3-SYSM"], + # Literal argument value "T1083": ["D3-FAPA", "D3-PA"], + # Literal argument value "T1497": ["D3-DA", "D3-SE"], + # Literal argument value "T1218": ["D3-PSA", "D3-PLA", "D3-EAW"], + # Literal argument value "T1218.011": ["D3-PSA", "D3-PLA", "D3-EAW"], + # Literal argument value "T1569": ["D3-SMRA", "D3-PSA", "D3-PA"], + # Literal argument value "T1569.002": ["D3-SMRA", "D3-PSA"], + # Literal argument value "T1012": ["D3-RRID", "D3-PA"], + # Literal argument value "T1112": ["D3-RRID", "D3-PA", "D3-REGG"], + # Literal argument value "T1057": ["D3-PA", "D3-PSA"], + # Literal argument value "T1518": ["D3-SYSM", "D3-PA"], + # Literal argument value "T1049": ["D3-NTA", "D3-PA"], + # Literal argument value "T1016": ["D3-NTA", "D3-PA", "D3-SYSM"], + # Literal argument value "T1033": ["D3-PA", "D3-UBA"], + # Literal argument value "T1087": ["D3-UBA", "D3-PA", "D3-SSA"], + # Literal argument value "T1087.001": ["D3-UBA", "D3-PA"], + # Literal argument value "T1087.002": ["D3-UBA", "D3-PA"], + # Literal argument value "T1018": ["D3-NTA", "D3-PA"], + # Literal argument value "T1047": ["D3-RPA", "D3-PSA", "D3-PA"], + # Literal argument value "T1190": ["D3-ISVA", "D3-NTA", "D3-AL"], + # Literal argument value "T1133": ["D3-NTA", "D3-MFA", "D3-RTSD"], + # Literal argument value "T1486": ["D3-BKUP", "D3-FBKP", "D3-ANTR", "D3-FA"], + # Literal argument value "T1490": ["D3-BKUP", "D3-FBKP", "D3-SSA"], + # Literal argument value "T1489": ["D3-SMRA", "D3-SSA"], + # Literal argument value "T1098": ["D3-UBA", "D3-SSA", "D3-PGOV"], + # Literal argument value "T1136": ["D3-UBA", "D3-SSA", "D3-UACM"], + # Literal argument value "T1136.001": ["D3-UBA", "D3-SSA", "D3-UACM"], + # Literal argument value "T1068": ["D3-SU", "D3-VULM", "D3-HBPI"], + # Literal argument value "T1548": ["D3-PSEP", "D3-PSA", "D3-PA"], + # Literal argument value "T1548.002": ["D3-PSEP", "D3-PSA"], + # Literal argument value "T1134": ["D3-PA", "D3-PSA", "D3-PSEP"], + # Literal argument value "T1134.001": ["D3-PA", "D3-PSA"], + # Literal argument value "T1574": ["D3-SWI", "D3-FCA", "D3-PLA"], + # Literal argument value "T1574.001": ["D3-SWI", "D3-FCA"], + # Literal argument value "T1204": ["D3-EAL", "D3-FA", "D3-UA"], + # Literal argument value "T1204.001": ["D3-UA", "D3-EAL"], + # Literal argument value "T1204.002": ["D3-FA", "D3-EAL", "D3-DA"], + # Literal argument value "T1071.004": ["D3-DPM", "D3-DNSSM", "D3-NTA"], + # Literal argument value "T1571": ["D3-NTA", "D3-PM", "D3-AL"], + # Literal argument value "T1572": ["D3-NTA", "D3-AL", "D3-PM"], + # Literal argument value "T1041": ["D3-ORA", "D3-NTA"], + # Literal argument value "T1005": ["D3-FAPA", "D3-PA"], + # Literal argument value "T1113": ["D3-PA", "D3-PSA"], + # Literal argument value "T1056": ["D3-PA", "D3-PSA", "D3-HBPI"], + # Literal argument value "T1056.001": ["D3-PA", "D3-PSA"], + # Literal argument value "T1560.003": ["D3-FA", "D3-ORA"], + # Literal argument value "T1583": ["D3-IPMR", "D3-DNSRA"], + # Literal argument value "T1584": ["D3-IPMR", "D3-DNSRA"], + # Literal argument value "T1595": ["D3-IRA", "D3-NTA"], + # Literal argument value "T1589": ["D3-UBA", "D3-THRT"], + # Literal argument value "T1590": ["D3-NTA", "D3-THRT"], + # Literal argument value "T1591": ["D3-THRT"], + # Literal argument value "T1592": ["D3-THRT"], } +# Define function import_d3fend_mappings def import_d3fend_mappings(db: Session) -> dict[str, int]: """Create ATT&CK → D3FEND mappings. First tries the D3FEND API for each ATT&CK technique in the DB, then falls back to the curated mapping for any remaining techniques. - Returns a dict with counts: {created, skipped, total}. + Args: + db (Session): Active SQLAlchemy database session. + + Returns: + dict[str, int]: Contains ``created``, ``skipped``, and ``total`` + mapping counts. """ + # Assign created = 0 created = 0 + # Assign skipped = 0 skipped = 0 # Get all ATT&CK techniques from the DB attack_techniques = db.query(Technique).all() + # Assign technique_map = {t.mitre_id: t for t in attack_techniques} technique_map = {t.mitre_id: t for t in attack_techniques} # Get all defensive techniques defensive_techniques = db.query(DefensiveTechnique).all() + # Assign d3fend_map = {dt.d3fend_id: dt for dt in defensive_techniques} d3fend_map = {dt.d3fend_id: dt for dt in defensive_techniques} + # Check: not d3fend_map if not d3fend_map: + # Log warning: "No D3FEND techniques in DB — run import_d3fend_te logger.warning("No D3FEND techniques in DB — run import_d3fend_techniques first") + # Return {"created": 0, "skipped": 0, "total": 0} return {"created": 0, "skipped": 0, "total": 0} # Use the curated mapping for now (API per-technique is very slow for 700+ techniques) for mitre_id, d3fend_ids in _ATTACK_TO_D3FEND.items(): + # Assign attack_tech = technique_map.get(mitre_id) attack_tech = technique_map.get(mitre_id) + # Check: not attack_tech if not attack_tech: + # Skip to the next loop iteration continue + # Iterate over d3fend_ids for d3fend_id in d3fend_ids: + # Assign def_tech = d3fend_map.get(d3fend_id) def_tech = d3fend_map.get(d3fend_id) + # Check: not def_tech if not def_tech: + # Skip to the next loop iteration continue # Check if mapping already exists existing = ( db.query(DefensiveTechniqueMapping) + # Chain .filter() call .filter( DefensiveTechniqueMapping.attack_technique_id == attack_tech.id, DefensiveTechniqueMapping.defensive_technique_id == def_tech.id, ) + # Chain .first() call .first() ) + # Check: existing if existing: + # Assign skipped = 1 skipped += 1 + # Skip to the next loop iteration continue + # Assign mapping = DefensiveTechniqueMapping( mapping = DefensiveTechniqueMapping( + # Keyword argument: attack_technique_id attack_technique_id=attack_tech.id, + # Keyword argument: defensive_technique_id defensive_technique_id=def_tech.id, ) + # Stage new record(s) for database insertion db.add(mapping) + # Assign created = 1 created += 1 + # Commit all pending changes to the database db.commit() + # Assign total = db.query(DefensiveTechniqueMapping).count() total = db.query(DefensiveTechniqueMapping).count() + # Log info: "D3FEND mappings: %d created, %d skipped, %d total logger.info("D3FEND mappings: %d created, %d skipped, %d total", created, skipped, total) + # Return {"created": created, "skipped": skipped, "total": total} return {"created": created, "skipped": skipped, "total": total} +# Define function sync def sync(db: Session) -> dict: """Sync D3FEND techniques and ATT&CK mappings. Called by the Data Sources router when the user clicks Sync for D3FEND. - Returns a flat summary dict suitable for ``last_sync_stats``. + + Args: + db (Session): Active SQLAlchemy database session. + + Returns: + dict: Flat summary dict suitable for ``last_sync_stats``, containing + ``techniques_created``, ``techniques_updated``, + ``techniques_total``, ``mappings_created``, + ``mappings_skipped``, and ``mappings_total``. """ - from app.models.data_source import DataSource + # Import datetime from datetime from datetime import datetime + # Import DataSource from app.models.data_source + from app.models.data_source import DataSource + + # Assign tech_result = import_d3fend_techniques(db) tech_result = import_d3fend_techniques(db) + # Assign mapping_result = import_d3fend_mappings(db) mapping_result = import_d3fend_mappings(db) + # Assign summary = { summary = { + # Literal argument value "techniques_created": tech_result.get("created", 0), + # Literal argument value "techniques_updated": tech_result.get("updated", 0), + # Literal argument value "techniques_total": tech_result.get("total", 0), + # Literal argument value "mappings_created": mapping_result.get("created", 0), + # Literal argument value "mappings_skipped": mapping_result.get("skipped", 0), + # Literal argument value "mappings_total": mapping_result.get("total", 0), } # Update DataSource record ds = db.query(DataSource).filter(DataSource.name == "d3fend").first() + # Check: ds if ds: + # Assign ds.last_sync_at = datetime.utcnow() ds.last_sync_at = datetime.utcnow() + # Assign ds.last_sync_status = "success" ds.last_sync_status = "success" + # Assign ds.last_sync_stats = summary ds.last_sync_stats = summary + # Commit all pending changes to the database db.commit() + # Log info: "D3FEND sync complete — %s", summary logger.info("D3FEND sync complete — %s", summary) + # Return summary return summary -def get_defenses_for_technique(db: Session, technique_id) -> list[dict]: - """Get all D3FEND defensive techniques mapped to a given ATT&CK technique.""" +# Define function get_defenses_for_technique +def get_defenses_for_technique(db: Session, technique_id: UUID) -> list[dict]: + """Return all D3FEND defensive techniques mapped to a given ATT&CK technique. + + Args: + db (Session): Active SQLAlchemy database session. + technique_id (UUID): UUID of the ATT&CK technique to look up. + + Returns: + list[dict]: List of defensive technique dicts, each containing + ``id``, ``d3fend_id``, ``name``, ``description``, ``tactic``, + and ``d3fend_url``. + """ + # Assign mappings = ( mappings = ( db.query(DefensiveTechniqueMapping) + # Chain .filter() call .filter(DefensiveTechniqueMapping.attack_technique_id == technique_id) + # Chain .all() call .all() ) + # Assign results = [] results = [] + # Iterate over mappings for m in mappings: + # Assign dt = m.defensive_technique dt = m.defensive_technique + # Call results.append() results.append({ + # Literal argument value "id": str(dt.id), + # Literal argument value "d3fend_id": dt.d3fend_id, + # Literal argument value "name": dt.name, + # Literal argument value "description": dt.description, + # Literal argument value "tactic": dt.tactic, + # Literal argument value "d3fend_url": dt.d3fend_url, }) + # Return results return results diff --git a/backend/app/services/d3fend_query_service.py b/backend/app/services/d3fend_query_service.py index dc433ee..1139a67 100644 --- a/backend/app/services/d3fend_query_service.py +++ b/backend/app/services/d3fend_query_service.py @@ -1,53 +1,92 @@ """D3FEND query service — framework-agnostic queries for defensive techniques.""" +# Enable future language features for compatibility from __future__ import annotations +# Import Optional from typing from typing import Optional +# Import func from sqlalchemy from sqlalchemy import func + +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session +# Import EntityNotFoundError from app.domain.errors from app.domain.errors import EntityNotFoundError + +# Import DefensiveTechnique from app.models.defensive_technique from app.models.defensive_technique import DefensiveTechnique + +# Import Technique from app.models.technique from app.models.technique import Technique + +# Import get_defenses_for_technique from app.services.d3fend_import_service from app.services.d3fend_import_service import get_defenses_for_technique + +# Import escape_like from app.utils from app.utils import escape_like +# Define function list_defensive_techniques def list_defensive_techniques( + # Entry: db db: Session, *, + # Entry: tactic tactic: Optional[str] = None, + # Entry: search search: Optional[str] = None, + # Entry: offset offset: int = 0, + # Entry: limit limit: int = 50, ) -> dict: """List D3FEND defensive techniques with optional filters.""" + # Assign query = db.query(DefensiveTechnique) query = db.query(DefensiveTechnique) + # Check: tactic if tactic: + # Assign query = query.filter(DefensiveTechnique.tactic == tactic) query = query.filter(DefensiveTechnique.tactic == tactic) + # Check: search if search: + # Assign pattern = f"%{escape_like(search)}%" pattern = f"%{escape_like(search)}%" + # Assign query = query.filter( query = query.filter( DefensiveTechnique.name.ilike(pattern) | DefensiveTechnique.d3fend_id.ilike(pattern) ) + # Assign total = query.count() total = query.count() + # Assign items = query.order_by(DefensiveTechnique.d3fend_id).offset(offset).limit(l... items = query.order_by(DefensiveTechnique.d3fend_id).offset(offset).limit(limit).all() + # Return { return { + # Literal argument value "total": total, + # Literal argument value "offset": offset, + # Literal argument value "limit": limit, + # Literal argument value "items": [ { + # Literal argument value "id": str(dt.id), + # Literal argument value "d3fend_id": dt.d3fend_id, + # Literal argument value "name": dt.name, + # Literal argument value "description": dt.description, + # Literal argument value "tactic": dt.tactic, + # Literal argument value "d3fend_url": dt.d3fend_url, } for dt in items @@ -55,28 +94,44 @@ def list_defensive_techniques( } +# Define function list_d3fend_tactics def list_d3fend_tactics(db: Session) -> list[dict]: """Return a list of all D3FEND tactics with counts.""" + # Assign rows = ( rows = ( db.query(DefensiveTechnique.tactic, func.count(DefensiveTechnique.id)) + # Chain .group_by() call .group_by(DefensiveTechnique.tactic) + # Chain .order_by() call .order_by(DefensiveTechnique.tactic) + # Chain .all() call .all() ) + # Return [{"tactic": tactic or "Unknown", "count": count} for tactic, count ... return [{"tactic": tactic or "Unknown", "count": count} for tactic, count in rows] +# Define function get_defenses_for_attack_technique def get_defenses_for_attack_technique(db: Session, mitre_id: str) -> dict: """Get all D3FEND defensive techniques mapped to a given ATT&CK technique.""" + # Assign technique = db.query(Technique).filter(Technique.mitre_id == mitre_id).first() technique = db.query(Technique).filter(Technique.mitre_id == mitre_id).first() + # Check: technique is None if technique is None: + # Raise EntityNotFoundError raise EntityNotFoundError("Technique", mitre_id) + # Assign defenses = get_defenses_for_technique(db, technique.id) defenses = get_defenses_for_technique(db, technique.id) + # Return { return { + # Literal argument value "mitre_id": mitre_id, + # Literal argument value "technique_name": technique.name, + # Literal argument value "defenses": defenses, + # Literal argument value "total": len(defenses), } diff --git a/backend/app/services/data_source_service.py b/backend/app/services/data_source_service.py index 25ad0ca..da449b2 100644 --- a/backend/app/services/data_source_service.py +++ b/backend/app/services/data_source_service.py @@ -4,61 +4,99 @@ Provides list, update, sync, and stats. Sync operations commit internally since they are long-running and self-contained. """ +# Enable future language features for compatibility from __future__ import annotations +# Import logging import logging + +# Import datetime from datetime from datetime import datetime +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session +# Import BusinessRuleViolation, EntityNotFoundError from app.domain.errors from app.domain.errors import BusinessRuleViolation, EntityNotFoundError + +# Import get_import_handler from app.domain.ports.import_service from app.domain.ports.import_service import get_import_handler + +# Import DataSource from app.models.data_source from app.models.data_source import DataSource +# Assign logger = logging.getLogger(__name__) logger = logging.getLogger(__name__) +# Define function list_sources def list_sources(db: Session) -> list[dict]: """Return all registered data sources as a list of dicts.""" + # Assign sources = db.query(DataSource).order_by(DataSource.name).all() sources = db.query(DataSource).order_by(DataSource.name).all() + # Return [ return [ { + # Literal argument value "id": str(s.id), + # Literal argument value "name": s.name, + # Literal argument value "display_name": s.display_name, + # Literal argument value "type": s.type, + # Literal argument value "url": s.url, + # Literal argument value "description": s.description, + # Literal argument value "is_enabled": s.is_enabled, + # Literal argument value "last_sync_at": s.last_sync_at.isoformat() if s.last_sync_at else None, + # Literal argument value "last_sync_status": s.last_sync_status, + # Literal argument value "last_sync_stats": s.last_sync_stats, + # Literal argument value "sync_frequency": s.sync_frequency, + # Literal argument value "config": s.config, + # Literal argument value "created_at": s.created_at.isoformat() if s.created_at else None, } for s in sources ] +# Define function update_source def update_source(db: Session, source_id: str, **fields: object) -> None: """Update a data source's fields (is_enabled, sync_frequency, config). Raises EntityNotFoundError if source does not exist. Does not commit; the router handles that. """ + # Assign ds = db.query(DataSource).filter(DataSource.id == source_id).first() ds = db.query(DataSource).filter(DataSource.id == source_id).first() + # Check: not ds if not ds: + # Raise EntityNotFoundError raise EntityNotFoundError("Data source", source_id) + # Check: "is_enabled" in fields if "is_enabled" in fields: + # Assign ds.is_enabled = fields["is_enabled"] ds.is_enabled = fields["is_enabled"] + # Check: "sync_frequency" in fields if "sync_frequency" in fields: + # Assign ds.sync_frequency = fields["sync_frequency"] ds.sync_frequency = fields["sync_frequency"] + # Check: "config" in fields if "config" in fields: + # Assign ds.config = fields["config"] ds.config = fields["config"] +# Define function sync_source def sync_source(db: Session, source_id: str) -> dict: """Trigger sync for a specific data source. @@ -67,131 +105,221 @@ def sync_source(db: Session, source_id: str) -> dict: Commits internally (long-running, self-contained operation). Returns dict with message, source, stats. """ + # Assign ds = db.query(DataSource).filter(DataSource.id == source_id).first() ds = db.query(DataSource).filter(DataSource.id == source_id).first() + # Check: not ds if not ds: + # Raise EntityNotFoundError raise EntityNotFoundError("Data source", source_id) + # Assign handler = get_import_handler(ds.name) handler = get_import_handler(ds.name) + # Check: handler is None if handler is None: + # Raise BusinessRuleViolation raise BusinessRuleViolation(f"No sync handler available for '{ds.name}'") + # Assign ds.last_sync_status = "in_progress" ds.last_sync_status = "in_progress" + # Commit all pending changes to the database db.commit() + # Attempt the following; catch errors below try: + # Assign summary = handler(db) summary = handler(db) + # Handle Exception except Exception as exc: + # Log error: "Sync failed for %s: %s", ds.name, exc, exc_info=T logger.error("Sync failed for %s: %s", ds.name, exc, exc_info=True) + # Assign ds.last_sync_status = "error" ds.last_sync_status = "error" + # Assign ds.last_sync_at = datetime.utcnow() ds.last_sync_at = datetime.utcnow() + # Assign ds.last_sync_stats = {"error": str(exc)} ds.last_sync_stats = {"error": str(exc)} + # Commit all pending changes to the database db.commit() + # Raise BusinessRuleViolation raise BusinessRuleViolation( f"Sync failed for '{ds.display_name}'. Check server logs for details." ) + # Assign ds.last_sync_at = datetime.utcnow() ds.last_sync_at = datetime.utcnow() + # Assign ds.last_sync_status = "success" ds.last_sync_status = "success" + # Assign ds.last_sync_stats = summary ds.last_sync_stats = summary + # Commit all pending changes to the database db.commit() + # Return { return { + # Literal argument value "message": f"Sync complete for {ds.display_name}", + # Literal argument value "source": ds.name, + # Literal argument value "stats": summary, } +# Define function sync_all_sources def sync_all_sources(db: Session) -> list[dict]: """Trigger sync for all enabled data sources (sequentially). Commits internally (long-running, self-contained operation). Returns list of result dicts with source, status, stats/detail. """ + # Assign enabled_sources = ( enabled_sources = ( db.query(DataSource) + # Chain .filter() call .filter(DataSource.is_enabled == True) + # Chain .order_by() call .order_by(DataSource.name) + # Chain .all() call .all() ) + # Assign results = [] results = [] + # Iterate over enabled_sources for ds in enabled_sources: + # Assign handler = get_import_handler(ds.name) handler = get_import_handler(ds.name) + # Check: handler is None if handler is None: + # Call results.append() results.append({ + # Literal argument value "source": ds.name, + # Literal argument value "status": "skipped", + # Literal argument value "detail": "No sync handler available", }) + # Skip to the next loop iteration continue + # Assign ds.last_sync_status = "in_progress" ds.last_sync_status = "in_progress" + # Commit all pending changes to the database db.commit() + # Attempt the following; catch errors below try: + # Assign summary = handler(db) summary = handler(db) + # Assign ds.last_sync_at = datetime.utcnow() ds.last_sync_at = datetime.utcnow() + # Assign ds.last_sync_status = "success" ds.last_sync_status = "success" + # Assign ds.last_sync_stats = summary ds.last_sync_stats = summary + # Commit all pending changes to the database db.commit() + # Call results.append() results.append({ + # Literal argument value "source": ds.name, + # Literal argument value "status": "success", + # Literal argument value "stats": summary, }) + # Handle Exception except Exception as exc: + # Log error: "Sync failed for %s: %s", ds.name, exc, exc_info=T logger.error("Sync failed for %s: %s", ds.name, exc, exc_info=True) + # Assign ds.last_sync_status = "error" ds.last_sync_status = "error" + # Assign ds.last_sync_at = datetime.utcnow() ds.last_sync_at = datetime.utcnow() + # Assign ds.last_sync_stats = {"error": str(exc)} ds.last_sync_stats = {"error": str(exc)} + # Commit all pending changes to the database db.commit() + # Call results.append() results.append({ + # Literal argument value "source": ds.name, + # Literal argument value "status": "error", + # Literal argument value "detail": "Sync failed. Check server logs for details.", }) + # Return results return results +# Define function get_source_stats def get_source_stats(db: Session, source_id: str) -> dict: """Return detailed statistics for a data source. Raises EntityNotFoundError if source does not exist. """ + # Assign ds = db.query(DataSource).filter(DataSource.id == source_id).first() ds = db.query(DataSource).filter(DataSource.id == source_id).first() + # Check: not ds if not ds: + # Raise EntityNotFoundError raise EntityNotFoundError("Data source", source_id) - from app.models.test_template import TestTemplate + # Import DetectionRule from app.models.detection_rule from app.models.detection_rule import DetectionRule + # Import TestTemplate from app.models.test_template + from app.models.test_template import TestTemplate + + # Assign template_count = 0 template_count = 0 + # Assign rule_count = 0 rule_count = 0 + # Check: ds.type == "attack_procedure" if ds.type == "attack_procedure": + # Assign template_count = ( template_count = ( db.query(TestTemplate) + # Chain .filter() call .filter(TestTemplate.source == ds.name) + # Chain .count() call .count() ) + # Alternative: ds.type == "detection_rule" elif ds.type == "detection_rule": + # Assign rule_count = ( rule_count = ( db.query(DetectionRule) + # Chain .filter() call .filter(DetectionRule.source == ds.name) + # Chain .count() call .count() ) + # Return { return { + # Literal argument value "id": str(ds.id), + # Literal argument value "name": ds.name, + # Literal argument value "display_name": ds.display_name, + # Literal argument value "type": ds.type, + # Literal argument value "is_enabled": ds.is_enabled, + # Literal argument value "last_sync_at": ds.last_sync_at.isoformat() if ds.last_sync_at else None, + # Literal argument value "last_sync_status": ds.last_sync_status, + # Literal argument value "last_sync_stats": ds.last_sync_stats, + # Literal argument value "total_templates": template_count, + # Literal argument value "total_rules": rule_count, } diff --git a/backend/app/services/detection_rule_service.py b/backend/app/services/detection_rule_service.py index 18c4d80..ec08384 100644 --- a/backend/app/services/detection_rule_service.py +++ b/backend/app/services/detection_rule_service.py @@ -6,76 +6,136 @@ that the router remains a thin HTTP adapter. This module is framework-agnostic: no FastAPI imports. """ +# Enable future language features for compatibility from __future__ import annotations +# Import datetime from datetime from datetime import datetime + +# Import Any from typing from typing import Any +# Import UUID from uuid +from uuid import UUID + +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session +# Import EntityNotFoundError from app.domain.errors from app.domain.errors import EntityNotFoundError -from app.models.detection_rule import DetectionRule -from app.models.test import Test -from app.models.test_template import TestTemplate -from app.models.test_template_detection_rule import TestTemplateDetectionRule -from app.models.test_detection_result import TestDetectionResult -from app.models.technique import Technique -from app.utils import escape_like +# Import DetectionRule from app.models.detection_rule +from app.models.detection_rule import DetectionRule + +# Import Technique from app.models.technique +from app.models.technique import Technique + +# Import Test from app.models.test +from app.models.test import Test + +# Import TestDetectionResult from app.models.test_detection_result +from app.models.test_detection_result import TestDetectionResult + +# Import TestTemplate from app.models.test_template +from app.models.test_template import TestTemplate + +# Import TestTemplateDetectionRule from app.models.test_template_detection_rule +from app.models.test_template_detection_rule import TestTemplateDetectionRule + +# Import escape_like from app.utils +from app.utils import escape_like # ── Public service functions ────────────────────────────────────────── def list_rules( + # Entry: db db: Session, *, + # Entry: technique technique: str | None = None, + # Entry: source source: str | None = None, + # Entry: severity severity: str | None = None, + # Entry: search search: str | None = None, + # Entry: offset offset: int = 0, + # Entry: limit limit: int = 50, ) -> dict[str, Any]: """List detection rules with optional filters and pagination.""" + # Assign query = db.query(DetectionRule).filter(DetectionRule.is_active == True) query = db.query(DetectionRule).filter(DetectionRule.is_active == True) + # Check: technique if technique: + # Assign query = query.filter(DetectionRule.mitre_technique_id == technique) query = query.filter(DetectionRule.mitre_technique_id == technique) + # Check: source if source: + # Assign query = query.filter(DetectionRule.source == source) query = query.filter(DetectionRule.source == source) + # Check: severity if severity: + # Assign query = query.filter(DetectionRule.severity == severity) query = query.filter(DetectionRule.severity == severity) + # Check: search if search: + # Assign pattern = f"%{escape_like(search)}%" pattern = f"%{escape_like(search)}%" + # Assign query = query.filter( query = query.filter( DetectionRule.title.ilike(pattern) | DetectionRule.description.ilike(pattern) ) + # Assign total = query.count() total = query.count() + # Assign items = ( items = ( query.order_by(DetectionRule.mitre_technique_id, DetectionRule.title) + # Chain .offset() call .offset(offset) + # Chain .limit() call .limit(limit) + # Chain .all() call .all() ) + # Return { return { + # Literal argument value "total": total, + # Literal argument value "offset": offset, + # Literal argument value "limit": limit, + # Literal argument value "items": [ { + # Literal argument value "id": str(r.id), + # Literal argument value "mitre_technique_id": r.mitre_technique_id, + # Literal argument value "title": r.title, + # Literal argument value "description": r.description, + # Literal argument value "source": r.source, + # Literal argument value "source_url": r.source_url, + # Literal argument value "rule_format": r.rule_format, + # Literal argument value "severity": r.severity, + # Literal argument value "platforms": r.platforms or [], + # Literal argument value "log_sources": r.log_sources, + # Literal argument value "is_active": r.is_active, } for r in items @@ -83,48 +143,78 @@ def list_rules( } +# Define function get_rules_for_template def get_rules_for_template(db: Session, template_id: str) -> dict[str, Any]: """Get detection rules associated with a test template. Raises EntityNotFoundError if the template does not exist. """ + # Assign template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first() template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first() + # Check: not template if not template: + # Raise EntityNotFoundError raise EntityNotFoundError("Test template", template_id) + # Assign associations = ( associations = ( db.query(TestTemplateDetectionRule) + # Chain .filter() call .filter(TestTemplateDetectionRule.test_template_id == template_id) + # Chain .all() call .all() ) + # Assign rules = [] rules = [] + # Iterate over associations for assoc in associations: + # Assign r = assoc.detection_rule r = assoc.detection_rule + # Call rules.append() rules.append({ + # Literal argument value "id": str(r.id), + # Literal argument value "mitre_technique_id": r.mitre_technique_id, + # Literal argument value "title": r.title, + # Literal argument value "description": r.description, + # Literal argument value "source": r.source, + # Literal argument value "source_url": r.source_url, + # Literal argument value "rule_content": r.rule_content, + # Literal argument value "rule_format": r.rule_format, + # Literal argument value "severity": r.severity, + # Literal argument value "platforms": r.platforms or [], + # Literal argument value "log_sources": r.log_sources, + # Literal argument value "is_primary": assoc.is_primary, }) + # Return { return { + # Literal argument value "template_id": str(template.id), + # Literal argument value "template_name": template.name, + # Literal argument value "mitre_technique_id": template.mitre_technique_id, + # Literal argument value "rules": rules, + # Literal argument value "total": len(rules), } +# Define function auto_associate_rules def auto_associate_rules(db: Session) -> dict[str, Any]: """Auto-associate test templates with detection rules by MITRE technique ID. @@ -132,188 +222,316 @@ def auto_associate_rules(db: Session) -> dict[str, Any]: technique and creates associations. Rules with severity high/critical are marked as primary. Performs commit internally. """ + # Assign templates = db.query(TestTemplate).filter(TestTemplate.is_active == True).all() templates = db.query(TestTemplate).filter(TestTemplate.is_active == True).all() + # Assign rules = db.query(DetectionRule).filter(DetectionRule.is_active == True).all() rules = db.query(DetectionRule).filter(DetectionRule.is_active == True).all() + # Assign rules_by_technique = {} rules_by_technique: dict[str, list] = {} + # Iterate over rules for rule in rules: + # Assign tid = rule.mitre_technique_id tid = rule.mitre_technique_id + # Check: tid not in rules_by_technique if tid not in rules_by_technique: + # Assign rules_by_technique[tid] = [] rules_by_technique[tid] = [] + # rules_by_technique[tid].append(rule) rules_by_technique[tid].append(rule) + # Assign created = 0 created = 0 + # Assign skipped = 0 skipped = 0 + # Assign high_severities = {"high", "critical"} high_severities = {"high", "critical"} + # Iterate over templates for template in templates: + # Assign matching_rules = rules_by_technique.get(template.mitre_technique_id, []) matching_rules = rules_by_technique.get(template.mitre_technique_id, []) + # Iterate over matching_rules for rule in matching_rules: + # Assign existing = ( existing = ( db.query(TestTemplateDetectionRule) + # Chain .filter() call .filter( TestTemplateDetectionRule.test_template_id == template.id, TestTemplateDetectionRule.detection_rule_id == rule.id, ) + # Chain .first() call .first() ) + # Check: existing if existing: + # Assign skipped = 1 skipped += 1 + # Skip to the next loop iteration continue + # Assign is_primary = (rule.severity or "").lower() in high_severities is_primary = (rule.severity or "").lower() in high_severities + # Assign assoc = TestTemplateDetectionRule( assoc = TestTemplateDetectionRule( + # Keyword argument: test_template_id test_template_id=template.id, + # Keyword argument: detection_rule_id detection_rule_id=rule.id, + # Keyword argument: is_primary is_primary=is_primary, ) + # Stage new record(s) for database insertion db.add(assoc) + # Assign created = 1 created += 1 + # Commit all pending changes to the database db.commit() + # Assign total = db.query(TestTemplateDetectionRule).count() total = db.query(TestTemplateDetectionRule).count() + # Return { return { + # Literal argument value "created": created, + # Literal argument value "skipped": skipped, + # Literal argument value "total_associations": total, } +# Define function get_rules_for_test def get_rules_for_test(db: Session, test_id: str) -> dict[str, Any]: """Get detection rules relevant to a test, along with their evaluation results. Finds rules by matching the test's technique to detection rules. Raises EntityNotFoundError if the test or its technique does not exist. """ + # Assign test = db.query(Test).filter(Test.id == test_id).first() test = db.query(Test).filter(Test.id == test_id).first() + # Check: not test if not test: + # Raise EntityNotFoundError raise EntityNotFoundError("Test", str(test_id)) + # Assign technique = db.query(Technique).filter(Technique.id == test.technique_id).first() technique = db.query(Technique).filter(Technique.id == test.technique_id).first() + # Check: not technique if not technique: + # Raise EntityNotFoundError raise EntityNotFoundError("Technique", str(test.technique_id)) + # Assign rules = ( rules = ( db.query(DetectionRule) + # Chain .filter() call .filter( DetectionRule.mitre_technique_id == technique.mitre_id, DetectionRule.is_active == True, ) + # Chain .order_by() call .order_by(DetectionRule.severity.desc(), DetectionRule.title) + # Chain .all() call .all() ) + # Assign existing_results = ( existing_results = ( db.query(TestDetectionResult) + # Chain .filter() call .filter(TestDetectionResult.test_id == test_id) + # Chain .all() call .all() ) + # Assign results_map = {str(r.detection_rule_id): r for r in existing_results} results_map = {str(r.detection_rule_id): r for r in existing_results} + # Assign items = [] items = [] + # Assign triggered_count = 0 triggered_count = 0 + # Assign evaluated_count = 0 evaluated_count = 0 + # Iterate over rules for rule in rules: + # Assign result = results_map.get(str(rule.id)) result = results_map.get(str(rule.id)) + # Assign triggered = result.triggered if result else None triggered = result.triggered if result else None + # Assign notes = result.notes if result else None notes = result.notes if result else None + # Assign evaluated_at = result.evaluated_at.isoformat() if result and result.evaluated_at e... evaluated_at = result.evaluated_at.isoformat() if result and result.evaluated_at else None + # Check: triggered is not None if triggered is not None: + # Assign evaluated_count = 1 evaluated_count += 1 + # Check: triggered if triggered: + # Assign triggered_count = 1 triggered_count += 1 + # Call items.append() items.append({ + # Literal argument value "id": str(rule.id), + # Literal argument value "mitre_technique_id": rule.mitre_technique_id, + # Literal argument value "title": rule.title, + # Literal argument value "description": rule.description, + # Literal argument value "source": rule.source, + # Literal argument value "source_url": rule.source_url, + # Literal argument value "rule_content": rule.rule_content, + # Literal argument value "rule_format": rule.rule_format, + # Literal argument value "severity": rule.severity, + # Literal argument value "platforms": rule.platforms or [], + # Literal argument value "log_sources": rule.log_sources, + # Literal argument value "triggered": triggered, + # Literal argument value "notes": notes, + # Literal argument value "evaluated_at": evaluated_at, + # Literal argument value "result_id": str(result.id) if result else None, }) + # Return { return { + # Literal argument value "test_id": str(test.id), + # Literal argument value "mitre_technique_id": technique.mitre_id, + # Literal argument value "rules": items, + # Literal argument value "total": len(items), + # Literal argument value "evaluated": evaluated_count, + # Literal argument value "triggered": triggered_count, + # Literal argument value "detection_rate": round(triggered_count / evaluated_count * 100, 1) if evaluated_count > 0 else 0, } +# Define function evaluate_rule def evaluate_rule( + # Entry: db db: Session, *, - test_id: Any, - detection_rule_id: Any, + # Entry: test_id + test_id: UUID, + # Entry: detection_rule_id + detection_rule_id: UUID, + # Entry: triggered triggered: bool | None, + # Entry: notes notes: str | None, - evaluator_id: Any, + # Entry: evaluator_id + evaluator_id: UUID, ) -> dict[str, Any]: """Save or update the evaluation result for a detection rule on a test. Raises EntityNotFoundError if the test or detection rule does not exist. """ + # Assign test = db.query(Test).filter(Test.id == test_id).first() test = db.query(Test).filter(Test.id == test_id).first() + # Check: not test if not test: + # Raise EntityNotFoundError raise EntityNotFoundError("Test", str(test_id)) + # Assign rule = db.query(DetectionRule).filter(DetectionRule.id == detection_rule_i... rule = db.query(DetectionRule).filter(DetectionRule.id == detection_rule_id).first() + # Check: not rule if not rule: + # Raise EntityNotFoundError raise EntityNotFoundError("Detection rule", str(detection_rule_id)) + # Assign existing = ( existing = ( db.query(TestDetectionResult) + # Chain .filter() call .filter( TestDetectionResult.test_id == test_id, TestDetectionResult.detection_rule_id == detection_rule_id, ) + # Chain .first() call .first() ) + # Check: existing if existing: + # Assign existing.triggered = triggered existing.triggered = triggered + # Assign existing.notes = notes existing.notes = notes + # Assign existing.evaluated_by = evaluator_id existing.evaluated_by = evaluator_id + # Assign existing.evaluated_at = datetime.utcnow() existing.evaluated_at = datetime.utcnow() + # Commit all pending changes to the database db.commit() + # Reload ORM object attributes from the database db.refresh(existing) + # Return { return { + # Literal argument value "id": str(existing.id), + # Literal argument value "triggered": existing.triggered, + # Literal argument value "notes": existing.notes, + # Literal argument value "evaluated_at": existing.evaluated_at.isoformat() if existing.evaluated_at else None, } + # Fallback: handle remaining cases else: + # Assign result = TestDetectionResult( result = TestDetectionResult( + # Keyword argument: test_id test_id=test_id, + # Keyword argument: detection_rule_id detection_rule_id=detection_rule_id, + # Keyword argument: triggered triggered=triggered, + # Keyword argument: notes notes=notes, + # Keyword argument: evaluated_by evaluated_by=evaluator_id, + # Keyword argument: evaluated_at evaluated_at=datetime.utcnow(), ) + # Stage new record(s) for database insertion db.add(result) + # Commit all pending changes to the database db.commit() + # Reload ORM object attributes from the database db.refresh(result) + # Return { return { + # Literal argument value "id": str(result.id), + # Literal argument value "triggered": result.triggered, + # Literal argument value "notes": result.notes, + # Literal argument value "evaluated_at": result.evaluated_at.isoformat() if result.evaluated_at else None, } diff --git a/backend/app/services/elastic_import_service.py b/backend/app/services/elastic_import_service.py index b5799ed..8707340 100644 --- a/backend/app/services/elastic_import_service.py +++ b/backend/app/services/elastic_import_service.py @@ -21,22 +21,39 @@ rules are identified by ``source = "elastic"`` + ``source_id`` (the TOML filename). """ +# Import io import io + +# Import logging import logging + +# Import shutil import shutil + +# Import tempfile import tempfile + +# Import zipfile import zipfile + +# Import datetime from datetime from datetime import datetime + +# Import Path from pathlib from pathlib import Path +# Import requests import requests as _requests + +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session -from app.models.detection_rule import DetectionRule +# Import DataSource from app.models.data_source from app.models.data_source import DataSource from app.models.technique import Technique from app.services.audit_service import log_action +# Assign logger = logging.getLogger(__name__) logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- @@ -44,19 +61,33 @@ logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- ELASTIC_ZIP_URL = ( + # Literal argument value "https://github.com/elastic/detection-rules" + # Literal argument value "/archive/refs/heads/main.zip" ) +# Assign _DOWNLOAD_TIMEOUT = 300 _DOWNLOAD_TIMEOUT = 300 +# Assign _ZIP_ROOT_PREFIX = "detection-rules-main" _ZIP_ROOT_PREFIX = "detection-rules-main" +# Safety limits for ZIP extraction — prevent zip-bomb DoS +_MAX_UNCOMPRESSED_SIZE = 500 * 1024 * 1024 # 500 MB +# Assign _MAX_ENTRIES = 50_000 +_MAX_ENTRIES = 50_000 + # Severity normalisation _SEVERITY_MAP = { + # Literal argument value "informational": "informational", + # Literal argument value "low": "low", + # Literal argument value "medium": "medium", + # Literal argument value "high": "high", + # Literal argument value "critical": "critical", } @@ -68,14 +99,21 @@ _SEVERITY_MAP = { def _download_zip(url: str = ELASTIC_ZIP_URL) -> bytes: """Download the Elastic Detection Rules ZIP and return raw bytes.""" + # Log info: "Downloading Elastic Detection Rules ZIP from %s … logger.info("Downloading Elastic Detection Rules ZIP from %s …", url) + # Assign resp = _requests.get(url, timeout=_DOWNLOAD_TIMEOUT, stream=True) resp = _requests.get(url, timeout=_DOWNLOAD_TIMEOUT, stream=True) + # Call resp.raise_for_status() resp.raise_for_status() + # Assign content = resp.content content = resp.content + # Log info: "Downloaded %.1f MB", len(content) / (1024 * 1024 logger.info("Downloaded %.1f MB", len(content) / (1024 * 1024)) + # Return content return content +# Define function _safe_extract_zip def _safe_extract_zip(zip_bytes: bytes, dest: str) -> None: """Extract *zip_bytes* into *dest* with Zip Slip and Zip Bomb protection. @@ -83,62 +121,85 @@ def _safe_extract_zip(zip_bytes: bytes, dest: str) -> None: directory (path traversal / Zip Slip) or if the archive exceeds the safety limits. """ - # Maximum uncompressed size: 500 MB — prevents zip-bomb DoS - _MAX_UNCOMPRESSED_SIZE = 500 * 1024 * 1024 - # Maximum number of entries - _MAX_ENTRIES = 50_000 - + # Assign dest_path = Path(dest).resolve() dest_path = Path(dest).resolve() + # Open context manager with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf: + # Assign entries = zf.infolist() entries = zf.infolist() + # Check: len(entries) > _MAX_ENTRIES if len(entries) > _MAX_ENTRIES: + # Raise ValueError raise ValueError( f"ZIP archive contains {len(entries)} entries " f"(limit: {_MAX_ENTRIES}) — possible zip bomb" ) + # Assign total_size = sum(info.file_size for info in entries) total_size = sum(info.file_size for info in entries) + # Check: total_size > _MAX_UNCOMPRESSED_SIZE if total_size > _MAX_UNCOMPRESSED_SIZE: + # Raise ValueError raise ValueError( f"ZIP uncompressed size {total_size / (1024 * 1024):.0f} MB " f"exceeds limit of {_MAX_UNCOMPRESSED_SIZE / (1024 * 1024):.0f} MB" ) + # Iterate over entries for member in entries: + # Assign target = (dest_path / member.filename).resolve() target = (dest_path / member.filename).resolve() + # Check: not target.is_relative_to(dest_path) if not target.is_relative_to(dest_path): + # Raise ValueError raise ValueError( f"Zip Slip detected — member '{member.filename}' " f"resolves outside target directory" ) + # Call zf.extractall() zf.extractall(dest) +# Define function _extract_zip def _extract_zip(zip_bytes: bytes, dest: str) -> Path: """Extract *zip_bytes* into *dest* and return rules/ dir.""" + # Call _safe_extract_zip() _safe_extract_zip(zip_bytes, dest) + # Assign rules_dir = Path(dest) / _ZIP_ROOT_PREFIX / "rules" rules_dir = Path(dest) / _ZIP_ROOT_PREFIX / "rules" + # Check: not rules_dir.is_dir() if not rules_dir.is_dir(): + # Raise FileNotFoundError raise FileNotFoundError( f"Expected rules directory not found at {rules_dir}" ) + # Return rules_dir return rules_dir +# Define function _parse_toml_safe def _parse_toml_safe(path: Path) -> dict | None: """Parse a TOML file. Uses the ``toml`` library.""" + # Attempt the following; catch errors below try: + # Import toml import toml + # Open context manager with open(path, "r", encoding="utf-8") as fh: + # Return toml.load(fh) return toml.load(fh) + # Handle Exception except Exception as exc: + # Log debug: "Failed to parse %s: %s", path, exc logger.debug("Failed to parse %s: %s", path, exc) + # Return None return None +# Define function _extract_mitre_techniques def _extract_mitre_techniques(threat_list: list) -> list[str]: """Extract MITRE technique IDs from Elastic's ``rule.threat`` array. @@ -156,82 +217,132 @@ def _extract_mitre_techniques(threat_list: list) -> list[str]: name = "LSASS Memory" id = "T1003.001" """ + # Assign technique_ids = [] technique_ids = [] + # Check: not isinstance(threat_list, list) if not isinstance(threat_list, list): + # Return technique_ids return technique_ids + # Iterate over threat_list for threat_entry in threat_list: + # Check: not isinstance(threat_entry, dict) if not isinstance(threat_entry, dict): + # Skip to the next loop iteration continue # Skip non-MITRE frameworks framework = threat_entry.get("framework", "") + # Check: "MITRE" not in str(framework).upper() if "MITRE" not in str(framework).upper(): + # Skip to the next loop iteration continue + # Assign techniques = threat_entry.get("technique", []) techniques = threat_entry.get("technique", []) + # Check: not isinstance(techniques, list) if not isinstance(techniques, list): + # Skip to the next loop iteration continue + # Iterate over techniques for tech in techniques: + # Check: not isinstance(tech, dict) if not isinstance(tech, dict): + # Skip to the next loop iteration continue + # Assign tech_id = tech.get("id", "") tech_id = tech.get("id", "") + # Check: tech_id and str(tech_id).upper().startswith("T") if tech_id and str(tech_id).upper().startswith("T"): + # Call technique_ids.append() technique_ids.append(str(tech_id).upper()) # Check subtechniques subtechniques = tech.get("subtechnique", []) + # Check: isinstance(subtechniques, list) if isinstance(subtechniques, list): + # Iterate over subtechniques for subtech in subtechniques: + # Check: isinstance(subtech, dict) if isinstance(subtech, dict): + # Assign sub_id = subtech.get("id", "") sub_id = subtech.get("id", "") + # Check: sub_id and str(sub_id).upper().startswith("T") if sub_id and str(sub_id).upper().startswith("T"): + # Call technique_ids.append() technique_ids.append(str(sub_id).upper()) + # Return list(set(technique_ids)) return list(set(technique_ids)) +# Define function _parse_elastic_rules def _parse_elastic_rules(rules_dir: Path) -> list[dict]: """Walk the rules directory and parse all TOML files. Returns a flat list of dicts, one per (rule, technique) combination. """ + # Assign results = [] results: list[dict] = [] + # Assign toml_files = sorted(rules_dir.rglob("*.toml")) toml_files = sorted(rules_dir.rglob("*.toml")) + # Log info: "Found %d TOML files to parse", len(toml_files logger.info("Found %d TOML files to parse", len(toml_files)) + # Iterate over toml_files for toml_path in toml_files: + # Assign data = _parse_toml_safe(toml_path) data = _parse_toml_safe(toml_path) + # Check: not data if not data: + # Skip to the next loop iteration continue + # Assign rule = data.get("rule", {}) rule = data.get("rule", {}) + # Check: not isinstance(rule, dict) if not isinstance(rule, dict): + # Skip to the next loop iteration continue + # Assign name = rule.get("name", "").strip() name = rule.get("name", "").strip() + # Check: not name if not name: + # Skip to the next loop iteration continue # Extract MITRE technique IDs threat_list = rule.get("threat", []) + # Assign technique_ids = _extract_mitre_techniques(threat_list) technique_ids = _extract_mitre_techniques(threat_list) + # Check: not technique_ids if not technique_ids: + # Skip to the next loop iteration continue + # Assign description = rule.get("description", "") description = rule.get("description", "") + # Assign query = rule.get("query", "") query = rule.get("query", "") + # Assign severity = _SEVERITY_MAP.get(str(rule.get("severity", "")).lower()) severity = _SEVERITY_MAP.get(str(rule.get("severity", "")).lower()) + # Assign rule_type = rule.get("type", "query") # query, eql, threshold, etc. rule_type = rule.get("type", "query") # query, eql, threshold, etc. # Determine rule format based on type if rule_type == "eql": + # Assign rule_format = "eql" rule_format = "eql" + # Alternative: rule_type == "esql" elif rule_type == "esql": + # Assign rule_format = "esql" rule_format = "esql" + # Fallback: handle remaining cases else: + # Assign rule_format = "kql" rule_format = "kql" # Use filename as source_id @@ -239,51 +350,79 @@ def _parse_elastic_rules(rules_dir: Path) -> list[dict]: # Read raw content try: + # Open context manager with open(toml_path, "r", encoding="utf-8") as fh: + # Assign raw_content = fh.read() raw_content = fh.read() + # Handle Exception except Exception: + # Assign raw_content = query or str(data) raw_content = query or str(data) # Build source URL relative = str(toml_path.relative_to(rules_dir.parent)).replace("\\", "/") + # Assign source_url = ( source_url = ( f"https://github.com/elastic/detection-rules/blob/main/{relative}" ) # One entry per technique for tech_id in technique_ids: + # Call results.append() results.append({ + # Literal argument value "mitre_technique_id": tech_id, + # Literal argument value "title": name[:500], + # Literal argument value "description": str(description)[:2000] if description else None, + # Literal argument value "source_id": source_id, + # Literal argument value "source_url": source_url, + # Literal argument value "rule_content": query[:50000] if query else raw_content[:50000], + # Literal argument value "rule_format": rule_format, + # Literal argument value "severity": severity, + # Literal argument value "platforms": _infer_platforms(rules_dir, toml_path), }) + # Log info: "Parsed %d (rule, technique) pairs total", len(res logger.info("Parsed %d (rule, technique) pairs total", len(results)) + # Return results return results +# Define function _infer_platforms def _infer_platforms(rules_dir: Path, toml_path: Path) -> list[str] | None: """Infer platforms from the rule's directory structure. Elastic organizes rules by OS: rules/windows/, rules/linux/, etc. """ + # Assign relative = toml_path.relative_to(rules_dir) relative = toml_path.relative_to(rules_dir) + # Assign parts = [p.lower() for p in relative.parts] parts = [p.lower() for p in relative.parts] + # Assign platforms = [] platforms = [] + # Check: "windows" in parts if "windows" in parts: + # Call platforms.append() platforms.append("windows") + # Check: "linux" in parts if "linux" in parts: + # Call platforms.append() platforms.append("linux") + # Check: "macos" in parts if "macos" in parts: + # Call platforms.append() platforms.append("macos") + # Return platforms if platforms else None return platforms if platforms else None @@ -297,47 +436,78 @@ def sync(db: Session) -> dict: Returns a summary dict with ``created``, ``skipped_existing``, ``total_parsed``. """ + # Assign tmp_dir = tempfile.mkdtemp(prefix="aegis_elastic_") tmp_dir = tempfile.mkdtemp(prefix="aegis_elastic_") + # Attempt the following; catch errors below try: + # Assign zip_bytes = _download_zip() zip_bytes = _download_zip() + # Assign rules_dir = _extract_zip(zip_bytes, tmp_dir) rules_dir = _extract_zip(zip_bytes, tmp_dir) + # Assign parsed_rules = _parse_elastic_rules(rules_dir) parsed_rules = _parse_elastic_rules(rules_dir) + # Always execute this cleanup block finally: + # Call shutil.rmtree() shutil.rmtree(tmp_dir, ignore_errors=True) + # Log info: "Cleaned up temp directory %s", tmp_dir logger.info("Cleaned up temp directory %s", tmp_dir) # Pre-load existing source_ids for dedup existing_ids: set[str] = { row[0] for row in db.query(DetectionRule.source_id) + # Chain .filter() call .filter(DetectionRule.source == "elastic") + # Chain .filter() call .filter(DetectionRule.source_id.isnot(None)) + # Chain .all() call .all() } + # Assign created = 0 created = 0 + # Assign skipped = 0 skipped = 0 new_technique_ids: set[str] = set() + # Iterate over parsed_rules for item in parsed_rules: + # Check: item["source_id"] in existing_ids if item["source_id"] in existing_ids: + # Assign skipped = 1 skipped += 1 + # Skip to the next loop iteration continue + # Assign rule = DetectionRule( rule = DetectionRule( + # Keyword argument: mitre_technique_id mitre_technique_id=item["mitre_technique_id"], + # Keyword argument: title title=item["title"], + # Keyword argument: description description=item["description"], + # Keyword argument: source source="elastic", + # Keyword argument: source_id source_id=item["source_id"], + # Keyword argument: source_url source_url=item["source_url"], + # Keyword argument: rule_content rule_content=item["rule_content"], + # Keyword argument: rule_format rule_format=item["rule_format"], + # Keyword argument: severity severity=item["severity"], + # Keyword argument: platforms platforms=item["platforms"], + # Keyword argument: is_active is_active=True, ) + # Stage new record(s) for database insertion db.add(rule) + # Call existing_ids.add() existing_ids.add(item["source_id"]) new_technique_ids.add(item["mitre_technique_id"]) created += 1 @@ -350,22 +520,36 @@ def sync(db: Session) -> dict: db.commit() + # Assign summary = { summary = { + # Literal argument value "created": created, + # Literal argument value "skipped_existing": skipped, + # Literal argument value "total_parsed": len(parsed_rules), } # Update DataSource record ds = db.query(DataSource).filter(DataSource.name == "elastic_rules").first() + # Check: ds if ds: + # Assign ds.last_sync_at = datetime.utcnow() ds.last_sync_at = datetime.utcnow() + # Assign ds.last_sync_status = "success" ds.last_sync_status = "success" + # Assign ds.last_sync_stats = summary ds.last_sync_stats = summary + # Commit all pending changes to the database db.commit() + # Log info: "Elastic import complete — %s", summary logger.info("Elastic import complete — %s", summary) + # Call log_action() log_action(db, user_id=None, action="import_elastic_rules", + # Keyword argument: entity_type entity_type="detection_rule", entity_id=None, details=summary) + # Commit all pending changes to the database db.commit() + # Return summary return summary diff --git a/backend/app/services/evidence_service.py b/backend/app/services/evidence_service.py index afb5841..70393aa 100644 --- a/backend/app/services/evidence_service.py +++ b/backend/app/services/evidence_service.py @@ -5,20 +5,32 @@ The router is responsible for HTTP concerns, file I/O, MinIO upload, audit logging, and response formatting. """ +# Enable future language features for compatibility from __future__ import annotations +# Import os import os + +# Import uuid import uuid +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session +# Import from app.domain.errors from app.domain.errors import ( BusinessRuleViolation, EntityNotFoundError, PermissionViolation, ) + +# Import TeamSide, TestState from app.models.enums from app.models.enums import TeamSide, TestState + +# Import Evidence from app.models.evidence from app.models.evidence import Evidence + +# Import Test from app.models.test from app.models.test import Test # States where red evidence can be uploaded / deleted @@ -31,19 +43,30 @@ MAX_UPLOAD_SIZE = 50 * 1024 * 1024 # Allowed file extensions (lowercase, with leading dot) ALLOWED_EXTENSIONS: frozenset[str] = frozenset({ + # Literal argument value ".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp", ".svg", + # Literal argument value ".pdf", ".doc", ".docx", ".xls", ".xlsx", ".csv", ".txt", + # Literal argument value ".md", ".rtf", ".odt", ".ods", + # Literal argument value ".log", ".pcap", ".pcapng", ".evtx", ".json", ".xml", + # Literal argument value ".yaml", ".yml", ".toml", + # Literal argument value ".zip", ".tar", ".gz", ".7z", + # Literal argument value ".har", ".eml", ".msg", }) +# Define function validate_upload_permission def validate_upload_permission( + # Entry: test test: Test, + # Entry: team team: TeamSide, + # Entry: user_role user_role: str, ) -> None: """Validate that the user can upload evidence for the given team in the current state. @@ -52,35 +75,56 @@ def validate_upload_permission( PermissionViolation: If user lacks role to upload for this team. BusinessRuleViolation: If test state does not allow uploading for this team. """ + # Check: user_role == "admin" if user_role == "admin": + # Return control to caller return + # Check: team == TeamSide.red if team == TeamSide.red: + # Check: user_role not in ("red_tech", "red_lead") if user_role not in ("red_tech", "red_lead"): + # Raise PermissionViolation raise PermissionViolation( + # Literal argument value "Only red_tech, red_lead or admin can upload red evidence" ) + # Check: test.state not in RED_EDITABLE_STATES if test.state not in RED_EDITABLE_STATES: + # Raise BusinessRuleViolation raise BusinessRuleViolation( f"Cannot upload red evidence in '{test.state.value}' state " + # Literal argument value "(allowed in: draft, red_executing)" ) + # Alternative: team == TeamSide.blue elif team == TeamSide.blue: + # Check: user_role not in ("blue_tech", "blue_lead") if user_role not in ("blue_tech", "blue_lead"): + # Raise PermissionViolation raise PermissionViolation( + # Literal argument value "Only blue_tech, blue_lead or admin can upload blue evidence" ) + # Check: test.state not in BLUE_EDITABLE_STATES if test.state not in BLUE_EDITABLE_STATES: + # Raise BusinessRuleViolation raise BusinessRuleViolation( f"Cannot upload blue evidence in '{test.state.value}' state " + # Literal argument value "(allowed in: blue_evaluating)" ) +# Define function validate_delete_permission def validate_delete_permission( + # Entry: test test: Test, + # Entry: evidence evidence: Evidence, + # Entry: user_role user_role: str, + # Entry: user_id user_id: uuid.UUID, ) -> None: """Validate that the user can delete this evidence in the current state. @@ -88,80 +132,125 @@ def validate_delete_permission( Raises: PermissionViolation: If user cannot delete in this state or lacks permission. """ + # Check: test.state in (TestState.in_review, TestState.validated, TestState.... if test.state in (TestState.in_review, TestState.validated, TestState.rejected): + # Raise PermissionViolation raise PermissionViolation( f"Cannot delete evidence when test is in '{test.state.value}' state" ) + # Check: user_role == "admin" if user_role == "admin": + # Return control to caller return + # Assign ev_team = evidence.team ev_team = evidence.team + # Check: ev_team == TeamSide.red if ev_team == TeamSide.red: + # Check: test.state not in RED_EDITABLE_STATES if test.state not in RED_EDITABLE_STATES: + # Raise PermissionViolation raise PermissionViolation( + # Literal argument value "Cannot delete red evidence outside draft/red_executing" ) + # Check: user_role not in ("red_tech", "red_lead") and evidence.uploaded_by ... if user_role not in ("red_tech", "red_lead") and evidence.uploaded_by != user_id: + # Raise PermissionViolation raise PermissionViolation( + # Literal argument value "Not enough permissions to delete this evidence" ) + # Alternative: ev_team == TeamSide.blue elif ev_team == TeamSide.blue: + # Check: test.state not in BLUE_EDITABLE_STATES if test.state not in BLUE_EDITABLE_STATES: + # Raise PermissionViolation raise PermissionViolation( + # Literal argument value "Cannot delete blue evidence outside blue_evaluating" ) + # Check: user_role not in ("blue_tech", "blue_lead") and evidence.uploaded_b... if user_role not in ("blue_tech", "blue_lead") and evidence.uploaded_by != user_id: + # Raise PermissionViolation raise PermissionViolation( + # Literal argument value "Not enough permissions to delete this evidence" ) +# Define function validate_file def validate_file(file_name: str, content_size: int) -> None: """Validate file extension and size. Raises: BusinessRuleViolation: If extension is not allowed or file exceeds size limit. """ + # _, ext = os.path.splitext(file_name) _, ext = os.path.splitext(file_name) + # Assign ext_lower = ext.lower() if ext else "" ext_lower = ext.lower() if ext else "" + # Check: ext_lower not in ALLOWED_EXTENSIONS if ext_lower not in ALLOWED_EXTENSIONS: + # Raise BusinessRuleViolation raise BusinessRuleViolation( f"File type '{ext}' is not allowed. " f"Permitted types: {', '.join(sorted(ALLOWED_EXTENSIONS))}" ) + # Check: content_size > MAX_UPLOAD_SIZE if content_size > MAX_UPLOAD_SIZE: + # Raise BusinessRuleViolation raise BusinessRuleViolation( f"File exceeds maximum upload size of {MAX_UPLOAD_SIZE // (1024 * 1024)} MB" ) +# Define function list_evidence_for_test def list_evidence_for_test( + # Entry: db db: Session, + # Entry: test_id test_id: uuid.UUID, *, + # Entry: team team: TeamSide | str | None = None, ) -> list[Evidence]: """Return evidence for a test, optionally filtered by team.""" + # Assign query = db.query(Evidence).filter(Evidence.test_id == test_id) query = db.query(Evidence).filter(Evidence.test_id == test_id) + # Check: team is not None if team is not None: + # Assign team_enum = TeamSide(team) if isinstance(team, str) else team team_enum = TeamSide(team) if isinstance(team, str) else team + # Assign query = query.filter(Evidence.team == team_enum) query = query.filter(Evidence.team == team_enum) + # Return query.order_by(Evidence.uploaded_at.desc()).all() return query.order_by(Evidence.uploaded_at.desc()).all() +# Define function get_evidence_or_raise def get_evidence_or_raise(db: Session, evidence_id: uuid.UUID) -> Evidence: """Fetch evidence by ID. Raises EntityNotFoundError if not found.""" + # Assign evidence = db.query(Evidence).filter(Evidence.id == evidence_id).first() evidence = db.query(Evidence).filter(Evidence.id == evidence_id).first() + # Check: evidence is None if evidence is None: + # Raise EntityNotFoundError raise EntityNotFoundError("Evidence", str(evidence_id)) + # Return evidence return evidence +# Define function get_test_or_raise def get_test_or_raise(db: Session, test_id: uuid.UUID) -> Test: """Fetch test by ID. Raises EntityNotFoundError if not found.""" + # Assign test = db.query(Test).filter(Test.id == test_id).first() test = db.query(Test).filter(Test.id == test_id).first() + # Check: test is None if test is None: + # Raise EntityNotFoundError raise EntityNotFoundError("Test", str(test_id)) + # Return test return test diff --git a/backend/app/services/heatmap_service.py b/backend/app/services/heatmap_service.py index d39c657..0671440 100644 --- a/backend/app/services/heatmap_service.py +++ b/backend/app/services/heatmap_service.py @@ -7,32 +7,59 @@ This module is framework-agnostic: no FastAPI imports, no HTTPException, no ``db.commit()``. """ +# Enable future language features for compatibility from __future__ import annotations +# Import json import json -from typing import Optional +# Import Callable from collections.abc +from collections.abc import Callable + +# Import func, or_ from sqlalchemy from sqlalchemy import func, or_ -from sqlalchemy.orm import Session +# Import Query, Session from sqlalchemy.orm +from sqlalchemy.orm import Query, Session + +# Import BusinessRuleViolation, EntityNotFoundError from app.domain.errors from app.domain.errors import BusinessRuleViolation, EntityNotFoundError + +# Import Campaign, CampaignTest from app.models.campaign from app.models.campaign import Campaign, CampaignTest + +# Import DetectionRule from app.models.detection_rule from app.models.detection_rule import DetectionRule -from app.models.defensive_technique import DefensiveTechniqueMapping + +# Import TechniqueStatus, TestState from app.models.enums from app.models.enums import TechniqueStatus, TestState + +# Import Technique from app.models.technique from app.models.technique import Technique + +# Import Test from app.models.test from app.models.test import Test + +# Import TestDetectionResult from app.models.test_detection_result from app.models.test_detection_result import TestDetectionResult + +# Import ThreatActor, ThreatActorTechnique from app.models.threat_actor from app.models.threat_actor import ThreatActor, ThreatActorTechnique + +# Import escape_like from app.utils from app.utils import escape_like # ── Constants ───────────────────────────────────────────────────────── ATTACK_VERSION = "15" +# Assign NAVIGATOR_VERSION = "5.0" NAVIGATOR_VERSION = "5.0" +# Assign LAYER_VERSION = "4.5" LAYER_VERSION = "4.5" +# Assign DOMAIN = "enterprise-attack" DOMAIN = "enterprise-attack" +# Assign STATUS_SCORE_MAP = { STATUS_SCORE_MAP: dict[TechniqueStatus, int] = { TechniqueStatus.validated: 100, TechniqueStatus.partial: 60, @@ -42,6 +69,7 @@ STATUS_SCORE_MAP: dict[TechniqueStatus, int] = { TechniqueStatus.review_required: 10, } +# Assign TEST_STATE_SCORE = { TEST_STATE_SCORE: dict[TestState, int] = { TestState.validated: 100, TestState.in_review: 70, @@ -56,74 +84,169 @@ TEST_STATE_SCORE: dict[TestState, int] = { def _score_to_color(score: int) -> str: - """Map a 0-100 score to a red-yellow-green colour hex.""" + """Map a 0-100 score to a red-yellow-green colour hex. + + Args: + score (int): Coverage score between 0 and 100 inclusive. + + Returns: + str: Hex colour string representing the score tier. + """ + # Check: score <= 0 if score <= 0: + # Return "#d3d3d3" return "#d3d3d3" + # Check: score <= 25 if score <= 25: + # Return "#ff6666" return "#ff6666" + # Check: score <= 50 if score <= 50: + # Return "#ff9933" return "#ff9933" + # Check: score <= 75 if score <= 75: + # Return "#ffff66" return "#ffff66" + # Return "#66ff66" return "#66ff66" +# Define function _build_layer_skeleton def _build_layer_skeleton( + # Entry: name name: str, + # Entry: description description: str, + # Entry: gradient_colors gradient_colors: list[str] | None = None, ) -> dict: - """Return a base layer dict compatible with ATT&CK Navigator.""" + """Return a base layer dict compatible with ATT&CK Navigator. + + Args: + name (str): Human-readable name for the layer. + description (str): Description text embedded in the layer metadata. + gradient_colors (list[str] | None): Optional list of hex colour stops + for the gradient; defaults to red-yellow-green if omitted. + + Returns: + dict: Skeleton layer dictionary with versions, domain, and empty + techniques list. + """ + # Return { return { + # Literal argument value "name": name, + # Literal argument value "versions": { + # Literal argument value "attack": ATTACK_VERSION, + # Literal argument value "navigator": NAVIGATOR_VERSION, + # Literal argument value "layer": LAYER_VERSION, }, + # Literal argument value "domain": DOMAIN, + # Literal argument value "description": description, + # Literal argument value "filters": {"platforms": ["windows", "linux", "macos"]}, + # Literal argument value "gradient": { + # Literal argument value "colors": gradient_colors or ["#ff6666", "#ffff66", "#66ff66"], + # Literal argument value "minValue": 0, + # Literal argument value "maxValue": 100, }, + # Literal argument value "techniques": [], } +# Define function _apply_filters def _apply_filters( - query, - model, + # Entry: query + query: Query, # type: ignore[type-arg] + # Entry: model + model: type, + # Entry: platforms platforms: list[str] | None = None, + # Entry: tactics tactics: list[str] | None = None, -): - """Apply common platform and tactic filters to a technique query.""" +) -> Query: # type: ignore[type-arg] + """Apply common platform and tactic filters to a technique query. + + Args: + query (Query): Base SQLAlchemy query targeting a technique-like model. + model (type): The SQLAlchemy model class that owns ``platforms`` and + ``tactic`` columns. + platforms (list[str] | None): Optional list of platform names to + filter by (OR-joined). + tactics (list[str] | None): Optional list of tactic strings to + filter by (OR-joined, case-insensitive substring match). + + Returns: + Query: The query with platform and tactic filters applied. + """ + # Check: platforms if platforms: + # Assign platform_filters = [ platform_filters = [ model.platforms.op("@>")(json.dumps([p])) for p in platforms ] + # Assign query = query.filter(or_(*platform_filters)) query = query.filter(or_(*platform_filters)) + # Check: tactics if tactics: + # Assign tactic_filters = [ tactic_filters = [ model.tactic.ilike(f"%{escape_like(t)}%") for t in tactics ] + # Assign query = query.filter(or_(*tactic_filters)) query = query.filter(or_(*tactic_filters)) + # Return query return query +# Define function _format_tactic def _format_tactic(tactic_str: str | None) -> str: - """Normalize tactic string to ATT&CK Navigator format (kebab-case).""" + """Normalize tactic string to ATT&CK Navigator format (kebab-case). + + Args: + tactic_str (str | None): Raw tactic string, possibly comma-separated + or mixed-case. + + Returns: + str: First tactic value lowercased and trimmed, or empty string if + the input is falsy. + """ + # Check: not tactic_str if not tactic_str: + # Return "" return "" + # Return tactic_str.split(",")[0].strip().lower() return tactic_str.split(",")[0].strip().lower() +# Define function _parse_csv def _parse_csv(value: str | None) -> list[str] | None: - """Split a comma-separated string into a trimmed list, or ``None``.""" + """Split a comma-separated string into a trimmed list, or ``None``. + + Args: + value (str | None): Comma-separated string to split, or ``None``. + + Returns: + list[str] | None: Non-empty trimmed tokens, or ``None`` if the input + is falsy or produces no tokens. + """ + # Check: not value if not value: + # Return None return None + # Return [v.strip() for v in value.split(",") if v.strip()] return [v.strip() for v in value.split(",") if v.strip()] @@ -131,132 +254,224 @@ def _parse_csv(value: str | None) -> list[str] | None: def build_coverage_layer( + # Entry: db db: Session, *, + # Entry: platforms platforms: str | None = None, + # Entry: tactics tactics: str | None = None, + # Entry: min_score min_score: int = 0, ) -> dict: - """Coverage layer -- score based on ``status_global`` of each technique.""" + """Coverage layer -- score based on ``status_global`` of each technique. + + Args: + db (Session): Active SQLAlchemy database session. + platforms (str | None): Optional comma-separated platform names to + filter techniques. + tactics (str | None): Optional comma-separated tactic names to filter + techniques. + min_score (int): Minimum score threshold; techniques below this are + omitted from the layer. + + Returns: + dict: ATT&CK Navigator-compatible layer dictionary. + """ + # Assign layer = _build_layer_skeleton("Aegis Coverage", "Coverage layer generated b... layer = _build_layer_skeleton("Aegis Coverage", "Coverage layer generated by Aegis") + # Assign query = _apply_filters( query = _apply_filters( db.query(Technique), Technique, _parse_csv(platforms), _parse_csv(tactics), ) + # Assign techniques = query.all() techniques = query.all() # Bulk-fetch test counts and rule counts to avoid N+1 tech_ids = [t.id for t in techniques] + # Assign mitre_ids = [t.mitre_id for t in techniques] mitre_ids = [t.mitre_id for t in techniques] + # Assign test_counts = dict( test_counts = dict( db.query(Test.technique_id, func.count(Test.id)) + # Chain .filter() call .filter(Test.technique_id.in_(tech_ids), Test.state == TestState.validated) + # Chain .group_by() call .group_by(Test.technique_id) + # Chain .all() call .all() ) if tech_ids else {} + # Assign rule_counts = dict( rule_counts = dict( db.query(DetectionRule.mitre_technique_id, func.count(DetectionRule.id)) + # Chain .filter() call .filter(DetectionRule.mitre_technique_id.in_(mitre_ids)) + # Chain .group_by() call .group_by(DetectionRule.mitre_technique_id) + # Chain .all() call .all() ) if mitre_ids else {} + # Iterate over techniques for tech in techniques: + # Assign score = STATUS_SCORE_MAP.get(tech.status_global, 0) score = STATUS_SCORE_MAP.get(tech.status_global, 0) + # Check: score < min_score if score < min_score: + # Skip to the next loop iteration continue + # Assign tc = test_counts.get(tech.id, 0) tc = test_counts.get(tech.id, 0) + # Assign rc = rule_counts.get(tech.mitre_id, 0) rc = rule_counts.get(tech.mitre_id, 0) + # Assign metadata = [ metadata = [ {"name": "tests_count", "value": str(tc)}, {"name": "detection_rules", "value": str(rc)}, ] + # Check: tech.last_review_date if tech.last_review_date: + # Call metadata.append() metadata.append( {"name": "last_validated", "value": tech.last_review_date.strftime("%Y-%m-%d")} ) + # Assign comment_parts = [ comment_parts = [ f"Status: {tech.status_global.value}", f"{tc} tests validated", f"{rc} detection rules", ] + # layer["techniques"].append({ layer["techniques"].append({ + # Literal argument value "techniqueID": tech.mitre_id, + # Literal argument value "tactic": _format_tactic(tech.tactic), + # Literal argument value "color": _score_to_color(score), + # Literal argument value "score": score, + # Literal argument value "comment": " - ".join(comment_parts), + # Literal argument value "enabled": True, + # Literal argument value "metadata": metadata, }) + # Return layer return layer +# Define function build_threat_actor_layer def build_threat_actor_layer( + # Entry: db db: Session, + # Entry: actor_id actor_id: str, *, + # Entry: platforms platforms: str | None = None, + # Entry: tactics tactics: str | None = None, + # Entry: min_score min_score: int = 0, ) -> dict: """Threat actor layer -- techniques used by an actor with coverage colour. Raises :class:`EntityNotFoundError` if the actor does not exist. + + Args: + db (Session): Active SQLAlchemy database session. + actor_id (str): UUID string identifying the threat actor. + platforms (str | None): Optional comma-separated platform names to + filter techniques. + tactics (str | None): Optional comma-separated tactic names to filter + techniques. + min_score (int): Minimum score threshold for actor techniques. + + Returns: + dict: ATT&CK Navigator-compatible layer dictionary coloured by + coverage status for the specified actor. """ + # Assign actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first() actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first() + # Check: not actor if not actor: + # Raise EntityNotFoundError raise EntityNotFoundError("ThreatActor", actor_id) + # Assign layer = _build_layer_skeleton( layer = _build_layer_skeleton( f"Threat Actor: {actor.name}", f"Techniques used by {actor.name} with coverage overlay", + # Keyword argument: gradient_colors gradient_colors=["#808080", "#ff6666", "#66ff66"], ) + # Assign actor_technique_ids = { actor_technique_ids = { row.technique_id for row in db.query(ThreatActorTechnique.technique_id) + # Chain .filter() call .filter(ThreatActorTechnique.threat_actor_id == actor.id) + # Chain .all() call .all() } + # Check: not actor_technique_ids if not actor_technique_ids: + # Return layer return layer + # Assign query = _apply_filters( query = _apply_filters( db.query(Technique), Technique, _parse_csv(platforms), _parse_csv(tactics), ) + # Assign techniques = query.all() techniques = query.all() # Bulk-fetch metadata for actor techniques only test_counts = dict( db.query(Test.technique_id, func.count(Test.id)) + # Chain .filter() call .filter(Test.technique_id.in_(actor_technique_ids), Test.state == TestState.validated) + # Chain .group_by() call .group_by(Test.technique_id) + # Chain .all() call .all() ) + # Assign actor_mitre_ids = [t.mitre_id for t in techniques if t.id in actor_technique_ids] actor_mitre_ids = [t.mitre_id for t in techniques if t.id in actor_technique_ids] + # Assign rule_counts = dict( rule_counts = dict( db.query(DetectionRule.mitre_technique_id, func.count(DetectionRule.id)) + # Chain .filter() call .filter(DetectionRule.mitre_technique_id.in_(actor_mitre_ids)) + # Chain .group_by() call .group_by(DetectionRule.mitre_technique_id) + # Chain .all() call .all() ) if actor_mitre_ids else {} + # Iterate over techniques for tech in techniques: + # Assign is_actor_technique = tech.id in actor_technique_ids is_actor_technique = tech.id in actor_technique_ids + # Assign score = STATUS_SCORE_MAP.get(tech.status_global, 0) if is_actor_technique e... score = STATUS_SCORE_MAP.get(tech.status_global, 0) if is_actor_technique else 0 + # Check: is_actor_technique and score < min_score if is_actor_technique and score < min_score: + # Skip to the next loop iteration continue # Only include techniques actually used by this actor — skip the rest @@ -284,14 +499,20 @@ def build_threat_actor_layer( "metadata": metadata, }) + # Return layer return layer +# Define function build_detection_rules_layer def build_detection_rules_layer( + # Entry: db db: Session, *, + # Entry: platforms platforms: str | None = None, + # Entry: tactics tactics: str | None = None, + # Entry: min_score min_score: int = 0, ) -> dict: """Detection rules layer -- score based on absolute rule count per technique. @@ -305,28 +526,40 @@ def build_detection_rules_layer( 4+ rules → green (score 100) """ layer = _build_layer_skeleton( + # Literal argument value "Detection Rules Coverage", "Number of active detection rules per technique", ) + # Assign query = _apply_filters( query = _apply_filters( db.query(Technique), Technique, _parse_csv(platforms), _parse_csv(tactics), ) + # Assign techniques = query.all() techniques = query.all() + # Assign rule_counts = dict( rule_counts = dict( db.query(DetectionRule.mitre_technique_id, func.count(DetectionRule.id)) + # Chain .filter() call .filter(DetectionRule.is_active == True) # noqa: E712 + # Chain .group_by() call .group_by(DetectionRule.mitre_technique_id) + # Chain .all() call .all() ) + # Assign evaluated_counts = dict( evaluated_counts = dict( db.query(DetectionRule.mitre_technique_id, func.count(TestDetectionResult.id)) + # Chain .join() call .join(TestDetectionResult, TestDetectionResult.detection_rule_id == DetectionRule.id) + # Chain .filter() call .filter(TestDetectionResult.triggered.isnot(None)) + # Chain .group_by() call .group_by(DetectionRule.mitre_technique_id) + # Chain .all() call .all() ) @@ -334,12 +567,16 @@ def build_detection_rules_layer( RULES_FOR_FULL_COVERAGE = 4 for tech in techniques: + # Assign total_rules = rule_counts.get(tech.mitre_id, 0) total_rules = rule_counts.get(tech.mitre_id, 0) + # Assign evaluated_rules = evaluated_counts.get(tech.mitre_id, 0) evaluated_rules = evaluated_counts.get(tech.mitre_id, 0) score = min(int((total_rules / RULES_FOR_FULL_COVERAGE) * 100), 100) + # Check: score < min_score if score < min_score: + # Skip to the next loop iteration continue rule_word = "rule" if total_rules == 1 else "rules" @@ -347,113 +584,194 @@ def build_detection_rules_layer( comment = f"{total_rules} active {rule_word}{eval_note}" layer["techniques"].append({ + # Literal argument value "techniqueID": tech.mitre_id, + # Literal argument value "tactic": _format_tactic(tech.tactic), + # Literal argument value "color": _score_to_color(score), + # Literal argument value "score": score, "comment": comment, "enabled": True, + # Literal argument value "metadata": [ {"name": "total_rules", "value": str(total_rules)}, {"name": "evaluated_rules", "value": str(evaluated_rules)}, ], }) + # Return layer return layer +# Define function build_campaign_layer def build_campaign_layer( + # Entry: db db: Session, + # Entry: campaign_id campaign_id: str, *, + # Entry: platforms platforms: str | None = None, + # Entry: tactics tactics: str | None = None, + # Entry: min_score min_score: int = 0, ) -> dict: """Campaign layer -- techniques in a campaign, coloured by test state. Raises :class:`EntityNotFoundError` if the campaign does not exist. + + Args: + db (Session): Active SQLAlchemy database session. + campaign_id (str): UUID string identifying the campaign. + platforms (str | None): Optional comma-separated platform names to + filter techniques. + tactics (str | None): Optional comma-separated tactic names to filter + techniques. + min_score (int): Minimum score threshold for techniques in the layer. + + Returns: + dict: ATT&CK Navigator-compatible layer dictionary where each + technique colour reflects the best test state within the campaign. """ + # Assign campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() + # Check: not campaign if not campaign: + # Raise EntityNotFoundError raise EntityNotFoundError("Campaign", campaign_id) + # Assign layer = _build_layer_skeleton( layer = _build_layer_skeleton( f"Campaign: {campaign.name}", f"Progress of campaign '{campaign.name}'", ) + # Assign campaign_tests = ( campaign_tests = ( db.query(CampaignTest) + # Chain .filter() call .filter(CampaignTest.campaign_id == campaign.id) + # Chain .all() call .all() ) + # Check: not campaign_tests if not campaign_tests: + # Return layer return layer + # Assign test_ids = [ct.test_id for ct in campaign_tests] test_ids = [ct.test_id for ct in campaign_tests] + # Assign tests = db.query(Test).filter(Test.id.in_(test_ids)).all() tests = db.query(Test).filter(Test.id.in_(test_ids)).all() + # Assign test_map = {t.id: t for t in tests} test_map = {t.id: t for t in tests} + # Assign technique_ids = {t.technique_id for t in tests if t.technique_id} technique_ids = {t.technique_id for t in tests if t.technique_id} + # Assign techniques = db.query(Technique).filter(Technique.id.in_(technique_ids)).all() techniques = db.query(Technique).filter(Technique.id.in_(technique_ids)).all() + # Assign tech_map = {t.id: t for t in techniques} tech_map = {t.id: t for t in techniques} # Group tests by technique, keeping the best state score tech_scores: dict = {} + # Iterate over campaign_tests for ct in campaign_tests: + # Assign test = test_map.get(ct.test_id) test = test_map.get(ct.test_id) + # Check: not test if not test: + # Skip to the next loop iteration continue + # Assign tech = tech_map.get(test.technique_id) tech = tech_map.get(test.technique_id) + # Check: not tech if not tech: + # Skip to the next loop iteration continue + # Assign state_score = TEST_STATE_SCORE.get(test.state, 0) state_score = TEST_STATE_SCORE.get(test.state, 0) + # Check: tech.mitre_id not in tech_scores if tech.mitre_id not in tech_scores: + # Assign tech_scores[tech.mitre_id] = { tech_scores[tech.mitre_id] = { + # Literal argument value "technique": tech, + # Literal argument value "max_score": state_score, + # Literal argument value "tests": [], } + # Fallback: handle remaining cases else: + # Assign tech_scores[tech.mitre_id]["max_score"] = max( tech_scores[tech.mitre_id]["max_score"] = max( tech_scores[tech.mitre_id]["max_score"], state_score, ) + # tech_scores[tech.mitre_id]["tests"].append(test) tech_scores[tech.mitre_id]["tests"].append(test) + # Assign platform_list = _parse_csv(platforms) platform_list = _parse_csv(platforms) + # Assign tactic_list = _parse_csv(tactics) tactic_list = _parse_csv(tactics) + # Iterate over tech_scores.items() for mitre_id, info in tech_scores.items(): + # Assign tech = info["technique"] tech = info["technique"] + # Assign score = info["max_score"] score = info["max_score"] + # Check: platform_list if platform_list: + # Assign tech_platforms = tech.platforms or [] tech_platforms = tech.platforms or [] + # Check: not any(p in tech_platforms for p in platform_list) if not any(p in tech_platforms for p in platform_list): + # Skip to the next loop iteration continue + # Check: tactic_list if tactic_list: + # Assign tech_tactics = [t.strip() for t in (tech.tactic or "").lower().split(",")] tech_tactics = [t.strip() for t in (tech.tactic or "").lower().split(",")] + # Check: not any(t in tech_tactics for t in tactic_list) if not any(t in tech_tactics for t in tactic_list): + # Skip to the next loop iteration continue + # Check: score < min_score if score < min_score: + # Skip to the next loop iteration continue + # Assign test_states = [t.state.value for t in info["tests"]] test_states = [t.state.value for t in info["tests"]] + # layer["techniques"].append({ layer["techniques"].append({ + # Literal argument value "techniqueID": mitre_id, + # Literal argument value "tactic": _format_tactic(tech.tactic), + # Literal argument value "color": _score_to_color(score), + # Literal argument value "score": score, + # Literal argument value "comment": f"Campaign tests: {', '.join(test_states)}", + # Literal argument value "enabled": True, + # Literal argument value "metadata": [ {"name": "campaign_tests", "value": str(len(info["tests"]))}, {"name": "best_state", "value": max(test_states) if test_states else "none"}, ], }) + # Return layer return layer @@ -470,67 +788,143 @@ def build_campaign_layer( class _LayerRegistry: """Extensible registry that maps layer type names to builder functions.""" + # Assign __slots__ = ("_simple", "_with_id") __slots__ = ("_simple", "_with_id") + # Define function __init__ def __init__(self) -> None: + # Assign self._simple = {} self._simple: dict[str, object] = {} + # Assign self._with_id = {} self._with_id: dict[str, object] = {} - def register(self, name: str, builder, *, requires_id: bool = False) -> None: + # Define function register + def register(self, name: str, builder: Callable[..., dict], *, requires_id: bool = False) -> None: + """Register a builder function under *name*. + + Args: + name (str): Unique layer type identifier. + builder (Callable[..., dict]): Layer builder function. + requires_id (bool): Whether the builder needs a positional + ``layer_id`` argument. + """ + # Assign target = self._with_id if requires_id else self._simple target = self._with_id if requires_id else self._simple + # Assign target[name] = builder target[name] = builder + # Apply the @property decorator @property + # Define function supported_types def supported_types(self) -> set[str]: + """Return the set of all registered layer type names. + + Returns: + set[str]: Union of simple and entity-bound layer type names. + """ + # Return set(self._simple) | set(self._with_id) return set(self._simple) | set(self._with_id) + # Define function build def build( self, + # Entry: db db: Session, + # Entry: layer_type layer_type: str, *, + # Entry: layer_id layer_id: str | None = None, + # Entry: platforms platforms: str | None = None, + # Entry: tactics tactics: str | None = None, + # Entry: min_score min_score: int = 0, ) -> dict: + """Dispatch to the registered builder for *layer_type*. + + Args: + db (Session): Active SQLAlchemy database session. + layer_type (str): Registered layer type name. + layer_id (str | None): Entity UUID for entity-bound layer types. + platforms (str | None): Optional comma-separated platform filter. + tactics (str | None): Optional comma-separated tactic filter. + min_score (int): Minimum score threshold. + + Returns: + dict: ATT&CK Navigator-compatible layer dictionary. + """ + # Assign kwargs = dict(platforms=platforms, tactics=tactics, min_score=min_score) kwargs = dict(platforms=platforms, tactics=tactics, min_score=min_score) + # Check: layer_type in self._simple if layer_type in self._simple: + # Return self._simple[layer_type](db, **kwargs) return self._simple[layer_type](db, **kwargs) + # Check: layer_type in self._with_id if layer_type in self._with_id: + # Check: not layer_id if not layer_id: + # Raise BusinessRuleViolation raise BusinessRuleViolation( f"layer_id is required for '{layer_type}' layer" ) + # Return self._with_id[layer_type](db, layer_id, **kwargs) return self._with_id[layer_type](db, layer_id, **kwargs) + # Raise BusinessRuleViolation raise BusinessRuleViolation(f"Unknown layer type: {layer_type}") +# Assign LAYER_REGISTRY = _LayerRegistry() LAYER_REGISTRY = _LayerRegistry() +# Call LAYER_REGISTRY.register() LAYER_REGISTRY.register("coverage", build_coverage_layer) +# Call LAYER_REGISTRY.register() LAYER_REGISTRY.register("detection-rules", build_detection_rules_layer) +# Call LAYER_REGISTRY.register() LAYER_REGISTRY.register("threat-actor", build_threat_actor_layer, requires_id=True) +# Call LAYER_REGISTRY.register() LAYER_REGISTRY.register("campaign", build_campaign_layer, requires_id=True) +# Assign SUPPORTED_LAYER_TYPES = LAYER_REGISTRY.supported_types # snapshot of built-in types SUPPORTED_LAYER_TYPES = LAYER_REGISTRY.supported_types # snapshot of built-in types -def register_layer(name: str, builder, *, requires_id: bool = False) -> None: - """Public API to register a new heatmap layer type at import time.""" +# Define function register_layer +def register_layer(name: str, builder: Callable[..., dict], *, requires_id: bool = False) -> None: + """Register a new heatmap layer type at import time. + + Args: + name (str): Unique identifier for the layer type used in API requests. + builder (Callable[..., dict]): Function that builds the layer dict; + must accept ``(db, *, platforms, tactics, min_score)`` and + optionally a positional ``layer_id`` when ``requires_id`` is + ``True``. + requires_id (bool): Set to ``True`` when the builder needs a + ``layer_id`` argument (e.g. threat-actor, campaign layers). + """ + # Call LAYER_REGISTRY.register() LAYER_REGISTRY.register(name, builder, requires_id=requires_id) +# Define function build_navigator_export def build_navigator_export( + # Entry: db db: Session, + # Entry: layer_type layer_type: str, *, + # Entry: layer_id layer_id: str | None = None, + # Entry: platforms platforms: str | None = None, + # Entry: tactics tactics: str | None = None, + # Entry: min_score min_score: int = 0, ) -> dict: """Build a heatmap layer dict by type name. @@ -539,8 +933,23 @@ def build_navigator_export( missing ``layer_id``. Raises :class:`EntityNotFoundError` when an entity-bound layer (threat-actor, campaign) references a non-existent record. + + Args: + db (Session): Active SQLAlchemy database session. + layer_type (str): Registered layer type name (e.g. ``"coverage"``, + ``"threat-actor"``). + layer_id (str | None): Entity UUID required for entity-bound layer + types such as ``"threat-actor"`` and ``"campaign"``. + platforms (str | None): Optional comma-separated platform filter. + tactics (str | None): Optional comma-separated tactic filter. + min_score (int): Minimum score; techniques below this are excluded. + + Returns: + dict: ATT&CK Navigator-compatible layer dictionary. """ + # Return LAYER_REGISTRY.build( return LAYER_REGISTRY.build( db, layer_type, + # Keyword argument: layer_id layer_id=layer_id, platforms=platforms, tactics=tactics, min_score=min_score, ) diff --git a/backend/app/services/intel_service.py b/backend/app/services/intel_service.py index c75b36d..947d35d 100644 --- a/backend/app/services/intel_service.py +++ b/backend/app/services/intel_service.py @@ -9,18 +9,34 @@ RSS feeds and parses them with the standard-library :mod:`xml.etree` parser. No LLMs or paid APIs are used. """ +# Import logging import logging + +# Import re import re -import defusedxml.ElementTree as ET + +# Import datetime from datetime from datetime import datetime +# Import defusedxml.ElementTree +import defusedxml.ElementTree as ET # noqa: N817 — ET is the universal stdlib alias for ElementTree + +# Import requests import requests as _requests + +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session +# Import IntelItem from app.models.intel from app.models.intel import IntelItem + +# Import Technique from app.models.technique from app.models.technique import Technique + +# Import log_action from app.services.audit_service from app.services.audit_service import log_action +# Assign logger = logging.getLogger(__name__) logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- @@ -29,7 +45,9 @@ logger = logging.getLogger(__name__) RSS_FEEDS: list[dict[str, str]] = [ { + # Literal argument value "name": "CISA Alerts", + # Literal argument value "url": "https://www.cisa.gov/cybersecurity-advisories/all.xml", }, { @@ -37,19 +55,27 @@ RSS_FEEDS: list[dict[str, str]] = [ "url": "https://feeds.feedburner.com/Securityweek", }, { + # Literal argument value "name": "SANS ISC", + # Literal argument value "url": "https://isc.sans.edu/rssfeed.xml", }, { + # Literal argument value "name": "BleepingComputer", + # Literal argument value "url": "https://www.bleepingcomputer.com/feed/", }, { + # Literal argument value "name": "The Hacker News", + # Literal argument value "url": "https://feeds.feedburner.com/TheHackersNews", }, { + # Literal argument value "name": "Krebs on Security", + # Literal argument value "url": "https://krebsonsecurity.com/feed/", }, ] @@ -73,49 +99,81 @@ def _fetch_feed(url: str) -> list[dict[str, str]]: Each entry is a dict with keys ``title``, ``link``, and ``description``. Returns an empty list on any error so the scan can continue. """ + # Attempt the following; catch errors below try: + # Assign resp = _requests.get(url, timeout=_FEED_TIMEOUT, headers={ resp = _requests.get(url, timeout=_FEED_TIMEOUT, headers={ + # Literal argument value "User-Agent": "AegisPlatform/1.0 IntelScan", }) + # Call resp.raise_for_status() resp.raise_for_status() + # Handle Exception except Exception as exc: + # Log warning: "Failed to fetch feed %s: %s", url, exc logger.warning("Failed to fetch feed %s: %s", url, exc) + # Return [] return [] + # Attempt the following; catch errors below try: + # Assign root = ET.fromstring(resp.content) root = ET.fromstring(resp.content) + # Handle ET.ParseError except ET.ParseError as exc: + # Log warning: "Failed to parse feed %s: %s", url, exc logger.warning("Failed to parse feed %s: %s", url, exc) + # Return [] return [] + # Assign entries = [] entries: list[dict[str, str]] = [] # RSS 2.0 format: ... for item in root.iter("item"): + # Assign title_el = item.find("title") title_el = item.find("title") + # Assign link_el = item.find("link") link_el = item.find("link") + # Assign desc_el = item.find("description") desc_el = item.find("description") + # Call entries.append() entries.append({ + # Literal argument value "title": title_el.text.strip() if title_el is not None and title_el.text else "", + # Literal argument value "link": link_el.text.strip() if link_el is not None and link_el.text else "", + # Literal argument value "description": desc_el.text.strip() if desc_el is not None and desc_el.text else "", }) # Atom format: ... ns = {"atom": "http://www.w3.org/2005/Atom"} + # Iterate over root.iter("{http for entry in root.iter("{http://www.w3.org/2005/Atom}entry"): + # Assign title_el = entry.find("atom:title", ns) title_el = entry.find("atom:title", ns) + # Assign link_el = entry.find("atom:link", ns) link_el = entry.find("atom:link", ns) + # Assign summary_el = entry.find("atom:summary", ns) summary_el = entry.find("atom:summary", ns) + # Assign link_href = "" link_href = "" + # Check: link_el is not None if link_el is not None: + # Assign link_href = link_el.get("href", "") link_href = link_el.get("href", "") + # Call entries.append() entries.append({ + # Literal argument value "title": title_el.text.strip() if title_el is not None and title_el.text else "", + # Literal argument value "link": link_href.strip(), + # Literal argument value "description": summary_el.text.strip() if summary_el is not None and summary_el.text else "", }) + # Return entries return entries @@ -147,6 +205,7 @@ def _entry_matches( name_patterns: list[re.Pattern], ) -> bool: """Return True if any pattern matches the entry's title or description.""" + # Assign text = f"{entry.get('title', '')} {entry.get('description', '')}" text = f"{entry.get('title', '')} {entry.get('description', '')}" return any(p.search(text) for p in id_patterns + name_patterns) @@ -164,20 +223,23 @@ def scan_intel(db: Session) -> dict: db : Session Active SQLAlchemy database session. - Returns + Returns: ------- dict Summary with keys ``new_items``, ``duplicates_skipped``, ``techniques_flagged``, ``feeds_checked``. """ + # Log info: "Intel scan starting..." logger.info("Intel scan starting...") # 1. Load all active techniques techniques = ( db.query(Technique) + # Chain .order_by() call .order_by(Technique.mitre_id) .all() ) + # Log info: "Scanning %d techniques against %d feeds", len(tec logger.info("Scanning %d techniques against %d feeds", len(techniques), len(RSS_FEEDS)) # 2. Pre-load all existing intel URLs for dedup @@ -187,24 +249,36 @@ def scan_intel(db: Session) -> dict: # 3. Fetch all feeds once all_entries: list[tuple[str, dict[str, str]]] = [] # (feed_name, entry) + # Assign feeds_ok = 0 feeds_ok = 0 + # Iterate over RSS_FEEDS for feed in RSS_FEEDS: + # Assign entries = _fetch_feed(feed["url"]) entries = _fetch_feed(feed["url"]) + # Check: entries if entries: + # Assign feeds_ok = 1 feeds_ok += 1 + # Iterate over entries for entry in entries: + # Call all_entries.append() all_entries.append((feed["name"], entry)) + # Log info: "Fetched %d entries from %d/%d feeds", len(all_ent logger.info("Fetched %d entries from %d/%d feeds", len(all_entries), feeds_ok, len(RSS_FEEDS)) # 4. Match entries to techniques new_items = 0 + # Assign duplicates_skipped = 0 duplicates_skipped = 0 + # Assign techniques_flagged = set() techniques_flagged: set[str] = set() + # Iterate over techniques for technique in techniques: id_patterns, name_patterns = _build_patterns(technique) + # Iterate over all_entries for feed_name, entry in all_entries: if not _entry_matches(entry, id_patterns, name_patterns): continue @@ -213,45 +287,69 @@ def scan_intel(db: Session) -> dict: if not entry.get("title", "").strip(): continue + # Assign url = entry.get("link", "").strip() url = entry.get("link", "").strip() + # Check: not url if not url: + # Skip to the next loop iteration continue # Dedup if url in existing_urls: + # Assign duplicates_skipped = 1 duplicates_skipped += 1 + # Skip to the next loop iteration continue # Create IntelItem intel_item = IntelItem( + # Keyword argument: technique_id technique_id=technique.id, + # Keyword argument: url url=url, + # Keyword argument: title title=entry.get("title", "")[:500], + # Keyword argument: source source=feed_name, + # Keyword argument: detected_at detected_at=datetime.utcnow(), + # Keyword argument: reviewed reviewed=False, ) + # Stage new record(s) for database insertion db.add(intel_item) + # Call existing_urls.add() existing_urls.add(url) + # Assign new_items = 1 new_items += 1 # Flag technique for review if not technique.review_required: + # Assign technique.review_required = True technique.review_required = True + # Call techniques_flagged.add() techniques_flagged.add(technique.mitre_id) # 5. Single commit db.commit() + # Assign summary = { summary = { + # Literal argument value "new_items": new_items, + # Literal argument value "duplicates_skipped": duplicates_skipped, + # Literal argument value "techniques_flagged": len(techniques_flagged), + # Literal argument value "feeds_checked": feeds_ok, } + # Log info: logger.info( + # Literal argument value "Intel scan complete — new=%d, duplicates_skipped=%d, " + # Literal argument value "techniques_flagged=%d, feeds_checked=%d", new_items, duplicates_skipped, len(techniques_flagged), feeds_ok, ) @@ -259,12 +357,19 @@ def scan_intel(db: Session) -> dict: # 6. Audit log log_action( db, + # Keyword argument: user_id user_id=None, + # Keyword argument: action action="intel_scan", + # Keyword argument: entity_type entity_type="intel_item", + # Keyword argument: entity_id entity_id=None, + # Keyword argument: details details=summary, ) + # Commit all pending changes to the database db.commit() + # Return summary return summary diff --git a/backend/app/services/jira_service.py b/backend/app/services/jira_service.py index fe1819e..ff96312 100644 --- a/backend/app/services/jira_service.py +++ b/backend/app/services/jira_service.py @@ -28,22 +28,44 @@ creates the Jira ticket and stores the link. from __future__ import annotations +# Import logging import logging + +# Import datetime from datetime from datetime import datetime -from typing import Optional + +# Import Any, Optional from typing +from typing import Any, Optional + +# Import UUID from uuid from uuid import UUID +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session +# Import settings from app.config from app.config import settings + +# Import EntityNotFoundError from app.domain.errors from app.domain.errors import EntityNotFoundError + +# Import InvalidOperationError from app.domain.exceptions from app.domain.exceptions import InvalidOperationError + +# Import Campaign from app.models.campaign from app.models.campaign import Campaign + +# Import JiraLink, JiraLinkEntityType, JiraSyncDirection from app.models.jira_link from app.models.jira_link import JiraLink, JiraLinkEntityType, JiraSyncDirection + +# Import Technique from app.models.technique from app.models.technique import Technique + +# Import Test from app.models.test from app.models.test import Test from app.models.user import User +# Assign logger = logging.getLogger(__name__) logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- @@ -624,6 +646,7 @@ def get_jira_client(): Prefer ``get_user_jira_client()`` for new code. """ if not settings.JIRA_ENABLED: + # Raise InvalidOperationError raise InvalidOperationError("Jira integration is not enabled") if not settings.JIRA_URL or not settings.JIRA_USERNAME or not settings.JIRA_API_TOKEN: raise InvalidOperationError( @@ -639,75 +662,121 @@ def get_jira_client(): ) +# Define function search_jira_issues def search_jira_issues(query: str, max_results: int = 10) -> list[dict]: """Search Jira issues by JQL or free text (uses global credentials).""" jira = get_jira_client() + # Assign jql = query if "=" in query or "~" in query else f'summary ~ "{query}"' jql = query if "=" in query or "~" in query else f'summary ~ "{query}"' + # Assign results = jira.jql(jql, limit=max_results) results = jira.jql(jql, limit=max_results) + # Return [ return [ { + # Literal argument value "issue_key": issue["key"], + # Literal argument value "summary": issue["fields"]["summary"], + # Literal argument value "status": issue["fields"]["status"]["name"], + # Literal argument value "assignee": (issue["fields"].get("assignee") or {}).get("displayName"), + # Literal argument value "priority": (issue["fields"].get("priority") or {}).get("name"), } for issue in results.get("issues", []) ] +# Define function create_jira_issue def create_jira_issue( + # Entry: project_key project_key: str, + # Entry: summary summary: str, + # Entry: description description: str, + # Entry: issue_type issue_type: str = "Task", + # Entry: labels labels: Optional[list[str]] = None, + # Entry: custom_fields custom_fields: Optional[dict] = None, ) -> dict: """Create a Jira issue and return its key + id (uses global credentials).""" jira = get_jira_client() + # Assign fields = { fields: dict = { + # Literal argument value "project": {"key": project_key}, + # Literal argument value "summary": summary, + # Literal argument value "description": description, + # Literal argument value "issuetype": {"name": issue_type}, } + # Check: labels if labels: + # Assign fields["labels"] = labels fields["labels"] = labels + # Check: custom_fields if custom_fields: + # Call fields.update() fields.update(custom_fields) + # Assign result = jira.issue_create(fields=fields) result = jira.issue_create(fields=fields) + # Return {"issue_key": result["key"], "issue_id": result["id"]} return {"issue_key": result["key"], "issue_id": result["id"]} +# Define function sync_jira_to_aegis def sync_jira_to_aegis(db: Session, link: JiraLink) -> None: """Pull current status from Jira into the local link record (global creds).""" jira = get_jira_client() + # Assign issue = jira.issue(link.jira_issue_key) issue = jira.issue(link.jira_issue_key) + # Assign fields = issue.get("fields", {}) fields = issue.get("fields", {}) + # Assign link.jira_status = fields.get("status", {}).get("name") link.jira_status = fields.get("status", {}).get("name") + # Assign link.jira_priority = (fields.get("priority") or {}).get("name") link.jira_priority = (fields.get("priority") or {}).get("name") + # Assign link.jira_assignee = (fields.get("assignee") or {}).get("displayName") link.jira_assignee = (fields.get("assignee") or {}).get("displayName") + # Assign link.jira_story_points = str(fields.get("customfield_10016", "")) link.jira_story_points = str(fields.get("customfield_10016", "")) + # Assign link.last_synced_at = datetime.utcnow() link.last_synced_at = datetime.utcnow() + # Flush changes to DB without committing the transaction db.flush() +# Define function sync_aegis_to_jira def sync_aegis_to_jira(db: Session, link: JiraLink, entity_data: dict) -> None: """Push an Aegis status update as a Jira comment (global creds).""" jira = get_jira_client() + # Assign comment_body = _build_sync_comment(entity_data) comment_body = _build_sync_comment(entity_data) + # Call jira.issue_add_comment() jira.issue_add_comment(link.jira_issue_key, comment_body) + # Assign link.last_synced_at = datetime.utcnow() link.last_synced_at = datetime.utcnow() + # Flush changes to DB without committing the transaction db.flush() +# Define function _build_sync_comment def _build_sync_comment(data: dict) -> str: lines = ["h3. Aegis Sync Update", ""] + # Iterate over data.items() for key, value in data.items(): + # Call lines.append() lines.append(f"*{key}:* {value}") + # Call lines.append() lines.append(f"\n_Synced at {datetime.utcnow().isoformat()}_") + # Return "\n".join(lines) return "\n".join(lines) @@ -715,60 +784,94 @@ def _build_sync_comment(data: dict) -> str: def create_link( + # Entry: db db: Session, *, + # Entry: entity_type entity_type: JiraLinkEntityType, + # Entry: entity_id entity_id: UUID, + # Entry: jira_issue_key jira_issue_key: str, + # Entry: sync_direction sync_direction: JiraSyncDirection, + # Entry: created_by created_by: UUID, ) -> JiraLink: link = JiraLink( + # Keyword argument: entity_type entity_type=entity_type, + # Keyword argument: entity_id entity_id=entity_id, + # Keyword argument: jira_issue_key jira_issue_key=jira_issue_key, + # Keyword argument: sync_direction sync_direction=sync_direction, + # Keyword argument: created_by created_by=created_by, ) + # Stage new record(s) for database insertion db.add(link) + # Flush changes to DB without committing the transaction db.flush() + # Check: settings.JIRA_ENABLED if settings.JIRA_ENABLED: + # Attempt the following; catch errors below try: + # Call sync_jira_to_aegis() sync_jira_to_aegis(db, link) + # Handle Exception except Exception as e: + # Log warning: "Initial Jira sync failed for %s: %s", jira_issue_ logger.warning("Initial Jira sync failed for %s: %s", jira_issue_key, e) + # Return link return link +# Define function list_links def list_links( + # Entry: db db: Session, *, + # Entry: entity_type entity_type: Optional[JiraLinkEntityType] = None, + # Entry: entity_id entity_id: Optional[UUID] = None, entity_ids: Optional[list[UUID]] = None, ) -> list[JiraLink]: query = db.query(JiraLink) + # Check: entity_type if entity_type: + # Assign query = query.filter(JiraLink.entity_type == entity_type) query = query.filter(JiraLink.entity_type == entity_type) + # Check: entity_id if entity_id: + # Assign query = query.filter(JiraLink.entity_id == entity_id) query = query.filter(JiraLink.entity_id == entity_id) elif entity_ids: query = query.filter(JiraLink.entity_id.in_(entity_ids)) return query.order_by(JiraLink.created_at.desc()).all() +# Define function get_link_or_raise def get_link_or_raise(db: Session, link_id: UUID) -> JiraLink: link = db.query(JiraLink).filter(JiraLink.id == link_id).first() + # Check: not link if not link: + # Raise EntityNotFoundError raise EntityNotFoundError("JiraLink", str(link_id)) + # Return link return link +# Define function delete_link def delete_link(db: Session, link_id: UUID) -> JiraLink: link = get_link_or_raise(db, link_id) + # Mark record for deletion on next commit db.delete(link) + # Return link return link @@ -776,43 +879,64 @@ def build_issue_data( db: Session, entity_type: JiraLinkEntityType, entity_id: UUID ) -> tuple[str, str]: """Build Jira issue summary and description from an Aegis entity.""" + # Check: entity_type == JiraLinkEntityType.test if entity_type == JiraLinkEntityType.test: + # Assign entity = db.query(Test).filter(Test.id == entity_id).first() entity = db.query(Test).filter(Test.id == entity_id).first() + # Check: not entity if not entity: + # Raise EntityNotFoundError raise EntityNotFoundError("Test", str(entity_id)) technique = db.query(Technique).filter(Technique.id == entity.technique_id).first() return ( f"[Aegis] {technique.mitre_id if technique else 'N/A'} — {entity.name}", _build_test_description(entity, technique), ) + # Alternative: entity_type == JiraLinkEntityType.campaign elif entity_type == JiraLinkEntityType.campaign: + # Assign entity = db.query(Campaign).filter(Campaign.id == entity_id).first() entity = db.query(Campaign).filter(Campaign.id == entity_id).first() + # Check: not entity if not entity: + # Raise EntityNotFoundError raise EntityNotFoundError("Campaign", str(entity_id)) + # Return ( return ( f"[Aegis Campaign] {entity.name}", f"Campaign: {entity.name}\nType: {entity.type}\nStatus: {entity.status}\n" f"Description: {entity.description or 'N/A'}", ) + # Alternative: entity_type == JiraLinkEntityType.technique elif entity_type == JiraLinkEntityType.technique: + # Assign entity = db.query(Technique).filter(Technique.id == entity_id).first() entity = db.query(Technique).filter(Technique.id == entity_id).first() + # Check: not entity if not entity: + # Raise EntityNotFoundError raise EntityNotFoundError("Technique", str(entity_id)) + # Return ( return ( f"[Aegis Technique] {entity.mitre_id} - {entity.name}", f"MITRE ID: {entity.mitre_id}\nName: {entity.name}\n" f"Tactic: {entity.tactic or 'N/A'}\n" f"Description: {entity.description or 'N/A'}", ) + # Fallback: handle remaining cases else: + # Return f"[Aegis] Entity {entity_id}", f"Entity type: {entity_type.value}" return f"[Aegis] Entity {entity_id}", f"Entity type: {entity_type.value}" +# Define function create_issue_and_link def create_issue_and_link( + # Entry: db db: Session, *, + # Entry: entity_type entity_type: JiraLinkEntityType, + # Entry: entity_id entity_id: UUID, + # Entry: created_by created_by: UUID, ) -> dict: """Create a Jira issue from an Aegis entity and link them (global creds).""" @@ -821,16 +945,25 @@ def create_issue_and_link( result = create_jira_issue( project_key=project_key, summary=summary, + # Keyword argument: description description=description, + # Keyword argument: labels labels=["aegis", entity_type.value], ) + # Assign link = JiraLink( link = JiraLink( + # Keyword argument: entity_type entity_type=entity_type, + # Keyword argument: entity_id entity_id=entity_id, + # Keyword argument: jira_issue_key jira_issue_key=result["issue_key"], + # Keyword argument: jira_issue_id jira_issue_id=result["issue_id"], jira_project_key=project_key, created_by=created_by, ) + # Stage new record(s) for database insertion db.add(link) + # Return {"issue_key": result["issue_key"], "link_id": str(link.id)} return {"issue_key": result["issue_key"], "link_id": str(link.id)} diff --git a/backend/app/services/lolbas_import_service.py b/backend/app/services/lolbas_import_service.py index bac41c0..d4d80b4 100644 --- a/backend/app/services/lolbas_import_service.py +++ b/backend/app/services/lolbas_import_service.py @@ -24,24 +24,45 @@ Deduplication keys: - GTFOBins: ``source + binary_name + function`` → stored in ``atomic_test_id`` """ +# Import io import io + +# Import logging import logging + +# Import re import re + +# Import shutil import shutil + +# Import tempfile import tempfile + +# Import zipfile import zipfile + +# Import datetime from datetime from datetime import datetime + +# Import Path from pathlib from pathlib import Path +# Import requests import requests as _requests + +# Import yaml import yaml + +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session -from app.models.test_template import TestTemplate +# Import DataSource from app.models.data_source from app.models.data_source import DataSource from app.models.technique import Technique from app.services.audit_service import log_action +# Assign logger = logging.getLogger(__name__) logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- @@ -49,34 +70,57 @@ logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- LOLBAS_ZIP_URL = ( + # Literal argument value "https://github.com/LOLBAS-Project/LOLBAS" + # Literal argument value "/archive/refs/heads/master.zip" ) +# Assign GTFOBINS_ZIP_URL = ( GTFOBINS_ZIP_URL = ( + # Literal argument value "https://github.com/GTFOBins/GTFOBins.github.io" + # Literal argument value "/archive/refs/heads/master.zip" ) +# Assign _DOWNLOAD_TIMEOUT = 300 _DOWNLOAD_TIMEOUT = 300 # GTFOBins function → MITRE technique mapping _GTFOBINS_FUNCTION_MAP: dict[str, str] = { + # Literal argument value "shell": "T1059", + # Literal argument value "command": "T1059", + # Literal argument value "reverse-shell": "T1059", + # Literal argument value "non-interactive-reverse-shell": "T1059", + # Literal argument value "bind-shell": "T1059", + # Literal argument value "non-interactive-bind-shell": "T1059", + # Literal argument value "file-upload": "T1105", + # Literal argument value "file-download": "T1105", + # Literal argument value "upload": "T1105", + # Literal argument value "download": "T1105", + # Literal argument value "file-write": "T1105", + # Literal argument value "file-read": "T1005", + # Literal argument value "library-load": "T1129", + # Literal argument value "sudo": "T1548.003", + # Literal argument value "suid": "T1548.001", + # Literal argument value "capabilities": "T1548", + # Literal argument value "limited-suid": "T1548.001", } @@ -88,18 +132,28 @@ _GTFOBINS_FUNCTION_MAP: dict[str, str] = { def _download_zip(url: str) -> bytes: """Download a ZIP from *url* and return raw bytes.""" + # Log info: "Downloading ZIP from %s …", url logger.info("Downloading ZIP from %s …", url) + # Assign resp = _requests.get(url, timeout=_DOWNLOAD_TIMEOUT, stream=True) resp = _requests.get(url, timeout=_DOWNLOAD_TIMEOUT, stream=True) + # Call resp.raise_for_status() resp.raise_for_status() + # Assign content = resp.content content = resp.content + # Log info: "Downloaded %.1f MB", len(content) / (1024 * 1024 logger.info("Downloaded %.1f MB", len(content) / (1024 * 1024)) + # Return content return content +# Define function _extract_zip def _extract_zip(zip_bytes: bytes, dest: str) -> Path: """Extract *zip_bytes* into *dest* and return the root directory.""" + # Open context manager with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf: + # Call zf.extractall() zf.extractall(dest) + # Return Path(dest) return Path(dest) @@ -110,83 +164,141 @@ def _extract_zip(zip_bytes: bytes, dest: str) -> Path: def _parse_lolbas(root_dir: Path) -> list[dict]: """Parse LOLBAS YAML files and return template dicts.""" + # Assign results = [] results: list[dict] = [] + # Assign lolbas_root = root_dir / "LOLBAS-master" lolbas_root = root_dir / "LOLBAS-master" + # Assign yaml_dirs = [ yaml_dirs = [ lolbas_root / "yml" / "OSBinaries", lolbas_root / "yml" / "OSLibraries", lolbas_root / "yml" / "OSScripts", ] + # Assign yaml_files = [] yaml_files = [] + # Iterate over yaml_dirs for d in yaml_dirs: + # Check: d.is_dir() if d.is_dir(): + # Call yaml_files.extend() yaml_files.extend(sorted(d.rglob("*.yml"))) + # Log info: "LOLBAS: Found %d YAML files", len(yaml_files logger.info("LOLBAS: Found %d YAML files", len(yaml_files)) + # Iterate over yaml_files for yaml_path in yaml_files: + # Attempt the following; catch errors below try: + # Open context manager with open(yaml_path, "r", encoding="utf-8") as fh: + # Assign data = yaml.safe_load(fh) data = yaml.safe_load(fh) + # Handle Exception except Exception as exc: + # Log debug: "Failed to parse %s: %s", yaml_path, exc logger.debug("Failed to parse %s: %s", yaml_path, exc) + # Skip to the next loop iteration continue + # Check: not isinstance(data, dict) if not isinstance(data, dict): + # Skip to the next loop iteration continue + # Assign binary_name = data.get("Name", "").strip() binary_name = data.get("Name", "").strip() + # Check: not binary_name if not binary_name: + # Skip to the next loop iteration continue + # Assign description = data.get("Description", "") description = data.get("Description", "") + # Assign commands = data.get("Commands", []) commands = data.get("Commands", []) + # Check: not isinstance(commands, list) if not isinstance(commands, list): + # Skip to the next loop iteration continue + # Iterate over commands for cmd_entry in commands: + # Check: not isinstance(cmd_entry, dict) if not isinstance(cmd_entry, dict): + # Skip to the next loop iteration continue + # Assign mitre_id = cmd_entry.get("MitreID") mitre_id = cmd_entry.get("MitreID") + # Check: not mitre_id if not mitre_id: + # Skip to the next loop iteration continue # Normalise the MITRE ID mitre_id = str(mitre_id).strip().upper() + # Check: not mitre_id.startswith("T") if not mitre_id.startswith("T"): + # Skip to the next loop iteration continue + # Assign command = cmd_entry.get("Command", "") command = cmd_entry.get("Command", "") + # Assign usecase = cmd_entry.get("Usecase", "") usecase = cmd_entry.get("Usecase", "") + # Assign cmd_description = cmd_entry.get("Description", "") cmd_description = cmd_entry.get("Description", "") # Dedup key dedup_key = f"lolbas:{binary_name}:{mitre_id}" + # Assign procedure = [] procedure = [] + # Check: cmd_description if cmd_description: + # Call procedure.append() procedure.append(f"Description: {cmd_description}") + # Check: usecase if usecase: + # Call procedure.append() procedure.append(f"Use case: {usecase}") + # Check: command if command: + # Call procedure.append() procedure.append(f"Command: {command}") + # Call results.append() results.append({ + # Literal argument value "mitre_technique_id": mitre_id, + # Literal argument value "name": f"LOLBAS: {binary_name} — {usecase or cmd_description or mitre_id}"[:500], - "description": f"{description}\n\n{cmd_description}".strip()[:2000] if description else cmd_description[:2000] if cmd_description else None, + # Literal argument value + "description": ( + f"{description}\n\n{cmd_description}".strip()[:2000] + if description + else cmd_description[:2000] if cmd_description else None + ), + # Literal argument value "source": "lolbas", + # Literal argument value "platform": "windows", + # Literal argument value "tool_suggested": binary_name, + # Literal argument value "attack_procedure": "\n".join(procedure)[:4000] if procedure else None, + # Literal argument value "atomic_test_id": dedup_key, + # Literal argument value "source_url": f"https://lolbas-project.github.io/lolbas/Binaries/{binary_name}/", }) + # Log info: "LOLBAS: Parsed %d templates", len(results logger.info("LOLBAS: Parsed %d templates", len(results)) + # Return results return results @@ -197,85 +309,138 @@ def _parse_lolbas(root_dir: Path) -> list[dict]: def _parse_gtfobins(root_dir: Path) -> list[dict]: """Parse GTFOBins markdown files and return template dicts.""" + # Assign results = [] results: list[dict] = [] + # Assign gtfobins_root = root_dir / "GTFOBins.github.io-master" / "_gtfobins" gtfobins_root = root_dir / "GTFOBins.github.io-master" / "_gtfobins" + # Check: not gtfobins_root.is_dir() if not gtfobins_root.is_dir(): + # Log warning: "GTFOBins directory not found at %s", gtfobins_roo logger.warning("GTFOBins directory not found at %s", gtfobins_root) + # Return results return results + # Assign md_files = sorted( md_files = sorted( f for f in gtfobins_root.iterdir() if f.is_file() and f.suffix in (".md", "") ) + # Log info: "GTFOBins: Found %d files", len(md_files logger.info("GTFOBins: Found %d files", len(md_files)) + # Iterate over md_files for md_path in md_files: + # Assign binary_name = md_path.stem # e.g. "awk" binary_name = md_path.stem # e.g. "awk" + # Attempt the following; catch errors below try: + # Open context manager with open(md_path, "r", encoding="utf-8") as fh: + # Assign content = fh.read() content = fh.read() + # Handle Exception except Exception as exc: + # Log debug: "Failed to read %s: %s", md_path, exc logger.debug("Failed to read %s: %s", md_path, exc) + # Skip to the next loop iteration continue # Extract YAML front-matter front_matter = _extract_front_matter(content) + # Check: not front_matter if not front_matter: + # Skip to the next loop iteration continue + # Assign functions = front_matter.get("functions", {}) functions = front_matter.get("functions", {}) + # Check: not isinstance(functions, dict) if not isinstance(functions, dict): + # Skip to the next loop iteration continue + # Iterate over functions.items() for func_name, func_data in functions.items(): # Map function to MITRE technique mitre_id = _GTFOBINS_FUNCTION_MAP.get(func_name.lower()) + # Check: not mitre_id if not mitre_id: + # Skip to the next loop iteration continue # Extract code examples from function data examples = [] + # Check: isinstance(func_data, list) if isinstance(func_data, list): + # Iterate over func_data for entry in func_data: + # Check: isinstance(entry, dict) if isinstance(entry, dict): + # Assign code = entry.get("code", "") code = entry.get("code", "") + # Check: code if code: + # Call examples.append() examples.append(str(code)) + # Alternative: isinstance(entry, str) elif isinstance(entry, str): + # Call examples.append() examples.append(entry) + # Assign procedure = "\n\n".join(examples) if examples else None procedure = "\n\n".join(examples) if examples else None + # Assign dedup_key = f"gtfobins:{binary_name}:{func_name}" dedup_key = f"gtfobins:{binary_name}:{func_name}" + # Call results.append() results.append({ + # Literal argument value "mitre_technique_id": mitre_id, + # Literal argument value "name": f"GTFOBins: {binary_name} — {func_name}"[:500], + # Literal argument value "description": f"Abuse {binary_name} binary for {func_name} on Linux/Unix."[:2000], + # Literal argument value "source": "gtfobins", + # Literal argument value "platform": "linux", + # Literal argument value "tool_suggested": binary_name, + # Literal argument value "attack_procedure": procedure[:4000] if procedure else None, + # Literal argument value "atomic_test_id": dedup_key, + # Literal argument value "source_url": f"https://gtfobins.github.io/gtfobins/{binary_name}/", }) + # Log info: "GTFOBins: Parsed %d templates", len(results logger.info("GTFOBins: Parsed %d templates", len(results)) + # Return results return results +# Define function _extract_front_matter def _extract_front_matter(content: str) -> dict | None: """Extract YAML front-matter from a markdown/GTFOBins file. Supports both ``---/---`` (standard front-matter) and ``---/...`` (YAML document-end marker used by GTFOBins). """ + # Assign match = re.match(r"^---\s*\n(.*?)\n(?:---|\.\.\.)", content, re.DOTALL) match = re.match(r"^---\s*\n(.*?)\n(?:---|\.\.\.)", content, re.DOTALL) + # Check: not match if not match: + # Return None return None + # Attempt the following; catch errors below try: + # Return yaml.safe_load(match.group(1)) return yaml.safe_load(match.group(1)) + # Handle Exception except Exception: + # Return None return None @@ -286,36 +451,59 @@ def _extract_front_matter(content: str) -> dict | None: def _upsert_templates(db: Session, items: list[dict], source_name: str) -> dict: """Insert templates, skipping existing ones by atomic_test_id.""" + # Assign existing_ids = { existing_ids: set[str] = { row[0] for row in db.query(TestTemplate.atomic_test_id) + # Chain .filter() call .filter(TestTemplate.source == source_name) + # Chain .filter() call .filter(TestTemplate.atomic_test_id.isnot(None)) + # Chain .all() call .all() } + # Assign created = 0 created = 0 + # Assign skipped = 0 skipped = 0 new_technique_ids: set[str] = set() + # Iterate over items for item in items: + # Check: item["atomic_test_id"] in existing_ids if item["atomic_test_id"] in existing_ids: + # Assign skipped = 1 skipped += 1 + # Skip to the next loop iteration continue + # Assign template = TestTemplate( template = TestTemplate( + # Keyword argument: mitre_technique_id mitre_technique_id=item["mitre_technique_id"], + # Keyword argument: name name=item["name"], + # Keyword argument: description description=item["description"], + # Keyword argument: source source=item["source"], + # Keyword argument: source_url source_url=item.get("source_url"), + # Keyword argument: attack_procedure attack_procedure=item.get("attack_procedure"), + # Keyword argument: platform platform=item["platform"], + # Keyword argument: tool_suggested tool_suggested=item.get("tool_suggested"), + # Keyword argument: atomic_test_id atomic_test_id=item["atomic_test_id"], + # Keyword argument: is_active is_active=True, ) + # Stage new record(s) for database insertion db.add(template) + # Call existing_ids.add() existing_ids.add(item["atomic_test_id"]) new_technique_ids.add(item["mitre_technique_id"]) created += 1 @@ -326,6 +514,7 @@ def _upsert_templates(db: Session, items: list[dict], source_name: str) -> dict: ).update({"review_required": True}, synchronize_session=False) db.commit() + # Return {"created": created, "skipped_existing": skipped, "total_parsed": l... return {"created": created, "skipped_existing": skipped, "total_parsed": len(items)} @@ -339,56 +528,93 @@ def sync(db: Session) -> dict: Returns a summary dict with ``created``, ``skipped_existing``, ``total_parsed``. """ + # Assign tmp_dir = tempfile.mkdtemp(prefix="aegis_lolbas_") tmp_dir = tempfile.mkdtemp(prefix="aegis_lolbas_") + # Attempt the following; catch errors below try: + # Assign zip_bytes = _download_zip(LOLBAS_ZIP_URL) zip_bytes = _download_zip(LOLBAS_ZIP_URL) + # Assign root_dir = _extract_zip(zip_bytes, tmp_dir) root_dir = _extract_zip(zip_bytes, tmp_dir) + # Assign parsed = _parse_lolbas(root_dir) parsed = _parse_lolbas(root_dir) + # Always execute this cleanup block finally: + # Call shutil.rmtree() shutil.rmtree(tmp_dir, ignore_errors=True) + # Assign summary = _upsert_templates(db, parsed, "lolbas") summary = _upsert_templates(db, parsed, "lolbas") # Update DataSource record ds = db.query(DataSource).filter(DataSource.name == "lolbas").first() + # Check: ds if ds: + # Assign ds.last_sync_at = datetime.utcnow() ds.last_sync_at = datetime.utcnow() + # Assign ds.last_sync_status = "success" ds.last_sync_status = "success" + # Assign ds.last_sync_stats = summary ds.last_sync_stats = summary + # Commit all pending changes to the database db.commit() + # Log info: "LOLBAS import complete — %s", summary logger.info("LOLBAS import complete — %s", summary) + # Call log_action() log_action(db, user_id=None, action="import_lolbas", + # Keyword argument: entity_type entity_type="test_template", entity_id=None, details=summary) + # Commit all pending changes to the database db.commit() + # Return summary return summary +# Define function sync_gtfobins def sync_gtfobins(db: Session) -> dict: """Import GTFOBins templates. Returns a summary dict with ``created``, ``skipped_existing``, ``total_parsed``. """ + # Assign tmp_dir = tempfile.mkdtemp(prefix="aegis_gtfobins_") tmp_dir = tempfile.mkdtemp(prefix="aegis_gtfobins_") + # Attempt the following; catch errors below try: + # Assign zip_bytes = _download_zip(GTFOBINS_ZIP_URL) zip_bytes = _download_zip(GTFOBINS_ZIP_URL) + # Assign root_dir = _extract_zip(zip_bytes, tmp_dir) root_dir = _extract_zip(zip_bytes, tmp_dir) + # Assign parsed = _parse_gtfobins(root_dir) parsed = _parse_gtfobins(root_dir) + # Always execute this cleanup block finally: + # Call shutil.rmtree() shutil.rmtree(tmp_dir, ignore_errors=True) + # Assign summary = _upsert_templates(db, parsed, "gtfobins") summary = _upsert_templates(db, parsed, "gtfobins") # Update DataSource record ds = db.query(DataSource).filter(DataSource.name == "gtfobins").first() + # Check: ds if ds: + # Assign ds.last_sync_at = datetime.utcnow() ds.last_sync_at = datetime.utcnow() + # Assign ds.last_sync_status = "success" ds.last_sync_status = "success" + # Assign ds.last_sync_stats = summary ds.last_sync_stats = summary + # Commit all pending changes to the database db.commit() + # Log info: "GTFOBins import complete — %s", summary logger.info("GTFOBins import complete — %s", summary) + # Call log_action() log_action(db, user_id=None, action="import_gtfobins", + # Keyword argument: entity_type entity_type="test_template", entity_id=None, details=summary) + # Commit all pending changes to the database db.commit() + # Return summary return summary diff --git a/backend/app/services/metrics_query_service.py b/backend/app/services/metrics_query_service.py index cb7f711..debf38a 100644 --- a/backend/app/services/metrics_query_service.py +++ b/backend/app/services/metrics_query_service.py @@ -7,16 +7,28 @@ of MITRE ATT&CK technique coverage for dashboards and reporting. This module is framework-agnostic: no FastAPI imports. """ +# Enable future language features for compatibility from __future__ import annotations +# Import defaultdict from collections from collections import defaultdict +# Import func from sqlalchemy from sqlalchemy import func + +# Import Session, joinedload from sqlalchemy.orm from sqlalchemy.orm import Session, joinedload +# Import TechniqueStatus, TestState from app.models.enums from app.models.enums import TechniqueStatus, TestState + +# Import Technique from app.models.technique from app.models.technique import Technique + +# Import Test from app.models.test from app.models.test import Test + +# Import from app.schemas.metrics from app.schemas.metrics import ( CoverageSummary, RecentTestItem, @@ -27,40 +39,60 @@ from app.schemas.metrics import ( ) +# Define function get_coverage_summary def get_coverage_summary(db: Session) -> CoverageSummary: """Return a global coverage summary across all techniques.""" + # Assign rows = ( rows = ( db.query( Technique.status_global, func.count(Technique.id).label("cnt"), ) + # Chain .group_by() call .group_by(Technique.status_global) + # Chain .all() call .all() ) + # Assign counts = {s.value: 0 for s in TechniqueStatus} counts: dict[str, int] = {s.value: 0 for s in TechniqueStatus} + # Iterate over rows for status, cnt in rows: + # Assign counts[status.value] = cnt counts[status.value] = cnt + # Assign total = sum(counts.values()) total = sum(counts.values()) + # Assign validated = counts["validated"] validated = counts["validated"] + # Assign partial = counts["partial"] partial = counts["partial"] + # Assign coverage_pct = ( coverage_pct = ( round((validated + partial) / total * 100, 2) if total > 0 else 0.0 ) + # Return CoverageSummary( return CoverageSummary( + # Keyword argument: total_techniques total_techniques=total, + # Keyword argument: validated validated=validated, + # Keyword argument: partial partial=partial, + # Keyword argument: not_covered not_covered=counts["not_covered"], + # Keyword argument: in_progress in_progress=counts["in_progress"], + # Keyword argument: not_evaluated not_evaluated=counts["not_evaluated"], + # Keyword argument: coverage_percentage coverage_percentage=coverage_pct, ) +# Define function get_coverage_by_tactic def get_coverage_by_tactic(db: Session) -> list[TacticCoverage]: """Return coverage breakdown grouped by tactic. @@ -68,6 +100,7 @@ def get_coverage_by_tactic(db: Session) -> list[TacticCoverage]: comma-separated string), the technique is counted once per tactic it belongs to. """ + # Assign techniques = db.query( techniques = db.query( Technique.tactic, Technique.status_global ).all() @@ -75,179 +108,270 @@ def get_coverage_by_tactic(db: Session) -> list[TacticCoverage]: # Accumulate per-tactic counters. A technique with tactic # "persistence, privilege-escalation" is counted in both. tactic_data: dict[str, dict[str, int]] = defaultdict( + # Entry: lambda lambda: {s.value: 0 for s in TechniqueStatus} ) + # Iterate over techniques for tactic_str, status in techniques: + # Check: not tactic_str if not tactic_str: + # Assign tactics = ["unknown"] tactics = ["unknown"] + # Fallback: handle remaining cases else: + # Assign tactics = [t.strip() for t in tactic_str.split(",")] tactics = [t.strip() for t in tactic_str.split(",")] + # Iterate over tactics for tactic in tactics: + # Assign tactic_data[tactic][status.value] = 1 tactic_data[tactic][status.value] += 1 + # Assign result = [] result = [] + # Iterate over sorted(tactic_data) for tactic in sorted(tactic_data): + # Assign counts = tactic_data[tactic] counts = tactic_data[tactic] + # Assign total = sum(counts.values()) total = sum(counts.values()) + # Call result.append() result.append( TacticCoverage( + # Keyword argument: tactic tactic=tactic, + # Keyword argument: total total=total, + # Keyword argument: validated validated=counts["validated"], + # Keyword argument: partial partial=counts["partial"], + # Keyword argument: not_covered not_covered=counts["not_covered"], + # Keyword argument: not_evaluated not_evaluated=counts["not_evaluated"], + # Keyword argument: in_progress in_progress=counts["in_progress"], ) ) + # Return result return result +# Define function get_test_pipeline_counts def get_test_pipeline_counts(db: Session) -> TestPipelineCounts: """Return how many tests are in each pipeline state.""" + # Assign rows = ( rows = ( db.query(Test.state, func.count(Test.id).label("cnt")) + # Chain .group_by() call .group_by(Test.state) + # Chain .all() call .all() ) + # Assign state_counts = {s.value: 0 for s in TestState} state_counts: dict[str, int] = {s.value: 0 for s in TestState} + # Iterate over rows for state, cnt in rows: + # Assign state_counts[state.value] = cnt state_counts[state.value] = cnt + # Assign total = sum(state_counts.values()) total = sum(state_counts.values()) + # Return TestPipelineCounts( return TestPipelineCounts( + # Keyword argument: draft draft=state_counts["draft"], + # Keyword argument: red_executing red_executing=state_counts["red_executing"], + # Keyword argument: blue_evaluating blue_evaluating=state_counts["blue_evaluating"], + # Keyword argument: in_review in_review=state_counts["in_review"], + # Keyword argument: validated validated=state_counts["validated"], + # Keyword argument: rejected rejected=state_counts["rejected"], + # Keyword argument: total total=total, ) +# Define function get_team_activity def get_team_activity(db: Session) -> list[TeamActivity]: """Return activity summary for Red and Blue teams.""" # Red Team: completed = tests past red_executing; pending = draft + red_executing red_completed = ( db.query(func.count(Test.id)) + # Chain .filter() call .filter(Test.state.in_([ TestState.blue_evaluating, TestState.in_review, TestState.validated, TestState.rejected, ])) + # Chain .scalar() call .scalar() ) or 0 + # Assign red_pending = ( red_pending = ( db.query(func.count(Test.id)) + # Chain .filter() call .filter(Test.state.in_([TestState.draft, TestState.red_executing])) + # Chain .scalar() call .scalar() ) or 0 # Blue Team: completed = tests past blue_evaluating; pending = blue_evaluating blue_completed = ( db.query(func.count(Test.id)) + # Chain .filter() call .filter(Test.state.in_([ TestState.in_review, TestState.validated, TestState.rejected, ])) + # Chain .scalar() call .scalar() ) or 0 + # Assign blue_pending = ( blue_pending = ( db.query(func.count(Test.id)) + # Chain .filter() call .filter(Test.state == TestState.blue_evaluating) + # Chain .scalar() call .scalar() ) or 0 + # Return [ return [ TeamActivity( + # Keyword argument: team team="Red Team", + # Keyword argument: tests_completed tests_completed=red_completed, + # Keyword argument: tests_pending tests_pending=red_pending, ), TeamActivity( + # Keyword argument: team team="Blue Team", + # Keyword argument: tests_completed tests_completed=blue_completed, + # Keyword argument: tests_pending tests_pending=blue_pending, ), ] +# Define function get_validation_rate def get_validation_rate(db: Session) -> list[ValidationRate]: """Return approval and rejection rates for Red Lead and Blue Lead.""" # Red Lead validations red_approved = ( db.query(func.count(Test.id)) + # Chain .filter() call .filter(Test.red_validation_status == "approved") + # Chain .scalar() call .scalar() ) or 0 + # Assign red_rejected = ( red_rejected = ( db.query(func.count(Test.id)) + # Chain .filter() call .filter(Test.red_validation_status == "rejected") + # Chain .scalar() call .scalar() ) or 0 + # Assign red_total = red_approved + red_rejected red_total = red_approved + red_rejected + # Assign red_rate = round(red_approved / red_total * 100, 1) if red_total > 0 else 0.0 red_rate = round(red_approved / red_total * 100, 1) if red_total > 0 else 0.0 # Blue Lead validations blue_approved = ( db.query(func.count(Test.id)) + # Chain .filter() call .filter(Test.blue_validation_status == "approved") + # Chain .scalar() call .scalar() ) or 0 + # Assign blue_rejected = ( blue_rejected = ( db.query(func.count(Test.id)) + # Chain .filter() call .filter(Test.blue_validation_status == "rejected") + # Chain .scalar() call .scalar() ) or 0 + # Assign blue_total = blue_approved + blue_rejected blue_total = blue_approved + blue_rejected + # Assign blue_rate = round(blue_approved / blue_total * 100, 1) if blue_total > 0 else 0.0 blue_rate = round(blue_approved / blue_total * 100, 1) if blue_total > 0 else 0.0 + # Return [ return [ ValidationRate( + # Keyword argument: role role="red_lead", + # Keyword argument: total_reviewed total_reviewed=red_total, + # Keyword argument: approved approved=red_approved, + # Keyword argument: rejected rejected=red_rejected, + # Keyword argument: approval_rate approval_rate=red_rate, ), ValidationRate( + # Keyword argument: role role="blue_lead", + # Keyword argument: total_reviewed total_reviewed=blue_total, + # Keyword argument: approved approved=blue_approved, + # Keyword argument: rejected rejected=blue_rejected, + # Keyword argument: approval_rate approval_rate=blue_rate, ), ] +# Define function get_recent_tests def get_recent_tests(db: Session, *, limit: int = 10) -> list[RecentTestItem]: """Return the most recently created tests.""" from sqlalchemy import nullslast tests = ( db.query(Test) + # Chain .options() call .options(joinedload(Test.technique)) .order_by(nullslast(Test.created_at.desc())) .limit(limit) + # Chain .all() call .all() ) + # Return [ return [ RecentTestItem( + # Keyword argument: id id=str(t.id), + # Keyword argument: name name=t.name, + # Keyword argument: state state=t.state.value, + # Keyword argument: technique_mitre_id technique_mitre_id=t.technique.mitre_id if t.technique else None, + # Keyword argument: technique_name technique_name=t.technique.name if t.technique else None, + # Keyword argument: created_at created_at=t.created_at, ) for t in tests diff --git a/backend/app/services/mitre_sync_service.py b/backend/app/services/mitre_sync_service.py index 73335a5..aa4fe98 100644 --- a/backend/app/services/mitre_sync_service.py +++ b/backend/app/services/mitre_sync_service.py @@ -6,123 +6,197 @@ ATT&CK collection, and upserts attack-pattern objects into the local when the TAXII server is unreachable. """ +# Import logging import logging + +# Import datetime from datetime from datetime import datetime +# Import requests import requests as _requests + +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session + +# Import Server as TaxiiServer from taxii2client.v20 from taxii2client.v20 import Server as TaxiiServer -from app.models.technique import Technique +# Import TechniqueStatus from app.models.enums from app.models.enums import TechniqueStatus + +# Import Technique from app.models.technique +from app.models.technique import Technique + +# Import log_action from app.services.audit_service from app.services.audit_service import log_action +# Assign logger = logging.getLogger(__name__) logger = logging.getLogger(__name__) +# Assign TAXII_SERVER_URL = "https://cti-taxii.mitre.org/taxii/" TAXII_SERVER_URL = "https://cti-taxii.mitre.org/taxii/" +# Assign MITRE_SOURCE_NAME = "mitre-attack" MITRE_SOURCE_NAME = "mitre-attack" +# Assign GITHUB_ENTERPRISE_URL = ( GITHUB_ENTERPRISE_URL = ( + # Literal argument value "https://raw.githubusercontent.com/mitre/cti/master/" + # Literal argument value "enterprise-attack/enterprise-attack.json" ) +# Define function _extract_mitre_id def _extract_mitre_id(external_references: list) -> str | None: """Return the MITRE ATT&CK ID (e.g. ``T1059.001``) from external_references.""" + # Check: not external_references if not external_references: + # Return None return None + # Iterate over external_references for ref in external_references: + # Check: ref.get("source_name") == MITRE_SOURCE_NAME if ref.get("source_name") == MITRE_SOURCE_NAME: + # Return ref.get("external_id") return ref.get("external_id") + # Return None return None +# Define function _extract_tactics def _extract_tactics(kill_chain_phases: list) -> str | None: """Return a comma-separated string of tactic phase names.""" + # Check: not kill_chain_phases if not kill_chain_phases: + # Return None return None + # Assign tactics = [ tactics = [ phase.get("phase_name") for phase in kill_chain_phases if phase.get("kill_chain_name") == "mitre-attack" ] + # Return ", ".join(tactics) if tactics else None return ", ".join(tactics) if tactics else None +# Define function _extract_platforms def _extract_platforms(stix_object: dict) -> list: """Return the list of platforms from the STIX object.""" + # Return stix_object.get("x_mitre_platforms", []) return stix_object.get("x_mitre_platforms", []) +# Define function _extract_version def _extract_version(stix_object: dict) -> str | None: """Return the MITRE ATT&CK version string.""" + # Return stix_object.get("x_mitre_version") return stix_object.get("x_mitre_version") +# Define function _extract_last_modified def _extract_last_modified(stix_object: dict) -> datetime | None: """Return the ``modified`` timestamp as a datetime, or None.""" + # Assign modified = stix_object.get("modified") modified = stix_object.get("modified") + # Check: modified is None if modified is None: + # Return None return None + # Check: isinstance(modified, datetime) if isinstance(modified, datetime): + # Return modified return modified + # Attempt the following; catch errors below try: + # Return datetime.fromisoformat(modified.replace("Z", "+00:00")) return datetime.fromisoformat(modified.replace("Z", "+00:00")) + # Handle (ValueError, AttributeError) except (ValueError, AttributeError): + # Return None return None +# Define function _fetch_attack_patterns_taxii def _fetch_attack_patterns_taxii() -> list[dict]: """Connect to the MITRE TAXII server and return all attack-pattern objects.""" + # Log info: "Connecting to MITRE TAXII server at %s", TAXII_SE logger.info("Connecting to MITRE TAXII server at %s", TAXII_SERVER_URL) + # Assign server = TaxiiServer(TAXII_SERVER_URL) server = TaxiiServer(TAXII_SERVER_URL) + # Assign api_root = server.api_roots[0] api_root = server.api_roots[0] + # Assign collection = api_root.collections[0] # Enterprise ATT&CK collection = api_root.collections[0] # Enterprise ATT&CK + # Log info: logger.info( + # Literal argument value "Fetching objects from collection '%s' (id=%s)", collection.title, collection.id, ) + # Assign bundle = collection.get_objects() bundle = collection.get_objects() + # Assign objects = bundle.get("objects", []) objects = bundle.get("objects", []) + # Assign attack_patterns = [ attack_patterns = [ obj for obj in objects if obj.get("type") == "attack-pattern" ] + # Log info: "Retrieved %d attack-pattern objects via TAXII", l logger.info("Retrieved %d attack-pattern objects via TAXII", len(attack_patterns)) + # Return attack_patterns return attack_patterns +# Define function _fetch_attack_patterns_github def _fetch_attack_patterns_github() -> list[dict]: """Fallback: fetch Enterprise ATT&CK bundle from the MITRE CTI GitHub repo.""" + # Log info: "Fetching Enterprise ATT&CK bundle from GitHub (%s logger.info("Fetching Enterprise ATT&CK bundle from GitHub (%s)", GITHUB_ENTERPRISE_URL) + # Assign resp = _requests.get(GITHUB_ENTERPRISE_URL, timeout=120) resp = _requests.get(GITHUB_ENTERPRISE_URL, timeout=120) + # Call resp.raise_for_status() resp.raise_for_status() + # Assign bundle = resp.json() bundle = resp.json() + # Assign objects = bundle.get("objects", []) objects = bundle.get("objects", []) + # Assign attack_patterns = [ attack_patterns = [ obj for obj in objects if obj.get("type") == "attack-pattern" ] + # Log info: "Retrieved %d attack-pattern objects via GitHub", logger.info("Retrieved %d attack-pattern objects via GitHub", len(attack_patterns)) + # Return attack_patterns return attack_patterns +# Define function _fetch_attack_patterns def _fetch_attack_patterns() -> list[dict]: """Return all attack-pattern objects, trying TAXII first then GitHub.""" + # Attempt the following; catch errors below try: + # Return _fetch_attack_patterns_taxii() return _fetch_attack_patterns_taxii() + # Handle Exception except Exception as exc: + # Log warning: logger.warning( + # Literal argument value "TAXII server unavailable (%s), falling back to GitHub mirror", exc, ) + # Return _fetch_attack_patterns_github() return _fetch_attack_patterns_github() +# Define function sync_mitre def sync_mitre(db: Session) -> dict: """Synchronize MITRE ATT&CK techniques into the local database. @@ -131,11 +205,12 @@ def sync_mitre(db: Session) -> dict: db : Session Active SQLAlchemy database session. - Returns + Returns: ------- dict Summary with keys ``created``, ``updated``, ``unchanged``, ``skipped``. """ + # Assign attack_patterns = _fetch_attack_patterns() attack_patterns = _fetch_attack_patterns() # Pre-load existing techniques keyed by mitre_id for fast lookup @@ -143,90 +218,149 @@ def sync_mitre(db: Session) -> dict: t.mitre_id: t for t in db.query(Technique).all() } + # Assign created = 0 created = 0 + # Assign updated = 0 updated = 0 + # Assign unchanged = 0 unchanged = 0 + # Assign skipped = 0 skipped = 0 + # Iterate over attack_patterns for obj in attack_patterns: # ------------------------------------------------------------------ # Skip revoked / deprecated objects # ------------------------------------------------------------------ if obj.get("revoked", False) or obj.get("x_mitre_deprecated", False): + # Assign skipped = 1 skipped += 1 + # Skip to the next loop iteration continue + # Assign mitre_id = _extract_mitre_id(obj.get("external_references", [])) mitre_id = _extract_mitre_id(obj.get("external_references", [])) + # Check: not mitre_id if not mitre_id: + # Assign skipped = 1 skipped += 1 + # Skip to the next loop iteration continue + # Assign name = obj.get("name", "") name = obj.get("name", "") + # Assign description = obj.get("description", "") description = obj.get("description", "") + # Assign tactic = _extract_tactics(obj.get("kill_chain_phases", [])) tactic = _extract_tactics(obj.get("kill_chain_phases", [])) + # Assign platforms = _extract_platforms(obj) platforms = _extract_platforms(obj) + # Assign version = _extract_version(obj) version = _extract_version(obj) + # Assign last_modified = _extract_last_modified(obj) last_modified = _extract_last_modified(obj) + # Assign is_subtechnique = "." in mitre_id is_subtechnique = "." in mitre_id + # Assign parent_mitre_id = mitre_id.split(".")[0] if is_subtechnique else None parent_mitre_id = mitre_id.split(".")[0] if is_subtechnique else None + # Assign existing = existing_techniques.get(mitre_id) existing = existing_techniques.get(mitre_id) + # Check: existing is None if existing is None: # ---- Create new technique ---- technique = Technique( + # Keyword argument: mitre_id mitre_id=mitre_id, + # Keyword argument: name name=name, + # Keyword argument: description description=description, + # Keyword argument: tactic tactic=tactic, + # Keyword argument: platforms platforms=platforms, + # Keyword argument: mitre_version mitre_version=version, + # Keyword argument: mitre_last_modified mitre_last_modified=last_modified, + # Keyword argument: is_subtechnique is_subtechnique=is_subtechnique, + # Keyword argument: parent_mitre_id parent_mitre_id=parent_mitre_id, + # Keyword argument: status_global status_global=TechniqueStatus.not_evaluated, + # Keyword argument: review_required review_required=False, ) + # Stage new record(s) for database insertion db.add(technique) + # Assign existing_techniques[mitre_id] = technique existing_techniques[mitre_id] = technique + # Assign created = 1 created += 1 + # Fallback: handle remaining cases else: # ---- Update if name or description changed ---- changes = False + # Check: existing.name != name if existing.name != name: + # Assign existing.name = name existing.name = name + # Assign changes = True changes = True + # Check: (existing.description or "") != (description or "") if (existing.description or "") != (description or ""): + # Assign existing.description = description existing.description = description + # Assign changes = True changes = True # Always keep metadata up-to-date (does not trigger review) existing.tactic = tactic + # Assign existing.platforms = platforms existing.platforms = platforms + # Assign existing.mitre_version = version existing.mitre_version = version + # Assign existing.mitre_last_modified = last_modified existing.mitre_last_modified = last_modified + # Assign existing.is_subtechnique = is_subtechnique existing.is_subtechnique = is_subtechnique + # Assign existing.parent_mitre_id = parent_mitre_id existing.parent_mitre_id = parent_mitre_id + # Check: changes if changes: + # Assign existing.review_required = True existing.review_required = True + # Assign updated = 1 updated += 1 + # Fallback: handle remaining cases else: + # Assign unchanged = 1 unchanged += 1 # Single commit for the whole batch db.commit() + # Assign summary = { summary = { + # Literal argument value "created": created, + # Literal argument value "updated": updated, + # Literal argument value "unchanged": unchanged, + # Literal argument value "skipped": skipped, } + # Log info: logger.info( + # Literal argument value "MITRE sync complete — created=%d, updated=%d, unchanged=%d, skipped=%d", created, updated, @@ -237,12 +371,19 @@ def sync_mitre(db: Session) -> dict: # Audit log (system action → user_id=None) log_action( db, + # Keyword argument: user_id user_id=None, + # Keyword argument: action action="mitre_sync", + # Keyword argument: entity_type entity_type="technique", + # Keyword argument: entity_id entity_id=None, + # Keyword argument: details details=summary, ) + # Commit all pending changes to the database db.commit() + # Return summary return summary diff --git a/backend/app/services/notification_service.py b/backend/app/services/notification_service.py index b78e7cc..c457b84 100644 --- a/backend/app/services/notification_service.py +++ b/backend/app/services/notification_service.py @@ -7,16 +7,29 @@ Functions in this module stage changes via ``db.add()`` / ``db.flush()`` but do **not** commit. The caller is responsible for committing. """ +# Import uuid import uuid + +# Import datetime, timedelta from datetime from datetime import datetime, timedelta -from sqlalchemy.orm import Session +# Import func from sqlalchemy from sqlalchemy import func -from app.domain.errors import EntityNotFoundError -from app.models.notification import Notification -from app.models.user import User +# Import Session from sqlalchemy.orm +from sqlalchemy.orm import Session +# Import EntityNotFoundError from app.domain.errors +from app.domain.errors import EntityNotFoundError + +# Import Notification from app.models.notification +from app.models.notification import Notification + +# Import Test from app.models.test +from app.models.test import Test + +# Import User from app.models.user +from app.models.user import User # --------------------------------------------------------------------------- # Core CRUD @@ -24,132 +37,209 @@ from app.models.user import User def list_notifications( + # Entry: db db: Session, + # Entry: user_id user_id: uuid.UUID, *, + # Entry: offset offset: int = 0, + # Entry: limit limit: int = 20, ) -> list[Notification]: """Return paginated notifications for a user, newest first.""" + # Return ( return ( db.query(Notification) + # Chain .filter() call .filter(Notification.user_id == user_id) + # Chain .order_by() call .order_by(Notification.created_at.desc()) + # Chain .offset() call .offset(offset) + # Chain .limit() call .limit(limit) + # Chain .all() call .all() ) +# Define function get_notification_or_raise def get_notification_or_raise( + # Entry: db db: Session, + # Entry: notification_id notification_id: uuid.UUID, + # Entry: user_id user_id: uuid.UUID, ) -> Notification: """Fetch a notification by ID and user, or raise EntityNotFoundError.""" + # Assign notif = ( notif = ( db.query(Notification) + # Chain .filter() call .filter( Notification.id == notification_id, Notification.user_id == user_id, ) + # Chain .first() call .first() ) + # Check: notif is None if notif is None: + # Raise EntityNotFoundError raise EntityNotFoundError("Notification", str(notification_id)) + # Return notif return notif +# Define function notify_role def notify_role( + # Entry: db db: Session, *, + # Entry: role role: str, + # Entry: type type: str, + # Entry: title title: str, + # Entry: message message: str, + # Entry: entity_type entity_type: str, + # Entry: entity_id entity_id: uuid.UUID, ) -> None: """Send notifications to all active users with a given role.""" + # Assign users = ( users = ( db.query(User) + # Chain .filter() call .filter(User.role == role, User.is_active == True) # noqa: E712 + # Chain .all() call .all() ) + # Iterate over users for user in users: + # Call create_notification() create_notification( db, + # Keyword argument: user_id user_id=user.id, + # Keyword argument: type type=type, + # Keyword argument: title title=title, + # Keyword argument: message message=message, + # Keyword argument: entity_type entity_type=entity_type, + # Keyword argument: entity_id entity_id=entity_id, ) +# Define function create_notification def create_notification( + # Entry: db db: Session, + # Entry: user_id user_id: uuid.UUID, + # Entry: type type: str, + # Entry: title title: str, + # Entry: message message: str | None = None, + # Entry: entity_type entity_type: str | None = None, + # Entry: entity_id entity_id: uuid.UUID | None = None, ) -> Notification: """Create a single notification for a user.""" + # Assign notif = Notification( notif = Notification( + # Keyword argument: user_id user_id=user_id, + # Keyword argument: type type=type, + # Keyword argument: title title=title, + # Keyword argument: message message=message, + # Keyword argument: entity_type entity_type=entity_type, + # Keyword argument: entity_id entity_id=entity_id, ) + # Stage new record(s) for database insertion db.add(notif) + # Flush changes to DB without committing the transaction db.flush() + # Return notif return notif +# Define function mark_as_read def mark_as_read( + # Entry: db db: Session, notification_id: uuid.UUID, user_id: uuid.UUID ) -> Notification: """Mark a single notification as read. Returns the notification. Raises EntityNotFoundError if not found.""" + # Assign notif = get_notification_or_raise(db, notification_id, user_id) notif = get_notification_or_raise(db, notification_id, user_id) + # Assign notif.read = True notif.read = True + # Return notif return notif +# Define function mark_all_as_read def mark_all_as_read(db: Session, user_id: uuid.UUID) -> int: """Mark all unread notifications for a user as read. Returns count updated.""" + # Assign count = ( count = ( db.query(Notification) + # Chain .filter() call .filter(Notification.user_id == user_id, Notification.read == False) # noqa: E712 + # Chain .update() call .update({"read": True}) ) + # Return count return count +# Define function get_unread_count def get_unread_count(db: Session, user_id: uuid.UUID) -> int: """Return the number of unread notifications for a user.""" + # Return ( return ( db.query(func.count(Notification.id)) + # Chain .filter() call .filter(Notification.user_id == user_id, Notification.read == False) # noqa: E712 + # Chain .scalar() call .scalar() ) or 0 +# Define function cleanup_old_notifications def cleanup_old_notifications(db: Session, days: int = 90) -> int: """Delete read notifications older than *days*. Returns count deleted.""" + # Assign cutoff = datetime.utcnow() - timedelta(days=days) cutoff = datetime.utcnow() - timedelta(days=days) + # Assign count = ( count = ( db.query(Notification) + # Chain .filter() call .filter( Notification.read == True, # noqa: E712 Notification.created_at < cutoff, ) + # Chain .delete() call .delete() ) + # Return count return count @@ -204,71 +294,118 @@ def notify_test_state_change(db: Session, test, new_state: str) -> None: - rejected -> notify creator - validated -> notify creator """ + # Assign test_name = test.name test_name = test.name + # Assign test_id = test.id test_id = test.id + # Assign creator_id = test.created_by creator_id = test.created_by + # Check: new_state == "red_executing" and creator_id if new_state == "red_executing" and creator_id: + # Call create_notification() create_notification( db, + # Keyword argument: user_id user_id=creator_id, + # Keyword argument: type type="test_state_changed", + # Keyword argument: title title="Test execution started", + # Keyword argument: message message=f'Your test "{test_name}" has moved to execution phase.', + # Keyword argument: entity_type entity_type="test", + # Keyword argument: entity_id entity_id=test_id, ) + # Alternative: new_state == "blue_evaluating" elif new_state == "blue_evaluating": # Notify all blue_tech users blue_users = db.query(User).filter(User.role == "blue_tech", User.is_active == True).all() # noqa: E712 + # Iterate over blue_users for user in blue_users: + # Call create_notification() create_notification( db, + # Keyword argument: user_id user_id=user.id, + # Keyword argument: type type="test_assigned", + # Keyword argument: title title="New test ready for blue evaluation", + # Keyword argument: message message=f'Test "{test_name}" needs blue team evaluation.', + # Keyword argument: entity_type entity_type="test", + # Keyword argument: entity_id entity_id=test_id, ) + # Alternative: new_state == "in_review" elif new_state == "in_review": # Notify red_lead and blue_lead users managers = ( db.query(User) + # Chain .filter() call .filter(User.role.in_(["red_lead", "blue_lead"]), User.is_active == True) # noqa: E712 + # Chain .all() call .all() ) + # Iterate over managers for user in managers: + # Call create_notification() create_notification( db, + # Keyword argument: user_id user_id=user.id, + # Keyword argument: type type="validation_needed", + # Keyword argument: title title="Test ready for validation", + # Keyword argument: message message=f'Test "{test_name}" is awaiting your review.', + # Keyword argument: entity_type entity_type="test", + # Keyword argument: entity_id entity_id=test_id, ) + # Alternative: new_state == "rejected" and creator_id elif new_state == "rejected" and creator_id: + # Call create_notification() create_notification( db, + # Keyword argument: user_id user_id=creator_id, + # Keyword argument: type type="test_rejected", + # Keyword argument: title title="Test rejected", + # Keyword argument: message message=f'Your test "{test_name}" has been rejected. Please review and resubmit.', + # Keyword argument: entity_type entity_type="test", + # Keyword argument: entity_id entity_id=test_id, ) + # Alternative: new_state == "validated" and creator_id elif new_state == "validated" and creator_id: + # Call create_notification() create_notification( db, + # Keyword argument: user_id user_id=creator_id, + # Keyword argument: type type="test_validated", + # Keyword argument: title title="Test validated", + # Keyword argument: message message=f'Your test "{test_name}" has been validated successfully.', + # Keyword argument: entity_type entity_type="test", + # Keyword argument: entity_id entity_id=test_id, ) diff --git a/backend/app/services/operational_metrics_service.py b/backend/app/services/operational_metrics_service.py index 1a29d00..71035a6 100644 --- a/backend/app/services/operational_metrics_service.py +++ b/backend/app/services/operational_metrics_service.py @@ -3,26 +3,45 @@ Calculates security operations KPIs from test data and audit logs. """ +# Import datetime, timedelta from datetime from datetime import datetime, timedelta + +# Import Optional from typing from typing import Optional -from sqlalchemy import func, case, and_, or_, extract +# Import func from sqlalchemy +from sqlalchemy import func + +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session -from app.models.test import Test -from app.models.technique import Technique -from app.models.test_detection_result import TestDetectionResult +# Import AuditLog from app.models.audit from app.models.audit import AuditLog -from app.models.enums import TestState, TestResult + +# Import TestResult, TestState from app.models.enums +from app.models.enums import TestResult, TestState + +# Import Technique from app.models.technique +from app.models.technique import Technique + +# Import Test from app.models.test +from app.models.test import Test + +# Import TestDetectionResult from app.models.test_detection_result +from app.models.test_detection_result import TestDetectionResult +# Define function _safe_stats def _safe_stats(values: list[float]) -> dict: """Compute mean, median, min, max from a list of floats (in hours). For sub-hour averages, mean_hours is stored as minutes to avoid rounding to 0.0 which is falsy in JavaScript.""" if not values: + # Return None return None + # Assign sorted_vals = sorted(values) sorted_vals = sorted(values) + # Assign n = len(sorted_vals) n = len(sorted_vals) mean = sum(sorted_vals) / n # Use minutes for sub-hour values to avoid JS falsy 0.0 @@ -31,8 +50,11 @@ def _safe_stats(values: list[float]) -> dict: "mean_hours": mean_display, "unit": "min" if mean < 1 else "hrs", "median_hours": round(sorted_vals[n // 2], 1), + # Literal argument value "min_hours": round(sorted_vals[0], 1), + # Literal argument value "max_hours": round(sorted_vals[-1], 1), + # Literal argument value "sample_size": n, } @@ -59,6 +81,7 @@ def calculate_mttd(db: Session) -> Optional[dict]: .all() ) + # Assign detection_times = [] detection_times = [] for t in tests: gross_secs = (t.blue_started_at - t.red_started_at).total_seconds() @@ -66,6 +89,7 @@ def calculate_mttd(db: Session) -> Optional[dict]: if net_secs > 0: detection_times.append(net_secs / 3600) + # Return _safe_stats(detection_times) return _safe_stats(detection_times) @@ -83,14 +107,17 @@ def calculate_mttr(db: Session) -> Optional[dict]: """ tests = ( db.query(Test) + # Chain .filter() call .filter( Test.state == TestState.validated, Test.red_started_at.isnot(None), Test.blue_validated_at.isnot(None), ) + # Chain .all() call .all() ) + # Assign response_times = [] response_times = [] for t in tests: gross_secs = (t.blue_validated_at - t.red_started_at).total_seconds() @@ -99,6 +126,7 @@ def calculate_mttr(db: Session) -> Optional[dict]: if net_secs > 0: response_times.append(net_secs / 3600) + # Return _safe_stats(response_times) return _safe_stats(response_times) @@ -106,34 +134,63 @@ def calculate_mttr(db: Session) -> Optional[dict]: def calculate_detection_efficacy(db: Session) -> dict: - """Calculate detection efficacy: detected / total validated tests.""" + """Calculate detection efficacy: detected / total validated tests. + + Args: + db (Session): Active SQLAlchemy database session. + + Returns: + dict: Contains ``percentage``, ``detected``, ``partially``, + ``not_detected``, and ``total``. + """ + # Assign validated_tests = ( validated_tests = ( db.query(Test) + # Chain .filter() call .filter(Test.state == TestState.validated) + # Chain .all() call .all() ) + # Assign total = len(validated_tests) total = len(validated_tests) + # Check: total == 0 if total == 0: + # Return { return { + # Literal argument value "percentage": 0, + # Literal argument value "detected": 0, + # Literal argument value "partially": 0, + # Literal argument value "not_detected": 0, + # Literal argument value "total": 0, } + # Assign detected = len([t for t in validated_tests if t.detection_result == TestResult... detected = len([t for t in validated_tests if t.detection_result == TestResult.detected]) + # Assign partially = len([t for t in validated_tests if t.detection_result == TestResult... partially = len([t for t in validated_tests if t.detection_result == TestResult.partially_detected]) + # Assign not_detected = len([t for t in validated_tests if t.detection_result == TestResult... not_detected = len([t for t in validated_tests if t.detection_result == TestResult.not_detected]) + # Assign percentage = round((detected / total) * 100, 1) if total > 0 else 0 percentage = round((detected / total) * 100, 1) if total > 0 else 0 + # Return { return { + # Literal argument value "percentage": percentage, + # Literal argument value "detected": detected, + # Literal argument value "partially": partially, + # Literal argument value "not_detected": not_detected, + # Literal argument value "total": total, } @@ -142,25 +199,45 @@ def calculate_detection_efficacy(db: Session) -> dict: def calculate_alert_fidelity(db: Session) -> dict: - """Calculate alert fidelity: ratio of triggered detection rules.""" + """Calculate alert fidelity: ratio of triggered detection rules. + + Args: + db (Session): Active SQLAlchemy database session. + + Returns: + dict: Contains ``percentage``, ``triggered``, ``not_triggered``, + and ``total_evaluated``. + """ + # Assign total_evaluated = ( total_evaluated = ( db.query(func.count(TestDetectionResult.id)) + # Chain .filter() call .filter(TestDetectionResult.triggered.isnot(None)) + # Chain .scalar() call .scalar() ) or 0 + # Assign triggered = ( triggered = ( db.query(func.count(TestDetectionResult.id)) + # Chain .filter() call .filter(TestDetectionResult.triggered == True) + # Chain .scalar() call .scalar() ) or 0 + # Assign not_triggered = total_evaluated - triggered not_triggered = total_evaluated - triggered + # Return { return { + # Literal argument value "percentage": round((triggered / total_evaluated) * 100, 1) if total_evaluated > 0 else 0, + # Literal argument value "triggered": triggered, + # Literal argument value "not_triggered": not_triggered, + # Literal argument value "total_evaluated": total_evaluated, } @@ -169,46 +246,78 @@ def calculate_alert_fidelity(db: Session) -> dict: def calculate_coverage_velocity(db: Session) -> dict: - """Calculate techniques validated per week.""" + """Calculate techniques validated per week. + + Args: + db (Session): Active SQLAlchemy database session. + + Returns: + dict: Contains ``techniques_per_week`` (float average over the last + 12 weeks) and ``trend`` (``"improving"``, ``"stable"``, or + ``"declining"``). + """ # Count techniques that changed to validated/partial in the last 12 weeks twelve_weeks_ago = datetime.utcnow() - timedelta(weeks=12) + # Assign weekly_counts = ( weekly_counts = ( db.query( func.date_trunc("week", Technique.last_review_date).label("week"), func.count(Technique.id).label("count"), ) + # Chain .filter() call .filter( Technique.last_review_date >= twelve_weeks_ago, Technique.last_review_date.isnot(None), ) + # Chain .group_by() call .group_by(func.date_trunc("week", Technique.last_review_date)) + # Chain .order_by() call .order_by("week") + # Chain .all() call .all() ) + # Check: weekly_counts if weekly_counts: + # Assign counts = [row.count for row in weekly_counts] counts = [row.count for row in weekly_counts] + # Assign avg_per_week = round(sum(counts) / len(counts), 1) avg_per_week = round(sum(counts) / len(counts), 1) # Trend: compare last 4 weeks vs previous 4 weeks recent = counts[-4:] if len(counts) >= 4 else counts + # Assign earlier = counts[-8:-4] if len(counts) >= 8 else counts[:len(counts) // 2] if... earlier = counts[-8:-4] if len(counts) >= 8 else counts[:len(counts) // 2] if counts else [] + # Assign recent_avg = sum(recent) / len(recent) if recent else 0 recent_avg = sum(recent) / len(recent) if recent else 0 + # Assign earlier_avg = sum(earlier) / len(earlier) if earlier else 0 earlier_avg = sum(earlier) / len(earlier) if earlier else 0 + # Check: recent_avg > earlier_avg * 1.1 if recent_avg > earlier_avg * 1.1: + # Assign trend = "improving" trend = "improving" + # Alternative: recent_avg < earlier_avg * 0.9 elif recent_avg < earlier_avg * 0.9: + # Assign trend = "declining" trend = "declining" + # Fallback: handle remaining cases else: + # Assign trend = "stable" trend = "stable" + # Fallback: handle remaining cases else: + # Assign avg_per_week = 0 avg_per_week = 0 + # Assign trend = "stable" trend = "stable" + # Return { return { + # Literal argument value "techniques_per_week": avg_per_week, + # Literal argument value "trend": trend, } @@ -264,6 +373,7 @@ def calculate_validation_throughput(db: Session) -> dict: else: trend = "stable" + # Return { return { "tests_per_week": conversion_rate, # reuse key for API compat "conversion_rate": conversion_rate, @@ -278,51 +388,84 @@ def calculate_validation_throughput(db: Session) -> dict: def calculate_rejection_rate(db: Session) -> dict: - """Calculate rejection rate, broken down by red_lead and blue_lead.""" + """Calculate rejection rate, broken down by red_lead and blue_lead. + + Args: + db (Session): Active SQLAlchemy database session. + + Returns: + dict: Contains ``percentage`` (overall rejection rate), ``by_red_lead`` + (red-lead rejection percentage), and ``by_blue_lead`` + (blue-lead rejection percentage). + """ + # Assign validated_count = ( validated_count = ( db.query(func.count(Test.id)) + # Chain .filter() call .filter(Test.state == TestState.validated) + # Chain .scalar() call .scalar() ) or 0 + # Assign rejected_count = ( rejected_count = ( db.query(func.count(Test.id)) + # Chain .filter() call .filter(Test.state == TestState.rejected) + # Chain .scalar() call .scalar() ) or 0 + # Assign total = validated_count + rejected_count total = validated_count + rejected_count + # Assign overall_pct = round((rejected_count / total) * 100, 1) if total > 0 else 0 overall_pct = round((rejected_count / total) * 100, 1) if total > 0 else 0 # By red_lead (red_validation_status == "rejected") red_rejected = ( db.query(func.count(Test.id)) + # Chain .filter() call .filter(Test.red_validation_status == "rejected") + # Chain .scalar() call .scalar() ) or 0 + # Assign red_total = ( red_total = ( db.query(func.count(Test.id)) + # Chain .filter() call .filter(Test.red_validation_status.in_(["approved", "rejected"])) + # Chain .scalar() call .scalar() ) or 0 + # Assign red_pct = round((red_rejected / red_total) * 100, 1) if red_total > 0 else 0 red_pct = round((red_rejected / red_total) * 100, 1) if red_total > 0 else 0 # By blue_lead blue_rejected = ( db.query(func.count(Test.id)) + # Chain .filter() call .filter(Test.blue_validation_status == "rejected") + # Chain .scalar() call .scalar() ) or 0 + # Assign blue_total = ( blue_total = ( db.query(func.count(Test.id)) + # Chain .filter() call .filter(Test.blue_validation_status.in_(["approved", "rejected"])) + # Chain .scalar() call .scalar() ) or 0 + # Assign blue_pct = round((blue_rejected / blue_total) * 100, 1) if blue_total > 0 else 0 blue_pct = round((blue_rejected / blue_total) * 100, 1) if blue_total > 0 else 0 + # Return { return { + # Literal argument value "percentage": overall_pct, + # Literal argument value "by_red_lead": red_pct, + # Literal argument value "by_blue_lead": blue_pct, } @@ -331,14 +474,31 @@ def calculate_rejection_rate(db: Session) -> dict: def get_all_operational_metrics(db: Session) -> dict: - """Get all operational metrics in a single response.""" + """Return all operational metrics combined in a single response. + + Args: + db (Session): Active SQLAlchemy database session. + + Returns: + dict: Contains ``mttd``, ``mttr``, ``detection_efficacy``, + ``alert_fidelity``, ``coverage_velocity``, + ``validation_throughput``, and ``rejection_rate`` keys. + """ + # Return { return { + # Literal argument value "mttd": calculate_mttd(db), + # Literal argument value "mttr": calculate_mttr(db), + # Literal argument value "detection_efficacy": calculate_detection_efficacy(db), + # Literal argument value "alert_fidelity": calculate_alert_fidelity(db), + # Literal argument value "coverage_velocity": calculate_coverage_velocity(db), + # Literal argument value "validation_throughput": calculate_validation_throughput(db), + # Literal argument value "rejection_rate": calculate_rejection_rate(db), } @@ -347,44 +507,77 @@ def get_all_operational_metrics(db: Session) -> dict: def get_operational_trend(db: Session, period: str = "90d") -> list: - """Get weekly trend data for operational metrics.""" + """Return weekly trend data for operational metrics. + + Args: + db (Session): Active SQLAlchemy database session. + period (str): Lookback period; one of ``"30d"``, ``"90d"`` + (default), or ``"1y"``. + + Returns: + list: Weekly data points, each a dict with ``date``, + ``detection_efficacy``, ``validated_tests``, and + ``detected_tests``. + """ + # Assign now = datetime.utcnow() now = datetime.utcnow() + # Check: period == "30d" if period == "30d": + # Assign start = now - timedelta(days=30) start = now - timedelta(days=30) + # Alternative: period == "1y" elif period == "1y": + # Assign start = now - timedelta(days=365) start = now - timedelta(days=365) + # Fallback: handle remaining cases else: + # Assign start = now - timedelta(days=90) start = now - timedelta(days=90) # Build weekly data points data_points = [] + # Assign current = start current = start + # Loop while current < now while current < now: + # Assign week_end = min(current + timedelta(days=7), now) week_end = min(current + timedelta(days=7), now) # Detection efficacy for tests validated up to this week validated_up_to = ( db.query(Test) + # Chain .filter() call .filter( Test.state == TestState.validated, Test.red_validated_at <= week_end, ) + # Chain .all() call .all() ) + # Assign total = len(validated_up_to) total = len(validated_up_to) + # Assign detected = len([t for t in validated_up_to if t.detection_result == TestResult... detected = len([t for t in validated_up_to if t.detection_result == TestResult.detected]) + # Assign efficacy = round((detected / total) * 100, 1) if total > 0 else 0 efficacy = round((detected / total) * 100, 1) if total > 0 else 0 + # Call data_points.append() data_points.append({ + # Literal argument value "date": current.strftime("%Y-%m-%d"), + # Literal argument value "detection_efficacy": efficacy, + # Literal argument value "validated_tests": total, + # Literal argument value "detected_tests": detected, }) + # Assign current = week_end current = week_end + # Return data_points return data_points @@ -392,20 +585,33 @@ def get_operational_trend(db: Session, period: str = "90d") -> list: def get_metrics_by_team(db: Session) -> dict: - """Get metrics broken down by Red vs Blue team.""" + """Return metrics broken down by Red vs Blue team. + + Args: + db (Session): Active SQLAlchemy database session. + + Returns: + dict: Contains ``red_team`` and ``blue_team`` sub-dicts, each with + ``tests_completed``, ``avg_completion_hours``, and + ``rejection_rate``. + """ # Red team metrics red_tests_completed = ( db.query(func.count(Test.id)) + # Chain .filter() call .filter(Test.state.in_([ TestState.blue_evaluating, TestState.in_review, TestState.validated, TestState.rejected, ])) + # Chain .scalar() call .scalar() ) or 0 + # Assign red_avg_time = None red_avg_time = None + # Assign red_times = [] red_times = [] # Red team avg execution time: red_started_at → blue_started_at (net of paused) tests_with_red = ( @@ -416,6 +622,7 @@ def get_metrics_by_team(db: Session) -> dict: ) .all() ) + # Iterate over tests_with_red for t in tests_with_red: gross = (t.blue_started_at - t.red_started_at).total_seconds() net = gross - (t.red_paused_seconds or 0) @@ -429,11 +636,13 @@ def get_metrics_by_team(db: Session) -> dict: # Blue team: count tests that reached the blue evaluation phase blue_tests_completed = ( db.query(func.count(Test.id)) + # Chain .filter() call .filter(Test.state.in_([ TestState.in_review, TestState.validated, TestState.rejected, ])) + # Chain .scalar() call .scalar() ) or 0 @@ -441,15 +650,20 @@ def get_metrics_by_team(db: Session) -> dict: # Prefer blue_work_started_at (actual pick-up) → blue_validated_at. # Fall back to blue_started_at if blue_work_started_at is not set. blue_avg_time = None + # Assign blue_times = [] blue_times = [] + # Assign tests_with_blue = ( tests_with_blue = ( db.query(Test) + # Chain .filter() call .filter( Test.blue_started_at.isnot(None), Test.blue_validated_at.isnot(None), ) + # Chain .all() call .all() ) + # Iterate over tests_with_blue for t in tests_with_blue: phase_start = t.blue_work_started_at or t.blue_started_at gross = (t.blue_validated_at - phase_start).total_seconds() @@ -463,15 +677,22 @@ def get_metrics_by_team(db: Session) -> dict: red_avg_raw = sum(red_times) / len(red_times) if red_times else None blue_avg_raw = sum(blue_times) / len(blue_times) if blue_times else None + # Return { return { + # Literal argument value "red_team": { + # Literal argument value "tests_completed": red_tests_completed, + # Literal argument value "avg_completion_hours": red_avg_time, "avg_unit": "min" if (red_avg_raw is not None and red_avg_raw < 1) else "hrs", "rejection_rate": calculate_rejection_rate(db)["by_red_lead"], }, + # Literal argument value "blue_team": { + # Literal argument value "tests_completed": blue_tests_completed, + # Literal argument value "avg_completion_hours": blue_avg_time, "avg_unit": "min" if (blue_avg_raw is not None and blue_avg_raw < 1) else "hrs", "rejection_rate": calculate_rejection_rate(db)["by_blue_lead"], diff --git a/backend/app/services/osint_enrichment_service.py b/backend/app/services/osint_enrichment_service.py index fcf2b4b..48d1693 100644 --- a/backend/app/services/osint_enrichment_service.py +++ b/backend/app/services/osint_enrichment_service.py @@ -1,237 +1,433 @@ -"""OSINT enrichment service — automatically discovers CVEs, advisories, and -related intelligence for MITRE ATT&CK techniques using the NVD API. +"""OSINT enrichment service — discovers CVEs, advisories, and threat intel for ATT&CK techniques via the NVD API. Designed to run as a weekly background job. Respects NVD rate limits (5 requests per 30 seconds without an API key, 50/30s with a key). """ +# Import logging import logging + +# Import time import time + +# Import Optional from typing from typing import Optional + +# Import UUID from uuid from uuid import UUID +# Import requests import requests + +# Import func from sqlalchemy from sqlalchemy import func + +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session +# Import settings from app.config from app.config import settings + +# Import EntityNotFoundError from app.domain.errors from app.domain.errors import EntityNotFoundError + +# Import OsintItem from app.models.osint_item from app.models.osint_item import OsintItem + +# Import Technique from app.models.technique from app.models.technique import Technique +# Assign logger = logging.getLogger(__name__) logger = logging.getLogger(__name__) +# Assign NVD_API_BASE = "https://services.nvd.nist.gov/rest/json/cves/2.0" NVD_API_BASE = "https://services.nvd.nist.gov/rest/json/cves/2.0" +# Assign NVD_RATE_LIMIT_BATCH = 5 NVD_RATE_LIMIT_BATCH = 5 +# Assign NVD_RATE_LIMIT_WAIT = 31 # seconds to wait after each batch NVD_RATE_LIMIT_WAIT = 31 # seconds to wait after each batch +# Define function enrich_technique_with_cves def enrich_technique_with_cves(db: Session, technique: Technique) -> int: """Search for CVEs related to a technique via the NVD API. Uses the technique name as a keyword search. Deduplicates against existing OsintItems so re-runs are safe. - Returns the number of new CVEs added. + Args: + db (Session): Active SQLAlchemy database session. + technique (Technique): The ATT&CK technique to enrich. + + Returns: + int: Number of new CVE items added to the database. """ + # Attempt the following; catch errors below try: + # Assign headers = {} headers = {} + # Check: getattr(settings, "NVD_API_KEY", "") if getattr(settings, "NVD_API_KEY", ""): + # Assign headers["apiKey"] = settings.NVD_API_KEY headers["apiKey"] = settings.NVD_API_KEY + # Assign params = { params = { + # Literal argument value "keywordSearch": technique.name, + # Literal argument value "resultsPerPage": 10, } + # Assign resp = requests.get( resp = requests.get( NVD_API_BASE, + # Keyword argument: params params=params, + # Keyword argument: headers headers=headers, + # Keyword argument: timeout timeout=30, ) + # Check: resp.status_code != 200 if resp.status_code != 200: + # Log warning: logger.warning( + # Literal argument value "NVD API error for %s: HTTP %d", technique.mitre_id, resp.status_code, ) + # Return 0 return 0 + # Assign data = resp.json() data = resp.json() + # Assign count = 0 count = 0 + # Iterate over data.get("vulnerabilities", []) for vuln in data.get("vulnerabilities", []): + # Assign cve = vuln.get("cve", {}) cve = vuln.get("cve", {}) + # Assign cve_id = cve.get("id") cve_id = cve.get("id") + # Check: not cve_id if not cve_id: + # Skip to the next loop iteration continue # Deduplicate exists = ( db.query(OsintItem.id) + # Chain .filter() call .filter( OsintItem.technique_id == technique.id, OsintItem.source_url.contains(cve_id), ) + # Chain .first() call .first() ) + # Check: exists if exists: + # Skip to the next loop iteration continue + # Assign descriptions = cve.get("descriptions", []) descriptions = cve.get("descriptions", []) + # Assign desc = next( desc = next( (d["value"] for d in descriptions if d["lang"] == "en"), "" ) # Extract CVSS severity metrics = cve.get("metrics", {}) + # Assign cvss_v31 = metrics.get("cvssMetricV31", []) cvss_v31 = metrics.get("cvssMetricV31", []) + # Assign cvss_v30 = metrics.get("cvssMetricV30", []) cvss_v30 = metrics.get("cvssMetricV30", []) + # Assign cvss_entry = (cvss_v31[0] if cvss_v31 else cvss_v30[0]) if (cvss_v31 or cvss_v30... cvss_entry = (cvss_v31[0] if cvss_v31 else cvss_v30[0]) if (cvss_v31 or cvss_v30) else {} + # Assign cvss_data = cvss_entry.get("cvssData", {}) if cvss_entry else {} cvss_data = cvss_entry.get("cvssData", {}) if cvss_entry else {} + # Assign severity = cvss_data.get("baseSeverity", "UNKNOWN") severity = cvss_data.get("baseSeverity", "UNKNOWN") + # Assign score = cvss_data.get("baseScore") score = cvss_data.get("baseScore") + # Assign item = OsintItem( item = OsintItem( + # Keyword argument: technique_id technique_id=technique.id, + # Keyword argument: source_type source_type="cve", + # Keyword argument: source_url source_url=f"https://nvd.nist.gov/vuln/detail/{cve_id}", + # Keyword argument: title title=cve_id, + # Keyword argument: description description=desc[:500] if desc else None, + # Keyword argument: severity severity=severity, + # Keyword argument: metadata_ metadata_={"cvss_score": score, "cve_id": cve_id}, ) + # Stage new record(s) for database insertion db.add(item) + # Assign count = 1 count += 1 + # Check: count > 0 if count > 0: + # Assign technique.review_required = True technique.review_required = True + # Commit all pending changes to the database db.commit() + # Log info: "Added %d CVEs for %s", count, technique.mitre_id logger.info("Added %d CVEs for %s", count, technique.mitre_id) + # Return count return count + # Handle requests.RequestException except requests.RequestException as e: + # Log error: logger.error( + # Literal argument value "HTTP error during OSINT enrichment for %s: %s", technique.mitre_id, e, ) + # Return 0 return 0 + # Handle Exception except Exception as e: + # Log error: logger.error( + # Literal argument value "OSINT enrichment failed for %s: %s", technique.mitre_id, e, + # Keyword argument: exc_info exc_info=True, ) + # Return 0 return 0 +# Define function enrich_all_techniques def enrich_all_techniques(db: Session) -> int: """Enrich all techniques with CVE data from NVD. Rate-limited: processes *NVD_RATE_LIMIT_BATCH* techniques, then sleeps for *NVD_RATE_LIMIT_WAIT* seconds to stay under NVD limits. - Returns total number of new OSINT items added. + Args: + db (Session): Active SQLAlchemy database session. + + Returns: + int: Total number of new OSINT items added across all techniques. """ + # Assign techniques = db.query(Technique).order_by(Technique.mitre_id).all() techniques = db.query(Technique).order_by(Technique.mitre_id).all() + # Assign total = 0 total = 0 + # Log info: logger.info( + # Literal argument value "Starting OSINT enrichment for %d techniques...", len(techniques), ) + # Iterate over enumerate(techniques) for i, tech in enumerate(techniques): + # Assign total = enrich_technique_with_cves(db, tech) total += enrich_technique_with_cves(db, tech) # Rate limiting: wait after every batch if (i + 1) % NVD_RATE_LIMIT_BATCH == 0 and (i + 1) < len(techniques): + # Log debug: logger.debug( + # Literal argument value "Rate limit pause after %d techniques (%ds)...", i + 1, NVD_RATE_LIMIT_WAIT, ) + # Call time.sleep() time.sleep(NVD_RATE_LIMIT_WAIT) + # Log info: logger.info( + # Literal argument value "OSINT enrichment complete — %d new items across %d techniques", total, len(techniques), ) + # Return total return total +# Define function get_osint_items_for_technique def get_osint_items_for_technique( + # Entry: db db: Session, + # Entry: technique_id technique_id: str, + # Entry: source_type source_type: str | None = None, + # Entry: reviewed reviewed: bool | None = None, ) -> list[OsintItem]: - """Retrieve OSINT items for a technique with optional filters.""" + """Retrieve OSINT items for a technique with optional filters. + + Args: + db (Session): Active SQLAlchemy database session. + technique_id (str): UUID string of the technique to query. + source_type (str | None): Optional filter by source type (e.g. + ``"cve"``). + reviewed (bool | None): Optional filter; ``True`` for reviewed items + only, ``False`` for unreviewed, ``None`` for all. + + Returns: + list[OsintItem]: Matching OSINT items ordered by discovery date + descending. + """ + # Assign query = db.query(OsintItem).filter(OsintItem.technique_id == technique_id) query = db.query(OsintItem).filter(OsintItem.technique_id == technique_id) + # Check: source_type if source_type: + # Assign query = query.filter(OsintItem.source_type == source_type) query = query.filter(OsintItem.source_type == source_type) + # Check: reviewed is not None if reviewed is not None: + # Assign query = query.filter(OsintItem.reviewed == reviewed) query = query.filter(OsintItem.reviewed == reviewed) + # Return query.order_by(OsintItem.discovered_at.desc()).all() return query.order_by(OsintItem.discovered_at.desc()).all() +# Define function mark_osint_reviewed def mark_osint_reviewed(db: Session, item_id: str) -> OsintItem | None: - """Mark an OSINT item as reviewed. Does not commit; caller uses UnitOfWork.""" + """Mark an OSINT item as reviewed. Does not commit; caller uses UnitOfWork. + + Args: + db (Session): Active SQLAlchemy database session. + item_id (str): UUID string of the OSINT item to mark. + + Returns: + OsintItem | None: The updated item, or ``None`` if not found. + """ + # Assign item = db.query(OsintItem).filter(OsintItem.id == item_id).first() item = db.query(OsintItem).filter(OsintItem.id == item_id).first() + # Check: item if item: + # Assign item.reviewed = True item.reviewed = True + # Return item return item +# Define function get_unreviewed_count def get_unreviewed_count(db: Session) -> int: - """Return the total number of unreviewed OSINT items.""" + """Return the total number of unreviewed OSINT items. + + Args: + db (Session): Active SQLAlchemy database session. + + Returns: + int: Count of OSINT items where ``reviewed`` is ``False``. + """ + # Return db.query(OsintItem).filter(OsintItem.reviewed == False).count() # ... return db.query(OsintItem).filter(OsintItem.reviewed == False).count() # noqa: E712 +# Define function list_osint_items def list_osint_items( + # Entry: db db: Session, *, + # Entry: technique_id technique_id: Optional[UUID] = None, + # Entry: source_type source_type: Optional[str] = None, + # Entry: reviewed reviewed: Optional[bool] = None, + # Entry: offset offset: int = 0, + # Entry: limit limit: int = 50, ) -> dict: - """List OSINT items with optional filters and pagination.""" + """List OSINT items with optional filters and pagination. + + Args: + db (Session): Active SQLAlchemy database session. + technique_id (Optional[UUID]): Filter by technique UUID. + source_type (Optional[str]): Filter by source type string (e.g. + ``"cve"``). + reviewed (Optional[bool]): Filter by reviewed status; ``None`` + returns all. + offset (int): Number of records to skip for pagination. + limit (int): Maximum number of records to return. + + Returns: + dict: Contains ``total`` count and ``items`` list of serialized OSINT + item dicts. + """ + # Assign query = db.query(OsintItem) query = db.query(OsintItem) + # Check: technique_id if technique_id: + # Assign query = query.filter(OsintItem.technique_id == technique_id) query = query.filter(OsintItem.technique_id == technique_id) + # Check: source_type if source_type: + # Assign query = query.filter(OsintItem.source_type == source_type) query = query.filter(OsintItem.source_type == source_type) + # Check: reviewed is not None if reviewed is not None: + # Assign query = query.filter(OsintItem.reviewed == reviewed) query = query.filter(OsintItem.reviewed == reviewed) + # Assign total = query.count() total = query.count() + # Assign items = ( items = ( query.order_by(OsintItem.discovered_at.desc()) + # Chain .offset() call .offset(offset) + # Chain .limit() call .limit(limit) + # Chain .all() call .all() ) + # Return { return { + # Literal argument value "total": total, + # Literal argument value "items": [ { + # Literal argument value "id": str(item.id), + # Literal argument value "technique_id": str(item.technique_id), + # Literal argument value "source_type": item.source_type, + # Literal argument value "source_url": item.source_url, + # Literal argument value "title": item.title, + # Literal argument value "description": item.description, + # Literal argument value "severity": item.severity, + # Literal argument value "discovered_at": item.discovered_at.isoformat() if item.discovered_at else None, + # Literal argument value "reviewed": item.reviewed, + # Literal argument value "metadata": item.metadata_, } for item in items @@ -239,39 +435,76 @@ def list_osint_items( } +# Define function get_osint_summary def get_osint_summary(db: Session) -> dict: - """Summary statistics for OSINT items.""" + """Return summary statistics for OSINT items. + + Args: + db (Session): Active SQLAlchemy database session. + + Returns: + dict: Contains ``total_items``, ``unreviewed``, + ``techniques_with_items``, ``by_severity``, and ``by_type``. + """ + # Assign total = db.query(func.count(OsintItem.id)).scalar() or 0 total = db.query(func.count(OsintItem.id)).scalar() or 0 + # Assign unreviewed = get_unreviewed_count(db) unreviewed = get_unreviewed_count(db) + # Assign by_severity = dict( by_severity = dict( db.query(OsintItem.severity, func.count(OsintItem.id)) + # Chain .group_by() call .group_by(OsintItem.severity) + # Chain .all() call .all() ) + # Assign by_type = dict( by_type = dict( db.query(OsintItem.source_type, func.count(OsintItem.id)) + # Chain .group_by() call .group_by(OsintItem.source_type) + # Chain .all() call .all() ) + # Assign techniques_with_items = ( techniques_with_items = ( db.query(func.count(func.distinct(OsintItem.technique_id))).scalar() or 0 ) + # Return { return { + # Literal argument value "total_items": total, + # Literal argument value "unreviewed": unreviewed, + # Literal argument value "techniques_with_items": techniques_with_items, + # Literal argument value "by_severity": by_severity, + # Literal argument value "by_type": by_type, } +# Define function get_technique_or_raise def get_technique_or_raise(db: Session, technique_id: UUID) -> Technique: - """Get a technique by ID or raise EntityNotFoundError.""" + """Return a technique by ID or raise EntityNotFoundError. + + Args: + db (Session): Active SQLAlchemy database session. + technique_id (UUID): UUID of the technique to retrieve. + + Returns: + Technique: The matching technique ORM object. + """ + # Assign technique = db.query(Technique).filter(Technique.id == technique_id).first() technique = db.query(Technique).filter(Technique.id == technique_id).first() + # Check: not technique if not technique: + # Raise EntityNotFoundError raise EntityNotFoundError("Technique", str(technique_id)) + # Return technique return technique diff --git a/backend/app/services/report_engine.py b/backend/app/services/report_engine.py index 4e31e30..66933c4 100644 --- a/backend/app/services/report_engine.py +++ b/backend/app/services/report_engine.py @@ -3,95 +3,151 @@ Uses WeasyPrint for PDF generation and docxtpl for DOCX. """ -import os -import uuid +# Import logging import logging + +# Import os +import os + +# Import uuid +import uuid + +# Import datetime from datetime from datetime import datetime +# Import Environment, FileSystemLoader from jinja2 from jinja2 import Environment, FileSystemLoader +# Import settings from app.config from app.config import settings +# Assign logger = logging.getLogger(__name__) logger = logging.getLogger(__name__) +# Define class ReportEngine class ReportEngine: """Template-based report generator supporting PDF, DOCX, and HTML output.""" + # Define function __init__ def __init__(self) -> None: + """Initialise the Jinja2 environment and ensure the output directory exists.""" + # Assign self.jinja_env = Environment( self.jinja_env = Environment( + # Keyword argument: loader loader=FileSystemLoader(settings.REPORT_TEMPLATES_DIR), + # Keyword argument: autoescape autoescape=True, ) + # Call os.makedirs() os.makedirs(settings.REPORT_OUTPUT_DIR, exist_ok=True) + # Define function render_html def render_html(self, template_name: str, context: dict) -> str: """Render a Jinja2 template to an HTML string.""" + # Assign template = self.jinja_env.get_template(f"{template_name}.html") template = self.jinja_env.get_template(f"{template_name}.html") + # Call context.setdefault() context.setdefault("company_name", settings.COMPANY_NAME) + # Call context.setdefault() context.setdefault("generated_at", datetime.utcnow().strftime("%B %d, %Y %H:%M UTC")) + # Return template.render(context) return template.render(context) + # Define function generate_pdf def generate_pdf(self, template_name: str, context: dict) -> str: """Render HTML and convert to PDF with WeasyPrint.""" - from weasyprint import HTML, CSS + # Import CSS, HTML from weasyprint + from weasyprint import CSS, HTML + # Assign html_content = self.render_html(template_name, context) html_content = self.render_html(template_name, context) + # Assign css_path = os.path.join(settings.REPORT_TEMPLATES_DIR, "styles", "report.css") css_path = os.path.join(settings.REPORT_TEMPLATES_DIR, "styles", "report.css") + # Assign output_path = os.path.join( output_path = os.path.join( settings.REPORT_OUTPUT_DIR, f"{template_name}_{uuid.uuid4().hex[:8]}.pdf", ) + # Assign stylesheets = [] stylesheets = [] + # Check: os.path.exists(css_path) if os.path.exists(css_path): + # Call stylesheets.append() stylesheets.append(CSS(filename=css_path)) + # Call HTML() HTML( + # Keyword argument: string string=html_content, + # Keyword argument: base_url base_url=settings.REPORT_TEMPLATES_DIR, ).write_pdf(output_path, stylesheets=stylesheets) + # Log info: "PDF generated: %s", output_path logger.info("PDF generated: %s", output_path) + # Return output_path return output_path + # Define function generate_docx def generate_docx(self, template_name: str, context: dict) -> str: """Render a .docx template with docxtpl.""" + # Import DocxTemplate from docxtpl from docxtpl import DocxTemplate + # Assign template_path = os.path.join( template_path = os.path.join( settings.REPORT_TEMPLATES_DIR, f"{template_name}.docx" ) + # Assign output_path = os.path.join( output_path = os.path.join( settings.REPORT_OUTPUT_DIR, f"{template_name}_{uuid.uuid4().hex[:8]}.docx", ) + # Assign doc = DocxTemplate(template_path) doc = DocxTemplate(template_path) + # Call context.setdefault() context.setdefault("company_name", settings.COMPANY_NAME) + # Call context.setdefault() context.setdefault("generated_at", datetime.utcnow().strftime("%B %d, %Y")) + # Call doc.render() doc.render(context) + # Call doc.save() doc.save(output_path) + # Log info: "DOCX generated: %s", output_path logger.info("DOCX generated: %s", output_path) + # Return output_path return output_path + # Define function generate_html def generate_html(self, template_name: str, context: dict) -> str: """Render and save a standalone HTML report (alias for spec compatibility).""" + # Return self.generate_html_file(template_name, context) return self.generate_html_file(template_name, context) + # Define function generate_html_file def generate_html_file(self, template_name: str, context: dict) -> str: """Render and save a standalone HTML report.""" + # Assign html_content = self.render_html(template_name, context) html_content = self.render_html(template_name, context) + # Assign output_path = os.path.join( output_path = os.path.join( settings.REPORT_OUTPUT_DIR, f"{template_name}_{uuid.uuid4().hex[:8]}.html", ) + # Open context manager with open(output_path, "w", encoding="utf-8") as f: + # Call f.write() f.write(html_content) + # Log info: "HTML report generated: %s", output_path logger.info("HTML report generated: %s", output_path) + # Return output_path return output_path +# Assign report_engine = ReportEngine() report_engine = ReportEngine() diff --git a/backend/app/services/report_generation_service.py b/backend/app/services/report_generation_service.py index 1c6ec64..d1da7ba 100644 --- a/backend/app/services/report_generation_service.py +++ b/backend/app/services/report_generation_service.py @@ -1,115 +1,195 @@ """High-level report generation — collects domain data and delegates to ReportEngine.""" +# Import logging import logging + +# Import datetime, timedelta from datetime from datetime import datetime, timedelta + +# Import UUID from uuid from uuid import UUID +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session +# Import EntityNotFoundError from app.domain.exceptions from app.domain.exceptions import EntityNotFoundError + +# Import Campaign, CampaignTest from app.models.campaign from app.models.campaign import Campaign, CampaignTest + +# Import CoverageSnapshot from app.models.coverage_snapshot from app.models.coverage_snapshot import CoverageSnapshot + +# Import Technique from app.models.technique from app.models.technique import Technique + +# Import Test from app.models.test from app.models.test import Test + +# Import ThreatActor from app.models.threat_actor from app.models.threat_actor import ThreatActor + +# Import report_engine from app.services.report_engine from app.services.report_engine import report_engine +# Assign logger = logging.getLogger(__name__) logger = logging.getLogger(__name__) +# Define function generate_purple_campaign_report def generate_purple_campaign_report( + # Entry: db db: Session, + # Entry: campaign_id campaign_id: str, + # Entry: output_format output_format: str = "pdf", ) -> str: """Generate the full Purple Team campaign report.""" + # Assign cid = campaign_id if isinstance(campaign_id, UUID) else UUID(str(campaign... cid = campaign_id if isinstance(campaign_id, UUID) else UUID(str(campaign_id)) + # Assign campaign = db.query(Campaign).filter(Campaign.id == cid).first() campaign = db.query(Campaign).filter(Campaign.id == cid).first() + # Check: not campaign if not campaign: + # Raise EntityNotFoundError raise EntityNotFoundError("Campaign", campaign_id) + # Assign campaign_tests = ( campaign_tests = ( db.query(Test) + # Chain .join() call .join(CampaignTest, CampaignTest.test_id == Test.id) + # Chain .filter() call .filter(CampaignTest.campaign_id == cid) + # Chain .all() call .all() ) + # Assign tests_data = [] tests_data = [] + # Iterate over campaign_tests for test in campaign_tests: + # Assign technique = db.query(Technique).filter(Technique.id == test.technique_id).first() technique = db.query(Technique).filter(Technique.id == test.technique_id).first() + # Call tests_data.append() tests_data.append({ + # Literal argument value "technique_mitre_id": technique.mitre_id if technique else "N/A", + # Literal argument value "name": test.name, + # Literal argument value "tactic": technique.tactic if technique else "N/A", + # Literal argument value "state": test.state.value if test.state else "draft", + # Literal argument value "detection_result": ( test.detection_result.value if test.detection_result else "pending" ), }) + # Assign validated = [t for t in campaign_tests if t.state and t.state.value == "validat... validated = [t for t in campaign_tests if t.state and t.state.value == "validated"] + # Assign detected = [ detected = [ t for t in validated if t.detection_result and t.detection_result.value == "detected" ] + # Assign not_detected = [ not_detected = [ t for t in validated if t.detection_result and t.detection_result.value == "not_detected" ] + # Assign critical_findings = [ critical_findings = [ { + # Literal argument value "technique_id": t["technique_mitre_id"], + # Literal argument value "name": t["name"], + # Literal argument value "severity": "critical", + # Literal argument value "description": "Technique was not detected during campaign execution.", + # Literal argument value "recommendation": "Implement detection rule or review existing SIEM/EDR configuration.", } for t in tests_data if t["detection_result"] == "not_detected" ] + # Assign org_score = _safe_org_score(db) org_score = _safe_org_score(db) + # Assign threat_actors = [] threat_actors = [] + # Check: campaign.threat_actor_id if campaign.threat_actor_id: + # Assign actor = db.query(ThreatActor).filter(ThreatActor.id == campaign.threat_acto... actor = db.query(ThreatActor).filter(ThreatActor.id == campaign.threat_actor_id).first() + # Check: actor if actor: + # Assign threat_actors = [{"name": actor.name}] threat_actors = [{"name": actor.name}] + # Assign context = { context = { + # Literal argument value "campaign": campaign, + # Literal argument value "tests": tests_data, + # Literal argument value "tests_validated": len(validated), + # Literal argument value "tests_detected": len(detected), + # Literal argument value "tests_not_detected": len(not_detected), + # Literal argument value "critical_findings": critical_findings, + # Literal argument value "org_score": org_score.get("overall", 0), + # Literal argument value "tactics": list({t["tactic"] for t in tests_data}), + # Literal argument value "threat_actors": threat_actors, } + # Return _generate(output_format, "purple_campaign", context) return _generate(output_format, "purple_campaign", context) +# Define function generate_coverage_report def generate_coverage_report( + # Entry: db db: Session, + # Entry: output_format output_format: str = "pdf", ) -> str: """Generate an organization-wide MITRE ATT&CK coverage report.""" - from sqlalchemy import func, case + # Import case, func from sqlalchemy + from sqlalchemy import case, func + # Assign org_score = _safe_org_score(db) org_score = _safe_org_score(db) + # Assign techniques = db.query(Technique).all() techniques = db.query(Technique).all() + # Assign status_counts = {"validated": 0, "partial": 0, "not_covered": 0, "in_progress": 0, ... status_counts = {"validated": 0, "partial": 0, "not_covered": 0, "in_progress": 0, "not_evaluated": 0} + # Iterate over techniques for t in techniques: + # Assign s = t.status_global.value if t.status_global else "not_evaluated" s = t.status_global.value if t.status_global else "not_evaluated" + # Check: s in status_counts if s in status_counts: + # Assign status_counts[s] = 1 status_counts[s] += 1 + # Assign summary = { summary = { + # Literal argument value "total_techniques": len(techniques), **status_counts, } @@ -121,14 +201,21 @@ def generate_coverage_report( func.count(Technique.id).label("total"), func.sum(case((Technique.status_global == "validated", 1), else_=0)).label("validated"), ) + # Chain .group_by() call .group_by(Technique.tactic) + # Chain .all() call .all() ) + # Assign tactics_coverage = [ tactics_coverage = [ { + # Literal argument value "tactic": r[0] or "Unknown", + # Literal argument value "total": r[1], + # Literal argument value "validated": int(r[2]), + # Literal argument value "coverage_pct": round((int(r[2]) / r[1]) * 100, 1) if r[1] > 0 else 0, } for r in tactic_rows @@ -136,213 +223,326 @@ def generate_coverage_report( # Never-tested techniques tested_ids = {t.technique_id for t in db.query(Test.technique_id).distinct().all()} + # Assign never_tested = [ never_tested = [ {"mitre_id": t.mitre_id, "name": t.name, "tactic": t.tactic} for t in techniques if t.id not in tested_ids ] + # Assign context = { context = { + # Literal argument value "org_score": org_score, + # Literal argument value "summary": summary, + # Literal argument value "tactics_coverage": tactics_coverage, + # Literal argument value "never_tested": never_tested[:50], } + # Return _generate(output_format, "coverage_report", context) return _generate(output_format, "coverage_report", context) +# Define function generate_executive_summary def generate_executive_summary( + # Entry: db db: Session, + # Entry: output_format output_format: str = "pdf", ) -> str: """Generate an executive summary report.""" + # Import func from sqlalchemy from sqlalchemy import func + # Assign org_score = _safe_org_score(db) org_score = _safe_org_score(db) + # Assign techniques = db.query(Technique).all() techniques = db.query(Technique).all() + # Assign status_counts = {"validated": 0, "partial": 0, "not_covered": 0, "in_progress": 0, ... status_counts = {"validated": 0, "partial": 0, "not_covered": 0, "in_progress": 0, "not_evaluated": 0} + # Iterate over techniques for t in techniques: + # Assign s = t.status_global.value if t.status_global else "not_evaluated" s = t.status_global.value if t.status_global else "not_evaluated" + # Check: s in status_counts if s in status_counts: + # Assign status_counts[s] = 1 status_counts[s] += 1 + # Assign summary = {"total_techniques": len(techniques), **status_counts} summary = {"total_techniques": len(techniques), **status_counts} + # Assign total_tests = db.query(func.count(Test.id)).scalar() or 0 total_tests = db.query(func.count(Test.id)).scalar() or 0 + # Assign active_campaigns = ( active_campaigns = ( db.query(func.count(Campaign.id)).filter(Campaign.status == "active").scalar() or 0 ) + # Assign quarter_ago = datetime.utcnow() - timedelta(days=90) quarter_ago = datetime.utcnow() - timedelta(days=90) + # Assign tests_this_quarter = ( tests_this_quarter = ( db.query(func.count(Test.id)).filter(Test.created_at >= quarter_ago).scalar() or 0 ) + # Assign open_remediations = ( open_remediations = ( db.query(func.count(Test.id)) + # Chain .filter() call .filter(Test.remediation_status.in_(["pending", "in_progress"])) + # Chain .scalar() call .scalar() or 0 ) # Detection rate among validated tests validated_count = status_counts["validated"] + # Assign detected_count = ( detected_count = ( db.query(func.count(Test.id)) + # Chain .filter() call .filter(Test.state == "validated", Test.detection_result == "detected") + # Chain .scalar() call .scalar() or 0 ) + # Assign detection_rate = round((detected_count / validated_count) * 100, 1) if validated_cou... detection_rate = round((detected_count / validated_count) * 100, 1) if validated_count > 0 else 0 # Top gaps — lowest coverage tactics from sqlalchemy import case as sql_case + # Assign tactic_rows = ( tactic_rows = ( db.query( Technique.tactic, func.count(Technique.id).label("total"), func.sum(sql_case((Technique.status_global == "validated", 1), else_=0)).label("validated"), ) + # Chain .group_by() call .group_by(Technique.tactic) + # Chain .all() call .all() ) + # Assign tactic_coverage = [ tactic_coverage = [ { + # Literal argument value "tactic": r[0] or "Unknown", + # Literal argument value "coverage_pct": round((int(r[2]) / r[1]) * 100, 1) if r[1] > 0 else 0, } for r in tactic_rows ] + # Assign top_gaps = sorted(tactic_coverage, key=lambda x: x["coverage_pct"])[:5] top_gaps = sorted(tactic_coverage, key=lambda x: x["coverage_pct"])[:5] + # Assign context = { context = { + # Literal argument value "org_score": org_score, + # Literal argument value "summary": summary, + # Literal argument value "total_tests": total_tests, + # Literal argument value "active_campaigns": active_campaigns, + # Literal argument value "tests_this_quarter": tests_this_quarter, + # Literal argument value "open_remediations": open_remediations, + # Literal argument value "detection_rate": detection_rate, + # Literal argument value "top_gaps": top_gaps, } + # Return _generate(output_format, "executive_summary", context) return _generate(output_format, "executive_summary", context) +# Define function generate_quarterly_summary def generate_quarterly_summary( + # Entry: db db: Session, + # Entry: output_format output_format: str = "pdf", ) -> str: """Quarterly summary — reuses executive metrics plus snapshot trend rows.""" - from sqlalchemy import case as sql_case, func + # Import case as sql_case from sqlalchemy + from sqlalchemy import case as sql_case + # Import func from sqlalchemy + from sqlalchemy import func + + # Assign org_score = _safe_org_score(db) org_score = _safe_org_score(db) + # Assign quarter_ago = datetime.utcnow() - timedelta(days=90) quarter_ago = datetime.utcnow() - timedelta(days=90) + # Assign tests_this_quarter = ( tests_this_quarter = ( db.query(func.count(Test.id)).filter(Test.created_at >= quarter_ago).scalar() or 0 ) + # Assign techniques = db.query(Technique).all() techniques = db.query(Technique).all() + # Assign validated_count = sum( validated_count = sum( + # Literal argument value 1 for t in techniques if t.status_global and t.status_global.value == "validated" ) + # Assign detected_count = ( detected_count = ( db.query(func.count(Test.id)) + # Chain .filter() call .filter(Test.state == "validated", Test.detection_result == "detected") + # Chain .scalar() call .scalar() or 0 ) + # Assign detection_rate = ( detection_rate = ( round((detected_count / validated_count) * 100, 1) if validated_count > 0 else 0 ) + # Assign tactic_rows = ( tactic_rows = ( db.query( Technique.tactic, func.count(Technique.id).label("total"), func.sum(sql_case((Technique.status_global == "validated", 1), else_=0)).label( + # Literal argument value "validated", ), ) + # Chain .group_by() call .group_by(Technique.tactic) + # Chain .all() call .all() ) + # Assign top_gaps = sorted( top_gaps = sorted( [ { + # Literal argument value "tactic": r[0] or "Unknown", + # Literal argument value "coverage_pct": round((int(r[2]) / r[1]) * 100, 1) if r[1] > 0 else 0, } for r in tactic_rows ], + # Keyword argument: key key=lambda x: x["coverage_pct"], )[:5] + # Assign snapshots = ( snapshots = ( db.query(CoverageSnapshot) + # Chain .filter() call .filter(CoverageSnapshot.created_at >= quarter_ago) + # Chain .order_by() call .order_by(CoverageSnapshot.created_at) + # Chain .all() call .all() ) + # Assign trend_rows = [ trend_rows = [ { + # Literal argument value "date": s.created_at.strftime("%Y-%m-%d") if s.created_at else "", + # Literal argument value "validated_count": s.validated_count, + # Literal argument value "total_techniques": s.total_techniques, + # Literal argument value "organization_score": round(s.organization_score, 1), } for s in snapshots ] + # Assign now = datetime.utcnow() now = datetime.utcnow() + # Assign quarter_label = f"Q{((now.month - 1) // 3) + 1} {now.year}" quarter_label = f"Q{((now.month - 1) // 3) + 1} {now.year}" + # Assign context = { context = { + # Literal argument value "quarter_label": quarter_label, + # Literal argument value "org_score": org_score, + # Literal argument value "tests_this_quarter": tests_this_quarter, + # Literal argument value "detection_rate": detection_rate, + # Literal argument value "trend_rows": trend_rows, + # Literal argument value "top_gaps": top_gaps, } + # Return _generate(output_format, "quarterly_summary", context) return _generate(output_format, "quarterly_summary", context) +# Define function generate_technique_detail_report def generate_technique_detail_report( + # Entry: db db: Session, + # Entry: technique_id technique_id: str, + # Entry: output_format output_format: str = "pdf", ) -> str: """Detailed report for a single MITRE technique and its tests.""" + # Assign tid = technique_id if isinstance(technique_id, UUID) else UUID(str(techni... tid = technique_id if isinstance(technique_id, UUID) else UUID(str(technique_id)) + # Assign technique = db.query(Technique).filter(Technique.id == tid).first() technique = db.query(Technique).filter(Technique.id == tid).first() + # Check: not technique if not technique: + # Raise EntityNotFoundError raise EntityNotFoundError("Technique", str(technique_id)) + # Assign related_tests = ( related_tests = ( db.query(Test) + # Chain .filter() call .filter(Test.technique_id == tid) + # Chain .order_by() call .order_by(Test.created_at.desc()) + # Chain .all() call .all() ) + # Assign tests_data = [ tests_data = [ { + # Literal argument value "name": t.name, + # Literal argument value "state": t.state.value if t.state else "draft", + # Literal argument value "detection_result": ( t.detection_result.value if t.detection_result else "pending" ), + # Literal argument value "created_at": t.created_at.strftime("%Y-%m-%d") if t.created_at else "", } for t in related_tests ] + # Assign context = { context = { + # Literal argument value "technique": technique, + # Literal argument value "technique_status": ( technique.status_global.value if technique.status_global else "not_evaluated" ), + # Literal argument value "tests": tests_data, } + # Return _generate(output_format, "technique_detail", context) return _generate(output_format, "technique_detail", context) @@ -351,19 +551,32 @@ def generate_technique_detail_report( def _safe_org_score(db: Session) -> dict: """Safely call the scoring service; return empty dict on failure.""" + # Attempt the following; catch errors below try: + # Import calculate_organization_score from app.services.scoring_service from app.services.scoring_service import calculate_organization_score + # Return calculate_organization_score(db) return calculate_organization_score(db) + # Handle Exception except Exception as e: + # Log warning: "Scoring service unavailable: %s", e logger.warning("Scoring service unavailable: %s", e) + # Return {"overall": 0, "coverage": 0, "detection_maturity": 0} return {"overall": 0, "coverage": 0, "detection_maturity": 0} +# Define function _generate def _generate(output_format: str, template_name: str, context: dict) -> str: """Dispatch to the correct ReportEngine method.""" + # Check: output_format == "pdf" if output_format == "pdf": + # Return report_engine.generate_pdf(template_name, context) return report_engine.generate_pdf(template_name, context) + # Alternative: output_format == "docx" elif output_format == "docx": + # Return report_engine.generate_docx(template_name, context) return report_engine.generate_docx(template_name, context) + # Fallback: handle remaining cases else: + # Return report_engine.generate_html_file(template_name, context) return report_engine.generate_html_file(template_name, context) diff --git a/backend/app/services/score_cache.py b/backend/app/services/score_cache.py index 67218f7..f27f381 100644 --- a/backend/app/services/score_cache.py +++ b/backend/app/services/score_cache.py @@ -7,78 +7,123 @@ Thread-safe: each worker process has its own dict, and the TTL ensures stale data does not persist longer than ``CACHE_TTL`` seconds. """ +# Import time import time + +# Import Any, Optional from typing from typing import Any, Optional +# Import Session from sqlalchemy.orm +from sqlalchemy.orm import Session + +# Assign CACHE_TTL = 300 # 5 minutes CACHE_TTL = 300 # 5 minutes +# Assign _cache = {} _cache: dict[str, dict[str, Any]] = {} -def get(key: str) -> Optional[Any]: +# Define function get +def get(key: str) -> Optional[Any]: # noqa: ANN401 # generic cache returns whatever was stored """Return cached value if present and not expired, else None.""" + # Assign entry = _cache.get(key) entry = _cache.get(key) + # Check: entry is None if entry is None: + # Return None return None + # Check: time.time() - entry["ts"] > CACHE_TTL if time.time() - entry["ts"] > CACHE_TTL: + # Call _cache.pop() _cache.pop(key, None) + # Return None return None + # Return entry["data"] return entry["data"] -def put(key: str, data: Any) -> None: +# Define function put +def put(key: str, data: Any) -> None: # noqa: ANN401 # generic cache accepts any serialisable value """Store *data* under *key* with the current timestamp.""" + # Assign _cache[key] = {"data": data, "ts": time.time()} _cache[key] = {"data": data, "ts": time.time()} +# Define function invalidate def invalidate(key: Optional[str] = None) -> None: """Remove one key or clear the whole cache.""" + # Check: key is None if key is None: + # Call _cache.clear() _cache.clear() + # Fallback: handle remaining cases else: + # Call _cache.pop() _cache.pop(key, None) # ── High-level helpers ──────────────────────────────────────────────── -def get_organization_score_cached(db): +def get_organization_score_cached(db: Session) -> dict: """Cached wrapper around ``calculate_organization_score``.""" + # Import calculate_organization_score from app.services.scoring_service from app.services.scoring_service import calculate_organization_score + # Assign cached = get("org_score") cached = get("org_score") + # Check: cached is not None if cached is not None: + # Return cached return cached + # Assign result = calculate_organization_score(db) result = calculate_organization_score(db) + # Call put() put("org_score", result) + # Return result return result -def get_operational_metrics_cached(db): +# Define function get_operational_metrics_cached +def get_operational_metrics_cached(db: Session) -> dict: """Cached wrapper around operational metrics (MTTD, MTTR, efficacy).""" + # Import from app.services.operational_metrics_service from app.services.operational_metrics_service import ( - calculate_mttd, - calculate_mttr, - calculate_detection_efficacy, calculate_alert_fidelity, calculate_coverage_velocity, - calculate_validation_throughput, + calculate_detection_efficacy, + calculate_mttd, + calculate_mttr, calculate_rejection_rate, + calculate_validation_throughput, ) + # Assign cached = get("op_metrics") cached = get("op_metrics") + # Check: cached is not None if cached is not None: + # Return cached return cached + # Assign result = { result = { + # Literal argument value "mttd": calculate_mttd(db), + # Literal argument value "mttr": calculate_mttr(db), + # Literal argument value "detection_efficacy": calculate_detection_efficacy(db), + # Literal argument value "alert_fidelity": calculate_alert_fidelity(db), + # Literal argument value "coverage_velocity": calculate_coverage_velocity(db), + # Literal argument value "validation_throughput": calculate_validation_throughput(db), + # Literal argument value "rejection_rate": calculate_rejection_rate(db), } + # Call put() put("op_metrics", result) + # Return result return result diff --git a/backend/app/services/scoring_config_service.py b/backend/app/services/scoring_config_service.py index fd54198..719da63 100644 --- a/backend/app/services/scoring_config_service.py +++ b/backend/app/services/scoring_config_service.py @@ -1,121 +1,202 @@ """Scoring configuration persistence service.""" +# Enable future language features for compatibility from __future__ import annotations +# Import uuid import uuid + +# Import Any from typing from typing import Any +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session +# Import settings from app.config from app.config import settings + +# Import ScoringWeights from app.domain.value_objects.scoring_weights from app.domain.value_objects.scoring_weights import ScoringWeights + +# Import ScoringConfig from app.models.scoring_config from app.models.scoring_config import ScoringConfig +# Define function _row_recency def _row_recency(row: ScoringConfig) -> float: + # Return float(getattr(row, "weight_recency", None) or getattr(row, "weight_... return float(getattr(row, "weight_recency", None) or getattr(row, "weight_freshness", 10.0)) +# Define function _row_severity def _row_severity(row: ScoringConfig) -> float: + # Return float( return float( getattr(row, "weight_severity", None) or getattr(row, "weight_platform_diversity", 10.0) ) +# Define function get_scoring_weights def get_scoring_weights(db: Session) -> ScoringWeights: """Return the active scoring weights from the database or env defaults.""" + # Assign row = db.query(ScoringConfig).first() row = db.query(ScoringConfig).first() + # Check: row is not None if row is not None: + # Return ScoringWeights( return ScoringWeights( + # Keyword argument: tests tests=row.weight_tests, + # Keyword argument: detection_rules detection_rules=row.weight_detection_rules, + # Keyword argument: d3fend d3fend=row.weight_d3fend, + # Keyword argument: recency recency=_row_recency(row), + # Keyword argument: severity severity=_row_severity(row), ) + # Return ScoringWeights( return ScoringWeights( + # Keyword argument: tests tests=float(settings.SCORING_WEIGHT_TESTS), + # Keyword argument: detection_rules detection_rules=float(settings.SCORING_WEIGHT_DETECTION_RULES), + # Keyword argument: d3fend d3fend=float(settings.SCORING_WEIGHT_D3FEND), + # Keyword argument: recency recency=float( getattr(settings, "SCORING_WEIGHT_RECENCY", settings.SCORING_WEIGHT_FRESHNESS) ), + # Keyword argument: severity severity=float( getattr(settings, "SCORING_WEIGHT_SEVERITY", settings.SCORING_WEIGHT_PLATFORM_DIVERSITY) ), ) +# Define function update_scoring_weights def update_scoring_weights( + # Entry: db db: Session, *, + # Entry: tests tests: float | None = None, + # Entry: detection_rules detection_rules: float | None = None, + # Entry: d3fend d3fend: float | None = None, + # Entry: recency recency: float | None = None, + # Entry: severity severity: float | None = None, + # Entry: freshness freshness: float | None = None, + # Entry: platform_diversity platform_diversity: float | None = None, + # Entry: updated_by updated_by: uuid.UUID | None = None, ) -> dict[str, Any]: """Upsert scoring weights. Does not commit.""" + # Check: freshness is not None and recency is None if freshness is not None and recency is None: + # Assign recency = freshness recency = freshness + # Check: platform_diversity is not None and severity is None if platform_diversity is not None and severity is None: + # Assign severity = platform_diversity severity = platform_diversity + # Assign current = get_scoring_weights(db) current = get_scoring_weights(db) + # Assign new = ScoringWeights( new = ScoringWeights( + # Keyword argument: tests tests=tests if tests is not None else current.tests, + # Keyword argument: detection_rules detection_rules=detection_rules if detection_rules is not None else current.detection_rules, + # Keyword argument: d3fend d3fend=d3fend if d3fend is not None else current.d3fend, + # Keyword argument: recency recency=recency if recency is not None else current.recency, + # Keyword argument: severity severity=severity if severity is not None else current.severity, ) + # Assign row = db.query(ScoringConfig).first() row = db.query(ScoringConfig).first() + # Check: row is None if row is None: + # Assign row = ScoringConfig() row = ScoringConfig() + # Stage new record(s) for database insertion db.add(row) + # Assign row.weight_tests = new.tests row.weight_tests = new.tests + # Assign row.weight_detection_rules = new.detection_rules row.weight_detection_rules = new.detection_rules + # Assign row.weight_d3fend = new.d3fend row.weight_d3fend = new.d3fend + # Check: hasattr(row, "weight_recency") if hasattr(row, "weight_recency"): + # Assign row.weight_recency = new.recency row.weight_recency = new.recency + # Alternative: hasattr(row, "weight_freshness") elif hasattr(row, "weight_freshness"): + # Assign row.weight_freshness = new.recency row.weight_freshness = new.recency + # Check: hasattr(row, "weight_severity") if hasattr(row, "weight_severity"): + # Assign row.weight_severity = new.severity row.weight_severity = new.severity + # Alternative: hasattr(row, "weight_platform_diversity") elif hasattr(row, "weight_platform_diversity"): + # Assign row.weight_platform_diversity = new.severity row.weight_platform_diversity = new.severity + # Check: updated_by is not None and hasattr(row, "updated_by") if updated_by is not None and hasattr(row, "updated_by"): + # Assign row.updated_by = updated_by row.updated_by = updated_by + # Return _weights_dict(new) return _weights_dict(new) +# Define function get_weights_dict def get_weights_dict(db: Session) -> dict[str, Any]: """Return current weights as a serialisable dict.""" + # Return _weights_dict(get_scoring_weights(db)) return _weights_dict(get_scoring_weights(db)) +# Define function _weights_dict def _weights_dict(w: ScoringWeights) -> dict[str, Any]: + # Assign weights = { weights = { + # Literal argument value "tests": w.tests, + # Literal argument value "detection_rules": w.detection_rules, + # Literal argument value "d3fend": w.d3fend, + # Literal argument value "recency": w.recency, + # Literal argument value "severity": w.severity, # Legacy keys for older clients "freshness": w.recency, + # Literal argument value "platform_diversity": w.severity, } + # Return { return { + # Literal argument value "weights": weights, + # Literal argument value "total": sum( [w.tests, w.detection_rules, w.d3fend, w.recency, w.severity] ), diff --git a/backend/app/services/scoring_service.py b/backend/app/services/scoring_service.py index a3cf003..c30bcc7 100644 --- a/backend/app/services/scoring_service.py +++ b/backend/app/services/scoring_service.py @@ -9,75 +9,160 @@ fixed number of aggregated queries so that organisation-wide calculations never produce N+1 traffic. """ +# Import datetime, timedelta, timezone from datetime from datetime import datetime, timedelta, timezone -from typing import Optional +# Import case, func from sqlalchemy from sqlalchemy import case, func + +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session +# Import EntityNotFoundError from app.domain.errors from app.domain.errors import EntityNotFoundError -from app.models.technique import Technique -from app.models.test import Test -from app.models.detection_rule import DetectionRule -from app.models.test_detection_result import TestDetectionResult + +# Import DefensiveTechniqueMapping from app.models.defensive_technique from app.models.defensive_technique import DefensiveTechniqueMapping + +# Import DetectionRule from app.models.detection_rule +from app.models.detection_rule import DetectionRule + +# Import TestResult, TestState from app.models.enums +from app.models.enums import TestResult, TestState + +# Import Technique from app.models.technique +from app.models.technique import Technique + +# Import Test from app.models.test +from app.models.test import Test + +# Import TestDetectionResult from app.models.test_detection_result +from app.models.test_detection_result import TestDetectionResult + +# Import ThreatActor, ThreatActorTechnique from app.models.threat_actor from app.models.threat_actor import ThreatActor, ThreatActorTechnique -from app.models.enums import TestState, TestResult + +# Import get_scoring_weights from app.services.scoring_config_service from app.services.scoring_config_service import get_scoring_weights +# Assign _SEVERITY_FACTORS = { _SEVERITY_FACTORS: dict[str, float] = { + # Literal argument value "critical": 1.0, + # Literal argument value "high": 0.85, + # Literal argument value "medium": 0.65, + # Literal argument value "low": 0.5, } +# Define function _recency_factor def _recency_factor(last_tested: datetime | None) -> float: - """Decay factor: 1.0 when recent, decreasing over time.""" + """Return a recency decay factor: 1.0 when recent, decreasing over time. + + Args: + last_tested (datetime | None): Datetime of the most recent validated + test, or ``None`` if the technique has never been tested. + + Returns: + float: A multiplier between 0.0 and 1.0; 0.0 when untested, 1.0 + when tested within the last 90 days. + """ + # Check: not last_tested if not last_tested: + # Return 0.0 return 0.0 + # Assign now = datetime.now(timezone.utc) now = datetime.now(timezone.utc) + # Assign tested = last_tested tested = last_tested + # Check: tested.tzinfo is None if tested.tzinfo is None: + # Assign tested = tested.replace(tzinfo=timezone.utc) tested = tested.replace(tzinfo=timezone.utc) + # Assign days_ago = (now - tested).days days_ago = (now - tested).days + # Check: days_ago <= 90 if days_ago <= 90: + # Return 1.0 return 1.0 + # Check: days_ago <= 180 if days_ago <= 180: + # Return 0.8 return 0.8 + # Check: days_ago <= 365 if days_ago <= 365: + # Return 0.5 return 0.5 + # Return 0.2 return 0.2 +# Define function _severity_factor def _severity_factor(severity_label: str | None) -> float: - """Map template severity to a 0–1 multiplier.""" + """Map template severity to a 0–1 multiplier. + + Args: + severity_label (str | None): Severity string from the test template + (e.g. ``"critical"``, ``"high"``). Case-insensitive. + + Returns: + float: A multiplier between 0.5 and 1.0; defaults to 0.7 for + unknown or missing labels. + """ + # Check: not severity_label if not severity_label: + # Return 0.7 return 0.7 + # Return _SEVERITY_FACTORS.get(severity_label.lower(), 0.7) return _SEVERITY_FACTORS.get(severity_label.lower(), 0.7) +# Define function _max_severity_by_mitre def _max_severity_by_mitre(db: Session) -> dict[str, str]: - """Highest severity label per MITRE id from active test templates.""" + """Return the highest severity label per MITRE ID from active test templates. + + Args: + db (Session): Active SQLAlchemy database session. + + Returns: + dict[str, str]: Mapping of MITRE technique ID to the highest severity + label (``"critical"`` > ``"high"`` > ``"medium"`` > ``"low"``) + found among active test templates for that technique. + """ + # Import TestTemplate from app.models.test_template from app.models.test_template import TestTemplate + # Assign order = {"critical": 4, "high": 3, "medium": 2, "low": 1} order = {"critical": 4, "high": 3, "medium": 2, "low": 1} + # Assign rows = ( rows = ( db.query(TestTemplate.mitre_technique_id, TestTemplate.severity) + # Chain .filter() call .filter( TestTemplate.is_active == True, # noqa: E712 TestTemplate.severity.isnot(None), ) + # Chain .all() call .all() ) + # Assign best = {} best: dict[str, str] = {} + # Iterate over rows for mitre_id, severity in rows: + # Check: not mitre_id or not severity if not mitre_id or not severity: + # Skip to the next loop iteration continue + # Assign current = best.get(mitre_id) current = best.get(mitre_id) + # Check: current is None or order.get(severity.lower(), 0) > order.get(curre... if current is None or order.get(severity.lower(), 0) > order.get(current.lower(), 0): + # Assign best[mitre_id] = severity best[mitre_id] = severity + # Return best return best @@ -96,14 +181,22 @@ def bulk_technique_scores(db: Session) -> dict: Returns ``{technique_id: {"total_score": float, "breakdown": dict}}``. """ + # Assign w = get_scoring_weights(db) w = get_scoring_weights(db) + # Assign w_tests = w.tests w_tests = w.tests + # Assign w_detection = w.detection_rules w_detection = w.detection_rules + # Assign w_d3fend = w.d3fend w_d3fend = w.d3fend + # Assign w_recency = w.recency w_recency = w.recency + # Assign w_severity = w.severity w_severity = w.severity + # Assign severity_by_mitre = _max_severity_by_mitre(db) severity_by_mitre = _max_severity_by_mitre(db) + # Assign last_validated = func.coalesce( last_validated = func.coalesce( Test.blue_validated_at, Test.red_validated_at, @@ -120,16 +213,25 @@ def bulk_technique_scores(db: Session) -> dict: ).label("detected_count"), func.max(last_validated).label("latest_validated_at"), ) + # Chain .filter() call .filter(Test.state == TestState.validated) + # Chain .group_by() call .group_by(Test.technique_id) + # Chain .all() call .all() ) + # Assign test_stats = {} test_stats: dict = {} + # Iterate over test_rows for row in test_rows: + # Assign test_stats[row.technique_id] = { test_stats[row.technique_id] = { + # Literal argument value "validated": row.validated_count, + # Literal argument value "detected": row.detected_count, + # Literal argument value "latest_validated_at": row.latest_validated_at, } @@ -139,10 +241,14 @@ def bulk_technique_scores(db: Session) -> dict: DetectionRule.mitre_technique_id, func.count(DetectionRule.id).label("total"), ) + # Chain .filter() call .filter(DetectionRule.is_active == True) # noqa: E712 + # Chain .group_by() call .group_by(DetectionRule.mitre_technique_id) + # Chain .all() call .all() ) + # Assign rules_by_mitre = {r.mitre_technique_id: r.total for r in rule_rows} rules_by_mitre: dict[str, int] = {r.mitre_technique_id: r.total for r in rule_rows} # Q3: triggered rules per mitre_id @@ -151,11 +257,16 @@ def bulk_technique_scores(db: Session) -> dict: DetectionRule.mitre_technique_id, func.count(TestDetectionResult.id).label("triggered"), ) + # Chain .join() call .join(DetectionRule, DetectionRule.id == TestDetectionResult.detection_rule_id) + # Chain .filter() call .filter(TestDetectionResult.triggered == True) # noqa: E712 + # Chain .group_by() call .group_by(DetectionRule.mitre_technique_id) + # Chain .all() call .all() ) + # Assign triggered_by_mitre = { triggered_by_mitre: dict[str, int] = { r.mitre_technique_id: r.triggered for r in triggered_rows } @@ -166,34 +277,53 @@ def bulk_technique_scores(db: Session) -> dict: DefensiveTechniqueMapping.attack_technique_id, func.count(DefensiveTechniqueMapping.id).label("total"), ) + # Chain .group_by() call .group_by(DefensiveTechniqueMapping.attack_technique_id) + # Chain .all() call .all() ) + # Assign d3fend_by_tech = {r.attack_technique_id: r.total for r in d3fend_rows} d3fend_by_tech: dict = {r.attack_technique_id: r.total for r in d3fend_rows} # Q5: all techniques techniques = db.query(Technique).all() + # Assign results = {} results: dict = {} + # Iterate over techniques for tech in techniques: + # Assign ts = test_stats.get(tech.id, {}) ts = test_stats.get(tech.id, {}) + # Assign validated = ts.get("validated", 0) validated = ts.get("validated", 0) + # Assign detected = ts.get("detected", 0) detected = ts.get("detected", 0) + # Assign latest_at = ts.get("latest_validated_at") latest_at = ts.get("latest_validated_at") + # Assign breakdown = {} breakdown = {} # 1. Tests validated with detection if validated > 0: + # Assign test_ratio = detected / validated test_ratio = detected / validated + # Assign test_score = round(test_ratio * w_tests, 1) test_score = round(test_ratio * w_tests, 1) + # Fallback: handle remaining cases else: + # Assign test_ratio = 0 test_ratio = 0 + # Assign test_score = 0 test_score = 0 + # Assign breakdown["tests_validated"] = { breakdown["tests_validated"] = { + # Literal argument value "score": test_score, + # Literal argument value "max": w_tests, + # Literal argument value "detail": ( f"{detected}/{validated} tests detected" if validated else "No validated tests" @@ -202,16 +332,27 @@ def bulk_technique_scores(db: Session) -> dict: # 2. Detection rules total_rules = rules_by_mitre.get(tech.mitre_id, 0) + # Assign triggered_rules = triggered_by_mitre.get(tech.mitre_id, 0) triggered_rules = triggered_by_mitre.get(tech.mitre_id, 0) + # Check: total_rules > 0 if total_rules > 0: + # Assign detection_ratio = min(triggered_rules / total_rules, 1.0) detection_ratio = min(triggered_rules / total_rules, 1.0) + # Assign detection_score = round(detection_ratio * w_detection, 1) detection_score = round(detection_ratio * w_detection, 1) + # Fallback: handle remaining cases else: + # Assign detection_ratio = 0 detection_ratio = 0 + # Assign detection_score = 0 detection_score = 0 + # Assign breakdown["detection_rules"] = { breakdown["detection_rules"] = { + # Literal argument value "score": detection_score, + # Literal argument value "max": w_detection, + # Literal argument value "detail": ( f"{triggered_rules}/{total_rules} rules triggered" if total_rules > 0 else "No detection rules available" @@ -220,15 +361,25 @@ def bulk_technique_scores(db: Session) -> dict: # 3. D3FEND coverage total_cm = d3fend_by_tech.get(tech.id, 0) + # Check: total_cm > 0 and detected > 0 if total_cm > 0 and detected > 0: + # Assign verified_cm = min(detected, total_cm) verified_cm = min(detected, total_cm) + # Assign d3fend_score = round((verified_cm / total_cm) * w_d3fend, 1) d3fend_score = round((verified_cm / total_cm) * w_d3fend, 1) + # Fallback: handle remaining cases else: + # Assign verified_cm = 0 verified_cm = 0 + # Assign d3fend_score = 0 d3fend_score = 0 + # Assign breakdown["d3fend_coverage"] = { breakdown["d3fend_coverage"] = { + # Literal argument value "score": d3fend_score, + # Literal argument value "max": w_d3fend, + # Literal argument value "detail": ( f"{verified_cm}/{total_cm} countermeasures" if total_cm > 0 else "No D3FEND mappings" @@ -237,29 +388,49 @@ def bulk_technique_scores(db: Session) -> dict: # 4. Recency decay recency_mult = _recency_factor(latest_at) + # Assign recency_score = round(recency_mult * w_recency, 1) recency_score = round(recency_mult * w_recency, 1) + # Check: latest_at if latest_at: + # Assign tested = latest_at tested = latest_at + # Check: tested.tzinfo is None if tested.tzinfo is None: + # Assign days_ago = (datetime.utcnow() - tested).days days_ago = (datetime.utcnow() - tested).days + # Fallback: handle remaining cases else: + # Assign days_ago = (datetime.now(timezone.utc) - tested.astimezone(timezone.utc)).days days_ago = (datetime.now(timezone.utc) - tested.astimezone(timezone.utc)).days + # Assign recency_detail = f"Last validated {days_ago} days ago (factor {recency_mult})" recency_detail = f"Last validated {days_ago} days ago (factor {recency_mult})" + # Fallback: handle remaining cases else: + # Assign recency_detail = "No validated tests" recency_detail = "No validated tests" + # Assign breakdown["recency"] = { breakdown["recency"] = { + # Literal argument value "score": recency_score, + # Literal argument value "max": w_recency, + # Literal argument value "detail": recency_detail, } # 5. Severity / criticality (template-driven) sev_label = severity_by_mitre.get(tech.mitre_id) + # Assign sev_mult = _severity_factor(sev_label) sev_mult = _severity_factor(sev_label) + # Assign severity_score = round(sev_mult * w_severity, 1) severity_score = round(sev_mult * w_severity, 1) + # Assign breakdown["severity"] = { breakdown["severity"] = { + # Literal argument value "score": severity_score, + # Literal argument value "max": w_severity, + # Literal argument value "detail": ( f"Template severity: {sev_label} (factor {sev_mult})" if sev_label @@ -267,18 +438,26 @@ def bulk_technique_scores(db: Session) -> dict: ), } + # Assign total = min( total = min( test_score + detection_score + d3fend_score + recency_score + severity_score, + # Literal argument value 100, ) + # Assign results[tech.id] = { results[tech.id] = { + # Literal argument value "total_score": round(total, 1), + # Literal argument value "breakdown": breakdown, + # Literal argument value "mitre_id": tech.mitre_id, + # Literal argument value "tactic": tech.tactic, } + # Return results return results @@ -286,65 +465,128 @@ def bulk_technique_scores(db: Session) -> dict: def score_technique_by_mitre_id(db: Session, mitre_id: str) -> dict: - """Get detailed score with breakdown for a technique by MITRE ID.""" + """Return detailed score with breakdown for a technique by MITRE ID. + + Args: + db (Session): Active SQLAlchemy database session. + mitre_id (str): MITRE ATT&CK technique identifier (e.g. ``"T1059"``). + + Returns: + dict: Scoring result containing ``mitre_id``, ``name``, ``tactic``, + ``status_global``, ``total_score``, and ``breakdown``. + """ + # Assign technique = db.query(Technique).filter(Technique.mitre_id == mitre_id).first() technique = db.query(Technique).filter(Technique.mitre_id == mitre_id).first() + # Check: not technique if not technique: + # Raise EntityNotFoundError raise EntityNotFoundError("Technique", mitre_id) + # Assign result = calculate_technique_score(technique, db) result = calculate_technique_score(technique, db) + # Return { return { + # Literal argument value "mitre_id": technique.mitre_id, + # Literal argument value "name": technique.name, + # Literal argument value "tactic": technique.tactic, + # Literal argument value "status_global": technique.status_global.value if technique.status_global else None, **result, } +# Define function score_actor_by_id def score_actor_by_id(db: Session, actor_id: str) -> dict: - """Get coverage score for a threat actor by ID.""" + """Return coverage score for a threat actor by ID. + + Args: + db (Session): Active SQLAlchemy database session. + actor_id (str): UUID string identifying the threat actor. + + Returns: + dict: Coverage score dictionary from + :func:`calculate_actor_coverage_score`. + """ + # Assign actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first() actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first() + # Check: not actor if not actor: + # Raise EntityNotFoundError raise EntityNotFoundError("ThreatActor", actor_id) + # Return calculate_actor_coverage_score(actor_id, db) return calculate_actor_coverage_score(actor_id, db) +# Define function calculate_technique_score def calculate_technique_score(technique: Technique, db: Session) -> dict: """Calculate a 0-100 score for a technique with detailed breakdown. Weights are read from the ``scoring_config`` table (or env defaults). + + Args: + technique (Technique): The technique ORM object to score. + db (Session): Active SQLAlchemy database session. + + Returns: + dict: Dictionary with ``total_score`` (float) and ``breakdown`` + (dict mapping component name to score, max, and detail string). """ + # Assign w = get_scoring_weights(db) w = get_scoring_weights(db) + # Assign w_tests = w.tests w_tests = w.tests + # Assign w_detection = w.detection_rules w_detection = w.detection_rules + # Assign w_d3fend = w.d3fend w_d3fend = w.d3fend + # Assign w_recency = w.recency w_recency = w.recency + # Assign w_severity = w.severity w_severity = w.severity + # Assign severity_by_mitre = _max_severity_by_mitre(db) severity_by_mitre = _max_severity_by_mitre(db) + # Assign breakdown = {} breakdown = {} # ── 1. Tests validated with detection ────────────────────────── all_tests = ( db.query(Test) + # Chain .filter() call .filter(Test.technique_id == technique.id) + # Chain .all() call .all() ) + # Assign validated_tests = [t for t in all_tests if t.state == TestState.validated] validated_tests = [t for t in all_tests if t.state == TestState.validated] + # Assign detected_tests = [ detected_tests = [ t for t in validated_tests if t.detection_result == TestResult.detected ] + # Check: validated_tests if validated_tests: + # Assign test_ratio = len(detected_tests) / len(validated_tests) test_ratio = len(detected_tests) / len(validated_tests) + # Assign test_score = round(test_ratio * w_tests, 1) test_score = round(test_ratio * w_tests, 1) + # Fallback: handle remaining cases else: + # Assign test_ratio = 0 test_ratio = 0 + # Assign test_score = 0 test_score = 0 + # Assign breakdown["tests_validated"] = { breakdown["tests_validated"] = { + # Literal argument value "score": test_score, + # Literal argument value "max": w_tests, + # Literal argument value "detail": f"{len(detected_tests)}/{len(validated_tests)} tests detected" if validated_tests else "No validated tests", @@ -353,37 +595,54 @@ def calculate_technique_score(technique: Technique, db: Session) -> dict: # ── 2. Detection rules coverage ─────────────────────────────── total_rules = ( db.query(func.count(DetectionRule.id)) + # Chain .filter() call .filter( DetectionRule.mitre_technique_id == technique.mitre_id, DetectionRule.is_active == True, # noqa: E712 ) + # Chain .scalar() call .scalar() ) or 0 + # Assign triggered_rules = 0 triggered_rules = 0 + # Check: total_rules > 0 if total_rules > 0: + # Assign triggered_rules = ( triggered_rules = ( db.query(func.count(TestDetectionResult.id)) + # Chain .join() call .join( DetectionRule, DetectionRule.id == TestDetectionResult.detection_rule_id, ) + # Chain .filter() call .filter( DetectionRule.mitre_technique_id == technique.mitre_id, TestDetectionResult.triggered == True, # noqa: E712 ) + # Chain .scalar() call .scalar() ) or 0 + # Assign detection_ratio = min(triggered_rules / total_rules, 1.0) detection_ratio = min(triggered_rules / total_rules, 1.0) + # Assign detection_score = round(detection_ratio * w_detection, 1) detection_score = round(detection_ratio * w_detection, 1) + # Fallback: handle remaining cases else: + # Assign detection_ratio = 0 detection_ratio = 0 + # Assign detection_score = 0 detection_score = 0 + # Assign breakdown["detection_rules"] = { breakdown["detection_rules"] = { + # Literal argument value "score": detection_score, + # Literal argument value "max": w_detection, + # Literal argument value "detail": f"{triggered_rules}/{total_rules} rules triggered" if total_rules > 0 else "No detection rules available", @@ -392,22 +651,36 @@ def calculate_technique_score(technique: Technique, db: Session) -> dict: # ── 3. D3FEND coverage ──────────────────────────────────────── total_countermeasures = ( db.query(func.count(DefensiveTechniqueMapping.id)) + # Chain .filter() call .filter(DefensiveTechniqueMapping.attack_technique_id == technique.id) + # Chain .scalar() call .scalar() ) or 0 + # Assign verified_countermeasures = 0 verified_countermeasures = 0 + # Check: total_countermeasures > 0 and len(detected_tests) > 0 if total_countermeasures > 0 and len(detected_tests) > 0: + # Assign verified_countermeasures = min(len(detected_tests), total_countermeasures) verified_countermeasures = min(len(detected_tests), total_countermeasures) + # Assign d3fend_ratio = verified_countermeasures / total_countermeasures d3fend_ratio = verified_countermeasures / total_countermeasures + # Assign d3fend_score = round(d3fend_ratio * w_d3fend, 1) d3fend_score = round(d3fend_ratio * w_d3fend, 1) + # Fallback: handle remaining cases else: + # Assign d3fend_ratio = 0 d3fend_ratio = 0 + # Assign d3fend_score = 0 d3fend_score = 0 + # Assign breakdown["d3fend_coverage"] = { breakdown["d3fend_coverage"] = { + # Literal argument value "score": d3fend_score, + # Literal argument value "max": w_d3fend, + # Literal argument value "detail": f"{verified_countermeasures}/{total_countermeasures} countermeasures" if total_countermeasures > 0 else "No D3FEND mappings", @@ -415,14 +688,22 @@ def calculate_technique_score(technique: Technique, db: Session) -> dict: # ── 4. Recency ──────────────────────────────────────────────── most_recent_test = None + # Iterate over validated_tests for t in validated_tests: + # Assign candidate = t.blue_validated_at or t.red_validated_at or t.created_at candidate = t.blue_validated_at or t.red_validated_at or t.created_at + # Check: candidate and (most_recent_test is None or candidate > most_recent_... if candidate and (most_recent_test is None or candidate > most_recent_test): + # Assign most_recent_test = candidate most_recent_test = candidate + # Assign recency_mult = _recency_factor(most_recent_test) recency_mult = _recency_factor(most_recent_test) + # Assign recency_score = round(recency_mult * w_recency, 1) recency_score = round(recency_mult * w_recency, 1) + # Check: most_recent_test if most_recent_test: + # Assign days_ago = ( days_ago = ( datetime.now(timezone.utc) - ( most_recent_test.replace(tzinfo=timezone.utc) @@ -430,23 +711,36 @@ def calculate_technique_score(technique: Technique, db: Session) -> dict: else most_recent_test.astimezone(timezone.utc) ) ).days + # Assign recency_detail = f"Last validated {days_ago} days ago (factor {recency_mult})" recency_detail = f"Last validated {days_ago} days ago (factor {recency_mult})" + # Fallback: handle remaining cases else: + # Assign recency_detail = "No validated tests" recency_detail = "No validated tests" + # Assign breakdown["recency"] = { breakdown["recency"] = { + # Literal argument value "score": recency_score, + # Literal argument value "max": w_recency, + # Literal argument value "detail": recency_detail, } # ── 5. Severity ─────────────────────────────────────────────── sev_label = severity_by_mitre.get(technique.mitre_id) + # Assign sev_mult = _severity_factor(sev_label) sev_mult = _severity_factor(sev_label) + # Assign severity_score = round(sev_mult * w_severity, 1) severity_score = round(sev_mult * w_severity, 1) + # Assign breakdown["severity"] = { breakdown["severity"] = { + # Literal argument value "score": severity_score, + # Literal argument value "max": w_severity, + # Literal argument value "detail": ( f"Template severity: {sev_label} (factor {sev_mult})" if sev_label @@ -457,11 +751,15 @@ def calculate_technique_score(technique: Technique, db: Session) -> dict: # ── Total ───────────────────────────────────────────────────── total = min( test_score + detection_score + d3fend_score + recency_score + severity_score, + # Literal argument value 100, ) + # Return { return { + # Literal argument value "total_score": round(total, 1), + # Literal argument value "breakdown": breakdown, } @@ -470,19 +768,36 @@ def calculate_technique_score(technique: Technique, db: Session) -> dict: def calculate_tactic_score(tactic: str, db: Session) -> dict: - """Calculate average score for all techniques in a tactic.""" + """Calculate average score for all techniques in a tactic. + + Args: + tactic (str): Tactic name used for case-insensitive substring matching + against technique tactic fields. + db (Session): Active SQLAlchemy database session. + + Returns: + dict: Contains ``tactic``, ``average_score``, ``techniques_count``, + and ``techniques_scored`` keys. + """ + # Assign scores_map = bulk_technique_scores(db) scores_map = bulk_technique_scores(db) + # Assign matching = [ matching = [ v["total_score"] for v in scores_map.values() if v.get("tactic") and tactic.lower() in v["tactic"].lower() ] + # Return { return { + # Literal argument value "tactic": tactic, + # Literal argument value "average_score": round(sum(matching) / len(matching), 1) if matching else 0, + # Literal argument value "techniques_count": len(matching), + # Literal argument value "techniques_scored": len([s for s in matching if s > 0]), } @@ -491,53 +806,100 @@ def calculate_tactic_score(tactic: str, db: Session) -> dict: def calculate_actor_coverage_score(actor_id: str, db: Session) -> dict: - """Calculate coverage score for a specific threat actor's techniques.""" + """Calculate coverage score for a specific threat actor's techniques. + + Args: + actor_id (str): UUID string identifying the threat actor. + db (Session): Active SQLAlchemy database session. + + Returns: + dict: Contains ``actor_id``, ``actor_name``, ``total_score``, + ``techniques_count``, ``techniques_covered``, and + ``techniques_detail`` keys. + """ + # Assign actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first() actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first() + # Check: not actor if not actor: + # Return {"total_score": 0, "techniques_count": 0, "techniques_covered": 0} return {"total_score": 0, "techniques_count": 0, "techniques_covered": 0} + # Assign actor_techniques = ( actor_techniques = ( db.query(ThreatActorTechnique) + # Chain .filter() call .filter(ThreatActorTechnique.threat_actor_id == actor.id) + # Chain .all() call .all() ) + # Assign technique_ids = {at.technique_id for at in actor_techniques} technique_ids = {at.technique_id for at in actor_techniques} + # Check: not technique_ids if not technique_ids: + # Return { return { + # Literal argument value "actor_id": str(actor.id), + # Literal argument value "actor_name": actor.name, + # Literal argument value "total_score": 0, + # Literal argument value "techniques_count": 0, + # Literal argument value "techniques_covered": 0, + # Literal argument value "techniques_detail": [], } + # Assign scores_map = bulk_technique_scores(db) scores_map = bulk_technique_scores(db) + # Assign scores = [] scores = [] + # Assign details = [] details = [] + # Iterate over technique_ids for tid in technique_ids: + # Assign entry = scores_map.get(tid) entry = scores_map.get(tid) + # Check: not entry if not entry: + # Skip to the next loop iteration continue + # Assign score = entry["total_score"] score = entry["total_score"] + # Call scores.append() scores.append(score) + # Call details.append() details.append({ + # Literal argument value "mitre_id": entry["mitre_id"], + # Literal argument value "name": entry.get("name", ""), + # Literal argument value "score": score, + # Literal argument value "breakdown": entry["breakdown"], }) + # Assign avg_score = round(sum(scores) / len(scores), 1) if scores else 0 avg_score = round(sum(scores) / len(scores), 1) if scores else 0 + # Return { return { + # Literal argument value "actor_id": str(actor.id), + # Literal argument value "actor_name": actor.name, + # Literal argument value "total_score": avg_score, + # Literal argument value "techniques_count": len(technique_ids), + # Literal argument value "techniques_covered": len([s for s in scores if s > 50]), + # Literal argument value "techniques_detail": details, } @@ -550,25 +912,49 @@ def calculate_organization_score(db: Session) -> dict: Uses ``bulk_technique_scores`` to compute all technique scores in 5 aggregated queries instead of N*5. + + Args: + db (Session): Active SQLAlchemy database session. + + Returns: + dict: Contains ``overall_score``, ``total_coverage``, + ``critical_coverage``, ``detection_maturity``, + ``response_readiness``, ``techniques_evaluated``, and + ``techniques_total``. """ + # Assign scores_map = bulk_technique_scores(db) scores_map = bulk_technique_scores(db) + # Assign total_count = len(scores_map) total_count = len(scores_map) + # Check: total_count == 0 if total_count == 0: + # Return { return { + # Literal argument value "overall_score": 0, + # Literal argument value "total_coverage": 0, + # Literal argument value "critical_coverage": 0, + # Literal argument value "detection_maturity": 0, + # Literal argument value "response_readiness": 0, + # Literal argument value "techniques_evaluated": 0, + # Literal argument value "techniques_total": 0, } + # Assign all_scores = [v["total_score"] for v in scores_map.values()] all_scores = [v["total_score"] for v in scores_map.values()] + # Assign evaluated_scores = [s for s in all_scores if s > 0] evaluated_scores = [s for s in all_scores if s > 0] + # Assign evaluated_count = len(evaluated_scores) evaluated_count = len(evaluated_scores) + # Assign total_coverage = ( total_coverage = ( round(sum(evaluated_scores) / len(evaluated_scores), 1) if evaluated_scores else 0 @@ -577,19 +963,25 @@ def calculate_organization_score(db: Session) -> dict: # Critical coverage: techniques with high/critical severity templates from app.models.test_template import TestTemplate + # Assign critical_mitre_ids = set( critical_mitre_ids = set( row[0] for row in db.query(TestTemplate.mitre_technique_id) + # Chain .filter() call .filter(TestTemplate.severity.in_(["high", "critical"])) + # Chain .distinct() call .distinct() + # Chain .all() call .all() ) + # Assign critical_scores = [ critical_scores = [ v["total_score"] for v in scores_map.values() if v.get("mitre_id") in critical_mitre_ids ] + # Assign critical_coverage = ( critical_coverage = ( round(sum(critical_scores) / len(critical_scores), 1) if critical_scores else 0 @@ -598,53 +990,76 @@ def calculate_organization_score(db: Session) -> dict: # Detection maturity (2 scalar queries — already efficient) total_rules = ( db.query(func.count(DetectionRule.id)) + # Chain .filter() call .filter(DetectionRule.is_active == True) # noqa: E712 + # Chain .scalar() call .scalar() ) or 0 + # Assign triggered_total = ( triggered_total = ( db.query(func.count(TestDetectionResult.id)) + # Chain .filter() call .filter(TestDetectionResult.triggered == True) # noqa: E712 + # Chain .scalar() call .scalar() ) or 0 + # Assign detection_maturity = ( detection_maturity = ( round((triggered_total / total_rules) * 100, 1) if total_rules > 0 else 0 ) + # Assign detection_maturity = min(detection_maturity, 100) detection_maturity = min(detection_maturity, 100) # Response readiness (2 scalar queries — already efficient) remediation_total = ( db.query(func.count(Test.id)) + # Chain .filter() call .filter(Test.remediation_status.isnot(None)) + # Chain .scalar() call .scalar() ) or 0 + # Assign remediation_completed = ( remediation_completed = ( db.query(func.count(Test.id)) + # Chain .filter() call .filter(Test.remediation_status == "completed") + # Chain .scalar() call .scalar() ) or 0 + # Assign response_readiness = ( response_readiness = ( round((remediation_completed / remediation_total) * 100, 1) if remediation_total > 0 else 0 ) + # Assign overall = round( overall = round( total_coverage * 0.4 + critical_coverage * 0.25 + detection_maturity * 0.2 + response_readiness * 0.15, + # Literal argument value 1, ) + # Return { return { + # Literal argument value "overall_score": overall, + # Literal argument value "total_coverage": total_coverage, + # Literal argument value "critical_coverage": critical_coverage, + # Literal argument value "detection_maturity": detection_maturity, + # Literal argument value "response_readiness": response_readiness, + # Literal argument value "techniques_evaluated": evaluated_count, + # Literal argument value "techniques_total": total_count, } @@ -653,38 +1068,57 @@ def calculate_organization_score(db: Session) -> dict: def get_score_history(db: Session, period: str = "90d") -> list: - """Get historical score snapshots. + """Return historical score snapshots approximated from test dates. - Since we don't have a dedicated history table, we approximate by - computing scores based on test dates within time windows. - Returns a list of weekly data points. + Since there is no dedicated history table, scores are approximated by + counting validated tests within weekly time windows. + + Args: + db (Session): Active SQLAlchemy database session. + period (str): Lookback period; one of ``"30d"``, ``"90d"`` + (default), or ``"1y"``. + + Returns: + list: Weekly data points, each a dict with ``date``, ``score``, + and ``validated_tests``. """ - from app.models.audit import AuditLog - + # Assign now = datetime.utcnow() now = datetime.utcnow() + # Check: period == "30d" if period == "30d": + # Assign start = now - timedelta(days=30) start = now - timedelta(days=30) + # Alternative: period == "1y" elif period == "1y": + # Assign start = now - timedelta(days=365) start = now - timedelta(days=365) + # else: # 90d default else: # 90d default + # Assign start = now - timedelta(days=90) start = now - timedelta(days=90) # Group validated tests by week weeks = [] + # Assign current = start current = start + # Loop while current < now while current < now: + # Assign week_end = min(current + timedelta(days=7), now) week_end = min(current + timedelta(days=7), now) # Count validated tests up to this week validated_up_to = ( db.query(func.count(Test.id)) + # Chain .filter() call .filter( Test.state == TestState.validated, Test.red_validated_at <= week_end, ) + # Chain .scalar() call .scalar() ) or 0 + # Assign total_techniques = ( total_techniques = ( db.query(func.count(Technique.id)).scalar() ) or 1 @@ -692,12 +1126,18 @@ def get_score_history(db: Session, period: str = "90d") -> list: # Simple approximation: coverage percentage as score proxy score_approx = round((validated_up_to / total_techniques) * 100, 1) + # Call weeks.append() weeks.append({ + # Literal argument value "date": current.strftime("%Y-%m-%d"), + # Literal argument value "score": min(score_approx, 100), + # Literal argument value "validated_tests": validated_up_to, }) + # Assign current = week_end current = week_end + # Return weeks return weeks diff --git a/backend/app/services/sigma_import_service.py b/backend/app/services/sigma_import_service.py index e67a5a1..96f213d 100644 --- a/backend/app/services/sigma_import_service.py +++ b/backend/app/services/sigma_import_service.py @@ -22,24 +22,45 @@ rules are identified by ``source = "sigma"`` + ``source_id`` (relative file path) and simply skipped. """ +# Import io import io + +# Import logging import logging + +# Import re import re + +# Import shutil import shutil + +# Import tempfile import tempfile + +# Import zipfile import zipfile + +# Import datetime from datetime from datetime import datetime + +# Import Path from pathlib from pathlib import Path +# Import requests import requests as _requests + +# Import yaml import yaml + +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session -from app.models.detection_rule import DetectionRule +# Import DataSource from app.models.data_source from app.models.data_source import DataSource from app.models.technique import Technique from app.services.audit_service import log_action +# Assign logger = logging.getLogger(__name__) logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- @@ -47,22 +68,35 @@ logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- SIGMA_ZIP_URL = ( + # Literal argument value "https://github.com/SigmaHQ/sigma/archive/refs/heads/master.zip" ) +# Assign _DOWNLOAD_TIMEOUT = 300 _DOWNLOAD_TIMEOUT = 300 +# Assign _ZIP_ROOT_PREFIX = "sigma-master" _ZIP_ROOT_PREFIX = "sigma-master" +# Safety limits for ZIP extraction — prevent zip-bomb DoS +_MAX_UNCOMPRESSED_SIZE = 500 * 1024 * 1024 # 500 MB +# Assign _MAX_ENTRIES = 50_000 +_MAX_ENTRIES = 50_000 + # Regex to extract MITRE ATT&CK technique IDs from Sigma tags # e.g. "attack.t1059.001" → "T1059.001" _ATTACK_TAG_RE = re.compile(r"attack\.(t\d{4}(?:\.\d{3})?)", re.IGNORECASE) # Sigma severity levels _SEVERITY_MAP = { + # Literal argument value "informational": "informational", + # Literal argument value "low": "low", + # Literal argument value "medium": "medium", + # Literal argument value "high": "high", + # Literal argument value "critical": "critical", } @@ -74,14 +108,21 @@ _SEVERITY_MAP = { def _download_zip(url: str = SIGMA_ZIP_URL) -> bytes: """Download the SigmaHQ ZIP and return raw bytes.""" + # Log info: "Downloading SigmaHQ ZIP from %s …", url logger.info("Downloading SigmaHQ ZIP from %s …", url) + # Assign resp = _requests.get(url, timeout=_DOWNLOAD_TIMEOUT, stream=True) resp = _requests.get(url, timeout=_DOWNLOAD_TIMEOUT, stream=True) + # Call resp.raise_for_status() resp.raise_for_status() + # Assign content = resp.content content = resp.content + # Log info: "Downloaded %.1f MB", len(content) / (1024 * 1024 logger.info("Downloaded %.1f MB", len(content) / (1024 * 1024)) + # Return content return content +# Define function _safe_extract_zip def _safe_extract_zip(zip_bytes: bytes, dest: str) -> None: """Extract *zip_bytes* into *dest* with Zip Slip and Zip Bomb protection. @@ -89,165 +130,249 @@ def _safe_extract_zip(zip_bytes: bytes, dest: str) -> None: directory (path traversal / Zip Slip) or if the archive exceeds the safety limits. """ - # Maximum uncompressed size: 500 MB — prevents zip-bomb DoS - _MAX_UNCOMPRESSED_SIZE = 500 * 1024 * 1024 - # Maximum number of entries - _MAX_ENTRIES = 50_000 - + # Assign dest_path = Path(dest).resolve() dest_path = Path(dest).resolve() + # Open context manager with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf: + # Assign entries = zf.infolist() entries = zf.infolist() + # Check: len(entries) > _MAX_ENTRIES if len(entries) > _MAX_ENTRIES: + # Raise ValueError raise ValueError( f"ZIP archive contains {len(entries)} entries " f"(limit: {_MAX_ENTRIES}) — possible zip bomb" ) + # Assign total_size = sum(info.file_size for info in entries) total_size = sum(info.file_size for info in entries) + # Check: total_size > _MAX_UNCOMPRESSED_SIZE if total_size > _MAX_UNCOMPRESSED_SIZE: + # Raise ValueError raise ValueError( f"ZIP uncompressed size {total_size / (1024 * 1024):.0f} MB " f"exceeds limit of {_MAX_UNCOMPRESSED_SIZE / (1024 * 1024):.0f} MB" ) + # Iterate over entries for member in entries: + # Assign target = (dest_path / member.filename).resolve() target = (dest_path / member.filename).resolve() + # Check: not target.is_relative_to(dest_path) if not target.is_relative_to(dest_path): + # Raise ValueError raise ValueError( f"Zip Slip detected — member '{member.filename}' " f"resolves outside target directory" ) + # Call zf.extractall() zf.extractall(dest) +# Define function _extract_zip def _extract_zip(zip_bytes: bytes, dest: str) -> Path: """Extract *zip_bytes* into *dest* and return the path to rules/ dir.""" + # Call _safe_extract_zip() _safe_extract_zip(zip_bytes, dest) + # Assign rules_dir = Path(dest) / _ZIP_ROOT_PREFIX / "rules" rules_dir = Path(dest) / _ZIP_ROOT_PREFIX / "rules" + # Check: not rules_dir.is_dir() if not rules_dir.is_dir(): + # Raise FileNotFoundError raise FileNotFoundError( f"Expected rules directory not found at {rules_dir}" ) + # Return rules_dir return rules_dir +# Define function _extract_attack_tags def _extract_attack_tags(tags: list) -> list[str]: """Extract MITRE technique IDs from Sigma tag list. Example input: ["attack.defense_evasion", "attack.t1059.001", "cve.2021.44228"] Example output: ["T1059.001"] """ + # Assign technique_ids = [] technique_ids = [] + # Iterate over tags for tag in tags: + # Assign m = _ATTACK_TAG_RE.match(str(tag).strip()) m = _ATTACK_TAG_RE.match(str(tag).strip()) + # Check: m if m: + # Call technique_ids.append() technique_ids.append(m.group(1).upper()) + # Return list(set(technique_ids)) return list(set(technique_ids)) +# Define function _parse_sigma_rules def _parse_sigma_rules(rules_dir: Path) -> list[dict]: """Walk the rules directory and parse all Sigma YAML files. Returns a flat list of dicts, one per (rule, technique) combination. A single Sigma rule tagged with N techniques produces N entries. """ + # Assign results = [] results: list[dict] = [] + # Assign yaml_files = sorted(rules_dir.rglob("*.yml")) yaml_files = sorted(rules_dir.rglob("*.yml")) + # Log info: "Found %d YAML files to parse", len(yaml_files logger.info("Found %d YAML files to parse", len(yaml_files)) + # Iterate over yaml_files for yaml_path in yaml_files: + # Assign relative_path = str(yaml_path.relative_to(rules_dir.parent)) relative_path = str(yaml_path.relative_to(rules_dir.parent)) + # Attempt the following; catch errors below try: + # Open context manager with open(yaml_path, "r", encoding="utf-8") as fh: + # Assign data = yaml.safe_load(fh) data = yaml.safe_load(fh) + # Handle Exception except Exception as exc: + # Log debug: "Failed to parse %s: %s", yaml_path, exc logger.debug("Failed to parse %s: %s", yaml_path, exc) + # Skip to the next loop iteration continue + # Check: not isinstance(data, dict) if not isinstance(data, dict): + # Skip to the next loop iteration continue + # Assign title = data.get("title", "").strip() title = data.get("title", "").strip() + # Check: not title if not title: + # Skip to the next loop iteration continue # Extract ATT&CK technique IDs from tags tags = data.get("tags", []) + # Check: not isinstance(tags, list) if not isinstance(tags, list): + # Skip to the next loop iteration continue + # Assign technique_ids = _extract_attack_tags(tags) technique_ids = _extract_attack_tags(tags) + # Check: not technique_ids if not technique_ids: + # continue # Skip rules without ATT&CK mapping continue # Skip rules without ATT&CK mapping + # Assign description = data.get("description", "") description = data.get("description", "") + # Assign level = str(data.get("level", "")).lower() level = str(data.get("level", "")).lower() + # Assign severity = _SEVERITY_MAP.get(level) severity = _SEVERITY_MAP.get(level) # Extract logsource logsource = data.get("logsource", {}) + # Check: not isinstance(logsource, dict) if not isinstance(logsource, dict): + # Assign logsource = {} logsource = {} # Read full YAML content for storage try: + # Open context manager with open(yaml_path, "r", encoding="utf-8") as fh: + # Assign raw_content = fh.read() raw_content = fh.read() + # Handle Exception except Exception: + # Assign raw_content = yaml.dump(data, default_flow_style=False) raw_content = yaml.dump(data, default_flow_style=False) # False positive assessment falsepositives = data.get("falsepositives", []) + # Check: isinstance(falsepositives, list) and len(falsepositives) > 3 if isinstance(falsepositives, list) and len(falsepositives) > 3: + # Assign fp_rate = "high" fp_rate = "high" + # Alternative: isinstance(falsepositives, list) and len(falsepositives) > 1 elif isinstance(falsepositives, list) and len(falsepositives) > 1: + # Assign fp_rate = "medium" fp_rate = "medium" + # Fallback: handle remaining cases else: + # Assign fp_rate = "low" fp_rate = "low" # Create one entry per technique for tech_id in technique_ids: + # Assign source_url = ( source_url = ( f"https://github.com/SigmaHQ/sigma/blob/master/" f"{relative_path.replace(chr(92), '/')}" ) + # Call results.append() results.append({ + # Literal argument value "mitre_technique_id": tech_id, + # Literal argument value "title": title[:500], + # Literal argument value "description": str(description)[:2000] if description else None, + # Literal argument value "source_id": relative_path, + # Literal argument value "source_url": source_url, + # Literal argument value "rule_content": raw_content, + # Literal argument value "severity": severity, + # Literal argument value "log_sources": logsource if logsource else None, + # Literal argument value "false_positive_rate": fp_rate, + # Literal argument value "platforms": _platforms_from_logsource(logsource), }) + # Log info: "Parsed %d (rule, technique) pairs total", len(res logger.info("Parsed %d (rule, technique) pairs total", len(results)) + # Return results return results +# Define function _platforms_from_logsource def _platforms_from_logsource(logsource: dict) -> list[str]: """Infer platform list from Sigma logsource.""" + # Assign platforms = [] platforms = [] + # Assign product = str(logsource.get("product", "")).lower() product = str(logsource.get("product", "")).lower() + # Assign service = str(logsource.get("service", "")).lower() service = str(logsource.get("service", "")).lower() + # Check: "windows" in product or "windows" in service if "windows" in product or "windows" in service: + # Call platforms.append() platforms.append("windows") + # Check: "linux" in product or "linux" in service if "linux" in product or "linux" in service: + # Call platforms.append() platforms.append("linux") + # Check: "macos" in product or "macos" in service if "macos" in product or "macos" in service: + # Call platforms.append() platforms.append("macos") # Sysmon → Windows if "sysmon" in service and "windows" not in platforms: + # Call platforms.append() platforms.append("windows") + # Return platforms if platforms else None return platforms if platforms else None @@ -264,59 +389,88 @@ def sync(db: Session) -> dict: db : Session Active SQLAlchemy database session. - Returns + Returns: ------- dict Summary with ``created``, ``skipped_existing``, ``total_parsed``. """ + # Assign tmp_dir = tempfile.mkdtemp(prefix="aegis_sigma_") tmp_dir = tempfile.mkdtemp(prefix="aegis_sigma_") + # Attempt the following; catch errors below try: + # Assign zip_bytes = _download_zip() zip_bytes = _download_zip() + # Assign rules_dir = _extract_zip(zip_bytes, tmp_dir) rules_dir = _extract_zip(zip_bytes, tmp_dir) + # Assign parsed_rules = _parse_sigma_rules(rules_dir) parsed_rules = _parse_sigma_rules(rules_dir) + # Always execute this cleanup block finally: + # Call shutil.rmtree() shutil.rmtree(tmp_dir, ignore_errors=True) + # Log info: "Cleaned up temp directory %s", tmp_dir logger.info("Cleaned up temp directory %s", tmp_dir) # Pre-load existing source_ids for dedup existing_ids: set[str] = { row[0] for row in db.query(DetectionRule.source_id) + # Chain .filter() call .filter(DetectionRule.source == "sigma") + # Chain .filter() call .filter(DetectionRule.source_id.isnot(None)) + # Chain .all() call .all() } + # Assign created = 0 created = 0 + # Assign skipped = 0 skipped = 0 new_technique_ids: set[str] = set() + # Iterate over parsed_rules for item in parsed_rules: - # Dedup key: source_id (relative path). A rule file may produce - # multiple entries (one per technique), but we deduplicate by - # source_id so re-runs are safe. For multi-technique rules we - # only skip if the exact same source_id is already present. - dedup_key = f"{item['source_id']}::{item['mitre_technique_id']}" + # Deduplicate by source_id: one rule file may map to multiple techniques, + # but we skip insertion if this source_id was already imported. if item["source_id"] in existing_ids: + # Assign skipped = 1 skipped += 1 + # Skip to the next loop iteration continue + # Assign rule = DetectionRule( rule = DetectionRule( + # Keyword argument: mitre_technique_id mitre_technique_id=item["mitre_technique_id"], + # Keyword argument: title title=item["title"], + # Keyword argument: description description=item["description"], + # Keyword argument: source source="sigma", + # Keyword argument: source_id source_id=item["source_id"], + # Keyword argument: source_url source_url=item["source_url"], + # Keyword argument: rule_content rule_content=item["rule_content"], + # Keyword argument: rule_format rule_format="sigma_yaml", + # Keyword argument: severity severity=item["severity"], + # Keyword argument: platforms platforms=item["platforms"], + # Keyword argument: log_sources log_sources=item["log_sources"], + # Keyword argument: false_positive_rate false_positive_rate=item["false_positive_rate"], + # Keyword argument: is_active is_active=True, ) + # Stage new record(s) for database insertion db.add(rule) + # Call existing_ids.add() existing_ids.add(item["source_id"]) new_technique_ids.add(item["mitre_technique_id"]) created += 1 @@ -329,30 +483,48 @@ def sync(db: Session) -> dict: db.commit() + # Assign summary = { summary = { + # Literal argument value "created": created, + # Literal argument value "skipped_existing": skipped, + # Literal argument value "total_parsed": len(parsed_rules), } # Update DataSource record ds = db.query(DataSource).filter(DataSource.name == "sigma").first() + # Check: ds if ds: + # Assign ds.last_sync_at = datetime.utcnow() ds.last_sync_at = datetime.utcnow() + # Assign ds.last_sync_status = "success" ds.last_sync_status = "success" + # Assign ds.last_sync_stats = summary ds.last_sync_stats = summary + # Commit all pending changes to the database db.commit() + # Log info: "Sigma import complete — %s", summary logger.info("Sigma import complete — %s", summary) + # Call log_action() log_action( db, + # Keyword argument: user_id user_id=None, + # Keyword argument: action action="import_sigma_rules", + # Keyword argument: entity_type entity_type="detection_rule", + # Keyword argument: entity_id entity_id=None, + # Keyword argument: details details=summary, ) + # Commit all pending changes to the database db.commit() + # Return summary return summary diff --git a/backend/app/services/snapshot_service.py b/backend/app/services/snapshot_service.py index 9a2e305..30e7b5f 100644 --- a/backend/app/services/snapshot_service.py +++ b/backend/app/services/snapshot_service.py @@ -7,25 +7,59 @@ Uses ``bulk_technique_scores`` so that snapshot creation runs in a fixed number of SQL queries regardless of technique count. """ +# Import logging import logging + +# Import uuid import uuid + +# Import defaultdict from collections from collections import defaultdict + +# Import datetime, timedelta, timezone from datetime from datetime import datetime, timedelta, timezone +# Import func from sqlalchemy from sqlalchemy import func + +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session +# Import EntityNotFoundError from app.domain.errors from app.domain.errors import EntityNotFoundError -from app.models.technique import Technique + +# Import CoverageSnapshot, SnapshotTechniqueState from app.models.coverage_snapshot from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState + +# Import TechniqueStatus from app.models.enums from app.models.enums import TechniqueStatus + +# Import Technique from app.models.technique +from app.models.technique import Technique + +# Import from app.services.scoring_service from app.services.scoring_service import ( bulk_technique_scores, calculate_organization_score, ) +# Assign logger = logging.getLogger(__name__) logger = logging.getLogger(__name__) +# Coverage status ordering for snapshot delta comparisons (higher = better coverage) +_STATUS_ORDER: dict[str, int] = { + # Literal argument value + "not_evaluated": 0, + # Literal argument value + "not_covered": 1, + # Literal argument value + "in_progress": 2, + # Literal argument value + "partial": 3, + # Literal argument value + "validated": 4, +} + # --------------------------------------------------------------------------- # Serialization and queries @@ -33,97 +67,207 @@ logger = logging.getLogger(__name__) def serialize_snapshot_summary(snap: CoverageSnapshot) -> dict: - """Lightweight serialization for list views.""" + """Return a lightweight serialization of a snapshot for list views. + + Args: + snap (CoverageSnapshot): The snapshot ORM object to serialize. + + Returns: + dict: Flat dictionary with summary fields (counts, scores, tactic + breakdown) suitable for paginated list responses. + """ + # Return { return { + # Literal argument value "id": str(snap.id), + # Literal argument value "name": snap.name, + # Literal argument value "organization_score": snap.organization_score, + # Literal argument value "total_techniques": snap.total_techniques, + # Literal argument value "validated_count": snap.validated_count, + # Literal argument value "partial_count": snap.partial_count, + # Literal argument value "not_covered_count": snap.not_covered_count, + # Literal argument value "in_progress_count": snap.in_progress_count, + # Literal argument value "not_evaluated_count": snap.not_evaluated_count, + # Literal argument value "coverage_percentage": getattr(snap, "coverage_percentage", 0.0), + # Literal argument value "by_tactic": getattr(snap, "by_tactic", None) or {}, + # Literal argument value "by_status": getattr(snap, "by_status", None) or {}, + # Literal argument value "stale_count": getattr(snap, "stale_count", 0), + # Literal argument value "never_tested_count": getattr(snap, "never_tested_count", 0), + # Literal argument value "created_by": str(snap.created_by) if snap.created_by else None, + # Literal argument value "created_at": snap.created_at.isoformat() if snap.created_at else None, } +# Define function serialize_snapshot_detail def serialize_snapshot_detail(db: Session, snap: CoverageSnapshot) -> dict: - """Full serialization including technique states.""" + """Return full serialization of a snapshot including per-technique states. + + Args: + db (Session): Active SQLAlchemy database session. + snap (CoverageSnapshot): The snapshot ORM object to serialize. + + Returns: + dict: Summary fields merged with a ``technique_states`` list, each + entry containing ``mitre_id``, ``technique_id``, ``status``, + and ``score``. + """ + # Assign base = serialize_snapshot_summary(snap) base = serialize_snapshot_summary(snap) + # Assign technique_states = ( technique_states = ( db.query(SnapshotTechniqueState) + # Chain .filter() call .filter(SnapshotTechniqueState.snapshot_id == snap.id) + # Chain .order_by() call .order_by(SnapshotTechniqueState.mitre_id) + # Chain .all() call .all() ) + # Assign base["technique_states"] = [ base["technique_states"] = [ { + # Literal argument value "mitre_id": s.mitre_id, + # Literal argument value "technique_id": str(s.technique_id), + # Literal argument value "status": s.status, + # Literal argument value "score": s.score, } for s in technique_states ] + # Return base return base +# Define function list_snapshots def list_snapshots( + # Entry: db db: Session, *, + # Entry: offset offset: int = 0, + # Entry: limit limit: int = 50, ) -> dict: - """List coverage snapshots ordered by creation date (newest first).""" + """List coverage snapshots ordered by creation date (newest first). + + Args: + db (Session): Active SQLAlchemy database session. + offset (int): Number of records to skip for pagination. + limit (int): Maximum number of records to return. + + Returns: + dict: Contains ``total``, ``offset``, ``limit``, and ``items`` (list + of serialized snapshot summaries). + """ + # Assign query = db.query(CoverageSnapshot) query = db.query(CoverageSnapshot) + # Assign total = query.count() total = query.count() + # Assign snapshots = ( snapshots = ( query + # Chain .order_by() call .order_by(CoverageSnapshot.created_at.desc()) + # Chain .offset() call .offset(offset) + # Chain .limit() call .limit(limit) + # Chain .all() call .all() ) + # Return { return { + # Literal argument value "total": total, + # Literal argument value "offset": offset, + # Literal argument value "limit": limit, + # Literal argument value "items": [serialize_snapshot_summary(s) for s in snapshots], } +# Define function get_snapshot_or_raise def get_snapshot_or_raise(db: Session, snapshot_id: str) -> CoverageSnapshot: - """Fetch snapshot by ID or raise EntityNotFoundError.""" + """Fetch snapshot by ID or raise EntityNotFoundError. + + Args: + db (Session): Active SQLAlchemy database session. + snapshot_id (str): UUID string of the snapshot to retrieve. + + Returns: + CoverageSnapshot: The matching snapshot ORM object. + """ + # Attempt the following; catch errors below try: + # Assign sid = uuid.UUID(snapshot_id) sid = uuid.UUID(snapshot_id) + # Handle (ValueError, TypeError) except (ValueError, TypeError): + # Raise EntityNotFoundError raise EntityNotFoundError("Snapshot", snapshot_id) + # Assign snapshot = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == sid).first() snapshot = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == sid).first() + # Check: snapshot is None if snapshot is None: + # Raise EntityNotFoundError raise EntityNotFoundError("Snapshot", snapshot_id) + # Return snapshot return snapshot +# Define function get_snapshot_detail def get_snapshot_detail(db: Session, snapshot_id: str) -> dict: - """Get detailed snapshot including per-technique states.""" + """Return detailed snapshot data including per-technique states. + + Args: + db (Session): Active SQLAlchemy database session. + snapshot_id (str): UUID string of the snapshot to retrieve. + + Returns: + dict: Full snapshot serialization from + :func:`serialize_snapshot_detail`. + """ + # Assign snapshot = get_snapshot_or_raise(db, snapshot_id) snapshot = get_snapshot_or_raise(db, snapshot_id) + # Return serialize_snapshot_detail(db, snapshot) return serialize_snapshot_detail(db, snapshot) +# Define function delete_snapshot def delete_snapshot(db: Session, snapshot_id: str) -> None: - """Delete a snapshot. Does not commit — caller must commit.""" + """Delete a snapshot. Does not commit — caller must commit. + + Args: + db (Session): Active SQLAlchemy database session. + snapshot_id (str): UUID string of the snapshot to delete. + """ + # Assign snapshot = get_snapshot_or_raise(db, snapshot_id) snapshot = get_snapshot_or_raise(db, snapshot_id) + # Mark record for deletion on next commit db.delete(snapshot) @@ -133,8 +277,11 @@ def delete_snapshot(db: Session, snapshot_id: str) -> None: def create_snapshot( + # Entry: db db: Session, + # Entry: name name: str | None = None, + # Entry: user_id user_id: uuid.UUID | None = None, ) -> CoverageSnapshot: """Capture the current coverage state into a new snapshot. @@ -144,121 +291,215 @@ def create_snapshot( 3. Compute the org score from the same bulk data. 4. Persist a ``CoverageSnapshot`` with normalised ``SnapshotTechniqueState`` rows. + + Args: + db (Session): Active SQLAlchemy database session. + name (str | None): Optional human-readable label for the snapshot. + user_id (uuid.UUID | None): UUID of the user creating the snapshot, + stored for auditing. + + Returns: + CoverageSnapshot: The newly created and committed snapshot ORM object. """ + # Assign scores_map = bulk_technique_scores(db) scores_map = bulk_technique_scores(db) + # Assign techniques = db.query(Technique).all() techniques = db.query(Technique).all() + # Assign validated_count = 0 validated_count = 0 + # Assign partial_count = 0 partial_count = 0 + # Assign not_covered_count = 0 not_covered_count = 0 + # Assign in_progress_count = 0 in_progress_count = 0 + # Assign not_evaluated_count = 0 not_evaluated_count = 0 + # Assign stale_count = 0 stale_count = 0 + # Assign never_tested_count = 0 never_tested_count = 0 + # Assign by_tactic = defaultdict( by_tactic: dict[str, dict] = defaultdict( + # Entry: lambda lambda: {"total": 0, "validated": 0, "partial": 0, "score_sum": 0.0} ) + # Assign by_status = defaultdict(int) by_status: dict[str, int] = defaultdict(int) + # Assign technique_rows = [] technique_rows: list[dict] = [] + # Iterate over techniques for tech in techniques: + # Assign status_value = ( status_value = ( tech.status_global.value if isinstance(tech.status_global, TechniqueStatus) else (tech.status_global or "not_evaluated") ) + # Check: status_value == "validated" if status_value == "validated": + # Assign validated_count = 1 validated_count += 1 + # Alternative: status_value == "partial" elif status_value == "partial": + # Assign partial_count = 1 partial_count += 1 + # Alternative: status_value == "not_covered" elif status_value == "not_covered": + # Assign not_covered_count = 1 not_covered_count += 1 + # Alternative: status_value == "in_progress" elif status_value == "in_progress": + # Assign in_progress_count = 1 in_progress_count += 1 + # Fallback: handle remaining cases else: + # Assign not_evaluated_count = 1 not_evaluated_count += 1 + # Assign entry = scores_map.get(tech.id, {}) entry = scores_map.get(tech.id, {}) + # Assign score = entry.get("total_score", 0) score = entry.get("total_score", 0) + # Call technique_rows.append() technique_rows.append({ + # Literal argument value "technique_id": tech.id, + # Literal argument value "mitre_id": tech.mitre_id, + # Literal argument value "status": status_value, + # Literal argument value "score": score, }) + # Assign by_status[status_value] = 1 by_status[status_value] += 1 + # Assign tactic_key = tech.tactic or "unknown" tactic_key = tech.tactic or "unknown" + # Assign bucket = by_tactic[tactic_key] bucket = by_tactic[tactic_key] + # Assign bucket["total"] = 1 bucket["total"] += 1 + # Assign bucket["score_sum"] = score bucket["score_sum"] += score + # Check: status_value == "validated" if status_value == "validated": + # Assign bucket["validated"] = 1 bucket["validated"] += 1 + # Alternative: status_value == "partial" elif status_value == "partial": + # Assign bucket["partial"] = 1 bucket["partial"] += 1 + # Check: status_value == "not_evaluated" if status_value == "not_evaluated": + # Assign never_tested_count = 1 never_tested_count += 1 + # Check: tech.review_required if tech.review_required: + # Assign stale_count = 1 stale_count += 1 + # Assign org_data = calculate_organization_score(db) org_data = calculate_organization_score(db) + # Assign org_score = org_data.get("overall_score", 0) org_score = org_data.get("overall_score", 0) + # Assign total_techniques = len(techniques) or 1 total_techniques = len(techniques) or 1 + # Assign coverage_pct = round((validated_count / total_techniques) * 100, 1) coverage_pct = round((validated_count / total_techniques) * 100, 1) + # Assign by_tactic_out = { by_tactic_out = { + # Entry: tactic tactic: { + # Literal argument value "total": data["total"], + # Literal argument value "validated": data["validated"], + # Literal argument value "partial": data["partial"], + # Literal argument value "average_score": round(data["score_sum"] / data["total"], 1) if data["total"] else 0, } for tactic, data in by_tactic.items() } + # Assign snapshot = CoverageSnapshot( snapshot = CoverageSnapshot( + # Keyword argument: name name=name, + # Keyword argument: organization_score organization_score=org_score, + # Keyword argument: total_techniques total_techniques=len(techniques), + # Keyword argument: validated_count validated_count=validated_count, + # Keyword argument: partial_count partial_count=partial_count, + # Keyword argument: not_covered_count not_covered_count=not_covered_count, + # Keyword argument: in_progress_count in_progress_count=in_progress_count, + # Keyword argument: not_evaluated_count not_evaluated_count=not_evaluated_count, + # Keyword argument: coverage_percentage coverage_percentage=coverage_pct, + # Keyword argument: by_tactic by_tactic=by_tactic_out, + # Keyword argument: by_status by_status=dict(by_status), + # Keyword argument: stale_count stale_count=stale_count, + # Keyword argument: never_tested_count never_tested_count=never_tested_count, + # Keyword argument: created_by created_by=user_id, ) + # Stage new record(s) for database insertion db.add(snapshot) + # Flush changes to DB without committing the transaction db.flush() + # Iterate over technique_rows for row in technique_rows: + # Assign state = SnapshotTechniqueState( state = SnapshotTechniqueState( + # Keyword argument: snapshot_id snapshot_id=snapshot.id, + # Keyword argument: technique_id technique_id=row["technique_id"], + # Keyword argument: mitre_id mitre_id=row["mitre_id"], + # Keyword argument: status status=row["status"], + # Keyword argument: score score=row["score"], ) + # Stage new record(s) for database insertion db.add(state) + # Commit all pending changes to the database db.commit() + # Reload ORM object attributes from the database db.refresh(snapshot) + # Log info: logger.info( + # Literal argument value "Snapshot '%s' created — %d techniques, org score %.1f", snapshot.name or snapshot.id, len(techniques), org_score, ) + # Return snapshot return snapshot @@ -268,99 +509,160 @@ def create_snapshot( def compare_snapshots( + # Entry: db db: Session, + # Entry: snapshot_a_id snapshot_a_id: uuid.UUID, + # Entry: snapshot_b_id snapshot_b_id: uuid.UUID, ) -> dict: """Compare two snapshots and return deltas. Returns improved/worsened technique lists plus aggregate statistics. + + Args: + db (Session): Active SQLAlchemy database session. + snapshot_a_id (uuid.UUID): UUID of the baseline (older) snapshot. + snapshot_b_id (uuid.UUID): UUID of the comparison (newer) snapshot. + + Returns: + dict: Contains ``snapshot_a``, ``snapshot_b``, ``score_delta``, + ``improved``, ``worsened``, ``unchanged_count``, and ``summary`` + keys. """ + # Assign snap_a = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_a... snap_a = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_a_id).first() + # Assign snap_b = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_b... snap_b = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_b_id).first() + # Check: not snap_a or not snap_b if not snap_a or not snap_b: + # Raise EntityNotFoundError raise EntityNotFoundError("Snapshot", f"{snapshot_a_id} or {snapshot_b_id}") # Build lookup dicts: mitre_id -> {status, score} states_a = { s.mitre_id: {"status": s.status, "score": s.score or 0} for s in db.query(SnapshotTechniqueState) + # Chain .filter() call .filter(SnapshotTechniqueState.snapshot_id == snapshot_a_id) + # Chain .all() call .all() } + # Assign states_b = { states_b = { s.mitre_id: {"status": s.status, "score": s.score or 0} for s in db.query(SnapshotTechniqueState) + # Chain .filter() call .filter(SnapshotTechniqueState.snapshot_id == snapshot_b_id) + # Chain .all() call .all() } - # Status priority for comparison - STATUS_ORDER = { - "not_evaluated": 0, - "not_covered": 1, - "in_progress": 2, - "partial": 3, - "validated": 4, - } - + # Assign improved = [] improved = [] + # Assign worsened = [] worsened = [] + # Assign unchanged_count = 0 unchanged_count = 0 + # Assign all_mitre_ids = set(states_a.keys()) | set(states_b.keys()) all_mitre_ids = set(states_a.keys()) | set(states_b.keys()) + # Iterate over sorted(all_mitre_ids) for mitre_id in sorted(all_mitre_ids): + # Assign a = states_a.get(mitre_id, {"status": "not_evaluated", "score": 0}) a = states_a.get(mitre_id, {"status": "not_evaluated", "score": 0}) + # Assign b = states_b.get(mitre_id, {"status": "not_evaluated", "score": 0}) b = states_b.get(mitre_id, {"status": "not_evaluated", "score": 0}) - a_order = STATUS_ORDER.get(a["status"], 0) - b_order = STATUS_ORDER.get(b["status"], 0) + # Assign a_order = _STATUS_ORDER.get(a["status"], 0) + a_order = _STATUS_ORDER.get(a["status"], 0) + # Assign b_order = _STATUS_ORDER.get(b["status"], 0) + b_order = _STATUS_ORDER.get(b["status"], 0) + # Check: b_order > a_order or (b_order == a_order and b["score"] > a["score"]) if b_order > a_order or (b_order == a_order and b["score"] > a["score"]): + # Call improved.append() improved.append({ + # Literal argument value "mitre_id": mitre_id, + # Literal argument value "old_status": a["status"], + # Literal argument value "new_status": b["status"], + # Literal argument value "old_score": a["score"], + # Literal argument value "new_score": b["score"], }) + # Alternative: b_order < a_order or (b_order == a_order and b["score"] < a["score"]) elif b_order < a_order or (b_order == a_order and b["score"] < a["score"]): + # Call worsened.append() worsened.append({ + # Literal argument value "mitre_id": mitre_id, + # Literal argument value "old_status": a["status"], + # Literal argument value "new_status": b["status"], + # Literal argument value "old_score": a["score"], + # Literal argument value "new_score": b["score"], }) + # Fallback: handle remaining cases else: + # Assign unchanged_count = 1 unchanged_count += 1 + # Define function _snap_summary def _snap_summary(snap: CoverageSnapshot) -> dict: + # Return { return { + # Literal argument value "id": str(snap.id), + # Literal argument value "name": snap.name, + # Literal argument value "organization_score": snap.organization_score, + # Literal argument value "total_techniques": snap.total_techniques, + # Literal argument value "validated_count": snap.validated_count, + # Literal argument value "partial_count": snap.partial_count, + # Literal argument value "not_covered_count": snap.not_covered_count, + # Literal argument value "in_progress_count": snap.in_progress_count, + # Literal argument value "not_evaluated_count": snap.not_evaluated_count, + # Literal argument value "created_at": snap.created_at.isoformat() if snap.created_at else None, } + # Return { return { + # Literal argument value "snapshot_a": _snap_summary(snap_a), + # Literal argument value "snapshot_b": _snap_summary(snap_b), + # Literal argument value "score_delta": round(snap_b.organization_score - snap_a.organization_score, 1), + # Literal argument value "improved": improved, + # Literal argument value "worsened": worsened, + # Literal argument value "unchanged_count": unchanged_count, + # Literal argument value "summary": { + # Literal argument value "improved_count": len(improved), + # Literal argument value "worsened_count": len(worsened), + # Literal argument value "new_count": len(states_b.keys() - states_a.keys()), }, } @@ -372,25 +674,53 @@ def compare_snapshots( def get_coverage_evolution(db: Session, *, months: int = 12) -> list[dict]: - """Return snapshot trend points for the last *months* months.""" + """Return snapshot trend points for the last *months* months. + + Args: + db (Session): Active SQLAlchemy database session. + months (int): Number of months to look back; defaults to 12. + + Returns: + list[dict]: Snapshot trend entries ordered by creation date ascending, + each containing ``date``, ``name``, ``org_score``, + ``coverage_pct``, ``by_tactic``, ``by_status``, + ``stale_count``, ``never_tested_count``, ``validated_count``, + and ``total_techniques``. + """ + # Assign cutoff = datetime.now(timezone.utc) - timedelta(days=months * 30) cutoff = datetime.now(timezone.utc) - timedelta(days=months * 30) + # Assign snapshots = ( snapshots = ( db.query(CoverageSnapshot) + # Chain .filter() call .filter(CoverageSnapshot.created_at >= cutoff) + # Chain .order_by() call .order_by(CoverageSnapshot.created_at.asc()) + # Chain .all() call .all() ) + # Return [ return [ { + # Literal argument value "date": snap.created_at.isoformat() if snap.created_at else None, + # Literal argument value "name": snap.name, + # Literal argument value "org_score": snap.organization_score, + # Literal argument value "coverage_pct": getattr(snap, "coverage_percentage", 0.0), + # Literal argument value "by_tactic": getattr(snap, "by_tactic", None) or {}, + # Literal argument value "by_status": getattr(snap, "by_status", None) or {}, + # Literal argument value "stale_count": getattr(snap, "stale_count", 0), + # Literal argument value "never_tested_count": getattr(snap, "never_tested_count", 0), + # Literal argument value "validated_count": snap.validated_count, + # Literal argument value "total_techniques": snap.total_techniques, } for snap in snapshots @@ -405,25 +735,46 @@ def get_coverage_evolution(db: Session, *, months: int = 12) -> list[dict]: def cleanup_old_snapshots(db: Session, keep_last: int = 52) -> int: """Delete oldest snapshots, keeping the most recent *keep_last*. - Returns the number of snapshots deleted. + Args: + db (Session): Active SQLAlchemy database session. + keep_last (int): Number of most-recent snapshots to retain; defaults + to 52 (one year of weekly snapshots). + + Returns: + int: Number of snapshots deleted. """ + # Assign total = db.query(func.count(CoverageSnapshot.id)).scalar() or 0 total = db.query(func.count(CoverageSnapshot.id)).scalar() or 0 + # Check: total <= keep_last if total <= keep_last: + # Return 0 return 0 + # Assign to_delete = total - keep_last to_delete = total - keep_last + # Assign old_snapshots = ( old_snapshots = ( db.query(CoverageSnapshot) + # Chain .order_by() call .order_by(CoverageSnapshot.created_at.asc()) + # Chain .limit() call .limit(to_delete) + # Chain .all() call .all() ) + # Assign deleted = 0 deleted = 0 + # Iterate over old_snapshots for snap in old_snapshots: + # Mark record for deletion on next commit db.delete(snap) + # Assign deleted = 1 deleted += 1 + # Commit all pending changes to the database db.commit() + # Log info: "Snapshot cleanup — deleted %d old snapshots (kept logger.info("Snapshot cleanup — deleted %d old snapshots (kept %d)", deleted, keep_last) + # Return deleted return deleted diff --git a/backend/app/services/stale_detection_service.py b/backend/app/services/stale_detection_service.py index bbd8374..28689b0 100644 --- a/backend/app/services/stale_detection_service.py +++ b/backend/app/services/stale_detection_service.py @@ -1,26 +1,41 @@ -"""Stale coverage detection — marks techniques whose last validated test -is older than a configurable threshold. +"""Stale coverage detection — marks techniques whose last validated test is older than a configurable threshold. This is the simple version. The full Decay Engine (Fase 8) will replace this with a multi-factor, configurable decay model with confidence scores. """ +# Import logging import logging + +# Import datetime, timedelta, timezone from datetime from datetime import datetime, timedelta, timezone +# Import func from sqlalchemy from sqlalchemy import func + +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session +# Import settings from app.config from app.config import settings + +# Import TechniqueStatus, TestState from app.models.enums from app.models.enums import TechniqueStatus, TestState + +# Import Technique from app.models.technique from app.models.technique import Technique + +# Import Test from app.models.test from app.models.test import Test +# Assign logger = logging.getLogger(__name__) logger = logging.getLogger(__name__) +# Assign STALE_THRESHOLD_DAYS = settings.STALE_THRESHOLD_DAYS STALE_THRESHOLD_DAYS = settings.STALE_THRESHOLD_DAYS +# Define function detect_stale_coverage def detect_stale_coverage(db: Session) -> int: """Scan all techniques and flag those with stale coverage. @@ -30,10 +45,17 @@ def detect_stale_coverage(db: Session) -> int: - It has never had a validated test (but has been manually marked as covered/partial). - Returns the number of newly-flagged techniques. + Args: + db (Session): Active SQLAlchemy database session. + + Returns: + int: Number of techniques newly flagged as stale (``review_required`` + set to ``True``) in this run. """ + # Assign cutoff = datetime.now(timezone.utc) - timedelta(days=STALE_THRESHOLD_DAYS) cutoff = datetime.now(timezone.utc) - timedelta(days=STALE_THRESHOLD_DAYS) + # Assign last_validated = func.coalesce( last_validated = func.coalesce( Test.blue_validated_at, Test.red_validated_at, @@ -46,40 +68,60 @@ def detect_stale_coverage(db: Session) -> int: Test.technique_id, func.max(last_validated).label("last_tested"), ) + # Chain .filter() call .filter(Test.state == TestState.validated) + # Chain .group_by() call .group_by(Test.technique_id) + # Chain .subquery() call .subquery() ) # Find techniques that are stale stale_techniques = ( db.query(Technique) + # Chain .outerjoin() call .outerjoin(latest_test, Technique.id == latest_test.c.technique_id) + # Chain .filter() call .filter( # Either tested before cutoff, or never tested at all (latest_test.c.last_tested < cutoff) | (latest_test.c.last_tested.is_(None)) ) + # Chain .filter() call .filter( # Only flag techniques that have a real status (not never-evaluated ones) Technique.status_global != TechniqueStatus.not_evaluated ) + # Chain .all() call .all() ) + # Assign count = 0 count = 0 + # Iterate over stale_techniques for tech in stale_techniques: + # Check: not tech.review_required if not tech.review_required: + # Assign tech.review_required = True tech.review_required = True + # Assign count = 1 count += 1 + # Log info: "Marked %s as stale coverage", tech.mitre_id logger.info("Marked %s as stale coverage", tech.mitre_id) + # Check: count > 0 if count > 0: + # Commit all pending changes to the database db.commit() + # Log info: logger.info( + # Literal argument value "Stale coverage detection complete — %d techniques flagged", count ) + # Fallback: handle remaining cases else: + # Log info: "Stale coverage detection complete — no new stale logger.info("Stale coverage detection complete — no new stale techniques") + # Return count return count diff --git a/backend/app/services/status_service.py b/backend/app/services/status_service.py index dab0044..2550599 100644 --- a/backend/app/services/status_service.py +++ b/backend/app/services/status_service.py @@ -10,21 +10,30 @@ The function mutates the technique but does **not** commit. The caller is responsible for committing the session. """ +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session +# Import TechniqueEntity from app.domain.entities.technique from app.domain.entities.technique import TechniqueEntity + +# Import Technique from app.models.technique from app.models.technique import Technique +# Define function recalculate_technique_status def recalculate_technique_status(db: Session, technique: Technique) -> None: """Recompute ``technique.status_global`` from its tests. ``db`` is accepted for backward compatibility but is not used directly — test data comes from the ORM relationship. """ + # Assign entity = TechniqueEntity.from_orm(technique) entity = TechniqueEntity.from_orm(technique) + # Assign test_snapshots = [ test_snapshots = [ (t.state, t.detection_result) for t in technique.tests ] + # Call entity.recalculate_status() entity.recalculate_status(test_snapshots) + # Assign technique.status_global = entity.status_global technique.status_global = entity.status_global diff --git a/backend/app/services/technique_query_service.py b/backend/app/services/technique_query_service.py index 43b7777..9e8c4d1 100644 --- a/backend/app/services/technique_query_service.py +++ b/backend/app/services/technique_query_service.py @@ -1,9 +1,14 @@ """Technique query service — framework-agnostic queries for technique details.""" +# Enable future language features for compatibility from __future__ import annotations +# Import Session, joinedload from sqlalchemy.orm from sqlalchemy.orm import Session, joinedload +# Import EntityNotFoundError from app.domain.errors from app.domain.errors import EntityNotFoundError + +# Import Technique from app.models.technique from app.models.technique import Technique from app.models.detection_rule import DetectionRule from app.models.intel import IntelItem @@ -13,15 +18,21 @@ from app.services.d3fend_import_service import get_defenses_for_technique _SEVERITY_ORDER = {"critical": 0, "high": 1, "medium": 2, "low": 3, "informational": 4, None: 5} +# Define function get_technique_detail def get_technique_detail(db: Session, mitre_id: str) -> dict: """Fetch full technique details including tests, detection rules, and D3FEND defenses.""" technique = ( db.query(Technique) + # Chain .options() call .options(joinedload(Technique.tests)) + # Chain .filter() call .filter(Technique.mitre_id == mitre_id) + # Chain .first() call .first() ) + # Check: technique is None if technique is None: + # Raise EntityNotFoundError raise EntityNotFoundError("Technique", mitre_id) defenses = get_defenses_for_technique(db, technique.id) @@ -49,26 +60,46 @@ def get_technique_detail(db: Session, mitre_id: str) -> dict: ) return { + # Literal argument value "id": str(technique.id), + # Literal argument value "mitre_id": technique.mitre_id, + # Literal argument value "name": technique.name, + # Literal argument value "description": technique.description, + # Literal argument value "tactic": technique.tactic, + # Literal argument value "platforms": technique.platforms or [], + # Literal argument value "mitre_version": technique.mitre_version, + # Literal argument value "mitre_last_modified": technique.mitre_last_modified, + # Literal argument value "is_subtechnique": technique.is_subtechnique, + # Literal argument value "parent_mitre_id": technique.parent_mitre_id, + # Literal argument value "status_global": technique.status_global.value if technique.status_global else "not_evaluated", + # Literal argument value "review_required": technique.review_required, + # Literal argument value "last_review_date": technique.last_review_date, + # Literal argument value "tests": [ { + # Literal argument value "id": str(t.id), + # Literal argument value "name": t.name, + # Literal argument value "state": t.state.value if t.state else None, + # Literal argument value "result": t.result.value if t.result else None, + # Literal argument value "platform": t.platform, + # Literal argument value "created_at": t.created_at.isoformat() if t.created_at else None, } for t in technique.tests diff --git a/backend/app/services/tempo_service.py b/backend/app/services/tempo_service.py index 7a1b1f4..b29a43f 100644 --- a/backend/app/services/tempo_service.py +++ b/backend/app/services/tempo_service.py @@ -18,15 +18,31 @@ blue_work_started_at) to when they submit, so it reflects actual working time rather than queue time. """ +# Import logging import logging -from typing import Optional +# Import Any, Optional from typing +from typing import Any, Optional + +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session +# Import settings from app.config from app.config import settings + +# Import InvalidOperationError from app.domain.exceptions from app.domain.exceptions import InvalidOperationError + +# Import JiraLink, JiraLinkEntityType from app.models.jira_link from app.models.jira_link import JiraLink, JiraLinkEntityType +# Import Test from app.models.test +from app.models.test import Test + +# Import User from app.models.user +from app.models.user import User + +# Assign logger = logging.getLogger(__name__) logger = logging.getLogger(__name__) # Only red team execution time goes to Tempo. @@ -85,23 +101,31 @@ def get_user_tempo_client(user, db=None): "Add it in Settings → Profile → Tempo Integration." ) try: + # Import client_v4 as tempo_client from tempoapiclient from tempoapiclient import client_v4 as tempo_client base_url = _get_tempo_base_url(db) logger.debug("Using Tempo base URL: %s", base_url) return tempo_client.Tempo(auth_token=token, base_url=base_url) except ImportError: + # Raise InvalidOperationError raise InvalidOperationError( + # Literal argument value "tempo-api-python-client is not installed. " "Run: pip install tempo-api-python-client" ) +# Define function log_worklog def log_worklog( user, jira_issue_id: int, + # Entry: author_account_id author_account_id: str, + # Entry: date date: str, + # Entry: time_spent_seconds time_spent_seconds: int, + # Entry: description description: str, db=None, ) -> dict: @@ -128,10 +152,15 @@ def log_worklog( raise RuntimeError(f"Tempo API error: {exc}") from exc +# Define function auto_log_test_worklog def auto_log_test_worklog( + # Entry: db db: Session, - test, - user, + # Entry: test + test: Test, + # Entry: user + user: User, + # Entry: activity_type activity_type: str, duration_seconds: int, ) -> Optional[dict]: @@ -156,6 +185,7 @@ def auto_log_test_worklog( # Global kill-switch if not settings.TEMPO_ENABLED: + # Return None return None if duration_seconds <= 0: @@ -183,15 +213,20 @@ def auto_log_test_worklog( # Need a Jira link with a numeric issue ID link = ( db.query(JiraLink) + # Chain .filter() call .filter( JiraLink.entity_id == test.id, JiraLink.entity_type == JiraLinkEntityType.test, ) + # Chain .first() call .first() ) + # Check: not link or not link.jira_issue_id if not link or not link.jira_issue_id: + # Log debug: "No Jira link for test %s, skipping Tempo worklog" logger.debug("No Jira link for test %s, skipping Tempo worklog", test.id) + # Return None return None jira_account_id = (getattr(user, "jira_account_id", "") or "").strip() @@ -202,6 +237,7 @@ def auto_log_test_worklog( ) return None + # Attempt the following; catch errors below try: # Use the phase start timestamp as the worklog date so it matches when # work actually happened (not the submission timestamp). @@ -231,6 +267,7 @@ def auto_log_test_worklog( test.id, getattr(user, "username", user), duration_seconds, work_date, ) return result + # Handle Exception except Exception as e: logger.warning( "Tempo worklog failed for test %s (user %s): %s", diff --git a/backend/app/services/test_crud_service.py b/backend/app/services/test_crud_service.py index 03e61d1..527e834 100644 --- a/backend/app/services/test_crud_service.py +++ b/backend/app/services/test_crud_service.py @@ -4,12 +4,15 @@ Framework-agnostic; uses domain exceptions from app.domain.errors. The router is responsible for HTTP concerns, auth, audit logging, and commit. """ +# Import uuid import uuid from datetime import datetime from typing import Any +# Import Session, joinedload from sqlalchemy.orm from sqlalchemy.orm import Session, joinedload +# Import from app.domain.errors from app.domain.errors import ( BusinessRuleViolation, EntityNotFoundError, @@ -21,19 +24,41 @@ from app.models.test import Test from app.models.test_template import TestTemplate from app.models.campaign import Campaign, CampaignTest from app.models.audit import AuditLog + +# Import TestState from app.models.enums +from app.models.enums import TestState + +# Import Technique from app.models.technique +from app.models.technique import Technique + +# Import Test from app.models.test +from app.models.test import Test + +# Import TestTemplate from app.models.test_template +from app.models.test_template import TestTemplate + +# Import escape_like from app.utils from app.utils import escape_like +# Define function list_tests def list_tests( + # Entry: db db: Session, *, + # Entry: state state: str | None = None, + # Entry: technique_id technique_id: uuid.UUID | None = None, + # Entry: platform platform: str | None = None, + # Entry: created_by created_by: uuid.UUID | None = None, + # Entry: pending_validation_side pending_validation_side: str | None = None, not_in_any_campaign: bool = False, offset: int = 0, + # Entry: limit limit: int = 50, ) -> list[Test]: """Return a paginated list of tests with optional filters. @@ -44,20 +69,32 @@ def list_tests( """ query = db.query(Test).options(joinedload(Test.technique)) + # Check: state if state: + # Assign query = query.filter(Test.state == state) query = query.filter(Test.state == state) + # Check: technique_id if technique_id: + # Assign query = query.filter(Test.technique_id == technique_id) query = query.filter(Test.technique_id == technique_id) + # Check: platform if platform: + # Assign query = query.filter(Test.platform.ilike(f"%{escape_like(platform)}%")) query = query.filter(Test.platform.ilike(f"%{escape_like(platform)}%")) + # Check: created_by if created_by: + # Assign query = query.filter(Test.created_by == created_by) query = query.filter(Test.created_by == created_by) + # Check: pending_validation_side == "red" if pending_validation_side == "red": + # Assign query = query.filter( query = query.filter( Test.state == TestState.in_review, Test.red_validation_status.in_(["pending", None]), ) + # Alternative: pending_validation_side == "blue" elif pending_validation_side == "blue": + # Assign query = query.filter( query = query.filter( Test.state == TestState.in_review, Test.blue_validation_status.in_(["pending", None]), @@ -82,42 +119,72 @@ def list_tests( ) query = query.filter(~Test.id.in_(future_draft_tests)) + # Return query.order_by(Test.created_at.desc()).offset(offset).limit(limit).... return query.order_by(Test.created_at.desc()).offset(offset).limit(limit).all() +# Define function create_test def create_test( + # Entry: db db: Session, *, + # Entry: technique_id technique_id: uuid.UUID, + # Entry: creator_id creator_id: uuid.UUID, - **fields: Any, + **fields: object, ) -> Test: """Create a new test linked to an existing technique. Raises EntityNotFoundError if the technique does not exist. Does not commit; caller uses UnitOfWork. + + Args: + db (Session): Active SQLAlchemy database session. + technique_id (uuid.UUID): UUID of the technique this test covers. + creator_id (uuid.UUID): UUID of the user creating the test. + **fields (object): Additional keyword arguments set as attributes on + the new test (e.g. ``name``, ``platform``, ``description``). + + Returns: + Test: The newly created test ORM object, flushed but not committed. """ + # Assign technique = db.query(Technique).filter(Technique.id == technique_id).first() technique = db.query(Technique).filter(Technique.id == technique_id).first() + # Check: technique is None if technique is None: + # Raise EntityNotFoundError raise EntityNotFoundError("Technique", str(technique_id)) + # Assign test = Test( test = Test( + # Keyword argument: technique_id technique_id=technique_id, + # Keyword argument: created_by created_by=creator_id, + # Keyword argument: state state=TestState.draft, created_at=datetime.utcnow(), # explicit — DB column has no server default **fields, ) + # Stage new record(s) for database insertion db.add(test) + # Flush changes to DB without committing the transaction db.flush() + # Return test return test +# Define function create_test_from_template def create_test_from_template( + # Entry: db db: Session, *, + # Entry: template_id template_id: uuid.UUID, + # Entry: technique_id_or_mitre technique_id_or_mitre: str, + # Entry: creator_id creator_id: uuid.UUID, # Optional user-edited overrides (take priority over template values) name_override: str | None = None, @@ -132,27 +199,53 @@ def create_test_from_template( Override fields, when provided, take precedence over the template's values. Raises EntityNotFoundError if template or technique not found. Does not commit; caller uses UnitOfWork. + + Args: + db (Session): Active SQLAlchemy database session. + template_id (uuid.UUID): UUID of the template to instantiate. + technique_id_or_mitre (str): UUID string or MITRE technique ID + (e.g. ``"T1059.001"``) identifying the target technique. + creator_id (uuid.UUID): UUID of the user creating the test. + + Returns: + Test: The newly created test populated from template fields, flushed + but not committed. """ + # Assign template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first() template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first() + # Check: template is None if template is None: + # Raise EntityNotFoundError raise EntityNotFoundError("TestTemplate", str(template_id)) + # Assign technique = None technique = None + # Attempt the following; catch errors below try: + # Assign technique_uuid = uuid.UUID(technique_id_or_mitre) technique_uuid = uuid.UUID(technique_id_or_mitre) + # Assign technique = db.query(Technique).filter(Technique.id == technique_uuid).first() technique = db.query(Technique).filter(Technique.id == technique_uuid).first() + # Handle ValueError except ValueError: + # Intentional no-op placeholder pass + # Check: technique is None if technique is None: + # Assign technique = db.query(Technique).filter( technique = db.query(Technique).filter( Technique.mitre_id == technique_id_or_mitre ).first() + # Check: technique is None if technique is None: + # Raise EntityNotFoundError raise EntityNotFoundError("Technique", technique_id_or_mitre) + # Assign test = Test( test = Test( + # Keyword argument: technique_id technique_id=technique.id, name=name_override if name_override is not None else template.name, description=description_override if description_override is not None else template.description, @@ -160,59 +253,111 @@ def create_test_from_template( procedure_text=procedure_text_override if procedure_text_override is not None else template.attack_procedure, tool_used=tool_used_override if tool_used_override is not None else template.tool_suggested, remediation_steps=template.suggested_remediation, + # Keyword argument: created_by created_by=creator_id, + # Keyword argument: state state=TestState.draft, created_at=datetime.utcnow(), # explicit — DB column has no server default ) + # Stage new record(s) for database insertion db.add(test) + # Flush changes to DB without committing the transaction db.flush() + # Return test return test +# Define function get_test_detail def get_test_detail(db: Session, test_id: uuid.UUID) -> Test: """Fetch a test with evidences and technique eager-loaded. Raises EntityNotFoundError if the test does not exist. + + Args: + db (Session): Active SQLAlchemy database session. + test_id (uuid.UUID): UUID of the test to retrieve. + + Returns: + Test: The test ORM object with ``evidences`` relationship loaded. """ + # Assign test = ( test = ( db.query(Test) .options(joinedload(Test.evidences), joinedload(Test.technique)) .filter(Test.id == test_id) + # Chain .first() call .first() ) + # Check: test is None if test is None: + # Raise EntityNotFoundError raise EntityNotFoundError("Test", str(test_id)) + # Return test return test +# Define function get_test_or_raise def get_test_or_raise(db: Session, test_id: uuid.UUID) -> Test: - """Fetch a test by ID. Raises EntityNotFoundError if not found.""" + """Fetch a test by ID. Raises EntityNotFoundError if not found. + + Args: + db (Session): Active SQLAlchemy database session. + test_id (uuid.UUID): UUID of the test to retrieve. + + Returns: + Test: The matching test ORM object. + """ + # Assign test = db.query(Test).filter(Test.id == test_id).first() test = db.query(Test).filter(Test.id == test_id).first() + # Check: test is None if test is None: + # Raise EntityNotFoundError raise EntityNotFoundError("Test", str(test_id)) + # Return test return test +# Define function get_test_with_technique def get_test_with_technique(db: Session, test_id: uuid.UUID) -> Test: - """Fetch a test with technique joined. Raises EntityNotFoundError if not found.""" + """Fetch a test with technique joined. Raises EntityNotFoundError if not found. + + Args: + db (Session): Active SQLAlchemy database session. + test_id (uuid.UUID): UUID of the test to retrieve. + + Returns: + Test: The test ORM object with ``technique`` relationship loaded. + """ + # Assign test = ( test = ( db.query(Test) + # Chain .options() call .options(joinedload(Test.technique)) + # Chain .filter() call .filter(Test.id == test_id) + # Chain .first() call .first() ) + # Check: test is None if test is None: + # Raise EntityNotFoundError raise EntityNotFoundError("Test", str(test_id)) + # Return test return test +# Define function update_test def update_test( + # Entry: db db: Session, + # Entry: test_id test_id: uuid.UUID, *, + # Entry: updater_id updater_id: uuid.UUID, + # Entry: updater_role updater_role: str, - **fields: Any, + **fields: object, ) -> Test: """Update general test fields (draft or rejected only). @@ -220,93 +365,170 @@ def update_test( Raises BusinessRuleViolation if state is not draft or rejected. Raises EntityNotFoundError if test not found. Does not commit; caller uses UnitOfWork. + + Args: + db (Session): Active SQLAlchemy database session. + test_id (uuid.UUID): UUID of the test to update. + updater_id (uuid.UUID): UUID of the user performing the update. + updater_role (str): Role of the updater; ``"admin"`` bypasses the + creator-only restriction. + **fields (object): Keyword arguments mapped directly onto test + attributes. + + Returns: + Test: The updated test ORM object, flushed but not committed. """ + # Assign test = get_test_or_raise(db, test_id) test = get_test_or_raise(db, test_id) + # Check: updater_role != "admin" and test.created_by != updater_id if updater_role != "admin" and test.created_by != updater_id: + # Raise PermissionViolation raise PermissionViolation( + # Literal argument value "Only the test creator or an admin can update this test" ) + # Check: test.state not in (TestState.draft, TestState.rejected) if test.state not in (TestState.draft, TestState.rejected): + # Raise BusinessRuleViolation raise BusinessRuleViolation( f"Cannot update a test in '{test.state.value}' state (must be draft or rejected)" ) + # Iterate over fields.items() for field, value in fields.items(): + # Call setattr() setattr(test, field, value) + # Flush changes to DB without committing the transaction db.flush() + # Return test return test -def update_test_red(db: Session, test_id: uuid.UUID, **fields: Any) -> Test: +# Define function update_test_red +def update_test_red(db: Session, test_id: uuid.UUID, **fields: object) -> Test: """Update Red Team fields (draft or red_executing only). Raises BusinessRuleViolation if state not in (draft, red_executing). Raises EntityNotFoundError if test not found. Does not commit; caller uses UnitOfWork. + + Args: + db (Session): Active SQLAlchemy database session. + test_id (uuid.UUID): UUID of the test to update. + **fields (object): Red-team field names and their new values. + + Returns: + Test: The updated test ORM object, flushed but not committed. """ + # Assign test = get_test_or_raise(db, test_id) test = get_test_or_raise(db, test_id) + # Check: test.state not in (TestState.draft, TestState.red_executing) if test.state not in (TestState.draft, TestState.red_executing): + # Raise BusinessRuleViolation raise BusinessRuleViolation( f"Cannot update red fields in '{test.state.value}' state " + # Literal argument value "(must be draft or red_executing)" ) + # Iterate over fields.items() for field, value in fields.items(): + # Call setattr() setattr(test, field, value) + # Flush changes to DB without committing the transaction db.flush() + # Return test return test -def update_test_blue(db: Session, test_id: uuid.UUID, **fields: Any) -> Test: +# Define function update_test_blue +def update_test_blue(db: Session, test_id: uuid.UUID, **fields: object) -> Test: """Update Blue Team fields (blue_evaluating only). Raises BusinessRuleViolation if state is not blue_evaluating. Raises EntityNotFoundError if test not found. Does not commit; caller uses UnitOfWork. + + Args: + db (Session): Active SQLAlchemy database session. + test_id (uuid.UUID): UUID of the test to update. + **fields (object): Blue-team field names and their new values. + + Returns: + Test: The updated test ORM object, flushed but not committed. """ + # Assign test = get_test_or_raise(db, test_id) test = get_test_or_raise(db, test_id) + # Check: test.state != TestState.blue_evaluating if test.state != TestState.blue_evaluating: + # Raise BusinessRuleViolation raise BusinessRuleViolation( f"Cannot update blue fields in '{test.state.value}' state " + # Literal argument value "(must be blue_evaluating)" ) + # Iterate over fields.items() for field, value in fields.items(): + # Call setattr() setattr(test, field, value) + # Flush changes to DB without committing the transaction db.flush() + # Return test return test +# Define function get_test_timeline def get_test_timeline(db: Session, test_id: uuid.UUID) -> list[dict[str, Any]]: """Return chronological audit-log history for a test. Raises EntityNotFoundError if the test does not exist. + + Args: + db (Session): Active SQLAlchemy database session. + test_id (uuid.UUID): UUID of the test whose history is requested. + + Returns: + list[dict[str, Any]]: Audit-log entries ordered by timestamp ascending, + each containing ``id``, ``action``, ``user_id``, ``timestamp``, + and ``details``. """ + # Call get_test_or_raise() get_test_or_raise(db, test_id) + # Assign logs = ( logs = ( db.query(AuditLog) + # Chain .filter() call .filter( AuditLog.entity_type == "test", AuditLog.entity_id == str(test_id), ) + # Chain .order_by() call .order_by(AuditLog.timestamp.asc()) + # Chain .all() call .all() ) + # Return [ return [ { + # Literal argument value "id": str(log.id), + # Literal argument value "action": log.action, + # Literal argument value "user_id": str(log.user_id) if log.user_id else None, + # Literal argument value "timestamp": log.timestamp.isoformat() if log.timestamp else None, + # Literal argument value "details": log.details, } for log in logs diff --git a/backend/app/services/test_template_service.py b/backend/app/services/test_template_service.py index 8f4fa30..0340328 100644 --- a/backend/app/services/test_template_service.py +++ b/backend/app/services/test_template_service.py @@ -1,43 +1,77 @@ """Test template service — framework-agnostic CRUD and queries.""" +# Enable future language features for compatibility from __future__ import annotations +# Import uuid import uuid +# Import func, or_ from sqlalchemy from sqlalchemy import func, or_ + +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session +# Import EntityNotFoundError from app.domain.errors from app.domain.errors import EntityNotFoundError + +# Import TestTemplate from app.models.test_template from app.models.test_template import TestTemplate + +# Import escape_like from app.utils from app.utils import escape_like +# Define function list_templates def list_templates( + # Entry: db db: Session, *, + # Entry: source source: str | None = None, + # Entry: platform platform: str | None = None, + # Entry: severity severity: str | None = None, + # Entry: mitre_technique_id mitre_technique_id: str | None = None, + # Entry: search search: str | None = None, + # Entry: is_active is_active: bool | None = None, + # Entry: offset offset: int = 0, + # Entry: limit limit: int = 50, ) -> list: """Return paginated, filterable list of test templates.""" + # Assign query = db.query(TestTemplate) query = db.query(TestTemplate) + # Check: is_active is not None if is_active is not None: + # Assign query = query.filter(TestTemplate.is_active == is_active) query = query.filter(TestTemplate.is_active == is_active) + # Check: source if source: + # Assign query = query.filter(TestTemplate.source == source) query = query.filter(TestTemplate.source == source) + # Check: platform if platform: + # Assign query = query.filter(TestTemplate.platform.ilike(f"%{escape_like(platform)}... query = query.filter(TestTemplate.platform.ilike(f"%{escape_like(platform)}%")) + # Check: severity if severity: + # Assign query = query.filter(TestTemplate.severity == severity) query = query.filter(TestTemplate.severity == severity) + # Check: mitre_technique_id if mitre_technique_id: + # Assign query = query.filter(TestTemplate.mitre_technique_id == mitre_technique_id) query = query.filter(TestTemplate.mitre_technique_id == mitre_technique_id) + # Check: search if search: + # Assign pattern = f"%{escape_like(search)}%" pattern = f"%{escape_like(search)}%" + # Assign query = query.filter( query = query.filter( or_( TestTemplate.name.ilike(pattern), @@ -45,106 +79,166 @@ def list_templates( ) ) + # Assign templates = ( templates = ( query + # Chain .order_by() call .order_by(TestTemplate.mitre_technique_id, TestTemplate.name) + # Chain .offset() call .offset(offset) + # Chain .limit() call .limit(limit) + # Chain .all() call .all() ) + # Return templates return templates +# Define function get_template_stats def get_template_stats(db: Session) -> dict: """Return catalog statistics: totals by source, platform, active/inactive.""" + # Assign total = db.query(func.count(TestTemplate.id)).scalar() or 0 total = db.query(func.count(TestTemplate.id)).scalar() or 0 + # Assign active = ( active = ( db.query(func.count(TestTemplate.id)) + # Chain .filter() call .filter(TestTemplate.is_active == True) # noqa: E712 + # Chain .scalar() call .scalar() ) or 0 + # Assign inactive = total - active inactive = total - active + # Assign source_rows = ( source_rows = ( db.query(TestTemplate.source, func.count(TestTemplate.id)) + # Chain .filter() call .filter(TestTemplate.is_active == True) # noqa: E712 + # Chain .group_by() call .group_by(TestTemplate.source) + # Chain .all() call .all() ) + # Assign by_source = {source: cnt for source, cnt in source_rows} by_source = {source: cnt for source, cnt in source_rows} + # Assign platform_rows = ( platform_rows = ( db.query(TestTemplate.platform, func.count(TestTemplate.id)) + # Chain .filter() call .filter(TestTemplate.is_active == True) # noqa: E712 + # Chain .group_by() call .group_by(TestTemplate.platform) + # Chain .all() call .all() ) + # Assign by_platform = {(platform or "unspecified"): cnt for platform, cnt in platform_rows} by_platform = {(platform or "unspecified"): cnt for platform, cnt in platform_rows} + # Return { return { + # Literal argument value "total": total, + # Literal argument value "active": active, + # Literal argument value "inactive": inactive, + # Literal argument value "by_source": by_source, + # Literal argument value "by_platform": by_platform, } +# Define function bulk_activate def bulk_activate(db: Session, *, activate: bool) -> int: """Set all templates to active or inactive. Returns count of affected. Does NOT commit.""" + # Assign count = ( count = ( db.query(TestTemplate) + # Chain .filter() call .filter(TestTemplate.is_active != activate) + # Chain .update() call .update({TestTemplate.is_active: activate}) ) + # Return count return count +# Define function get_templates_by_technique def get_templates_by_technique(db: Session, mitre_id: str) -> list: """Return all active templates mapped to a specific MITRE technique.""" + # Return ( return ( db.query(TestTemplate) + # Chain .filter() call .filter( TestTemplate.mitre_technique_id == mitre_id, TestTemplate.is_active == True, # noqa: E712 ) + # Chain .order_by() call .order_by(TestTemplate.name) + # Chain .all() call .all() ) +# Define function get_template_or_raise def get_template_or_raise(db: Session, template_id: uuid.UUID) -> TestTemplate: """Return a template by ID. Raises EntityNotFoundError if not found.""" + # Assign template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first() template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first() + # Check: template is None if template is None: + # Raise EntityNotFoundError raise EntityNotFoundError("Test template", str(template_id)) + # Return template return template +# Define function create_template def create_template(db: Session, **fields: object) -> TestTemplate: """Create a test template from keyword args (e.g. payload.model_dump()). Does NOT commit.""" + # Assign template = TestTemplate(**fields) template = TestTemplate(**fields) + # Stage new record(s) for database insertion db.add(template) + # Return template return template +# Define function update_template def update_template(db: Session, template_id: uuid.UUID, **fields: object) -> TestTemplate: """Update an existing template. Raises EntityNotFoundError if not found. Does NOT commit.""" + # Assign template = get_template_or_raise(db, template_id) template = get_template_or_raise(db, template_id) + # Iterate over fields.items() for field, value in fields.items(): + # Check: hasattr(template, field) if hasattr(template, field): + # Call setattr() setattr(template, field, value) + # Return template return template +# Define function toggle_template_active def toggle_template_active(db: Session, template_id: uuid.UUID) -> TestTemplate: """Toggle template active/inactive. Does NOT commit.""" + # Assign template = get_template_or_raise(db, template_id) template = get_template_or_raise(db, template_id) + # Assign template.is_active = not template.is_active template.is_active = not template.is_active + # Return template return template +# Define function soft_delete_template def soft_delete_template(db: Session, template_id: uuid.UUID) -> None: """Soft-delete a template by setting is_active=False. Does NOT commit.""" + # Assign template = get_template_or_raise(db, template_id) template = get_template_or_raise(db, template_id) + # Assign template.is_active = False template.is_active = False diff --git a/backend/app/services/test_workflow_service.py b/backend/app/services/test_workflow_service.py index 330fd80..5003756 100644 --- a/backend/app/services/test_workflow_service.py +++ b/backend/app/services/test_workflow_service.py @@ -12,11 +12,19 @@ an audit-log entry. The caller (router) is responsible for committing the session via the Unit of Work pattern. """ +# Import logging import logging + +# Import uuid +import uuid + +# Import datetime from datetime from datetime import datetime +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session +# Import settings from app.config from app.config import settings from app.domain.exceptions import InvalidOperationError, InvalidTransitionError from app.domain.test_entity import TestEntity @@ -27,6 +35,31 @@ from app.models.user import User from app.services.audit_service import log_action from app.services.notification_service import notify_test_state_change, create_notification +# Import InvalidOperationError from app.domain.exceptions +from app.domain.exceptions import InvalidOperationError + +# Import TestEntity from app.domain.test_entity +from app.domain.test_entity import TestEntity + +# Import TestState from app.models.enums +from app.models.enums import TestState + +# Import Test from app.models.test +from app.models.test import Test + +# Import User from app.models.user +from app.models.user import User + +# Import log_action from app.services.audit_service +from app.services.audit_service import log_action + +# Import from app.services.notification_service +from app.services.notification_service import ( + create_notification, + notify_test_state_change, +) + +# Assign logger = logging.getLogger(__name__) logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- @@ -50,18 +83,35 @@ VALID_TRANSITIONS: dict[TestState, list[TestState]] = { def can_transition(test: Test, target_state: TestState) -> bool: - """Return *True* if moving *test* to *target_state* is allowed.""" + """Return *True* if moving *test* to *target_state* is allowed. + + Args: + test (Test): The test whose current state is being checked. + target_state (TestState): The state to transition to. + + Returns: + bool: ``True`` if the transition is permitted by ``VALID_TRANSITIONS``. + """ + # Assign current = test.state if isinstance(test.state, TestState) else TestState(test... current = test.state if isinstance(test.state, TestState) else TestState(test.state) + # Return target_state in VALID_TRANSITIONS.get(current, []) return target_state in VALID_TRANSITIONS.get(current, []) +# Define function transition_state def transition_state( + # Entry: db db: Session, + # Entry: test test: Test, + # Entry: target_state target_state: TestState, + # Entry: user user: User, *, + # Entry: action_name action_name: str = "transition_state", + # Entry: extra_details extra_details: dict | None = None, ) -> Test: """Validate and perform a state transition, log it, and flush. @@ -71,36 +121,71 @@ def transition_state( when the transition is illegal. The entity is authoritative for which transitions are valid; the module-level ``VALID_TRANSITIONS`` dict is kept temporarily for backward compatibility of ``can_transition()``. + + Args: + db (Session): Active SQLAlchemy database session. + test (Test): The test ORM object to transition. + target_state (TestState): Desired next state. + user (User): The user performing the transition (logged in audit). + action_name (str): Audit log action label; defaults to + ``"transition_state"``. + extra_details (dict | None): Optional extra key-value pairs merged + into the audit log details. + + Returns: + Test: The mutated test ORM object (state updated, flushed). """ + # Assign entity = TestEntity.from_orm(test) entity = TestEntity.from_orm(test) + # Assign previous_state = entity.transition_to(target_state) previous_state = entity.transition_to(target_state) + # Assign test.state = entity.state test.state = entity.state + # Flush changes to DB without committing the transaction db.flush() + # Assign details = { details: dict = { + # Literal argument value "previous_state": previous_state, + # Literal argument value "new_state": target_state.value, + # Literal argument value "test_name": test.name, + # Literal argument value "technique_id": str(test.technique_id), } + # Check: extra_details if extra_details: + # Call details.update() details.update(extra_details) + # Call log_action() log_action( db, + # Keyword argument: user_id user_id=user.id, + # Keyword argument: action action=action_name, + # Keyword argument: entity_type entity_type="test", + # Keyword argument: entity_id entity_id=test.id, + # Keyword argument: details details=details, ) + # Attempt the following; catch errors below try: + # Call notify_test_state_change() notify_test_state_change(db, test, target_state.value) + # Handle Exception except Exception as e: + # Log warning: "Notification failed for test %s: %s", test.id, e logger.warning("Notification failed for test %s: %s", test.id, e, exc_info=True) + # Return test return test @@ -112,27 +197,44 @@ def transition_state( def start_execution(db: Session, test: Test, user: User) -> Test: """Move from ``draft`` → ``red_executing``.""" entity = TestEntity.from_orm(test) + # Call entity.start_execution() entity.start_execution() + # Call entity.apply_to() entity.apply_to(test) + # Flush changes to DB without committing the transaction db.flush() + # Call log_action() log_action( db, + # Keyword argument: user_id user_id=user.id, + # Keyword argument: action action="start_execution", + # Keyword argument: entity_type entity_type="test", + # Keyword argument: entity_id entity_id=test.id, + # Keyword argument: details details={ + # Literal argument value "previous_state": "draft", + # Literal argument value "new_state": test.state.value, + # Literal argument value "test_name": test.name, + # Literal argument value "technique_id": str(test.technique_id), }, ) + # Attempt the following; catch errors below try: + # Call notify_test_state_change() notify_test_state_change(db, test, test.state.value) + # Handle Exception except Exception as e: + # Log warning: "Notification failed for test %s: %s", test.id, e logger.warning("Notification failed for test %s: %s", test.id, e, exc_info=True) try: @@ -144,6 +246,7 @@ def start_execution(db: Session, test: Test, user: User) -> Test: return test +# Define function submit_red_evidence def submit_red_evidence(db: Session, test: Test, user: User) -> Test: """Move from ``red_executing`` → ``blue_evaluating``. @@ -151,6 +254,14 @@ def submit_red_evidence(db: Session, test: Test, user: User) -> Test: Requires at least one Red Team evidence file to be uploaded. Stops the Red Team timer and creates an automatic worklog. Starts the Blue Team timer by recording ``blue_started_at``. + + Args: + db (Session): Active SQLAlchemy database session. + test (Test): The test whose red-team evidence is being submitted. + user (User): The red-team user submitting the evidence. + + Returns: + Test: The mutated test with state advanced and blue timer started. """ # Evidence is mandatory before submitting red_evidence_count = ( @@ -167,29 +278,42 @@ def submit_red_evidence(db: Session, test: Test, user: User) -> Test: # Auto-resume if paused paused_extra = 0 + # Check: test.paused_at is not None if test.paused_at is not None: + # Assign paused_extra = max(int((now - test.paused_at).total_seconds()), 0) paused_extra = max(int((now - test.paused_at).total_seconds()), 0) + # Assign test.paused_at = None test.paused_at = None + # Assign test = transition_state( test = transition_state( db, test, TestState.blue_evaluating, user, + # Keyword argument: action_name action_name="submit_red_evidence", ) # Create automatic worklog for Red Team phase (subtract paused time) _create_phase_worklog( db, + # Keyword argument: test test=test, + # Keyword argument: user user=user, + # Keyword argument: phase_started_at phase_started_at=test.red_started_at, + # Keyword argument: phase_ended_at phase_ended_at=now, + # Keyword argument: paused_seconds paused_seconds=(test.red_paused_seconds or 0) + paused_extra, + # Keyword argument: activity_type activity_type="red_team_execution", + # Keyword argument: description description=f"Red Team execution: {test.name}", ) # Start Blue Team timer test.blue_started_at = now + # Assign test.blue_paused_seconds = 0 test.blue_paused_seconds = 0 try: @@ -234,6 +358,7 @@ def start_blue_work(db: Session, test: Test, user: User) -> Test: return test +# Define function submit_blue_evidence def submit_blue_evidence(db: Session, test: Test, user: User) -> Test: """Move from ``blue_evaluating`` → ``in_review``. @@ -258,12 +383,17 @@ def submit_blue_evidence(db: Session, test: Test, user: User) -> Test: # Auto-resume if paused paused_extra = 0 + # Check: test.paused_at is not None if test.paused_at is not None: + # Assign paused_extra = max(int((now - test.paused_at).total_seconds()), 0) paused_extra = max(int((now - test.paused_at).total_seconds()), 0) + # Assign test.paused_at = None test.paused_at = None + # Assign test = transition_state( test = transition_state( db, test, TestState.in_review, user, + # Keyword argument: action_name action_name="submit_blue_evidence", ) @@ -272,12 +402,17 @@ def submit_blue_evidence(db: Session, test: Test, user: User) -> Test: # Tempo worklog reflects real working time, not just queue time. _create_phase_worklog( db, + # Keyword argument: test test=test, + # Keyword argument: user user=user, phase_started_at=test.blue_work_started_at or test.blue_started_at, phase_ended_at=now, + # Keyword argument: paused_seconds paused_seconds=(test.blue_paused_seconds or 0) + paused_extra, + # Keyword argument: activity_type activity_type="blue_team_evaluation", + # Keyword argument: description description=f"Blue Team evaluation: {test.name}", ) @@ -290,69 +425,125 @@ def submit_blue_evidence(db: Session, test: Test, user: User) -> Test: return test +# Define function pause_timer def pause_timer(db: Session, test: Test, user: User) -> Test: """Pause the active phase timer. Can only be called when the test is in ``red_executing`` or ``blue_evaluating`` and is not already paused. + + Args: + db (Session): Active SQLAlchemy database session. + test (Test): The currently active test. + user (User): The user pausing the timer. + + Returns: + Test: The mutated test with ``paused_at`` set to the current UTC time. """ + # Check: test.state not in (TestState.red_executing, TestState.blue_evaluating) if test.state not in (TestState.red_executing, TestState.blue_evaluating): + # Raise InvalidOperationError raise InvalidOperationError( f"Cannot pause timer in '{test.state.value}' state" ) + # Check: test.paused_at is not None if test.paused_at is not None: + # Raise InvalidOperationError raise InvalidOperationError("Timer is already paused") + # Assign test.paused_at = datetime.utcnow() test.paused_at = datetime.utcnow() + # Call log_action() log_action( db, + # Keyword argument: user_id user_id=user.id, + # Keyword argument: action action="pause_timer", + # Keyword argument: entity_type entity_type="test", + # Keyword argument: entity_id entity_id=test.id, + # Keyword argument: details details={"state": test.state.value}, ) + # Return test return test +# Define function resume_timer def resume_timer(db: Session, test: Test, user: User) -> Test: """Resume a paused phase timer. Accumulates the paused duration into the appropriate counter so it is subtracted from the final worklog. + + Args: + db (Session): Active SQLAlchemy database session. + test (Test): The paused test to resume. + user (User): The user resuming the timer. + + Returns: + Test: The mutated test with ``paused_at`` cleared and accumulated + pause seconds updated. """ + # Check: test.paused_at is None if test.paused_at is None: + # Raise InvalidOperationError raise InvalidOperationError("Timer is not paused") + # Assign now = datetime.utcnow() now = datetime.utcnow() + # Assign paused_seconds = max(int((now - test.paused_at).total_seconds()), 0) paused_seconds = max(int((now - test.paused_at).total_seconds()), 0) + # Check: test.state == TestState.red_executing if test.state == TestState.red_executing: + # Assign test.red_paused_seconds = (test.red_paused_seconds or 0) + paused_seconds test.red_paused_seconds = (test.red_paused_seconds or 0) + paused_seconds + # Alternative: test.state == TestState.blue_evaluating elif test.state == TestState.blue_evaluating: + # Assign test.blue_paused_seconds = (test.blue_paused_seconds or 0) + paused_seconds test.blue_paused_seconds = (test.blue_paused_seconds or 0) + paused_seconds + # Assign test.paused_at = None test.paused_at = None + # Call log_action() log_action( db, + # Keyword argument: user_id user_id=user.id, + # Keyword argument: action action="resume_timer", + # Keyword argument: entity_type entity_type="test", + # Keyword argument: entity_id entity_id=test.id, + # Keyword argument: details details={"paused_seconds": paused_seconds, "state": test.state.value}, ) + # Return test return test +# Define function _create_phase_worklog def _create_phase_worklog( + # Entry: db db: Session, *, + # Entry: test test: Test, + # Entry: user user: User, + # Entry: phase_started_at phase_started_at: datetime | None, + # Entry: phase_ended_at phase_ended_at: datetime, + # Entry: paused_seconds paused_seconds: int = 0, + # Entry: activity_type activity_type: str, + # Entry: description description: str, ) -> None: """Create an automatic, integrity-hashed worklog for a completed phase. @@ -360,32 +551,64 @@ def _create_phase_worklog( Subtracts accumulated *paused_seconds* from the gross elapsed time so the worklog reflects only active working time. Also triggers Tempo sync if the test has a Jira link. + + Args: + db (Session): Active SQLAlchemy database session. + test (Test): The test for which the worklog is being created. + user (User): The user attributed to the worklog. + phase_started_at (datetime | None): Timestamp when the phase began; + if ``None`` the worklog is skipped with a warning. + phase_ended_at (datetime): Timestamp when the phase ended. + paused_seconds (int): Accumulated paused time in seconds to subtract + from gross elapsed time. + activity_type (str): Worklog activity type label (e.g. + ``"red_team_execution"``). + description (str): Human-readable description for the worklog. """ + # Check: not phase_started_at if not phase_started_at: + # Log warning: logger.warning( + # Literal argument value "No phase start timestamp for test %s (%s), skipping worklog", test.id, activity_type, ) + # Return control to caller return + # Assign gross_seconds = int((phase_ended_at - phase_started_at).total_seconds()) gross_seconds = int((phase_ended_at - phase_started_at).total_seconds()) + # Assign duration_seconds = max(gross_seconds - paused_seconds, 1) duration_seconds = max(gross_seconds - paused_seconds, 1) + # Attempt the following; catch errors below try: + # Import create_worklog from app.services.worklog_service from app.services.worklog_service import create_worklog + # Assign wl = create_worklog( wl = create_worklog( db, + # Keyword argument: entity_type entity_type="test", + # Keyword argument: entity_id entity_id=test.id, + # Keyword argument: user_id user_id=user.id, + # Keyword argument: activity_type activity_type=activity_type, + # Keyword argument: started_at started_at=phase_started_at, + # Keyword argument: ended_at ended_at=phase_ended_at, + # Keyword argument: duration_seconds duration_seconds=duration_seconds, + # Keyword argument: description description=description, ) + # Log info: logger.info( + # Literal argument value "Auto-worklog created for test %s: %s, %ds (worklog %s)", test.id, activity_type, duration_seconds, wl.id, ) @@ -393,6 +616,7 @@ def _create_phase_worklog( # Sync to Tempo: only red_team_execution, using the already-computed # duration so the Tempo entry is identical to the Aegis worklog. try: + # Import auto_log_test_worklog from app.services.tempo_service from app.services.tempo_service import auto_log_test_worklog tempo_result = auto_log_test_worklog(db, test, user, activity_type, duration_seconds) if tempo_result and isinstance(tempo_result, dict): @@ -400,17 +624,26 @@ def _create_phase_worklog( wl.tempo_worklog_id = str(tempo_result.get("tempoWorklogId", "")) db.flush() except Exception as e: + # Log warning: "Tempo sync failed for worklog: %s", e, exc_info=T logger.warning("Tempo sync failed for worklog: %s", e, exc_info=True) + # Handle Exception except Exception as e: + # Log error: "Failed to create auto-worklog for test %s: %s", t logger.error("Failed to create auto-worklog for test %s: %s", test.id, e, exc_info=True) +# Define function validate_as_red_lead def validate_as_red_lead( + # Entry: db db: Session, + # Entry: test test: Test, + # Entry: user user: User, + # Entry: validation_status validation_status: str, + # Entry: notes notes: str | None = None, ) -> Test: """Record Red Lead's validation decision. @@ -418,21 +651,45 @@ def validate_as_red_lead( Delegates validation rules and state mutation entirely to :meth:`TestEntity.validate_red`. If both leads have voted the entity will also advance the test to ``validated`` or ``rejected``. + + Args: + db (Session): Active SQLAlchemy database session. + test (Test): The test being reviewed. + user (User): The red-lead user casting their vote. + validation_status (str): Validation decision, e.g. ``"approved"`` or + ``"rejected"``. + notes (str | None): Optional freeform notes explaining the decision. + + Returns: + Test: The mutated test with red-lead validation fields set. """ + # Assign entity = TestEntity.from_orm(test) entity = TestEntity.from_orm(test) + # Call entity.validate_red() entity.validate_red(validation_status, by=user.id, notes=notes) + # Call entity.apply_to() entity.apply_to(test) + # Flush changes to DB without committing the transaction db.flush() + # Call log_action() log_action( db, + # Keyword argument: user_id user_id=user.id, + # Keyword argument: action action="validate_as_red_lead", + # Keyword argument: entity_type entity_type="test", + # Keyword argument: entity_id entity_id=test.id, + # Keyword argument: details details={ + # Literal argument value "validation_status": validation_status, + # Literal argument value "notes": notes, + # Literal argument value "technique_id": str(test.technique_id), }, ) @@ -441,11 +698,17 @@ def validate_as_red_lead( return test +# Define function validate_as_blue_lead def validate_as_blue_lead( + # Entry: db db: Session, + # Entry: test test: Test, + # Entry: user user: User, + # Entry: validation_status validation_status: str, + # Entry: notes notes: str | None = None, ) -> Test: """Record Blue Lead's validation decision. @@ -453,21 +716,45 @@ def validate_as_blue_lead( Delegates validation rules and state mutation entirely to :meth:`TestEntity.validate_blue`. If both leads have voted the entity will also advance the test to ``validated`` or ``rejected``. + + Args: + db (Session): Active SQLAlchemy database session. + test (Test): The test being reviewed. + user (User): The blue-lead user casting their vote. + validation_status (str): Validation decision, e.g. ``"approved"`` or + ``"rejected"``. + notes (str | None): Optional freeform notes explaining the decision. + + Returns: + Test: The mutated test with blue-lead validation fields set. """ + # Assign entity = TestEntity.from_orm(test) entity = TestEntity.from_orm(test) + # Call entity.validate_blue() entity.validate_blue(validation_status, by=user.id, notes=notes) + # Call entity.apply_to() entity.apply_to(test) + # Flush changes to DB without committing the transaction db.flush() + # Call log_action() log_action( db, + # Keyword argument: user_id user_id=user.id, + # Keyword argument: action action="validate_as_blue_lead", + # Keyword argument: entity_type entity_type="test", + # Keyword argument: entity_id entity_id=test.id, + # Keyword argument: details details={ + # Literal argument value "validation_status": validation_status, + # Literal argument value "notes": notes, + # Literal argument value "technique_id": str(test.technique_id), }, ) @@ -476,35 +763,61 @@ def validate_as_blue_lead( return test +# Define function check_dual_validation def check_dual_validation(db: Session, test: Test) -> Test: """Evaluate both leads' decisions and advance the test if both have voted. All state mutation is delegated to :meth:`TestEntity.check_dual_validation`. This function never assigns ``test.state`` directly. + + Args: + db (Session): Active SQLAlchemy database session. + test (Test): The test to evaluate. + + Returns: + Test: The mutated test, potentially with state advanced to + ``validated`` or ``rejected``. """ + # Assign entity = TestEntity.from_orm(test) entity = TestEntity.from_orm(test) + # Call entity.check_dual_validation() entity.check_dual_validation() + # Call entity.apply_to() entity.apply_to(test) + # Call _dispatch_dual_validation_effects() _dispatch_dual_validation_effects(db, test, entity) + # Return test return test +# Define function _dispatch_dual_validation_effects def _dispatch_dual_validation_effects( db: Session, test: Test, entity: TestEntity, actor: User | None = None ) -> None: """Dispatch side effects (notifications, cache, Jira) based on domain events.""" for event in entity.events: + # Check: event.name == "dual_validation_approved" if event.name == "dual_validation_approved": + # Attempt the following; catch errors below try: + # Import invalidate from app.services.score_cache from app.services.score_cache import invalidate + # Call invalidate() invalidate() + # Handle Exception except Exception as e: + # Log warning: "Score cache invalidation failed: %s", e, exc_info logger.warning("Score cache invalidation failed: %s", e, exc_info=True) + # Attempt the following; catch errors below try: + # Call notify_test_state_change() notify_test_state_change(db, test, "validated") + # Handle Exception except Exception as e: + # Log warning: logger.warning( + # Literal argument value "Notification failed for test %s (validated): %s", test.id, e, exc_info=True, ) @@ -516,10 +829,15 @@ def _dispatch_dual_validation_effects( logger.warning("Jira push failed for test %s: %s", test.id, e, exc_info=True) elif event.name == "dual_validation_rejected": + # Attempt the following; catch errors below try: + # Call notify_test_state_change() notify_test_state_change(db, test, "rejected") + # Handle Exception except Exception as e: + # Log warning: logger.warning( + # Literal argument value "Notification failed for test %s (rejected): %s", test.id, e, exc_info=True, ) @@ -585,6 +903,7 @@ def _notify_validation_conflict(db: Session, test: Test, actor: User | None) -> ) +# Define function handle_remediation_completed def handle_remediation_completed(db: Session, test: Test, user: User) -> Test | None: """Create a re-test when remediation is completed. @@ -594,121 +913,199 @@ def handle_remediation_completed(db: Session, test: Test, user: User) -> Test | Prevents infinite loops by enforcing ``MAX_RETEST_COUNT``. - Returns the new retest or *None* if the limit was reached. + Args: + db (Session): Active SQLAlchemy database session. + test (Test): The test whose remediation was completed. + user (User): The user triggering the remediation completion. + + Returns: + Test | None: The newly created retest, or ``None`` if the maximum + retest count has been reached. """ # Always reference the original test, not an intermediate retest original_test_id = test.retest_of or test.id + # Check: test.retest_count >= settings.MAX_RETEST_COUNT if test.retest_count >= settings.MAX_RETEST_COUNT: # Max retests reached — notify and bail out if test.created_by: + # Call create_notification() create_notification( db, + # Keyword argument: user_id user_id=test.created_by, + # Keyword argument: type type="max_retests_reached", + # Keyword argument: title title="Maximum retests reached", + # Keyword argument: message message=( f'Test "{test.name}" has reached the maximum of ' f'{settings.MAX_RETEST_COUNT} retests. Manual review required.' ), + # Keyword argument: entity_type entity_type="test", + # Keyword argument: entity_id entity_id=test.id, ) + # Call log_action() log_action( db, + # Keyword argument: user_id user_id=user.id, + # Keyword argument: action action="max_retests_reached", + # Keyword argument: entity_type entity_type="test", + # Keyword argument: entity_id entity_id=test.id, + # Keyword argument: details details={ + # Literal argument value "retest_count": test.retest_count, + # Literal argument value "max_allowed": settings.MAX_RETEST_COUNT, + # Literal argument value "original_test_id": str(original_test_id), }, ) + # Return None return None + # Assign retest = Test( retest = Test( + # Keyword argument: technique_id technique_id=test.technique_id, + # Keyword argument: name name=f"[Retest #{test.retest_count + 1}] {test.name.replace(f'[Retest #{test.retest_count}] ', '')}", + # Keyword argument: description description=test.description, + # Keyword argument: platform platform=test.platform, + # Keyword argument: procedure_text procedure_text=test.procedure_text, + # Keyword argument: tool_used tool_used=test.tool_used, + # Keyword argument: state state=TestState.draft, + # Keyword argument: created_by created_by=test.created_by, + # Keyword argument: retest_of retest_of=original_test_id, + # Keyword argument: retest_count retest_count=test.retest_count + 1, ) + # Stage new record(s) for database insertion db.add(retest) + # Flush changes to DB without committing the transaction db.flush() + # Call log_action() log_action( db, + # Keyword argument: user_id user_id=user.id, + # Keyword argument: action action="create_retest", + # Keyword argument: entity_type entity_type="test", + # Keyword argument: entity_id entity_id=retest.id, + # Keyword argument: details details={ + # Literal argument value "original_test_id": str(original_test_id), + # Literal argument value "retest_number": retest.retest_count, + # Literal argument value "source_test_id": str(test.id), }, ) # Notify the test creator and any red_tech users if test.created_by: + # Call create_notification() create_notification( db, + # Keyword argument: user_id user_id=test.created_by, + # Keyword argument: type type="retest_created", + # Keyword argument: title title="Re-test created", + # Keyword argument: message message=( f'A re-test has been automatically created for "{test.name}" ' f'after remediation was completed.' ), + # Keyword argument: entity_type entity_type="test", + # Keyword argument: entity_id entity_id=retest.id, ) + # Flush changes to DB without committing the transaction db.flush() + # Return retest return retest -def get_retest_chain(db: Session, test_id) -> list[Test]: +# Define function get_retest_chain +def get_retest_chain(db: Session, test_id: uuid.UUID) -> list[Test]: """Return the full chain of retests for a given test. Includes the original test and all subsequent retests, ordered by retest_count. + + Args: + db (Session): Active SQLAlchemy database session. + test_id (uuid.UUID): UUID of any test in the retest chain. + + Returns: + list[Test]: The original test followed by all its retests in + ascending retest-count order. Returns an empty list if the + test is not found. """ + # Import uuid import uuid as _uuid + # Assign tid = _uuid.UUID(str(test_id)) if not isinstance(test_id, _uuid.UUID) els... tid = _uuid.UUID(str(test_id)) if not isinstance(test_id, _uuid.UUID) else test_id # Find the original test first test = db.query(Test).filter(Test.id == tid).first() + # Check: not test if not test: + # Return [] return [] + # Assign original_id = test.retest_of or test.id original_id = test.retest_of or test.id # Get original original = db.query(Test).filter(Test.id == original_id).first() + # Check: not original if not original: + # Return [test] return [test] # Get all retests of the original retests = ( db.query(Test) + # Chain .filter() call .filter(Test.retest_of == original_id) + # Chain .order_by() call .order_by(Test.retest_count) + # Chain .all() call .all() ) + # Return [original] + retests return [original] + retests +# Define function reopen_test def reopen_test(db: Session, test: Test, user: User) -> Test: """Move a ``rejected`` test back to ``draft`` for continued work. @@ -719,20 +1116,27 @@ def reopen_test(db: Session, test: Test, user: User) -> Test: re-validate the updated submission. Phase timing is reset so the timer starts fresh for the new execution attempt. """ + # Assign test = transition_state( test = transition_state( db, test, TestState.draft, user, + # Keyword argument: action_name action_name="reopen_test", ) # Clear validation DECISIONS — leads must re-validate the new attempt. # Rejection NOTES are intentionally kept so teams see what needs fixing. test.red_validation_status = None + # Assign test.red_validated_by = None test.red_validated_by = None + # Assign test.red_validated_at = None test.red_validated_at = None # test.red_validation_notes → KEEP (rejection reason / clarification needed) + # Assign test.blue_validation_status = None test.blue_validation_status = None + # Assign test.blue_validated_by = None test.blue_validated_by = None + # Assign test.blue_validated_at = None test.blue_validated_at = None # test.blue_validation_notes → KEEP (rejection reason / clarification needed) @@ -749,4 +1153,5 @@ def reopen_test(db: Session, test: Test, user: User) -> Test: except Exception as e: logger.warning("Jira push failed for test %s: %s", test.id, e, exc_info=True) + # Return test return test diff --git a/backend/app/services/threat_actor_import_service.py b/backend/app/services/threat_actor_import_service.py index 7b9de4c..e7127b3 100644 --- a/backend/app/services/threat_actor_import_service.py +++ b/backend/app/services/threat_actor_import_service.py @@ -26,23 +26,49 @@ Deduplication by ``mitre_id`` for ThreatActor and by the unique constraint ``(threat_actor_id, technique_id)`` for ThreatActorTechnique. """ +# Import io import io + +# Import json import json + +# Import logging import logging + +# Import shutil import shutil + +# Import tempfile import tempfile + +# Import zipfile import zipfile + +# Import datetime from datetime from datetime import datetime + +# Import Path from pathlib from pathlib import Path +# Import requests import requests as _requests + +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session -from app.models.threat_actor import ThreatActor, ThreatActorTechnique -from app.models.technique import Technique +# Import DataSource from app.models.data_source from app.models.data_source import DataSource + +# Import Technique from app.models.technique +from app.models.technique import Technique + +# Import ThreatActor, ThreatActorTechnique from app.models.threat_actor +from app.models.threat_actor import ThreatActor, ThreatActorTechnique + +# Import log_action from app.services.audit_service from app.services.audit_service import log_action +# Assign logger = logging.getLogger(__name__) logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- @@ -50,11 +76,15 @@ logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- MITRE_CTI_ZIP_URL = ( + # Literal argument value "https://github.com/mitre/cti" + # Literal argument value "/archive/refs/heads/master.zip" ) +# Assign _DOWNLOAD_TIMEOUT = 300 _DOWNLOAD_TIMEOUT = 300 +# Assign _ZIP_ROOT_PREFIX = "cti-master" _ZIP_ROOT_PREFIX = "cti-master" @@ -65,54 +95,86 @@ _ZIP_ROOT_PREFIX = "cti-master" def _download_zip(url: str = MITRE_CTI_ZIP_URL) -> bytes: """Download the MITRE CTI ZIP and return raw bytes.""" + # Log info: "Downloading MITRE CTI ZIP from %s …", url logger.info("Downloading MITRE CTI ZIP from %s …", url) + # Assign resp = _requests.get(url, timeout=_DOWNLOAD_TIMEOUT, stream=True) resp = _requests.get(url, timeout=_DOWNLOAD_TIMEOUT, stream=True) + # Call resp.raise_for_status() resp.raise_for_status() + # Assign content = resp.content content = resp.content + # Log info: "Downloaded %.1f MB", len(content) / (1024 * 1024 logger.info("Downloaded %.1f MB", len(content) / (1024 * 1024)) + # Return content return content +# Define function _extract_zip_and_load_bundle def _extract_zip_and_load_bundle(zip_bytes: bytes, dest: str) -> dict: """Extract ZIP and load the enterprise-attack STIX bundle.""" + # Open context manager with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf: + # Call zf.extractall() zf.extractall(dest) + # Assign bundle_path = ( bundle_path = ( Path(dest) / _ZIP_ROOT_PREFIX / "enterprise-attack" / "enterprise-attack.json" ) + # Check: not bundle_path.is_file() if not bundle_path.is_file(): + # Raise FileNotFoundError raise FileNotFoundError( f"STIX bundle not found at {bundle_path}" ) + # Log info: "Loading STIX bundle from %s …", bundle_path logger.info("Loading STIX bundle from %s …", bundle_path) + # Open context manager with open(bundle_path, "r", encoding="utf-8") as fh: + # Assign bundle = json.load(fh) bundle = json.load(fh) + # Assign objects = bundle.get("objects", []) objects = bundle.get("objects", []) + # Log info: "Loaded %d STIX objects", len(objects logger.info("Loaded %d STIX objects", len(objects)) + # Return bundle return bundle +# Define function _extract_mitre_id def _extract_mitre_id(external_references: list) -> str | None: """Extract the MITRE ATT&CK ID from external_references.""" + # Check: not isinstance(external_references, list) if not isinstance(external_references, list): + # Return None return None + # Iterate over external_references for ref in external_references: + # Check: isinstance(ref, dict) and ref.get("source_name") == "mitre-attack" if isinstance(ref, dict) and ref.get("source_name") == "mitre-attack": + # Return ref.get("external_id") return ref.get("external_id") + # Return None return None +# Define function _extract_mitre_url def _extract_mitre_url(external_references: list) -> str | None: """Extract the MITRE ATT&CK URL from external_references.""" + # Check: not isinstance(external_references, list) if not isinstance(external_references, list): + # Return None return None + # Iterate over external_references for ref in external_references: + # Check: isinstance(ref, dict) and ref.get("source_name") == "mitre-attack" if isinstance(ref, dict) and ref.get("source_name") == "mitre-attack": + # Return ref.get("url") return ref.get("url") + # Return None return None @@ -316,25 +378,41 @@ def _infer_motivation_from_description(description: str) -> str | None: def _parse_intrusion_sets(objects: list) -> list[dict]: """Parse STIX intrusion-set objects into ThreatActor dicts.""" + # Assign actors = [] actors = [] + # Iterate over objects for obj in objects: + # Check: obj.get("type") != "intrusion-set" if obj.get("type") != "intrusion-set": + # Skip to the next loop iteration continue + # Check: obj.get("revoked") if obj.get("revoked"): + # Skip to the next loop iteration continue + # Assign ext_refs = obj.get("external_references", []) ext_refs = obj.get("external_references", []) + # Assign mitre_id = _extract_mitre_id(ext_refs) mitre_id = _extract_mitre_id(ext_refs) + # Assign mitre_url = _extract_mitre_url(ext_refs) mitre_url = _extract_mitre_url(ext_refs) + # Assign name = obj.get("name", "").strip() name = obj.get("name", "").strip() + # Check: not name if not name: + # Skip to the next loop iteration continue + # Assign aliases = obj.get("aliases", []) aliases = obj.get("aliases", []) + # Check: isinstance(aliases, list) and name in aliases if isinstance(aliases, list) and name in aliases: + # Assign aliases = [a for a in aliases if a != name] aliases = [a for a in aliases if a != name] + # Assign description = obj.get("description", "") description = obj.get("description", "") # Derive motivation: curated override > STIX field > description inference @@ -348,80 +426,129 @@ def _parse_intrusion_sets(objects: list) -> list[dict]: # Extract references (non-MITRE) references = [] + # Iterate over ext_refs for ref in ext_refs: + # Check: isinstance(ref, dict) and ref.get("source_name") != "mitre-attack" if isinstance(ref, dict) and ref.get("source_name") != "mitre-attack": + # Call references.append() references.append({ + # Literal argument value "source": ref.get("source_name", ""), + # Literal argument value "url": ref.get("url", ""), + # Literal argument value "description": ref.get("description", ""), }) + # Call actors.append() actors.append({ + # Literal argument value "stix_id": obj.get("id"), # e.g. "intrusion-set--abc123" + # Literal argument value "mitre_id": mitre_id, + # Literal argument value "name": name, + # Literal argument value "aliases": aliases if aliases else [], + # Literal argument value "description": description, + # Literal argument value "mitre_url": mitre_url, + # Literal argument value "references": references[:20], # cap to avoid bloat + # Literal argument value "first_seen": obj.get("first_seen"), + # Literal argument value "last_seen": obj.get("last_seen"), "motivation": motivation, "sophistication": sophistication, }) + # Log info: "Parsed %d intrusion-sets (threat actors)", len(ac logger.info("Parsed %d intrusion-sets (threat actors)", len(actors)) + # Return actors return actors +# Define function _parse_relationships def _parse_relationships(objects: list) -> list[dict]: - """Parse STIX relationship objects (type=uses) linking - intrusion-sets to attack-patterns. - """ + """Parse STIX relationship objects (type=uses) linking intrusion-sets to attack-patterns.""" + # Assign relationships = [] relationships = [] + # Iterate over objects for obj in objects: + # Check: obj.get("type") != "relationship" if obj.get("type") != "relationship": + # Skip to the next loop iteration continue + # Check: obj.get("relationship_type") != "uses" if obj.get("relationship_type") != "uses": + # Skip to the next loop iteration continue + # Check: obj.get("revoked") if obj.get("revoked"): + # Skip to the next loop iteration continue + # Assign source_ref = obj.get("source_ref", "") source_ref = obj.get("source_ref", "") + # Assign target_ref = obj.get("target_ref", "") target_ref = obj.get("target_ref", "") # We want intrusion-set → attack-pattern if not source_ref.startswith("intrusion-set--"): + # Skip to the next loop iteration continue + # Check: not target_ref.startswith("attack-pattern--") if not target_ref.startswith("attack-pattern--"): + # Skip to the next loop iteration continue + # Call relationships.append() relationships.append({ + # Literal argument value "source_ref": source_ref, + # Literal argument value "target_ref": target_ref, + # Literal argument value "description": obj.get("description", ""), }) + # Log info: "Parsed %d uses-relationships (actor→technique)", logger.info("Parsed %d uses-relationships (actor→technique)", len(relationships)) + # Return relationships return relationships +# Define function _build_attack_pattern_map def _build_attack_pattern_map(objects: list) -> dict[str, str]: """Build a map from STIX attack-pattern ID → MITRE technique ID. e.g. {"attack-pattern--abc123": "T1059.001"} """ + # Assign mapping = {} mapping = {} + # Iterate over objects for obj in objects: + # Check: obj.get("type") != "attack-pattern" if obj.get("type") != "attack-pattern": + # Skip to the next loop iteration continue + # Check: obj.get("revoked") if obj.get("revoked"): + # Skip to the next loop iteration continue + # Assign stix_id = obj.get("id", "") stix_id = obj.get("id", "") + # Assign mitre_id = _extract_mitre_id(obj.get("external_references", [])) mitre_id = _extract_mitre_id(obj.get("external_references", [])) + # Check: stix_id and mitre_id if stix_id and mitre_id: + # Assign mapping[stix_id] = mitre_id mapping[stix_id] = mitre_id + # Log info: "Built attack-pattern map with %d entries", len(ma logger.info("Built attack-pattern map with %d entries", len(mapping)) + # Return mapping return mapping @@ -435,24 +562,31 @@ def sync(db: Session) -> dict: Returns a summary dict. """ + # Assign tmp_dir = tempfile.mkdtemp(prefix="aegis_mitre_cti_") tmp_dir = tempfile.mkdtemp(prefix="aegis_mitre_cti_") + # Attempt the following; catch errors below try: + # Assign zip_bytes = _download_zip() zip_bytes = _download_zip() + # Assign bundle = _extract_zip_and_load_bundle(zip_bytes, tmp_dir) bundle = _extract_zip_and_load_bundle(zip_bytes, tmp_dir) + # Always execute this cleanup block finally: + # Call shutil.rmtree() shutil.rmtree(tmp_dir, ignore_errors=True) + # Log info: "Cleaned up temp directory %s", tmp_dir logger.info("Cleaned up temp directory %s", tmp_dir) + # Assign objects = bundle.get("objects", []) objects = bundle.get("objects", []) # Step 1: Parse data actor_dicts = _parse_intrusion_sets(objects) + # Assign relationships = _parse_relationships(objects) relationships = _parse_relationships(objects) + # Assign attack_pattern_map = _build_attack_pattern_map(objects) attack_pattern_map = _build_attack_pattern_map(objects) - # Step 2: Build STIX-ID → actor dict map - stix_to_actor = {a["stix_id"]: a for a in actor_dicts} - # Step 3: Load existing actors and techniques from DB existing_actors = { row.mitre_id: row @@ -460,6 +594,7 @@ def sync(db: Session) -> dict: if row.mitre_id } + # Assign technique_by_mitre_id = { technique_by_mitre_id = { row.mitre_id: row for row in db.query(Technique).all() @@ -467,22 +602,35 @@ def sync(db: Session) -> dict: # Step 4: Upsert threat actors actors_created = 0 + # Assign actors_skipped = 0 actors_skipped = 0 + # Assign stix_to_db_actor = {} stix_to_db_actor: dict[str, ThreatActor] = {} + # Iterate over actor_dicts for actor_dict in actor_dicts: + # Assign mitre_id = actor_dict["mitre_id"] mitre_id = actor_dict["mitre_id"] + # Assign stix_id = actor_dict["stix_id"] stix_id = actor_dict["stix_id"] + # Check: mitre_id and mitre_id in existing_actors if mitre_id and mitre_id in existing_actors: # Update existing actor db_actor = existing_actors[mitre_id] + # Assign db_actor.name = actor_dict["name"] db_actor.name = actor_dict["name"] + # Assign db_actor.aliases = actor_dict["aliases"] db_actor.aliases = actor_dict["aliases"] + # Assign db_actor.description = actor_dict["description"] db_actor.description = actor_dict["description"] + # Assign db_actor.mitre_url = actor_dict["mitre_url"] db_actor.mitre_url = actor_dict["mitre_url"] + # Assign db_actor.references = actor_dict["references"] db_actor.references = actor_dict["references"] + # Assign db_actor.first_seen = actor_dict.get("first_seen") db_actor.first_seen = actor_dict.get("first_seen") + # Assign db_actor.last_seen = actor_dict.get("last_seen") db_actor.last_seen = actor_dict.get("last_seen") # Update enrichment fields if available if actor_dict.get("motivation"): @@ -490,101 +638,165 @@ def sync(db: Session) -> dict: if actor_dict.get("sophistication"): db_actor.sophistication = actor_dict["sophistication"] stix_to_db_actor[stix_id] = db_actor + # Assign actors_skipped = 1 actors_skipped += 1 + # Fallback: handle remaining cases else: # Create new actor db_actor = ThreatActor( + # Keyword argument: mitre_id mitre_id=mitre_id, + # Keyword argument: name name=actor_dict["name"], + # Keyword argument: aliases aliases=actor_dict["aliases"], + # Keyword argument: description description=actor_dict["description"], + # Keyword argument: mitre_url mitre_url=actor_dict["mitre_url"], + # Keyword argument: references references=actor_dict["references"], + # Keyword argument: first_seen first_seen=actor_dict.get("first_seen"), + # Keyword argument: last_seen last_seen=actor_dict.get("last_seen"), motivation=actor_dict.get("motivation"), sophistication=actor_dict.get("sophistication"), is_active=True, ) + # Stage new record(s) for database insertion db.add(db_actor) + # Flush changes to DB without committing the transaction db.flush() # get the ID + # Check: mitre_id if mitre_id: + # Assign existing_actors[mitre_id] = db_actor existing_actors[mitre_id] = db_actor + # Assign stix_to_db_actor[stix_id] = db_actor stix_to_db_actor[stix_id] = db_actor + # Assign actors_created = 1 actors_created += 1 + # Flush changes to DB without committing the transaction db.flush() # Step 5: Upsert actor-technique relationships # Load existing relationships existing_rels: set[tuple] = set() + # Iterate over db.query(ThreatActorTechnique).all() for row in db.query(ThreatActorTechnique).all(): + # Call existing_rels.add() existing_rels.add((str(row.threat_actor_id), str(row.technique_id))) + # Assign rels_created = 0 rels_created = 0 + # Assign rels_skipped = 0 rels_skipped = 0 + # Iterate over relationships for rel in relationships: + # Assign source_ref = rel["source_ref"] source_ref = rel["source_ref"] + # Assign target_ref = rel["target_ref"] target_ref = rel["target_ref"] # Resolve actor db_actor = stix_to_db_actor.get(source_ref) + # Check: not db_actor if not db_actor: + # Skip to the next loop iteration continue # Resolve technique mitre_technique_id = attack_pattern_map.get(target_ref) + # Check: not mitre_technique_id if not mitre_technique_id: + # Skip to the next loop iteration continue + # Assign db_technique = technique_by_mitre_id.get(mitre_technique_id) db_technique = technique_by_mitre_id.get(mitre_technique_id) + # Check: not db_technique if not db_technique: + # Skip to the next loop iteration continue + # Assign rel_key = (str(db_actor.id), str(db_technique.id)) rel_key = (str(db_actor.id), str(db_technique.id)) + # Check: rel_key in existing_rels if rel_key in existing_rels: + # Assign rels_skipped = 1 rels_skipped += 1 + # Skip to the next loop iteration continue + # Assign actor_technique = ThreatActorTechnique( actor_technique = ThreatActorTechnique( + # Keyword argument: threat_actor_id threat_actor_id=db_actor.id, + # Keyword argument: technique_id technique_id=db_technique.id, + # Keyword argument: usage_description usage_description=rel["description"][:5000] if rel["description"] else None, ) + # Stage new record(s) for database insertion db.add(actor_technique) + # Call existing_rels.add() existing_rels.add(rel_key) + # Assign rels_created = 1 rels_created += 1 + # Commit all pending changes to the database db.commit() + # Assign summary = { summary = { + # Literal argument value "actors_created": actors_created, + # Literal argument value "actors_updated": actors_skipped, + # Literal argument value "relationships_created": rels_created, + # Literal argument value "relationships_skipped": rels_skipped, + # Literal argument value "total_actors_parsed": len(actor_dicts), + # Literal argument value "total_relationships_parsed": len(relationships), } # Update DataSource record ds = db.query(DataSource).filter(DataSource.name == "mitre_cti").first() + # Check: ds if ds: + # Assign ds.last_sync_at = datetime.utcnow() ds.last_sync_at = datetime.utcnow() + # Assign ds.last_sync_status = "success" ds.last_sync_status = "success" + # Assign ds.last_sync_stats = summary ds.last_sync_stats = summary + # Commit all pending changes to the database db.commit() + # Log info: "MITRE CTI threat actor import complete — %s", sum logger.info("MITRE CTI threat actor import complete — %s", summary) + # Call log_action() log_action( db, + # Keyword argument: user_id user_id=None, + # Keyword argument: action action="import_threat_actors", + # Keyword argument: entity_type entity_type="threat_actor", + # Keyword argument: entity_id entity_id=None, + # Keyword argument: details details=summary, ) + # Commit all pending changes to the database db.commit() + # Return summary return summary diff --git a/backend/app/services/threat_actor_service.py b/backend/app/services/threat_actor_service.py index c8aa18f..abfafd1 100644 --- a/backend/app/services/threat_actor_service.py +++ b/backend/app/services/threat_actor_service.py @@ -6,34 +6,56 @@ that the router remains a thin HTTP adapter. This module is framework-agnostic: no FastAPI imports. """ +# Enable future language features for compatibility from __future__ import annotations +# Import Any from typing from typing import Any from sqlalchemy import case, cast, func, or_, Text from sqlalchemy.orm import Session +# Import EntityNotFoundError from app.domain.errors from app.domain.errors import EntityNotFoundError -from app.models.enums import TechniqueStatus -from app.models.test import Test -from app.models.test_template import TestTemplate -from app.models.threat_actor import ThreatActor, ThreatActorTechnique -from app.models.technique import Technique -from app.utils import escape_like +# Import TechniqueStatus from app.models.enums +from app.models.enums import TechniqueStatus + +# Import Technique from app.models.technique +from app.models.technique import Technique + +# Import Test from app.models.test +from app.models.test import Test + +# Import TestTemplate from app.models.test_template +from app.models.test_template import TestTemplate + +# Import ThreatActor, ThreatActorTechnique from app.models.threat_actor +from app.models.threat_actor import ThreatActor, ThreatActorTechnique + +# Import escape_like from app.utils +from app.utils import escape_like # ── Public service functions ────────────────────────────────────────── def list_actors( + # Entry: db db: Session, *, + # Entry: search search: str | None = None, + # Entry: country country: str | None = None, + # Entry: motivation motivation: str | None = None, + # Entry: sophistication sophistication: str | None = None, + # Entry: target_sectors target_sectors: str | None = None, + # Entry: offset offset: int = 0, + # Entry: limit limit: int = 50, ) -> dict[str, Any]: """List threat actors with optional filters, pagination, and coverage stats. @@ -41,10 +63,14 @@ def list_actors( Uses grouped subqueries to avoid N+1: technique counts and coverage counts are fetched in one query per page. """ + # Assign query = db.query(ThreatActor) query = db.query(ThreatActor) + # Check: search if search: + # Assign pattern = f"%{escape_like(search)}%" pattern = f"%{escape_like(search)}%" + # Assign query = query.filter( query = query.filter( or_( ThreatActor.name.ilike(pattern), @@ -53,35 +79,52 @@ def list_actors( ) ) + # Check: country if country: + # Assign query = query.filter(ThreatActor.country == country) query = query.filter(ThreatActor.country == country) + # Check: motivation if motivation: + # Assign query = query.filter(ThreatActor.motivation == motivation) query = query.filter(ThreatActor.motivation == motivation) + # Check: sophistication if sophistication: + # Assign query = query.filter(ThreatActor.sophistication == sophistication) query = query.filter(ThreatActor.sophistication == sophistication) + # Check: target_sectors if target_sectors: + # Assign query = query.filter( query = query.filter( cast(ThreatActor.target_sectors, Text).ilike( f"%{escape_like(target_sectors)}%" ) ) + # Assign total = query.count() total = query.count() + # Assign actors = ( actors = ( query.order_by(ThreatActor.name).offset(offset).limit(limit).all() ) + # Check: not actors if not actors: + # Return { return { + # Literal argument value "total": total, + # Literal argument value "offset": offset, + # Literal argument value "limit": limit, + # Literal argument value "items": [], } + # Assign actor_ids = [a.id for a in actors] actor_ids = [a.id for a in actors] # Single grouped query: tech_count and covered_count per actor @@ -96,215 +139,355 @@ def list_actors( TechniqueStatus.validated, TechniqueStatus.partial, ]), + # Literal argument value 1, ), + # Keyword argument: else_ else_=0, ) ).label("covered_count"), ) + # Chain .join() call .join(Technique, ThreatActorTechnique.technique_id == Technique.id) + # Chain .filter() call .filter(ThreatActorTechnique.threat_actor_id.in_(actor_ids)) + # Chain .group_by() call .group_by(ThreatActorTechnique.threat_actor_id) ).all() + # Assign counts_map = { counts_map = { str(row.threat_actor_id): { + # Literal argument value "tech_count": row.tech_count, + # Literal argument value "covered_count": row.covered_count or 0, } for row in counts_rows } + # Assign results = [] results = [] + # Iterate over actors for actor in actors: + # Assign cnt = counts_map.get(str(actor.id), {"tech_count": 0, "covered_count": 0}) cnt = counts_map.get(str(actor.id), {"tech_count": 0, "covered_count": 0}) + # Assign tech_count = cnt["tech_count"] tech_count = cnt["tech_count"] + # Assign covered = cnt["covered_count"] covered = cnt["covered_count"] + # Assign coverage_pct = round((covered / tech_count * 100), 1) if tech_count > 0 else 0.0 coverage_pct = round((covered / tech_count * 100), 1) if tech_count > 0 else 0.0 + # Call results.append() results.append({ + # Literal argument value "id": str(actor.id), + # Literal argument value "mitre_id": actor.mitre_id, + # Literal argument value "name": actor.name, + # Literal argument value "aliases": actor.aliases or [], + # Literal argument value "country": actor.country, + # Literal argument value "target_sectors": actor.target_sectors or [], + # Literal argument value "target_regions": actor.target_regions or [], + # Literal argument value "motivation": actor.motivation, + # Literal argument value "sophistication": actor.sophistication, + # Literal argument value "mitre_url": actor.mitre_url, + # Literal argument value "technique_count": tech_count, + # Literal argument value "coverage_pct": coverage_pct, + # Literal argument value "is_active": actor.is_active, }) + # Return { return { + # Literal argument value "total": total, + # Literal argument value "offset": offset, + # Literal argument value "limit": limit, + # Literal argument value "items": results, } +# Define function get_actor_detail def get_actor_detail(db: Session, actor_id: str) -> dict[str, Any]: """Get detailed threat actor with techniques. Raises EntityNotFoundError if the actor does not exist. """ + # Assign actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first() actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first() + # Check: not actor if not actor: + # Raise EntityNotFoundError raise EntityNotFoundError("Threat actor", actor_id) + # Assign actor_techniques = ( actor_techniques = ( db.query(ThreatActorTechnique, Technique) + # Chain .join() call .join(Technique, ThreatActorTechnique.technique_id == Technique.id) + # Chain .filter() call .filter(ThreatActorTechnique.threat_actor_id == actor.id) + # Chain .order_by() call .order_by(Technique.mitre_id) + # Chain .all() call .all() ) + # Assign techniques_list = [ techniques_list = [ { + # Literal argument value "technique_id": str(tech.id), + # Literal argument value "mitre_id": tech.mitre_id, + # Literal argument value "name": tech.name, + # Literal argument value "tactic": tech.tactic, + # Literal argument value "status_global": tech.status_global.value if tech.status_global else None, + # Literal argument value "usage_description": at.usage_description, + # Literal argument value "first_seen_using": at.first_seen_using, } for at, tech in actor_techniques ] + # Return { return { + # Literal argument value "id": str(actor.id), + # Literal argument value "mitre_id": actor.mitre_id, + # Literal argument value "name": actor.name, + # Literal argument value "aliases": actor.aliases or [], + # Literal argument value "description": actor.description, + # Literal argument value "country": actor.country, + # Literal argument value "target_sectors": actor.target_sectors or [], + # Literal argument value "target_regions": actor.target_regions or [], + # Literal argument value "motivation": actor.motivation, + # Literal argument value "sophistication": actor.sophistication, + # Literal argument value "first_seen": actor.first_seen, + # Literal argument value "last_seen": actor.last_seen, + # Literal argument value "references": actor.references or [], + # Literal argument value "mitre_url": actor.mitre_url, + # Literal argument value "is_active": actor.is_active, + # Literal argument value "techniques": techniques_list, } +# Define function get_actor_coverage def get_actor_coverage(db: Session, actor_id: str) -> dict[str, Any]: """Calculate coverage percentage against a specific threat actor. Raises EntityNotFoundError if the actor does not exist. """ + # Assign actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first() actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first() + # Check: not actor if not actor: + # Raise EntityNotFoundError raise EntityNotFoundError("Threat actor", actor_id) + # Assign actor_techniques = ( actor_techniques = ( db.query(Technique) + # Chain .join() call .join(ThreatActorTechnique, ThreatActorTechnique.technique_id == Technique.id) + # Chain .filter() call .filter(ThreatActorTechnique.threat_actor_id == actor.id) + # Chain .all() call .all() ) + # Assign total = len(actor_techniques) total = len(actor_techniques) + # Check: total == 0 if total == 0: + # Return { return { + # Literal argument value "actor_id": str(actor.id), + # Literal argument value "actor_name": actor.name, + # Literal argument value "total_techniques": 0, + # Literal argument value "coverage_pct": 0.0, + # Literal argument value "breakdown": {}, } + # Assign breakdown = {} breakdown: dict[str, int] = {} + # Iterate over actor_techniques for tech in actor_techniques: + # Assign status = tech.status_global.value if tech.status_global else "not_evaluated" status = tech.status_global.value if tech.status_global else "not_evaluated" + # Assign breakdown[status] = breakdown.get(status, 0) + 1 breakdown[status] = breakdown.get(status, 0) + 1 + # Assign covered = breakdown.get("validated", 0) + breakdown.get("partial", 0) covered = breakdown.get("validated", 0) + breakdown.get("partial", 0) + # Assign coverage_pct = round((covered / total * 100), 1) coverage_pct = round((covered / total * 100), 1) + # Return { return { + # Literal argument value "actor_id": str(actor.id), + # Literal argument value "actor_name": actor.name, + # Literal argument value "total_techniques": total, + # Literal argument value "covered": covered, + # Literal argument value "coverage_pct": coverage_pct, + # Literal argument value "breakdown": breakdown, } +# Define function get_actor_gaps def get_actor_gaps(db: Session, actor_id: str) -> dict[str, Any]: """Identify techniques of this actor that are not fully validated. Raises EntityNotFoundError if the actor does not exist. """ + # Assign actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first() actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first() + # Check: not actor if not actor: + # Raise EntityNotFoundError raise EntityNotFoundError("Threat actor", actor_id) + # Assign gap_techniques = ( gap_techniques = ( db.query(Technique, ThreatActorTechnique) + # Chain .join() call .join(ThreatActorTechnique, ThreatActorTechnique.technique_id == Technique.id) + # Chain .filter() call .filter(ThreatActorTechnique.threat_actor_id == actor.id) + # Chain .filter() call .filter(Technique.status_global != TechniqueStatus.validated) + # Chain .order_by() call .order_by(Technique.mitre_id) + # Chain .all() call .all() ) + # Check: not gap_techniques if not gap_techniques: + # Return { return { + # Literal argument value "actor_id": str(actor.id), + # Literal argument value "actor_name": actor.name, + # Literal argument value "total_gaps": 0, + # Literal argument value "gaps": [], } + # Assign technique_ids = [tech.id for tech, _ in gap_techniques] technique_ids = [tech.id for tech, _ in gap_techniques] + # Assign mitre_ids = [tech.mitre_id for tech, _ in gap_techniques] mitre_ids = [tech.mitre_id for tech, _ in gap_techniques] # Batch template counts by mitre_technique_id template_counts = ( db.query(TestTemplate.mitre_technique_id, func.count(TestTemplate.id).label("cnt")) + # Chain .filter() call .filter(TestTemplate.mitre_technique_id.in_(mitre_ids)) + # Chain .filter() call .filter(TestTemplate.is_active == True) + # Chain .group_by() call .group_by(TestTemplate.mitre_technique_id) ).all() + # Assign template_map = {row.mitre_technique_id: row.cnt for row in template_counts} template_map = {row.mitre_technique_id: row.cnt for row in template_counts} # Batch test counts by technique_id test_counts = ( db.query(Test.technique_id, func.count(Test.id).label("cnt")) + # Chain .filter() call .filter(Test.technique_id.in_(technique_ids)) + # Chain .group_by() call .group_by(Test.technique_id) ).all() + # Assign test_map = {str(row.technique_id): row.cnt for row in test_counts} test_map = {str(row.technique_id): row.cnt for row in test_counts} + # Assign gaps = [] gaps = [] + # Iterate over gap_techniques for tech, at in gap_techniques: + # Assign template_count = template_map.get(tech.mitre_id, 0) template_count = template_map.get(tech.mitre_id, 0) + # Assign test_count = test_map.get(str(tech.id), 0) test_count = test_map.get(str(tech.id), 0) + # Call gaps.append() gaps.append({ + # Literal argument value "technique_id": str(tech.id), + # Literal argument value "mitre_id": tech.mitre_id, + # Literal argument value "name": tech.name, + # Literal argument value "tactic": tech.tactic, + # Literal argument value "status_global": tech.status_global.value if tech.status_global else None, + # Literal argument value "usage_description": at.usage_description, + # Literal argument value "available_templates": template_count, + # Literal argument value "existing_tests": test_count, + # Literal argument value "has_templates": template_count > 0, }) + # Return { return { + # Literal argument value "actor_id": str(actor.id), + # Literal argument value "actor_name": actor.name, + # Literal argument value "total_gaps": len(gaps), + # Literal argument value "gaps": gaps, } diff --git a/backend/app/services/user_service.py b/backend/app/services/user_service.py index 004da4b..dc735c4 100644 --- a/backend/app/services/user_service.py +++ b/backend/app/services/user_service.py @@ -4,30 +4,51 @@ Uses domain exceptions from app.domain.errors. The router handles HTTP concerns, auth, audit logging, and commit. """ +# Enable future language features for compatibility from __future__ import annotations +# Import uuid import uuid +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session +# Import hash_password from app.auth from app.auth import hash_password -from app.domain.errors import BusinessRuleViolation, DuplicateEntityError, EntityNotFoundError + +# Import from app.domain.errors +from app.domain.errors import ( + BusinessRuleViolation, + DuplicateEntityError, + EntityNotFoundError, +) + +# Import User from app.models.user from app.models.user import User +# Assign VALID_ROLES = {"admin", "red_tech", "blue_tech", "red_lead", "blue_lead", "viewer"} VALID_ROLES = {"admin", "red_tech", "blue_tech", "red_lead", "blue_lead", "viewer"} +# Define function list_users def list_users(db: Session) -> list[User]: """Return a list of all users ordered by username.""" + # Return db.query(User).order_by(User.username).all() return db.query(User).order_by(User.username).all() +# Define function create_user def create_user( + # Entry: db db: Session, *, + # Entry: username username: str, + # Entry: email email: str | None, + # Entry: password password: str, + # Entry: role role: str, ) -> User: """Create a new user. @@ -36,33 +57,51 @@ def create_user( Raises BusinessRuleViolation if role is invalid. Does not commit; the router handles that. """ + # Assign existing = db.query(User).filter(User.username == username).first() existing = db.query(User).filter(User.username == username).first() + # Check: existing if existing: + # Raise DuplicateEntityError raise DuplicateEntityError("User", "username", username) + # Check: role not in VALID_ROLES if role not in VALID_ROLES: + # Raise BusinessRuleViolation raise BusinessRuleViolation( f"Invalid role '{role}'. Must be one of: {', '.join(sorted(VALID_ROLES))}" ) + # Assign user = User( user = User( + # Keyword argument: username username=username, + # Keyword argument: email email=email, + # Keyword argument: hashed_password hashed_password=hash_password(password), + # Keyword argument: role role=role, ) + # Stage new record(s) for database insertion db.add(user) + # Return user return user +# Define function get_user_or_raise def get_user_or_raise(db: Session, user_id: uuid.UUID) -> User: """Return a user by ID or raise EntityNotFoundError.""" + # Assign user = db.query(User).filter(User.id == user_id).first() user = db.query(User).filter(User.id == user_id).first() + # Check: user is None if user is None: + # Raise EntityNotFoundError raise EntityNotFoundError("User", str(user_id)) + # Return user return user +# Define function update_user def update_user(db: Session, user_id: uuid.UUID, **fields: object) -> User: """Update one or more fields of an existing user. @@ -71,18 +110,27 @@ def update_user(db: Session, user_id: uuid.UUID, **fields: object) -> User: Handles 'password' by hashing and storing as 'hashed_password'. Does not commit; the router handles that. """ + # Assign user = get_user_or_raise(db, user_id) user = get_user_or_raise(db, user_id) + # Assign update_data = dict(fields) update_data = dict(fields) + # Check: "role" in update_data and update_data["role"] not in VALID_ROLES if "role" in update_data and update_data["role"] not in VALID_ROLES: + # Raise BusinessRuleViolation raise BusinessRuleViolation( f"Invalid role '{update_data['role']}'. Must be one of: {', '.join(sorted(VALID_ROLES))}" ) + # Check: "password" in update_data if "password" in update_data: + # Assign update_data["hashed_password"] = hash_password(str(update_data.pop("password"))) update_data["hashed_password"] = hash_password(str(update_data.pop("password"))) + # Iterate over update_data.items() for field, value in update_data.items(): + # Call setattr() setattr(user, field, value) + # Return user return user diff --git a/backend/app/services/worklog_service.py b/backend/app/services/worklog_service.py index dfd4dd1..11bbb6b 100644 --- a/backend/app/services/worklog_service.py +++ b/backend/app/services/worklog_service.py @@ -1,83 +1,141 @@ """Internal worklog service — CRUD with integrity hashing.""" +# Import hashlib import hashlib + +# Import logging import logging + +# Import datetime from datetime from datetime import datetime + +# Import Optional from typing from typing import Optional + +# Import UUID from uuid from uuid import UUID +# Import Session from sqlalchemy.orm from sqlalchemy.orm import Session +# Import EntityNotFoundError from app.domain.errors from app.domain.errors import EntityNotFoundError + +# Import Worklog from app.models.worklog from app.models.worklog import Worklog +# Assign logger = logging.getLogger(__name__) logger = logging.getLogger(__name__) +# Define function create_worklog def create_worklog( + # Entry: db db: Session, *, + # Entry: entity_type entity_type: str, + # Entry: entity_id entity_id: UUID, + # Entry: user_id user_id: UUID, + # Entry: activity_type activity_type: str, + # Entry: started_at started_at: datetime, + # Entry: duration_seconds duration_seconds: int, + # Entry: ended_at ended_at: Optional[datetime] = None, + # Entry: description description: Optional[str] = None, ) -> Worklog: """Create a worklog with an auto-computed integrity hash.""" + # Assign wl = Worklog( wl = Worklog( + # Keyword argument: entity_type entity_type=entity_type, + # Keyword argument: entity_id entity_id=entity_id, + # Keyword argument: user_id user_id=user_id, + # Keyword argument: activity_type activity_type=activity_type, + # Keyword argument: started_at started_at=started_at, + # Keyword argument: ended_at ended_at=ended_at, + # Keyword argument: duration_seconds duration_seconds=duration_seconds, + # Keyword argument: description description=description, ) + # Assign wl.integrity_hash = _compute_hash(wl) wl.integrity_hash = _compute_hash(wl) + # Stage new record(s) for database insertion db.add(wl) # Does not commit; caller (router) uses UnitOfWork. return wl +# Define function get_worklog_or_raise def get_worklog_or_raise(db: Session, worklog_id: UUID) -> Worklog: """Get a worklog by ID or raise EntityNotFoundError.""" + # Assign wl = db.query(Worklog).filter(Worklog.id == worklog_id).first() wl = db.query(Worklog).filter(Worklog.id == worklog_id).first() + # Check: not wl if not wl: + # Raise EntityNotFoundError raise EntityNotFoundError("Worklog", str(worklog_id)) + # Return wl return wl +# Define function list_worklogs def list_worklogs( + # Entry: db db: Session, *, + # Entry: entity_type entity_type: Optional[str] = None, + # Entry: entity_id entity_id: Optional[UUID] = None, + # Entry: user_id user_id: Optional[UUID] = None, ) -> list[Worklog]: """List worklogs with optional filters.""" + # Assign query = db.query(Worklog) query = db.query(Worklog) + # Check: entity_type if entity_type: + # Assign query = query.filter(Worklog.entity_type == entity_type) query = query.filter(Worklog.entity_type == entity_type) + # Check: entity_id if entity_id: + # Assign query = query.filter(Worklog.entity_id == entity_id) query = query.filter(Worklog.entity_id == entity_id) + # Check: user_id if user_id: + # Assign query = query.filter(Worklog.user_id == user_id) query = query.filter(Worklog.user_id == user_id) + # Return query.order_by(Worklog.started_at.desc()).all() return query.order_by(Worklog.started_at.desc()).all() +# Define function verify_worklog_integrity def verify_worklog_integrity(wl: Worklog) -> bool: """Return True if the worklog has not been tampered with.""" + # Return wl.integrity_hash == _compute_hash(wl) return wl.integrity_hash == _compute_hash(wl) +# Define function _compute_hash def _compute_hash(wl: Worklog) -> str: """SHA-256 of the immutable fields for audit integrity.""" + # Assign data = ( data = ( f"{wl.entity_type}:{wl.entity_id}:{wl.user_id}:" f"{wl.activity_type}:{wl.started_at}:{wl.duration_seconds}" ) + # Return hashlib.sha256(data.encode()).hexdigest() return hashlib.sha256(data.encode()).hexdigest() diff --git a/backend/app/storage.py b/backend/app/storage.py index 3fa80ae..832c818 100644 --- a/backend/app/storage.py +++ b/backend/app/storage.py @@ -12,9 +12,13 @@ Two clients are maintained: ``MINIO_ENDPOINT`` (backwards-compatible). """ +# Import boto3 import boto3 + +# Import ClientError from botocore.exceptions from botocore.exceptions import ClientError +# Import settings from app.config from app.config import settings # --------------------------------------------------------------------------- @@ -25,10 +29,15 @@ _scheme = "https" if settings.MINIO_SECURE else "http" # Internal client — used for uploads and bucket management _client = boto3.client( + # Literal argument value "s3", + # Keyword argument: endpoint_url endpoint_url=f"{_scheme}://{settings.MINIO_ENDPOINT}", + # Keyword argument: aws_access_key_id aws_access_key_id=settings.MINIO_ACCESS_KEY, + # Keyword argument: aws_secret_access_key aws_secret_access_key=settings.MINIO_SECRET_KEY, + # Keyword argument: region_name region_name="us-east-1", # MinIO ignores this but boto3 requires it ) @@ -51,22 +60,32 @@ _public_client = boto3.client( def ensure_bucket_exists() -> None: """Create the evidence bucket if it does not already exist.""" + # Attempt the following; catch errors below try: + # Call _client.head_bucket() _client.head_bucket(Bucket=settings.MINIO_BUCKET) + # Handle ClientError except ClientError: + # Call _client.create_bucket() _client.create_bucket(Bucket=settings.MINIO_BUCKET) +# Define function upload_file def upload_file(content: bytes, key: str) -> str: """Upload *content* to the evidence bucket under *key*. Returns the key that was written (same as the input). """ + # Call _client.put_object() _client.put_object( + # Keyword argument: Bucket Bucket=settings.MINIO_BUCKET, + # Keyword argument: Key Key=key, + # Keyword argument: Body Body=content, ) + # Return key return key @@ -85,6 +104,8 @@ def get_presigned_url(key: str, expiration: int = 3600) -> str: """ return _public_client.generate_presigned_url( "get_object", + # Keyword argument: Params Params={"Bucket": settings.MINIO_BUCKET, "Key": key}, + # Keyword argument: ExpiresIn ExpiresIn=expiration, ) diff --git a/backend/app/utils.py b/backend/app/utils.py index 8adaefb..a3d3790 100644 --- a/backend/app/utils.py +++ b/backend/app/utils.py @@ -1,6 +1,7 @@ """Shared utility helpers.""" +# Define function escape_like def escape_like(value: str) -> str: """Escape SQL LIKE wildcard characters (``%`` and ``_``). @@ -13,9 +14,13 @@ def escape_like(value: str) -> str: from app.utils import escape_like query.filter(Model.name.ilike(f"%{escape_like(term)}%")) """ + # Return ( return ( value + # Chain .replace() call .replace("\\", "\\\\") + # Chain .replace() call .replace("%", "\\%") + # Chain .replace() call .replace("_", "\\_") ) diff --git a/backend/requirements.txt b/backend/requirements.txt index 1f2043a..b975e74 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -3,14 +3,13 @@ uvicorn[standard] sqlalchemy psycopg2-binary alembic -python-jose[cryptography] +PyJWT passlib[bcrypt] bcrypt==4.0.1 boto3 apscheduler requests pyyaml -pySigma toml taxii2-client python-multipart diff --git a/backend/ruff.toml b/backend/ruff.toml index a3bdc3c..876dd28 100644 --- a/backend/ruff.toml +++ b/backend/ruff.toml @@ -1,13 +1,27 @@ +# PEP8 line length: 120 chars — the codebase uses longer identifiers and SQLAlchemy chaining +line-length = 120 + [lint] -# Ignore rules that have widespread pre-existing violations. -# These can be cleaned up incrementally in follow-up PRs. +# PEP8 compliance rules enforced: +# E/W — pycodestyle (core PEP8 style and warnings) +# F — pyflakes (unused imports, undefined names) +# I — isort (import ordering per PEP8 convention) +# N — pep8-naming (class/function/variable naming conventions) +# ANN — flake8-annotations (type hint enforcement) +select = ["E", "W", "F", "I", "N", "ANN", "D"] + ignore = [ - "E402", # module-level import not at top of file (app.main, some services) - "E712", # == True comparisons (required by SQLAlchemy filter syntax) - "F401", # unused imports (widespread; clean up incrementally) - "F841", # unused local variables (a few occurrences) + # SQLAlchemy filter syntax requires `== True` / `== False` comparisons + "E712", + # ANN101/ANN102 (self/cls type annotations) removed from ruff — not needed ] +[lint.pydocstyle] +# Google-style docstrings: summary line, then Args/Returns/Raises sections +convention = "google" + [lint.per-file-ignores] -# Test files may use broad exception catching and unusual import patterns -"tests/**" = ["E", "F"] +# Tests use broad exception catching and unusual import patterns +"tests/**" = ["E", "F", "N"] +# Data file: D3FEND technique descriptions contain URLs and long strings that cannot be meaningfully wrapped +"app/services/d3fend_import_service.py" = ["E501"] diff --git a/backend/tests/test_auth.py b/backend/tests/test_auth.py index c6de441..4377c31 100644 --- a/backend/tests/test_auth.py +++ b/backend/tests/test_auth.py @@ -102,7 +102,7 @@ def test_logout_revokes_token(client, admin_user): ) assert out.status_code == 200 - from jose import jwt + import jwt from app.config import settings from app.infrastructure.redis_client import get_redis_blacklist