Files
Aegis/backend/app/main.py
T
kitos 394d5d9056 refactor(types): add comprehensive type annotations across backend Python codebase
Enable ANN rules in ruff.toml (flake8-annotations) and resolve all 221 violations:

ANN201/ANN202 — return types on 168 public/private functions:
- All 28 FastAPI routers: endpoints annotated with dict/list/specific schema/
  StreamingResponse/FileResponse/JSONResponse as appropriate
- main.py: lifespan→AsyncGenerator[None,None], exception handlers→JSONResponse
- database.py: get_db→Generator[Session,None,None], proxy methods→correct types
- middleware/request_context.py: dispatch→Response with Callable call_next type

ANN001/ANN002/ANN003 — 32 missing argument types:
- seed_demo.py: all db parameters typed as Session
- domain/unit_of_work.py: __aexit__ exc_type/exc_val/exc_tb typed with TracebackType
- services: audit_service user_id→UUID|None, heatmap_service query/model/builder,
  notification_service test→Test, tempo_service test→Test/user→User,
  test_workflow_service test_id→UUID, campaign_crud **fields→object,
  test_crud **fields→object (4 sites)

ANN401 — 16 Any usages resolved:
- Domain entities (campaign/technique/threat_actor/test_entity): replaced Any with
  actual ORM types via TYPE_CHECKING guards to avoid circular imports
- detection_rule_service: test_id/detection_rule_id/evaluator_id→UUID
- score_cache: kept Any with # noqa: ANN401 (genuinely generic cache)
- jira_service/tempo_service: kept Any with # noqa: ANN401 (lazy optional deps)
- d3fend_import_service: _to_str(v: Any) kept with # noqa: ANN401

