feat(phase-34): resolve blocking tech debt — Redis, domain exceptions, indexes, CI
Some checks failed
Aegis CI / lint-and-test (push) Has been cancelled
Some checks failed
Aegis CI / lint-and-test (push) Has been cancelled
Foundational changes required before any new feature work can begin. - 0.1 Redis infrastructure: add redis:7-alpine to docker-compose dev and prod, REDIS_URL config, singleton client in app/infrastructure/redis_client.py - 0.2 Token blacklist on Redis SEC-001: replace in-memory dict with Redis SETEX keyed by jti, auto-expiring TTL derived from token exp - 0.3 Database indexes SR-006: Alembic migration b019 with 5 composite indexes for scoring, MTTD/MTTR, remediation, and notification queries - 0.4 Domain exceptions TD-003: app/domain/exceptions.py with typed errors, error_handler middleware mapping them to HTTP, services decoupled from FastAPI - 0.5 Fix silenced exceptions TD-007: replace 4 bare except-pass blocks in test_workflow_service with logger.warning with exc_info - 0.6 CI pipeline TD-009: GitHub Actions workflow with Postgres and Redis service containers, ruff lint, pytest; ruff.toml for baseline config
This commit is contained in:
@@ -4,12 +4,12 @@ Security utilities: password hashing and JWT token management.
|
||||
This module provides pure functions for:
|
||||
- Hashing and verifying passwords using bcrypt via passlib.
|
||||
- Creating JWT access tokens using python-jose.
|
||||
- Managing an in-memory token blacklist for revocation.
|
||||
- Managing a Redis-backed token blacklist for revocation.
|
||||
|
||||
No endpoints are defined here.
|
||||
"""
|
||||
|
||||
import threading
|
||||
import logging
|
||||
import uuid as _uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
@@ -18,6 +18,8 @@ from passlib.context import CryptContext
|
||||
|
||||
from app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Password hashing
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -58,36 +60,43 @@ def create_access_token(data: dict) -> str:
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Token blacklist (in-memory)
|
||||
# Token blacklist (Redis-backed)
|
||||
# ---------------------------------------------------------------------------
|
||||
# Stores (jti, expiry_timestamp) tuples. Entries are automatically purged
|
||||
# once they are past their original expiry (the token would be invalid
|
||||
# anyway at that point). Thread-safe via a simple lock.
|
||||
# Each revoked token's ``jti`` is stored in Redis with a TTL equal to the
|
||||
# token's remaining lifetime. This means entries auto-expire exactly when
|
||||
# the token would have become invalid anyway — no manual cleanup needed.
|
||||
#
|
||||
# For multi-worker / multi-process deployments, consider replacing this
|
||||
# with a shared store like Redis.
|
||||
# Redis survives backend restarts, so blacklisted tokens stay revoked
|
||||
# across deploys and multi-worker setups.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_blacklist: dict[str, float] = {} # jti → expiry epoch
|
||||
_blacklist_lock = threading.Lock()
|
||||
_BLACKLIST_PREFIX = "blacklist:"
|
||||
|
||||
|
||||
def blacklist_token(jti: str, exp: float) -> None:
|
||||
"""Add *jti* to the blacklist until it naturally expires at *exp*."""
|
||||
with _blacklist_lock:
|
||||
_blacklist[jti] = exp
|
||||
_cleanup_blacklist()
|
||||
"""Add *jti* to the Redis blacklist with a TTL derived from *exp*.
|
||||
|
||||
*exp* is the token's ``exp`` claim (epoch timestamp). The TTL is set
|
||||
to ``exp - now`` so the key vanishes when the token would have expired
|
||||
naturally.
|
||||
"""
|
||||
from app.infrastructure.redis_client import get_redis
|
||||
|
||||
ttl = max(int(exp - datetime.now(timezone.utc).timestamp()), 1)
|
||||
try:
|
||||
r = get_redis()
|
||||
r.setex(f"{_BLACKLIST_PREFIX}{jti}", ttl, "1")
|
||||
except Exception:
|
||||
logger.warning("Failed to blacklist token %s in Redis", jti, exc_info=True)
|
||||
|
||||
|
||||
def is_token_blacklisted(jti: str) -> bool:
|
||||
"""Return ``True`` if *jti* has been revoked."""
|
||||
with _blacklist_lock:
|
||||
return jti in _blacklist
|
||||
"""Return ``True`` if *jti* has been revoked (exists in Redis)."""
|
||||
from app.infrastructure.redis_client import get_redis
|
||||
|
||||
|
||||
def _cleanup_blacklist() -> None:
|
||||
"""Remove entries whose tokens have already expired (caller holds lock)."""
|
||||
now = datetime.now(timezone.utc).timestamp()
|
||||
expired = [k for k, exp in _blacklist.items() if exp < now]
|
||||
for k in expired:
|
||||
del _blacklist[k]
|
||||
try:
|
||||
r = get_redis()
|
||||
return r.exists(f"{_BLACKLIST_PREFIX}{jti}") > 0
|
||||
except Exception:
|
||||
logger.warning("Failed to check blacklist for %s in Redis", jti, exc_info=True)
|
||||
return False
|
||||
|
||||
@@ -24,6 +24,9 @@ class Settings(BaseSettings):
|
||||
ALGORITHM: str = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 15 # short-lived for security; configurable via env
|
||||
|
||||
# ── Redis ─────────────────────────────────────────────────────────
|
||||
REDIS_URL: str = "redis://redis:6379/0"
|
||||
|
||||
# ── CORS ─────────────────────────────────────────────────────────
|
||||
# Comma-separated list of allowed origins, or a JSON array.
|
||||
# In dev this defaults to common local ports; in production set it
|
||||
|
||||
0
backend/app/domain/__init__.py
Normal file
0
backend/app/domain/__init__.py
Normal file
67
backend/app/domain/exceptions.py
Normal file
67
backend/app/domain/exceptions.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""Domain exceptions for Aegis business logic.
|
||||
|
||||
These exceptions are raised by service-layer code and automatically
|
||||
mapped to HTTP responses by the error-handler middleware registered
|
||||
in ``app.main``. This keeps the service layer free from any HTTP
|
||||
or framework coupling.
|
||||
"""
|
||||
|
||||
|
||||
class DomainException(Exception):
|
||||
"""Base for all domain exceptions."""
|
||||
|
||||
def __init__(self, message: str, code: str = "DOMAIN_ERROR"):
|
||||
self.message = message
|
||||
self.code = code
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class EntityNotFoundError(DomainException):
|
||||
"""Raised when a requested entity does not exist."""
|
||||
|
||||
def __init__(self, entity: str, identifier: str):
|
||||
super().__init__(f"{entity} not found: {identifier}", "NOT_FOUND")
|
||||
self.entity = entity
|
||||
self.identifier = identifier
|
||||
|
||||
|
||||
class DuplicateEntityError(DomainException):
|
||||
"""Raised when creating an entity that already exists."""
|
||||
|
||||
def __init__(self, entity: str, field: str, value: str):
|
||||
super().__init__(
|
||||
f"{entity} with {field}='{value}' already exists",
|
||||
"DUPLICATE",
|
||||
)
|
||||
|
||||
|
||||
class InvalidTransitionError(DomainException):
|
||||
"""Raised when a state-machine transition is not allowed."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
current_state: str,
|
||||
target_state: str,
|
||||
valid_transitions: list[str] | None = None,
|
||||
):
|
||||
msg = f"Cannot transition from '{current_state}' to '{target_state}'"
|
||||
if valid_transitions:
|
||||
msg += f". Valid transitions: {valid_transitions}"
|
||||
super().__init__(msg, "INVALID_TRANSITION")
|
||||
self.current_state = current_state
|
||||
self.target_state = target_state
|
||||
self.valid_transitions = valid_transitions or []
|
||||
|
||||
|
||||
class InvalidOperationError(DomainException):
|
||||
"""Raised when an operation is invalid in the current context."""
|
||||
|
||||
def __init__(self, message: str):
|
||||
super().__init__(message, "INVALID_OPERATION")
|
||||
|
||||
|
||||
class AuthorizationError(DomainException):
|
||||
"""Raised when the user lacks permissions for an action."""
|
||||
|
||||
def __init__(self, message: str = "Insufficient permissions"):
|
||||
super().__init__(message, "FORBIDDEN")
|
||||
0
backend/app/infrastructure/__init__.py
Normal file
0
backend/app/infrastructure/__init__.py
Normal file
34
backend/app/infrastructure/redis_client.py
Normal file
34
backend/app/infrastructure/redis_client.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""Redis client singleton.
|
||||
|
||||
Provides a lazily-initialised Redis connection that is reused across
|
||||
the application. The connection URL is read from ``settings.REDIS_URL``.
|
||||
|
||||
Usage::
|
||||
|
||||
from app.infrastructure.redis_client import get_redis
|
||||
|
||||
r = get_redis()
|
||||
r.set("key", "value", ex=300)
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
import redis
|
||||
|
||||
from app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_redis_client: redis.Redis | None = None
|
||||
|
||||
|
||||
def get_redis() -> redis.Redis:
|
||||
"""Return a shared Redis client, creating it on first call."""
|
||||
global _redis_client
|
||||
if _redis_client is None:
|
||||
_redis_client = redis.from_url(
|
||||
settings.REDIS_URL,
|
||||
decode_responses=True,
|
||||
)
|
||||
logger.info("Redis client connected to %s", settings.REDIS_URL)
|
||||
return _redis_client
|
||||
@@ -32,6 +32,8 @@ from app.routers import scores as scores_router
|
||||
from app.routers import operational_metrics as operational_metrics_router
|
||||
from app.routers import compliance as compliance_router
|
||||
from app.routers import snapshots as snapshots_router
|
||||
from app.domain.exceptions import DomainException
|
||||
from app.middleware.error_handler import domain_exception_handler
|
||||
from app.storage import ensure_bucket_exists
|
||||
from app.jobs.mitre_sync_job import start_scheduler, scheduler
|
||||
|
||||
@@ -68,6 +70,9 @@ limiter = Limiter(key_func=get_remote_address)
|
||||
app.state.limiter = limiter
|
||||
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
||||
|
||||
# ── Domain exception → HTTP mapping ──────────────────────────────────────
|
||||
app.add_exception_handler(DomainException, domain_exception_handler)
|
||||
|
||||
# ── CORS ──────────────────────────────────────────────────────────────────
|
||||
from app.config import settings as _settings
|
||||
|
||||
|
||||
0
backend/app/middleware/__init__.py
Normal file
0
backend/app/middleware/__init__.py
Normal file
43
backend/app/middleware/error_handler.py
Normal file
43
backend/app/middleware/error_handler.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""Domain exception → HTTP response mapping.
|
||||
|
||||
This module provides a single exception handler that converts
|
||||
domain-layer exceptions into structured JSON responses, keeping
|
||||
the service layer free from FastAPI's ``HTTPException``.
|
||||
"""
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from app.domain.exceptions import (
|
||||
AuthorizationError,
|
||||
DomainException,
|
||||
DuplicateEntityError,
|
||||
EntityNotFoundError,
|
||||
InvalidOperationError,
|
||||
InvalidTransitionError,
|
||||
)
|
||||
|
||||
EXCEPTION_STATUS_MAP: dict[type[DomainException], int] = {
|
||||
EntityNotFoundError: 404,
|
||||
DuplicateEntityError: 409,
|
||||
InvalidTransitionError: 400,
|
||||
InvalidOperationError: 400,
|
||||
AuthorizationError: 403,
|
||||
}
|
||||
|
||||
|
||||
async def domain_exception_handler(
|
||||
request: Request,
|
||||
exc: DomainException,
|
||||
) -> JSONResponse:
|
||||
"""Convert a :class:`DomainException` into a JSON error response."""
|
||||
status_code = EXCEPTION_STATUS_MAP.get(type(exc), 400)
|
||||
|
||||
content: dict = {"detail": exc.message, "code": exc.code}
|
||||
|
||||
if isinstance(exc, InvalidTransitionError):
|
||||
content["current_state"] = exc.current_state
|
||||
content["target_state"] = exc.target_state
|
||||
content["valid_transitions"] = exc.valid_transitions
|
||||
|
||||
return JSONResponse(status_code=status_code, content=content)
|
||||
@@ -98,8 +98,9 @@ def logout(
|
||||
):
|
||||
"""Clear the authentication cookie and revoke the current token.
|
||||
|
||||
The token's ``jti`` is added to an in-memory blacklist so it cannot
|
||||
The token's ``jti`` is added to the Redis blacklist so it cannot
|
||||
be reused even if the cookie has already been copied elsewhere.
|
||||
The blacklist entry auto-expires when the token's ``exp`` is reached.
|
||||
"""
|
||||
# Attempt to blacklist the token's jti
|
||||
token = aegis_token or request.headers.get("Authorization", "").removeprefix("Bearer ").strip()
|
||||
|
||||
@@ -8,9 +8,9 @@ import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.domain.exceptions import EntityNotFoundError, InvalidOperationError
|
||||
from app.models.campaign import Campaign, CampaignTest, KILL_CHAIN_PHASES
|
||||
from app.models.test import Test
|
||||
from app.models.test_template import TestTemplate
|
||||
@@ -49,7 +49,7 @@ def validate_no_circular_dependency(
|
||||
) -> None:
|
||||
"""Walk the depends_on chain and verify no cycle is formed.
|
||||
|
||||
Raises HTTPException(400) if a circular dependency is detected.
|
||||
Raises :class:`InvalidOperationError` if a circular dependency is detected.
|
||||
"""
|
||||
if depends_on_id is None:
|
||||
return
|
||||
@@ -59,9 +59,8 @@ def validate_no_circular_dependency(
|
||||
|
||||
while current is not None:
|
||||
if current in visited or current == test_id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Circular dependency detected in campaign test chain",
|
||||
raise InvalidOperationError(
|
||||
"Circular dependency detected in campaign test chain"
|
||||
)
|
||||
visited.add(current)
|
||||
parent = db.query(CampaignTest).filter_by(id=current).first()
|
||||
@@ -119,7 +118,7 @@ def generate_campaign_from_threat_actor(
|
||||
"""
|
||||
actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first()
|
||||
if not actor:
|
||||
raise HTTPException(status_code=404, detail="Threat actor not found")
|
||||
raise EntityNotFoundError("ThreatActor", str(actor_id))
|
||||
|
||||
# Get unvalidated techniques for this actor
|
||||
gap_techniques = (
|
||||
@@ -132,9 +131,8 @@ def generate_campaign_from_threat_actor(
|
||||
)
|
||||
|
||||
if not gap_techniques:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"No uncovered techniques found for {actor.name}",
|
||||
raise InvalidOperationError(
|
||||
f"No uncovered techniques found for {actor.name}"
|
||||
)
|
||||
|
||||
# Create the campaign
|
||||
|
||||
@@ -11,18 +11,21 @@ Every public function validates the transition, mutates the test, writes an
|
||||
audit-log entry, and commits the session.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.config import settings
|
||||
from app.domain.exceptions import InvalidOperationError, InvalidTransitionError
|
||||
from app.models.enums import TestState
|
||||
from app.models.test import Test
|
||||
from app.models.user import User
|
||||
from app.services.audit_service import log_action
|
||||
from app.services.notification_service import notify_test_state_change, create_notification
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Valid transition map
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -59,23 +62,15 @@ def transition_state(
|
||||
) -> Test:
|
||||
"""Validate and perform a state transition, log it, and commit.
|
||||
|
||||
Raises :class:`~fastapi.HTTPException` 400 when the transition is invalid.
|
||||
Raises :class:`InvalidTransitionError` when the transition is invalid.
|
||||
"""
|
||||
if not can_transition(test, target_state):
|
||||
current = test.state if isinstance(test.state, TestState) else TestState(test.state)
|
||||
valid = [s.value for s in VALID_TRANSITIONS.get(current, [])]
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={
|
||||
"message": (
|
||||
f"Cannot transition from '{current.value}' to '{target_state.value}'. "
|
||||
f"Valid transitions: {valid}"
|
||||
),
|
||||
"code": "INVALID_TRANSITION",
|
||||
"current_state": current.value,
|
||||
"target_state": target_state.value,
|
||||
"valid_transitions": valid,
|
||||
},
|
||||
raise InvalidTransitionError(
|
||||
current_state=current.value,
|
||||
target_state=target_state.value,
|
||||
valid_transitions=valid,
|
||||
)
|
||||
|
||||
previous_state = test.state.value if isinstance(test.state, TestState) else test.state
|
||||
@@ -103,8 +98,8 @@ def transition_state(
|
||||
# Dispatch in-app notifications for the new state
|
||||
try:
|
||||
notify_test_state_change(db, test, target_state.value)
|
||||
except Exception:
|
||||
pass # Notifications are best-effort — don't block the workflow
|
||||
except Exception as e:
|
||||
logger.warning("Notification failed for test %s: %s", test.id, e, exc_info=True)
|
||||
|
||||
return test
|
||||
|
||||
@@ -169,22 +164,13 @@ def validate_as_red_lead(
|
||||
"""
|
||||
current = test.state.value if isinstance(test.state, TestState) else test.state
|
||||
if test.state not in (TestState.in_review,):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={
|
||||
"message": f"Cannot validate red side while test is in '{current}' state (must be in_review)",
|
||||
"code": "INVALID_STATE",
|
||||
"current_state": current,
|
||||
},
|
||||
raise InvalidOperationError(
|
||||
f"Cannot validate red side while test is in '{current}' state (must be in_review)"
|
||||
)
|
||||
|
||||
if validation_status not in ("approved", "rejected"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={
|
||||
"message": "validation_status must be 'approved' or 'rejected'",
|
||||
"code": "INVALID_VALIDATION_STATUS",
|
||||
},
|
||||
raise InvalidOperationError(
|
||||
"validation_status must be 'approved' or 'rejected'"
|
||||
)
|
||||
|
||||
now = datetime.utcnow()
|
||||
@@ -225,22 +211,13 @@ def validate_as_blue_lead(
|
||||
"""
|
||||
current = test.state.value if isinstance(test.state, TestState) else test.state
|
||||
if test.state not in (TestState.in_review,):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={
|
||||
"message": f"Cannot validate blue side while test is in '{current}' state (must be in_review)",
|
||||
"code": "INVALID_STATE",
|
||||
"current_state": current,
|
||||
},
|
||||
raise InvalidOperationError(
|
||||
f"Cannot validate blue side while test is in '{current}' state (must be in_review)"
|
||||
)
|
||||
|
||||
if validation_status not in ("approved", "rejected"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={
|
||||
"message": "validation_status must be 'approved' or 'rejected'",
|
||||
"code": "INVALID_VALIDATION_STATUS",
|
||||
},
|
||||
raise InvalidOperationError(
|
||||
"validation_status must be 'approved' or 'rejected'"
|
||||
)
|
||||
|
||||
now = datetime.utcnow()
|
||||
@@ -283,8 +260,8 @@ def check_dual_validation(db: Session, test: Test) -> Test:
|
||||
db.commit()
|
||||
try:
|
||||
notify_test_state_change(db, test, "rejected")
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning("Notification failed for test %s (rejected): %s", test.id, e, exc_info=True)
|
||||
elif red_status == "approved" and blue_status == "approved":
|
||||
test.state = TestState.validated
|
||||
db.commit()
|
||||
@@ -292,12 +269,12 @@ def check_dual_validation(db: Session, test: Test) -> Test:
|
||||
try:
|
||||
from app.services.score_cache import invalidate
|
||||
invalidate()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning("Score cache invalidation failed: %s", e, exc_info=True)
|
||||
try:
|
||||
notify_test_state_change(db, test, "validated")
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning("Notification failed for test %s (validated): %s", test.id, e, exc_info=True)
|
||||
else:
|
||||
# One side hasn't voted yet — stay in_review, just flush
|
||||
db.commit()
|
||||
|
||||
Reference in New Issue
Block a user