fix: resolve 20 security vulnerabilities from comprehensive audit

Critical (1-3):
- Replace hardcoded admin credentials with secure auto-generation (seed.py)
- Enforce SECRET_KEY configuration, fail in production if missing (config.py)
- Add Zip Slip and Zip Bomb protection to all ZIP import services

High/Medium (4-9):
- Add 50MB file size limit and extension whitelist to evidence uploads
- Configure CORS origins via environment variable instead of hardcoded
- Migrate JWT storage from localStorage to HttpOnly cookies (frontend+backend)
- Add rate limiting (5/min) on login endpoint via slowapi
- Replace generic dict payloads with Pydantic schemas (mass assignment)

Medium (10-17):
- Check is_active on login to prevent disabled users from authenticating
- Sanitize exception messages in API responses (system, data_sources)
- Escape LIKE wildcards in all ilike search filters across 8 routers
- Run Docker container as non-root user (appuser)
- Make MINIO_SECURE configurable via environment variable
- Add password complexity policy (12+ chars, upper/lower/digit/special)
- Implement JWT token revocation via in-memory blacklist + reduce TTL to 15min
- Replace xml.etree with defusedxml to prevent Billion Laughs attacks

Low (18-20):
- Add security headers to Nginx (CSP, X-Frame-Options, HSTS-ready, etc.)
- Disable Swagger UI/ReDoc/OpenAPI in production
- Restrict /health endpoint to internal networks via Nginx ACL

Also: rewrite install.sh as interactive wizard for guided deployment,
fix test-from-template validation error (technique_id UUID vs MITRE ID)
This commit is contained in:
2026-02-11 08:56:26 +01:00
parent e7e63161e8
commit 64d64080e0
36 changed files with 1154 additions and 311 deletions

View File

@@ -1,17 +1,40 @@
"""Authentication router: login and current-user endpoints."""
"""Authentication router: login, logout and current-user endpoints.
from fastapi import APIRouter, Depends, HTTPException, status
The JWT access token is delivered as an **HttpOnly** cookie
(``aegis_token``) so it is inaccessible to client-side JavaScript,
mitigating XSS token-theft attacks. The JSON response also includes
the token in the body for backwards compatibility and for clients that
cannot use cookies (e.g. Swagger UI).
"""
import os
from fastapi import APIRouter, Cookie, Depends, HTTPException, Request, Response, status
from fastapi.security import OAuth2PasswordRequestForm
from slowapi import Limiter
from slowapi.util import get_remote_address
from sqlalchemy.orm import Session
from app.auth import verify_password, create_access_token
from jose import jwt, JWTError
from app.auth import verify_password, create_access_token, blacklist_token
from app.config import settings
from app.database import get_db
from app.dependencies.auth import get_current_user
from app.models.user import User
from app.schemas.auth import TokenResponse, UserOut
# Rate limiter instance (shares backend state via app.state.limiter)
limiter = Limiter(key_func=get_remote_address)
router = APIRouter(prefix="/auth", tags=["auth"])
# Detect whether we're behind HTTPS (production) so the cookie can be Secure
_IS_HTTPS = os.environ.get("AEGIS_ENV", "").lower() == "production"
# Cookie name used to transport the JWT
_COOKIE_NAME = "aegis_token"
# ---------------------------------------------------------------------------
# POST /auth/login
@@ -19,11 +42,19 @@ router = APIRouter(prefix="/auth", tags=["auth"])
@router.post("/login", response_model=TokenResponse)
@limiter.limit("5/minute")
def login(
request: Request,
response: Response,
form_data: OAuth2PasswordRequestForm = Depends(),
db: Session = Depends(get_db),
):
"""Authenticate a user and return a JWT access token."""
"""Authenticate a user and return a JWT access token.
Rate-limited to **5 attempts per minute per IP** to prevent brute-force
attacks. The token is set as an HttpOnly cookie **and** returned in the
JSON body for API/Swagger compatibility.
"""
user = db.query(User).filter(User.username == form_data.username).first()
if user is None or not verify_password(form_data.password, user.hashed_password):
@@ -32,10 +63,70 @@ def login(
detail="Incorrect username or password",
)
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Account is disabled. Contact an administrator.",
)
access_token = create_access_token(data={"sub": user.username})
# Set HttpOnly cookie — inaccessible from JS
response.set_cookie(
key=_COOKIE_NAME,
value=access_token,
httponly=True,
secure=_IS_HTTPS,
samesite="strict",
max_age=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
path="/",
)
return TokenResponse(access_token=access_token)
# ---------------------------------------------------------------------------
# POST /auth/logout
# ---------------------------------------------------------------------------
@router.post("/logout")
def logout(
request: Request,
response: Response,
aegis_token: str | None = Cookie(None),
):
"""Clear the authentication cookie and revoke the current token.
The token's ``jti`` is added to an in-memory blacklist so it cannot
be reused even if the cookie has already been copied elsewhere.
"""
# Attempt to blacklist the token's jti
token = aegis_token or request.headers.get("Authorization", "").removeprefix("Bearer ").strip()
if token:
try:
payload = jwt.decode(
token,
settings.SECRET_KEY,
algorithms=[settings.ALGORITHM],
)
jti = payload.get("jti")
exp = payload.get("exp", 0)
if jti:
blacklist_token(jti, float(exp))
except JWTError:
pass # token already invalid — nothing to revoke
response.delete_cookie(
key=_COOKIE_NAME,
httponly=True,
secure=_IS_HTTPS,
samesite="strict",
path="/",
)
return {"detail": "Logged out"}
# ---------------------------------------------------------------------------
# GET /auth/me
# ---------------------------------------------------------------------------

