fix: resolve 20 security vulnerabilities from comprehensive audit

Critical (1-3):
- Replace hardcoded admin credentials with secure auto-generation (seed.py)
- Enforce SECRET_KEY configuration, fail in production if missing (config.py)
- Add Zip Slip and Zip Bomb protection to all ZIP import services

High/Medium (4-9):
- Add 50MB file size limit and extension whitelist to evidence uploads
- Configure CORS origins via environment variable instead of hardcoded
- Migrate JWT storage from localStorage to HttpOnly cookies (frontend+backend)
- Add rate limiting (5/min) on login endpoint via slowapi
- Replace generic dict payloads with Pydantic schemas (mass assignment)

Medium (10-17):
- Check is_active on login to prevent disabled users from authenticating
- Sanitize exception messages in API responses (system, data_sources)
- Escape LIKE wildcards in all ilike search filters across 8 routers
- Run Docker container as non-root user (appuser)
- Make MINIO_SECURE configurable via environment variable
- Add password complexity policy (12+ chars, upper/lower/digit/special)
- Implement JWT token revocation via in-memory blacklist + reduce TTL to 15min
- Replace xml.etree with defusedxml to prevent Billion Laughs attacks

Low (18-20):
- Add security headers to Nginx (CSP, X-Frame-Options, HSTS-ready, etc.)
- Disable Swagger UI/ReDoc/OpenAPI in production
- Restrict /health endpoint to internal networks via Nginx ACL

Also: rewrite install.sh as interactive wizard for guided deployment,
fix test-from-template validation error (technique_id UUID vs MITRE ID)
This commit is contained in:
2026-02-11 08:56:26 +01:00
parent e7e63161e8
commit 64d64080e0
36 changed files with 1154 additions and 311 deletions

View File

@@ -4,10 +4,13 @@ Security utilities: password hashing and JWT token management.
This module provides pure functions for:
- Hashing and verifying passwords using bcrypt via passlib.
- Creating JWT access tokens using python-jose.
- Managing an in-memory token blacklist for revocation.
No endpoints are defined here.
"""
import threading
import uuid as _uuid
from datetime import datetime, timedelta, timezone
from jose import jwt
@@ -38,13 +41,53 @@ def verify_password(plain: str, hashed: str) -> bool:
def create_access_token(data: dict) -> str:
"""Create a signed JWT containing *data* plus an ``exp`` claim.
"""Create a signed JWT containing *data* plus ``exp`` and ``jti`` claims.
The token expires after ``ACCESS_TOKEN_EXPIRE_MINUTES`` (from settings).
- ``jti`` (JWT ID): unique identifier that enables token revocation.
- ``exp``: expiration timestamp based on ``ACCESS_TOKEN_EXPIRE_MINUTES``.
"""
to_encode = data.copy()
expire = datetime.now(timezone.utc) + timedelta(
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES,
)
to_encode.update({"exp": expire})
to_encode.update({
"exp": expire,
"jti": str(_uuid.uuid4()),
})
return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
# ---------------------------------------------------------------------------
# Token blacklist (in-memory)
# ---------------------------------------------------------------------------
# Stores (jti, expiry_timestamp) tuples. Entries are automatically purged
# once they are past their original expiry (the token would be invalid
# anyway at that point). Thread-safe via a simple lock.
#
# For multi-worker / multi-process deployments, consider replacing this
# with a shared store like Redis.
# ---------------------------------------------------------------------------
_blacklist: dict[str, float] = {} # jti → expiry epoch
_blacklist_lock = threading.Lock()
def blacklist_token(jti: str, exp: float) -> None:
"""Add *jti* to the blacklist until it naturally expires at *exp*."""
with _blacklist_lock:
_blacklist[jti] = exp
_cleanup_blacklist()
def is_token_blacklisted(jti: str) -> bool:
"""Return ``True`` if *jti* has been revoked."""
with _blacklist_lock:
return jti in _blacklist
def _cleanup_blacklist() -> None:
"""Remove entries whose tokens have already expired (caller holds lock)."""
now = datetime.now(timezone.utc).timestamp()
expired = [k for k, exp in _blacklist.items() if exp < now]
for k in expired:
del _blacklist[k]

View File

