refactor(types): add comprehensive type annotations across backend Python codebase

Enable ANN rules in ruff.toml (flake8-annotations) and resolve all 221 violations:

ANN201/ANN202 — return types on 168 public/private functions:
- All 28 FastAPI routers: endpoints annotated with dict/list/specific schema/
  StreamingResponse/FileResponse/JSONResponse as appropriate
- main.py: lifespan→AsyncGenerator[None,None], exception handlers→JSONResponse
- database.py: get_db→Generator[Session,None,None], proxy methods→correct types
- middleware/request_context.py: dispatch→Response with Callable call_next type

ANN001/ANN002/ANN003 — 32 missing argument types:
- seed_demo.py: all db parameters typed as Session
- domain/unit_of_work.py: __aexit__ exc_type/exc_val/exc_tb typed with TracebackType
- services: audit_service user_id→UUID|None, heatmap_service query/model/builder,
  notification_service test→Test, tempo_service test→Test/user→User,
  test_workflow_service test_id→UUID, campaign_crud **fields→object,
  test_crud **fields→object (4 sites)

ANN401 — 16 Any usages resolved:
- Domain entities (campaign/technique/threat_actor/test_entity): replaced Any with
  actual ORM types via TYPE_CHECKING guards to avoid circular imports
- detection_rule_service: test_id/detection_rule_id/evaluator_id→UUID
- score_cache: kept Any with # noqa: ANN401 (genuinely generic cache)
- jira_service/tempo_service: kept Any with # noqa: ANN401 (lazy optional deps)
- d3fend_import_service: _to_str(v: Any) kept with # noqa: ANN401

