refactor(types): add comprehensive type annotations across backend Python codebase

Enable ANN rules in ruff.toml (flake8-annotations) and resolve all 221 violations:

ANN201/ANN202 — return types on 168 public/private functions:
- All 28 FastAPI routers: endpoints annotated with dict/list/specific schema/
  StreamingResponse/FileResponse/JSONResponse as appropriate
- main.py: lifespan→AsyncGenerator[None,None], exception handlers→JSONResponse
- database.py: get_db→Generator[Session,None,None], proxy methods→correct types
- middleware/request_context.py: dispatch→Response with Callable call_next type

ANN001/ANN002/ANN003 — 32 missing argument types:
- seed_demo.py: all db parameters typed as Session
- domain/unit_of_work.py: __aexit__ exc_type/exc_val/exc_tb typed with TracebackType
- services: audit_service user_id→UUID|None, heatmap_service query/model/builder,
  notification_service test→Test, tempo_service test→Test/user→User,
  test_workflow_service test_id→UUID, campaign_crud **fields→object,
  test_crud **fields→object (4 sites)

ANN401 — 16 Any usages resolved:
- Domain entities (campaign/technique/threat_actor/test_entity): replaced Any with
  actual ORM types via TYPE_CHECKING guards to avoid circular imports
- detection_rule_service: test_id/detection_rule_id/evaluator_id→UUID
- score_cache: kept Any with # noqa: ANN401 (genuinely generic cache)
- jira_service/tempo_service: kept Any with # noqa: ANN401 (lazy optional deps)
- d3fend_import_service: _to_str(v: Any) kept with # noqa: ANN401