@@ -1,20 +1,46 @@
import os
import secrets
import warnings
from pydantic_settings import BaseSettings
# ---------------------------------------------------------------------------
# Detect environment: "production" when AEGIS_ENV or common indicators are set
# ---------------------------------------------------------------------------
_is_production = os.environ.get("AEGIS_ENV", "").lower() == "production" or bool(
os.environ.get("SECRET_KEY") # having an explicit SECRET_KEY hints prod
)
class Settings(BaseSettings):
DATABASE_URL: str = "postgresql://postgres:postgres@postgres:5432/attackdb"
SECRET_KEY: str = "change-me-in-production"
# ── Security ──────────────────────────────────────────────────────
# SECRET_KEY has NO safe default. In development a random key is
# generated at startup (tokens invalidate on restart — acceptable
# for local dev). In production it MUST be supplied via env/.env
# so tokens survive restarts.
SECRET_KEY: str = ""
ALGORITHM: str = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES: int = 60
ACCESS_TOKEN_EXPIRE_MINUTES: int = 15 # short-lived for security; configurable via env
# ── CORS ─────────────────────────────────────────────────────────
# Comma-separated list of allowed origins, or a JSON array.
# In dev this defaults to common local ports; in production set it
# to the actual frontend domain(s).
CORS_ORIGINS: str = "http://localhost:3000,http://localhost:5173"
# ── MinIO / S3 ───────────────────────────────────────────────────
MINIO_ENDPOINT: str = "minio:9000"
MINIO_ACCESS_KEY: str = "minioadmin"
MINIO_SECRET_KEY: str = "minioadmin"
MINIO_BUCKET: str = "evidence"
MINIO_SECURE: bool = False # True → use HTTPS to connect to MinIO
# Re-testing
# ── Re-testing ───────────────────────────────────────────────────
MAX_RETEST_COUNT: int = 3 # maximum automatic retests per original test
# Scoring weights (must sum to 100)
# ── Scoring weights (must sum to 100) ────────────────────────────
SCORING_WEIGHT_TESTS: int = 40
SCORING_WEIGHT_DETECTION_RULES: int = 20
SCORING_WEIGHT_D3FEND: int = 15
@@ -26,3 +52,29 @@ class Settings(BaseSettings):
settings = Settings()
# ---------------------------------------------------------------------------
# Post-init validation for SECRET_KEY
# ---------------------------------------------------------------------------
_UNSAFE_SECRETS = {
"",
"change-me-in-production",
"change-me-in-production-use-a-long-random-string",
}
if settings.SECRET_KEY in _UNSAFE_SECRETS:
if _is_production:
raise RuntimeError(
"CRITICAL: SECRET_KEY is not configured. "
"Set a strong random value (>= 32 chars) via the SECRET_KEY "
"environment variable or in your .env file before running in "
"production. Example: openssl rand -hex 32"
)
# Development: auto-generate an ephemeral key and warn
settings.SECRET_KEY = secrets.token_hex(32)
warnings.warn(
"SECRET_KEY was not set — using an auto-generated ephemeral key. "
"JWT tokens will be invalidated on every restart. "
"Set SECRET_KEY in your environment for persistent sessions.",
stacklevel=2,
)

View File

