From 394d5d9056a7bc9234575e654c22ba0320f8f842 Mon Sep 17 00:00:00 2001 From: kitos Date: Tue, 9 Jun 2026 17:04:51 +0200 Subject: [PATCH] refactor(types): add comprehensive type annotations across backend Python codebase MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Enable ANN rules in ruff.toml (flake8-annotations) and resolve all 221 violations: ANN201/ANN202 — return types on 168 public/private functions: - All 28 FastAPI routers: endpoints annotated with dict/list/specific schema/ StreamingResponse/FileResponse/JSONResponse as appropriate - main.py: lifespan→AsyncGenerator[None,None], exception handlers→JSONResponse - database.py: get_db→Generator[Session,None,None], proxy methods→correct types - middleware/request_context.py: dispatch→Response with Callable call_next type ANN001/ANN002/ANN003 — 32 missing argument types: - seed_demo.py: all db parameters typed as Session - domain/unit_of_work.py: __aexit__ exc_type/exc_val/exc_tb typed with TracebackType - services: audit_service user_id→UUID|None, heatmap_service query/model/builder, notification_service test→Test, tempo_service test→Test/user→User, test_workflow_service test_id→UUID, campaign_crud **fields→object, test_crud **fields→object (4 sites) ANN401 — 16 Any usages resolved: - Domain entities (campaign/technique/threat_actor/test_entity): replaced Any with actual ORM types via TYPE_CHECKING guards to avoid circular imports - detection_rule_service: test_id/detection_rule_id/evaluator_id→UUID - score_cache: kept Any with # noqa: ANN401 (genuinely generic cache) - jira_service/tempo_service: kept Any with # noqa: ANN401 (lazy optional deps) - d3fend_import_service: _to_str(v: Any) kept with # noqa: ANN401 ANN204/ANN205/ANN206 — special/static/class methods: - database.py proxy __call__/__getattr__: *args: object/**kwargs: object - schemas/test.py model_validate: obj→object, **kwargs→object - sa_technique_repository._int_type→type All 439 unit tests pass. ruff check app/ → All checks passed! Co-Authored-By: Claude Sonnet 4.6 --- backend/app/database.py | 17 +++++---- backend/app/dependencies/auth.py | 5 ++- backend/app/domain/entities/campaign.py | 7 +++- backend/app/domain/entities/technique.py | 9 +++-- backend/app/domain/entities/threat_actor.py | 7 +++- backend/app/domain/test_entity.py | 9 +++-- backend/app/domain/unit_of_work.py | 9 ++++- .../repositories/sa_technique_repository.py | 2 +- backend/app/main.py | 11 +++--- backend/app/middleware/request_context.py | 8 +++- backend/app/routers/advanced_metrics.py | 8 ++-- backend/app/routers/analytics.py | 8 ++-- backend/app/routers/audit.py | 6 +-- backend/app/routers/auth.py | 8 ++-- backend/app/routers/campaigns.py | 24 ++++++------ backend/app/routers/compliance.py | 14 +++---- backend/app/routers/d3fend.py | 8 ++-- backend/app/routers/data_sources.py | 10 ++--- backend/app/routers/detection_rules.py | 10 ++--- backend/app/routers/evidence.py | 8 ++-- backend/app/routers/heatmap.py | 10 ++--- backend/app/routers/jira.py | 12 +++--- backend/app/routers/metrics.py | 12 +++--- backend/app/routers/notifications.py | 8 ++-- backend/app/routers/operational_metrics.py | 6 +-- backend/app/routers/osint.py | 10 ++--- backend/app/routers/professional_reports.py | 10 ++--- backend/app/routers/reports.py | 8 ++-- backend/app/routers/scores.py | 14 +++---- backend/app/routers/snapshots.py | 12 +++--- backend/app/routers/system.py | 8 ++-- backend/app/routers/techniques.py | 10 ++--- backend/app/routers/test_templates.py | 18 ++++----- backend/app/routers/tests.py | 38 +++++++++---------- backend/app/routers/threat_actors.py | 8 ++-- backend/app/routers/users.py | 8 ++-- backend/app/routers/worklogs.py | 8 ++-- backend/app/schemas/test.py | 2 +- backend/app/seed_demo.py | 18 +++++---- backend/app/services/audit_service.py | 3 +- backend/app/services/campaign_crud_service.py | 2 +- backend/app/services/d3fend_import_service.py | 5 ++- .../app/services/detection_rule_service.py | 7 ++-- backend/app/services/heatmap_service.py | 13 ++++--- backend/app/services/jira_service.py | 4 +- backend/app/services/notification_service.py | 3 +- backend/app/services/score_cache.py | 10 +++-- backend/app/services/tempo_service.py | 12 +++--- backend/app/services/test_crud_service.py | 8 ++-- backend/app/services/test_workflow_service.py | 3 +- backend/ruff.toml | 12 +++--- 51 files changed, 267 insertions(+), 223 deletions(-) diff --git a/backend/app/database.py b/backend/app/database.py index d84b999..6405ffe 100644 --- a/backend/app/database.py +++ b/backend/app/database.py @@ -1,5 +1,8 @@ +from collections.abc import Generator + from sqlalchemy import create_engine -from sqlalchemy.orm import declarative_base, sessionmaker +from sqlalchemy.engine import Engine +from sqlalchemy.orm import Session, declarative_base, sessionmaker Base = declarative_base() @@ -10,7 +13,7 @@ _engine = None _SessionLocal = None -def _get_engine(): +def _get_engine() -> Engine: global _engine if _engine is None: from app.config import settings @@ -28,7 +31,7 @@ def _get_engine(): return _engine -def _get_session_factory(): +def _get_session_factory() -> sessionmaker: global _SessionLocal if _SessionLocal is None: _SessionLocal = sessionmaker( @@ -41,10 +44,10 @@ class _LazySessionLocal: """Proxy so ``SessionLocal()`` keeps working as before but the real sessionmaker is only created on first call.""" - def __call__(self, *args, **kwargs): + def __call__(self, *args: object, **kwargs: object) -> Session: return _get_session_factory()(*args, **kwargs) - def __getattr__(self, name): + def __getattr__(self, name: str) -> object: return getattr(_get_session_factory(), name) @@ -53,14 +56,14 @@ SessionLocal = _LazySessionLocal() class _EngineProxy: """Thin proxy so ``from app.database import engine`` still works.""" - def __getattr__(self, name): + def __getattr__(self, name: str) -> object: return getattr(_get_engine(), name) engine = _EngineProxy() # type: ignore[assignment] -def get_db(): +def get_db() -> Generator[Session, None, None]: db = SessionLocal() try: yield db diff --git a/backend/app/dependencies/auth.py b/backend/app/dependencies/auth.py index 9771d02..72a5895 100644 --- a/backend/app/dependencies/auth.py +++ b/backend/app/dependencies/auth.py @@ -8,6 +8,7 @@ Provides: (admins always pass). """ +from collections.abc import Callable from typing import Optional from fastapi import Cookie, Depends, HTTPException, status @@ -112,7 +113,7 @@ async def require_password_changed( return current_user -def require_role(required_role: str): +def require_role(required_role: str) -> Callable[..., object]: """Return a FastAPI dependency that enforces *required_role*. The dependency allows the request to proceed when @@ -133,7 +134,7 @@ def require_role(required_role: str): return role_checker -def require_any_role(*roles: str): +def require_any_role(*roles: str) -> Callable[..., object]: """Return a FastAPI dependency that enforces **any** of the given *roles*. Admins always pass. Usage example:: diff --git a/backend/app/domain/entities/campaign.py b/backend/app/domain/entities/campaign.py index 02c1487..0482f19 100644 --- a/backend/app/domain/entities/campaign.py +++ b/backend/app/domain/entities/campaign.py @@ -8,10 +8,13 @@ from __future__ import annotations import enum import uuid from dataclasses import dataclass, field -from typing import Any +from typing import TYPE_CHECKING from app.domain.errors import BusinessRuleViolation, InvalidStateTransition +if TYPE_CHECKING: + from app.models.campaign import Campaign as CampaignORM + class CampaignStatus(str, enum.Enum): draft = "draft" @@ -86,7 +89,7 @@ class CampaignEntity: ) @classmethod - def from_orm(cls, orm: Any) -> CampaignEntity: + def from_orm(cls, orm: CampaignORM) -> CampaignEntity: """Build a CampaignEntity from a SQLAlchemy Campaign model.""" test_count = len(getattr(orm, "campaign_tests", None) or []) return cls( diff --git a/backend/app/domain/entities/technique.py b/backend/app/domain/entities/technique.py index 2184df5..346bd44 100644 --- a/backend/app/domain/entities/technique.py +++ b/backend/app/domain/entities/technique.py @@ -17,11 +17,14 @@ from __future__ import annotations import uuid from dataclasses import dataclass, field from datetime import datetime -from typing import Any +from typing import TYPE_CHECKING from app.domain.enums import TechniqueStatus, TestResult, TestState from app.domain.value_objects.mitre_id import MitreId +if TYPE_CHECKING: + from app.models.technique import Technique as TechniqueORM + @dataclass(frozen=True) class _TestSnapshot: @@ -76,7 +79,7 @@ class TechniqueEntity: ) @classmethod - def from_orm(cls, model: Any) -> TechniqueEntity: + def from_orm(cls, model: TechniqueORM) -> TechniqueEntity: """Build a TechniqueEntity from a SQLAlchemy Technique model.""" raw_status = model.status_global if raw_status is None: @@ -101,7 +104,7 @@ class TechniqueEntity: mitre_last_modified=getattr(model, "mitre_last_modified", None), ) - def apply_to(self, model: Any) -> None: + def apply_to(self, model: TechniqueORM) -> None: """Copy mutable fields back onto the ORM model.""" model.status_global = self.status_global model.review_required = self.review_required diff --git a/backend/app/domain/entities/threat_actor.py b/backend/app/domain/entities/threat_actor.py index d477014..42821be 100644 --- a/backend/app/domain/entities/threat_actor.py +++ b/backend/app/domain/entities/threat_actor.py @@ -7,7 +7,10 @@ from __future__ import annotations import uuid from dataclasses import dataclass, field -from typing import Any +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from app.models.threat_actor import ThreatActor as ThreatActorORM @dataclass @@ -63,7 +66,7 @@ class ThreatActorEntity: return round(len(self.covered_techniques) / len(self.techniques) * 100, 1) @classmethod - def from_orm(cls, orm: Any) -> ThreatActorEntity: + def from_orm(cls, orm: ThreatActorORM) -> ThreatActorEntity: techs: list[ThreatActorTechniqueRef] = [] for tat in getattr(orm, "techniques", None) or []: technique = getattr(tat, "technique", None) diff --git a/backend/app/domain/test_entity.py b/backend/app/domain/test_entity.py index e91cdca..ec15ffb 100644 --- a/backend/app/domain/test_entity.py +++ b/backend/app/domain/test_entity.py @@ -26,7 +26,7 @@ import enum import uuid from dataclasses import dataclass, field from datetime import datetime -from typing import Any +from typing import TYPE_CHECKING, Any from app.domain.errors import ( BusinessRuleViolation, @@ -34,6 +34,9 @@ from app.domain.errors import ( InvalidStateTransition, ) +if TYPE_CHECKING: + from app.models.test import Test as TestORM + # ── Value objects ──────────────────────────────────────────────────── @@ -103,7 +106,7 @@ class TestEntity: # -- Factory -------------------------------------------------------- @classmethod - def from_orm(cls, model: Any) -> TestEntity: + def from_orm(cls, model: TestORM) -> TestEntity: """Build a TestEntity from a SQLAlchemy ``Test`` model instance.""" raw_state = model.state state = raw_state if isinstance(raw_state, TestState) else TestState(raw_state) @@ -126,7 +129,7 @@ class TestEntity: blue_paused_seconds=model.blue_paused_seconds or 0, ) - def apply_to(self, model: Any) -> None: + def apply_to(self, model: TestORM) -> None: """Copy the entity's mutable fields back onto the ORM model.""" model.state = self.state model.red_validation_status = self.red_validation_status diff --git a/backend/app/domain/unit_of_work.py b/backend/app/domain/unit_of_work.py index 83b2400..8cf50c3 100644 --- a/backend/app/domain/unit_of_work.py +++ b/backend/app/domain/unit_of_work.py @@ -22,6 +22,8 @@ Services should **never** call ``db.commit()``; they use ``db.add()`` / from __future__ import annotations +from types import TracebackType + from sqlalchemy.orm import Session @@ -36,7 +38,12 @@ class UnitOfWork: def __enter__(self) -> "UnitOfWork": return self - def __exit__(self, exc_type, exc_val, exc_tb) -> None: + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: if exc_type is not None: self.rollback() diff --git a/backend/app/infrastructure/persistence/repositories/sa_technique_repository.py b/backend/app/infrastructure/persistence/repositories/sa_technique_repository.py index f1a3828..0582d53 100644 --- a/backend/app/infrastructure/persistence/repositories/sa_technique_repository.py +++ b/backend/app/infrastructure/persistence/repositories/sa_technique_repository.py @@ -197,7 +197,7 @@ class SATechniqueRepository: # -- Internal ---------------------------------------------------------- @staticmethod - def _int_type(): + def _int_type() -> type: """Return an Integer type for CAST expressions (SQLite-compatible).""" from sqlalchemy import Integer return Integer diff --git a/backend/app/main.py b/backend/app/main.py index a0af92f..b17d07d 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -1,5 +1,6 @@ import logging import os +from collections.abc import AsyncGenerator from contextlib import asynccontextmanager from fastapi import FastAPI, Request, status @@ -53,7 +54,7 @@ setup_logging() _IS_PRODUCTION = os.environ.get("AEGIS_ENV", "").lower() == "production" @asynccontextmanager -async def lifespan(app: FastAPI): +async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: """Startup / shutdown logic.""" ensure_bucket_exists() start_scheduler() @@ -124,7 +125,7 @@ app.include_router(osint_router.router, prefix="/api/v1") @app.get("/health", include_in_schema=False) -def health(): +def health() -> dict[str, str]: """Minimal health check — returns only an HTTP 200 with no service metadata. Access is restricted to internal networks at the Nginx level @@ -149,7 +150,7 @@ def _serialize_validation_errors(exc: RequestValidationError) -> list[dict]: @app.exception_handler(RequestValidationError) -async def validation_exception_handler(request: Request, exc: RequestValidationError): +async def validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse: """Handle validation errors with consistent format.""" return JSONResponse( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, @@ -162,7 +163,7 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE @app.exception_handler(SQLAlchemyError) -async def sqlalchemy_exception_handler(request: Request, exc: SQLAlchemyError): +async def sqlalchemy_exception_handler(request: Request, exc: SQLAlchemyError) -> JSONResponse: """Handle database errors.""" logging.error(f"Database error: {exc}") return JSONResponse( @@ -175,7 +176,7 @@ async def sqlalchemy_exception_handler(request: Request, exc: SQLAlchemyError): @app.exception_handler(Exception) -async def general_exception_handler(request: Request, exc: Exception): +async def general_exception_handler(request: Request, exc: Exception) -> JSONResponse: """Handle all unhandled exceptions.""" logging.error(f"Unhandled exception: {exc}") return JSONResponse( diff --git a/backend/app/middleware/request_context.py b/backend/app/middleware/request_context.py index f49ef57..79a588a 100644 --- a/backend/app/middleware/request_context.py +++ b/backend/app/middleware/request_context.py @@ -1,9 +1,11 @@ """Request context middleware — captures client IP and User-Agent per request.""" +from collections.abc import Awaitable, Callable from contextvars import ContextVar from fastapi import Request from starlette.middleware.base import BaseHTTPMiddleware +from starlette.responses import Response request_ip: ContextVar[str] = ContextVar("request_ip", default="") request_user_agent: ContextVar[str] = ContextVar("request_user_agent", default="") @@ -20,7 +22,11 @@ def resolve_client_ip(request: Request) -> str: class RequestContextMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request: Request, call_next): + async def dispatch( + self, + request: Request, + call_next: Callable[[Request], Awaitable[Response]], + ) -> Response: request_ip.set(resolve_client_ip(request)) request_user_agent.set(request.headers.get("User-Agent", "")) return await call_next(request) diff --git a/backend/app/routers/advanced_metrics.py b/backend/app/routers/advanced_metrics.py index 7308d23..0de7661 100644 --- a/backend/app/routers/advanced_metrics.py +++ b/backend/app/routers/advanced_metrics.py @@ -15,7 +15,7 @@ router = APIRouter(prefix="/metrics/advanced", tags=["advanced-metrics"]) def coverage_by_tactic( db: Session = Depends(get_db), user: User = Depends(get_current_user), -): +) -> list: """Coverage percentage broken down by MITRE ATT&CK tactic.""" return advanced_metrics_service.get_coverage_by_tactic(db) @@ -24,7 +24,7 @@ def coverage_by_tactic( def never_tested_techniques( db: Session = Depends(get_db), user: User = Depends(get_current_user), -): +) -> list: """Techniques that have never had a test created.""" return advanced_metrics_service.get_never_tested_techniques(db) @@ -33,7 +33,7 @@ def never_tested_techniques( def avg_validation_time( db: Session = Depends(get_db), user: User = Depends(get_current_user), -): +) -> dict: """Average time from test creation to validation, computed from audit logs. Returns overall average and per-phase averages where data is available. @@ -45,6 +45,6 @@ def avg_validation_time( def detection_rate_trend( db: Session = Depends(get_db), user: User = Depends(get_current_user), -): +) -> list: """Monthly detection rate trend for the last 12 months.""" return advanced_metrics_service.get_detection_rate_trend(db) diff --git a/backend/app/routers/analytics.py b/backend/app/routers/analytics.py index 4997888..562bebc 100644 --- a/backend/app/routers/analytics.py +++ b/backend/app/routers/analytics.py @@ -19,7 +19,7 @@ router = APIRouter(prefix="/analytics", tags=["analytics"]) def analytics_coverage( db: Session = Depends(get_db), user: User = Depends(get_current_user), -): +) -> list: """Coverage per technique — flat format for BI dashboards.""" return analytics_service.get_coverage_analytics(db) @@ -30,7 +30,7 @@ def analytics_tests( date_to: str = Query(None, description="ISO date filter (<=)"), db: Session = Depends(get_db), user: User = Depends(get_current_user), -): +) -> list: """All tests with timestamps — flat format for BI dashboards.""" return analytics_service.get_tests_analytics( db, date_from=date_from, date_to=date_to @@ -41,7 +41,7 @@ def analytics_tests( def analytics_trends( db: Session = Depends(get_db), user: User = Depends(get_current_user), -): +) -> list: """Historical coverage snapshots for trend visualization.""" return analytics_service.get_trends_analytics(db) @@ -50,6 +50,6 @@ def analytics_trends( def analytics_operators( db: Session = Depends(get_db), user: User = Depends(require_role("admin")), -): +) -> list: """Per-operator metrics — for workload management dashboards.""" return analytics_service.get_operators_analytics(db) diff --git a/backend/app/routers/audit.py b/backend/app/routers/audit.py index 0dd257b..a96c393 100644 --- a/backend/app/routers/audit.py +++ b/backend/app/routers/audit.py @@ -30,7 +30,7 @@ def list_audit_logs( limit: int = Query(50, ge=1, le=100, description="Max records to return"), db: Session = Depends(get_db), current_user: User = Depends(require_role("admin")), -): +) -> AuditLogPage: """Return paginated audit logs with optional filters. **Requires admin role.** @@ -57,7 +57,7 @@ def list_audit_logs( def list_actions( db: Session = Depends(get_db), current_user: User = Depends(require_role("admin")), -): +) -> list[str]: """Return a list of distinct action types in the audit log. **Requires admin role.** @@ -69,7 +69,7 @@ def list_actions( def list_entity_types( db: Session = Depends(get_db), current_user: User = Depends(require_role("admin")), -): +) -> list[str]: """Return a list of distinct entity types in the audit log. **Requires admin role.** diff --git a/backend/app/routers/auth.py b/backend/app/routers/auth.py index f8a6c59..087a64f 100644 --- a/backend/app/routers/auth.py +++ b/backend/app/routers/auth.py @@ -46,7 +46,7 @@ def login( response: Response, form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db), -): +) -> TokenResponse: """Authenticate a user and return a JWT access token. Rate-limited to **5 attempts per minute per IP**. Failed and successful @@ -110,7 +110,7 @@ def logout( request: Request, response: Response, aegis_token: str | None = Cookie(None), -): +) -> dict: """Clear the authentication cookie and revoke the current token.""" bearer = ( request.headers.get("Authorization") @@ -148,7 +148,7 @@ def logout( @router.get("/me", response_model=UserOut) -def read_current_user(current_user: User = Depends(get_current_user)): +def read_current_user(current_user: User = Depends(get_current_user)) -> UserOut: """Return the profile of the currently authenticated user.""" return current_user @@ -158,7 +158,7 @@ def change_password( body: PasswordChange, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), -): +) -> dict: """Change the current user's password.""" auth_change_password( db, diff --git a/backend/app/routers/campaigns.py b/backend/app/routers/campaigns.py index 9dc9d61..9fc281b 100644 --- a/backend/app/routers/campaigns.py +++ b/backend/app/routers/campaigns.py @@ -107,7 +107,7 @@ def list_campaigns( limit: int = Query(50, ge=1, le=200), db: Session = Depends(get_db), current_user: User = Depends(get_current_user), -): +) -> list: """List campaigns with optional filters and pagination.""" return crud_list( db, @@ -129,7 +129,7 @@ def create_campaign( payload: CampaignCreate, db: Session = Depends(get_db), current_user: User = Depends(require_any_role("red_lead", "blue_lead")), -): +) -> dict: """Create a new campaign.""" with UnitOfWork(db) as uow: result = crud_create( @@ -165,7 +165,7 @@ def get_campaign( campaign_id: str, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), -): +) -> dict: """Get detailed campaign info including tests and progress.""" return crud_get_detail(db, campaign_id) @@ -180,7 +180,7 @@ def update_campaign( payload: CampaignUpdate, db: Session = Depends(get_db), current_user: User = Depends(require_any_role("red_lead", "blue_lead")), -): +) -> dict: """Update a campaign. Only allowed in draft or active state.""" update_data = payload.model_dump(exclude_unset=True) with UnitOfWork(db) as uow: @@ -214,7 +214,7 @@ def add_test_to_campaign( payload: AddTestPayload, db: Session = Depends(get_db), current_user: User = Depends(require_any_role("red_lead", "blue_lead")), -): +) -> dict: """Add a test to a campaign with optional ordering and dependency.""" with UnitOfWork(db) as uow: result = crud_add_test( @@ -239,7 +239,7 @@ def remove_test_from_campaign( campaign_test_id: str, db: Session = Depends(get_db), current_user: User = Depends(require_any_role("red_lead", "blue_lead")), -): +) -> dict: """Remove a test from a campaign.""" with UnitOfWork(db) as uow: crud_remove_test(db, campaign_id, campaign_test_id) @@ -256,7 +256,7 @@ def activate_campaign( campaign_id: str, db: Session = Depends(get_db), current_user: User = Depends(require_any_role("red_lead", "blue_lead")), -): +) -> dict: """Activate a campaign, moving it from draft to active.""" with UnitOfWork(db) as uow: campaign = crud_activate(db, campaign_id) @@ -292,7 +292,7 @@ def complete_campaign( campaign_id: str, db: Session = Depends(get_db), current_user: User = Depends(require_any_role("red_lead", "admin")), -): +) -> dict: """Mark a campaign as completed.""" with UnitOfWork(db) as uow: campaign = crud_complete(db, campaign_id) @@ -319,7 +319,7 @@ def get_campaign_progress_endpoint( campaign_id: str, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), -): +) -> dict: """Get progress statistics for a campaign.""" return crud_get_progress(db, campaign_id) @@ -333,7 +333,7 @@ def generate_campaign_from_actor( actor_id: str, db: Session = Depends(get_db), current_user: User = Depends(require_any_role("red_lead", "blue_lead")), -): +) -> dict: """Auto-generate a campaign from a threat actor's uncovered techniques. Creates tests from the best available templates and orders them @@ -369,7 +369,7 @@ def schedule_campaign( payload: SchedulePayload, db: Session = Depends(get_db), current_user: User = Depends(require_any_role("red_lead", "blue_lead")), -): +) -> dict: """Configure or update the recurrence schedule for a campaign. Only the campaign creator or admin can change scheduling. @@ -411,6 +411,6 @@ def get_campaign_history( campaign_id: str, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), -): +) -> list: """List all child campaigns (execution history) of a recurring campaign.""" return crud_get_history(db, campaign_id) diff --git a/backend/app/routers/compliance.py b/backend/app/routers/compliance.py index ead0a91..1016e18 100644 --- a/backend/app/routers/compliance.py +++ b/backend/app/routers/compliance.py @@ -34,7 +34,7 @@ router = APIRouter(prefix="/compliance", tags=["compliance"]) def list_frameworks_endpoint( db: Session = Depends(get_db), current_user: User = Depends(get_current_user), -): +) -> list: """List all available compliance frameworks.""" return list_frameworks(db) @@ -47,7 +47,7 @@ def framework_status( framework_id: str, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), -): +) -> dict: """Get compliance status for each control in a framework.""" return get_framework_status(db, framework_id) @@ -60,7 +60,7 @@ def framework_report( framework_id: str, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), -): +) -> dict: """Get the full compliance report (same as status but marked as report).""" return get_framework_status(db, framework_id) @@ -73,7 +73,7 @@ def framework_report_csv( framework_id: str, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), -): +) -> StreamingResponse: """Export compliance report as CSV.""" csv_bytes, filename = build_framework_report_csv(db, framework_id) return StreamingResponse( @@ -93,7 +93,7 @@ def framework_gaps( framework_id: str, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), -): +) -> dict: """Get controls with techniques that are not adequately covered.""" return get_framework_gaps(db, framework_id) @@ -105,7 +105,7 @@ def framework_gaps( def import_nist( db: Session = Depends(get_db), current_user: User = Depends(require_role("admin")), -): +) -> dict: """Import NIST 800-53 Rev 5 mappings (admin only).""" result = import_nist_800_53_mappings(db) return result @@ -115,7 +115,7 @@ def import_nist( def import_cis( db: Session = Depends(get_db), current_user: User = Depends(require_role("admin")), -): +) -> dict: """Import CIS Controls v8 mappings (admin only).""" result = import_cis_controls_v8_mappings(db) return result diff --git a/backend/app/routers/d3fend.py b/backend/app/routers/d3fend.py index 1a61a9e..955c0da 100644 --- a/backend/app/routers/d3fend.py +++ b/backend/app/routers/d3fend.py @@ -38,7 +38,7 @@ def list_defensive_techniques( limit: int = Query(50, ge=1, le=200), db: Session = Depends(get_db), current_user: User = Depends(get_current_user), -): +) -> list: """List all D3FEND defensive techniques with optional filters.""" return list_defensive_techniques_svc( db, tactic=tactic, search=search, offset=offset, limit=limit @@ -53,7 +53,7 @@ def list_defensive_techniques( def list_d3fend_tactics_endpoint( db: Session = Depends(get_db), current_user: User = Depends(get_current_user), -): +) -> list: """Return a list of all D3FEND tactics with counts.""" return list_d3fend_tactics(db) @@ -67,7 +67,7 @@ def get_defenses_for_attack_technique_endpoint( mitre_id: str, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), -): +) -> list: """Get all D3FEND defensive techniques mapped to a given ATT&CK technique.""" return get_defenses_for_attack_technique(db, mitre_id) @@ -80,7 +80,7 @@ def get_defenses_for_attack_technique_endpoint( def trigger_d3fend_import( db: Session = Depends(get_db), current_user: User = Depends(require_role("admin")), -): +) -> dict: """Import D3FEND techniques and ATT&CK mappings. Admin only.""" tech_result = import_d3fend_techniques(db) mapping_result = import_d3fend_mappings(db) diff --git a/backend/app/routers/data_sources.py b/backend/app/routers/data_sources.py index 2a751f8..e670e13 100644 --- a/backend/app/routers/data_sources.py +++ b/backend/app/routers/data_sources.py @@ -47,7 +47,7 @@ router = APIRouter(prefix="/data-sources", tags=["data-sources"]) def list_data_sources( db: Session = Depends(get_db), current_user: User = Depends(require_role("admin")), -): +) -> list: """List all registered data sources. **Requires** the ``admin`` role. @@ -61,7 +61,7 @@ def update_data_source( body: DataSourceUpdate, db: Session = Depends(get_db), current_user: User = Depends(require_role("admin")), -): +) -> dict: """Update a data source (enable/disable, change config). **Requires** the ``admin`` role. @@ -87,7 +87,7 @@ def sync_data_source( source_id: str, db: Session = Depends(get_db), current_user: User = Depends(require_role("admin")), -): +) -> dict: """Trigger sync/import for a specific data source. **Requires** the ``admin`` role. @@ -99,7 +99,7 @@ def sync_data_source( def sync_all_data_sources( db: Session = Depends(get_db), current_user: User = Depends(require_role("admin")), -): +) -> dict: """Trigger sync for all enabled data sources (sequentially). **Requires** the ``admin`` role. @@ -125,7 +125,7 @@ def get_data_source_stats( source_id: str, db: Session = Depends(get_db), current_user: User = Depends(require_role("admin")), -): +) -> dict: """Get detailed statistics for a specific data source. **Requires** the ``admin`` role. diff --git a/backend/app/routers/detection_rules.py b/backend/app/routers/detection_rules.py index f6235a7..ffe9cf5 100644 --- a/backend/app/routers/detection_rules.py +++ b/backend/app/routers/detection_rules.py @@ -51,7 +51,7 @@ def list_detection_rules( limit: int = Query(50, ge=1, le=200), db: Session = Depends(get_db), current_user: User = Depends(get_current_user), -): +) -> list: """List detection rules with optional filters and pagination.""" return list_rules( db, @@ -72,7 +72,7 @@ def get_detection_rules_for_template( template_id: str, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), -): +) -> list: """Get detection rules associated with a test template.""" return get_rules_for_template(db, template_id) @@ -84,7 +84,7 @@ def get_detection_rules_for_template( def auto_associate_detection_rules( db: Session = Depends(get_db), current_user: User = Depends(require_role("admin")), -): +) -> dict: """Auto-associate test templates with detection rules by MITRE technique ID. For each active template, find all active detection rules for the same @@ -102,7 +102,7 @@ def get_detection_rules_for_test( test_id: str, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), -): +) -> list: """Get detection rules relevant to a test, along with their evaluation results. Finds rules by matching the test's technique_id to detection rules, @@ -119,7 +119,7 @@ def evaluate_detection_rule( payload: DetectionRuleEvaluate, db: Session = Depends(get_db), current_user: User = Depends(require_any_role("blue_tech", "blue_lead")), -): +) -> dict: """Save or update the evaluation result for a detection rule on a test.""" return evaluate_rule( db, diff --git a/backend/app/routers/evidence.py b/backend/app/routers/evidence.py index 956aad2..0256546 100644 --- a/backend/app/routers/evidence.py +++ b/backend/app/routers/evidence.py @@ -88,7 +88,7 @@ async def upload_evidence( notes: Optional[str] = Form(None), db: Session = Depends(get_db), current_user: User = Depends(get_current_user), -): +) -> EvidenceOut: """Upload a file as evidence for the given test. The ``team`` field (sent as form data) determines whether this is @@ -154,7 +154,7 @@ def list_evidence( team: Optional[str] = Query(None, description="Filter by team: red or blue"), db: Session = Depends(get_db), current_user: User = Depends(get_current_user), -): +) -> list[EvidenceOut]: """List all evidences for a test, optionally filtered by team.""" get_test_or_raise(db, test_id) evidences = list_evidence_for_test(db, test_id, team=team) @@ -171,7 +171,7 @@ def get_evidence( evidence_id: _uuid.UUID, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), -): +) -> EvidenceOut: """Return evidence metadata together with a presigned download URL.""" evidence = get_evidence_or_raise(db, evidence_id) return _evidence_to_out(evidence) @@ -187,7 +187,7 @@ def delete_evidence( evidence_id: _uuid.UUID, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), -): +) -> dict: """Delete an evidence record. Only allowed in editable states: diff --git a/backend/app/routers/heatmap.py b/backend/app/routers/heatmap.py index 18ec7f4..454a811 100644 --- a/backend/app/routers/heatmap.py +++ b/backend/app/routers/heatmap.py @@ -28,7 +28,7 @@ def heatmap_coverage( min_score: int = Query(0, ge=0, le=100), db: Session = Depends(get_db), current_user: User = Depends(get_current_user), -): +) -> dict: """Coverage layer — score based on status_global of each technique.""" return heatmap_service.build_coverage_layer( db, platforms=platforms, tactics=tactics, min_score=min_score, @@ -43,7 +43,7 @@ def heatmap_threat_actor( min_score: int = Query(0, ge=0, le=100), db: Session = Depends(get_db), current_user: User = Depends(get_current_user), -): +) -> dict: """Threat actor layer — techniques used by an actor with coverage color.""" return heatmap_service.build_threat_actor_layer( db, actor_id, platforms=platforms, tactics=tactics, min_score=min_score, @@ -57,7 +57,7 @@ def heatmap_detection_rules( min_score: int = Query(0, ge=0, le=100), db: Session = Depends(get_db), current_user: User = Depends(get_current_user), -): +) -> dict: """Detection rules layer — score based on ratio of rules available vs total.""" return heatmap_service.build_detection_rules_layer( db, platforms=platforms, tactics=tactics, min_score=min_score, @@ -72,7 +72,7 @@ def heatmap_campaign( min_score: int = Query(0, ge=0, le=100), db: Session = Depends(get_db), current_user: User = Depends(get_current_user), -): +) -> dict: """Campaign layer — only techniques in the campaign, colored by test state.""" return heatmap_service.build_campaign_layer( db, campaign_id, platforms=platforms, tactics=tactics, min_score=min_score, @@ -88,7 +88,7 @@ def export_navigator( min_score: int = Query(0, ge=0, le=100), db: Session = Depends(get_db), current_user: User = Depends(get_current_user), -): +) -> StreamingResponse: """Export a heatmap layer as a downloadable JSON file for ATT&CK Navigator.""" data = heatmap_service.build_navigator_export( db, layer, layer_id=layer_id, diff --git a/backend/app/routers/jira.py b/backend/app/routers/jira.py index 4f19d18..2153e53 100644 --- a/backend/app/routers/jira.py +++ b/backend/app/routers/jira.py @@ -29,7 +29,7 @@ def search_issues( q: str = Query(..., min_length=2), max_results: int = Query(10, le=50), user: User = Depends(get_current_user), -): +) -> list[JiraIssueResult]: """Search Jira issues by JQL or free text.""" return jira_service.search_jira_issues(q, max_results) @@ -39,7 +39,7 @@ def create_link( body: JiraLinkCreate, db: Session = Depends(get_db), user: User = Depends(get_current_user), -): +) -> JiraLinkOut: """Associate an Aegis entity with a Jira issue.""" with UnitOfWork(db) as uow: link = jira_service.create_link( @@ -74,7 +74,7 @@ def list_links( entity_id: Optional[UUID] = None, db: Session = Depends(get_db), user: User = Depends(get_current_user), -): +) -> list[JiraLinkOut]: """List Jira links, optionally filtered by entity.""" return jira_service.list_links( db, @@ -88,7 +88,7 @@ def sync_link( link_id: UUID, db: Session = Depends(get_db), user: User = Depends(require_role("admin")), -): +) -> dict: """Force bidirectional sync for a specific Jira link.""" with UnitOfWork(db) as uow: link = jira_service.get_link_or_raise(db, link_id) @@ -102,7 +102,7 @@ def delete_link( link_id: UUID, db: Session = Depends(get_db), user: User = Depends(get_current_user), -): +) -> None: """Remove a Jira link.""" with UnitOfWork(db) as uow: link = jira_service.delete_link(db, link_id) @@ -123,7 +123,7 @@ def create_issue_from_entity( entity_id: UUID, db: Session = Depends(get_db), user: User = Depends(get_current_user), -): +) -> dict: """Auto-create a Jira issue from an Aegis entity and link them.""" with UnitOfWork(db) as uow: result = jira_service.create_issue_and_link( diff --git a/backend/app/routers/metrics.py b/backend/app/routers/metrics.py index a8adcc3..f3b24f6 100644 --- a/backend/app/routers/metrics.py +++ b/backend/app/routers/metrics.py @@ -42,7 +42,7 @@ router = APIRouter(prefix="/metrics", tags=["metrics"]) def coverage_summary( db: Session = Depends(get_db), current_user: User = Depends(get_current_user), -): +) -> CoverageSummary: """Return a global coverage summary across all techniques.""" return get_coverage_summary(db) @@ -56,7 +56,7 @@ def coverage_summary( def coverage_by_tactic( db: Session = Depends(get_db), current_user: User = Depends(get_current_user), -): +) -> list[TacticCoverage]: """Return coverage breakdown grouped by tactic.""" return get_coverage_by_tactic(db) @@ -70,7 +70,7 @@ def coverage_by_tactic( def test_pipeline( db: Session = Depends(get_db), current_user: User = Depends(get_current_user), -): +) -> TestPipelineCounts: """Return how many tests are in each pipeline state.""" return get_test_pipeline_counts(db) @@ -84,7 +84,7 @@ def test_pipeline( def team_activity( db: Session = Depends(get_db), current_user: User = Depends(get_current_user), -): +) -> list[TeamActivity]: """Return activity summary for Red and Blue teams.""" return get_team_activity(db) @@ -98,7 +98,7 @@ def team_activity( def validation_rate( db: Session = Depends(get_db), current_user: User = Depends(get_current_user), -): +) -> list[ValidationRate]: """Return approval and rejection rates for Red Lead and Blue Lead.""" return get_validation_rate(db) @@ -112,6 +112,6 @@ def validation_rate( def recent_tests( db: Session = Depends(get_db), current_user: User = Depends(get_current_user), -): +) -> list[RecentTestItem]: """Return the 10 most recently created tests.""" return get_recent_tests(db, limit=10) diff --git a/backend/app/routers/notifications.py b/backend/app/routers/notifications.py index c60b8c7..591c724 100644 --- a/backend/app/routers/notifications.py +++ b/backend/app/routers/notifications.py @@ -39,7 +39,7 @@ def list_notifications_endpoint( limit: int = Query(20, ge=1, le=100), db: Session = Depends(get_db), current_user: User = Depends(get_current_user), -): +) -> list[NotificationOut]: """Return paginated notifications for the current user, newest first.""" return list_notifications(db, current_user.id, offset=offset, limit=limit) @@ -53,7 +53,7 @@ def list_notifications_endpoint( def unread_count( db: Session = Depends(get_db), current_user: User = Depends(get_current_user), -): +) -> UnreadCountOut: """Return the number of unread notifications for the current user.""" count = get_unread_count(db, current_user.id) return UnreadCountOut(unread_count=count) @@ -69,7 +69,7 @@ def read_notification( notification_id: uuid.UUID, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), -): +) -> NotificationOut: """Mark a single notification as read.""" with UnitOfWork(db) as uow: notif = mark_as_read(db, notification_id, current_user.id) @@ -86,7 +86,7 @@ def read_notification( def read_all_notifications( db: Session = Depends(get_db), current_user: User = Depends(get_current_user), -): +) -> dict: """Mark all notifications for the current user as read.""" with UnitOfWork(db) as uow: count = mark_all_as_read(db, current_user.id) diff --git a/backend/app/routers/operational_metrics.py b/backend/app/routers/operational_metrics.py index a874a2c..976725d 100644 --- a/backend/app/routers/operational_metrics.py +++ b/backend/app/routers/operational_metrics.py @@ -25,7 +25,7 @@ router = APIRouter(prefix="/metrics/operational", tags=["operational-metrics"]) def operational_metrics( db: Session = Depends(get_db), current_user: User = Depends(get_current_user), -): +) -> dict: """Get all operational metrics (MTTD, MTTR, etc.) — cached for 5 min.""" from app.services.score_cache import get_operational_metrics_cached @@ -40,7 +40,7 @@ def operational_trend( period: str = Query("90d", pattern="^(30d|90d|1y)$"), db: Session = Depends(get_db), current_user: User = Depends(get_current_user), -): +) -> dict: """Get weekly trend data for operational metrics.""" return get_operational_trend(db, period) @@ -52,6 +52,6 @@ def operational_trend( def metrics_by_team( db: Session = Depends(get_db), current_user: User = Depends(get_current_user), -): +) -> dict: """Get metrics broken down by Red Team vs Blue Team.""" return get_metrics_by_team(db) diff --git a/backend/app/routers/osint.py b/backend/app/routers/osint.py index 9670f67..97bd06f 100644 --- a/backend/app/routers/osint.py +++ b/backend/app/routers/osint.py @@ -57,7 +57,7 @@ def list_osint_items( limit: int = Query(50, ge=1, le=200), db: Session = Depends(get_db), user: User = Depends(get_current_user), -): +) -> list: """List OSINT items with optional filters.""" return service_list_osint_items( db, @@ -73,7 +73,7 @@ def list_osint_items( def osint_summary( db: Session = Depends(get_db), user: User = Depends(get_current_user), -): +) -> dict: """Summary statistics for OSINT items.""" return get_osint_summary(db) @@ -83,7 +83,7 @@ def review_osint_item( item_id: UUID, db: Session = Depends(get_db), user: User = Depends(get_current_user), -): +) -> dict: """Mark an OSINT item as reviewed.""" with UnitOfWork(db) as uow: item = mark_osint_reviewed(db, str(item_id)) @@ -101,7 +101,7 @@ def trigger_technique_enrichment( technique_id: UUID, db: Session = Depends(get_db), user: User = Depends(require_any_role("red_lead", "blue_lead")), -): +) -> dict: """Manually trigger OSINT enrichment for a single technique.""" technique = get_technique_or_raise(db, technique_id) count = enrich_technique_with_cves(db, technique) @@ -119,7 +119,7 @@ def get_technique_osint( reviewed: bool | None = Query(None), db: Session = Depends(get_db), user: User = Depends(get_current_user), -): +) -> list: """Get all OSINT items for a specific technique.""" items = get_osint_items_for_technique( db, diff --git a/backend/app/routers/professional_reports.py b/backend/app/routers/professional_reports.py index 1414b1c..76e0e2f 100644 --- a/backend/app/routers/professional_reports.py +++ b/backend/app/routers/professional_reports.py @@ -29,7 +29,7 @@ def generate_purple_report( format: str = Query("pdf", pattern="^(pdf|docx|html)$"), db: Session = Depends(get_db), user: User = Depends(require_any_role("red_lead", "blue_lead", "viewer")), -): +) -> FileResponse: """Generate a Purple Team campaign assessment report.""" filepath = report_generation_service.generate_purple_campaign_report( db, str(campaign_id), output_format=format, @@ -48,7 +48,7 @@ def generate_coverage_report( format: str = Query("pdf", pattern="^(pdf|docx|html)$"), db: Session = Depends(get_db), user: User = Depends(require_any_role("red_lead", "blue_lead", "viewer")), -): +) -> FileResponse: """Generate an organization-wide MITRE ATT&CK coverage report.""" filepath = report_generation_service.generate_coverage_report( db, output_format=format, @@ -67,7 +67,7 @@ def generate_executive_report( format: str = Query("pdf", pattern="^(pdf|docx|html)$"), db: Session = Depends(get_db), user: User = Depends(require_any_role("red_lead", "blue_lead", "viewer")), -): +) -> FileResponse: """Generate an executive security summary report.""" filepath = report_generation_service.generate_executive_summary( db, output_format=format, @@ -86,7 +86,7 @@ def generate_quarterly_report( format: str = Query("pdf", pattern="^(pdf|docx|html)$"), db: Session = Depends(get_db), user: User = Depends(require_any_role("red_lead", "blue_lead", "viewer")), -): +) -> FileResponse: """Generate a quarterly security summary report.""" filepath = report_generation_service.generate_quarterly_summary( db, output_format=format, @@ -106,7 +106,7 @@ def generate_technique_report( format: str = Query("pdf", pattern="^(pdf|docx|html)$"), db: Session = Depends(get_db), user: User = Depends(get_current_user), -): +) -> FileResponse: """Generate a detailed report for one MITRE technique.""" filepath = report_generation_service.generate_technique_detail_report( db, str(technique_id), output_format=format, diff --git a/backend/app/routers/reports.py b/backend/app/routers/reports.py index d065116..1494640 100644 --- a/backend/app/routers/reports.py +++ b/backend/app/routers/reports.py @@ -38,7 +38,7 @@ def coverage_summary( platform: Optional[str] = Query(None, description="Filter by platform (in techniques)"), db: Session = Depends(get_db), current_user: User = Depends(get_current_user), -): +) -> dict: """Full coverage report as JSON — technique-by-technique with test counts.""" return build_coverage_summary(db, tactic=tactic, platform=platform) @@ -49,7 +49,7 @@ def coverage_csv( platform: Optional[str] = Query(None), db: Session = Depends(get_db), current_user: User = Depends(get_current_user), -): +) -> StreamingResponse: """Export coverage as a downloadable CSV.""" rows = build_coverage_csv_rows(db, tactic=tactic, platform=platform) @@ -74,7 +74,7 @@ def test_results( date_to: Optional[str] = Query(None, description="ISO date string YYYY-MM-DD"), db: Session = Depends(get_db), current_user: User = Depends(get_current_user), -): +) -> dict: """Report of test results with optional filters.""" return build_test_results_report(db, state=state, date_from=date_from, date_to=date_to) @@ -84,6 +84,6 @@ def remediation_status( status: Optional[str] = Query(None, description="Filter by remediation status"), db: Session = Depends(get_db), current_user: User = Depends(get_current_user), -): +) -> dict: """Report of remediation status across all tests.""" return build_remediation_status_report(db, status=status) diff --git a/backend/app/routers/scores.py b/backend/app/routers/scores.py index 4a2ae78..85d0d2a 100644 --- a/backend/app/routers/scores.py +++ b/backend/app/routers/scores.py @@ -35,7 +35,7 @@ def score_technique( mitre_id: str, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), -): +) -> dict: """Get detailed score with breakdown for a specific technique.""" return score_technique_by_mitre_id(db, mitre_id) @@ -48,7 +48,7 @@ def score_tactic( tactic: str, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), -): +) -> dict: """Get average score for a tactic.""" return calculate_tactic_score(tactic, db) @@ -61,7 +61,7 @@ def score_threat_actor( actor_id: str, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), -): +) -> dict: """Get coverage score against a specific threat actor.""" return score_actor_by_id(db, actor_id) @@ -73,7 +73,7 @@ def score_threat_actor( def score_organization( db: Session = Depends(get_db), current_user: User = Depends(get_current_user), -): +) -> dict: """Get the overall organization security score (cached for 5 min).""" from app.services.score_cache import get_organization_score_cached @@ -88,7 +88,7 @@ def score_history( period: str = Query("90d", pattern="^(30d|90d|1y)$"), db: Session = Depends(get_db), current_user: User = Depends(get_current_user), -): +) -> dict: """Get historical score data points (weekly).""" return get_score_history(db, period) @@ -100,7 +100,7 @@ def score_history( def get_scoring_config( db: Session = Depends(get_db), current_user: User = Depends(require_role("admin")), -): +) -> dict: """Get current scoring weights (admin only).""" return get_weights_dict(db) @@ -123,7 +123,7 @@ def update_scoring_config( payload: ScoringConfigUpdate, db: Session = Depends(get_db), current_user: User = Depends(require_role("admin")), -): +) -> dict: """Update scoring weights (admin only). Weights are persisted in the database and survive restarts. diff --git a/backend/app/routers/snapshots.py b/backend/app/routers/snapshots.py index b956cfd..1d1c576 100644 --- a/backend/app/routers/snapshots.py +++ b/backend/app/routers/snapshots.py @@ -52,7 +52,7 @@ def list_snapshots( limit: int = Query(50, ge=1, le=200), db: Session = Depends(get_db), current_user: User = Depends(get_current_user), -): +) -> list: """List coverage snapshots ordered by creation date (newest first).""" return list_snapshots_svc(db, offset=offset, limit=limit) @@ -66,7 +66,7 @@ def create_snapshot_endpoint( payload: SnapshotCreate, db: Session = Depends(get_db), current_user: User = Depends(require_any_role("red_lead", "blue_lead", "admin")), -): +) -> dict: """Create a manual coverage snapshot with an optional name.""" snapshot = create_snapshot(db, name=payload.name, user_id=current_user.id) @@ -94,7 +94,7 @@ def coverage_evolution( months: int = Query(12, ge=1, le=36), db: Session = Depends(get_db), current_user: User = Depends(get_current_user), -): +) -> list: """Return coverage snapshots for trend charts (last *months* months).""" return get_coverage_evolution(db, months=months) @@ -109,7 +109,7 @@ def compare_snapshots_endpoint( b: str = Query(..., description="Snapshot B ID"), db: Session = Depends(get_db), current_user: User = Depends(get_current_user), -): +) -> dict: """Compare two snapshots showing improved, worsened, and unchanged techniques.""" try: a_id = uuid.UUID(a) @@ -129,7 +129,7 @@ def get_snapshot( snapshot_id: str, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), -): +) -> dict: """Get detailed snapshot information including per-technique states.""" return get_snapshot_detail(db, snapshot_id) @@ -143,7 +143,7 @@ def delete_snapshot_endpoint( snapshot_id: str, db: Session = Depends(get_db), current_user: User = Depends(require_role("admin")), -): +) -> dict: """Delete a snapshot (admin only).""" snapshot = get_snapshot_or_raise(db, snapshot_id) diff --git a/backend/app/routers/system.py b/backend/app/routers/system.py index 9f2922c..55b63d0 100644 --- a/backend/app/routers/system.py +++ b/backend/app/routers/system.py @@ -30,7 +30,7 @@ def trigger_mitre_sync( request: Request, db: Session = Depends(get_db), current_user: User = Depends(require_role("admin")), -): +) -> dict: """Manually trigger a MITRE ATT&CK synchronisation. **Requires** the ``admin`` role. @@ -50,7 +50,7 @@ def trigger_mitre_sync( def trigger_intel_scan( db: Session = Depends(get_db), current_user: User = Depends(require_role("admin")), -): +) -> dict: """Manually trigger a threat-intelligence scan. **Requires** the ``admin`` role. @@ -71,7 +71,7 @@ def trigger_atomic_import( request: Request, db: Session = Depends(get_db), current_user: User = Depends(require_role("admin")), -): +) -> dict: """Trigger an import of Atomic Red Team tests as TestTemplates. **Requires** the ``admin`` role. @@ -101,7 +101,7 @@ def trigger_atomic_import( @router.get("/scheduler-status") def scheduler_status( current_user: User = Depends(require_role("admin")), -): +) -> dict: """Return the current state of the background scheduler. **Requires** the ``admin`` role. diff --git a/backend/app/routers/techniques.py b/backend/app/routers/techniques.py index d04dd84..77b8d9c 100644 --- a/backend/app/routers/techniques.py +++ b/backend/app/routers/techniques.py @@ -45,7 +45,7 @@ def list_techniques( review_required: bool | None = Query(None, description="Filter by review flag"), repo: SATechniqueRepository = Depends(get_technique_repository), current_user: User = Depends(get_current_user), -): +) -> list: """Return a lightweight list of techniques, optionally filtered.""" return repo.list_all( tactic=tactic, @@ -64,7 +64,7 @@ def get_technique( mitre_id: str, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), -): +) -> dict: """Return full details for a single technique, including its tests and D3FEND defenses.""" return get_technique_detail(db, mitre_id) @@ -84,7 +84,7 @@ def create_technique( db: Session = Depends(get_db), repo: SATechniqueRepository = Depends(get_technique_repository), current_user: User = Depends(require_role("admin")), -): +) -> TechniqueOut: """Create a new technique manually.""" if repo.exists_by_mitre_id(payload.mitre_id): raise DuplicateEntityError("Technique", "mitre_id", payload.mitre_id) @@ -124,7 +124,7 @@ def update_technique( db: Session = Depends(get_db), repo: SATechniqueRepository = Depends(get_technique_repository), current_user: User = Depends(require_role("admin")), -): +) -> TechniqueOut: """Update one or more fields of an existing technique.""" entity = repo.find_by_mitre_id(mitre_id) if entity is None: @@ -160,7 +160,7 @@ def review_technique( db: Session = Depends(get_db), repo: SATechniqueRepository = Depends(get_technique_repository), current_user: User = Depends(require_any_role("red_lead", "blue_lead")), -): +) -> TechniqueOut: """Mark a technique as reviewed. Sets ``review_required`` to *False* and records the current timestamp diff --git a/backend/app/routers/test_templates.py b/backend/app/routers/test_templates.py index 4cbd5d3..c015661 100644 --- a/backend/app/routers/test_templates.py +++ b/backend/app/routers/test_templates.py @@ -78,7 +78,7 @@ def _list_templates_handler( limit: int = Query(50, ge=1, le=200), db: Session = Depends(get_db), current_user: User = Depends(get_current_user), -): +) -> list: """Return a paginated, filterable list of test templates.""" return list_templates( db, @@ -102,7 +102,7 @@ def _list_templates_handler( def template_stats( db: Session = Depends(get_db), current_user: User = Depends(require_any_role("red_lead", "blue_lead")), -): +) -> dict: """Return catalog statistics: active, by_source, by_platform.""" return get_template_stats(db) @@ -117,7 +117,7 @@ def bulk_activate_templates( activate: bool = Query(True, description="True to activate all, False to deactivate all"), db: Session = Depends(get_db), current_user: User = Depends(require_any_role("red_lead", "blue_lead")), -): +) -> dict: """Set all templates to active or inactive.""" count = bulk_activate(db, activate=activate) with UnitOfWork(db) as uow: @@ -148,7 +148,7 @@ def _templates_by_technique_handler( mitre_id: str, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), -): +) -> list: """Return all active templates mapped to a specific MITRE technique.""" return templates_by_technique(db, mitre_id) @@ -163,7 +163,7 @@ def get_template( template_id: uuid.UUID, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), -): +) -> TestTemplateOut: """Return full details for a single test template.""" return get_template_or_raise(db, template_id) @@ -182,7 +182,7 @@ def create_template( payload: TestTemplateCreate, db: Session = Depends(get_db), current_user: User = Depends(require_any_role("red_lead", "blue_lead")), -): +) -> TestTemplateOut: """Create a custom test template.""" template = create_template_svc(db, **payload.model_dump()) with UnitOfWork(db) as uow: @@ -215,7 +215,7 @@ def update_template( payload: TestTemplateCreate, db: Session = Depends(get_db), current_user: User = Depends(require_any_role("red_lead", "blue_lead")), -): +) -> TestTemplateOut: """Update fields of an existing test template.""" template = update_template_svc(db, template_id, **payload.model_dump(exclude_unset=True)) with UnitOfWork(db) as uow: @@ -243,7 +243,7 @@ def toggle_template_active( template_id: uuid.UUID, db: Session = Depends(get_db), current_user: User = Depends(require_any_role("red_lead", "blue_lead")), -): +) -> TestTemplateOut: """Toggle a template between active and inactive (is_active = not is_active).""" template = toggle_template_active_svc(db, template_id) with UnitOfWork(db) as uow: @@ -271,7 +271,7 @@ def delete_template( template_id: uuid.UUID, db: Session = Depends(get_db), current_user: User = Depends(require_any_role("red_lead", "blue_lead")), -): +) -> dict: """Soft-delete a test template by setting ``is_active=False``.""" template = get_template_or_raise(db, template_id) soft_delete_template(db, template_id) diff --git a/backend/app/routers/tests.py b/backend/app/routers/tests.py index 39c1419..09a7685 100644 --- a/backend/app/routers/tests.py +++ b/backend/app/routers/tests.py @@ -126,7 +126,7 @@ def list_tests( limit: int = Query(50, ge=1, le=200), db: Session = Depends(get_db), current_user: User = Depends(get_current_user), -): +) -> list: """Return a paginated list of tests, optionally filtered by state, technique, platform or creator.""" return crud_list_tests( db, @@ -156,7 +156,7 @@ def create_test( payload: TestCreate, db: Session = Depends(get_db), current_user: User = Depends(require_any_role("red_lead", "blue_lead")), -): +) -> TestOut: """Create a new test linked to an existing technique. ``created_by`` is set automatically and ``state`` defaults to *draft*. @@ -198,7 +198,7 @@ def create_test_from_template( payload: TestTemplateInstantiate, db: Session = Depends(get_db), current_user: User = Depends(require_any_role("red_lead", "blue_lead")), -): +) -> TestOut: """Instantiate a real Test from an existing TestTemplate. The template's fields are copied into the new test as starting data. @@ -238,7 +238,7 @@ def get_test( test_id: uuid.UUID, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), -): +) -> TestOut: """Return full details for a single test, including its evidences.""" return crud_get_test_detail(db, test_id) @@ -254,7 +254,7 @@ def update_test( payload: TestUpdate, db: Session = Depends(get_db), current_user: User = Depends(require_any_role("red_lead", "blue_lead")), -): +) -> TestOut: """Update one or more fields of an existing test. Only leads or admins can update general test fields. @@ -294,7 +294,7 @@ def update_test_classification( payload: TestClassificationUpdate, db: Session = Depends(get_db), current_user: User = Depends(require_role("admin")), -): +) -> TestOut: """Update the data classification label for a test (admin only).""" with UnitOfWork(db) as uow: test = crud_get_test_or_raise(db, test_id) @@ -324,7 +324,7 @@ def update_test_red( payload: TestRedUpdate, db: Session = Depends(get_db), current_user: User = Depends(require_any_role("red_tech", "red_lead")), -): +) -> TestOut: """Red Team updates their fields (allowed in ``draft`` and ``red_executing``).""" update_data = payload.model_dump(exclude_unset=True) with UnitOfWork(db) as uow: @@ -354,7 +354,7 @@ def update_test_blue( payload: TestBlueUpdate, db: Session = Depends(get_db), current_user: User = Depends(require_any_role("blue_tech", "blue_lead")), -): +) -> TestOut: """Blue Team updates their fields (allowed only in ``blue_evaluating``).""" update_data = payload.model_dump(exclude_unset=True) with UnitOfWork(db) as uow: @@ -383,7 +383,7 @@ def start_execution( test_id: uuid.UUID, db: Session = Depends(get_db), current_user: User = Depends(require_any_role("red_tech", "red_lead")), -): +) -> TestOut: """Move a test from ``draft`` to ``red_executing``.""" test = crud_get_test_or_raise(db, test_id) with UnitOfWork(db) as uow: @@ -403,7 +403,7 @@ def submit_red( test_id: uuid.UUID, db: Session = Depends(get_db), current_user: User = Depends(require_any_role("red_tech", "red_lead")), -): +) -> TestOut: """Red Team finalises — move from ``red_executing`` to ``blue_evaluating``.""" test = crud_get_test_or_raise(db, test_id) with UnitOfWork(db) as uow: @@ -423,7 +423,7 @@ def submit_blue( test_id: uuid.UUID, db: Session = Depends(get_db), current_user: User = Depends(require_any_role("blue_tech", "blue_lead")), -): +) -> TestOut: """Blue Team finalises — move from ``blue_evaluating`` to ``in_review``.""" test = crud_get_test_or_raise(db, test_id) with UnitOfWork(db) as uow: @@ -443,7 +443,7 @@ def pause_timer( test_id: uuid.UUID, db: Session = Depends(get_db), current_user: User = Depends(require_any_role("red_tech", "blue_tech", "red_lead", "blue_lead")), -): +) -> TestOut: """Pause the running timer for the current phase (red_executing or blue_evaluating).""" test = crud_get_test_or_raise(db, test_id) with UnitOfWork(db) as uow: @@ -463,7 +463,7 @@ def resume_timer( test_id: uuid.UUID, db: Session = Depends(get_db), current_user: User = Depends(require_any_role("red_tech", "blue_tech", "red_lead", "blue_lead")), -): +) -> TestOut: """Resume the paused timer for the current phase.""" test = crud_get_test_or_raise(db, test_id) with UnitOfWork(db) as uow: @@ -484,7 +484,7 @@ def validate_red( payload: TestRedValidate, db: Session = Depends(get_db), current_user: User = Depends(require_any_role("red_lead")), -): +) -> TestOut: """Red Lead approves or rejects the red side of a test.""" test = crud_get_test_with_technique(db, test_id) with UnitOfWork(db) as uow: @@ -511,7 +511,7 @@ def validate_blue( payload: TestBlueValidate, db: Session = Depends(get_db), current_user: User = Depends(require_any_role("blue_lead")), -): +) -> TestOut: """Blue Lead approves or rejects the blue side of a test.""" test = crud_get_test_with_technique(db, test_id) with UnitOfWork(db) as uow: @@ -537,7 +537,7 @@ def reopen( test_id: uuid.UUID, db: Session = Depends(get_db), current_user: User = Depends(require_any_role("red_lead", "blue_lead")), -): +) -> TestOut: """Reopen a rejected test, moving it back to ``draft``.""" test = crud_get_test_or_raise(db, test_id) with UnitOfWork(db) as uow: @@ -558,7 +558,7 @@ def update_remediation( payload: TestRemediationUpdate, db: Session = Depends(get_db), current_user: User = Depends(require_any_role("red_lead", "blue_lead")), -): +) -> TestOut: """Update remediation fields on a test. When ``remediation_status`` transitions to ``'completed'``, an automatic @@ -602,7 +602,7 @@ def get_test_timeline( test_id: uuid.UUID, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), -): +) -> list: """Return the chronological audit-log history for a test.""" return crud_get_test_timeline(db, test_id) @@ -617,7 +617,7 @@ def get_retest_chain( test_id: uuid.UUID, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), -): +) -> list: """Return the full chain of retests (original + all retests) for a test.""" chain = wf_get_retest_chain(db, test_id) if not chain: diff --git a/backend/app/routers/threat_actors.py b/backend/app/routers/threat_actors.py index 733112a..ff29314 100644 --- a/backend/app/routers/threat_actors.py +++ b/backend/app/routers/threat_actors.py @@ -36,7 +36,7 @@ def list_threat_actors( limit: int = Query(50, ge=1, le=200), db: Session = Depends(get_db), current_user: User = Depends(get_current_user), -): +) -> list: """List threat actors with optional filters and pagination. **Requires** authentication (any role). @@ -58,7 +58,7 @@ def get_threat_actor( actor_id: str, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), -): +) -> dict: """Get detailed info about a threat actor including techniques. **Requires** authentication (any role). @@ -71,7 +71,7 @@ def get_threat_actor_coverage( actor_id: str, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), -): +) -> dict: """Calculate coverage percentage against a specific threat actor. **Requires** authentication (any role). @@ -87,7 +87,7 @@ def get_threat_actor_gaps( actor_id: str, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), -): +) -> list: """Identify techniques of this actor that are NOT fully validated. **Requires** authentication (any role). diff --git a/backend/app/routers/users.py b/backend/app/routers/users.py index 4119094..2d5be14 100644 --- a/backend/app/routers/users.py +++ b/backend/app/routers/users.py @@ -30,7 +30,7 @@ router = APIRouter(prefix="/users", tags=["users"]) def list_users_route( db: Session = Depends(get_db), current_user: User = Depends(require_role("admin")), -): +) -> list[UserOut]: """Return a list of all users. **Requires admin role.**""" return list_users(db) @@ -45,7 +45,7 @@ def create_user_route( payload: UserCreate, db: Session = Depends(get_db), current_user: User = Depends(require_role("admin")), -): +) -> UserOut: """Create a new user. **Requires admin role.**""" with UnitOfWork(db) as uow: user = create_user( @@ -79,7 +79,7 @@ def get_user( user_id: uuid.UUID, db: Session = Depends(get_db), current_user: User = Depends(require_role("admin")), -): +) -> UserOut: """Return a single user by ID. **Requires admin role.**""" return get_user_or_raise(db, user_id) @@ -95,7 +95,7 @@ def update_user_route( payload: UserUpdate, db: Session = Depends(get_db), current_user: User = Depends(require_role("admin")), -): +) -> UserOut: """Update one or more fields of an existing user. **Requires admin role.**""" update_data = payload.model_dump(exclude_unset=True) with UnitOfWork(db) as uow: diff --git a/backend/app/routers/worklogs.py b/backend/app/routers/worklogs.py index c27ecf8..0ed1e92 100644 --- a/backend/app/routers/worklogs.py +++ b/backend/app/routers/worklogs.py @@ -56,7 +56,7 @@ def create( body: WorklogCreate, db: Session = Depends(get_db), user: User = Depends(require_any_role("red_tech", "blue_tech", "red_lead", "blue_lead")), -): +) -> WorklogOut: """Create a manually-logged worklog entry.""" with UnitOfWork(db) as uow: wl = worklog_service.create_worklog( @@ -82,7 +82,7 @@ def list_all( user_id: Optional[UUID] = None, db: Session = Depends(get_db), _user: User = Depends(get_current_user), -): +) -> list[WorklogOut]: """List worklogs with optional filters.""" return worklog_service.list_worklogs( db, @@ -97,7 +97,7 @@ def get_one( worklog_id: UUID, db: Session = Depends(get_db), _user: User = Depends(get_current_user), -): +) -> WorklogOut: """Get a single worklog by ID.""" return worklog_service.get_worklog_or_raise(db, worklog_id) @@ -107,7 +107,7 @@ def verify_integrity( worklog_id: UUID, db: Session = Depends(get_db), _user: User = Depends(get_current_user), -): +) -> dict: """Check whether a worklog's integrity hash is still valid.""" wl = worklog_service.get_worklog_or_raise(db, worklog_id) return { diff --git a/backend/app/schemas/test.py b/backend/app/schemas/test.py index 8da1351..9e7010f 100644 --- a/backend/app/schemas/test.py +++ b/backend/app/schemas/test.py @@ -167,7 +167,7 @@ class TestOut(BaseModel): model_config = ConfigDict(from_attributes=True) @classmethod - def model_validate(cls, obj, **kwargs): + def model_validate(cls, obj: object, **kwargs: object) -> "TestOut": """Override to populate technique fields from the relationship.""" if hasattr(obj, "technique") and obj.technique is not None: obj.__dict__["technique_mitre_id"] = obj.technique.mitre_id diff --git a/backend/app/seed_demo.py b/backend/app/seed_demo.py index 82d7968..4d624cb 100644 --- a/backend/app/seed_demo.py +++ b/backend/app/seed_demo.py @@ -16,6 +16,8 @@ import random import uuid from datetime import datetime, timedelta +from sqlalchemy.orm import Session + from app.auth import hash_password from app.database import SessionLocal from app.models.audit import AuditLog @@ -102,7 +104,7 @@ TEMPLATE_NAMES = [ # --------------------------------------------------------------------------- -def _cleanup_demo_data(db) -> None: +def _cleanup_demo_data(db: Session) -> None: """Remove all previously seeded demo data.""" # Delete in order to respect FK constraints demo_users = db.query(User).filter(User.username.like(f"{DEMO_PREFIX}%")).all() @@ -154,7 +156,7 @@ def _cleanup_demo_data(db) -> None: # --------------------------------------------------------------------------- -def _seed_users(db) -> list[User]: +def _seed_users(db: Session) -> list[User]: """Create 5 users per role (25 total).""" users = [] for role in ROLES: @@ -173,7 +175,7 @@ def _seed_users(db) -> list[User]: return users -def _seed_technique_statuses(db, count: int = 50) -> list[Technique]: +def _seed_technique_statuses(db: Session, count: int = 50) -> list[Technique]: """Set varied statuses on up to *count* techniques.""" techniques = db.query(Technique).limit(count).all() if not techniques: @@ -192,7 +194,7 @@ def _seed_technique_statuses(db, count: int = 50) -> list[Technique]: return techniques -def _seed_tests(db, users: list[User], techniques: list[Technique], count: int = 100) -> list[Test]: +def _seed_tests(db: Session, users: list[User], techniques: list[Technique], count: int = 100) -> list[Test]: """Create *count* tests in various pipeline states.""" if not techniques: logger.warning("No techniques available — skipping test seeding.") @@ -267,7 +269,7 @@ def _seed_tests(db, users: list[User], techniques: list[Technique], count: int = return tests -def _seed_evidences(db, tests: list[Test], users: list[User], count: int = 50) -> list[Evidence]: +def _seed_evidences(db: Session, tests: list[Test], users: list[User], count: int = 50) -> list[Evidence]: """Create *count* dummy evidence records.""" if not tests: return [] @@ -305,7 +307,7 @@ def _seed_evidences(db, tests: list[Test], users: list[User], count: int = 50) - return evidences -def _seed_audit_logs(db, users: list[User], count: int = 20) -> None: +def _seed_audit_logs(db: Session, users: list[User], count: int = 20) -> None: """Create *count* varied audit log entries.""" for i in range(count): user = random.choice(users) @@ -323,7 +325,7 @@ def _seed_audit_logs(db, users: list[User], count: int = 20) -> None: logger.info("Created %d demo audit logs.", count) -def _seed_notifications(db, users: list[User], count: int = 30) -> None: +def _seed_notifications(db: Session, users: list[User], count: int = 30) -> None: """Create *count* notifications spread across demo users.""" for i in range(count): user = random.choice(users) @@ -344,7 +346,7 @@ def _seed_notifications(db, users: list[User], count: int = 30) -> None: logger.info("Created %d demo notifications.", count) -def _seed_templates(db, techniques: list[Technique], count: int = 10) -> None: +def _seed_templates(db: Session, techniques: list[Technique], count: int = 10) -> None: """Create *count* manual demo templates.""" if not techniques: return diff --git a/backend/app/services/audit_service.py b/backend/app/services/audit_service.py index bf1c3a6..f7d578e 100644 --- a/backend/app/services/audit_service.py +++ b/backend/app/services/audit_service.py @@ -4,6 +4,7 @@ from __future__ import annotations import hashlib from datetime import datetime, timezone +from uuid import UUID from sqlalchemy.orm import Session @@ -35,7 +36,7 @@ def verify_audit_integrity(entry: AuditLog) -> bool: def log_action( db: Session, - user_id, + user_id: UUID | None, action: str, entity_type: str | None = None, entity_id: str | None = None, diff --git a/backend/app/services/campaign_crud_service.py b/backend/app/services/campaign_crud_service.py index 6c4cfdb..f6e8be4 100644 --- a/backend/app/services/campaign_crud_service.py +++ b/backend/app/services/campaign_crud_service.py @@ -192,7 +192,7 @@ def update_campaign( *, updater_id: uuid.UUID, updater_role: str, - **fields, + **fields: object, ) -> dict: """Update a campaign. Only allowed in draft or active state. diff --git a/backend/app/services/d3fend_import_service.py b/backend/app/services/d3fend_import_service.py index 28b0c45..d5a5d0e 100644 --- a/backend/app/services/d3fend_import_service.py +++ b/backend/app/services/d3fend_import_service.py @@ -8,6 +8,7 @@ Uses the D3FEND public API: import logging from typing import Any +from uuid import UUID import httpx from sqlalchemy.orm import Session @@ -26,7 +27,7 @@ D3FEND_TACTICS = ["Detect", "Harden", "Isolate", "Deceive", "Evict", "Model"] # ── Import all D3FEND techniques ───────────────────────────────────── -def _to_str(v: Any) -> str: +def _to_str(v: Any) -> str: # noqa: ANN401 """Coerce an RDF value (str, dict with @value, or list) to a plain string.""" if isinstance(v, dict): return v.get("@value", str(v)) @@ -432,7 +433,7 @@ def sync(db: Session) -> dict: return summary -def get_defenses_for_technique(db: Session, technique_id) -> list[dict]: +def get_defenses_for_technique(db: Session, technique_id: UUID) -> list[dict]: """Get all D3FEND defensive techniques mapped to a given ATT&CK technique.""" mappings = ( db.query(DefensiveTechniqueMapping) diff --git a/backend/app/services/detection_rule_service.py b/backend/app/services/detection_rule_service.py index d73d643..9b859f4 100644 --- a/backend/app/services/detection_rule_service.py +++ b/backend/app/services/detection_rule_service.py @@ -10,6 +10,7 @@ from __future__ import annotations from datetime import datetime from typing import Any +from uuid import UUID from sqlalchemy.orm import Session @@ -258,11 +259,11 @@ def get_rules_for_test(db: Session, test_id: str) -> dict[str, Any]: def evaluate_rule( db: Session, *, - test_id: Any, - detection_rule_id: Any, + test_id: UUID, + detection_rule_id: UUID, triggered: bool | None, notes: str | None, - evaluator_id: Any, + evaluator_id: UUID, ) -> dict[str, Any]: """Save or update the evaluation result for a detection rule on a test. diff --git a/backend/app/services/heatmap_service.py b/backend/app/services/heatmap_service.py index b194ba4..fd9dca3 100644 --- a/backend/app/services/heatmap_service.py +++ b/backend/app/services/heatmap_service.py @@ -10,9 +10,10 @@ no ``db.commit()``. from __future__ import annotations import json +from collections.abc import Callable from sqlalchemy import func, or_ -from sqlalchemy.orm import Session +from sqlalchemy.orm import Query, Session from app.domain.errors import BusinessRuleViolation, EntityNotFoundError from app.models.campaign import Campaign, CampaignTest @@ -92,11 +93,11 @@ def _build_layer_skeleton( def _apply_filters( - query, - model, + query: Query, # type: ignore[type-arg] + model: type, platforms: list[str] | None = None, tactics: list[str] | None = None, -): +) -> Query: # type: ignore[type-arg] """Apply common platform and tactic filters to a technique query.""" if platforms: platform_filters = [ @@ -470,7 +471,7 @@ class _LayerRegistry: self._simple: dict[str, object] = {} self._with_id: dict[str, object] = {} - def register(self, name: str, builder, *, requires_id: bool = False) -> None: + def register(self, name: str, builder: Callable[..., dict], *, requires_id: bool = False) -> None: target = self._with_id if requires_id else self._simple target[name] = builder @@ -513,7 +514,7 @@ LAYER_REGISTRY.register("campaign", build_campaign_layer, requires_id=True) SUPPORTED_LAYER_TYPES = LAYER_REGISTRY.supported_types # snapshot of built-in types -def register_layer(name: str, builder, *, requires_id: bool = False) -> None: +def register_layer(name: str, builder: Callable[..., dict], *, requires_id: bool = False) -> None: """Public API to register a new heatmap layer type at import time.""" LAYER_REGISTRY.register(name, builder, requires_id=requires_id) diff --git a/backend/app/services/jira_service.py b/backend/app/services/jira_service.py index 9bc760e..709068b 100644 --- a/backend/app/services/jira_service.py +++ b/backend/app/services/jira_service.py @@ -2,7 +2,7 @@ import logging from datetime import datetime -from typing import Optional +from typing import Any, Optional from uuid import UUID from sqlalchemy.orm import Session @@ -20,7 +20,7 @@ logger = logging.getLogger(__name__) _jira_client = None -def get_jira_client(): +def get_jira_client() -> Any: # noqa: ANN401 # atlassian.Jira imported lazily from optional dep """Return a lazily-initialised Jira client, or raise if disabled.""" global _jira_client if not settings.JIRA_ENABLED: diff --git a/backend/app/services/notification_service.py b/backend/app/services/notification_service.py index 5be53e9..8b34b51 100644 --- a/backend/app/services/notification_service.py +++ b/backend/app/services/notification_service.py @@ -15,6 +15,7 @@ from sqlalchemy.orm import Session from app.domain.errors import EntityNotFoundError from app.models.notification import Notification +from app.models.test import Test from app.models.user import User # --------------------------------------------------------------------------- @@ -157,7 +158,7 @@ def cleanup_old_notifications(db: Session, days: int = 90) -> int: # --------------------------------------------------------------------------- -def notify_test_state_change(db: Session, test, new_state: str) -> None: +def notify_test_state_change(db: Session, test: Test, new_state: str) -> None: """Dispatch notifications based on a test's new state. Called by the workflow service after each state transition. diff --git a/backend/app/services/score_cache.py b/backend/app/services/score_cache.py index 7bfd296..d2d2791 100644 --- a/backend/app/services/score_cache.py +++ b/backend/app/services/score_cache.py @@ -10,12 +10,14 @@ stale data does not persist longer than ``CACHE_TTL`` seconds. import time from typing import Any, Optional +from sqlalchemy.orm import Session + CACHE_TTL = 300 # 5 minutes _cache: dict[str, dict[str, Any]] = {} -def get(key: str) -> Optional[Any]: +def get(key: str) -> Optional[Any]: # noqa: ANN401 # generic cache returns whatever was stored """Return cached value if present and not expired, else None.""" entry = _cache.get(key) if entry is None: @@ -26,7 +28,7 @@ def get(key: str) -> Optional[Any]: return entry["data"] -def put(key: str, data: Any) -> None: +def put(key: str, data: Any) -> None: # noqa: ANN401 # generic cache accepts any serialisable value """Store *data* under *key* with the current timestamp.""" _cache[key] = {"data": data, "ts": time.time()} @@ -42,7 +44,7 @@ def invalidate(key: Optional[str] = None) -> None: # ── High-level helpers ──────────────────────────────────────────────── -def get_organization_score_cached(db): +def get_organization_score_cached(db: Session) -> dict: """Cached wrapper around ``calculate_organization_score``.""" from app.services.scoring_service import calculate_organization_score @@ -55,7 +57,7 @@ def get_organization_score_cached(db): return result -def get_operational_metrics_cached(db): +def get_operational_metrics_cached(db: Session) -> dict: """Cached wrapper around operational metrics (MTTD, MTTR, efficacy).""" from app.services.operational_metrics_service import ( calculate_alert_fidelity, diff --git a/backend/app/services/tempo_service.py b/backend/app/services/tempo_service.py index 9569787..d8dfa52 100644 --- a/backend/app/services/tempo_service.py +++ b/backend/app/services/tempo_service.py @@ -1,18 +1,20 @@ """Tempo time-tracking integration service.""" import logging -from typing import Optional +from typing import Any, Optional from sqlalchemy.orm import Session from app.config import settings from app.domain.exceptions import InvalidOperationError from app.models.jira_link import JiraLink, JiraLinkEntityType +from app.models.test import Test +from app.models.user import User logger = logging.getLogger(__name__) -def get_tempo_client(): +def get_tempo_client() -> Any: # noqa: ANN401 # tempoapiclient.Tempo imported lazily from optional dep """Return a Tempo API client, or raise if disabled.""" if not settings.TEMPO_ENABLED: raise InvalidOperationError("Tempo integration is not enabled") @@ -52,8 +54,8 @@ def log_worklog( def auto_log_test_worklog( db: Session, - test, - user, + test: Test, + user: User, activity_type: str, ) -> Optional[dict]: """If the test has a Jira link, log time to Tempo automatically. @@ -97,7 +99,7 @@ def auto_log_test_worklog( return None -def _calculate_duration(test, activity_type: str) -> int: +def _calculate_duration(test: Test, activity_type: str) -> int: """Calculate real duration in seconds from the phase timing fields. Uses the actual start/end timestamps recorded by the workflow buttons, diff --git a/backend/app/services/test_crud_service.py b/backend/app/services/test_crud_service.py index 2bf2e13..90dcfaa 100644 --- a/backend/app/services/test_crud_service.py +++ b/backend/app/services/test_crud_service.py @@ -63,7 +63,7 @@ def create_test( *, technique_id: uuid.UUID, creator_id: uuid.UUID, - **fields: Any, + **fields: object, ) -> Test: """Create a new test linked to an existing technique. @@ -176,7 +176,7 @@ def update_test( *, updater_id: uuid.UUID, updater_role: str, - **fields: Any, + **fields: object, ) -> Test: """Update general test fields (draft or rejected only). @@ -204,7 +204,7 @@ def update_test( return test -def update_test_red(db: Session, test_id: uuid.UUID, **fields: Any) -> Test: +def update_test_red(db: Session, test_id: uuid.UUID, **fields: object) -> Test: """Update Red Team fields (draft or red_executing only). Raises BusinessRuleViolation if state not in (draft, red_executing). @@ -226,7 +226,7 @@ def update_test_red(db: Session, test_id: uuid.UUID, **fields: Any) -> Test: return test -def update_test_blue(db: Session, test_id: uuid.UUID, **fields: Any) -> Test: +def update_test_blue(db: Session, test_id: uuid.UUID, **fields: object) -> Test: """Update Blue Team fields (blue_evaluating only). Raises BusinessRuleViolation if state is not blue_evaluating. diff --git a/backend/app/services/test_workflow_service.py b/backend/app/services/test_workflow_service.py index 0c3bbef..b16c130 100644 --- a/backend/app/services/test_workflow_service.py +++ b/backend/app/services/test_workflow_service.py @@ -13,6 +13,7 @@ session via the Unit of Work pattern. """ import logging +import uuid from datetime import datetime from sqlalchemy.orm import Session @@ -530,7 +531,7 @@ def handle_remediation_completed(db: Session, test: Test, user: User) -> Test | return retest -def get_retest_chain(db: Session, test_id) -> list[Test]: +def get_retest_chain(db: Session, test_id: uuid.UUID) -> list[Test]: """Return the full chain of retests for a given test. Includes the original test and all subsequent retests, ordered diff --git a/backend/ruff.toml b/backend/ruff.toml index c7b8bd6..612adf7 100644 --- a/backend/ruff.toml +++ b/backend/ruff.toml @@ -3,15 +3,17 @@ line-length = 120 [lint] # PEP8 compliance rules enforced: -# E/W — pycodestyle (core PEP8 style and warnings) -# F — pyflakes (unused imports, undefined names) -# I — isort (import ordering per PEP8 convention) -# N — pep8-naming (class/function/variable naming conventions) -select = ["E", "W", "F", "I", "N"] +# E/W — pycodestyle (core PEP8 style and warnings) +# F — pyflakes (unused imports, undefined names) +# I — isort (import ordering per PEP8 convention) +# N — pep8-naming (class/function/variable naming conventions) +# ANN — flake8-annotations (type hint enforcement) +select = ["E", "W", "F", "I", "N", "ANN"] ignore = [ # SQLAlchemy filter syntax requires `== True` / `== False` comparisons "E712", + # ANN101/ANN102 (self/cls type annotations) removed from ruff — not needed ] [lint.per-file-ignores]