ANN204/ANN205/ANN206 — special/static/class methods:
- database.py proxy __call__/__getattr__: *args: object/**kwargs: object
- schemas/test.py model_validate: obj→object, **kwargs→object
- sa_technique_repository._int_type→type

All 439 unit tests pass. ruff check app/ → All checks passed!

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-06-09 17:04:51 +02:00

189 lines
8.3 KiB
Python

import logging
import os
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request, status
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from slowapi import _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded
from sqlalchemy.exc import SQLAlchemyError
from app.config import settings as _settings
from app.domain.errors import DomainError
from app.jobs.mitre_sync_job import scheduler, start_scheduler
from app.limiter import limiter
from app.logging_config import setup_logging
from app.middleware.error_handler import domain_exception_handler
from app.middleware.request_context import RequestContextMiddleware
from app.routers import advanced_metrics as advanced_metrics_router
from app.routers import analytics as analytics_router
from app.routers import audit as audit_router
from app.routers import auth as auth_router
from app.routers import campaigns as campaigns_router
from app.routers import compliance as compliance_router
from app.routers import d3fend as d3fend_router
from app.routers import data_sources as data_sources_router
from app.routers import detection_rules as detection_rules_router
from app.routers import evidence as evidence_router
from app.routers import heatmap as heatmap_router
from app.routers import jira as jira_router
from app.routers import metrics as metrics_router
from app.routers import notifications as notifications_router
from app.routers import operational_metrics as operational_metrics_router
from app.routers import osint as osint_router
from app.routers import professional_reports as professional_reports_router
from app.routers import reports as reports_router
from app.routers import scores as scores_router
from app.routers import snapshots as snapshots_router
from app.routers import system as system_router
from app.routers import techniques as techniques_router
from app.routers import test_templates as test_templates_router
from app.routers import tests as tests_router
from app.routers import threat_actors as threat_actors_router
from app.routers import users as users_router
from app.routers import worklogs as worklogs_router
from app.storage import ensure_bucket_exists
# Configure structured logging before any module initialises its own logger
setup_logging()
# ── Environment detection ─────────────────────────────────────────────────
_IS_PRODUCTION = os.environ.get("AEGIS_ENV", "").lower() == "production"
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
"""Startup / shutdown logic."""
ensure_bucket_exists()
start_scheduler()
yield
# Graceful shutdown of the background scheduler
scheduler.shutdown(wait=False)
# ── In production, disable Swagger UI and ReDoc to hide API surface ──────
app = FastAPI(
title="Attack Coverage Platform",
lifespan=lifespan,
docs_url=None if _IS_PRODUCTION else "/docs",
redoc_url=None if _IS_PRODUCTION else "/redoc",
openapi_url=None if _IS_PRODUCTION else "/openapi.json",
)
# ── Rate Limiter ──────────────────────────────────────────────────────────
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
app.add_middleware(RequestContextMiddleware)
# ── Domain exception → HTTP mapping ──────────────────────────────────────
app.add_exception_handler(DomainError, domain_exception_handler)
# ── CORS ──────────────────────────────────────────────────────────────────
_cors_origins: list[str] = [
o.strip() for o in _settings.CORS_ORIGINS.split(",") if o.strip()
]
app.add_middleware(
CORSMiddleware,
allow_origins=_cors_origins,
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
allow_headers=["Authorization", "Content-Type"],
)
# ── Routers ──────────────────────────────────────────────────────────────
app.include_router(auth_router.router, prefix="/api/v1")
app.include_router(techniques_router.router, prefix="/api/v1")
app.include_router(tests_router.router, prefix="/api/v1")
app.include_router(evidence_router.router, prefix="/api/v1")
app.include_router(test_templates_router.router, prefix="/api/v1")
app.include_router(system_router.router, prefix="/api/v1")
app.include_router(metrics_router.router, prefix="/api/v1")
app.include_router(users_router.router, prefix="/api/v1")
app.include_router(audit_router.router, prefix="/api/v1")
app.include_router(notifications_router.router, prefix="/api/v1")
app.include_router(reports_router.router, prefix="/api/v1")
app.include_router(data_sources_router.router, prefix="/api/v1")
app.include_router(threat_actors_router.router, prefix="/api/v1")
app.include_router(d3fend_router.router, prefix="/api/v1")
app.include_router(detection_rules_router.router, prefix="/api/v1")
app.include_router(campaigns_router.router, prefix="/api/v1")
app.include_router(heatmap_router.router, prefix="/api/v1")
app.include_router(scores_router.router, prefix="/api/v1")
app.include_router(operational_metrics_router.router, prefix="/api/v1")
app.include_router(compliance_router.router, prefix="/api/v1")
app.include_router(snapshots_router.router, prefix="/api/v1")
app.include_router(jira_router.router, prefix="/api/v1")
app.include_router(worklogs_router.router, prefix="/api/v1")
app.include_router(professional_reports_router.router, prefix="/api/v1")
app.include_router(analytics_router.router, prefix="/api/v1")
app.include_router(advanced_metrics_router.router, prefix="/api/v1")
app.include_router(osint_router.router, prefix="/api/v1")
@app.get("/health", include_in_schema=False)
def health() -> dict[str, str]:
"""Minimal health check — returns only an HTTP 200 with no service metadata.
Access is restricted to internal networks at the Nginx level
(see ``frontend/nginx.conf``).
"""
return {"status": "ok"}
# ── Exception Handlers ────────────────────────────────────────────────────
def _serialize_validation_errors(exc: RequestValidationError) -> list[dict]:
"""Return validation errors safe for JSON (no raw exception objects)."""
serialized: list[dict] = []
for err in exc.errors():
item = dict(err)
ctx = item.get("ctx")
if isinstance(ctx, dict):
item["ctx"] = {key: str(value) for key, value in ctx.items()}
serialized.append(item)
return serialized
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse:
"""Handle validation errors with consistent format."""
return JSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
content={
"detail": "Validation error",
"code": "VALIDATION_ERROR",
"errors": _serialize_validation_errors(exc),
},
)
@app.exception_handler(SQLAlchemyError)
async def sqlalchemy_exception_handler(request: Request, exc: SQLAlchemyError) -> JSONResponse:
"""Handle database errors."""
logging.error(f"Database error: {exc}")
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={
"detail": "Database error occurred",
"code": "DATABASE_ERROR",
},
)
@app.exception_handler(Exception)
async def general_exception_handler(request: Request, exc: Exception) -> JSONResponse:
"""Handle all unhandled exceptions."""
logging.error(f"Unhandled exception: {exc}")
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={
"detail": "An internal server error occurred",
"code": "INTERNAL_ERROR",
},
)