Compare commits

..

5 Commits

26 changed files with 666 additions and 88 deletions
@@ -0,0 +1,58 @@
"""Phase 3: audit trail columns and data classification fields.
Revision ID: b029phase3
Revises: b028phase0
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
revision: str = "b029phase3"
down_revision: Union[str, None] = "b028phase0"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def _column_names(table: str) -> set[str]:
bind = op.get_bind()
insp = sa.inspect(bind)
return {c["name"] for c in insp.get_columns(table)}
def upgrade() -> None:
audit_cols = _column_names("audit_logs")
if "ip_address" not in audit_cols:
op.add_column("audit_logs", sa.Column("ip_address", sa.String(45), nullable=True))
if "user_agent" not in audit_cols:
op.add_column("audit_logs", sa.Column("user_agent", sa.String(500), nullable=True))
if "integrity_hash" not in audit_cols:
op.add_column("audit_logs", sa.Column("integrity_hash", sa.String(64), nullable=True))
if "session_id" not in audit_cols:
op.add_column("audit_logs", sa.Column("session_id", sa.String(100), nullable=True))
for table in ("tests", "evidences", "campaigns"):
cols = _column_names(table)
if "data_classification" not in cols:
op.add_column(
table,
sa.Column(
"data_classification",
sa.String(20),
nullable=False,
server_default="internal",
),
)
def downgrade() -> None:
for table in ("campaigns", "evidences", "tests"):
cols = _column_names(table)
if "data_classification" in cols:
op.drop_column(table, "data_classification")
audit_cols = _column_names("audit_logs")
for col in ("session_id", "integrity_hash", "user_agent", "ip_address"):
if col in audit_cols:
op.drop_column("audit_logs", col)
+7
View File
@@ -35,3 +35,10 @@ class TestResult(str, enum.Enum):
detected = "detected" detected = "detected"
not_detected = "not_detected" not_detected = "not_detected"
partially_detected = "partially_detected" partially_detected = "partially_detected"
class DataClassification(str, enum.Enum):
public = "public"
internal = "internal"
sensitive = "sensitive"
restricted = "restricted"
+11 -1
View File
@@ -23,6 +23,7 @@ from app.services.campaign_scheduler_service import check_and_run_recurring_camp
from app.jobs.jira_sync_job import sync_all_jira_links from app.jobs.jira_sync_job import sync_all_jira_links
from app.services.osint_enrichment_service import enrich_all_techniques from app.services.osint_enrichment_service import enrich_all_techniques
from app.services.stale_detection_service import detect_stale_coverage from app.services.stale_detection_service import detect_stale_coverage
from app.jobs.retention_job import run_retention_job
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -217,10 +218,19 @@ def start_scheduler() -> None:
name="Stale coverage detection (daily)", name="Stale coverage detection (daily)",
replace_existing=True, replace_existing=True,
) )
scheduler.add_job(
run_retention_job,
trigger="interval",
hours=24,
id="retention_policies",
name="Data retention policies (daily)",
replace_existing=True,
)
scheduler.start() scheduler.start()
logger.info( logger.info(
"Background scheduler started — mitre_sync (24h), intel_scan (7d), " "Background scheduler started — mitre_sync (24h), intel_scan (7d), "
"notification_cleanup (24h), weekly_snapshot (Sundays 00:00), " "notification_cleanup (24h), weekly_snapshot (Sundays 00:00), "
"recurring_campaigns (daily), jira_sync (1h), " "recurring_campaigns (daily), jira_sync (1h), "
"osint_enrichment (weekly), stale_detection (daily)" "osint_enrichment (weekly), stale_detection (daily), "
"retention_policies (daily)"
) )
+53
View File
@@ -0,0 +1,53 @@
"""Data retention policies — scheduled cleanup of aged records."""
from __future__ import annotations
import logging
from datetime import datetime, timedelta, timezone
from sqlalchemy.orm import Session
from app.database import SessionLocal
from app.models.audit import AuditLog
from app.services.notification_service import cleanup_old_notifications
logger = logging.getLogger(__name__)
AUDIT_LOG_RETENTION_DAYS = 730
def apply_retention_policies(db: Session) -> dict[str, int]:
"""Apply retention rules. Commits the session before returning."""
cutoff = datetime.now(timezone.utc) - timedelta(days=AUDIT_LOG_RETENTION_DAYS)
deleted_audit = (
db.query(AuditLog)
.filter(AuditLog.timestamp < cutoff)
.delete(synchronize_session=False)
)
if deleted_audit:
logger.info(
"Retention: deleted %d audit logs older than %d days",
deleted_audit,
AUDIT_LOG_RETENTION_DAYS,
)
deleted_notifications = cleanup_old_notifications(db, days=90)
db.commit()
return {
"audit_logs_deleted": deleted_audit,
"notifications_deleted": deleted_notifications,
}
def run_retention_job() -> None:
"""Entry point for the daily retention scheduler job."""
logger.info("Scheduled retention job starting...")
db = SessionLocal()
try:
summary = apply_retention_policies(db)
logger.info("Retention job finished — %s", summary)
except Exception:
logger.exception("Retention job failed")
db.rollback()
finally:
db.close()
+6
View File
@@ -0,0 +1,6 @@
"""Shared SlowAPI rate limiter for all routers."""
from slowapi import Limiter
from slowapi.util import get_remote_address
limiter = Limiter(key_func=get_remote_address)
+19 -5
View File
@@ -6,9 +6,8 @@ from fastapi import FastAPI, Request, status
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from slowapi import Limiter, _rate_limit_exceeded_handler from slowapi import _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded from slowapi.errors import RateLimitExceeded
from slowapi.util import get_remote_address
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from app.routers import auth as auth_router from app.routers import auth as auth_router
@@ -40,6 +39,8 @@ from app.routers import advanced_metrics as advanced_metrics_router
from app.routers import osint as osint_router from app.routers import osint as osint_router
from app.domain.errors import DomainError from app.domain.errors import DomainError
from app.middleware.error_handler import domain_exception_handler from app.middleware.error_handler import domain_exception_handler
from app.middleware.request_context import RequestContextMiddleware
from app.limiter import limiter
from app.storage import ensure_bucket_exists from app.storage import ensure_bucket_exists
from app.jobs.mitre_sync_job import start_scheduler, scheduler from app.jobs.mitre_sync_job import start_scheduler, scheduler
@@ -71,10 +72,11 @@ app = FastAPI(
) )
# ── Rate Limiter ────────────────────────────────────────────────────────── # ── Rate Limiter ──────────────────────────────────────────────────────────
limiter = Limiter(key_func=get_remote_address)
app.state.limiter = limiter app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
app.add_middleware(RequestContextMiddleware)
# ── Domain exception → HTTP mapping ────────────────────────────────────── # ── Domain exception → HTTP mapping ──────────────────────────────────────
app.add_exception_handler(DomainError, domain_exception_handler) app.add_exception_handler(DomainError, domain_exception_handler)
@@ -136,15 +138,27 @@ def health():
# ── Exception Handlers ──────────────────────────────────────────────────── # ── Exception Handlers ────────────────────────────────────────────────────
def _serialize_validation_errors(exc: RequestValidationError) -> list[dict]:
"""Return validation errors safe for JSON (no raw exception objects)."""
serialized: list[dict] = []
for err in exc.errors():
item = dict(err)
ctx = item.get("ctx")
if isinstance(ctx, dict):
item["ctx"] = {key: str(value) for key, value in ctx.items()}
serialized.append(item)
return serialized
@app.exception_handler(RequestValidationError) @app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError): async def validation_exception_handler(request: Request, exc: RequestValidationError):
"""Handle validation errors with consistent format.""" """Handle validation errors with consistent format."""
return JSONResponse( return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
content={ content={
"detail": "Validation error", "detail": "Validation error",
"code": "VALIDATION_ERROR", "code": "VALIDATION_ERROR",
"errors": exc.errors(), "errors": _serialize_validation_errors(exc),
}, },
) )
+26
View File
@@ -0,0 +1,26 @@
"""Request context middleware — captures client IP and User-Agent per request."""
from contextvars import ContextVar
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
request_ip: ContextVar[str] = ContextVar("request_ip", default="")
request_user_agent: ContextVar[str] = ContextVar("request_user_agent", default="")
def resolve_client_ip(request: Request) -> str:
"""Extract the client IP, honouring ``X-Forwarded-For`` when present."""
forwarded = request.headers.get("X-Forwarded-For")
if forwarded:
return forwarded.split(",")[0].strip()
if request.client:
return request.client.host
return "unknown"
class RequestContextMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
request_ip.set(resolve_client_ip(request))
request_user_agent.set(request.headers.get("User-Agent", ""))
return await call_next(request)
+4
View File
@@ -22,6 +22,10 @@ class AuditLog(Base):
entity_id = Column(String, nullable=True) entity_id = Column(String, nullable=True)
timestamp = Column(DateTime(timezone=True), server_default=func.now()) timestamp = Column(DateTime(timezone=True), server_default=func.now())
details = Column(JSONB, nullable=True) details = Column(JSONB, nullable=True)
ip_address = Column(String(45), nullable=True)
user_agent = Column(String(500), nullable=True)
integrity_hash = Column(String(64), nullable=True)
session_id = Column(String(100), nullable=True)
# Relationships # Relationships
user = relationship("User") user = relationship("User")
+1
View File
@@ -53,6 +53,7 @@ class Campaign(Base):
target_platform = Column(String, nullable=True) target_platform = Column(String, nullable=True)
tags = Column(JSONB, nullable=True, default=[]) tags = Column(JSONB, nullable=True, default=[])
created_at = Column(DateTime(timezone=True), server_default=func.now()) created_at = Column(DateTime(timezone=True), server_default=func.now())
data_classification = Column(String(20), nullable=False, server_default="internal")
# Recurring scheduling fields # Recurring scheduling fields
is_recurring = Column(Boolean, default=False) is_recurring = Column(Boolean, default=False)
+1
View File
@@ -6,6 +6,7 @@ working with ``from app.models.enums import ...``.
""" """
from app.domain.enums import ( # noqa: F401 from app.domain.enums import ( # noqa: F401
DataClassification,
TeamSide, TeamSide,
TechniqueStatus, TechniqueStatus,
TestResult, TestResult,
+1
View File
@@ -28,6 +28,7 @@ class Evidence(Base):
uploaded_at = Column(DateTime(timezone=True), server_default=func.now()) uploaded_at = Column(DateTime(timezone=True), server_default=func.now())
team = Column(Enum(TeamSide, name="teamside"), nullable=False, default=TeamSide.red) team = Column(Enum(TeamSide, name="teamside"), nullable=False, default=TeamSide.red)
notes = Column(Text, nullable=True) notes = Column(Text, nullable=True)
data_classification = Column(String(20), nullable=False, server_default="internal")
# Relationships # Relationships
test = relationship("Test", back_populates="evidences") test = relationship("Test", back_populates="evidences")
+1
View File
@@ -62,6 +62,7 @@ class Test(Base):
# ── Re-test fields ──────────────────────────────────────────── # ── Re-test fields ────────────────────────────────────────────
retest_of = Column(UUID(as_uuid=True), ForeignKey("tests.id"), nullable=True) retest_of = Column(UUID(as_uuid=True), ForeignKey("tests.id"), nullable=True)
retest_count = Column(Integer, default=0) retest_count = Column(Integer, default=0)
data_classification = Column(String(20), nullable=False, server_default="internal")
# ── Relationships ─────────────────────────────────────────────── # ── Relationships ───────────────────────────────────────────────
technique = relationship("Technique", back_populates="tests") technique = relationship("Technique", back_populates="tests")
+51 -56
View File
@@ -11,39 +11,33 @@ import os
from fastapi import APIRouter, Cookie, Depends, Request, Response from fastapi import APIRouter, Cookie, Depends, Request, Response
from fastapi.security import OAuth2PasswordRequestForm from fastapi.security import OAuth2PasswordRequestForm
from slowapi import Limiter
from slowapi.util import get_remote_address
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from jose import jwt, JWTError from jose import jwt, JWTError
from app.auth import create_access_token, blacklist_token from app.auth import create_access_token, blacklist_token, verify_password
from app.config import settings from app.config import settings
from app.database import get_db from app.database import get_db
from app.dependencies.auth import get_current_user from app.dependencies.auth import get_current_user
from app.domain.errors import BusinessRuleViolation, PermissionViolation
from app.domain.unit_of_work import UnitOfWork from app.domain.unit_of_work import UnitOfWork
from app.limiter import limiter
from app.middleware.request_context import resolve_client_ip
from app.models.user import User from app.models.user import User
from app.services.auth_service import authenticate_user, change_password as auth_change_password from app.services.auth_service import (
_DUMMY_HASH,
change_password as auth_change_password,
)
from app.services.audit_service import log_action
from app.schemas.auth import TokenResponse, UserOut from app.schemas.auth import TokenResponse, UserOut
from app.schemas.user import PasswordChange from app.schemas.user import PasswordChange
# Rate limiter instance (shares backend state via app.state.limiter)
limiter = Limiter(key_func=get_remote_address)
router = APIRouter(prefix="/auth", tags=["auth"]) router = APIRouter(prefix="/auth", tags=["auth"])
# Detect whether we're behind HTTPS (production) so the cookie can be Secure
_IS_HTTPS = os.environ.get("AEGIS_ENV", "").lower() == "production" _IS_HTTPS = os.environ.get("AEGIS_ENV", "").lower() == "production"
# Cookie name used to transport the JWT
_COOKIE_NAME = "aegis_token" _COOKIE_NAME = "aegis_token"
# ---------------------------------------------------------------------------
# POST /auth/login
# ---------------------------------------------------------------------------
@router.post("/login", response_model=TokenResponse) @router.post("/login", response_model=TokenResponse)
@limiter.limit("5/minute") @limiter.limit("5/minute")
def login( def login(
@@ -54,19 +48,49 @@ def login(
): ):
"""Authenticate a user and return a JWT access token. """Authenticate a user and return a JWT access token.
Rate-limited to **5 attempts per minute per IP** to prevent brute-force Rate-limited to **5 attempts per minute per IP**. Failed and successful
attacks. The token is set as an HttpOnly cookie **and** returned in the logins are recorded in the audit log (SEC-009).
JSON body for API/Swagger compatibility.
""" """
user = authenticate_user( user = db.query(User).filter(User.username == form_data.username).first()
db, target_hash = user.hashed_password if user else _DUMMY_HASH
username=form_data.username, password_valid = verify_password(form_data.password, target_hash)
password=form_data.password, ip = resolve_client_ip(request)
)
if user is None or not password_valid:
with UnitOfWork(db) as uow:
log_action(
db,
user.id if user else None,
"LOGIN_FAILED",
"auth",
None,
details={
"username": form_data.username,
"ip": ip,
"reason": "invalid_credentials",
},
ip_address=ip,
)
uow.commit()
raise BusinessRuleViolation("Incorrect username or password")
if not user.is_active:
raise PermissionViolation("Account is disabled. Contact an administrator.")
access_token = create_access_token(data={"sub": user.username}) access_token = create_access_token(data={"sub": user.username})
# Set HttpOnly cookie — inaccessible from JS with UnitOfWork(db) as uow:
log_action(
db,
user.id,
"LOGIN_SUCCESS",
"auth",
str(user.id),
details={"username": user.username, "ip": ip},
ip_address=ip,
)
uow.commit()
response.set_cookie( response.set_cookie(
key=_COOKIE_NAME, key=_COOKIE_NAME,
value=access_token, value=access_token,
@@ -80,27 +104,13 @@ def login(
return TokenResponse(access_token=access_token) return TokenResponse(access_token=access_token)
# ---------------------------------------------------------------------------
# POST /auth/logout
# ---------------------------------------------------------------------------
@router.post("/logout") @router.post("/logout")
def logout( def logout(
request: Request, request: Request,
response: Response, response: Response,
aegis_token: str | None = Cookie(None), aegis_token: str | None = Cookie(None),
): ):
"""Clear the authentication cookie and revoke the current token. """Clear the authentication cookie and revoke the current token."""
The token's ``jti`` is added to the Redis blacklist so it cannot
be reused even if the cookie has already been copied elsewhere.
The blacklist entry auto-expires when the token's ``exp`` is reached.
When both HttpOnly cookie and ``Authorization: Bearer`` are present
(typical for API clients), **both** are revoked so the session cannot
survive on whichever credential the next request prefers.
"""
bearer = ( bearer = (
request.headers.get("Authorization") request.headers.get("Authorization")
or request.headers.get("authorization") or request.headers.get("authorization")
@@ -124,7 +134,7 @@ def logout(
if jti: if jti:
blacklist_token(jti, float(exp)) blacklist_token(jti, float(exp))
except JWTError: except JWTError:
pass # token already invalid — nothing to revoke for this raw value pass
response.delete_cookie( response.delete_cookie(
key=_COOKIE_NAME, key=_COOKIE_NAME,
@@ -136,34 +146,19 @@ def logout(
return {"detail": "Logged out"} return {"detail": "Logged out"}
# ---------------------------------------------------------------------------
# GET /auth/me
# ---------------------------------------------------------------------------
@router.get("/me", response_model=UserOut) @router.get("/me", response_model=UserOut)
def read_current_user(current_user: User = Depends(get_current_user)): def read_current_user(current_user: User = Depends(get_current_user)):
"""Return the profile of the currently authenticated user.""" """Return the profile of the currently authenticated user."""
return current_user return current_user
# ---------------------------------------------------------------------------
# POST /auth/change-password
# ---------------------------------------------------------------------------
@router.post("/change-password") @router.post("/change-password")
def change_password( def change_password(
body: PasswordChange, body: PasswordChange,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ):
"""Change the current user's password. """Change the current user's password."""
Requires the current password for verification. On success the
``must_change_password`` flag is cleared so the user can proceed
normally.
"""
auth_change_password( auth_change_password(
db, db,
current_user, current_user,
+4 -1
View File
@@ -24,7 +24,7 @@ import os
import uuid as _uuid import uuid as _uuid
from typing import Optional from typing import Optional
from fastapi import APIRouter, Depends, File, Form, Query, UploadFile, status from fastapi import APIRouter, Depends, File, Form, Query, Request, UploadFile, status
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.database import get_db from app.database import get_db
@@ -44,6 +44,7 @@ from app.services.evidence_service import (
validate_file, validate_file,
validate_upload_permission, validate_upload_permission,
) )
from app.limiter import limiter
from app.storage import get_presigned_url, upload_file from app.storage import get_presigned_url, upload_file
router = APIRouter(tags=["evidence"]) router = APIRouter(tags=["evidence"])
@@ -78,7 +79,9 @@ def _evidence_to_out(evidence: Evidence) -> EvidenceOut:
response_model=EvidenceOut, response_model=EvidenceOut,
status_code=status.HTTP_201_CREATED, status_code=status.HTTP_201_CREATED,
) )
@limiter.limit("10/minute")
async def upload_evidence( async def upload_evidence(
request: Request,
test_id: _uuid.UUID, test_id: _uuid.UUID,
file: UploadFile = File(...), file: UploadFile = File(...),
team: TeamSide = Form(TeamSide.red), team: TeamSide = Form(TeamSide.red),
+12 -1
View File
@@ -2,13 +2,14 @@
from uuid import UUID from uuid import UUID
from fastapi import APIRouter, Depends, Query from fastapi import APIRouter, Depends, Query, Request
from fastapi.responses import FileResponse from fastapi.responses import FileResponse
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.database import get_db from app.database import get_db
from app.dependencies.auth import get_current_user, require_any_role from app.dependencies.auth import get_current_user, require_any_role
from app.models.user import User from app.models.user import User
from app.limiter import limiter
from app.services import report_generation_service from app.services import report_generation_service
router = APIRouter(prefix="/reports/generate", tags=["professional-reports"]) router = APIRouter(prefix="/reports/generate", tags=["professional-reports"])
@@ -21,7 +22,9 @@ _MEDIA_TYPES = {
@router.get("/purple-campaign/{campaign_id}") @router.get("/purple-campaign/{campaign_id}")
@limiter.limit("5/minute")
def generate_purple_report( def generate_purple_report(
request: Request,
campaign_id: UUID, campaign_id: UUID,
format: str = Query("pdf", pattern="^(pdf|docx|html)$"), format: str = Query("pdf", pattern="^(pdf|docx|html)$"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
@@ -39,7 +42,9 @@ def generate_purple_report(
@router.get("/coverage-summary") @router.get("/coverage-summary")
@limiter.limit("5/minute")
def generate_coverage_report( def generate_coverage_report(
request: Request,
format: str = Query("pdf", pattern="^(pdf|docx|html)$"), format: str = Query("pdf", pattern="^(pdf|docx|html)$"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
user: User = Depends(require_any_role("red_lead", "blue_lead", "viewer")), user: User = Depends(require_any_role("red_lead", "blue_lead", "viewer")),
@@ -56,7 +61,9 @@ def generate_coverage_report(
@router.get("/executive-summary") @router.get("/executive-summary")
@limiter.limit("5/minute")
def generate_executive_report( def generate_executive_report(
request: Request,
format: str = Query("pdf", pattern="^(pdf|docx|html)$"), format: str = Query("pdf", pattern="^(pdf|docx|html)$"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
user: User = Depends(require_any_role("red_lead", "blue_lead", "viewer")), user: User = Depends(require_any_role("red_lead", "blue_lead", "viewer")),
@@ -73,7 +80,9 @@ def generate_executive_report(
@router.get("/quarterly-summary") @router.get("/quarterly-summary")
@limiter.limit("5/minute")
def generate_quarterly_report( def generate_quarterly_report(
request: Request,
format: str = Query("pdf", pattern="^(pdf|docx|html)$"), format: str = Query("pdf", pattern="^(pdf|docx|html)$"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
user: User = Depends(require_any_role("red_lead", "blue_lead", "viewer")), user: User = Depends(require_any_role("red_lead", "blue_lead", "viewer")),
@@ -90,7 +99,9 @@ def generate_quarterly_report(
@router.get("/technique/{technique_id}") @router.get("/technique/{technique_id}")
@limiter.limit("5/minute")
def generate_technique_report( def generate_technique_report(
request: Request,
technique_id: UUID, technique_id: UUID,
format: str = Query("pdf", pattern="^(pdf|docx|html)$"), format: str = Query("pdf", pattern="^(pdf|docx|html)$"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
+6 -1
View File
@@ -7,7 +7,7 @@ scheduler health introspection.
import logging import logging
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends, Request
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.database import get_db from app.database import get_db
@@ -17,6 +17,7 @@ from app.services.mitre_sync_service import sync_mitre
from app.services.intel_service import scan_intel from app.services.intel_service import scan_intel
from app.services.atomic_import_service import import_atomic_red_team from app.services.atomic_import_service import import_atomic_red_team
from app.jobs.mitre_sync_job import scheduler from app.jobs.mitre_sync_job import scheduler
from app.limiter import limiter
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -24,7 +25,9 @@ router = APIRouter(prefix="/system", tags=["system"])
@router.post("/sync-mitre") @router.post("/sync-mitre")
@limiter.limit("2/hour")
def trigger_mitre_sync( def trigger_mitre_sync(
request: Request,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")), current_user: User = Depends(require_role("admin")),
): ):
@@ -63,7 +66,9 @@ def trigger_intel_scan(
@router.post("/import-atomic-tests") @router.post("/import-atomic-tests")
@limiter.limit("2/hour")
def trigger_atomic_import( def trigger_atomic_import(
request: Request,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")), current_user: User = Depends(require_role("admin")),
): ):
+39 -2
View File
@@ -21,11 +21,13 @@ GET /tests/{id}/timeline — audit-log history for this test
import uuid import uuid
from typing import Optional from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query, status from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.database import get_db from app.database import get_db
from app.dependencies.auth import get_current_user, require_any_role from app.dependencies.auth import get_current_user, require_any_role, require_role
from app.domain.enums import DataClassification
from app.limiter import limiter
from app.models.enums import TestState from app.models.enums import TestState
from app.models.user import User from app.models.user import User
from app.schemas.test import ( from app.schemas.test import (
@@ -37,6 +39,7 @@ from app.schemas.test import (
TestRedValidate, TestRedValidate,
TestBlueValidate, TestBlueValidate,
TestRemediationUpdate, TestRemediationUpdate,
TestClassificationUpdate,
) )
from app.schemas.test_template import TestTemplateInstantiate from app.schemas.test_template import TestTemplateInstantiate
from app.domain.unit_of_work import UnitOfWork from app.domain.unit_of_work import UnitOfWork
@@ -112,7 +115,9 @@ def list_tests(
response_model=TestOut, response_model=TestOut,
status_code=status.HTTP_201_CREATED, status_code=status.HTTP_201_CREATED,
) )
@limiter.limit("30/minute")
def create_test( def create_test(
request: Request,
payload: TestCreate, payload: TestCreate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")), current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
@@ -152,7 +157,9 @@ def create_test(
response_model=TestOut, response_model=TestOut,
status_code=status.HTTP_201_CREATED, status_code=status.HTTP_201_CREATED,
) )
@limiter.limit("30/minute")
def create_test_from_template( def create_test_from_template(
request: Request,
payload: TestTemplateInstantiate, payload: TestTemplateInstantiate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")), current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
@@ -241,6 +248,36 @@ def update_test(
return test return test
# ---------------------------------------------------------------------------
# PATCH /tests/{id}/classification — admin data classification
# ---------------------------------------------------------------------------
@router.patch("/{test_id}/classification", response_model=TestOut)
def update_test_classification(
test_id: uuid.UUID,
payload: TestClassificationUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
"""Update the data classification label for a test (admin only)."""
with UnitOfWork(db) as uow:
test = crud_get_test_or_raise(db, test_id)
test.data_classification = payload.data_classification.value
db.flush()
log_action(
db,
user_id=current_user.id,
action="update_test_classification",
entity_type="test",
entity_id=test.id,
details={"data_classification": payload.data_classification.value},
)
uow.commit()
db.refresh(test)
return test
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# PATCH /tests/{id}/red — Red Team update (draft, red_executing) # PATCH /tests/{id}/red — Red Team update (draft, red_executing)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
+8
View File
@@ -5,6 +5,7 @@ from datetime import datetime
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from app.domain.enums import DataClassification
from app.models.enums import TestResult, TestState from app.models.enums import TestResult, TestState
@@ -25,6 +26,12 @@ class TestCreate(BaseModel):
# ── Update (general) ─────────────────────────────────────────────── # ── Update (general) ───────────────────────────────────────────────
class TestClassificationUpdate(BaseModel):
"""Admin-only payload for changing data classification."""
data_classification: DataClassification
class TestUpdate(BaseModel): class TestUpdate(BaseModel):
"""Payload for partially updating an existing test. """Payload for partially updating an existing test.
Every field is optional so callers send only what changed.""" Every field is optional so callers send only what changed."""
@@ -152,6 +159,7 @@ class TestOut(BaseModel):
# Re-test fields # Re-test fields
retest_of: uuid.UUID | None = None retest_of: uuid.UUID | None = None
retest_count: int = 0 retest_count: int = 0
data_classification: str = "internal"
# Technique info (populated when joined) # Technique info (populated when joined)
technique_mitre_id: str | None = None technique_mitre_id: str | None = None
+50 -17
View File
@@ -1,32 +1,65 @@
"""Audit logging with request context and integrity hashing."""
from __future__ import annotations
import hashlib
from datetime import datetime, timezone
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.middleware.request_context import request_ip, request_user_agent
from app.models.audit import AuditLog from app.models.audit import AuditLog
def _integrity_payload(entry: AuditLog) -> str:
ts = entry.timestamp
if ts is None:
ts = datetime.now(timezone.utc)
user_part = str(entry.user_id) if entry.user_id else ""
entity_type = entry.entity_type or ""
entity_id = entry.entity_id or ""
return f"{user_part}:{entry.action}:{entity_type}:{entity_id}:{ts.isoformat()}"
def compute_integrity_hash(entry: AuditLog) -> str:
"""Return the SHA-256 hex digest for an audit log entry."""
return hashlib.sha256(_integrity_payload(entry).encode()).hexdigest()
def verify_audit_integrity(entry: AuditLog) -> bool:
"""Return whether the stored hash matches the entry's current fields."""
if not entry.integrity_hash:
return False
return entry.integrity_hash == compute_integrity_hash(entry)
def log_action( def log_action(
db: Session, db: Session,
user_id, user_id,
action: str, action: str,
entity_type: str = None, entity_type: str | None = None,
entity_id: str = None, entity_id: str | None = None,
details: dict = None details: dict | None = None,
): *,
""" ip_address: str | None = None,
Log an action to the audit log. user_agent: str | None = None,
session_id: str | None = None,
Args: ) -> AuditLog:
db: Database session """Record an audit event. Does not commit — the caller owns the transaction."""
user_id: UUID of the user performing the action (can be None for system actions) ip = ip_address if ip_address is not None else request_ip.get("")
action: Description of the action (e.g., "create_test", "validate_technique") ua = user_agent if user_agent is not None else request_user_agent.get("")
entity_type: Type of entity affected (e.g., "technique", "test", "user")
entity_id: ID of the entity affected entry = AuditLog(
details: Additional details as a dictionary (stored as JSONB)
"""
log = AuditLog(
user_id=user_id, user_id=user_id,
action=action, action=action,
entity_type=entity_type, entity_type=entity_type,
entity_id=str(entity_id) if entity_id else None, entity_id=str(entity_id) if entity_id else None,
details=details, details=details,
ip_address=ip or None,
user_agent=ua or None,
session_id=session_id,
) )
db.add(log) db.add(entry)
db.flush()
entry.integrity_hash = compute_integrity_hash(entry)
return entry
+2 -4
View File
@@ -121,10 +121,8 @@ def client(db, monkeypatch):
app.dependency_overrides[get_db] = override_get_db app.dependency_overrides[get_db] = override_get_db
Base.metadata.create_all(bind=engine) Base.metadata.create_all(bind=engine)
if hasattr(app.state, "limiter"): from app.limiter import limiter
app.state.limiter.enabled = False limiter.enabled = False
from app.routers.auth import limiter as auth_limiter
auth_limiter.enabled = False
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
with TestClient(app) as test_client: with TestClient(app) as test_client:
+84
View File
@@ -0,0 +1,84 @@
"""Tests for enhanced audit trail (IP, user-agent, integrity hash)."""
from datetime import datetime, timedelta, timezone
import pytest
from app.middleware.request_context import request_ip, request_user_agent
from app.models.audit import AuditLog
from app.services.audit_service import (
compute_integrity_hash,
log_action,
verify_audit_integrity,
)
from app.jobs.retention_job import apply_retention_policies
class TestAuditIntegrity:
def test_integrity_hash_set_on_log(self, db):
entry = log_action(
db,
user_id=None,
action="test_action",
entity_type="test",
entity_id="abc",
ip_address="10.0.0.1",
user_agent="pytest",
)
db.commit()
db.refresh(entry)
assert entry.integrity_hash
assert len(entry.integrity_hash) == 64
assert verify_audit_integrity(entry)
def test_tampered_entry_fails_integrity(self, db):
entry = log_action(
db,
user_id=None,
action="test_action",
entity_type="test",
entity_id="abc",
)
db.commit()
entry.entity_id = "tampered"
assert not verify_audit_integrity(entry)
def test_recomputed_hash_matches_stored(self, db):
entry = log_action(db, None, "update", "user", "1")
db.commit()
assert entry.integrity_hash == compute_integrity_hash(entry)
class TestRequestContext:
def test_context_vars_used_by_log_action(self, db):
token_ip = request_ip.set("203.0.113.42")
token_ua = request_user_agent.set("AegisTestClient/1.0")
try:
entry = log_action(db, None, "ctx_action", "system", None)
db.commit()
assert entry.ip_address == "203.0.113.42"
assert entry.user_agent == "AegisTestClient/1.0"
finally:
request_ip.reset(token_ip)
request_user_agent.reset(token_ua)
class TestRetentionJob:
def test_deletes_old_audit_logs(self, db):
old = AuditLog(
action="old",
entity_type="system",
timestamp=datetime.now(timezone.utc) - timedelta(days=800),
)
recent = AuditLog(
action="recent",
entity_type="system",
timestamp=datetime.now(timezone.utc) - timedelta(days=1),
)
db.add_all([old, recent])
db.commit()
summary = apply_retention_policies(db)
assert summary["audit_logs_deleted"] >= 1
remaining = db.query(AuditLog).all()
assert all(log.action != "old" for log in remaining)
+58
View File
@@ -0,0 +1,58 @@
"""Tests for login attempt auditing (SEC-009)."""
from app.models.audit import AuditLog
def test_login_failed_creates_audit_entry(client, admin_user, db):
response = client.post(
"/api/v1/auth/login",
data={"username": "admin", "password": "wrong"},
headers={"X-Forwarded-For": "198.51.100.10", "User-Agent": "LoginAuditTest/1.0"},
)
assert response.status_code == 400
log = (
db.query(AuditLog)
.filter(AuditLog.action == "LOGIN_FAILED")
.order_by(AuditLog.timestamp.desc())
.first()
)
assert log is not None
assert log.entity_type == "auth"
assert log.details["username"] == "admin"
assert log.details["reason"] == "invalid_credentials"
assert log.ip_address == "198.51.100.10"
assert log.user_agent == "LoginAuditTest/1.0"
assert log.integrity_hash
def test_login_success_creates_audit_entry(client, admin_user, db):
client.cookies.clear()
response = client.post(
"/api/v1/auth/login",
data={"username": "admin", "password": "admin123"},
headers={"X-Forwarded-For": "198.51.100.20"},
)
assert response.status_code == 200
log = (
db.query(AuditLog)
.filter(AuditLog.action == "LOGIN_SUCCESS")
.order_by(AuditLog.timestamp.desc())
.first()
)
assert log is not None
assert log.user_id == admin_user.id
assert log.ip_address == "198.51.100.20"
assert log.integrity_hash
def test_login_unknown_user_still_audited(client, db):
response = client.post(
"/api/v1/auth/login",
data={"username": "nobody", "password": "password"},
)
assert response.status_code == 400
log = db.query(AuditLog).filter(AuditLog.action == "LOGIN_FAILED").first()
assert log is not None
assert log.user_id is None
+72
View File
@@ -0,0 +1,72 @@
"""Tests for data classification fields and admin updates."""
from app.models.enums import TestState
from app.models.test import Test
from app.models.technique import Technique
def _seed_technique(db) -> Technique:
technique = Technique(
mitre_id="T9999",
name="Test Technique",
tactic="test",
platforms=["linux"],
)
db.add(technique)
db.commit()
db.refresh(technique)
return technique
def test_new_test_defaults_to_internal(db, red_lead_user):
technique = _seed_technique(db)
test = Test(
technique_id=technique.id,
name="Classification test",
created_by=red_lead_user.id,
)
db.add(test)
db.commit()
db.refresh(test)
assert test.data_classification == "internal"
def test_admin_can_update_classification(client, db, admin_user, admin_token, red_lead_user):
technique = _seed_technique(db)
test = Test(
technique_id=technique.id,
name="Classify me",
created_by=red_lead_user.id,
state=TestState.draft,
)
db.add(test)
db.commit()
response = client.patch(
f"/api/v1/tests/{test.id}/classification",
json={"data_classification": "sensitive"},
headers={"Authorization": f"Bearer {admin_token}"},
)
assert response.status_code == 200
assert response.json()["data_classification"] == "sensitive"
db.refresh(test)
assert test.data_classification == "sensitive"
def test_non_admin_cannot_update_classification(client, db, admin_user, red_lead_token, red_lead_user):
technique = _seed_technique(db)
test = Test(
technique_id=technique.id,
name="Protected",
created_by=red_lead_user.id,
)
db.add(test)
db.commit()
response = client.patch(
f"/api/v1/tests/{test.id}/classification",
json={"data_classification": "restricted"},
headers={"Authorization": f"Bearer {red_lead_token}"},
)
assert response.status_code == 403
+25
View File
@@ -0,0 +1,25 @@
"""Smoke tests for extended rate-limit decorators (SEC-003)."""
import inspect
from app.routers import evidence, professional_reports, system, tests
def test_sync_mitre_has_hourly_limit():
source = inspect.getsource(system.trigger_mitre_sync)
assert "2/hour" in source
def test_create_test_has_per_minute_limit():
source = inspect.getsource(tests.create_test)
assert "30/minute" in source
def test_upload_evidence_has_per_minute_limit():
source = inspect.getsource(evidence.upload_evidence)
assert "10/minute" in source
def test_report_endpoints_have_per_minute_limit():
source = inspect.getsource(professional_reports.generate_coverage_report)
assert "5/minute" in source
@@ -42,6 +42,14 @@ class TestUsernameValidation:
with pytest.raises(ValidationError, match="3-50 characters"): with pytest.raises(ValidationError, match="3-50 characters"):
UserCreate(username="john@doe", password="SecurePass123!@#") UserCreate(username="john@doe", password="SecurePass123!@#")
def test_reserved_username_system(self):
with pytest.raises(ValidationError):
UserCreate(username="system", password="SecurePass123!@#")
def test_invalid_username_path_chars(self):
with pytest.raises(ValidationError):
UserCreate(username="../admin", password="SecurePass123!@#")
def test_reserved_username_admin(self): def test_reserved_username_admin(self):
with pytest.raises(ValidationError, match="reserved"): with pytest.raises(ValidationError, match="reserved"):
UserCreate(username="admin", password="SecurePass123!@#") UserCreate(username="admin", password="SecurePass123!@#")
+59
View File
@@ -0,0 +1,59 @@
"""API-level validation tests for user creation (SEC-004, SEC-007)."""
def test_create_user_weak_password_rejected(client, admin_user, admin_token):
response = client.post(
"/api/v1/users",
json={
"username": "newuser",
"password": "123",
"email": "new@test.com",
"role": "viewer",
},
headers={"Authorization": f"Bearer {admin_token}"},
)
assert response.status_code == 422
assert "password" in response.text.lower()
def test_create_user_reserved_username(client, admin_user, admin_token):
response = client.post(
"/api/v1/users",
json={
"username": "system",
"password": "SecurePass123!@#",
"email": "sys@test.com",
"role": "viewer",
},
headers={"Authorization": f"Bearer {admin_token}"},
)
assert response.status_code == 422
def test_create_user_invalid_username_chars(client, admin_user, admin_token):
response = client.post(
"/api/v1/users",
json={
"username": "../admin",
"password": "SecurePass123!@#",
"email": "bad@test.com",
"role": "viewer",
},
headers={"Authorization": f"Bearer {admin_token}"},
)
assert response.status_code == 422
def test_create_user_valid_password_accepted(client, admin_user, admin_token):
response = client.post(
"/api/v1/users",
json={
"username": "validuser99",
"password": "ValidPass123!@#",
"email": "valid@test.com",
"role": "viewer",
},
headers={"Authorization": f"Bearer {admin_token}"},
)
assert response.status_code == 201
assert response.json()["username"] == "validuser99"