ANN204/ANN205/ANN206 — special/static/class methods:
- database.py proxy __call__/__getattr__: *args: object/**kwargs: object
- schemas/test.py model_validate: obj→object, **kwargs→object
- sa_technique_repository._int_type→type

All 439 unit tests pass. ruff check app/ → All checks passed!
This commit is contained in:
kitos
2026-06-09 17:04:51 +02:00
parent 8f98bdd273
commit 9ff0f04ba3
51 changed files with 267 additions and 223 deletions
+10 -7
View File
@@ -1,5 +1,8 @@
from collections.abc import Generator
from sqlalchemy import create_engine
from sqlalchemy.orm import declarative_base, sessionmaker
from sqlalchemy.engine import Engine
from sqlalchemy.orm import Session, declarative_base, sessionmaker
Base = declarative_base()
@@ -10,7 +13,7 @@ _engine = None
_SessionLocal = None
def _get_engine():
def _get_engine() -> Engine:
global _engine
if _engine is None:
from app.config import settings
@@ -28,7 +31,7 @@ def _get_engine():
return _engine
def _get_session_factory():
def _get_session_factory() -> sessionmaker:
global _SessionLocal
if _SessionLocal is None:
_SessionLocal = sessionmaker(
@@ -41,10 +44,10 @@ class _LazySessionLocal:
"""Proxy so ``SessionLocal()`` keeps working as before but the real
sessionmaker is only created on first call."""
def __call__(self, *args, **kwargs):
def __call__(self, *args: object, **kwargs: object) -> Session:
return _get_session_factory()(*args, **kwargs)
def __getattr__(self, name):
def __getattr__(self, name: str) -> object:
return getattr(_get_session_factory(), name)
@@ -53,14 +56,14 @@ SessionLocal = _LazySessionLocal()
class _EngineProxy:
"""Thin proxy so ``from app.database import engine`` still works."""
def __getattr__(self, name):
def __getattr__(self, name: str) -> object:
return getattr(_get_engine(), name)
engine = _EngineProxy() # type: ignore[assignment]
def get_db():
def get_db() -> Generator[Session, None, None]:
db = SessionLocal()
try:
yield db
+3 -2
View File
@@ -8,6 +8,7 @@ Provides:
(admins always pass).
"""
from collections.abc import Callable
from typing import Optional
from fastapi import Cookie, Depends, HTTPException, status
@@ -112,7 +113,7 @@ async def require_password_changed(
return current_user
def require_role(required_role: str):
def require_role(required_role: str) -> Callable[..., object]:
"""Return a FastAPI dependency that enforces *required_role*.
The dependency allows the request to proceed when
@@ -133,7 +134,7 @@ def require_role(required_role: str):
return role_checker
def require_any_role(*roles: str):
def require_any_role(*roles: str) -> Callable[..., object]:
"""Return a FastAPI dependency that enforces **any** of the given *roles*.
Admins always pass. Usage example::
+5 -2
View File
@@ -8,10 +8,13 @@ from __future__ import annotations
import enum
import uuid
from dataclasses import dataclass, field
from typing import Any
from typing import TYPE_CHECKING
from app.domain.errors import BusinessRuleViolation, InvalidStateTransition
if TYPE_CHECKING:
from app.models.campaign import Campaign as CampaignORM
class CampaignStatus(str, enum.Enum):
draft = "draft"
@@ -86,7 +89,7 @@ class CampaignEntity:
)
@classmethod
def from_orm(cls, orm: Any) -> CampaignEntity:
def from_orm(cls, orm: CampaignORM) -> CampaignEntity:
"""Build a CampaignEntity from a SQLAlchemy Campaign model."""
test_count = len(getattr(orm, "campaign_tests", None) or [])
return cls(
+6 -3
View File
@@ -17,11 +17,14 @@ from __future__ import annotations
import uuid
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any
from typing import TYPE_CHECKING
from app.domain.enums import TechniqueStatus, TestResult, TestState
from app.domain.value_objects.mitre_id import MitreId
if TYPE_CHECKING:
from app.models.technique import Technique as TechniqueORM
@dataclass(frozen=True)
class _TestSnapshot:
@@ -76,7 +79,7 @@ class TechniqueEntity:
)
@classmethod
def from_orm(cls, model: Any) -> TechniqueEntity:
def from_orm(cls, model: TechniqueORM) -> TechniqueEntity:
"""Build a TechniqueEntity from a SQLAlchemy Technique model."""
raw_status = model.status_global
if raw_status is None:
@@ -101,7 +104,7 @@ class TechniqueEntity:
mitre_last_modified=getattr(model, "mitre_last_modified", None),
)
def apply_to(self, model: Any) -> None:
def apply_to(self, model: TechniqueORM) -> None:
"""Copy mutable fields back onto the ORM model."""
model.status_global = self.status_global
model.review_required = self.review_required
+5 -2
View File
@@ -7,7 +7,10 @@ from __future__ import annotations
import uuid
from dataclasses import dataclass, field
from typing import Any
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from app.models.threat_actor import ThreatActor as ThreatActorORM
@dataclass
@@ -63,7 +66,7 @@ class ThreatActorEntity:
return round(len(self.covered_techniques) / len(self.techniques) * 100, 1)
@classmethod
def from_orm(cls, orm: Any) -> ThreatActorEntity:
def from_orm(cls, orm: ThreatActorORM) -> ThreatActorEntity:
techs: list[ThreatActorTechniqueRef] = []
for tat in getattr(orm, "techniques", None) or []:
technique = getattr(tat, "technique", None)
+6 -3
View File
@@ -26,7 +26,7 @@ import enum
import uuid
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any
from typing import TYPE_CHECKING, Any
from app.domain.errors import (
BusinessRuleViolation,
@@ -34,6 +34,9 @@ from app.domain.errors import (
InvalidStateTransition,
)
if TYPE_CHECKING:
from app.models.test import Test as TestORM
# ── Value objects ────────────────────────────────────────────────────
@@ -103,7 +106,7 @@ class TestEntity:
# -- Factory --------------------------------------------------------
@classmethod
def from_orm(cls, model: Any) -> TestEntity:
def from_orm(cls, model: TestORM) -> TestEntity:
"""Build a TestEntity from a SQLAlchemy ``Test`` model instance."""
raw_state = model.state
state = raw_state if isinstance(raw_state, TestState) else TestState(raw_state)
@@ -126,7 +129,7 @@ class TestEntity:
blue_paused_seconds=model.blue_paused_seconds or 0,
)
def apply_to(self, model: Any) -> None:
def apply_to(self, model: TestORM) -> None:
"""Copy the entity's mutable fields back onto the ORM model."""
model.state = self.state
model.red_validation_status = self.red_validation_status
+8 -1
View File
@@ -22,6 +22,8 @@ Services should **never** call ``db.commit()``; they use ``db.add()`` /
from __future__ import annotations
from types import TracebackType
from sqlalchemy.orm import Session
@@ -36,7 +38,12 @@ class UnitOfWork:
def __enter__(self) -> "UnitOfWork":
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
if exc_type is not None:
self.rollback()
@@ -197,7 +197,7 @@ class SATechniqueRepository:
# -- Internal ----------------------------------------------------------
@staticmethod
def _int_type():
def _int_type() -> type:
"""Return an Integer type for CAST expressions (SQLite-compatible)."""
from sqlalchemy import Integer
return Integer
+6 -5
View File
@@ -1,5 +1,6 @@
import logging
import os
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request, status
@@ -53,7 +54,7 @@ setup_logging()
_IS_PRODUCTION = os.environ.get("AEGIS_ENV", "").lower() == "production"
@asynccontextmanager
async def lifespan(app: FastAPI):
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
"""Startup / shutdown logic."""
ensure_bucket_exists()
start_scheduler()
@@ -124,7 +125,7 @@ app.include_router(osint_router.router, prefix="/api/v1")
@app.get("/health", include_in_schema=False)
def health():
def health() -> dict[str, str]:
"""Minimal health check — returns only an HTTP 200 with no service metadata.
Access is restricted to internal networks at the Nginx level
@@ -149,7 +150,7 @@ def _serialize_validation_errors(exc: RequestValidationError) -> list[dict]:
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
async def validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse:
"""Handle validation errors with consistent format."""
return JSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
@@ -162,7 +163,7 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE
@app.exception_handler(SQLAlchemyError)
async def sqlalchemy_exception_handler(request: Request, exc: SQLAlchemyError):
async def sqlalchemy_exception_handler(request: Request, exc: SQLAlchemyError) -> JSONResponse:
"""Handle database errors."""
logging.error(f"Database error: {exc}")
return JSONResponse(
@@ -175,7 +176,7 @@ async def sqlalchemy_exception_handler(request: Request, exc: SQLAlchemyError):
@app.exception_handler(Exception)
async def general_exception_handler(request: Request, exc: Exception):
async def general_exception_handler(request: Request, exc: Exception) -> JSONResponse:
"""Handle all unhandled exceptions."""
logging.error(f"Unhandled exception: {exc}")
return JSONResponse(
+7 -1
View File
@@ -1,9 +1,11 @@
"""Request context middleware — captures client IP and User-Agent per request."""
from collections.abc import Awaitable, Callable
from contextvars import ContextVar
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
request_ip: ContextVar[str] = ContextVar("request_ip", default="")
request_user_agent: ContextVar[str] = ContextVar("request_user_agent", default="")
@@ -20,7 +22,11 @@ def resolve_client_ip(request: Request) -> str:
class RequestContextMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
async def dispatch(
self,
request: Request,
call_next: Callable[[Request], Awaitable[Response]],
) -> Response:
request_ip.set(resolve_client_ip(request))
request_user_agent.set(request.headers.get("User-Agent", ""))
return await call_next(request)
+4 -4
View File
@@ -15,7 +15,7 @@ router = APIRouter(prefix="/metrics/advanced", tags=["advanced-metrics"])
def coverage_by_tactic(
db: Session = Depends(get_db),
user: User = Depends(get_current_user),
):
) -> list:
"""Coverage percentage broken down by MITRE ATT&CK tactic."""
return advanced_metrics_service.get_coverage_by_tactic(db)
@@ -24,7 +24,7 @@ def coverage_by_tactic(
def never_tested_techniques(
db: Session = Depends(get_db),
user: User = Depends(get_current_user),
):
) -> list:
"""Techniques that have never had a test created."""
return advanced_metrics_service.get_never_tested_techniques(db)
@@ -33,7 +33,7 @@ def never_tested_techniques(
def avg_validation_time(
db: Session = Depends(get_db),
user: User = Depends(get_current_user),
):
) -> dict:
"""Average time from test creation to validation, computed from audit logs.
Returns overall average and per-phase averages where data is available.
@@ -45,6 +45,6 @@ def avg_validation_time(
def detection_rate_trend(
db: Session = Depends(get_db),
user: User = Depends(get_current_user),
):
) -> list:
"""Monthly detection rate trend for the last 12 months."""
return advanced_metrics_service.get_detection_rate_trend(db)
+4 -4
View File
@@ -19,7 +19,7 @@ router = APIRouter(prefix="/analytics", tags=["analytics"])
def analytics_coverage(
db: Session = Depends(get_db),
user: User = Depends(get_current_user),
):
) -> list:
"""Coverage per technique — flat format for BI dashboards."""
return analytics_service.get_coverage_analytics(db)
@@ -30,7 +30,7 @@ def analytics_tests(
date_to: str = Query(None, description="ISO date filter (<=)"),
db: Session = Depends(get_db),
user: User = Depends(get_current_user),
):
) -> list:
"""All tests with timestamps — flat format for BI dashboards."""
return analytics_service.get_tests_analytics(
db, date_from=date_from, date_to=date_to
@@ -41,7 +41,7 @@ def analytics_tests(
def analytics_trends(
db: Session = Depends(get_db),
user: User = Depends(get_current_user),
):
) -> list:
"""Historical coverage snapshots for trend visualization."""
return analytics_service.get_trends_analytics(db)
@@ -50,6 +50,6 @@ def analytics_trends(
def analytics_operators(
db: Session = Depends(get_db),
user: User = Depends(require_role("admin")),
):
) -> list:
"""Per-operator metrics — for workload management dashboards."""
return analytics_service.get_operators_analytics(db)
+3 -3
View File
@@ -30,7 +30,7 @@ def list_audit_logs(
limit: int = Query(50, ge=1, le=100, description="Max records to return"),
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
) -> AuditLogPage:
"""Return paginated audit logs with optional filters.
**Requires admin role.**
@@ -57,7 +57,7 @@ def list_audit_logs(
def list_actions(
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
) -> list[str]:
"""Return a list of distinct action types in the audit log.
**Requires admin role.**
@@ -69,7 +69,7 @@ def list_actions(
def list_entity_types(
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
) -> list[str]:
"""Return a list of distinct entity types in the audit log.
**Requires admin role.**
+4 -4
View File
@@ -46,7 +46,7 @@ def login(
response: Response,
form_data: OAuth2PasswordRequestForm = Depends(),
db: Session = Depends(get_db),
):
) -> TokenResponse:
"""Authenticate a user and return a JWT access token.
Rate-limited to **5 attempts per minute per IP**. Failed and successful
@@ -110,7 +110,7 @@ def logout(
request: Request,
response: Response,
aegis_token: str | None = Cookie(None),
):
) -> dict:
"""Clear the authentication cookie and revoke the current token."""
bearer = (
request.headers.get("Authorization")
@@ -148,7 +148,7 @@ def logout(
@router.get("/me", response_model=UserOut)
def read_current_user(current_user: User = Depends(get_current_user)):
def read_current_user(current_user: User = Depends(get_current_user)) -> UserOut:
"""Return the profile of the currently authenticated user."""
return current_user
@@ -158,7 +158,7 @@ def change_password(
body: PasswordChange,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> dict:
"""Change the current user's password."""
auth_change_password(
db,
+12 -12
View File
@@ -107,7 +107,7 @@ def list_campaigns(
limit: int = Query(50, ge=1, le=200),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> list:
"""List campaigns with optional filters and pagination."""
return crud_list(
db,
@@ -129,7 +129,7 @@ def create_campaign(
payload: CampaignCreate,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
) -> dict:
"""Create a new campaign."""
with UnitOfWork(db) as uow:
result = crud_create(
@@ -165,7 +165,7 @@ def get_campaign(
campaign_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> dict:
"""Get detailed campaign info including tests and progress."""
return crud_get_detail(db, campaign_id)
@@ -180,7 +180,7 @@ def update_campaign(
payload: CampaignUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
) -> dict:
"""Update a campaign. Only allowed in draft or active state."""
update_data = payload.model_dump(exclude_unset=True)
with UnitOfWork(db) as uow:
@@ -214,7 +214,7 @@ def add_test_to_campaign(
payload: AddTestPayload,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
) -> dict:
"""Add a test to a campaign with optional ordering and dependency."""
with UnitOfWork(db) as uow:
result = crud_add_test(
@@ -239,7 +239,7 @@ def remove_test_from_campaign(
campaign_test_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
) -> dict:
"""Remove a test from a campaign."""
with UnitOfWork(db) as uow:
crud_remove_test(db, campaign_id, campaign_test_id)
@@ -256,7 +256,7 @@ def activate_campaign(
campaign_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
) -> dict:
"""Activate a campaign, moving it from draft to active."""
with UnitOfWork(db) as uow:
campaign = crud_activate(db, campaign_id)
@@ -292,7 +292,7 @@ def complete_campaign(
campaign_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "admin")),
):
) -> dict:
"""Mark a campaign as completed."""
with UnitOfWork(db) as uow:
campaign = crud_complete(db, campaign_id)
@@ -319,7 +319,7 @@ def get_campaign_progress_endpoint(
campaign_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> dict:
"""Get progress statistics for a campaign."""
return crud_get_progress(db, campaign_id)
@@ -333,7 +333,7 @@ def generate_campaign_from_actor(
actor_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
) -> dict:
"""Auto-generate a campaign from a threat actor's uncovered techniques.
Creates tests from the best available templates and orders them
@@ -369,7 +369,7 @@ def schedule_campaign(
payload: SchedulePayload,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
) -> dict:
"""Configure or update the recurrence schedule for a campaign.
Only the campaign creator or admin can change scheduling.
@@ -411,6 +411,6 @@ def get_campaign_history(
campaign_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> list:
"""List all child campaigns (execution history) of a recurring campaign."""
return crud_get_history(db, campaign_id)
+7 -7
View File
@@ -34,7 +34,7 @@ router = APIRouter(prefix="/compliance", tags=["compliance"])
def list_frameworks_endpoint(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> list:
"""List all available compliance frameworks."""
return list_frameworks(db)
@@ -47,7 +47,7 @@ def framework_status(
framework_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> dict:
"""Get compliance status for each control in a framework."""
return get_framework_status(db, framework_id)
@@ -60,7 +60,7 @@ def framework_report(
framework_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> dict:
"""Get the full compliance report (same as status but marked as report)."""
return get_framework_status(db, framework_id)
@@ -73,7 +73,7 @@ def framework_report_csv(
framework_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> StreamingResponse:
"""Export compliance report as CSV."""
csv_bytes, filename = build_framework_report_csv(db, framework_id)
return StreamingResponse(
@@ -93,7 +93,7 @@ def framework_gaps(
framework_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> dict:
"""Get controls with techniques that are not adequately covered."""
return get_framework_gaps(db, framework_id)
@@ -105,7 +105,7 @@ def framework_gaps(
def import_nist(
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
) -> dict:
"""Import NIST 800-53 Rev 5 mappings (admin only)."""
result = import_nist_800_53_mappings(db)
return result
@@ -115,7 +115,7 @@ def import_nist(
def import_cis(
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
) -> dict:
"""Import CIS Controls v8 mappings (admin only)."""
result = import_cis_controls_v8_mappings(db)
return result
+4 -4
View File
@@ -38,7 +38,7 @@ def list_defensive_techniques(
limit: int = Query(50, ge=1, le=200),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> list:
"""List all D3FEND defensive techniques with optional filters."""
return list_defensive_techniques_svc(
db, tactic=tactic, search=search, offset=offset, limit=limit
@@ -53,7 +53,7 @@ def list_defensive_techniques(
def list_d3fend_tactics_endpoint(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> list:
"""Return a list of all D3FEND tactics with counts."""
return list_d3fend_tactics(db)
@@ -67,7 +67,7 @@ def get_defenses_for_attack_technique_endpoint(
mitre_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> list:
"""Get all D3FEND defensive techniques mapped to a given ATT&CK technique."""
return get_defenses_for_attack_technique(db, mitre_id)
@@ -80,7 +80,7 @@ def get_defenses_for_attack_technique_endpoint(
def trigger_d3fend_import(
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
) -> dict:
"""Import D3FEND techniques and ATT&CK mappings. Admin only."""
tech_result = import_d3fend_techniques(db)
mapping_result = import_d3fend_mappings(db)
+5 -5
View File
@@ -47,7 +47,7 @@ router = APIRouter(prefix="/data-sources", tags=["data-sources"])
def list_data_sources(
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
) -> list:
"""List all registered data sources.
**Requires** the ``admin`` role.
@@ -61,7 +61,7 @@ def update_data_source(
body: DataSourceUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
) -> dict:
"""Update a data source (enable/disable, change config).
**Requires** the ``admin`` role.
@@ -87,7 +87,7 @@ def sync_data_source(
source_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
) -> dict:
"""Trigger sync/import for a specific data source.
**Requires** the ``admin`` role.
@@ -99,7 +99,7 @@ def sync_data_source(
def sync_all_data_sources(
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
) -> dict:
"""Trigger sync for all enabled data sources (sequentially).
**Requires** the ``admin`` role.
@@ -125,7 +125,7 @@ def get_data_source_stats(
source_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
) -> dict:
"""Get detailed statistics for a specific data source.
**Requires** the ``admin`` role.
+5 -5
View File
@@ -51,7 +51,7 @@ def list_detection_rules(
limit: int = Query(50, ge=1, le=200),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> list:
"""List detection rules with optional filters and pagination."""
return list_rules(
db,
@@ -72,7 +72,7 @@ def get_detection_rules_for_template(
template_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> list:
"""Get detection rules associated with a test template."""
return get_rules_for_template(db, template_id)
@@ -84,7 +84,7 @@ def get_detection_rules_for_template(
def auto_associate_detection_rules(
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
) -> dict:
"""Auto-associate test templates with detection rules by MITRE technique ID.
For each active template, find all active detection rules for the same
@@ -102,7 +102,7 @@ def get_detection_rules_for_test(
test_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> list:
"""Get detection rules relevant to a test, along with their evaluation results.
Finds rules by matching the test's technique_id to detection rules,
@@ -119,7 +119,7 @@ def evaluate_detection_rule(
payload: DetectionRuleEvaluate,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("blue_tech", "blue_lead")),
):
) -> dict:
"""Save or update the evaluation result for a detection rule on a test."""
return evaluate_rule(
db,
+4 -4
View File
@@ -88,7 +88,7 @@ async def upload_evidence(
notes: Optional[str] = Form(None),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> EvidenceOut:
"""Upload a file as evidence for the given test.
The ``team`` field (sent as form data) determines whether this is
@@ -154,7 +154,7 @@ def list_evidence(
team: Optional[str] = Query(None, description="Filter by team: red or blue"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> list[EvidenceOut]:
"""List all evidences for a test, optionally filtered by team."""
get_test_or_raise(db, test_id)
evidences = list_evidence_for_test(db, test_id, team=team)
@@ -171,7 +171,7 @@ def get_evidence(
evidence_id: _uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> EvidenceOut:
"""Return evidence metadata together with a presigned download URL."""
evidence = get_evidence_or_raise(db, evidence_id)
return _evidence_to_out(evidence)
@@ -187,7 +187,7 @@ def delete_evidence(
evidence_id: _uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> dict:
"""Delete an evidence record.
Only allowed in editable states:
+5 -5
View File
@@ -28,7 +28,7 @@ def heatmap_coverage(
min_score: int = Query(0, ge=0, le=100),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> dict:
"""Coverage layer — score based on status_global of each technique."""
return heatmap_service.build_coverage_layer(
db, platforms=platforms, tactics=tactics, min_score=min_score,
@@ -43,7 +43,7 @@ def heatmap_threat_actor(
min_score: int = Query(0, ge=0, le=100),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> dict:
"""Threat actor layer — techniques used by an actor with coverage color."""
return heatmap_service.build_threat_actor_layer(
db, actor_id, platforms=platforms, tactics=tactics, min_score=min_score,
@@ -57,7 +57,7 @@ def heatmap_detection_rules(
min_score: int = Query(0, ge=0, le=100),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> dict:
"""Detection rules layer — score based on ratio of rules available vs total."""
return heatmap_service.build_detection_rules_layer(
db, platforms=platforms, tactics=tactics, min_score=min_score,
@@ -72,7 +72,7 @@ def heatmap_campaign(
min_score: int = Query(0, ge=0, le=100),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> dict:
"""Campaign layer — only techniques in the campaign, colored by test state."""
return heatmap_service.build_campaign_layer(
db, campaign_id, platforms=platforms, tactics=tactics, min_score=min_score,
@@ -88,7 +88,7 @@ def export_navigator(
min_score: int = Query(0, ge=0, le=100),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> StreamingResponse:
"""Export a heatmap layer as a downloadable JSON file for ATT&CK Navigator."""
data = heatmap_service.build_navigator_export(
db, layer, layer_id=layer_id,
+6 -6
View File
@@ -29,7 +29,7 @@ def search_issues(
q: str = Query(..., min_length=2),
max_results: int = Query(10, le=50),
user: User = Depends(get_current_user),
):
) -> list[JiraIssueResult]:
"""Search Jira issues by JQL or free text."""
return jira_service.search_jira_issues(q, max_results)
@@ -39,7 +39,7 @@ def create_link(
body: JiraLinkCreate,
db: Session = Depends(get_db),
user: User = Depends(get_current_user),
):
) -> JiraLinkOut:
"""Associate an Aegis entity with a Jira issue."""
with UnitOfWork(db) as uow:
link = jira_service.create_link(
@@ -74,7 +74,7 @@ def list_links(
entity_id: Optional[UUID] = None,
db: Session = Depends(get_db),
user: User = Depends(get_current_user),
):
) -> list[JiraLinkOut]:
"""List Jira links, optionally filtered by entity."""
return jira_service.list_links(
db,
@@ -88,7 +88,7 @@ def sync_link(
link_id: UUID,
db: Session = Depends(get_db),
user: User = Depends(require_role("admin")),
):
) -> dict:
"""Force bidirectional sync for a specific Jira link."""
with UnitOfWork(db) as uow:
link = jira_service.get_link_or_raise(db, link_id)
@@ -102,7 +102,7 @@ def delete_link(
link_id: UUID,
db: Session = Depends(get_db),
user: User = Depends(get_current_user),
):
) -> None:
"""Remove a Jira link."""
with UnitOfWork(db) as uow:
link = jira_service.delete_link(db, link_id)
@@ -123,7 +123,7 @@ def create_issue_from_entity(
entity_id: UUID,
db: Session = Depends(get_db),
user: User = Depends(get_current_user),
):
) -> dict:
"""Auto-create a Jira issue from an Aegis entity and link them."""
with UnitOfWork(db) as uow:
result = jira_service.create_issue_and_link(
+6 -6
View File
@@ -42,7 +42,7 @@ router = APIRouter(prefix="/metrics", tags=["metrics"])
def coverage_summary(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> CoverageSummary:
"""Return a global coverage summary across all techniques."""
return get_coverage_summary(db)
@@ -56,7 +56,7 @@ def coverage_summary(
def coverage_by_tactic(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> list[TacticCoverage]:
"""Return coverage breakdown grouped by tactic."""
return get_coverage_by_tactic(db)
@@ -70,7 +70,7 @@ def coverage_by_tactic(
def test_pipeline(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> TestPipelineCounts:
"""Return how many tests are in each pipeline state."""
return get_test_pipeline_counts(db)
@@ -84,7 +84,7 @@ def test_pipeline(
def team_activity(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> list[TeamActivity]:
"""Return activity summary for Red and Blue teams."""
return get_team_activity(db)
@@ -98,7 +98,7 @@ def team_activity(
def validation_rate(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> list[ValidationRate]:
"""Return approval and rejection rates for Red Lead and Blue Lead."""
return get_validation_rate(db)
@@ -112,6 +112,6 @@ def validation_rate(
def recent_tests(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> list[RecentTestItem]:
"""Return the 10 most recently created tests."""
return get_recent_tests(db, limit=10)
+4 -4
View File
@@ -39,7 +39,7 @@ def list_notifications_endpoint(
limit: int = Query(20, ge=1, le=100),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> list[NotificationOut]:
"""Return paginated notifications for the current user, newest first."""
return list_notifications(db, current_user.id, offset=offset, limit=limit)
@@ -53,7 +53,7 @@ def list_notifications_endpoint(
def unread_count(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> UnreadCountOut:
"""Return the number of unread notifications for the current user."""
count = get_unread_count(db, current_user.id)
return UnreadCountOut(unread_count=count)
@@ -69,7 +69,7 @@ def read_notification(
notification_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> NotificationOut:
"""Mark a single notification as read."""
with UnitOfWork(db) as uow:
notif = mark_as_read(db, notification_id, current_user.id)
@@ -86,7 +86,7 @@ def read_notification(
def read_all_notifications(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> dict:
"""Mark all notifications for the current user as read."""
with UnitOfWork(db) as uow:
count = mark_all_as_read(db, current_user.id)
+3 -3
View File
@@ -25,7 +25,7 @@ router = APIRouter(prefix="/metrics/operational", tags=["operational-metrics"])
def operational_metrics(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> dict:
"""Get all operational metrics (MTTD, MTTR, etc.) — cached for 5 min."""
from app.services.score_cache import get_operational_metrics_cached
@@ -40,7 +40,7 @@ def operational_trend(
period: str = Query("90d", pattern="^(30d|90d|1y)$"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> dict:
"""Get weekly trend data for operational metrics."""
return get_operational_trend(db, period)
@@ -52,6 +52,6 @@ def operational_trend(
def metrics_by_team(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> dict:
"""Get metrics broken down by Red Team vs Blue Team."""
return get_metrics_by_team(db)
+5 -5
View File
@@ -57,7 +57,7 @@ def list_osint_items(
limit: int = Query(50, ge=1, le=200),
db: Session = Depends(get_db),
user: User = Depends(get_current_user),
):
) -> list:
"""List OSINT items with optional filters."""
return service_list_osint_items(
db,
@@ -73,7 +73,7 @@ def list_osint_items(
def osint_summary(
db: Session = Depends(get_db),
user: User = Depends(get_current_user),
):
) -> dict:
"""Summary statistics for OSINT items."""
return get_osint_summary(db)
@@ -83,7 +83,7 @@ def review_osint_item(
item_id: UUID,
db: Session = Depends(get_db),
user: User = Depends(get_current_user),
):
) -> dict:
"""Mark an OSINT item as reviewed."""
with UnitOfWork(db) as uow:
item = mark_osint_reviewed(db, str(item_id))
@@ -101,7 +101,7 @@ def trigger_technique_enrichment(
technique_id: UUID,
db: Session = Depends(get_db),
user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
) -> dict:
"""Manually trigger OSINT enrichment for a single technique."""
technique = get_technique_or_raise(db, technique_id)
count = enrich_technique_with_cves(db, technique)
@@ -119,7 +119,7 @@ def get_technique_osint(
reviewed: bool | None = Query(None),
db: Session = Depends(get_db),
user: User = Depends(get_current_user),
):
) -> list:
"""Get all OSINT items for a specific technique."""
items = get_osint_items_for_technique(
db,
+5 -5
View File
@@ -29,7 +29,7 @@ def generate_purple_report(
format: str = Query("pdf", pattern="^(pdf|docx|html)$"),
db: Session = Depends(get_db),
user: User = Depends(require_any_role("red_lead", "blue_lead", "viewer")),
):
) -> FileResponse:
"""Generate a Purple Team campaign assessment report."""
filepath = report_generation_service.generate_purple_campaign_report(
db, str(campaign_id), output_format=format,
@@ -48,7 +48,7 @@ def generate_coverage_report(
format: str = Query("pdf", pattern="^(pdf|docx|html)$"),
db: Session = Depends(get_db),
user: User = Depends(require_any_role("red_lead", "blue_lead", "viewer")),
):
) -> FileResponse:
"""Generate an organization-wide MITRE ATT&CK coverage report."""
filepath = report_generation_service.generate_coverage_report(
db, output_format=format,
@@ -67,7 +67,7 @@ def generate_executive_report(
format: str = Query("pdf", pattern="^(pdf|docx|html)$"),
db: Session = Depends(get_db),
user: User = Depends(require_any_role("red_lead", "blue_lead", "viewer")),
):
) -> FileResponse:
"""Generate an executive security summary report."""
filepath = report_generation_service.generate_executive_summary(
db, output_format=format,
@@ -86,7 +86,7 @@ def generate_quarterly_report(
format: str = Query("pdf", pattern="^(pdf|docx|html)$"),
db: Session = Depends(get_db),
user: User = Depends(require_any_role("red_lead", "blue_lead", "viewer")),
):
) -> FileResponse:
"""Generate a quarterly security summary report."""
filepath = report_generation_service.generate_quarterly_summary(
db, output_format=format,
@@ -106,7 +106,7 @@ def generate_technique_report(
format: str = Query("pdf", pattern="^(pdf|docx|html)$"),
db: Session = Depends(get_db),
user: User = Depends(get_current_user),
):
) -> FileResponse:
"""Generate a detailed report for one MITRE technique."""
filepath = report_generation_service.generate_technique_detail_report(
db, str(technique_id), output_format=format,
+4 -4
View File
@@ -38,7 +38,7 @@ def coverage_summary(
platform: Optional[str] = Query(None, description="Filter by platform (in techniques)"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> dict:
"""Full coverage report as JSON — technique-by-technique with test counts."""
return build_coverage_summary(db, tactic=tactic, platform=platform)
@@ -49,7 +49,7 @@ def coverage_csv(
platform: Optional[str] = Query(None),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> StreamingResponse:
"""Export coverage as a downloadable CSV."""
rows = build_coverage_csv_rows(db, tactic=tactic, platform=platform)
@@ -74,7 +74,7 @@ def test_results(
date_to: Optional[str] = Query(None, description="ISO date string YYYY-MM-DD"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> dict:
"""Report of test results with optional filters."""
return build_test_results_report(db, state=state, date_from=date_from, date_to=date_to)
@@ -84,6 +84,6 @@ def remediation_status(
status: Optional[str] = Query(None, description="Filter by remediation status"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> dict:
"""Report of remediation status across all tests."""
return build_remediation_status_report(db, status=status)
+7 -7
View File
@@ -35,7 +35,7 @@ def score_technique(
mitre_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> dict:
"""Get detailed score with breakdown for a specific technique."""
return score_technique_by_mitre_id(db, mitre_id)
@@ -48,7 +48,7 @@ def score_tactic(
tactic: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> dict:
"""Get average score for a tactic."""
return calculate_tactic_score(tactic, db)
@@ -61,7 +61,7 @@ def score_threat_actor(
actor_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> dict:
"""Get coverage score against a specific threat actor."""
return score_actor_by_id(db, actor_id)
@@ -73,7 +73,7 @@ def score_threat_actor(
def score_organization(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> dict:
"""Get the overall organization security score (cached for 5 min)."""
from app.services.score_cache import get_organization_score_cached
@@ -88,7 +88,7 @@ def score_history(
period: str = Query("90d", pattern="^(30d|90d|1y)$"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> dict:
"""Get historical score data points (weekly)."""
return get_score_history(db, period)
@@ -100,7 +100,7 @@ def score_history(
def get_scoring_config(
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
) -> dict:
"""Get current scoring weights (admin only)."""
return get_weights_dict(db)
@@ -123,7 +123,7 @@ def update_scoring_config(
payload: ScoringConfigUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
) -> dict:
"""Update scoring weights (admin only).
Weights are persisted in the database and survive restarts.
+6 -6
View File
@@ -52,7 +52,7 @@ def list_snapshots(
limit: int = Query(50, ge=1, le=200),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> list:
"""List coverage snapshots ordered by creation date (newest first)."""
return list_snapshots_svc(db, offset=offset, limit=limit)
@@ -66,7 +66,7 @@ def create_snapshot_endpoint(
payload: SnapshotCreate,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead", "admin")),
):
) -> dict:
"""Create a manual coverage snapshot with an optional name."""
snapshot = create_snapshot(db, name=payload.name, user_id=current_user.id)
@@ -94,7 +94,7 @@ def coverage_evolution(
months: int = Query(12, ge=1, le=36),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> list:
"""Return coverage snapshots for trend charts (last *months* months)."""
return get_coverage_evolution(db, months=months)
@@ -109,7 +109,7 @@ def compare_snapshots_endpoint(
b: str = Query(..., description="Snapshot B ID"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> dict:
"""Compare two snapshots showing improved, worsened, and unchanged techniques."""
try:
a_id = uuid.UUID(a)
@@ -129,7 +129,7 @@ def get_snapshot(
snapshot_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> dict:
"""Get detailed snapshot information including per-technique states."""
return get_snapshot_detail(db, snapshot_id)
@@ -143,7 +143,7 @@ def delete_snapshot_endpoint(
snapshot_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
) -> dict:
"""Delete a snapshot (admin only)."""
snapshot = get_snapshot_or_raise(db, snapshot_id)
+4 -4
View File
@@ -30,7 +30,7 @@ def trigger_mitre_sync(
request: Request,
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
) -> dict:
"""Manually trigger a MITRE ATT&CK synchronisation.
**Requires** the ``admin`` role.
@@ -50,7 +50,7 @@ def trigger_mitre_sync(
def trigger_intel_scan(
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
) -> dict:
"""Manually trigger a threat-intelligence scan.
**Requires** the ``admin`` role.
@@ -71,7 +71,7 @@ def trigger_atomic_import(
request: Request,
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
) -> dict:
"""Trigger an import of Atomic Red Team tests as TestTemplates.
**Requires** the ``admin`` role.
@@ -101,7 +101,7 @@ def trigger_atomic_import(
@router.get("/scheduler-status")
def scheduler_status(
current_user: User = Depends(require_role("admin")),
):
) -> dict:
"""Return the current state of the background scheduler.
**Requires** the ``admin`` role.
+5 -5
View File
@@ -45,7 +45,7 @@ def list_techniques(
review_required: bool | None = Query(None, description="Filter by review flag"),
repo: SATechniqueRepository = Depends(get_technique_repository),
current_user: User = Depends(get_current_user),
):
) -> list:
"""Return a lightweight list of techniques, optionally filtered."""
return repo.list_all(
tactic=tactic,
@@ -64,7 +64,7 @@ def get_technique(
mitre_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> dict:
"""Return full details for a single technique, including its tests and D3FEND defenses."""
return get_technique_detail(db, mitre_id)
@@ -84,7 +84,7 @@ def create_technique(
db: Session = Depends(get_db),
repo: SATechniqueRepository = Depends(get_technique_repository),
current_user: User = Depends(require_role("admin")),
):
) -> TechniqueOut:
"""Create a new technique manually."""
if repo.exists_by_mitre_id(payload.mitre_id):
raise DuplicateEntityError("Technique", "mitre_id", payload.mitre_id)
@@ -124,7 +124,7 @@ def update_technique(
db: Session = Depends(get_db),
repo: SATechniqueRepository = Depends(get_technique_repository),
current_user: User = Depends(require_role("admin")),
):
) -> TechniqueOut:
"""Update one or more fields of an existing technique."""
entity = repo.find_by_mitre_id(mitre_id)
if entity is None:
@@ -160,7 +160,7 @@ def review_technique(
db: Session = Depends(get_db),
repo: SATechniqueRepository = Depends(get_technique_repository),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
) -> TechniqueOut:
"""Mark a technique as reviewed.
Sets ``review_required`` to *False* and records the current timestamp
+9 -9
View File
@@ -78,7 +78,7 @@ def _list_templates_handler(
limit: int = Query(50, ge=1, le=200),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> list:
"""Return a paginated, filterable list of test templates."""
return list_templates(
db,
@@ -102,7 +102,7 @@ def _list_templates_handler(
def template_stats(
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
) -> dict:
"""Return catalog statistics: active, by_source, by_platform."""
return get_template_stats(db)
@@ -117,7 +117,7 @@ def bulk_activate_templates(
activate: bool = Query(True, description="True to activate all, False to deactivate all"),
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
) -> dict:
"""Set all templates to active or inactive."""
count = bulk_activate(db, activate=activate)
with UnitOfWork(db) as uow:
@@ -148,7 +148,7 @@ def _templates_by_technique_handler(
mitre_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> list:
"""Return all active templates mapped to a specific MITRE technique."""
return templates_by_technique(db, mitre_id)
@@ -163,7 +163,7 @@ def get_template(
template_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> TestTemplateOut:
"""Return full details for a single test template."""
return get_template_or_raise(db, template_id)
@@ -182,7 +182,7 @@ def create_template(
payload: TestTemplateCreate,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
) -> TestTemplateOut:
"""Create a custom test template."""
template = create_template_svc(db, **payload.model_dump())
with UnitOfWork(db) as uow:
@@ -215,7 +215,7 @@ def update_template(
payload: TestTemplateCreate,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
) -> TestTemplateOut:
"""Update fields of an existing test template."""
template = update_template_svc(db, template_id, **payload.model_dump(exclude_unset=True))
with UnitOfWork(db) as uow:
@@ -243,7 +243,7 @@ def toggle_template_active(
template_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
) -> TestTemplateOut:
"""Toggle a template between active and inactive (is_active = not is_active)."""
template = toggle_template_active_svc(db, template_id)
with UnitOfWork(db) as uow:
@@ -271,7 +271,7 @@ def delete_template(
template_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
) -> dict:
"""Soft-delete a test template by setting ``is_active=False``."""
template = get_template_or_raise(db, template_id)
soft_delete_template(db, template_id)
+19 -19
View File
@@ -126,7 +126,7 @@ def list_tests(
limit: int = Query(50, ge=1, le=200),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> list:
"""Return a paginated list of tests, optionally filtered by state, technique, platform or creator."""
return crud_list_tests(
db,
@@ -156,7 +156,7 @@ def create_test(
payload: TestCreate,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
) -> TestOut:
"""Create a new test linked to an existing technique.
``created_by`` is set automatically and ``state`` defaults to *draft*.
@@ -198,7 +198,7 @@ def create_test_from_template(
payload: TestTemplateInstantiate,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
) -> TestOut:
"""Instantiate a real Test from an existing TestTemplate.
The template's fields are copied into the new test as starting data.
@@ -238,7 +238,7 @@ def get_test(
test_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> TestOut:
"""Return full details for a single test, including its evidences."""
return crud_get_test_detail(db, test_id)
@@ -254,7 +254,7 @@ def update_test(
payload: TestUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
) -> TestOut:
"""Update one or more fields of an existing test.
Only leads or admins can update general test fields.
@@ -294,7 +294,7 @@ def update_test_classification(
payload: TestClassificationUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
) -> TestOut:
"""Update the data classification label for a test (admin only)."""
with UnitOfWork(db) as uow:
test = crud_get_test_or_raise(db, test_id)
@@ -324,7 +324,7 @@ def update_test_red(
payload: TestRedUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_tech", "red_lead")),
):
) -> TestOut:
"""Red Team updates their fields (allowed in ``draft`` and ``red_executing``)."""
update_data = payload.model_dump(exclude_unset=True)
with UnitOfWork(db) as uow:
@@ -354,7 +354,7 @@ def update_test_blue(
payload: TestBlueUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("blue_tech", "blue_lead")),
):
) -> TestOut:
"""Blue Team updates their fields (allowed only in ``blue_evaluating``)."""
update_data = payload.model_dump(exclude_unset=True)
with UnitOfWork(db) as uow:
@@ -383,7 +383,7 @@ def start_execution(
test_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_tech", "red_lead")),
):
) -> TestOut:
"""Move a test from ``draft`` to ``red_executing``."""
test = crud_get_test_or_raise(db, test_id)
with UnitOfWork(db) as uow:
@@ -403,7 +403,7 @@ def submit_red(
test_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_tech", "red_lead")),
):
) -> TestOut:
"""Red Team finalises — move from ``red_executing`` to ``blue_evaluating``."""
test = crud_get_test_or_raise(db, test_id)
with UnitOfWork(db) as uow:
@@ -423,7 +423,7 @@ def submit_blue(
test_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("blue_tech", "blue_lead")),
):
) -> TestOut:
"""Blue Team finalises — move from ``blue_evaluating`` to ``in_review``."""
test = crud_get_test_or_raise(db, test_id)
with UnitOfWork(db) as uow:
@@ -443,7 +443,7 @@ def pause_timer(
test_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_tech", "blue_tech", "red_lead", "blue_lead")),
):
) -> TestOut:
"""Pause the running timer for the current phase (red_executing or blue_evaluating)."""
test = crud_get_test_or_raise(db, test_id)
with UnitOfWork(db) as uow:
@@ -463,7 +463,7 @@ def resume_timer(
test_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_tech", "blue_tech", "red_lead", "blue_lead")),
):
) -> TestOut:
"""Resume the paused timer for the current phase."""
test = crud_get_test_or_raise(db, test_id)
with UnitOfWork(db) as uow:
@@ -484,7 +484,7 @@ def validate_red(
payload: TestRedValidate,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead")),
):
) -> TestOut:
"""Red Lead approves or rejects the red side of a test."""
test = crud_get_test_with_technique(db, test_id)
with UnitOfWork(db) as uow:
@@ -511,7 +511,7 @@ def validate_blue(
payload: TestBlueValidate,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("blue_lead")),
):
) -> TestOut:
"""Blue Lead approves or rejects the blue side of a test."""
test = crud_get_test_with_technique(db, test_id)
with UnitOfWork(db) as uow:
@@ -537,7 +537,7 @@ def reopen(
test_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
) -> TestOut:
"""Reopen a rejected test, moving it back to ``draft``."""
test = crud_get_test_or_raise(db, test_id)
with UnitOfWork(db) as uow:
@@ -558,7 +558,7 @@ def update_remediation(
payload: TestRemediationUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
) -> TestOut:
"""Update remediation fields on a test.
When ``remediation_status`` transitions to ``'completed'``, an automatic
@@ -602,7 +602,7 @@ def get_test_timeline(
test_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> list:
"""Return the chronological audit-log history for a test."""
return crud_get_test_timeline(db, test_id)
@@ -617,7 +617,7 @@ def get_retest_chain(
test_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> list:
"""Return the full chain of retests (original + all retests) for a test."""
chain = wf_get_retest_chain(db, test_id)
if not chain:
+4 -4
View File
@@ -36,7 +36,7 @@ def list_threat_actors(
limit: int = Query(50, ge=1, le=200),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> list:
"""List threat actors with optional filters and pagination.
**Requires** authentication (any role).
@@ -58,7 +58,7 @@ def get_threat_actor(
actor_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> dict:
"""Get detailed info about a threat actor including techniques.
**Requires** authentication (any role).
@@ -71,7 +71,7 @@ def get_threat_actor_coverage(
actor_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> dict:
"""Calculate coverage percentage against a specific threat actor.
**Requires** authentication (any role).
@@ -87,7 +87,7 @@ def get_threat_actor_gaps(
actor_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> list:
"""Identify techniques of this actor that are NOT fully validated.
**Requires** authentication (any role).
+4 -4
View File
@@ -30,7 +30,7 @@ router = APIRouter(prefix="/users", tags=["users"])
def list_users_route(
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
) -> list[UserOut]:
"""Return a list of all users. **Requires admin role.**"""
return list_users(db)
@@ -45,7 +45,7 @@ def create_user_route(
payload: UserCreate,
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
) -> UserOut:
"""Create a new user. **Requires admin role.**"""
with UnitOfWork(db) as uow:
user = create_user(
@@ -79,7 +79,7 @@ def get_user(
user_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
) -> UserOut:
"""Return a single user by ID. **Requires admin role.**"""
return get_user_or_raise(db, user_id)
@@ -95,7 +95,7 @@ def update_user_route(
payload: UserUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
) -> UserOut:
"""Update one or more fields of an existing user. **Requires admin role.**"""
update_data = payload.model_dump(exclude_unset=True)
with UnitOfWork(db) as uow:
+4 -4
View File
@@ -56,7 +56,7 @@ def create(
body: WorklogCreate,
db: Session = Depends(get_db),
user: User = Depends(require_any_role("red_tech", "blue_tech", "red_lead", "blue_lead")),
):
) -> WorklogOut:
"""Create a manually-logged worklog entry."""
with UnitOfWork(db) as uow:
wl = worklog_service.create_worklog(
@@ -82,7 +82,7 @@ def list_all(
user_id: Optional[UUID] = None,
db: Session = Depends(get_db),
_user: User = Depends(get_current_user),
):
) -> list[WorklogOut]:
"""List worklogs with optional filters."""
return worklog_service.list_worklogs(
db,
@@ -97,7 +97,7 @@ def get_one(
worklog_id: UUID,
db: Session = Depends(get_db),
_user: User = Depends(get_current_user),
):
) -> WorklogOut:
"""Get a single worklog by ID."""
return worklog_service.get_worklog_or_raise(db, worklog_id)
@@ -107,7 +107,7 @@ def verify_integrity(
worklog_id: UUID,
db: Session = Depends(get_db),
_user: User = Depends(get_current_user),
):
) -> dict:
"""Check whether a worklog's integrity hash is still valid."""
wl = worklog_service.get_worklog_or_raise(db, worklog_id)
return {
+1 -1
View File
@@ -167,7 +167,7 @@ class TestOut(BaseModel):
model_config = ConfigDict(from_attributes=True)
@classmethod
def model_validate(cls, obj, **kwargs):
def model_validate(cls, obj: object, **kwargs: object) -> "TestOut":
"""Override to populate technique fields from the relationship."""
if hasattr(obj, "technique") and obj.technique is not None:
obj.__dict__["technique_mitre_id"] = obj.technique.mitre_id
+10 -8
View File
@@ -16,6 +16,8 @@ import random
import uuid
from datetime import datetime, timedelta
from sqlalchemy.orm import Session
from app.auth import hash_password
from app.database import SessionLocal
from app.models.audit import AuditLog
@@ -102,7 +104,7 @@ TEMPLATE_NAMES = [
# ---------------------------------------------------------------------------
def _cleanup_demo_data(db) -> None:
def _cleanup_demo_data(db: Session) -> None:
"""Remove all previously seeded demo data."""
# Delete in order to respect FK constraints
demo_users = db.query(User).filter(User.username.like(f"{DEMO_PREFIX}%")).all()
@@ -154,7 +156,7 @@ def _cleanup_demo_data(db) -> None:
# ---------------------------------------------------------------------------
def _seed_users(db) -> list[User]:
def _seed_users(db: Session) -> list[User]:
"""Create 5 users per role (25 total)."""
users = []
for role in ROLES:
@@ -173,7 +175,7 @@ def _seed_users(db) -> list[User]:
return users
def _seed_technique_statuses(db, count: int = 50) -> list[Technique]:
def _seed_technique_statuses(db: Session, count: int = 50) -> list[Technique]:
"""Set varied statuses on up to *count* techniques."""
techniques = db.query(Technique).limit(count).all()
if not techniques:
@@ -192,7 +194,7 @@ def _seed_technique_statuses(db, count: int = 50) -> list[Technique]:
return techniques
def _seed_tests(db, users: list[User], techniques: list[Technique], count: int = 100) -> list[Test]:
def _seed_tests(db: Session, users: list[User], techniques: list[Technique], count: int = 100) -> list[Test]:
"""Create *count* tests in various pipeline states."""
if not techniques:
logger.warning("No techniques available — skipping test seeding.")
@@ -267,7 +269,7 @@ def _seed_tests(db, users: list[User], techniques: list[Technique], count: int =
return tests
def _seed_evidences(db, tests: list[Test], users: list[User], count: int = 50) -> list[Evidence]:
def _seed_evidences(db: Session, tests: list[Test], users: list[User], count: int = 50) -> list[Evidence]:
"""Create *count* dummy evidence records."""
if not tests:
return []
@@ -305,7 +307,7 @@ def _seed_evidences(db, tests: list[Test], users: list[User], count: int = 50) -
return evidences
def _seed_audit_logs(db, users: list[User], count: int = 20) -> None:
def _seed_audit_logs(db: Session, users: list[User], count: int = 20) -> None:
"""Create *count* varied audit log entries."""
for i in range(count):
user = random.choice(users)
@@ -323,7 +325,7 @@ def _seed_audit_logs(db, users: list[User], count: int = 20) -> None:
logger.info("Created %d demo audit logs.", count)
def _seed_notifications(db, users: list[User], count: int = 30) -> None:
def _seed_notifications(db: Session, users: list[User], count: int = 30) -> None:
"""Create *count* notifications spread across demo users."""
for i in range(count):
user = random.choice(users)
@@ -344,7 +346,7 @@ def _seed_notifications(db, users: list[User], count: int = 30) -> None:
logger.info("Created %d demo notifications.", count)
def _seed_templates(db, techniques: list[Technique], count: int = 10) -> None:
def _seed_templates(db: Session, techniques: list[Technique], count: int = 10) -> None:
"""Create *count* manual demo templates."""
if not techniques:
return
+2 -1
View File
@@ -4,6 +4,7 @@ from __future__ import annotations
import hashlib
from datetime import datetime, timezone
from uuid import UUID
from sqlalchemy.orm import Session
@@ -35,7 +36,7 @@ def verify_audit_integrity(entry: AuditLog) -> bool:
def log_action(
db: Session,
user_id,
user_id: UUID | None,
action: str,
entity_type: str | None = None,
entity_id: str | None = None,
@@ -192,7 +192,7 @@ def update_campaign(
*,
updater_id: uuid.UUID,
updater_role: str,
**fields,
**fields: object,
) -> dict:
"""Update a campaign. Only allowed in draft or active state.
@@ -8,6 +8,7 @@ Uses the D3FEND public API:
import logging
from typing import Any
from uuid import UUID
import httpx
from sqlalchemy.orm import Session
@@ -26,7 +27,7 @@ D3FEND_TACTICS = ["Detect", "Harden", "Isolate", "Deceive", "Evict", "Model"]
# ── Import all D3FEND techniques ─────────────────────────────────────
def _to_str(v: Any) -> str:
def _to_str(v: Any) -> str: # noqa: ANN401
"""Coerce an RDF value (str, dict with @value, or list) to a plain string."""
if isinstance(v, dict):
return v.get("@value", str(v))
@@ -432,7 +433,7 @@ def sync(db: Session) -> dict:
return summary
def get_defenses_for_technique(db: Session, technique_id) -> list[dict]:
def get_defenses_for_technique(db: Session, technique_id: UUID) -> list[dict]:
"""Get all D3FEND defensive techniques mapped to a given ATT&CK technique."""
mappings = (
db.query(DefensiveTechniqueMapping)
@@ -10,6 +10,7 @@ from __future__ import annotations
from datetime import datetime
from typing import Any
from uuid import UUID
from sqlalchemy.orm import Session
@@ -258,11 +259,11 @@ def get_rules_for_test(db: Session, test_id: str) -> dict[str, Any]:
def evaluate_rule(
db: Session,
*,
test_id: Any,
detection_rule_id: Any,
test_id: UUID,
detection_rule_id: UUID,
triggered: bool | None,
notes: str | None,
evaluator_id: Any,
evaluator_id: UUID,
) -> dict[str, Any]:
"""Save or update the evaluation result for a detection rule on a test.
+7 -6
View File
@@ -10,9 +10,10 @@ no ``db.commit()``.
from __future__ import annotations
import json
from collections.abc import Callable
from sqlalchemy import func, or_
from sqlalchemy.orm import Session
from sqlalchemy.orm import Query, Session
from app.domain.errors import BusinessRuleViolation, EntityNotFoundError
from app.models.campaign import Campaign, CampaignTest
@@ -92,11 +93,11 @@ def _build_layer_skeleton(
def _apply_filters(
query,
model,
query: Query, # type: ignore[type-arg]
model: type,
platforms: list[str] | None = None,
tactics: list[str] | None = None,
):
) -> Query: # type: ignore[type-arg]
"""Apply common platform and tactic filters to a technique query."""
if platforms:
platform_filters = [
@@ -470,7 +471,7 @@ class _LayerRegistry:
self._simple: dict[str, object] = {}
self._with_id: dict[str, object] = {}
def register(self, name: str, builder, *, requires_id: bool = False) -> None:
def register(self, name: str, builder: Callable[..., dict], *, requires_id: bool = False) -> None:
target = self._with_id if requires_id else self._simple
target[name] = builder
@@ -513,7 +514,7 @@ LAYER_REGISTRY.register("campaign", build_campaign_layer, requires_id=True)
SUPPORTED_LAYER_TYPES = LAYER_REGISTRY.supported_types # snapshot of built-in types
def register_layer(name: str, builder, *, requires_id: bool = False) -> None:
def register_layer(name: str, builder: Callable[..., dict], *, requires_id: bool = False) -> None:
"""Public API to register a new heatmap layer type at import time."""
LAYER_REGISTRY.register(name, builder, requires_id=requires_id)
+2 -2
View File
@@ -2,7 +2,7 @@
import logging
from datetime import datetime
from typing import Optional
from typing import Any, Optional
from uuid import UUID
from sqlalchemy.orm import Session
@@ -20,7 +20,7 @@ logger = logging.getLogger(__name__)
_jira_client = None
def get_jira_client():
def get_jira_client() -> Any: # noqa: ANN401 # atlassian.Jira imported lazily from optional dep
"""Return a lazily-initialised Jira client, or raise if disabled."""
global _jira_client
if not settings.JIRA_ENABLED:
+2 -1
View File
@@ -15,6 +15,7 @@ from sqlalchemy.orm import Session
from app.domain.errors import EntityNotFoundError
from app.models.notification import Notification
from app.models.test import Test
from app.models.user import User
# ---------------------------------------------------------------------------
@@ -157,7 +158,7 @@ def cleanup_old_notifications(db: Session, days: int = 90) -> int:
# ---------------------------------------------------------------------------
def notify_test_state_change(db: Session, test, new_state: str) -> None:
def notify_test_state_change(db: Session, test: Test, new_state: str) -> None:
"""Dispatch notifications based on a test's new state.
Called by the workflow service after each state transition.
+6 -4
View File
@@ -10,12 +10,14 @@ stale data does not persist longer than ``CACHE_TTL`` seconds.
import time
from typing import Any, Optional
from sqlalchemy.orm import Session
CACHE_TTL = 300 # 5 minutes
_cache: dict[str, dict[str, Any]] = {}
def get(key: str) -> Optional[Any]:
def get(key: str) -> Optional[Any]: # noqa: ANN401 # generic cache returns whatever was stored
"""Return cached value if present and not expired, else None."""
entry = _cache.get(key)
if entry is None:
@@ -26,7 +28,7 @@ def get(key: str) -> Optional[Any]:
return entry["data"]
def put(key: str, data: Any) -> None:
def put(key: str, data: Any) -> None: # noqa: ANN401 # generic cache accepts any serialisable value
"""Store *data* under *key* with the current timestamp."""
_cache[key] = {"data": data, "ts": time.time()}
@@ -42,7 +44,7 @@ def invalidate(key: Optional[str] = None) -> None:
# ── High-level helpers ────────────────────────────────────────────────
def get_organization_score_cached(db):
def get_organization_score_cached(db: Session) -> dict:
"""Cached wrapper around ``calculate_organization_score``."""
from app.services.scoring_service import calculate_organization_score
@@ -55,7 +57,7 @@ def get_organization_score_cached(db):
return result
def get_operational_metrics_cached(db):
def get_operational_metrics_cached(db: Session) -> dict:
"""Cached wrapper around operational metrics (MTTD, MTTR, efficacy)."""
from app.services.operational_metrics_service import (
calculate_alert_fidelity,
+7 -5
View File
@@ -1,18 +1,20 @@
"""Tempo time-tracking integration service."""
import logging
from typing import Optional
from typing import Any, Optional
from sqlalchemy.orm import Session
from app.config import settings
from app.domain.exceptions import InvalidOperationError
from app.models.jira_link import JiraLink, JiraLinkEntityType
from app.models.test import Test
from app.models.user import User
logger = logging.getLogger(__name__)
def get_tempo_client():
def get_tempo_client() -> Any: # noqa: ANN401 # tempoapiclient.Tempo imported lazily from optional dep
"""Return a Tempo API client, or raise if disabled."""
if not settings.TEMPO_ENABLED:
raise InvalidOperationError("Tempo integration is not enabled")
@@ -52,8 +54,8 @@ def log_worklog(
def auto_log_test_worklog(
db: Session,
test,
user,
test: Test,
user: User,
activity_type: str,
) -> Optional[dict]:
"""If the test has a Jira link, log time to Tempo automatically.
@@ -97,7 +99,7 @@ def auto_log_test_worklog(
return None
def _calculate_duration(test, activity_type: str) -> int:
def _calculate_duration(test: Test, activity_type: str) -> int:
"""Calculate real duration in seconds from the phase timing fields.
Uses the actual start/end timestamps recorded by the workflow buttons,
+4 -4
View File
@@ -63,7 +63,7 @@ def create_test(
*,
technique_id: uuid.UUID,
creator_id: uuid.UUID,
**fields: Any,
**fields: object,
) -> Test:
"""Create a new test linked to an existing technique.
@@ -176,7 +176,7 @@ def update_test(
*,
updater_id: uuid.UUID,
updater_role: str,
**fields: Any,
**fields: object,
) -> Test:
"""Update general test fields (draft or rejected only).
@@ -204,7 +204,7 @@ def update_test(
return test
def update_test_red(db: Session, test_id: uuid.UUID, **fields: Any) -> Test:
def update_test_red(db: Session, test_id: uuid.UUID, **fields: object) -> Test:
"""Update Red Team fields (draft or red_executing only).
Raises BusinessRuleViolation if state not in (draft, red_executing).
@@ -226,7 +226,7 @@ def update_test_red(db: Session, test_id: uuid.UUID, **fields: Any) -> Test:
return test
def update_test_blue(db: Session, test_id: uuid.UUID, **fields: Any) -> Test:
def update_test_blue(db: Session, test_id: uuid.UUID, **fields: object) -> Test:
"""Update Blue Team fields (blue_evaluating only).
Raises BusinessRuleViolation if state is not blue_evaluating.
@@ -13,6 +13,7 @@ session via the Unit of Work pattern.
"""
import logging
import uuid
from datetime import datetime
from sqlalchemy.orm import Session
@@ -530,7 +531,7 @@ def handle_remediation_completed(db: Session, test: Test, user: User) -> Test |
return retest
def get_retest_chain(db: Session, test_id) -> list[Test]:
def get_retest_chain(db: Session, test_id: uuid.UUID) -> list[Test]:
"""Return the full chain of retests for a given test.
Includes the original test and all subsequent retests, ordered
+3 -1
View File
@@ -7,11 +7,13 @@ line-length = 120
# F — pyflakes (unused imports, undefined names)
# I — isort (import ordering per PEP8 convention)
# N — pep8-naming (class/function/variable naming conventions)
select = ["E", "W", "F", "I", "N"]
# ANN — flake8-annotations (type hint enforcement)
select = ["E", "W", "F", "I", "N", "ANN"]
ignore = [
# SQLAlchemy filter syntax requires `== True` / `== False` comparisons
"E712",
# ANN101/ANN102 (self/cls type annotations) removed from ruff — not needed
]
[lint.per-file-ignores]