ANN204/ANN205/ANN206 — special/static/class methods:
- database.py proxy __call__/__getattr__: *args: object/**kwargs: object
- schemas/test.py model_validate: obj→object, **kwargs→object
- sa_technique_repository._int_type→type

All 439 unit tests pass. ruff check app/ → All checks passed!
This commit is contained in:
kitos
2026-06-09 17:04:51 +02:00
parent 8f98bdd273
commit 9ff0f04ba3
51 changed files with 267 additions and 223 deletions
+10 -7
View File
@@ -1,5 +1,8 @@
from collections.abc import Generator
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.orm import declarative_base, sessionmaker from sqlalchemy.engine import Engine
from sqlalchemy.orm import Session, declarative_base, sessionmaker
Base = declarative_base() Base = declarative_base()
@@ -10,7 +13,7 @@ _engine = None
_SessionLocal = None _SessionLocal = None
def _get_engine(): def _get_engine() -> Engine:
global _engine global _engine
if _engine is None: if _engine is None:
from app.config import settings from app.config import settings
@@ -28,7 +31,7 @@ def _get_engine():
return _engine return _engine
def _get_session_factory(): def _get_session_factory() -> sessionmaker:
global _SessionLocal global _SessionLocal
if _SessionLocal is None: if _SessionLocal is None:
_SessionLocal = sessionmaker( _SessionLocal = sessionmaker(
@@ -41,10 +44,10 @@ class _LazySessionLocal:
"""Proxy so ``SessionLocal()`` keeps working as before but the real """Proxy so ``SessionLocal()`` keeps working as before but the real
sessionmaker is only created on first call.""" sessionmaker is only created on first call."""
def __call__(self, *args, **kwargs): def __call__(self, *args: object, **kwargs: object) -> Session:
return _get_session_factory()(*args, **kwargs) return _get_session_factory()(*args, **kwargs)
def __getattr__(self, name): def __getattr__(self, name: str) -> object:
return getattr(_get_session_factory(), name) return getattr(_get_session_factory(), name)
@@ -53,14 +56,14 @@ SessionLocal = _LazySessionLocal()
class _EngineProxy: class _EngineProxy:
"""Thin proxy so ``from app.database import engine`` still works.""" """Thin proxy so ``from app.database import engine`` still works."""
def __getattr__(self, name): def __getattr__(self, name: str) -> object:
return getattr(_get_engine(), name) return getattr(_get_engine(), name)
engine = _EngineProxy() # type: ignore[assignment] engine = _EngineProxy() # type: ignore[assignment]
def get_db(): def get_db() -> Generator[Session, None, None]:
db = SessionLocal() db = SessionLocal()
try: try:
yield db yield db
+3 -2
View File
@@ -8,6 +8,7 @@ Provides:
(admins always pass). (admins always pass).
""" """
from collections.abc import Callable
from typing import Optional from typing import Optional
from fastapi import Cookie, Depends, HTTPException, status from fastapi import Cookie, Depends, HTTPException, status
@@ -112,7 +113,7 @@ async def require_password_changed(
return current_user return current_user
def require_role(required_role: str): def require_role(required_role: str) -> Callable[..., object]:
"""Return a FastAPI dependency that enforces *required_role*. """Return a FastAPI dependency that enforces *required_role*.
The dependency allows the request to proceed when The dependency allows the request to proceed when
@@ -133,7 +134,7 @@ def require_role(required_role: str):
return role_checker return role_checker
def require_any_role(*roles: str): def require_any_role(*roles: str) -> Callable[..., object]:
"""Return a FastAPI dependency that enforces **any** of the given *roles*. """Return a FastAPI dependency that enforces **any** of the given *roles*.
Admins always pass. Usage example:: Admins always pass. Usage example::
+5 -2
View File
@@ -8,10 +8,13 @@ from __future__ import annotations
import enum import enum
import uuid import uuid
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any from typing import TYPE_CHECKING
from app.domain.errors import BusinessRuleViolation, InvalidStateTransition from app.domain.errors import BusinessRuleViolation, InvalidStateTransition
if TYPE_CHECKING:
from app.models.campaign import Campaign as CampaignORM
class CampaignStatus(str, enum.Enum): class CampaignStatus(str, enum.Enum):
draft = "draft" draft = "draft"
@@ -86,7 +89,7 @@ class CampaignEntity:
) )
@classmethod @classmethod
def from_orm(cls, orm: Any) -> CampaignEntity: def from_orm(cls, orm: CampaignORM) -> CampaignEntity:
"""Build a CampaignEntity from a SQLAlchemy Campaign model.""" """Build a CampaignEntity from a SQLAlchemy Campaign model."""
test_count = len(getattr(orm, "campaign_tests", None) or []) test_count = len(getattr(orm, "campaign_tests", None) or [])
return cls( return cls(
+6 -3
View File
@@ -17,11 +17,14 @@ from __future__ import annotations
import uuid import uuid
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import datetime from datetime import datetime
from typing import Any from typing import TYPE_CHECKING
from app.domain.enums import TechniqueStatus, TestResult, TestState from app.domain.enums import TechniqueStatus, TestResult, TestState
from app.domain.value_objects.mitre_id import MitreId from app.domain.value_objects.mitre_id import MitreId
if TYPE_CHECKING:
from app.models.technique import Technique as TechniqueORM
@dataclass(frozen=True) @dataclass(frozen=True)
class _TestSnapshot: class _TestSnapshot:
@@ -76,7 +79,7 @@ class TechniqueEntity:
) )
@classmethod @classmethod
def from_orm(cls, model: Any) -> TechniqueEntity: def from_orm(cls, model: TechniqueORM) -> TechniqueEntity:
"""Build a TechniqueEntity from a SQLAlchemy Technique model.""" """Build a TechniqueEntity from a SQLAlchemy Technique model."""
raw_status = model.status_global raw_status = model.status_global
if raw_status is None: if raw_status is None:
@@ -101,7 +104,7 @@ class TechniqueEntity:
mitre_last_modified=getattr(model, "mitre_last_modified", None), mitre_last_modified=getattr(model, "mitre_last_modified", None),
) )
def apply_to(self, model: Any) -> None: def apply_to(self, model: TechniqueORM) -> None:
"""Copy mutable fields back onto the ORM model.""" """Copy mutable fields back onto the ORM model."""
model.status_global = self.status_global model.status_global = self.status_global
model.review_required = self.review_required model.review_required = self.review_required
+5 -2
View File
@@ -7,7 +7,10 @@ from __future__ import annotations
import uuid import uuid
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any from typing import TYPE_CHECKING
if TYPE_CHECKING:
from app.models.threat_actor import ThreatActor as ThreatActorORM
@dataclass @dataclass
@@ -63,7 +66,7 @@ class ThreatActorEntity:
return round(len(self.covered_techniques) / len(self.techniques) * 100, 1) return round(len(self.covered_techniques) / len(self.techniques) * 100, 1)
@classmethod @classmethod
def from_orm(cls, orm: Any) -> ThreatActorEntity: def from_orm(cls, orm: ThreatActorORM) -> ThreatActorEntity:
techs: list[ThreatActorTechniqueRef] = [] techs: list[ThreatActorTechniqueRef] = []
for tat in getattr(orm, "techniques", None) or []: for tat in getattr(orm, "techniques", None) or []:
technique = getattr(tat, "technique", None) technique = getattr(tat, "technique", None)
+6 -3
View File
@@ -26,7 +26,7 @@ import enum
import uuid import uuid
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import datetime from datetime import datetime
from typing import Any from typing import TYPE_CHECKING, Any
from app.domain.errors import ( from app.domain.errors import (
BusinessRuleViolation, BusinessRuleViolation,
@@ -34,6 +34,9 @@ from app.domain.errors import (
InvalidStateTransition, InvalidStateTransition,
) )
if TYPE_CHECKING:
from app.models.test import Test as TestORM
# ── Value objects ──────────────────────────────────────────────────── # ── Value objects ────────────────────────────────────────────────────
@@ -103,7 +106,7 @@ class TestEntity:
# -- Factory -------------------------------------------------------- # -- Factory --------------------------------------------------------
@classmethod @classmethod
def from_orm(cls, model: Any) -> TestEntity: def from_orm(cls, model: TestORM) -> TestEntity:
"""Build a TestEntity from a SQLAlchemy ``Test`` model instance.""" """Build a TestEntity from a SQLAlchemy ``Test`` model instance."""
raw_state = model.state raw_state = model.state
state = raw_state if isinstance(raw_state, TestState) else TestState(raw_state) state = raw_state if isinstance(raw_state, TestState) else TestState(raw_state)
@@ -126,7 +129,7 @@ class TestEntity:
blue_paused_seconds=model.blue_paused_seconds or 0, blue_paused_seconds=model.blue_paused_seconds or 0,
) )
def apply_to(self, model: Any) -> None: def apply_to(self, model: TestORM) -> None:
"""Copy the entity's mutable fields back onto the ORM model.""" """Copy the entity's mutable fields back onto the ORM model."""
model.state = self.state model.state = self.state
model.red_validation_status = self.red_validation_status model.red_validation_status = self.red_validation_status
+8 -1
View File
@@ -22,6 +22,8 @@ Services should **never** call ``db.commit()``; they use ``db.add()`` /
from __future__ import annotations from __future__ import annotations
from types import TracebackType
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -36,7 +38,12 @@ class UnitOfWork:
def __enter__(self) -> "UnitOfWork": def __enter__(self) -> "UnitOfWork":
return self return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None: def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
if exc_type is not None: if exc_type is not None:
self.rollback() self.rollback()
@@ -197,7 +197,7 @@ class SATechniqueRepository:
# -- Internal ---------------------------------------------------------- # -- Internal ----------------------------------------------------------
@staticmethod @staticmethod
def _int_type(): def _int_type() -> type:
"""Return an Integer type for CAST expressions (SQLite-compatible).""" """Return an Integer type for CAST expressions (SQLite-compatible)."""
from sqlalchemy import Integer from sqlalchemy import Integer
return Integer return Integer
+6 -5
View File
@@ -1,5 +1,6 @@
import logging import logging
import os import os
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from fastapi import FastAPI, Request, status from fastapi import FastAPI, Request, status
@@ -53,7 +54,7 @@ setup_logging()
_IS_PRODUCTION = os.environ.get("AEGIS_ENV", "").lower() == "production" _IS_PRODUCTION = os.environ.get("AEGIS_ENV", "").lower() == "production"
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
"""Startup / shutdown logic.""" """Startup / shutdown logic."""
ensure_bucket_exists() ensure_bucket_exists()
start_scheduler() start_scheduler()
@@ -124,7 +125,7 @@ app.include_router(osint_router.router, prefix="/api/v1")
@app.get("/health", include_in_schema=False) @app.get("/health", include_in_schema=False)
def health(): def health() -> dict[str, str]:
"""Minimal health check — returns only an HTTP 200 with no service metadata. """Minimal health check — returns only an HTTP 200 with no service metadata.
Access is restricted to internal networks at the Nginx level Access is restricted to internal networks at the Nginx level
@@ -149,7 +150,7 @@ def _serialize_validation_errors(exc: RequestValidationError) -> list[dict]:
@app.exception_handler(RequestValidationError) @app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError): async def validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse:
"""Handle validation errors with consistent format.""" """Handle validation errors with consistent format."""
return JSONResponse( return JSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
@@ -162,7 +163,7 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE
@app.exception_handler(SQLAlchemyError) @app.exception_handler(SQLAlchemyError)
async def sqlalchemy_exception_handler(request: Request, exc: SQLAlchemyError): async def sqlalchemy_exception_handler(request: Request, exc: SQLAlchemyError) -> JSONResponse:
"""Handle database errors.""" """Handle database errors."""
logging.error(f"Database error: {exc}") logging.error(f"Database error: {exc}")
return JSONResponse( return JSONResponse(
@@ -175,7 +176,7 @@ async def sqlalchemy_exception_handler(request: Request, exc: SQLAlchemyError):
@app.exception_handler(Exception) @app.exception_handler(Exception)
async def general_exception_handler(request: Request, exc: Exception): async def general_exception_handler(request: Request, exc: Exception) -> JSONResponse:
"""Handle all unhandled exceptions.""" """Handle all unhandled exceptions."""
logging.error(f"Unhandled exception: {exc}") logging.error(f"Unhandled exception: {exc}")
return JSONResponse( return JSONResponse(
+7 -1
View File
@@ -1,9 +1,11 @@
"""Request context middleware — captures client IP and User-Agent per request.""" """Request context middleware — captures client IP and User-Agent per request."""
from collections.abc import Awaitable, Callable
from contextvars import ContextVar from contextvars import ContextVar
from fastapi import Request from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
request_ip: ContextVar[str] = ContextVar("request_ip", default="") request_ip: ContextVar[str] = ContextVar("request_ip", default="")
request_user_agent: ContextVar[str] = ContextVar("request_user_agent", default="") request_user_agent: ContextVar[str] = ContextVar("request_user_agent", default="")
@@ -20,7 +22,11 @@ def resolve_client_ip(request: Request) -> str:
class RequestContextMiddleware(BaseHTTPMiddleware): class RequestContextMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next): async def dispatch(
self,
request: Request,
call_next: Callable[[Request], Awaitable[Response]],
) -> Response:
request_ip.set(resolve_client_ip(request)) request_ip.set(resolve_client_ip(request))
request_user_agent.set(request.headers.get("User-Agent", "")) request_user_agent.set(request.headers.get("User-Agent", ""))
return await call_next(request) return await call_next(request)
+4 -4
View File
@@ -15,7 +15,7 @@ router = APIRouter(prefix="/metrics/advanced", tags=["advanced-metrics"])
def coverage_by_tactic( def coverage_by_tactic(
db: Session = Depends(get_db), db: Session = Depends(get_db),
user: User = Depends(get_current_user), user: User = Depends(get_current_user),
): ) -> list:
"""Coverage percentage broken down by MITRE ATT&CK tactic.""" """Coverage percentage broken down by MITRE ATT&CK tactic."""
return advanced_metrics_service.get_coverage_by_tactic(db) return advanced_metrics_service.get_coverage_by_tactic(db)
@@ -24,7 +24,7 @@ def coverage_by_tactic(
def never_tested_techniques( def never_tested_techniques(
db: Session = Depends(get_db), db: Session = Depends(get_db),
user: User = Depends(get_current_user), user: User = Depends(get_current_user),
): ) -> list:
"""Techniques that have never had a test created.""" """Techniques that have never had a test created."""
return advanced_metrics_service.get_never_tested_techniques(db) return advanced_metrics_service.get_never_tested_techniques(db)
@@ -33,7 +33,7 @@ def never_tested_techniques(
def avg_validation_time( def avg_validation_time(
db: Session = Depends(get_db), db: Session = Depends(get_db),
user: User = Depends(get_current_user), user: User = Depends(get_current_user),
): ) -> dict:
"""Average time from test creation to validation, computed from audit logs. """Average time from test creation to validation, computed from audit logs.
Returns overall average and per-phase averages where data is available. Returns overall average and per-phase averages where data is available.
@@ -45,6 +45,6 @@ def avg_validation_time(
def detection_rate_trend( def detection_rate_trend(
db: Session = Depends(get_db), db: Session = Depends(get_db),
user: User = Depends(get_current_user), user: User = Depends(get_current_user),
): ) -> list:
"""Monthly detection rate trend for the last 12 months.""" """Monthly detection rate trend for the last 12 months."""
return advanced_metrics_service.get_detection_rate_trend(db) return advanced_metrics_service.get_detection_rate_trend(db)
+4 -4
View File
@@ -19,7 +19,7 @@ router = APIRouter(prefix="/analytics", tags=["analytics"])
def analytics_coverage( def analytics_coverage(
db: Session = Depends(get_db), db: Session = Depends(get_db),
user: User = Depends(get_current_user), user: User = Depends(get_current_user),
): ) -> list:
"""Coverage per technique — flat format for BI dashboards.""" """Coverage per technique — flat format for BI dashboards."""
return analytics_service.get_coverage_analytics(db) return analytics_service.get_coverage_analytics(db)
@@ -30,7 +30,7 @@ def analytics_tests(
date_to: str = Query(None, description="ISO date filter (<=)"), date_to: str = Query(None, description="ISO date filter (<=)"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
user: User = Depends(get_current_user), user: User = Depends(get_current_user),
): ) -> list:
"""All tests with timestamps — flat format for BI dashboards.""" """All tests with timestamps — flat format for BI dashboards."""
return analytics_service.get_tests_analytics( return analytics_service.get_tests_analytics(
db, date_from=date_from, date_to=date_to db, date_from=date_from, date_to=date_to
@@ -41,7 +41,7 @@ def analytics_tests(
def analytics_trends( def analytics_trends(
db: Session = Depends(get_db), db: Session = Depends(get_db),
user: User = Depends(get_current_user), user: User = Depends(get_current_user),
): ) -> list:
"""Historical coverage snapshots for trend visualization.""" """Historical coverage snapshots for trend visualization."""
return analytics_service.get_trends_analytics(db) return analytics_service.get_trends_analytics(db)
@@ -50,6 +50,6 @@ def analytics_trends(
def analytics_operators( def analytics_operators(
db: Session = Depends(get_db), db: Session = Depends(get_db),
user: User = Depends(require_role("admin")), user: User = Depends(require_role("admin")),
): ) -> list:
"""Per-operator metrics — for workload management dashboards.""" """Per-operator metrics — for workload management dashboards."""
return analytics_service.get_operators_analytics(db) return analytics_service.get_operators_analytics(db)
+3 -3
View File
@@ -30,7 +30,7 @@ def list_audit_logs(
limit: int = Query(50, ge=1, le=100, description="Max records to return"), limit: int = Query(50, ge=1, le=100, description="Max records to return"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")), current_user: User = Depends(require_role("admin")),
): ) -> AuditLogPage:
"""Return paginated audit logs with optional filters. """Return paginated audit logs with optional filters.
**Requires admin role.** **Requires admin role.**
@@ -57,7 +57,7 @@ def list_audit_logs(
def list_actions( def list_actions(
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")), current_user: User = Depends(require_role("admin")),
): ) -> list[str]:
"""Return a list of distinct action types in the audit log. """Return a list of distinct action types in the audit log.
**Requires admin role.** **Requires admin role.**
@@ -69,7 +69,7 @@ def list_actions(
def list_entity_types( def list_entity_types(
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")), current_user: User = Depends(require_role("admin")),
): ) -> list[str]:
"""Return a list of distinct entity types in the audit log. """Return a list of distinct entity types in the audit log.
**Requires admin role.** **Requires admin role.**
+4 -4
View File
@@ -46,7 +46,7 @@ def login(
response: Response, response: Response,
form_data: OAuth2PasswordRequestForm = Depends(), form_data: OAuth2PasswordRequestForm = Depends(),
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ) -> TokenResponse:
"""Authenticate a user and return a JWT access token. """Authenticate a user and return a JWT access token.
Rate-limited to **5 attempts per minute per IP**. Failed and successful Rate-limited to **5 attempts per minute per IP**. Failed and successful
@@ -110,7 +110,7 @@ def logout(
request: Request, request: Request,
response: Response, response: Response,
aegis_token: str | None = Cookie(None), aegis_token: str | None = Cookie(None),
): ) -> dict:
"""Clear the authentication cookie and revoke the current token.""" """Clear the authentication cookie and revoke the current token."""
bearer = ( bearer = (
request.headers.get("Authorization") request.headers.get("Authorization")
@@ -148,7 +148,7 @@ def logout(
@router.get("/me", response_model=UserOut) @router.get("/me", response_model=UserOut)
def read_current_user(current_user: User = Depends(get_current_user)): def read_current_user(current_user: User = Depends(get_current_user)) -> UserOut:
"""Return the profile of the currently authenticated user.""" """Return the profile of the currently authenticated user."""
return current_user return current_user
@@ -158,7 +158,7 @@ def change_password(
body: PasswordChange, body: PasswordChange,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> dict:
"""Change the current user's password.""" """Change the current user's password."""
auth_change_password( auth_change_password(
db, db,
+12 -12
View File
@@ -107,7 +107,7 @@ def list_campaigns(
limit: int = Query(50, ge=1, le=200), limit: int = Query(50, ge=1, le=200),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> list:
"""List campaigns with optional filters and pagination.""" """List campaigns with optional filters and pagination."""
return crud_list( return crud_list(
db, db,
@@ -129,7 +129,7 @@ def create_campaign(
payload: CampaignCreate, payload: CampaignCreate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")), current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
): ) -> dict:
"""Create a new campaign.""" """Create a new campaign."""
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
result = crud_create( result = crud_create(
@@ -165,7 +165,7 @@ def get_campaign(
campaign_id: str, campaign_id: str,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> dict:
"""Get detailed campaign info including tests and progress.""" """Get detailed campaign info including tests and progress."""
return crud_get_detail(db, campaign_id) return crud_get_detail(db, campaign_id)
@@ -180,7 +180,7 @@ def update_campaign(
payload: CampaignUpdate, payload: CampaignUpdate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")), current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
): ) -> dict:
"""Update a campaign. Only allowed in draft or active state.""" """Update a campaign. Only allowed in draft or active state."""
update_data = payload.model_dump(exclude_unset=True) update_data = payload.model_dump(exclude_unset=True)
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
@@ -214,7 +214,7 @@ def add_test_to_campaign(
payload: AddTestPayload, payload: AddTestPayload,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")), current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
): ) -> dict:
"""Add a test to a campaign with optional ordering and dependency.""" """Add a test to a campaign with optional ordering and dependency."""
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
result = crud_add_test( result = crud_add_test(
@@ -239,7 +239,7 @@ def remove_test_from_campaign(
campaign_test_id: str, campaign_test_id: str,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")), current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
): ) -> dict:
"""Remove a test from a campaign.""" """Remove a test from a campaign."""
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
crud_remove_test(db, campaign_id, campaign_test_id) crud_remove_test(db, campaign_id, campaign_test_id)
@@ -256,7 +256,7 @@ def activate_campaign(
campaign_id: str, campaign_id: str,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")), current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
): ) -> dict:
"""Activate a campaign, moving it from draft to active.""" """Activate a campaign, moving it from draft to active."""
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
campaign = crud_activate(db, campaign_id) campaign = crud_activate(db, campaign_id)
@@ -292,7 +292,7 @@ def complete_campaign(
campaign_id: str, campaign_id: str,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "admin")), current_user: User = Depends(require_any_role("red_lead", "admin")),
): ) -> dict:
"""Mark a campaign as completed.""" """Mark a campaign as completed."""
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
campaign = crud_complete(db, campaign_id) campaign = crud_complete(db, campaign_id)
@@ -319,7 +319,7 @@ def get_campaign_progress_endpoint(
campaign_id: str, campaign_id: str,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> dict:
"""Get progress statistics for a campaign.""" """Get progress statistics for a campaign."""
return crud_get_progress(db, campaign_id) return crud_get_progress(db, campaign_id)
@@ -333,7 +333,7 @@ def generate_campaign_from_actor(
actor_id: str, actor_id: str,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")), current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
): ) -> dict:
"""Auto-generate a campaign from a threat actor's uncovered techniques. """Auto-generate a campaign from a threat actor's uncovered techniques.
Creates tests from the best available templates and orders them Creates tests from the best available templates and orders them
@@ -369,7 +369,7 @@ def schedule_campaign(
payload: SchedulePayload, payload: SchedulePayload,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")), current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
): ) -> dict:
"""Configure or update the recurrence schedule for a campaign. """Configure or update the recurrence schedule for a campaign.
Only the campaign creator or admin can change scheduling. Only the campaign creator or admin can change scheduling.
@@ -411,6 +411,6 @@ def get_campaign_history(
campaign_id: str, campaign_id: str,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> list:
"""List all child campaigns (execution history) of a recurring campaign.""" """List all child campaigns (execution history) of a recurring campaign."""
return crud_get_history(db, campaign_id) return crud_get_history(db, campaign_id)
+7 -7
View File
@@ -34,7 +34,7 @@ router = APIRouter(prefix="/compliance", tags=["compliance"])
def list_frameworks_endpoint( def list_frameworks_endpoint(
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> list:
"""List all available compliance frameworks.""" """List all available compliance frameworks."""
return list_frameworks(db) return list_frameworks(db)
@@ -47,7 +47,7 @@ def framework_status(
framework_id: str, framework_id: str,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> dict:
"""Get compliance status for each control in a framework.""" """Get compliance status for each control in a framework."""
return get_framework_status(db, framework_id) return get_framework_status(db, framework_id)
@@ -60,7 +60,7 @@ def framework_report(
framework_id: str, framework_id: str,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> dict:
"""Get the full compliance report (same as status but marked as report).""" """Get the full compliance report (same as status but marked as report)."""
return get_framework_status(db, framework_id) return get_framework_status(db, framework_id)
@@ -73,7 +73,7 @@ def framework_report_csv(
framework_id: str, framework_id: str,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> StreamingResponse:
"""Export compliance report as CSV.""" """Export compliance report as CSV."""
csv_bytes, filename = build_framework_report_csv(db, framework_id) csv_bytes, filename = build_framework_report_csv(db, framework_id)
return StreamingResponse( return StreamingResponse(
@@ -93,7 +93,7 @@ def framework_gaps(
framework_id: str, framework_id: str,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> dict:
"""Get controls with techniques that are not adequately covered.""" """Get controls with techniques that are not adequately covered."""
return get_framework_gaps(db, framework_id) return get_framework_gaps(db, framework_id)
@@ -105,7 +105,7 @@ def framework_gaps(
def import_nist( def import_nist(
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")), current_user: User = Depends(require_role("admin")),
): ) -> dict:
"""Import NIST 800-53 Rev 5 mappings (admin only).""" """Import NIST 800-53 Rev 5 mappings (admin only)."""
result = import_nist_800_53_mappings(db) result = import_nist_800_53_mappings(db)
return result return result
@@ -115,7 +115,7 @@ def import_nist(
def import_cis( def import_cis(
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")), current_user: User = Depends(require_role("admin")),
): ) -> dict:
"""Import CIS Controls v8 mappings (admin only).""" """Import CIS Controls v8 mappings (admin only)."""
result = import_cis_controls_v8_mappings(db) result = import_cis_controls_v8_mappings(db)
return result return result
+4 -4
View File
@@ -38,7 +38,7 @@ def list_defensive_techniques(
limit: int = Query(50, ge=1, le=200), limit: int = Query(50, ge=1, le=200),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> list:
"""List all D3FEND defensive techniques with optional filters.""" """List all D3FEND defensive techniques with optional filters."""
return list_defensive_techniques_svc( return list_defensive_techniques_svc(
db, tactic=tactic, search=search, offset=offset, limit=limit db, tactic=tactic, search=search, offset=offset, limit=limit
@@ -53,7 +53,7 @@ def list_defensive_techniques(
def list_d3fend_tactics_endpoint( def list_d3fend_tactics_endpoint(
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> list:
"""Return a list of all D3FEND tactics with counts.""" """Return a list of all D3FEND tactics with counts."""
return list_d3fend_tactics(db) return list_d3fend_tactics(db)
@@ -67,7 +67,7 @@ def get_defenses_for_attack_technique_endpoint(
mitre_id: str, mitre_id: str,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> list:
"""Get all D3FEND defensive techniques mapped to a given ATT&CK technique.""" """Get all D3FEND defensive techniques mapped to a given ATT&CK technique."""
return get_defenses_for_attack_technique(db, mitre_id) return get_defenses_for_attack_technique(db, mitre_id)
@@ -80,7 +80,7 @@ def get_defenses_for_attack_technique_endpoint(
def trigger_d3fend_import( def trigger_d3fend_import(
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")), current_user: User = Depends(require_role("admin")),
): ) -> dict:
"""Import D3FEND techniques and ATT&CK mappings. Admin only.""" """Import D3FEND techniques and ATT&CK mappings. Admin only."""
tech_result = import_d3fend_techniques(db) tech_result = import_d3fend_techniques(db)
mapping_result = import_d3fend_mappings(db) mapping_result = import_d3fend_mappings(db)
+5 -5
View File
@@ -47,7 +47,7 @@ router = APIRouter(prefix="/data-sources", tags=["data-sources"])
def list_data_sources( def list_data_sources(
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")), current_user: User = Depends(require_role("admin")),
): ) -> list:
"""List all registered data sources. """List all registered data sources.
**Requires** the ``admin`` role. **Requires** the ``admin`` role.
@@ -61,7 +61,7 @@ def update_data_source(
body: DataSourceUpdate, body: DataSourceUpdate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")), current_user: User = Depends(require_role("admin")),
): ) -> dict:
"""Update a data source (enable/disable, change config). """Update a data source (enable/disable, change config).
**Requires** the ``admin`` role. **Requires** the ``admin`` role.
@@ -87,7 +87,7 @@ def sync_data_source(
source_id: str, source_id: str,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")), current_user: User = Depends(require_role("admin")),
): ) -> dict:
"""Trigger sync/import for a specific data source. """Trigger sync/import for a specific data source.
**Requires** the ``admin`` role. **Requires** the ``admin`` role.
@@ -99,7 +99,7 @@ def sync_data_source(
def sync_all_data_sources( def sync_all_data_sources(
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")), current_user: User = Depends(require_role("admin")),
): ) -> dict:
"""Trigger sync for all enabled data sources (sequentially). """Trigger sync for all enabled data sources (sequentially).
**Requires** the ``admin`` role. **Requires** the ``admin`` role.
@@ -125,7 +125,7 @@ def get_data_source_stats(
source_id: str, source_id: str,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")), current_user: User = Depends(require_role("admin")),
): ) -> dict:
"""Get detailed statistics for a specific data source. """Get detailed statistics for a specific data source.
**Requires** the ``admin`` role. **Requires** the ``admin`` role.
+5 -5
View File
@@ -51,7 +51,7 @@ def list_detection_rules(
limit: int = Query(50, ge=1, le=200), limit: int = Query(50, ge=1, le=200),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> list:
"""List detection rules with optional filters and pagination.""" """List detection rules with optional filters and pagination."""
return list_rules( return list_rules(
db, db,
@@ -72,7 +72,7 @@ def get_detection_rules_for_template(
template_id: str, template_id: str,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> list:
"""Get detection rules associated with a test template.""" """Get detection rules associated with a test template."""
return get_rules_for_template(db, template_id) return get_rules_for_template(db, template_id)
@@ -84,7 +84,7 @@ def get_detection_rules_for_template(
def auto_associate_detection_rules( def auto_associate_detection_rules(
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")), current_user: User = Depends(require_role("admin")),
): ) -> dict:
"""Auto-associate test templates with detection rules by MITRE technique ID. """Auto-associate test templates with detection rules by MITRE technique ID.
For each active template, find all active detection rules for the same For each active template, find all active detection rules for the same
@@ -102,7 +102,7 @@ def get_detection_rules_for_test(
test_id: str, test_id: str,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> list:
"""Get detection rules relevant to a test, along with their evaluation results. """Get detection rules relevant to a test, along with their evaluation results.
Finds rules by matching the test's technique_id to detection rules, Finds rules by matching the test's technique_id to detection rules,
@@ -119,7 +119,7 @@ def evaluate_detection_rule(
payload: DetectionRuleEvaluate, payload: DetectionRuleEvaluate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("blue_tech", "blue_lead")), current_user: User = Depends(require_any_role("blue_tech", "blue_lead")),
): ) -> dict:
"""Save or update the evaluation result for a detection rule on a test.""" """Save or update the evaluation result for a detection rule on a test."""
return evaluate_rule( return evaluate_rule(
db, db,
+4 -4
View File
@@ -88,7 +88,7 @@ async def upload_evidence(
notes: Optional[str] = Form(None), notes: Optional[str] = Form(None),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> EvidenceOut:
"""Upload a file as evidence for the given test. """Upload a file as evidence for the given test.
The ``team`` field (sent as form data) determines whether this is The ``team`` field (sent as form data) determines whether this is
@@ -154,7 +154,7 @@ def list_evidence(
team: Optional[str] = Query(None, description="Filter by team: red or blue"), team: Optional[str] = Query(None, description="Filter by team: red or blue"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> list[EvidenceOut]:
"""List all evidences for a test, optionally filtered by team.""" """List all evidences for a test, optionally filtered by team."""
get_test_or_raise(db, test_id) get_test_or_raise(db, test_id)
evidences = list_evidence_for_test(db, test_id, team=team) evidences = list_evidence_for_test(db, test_id, team=team)
@@ -171,7 +171,7 @@ def get_evidence(
evidence_id: _uuid.UUID, evidence_id: _uuid.UUID,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> EvidenceOut:
"""Return evidence metadata together with a presigned download URL.""" """Return evidence metadata together with a presigned download URL."""
evidence = get_evidence_or_raise(db, evidence_id) evidence = get_evidence_or_raise(db, evidence_id)
return _evidence_to_out(evidence) return _evidence_to_out(evidence)
@@ -187,7 +187,7 @@ def delete_evidence(
evidence_id: _uuid.UUID, evidence_id: _uuid.UUID,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> dict:
"""Delete an evidence record. """Delete an evidence record.
Only allowed in editable states: Only allowed in editable states:
+5 -5
View File
@@ -28,7 +28,7 @@ def heatmap_coverage(
min_score: int = Query(0, ge=0, le=100), min_score: int = Query(0, ge=0, le=100),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> dict:
"""Coverage layer — score based on status_global of each technique.""" """Coverage layer — score based on status_global of each technique."""
return heatmap_service.build_coverage_layer( return heatmap_service.build_coverage_layer(
db, platforms=platforms, tactics=tactics, min_score=min_score, db, platforms=platforms, tactics=tactics, min_score=min_score,
@@ -43,7 +43,7 @@ def heatmap_threat_actor(
min_score: int = Query(0, ge=0, le=100), min_score: int = Query(0, ge=0, le=100),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> dict:
"""Threat actor layer — techniques used by an actor with coverage color.""" """Threat actor layer — techniques used by an actor with coverage color."""
return heatmap_service.build_threat_actor_layer( return heatmap_service.build_threat_actor_layer(
db, actor_id, platforms=platforms, tactics=tactics, min_score=min_score, db, actor_id, platforms=platforms, tactics=tactics, min_score=min_score,
@@ -57,7 +57,7 @@ def heatmap_detection_rules(
min_score: int = Query(0, ge=0, le=100), min_score: int = Query(0, ge=0, le=100),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> dict:
"""Detection rules layer — score based on ratio of rules available vs total.""" """Detection rules layer — score based on ratio of rules available vs total."""
return heatmap_service.build_detection_rules_layer( return heatmap_service.build_detection_rules_layer(
db, platforms=platforms, tactics=tactics, min_score=min_score, db, platforms=platforms, tactics=tactics, min_score=min_score,
@@ -72,7 +72,7 @@ def heatmap_campaign(
min_score: int = Query(0, ge=0, le=100), min_score: int = Query(0, ge=0, le=100),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> dict:
"""Campaign layer — only techniques in the campaign, colored by test state.""" """Campaign layer — only techniques in the campaign, colored by test state."""
return heatmap_service.build_campaign_layer( return heatmap_service.build_campaign_layer(
db, campaign_id, platforms=platforms, tactics=tactics, min_score=min_score, db, campaign_id, platforms=platforms, tactics=tactics, min_score=min_score,
@@ -88,7 +88,7 @@ def export_navigator(
min_score: int = Query(0, ge=0, le=100), min_score: int = Query(0, ge=0, le=100),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> StreamingResponse:
"""Export a heatmap layer as a downloadable JSON file for ATT&CK Navigator.""" """Export a heatmap layer as a downloadable JSON file for ATT&CK Navigator."""
data = heatmap_service.build_navigator_export( data = heatmap_service.build_navigator_export(
db, layer, layer_id=layer_id, db, layer, layer_id=layer_id,
+6 -6
View File
@@ -29,7 +29,7 @@ def search_issues(
q: str = Query(..., min_length=2), q: str = Query(..., min_length=2),
max_results: int = Query(10, le=50), max_results: int = Query(10, le=50),
user: User = Depends(get_current_user), user: User = Depends(get_current_user),
): ) -> list[JiraIssueResult]:
"""Search Jira issues by JQL or free text.""" """Search Jira issues by JQL or free text."""
return jira_service.search_jira_issues(q, max_results) return jira_service.search_jira_issues(q, max_results)
@@ -39,7 +39,7 @@ def create_link(
body: JiraLinkCreate, body: JiraLinkCreate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
user: User = Depends(get_current_user), user: User = Depends(get_current_user),
): ) -> JiraLinkOut:
"""Associate an Aegis entity with a Jira issue.""" """Associate an Aegis entity with a Jira issue."""
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
link = jira_service.create_link( link = jira_service.create_link(
@@ -74,7 +74,7 @@ def list_links(
entity_id: Optional[UUID] = None, entity_id: Optional[UUID] = None,
db: Session = Depends(get_db), db: Session = Depends(get_db),
user: User = Depends(get_current_user), user: User = Depends(get_current_user),
): ) -> list[JiraLinkOut]:
"""List Jira links, optionally filtered by entity.""" """List Jira links, optionally filtered by entity."""
return jira_service.list_links( return jira_service.list_links(
db, db,
@@ -88,7 +88,7 @@ def sync_link(
link_id: UUID, link_id: UUID,
db: Session = Depends(get_db), db: Session = Depends(get_db),
user: User = Depends(require_role("admin")), user: User = Depends(require_role("admin")),
): ) -> dict:
"""Force bidirectional sync for a specific Jira link.""" """Force bidirectional sync for a specific Jira link."""
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
link = jira_service.get_link_or_raise(db, link_id) link = jira_service.get_link_or_raise(db, link_id)
@@ -102,7 +102,7 @@ def delete_link(
link_id: UUID, link_id: UUID,
db: Session = Depends(get_db), db: Session = Depends(get_db),
user: User = Depends(get_current_user), user: User = Depends(get_current_user),
): ) -> None:
"""Remove a Jira link.""" """Remove a Jira link."""
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
link = jira_service.delete_link(db, link_id) link = jira_service.delete_link(db, link_id)
@@ -123,7 +123,7 @@ def create_issue_from_entity(
entity_id: UUID, entity_id: UUID,
db: Session = Depends(get_db), db: Session = Depends(get_db),
user: User = Depends(get_current_user), user: User = Depends(get_current_user),
): ) -> dict:
"""Auto-create a Jira issue from an Aegis entity and link them.""" """Auto-create a Jira issue from an Aegis entity and link them."""
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
result = jira_service.create_issue_and_link( result = jira_service.create_issue_and_link(
+6 -6
View File
@@ -42,7 +42,7 @@ router = APIRouter(prefix="/metrics", tags=["metrics"])
def coverage_summary( def coverage_summary(
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> CoverageSummary:
"""Return a global coverage summary across all techniques.""" """Return a global coverage summary across all techniques."""
return get_coverage_summary(db) return get_coverage_summary(db)
@@ -56,7 +56,7 @@ def coverage_summary(
def coverage_by_tactic( def coverage_by_tactic(
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> list[TacticCoverage]:
"""Return coverage breakdown grouped by tactic.""" """Return coverage breakdown grouped by tactic."""
return get_coverage_by_tactic(db) return get_coverage_by_tactic(db)
@@ -70,7 +70,7 @@ def coverage_by_tactic(
def test_pipeline( def test_pipeline(
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> TestPipelineCounts:
"""Return how many tests are in each pipeline state.""" """Return how many tests are in each pipeline state."""
return get_test_pipeline_counts(db) return get_test_pipeline_counts(db)
@@ -84,7 +84,7 @@ def test_pipeline(
def team_activity( def team_activity(
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> list[TeamActivity]:
"""Return activity summary for Red and Blue teams.""" """Return activity summary for Red and Blue teams."""
return get_team_activity(db) return get_team_activity(db)
@@ -98,7 +98,7 @@ def team_activity(
def validation_rate( def validation_rate(
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> list[ValidationRate]:
"""Return approval and rejection rates for Red Lead and Blue Lead.""" """Return approval and rejection rates for Red Lead and Blue Lead."""
return get_validation_rate(db) return get_validation_rate(db)
@@ -112,6 +112,6 @@ def validation_rate(
def recent_tests( def recent_tests(
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> list[RecentTestItem]:
"""Return the 10 most recently created tests.""" """Return the 10 most recently created tests."""
return get_recent_tests(db, limit=10) return get_recent_tests(db, limit=10)
+4 -4
View File
@@ -39,7 +39,7 @@ def list_notifications_endpoint(
limit: int = Query(20, ge=1, le=100), limit: int = Query(20, ge=1, le=100),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> list[NotificationOut]:
"""Return paginated notifications for the current user, newest first.""" """Return paginated notifications for the current user, newest first."""
return list_notifications(db, current_user.id, offset=offset, limit=limit) return list_notifications(db, current_user.id, offset=offset, limit=limit)
@@ -53,7 +53,7 @@ def list_notifications_endpoint(
def unread_count( def unread_count(
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> UnreadCountOut:
"""Return the number of unread notifications for the current user.""" """Return the number of unread notifications for the current user."""
count = get_unread_count(db, current_user.id) count = get_unread_count(db, current_user.id)
return UnreadCountOut(unread_count=count) return UnreadCountOut(unread_count=count)
@@ -69,7 +69,7 @@ def read_notification(
notification_id: uuid.UUID, notification_id: uuid.UUID,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> NotificationOut:
"""Mark a single notification as read.""" """Mark a single notification as read."""
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
notif = mark_as_read(db, notification_id, current_user.id) notif = mark_as_read(db, notification_id, current_user.id)
@@ -86,7 +86,7 @@ def read_notification(
def read_all_notifications( def read_all_notifications(
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> dict:
"""Mark all notifications for the current user as read.""" """Mark all notifications for the current user as read."""
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
count = mark_all_as_read(db, current_user.id) count = mark_all_as_read(db, current_user.id)
+3 -3
View File
@@ -25,7 +25,7 @@ router = APIRouter(prefix="/metrics/operational", tags=["operational-metrics"])
def operational_metrics( def operational_metrics(
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> dict:
"""Get all operational metrics (MTTD, MTTR, etc.) — cached for 5 min.""" """Get all operational metrics (MTTD, MTTR, etc.) — cached for 5 min."""
from app.services.score_cache import get_operational_metrics_cached from app.services.score_cache import get_operational_metrics_cached
@@ -40,7 +40,7 @@ def operational_trend(
period: str = Query("90d", pattern="^(30d|90d|1y)$"), period: str = Query("90d", pattern="^(30d|90d|1y)$"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> dict:
"""Get weekly trend data for operational metrics.""" """Get weekly trend data for operational metrics."""
return get_operational_trend(db, period) return get_operational_trend(db, period)
@@ -52,6 +52,6 @@ def operational_trend(
def metrics_by_team( def metrics_by_team(
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> dict:
"""Get metrics broken down by Red Team vs Blue Team.""" """Get metrics broken down by Red Team vs Blue Team."""
return get_metrics_by_team(db) return get_metrics_by_team(db)
+5 -5
View File
@@ -57,7 +57,7 @@ def list_osint_items(
limit: int = Query(50, ge=1, le=200), limit: int = Query(50, ge=1, le=200),
db: Session = Depends(get_db), db: Session = Depends(get_db),
user: User = Depends(get_current_user), user: User = Depends(get_current_user),
): ) -> list:
"""List OSINT items with optional filters.""" """List OSINT items with optional filters."""
return service_list_osint_items( return service_list_osint_items(
db, db,
@@ -73,7 +73,7 @@ def list_osint_items(
def osint_summary( def osint_summary(
db: Session = Depends(get_db), db: Session = Depends(get_db),
user: User = Depends(get_current_user), user: User = Depends(get_current_user),
): ) -> dict:
"""Summary statistics for OSINT items.""" """Summary statistics for OSINT items."""
return get_osint_summary(db) return get_osint_summary(db)
@@ -83,7 +83,7 @@ def review_osint_item(
item_id: UUID, item_id: UUID,
db: Session = Depends(get_db), db: Session = Depends(get_db),
user: User = Depends(get_current_user), user: User = Depends(get_current_user),
): ) -> dict:
"""Mark an OSINT item as reviewed.""" """Mark an OSINT item as reviewed."""
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
item = mark_osint_reviewed(db, str(item_id)) item = mark_osint_reviewed(db, str(item_id))
@@ -101,7 +101,7 @@ def trigger_technique_enrichment(
technique_id: UUID, technique_id: UUID,
db: Session = Depends(get_db), db: Session = Depends(get_db),
user: User = Depends(require_any_role("red_lead", "blue_lead")), user: User = Depends(require_any_role("red_lead", "blue_lead")),
): ) -> dict:
"""Manually trigger OSINT enrichment for a single technique.""" """Manually trigger OSINT enrichment for a single technique."""
technique = get_technique_or_raise(db, technique_id) technique = get_technique_or_raise(db, technique_id)
count = enrich_technique_with_cves(db, technique) count = enrich_technique_with_cves(db, technique)
@@ -119,7 +119,7 @@ def get_technique_osint(
reviewed: bool | None = Query(None), reviewed: bool | None = Query(None),
db: Session = Depends(get_db), db: Session = Depends(get_db),
user: User = Depends(get_current_user), user: User = Depends(get_current_user),
): ) -> list:
"""Get all OSINT items for a specific technique.""" """Get all OSINT items for a specific technique."""
items = get_osint_items_for_technique( items = get_osint_items_for_technique(
db, db,
+5 -5
View File
@@ -29,7 +29,7 @@ def generate_purple_report(
format: str = Query("pdf", pattern="^(pdf|docx|html)$"), format: str = Query("pdf", pattern="^(pdf|docx|html)$"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
user: User = Depends(require_any_role("red_lead", "blue_lead", "viewer")), user: User = Depends(require_any_role("red_lead", "blue_lead", "viewer")),
): ) -> FileResponse:
"""Generate a Purple Team campaign assessment report.""" """Generate a Purple Team campaign assessment report."""
filepath = report_generation_service.generate_purple_campaign_report( filepath = report_generation_service.generate_purple_campaign_report(
db, str(campaign_id), output_format=format, db, str(campaign_id), output_format=format,
@@ -48,7 +48,7 @@ def generate_coverage_report(
format: str = Query("pdf", pattern="^(pdf|docx|html)$"), format: str = Query("pdf", pattern="^(pdf|docx|html)$"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
user: User = Depends(require_any_role("red_lead", "blue_lead", "viewer")), user: User = Depends(require_any_role("red_lead", "blue_lead", "viewer")),
): ) -> FileResponse:
"""Generate an organization-wide MITRE ATT&CK coverage report.""" """Generate an organization-wide MITRE ATT&CK coverage report."""
filepath = report_generation_service.generate_coverage_report( filepath = report_generation_service.generate_coverage_report(
db, output_format=format, db, output_format=format,
@@ -67,7 +67,7 @@ def generate_executive_report(
format: str = Query("pdf", pattern="^(pdf|docx|html)$"), format: str = Query("pdf", pattern="^(pdf|docx|html)$"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
user: User = Depends(require_any_role("red_lead", "blue_lead", "viewer")), user: User = Depends(require_any_role("red_lead", "blue_lead", "viewer")),
): ) -> FileResponse:
"""Generate an executive security summary report.""" """Generate an executive security summary report."""
filepath = report_generation_service.generate_executive_summary( filepath = report_generation_service.generate_executive_summary(
db, output_format=format, db, output_format=format,
@@ -86,7 +86,7 @@ def generate_quarterly_report(
format: str = Query("pdf", pattern="^(pdf|docx|html)$"), format: str = Query("pdf", pattern="^(pdf|docx|html)$"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
user: User = Depends(require_any_role("red_lead", "blue_lead", "viewer")), user: User = Depends(require_any_role("red_lead", "blue_lead", "viewer")),
): ) -> FileResponse:
"""Generate a quarterly security summary report.""" """Generate a quarterly security summary report."""
filepath = report_generation_service.generate_quarterly_summary( filepath = report_generation_service.generate_quarterly_summary(
db, output_format=format, db, output_format=format,
@@ -106,7 +106,7 @@ def generate_technique_report(
format: str = Query("pdf", pattern="^(pdf|docx|html)$"), format: str = Query("pdf", pattern="^(pdf|docx|html)$"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
user: User = Depends(get_current_user), user: User = Depends(get_current_user),
): ) -> FileResponse:
"""Generate a detailed report for one MITRE technique.""" """Generate a detailed report for one MITRE technique."""
filepath = report_generation_service.generate_technique_detail_report( filepath = report_generation_service.generate_technique_detail_report(
db, str(technique_id), output_format=format, db, str(technique_id), output_format=format,
+4 -4
View File
@@ -38,7 +38,7 @@ def coverage_summary(
platform: Optional[str] = Query(None, description="Filter by platform (in techniques)"), platform: Optional[str] = Query(None, description="Filter by platform (in techniques)"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> dict:
"""Full coverage report as JSON — technique-by-technique with test counts.""" """Full coverage report as JSON — technique-by-technique with test counts."""
return build_coverage_summary(db, tactic=tactic, platform=platform) return build_coverage_summary(db, tactic=tactic, platform=platform)
@@ -49,7 +49,7 @@ def coverage_csv(
platform: Optional[str] = Query(None), platform: Optional[str] = Query(None),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> StreamingResponse:
"""Export coverage as a downloadable CSV.""" """Export coverage as a downloadable CSV."""
rows = build_coverage_csv_rows(db, tactic=tactic, platform=platform) rows = build_coverage_csv_rows(db, tactic=tactic, platform=platform)
@@ -74,7 +74,7 @@ def test_results(
date_to: Optional[str] = Query(None, description="ISO date string YYYY-MM-DD"), date_to: Optional[str] = Query(None, description="ISO date string YYYY-MM-DD"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> dict:
"""Report of test results with optional filters.""" """Report of test results with optional filters."""
return build_test_results_report(db, state=state, date_from=date_from, date_to=date_to) return build_test_results_report(db, state=state, date_from=date_from, date_to=date_to)
@@ -84,6 +84,6 @@ def remediation_status(
status: Optional[str] = Query(None, description="Filter by remediation status"), status: Optional[str] = Query(None, description="Filter by remediation status"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> dict:
"""Report of remediation status across all tests.""" """Report of remediation status across all tests."""
return build_remediation_status_report(db, status=status) return build_remediation_status_report(db, status=status)
+7 -7
View File
@@ -35,7 +35,7 @@ def score_technique(
mitre_id: str, mitre_id: str,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> dict:
"""Get detailed score with breakdown for a specific technique.""" """Get detailed score with breakdown for a specific technique."""
return score_technique_by_mitre_id(db, mitre_id) return score_technique_by_mitre_id(db, mitre_id)
@@ -48,7 +48,7 @@ def score_tactic(
tactic: str, tactic: str,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> dict:
"""Get average score for a tactic.""" """Get average score for a tactic."""
return calculate_tactic_score(tactic, db) return calculate_tactic_score(tactic, db)
@@ -61,7 +61,7 @@ def score_threat_actor(
actor_id: str, actor_id: str,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> dict:
"""Get coverage score against a specific threat actor.""" """Get coverage score against a specific threat actor."""
return score_actor_by_id(db, actor_id) return score_actor_by_id(db, actor_id)
@@ -73,7 +73,7 @@ def score_threat_actor(
def score_organization( def score_organization(
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> dict:
"""Get the overall organization security score (cached for 5 min).""" """Get the overall organization security score (cached for 5 min)."""
from app.services.score_cache import get_organization_score_cached from app.services.score_cache import get_organization_score_cached
@@ -88,7 +88,7 @@ def score_history(
period: str = Query("90d", pattern="^(30d|90d|1y)$"), period: str = Query("90d", pattern="^(30d|90d|1y)$"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> dict:
"""Get historical score data points (weekly).""" """Get historical score data points (weekly)."""
return get_score_history(db, period) return get_score_history(db, period)
@@ -100,7 +100,7 @@ def score_history(
def get_scoring_config( def get_scoring_config(
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")), current_user: User = Depends(require_role("admin")),
): ) -> dict:
"""Get current scoring weights (admin only).""" """Get current scoring weights (admin only)."""
return get_weights_dict(db) return get_weights_dict(db)
@@ -123,7 +123,7 @@ def update_scoring_config(
payload: ScoringConfigUpdate, payload: ScoringConfigUpdate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")), current_user: User = Depends(require_role("admin")),
): ) -> dict:
"""Update scoring weights (admin only). """Update scoring weights (admin only).
Weights are persisted in the database and survive restarts. Weights are persisted in the database and survive restarts.
+6 -6
View File
@@ -52,7 +52,7 @@ def list_snapshots(
limit: int = Query(50, ge=1, le=200), limit: int = Query(50, ge=1, le=200),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> list:
"""List coverage snapshots ordered by creation date (newest first).""" """List coverage snapshots ordered by creation date (newest first)."""
return list_snapshots_svc(db, offset=offset, limit=limit) return list_snapshots_svc(db, offset=offset, limit=limit)
@@ -66,7 +66,7 @@ def create_snapshot_endpoint(
payload: SnapshotCreate, payload: SnapshotCreate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead", "admin")), current_user: User = Depends(require_any_role("red_lead", "blue_lead", "admin")),
): ) -> dict:
"""Create a manual coverage snapshot with an optional name.""" """Create a manual coverage snapshot with an optional name."""
snapshot = create_snapshot(db, name=payload.name, user_id=current_user.id) snapshot = create_snapshot(db, name=payload.name, user_id=current_user.id)
@@ -94,7 +94,7 @@ def coverage_evolution(
months: int = Query(12, ge=1, le=36), months: int = Query(12, ge=1, le=36),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> list:
"""Return coverage snapshots for trend charts (last *months* months).""" """Return coverage snapshots for trend charts (last *months* months)."""
return get_coverage_evolution(db, months=months) return get_coverage_evolution(db, months=months)
@@ -109,7 +109,7 @@ def compare_snapshots_endpoint(
b: str = Query(..., description="Snapshot B ID"), b: str = Query(..., description="Snapshot B ID"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> dict:
"""Compare two snapshots showing improved, worsened, and unchanged techniques.""" """Compare two snapshots showing improved, worsened, and unchanged techniques."""
try: try:
a_id = uuid.UUID(a) a_id = uuid.UUID(a)
@@ -129,7 +129,7 @@ def get_snapshot(
snapshot_id: str, snapshot_id: str,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> dict:
"""Get detailed snapshot information including per-technique states.""" """Get detailed snapshot information including per-technique states."""
return get_snapshot_detail(db, snapshot_id) return get_snapshot_detail(db, snapshot_id)
@@ -143,7 +143,7 @@ def delete_snapshot_endpoint(
snapshot_id: str, snapshot_id: str,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")), current_user: User = Depends(require_role("admin")),
): ) -> dict:
"""Delete a snapshot (admin only).""" """Delete a snapshot (admin only)."""
snapshot = get_snapshot_or_raise(db, snapshot_id) snapshot = get_snapshot_or_raise(db, snapshot_id)
+4 -4
View File
@@ -30,7 +30,7 @@ def trigger_mitre_sync(
request: Request, request: Request,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")), current_user: User = Depends(require_role("admin")),
): ) -> dict:
"""Manually trigger a MITRE ATT&CK synchronisation. """Manually trigger a MITRE ATT&CK synchronisation.
**Requires** the ``admin`` role. **Requires** the ``admin`` role.
@@ -50,7 +50,7 @@ def trigger_mitre_sync(
def trigger_intel_scan( def trigger_intel_scan(
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")), current_user: User = Depends(require_role("admin")),
): ) -> dict:
"""Manually trigger a threat-intelligence scan. """Manually trigger a threat-intelligence scan.
**Requires** the ``admin`` role. **Requires** the ``admin`` role.
@@ -71,7 +71,7 @@ def trigger_atomic_import(
request: Request, request: Request,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")), current_user: User = Depends(require_role("admin")),
): ) -> dict:
"""Trigger an import of Atomic Red Team tests as TestTemplates. """Trigger an import of Atomic Red Team tests as TestTemplates.
**Requires** the ``admin`` role. **Requires** the ``admin`` role.
@@ -101,7 +101,7 @@ def trigger_atomic_import(
@router.get("/scheduler-status") @router.get("/scheduler-status")
def scheduler_status( def scheduler_status(
current_user: User = Depends(require_role("admin")), current_user: User = Depends(require_role("admin")),
): ) -> dict:
"""Return the current state of the background scheduler. """Return the current state of the background scheduler.
**Requires** the ``admin`` role. **Requires** the ``admin`` role.
+5 -5
View File
@@ -45,7 +45,7 @@ def list_techniques(
review_required: bool | None = Query(None, description="Filter by review flag"), review_required: bool | None = Query(None, description="Filter by review flag"),
repo: SATechniqueRepository = Depends(get_technique_repository), repo: SATechniqueRepository = Depends(get_technique_repository),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> list:
"""Return a lightweight list of techniques, optionally filtered.""" """Return a lightweight list of techniques, optionally filtered."""
return repo.list_all( return repo.list_all(
tactic=tactic, tactic=tactic,
@@ -64,7 +64,7 @@ def get_technique(
mitre_id: str, mitre_id: str,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> dict:
"""Return full details for a single technique, including its tests and D3FEND defenses.""" """Return full details for a single technique, including its tests and D3FEND defenses."""
return get_technique_detail(db, mitre_id) return get_technique_detail(db, mitre_id)
@@ -84,7 +84,7 @@ def create_technique(
db: Session = Depends(get_db), db: Session = Depends(get_db),
repo: SATechniqueRepository = Depends(get_technique_repository), repo: SATechniqueRepository = Depends(get_technique_repository),
current_user: User = Depends(require_role("admin")), current_user: User = Depends(require_role("admin")),
): ) -> TechniqueOut:
"""Create a new technique manually.""" """Create a new technique manually."""
if repo.exists_by_mitre_id(payload.mitre_id): if repo.exists_by_mitre_id(payload.mitre_id):
raise DuplicateEntityError("Technique", "mitre_id", payload.mitre_id) raise DuplicateEntityError("Technique", "mitre_id", payload.mitre_id)
@@ -124,7 +124,7 @@ def update_technique(
db: Session = Depends(get_db), db: Session = Depends(get_db),
repo: SATechniqueRepository = Depends(get_technique_repository), repo: SATechniqueRepository = Depends(get_technique_repository),
current_user: User = Depends(require_role("admin")), current_user: User = Depends(require_role("admin")),
): ) -> TechniqueOut:
"""Update one or more fields of an existing technique.""" """Update one or more fields of an existing technique."""
entity = repo.find_by_mitre_id(mitre_id) entity = repo.find_by_mitre_id(mitre_id)
if entity is None: if entity is None:
@@ -160,7 +160,7 @@ def review_technique(
db: Session = Depends(get_db), db: Session = Depends(get_db),
repo: SATechniqueRepository = Depends(get_technique_repository), repo: SATechniqueRepository = Depends(get_technique_repository),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")), current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
): ) -> TechniqueOut:
"""Mark a technique as reviewed. """Mark a technique as reviewed.
Sets ``review_required`` to *False* and records the current timestamp Sets ``review_required`` to *False* and records the current timestamp
+9 -9
View File
@@ -78,7 +78,7 @@ def _list_templates_handler(
limit: int = Query(50, ge=1, le=200), limit: int = Query(50, ge=1, le=200),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> list:
"""Return a paginated, filterable list of test templates.""" """Return a paginated, filterable list of test templates."""
return list_templates( return list_templates(
db, db,
@@ -102,7 +102,7 @@ def _list_templates_handler(
def template_stats( def template_stats(
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")), current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
): ) -> dict:
"""Return catalog statistics: active, by_source, by_platform.""" """Return catalog statistics: active, by_source, by_platform."""
return get_template_stats(db) return get_template_stats(db)
@@ -117,7 +117,7 @@ def bulk_activate_templates(
activate: bool = Query(True, description="True to activate all, False to deactivate all"), activate: bool = Query(True, description="True to activate all, False to deactivate all"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")), current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
): ) -> dict:
"""Set all templates to active or inactive.""" """Set all templates to active or inactive."""
count = bulk_activate(db, activate=activate) count = bulk_activate(db, activate=activate)
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
@@ -148,7 +148,7 @@ def _templates_by_technique_handler(
mitre_id: str, mitre_id: str,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> list:
"""Return all active templates mapped to a specific MITRE technique.""" """Return all active templates mapped to a specific MITRE technique."""
return templates_by_technique(db, mitre_id) return templates_by_technique(db, mitre_id)
@@ -163,7 +163,7 @@ def get_template(
template_id: uuid.UUID, template_id: uuid.UUID,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> TestTemplateOut:
"""Return full details for a single test template.""" """Return full details for a single test template."""
return get_template_or_raise(db, template_id) return get_template_or_raise(db, template_id)
@@ -182,7 +182,7 @@ def create_template(
payload: TestTemplateCreate, payload: TestTemplateCreate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")), current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
): ) -> TestTemplateOut:
"""Create a custom test template.""" """Create a custom test template."""
template = create_template_svc(db, **payload.model_dump()) template = create_template_svc(db, **payload.model_dump())
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
@@ -215,7 +215,7 @@ def update_template(
payload: TestTemplateCreate, payload: TestTemplateCreate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")), current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
): ) -> TestTemplateOut:
"""Update fields of an existing test template.""" """Update fields of an existing test template."""
template = update_template_svc(db, template_id, **payload.model_dump(exclude_unset=True)) template = update_template_svc(db, template_id, **payload.model_dump(exclude_unset=True))
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
@@ -243,7 +243,7 @@ def toggle_template_active(
template_id: uuid.UUID, template_id: uuid.UUID,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")), current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
): ) -> TestTemplateOut:
"""Toggle a template between active and inactive (is_active = not is_active).""" """Toggle a template between active and inactive (is_active = not is_active)."""
template = toggle_template_active_svc(db, template_id) template = toggle_template_active_svc(db, template_id)
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
@@ -271,7 +271,7 @@ def delete_template(
template_id: uuid.UUID, template_id: uuid.UUID,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")), current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
): ) -> dict:
"""Soft-delete a test template by setting ``is_active=False``.""" """Soft-delete a test template by setting ``is_active=False``."""
template = get_template_or_raise(db, template_id) template = get_template_or_raise(db, template_id)
soft_delete_template(db, template_id) soft_delete_template(db, template_id)
+19 -19
View File
@@ -126,7 +126,7 @@ def list_tests(
limit: int = Query(50, ge=1, le=200), limit: int = Query(50, ge=1, le=200),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> list:
"""Return a paginated list of tests, optionally filtered by state, technique, platform or creator.""" """Return a paginated list of tests, optionally filtered by state, technique, platform or creator."""
return crud_list_tests( return crud_list_tests(
db, db,
@@ -156,7 +156,7 @@ def create_test(
payload: TestCreate, payload: TestCreate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")), current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
): ) -> TestOut:
"""Create a new test linked to an existing technique. """Create a new test linked to an existing technique.
``created_by`` is set automatically and ``state`` defaults to *draft*. ``created_by`` is set automatically and ``state`` defaults to *draft*.
@@ -198,7 +198,7 @@ def create_test_from_template(
payload: TestTemplateInstantiate, payload: TestTemplateInstantiate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")), current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
): ) -> TestOut:
"""Instantiate a real Test from an existing TestTemplate. """Instantiate a real Test from an existing TestTemplate.
The template's fields are copied into the new test as starting data. The template's fields are copied into the new test as starting data.
@@ -238,7 +238,7 @@ def get_test(
test_id: uuid.UUID, test_id: uuid.UUID,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> TestOut:
"""Return full details for a single test, including its evidences.""" """Return full details for a single test, including its evidences."""
return crud_get_test_detail(db, test_id) return crud_get_test_detail(db, test_id)
@@ -254,7 +254,7 @@ def update_test(
payload: TestUpdate, payload: TestUpdate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")), current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
): ) -> TestOut:
"""Update one or more fields of an existing test. """Update one or more fields of an existing test.
Only leads or admins can update general test fields. Only leads or admins can update general test fields.
@@ -294,7 +294,7 @@ def update_test_classification(
payload: TestClassificationUpdate, payload: TestClassificationUpdate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")), current_user: User = Depends(require_role("admin")),
): ) -> TestOut:
"""Update the data classification label for a test (admin only).""" """Update the data classification label for a test (admin only)."""
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
test = crud_get_test_or_raise(db, test_id) test = crud_get_test_or_raise(db, test_id)
@@ -324,7 +324,7 @@ def update_test_red(
payload: TestRedUpdate, payload: TestRedUpdate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_tech", "red_lead")), current_user: User = Depends(require_any_role("red_tech", "red_lead")),
): ) -> TestOut:
"""Red Team updates their fields (allowed in ``draft`` and ``red_executing``).""" """Red Team updates their fields (allowed in ``draft`` and ``red_executing``)."""
update_data = payload.model_dump(exclude_unset=True) update_data = payload.model_dump(exclude_unset=True)
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
@@ -354,7 +354,7 @@ def update_test_blue(
payload: TestBlueUpdate, payload: TestBlueUpdate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("blue_tech", "blue_lead")), current_user: User = Depends(require_any_role("blue_tech", "blue_lead")),
): ) -> TestOut:
"""Blue Team updates their fields (allowed only in ``blue_evaluating``).""" """Blue Team updates their fields (allowed only in ``blue_evaluating``)."""
update_data = payload.model_dump(exclude_unset=True) update_data = payload.model_dump(exclude_unset=True)
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
@@ -383,7 +383,7 @@ def start_execution(
test_id: uuid.UUID, test_id: uuid.UUID,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_tech", "red_lead")), current_user: User = Depends(require_any_role("red_tech", "red_lead")),
): ) -> TestOut:
"""Move a test from ``draft`` to ``red_executing``.""" """Move a test from ``draft`` to ``red_executing``."""
test = crud_get_test_or_raise(db, test_id) test = crud_get_test_or_raise(db, test_id)
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
@@ -403,7 +403,7 @@ def submit_red(
test_id: uuid.UUID, test_id: uuid.UUID,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_tech", "red_lead")), current_user: User = Depends(require_any_role("red_tech", "red_lead")),
): ) -> TestOut:
"""Red Team finalises — move from ``red_executing`` to ``blue_evaluating``.""" """Red Team finalises — move from ``red_executing`` to ``blue_evaluating``."""
test = crud_get_test_or_raise(db, test_id) test = crud_get_test_or_raise(db, test_id)
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
@@ -423,7 +423,7 @@ def submit_blue(
test_id: uuid.UUID, test_id: uuid.UUID,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("blue_tech", "blue_lead")), current_user: User = Depends(require_any_role("blue_tech", "blue_lead")),
): ) -> TestOut:
"""Blue Team finalises — move from ``blue_evaluating`` to ``in_review``.""" """Blue Team finalises — move from ``blue_evaluating`` to ``in_review``."""
test = crud_get_test_or_raise(db, test_id) test = crud_get_test_or_raise(db, test_id)
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
@@ -443,7 +443,7 @@ def pause_timer(
test_id: uuid.UUID, test_id: uuid.UUID,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_tech", "blue_tech", "red_lead", "blue_lead")), current_user: User = Depends(require_any_role("red_tech", "blue_tech", "red_lead", "blue_lead")),
): ) -> TestOut:
"""Pause the running timer for the current phase (red_executing or blue_evaluating).""" """Pause the running timer for the current phase (red_executing or blue_evaluating)."""
test = crud_get_test_or_raise(db, test_id) test = crud_get_test_or_raise(db, test_id)
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
@@ -463,7 +463,7 @@ def resume_timer(
test_id: uuid.UUID, test_id: uuid.UUID,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_tech", "blue_tech", "red_lead", "blue_lead")), current_user: User = Depends(require_any_role("red_tech", "blue_tech", "red_lead", "blue_lead")),
): ) -> TestOut:
"""Resume the paused timer for the current phase.""" """Resume the paused timer for the current phase."""
test = crud_get_test_or_raise(db, test_id) test = crud_get_test_or_raise(db, test_id)
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
@@ -484,7 +484,7 @@ def validate_red(
payload: TestRedValidate, payload: TestRedValidate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead")), current_user: User = Depends(require_any_role("red_lead")),
): ) -> TestOut:
"""Red Lead approves or rejects the red side of a test.""" """Red Lead approves or rejects the red side of a test."""
test = crud_get_test_with_technique(db, test_id) test = crud_get_test_with_technique(db, test_id)
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
@@ -511,7 +511,7 @@ def validate_blue(
payload: TestBlueValidate, payload: TestBlueValidate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("blue_lead")), current_user: User = Depends(require_any_role("blue_lead")),
): ) -> TestOut:
"""Blue Lead approves or rejects the blue side of a test.""" """Blue Lead approves or rejects the blue side of a test."""
test = crud_get_test_with_technique(db, test_id) test = crud_get_test_with_technique(db, test_id)
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
@@ -537,7 +537,7 @@ def reopen(
test_id: uuid.UUID, test_id: uuid.UUID,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")), current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
): ) -> TestOut:
"""Reopen a rejected test, moving it back to ``draft``.""" """Reopen a rejected test, moving it back to ``draft``."""
test = crud_get_test_or_raise(db, test_id) test = crud_get_test_or_raise(db, test_id)
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
@@ -558,7 +558,7 @@ def update_remediation(
payload: TestRemediationUpdate, payload: TestRemediationUpdate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")), current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
): ) -> TestOut:
"""Update remediation fields on a test. """Update remediation fields on a test.
When ``remediation_status`` transitions to ``'completed'``, an automatic When ``remediation_status`` transitions to ``'completed'``, an automatic
@@ -602,7 +602,7 @@ def get_test_timeline(
test_id: uuid.UUID, test_id: uuid.UUID,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> list:
"""Return the chronological audit-log history for a test.""" """Return the chronological audit-log history for a test."""
return crud_get_test_timeline(db, test_id) return crud_get_test_timeline(db, test_id)
@@ -617,7 +617,7 @@ def get_retest_chain(
test_id: uuid.UUID, test_id: uuid.UUID,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> list:
"""Return the full chain of retests (original + all retests) for a test.""" """Return the full chain of retests (original + all retests) for a test."""
chain = wf_get_retest_chain(db, test_id) chain = wf_get_retest_chain(db, test_id)
if not chain: if not chain:
+4 -4
View File
@@ -36,7 +36,7 @@ def list_threat_actors(
limit: int = Query(50, ge=1, le=200), limit: int = Query(50, ge=1, le=200),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> list:
"""List threat actors with optional filters and pagination. """List threat actors with optional filters and pagination.
**Requires** authentication (any role). **Requires** authentication (any role).
@@ -58,7 +58,7 @@ def get_threat_actor(
actor_id: str, actor_id: str,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> dict:
"""Get detailed info about a threat actor including techniques. """Get detailed info about a threat actor including techniques.
**Requires** authentication (any role). **Requires** authentication (any role).
@@ -71,7 +71,7 @@ def get_threat_actor_coverage(
actor_id: str, actor_id: str,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> dict:
"""Calculate coverage percentage against a specific threat actor. """Calculate coverage percentage against a specific threat actor.
**Requires** authentication (any role). **Requires** authentication (any role).
@@ -87,7 +87,7 @@ def get_threat_actor_gaps(
actor_id: str, actor_id: str,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> list:
"""Identify techniques of this actor that are NOT fully validated. """Identify techniques of this actor that are NOT fully validated.
**Requires** authentication (any role). **Requires** authentication (any role).
+4 -4
View File
@@ -30,7 +30,7 @@ router = APIRouter(prefix="/users", tags=["users"])
def list_users_route( def list_users_route(
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")), current_user: User = Depends(require_role("admin")),
): ) -> list[UserOut]:
"""Return a list of all users. **Requires admin role.**""" """Return a list of all users. **Requires admin role.**"""
return list_users(db) return list_users(db)
@@ -45,7 +45,7 @@ def create_user_route(
payload: UserCreate, payload: UserCreate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")), current_user: User = Depends(require_role("admin")),
): ) -> UserOut:
"""Create a new user. **Requires admin role.**""" """Create a new user. **Requires admin role.**"""
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
user = create_user( user = create_user(
@@ -79,7 +79,7 @@ def get_user(
user_id: uuid.UUID, user_id: uuid.UUID,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")), current_user: User = Depends(require_role("admin")),
): ) -> UserOut:
"""Return a single user by ID. **Requires admin role.**""" """Return a single user by ID. **Requires admin role.**"""
return get_user_or_raise(db, user_id) return get_user_or_raise(db, user_id)
@@ -95,7 +95,7 @@ def update_user_route(
payload: UserUpdate, payload: UserUpdate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")), current_user: User = Depends(require_role("admin")),
): ) -> UserOut:
"""Update one or more fields of an existing user. **Requires admin role.**""" """Update one or more fields of an existing user. **Requires admin role.**"""
update_data = payload.model_dump(exclude_unset=True) update_data = payload.model_dump(exclude_unset=True)
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
+4 -4
View File
@@ -56,7 +56,7 @@ def create(
body: WorklogCreate, body: WorklogCreate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
user: User = Depends(require_any_role("red_tech", "blue_tech", "red_lead", "blue_lead")), user: User = Depends(require_any_role("red_tech", "blue_tech", "red_lead", "blue_lead")),
): ) -> WorklogOut:
"""Create a manually-logged worklog entry.""" """Create a manually-logged worklog entry."""
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
wl = worklog_service.create_worklog( wl = worklog_service.create_worklog(
@@ -82,7 +82,7 @@ def list_all(
user_id: Optional[UUID] = None, user_id: Optional[UUID] = None,
db: Session = Depends(get_db), db: Session = Depends(get_db),
_user: User = Depends(get_current_user), _user: User = Depends(get_current_user),
): ) -> list[WorklogOut]:
"""List worklogs with optional filters.""" """List worklogs with optional filters."""
return worklog_service.list_worklogs( return worklog_service.list_worklogs(
db, db,
@@ -97,7 +97,7 @@ def get_one(
worklog_id: UUID, worklog_id: UUID,
db: Session = Depends(get_db), db: Session = Depends(get_db),
_user: User = Depends(get_current_user), _user: User = Depends(get_current_user),
): ) -> WorklogOut:
"""Get a single worklog by ID.""" """Get a single worklog by ID."""
return worklog_service.get_worklog_or_raise(db, worklog_id) return worklog_service.get_worklog_or_raise(db, worklog_id)
@@ -107,7 +107,7 @@ def verify_integrity(
worklog_id: UUID, worklog_id: UUID,
db: Session = Depends(get_db), db: Session = Depends(get_db),
_user: User = Depends(get_current_user), _user: User = Depends(get_current_user),
): ) -> dict:
"""Check whether a worklog's integrity hash is still valid.""" """Check whether a worklog's integrity hash is still valid."""
wl = worklog_service.get_worklog_or_raise(db, worklog_id) wl = worklog_service.get_worklog_or_raise(db, worklog_id)
return { return {
+1 -1
View File
@@ -167,7 +167,7 @@ class TestOut(BaseModel):
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)
@classmethod @classmethod
def model_validate(cls, obj, **kwargs): def model_validate(cls, obj: object, **kwargs: object) -> "TestOut":
"""Override to populate technique fields from the relationship.""" """Override to populate technique fields from the relationship."""
if hasattr(obj, "technique") and obj.technique is not None: if hasattr(obj, "technique") and obj.technique is not None:
obj.__dict__["technique_mitre_id"] = obj.technique.mitre_id obj.__dict__["technique_mitre_id"] = obj.technique.mitre_id
+10 -8
View File
@@ -16,6 +16,8 @@ import random
import uuid import uuid
from datetime import datetime, timedelta from datetime import datetime, timedelta
from sqlalchemy.orm import Session
from app.auth import hash_password from app.auth import hash_password
from app.database import SessionLocal from app.database import SessionLocal
from app.models.audit import AuditLog from app.models.audit import AuditLog
@@ -102,7 +104,7 @@ TEMPLATE_NAMES = [
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def _cleanup_demo_data(db) -> None: def _cleanup_demo_data(db: Session) -> None:
"""Remove all previously seeded demo data.""" """Remove all previously seeded demo data."""
# Delete in order to respect FK constraints # Delete in order to respect FK constraints
demo_users = db.query(User).filter(User.username.like(f"{DEMO_PREFIX}%")).all() demo_users = db.query(User).filter(User.username.like(f"{DEMO_PREFIX}%")).all()
@@ -154,7 +156,7 @@ def _cleanup_demo_data(db) -> None:
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def _seed_users(db) -> list[User]: def _seed_users(db: Session) -> list[User]:
"""Create 5 users per role (25 total).""" """Create 5 users per role (25 total)."""
users = [] users = []
for role in ROLES: for role in ROLES:
@@ -173,7 +175,7 @@ def _seed_users(db) -> list[User]:
return users return users
def _seed_technique_statuses(db, count: int = 50) -> list[Technique]: def _seed_technique_statuses(db: Session, count: int = 50) -> list[Technique]:
"""Set varied statuses on up to *count* techniques.""" """Set varied statuses on up to *count* techniques."""
techniques = db.query(Technique).limit(count).all() techniques = db.query(Technique).limit(count).all()
if not techniques: if not techniques:
@@ -192,7 +194,7 @@ def _seed_technique_statuses(db, count: int = 50) -> list[Technique]:
return techniques return techniques
def _seed_tests(db, users: list[User], techniques: list[Technique], count: int = 100) -> list[Test]: def _seed_tests(db: Session, users: list[User], techniques: list[Technique], count: int = 100) -> list[Test]:
"""Create *count* tests in various pipeline states.""" """Create *count* tests in various pipeline states."""
if not techniques: if not techniques:
logger.warning("No techniques available — skipping test seeding.") logger.warning("No techniques available — skipping test seeding.")
@@ -267,7 +269,7 @@ def _seed_tests(db, users: list[User], techniques: list[Technique], count: int =
return tests return tests
def _seed_evidences(db, tests: list[Test], users: list[User], count: int = 50) -> list[Evidence]: def _seed_evidences(db: Session, tests: list[Test], users: list[User], count: int = 50) -> list[Evidence]:
"""Create *count* dummy evidence records.""" """Create *count* dummy evidence records."""
if not tests: if not tests:
return [] return []
@@ -305,7 +307,7 @@ def _seed_evidences(db, tests: list[Test], users: list[User], count: int = 50) -
return evidences return evidences
def _seed_audit_logs(db, users: list[User], count: int = 20) -> None: def _seed_audit_logs(db: Session, users: list[User], count: int = 20) -> None:
"""Create *count* varied audit log entries.""" """Create *count* varied audit log entries."""
for i in range(count): for i in range(count):
user = random.choice(users) user = random.choice(users)
@@ -323,7 +325,7 @@ def _seed_audit_logs(db, users: list[User], count: int = 20) -> None:
logger.info("Created %d demo audit logs.", count) logger.info("Created %d demo audit logs.", count)
def _seed_notifications(db, users: list[User], count: int = 30) -> None: def _seed_notifications(db: Session, users: list[User], count: int = 30) -> None:
"""Create *count* notifications spread across demo users.""" """Create *count* notifications spread across demo users."""
for i in range(count): for i in range(count):
user = random.choice(users) user = random.choice(users)
@@ -344,7 +346,7 @@ def _seed_notifications(db, users: list[User], count: int = 30) -> None:
logger.info("Created %d demo notifications.", count) logger.info("Created %d demo notifications.", count)
def _seed_templates(db, techniques: list[Technique], count: int = 10) -> None: def _seed_templates(db: Session, techniques: list[Technique], count: int = 10) -> None:
"""Create *count* manual demo templates.""" """Create *count* manual demo templates."""
if not techniques: if not techniques:
return return
+2 -1
View File
@@ -4,6 +4,7 @@ from __future__ import annotations
import hashlib import hashlib
from datetime import datetime, timezone from datetime import datetime, timezone
from uuid import UUID
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -35,7 +36,7 @@ def verify_audit_integrity(entry: AuditLog) -> bool:
def log_action( def log_action(
db: Session, db: Session,
user_id, user_id: UUID | None,
action: str, action: str,
entity_type: str | None = None, entity_type: str | None = None,
entity_id: str | None = None, entity_id: str | None = None,
@@ -192,7 +192,7 @@ def update_campaign(
*, *,
updater_id: uuid.UUID, updater_id: uuid.UUID,
updater_role: str, updater_role: str,
**fields, **fields: object,
) -> dict: ) -> dict:
"""Update a campaign. Only allowed in draft or active state. """Update a campaign. Only allowed in draft or active state.
@@ -8,6 +8,7 @@ Uses the D3FEND public API:
import logging import logging
from typing import Any from typing import Any
from uuid import UUID
import httpx import httpx
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -26,7 +27,7 @@ D3FEND_TACTICS = ["Detect", "Harden", "Isolate", "Deceive", "Evict", "Model"]
# ── Import all D3FEND techniques ───────────────────────────────────── # ── Import all D3FEND techniques ─────────────────────────────────────
def _to_str(v: Any) -> str: def _to_str(v: Any) -> str: # noqa: ANN401
"""Coerce an RDF value (str, dict with @value, or list) to a plain string.""" """Coerce an RDF value (str, dict with @value, or list) to a plain string."""
if isinstance(v, dict): if isinstance(v, dict):
return v.get("@value", str(v)) return v.get("@value", str(v))
@@ -432,7 +433,7 @@ def sync(db: Session) -> dict:
return summary return summary
def get_defenses_for_technique(db: Session, technique_id) -> list[dict]: def get_defenses_for_technique(db: Session, technique_id: UUID) -> list[dict]:
"""Get all D3FEND defensive techniques mapped to a given ATT&CK technique.""" """Get all D3FEND defensive techniques mapped to a given ATT&CK technique."""
mappings = ( mappings = (
db.query(DefensiveTechniqueMapping) db.query(DefensiveTechniqueMapping)
@@ -10,6 +10,7 @@ from __future__ import annotations
from datetime import datetime from datetime import datetime
from typing import Any from typing import Any
from uuid import UUID
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -258,11 +259,11 @@ def get_rules_for_test(db: Session, test_id: str) -> dict[str, Any]:
def evaluate_rule( def evaluate_rule(
db: Session, db: Session,
*, *,
test_id: Any, test_id: UUID,
detection_rule_id: Any, detection_rule_id: UUID,
triggered: bool | None, triggered: bool | None,
notes: str | None, notes: str | None,
evaluator_id: Any, evaluator_id: UUID,
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Save or update the evaluation result for a detection rule on a test. """Save or update the evaluation result for a detection rule on a test.
+7 -6
View File
@@ -10,9 +10,10 @@ no ``db.commit()``.
from __future__ import annotations from __future__ import annotations
import json import json
from collections.abc import Callable
from sqlalchemy import func, or_ from sqlalchemy import func, or_
from sqlalchemy.orm import Session from sqlalchemy.orm import Query, Session
from app.domain.errors import BusinessRuleViolation, EntityNotFoundError from app.domain.errors import BusinessRuleViolation, EntityNotFoundError
from app.models.campaign import Campaign, CampaignTest from app.models.campaign import Campaign, CampaignTest
@@ -92,11 +93,11 @@ def _build_layer_skeleton(
def _apply_filters( def _apply_filters(
query, query: Query, # type: ignore[type-arg]
model, model: type,
platforms: list[str] | None = None, platforms: list[str] | None = None,
tactics: list[str] | None = None, tactics: list[str] | None = None,
): ) -> Query: # type: ignore[type-arg]
"""Apply common platform and tactic filters to a technique query.""" """Apply common platform and tactic filters to a technique query."""
if platforms: if platforms:
platform_filters = [ platform_filters = [
@@ -470,7 +471,7 @@ class _LayerRegistry:
self._simple: dict[str, object] = {} self._simple: dict[str, object] = {}
self._with_id: dict[str, object] = {} self._with_id: dict[str, object] = {}
def register(self, name: str, builder, *, requires_id: bool = False) -> None: def register(self, name: str, builder: Callable[..., dict], *, requires_id: bool = False) -> None:
target = self._with_id if requires_id else self._simple target = self._with_id if requires_id else self._simple
target[name] = builder target[name] = builder
@@ -513,7 +514,7 @@ LAYER_REGISTRY.register("campaign", build_campaign_layer, requires_id=True)
SUPPORTED_LAYER_TYPES = LAYER_REGISTRY.supported_types # snapshot of built-in types SUPPORTED_LAYER_TYPES = LAYER_REGISTRY.supported_types # snapshot of built-in types
def register_layer(name: str, builder, *, requires_id: bool = False) -> None: def register_layer(name: str, builder: Callable[..., dict], *, requires_id: bool = False) -> None:
"""Public API to register a new heatmap layer type at import time.""" """Public API to register a new heatmap layer type at import time."""
LAYER_REGISTRY.register(name, builder, requires_id=requires_id) LAYER_REGISTRY.register(name, builder, requires_id=requires_id)
+2 -2
View File
@@ -2,7 +2,7 @@
import logging import logging
from datetime import datetime from datetime import datetime
from typing import Optional from typing import Any, Optional
from uuid import UUID from uuid import UUID
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -20,7 +20,7 @@ logger = logging.getLogger(__name__)
_jira_client = None _jira_client = None
def get_jira_client(): def get_jira_client() -> Any: # noqa: ANN401 # atlassian.Jira imported lazily from optional dep
"""Return a lazily-initialised Jira client, or raise if disabled.""" """Return a lazily-initialised Jira client, or raise if disabled."""
global _jira_client global _jira_client
if not settings.JIRA_ENABLED: if not settings.JIRA_ENABLED:
+2 -1
View File
@@ -15,6 +15,7 @@ from sqlalchemy.orm import Session
from app.domain.errors import EntityNotFoundError from app.domain.errors import EntityNotFoundError
from app.models.notification import Notification from app.models.notification import Notification
from app.models.test import Test
from app.models.user import User from app.models.user import User
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -157,7 +158,7 @@ def cleanup_old_notifications(db: Session, days: int = 90) -> int:
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def notify_test_state_change(db: Session, test, new_state: str) -> None: def notify_test_state_change(db: Session, test: Test, new_state: str) -> None:
"""Dispatch notifications based on a test's new state. """Dispatch notifications based on a test's new state.
Called by the workflow service after each state transition. Called by the workflow service after each state transition.
+6 -4
View File
@@ -10,12 +10,14 @@ stale data does not persist longer than ``CACHE_TTL`` seconds.
import time import time
from typing import Any, Optional from typing import Any, Optional
from sqlalchemy.orm import Session
CACHE_TTL = 300 # 5 minutes CACHE_TTL = 300 # 5 minutes
_cache: dict[str, dict[str, Any]] = {} _cache: dict[str, dict[str, Any]] = {}
def get(key: str) -> Optional[Any]: def get(key: str) -> Optional[Any]: # noqa: ANN401 # generic cache returns whatever was stored
"""Return cached value if present and not expired, else None.""" """Return cached value if present and not expired, else None."""
entry = _cache.get(key) entry = _cache.get(key)
if entry is None: if entry is None:
@@ -26,7 +28,7 @@ def get(key: str) -> Optional[Any]:
return entry["data"] return entry["data"]
def put(key: str, data: Any) -> None: def put(key: str, data: Any) -> None: # noqa: ANN401 # generic cache accepts any serialisable value
"""Store *data* under *key* with the current timestamp.""" """Store *data* under *key* with the current timestamp."""
_cache[key] = {"data": data, "ts": time.time()} _cache[key] = {"data": data, "ts": time.time()}
@@ -42,7 +44,7 @@ def invalidate(key: Optional[str] = None) -> None:
# ── High-level helpers ──────────────────────────────────────────────── # ── High-level helpers ────────────────────────────────────────────────
def get_organization_score_cached(db): def get_organization_score_cached(db: Session) -> dict:
"""Cached wrapper around ``calculate_organization_score``.""" """Cached wrapper around ``calculate_organization_score``."""
from app.services.scoring_service import calculate_organization_score from app.services.scoring_service import calculate_organization_score
@@ -55,7 +57,7 @@ def get_organization_score_cached(db):
return result return result
def get_operational_metrics_cached(db): def get_operational_metrics_cached(db: Session) -> dict:
"""Cached wrapper around operational metrics (MTTD, MTTR, efficacy).""" """Cached wrapper around operational metrics (MTTD, MTTR, efficacy)."""
from app.services.operational_metrics_service import ( from app.services.operational_metrics_service import (
calculate_alert_fidelity, calculate_alert_fidelity,
+7 -5
View File
@@ -1,18 +1,20 @@
"""Tempo time-tracking integration service.""" """Tempo time-tracking integration service."""
import logging import logging
from typing import Optional from typing import Any, Optional
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.config import settings from app.config import settings
from app.domain.exceptions import InvalidOperationError from app.domain.exceptions import InvalidOperationError
from app.models.jira_link import JiraLink, JiraLinkEntityType from app.models.jira_link import JiraLink, JiraLinkEntityType
from app.models.test import Test
from app.models.user import User
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_tempo_client(): def get_tempo_client() -> Any: # noqa: ANN401 # tempoapiclient.Tempo imported lazily from optional dep
"""Return a Tempo API client, or raise if disabled.""" """Return a Tempo API client, or raise if disabled."""
if not settings.TEMPO_ENABLED: if not settings.TEMPO_ENABLED:
raise InvalidOperationError("Tempo integration is not enabled") raise InvalidOperationError("Tempo integration is not enabled")
@@ -52,8 +54,8 @@ def log_worklog(
def auto_log_test_worklog( def auto_log_test_worklog(
db: Session, db: Session,
test, test: Test,
user, user: User,
activity_type: str, activity_type: str,
) -> Optional[dict]: ) -> Optional[dict]:
"""If the test has a Jira link, log time to Tempo automatically. """If the test has a Jira link, log time to Tempo automatically.
@@ -97,7 +99,7 @@ def auto_log_test_worklog(
return None return None
def _calculate_duration(test, activity_type: str) -> int: def _calculate_duration(test: Test, activity_type: str) -> int:
"""Calculate real duration in seconds from the phase timing fields. """Calculate real duration in seconds from the phase timing fields.
Uses the actual start/end timestamps recorded by the workflow buttons, Uses the actual start/end timestamps recorded by the workflow buttons,
+4 -4
View File
@@ -63,7 +63,7 @@ def create_test(
*, *,
technique_id: uuid.UUID, technique_id: uuid.UUID,
creator_id: uuid.UUID, creator_id: uuid.UUID,
**fields: Any, **fields: object,
) -> Test: ) -> Test:
"""Create a new test linked to an existing technique. """Create a new test linked to an existing technique.
@@ -176,7 +176,7 @@ def update_test(
*, *,
updater_id: uuid.UUID, updater_id: uuid.UUID,
updater_role: str, updater_role: str,
**fields: Any, **fields: object,
) -> Test: ) -> Test:
"""Update general test fields (draft or rejected only). """Update general test fields (draft or rejected only).
@@ -204,7 +204,7 @@ def update_test(
return test return test
def update_test_red(db: Session, test_id: uuid.UUID, **fields: Any) -> Test: def update_test_red(db: Session, test_id: uuid.UUID, **fields: object) -> Test:
"""Update Red Team fields (draft or red_executing only). """Update Red Team fields (draft or red_executing only).
Raises BusinessRuleViolation if state not in (draft, red_executing). Raises BusinessRuleViolation if state not in (draft, red_executing).
@@ -226,7 +226,7 @@ def update_test_red(db: Session, test_id: uuid.UUID, **fields: Any) -> Test:
return test return test
def update_test_blue(db: Session, test_id: uuid.UUID, **fields: Any) -> Test: def update_test_blue(db: Session, test_id: uuid.UUID, **fields: object) -> Test:
"""Update Blue Team fields (blue_evaluating only). """Update Blue Team fields (blue_evaluating only).
Raises BusinessRuleViolation if state is not blue_evaluating. Raises BusinessRuleViolation if state is not blue_evaluating.
@@ -13,6 +13,7 @@ session via the Unit of Work pattern.
""" """
import logging import logging
import uuid
from datetime import datetime from datetime import datetime
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -530,7 +531,7 @@ def handle_remediation_completed(db: Session, test: Test, user: User) -> Test |
return retest return retest
def get_retest_chain(db: Session, test_id) -> list[Test]: def get_retest_chain(db: Session, test_id: uuid.UUID) -> list[Test]:
"""Return the full chain of retests for a given test. """Return the full chain of retests for a given test.
Includes the original test and all subsequent retests, ordered Includes the original test and all subsequent retests, ordered
+3 -1
View File
@@ -7,11 +7,13 @@ line-length = 120
# F — pyflakes (unused imports, undefined names) # F — pyflakes (unused imports, undefined names)
# I — isort (import ordering per PEP8 convention) # I — isort (import ordering per PEP8 convention)
# N — pep8-naming (class/function/variable naming conventions) # N — pep8-naming (class/function/variable naming conventions)
select = ["E", "W", "F", "I", "N"] # ANN — flake8-annotations (type hint enforcement)
select = ["E", "W", "F", "I", "N", "ANN"]
ignore = [ ignore = [
# SQLAlchemy filter syntax requires `== True` / `== False` comparisons # SQLAlchemy filter syntax requires `== True` / `== False` comparisons
"E712", "E712",
# ANN101/ANN102 (self/cls type annotations) removed from ruff — not needed
] ]
[lint.per-file-ignores] [lint.per-file-ignores]