feat(phase-34): resolve blocking tech debt — Redis, domain exceptions, indexes, CI
Some checks failed
Aegis CI / lint-and-test (push) Has been cancelled

Foundational changes required before any new feature work can begin.

- 0.1 Redis infrastructure: add redis:7-alpine to docker-compose dev and prod,
  REDIS_URL config, singleton client in app/infrastructure/redis_client.py
- 0.2 Token blacklist on Redis SEC-001: replace in-memory dict with Redis SETEX
  keyed by jti, auto-expiring TTL derived from token exp
- 0.3 Database indexes SR-006: Alembic migration b019 with 5 composite indexes
  for scoring, MTTD/MTTR, remediation, and notification queries
- 0.4 Domain exceptions TD-003: app/domain/exceptions.py with typed errors,
  error_handler middleware mapping them to HTTP, services decoupled from FastAPI
- 0.5 Fix silenced exceptions TD-007: replace 4 bare except-pass blocks in
  test_workflow_service with logger.warning with exc_info
- 0.6 CI pipeline TD-009: GitHub Actions workflow with Postgres and Redis
  service containers, ruff lint, pytest; ruff.toml for baseline config
This commit is contained in:
2026-02-17 15:43:05 +01:00
parent 6a327f6b51
commit 6d18a5417d
21 changed files with 464 additions and 124 deletions

View File