View File

@@ -174,7 +174,8 @@ def list_campaigns(
if threat_actor_id:
query = query.filter(Campaign.threat_actor_id == threat_actor_id)
if search:
pattern = f"%{search}%"
from app.utils import escape_like
pattern = f"%{escape_like(search)}%"
query = query.filter(Campaign.name.ilike(pattern) | Campaign.description.ilike(pattern))
total = query.count()

View File

@@ -42,7 +42,8 @@ def list_defensive_techniques(
query = query.filter(DefensiveTechnique.tactic == tactic)
if search:
pattern = f"%{search}%"
from app.utils import escape_like
pattern = f"%{escape_like(search)}%"
query = query.filter(
DefensiveTechnique.name.ilike(pattern)
| DefensiveTechnique.d3fend_id.ilike(pattern)

View File

@@ -7,8 +7,10 @@ including sync triggers, enable/disable toggles, and statistics.
import logging
from datetime import datetime
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel
from sqlalchemy.orm import Session
from app.database import get_db
@@ -17,6 +19,17 @@ from app.models.user import User
from app.models.data_source import DataSource
from app.services.audit_service import log_action
# ---------------------------------------------------------------------------
# Pydantic schemas for request validation
# ---------------------------------------------------------------------------
class DataSourceUpdate(BaseModel):
"""Payload for updating a data source — only allowed fields."""
is_enabled: Optional[bool] = None
sync_frequency: Optional[str] = None
config: Optional[dict] = None
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/data-sources", tags=["data-sources"])
@@ -90,29 +103,26 @@ def list_data_sources(
@router.patch("/{source_id}")
def update_data_source(
source_id: str,
body: dict,
body: DataSourceUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
"""Update a data source (enable/disable, change config).
**Requires** the ``admin`` role.
Body fields (all optional):
- ``is_enabled`` (bool)
- ``sync_frequency`` (str)
- ``config`` (dict)
"""
ds = db.query(DataSource).filter(DataSource.id == source_id).first()
if not ds:
raise HTTPException(status_code=404, detail="Data source not found")
if "is_enabled" in body:
ds.is_enabled = bool(body["is_enabled"])
if "sync_frequency" in body:
ds.sync_frequency = body["sync_frequency"]
if "config" in body:
ds.config = body["config"]
update_data = body.model_dump(exclude_unset=True)
if "is_enabled" in update_data:
ds.is_enabled = update_data["is_enabled"]
if "sync_frequency" in update_data:
ds.sync_frequency = update_data["sync_frequency"]
if "config" in update_data:
ds.config = update_data["config"]
db.commit()
@@ -122,7 +132,7 @@ def update_data_source(
action="update_data_source",
entity_type="data_source",
entity_id=str(ds.id),
details={"updates": body},
details={"updates": update_data},
)
return {"message": "Data source updated", "id": str(ds.id)}
@@ -156,14 +166,14 @@ def sync_data_source(
try:
summary = handler(db)
except Exception as exc:
logger.error("Sync failed for %s: %s", ds.name, exc)
logger.error("Sync failed for %s: %s", ds.name, exc, exc_info=True)
ds.last_sync_status = "error"
ds.last_sync_at = datetime.utcnow()
ds.last_sync_stats = {"error": str(exc)}
db.commit()
raise HTTPException(
status_code=500,
detail=f"Sync failed: {str(exc)}",
detail=f"Sync failed for '{ds.display_name}'. Check server logs for details.",
)
# Update DS record (the handler may already have done this,
@@ -222,7 +232,7 @@ def sync_all_data_sources(
"stats": summary,
})
except Exception as exc:
logger.error("Sync failed for %s: %s", ds.name, exc)
logger.error("Sync failed for %s: %s", ds.name, exc, exc_info=True)
ds.last_sync_status = "error"
ds.last_sync_at = datetime.utcnow()
ds.last_sync_stats = {"error": str(exc)}
@@ -230,7 +240,7 @@ def sync_all_data_sources(
results.append({
"source": ds.name,
"status": "error",
"detail": str(exc),
"detail": "Sync failed. Check server logs for details.",
})
log_action(

View File

@@ -5,10 +5,12 @@ and managing the template ↔ detection rule associations.
"""
import logging
import uuid
from typing import Optional
from datetime import datetime
from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel
from sqlalchemy import func
from sqlalchemy.orm import Session
@@ -20,6 +22,18 @@ from app.models.test_template import TestTemplate
from app.models.test_template_detection_rule import TestTemplateDetectionRule
from app.models.test_detection_result import TestDetectionResult
# ---------------------------------------------------------------------------
# Pydantic schemas for request validation
# ---------------------------------------------------------------------------
class DetectionRuleEvaluate(BaseModel):
"""Payload for evaluating a detection rule against a test."""
test_id: uuid.UUID
detection_rule_id: uuid.UUID
triggered: Optional[bool] = None
notes: Optional[str] = None
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/detection-rules", tags=["detection-rules"])
@@ -53,7 +67,8 @@ def list_detection_rules(
query = query.filter(DetectionRule.severity == severity)
if search:
pattern = f"%{search}%"
from app.utils import escape_like
pattern = f"%{escape_like(search)}%"
query = query.filter(
DetectionRule.title.ilike(pattern)
| DetectionRule.description.ilike(pattern)
@@ -294,27 +309,15 @@ def get_detection_rules_for_test(
@router.post("/evaluate")
def evaluate_detection_rule(
payload: dict,
payload: DetectionRuleEvaluate,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("blue_tech", "blue_lead")),
):
"""Save or update the evaluation result for a detection rule on a test.
Body:
{
"test_id": "...",
"detection_rule_id": "...",
"triggered": true | false | null,
"notes": "optional notes"
}
"""
test_id = payload.get("test_id")
detection_rule_id = payload.get("detection_rule_id")
triggered = payload.get("triggered")
notes = payload.get("notes")
if not test_id or not detection_rule_id:
raise HTTPException(status_code=400, detail="test_id and detection_rule_id are required")
"""Save or update the evaluation result for a detection rule on a test."""
test_id = payload.test_id
detection_rule_id = payload.detection_rule_id
triggered = payload.triggered
notes = payload.notes
# Check test exists
from app.models.test import Test

View File

@@ -20,6 +20,7 @@ Access Control
"""
import hashlib
import os
import uuid as _uuid
from typing import Optional
@@ -43,6 +44,29 @@ _RED_EDITABLE_STATES = (TestState.draft, TestState.red_executing)
# States where blue evidence can be uploaded / deleted
_BLUE_EDITABLE_STATES = (TestState.blue_evaluating,)
# ---------------------------------------------------------------------------
# Upload safety limits
# ---------------------------------------------------------------------------
# Maximum upload size in bytes (default 50 MB)
_MAX_UPLOAD_SIZE = 50 * 1024 * 1024
# Allowed file extensions (lowercase, with leading dot)
_ALLOWED_EXTENSIONS: set[str] = {
# Images / screenshots
".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp", ".svg",
# Documents
".pdf", ".doc", ".docx", ".xls", ".xlsx", ".csv", ".txt",
".md", ".rtf", ".odt", ".ods",
# Logs & captures
".log", ".pcap", ".pcapng", ".evtx", ".json", ".xml",
".yaml", ".yml", ".toml",
# Archives (for bundled evidence)
".zip", ".tar", ".gz", ".7z",
# Other common evidence types
".har", ".eml", ".msg",
}
# ---------------------------------------------------------------------------
# Helpers
@@ -177,21 +201,39 @@ async def upload_evidence(
# Validate permissions
_validate_upload_permission(test, team, current_user)
# 1. Read content + hash
content = await file.read()
# 1. Validate file extension
file_name = file.filename or "unnamed"
_, ext = os.path.splitext(file_name)
if ext.lower() not in _ALLOWED_EXTENSIONS:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"File type '{ext}' is not allowed. "
f"Permitted types: {', '.join(sorted(_ALLOWED_EXTENSIONS))}",
)
# 2. Read content with size limit
content = await file.read(_MAX_UPLOAD_SIZE + 1)
if len(content) > _MAX_UPLOAD_SIZE:
raise HTTPException(
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
detail=f"File exceeds maximum upload size of "
f"{_MAX_UPLOAD_SIZE // (1024 * 1024)} MB",
)
# 3. Hash
sha256 = hashlib.sha256(content).hexdigest()
# 2. Object key
file_name = file.filename or "unnamed"
key = f"{test_id}/{_uuid.uuid4()}_{file_name}"
# 4. Object key (sanitise filename to prevent path traversal in storage)
safe_name = os.path.basename(file_name)
key = f"{test_id}/{_uuid.uuid4()}_{safe_name}"
# 3. Upload to MinIO
# 5. Upload to MinIO
upload_file(content, key)
# 4. Persist metadata
# 6. Persist metadata
evidence = Evidence(
test_id=test_id,
file_name=file_name,
file_name=safe_name,
file_path=key,
sha256_hash=sha256,
uploaded_by=current_user.id,
@@ -202,7 +244,7 @@ async def upload_evidence(
db.commit()
db.refresh(evidence)
# 5. Audit
# 7. Audit
log_action(
db,
user_id=current_user.id,
@@ -210,7 +252,7 @@ async def upload_evidence(
entity_type="evidence",
entity_id=evidence.id,
details={
"file_name": file_name,
"file_name": safe_name,
"sha256": sha256,
"test_id": str(test_id),
"team": team.value,

View File

@@ -107,9 +107,10 @@ def _apply_filters(
query = query.filter(or_(*platform_filters))
if tactics:
from sqlalchemy import or_
from app.utils import escape_like
tactic_filters = []
for tactic in tactics:
tactic_filters.append(model.tactic.ilike(f"%{tactic}%"))
tactic_filters.append(model.tactic.ilike(f"%{escape_like(tactic)}%"))
query = query.filter(or_(*tactic_filters))
return query

View File

@@ -43,7 +43,8 @@ def coverage_summary(
"""Full coverage report as JSON — technique-by-technique with test counts."""
query = db.query(Technique)
if tactic:
query = query.filter(Technique.tactic.ilike(f"%{tactic}%"))
from app.utils import escape_like
query = query.filter(Technique.tactic.ilike(f"%{escape_like(tactic)}%"))
techniques = query.order_by(Technique.mitre_id).all()
@@ -109,7 +110,8 @@ def coverage_csv(
"""Export coverage as a downloadable CSV."""
query = db.query(Technique)
if tactic:
query = query.filter(Technique.tactic.ilike(f"%{tactic}%"))
from app.utils import escape_like
query = query.filter(Technique.tactic.ilike(f"%{escape_like(tactic)}%"))
techniques = query.order_by(Technique.mitre_id).all()

View File

@@ -80,10 +80,9 @@ def trigger_atomic_import(
try:
summary = import_atomic_red_team(db)
except Exception as exc:
logger.error("Atomic Red Team import failed: %s", exc)
logger.error("Atomic Red Team import failed: %s", exc, exc_info=True)
return {
"message": "Import failed",
"error": str(exc),
"message": "Import failed. Check server logs for details.",
}
return {

View File

@@ -69,13 +69,15 @@ def list_templates(
if source:
query = query.filter(TestTemplate.source == source)
if platform:
query = query.filter(TestTemplate.platform.ilike(f"%{platform}%"))
from app.utils import escape_like
query = query.filter(TestTemplate.platform.ilike(f"%{escape_like(platform)}%"))
if severity:
query = query.filter(TestTemplate.severity == severity)
if mitre_technique_id:
query = query.filter(TestTemplate.mitre_technique_id == mitre_technique_id)
if search:
pattern = f"%{search}%"
from app.utils import escape_like
pattern = f"%{escape_like(search)}%"
query = query.filter(
or_(
TestTemplate.name.ilike(pattern),

View File

@@ -109,7 +109,8 @@ def list_tests(
if technique_id:
query = query.filter(Test.technique_id == technique_id)
if platform:
query = query.filter(Test.platform.ilike(f"%{platform}%"))
from app.utils import escape_like
query = query.filter(Test.platform.ilike(f"%{escape_like(platform)}%"))
if created_by:
query = query.filter(Test.created_by == created_by)
if pending_validation_side == "red":
@@ -200,15 +201,25 @@ def create_test_from_template(
detail=f"TestTemplate with id '{payload.template_id}' not found",
)
technique = db.query(Technique).filter(Technique.id == payload.technique_id).first()
# Resolve technique_id: accept both UUID and MITRE ID (e.g. "T1059.001")
technique = None
try:
technique_uuid = uuid.UUID(payload.technique_id)
technique = db.query(Technique).filter(Technique.id == technique_uuid).first()
except ValueError:
pass
if technique is None:
technique = db.query(Technique).filter(Technique.mitre_id == payload.technique_id).first()
if technique is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Technique with id '{payload.technique_id}' not found",
detail=f"Technique '{payload.technique_id}' not found",
)
test = Test(
technique_id=payload.technique_id,
technique_id=technique.id,
name=template.name,
description=template.description,
platform=template.platform,

View File

@@ -49,7 +49,8 @@ def list_threat_actors(
# Filters
if search:
pattern = f"%{search}%"
from app.utils import escape_like
pattern = f"%{escape_like(search)}%"
query = query.filter(
or_(
ThreatActor.name.ilike(pattern),
@@ -68,9 +69,10 @@ def list_threat_actors(
query = query.filter(ThreatActor.sophistication == sophistication)
if target_sectors:
from app.utils import escape_like
# JSONB contains check
query = query.filter(
func.cast(ThreatActor.target_sectors, func.text()).ilike(f"%{target_sectors}%")
func.cast(ThreatActor.target_sectors, func.text()).ilike(f"%{escape_like(target_sectors)}%")
)
# Total count