"""Authentication and RBAC dependencies for FastAPI. Provides: - ``get_current_user``: decodes JWT from HttpOnly cookie (preferred) or Authorization header (fallback), fetches user from DB, raises 401 on failure. - ``require_role``: factory that returns a dependency enforcing a specific role (admins always pass). """ # Import Callable from collections.abc from collections.abc import Callable # Import Optional from typing from typing import Optional # Import Cookie, Depends, HTTPException, status from fastapi from fastapi import Cookie, Depends, HTTPException, status # Import OAuth2PasswordBearer from fastapi.security from fastapi.security import OAuth2PasswordBearer # Import JWTError, jwt from jose from jose import JWTError, jwt # Import Session from sqlalchemy.orm from sqlalchemy.orm import Session # Import auth as auth_lib from app from app import auth as auth_lib # Import settings from app.config from app.config import settings # Import get_db from app.database from app.database import get_db # Import User from app.models.user from app.models.user import User # --------------------------------------------------------------------------- # OAuth2 scheme (reads Authorization header — used as fallback / Swagger UI) # --------------------------------------------------------------------------- oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login", auto_error=False) # Cookie name — must match the one set in the auth router _COOKIE_NAME = "aegis_token" # --------------------------------------------------------------------------- # Current-user dependency # --------------------------------------------------------------------------- async def get_current_user( # Entry: aegis_token aegis_token: Optional[str] = Cookie(None), # Entry: bearer_token bearer_token: Optional[str] = Depends(oauth2_scheme), # Entry: db db: Session = Depends(get_db), ) -> User: """Decode the JWT, look up the user in *db*, and return it. Token resolution order: 1. ``aegis_token`` **HttpOnly cookie** (preferred — immune to XSS). 2. ``Authorization: Bearer `` header (fallback for API clients and Swagger UI). Raises :class:`~fastapi.HTTPException` **401** when: - no token is found in either location, - the token cannot be decoded, - the ``sub`` claim is missing, or - no matching active user exists in the database. """ # Assign credentials_exception = HTTPException( credentials_exception = HTTPException( # Keyword argument: status_code status_code=status.HTTP_401_UNAUTHORIZED, # Keyword argument: detail detail="Could not validate credentials", # Keyword argument: headers headers={"WWW-Authenticate": "Bearer"}, ) # Assign revoked_exception = HTTPException( revoked_exception = HTTPException( # Keyword argument: status_code status_code=status.HTTP_401_UNAUTHORIZED, # Keyword argument: detail detail="Token has been revoked", # Keyword argument: headers headers={"WWW-Authenticate": "Bearer"}, ) # Prefer cookie, fall back to header token = aegis_token or bearer_token # Check: token is None if token is None: # Raise credentials_exception raise credentials_exception # Attempt the following; catch errors below try: # Assign payload = jwt.decode( payload = jwt.decode( token, settings.SECRET_KEY, # Keyword argument: algorithms algorithms=[settings.ALGORITHM], ) # Assign username = payload.get("sub") username: str | None = payload.get("sub") # Check: username is None if username is None: # Raise credentials_exception raise credentials_exception # Check token blacklist (revoked tokens) jti: str | None = payload.get("jti") # Check: jti and auth_lib.is_token_blacklisted(jti) if jti and auth_lib.is_token_blacklisted(jti): # Raise revoked_exception raise revoked_exception # Handle JWTError except JWTError: # Raise credentials_exception raise credentials_exception # Assign user = db.query(User).filter(User.username == username).first() user = db.query(User).filter(User.username == username).first() # Check: user is None or not user.is_active if user is None or not user.is_active: # Raise credentials_exception raise credentials_exception # Return user return user # --------------------------------------------------------------------------- # Role-based access control dependency # --------------------------------------------------------------------------- async def require_password_changed( # Entry: current_user current_user: User = Depends(get_current_user), ) -> User: """Block all requests when the user still needs to change their password. Only ``/auth/change-password`` and ``/auth/me`` are exempt — those endpoints do **not** depend on this function. """ # Check: getattr(current_user, "must_change_password", False) if getattr(current_user, "must_change_password", False): # Raise HTTPException raise HTTPException( # Keyword argument: status_code status_code=status.HTTP_403_FORBIDDEN, # Keyword argument: detail detail="PASSWORD_CHANGE_REQUIRED", ) # Return current_user return current_user # Define function require_role def require_role(required_role: str) -> Callable[..., object]: """Return a FastAPI dependency that enforces *required_role*. The dependency allows the request to proceed when ``user.role == required_role`` **or** ``user.role == "admin"``. Otherwise it raises :class:`~fastapi.HTTPException` **403**. """ # Define async function role_checker async def role_checker( # Entry: current_user current_user: User = Depends(get_current_user), ) -> User: # Check: current_user.role != required_role and current_user.role != "admin" if current_user.role != required_role and current_user.role != "admin": # Raise HTTPException raise HTTPException( # Keyword argument: status_code status_code=status.HTTP_403_FORBIDDEN, # Keyword argument: detail detail="Not enough permissions", ) # Return current_user return current_user # Return role_checker return role_checker # Define function require_any_role def require_any_role(*roles: str) -> Callable[..., object]: """Return a FastAPI dependency that enforces **any** of the given *roles*. Admins always pass. Usage example:: @router.patch("/resource", dependencies=[Depends(require_any_role("red_lead", "blue_lead"))]) """ # Define async function role_checker async def role_checker( # Entry: current_user current_user: User = Depends(get_current_user), ) -> User: # Check: current_user.role != "admin" and current_user.role not in roles if current_user.role != "admin" and current_user.role not in roles: # Raise HTTPException raise HTTPException( # Keyword argument: status_code status_code=status.HTTP_403_FORBIDDEN, # Keyword argument: detail detail="Not enough permissions", ) # Return current_user return current_user # Return role_checker return role_checker