@@ -4,12 +4,12 @@ Security utilities: password hashing and JWT token management.
This module provides pure functions for:
- Hashing and verifying passwords using bcrypt via passlib.
- Creating JWT access tokens using python-jose.
- Managing an in-memory token blacklist for revocation.
- Managing a Redis-backed token blacklist for revocation.
No endpoints are defined here.
"""
import threading
import logging
import uuid as _uuid
from datetime import datetime, timedelta, timezone
@@ -18,6 +18,8 @@ from passlib.context import CryptContext
from app.config import settings
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Password hashing
# ---------------------------------------------------------------------------
@@ -58,36 +60,43 @@ def create_access_token(data: dict) -> str:
# ---------------------------------------------------------------------------
# Token blacklist (in-memory)
# Token blacklist (Redis-backed)
# ---------------------------------------------------------------------------
# Stores (jti, expiry_timestamp) tuples. Entries are automatically purged
# once they are past their original expiry (the token would be invalid
# anyway at that point). Thread-safe via a simple lock.
# Each revoked token's ``jti`` is stored in Redis with a TTL equal to the
# token's remaining lifetime. This means entries auto-expire exactly when
# the token would have become invalid anyway — no manual cleanup needed.
#
# For multi-worker / multi-process deployments, consider replacing this
# with a shared store like Redis.
# Redis survives backend restarts, so blacklisted tokens stay revoked
# across deploys and multi-worker setups.
# ---------------------------------------------------------------------------
_blacklist: dict[str, float] = {} # jti → expiry epoch
_blacklist_lock = threading.Lock()
_BLACKLIST_PREFIX = "blacklist:"
def blacklist_token(jti: str, exp: float) -> None:
"""Add *jti* to the blacklist until it naturally expires at *exp*."""
with _blacklist_lock:
_blacklist[jti] = exp
_cleanup_blacklist()
"""Add *jti* to the Redis blacklist with a TTL derived from *exp*.
*exp* is the token's ``exp`` claim (epoch timestamp). The TTL is set
to ``exp - now`` so the key vanishes when the token would have expired
naturally.
"""
from app.infrastructure.redis_client import get_redis
ttl = max(int(exp - datetime.now(timezone.utc).timestamp()), 1)
try:
r = get_redis()
r.setex(f"{_BLACKLIST_PREFIX}{jti}", ttl, "1")
except Exception:
logger.warning("Failed to blacklist token %s in Redis", jti, exc_info=True)
def is_token_blacklisted(jti: str) -> bool:
"""Return ``True`` if *jti* has been revoked."""
with _blacklist_lock:
return jti in _blacklist
"""Return ``True`` if *jti* has been revoked (exists in Redis)."""
from app.infrastructure.redis_client import get_redis
def _cleanup_blacklist() -> None:
"""Remove entries whose tokens have already expired (caller holds lock)."""
now = datetime.now(timezone.utc).timestamp()
expired = [k for k, exp in _blacklist.items() if exp < now]
for k in expired:
del _blacklist[k]
try:
r = get_redis()
return r.exists(f"{_BLACKLIST_PREFIX}{jti}") > 0
except Exception:
logger.warning("Failed to check blacklist for %s in Redis", jti, exc_info=True)
return False

View File

@@ -24,6 +24,9 @@ class Settings(BaseSettings):
ALGORITHM: str = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES: int = 15 # short-lived for security; configurable via env
# ── Redis ─────────────────────────────────────────────────────────
REDIS_URL: str = "redis://redis:6379/0"
# ── CORS ─────────────────────────────────────────────────────────
# Comma-separated list of allowed origins, or a JSON array.
# In dev this defaults to common local ports; in production set it

View File

View File

@@ -0,0 +1,67 @@
"""Domain exceptions for Aegis business logic.
These exceptions are raised by service-layer code and automatically
mapped to HTTP responses by the error-handler middleware registered
in ``app.main``. This keeps the service layer free from any HTTP
or framework coupling.
"""
class DomainException(Exception):
"""Base for all domain exceptions."""
def __init__(self, message: str, code: str = "DOMAIN_ERROR"):
self.message = message
self.code = code
super().__init__(message)
class EntityNotFoundError(DomainException):
"""Raised when a requested entity does not exist."""
def __init__(self, entity: str, identifier: str):
super().__init__(f"{entity} not found: {identifier}", "NOT_FOUND")
self.entity = entity
self.identifier = identifier
class DuplicateEntityError(DomainException):
"""Raised when creating an entity that already exists."""
def __init__(self, entity: str, field: str, value: str):
super().__init__(
f"{entity} with {field}='{value}' already exists",
"DUPLICATE",
)
class InvalidTransitionError(DomainException):
"""Raised when a state-machine transition is not allowed."""
def __init__(
self,
current_state: str,
target_state: str,
valid_transitions: list[str] | None = None,
):
msg = f"Cannot transition from '{current_state}' to '{target_state}'"
if valid_transitions:
msg += f". Valid transitions: {valid_transitions}"
super().__init__(msg, "INVALID_TRANSITION")
self.current_state = current_state
self.target_state = target_state
self.valid_transitions = valid_transitions or []
class InvalidOperationError(DomainException):
"""Raised when an operation is invalid in the current context."""
def __init__(self, message: str):
super().__init__(message, "INVALID_OPERATION")
class AuthorizationError(DomainException):
"""Raised when the user lacks permissions for an action."""
def __init__(self, message: str = "Insufficient permissions"):
super().__init__(message, "FORBIDDEN")

View File

View File

@@ -0,0 +1,34 @@
"""Redis client singleton.
Provides a lazily-initialised Redis connection that is reused across
the application. The connection URL is read from ``settings.REDIS_URL``.
Usage::
from app.infrastructure.redis_client import get_redis
r = get_redis()
r.set("key", "value", ex=300)
"""
import logging
import redis
from app.config import settings
logger = logging.getLogger(__name__)
_redis_client: redis.Redis | None = None
def get_redis() -> redis.Redis:
"""Return a shared Redis client, creating it on first call."""
global _redis_client
if _redis_client is None:
_redis_client = redis.from_url(
settings.REDIS_URL,
decode_responses=True,
)
logger.info("Redis client connected to %s", settings.REDIS_URL)
return _redis_client

View File

@@ -32,6 +32,8 @@ from app.routers import scores as scores_router
from app.routers import operational_metrics as operational_metrics_router
from app.routers import compliance as compliance_router
from app.routers import snapshots as snapshots_router
from app.domain.exceptions import DomainException
from app.middleware.error_handler import domain_exception_handler
from app.storage import ensure_bucket_exists
from app.jobs.mitre_sync_job import start_scheduler, scheduler
@@ -68,6 +70,9 @@ limiter = Limiter(key_func=get_remote_address)
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
# ── Domain exception → HTTP mapping ──────────────────────────────────────
app.add_exception_handler(DomainException, domain_exception_handler)
# ── CORS ──────────────────────────────────────────────────────────────────
from app.config import settings as _settings

View File

View File

@@ -0,0 +1,43 @@
"""Domain exception → HTTP response mapping.
This module provides a single exception handler that converts
domain-layer exceptions into structured JSON responses, keeping
the service layer free from FastAPI's ``HTTPException``.
"""
from fastapi import Request
from fastapi.responses import JSONResponse
from app.domain.exceptions import (
AuthorizationError,
DomainException,
DuplicateEntityError,
EntityNotFoundError,
InvalidOperationError,
InvalidTransitionError,
)
EXCEPTION_STATUS_MAP: dict[type[DomainException], int] = {
EntityNotFoundError: 404,
DuplicateEntityError: 409,
InvalidTransitionError: 400,
InvalidOperationError: 400,
AuthorizationError: 403,
}
async def domain_exception_handler(
request: Request,
exc: DomainException,
) -> JSONResponse:
"""Convert a :class:`DomainException` into a JSON error response."""
status_code = EXCEPTION_STATUS_MAP.get(type(exc), 400)
content: dict = {"detail": exc.message, "code": exc.code}
if isinstance(exc, InvalidTransitionError):
content["current_state"] = exc.current_state
content["target_state"] = exc.target_state
content["valid_transitions"] = exc.valid_transitions
return JSONResponse(status_code=status_code, content=content)

View File

@@ -98,8 +98,9 @@ def logout(
):
"""Clear the authentication cookie and revoke the current token.
The token's ``jti`` is added to an in-memory blacklist so it cannot
The token's ``jti`` is added to the Redis blacklist so it cannot
be reused even if the cookie has already been copied elsewhere.
The blacklist entry auto-expires when the token's ``exp`` is reached.
"""
# Attempt to blacklist the token's jti
token = aegis_token or request.headers.get("Authorization", "").removeprefix("Bearer ").strip()

View File

@@ -8,9 +8,9 @@ import logging
import uuid
from datetime import datetime
from fastapi import HTTPException
from sqlalchemy.orm import Session
from app.domain.exceptions import EntityNotFoundError, InvalidOperationError
from app.models.campaign import Campaign, CampaignTest, KILL_CHAIN_PHASES
from app.models.test import Test
from app.models.test_template import TestTemplate
@@ -49,7 +49,7 @@ def validate_no_circular_dependency(
) -> None:
"""Walk the depends_on chain and verify no cycle is formed.
Raises HTTPException(400) if a circular dependency is detected.
Raises :class:`InvalidOperationError` if a circular dependency is detected.
"""
if depends_on_id is None:
return
@@ -59,9 +59,8 @@ def validate_no_circular_dependency(
while current is not None:
if current in visited or current == test_id:
raise HTTPException(
status_code=400,
detail="Circular dependency detected in campaign test chain",
raise InvalidOperationError(
"Circular dependency detected in campaign test chain"
)
visited.add(current)
parent = db.query(CampaignTest).filter_by(id=current).first()
@@ -119,7 +118,7 @@ def generate_campaign_from_threat_actor(
"""
actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first()
if not actor:
raise HTTPException(status_code=404, detail="Threat actor not found")
raise EntityNotFoundError("ThreatActor", str(actor_id))
# Get unvalidated techniques for this actor
gap_techniques = (
@@ -132,9 +131,8 @@ def generate_campaign_from_threat_actor(
)
if not gap_techniques:
raise HTTPException(
status_code=400,
detail=f"No uncovered techniques found for {actor.name}",
raise InvalidOperationError(
f"No uncovered techniques found for {actor.name}"
)
# Create the campaign

View File

@@ -11,18 +11,21 @@ Every public function validates the transition, mutates the test, writes an
audit-log entry, and commits the session.
"""
import logging
from datetime import datetime
from fastapi import HTTPException, status
from sqlalchemy.orm import Session
from app.config import settings
from app.domain.exceptions import InvalidOperationError, InvalidTransitionError
from app.models.enums import TestState
from app.models.test import Test
from app.models.user import User
from app.services.audit_service import log_action
from app.services.notification_service import notify_test_state_change, create_notification
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Valid transition map
# ---------------------------------------------------------------------------
@@ -59,23 +62,15 @@ def transition_state(
) -> Test:
"""Validate and perform a state transition, log it, and commit.
Raises :class:`~fastapi.HTTPException` 400 when the transition is invalid.
Raises :class:`InvalidTransitionError` when the transition is invalid.
"""
if not can_transition(test, target_state):
current = test.state if isinstance(test.state, TestState) else TestState(test.state)
valid = [s.value for s in VALID_TRANSITIONS.get(current, [])]
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"message": (
f"Cannot transition from '{current.value}' to '{target_state.value}'. "
f"Valid transitions: {valid}"
),
"code": "INVALID_TRANSITION",
"current_state": current.value,
"target_state": target_state.value,
"valid_transitions": valid,
},
raise InvalidTransitionError(
current_state=current.value,
target_state=target_state.value,
valid_transitions=valid,
)
previous_state = test.state.value if isinstance(test.state, TestState) else test.state
@@ -103,8 +98,8 @@ def transition_state(
# Dispatch in-app notifications for the new state
try:
notify_test_state_change(db, test, target_state.value)
except Exception:
pass # Notifications are best-effort — don't block the workflow
except Exception as e:
logger.warning("Notification failed for test %s: %s", test.id, e, exc_info=True)
return test
@@ -169,22 +164,13 @@ def validate_as_red_lead(
"""
current = test.state.value if isinstance(test.state, TestState) else test.state
if test.state not in (TestState.in_review,):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"message": f"Cannot validate red side while test is in '{current}' state (must be in_review)",
"code": "INVALID_STATE",
"current_state": current,
},
raise InvalidOperationError(
f"Cannot validate red side while test is in '{current}' state (must be in_review)"
)
if validation_status not in ("approved", "rejected"):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"message": "validation_status must be 'approved' or 'rejected'",
"code": "INVALID_VALIDATION_STATUS",
},
raise InvalidOperationError(
"validation_status must be 'approved' or 'rejected'"
)
now = datetime.utcnow()
@@ -225,22 +211,13 @@ def validate_as_blue_lead(
"""
current = test.state.value if isinstance(test.state, TestState) else test.state
if test.state not in (TestState.in_review,):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"message": f"Cannot validate blue side while test is in '{current}' state (must be in_review)",
"code": "INVALID_STATE",
"current_state": current,
},
raise InvalidOperationError(
f"Cannot validate blue side while test is in '{current}' state (must be in_review)"
)
if validation_status not in ("approved", "rejected"):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"message": "validation_status must be 'approved' or 'rejected'",
"code": "INVALID_VALIDATION_STATUS",
},
raise InvalidOperationError(
"validation_status must be 'approved' or 'rejected'"
)
now = datetime.utcnow()
@@ -283,8 +260,8 @@ def check_dual_validation(db: Session, test: Test) -> Test:
db.commit()
try:
notify_test_state_change(db, test, "rejected")
except Exception:
pass
except Exception as e:
logger.warning("Notification failed for test %s (rejected): %s", test.id, e, exc_info=True)
elif red_status == "approved" and blue_status == "approved":
test.state = TestState.validated
db.commit()
@@ -292,12 +269,12 @@ def check_dual_validation(db: Session, test: Test) -> Test:
try:
from app.services.score_cache import invalidate
invalidate()
except Exception:
pass
except Exception as e:
logger.warning("Score cache invalidation failed: %s", e, exc_info=True)
try:
notify_test_state_change(db, test, "validated")
except Exception:
pass
except Exception as e:
logger.warning("Notification failed for test %s (validated): %s", test.id, e, exc_info=True)
else:
# One side hasn't voted yet — stay in_review, just flush
db.commit()