@@ -2,25 +2,32 @@
Authentication and RBAC dependencies for FastAPI.
Provides:
- ``get_current_user``: decodes JWT, fetches user from DB, raises 401 on failure.
- ``get_current_user``: decodes JWT from HttpOnly cookie (preferred) or
Authorization header (fallback), fetches user from DB, raises 401 on failure.
- ``require_role``: factory that returns a dependency enforcing a specific role
(admins always pass).
"""
from fastapi import Depends, HTTPException, status
from typing import Optional
from fastapi import Cookie, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from jose import JWTError, jwt
from sqlalchemy.orm import Session
from app.auth import is_token_blacklisted
from app.config import settings
from app.database import get_db
from app.models.user import User
# ---------------------------------------------------------------------------
# OAuth2 scheme
# OAuth2 scheme (reads Authorization header — used as fallback / Swagger UI)
# ---------------------------------------------------------------------------
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login", auto_error=False)
# Cookie name — must match the one set in the auth router
_COOKIE_NAME = "aegis_token"
# ---------------------------------------------------------------------------
# Current-user dependency
@@ -28,12 +35,19 @@ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
async def get_current_user(
token: str = Depends(oauth2_scheme),
aegis_token: Optional[str] = Cookie(None),
bearer_token: Optional[str] = Depends(oauth2_scheme),
db: Session = Depends(get_db),
) -> User:
"""Decode the JWT *token*, look up the user in *db*, and return it.
"""Decode the JWT, look up the user in *db*, and return it.
Token resolution order:
1. ``aegis_token`` **HttpOnly cookie** (preferred — immune to XSS).
2. ``Authorization: Bearer <token>`` header (fallback for API clients
and Swagger UI).
Raises :class:`~fastapi.HTTPException` **401** when:
- no token is found in either location,
- the token cannot be decoded,
- the ``sub`` claim is missing, or
- no matching active user exists in the database.
@@ -44,6 +58,11 @@ async def get_current_user(
headers={"WWW-Authenticate": "Bearer"},
)
# Prefer cookie, fall back to header
token = aegis_token or bearer_token
if token is None:
raise credentials_exception
try:
payload = jwt.decode(
token,
@@ -53,11 +72,15 @@ async def get_current_user(
username: str | None = payload.get("sub")
if username is None:
raise credentials_exception
# Check token blacklist (revoked tokens)
jti: str | None = payload.get("jti")
if jti and is_token_blacklisted(jti):
raise credentials_exception
except JWTError:
raise credentials_exception
user = db.query(User).filter(User.username == username).first()
if user is None:
if user is None or not user.is_active:
raise credentials_exception
return user

View File

@@ -1,10 +1,14 @@
import logging
import os
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from fastapi.exceptions import RequestValidationError
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded
from slowapi.util import get_remote_address
from sqlalchemy.exc import SQLAlchemyError
from app.routers import auth as auth_router
@@ -31,6 +35,9 @@ from app.routers import snapshots as snapshots_router
from app.storage import ensure_bucket_exists
from app.jobs.mitre_sync_job import start_scheduler, scheduler
# ── Environment detection ─────────────────────────────────────────────────
_IS_PRODUCTION = os.environ.get("AEGIS_ENV", "").lower() == "production"
# ── Logging ───────────────────────────────────────────────────────────────
logging.basicConfig(
level=logging.INFO,
@@ -47,15 +54,33 @@ async def lifespan(app: FastAPI):
scheduler.shutdown(wait=False)
app = FastAPI(title="Attack Coverage Platform", lifespan=lifespan)
# ── In production, disable Swagger UI and ReDoc to hide API surface ──────
app = FastAPI(
title="Attack Coverage Platform",
lifespan=lifespan,
docs_url=None if _IS_PRODUCTION else "/docs",
redoc_url=None if _IS_PRODUCTION else "/redoc",
openapi_url=None if _IS_PRODUCTION else "/openapi.json",
)
# ── Rate Limiter ──────────────────────────────────────────────────────────
limiter = Limiter(key_func=get_remote_address)
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
# ── CORS ──────────────────────────────────────────────────────────────────
from app.config import settings as _settings
_cors_origins: list[str] = [
o.strip() for o in _settings.CORS_ORIGINS.split(",") if o.strip()
]
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:3000", "http://localhost:5173"],
allow_origins=_cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
allow_headers=["Authorization", "Content-Type"],
)
# ── Routers ──────────────────────────────────────────────────────────────
@@ -82,8 +107,13 @@ app.include_router(compliance_router.router, prefix="/api/v1")
app.include_router(snapshots_router.router, prefix="/api/v1")
@app.get("/health")
@app.get("/health", include_in_schema=False)
def health():
"""Minimal health check — returns only an HTTP 200 with no service metadata.
Access is restricted to internal networks at the Nginx level
(see ``frontend/nginx.conf``).
"""
return {"status": "ok"}

View File

@@ -1,17 +1,40 @@
"""Authentication router: login and current-user endpoints."""
"""Authentication router: login, logout and current-user endpoints.
from fastapi import APIRouter, Depends, HTTPException, status
The JWT access token is delivered as an **HttpOnly** cookie
(``aegis_token``) so it is inaccessible to client-side JavaScript,
mitigating XSS token-theft attacks. The JSON response also includes
the token in the body for backwards compatibility and for clients that
cannot use cookies (e.g. Swagger UI).
"""
import os
from fastapi import APIRouter, Cookie, Depends, HTTPException, Request, Response, status
from fastapi.security import OAuth2PasswordRequestForm
from slowapi import Limiter
from slowapi.util import get_remote_address
from sqlalchemy.orm import Session
from app.auth import verify_password, create_access_token
from jose import jwt, JWTError
from app.auth import verify_password, create_access_token, blacklist_token
from app.config import settings
from app.database import get_db
from app.dependencies.auth import get_current_user
from app.models.user import User
from app.schemas.auth import TokenResponse, UserOut
# Rate limiter instance (shares backend state via app.state.limiter)
limiter = Limiter(key_func=get_remote_address)
router = APIRouter(prefix="/auth", tags=["auth"])
# Detect whether we're behind HTTPS (production) so the cookie can be Secure
_IS_HTTPS = os.environ.get("AEGIS_ENV", "").lower() == "production"
# Cookie name used to transport the JWT
_COOKIE_NAME = "aegis_token"
# ---------------------------------------------------------------------------
# POST /auth/login
@@ -19,11 +42,19 @@ router = APIRouter(prefix="/auth", tags=["auth"])
@router.post("/login", response_model=TokenResponse)
@limiter.limit("5/minute")
def login(
request: Request,
response: Response,
form_data: OAuth2PasswordRequestForm = Depends(),
db: Session = Depends(get_db),
):
"""Authenticate a user and return a JWT access token."""
"""Authenticate a user and return a JWT access token.
Rate-limited to **5 attempts per minute per IP** to prevent brute-force
attacks. The token is set as an HttpOnly cookie **and** returned in the
JSON body for API/Swagger compatibility.
"""
user = db.query(User).filter(User.username == form_data.username).first()
if user is None or not verify_password(form_data.password, user.hashed_password):
@@ -32,10 +63,70 @@ def login(
detail="Incorrect username or password",
)
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Account is disabled. Contact an administrator.",
)
access_token = create_access_token(data={"sub": user.username})
# Set HttpOnly cookie — inaccessible from JS
response.set_cookie(
key=_COOKIE_NAME,
value=access_token,
httponly=True,
secure=_IS_HTTPS,
samesite="strict",
max_age=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
path="/",
)
return TokenResponse(access_token=access_token)
# ---------------------------------------------------------------------------
# POST /auth/logout
# ---------------------------------------------------------------------------
@router.post("/logout")
def logout(
request: Request,
response: Response,
aegis_token: str | None = Cookie(None),
):
"""Clear the authentication cookie and revoke the current token.
The token's ``jti`` is added to an in-memory blacklist so it cannot
be reused even if the cookie has already been copied elsewhere.
"""
# Attempt to blacklist the token's jti
token = aegis_token or request.headers.get("Authorization", "").removeprefix("Bearer ").strip()
if token:
try:
payload = jwt.decode(
token,
settings.SECRET_KEY,
algorithms=[settings.ALGORITHM],
)
jti = payload.get("jti")
exp = payload.get("exp", 0)
if jti:
blacklist_token(jti, float(exp))
except JWTError:
pass # token already invalid — nothing to revoke
response.delete_cookie(
key=_COOKIE_NAME,
httponly=True,
secure=_IS_HTTPS,
samesite="strict",
path="/",
)
return {"detail": "Logged out"}
# ---------------------------------------------------------------------------
# GET /auth/me
# ---------------------------------------------------------------------------

View File

@@ -174,7 +174,8 @@ def list_campaigns(
if threat_actor_id:
query = query.filter(Campaign.threat_actor_id == threat_actor_id)
if search:
pattern = f"%{search}%"
from app.utils import escape_like
pattern = f"%{escape_like(search)}%"
query = query.filter(Campaign.name.ilike(pattern) | Campaign.description.ilike(pattern))
total = query.count()

View File

@@ -42,7 +42,8 @@ def list_defensive_techniques(
query = query.filter(DefensiveTechnique.tactic == tactic)
if search:
pattern = f"%{search}%"
from app.utils import escape_like
pattern = f"%{escape_like(search)}%"
query = query.filter(
DefensiveTechnique.name.ilike(pattern)
| DefensiveTechnique.d3fend_id.ilike(pattern)

View File

@@ -7,8 +7,10 @@ including sync triggers, enable/disable toggles, and statistics.
import logging
from datetime import datetime
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel
from sqlalchemy.orm import Session
from app.database import get_db
@@ -17,6 +19,17 @@ from app.models.user import User
from app.models.data_source import DataSource
from app.services.audit_service import log_action
# ---------------------------------------------------------------------------
# Pydantic schemas for request validation
# ---------------------------------------------------------------------------
class DataSourceUpdate(BaseModel):
"""Payload for updating a data source — only allowed fields."""
is_enabled: Optional[bool] = None
sync_frequency: Optional[str] = None
config: Optional[dict] = None
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/data-sources", tags=["data-sources"])
@@ -90,29 +103,26 @@ def list_data_sources(
@router.patch("/{source_id}")
def update_data_source(
source_id: str,
body: dict,
body: DataSourceUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
"""Update a data source (enable/disable, change config).
**Requires** the ``admin`` role.
Body fields (all optional):
- ``is_enabled`` (bool)
- ``sync_frequency`` (str)
- ``config`` (dict)
"""
ds = db.query(DataSource).filter(DataSource.id == source_id).first()
if not ds:
raise HTTPException(status_code=404, detail="Data source not found")
if "is_enabled" in body:
ds.is_enabled = bool(body["is_enabled"])
if "sync_frequency" in body:
ds.sync_frequency = body["sync_frequency"]
if "config" in body:
ds.config = body["config"]
update_data = body.model_dump(exclude_unset=True)
if "is_enabled" in update_data:
ds.is_enabled = update_data["is_enabled"]
if "sync_frequency" in update_data:
ds.sync_frequency = update_data["sync_frequency"]
if "config" in update_data:
ds.config = update_data["config"]
db.commit()
@@ -122,7 +132,7 @@ def update_data_source(
action="update_data_source",
entity_type="data_source",
entity_id=str(ds.id),
details={"updates": body},
details={"updates": update_data},
)
return {"message": "Data source updated", "id": str(ds.id)}
@@ -156,14 +166,14 @@ def sync_data_source(
try:
summary = handler(db)
except Exception as exc:
logger.error("Sync failed for %s: %s", ds.name, exc)
logger.error("Sync failed for %s: %s", ds.name, exc, exc_info=True)
ds.last_sync_status = "error"
ds.last_sync_at = datetime.utcnow()
ds.last_sync_stats = {"error": str(exc)}
db.commit()
raise HTTPException(
status_code=500,
detail=f"Sync failed: {str(exc)}",
detail=f"Sync failed for '{ds.display_name}'. Check server logs for details.",
)
# Update DS record (the handler may already have done this,
@@ -222,7 +232,7 @@ def sync_all_data_sources(
"stats": summary,
})
except Exception as exc:
logger.error("Sync failed for %s: %s", ds.name, exc)
logger.error("Sync failed for %s: %s", ds.name, exc, exc_info=True)
ds.last_sync_status = "error"
ds.last_sync_at = datetime.utcnow()
ds.last_sync_stats = {"error": str(exc)}
@@ -230,7 +240,7 @@ def sync_all_data_sources(
results.append({
"source": ds.name,
"status": "error",
"detail": str(exc),
"detail": "Sync failed. Check server logs for details.",
})
log_action(

View File

@@ -5,10 +5,12 @@ and managing the template ↔ detection rule associations.
"""
import logging
import uuid
from typing import Optional
from datetime import datetime
from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel
from sqlalchemy import func
from sqlalchemy.orm import Session
@@ -20,6 +22,18 @@ from app.models.test_template import TestTemplate
from app.models.test_template_detection_rule import TestTemplateDetectionRule
from app.models.test_detection_result import TestDetectionResult
# ---------------------------------------------------------------------------
# Pydantic schemas for request validation
# ---------------------------------------------------------------------------
class DetectionRuleEvaluate(BaseModel):
"""Payload for evaluating a detection rule against a test."""
test_id: uuid.UUID
detection_rule_id: uuid.UUID
triggered: Optional[bool] = None
notes: Optional[str] = None
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/detection-rules", tags=["detection-rules"])
@@ -53,7 +67,8 @@ def list_detection_rules(
query = query.filter(DetectionRule.severity == severity)
if search:
pattern = f"%{search}%"
from app.utils import escape_like
pattern = f"%{escape_like(search)}%"
query = query.filter(
DetectionRule.title.ilike(pattern)
| DetectionRule.description.ilike(pattern)
@@ -294,27 +309,15 @@ def get_detection_rules_for_test(
@router.post("/evaluate")
def evaluate_detection_rule(
payload: dict,
payload: DetectionRuleEvaluate,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("blue_tech", "blue_lead")),
):
"""Save or update the evaluation result for a detection rule on a test.
Body:
{
"test_id": "...",
"detection_rule_id": "...",
"triggered": true | false | null,
"notes": "optional notes"
}
"""
test_id = payload.get("test_id")
detection_rule_id = payload.get("detection_rule_id")
triggered = payload.get("triggered")
notes = payload.get("notes")
if not test_id or not detection_rule_id:
raise HTTPException(status_code=400, detail="test_id and detection_rule_id are required")
"""Save or update the evaluation result for a detection rule on a test."""
test_id = payload.test_id
detection_rule_id = payload.detection_rule_id
triggered = payload.triggered
notes = payload.notes
# Check test exists
from app.models.test import Test

View File

@@ -20,6 +20,7 @@ Access Control
"""
import hashlib
import os
import uuid as _uuid
from typing import Optional
@@ -43,6 +44,29 @@ _RED_EDITABLE_STATES = (TestState.draft, TestState.red_executing)
# States where blue evidence can be uploaded / deleted
_BLUE_EDITABLE_STATES = (TestState.blue_evaluating,)
# ---------------------------------------------------------------------------
# Upload safety limits
# ---------------------------------------------------------------------------
# Maximum upload size in bytes (default 50 MB)
_MAX_UPLOAD_SIZE = 50 * 1024 * 1024
# Allowed file extensions (lowercase, with leading dot)
_ALLOWED_EXTENSIONS: set[str] = {
# Images / screenshots
".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp", ".svg",
# Documents
".pdf", ".doc", ".docx", ".xls", ".xlsx", ".csv", ".txt",
".md", ".rtf", ".odt", ".ods",
# Logs & captures
".log", ".pcap", ".pcapng", ".evtx", ".json", ".xml",
".yaml", ".yml", ".toml",
# Archives (for bundled evidence)
".zip", ".tar", ".gz", ".7z",
# Other common evidence types
".har", ".eml", ".msg",
}
# ---------------------------------------------------------------------------
# Helpers
@@ -177,21 +201,39 @@ async def upload_evidence(
# Validate permissions
_validate_upload_permission(test, team, current_user)
# 1. Read content + hash
content = await file.read()
# 1. Validate file extension
file_name = file.filename or "unnamed"
_, ext = os.path.splitext(file_name)
if ext.lower() not in _ALLOWED_EXTENSIONS:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"File type '{ext}' is not allowed. "
f"Permitted types: {', '.join(sorted(_ALLOWED_EXTENSIONS))}",
)
# 2. Read content with size limit
content = await file.read(_MAX_UPLOAD_SIZE + 1)
if len(content) > _MAX_UPLOAD_SIZE:
raise HTTPException(
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
detail=f"File exceeds maximum upload size of "
f"{_MAX_UPLOAD_SIZE // (1024 * 1024)} MB",
)
# 3. Hash
sha256 = hashlib.sha256(content).hexdigest()
# 2. Object key
file_name = file.filename or "unnamed"
key = f"{test_id}/{_uuid.uuid4()}_{file_name}"
# 4. Object key (sanitise filename to prevent path traversal in storage)
safe_name = os.path.basename(file_name)
key = f"{test_id}/{_uuid.uuid4()}_{safe_name}"
# 3. Upload to MinIO
# 5. Upload to MinIO
upload_file(content, key)
# 4. Persist metadata
# 6. Persist metadata
evidence = Evidence(
test_id=test_id,
file_name=file_name,
file_name=safe_name,
file_path=key,
sha256_hash=sha256,
uploaded_by=current_user.id,
@@ -202,7 +244,7 @@ async def upload_evidence(
db.commit()
db.refresh(evidence)
# 5. Audit
# 7. Audit
log_action(
db,
user_id=current_user.id,
@@ -210,7 +252,7 @@ async def upload_evidence(
entity_type="evidence",
entity_id=evidence.id,
details={
"file_name": file_name,
"file_name": safe_name,
"sha256": sha256,
"test_id": str(test_id),
"team": team.value,

View File

@@ -107,9 +107,10 @@ def _apply_filters(
query = query.filter(or_(*platform_filters))
if tactics:
from sqlalchemy import or_
from app.utils import escape_like
tactic_filters = []
for tactic in tactics:
tactic_filters.append(model.tactic.ilike(f"%{tactic}%"))
tactic_filters.append(model.tactic.ilike(f"%{escape_like(tactic)}%"))
query = query.filter(or_(*tactic_filters))
return query

View File

@@ -43,7 +43,8 @@ def coverage_summary(
"""Full coverage report as JSON — technique-by-technique with test counts."""
query = db.query(Technique)
if tactic:
query = query.filter(Technique.tactic.ilike(f"%{tactic}%"))
from app.utils import escape_like
query = query.filter(Technique.tactic.ilike(f"%{escape_like(tactic)}%"))
techniques = query.order_by(Technique.mitre_id).all()
@@ -109,7 +110,8 @@ def coverage_csv(
"""Export coverage as a downloadable CSV."""
query = db.query(Technique)
if tactic:
query = query.filter(Technique.tactic.ilike(f"%{tactic}%"))
from app.utils import escape_like
query = query.filter(Technique.tactic.ilike(f"%{escape_like(tactic)}%"))
techniques = query.order_by(Technique.mitre_id).all()

View File

@@ -80,10 +80,9 @@ def trigger_atomic_import(
try:
summary = import_atomic_red_team(db)
except Exception as exc:
logger.error("Atomic Red Team import failed: %s", exc)
logger.error("Atomic Red Team import failed: %s", exc, exc_info=True)
return {
"message": "Import failed",
"error": str(exc),
"message": "Import failed. Check server logs for details.",
}
return {

View File

@@ -69,13 +69,15 @@ def list_templates(
if source:
query = query.filter(TestTemplate.source == source)
if platform:
query = query.filter(TestTemplate.platform.ilike(f"%{platform}%"))
from app.utils import escape_like
query = query.filter(TestTemplate.platform.ilike(f"%{escape_like(platform)}%"))
if severity:
query = query.filter(TestTemplate.severity == severity)
if mitre_technique_id:
query = query.filter(TestTemplate.mitre_technique_id == mitre_technique_id)
if search:
pattern = f"%{search}%"
from app.utils import escape_like
pattern = f"%{escape_like(search)}%"
query = query.filter(
or_(
TestTemplate.name.ilike(pattern),

View File

@@ -109,7 +109,8 @@ def list_tests(
if technique_id:
query = query.filter(Test.technique_id == technique_id)
if platform:
query = query.filter(Test.platform.ilike(f"%{platform}%"))
from app.utils import escape_like
query = query.filter(Test.platform.ilike(f"%{escape_like(platform)}%"))
if created_by:
query = query.filter(Test.created_by == created_by)
if pending_validation_side == "red":
@@ -200,15 +201,25 @@ def create_test_from_template(
detail=f"TestTemplate with id '{payload.template_id}' not found",
)
technique = db.query(Technique).filter(Technique.id == payload.technique_id).first()
# Resolve technique_id: accept both UUID and MITRE ID (e.g. "T1059.001")
technique = None
try:
technique_uuid = uuid.UUID(payload.technique_id)
technique = db.query(Technique).filter(Technique.id == technique_uuid).first()
except ValueError:
pass
if technique is None:
technique = db.query(Technique).filter(Technique.mitre_id == payload.technique_id).first()
if technique is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Technique with id '{payload.technique_id}' not found",
detail=f"Technique '{payload.technique_id}' not found",
)
test = Test(
technique_id=payload.technique_id,
technique_id=technique.id,
name=template.name,
description=template.description,
platform=template.platform,

View File

@@ -49,7 +49,8 @@ def list_threat_actors(
# Filters
if search:
pattern = f"%{search}%"
from app.utils import escape_like
pattern = f"%{escape_like(search)}%"
query = query.filter(
or_(
ThreatActor.name.ilike(pattern),
@@ -68,9 +69,10 @@ def list_threat_actors(
query = query.filter(ThreatActor.sophistication == sophistication)
if target_sectors:
from app.utils import escape_like
# JSONB contains check
query = query.filter(
func.cast(ThreatActor.target_sectors, func.text()).ilike(f"%{target_sectors}%")
func.cast(ThreatActor.target_sectors, func.text()).ilike(f"%{escape_like(target_sectors)}%")
)
# Total count

View File

@@ -75,4 +75,4 @@ class TestTemplateInstantiate(BaseModel):
"""Payload to create a real test from an existing template."""
template_id: uuid.UUID
technique_id: uuid.UUID
technique_id: str # accepts both UUID and MITRE ID (e.g. "T1059.001")

View File

@@ -1,9 +1,49 @@
"""Pydantic schemas for User management endpoints."""
import re
import uuid
from datetime import datetime
from pydantic import BaseModel, ConfigDict, EmailStr
from pydantic import BaseModel, ConfigDict, EmailStr, field_validator
# ── Password policy ─────────────────────────────────────────────────
_MIN_PASSWORD_LENGTH = 12
_PASSWORD_RULES: list[tuple[str, str]] = [
(r"[A-Z]", "at least one uppercase letter"),
(r"[a-z]", "at least one lowercase letter"),
(r"[0-9]", "at least one digit"),
(r"[!@#$%^&*()_+\-=\[\]{};':\"\\|,.<>/?`~]", "at least one special character"),
]
def _validate_password_strength(password: str) -> str:
"""Check that *password* satisfies the complexity policy.
Rules:
- Minimum 12 characters
- At least one uppercase letter
- At least one lowercase letter
- At least one digit
- At least one special character
"""
errors: list[str] = []
if len(password) < _MIN_PASSWORD_LENGTH:
errors.append(f"must be at least {_MIN_PASSWORD_LENGTH} characters long")
for pattern, description in _PASSWORD_RULES:
if not re.search(pattern, password):
errors.append(description)
if errors:
raise ValueError(
"Password does not meet complexity requirements: " + "; ".join(errors)
)
return password
# ── Create ──────────────────────────────────────────────────────────
@@ -16,6 +56,11 @@ class UserCreate(BaseModel):
password: str
role: str = "viewer"
@field_validator("password")
@classmethod
def password_strength(cls, v: str) -> str:
return _validate_password_strength(v)
# ── Update ──────────────────────────────────────────────────────────
@@ -28,6 +73,13 @@ class UserUpdate(BaseModel):
is_active: bool | None = None
password: str | None = None
@field_validator("password")
@classmethod
def password_strength(cls, v: str | None) -> str | None:
if v is not None:
return _validate_password_strength(v)
return v
# ── Read (full) ─────────────────────────────────────────────────────

View File

@@ -1,32 +1,80 @@
"""
Seed script — creates the initial admin user if it does not already exist.
On first run the admin credentials are generated securely:
- Username is read from ``ADMIN_USERNAME`` env var (default: ``admin``).
- Password is read from ``ADMIN_PASSWORD`` env var. When the variable is
**not set**, a cryptographically random 16-character password is generated
automatically and printed to the startup logs so the operator can copy it.
Usage:
python -m app.seed
"""
import os
import secrets
import string
from app.auth import hash_password
from app.database import SessionLocal
from app.models.user import User
# Characters for auto-generated passwords (alphanumeric + safe symbols)
_PW_ALPHABET = string.ascii_letters + string.digits + "!@#$%&*-_+"
def _generate_password(length: int = 16) -> str:
"""Return a cryptographically random password of *length* characters."""
return "".join(secrets.choice(_PW_ALPHABET) for _ in range(length))
def seed_admin() -> None:
"""Create the default admin user when it is missing."""
"""Create the initial admin user when it is missing.
Reads ``ADMIN_USERNAME`` and ``ADMIN_PASSWORD`` from the environment.
If ``ADMIN_PASSWORD`` is empty or unset a secure random password is
generated and displayed in the logs.
"""
db = SessionLocal()
try:
existing = db.query(User).filter(User.username == "admin").first()
admin_username = os.environ.get("ADMIN_USERNAME", "admin").strip() or "admin"
existing = db.query(User).filter(User.username == admin_username).first()
if existing:
print("Admin user already exists — skipping.")
print(f"Admin user '{admin_username}' already exists — skipping.")
return
admin_password = os.environ.get("ADMIN_PASSWORD", "").strip()
password_was_generated = False
if not admin_password:
admin_password = _generate_password()
password_was_generated = True
admin = User(
username="admin",
hashed_password=hash_password("admin123"),
username=admin_username,
hashed_password=hash_password(admin_password),
role="admin",
)
db.add(admin)
db.commit()
print("Admin user created successfully.")
# ── Display credentials in startup logs ──────────────────────
print()
print("=" * 60)
print(" AEGIS — Initial Admin User Created")
print("=" * 60)
print(f" Username : {admin_username}")
if password_was_generated:
print(f" Password : {admin_password}")
print()
print(" ** This password was auto-generated because")
print(" ADMIN_PASSWORD was not set in the environment. **")
print(" ** Save it now — it will NOT be shown again. **")
else:
print(" Password : (set via ADMIN_PASSWORD env var)")
print("=" * 60)
print()
finally:
db.close()

View File

@@ -70,10 +70,50 @@ def _download_zip(url: str = ATOMIC_RT_ZIP_URL) -> bytes:
return content
def _safe_extract_zip(zip_bytes: bytes, dest: str) -> None:
"""Extract *zip_bytes* into *dest* with Zip Slip and Zip Bomb protection.
Raises :class:`ValueError` if any member tries to escape the target
directory (path traversal / Zip Slip) or if the archive exceeds the
safety limits.
"""
# Maximum uncompressed size: 500 MB — prevents zip-bomb DoS
_MAX_UNCOMPRESSED_SIZE = 500 * 1024 * 1024
# Maximum number of entries
_MAX_ENTRIES = 50_000
dest_path = Path(dest).resolve()
with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
entries = zf.infolist()
if len(entries) > _MAX_ENTRIES:
raise ValueError(
f"ZIP archive contains {len(entries)} entries "
f"(limit: {_MAX_ENTRIES}) — possible zip bomb"
)
total_size = sum(info.file_size for info in entries)
if total_size > _MAX_UNCOMPRESSED_SIZE:
raise ValueError(
f"ZIP uncompressed size {total_size / (1024 * 1024):.0f} MB "
f"exceeds limit of {_MAX_UNCOMPRESSED_SIZE / (1024 * 1024):.0f} MB"
)
for member in entries:
target = (dest_path / member.filename).resolve()
if not target.is_relative_to(dest_path):
raise ValueError(
f"Zip Slip detected — member '{member.filename}' "
f"resolves outside target directory"
)
zf.extractall(dest)
def _extract_zip(zip_bytes: bytes, dest: str) -> Path:
"""Extract *zip_bytes* into *dest* and return the path to the atomics/ dir."""
with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
zf.extractall(dest)
_safe_extract_zip(zip_bytes, dest)
atomics_dir = Path(dest) / _ZIP_ROOT_PREFIX / "atomics"
if not atomics_dir.is_dir():
raise FileNotFoundError(

View File

@@ -75,10 +75,50 @@ def _download_zip(url: str = ELASTIC_ZIP_URL) -> bytes:
return content
def _safe_extract_zip(zip_bytes: bytes, dest: str) -> None:
"""Extract *zip_bytes* into *dest* with Zip Slip and Zip Bomb protection.
Raises :class:`ValueError` if any member tries to escape the target
directory (path traversal / Zip Slip) or if the archive exceeds the
safety limits.
"""
# Maximum uncompressed size: 500 MB — prevents zip-bomb DoS
_MAX_UNCOMPRESSED_SIZE = 500 * 1024 * 1024
# Maximum number of entries
_MAX_ENTRIES = 50_000
dest_path = Path(dest).resolve()
with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
entries = zf.infolist()
if len(entries) > _MAX_ENTRIES:
raise ValueError(
f"ZIP archive contains {len(entries)} entries "
f"(limit: {_MAX_ENTRIES}) — possible zip bomb"
)
total_size = sum(info.file_size for info in entries)
if total_size > _MAX_UNCOMPRESSED_SIZE:
raise ValueError(
f"ZIP uncompressed size {total_size / (1024 * 1024):.0f} MB "
f"exceeds limit of {_MAX_UNCOMPRESSED_SIZE / (1024 * 1024):.0f} MB"
)
for member in entries:
target = (dest_path / member.filename).resolve()
if not target.is_relative_to(dest_path):
raise ValueError(
f"Zip Slip detected — member '{member.filename}' "
f"resolves outside target directory"
)
zf.extractall(dest)
def _extract_zip(zip_bytes: bytes, dest: str) -> Path:
"""Extract *zip_bytes* into *dest* and return rules/ dir."""
with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
zf.extractall(dest)
_safe_extract_zip(zip_bytes, dest)
rules_dir = Path(dest) / _ZIP_ROOT_PREFIX / "rules"
if not rules_dir.is_dir():
raise FileNotFoundError(

View File

@@ -11,7 +11,7 @@ parser. No LLMs or paid APIs are used.
import logging
import re
import xml.etree.ElementTree as ET
import defusedxml.ElementTree as ET
from datetime import datetime
import requests as _requests

View File

@@ -81,10 +81,50 @@ def _download_zip(url: str = SIGMA_ZIP_URL) -> bytes:
return content
def _safe_extract_zip(zip_bytes: bytes, dest: str) -> None:
"""Extract *zip_bytes* into *dest* with Zip Slip and Zip Bomb protection.
Raises :class:`ValueError` if any member tries to escape the target
directory (path traversal / Zip Slip) or if the archive exceeds the
safety limits.
"""
# Maximum uncompressed size: 500 MB — prevents zip-bomb DoS
_MAX_UNCOMPRESSED_SIZE = 500 * 1024 * 1024
# Maximum number of entries
_MAX_ENTRIES = 50_000
dest_path = Path(dest).resolve()
with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
entries = zf.infolist()
if len(entries) > _MAX_ENTRIES:
raise ValueError(
f"ZIP archive contains {len(entries)} entries "
f"(limit: {_MAX_ENTRIES}) — possible zip bomb"
)
total_size = sum(info.file_size for info in entries)
if total_size > _MAX_UNCOMPRESSED_SIZE:
raise ValueError(
f"ZIP uncompressed size {total_size / (1024 * 1024):.0f} MB "
f"exceeds limit of {_MAX_UNCOMPRESSED_SIZE / (1024 * 1024):.0f} MB"
)
for member in entries:
target = (dest_path / member.filename).resolve()
if not target.is_relative_to(dest_path):
raise ValueError(
f"Zip Slip detected — member '{member.filename}' "
f"resolves outside target directory"
)
zf.extractall(dest)
def _extract_zip(zip_bytes: bytes, dest: str) -> Path:
"""Extract *zip_bytes* into *dest* and return the path to rules/ dir."""
with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
zf.extractall(dest)
_safe_extract_zip(zip_bytes, dest)
rules_dir = Path(dest) / _ZIP_ROOT_PREFIX / "rules"
if not rules_dir.is_dir():
raise FileNotFoundError(

View File

@@ -13,9 +13,11 @@ from app.config import settings
# Shared client (module-level singleton)
# ---------------------------------------------------------------------------
_scheme = "https" if settings.MINIO_SECURE else "http"
_client = boto3.client(
"s3",
endpoint_url=f"http://{settings.MINIO_ENDPOINT}",
endpoint_url=f"{_scheme}://{settings.MINIO_ENDPOINT}",
aws_access_key_id=settings.MINIO_ACCESS_KEY,
aws_secret_access_key=settings.MINIO_SECRET_KEY,
region_name="us-east-1", # MinIO ignores this but boto3 requires it

21
backend/app/utils.py Normal file
View File

@@ -0,0 +1,21 @@
"""Shared utility helpers."""
def escape_like(value: str) -> str:
"""Escape SQL LIKE wildcard characters (``%`` and ``_``).
Prevents user-supplied search terms from being interpreted as LIKE
pattern metacharacters when used with SQLAlchemy's ``ilike``/``like``
methods.
Usage::
from app.utils import escape_like
query.filter(Model.name.ilike(f"%{escape_like(term)}%"))
"""
return (
value
.replace("\\", "\\\\")
.replace("%", "\\%")
.replace("_", "\\_")
)