fix: resolve 20 security vulnerabilities from comprehensive audit
Critical (1-3): - Replace hardcoded admin credentials with secure auto-generation (seed.py) - Enforce SECRET_KEY configuration, fail in production if missing (config.py) - Add Zip Slip and Zip Bomb protection to all ZIP import services High/Medium (4-9): - Add 50MB file size limit and extension whitelist to evidence uploads - Configure CORS origins via environment variable instead of hardcoded - Migrate JWT storage from localStorage to HttpOnly cookies (frontend+backend) - Add rate limiting (5/min) on login endpoint via slowapi - Replace generic dict payloads with Pydantic schemas (mass assignment) Medium (10-17): - Check is_active on login to prevent disabled users from authenticating - Sanitize exception messages in API responses (system, data_sources) - Escape LIKE wildcards in all ilike search filters across 8 routers - Run Docker container as non-root user (appuser) - Make MINIO_SECURE configurable via environment variable - Add password complexity policy (12+ chars, upper/lower/digit/special) - Implement JWT token revocation via in-memory blacklist + reduce TTL to 15min - Replace xml.etree with defusedxml to prevent Billion Laughs attacks Low (18-20): - Add security headers to Nginx (CSP, X-Frame-Options, HSTS-ready, etc.) - Disable Swagger UI/ReDoc/OpenAPI in production - Restrict /health endpoint to internal networks via Nginx ACL Also: rewrite install.sh as interactive wizard for guided deployment, fix test-from-template validation error (technique_id UUID vs MITRE ID)
This commit is contained in:
@@ -4,10 +4,13 @@ Security utilities: password hashing and JWT token management.
|
||||
This module provides pure functions for:
|
||||
- Hashing and verifying passwords using bcrypt via passlib.
|
||||
- Creating JWT access tokens using python-jose.
|
||||
- Managing an in-memory token blacklist for revocation.
|
||||
|
||||
No endpoints are defined here.
|
||||
"""
|
||||
|
||||
import threading
|
||||
import uuid as _uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from jose import jwt
|
||||
@@ -38,13 +41,53 @@ def verify_password(plain: str, hashed: str) -> bool:
|
||||
|
||||
|
||||
def create_access_token(data: dict) -> str:
|
||||
"""Create a signed JWT containing *data* plus an ``exp`` claim.
|
||||
"""Create a signed JWT containing *data* plus ``exp`` and ``jti`` claims.
|
||||
|
||||
The token expires after ``ACCESS_TOKEN_EXPIRE_MINUTES`` (from settings).
|
||||
- ``jti`` (JWT ID): unique identifier that enables token revocation.
|
||||
- ``exp``: expiration timestamp based on ``ACCESS_TOKEN_EXPIRE_MINUTES``.
|
||||
"""
|
||||
to_encode = data.copy()
|
||||
expire = datetime.now(timezone.utc) + timedelta(
|
||||
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES,
|
||||
)
|
||||
to_encode.update({"exp": expire})
|
||||
to_encode.update({
|
||||
"exp": expire,
|
||||
"jti": str(_uuid.uuid4()),
|
||||
})
|
||||
return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Token blacklist (in-memory)
|
||||
# ---------------------------------------------------------------------------
|
||||
# Stores (jti, expiry_timestamp) tuples. Entries are automatically purged
|
||||
# once they are past their original expiry (the token would be invalid
|
||||
# anyway at that point). Thread-safe via a simple lock.
|
||||
#
|
||||
# For multi-worker / multi-process deployments, consider replacing this
|
||||
# with a shared store like Redis.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_blacklist: dict[str, float] = {} # jti → expiry epoch
|
||||
_blacklist_lock = threading.Lock()
|
||||
|
||||
|
||||
def blacklist_token(jti: str, exp: float) -> None:
|
||||
"""Add *jti* to the blacklist until it naturally expires at *exp*."""
|
||||
with _blacklist_lock:
|
||||
_blacklist[jti] = exp
|
||||
_cleanup_blacklist()
|
||||
|
||||
|
||||
def is_token_blacklisted(jti: str) -> bool:
|
||||
"""Return ``True`` if *jti* has been revoked."""
|
||||
with _blacklist_lock:
|
||||
return jti in _blacklist
|
||||
|
||||
|
||||
def _cleanup_blacklist() -> None:
|
||||
"""Remove entries whose tokens have already expired (caller holds lock)."""
|
||||
now = datetime.now(timezone.utc).timestamp()
|
||||
expired = [k for k, exp in _blacklist.items() if exp < now]
|
||||
for k in expired:
|
||||
del _blacklist[k]
|
||||
|
||||
@@ -1,20 +1,46 @@
|
||||
import os
|
||||
import secrets
|
||||
import warnings
|
||||
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Detect environment: "production" when AEGIS_ENV or common indicators are set
|
||||
# ---------------------------------------------------------------------------
|
||||
_is_production = os.environ.get("AEGIS_ENV", "").lower() == "production" or bool(
|
||||
os.environ.get("SECRET_KEY") # having an explicit SECRET_KEY hints prod
|
||||
)
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
DATABASE_URL: str = "postgresql://postgres:postgres@postgres:5432/attackdb"
|
||||
SECRET_KEY: str = "change-me-in-production"
|
||||
|
||||
# ── Security ──────────────────────────────────────────────────────
|
||||
# SECRET_KEY has NO safe default. In development a random key is
|
||||
# generated at startup (tokens invalidate on restart — acceptable
|
||||
# for local dev). In production it MUST be supplied via env/.env
|
||||
# so tokens survive restarts.
|
||||
SECRET_KEY: str = ""
|
||||
ALGORITHM: str = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 60
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 15 # short-lived for security; configurable via env
|
||||
|
||||
# ── CORS ─────────────────────────────────────────────────────────
|
||||
# Comma-separated list of allowed origins, or a JSON array.
|
||||
# In dev this defaults to common local ports; in production set it
|
||||
# to the actual frontend domain(s).
|
||||
CORS_ORIGINS: str = "http://localhost:3000,http://localhost:5173"
|
||||
|
||||
# ── MinIO / S3 ───────────────────────────────────────────────────
|
||||
MINIO_ENDPOINT: str = "minio:9000"
|
||||
MINIO_ACCESS_KEY: str = "minioadmin"
|
||||
MINIO_SECRET_KEY: str = "minioadmin"
|
||||
MINIO_BUCKET: str = "evidence"
|
||||
MINIO_SECURE: bool = False # True → use HTTPS to connect to MinIO
|
||||
|
||||
# Re-testing
|
||||
# ── Re-testing ───────────────────────────────────────────────────
|
||||
MAX_RETEST_COUNT: int = 3 # maximum automatic retests per original test
|
||||
|
||||
# Scoring weights (must sum to 100)
|
||||
# ── Scoring weights (must sum to 100) ────────────────────────────
|
||||
SCORING_WEIGHT_TESTS: int = 40
|
||||
SCORING_WEIGHT_DETECTION_RULES: int = 20
|
||||
SCORING_WEIGHT_D3FEND: int = 15
|
||||
@@ -26,3 +52,29 @@ class Settings(BaseSettings):
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Post-init validation for SECRET_KEY
|
||||
# ---------------------------------------------------------------------------
|
||||
_UNSAFE_SECRETS = {
|
||||
"",
|
||||
"change-me-in-production",
|
||||
"change-me-in-production-use-a-long-random-string",
|
||||
}
|
||||
|
||||
if settings.SECRET_KEY in _UNSAFE_SECRETS:
|
||||
if _is_production:
|
||||
raise RuntimeError(
|
||||
"CRITICAL: SECRET_KEY is not configured. "
|
||||
"Set a strong random value (>= 32 chars) via the SECRET_KEY "
|
||||
"environment variable or in your .env file before running in "
|
||||
"production. Example: openssl rand -hex 32"
|
||||
)
|
||||
# Development: auto-generate an ephemeral key and warn
|
||||
settings.SECRET_KEY = secrets.token_hex(32)
|
||||
warnings.warn(
|
||||
"SECRET_KEY was not set — using an auto-generated ephemeral key. "
|
||||
"JWT tokens will be invalidated on every restart. "
|
||||
"Set SECRET_KEY in your environment for persistent sessions.",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
@@ -2,25 +2,32 @@
|
||||
Authentication and RBAC dependencies for FastAPI.
|
||||
|
||||
Provides:
|
||||
- ``get_current_user``: decodes JWT, fetches user from DB, raises 401 on failure.
|
||||
- ``get_current_user``: decodes JWT from HttpOnly cookie (preferred) or
|
||||
Authorization header (fallback), fetches user from DB, raises 401 on failure.
|
||||
- ``require_role``: factory that returns a dependency enforcing a specific role
|
||||
(admins always pass).
|
||||
"""
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import Cookie, Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from jose import JWTError, jwt
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.auth import is_token_blacklisted
|
||||
from app.config import settings
|
||||
from app.database import get_db
|
||||
from app.models.user import User
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OAuth2 scheme
|
||||
# OAuth2 scheme (reads Authorization header — used as fallback / Swagger UI)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login", auto_error=False)
|
||||
|
||||
# Cookie name — must match the one set in the auth router
|
||||
_COOKIE_NAME = "aegis_token"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Current-user dependency
|
||||
@@ -28,12 +35,19 @@ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
token: str = Depends(oauth2_scheme),
|
||||
aegis_token: Optional[str] = Cookie(None),
|
||||
bearer_token: Optional[str] = Depends(oauth2_scheme),
|
||||
db: Session = Depends(get_db),
|
||||
) -> User:
|
||||
"""Decode the JWT *token*, look up the user in *db*, and return it.
|
||||
"""Decode the JWT, look up the user in *db*, and return it.
|
||||
|
||||
Token resolution order:
|
||||
1. ``aegis_token`` **HttpOnly cookie** (preferred — immune to XSS).
|
||||
2. ``Authorization: Bearer <token>`` header (fallback for API clients
|
||||
and Swagger UI).
|
||||
|
||||
Raises :class:`~fastapi.HTTPException` **401** when:
|
||||
- no token is found in either location,
|
||||
- the token cannot be decoded,
|
||||
- the ``sub`` claim is missing, or
|
||||
- no matching active user exists in the database.
|
||||
@@ -44,6 +58,11 @@ async def get_current_user(
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
# Prefer cookie, fall back to header
|
||||
token = aegis_token or bearer_token
|
||||
if token is None:
|
||||
raise credentials_exception
|
||||
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
@@ -53,11 +72,15 @@ async def get_current_user(
|
||||
username: str | None = payload.get("sub")
|
||||
if username is None:
|
||||
raise credentials_exception
|
||||
# Check token blacklist (revoked tokens)
|
||||
jti: str | None = payload.get("jti")
|
||||
if jti and is_token_blacklisted(jti):
|
||||
raise credentials_exception
|
||||
except JWTError:
|
||||
raise credentials_exception
|
||||
|
||||
user = db.query(User).filter(User.username == username).first()
|
||||
if user is None:
|
||||
if user is None or not user.is_active:
|
||||
raise credentials_exception
|
||||
|
||||
return user
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
import logging
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI, Request, status
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from slowapi import Limiter, _rate_limit_exceeded_handler
|
||||
from slowapi.errors import RateLimitExceeded
|
||||
from slowapi.util import get_remote_address
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from app.routers import auth as auth_router
|
||||
@@ -31,6 +35,9 @@ from app.routers import snapshots as snapshots_router
|
||||
from app.storage import ensure_bucket_exists
|
||||
from app.jobs.mitre_sync_job import start_scheduler, scheduler
|
||||
|
||||
# ── Environment detection ─────────────────────────────────────────────────
|
||||
_IS_PRODUCTION = os.environ.get("AEGIS_ENV", "").lower() == "production"
|
||||
|
||||
# ── Logging ───────────────────────────────────────────────────────────────
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
@@ -47,15 +54,33 @@ async def lifespan(app: FastAPI):
|
||||
scheduler.shutdown(wait=False)
|
||||
|
||||
|
||||
app = FastAPI(title="Attack Coverage Platform", lifespan=lifespan)
|
||||
# ── In production, disable Swagger UI and ReDoc to hide API surface ──────
|
||||
app = FastAPI(
|
||||
title="Attack Coverage Platform",
|
||||
lifespan=lifespan,
|
||||
docs_url=None if _IS_PRODUCTION else "/docs",
|
||||
redoc_url=None if _IS_PRODUCTION else "/redoc",
|
||||
openapi_url=None if _IS_PRODUCTION else "/openapi.json",
|
||||
)
|
||||
|
||||
# ── Rate Limiter ──────────────────────────────────────────────────────────
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
app.state.limiter = limiter
|
||||
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
||||
|
||||
# ── CORS ──────────────────────────────────────────────────────────────────
|
||||
from app.config import settings as _settings
|
||||
|
||||
_cors_origins: list[str] = [
|
||||
o.strip() for o in _settings.CORS_ORIGINS.split(",") if o.strip()
|
||||
]
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["http://localhost:3000", "http://localhost:5173"],
|
||||
allow_origins=_cors_origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
|
||||
allow_headers=["Authorization", "Content-Type"],
|
||||
)
|
||||
|
||||
# ── Routers ──────────────────────────────────────────────────────────────
|
||||
@@ -82,8 +107,13 @@ app.include_router(compliance_router.router, prefix="/api/v1")
|
||||
app.include_router(snapshots_router.router, prefix="/api/v1")
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
@app.get("/health", include_in_schema=False)
|
||||
def health():
|
||||
"""Minimal health check — returns only an HTTP 200 with no service metadata.
|
||||
|
||||
Access is restricted to internal networks at the Nginx level
|
||||
(see ``frontend/nginx.conf``).
|
||||
"""
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
|
||||
@@ -1,17 +1,40 @@
|
||||
"""Authentication router: login and current-user endpoints."""
|
||||
"""Authentication router: login, logout and current-user endpoints.
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
The JWT access token is delivered as an **HttpOnly** cookie
|
||||
(``aegis_token``) so it is inaccessible to client-side JavaScript,
|
||||
mitigating XSS token-theft attacks. The JSON response also includes
|
||||
the token in the body for backwards compatibility and for clients that
|
||||
cannot use cookies (e.g. Swagger UI).
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from fastapi import APIRouter, Cookie, Depends, HTTPException, Request, Response, status
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.auth import verify_password, create_access_token
|
||||
from jose import jwt, JWTError
|
||||
|
||||
from app.auth import verify_password, create_access_token, blacklist_token
|
||||
from app.config import settings
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import get_current_user
|
||||
from app.models.user import User
|
||||
from app.schemas.auth import TokenResponse, UserOut
|
||||
|
||||
# Rate limiter instance (shares backend state via app.state.limiter)
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
# Detect whether we're behind HTTPS (production) so the cookie can be Secure
|
||||
_IS_HTTPS = os.environ.get("AEGIS_ENV", "").lower() == "production"
|
||||
|
||||
# Cookie name used to transport the JWT
|
||||
_COOKIE_NAME = "aegis_token"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /auth/login
|
||||
@@ -19,11 +42,19 @@ router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
|
||||
@router.post("/login", response_model=TokenResponse)
|
||||
@limiter.limit("5/minute")
|
||||
def login(
|
||||
request: Request,
|
||||
response: Response,
|
||||
form_data: OAuth2PasswordRequestForm = Depends(),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Authenticate a user and return a JWT access token."""
|
||||
"""Authenticate a user and return a JWT access token.
|
||||
|
||||
Rate-limited to **5 attempts per minute per IP** to prevent brute-force
|
||||
attacks. The token is set as an HttpOnly cookie **and** returned in the
|
||||
JSON body for API/Swagger compatibility.
|
||||
"""
|
||||
user = db.query(User).filter(User.username == form_data.username).first()
|
||||
|
||||
if user is None or not verify_password(form_data.password, user.hashed_password):
|
||||
@@ -32,10 +63,70 @@ def login(
|
||||
detail="Incorrect username or password",
|
||||
)
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Account is disabled. Contact an administrator.",
|
||||
)
|
||||
|
||||
access_token = create_access_token(data={"sub": user.username})
|
||||
|
||||
# Set HttpOnly cookie — inaccessible from JS
|
||||
response.set_cookie(
|
||||
key=_COOKIE_NAME,
|
||||
value=access_token,
|
||||
httponly=True,
|
||||
secure=_IS_HTTPS,
|
||||
samesite="strict",
|
||||
max_age=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
||||
path="/",
|
||||
)
|
||||
|
||||
return TokenResponse(access_token=access_token)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /auth/logout
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.post("/logout")
|
||||
def logout(
|
||||
request: Request,
|
||||
response: Response,
|
||||
aegis_token: str | None = Cookie(None),
|
||||
):
|
||||
"""Clear the authentication cookie and revoke the current token.
|
||||
|
||||
The token's ``jti`` is added to an in-memory blacklist so it cannot
|
||||
be reused even if the cookie has already been copied elsewhere.
|
||||
"""
|
||||
# Attempt to blacklist the token's jti
|
||||
token = aegis_token or request.headers.get("Authorization", "").removeprefix("Bearer ").strip()
|
||||
if token:
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.SECRET_KEY,
|
||||
algorithms=[settings.ALGORITHM],
|
||||
)
|
||||
jti = payload.get("jti")
|
||||
exp = payload.get("exp", 0)
|
||||
if jti:
|
||||
blacklist_token(jti, float(exp))
|
||||
except JWTError:
|
||||
pass # token already invalid — nothing to revoke
|
||||
|
||||
response.delete_cookie(
|
||||
key=_COOKIE_NAME,
|
||||
httponly=True,
|
||||
secure=_IS_HTTPS,
|
||||
samesite="strict",
|
||||
path="/",
|
||||
)
|
||||
return {"detail": "Logged out"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /auth/me
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -174,7 +174,8 @@ def list_campaigns(
|
||||
if threat_actor_id:
|
||||
query = query.filter(Campaign.threat_actor_id == threat_actor_id)
|
||||
if search:
|
||||
pattern = f"%{search}%"
|
||||
from app.utils import escape_like
|
||||
pattern = f"%{escape_like(search)}%"
|
||||
query = query.filter(Campaign.name.ilike(pattern) | Campaign.description.ilike(pattern))
|
||||
|
||||
total = query.count()
|
||||
|
||||
@@ -42,7 +42,8 @@ def list_defensive_techniques(
|
||||
query = query.filter(DefensiveTechnique.tactic == tactic)
|
||||
|
||||
if search:
|
||||
pattern = f"%{search}%"
|
||||
from app.utils import escape_like
|
||||
pattern = f"%{escape_like(search)}%"
|
||||
query = query.filter(
|
||||
DefensiveTechnique.name.ilike(pattern)
|
||||
| DefensiveTechnique.d3fend_id.ilike(pattern)
|
||||
|
||||
@@ -7,8 +7,10 @@ including sync triggers, enable/disable toggles, and statistics.
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
@@ -17,6 +19,17 @@ from app.models.user import User
|
||||
from app.models.data_source import DataSource
|
||||
from app.services.audit_service import log_action
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pydantic schemas for request validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class DataSourceUpdate(BaseModel):
|
||||
"""Payload for updating a data source — only allowed fields."""
|
||||
is_enabled: Optional[bool] = None
|
||||
sync_frequency: Optional[str] = None
|
||||
config: Optional[dict] = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/data-sources", tags=["data-sources"])
|
||||
@@ -90,29 +103,26 @@ def list_data_sources(
|
||||
@router.patch("/{source_id}")
|
||||
def update_data_source(
|
||||
source_id: str,
|
||||
body: dict,
|
||||
body: DataSourceUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Update a data source (enable/disable, change config).
|
||||
|
||||
**Requires** the ``admin`` role.
|
||||
|
||||
Body fields (all optional):
|
||||
- ``is_enabled`` (bool)
|
||||
- ``sync_frequency`` (str)
|
||||
- ``config`` (dict)
|
||||
"""
|
||||
ds = db.query(DataSource).filter(DataSource.id == source_id).first()
|
||||
if not ds:
|
||||
raise HTTPException(status_code=404, detail="Data source not found")
|
||||
|
||||
if "is_enabled" in body:
|
||||
ds.is_enabled = bool(body["is_enabled"])
|
||||
if "sync_frequency" in body:
|
||||
ds.sync_frequency = body["sync_frequency"]
|
||||
if "config" in body:
|
||||
ds.config = body["config"]
|
||||
update_data = body.model_dump(exclude_unset=True)
|
||||
|
||||
if "is_enabled" in update_data:
|
||||
ds.is_enabled = update_data["is_enabled"]
|
||||
if "sync_frequency" in update_data:
|
||||
ds.sync_frequency = update_data["sync_frequency"]
|
||||
if "config" in update_data:
|
||||
ds.config = update_data["config"]
|
||||
|
||||
db.commit()
|
||||
|
||||
@@ -122,7 +132,7 @@ def update_data_source(
|
||||
action="update_data_source",
|
||||
entity_type="data_source",
|
||||
entity_id=str(ds.id),
|
||||
details={"updates": body},
|
||||
details={"updates": update_data},
|
||||
)
|
||||
|
||||
return {"message": "Data source updated", "id": str(ds.id)}
|
||||
@@ -156,14 +166,14 @@ def sync_data_source(
|
||||
try:
|
||||
summary = handler(db)
|
||||
except Exception as exc:
|
||||
logger.error("Sync failed for %s: %s", ds.name, exc)
|
||||
logger.error("Sync failed for %s: %s", ds.name, exc, exc_info=True)
|
||||
ds.last_sync_status = "error"
|
||||
ds.last_sync_at = datetime.utcnow()
|
||||
ds.last_sync_stats = {"error": str(exc)}
|
||||
db.commit()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Sync failed: {str(exc)}",
|
||||
detail=f"Sync failed for '{ds.display_name}'. Check server logs for details.",
|
||||
)
|
||||
|
||||
# Update DS record (the handler may already have done this,
|
||||
@@ -222,7 +232,7 @@ def sync_all_data_sources(
|
||||
"stats": summary,
|
||||
})
|
||||
except Exception as exc:
|
||||
logger.error("Sync failed for %s: %s", ds.name, exc)
|
||||
logger.error("Sync failed for %s: %s", ds.name, exc, exc_info=True)
|
||||
ds.last_sync_status = "error"
|
||||
ds.last_sync_at = datetime.utcnow()
|
||||
ds.last_sync_stats = {"error": str(exc)}
|
||||
@@ -230,7 +240,7 @@ def sync_all_data_sources(
|
||||
results.append({
|
||||
"source": ds.name,
|
||||
"status": "error",
|
||||
"detail": str(exc),
|
||||
"detail": "Sync failed. Check server logs for details.",
|
||||
})
|
||||
|
||||
log_action(
|
||||
|
||||
@@ -5,10 +5,12 @@ and managing the template ↔ detection rule associations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -20,6 +22,18 @@ from app.models.test_template import TestTemplate
|
||||
from app.models.test_template_detection_rule import TestTemplateDetectionRule
|
||||
from app.models.test_detection_result import TestDetectionResult
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pydantic schemas for request validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class DetectionRuleEvaluate(BaseModel):
|
||||
"""Payload for evaluating a detection rule against a test."""
|
||||
test_id: uuid.UUID
|
||||
detection_rule_id: uuid.UUID
|
||||
triggered: Optional[bool] = None
|
||||
notes: Optional[str] = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/detection-rules", tags=["detection-rules"])
|
||||
@@ -53,7 +67,8 @@ def list_detection_rules(
|
||||
query = query.filter(DetectionRule.severity == severity)
|
||||
|
||||
if search:
|
||||
pattern = f"%{search}%"
|
||||
from app.utils import escape_like
|
||||
pattern = f"%{escape_like(search)}%"
|
||||
query = query.filter(
|
||||
DetectionRule.title.ilike(pattern)
|
||||
| DetectionRule.description.ilike(pattern)
|
||||
@@ -294,27 +309,15 @@ def get_detection_rules_for_test(
|
||||
|
||||
@router.post("/evaluate")
|
||||
def evaluate_detection_rule(
|
||||
payload: dict,
|
||||
payload: DetectionRuleEvaluate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_any_role("blue_tech", "blue_lead")),
|
||||
):
|
||||
"""Save or update the evaluation result for a detection rule on a test.
|
||||
|
||||
Body:
|
||||
{
|
||||
"test_id": "...",
|
||||
"detection_rule_id": "...",
|
||||
"triggered": true | false | null,
|
||||
"notes": "optional notes"
|
||||
}
|
||||
"""
|
||||
test_id = payload.get("test_id")
|
||||
detection_rule_id = payload.get("detection_rule_id")
|
||||
triggered = payload.get("triggered")
|
||||
notes = payload.get("notes")
|
||||
|
||||
if not test_id or not detection_rule_id:
|
||||
raise HTTPException(status_code=400, detail="test_id and detection_rule_id are required")
|
||||
"""Save or update the evaluation result for a detection rule on a test."""
|
||||
test_id = payload.test_id
|
||||
detection_rule_id = payload.detection_rule_id
|
||||
triggered = payload.triggered
|
||||
notes = payload.notes
|
||||
|
||||
# Check test exists
|
||||
from app.models.test import Test
|
||||
|
||||
@@ -20,6 +20,7 @@ Access Control
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
import uuid as _uuid
|
||||
from typing import Optional
|
||||
|
||||
@@ -43,6 +44,29 @@ _RED_EDITABLE_STATES = (TestState.draft, TestState.red_executing)
|
||||
# States where blue evidence can be uploaded / deleted
|
||||
_BLUE_EDITABLE_STATES = (TestState.blue_evaluating,)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Upload safety limits
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Maximum upload size in bytes (default 50 MB)
|
||||
_MAX_UPLOAD_SIZE = 50 * 1024 * 1024
|
||||
|
||||
# Allowed file extensions (lowercase, with leading dot)
|
||||
_ALLOWED_EXTENSIONS: set[str] = {
|
||||
# Images / screenshots
|
||||
".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp", ".svg",
|
||||
# Documents
|
||||
".pdf", ".doc", ".docx", ".xls", ".xlsx", ".csv", ".txt",
|
||||
".md", ".rtf", ".odt", ".ods",
|
||||
# Logs & captures
|
||||
".log", ".pcap", ".pcapng", ".evtx", ".json", ".xml",
|
||||
".yaml", ".yml", ".toml",
|
||||
# Archives (for bundled evidence)
|
||||
".zip", ".tar", ".gz", ".7z",
|
||||
# Other common evidence types
|
||||
".har", ".eml", ".msg",
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
@@ -177,21 +201,39 @@ async def upload_evidence(
|
||||
# Validate permissions
|
||||
_validate_upload_permission(test, team, current_user)
|
||||
|
||||
# 1. Read content + hash
|
||||
content = await file.read()
|
||||
# 1. Validate file extension
|
||||
file_name = file.filename or "unnamed"
|
||||
_, ext = os.path.splitext(file_name)
|
||||
if ext.lower() not in _ALLOWED_EXTENSIONS:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"File type '{ext}' is not allowed. "
|
||||
f"Permitted types: {', '.join(sorted(_ALLOWED_EXTENSIONS))}",
|
||||
)
|
||||
|
||||
# 2. Read content with size limit
|
||||
content = await file.read(_MAX_UPLOAD_SIZE + 1)
|
||||
if len(content) > _MAX_UPLOAD_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
|
||||
detail=f"File exceeds maximum upload size of "
|
||||
f"{_MAX_UPLOAD_SIZE // (1024 * 1024)} MB",
|
||||
)
|
||||
|
||||
# 3. Hash
|
||||
sha256 = hashlib.sha256(content).hexdigest()
|
||||
|
||||
# 2. Object key
|
||||
file_name = file.filename or "unnamed"
|
||||
key = f"{test_id}/{_uuid.uuid4()}_{file_name}"
|
||||
# 4. Object key (sanitise filename to prevent path traversal in storage)
|
||||
safe_name = os.path.basename(file_name)
|
||||
key = f"{test_id}/{_uuid.uuid4()}_{safe_name}"
|
||||
|
||||
# 3. Upload to MinIO
|
||||
# 5. Upload to MinIO
|
||||
upload_file(content, key)
|
||||
|
||||
# 4. Persist metadata
|
||||
# 6. Persist metadata
|
||||
evidence = Evidence(
|
||||
test_id=test_id,
|
||||
file_name=file_name,
|
||||
file_name=safe_name,
|
||||
file_path=key,
|
||||
sha256_hash=sha256,
|
||||
uploaded_by=current_user.id,
|
||||
@@ -202,7 +244,7 @@ async def upload_evidence(
|
||||
db.commit()
|
||||
db.refresh(evidence)
|
||||
|
||||
# 5. Audit
|
||||
# 7. Audit
|
||||
log_action(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
@@ -210,7 +252,7 @@ async def upload_evidence(
|
||||
entity_type="evidence",
|
||||
entity_id=evidence.id,
|
||||
details={
|
||||
"file_name": file_name,
|
||||
"file_name": safe_name,
|
||||
"sha256": sha256,
|
||||
"test_id": str(test_id),
|
||||
"team": team.value,
|
||||
|
||||
@@ -107,9 +107,10 @@ def _apply_filters(
|
||||
query = query.filter(or_(*platform_filters))
|
||||
if tactics:
|
||||
from sqlalchemy import or_
|
||||
from app.utils import escape_like
|
||||
tactic_filters = []
|
||||
for tactic in tactics:
|
||||
tactic_filters.append(model.tactic.ilike(f"%{tactic}%"))
|
||||
tactic_filters.append(model.tactic.ilike(f"%{escape_like(tactic)}%"))
|
||||
query = query.filter(or_(*tactic_filters))
|
||||
return query
|
||||
|
||||
|
||||
@@ -43,7 +43,8 @@ def coverage_summary(
|
||||
"""Full coverage report as JSON — technique-by-technique with test counts."""
|
||||
query = db.query(Technique)
|
||||
if tactic:
|
||||
query = query.filter(Technique.tactic.ilike(f"%{tactic}%"))
|
||||
from app.utils import escape_like
|
||||
query = query.filter(Technique.tactic.ilike(f"%{escape_like(tactic)}%"))
|
||||
|
||||
techniques = query.order_by(Technique.mitre_id).all()
|
||||
|
||||
@@ -109,7 +110,8 @@ def coverage_csv(
|
||||
"""Export coverage as a downloadable CSV."""
|
||||
query = db.query(Technique)
|
||||
if tactic:
|
||||
query = query.filter(Technique.tactic.ilike(f"%{tactic}%"))
|
||||
from app.utils import escape_like
|
||||
query = query.filter(Technique.tactic.ilike(f"%{escape_like(tactic)}%"))
|
||||
|
||||
techniques = query.order_by(Technique.mitre_id).all()
|
||||
|
||||
|
||||
@@ -80,10 +80,9 @@ def trigger_atomic_import(
|
||||
try:
|
||||
summary = import_atomic_red_team(db)
|
||||
except Exception as exc:
|
||||
logger.error("Atomic Red Team import failed: %s", exc)
|
||||
logger.error("Atomic Red Team import failed: %s", exc, exc_info=True)
|
||||
return {
|
||||
"message": "Import failed",
|
||||
"error": str(exc),
|
||||
"message": "Import failed. Check server logs for details.",
|
||||
}
|
||||
|
||||
return {
|
||||
|
||||
@@ -69,13 +69,15 @@ def list_templates(
|
||||
if source:
|
||||
query = query.filter(TestTemplate.source == source)
|
||||
if platform:
|
||||
query = query.filter(TestTemplate.platform.ilike(f"%{platform}%"))
|
||||
from app.utils import escape_like
|
||||
query = query.filter(TestTemplate.platform.ilike(f"%{escape_like(platform)}%"))
|
||||
if severity:
|
||||
query = query.filter(TestTemplate.severity == severity)
|
||||
if mitre_technique_id:
|
||||
query = query.filter(TestTemplate.mitre_technique_id == mitre_technique_id)
|
||||
if search:
|
||||
pattern = f"%{search}%"
|
||||
from app.utils import escape_like
|
||||
pattern = f"%{escape_like(search)}%"
|
||||
query = query.filter(
|
||||
or_(
|
||||
TestTemplate.name.ilike(pattern),
|
||||
|
||||
@@ -109,7 +109,8 @@ def list_tests(
|
||||
if technique_id:
|
||||
query = query.filter(Test.technique_id == technique_id)
|
||||
if platform:
|
||||
query = query.filter(Test.platform.ilike(f"%{platform}%"))
|
||||
from app.utils import escape_like
|
||||
query = query.filter(Test.platform.ilike(f"%{escape_like(platform)}%"))
|
||||
if created_by:
|
||||
query = query.filter(Test.created_by == created_by)
|
||||
if pending_validation_side == "red":
|
||||
@@ -200,15 +201,25 @@ def create_test_from_template(
|
||||
detail=f"TestTemplate with id '{payload.template_id}' not found",
|
||||
)
|
||||
|
||||
technique = db.query(Technique).filter(Technique.id == payload.technique_id).first()
|
||||
# Resolve technique_id: accept both UUID and MITRE ID (e.g. "T1059.001")
|
||||
technique = None
|
||||
try:
|
||||
technique_uuid = uuid.UUID(payload.technique_id)
|
||||
technique = db.query(Technique).filter(Technique.id == technique_uuid).first()
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if technique is None:
|
||||
technique = db.query(Technique).filter(Technique.mitre_id == payload.technique_id).first()
|
||||
|
||||
if technique is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Technique with id '{payload.technique_id}' not found",
|
||||
detail=f"Technique '{payload.technique_id}' not found",
|
||||
)
|
||||
|
||||
test = Test(
|
||||
technique_id=payload.technique_id,
|
||||
technique_id=technique.id,
|
||||
name=template.name,
|
||||
description=template.description,
|
||||
platform=template.platform,
|
||||
|
||||
@@ -49,7 +49,8 @@ def list_threat_actors(
|
||||
|
||||
# Filters
|
||||
if search:
|
||||
pattern = f"%{search}%"
|
||||
from app.utils import escape_like
|
||||
pattern = f"%{escape_like(search)}%"
|
||||
query = query.filter(
|
||||
or_(
|
||||
ThreatActor.name.ilike(pattern),
|
||||
@@ -68,9 +69,10 @@ def list_threat_actors(
|
||||
query = query.filter(ThreatActor.sophistication == sophistication)
|
||||
|
||||
if target_sectors:
|
||||
from app.utils import escape_like
|
||||
# JSONB contains check
|
||||
query = query.filter(
|
||||
func.cast(ThreatActor.target_sectors, func.text()).ilike(f"%{target_sectors}%")
|
||||
func.cast(ThreatActor.target_sectors, func.text()).ilike(f"%{escape_like(target_sectors)}%")
|
||||
)
|
||||
|
||||
# Total count
|
||||
|
||||
@@ -75,4 +75,4 @@ class TestTemplateInstantiate(BaseModel):
|
||||
"""Payload to create a real test from an existing template."""
|
||||
|
||||
template_id: uuid.UUID
|
||||
technique_id: uuid.UUID
|
||||
technique_id: str # accepts both UUID and MITRE ID (e.g. "T1059.001")
|
||||
|
||||
@@ -1,9 +1,49 @@
|
||||
"""Pydantic schemas for User management endpoints."""
|
||||
|
||||
import re
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, EmailStr
|
||||
from pydantic import BaseModel, ConfigDict, EmailStr, field_validator
|
||||
|
||||
|
||||
# ── Password policy ─────────────────────────────────────────────────
|
||||
|
||||
_MIN_PASSWORD_LENGTH = 12
|
||||
|
||||
_PASSWORD_RULES: list[tuple[str, str]] = [
|
||||
(r"[A-Z]", "at least one uppercase letter"),
|
||||
(r"[a-z]", "at least one lowercase letter"),
|
||||
(r"[0-9]", "at least one digit"),
|
||||
(r"[!@#$%^&*()_+\-=\[\]{};':\"\\|,.<>/?`~]", "at least one special character"),
|
||||
]
|
||||
|
||||
|
||||
def _validate_password_strength(password: str) -> str:
|
||||
"""Check that *password* satisfies the complexity policy.
|
||||
|
||||
Rules:
|
||||
- Minimum 12 characters
|
||||
- At least one uppercase letter
|
||||
- At least one lowercase letter
|
||||
- At least one digit
|
||||
- At least one special character
|
||||
"""
|
||||
errors: list[str] = []
|
||||
|
||||
if len(password) < _MIN_PASSWORD_LENGTH:
|
||||
errors.append(f"must be at least {_MIN_PASSWORD_LENGTH} characters long")
|
||||
|
||||
for pattern, description in _PASSWORD_RULES:
|
||||
if not re.search(pattern, password):
|
||||
errors.append(description)
|
||||
|
||||
if errors:
|
||||
raise ValueError(
|
||||
"Password does not meet complexity requirements: " + "; ".join(errors)
|
||||
)
|
||||
|
||||
return password
|
||||
|
||||
|
||||
# ── Create ──────────────────────────────────────────────────────────
|
||||
@@ -16,6 +56,11 @@ class UserCreate(BaseModel):
|
||||
password: str
|
||||
role: str = "viewer"
|
||||
|
||||
@field_validator("password")
|
||||
@classmethod
|
||||
def password_strength(cls, v: str) -> str:
|
||||
return _validate_password_strength(v)
|
||||
|
||||
|
||||
# ── Update ──────────────────────────────────────────────────────────
|
||||
|
||||
@@ -28,6 +73,13 @@ class UserUpdate(BaseModel):
|
||||
is_active: bool | None = None
|
||||
password: str | None = None
|
||||
|
||||
@field_validator("password")
|
||||
@classmethod
|
||||
def password_strength(cls, v: str | None) -> str | None:
|
||||
if v is not None:
|
||||
return _validate_password_strength(v)
|
||||
return v
|
||||
|
||||
|
||||
# ── Read (full) ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@@ -1,32 +1,80 @@
|
||||
"""
|
||||
Seed script — creates the initial admin user if it does not already exist.
|
||||
|
||||
On first run the admin credentials are generated securely:
|
||||
- Username is read from ``ADMIN_USERNAME`` env var (default: ``admin``).
|
||||
- Password is read from ``ADMIN_PASSWORD`` env var. When the variable is
|
||||
**not set**, a cryptographically random 16-character password is generated
|
||||
automatically and printed to the startup logs so the operator can copy it.
|
||||
|
||||
Usage:
|
||||
python -m app.seed
|
||||
"""
|
||||
|
||||
import os
|
||||
import secrets
|
||||
import string
|
||||
|
||||
from app.auth import hash_password
|
||||
from app.database import SessionLocal
|
||||
from app.models.user import User
|
||||
|
||||
# Characters for auto-generated passwords (alphanumeric + safe symbols)
|
||||
_PW_ALPHABET = string.ascii_letters + string.digits + "!@#$%&*-_+"
|
||||
|
||||
|
||||
def _generate_password(length: int = 16) -> str:
|
||||
"""Return a cryptographically random password of *length* characters."""
|
||||
return "".join(secrets.choice(_PW_ALPHABET) for _ in range(length))
|
||||
|
||||
|
||||
def seed_admin() -> None:
|
||||
"""Create the default admin user when it is missing."""
|
||||
"""Create the initial admin user when it is missing.
|
||||
|
||||
Reads ``ADMIN_USERNAME`` and ``ADMIN_PASSWORD`` from the environment.
|
||||
If ``ADMIN_PASSWORD`` is empty or unset a secure random password is
|
||||
generated and displayed in the logs.
|
||||
"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
existing = db.query(User).filter(User.username == "admin").first()
|
||||
admin_username = os.environ.get("ADMIN_USERNAME", "admin").strip() or "admin"
|
||||
|
||||
existing = db.query(User).filter(User.username == admin_username).first()
|
||||
if existing:
|
||||
print("Admin user already exists — skipping.")
|
||||
print(f"Admin user '{admin_username}' already exists — skipping.")
|
||||
return
|
||||
|
||||
admin_password = os.environ.get("ADMIN_PASSWORD", "").strip()
|
||||
password_was_generated = False
|
||||
|
||||
if not admin_password:
|
||||
admin_password = _generate_password()
|
||||
password_was_generated = True
|
||||
|
||||
admin = User(
|
||||
username="admin",
|
||||
hashed_password=hash_password("admin123"),
|
||||
username=admin_username,
|
||||
hashed_password=hash_password(admin_password),
|
||||
role="admin",
|
||||
)
|
||||
db.add(admin)
|
||||
db.commit()
|
||||
print("Admin user created successfully.")
|
||||
|
||||
# ── Display credentials in startup logs ──────────────────────
|
||||
print()
|
||||
print("=" * 60)
|
||||
print(" AEGIS — Initial Admin User Created")
|
||||
print("=" * 60)
|
||||
print(f" Username : {admin_username}")
|
||||
if password_was_generated:
|
||||
print(f" Password : {admin_password}")
|
||||
print()
|
||||
print(" ** This password was auto-generated because")
|
||||
print(" ADMIN_PASSWORD was not set in the environment. **")
|
||||
print(" ** Save it now — it will NOT be shown again. **")
|
||||
else:
|
||||
print(" Password : (set via ADMIN_PASSWORD env var)")
|
||||
print("=" * 60)
|
||||
print()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@@ -70,10 +70,50 @@ def _download_zip(url: str = ATOMIC_RT_ZIP_URL) -> bytes:
|
||||
return content
|
||||
|
||||
|
||||
def _safe_extract_zip(zip_bytes: bytes, dest: str) -> None:
|
||||
"""Extract *zip_bytes* into *dest* with Zip Slip and Zip Bomb protection.
|
||||
|
||||
Raises :class:`ValueError` if any member tries to escape the target
|
||||
directory (path traversal / Zip Slip) or if the archive exceeds the
|
||||
safety limits.
|
||||
"""
|
||||
# Maximum uncompressed size: 500 MB — prevents zip-bomb DoS
|
||||
_MAX_UNCOMPRESSED_SIZE = 500 * 1024 * 1024
|
||||
# Maximum number of entries
|
||||
_MAX_ENTRIES = 50_000
|
||||
|
||||
dest_path = Path(dest).resolve()
|
||||
|
||||
with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
|
||||
entries = zf.infolist()
|
||||
|
||||
if len(entries) > _MAX_ENTRIES:
|
||||
raise ValueError(
|
||||
f"ZIP archive contains {len(entries)} entries "
|
||||
f"(limit: {_MAX_ENTRIES}) — possible zip bomb"
|
||||
)
|
||||
|
||||
total_size = sum(info.file_size for info in entries)
|
||||
if total_size > _MAX_UNCOMPRESSED_SIZE:
|
||||
raise ValueError(
|
||||
f"ZIP uncompressed size {total_size / (1024 * 1024):.0f} MB "
|
||||
f"exceeds limit of {_MAX_UNCOMPRESSED_SIZE / (1024 * 1024):.0f} MB"
|
||||
)
|
||||
|
||||
for member in entries:
|
||||
target = (dest_path / member.filename).resolve()
|
||||
if not target.is_relative_to(dest_path):
|
||||
raise ValueError(
|
||||
f"Zip Slip detected — member '{member.filename}' "
|
||||
f"resolves outside target directory"
|
||||
)
|
||||
|
||||
zf.extractall(dest)
|
||||
|
||||
|
||||
def _extract_zip(zip_bytes: bytes, dest: str) -> Path:
|
||||
"""Extract *zip_bytes* into *dest* and return the path to the atomics/ dir."""
|
||||
with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
|
||||
zf.extractall(dest)
|
||||
_safe_extract_zip(zip_bytes, dest)
|
||||
atomics_dir = Path(dest) / _ZIP_ROOT_PREFIX / "atomics"
|
||||
if not atomics_dir.is_dir():
|
||||
raise FileNotFoundError(
|
||||
|
||||
@@ -75,10 +75,50 @@ def _download_zip(url: str = ELASTIC_ZIP_URL) -> bytes:
|
||||
return content
|
||||
|
||||
|
||||
def _safe_extract_zip(zip_bytes: bytes, dest: str) -> None:
|
||||
"""Extract *zip_bytes* into *dest* with Zip Slip and Zip Bomb protection.
|
||||
|
||||
Raises :class:`ValueError` if any member tries to escape the target
|
||||
directory (path traversal / Zip Slip) or if the archive exceeds the
|
||||
safety limits.
|
||||
"""
|
||||
# Maximum uncompressed size: 500 MB — prevents zip-bomb DoS
|
||||
_MAX_UNCOMPRESSED_SIZE = 500 * 1024 * 1024
|
||||
# Maximum number of entries
|
||||
_MAX_ENTRIES = 50_000
|
||||
|
||||
dest_path = Path(dest).resolve()
|
||||
|
||||
with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
|
||||
entries = zf.infolist()
|
||||
|
||||
if len(entries) > _MAX_ENTRIES:
|
||||
raise ValueError(
|
||||
f"ZIP archive contains {len(entries)} entries "
|
||||
f"(limit: {_MAX_ENTRIES}) — possible zip bomb"
|
||||
)
|
||||
|
||||
total_size = sum(info.file_size for info in entries)
|
||||
if total_size > _MAX_UNCOMPRESSED_SIZE:
|
||||
raise ValueError(
|
||||
f"ZIP uncompressed size {total_size / (1024 * 1024):.0f} MB "
|
||||
f"exceeds limit of {_MAX_UNCOMPRESSED_SIZE / (1024 * 1024):.0f} MB"
|
||||
)
|
||||
|
||||
for member in entries:
|
||||
target = (dest_path / member.filename).resolve()
|
||||
if not target.is_relative_to(dest_path):
|
||||
raise ValueError(
|
||||
f"Zip Slip detected — member '{member.filename}' "
|
||||
f"resolves outside target directory"
|
||||
)
|
||||
|
||||
zf.extractall(dest)
|
||||
|
||||
|
||||
def _extract_zip(zip_bytes: bytes, dest: str) -> Path:
|
||||
"""Extract *zip_bytes* into *dest* and return rules/ dir."""
|
||||
with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
|
||||
zf.extractall(dest)
|
||||
_safe_extract_zip(zip_bytes, dest)
|
||||
rules_dir = Path(dest) / _ZIP_ROOT_PREFIX / "rules"
|
||||
if not rules_dir.is_dir():
|
||||
raise FileNotFoundError(
|
||||
|
||||
@@ -11,7 +11,7 @@ parser. No LLMs or paid APIs are used.
|
||||
|
||||
import logging
|
||||
import re
|
||||
import xml.etree.ElementTree as ET
|
||||
import defusedxml.ElementTree as ET
|
||||
from datetime import datetime
|
||||
|
||||
import requests as _requests
|
||||
|
||||
@@ -81,10 +81,50 @@ def _download_zip(url: str = SIGMA_ZIP_URL) -> bytes:
|
||||
return content
|
||||
|
||||
|
||||
def _safe_extract_zip(zip_bytes: bytes, dest: str) -> None:
|
||||
"""Extract *zip_bytes* into *dest* with Zip Slip and Zip Bomb protection.
|
||||
|
||||
Raises :class:`ValueError` if any member tries to escape the target
|
||||
directory (path traversal / Zip Slip) or if the archive exceeds the
|
||||
safety limits.
|
||||
"""
|
||||
# Maximum uncompressed size: 500 MB — prevents zip-bomb DoS
|
||||
_MAX_UNCOMPRESSED_SIZE = 500 * 1024 * 1024
|
||||
# Maximum number of entries
|
||||
_MAX_ENTRIES = 50_000
|
||||
|
||||
dest_path = Path(dest).resolve()
|
||||
|
||||
with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
|
||||
entries = zf.infolist()
|
||||
|
||||
if len(entries) > _MAX_ENTRIES:
|
||||
raise ValueError(
|
||||
f"ZIP archive contains {len(entries)} entries "
|
||||
f"(limit: {_MAX_ENTRIES}) — possible zip bomb"
|
||||
)
|
||||
|
||||
total_size = sum(info.file_size for info in entries)
|
||||
if total_size > _MAX_UNCOMPRESSED_SIZE:
|
||||
raise ValueError(
|
||||
f"ZIP uncompressed size {total_size / (1024 * 1024):.0f} MB "
|
||||
f"exceeds limit of {_MAX_UNCOMPRESSED_SIZE / (1024 * 1024):.0f} MB"
|
||||
)
|
||||
|
||||
for member in entries:
|
||||
target = (dest_path / member.filename).resolve()
|
||||
if not target.is_relative_to(dest_path):
|
||||
raise ValueError(
|
||||
f"Zip Slip detected — member '{member.filename}' "
|
||||
f"resolves outside target directory"
|
||||
)
|
||||
|
||||
zf.extractall(dest)
|
||||
|
||||
|
||||
def _extract_zip(zip_bytes: bytes, dest: str) -> Path:
|
||||
"""Extract *zip_bytes* into *dest* and return the path to rules/ dir."""
|
||||
with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
|
||||
zf.extractall(dest)
|
||||
_safe_extract_zip(zip_bytes, dest)
|
||||
rules_dir = Path(dest) / _ZIP_ROOT_PREFIX / "rules"
|
||||
if not rules_dir.is_dir():
|
||||
raise FileNotFoundError(
|
||||
|
||||
@@ -13,9 +13,11 @@ from app.config import settings
|
||||
# Shared client (module-level singleton)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_scheme = "https" if settings.MINIO_SECURE else "http"
|
||||
|
||||
_client = boto3.client(
|
||||
"s3",
|
||||
endpoint_url=f"http://{settings.MINIO_ENDPOINT}",
|
||||
endpoint_url=f"{_scheme}://{settings.MINIO_ENDPOINT}",
|
||||
aws_access_key_id=settings.MINIO_ACCESS_KEY,
|
||||
aws_secret_access_key=settings.MINIO_SECRET_KEY,
|
||||
region_name="us-east-1", # MinIO ignores this but boto3 requires it
|
||||
|
||||
21
backend/app/utils.py
Normal file
21
backend/app/utils.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""Shared utility helpers."""
|
||||
|
||||
|
||||
def escape_like(value: str) -> str:
|
||||
"""Escape SQL LIKE wildcard characters (``%`` and ``_``).
|
||||
|
||||
Prevents user-supplied search terms from being interpreted as LIKE
|
||||
pattern metacharacters when used with SQLAlchemy's ``ilike``/``like``
|
||||
methods.
|
||||
|
||||
Usage::
|
||||
|
||||
from app.utils import escape_like
|
||||
query.filter(Model.name.ilike(f"%{escape_like(term)}%"))
|
||||
"""
|
||||
return (
|
||||
value
|
||||
.replace("\\", "\\\\")
|
||||
.replace("%", "\\%")
|
||||
.replace("_", "\\_")
|
||||
)
|
||||
Reference in New Issue
Block a user