feat(refactor): PEP8, type annotations, docstrings and PyJWT security fix
This commit is contained in:
@@ -0,0 +1 @@
|
||||
"""FastAPI dependency injection helpers for auth, DB, and shared state."""
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
"""
|
||||
Authentication and RBAC dependencies for FastAPI.
|
||||
"""Authentication and RBAC dependencies for FastAPI.
|
||||
|
||||
Provides:
|
||||
- ``get_current_user``: decodes JWT from HttpOnly cookie (preferred) or
|
||||
@@ -9,16 +8,34 @@ Provides:
|
||||
(admins always pass).
|
||||
"""
|
||||
|
||||
# Import Callable from collections.abc
|
||||
from collections.abc import Callable
|
||||
|
||||
# Import Optional from typing
|
||||
from typing import Optional
|
||||
|
||||
# Import Cookie, Depends, HTTPException, status from fastapi
|
||||
from fastapi import Cookie, Depends, HTTPException, status
|
||||
|
||||
# Import OAuth2PasswordBearer from fastapi.security
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from jose import JWTError, jwt
|
||||
|
||||
# Import jwt (PyJWT)
|
||||
import jwt
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import auth as auth_lib from app
|
||||
from app import auth as auth_lib
|
||||
|
||||
# Import settings from app.config
|
||||
from app.config import settings
|
||||
|
||||
# Import get_db from app.database
|
||||
from app.database import get_db
|
||||
|
||||
# Import User from app.models.user
|
||||
from app.models.user import User
|
||||
from app.models.api_key import KEY_PREFIX
|
||||
|
||||
@@ -37,8 +54,11 @@ _COOKIE_NAME = "aegis_token"
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
# Entry: aegis_token
|
||||
aegis_token: Optional[str] = Cookie(None),
|
||||
# Entry: bearer_token
|
||||
bearer_token: Optional[str] = Depends(oauth2_scheme),
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
) -> User:
|
||||
"""Decode the JWT, look up the user in *db*, and return it.
|
||||
@@ -54,20 +74,30 @@ async def get_current_user(
|
||||
- the ``sub`` claim is missing, or
|
||||
- no matching active user exists in the database.
|
||||
"""
|
||||
# Assign credentials_exception = HTTPException(
|
||||
credentials_exception = HTTPException(
|
||||
# Keyword argument: status_code
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
# Keyword argument: detail
|
||||
detail="Could not validate credentials",
|
||||
# Keyword argument: headers
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
# Assign revoked_exception = HTTPException(
|
||||
revoked_exception = HTTPException(
|
||||
# Keyword argument: status_code
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
# Keyword argument: detail
|
||||
detail="Token has been revoked",
|
||||
# Keyword argument: headers
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
# Prefer cookie, fall back to header
|
||||
token = aegis_token or bearer_token
|
||||
# Check: token is None
|
||||
if token is None:
|
||||
# Raise credentials_exception
|
||||
raise credentials_exception
|
||||
|
||||
# ── API Key path (Bearer token starts with "aegis_") ──────────────────
|
||||
@@ -80,25 +110,38 @@ async def get_current_user(
|
||||
|
||||
# ── JWT path ──────────────────────────────────────────────────────────
|
||||
try:
|
||||
# Assign payload = jwt.decode(
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.SECRET_KEY,
|
||||
# Keyword argument: algorithms
|
||||
algorithms=[settings.ALGORITHM],
|
||||
)
|
||||
# Assign username = payload.get("sub")
|
||||
username: str | None = payload.get("sub")
|
||||
# Check: username is None
|
||||
if username is None:
|
||||
# Raise credentials_exception
|
||||
raise credentials_exception
|
||||
# Check token blacklist (revoked tokens)
|
||||
jti: str | None = payload.get("jti")
|
||||
# Check: jti and auth_lib.is_token_blacklisted(jti)
|
||||
if jti and auth_lib.is_token_blacklisted(jti):
|
||||
# Raise revoked_exception
|
||||
raise revoked_exception
|
||||
except JWTError:
|
||||
# Handle any JWT validation error (expired, invalid signature, malformed)
|
||||
except jwt.exceptions.InvalidTokenError:
|
||||
# Raise credentials_exception
|
||||
raise credentials_exception
|
||||
|
||||
# Assign user = db.query(User).filter(User.username == username).first()
|
||||
user = db.query(User).filter(User.username == username).first()
|
||||
# Check: user is None or not user.is_active
|
||||
if user is None or not user.is_active:
|
||||
# Raise credentials_exception
|
||||
raise credentials_exception
|
||||
|
||||
# Return user
|
||||
return user
|
||||
|
||||
|
||||
@@ -108,6 +151,7 @@ async def get_current_user(
|
||||
|
||||
|
||||
async def require_password_changed(
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> User:
|
||||
"""Block all requests when the user still needs to change their password.
|
||||
@@ -115,11 +159,16 @@ async def require_password_changed(
|
||||
Only ``/auth/change-password`` and ``/auth/me`` are exempt — those
|
||||
endpoints do **not** depend on this function.
|
||||
"""
|
||||
# Check: getattr(current_user, "must_change_password", False)
|
||||
if getattr(current_user, "must_change_password", False):
|
||||
# Raise HTTPException
|
||||
raise HTTPException(
|
||||
# Keyword argument: status_code
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
# Keyword argument: detail
|
||||
detail="PASSWORD_CHANGE_REQUIRED",
|
||||
)
|
||||
# Return current_user
|
||||
return current_user
|
||||
|
||||
|
||||
@@ -147,22 +196,30 @@ def require_role(required_role: str):
|
||||
Otherwise it raises :class:`~fastapi.HTTPException` **403**.
|
||||
"""
|
||||
|
||||
# Define async function role_checker
|
||||
async def role_checker(
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> User:
|
||||
# Check: current_user.role != required_role and current_user.role != "admin"
|
||||
if current_user.role != required_role and current_user.role != "admin":
|
||||
# Raise HTTPException
|
||||
raise HTTPException(
|
||||
# Keyword argument: status_code
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
# Keyword argument: detail
|
||||
detail="Not enough permissions",
|
||||
)
|
||||
scope = "admin" if required_role == "admin" else "write"
|
||||
_check_api_key_scope(current_user, scope)
|
||||
return current_user
|
||||
|
||||
# Return role_checker
|
||||
return role_checker
|
||||
|
||||
|
||||
def require_any_role(*roles: str):
|
||||
# Define function require_any_role
|
||||
def require_any_role(*roles: str) -> Callable[..., object]:
|
||||
"""Return a FastAPI dependency that enforces **any** of the given *roles*.
|
||||
|
||||
Admins always pass. Also enforces API key scopes: if the only accepted
|
||||
@@ -174,18 +231,25 @@ def require_any_role(*roles: str):
|
||||
@router.patch("/resource", dependencies=[Depends(require_any_role("red_lead", "blue_lead"))])
|
||||
"""
|
||||
|
||||
# Define async function role_checker
|
||||
async def role_checker(
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> User:
|
||||
# Check: current_user.role != "admin" and current_user.role not in roles
|
||||
if current_user.role != "admin" and current_user.role not in roles:
|
||||
# Raise HTTPException
|
||||
raise HTTPException(
|
||||
# Keyword argument: status_code
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
# Keyword argument: detail
|
||||
detail="Not enough permissions",
|
||||
)
|
||||
scope = "admin" if set(roles) == {"admin"} else "write"
|
||||
_check_api_key_scope(current_user, scope)
|
||||
return current_user
|
||||
|
||||
# Return role_checker
|
||||
return role_checker
|
||||
|
||||
|
||||
|
||||
@@ -4,27 +4,41 @@ Wiring lives ONLY in the presentation layer — use cases and services
|
||||
never know which concrete repository implementation they receive.
|
||||
"""
|
||||
|
||||
# Import Depends from fastapi
|
||||
from fastapi import Depends
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import get_db from app.database
|
||||
from app.database import get_db
|
||||
|
||||
# Import from app.infrastructure.persistence.repositories.sa_technique_repository
|
||||
from app.infrastructure.persistence.repositories.sa_technique_repository import (
|
||||
SATechniqueRepository,
|
||||
)
|
||||
|
||||
# Import from app.infrastructure.persistence.repositories.sa_test_repository
|
||||
from app.infrastructure.persistence.repositories.sa_test_repository import (
|
||||
SATestRepository,
|
||||
)
|
||||
|
||||
|
||||
# Define function get_technique_repository
|
||||
def get_technique_repository(
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
) -> SATechniqueRepository:
|
||||
"""Provide a TechniqueRepository backed by the current DB session."""
|
||||
# Return SATechniqueRepository(db)
|
||||
return SATechniqueRepository(db)
|
||||
|
||||
|
||||
# Define function get_test_repository
|
||||
def get_test_repository(
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
) -> SATestRepository:
|
||||
"""Provide a TestRepository backed by the current DB session."""
|
||||
# Return SATestRepository(db)
|
||||
return SATestRepository(db)
|
||||
|
||||
Reference in New Issue
Block a user