Compare commits
5 Commits
a8a24b5429
...
bdeeed54e1
| Author | SHA1 | Date | |
|---|---|---|---|
| bdeeed54e1 | |||
| 3e854b7b79 | |||
| 5b29c2fc56 | |||
| 6b076f52b2 | |||
| c0aff4cbeb |
@@ -0,0 +1,58 @@
|
||||
"""Phase 3: audit trail columns and data classification fields.
|
||||
|
||||
Revision ID: b029phase3
|
||||
Revises: b028phase0
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
revision: str = "b029phase3"
|
||||
down_revision: Union[str, None] = "b028phase0"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def _column_names(table: str) -> set[str]:
|
||||
bind = op.get_bind()
|
||||
insp = sa.inspect(bind)
|
||||
return {c["name"] for c in insp.get_columns(table)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
audit_cols = _column_names("audit_logs")
|
||||
if "ip_address" not in audit_cols:
|
||||
op.add_column("audit_logs", sa.Column("ip_address", sa.String(45), nullable=True))
|
||||
if "user_agent" not in audit_cols:
|
||||
op.add_column("audit_logs", sa.Column("user_agent", sa.String(500), nullable=True))
|
||||
if "integrity_hash" not in audit_cols:
|
||||
op.add_column("audit_logs", sa.Column("integrity_hash", sa.String(64), nullable=True))
|
||||
if "session_id" not in audit_cols:
|
||||
op.add_column("audit_logs", sa.Column("session_id", sa.String(100), nullable=True))
|
||||
|
||||
for table in ("tests", "evidences", "campaigns"):
|
||||
cols = _column_names(table)
|
||||
if "data_classification" not in cols:
|
||||
op.add_column(
|
||||
table,
|
||||
sa.Column(
|
||||
"data_classification",
|
||||
sa.String(20),
|
||||
nullable=False,
|
||||
server_default="internal",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
for table in ("campaigns", "evidences", "tests"):
|
||||
cols = _column_names(table)
|
||||
if "data_classification" in cols:
|
||||
op.drop_column(table, "data_classification")
|
||||
|
||||
audit_cols = _column_names("audit_logs")
|
||||
for col in ("session_id", "integrity_hash", "user_agent", "ip_address"):
|
||||
if col in audit_cols:
|
||||
op.drop_column("audit_logs", col)
|
||||
@@ -35,3 +35,10 @@ class TestResult(str, enum.Enum):
|
||||
detected = "detected"
|
||||
not_detected = "not_detected"
|
||||
partially_detected = "partially_detected"
|
||||
|
||||
|
||||
class DataClassification(str, enum.Enum):
|
||||
public = "public"
|
||||
internal = "internal"
|
||||
sensitive = "sensitive"
|
||||
restricted = "restricted"
|
||||
|
||||
@@ -23,6 +23,7 @@ from app.services.campaign_scheduler_service import check_and_run_recurring_camp
|
||||
from app.jobs.jira_sync_job import sync_all_jira_links
|
||||
from app.services.osint_enrichment_service import enrich_all_techniques
|
||||
from app.services.stale_detection_service import detect_stale_coverage
|
||||
from app.jobs.retention_job import run_retention_job
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -217,10 +218,19 @@ def start_scheduler() -> None:
|
||||
name="Stale coverage detection (daily)",
|
||||
replace_existing=True,
|
||||
)
|
||||
scheduler.add_job(
|
||||
run_retention_job,
|
||||
trigger="interval",
|
||||
hours=24,
|
||||
id="retention_policies",
|
||||
name="Data retention policies (daily)",
|
||||
replace_existing=True,
|
||||
)
|
||||
scheduler.start()
|
||||
logger.info(
|
||||
"Background scheduler started — mitre_sync (24h), intel_scan (7d), "
|
||||
"notification_cleanup (24h), weekly_snapshot (Sundays 00:00), "
|
||||
"recurring_campaigns (daily), jira_sync (1h), "
|
||||
"osint_enrichment (weekly), stale_detection (daily)"
|
||||
"osint_enrichment (weekly), stale_detection (daily), "
|
||||
"retention_policies (daily)"
|
||||
)
|
||||
|
||||
@@ -0,0 +1,53 @@
|
||||
"""Data retention policies — scheduled cleanup of aged records."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import SessionLocal
|
||||
from app.models.audit import AuditLog
|
||||
from app.services.notification_service import cleanup_old_notifications
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
AUDIT_LOG_RETENTION_DAYS = 730
|
||||
|
||||
|
||||
def apply_retention_policies(db: Session) -> dict[str, int]:
|
||||
"""Apply retention rules. Commits the session before returning."""
|
||||
cutoff = datetime.now(timezone.utc) - timedelta(days=AUDIT_LOG_RETENTION_DAYS)
|
||||
deleted_audit = (
|
||||
db.query(AuditLog)
|
||||
.filter(AuditLog.timestamp < cutoff)
|
||||
.delete(synchronize_session=False)
|
||||
)
|
||||
if deleted_audit:
|
||||
logger.info(
|
||||
"Retention: deleted %d audit logs older than %d days",
|
||||
deleted_audit,
|
||||
AUDIT_LOG_RETENTION_DAYS,
|
||||
)
|
||||
|
||||
deleted_notifications = cleanup_old_notifications(db, days=90)
|
||||
db.commit()
|
||||
return {
|
||||
"audit_logs_deleted": deleted_audit,
|
||||
"notifications_deleted": deleted_notifications,
|
||||
}
|
||||
|
||||
|
||||
def run_retention_job() -> None:
|
||||
"""Entry point for the daily retention scheduler job."""
|
||||
logger.info("Scheduled retention job starting...")
|
||||
db = SessionLocal()
|
||||
try:
|
||||
summary = apply_retention_policies(db)
|
||||
logger.info("Retention job finished — %s", summary)
|
||||
except Exception:
|
||||
logger.exception("Retention job failed")
|
||||
db.rollback()
|
||||
finally:
|
||||
db.close()
|
||||
@@ -0,0 +1,6 @@
|
||||
"""Shared SlowAPI rate limiter for all routers."""
|
||||
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
+19
-5
@@ -6,9 +6,8 @@ from fastapi import FastAPI, Request, status
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from slowapi import Limiter, _rate_limit_exceeded_handler
|
||||
from slowapi import _rate_limit_exceeded_handler
|
||||
from slowapi.errors import RateLimitExceeded
|
||||
from slowapi.util import get_remote_address
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from app.routers import auth as auth_router
|
||||
@@ -40,6 +39,8 @@ from app.routers import advanced_metrics as advanced_metrics_router
|
||||
from app.routers import osint as osint_router
|
||||
from app.domain.errors import DomainError
|
||||
from app.middleware.error_handler import domain_exception_handler
|
||||
from app.middleware.request_context import RequestContextMiddleware
|
||||
from app.limiter import limiter
|
||||
from app.storage import ensure_bucket_exists
|
||||
from app.jobs.mitre_sync_job import start_scheduler, scheduler
|
||||
|
||||
@@ -71,10 +72,11 @@ app = FastAPI(
|
||||
)
|
||||
|
||||
# ── Rate Limiter ──────────────────────────────────────────────────────────
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
app.state.limiter = limiter
|
||||
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
||||
|
||||
app.add_middleware(RequestContextMiddleware)
|
||||
|
||||
# ── Domain exception → HTTP mapping ──────────────────────────────────────
|
||||
app.add_exception_handler(DomainError, domain_exception_handler)
|
||||
|
||||
@@ -136,15 +138,27 @@ def health():
|
||||
# ── Exception Handlers ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _serialize_validation_errors(exc: RequestValidationError) -> list[dict]:
|
||||
"""Return validation errors safe for JSON (no raw exception objects)."""
|
||||
serialized: list[dict] = []
|
||||
for err in exc.errors():
|
||||
item = dict(err)
|
||||
ctx = item.get("ctx")
|
||||
if isinstance(ctx, dict):
|
||||
item["ctx"] = {key: str(value) for key, value in ctx.items()}
|
||||
serialized.append(item)
|
||||
return serialized
|
||||
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
||||
"""Handle validation errors with consistent format."""
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
content={
|
||||
"detail": "Validation error",
|
||||
"code": "VALIDATION_ERROR",
|
||||
"errors": exc.errors(),
|
||||
"errors": _serialize_validation_errors(exc),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
"""Request context middleware — captures client IP and User-Agent per request."""
|
||||
|
||||
from contextvars import ContextVar
|
||||
|
||||
from fastapi import Request
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
request_ip: ContextVar[str] = ContextVar("request_ip", default="")
|
||||
request_user_agent: ContextVar[str] = ContextVar("request_user_agent", default="")
|
||||
|
||||
|
||||
def resolve_client_ip(request: Request) -> str:
|
||||
"""Extract the client IP, honouring ``X-Forwarded-For`` when present."""
|
||||
forwarded = request.headers.get("X-Forwarded-For")
|
||||
if forwarded:
|
||||
return forwarded.split(",")[0].strip()
|
||||
if request.client:
|
||||
return request.client.host
|
||||
return "unknown"
|
||||
|
||||
|
||||
class RequestContextMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
request_ip.set(resolve_client_ip(request))
|
||||
request_user_agent.set(request.headers.get("User-Agent", ""))
|
||||
return await call_next(request)
|
||||
@@ -22,6 +22,10 @@ class AuditLog(Base):
|
||||
entity_id = Column(String, nullable=True)
|
||||
timestamp = Column(DateTime(timezone=True), server_default=func.now())
|
||||
details = Column(JSONB, nullable=True)
|
||||
ip_address = Column(String(45), nullable=True)
|
||||
user_agent = Column(String(500), nullable=True)
|
||||
integrity_hash = Column(String(64), nullable=True)
|
||||
session_id = Column(String(100), nullable=True)
|
||||
|
||||
# Relationships
|
||||
user = relationship("User")
|
||||
|
||||
@@ -53,6 +53,7 @@ class Campaign(Base):
|
||||
target_platform = Column(String, nullable=True)
|
||||
tags = Column(JSONB, nullable=True, default=[])
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
data_classification = Column(String(20), nullable=False, server_default="internal")
|
||||
|
||||
# Recurring scheduling fields
|
||||
is_recurring = Column(Boolean, default=False)
|
||||
|
||||
@@ -6,6 +6,7 @@ working with ``from app.models.enums import ...``.
|
||||
"""
|
||||
|
||||
from app.domain.enums import ( # noqa: F401
|
||||
DataClassification,
|
||||
TeamSide,
|
||||
TechniqueStatus,
|
||||
TestResult,
|
||||
|
||||
@@ -28,6 +28,7 @@ class Evidence(Base):
|
||||
uploaded_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
team = Column(Enum(TeamSide, name="teamside"), nullable=False, default=TeamSide.red)
|
||||
notes = Column(Text, nullable=True)
|
||||
data_classification = Column(String(20), nullable=False, server_default="internal")
|
||||
|
||||
# Relationships
|
||||
test = relationship("Test", back_populates="evidences")
|
||||
|
||||
@@ -62,6 +62,7 @@ class Test(Base):
|
||||
# ── Re-test fields ────────────────────────────────────────────
|
||||
retest_of = Column(UUID(as_uuid=True), ForeignKey("tests.id"), nullable=True)
|
||||
retest_count = Column(Integer, default=0)
|
||||
data_classification = Column(String(20), nullable=False, server_default="internal")
|
||||
|
||||
# ── Relationships ───────────────────────────────────────────────
|
||||
technique = relationship("Technique", back_populates="tests")
|
||||
|
||||
+51
-56
@@ -11,39 +11,33 @@ import os
|
||||
|
||||
from fastapi import APIRouter, Cookie, Depends, Request, Response
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from jose import jwt, JWTError
|
||||
|
||||
from app.auth import create_access_token, blacklist_token
|
||||
from app.auth import create_access_token, blacklist_token, verify_password
|
||||
from app.config import settings
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import get_current_user
|
||||
from app.domain.errors import BusinessRuleViolation, PermissionViolation
|
||||
from app.domain.unit_of_work import UnitOfWork
|
||||
from app.limiter import limiter
|
||||
from app.middleware.request_context import resolve_client_ip
|
||||
from app.models.user import User
|
||||
from app.services.auth_service import authenticate_user, change_password as auth_change_password
|
||||
from app.services.auth_service import (
|
||||
_DUMMY_HASH,
|
||||
change_password as auth_change_password,
|
||||
)
|
||||
from app.services.audit_service import log_action
|
||||
from app.schemas.auth import TokenResponse, UserOut
|
||||
from app.schemas.user import PasswordChange
|
||||
|
||||
# Rate limiter instance (shares backend state via app.state.limiter)
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
# Detect whether we're behind HTTPS (production) so the cookie can be Secure
|
||||
_IS_HTTPS = os.environ.get("AEGIS_ENV", "").lower() == "production"
|
||||
|
||||
# Cookie name used to transport the JWT
|
||||
_COOKIE_NAME = "aegis_token"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /auth/login
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.post("/login", response_model=TokenResponse)
|
||||
@limiter.limit("5/minute")
|
||||
def login(
|
||||
@@ -54,19 +48,49 @@ def login(
|
||||
):
|
||||
"""Authenticate a user and return a JWT access token.
|
||||
|
||||
Rate-limited to **5 attempts per minute per IP** to prevent brute-force
|
||||
attacks. The token is set as an HttpOnly cookie **and** returned in the
|
||||
JSON body for API/Swagger compatibility.
|
||||
Rate-limited to **5 attempts per minute per IP**. Failed and successful
|
||||
logins are recorded in the audit log (SEC-009).
|
||||
"""
|
||||
user = authenticate_user(
|
||||
db,
|
||||
username=form_data.username,
|
||||
password=form_data.password,
|
||||
)
|
||||
user = db.query(User).filter(User.username == form_data.username).first()
|
||||
target_hash = user.hashed_password if user else _DUMMY_HASH
|
||||
password_valid = verify_password(form_data.password, target_hash)
|
||||
ip = resolve_client_ip(request)
|
||||
|
||||
if user is None or not password_valid:
|
||||
with UnitOfWork(db) as uow:
|
||||
log_action(
|
||||
db,
|
||||
user.id if user else None,
|
||||
"LOGIN_FAILED",
|
||||
"auth",
|
||||
None,
|
||||
details={
|
||||
"username": form_data.username,
|
||||
"ip": ip,
|
||||
"reason": "invalid_credentials",
|
||||
},
|
||||
ip_address=ip,
|
||||
)
|
||||
uow.commit()
|
||||
raise BusinessRuleViolation("Incorrect username or password")
|
||||
|
||||
if not user.is_active:
|
||||
raise PermissionViolation("Account is disabled. Contact an administrator.")
|
||||
|
||||
access_token = create_access_token(data={"sub": user.username})
|
||||
|
||||
# Set HttpOnly cookie — inaccessible from JS
|
||||
with UnitOfWork(db) as uow:
|
||||
log_action(
|
||||
db,
|
||||
user.id,
|
||||
"LOGIN_SUCCESS",
|
||||
"auth",
|
||||
str(user.id),
|
||||
details={"username": user.username, "ip": ip},
|
||||
ip_address=ip,
|
||||
)
|
||||
uow.commit()
|
||||
|
||||
response.set_cookie(
|
||||
key=_COOKIE_NAME,
|
||||
value=access_token,
|
||||
@@ -80,27 +104,13 @@ def login(
|
||||
return TokenResponse(access_token=access_token)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /auth/logout
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.post("/logout")
|
||||
def logout(
|
||||
request: Request,
|
||||
response: Response,
|
||||
aegis_token: str | None = Cookie(None),
|
||||
):
|
||||
"""Clear the authentication cookie and revoke the current token.
|
||||
|
||||
The token's ``jti`` is added to the Redis blacklist so it cannot
|
||||
be reused even if the cookie has already been copied elsewhere.
|
||||
The blacklist entry auto-expires when the token's ``exp`` is reached.
|
||||
|
||||
When both HttpOnly cookie and ``Authorization: Bearer`` are present
|
||||
(typical for API clients), **both** are revoked so the session cannot
|
||||
survive on whichever credential the next request prefers.
|
||||
"""
|
||||
"""Clear the authentication cookie and revoke the current token."""
|
||||
bearer = (
|
||||
request.headers.get("Authorization")
|
||||
or request.headers.get("authorization")
|
||||
@@ -124,7 +134,7 @@ def logout(
|
||||
if jti:
|
||||
blacklist_token(jti, float(exp))
|
||||
except JWTError:
|
||||
pass # token already invalid — nothing to revoke for this raw value
|
||||
pass
|
||||
|
||||
response.delete_cookie(
|
||||
key=_COOKIE_NAME,
|
||||
@@ -136,34 +146,19 @@ def logout(
|
||||
return {"detail": "Logged out"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /auth/me
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserOut)
|
||||
def read_current_user(current_user: User = Depends(get_current_user)):
|
||||
"""Return the profile of the currently authenticated user."""
|
||||
return current_user
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /auth/change-password
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.post("/change-password")
|
||||
def change_password(
|
||||
body: PasswordChange,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Change the current user's password.
|
||||
|
||||
Requires the current password for verification. On success the
|
||||
``must_change_password`` flag is cleared so the user can proceed
|
||||
normally.
|
||||
"""
|
||||
"""Change the current user's password."""
|
||||
auth_change_password(
|
||||
db,
|
||||
current_user,
|
||||
|
||||
@@ -24,7 +24,7 @@ import os
|
||||
import uuid as _uuid
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, File, Form, Query, UploadFile, status
|
||||
from fastapi import APIRouter, Depends, File, Form, Query, Request, UploadFile, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
@@ -44,6 +44,7 @@ from app.services.evidence_service import (
|
||||
validate_file,
|
||||
validate_upload_permission,
|
||||
)
|
||||
from app.limiter import limiter
|
||||
from app.storage import get_presigned_url, upload_file
|
||||
|
||||
router = APIRouter(tags=["evidence"])
|
||||
@@ -78,7 +79,9 @@ def _evidence_to_out(evidence: Evidence) -> EvidenceOut:
|
||||
response_model=EvidenceOut,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
)
|
||||
@limiter.limit("10/minute")
|
||||
async def upload_evidence(
|
||||
request: Request,
|
||||
test_id: _uuid.UUID,
|
||||
file: UploadFile = File(...),
|
||||
team: TeamSide = Form(TeamSide.red),
|
||||
|
||||
@@ -2,13 +2,14 @@
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
from fastapi.responses import FileResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import get_current_user, require_any_role
|
||||
from app.models.user import User
|
||||
from app.limiter import limiter
|
||||
from app.services import report_generation_service
|
||||
|
||||
router = APIRouter(prefix="/reports/generate", tags=["professional-reports"])
|
||||
@@ -21,7 +22,9 @@ _MEDIA_TYPES = {
|
||||
|
||||
|
||||
@router.get("/purple-campaign/{campaign_id}")
|
||||
@limiter.limit("5/minute")
|
||||
def generate_purple_report(
|
||||
request: Request,
|
||||
campaign_id: UUID,
|
||||
format: str = Query("pdf", pattern="^(pdf|docx|html)$"),
|
||||
db: Session = Depends(get_db),
|
||||
@@ -39,7 +42,9 @@ def generate_purple_report(
|
||||
|
||||
|
||||
@router.get("/coverage-summary")
|
||||
@limiter.limit("5/minute")
|
||||
def generate_coverage_report(
|
||||
request: Request,
|
||||
format: str = Query("pdf", pattern="^(pdf|docx|html)$"),
|
||||
db: Session = Depends(get_db),
|
||||
user: User = Depends(require_any_role("red_lead", "blue_lead", "viewer")),
|
||||
@@ -56,7 +61,9 @@ def generate_coverage_report(
|
||||
|
||||
|
||||
@router.get("/executive-summary")
|
||||
@limiter.limit("5/minute")
|
||||
def generate_executive_report(
|
||||
request: Request,
|
||||
format: str = Query("pdf", pattern="^(pdf|docx|html)$"),
|
||||
db: Session = Depends(get_db),
|
||||
user: User = Depends(require_any_role("red_lead", "blue_lead", "viewer")),
|
||||
@@ -73,7 +80,9 @@ def generate_executive_report(
|
||||
|
||||
|
||||
@router.get("/quarterly-summary")
|
||||
@limiter.limit("5/minute")
|
||||
def generate_quarterly_report(
|
||||
request: Request,
|
||||
format: str = Query("pdf", pattern="^(pdf|docx|html)$"),
|
||||
db: Session = Depends(get_db),
|
||||
user: User = Depends(require_any_role("red_lead", "blue_lead", "viewer")),
|
||||
@@ -90,7 +99,9 @@ def generate_quarterly_report(
|
||||
|
||||
|
||||
@router.get("/technique/{technique_id}")
|
||||
@limiter.limit("5/minute")
|
||||
def generate_technique_report(
|
||||
request: Request,
|
||||
technique_id: UUID,
|
||||
format: str = Query("pdf", pattern="^(pdf|docx|html)$"),
|
||||
db: Session = Depends(get_db),
|
||||
|
||||
@@ -7,7 +7,7 @@ scheduler health introspection.
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
@@ -17,6 +17,7 @@ from app.services.mitre_sync_service import sync_mitre
|
||||
from app.services.intel_service import scan_intel
|
||||
from app.services.atomic_import_service import import_atomic_red_team
|
||||
from app.jobs.mitre_sync_job import scheduler
|
||||
from app.limiter import limiter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -24,7 +25,9 @@ router = APIRouter(prefix="/system", tags=["system"])
|
||||
|
||||
|
||||
@router.post("/sync-mitre")
|
||||
@limiter.limit("2/hour")
|
||||
def trigger_mitre_sync(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
@@ -63,7 +66,9 @@ def trigger_intel_scan(
|
||||
|
||||
|
||||
@router.post("/import-atomic-tests")
|
||||
@limiter.limit("2/hour")
|
||||
def trigger_atomic_import(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
|
||||
@@ -21,11 +21,13 @@ GET /tests/{id}/timeline — audit-log history for this test
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import get_current_user, require_any_role
|
||||
from app.dependencies.auth import get_current_user, require_any_role, require_role
|
||||
from app.domain.enums import DataClassification
|
||||
from app.limiter import limiter
|
||||
from app.models.enums import TestState
|
||||
from app.models.user import User
|
||||
from app.schemas.test import (
|
||||
@@ -37,6 +39,7 @@ from app.schemas.test import (
|
||||
TestRedValidate,
|
||||
TestBlueValidate,
|
||||
TestRemediationUpdate,
|
||||
TestClassificationUpdate,
|
||||
)
|
||||
from app.schemas.test_template import TestTemplateInstantiate
|
||||
from app.domain.unit_of_work import UnitOfWork
|
||||
@@ -112,7 +115,9 @@ def list_tests(
|
||||
response_model=TestOut,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
)
|
||||
@limiter.limit("30/minute")
|
||||
def create_test(
|
||||
request: Request,
|
||||
payload: TestCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||
@@ -152,7 +157,9 @@ def create_test(
|
||||
response_model=TestOut,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
)
|
||||
@limiter.limit("30/minute")
|
||||
def create_test_from_template(
|
||||
request: Request,
|
||||
payload: TestTemplateInstantiate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||
@@ -241,6 +248,36 @@ def update_test(
|
||||
return test
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PATCH /tests/{id}/classification — admin data classification
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.patch("/{test_id}/classification", response_model=TestOut)
|
||||
def update_test_classification(
|
||||
test_id: uuid.UUID,
|
||||
payload: TestClassificationUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Update the data classification label for a test (admin only)."""
|
||||
with UnitOfWork(db) as uow:
|
||||
test = crud_get_test_or_raise(db, test_id)
|
||||
test.data_classification = payload.data_classification.value
|
||||
db.flush()
|
||||
log_action(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
action="update_test_classification",
|
||||
entity_type="test",
|
||||
entity_id=test.id,
|
||||
details={"data_classification": payload.data_classification.value},
|
||||
)
|
||||
uow.commit()
|
||||
db.refresh(test)
|
||||
return test
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PATCH /tests/{id}/red — Red Team update (draft, red_executing)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -5,6 +5,7 @@ from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from app.domain.enums import DataClassification
|
||||
from app.models.enums import TestResult, TestState
|
||||
|
||||
|
||||
@@ -25,6 +26,12 @@ class TestCreate(BaseModel):
|
||||
# ── Update (general) ───────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestClassificationUpdate(BaseModel):
|
||||
"""Admin-only payload for changing data classification."""
|
||||
|
||||
data_classification: DataClassification
|
||||
|
||||
|
||||
class TestUpdate(BaseModel):
|
||||
"""Payload for partially updating an existing test.
|
||||
Every field is optional so callers send only what changed."""
|
||||
@@ -152,6 +159,7 @@ class TestOut(BaseModel):
|
||||
# Re-test fields
|
||||
retest_of: uuid.UUID | None = None
|
||||
retest_count: int = 0
|
||||
data_classification: str = "internal"
|
||||
|
||||
# Technique info (populated when joined)
|
||||
technique_mitre_id: str | None = None
|
||||
|
||||
@@ -1,32 +1,65 @@
|
||||
"""Audit logging with request context and integrity hashing."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.middleware.request_context import request_ip, request_user_agent
|
||||
from app.models.audit import AuditLog
|
||||
|
||||
|
||||
def _integrity_payload(entry: AuditLog) -> str:
|
||||
ts = entry.timestamp
|
||||
if ts is None:
|
||||
ts = datetime.now(timezone.utc)
|
||||
user_part = str(entry.user_id) if entry.user_id else ""
|
||||
entity_type = entry.entity_type or ""
|
||||
entity_id = entry.entity_id or ""
|
||||
return f"{user_part}:{entry.action}:{entity_type}:{entity_id}:{ts.isoformat()}"
|
||||
|
||||
|
||||
def compute_integrity_hash(entry: AuditLog) -> str:
|
||||
"""Return the SHA-256 hex digest for an audit log entry."""
|
||||
return hashlib.sha256(_integrity_payload(entry).encode()).hexdigest()
|
||||
|
||||
|
||||
def verify_audit_integrity(entry: AuditLog) -> bool:
|
||||
"""Return whether the stored hash matches the entry's current fields."""
|
||||
if not entry.integrity_hash:
|
||||
return False
|
||||
return entry.integrity_hash == compute_integrity_hash(entry)
|
||||
|
||||
|
||||
def log_action(
|
||||
db: Session,
|
||||
user_id,
|
||||
action: str,
|
||||
entity_type: str = None,
|
||||
entity_id: str = None,
|
||||
details: dict = None
|
||||
):
|
||||
"""
|
||||
Log an action to the audit log.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_id: UUID of the user performing the action (can be None for system actions)
|
||||
action: Description of the action (e.g., "create_test", "validate_technique")
|
||||
entity_type: Type of entity affected (e.g., "technique", "test", "user")
|
||||
entity_id: ID of the entity affected
|
||||
details: Additional details as a dictionary (stored as JSONB)
|
||||
"""
|
||||
log = AuditLog(
|
||||
entity_type: str | None = None,
|
||||
entity_id: str | None = None,
|
||||
details: dict | None = None,
|
||||
*,
|
||||
ip_address: str | None = None,
|
||||
user_agent: str | None = None,
|
||||
session_id: str | None = None,
|
||||
) -> AuditLog:
|
||||
"""Record an audit event. Does not commit — the caller owns the transaction."""
|
||||
ip = ip_address if ip_address is not None else request_ip.get("")
|
||||
ua = user_agent if user_agent is not None else request_user_agent.get("")
|
||||
|
||||
entry = AuditLog(
|
||||
user_id=user_id,
|
||||
action=action,
|
||||
entity_type=entity_type,
|
||||
entity_id=str(entity_id) if entity_id else None,
|
||||
details=details,
|
||||
ip_address=ip or None,
|
||||
user_agent=ua or None,
|
||||
session_id=session_id,
|
||||
)
|
||||
db.add(log)
|
||||
db.add(entry)
|
||||
db.flush()
|
||||
entry.integrity_hash = compute_integrity_hash(entry)
|
||||
return entry
|
||||
|
||||
@@ -121,10 +121,8 @@ def client(db, monkeypatch):
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
if hasattr(app.state, "limiter"):
|
||||
app.state.limiter.enabled = False
|
||||
from app.routers.auth import limiter as auth_limiter
|
||||
auth_limiter.enabled = False
|
||||
from app.limiter import limiter
|
||||
limiter.enabled = False
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
with TestClient(app) as test_client:
|
||||
|
||||
@@ -0,0 +1,84 @@
|
||||
"""Tests for enhanced audit trail (IP, user-agent, integrity hash)."""
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from app.middleware.request_context import request_ip, request_user_agent
|
||||
from app.models.audit import AuditLog
|
||||
from app.services.audit_service import (
|
||||
compute_integrity_hash,
|
||||
log_action,
|
||||
verify_audit_integrity,
|
||||
)
|
||||
from app.jobs.retention_job import apply_retention_policies
|
||||
|
||||
|
||||
class TestAuditIntegrity:
|
||||
def test_integrity_hash_set_on_log(self, db):
|
||||
entry = log_action(
|
||||
db,
|
||||
user_id=None,
|
||||
action="test_action",
|
||||
entity_type="test",
|
||||
entity_id="abc",
|
||||
ip_address="10.0.0.1",
|
||||
user_agent="pytest",
|
||||
)
|
||||
db.commit()
|
||||
db.refresh(entry)
|
||||
assert entry.integrity_hash
|
||||
assert len(entry.integrity_hash) == 64
|
||||
assert verify_audit_integrity(entry)
|
||||
|
||||
def test_tampered_entry_fails_integrity(self, db):
|
||||
entry = log_action(
|
||||
db,
|
||||
user_id=None,
|
||||
action="test_action",
|
||||
entity_type="test",
|
||||
entity_id="abc",
|
||||
)
|
||||
db.commit()
|
||||
entry.entity_id = "tampered"
|
||||
assert not verify_audit_integrity(entry)
|
||||
|
||||
def test_recomputed_hash_matches_stored(self, db):
|
||||
entry = log_action(db, None, "update", "user", "1")
|
||||
db.commit()
|
||||
assert entry.integrity_hash == compute_integrity_hash(entry)
|
||||
|
||||
|
||||
class TestRequestContext:
|
||||
def test_context_vars_used_by_log_action(self, db):
|
||||
token_ip = request_ip.set("203.0.113.42")
|
||||
token_ua = request_user_agent.set("AegisTestClient/1.0")
|
||||
try:
|
||||
entry = log_action(db, None, "ctx_action", "system", None)
|
||||
db.commit()
|
||||
assert entry.ip_address == "203.0.113.42"
|
||||
assert entry.user_agent == "AegisTestClient/1.0"
|
||||
finally:
|
||||
request_ip.reset(token_ip)
|
||||
request_user_agent.reset(token_ua)
|
||||
|
||||
|
||||
class TestRetentionJob:
|
||||
def test_deletes_old_audit_logs(self, db):
|
||||
old = AuditLog(
|
||||
action="old",
|
||||
entity_type="system",
|
||||
timestamp=datetime.now(timezone.utc) - timedelta(days=800),
|
||||
)
|
||||
recent = AuditLog(
|
||||
action="recent",
|
||||
entity_type="system",
|
||||
timestamp=datetime.now(timezone.utc) - timedelta(days=1),
|
||||
)
|
||||
db.add_all([old, recent])
|
||||
db.commit()
|
||||
|
||||
summary = apply_retention_policies(db)
|
||||
assert summary["audit_logs_deleted"] >= 1
|
||||
remaining = db.query(AuditLog).all()
|
||||
assert all(log.action != "old" for log in remaining)
|
||||
@@ -0,0 +1,58 @@
|
||||
"""Tests for login attempt auditing (SEC-009)."""
|
||||
|
||||
from app.models.audit import AuditLog
|
||||
|
||||
|
||||
def test_login_failed_creates_audit_entry(client, admin_user, db):
|
||||
response = client.post(
|
||||
"/api/v1/auth/login",
|
||||
data={"username": "admin", "password": "wrong"},
|
||||
headers={"X-Forwarded-For": "198.51.100.10", "User-Agent": "LoginAuditTest/1.0"},
|
||||
)
|
||||
assert response.status_code == 400
|
||||
|
||||
log = (
|
||||
db.query(AuditLog)
|
||||
.filter(AuditLog.action == "LOGIN_FAILED")
|
||||
.order_by(AuditLog.timestamp.desc())
|
||||
.first()
|
||||
)
|
||||
assert log is not None
|
||||
assert log.entity_type == "auth"
|
||||
assert log.details["username"] == "admin"
|
||||
assert log.details["reason"] == "invalid_credentials"
|
||||
assert log.ip_address == "198.51.100.10"
|
||||
assert log.user_agent == "LoginAuditTest/1.0"
|
||||
assert log.integrity_hash
|
||||
|
||||
|
||||
def test_login_success_creates_audit_entry(client, admin_user, db):
|
||||
client.cookies.clear()
|
||||
response = client.post(
|
||||
"/api/v1/auth/login",
|
||||
data={"username": "admin", "password": "admin123"},
|
||||
headers={"X-Forwarded-For": "198.51.100.20"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
log = (
|
||||
db.query(AuditLog)
|
||||
.filter(AuditLog.action == "LOGIN_SUCCESS")
|
||||
.order_by(AuditLog.timestamp.desc())
|
||||
.first()
|
||||
)
|
||||
assert log is not None
|
||||
assert log.user_id == admin_user.id
|
||||
assert log.ip_address == "198.51.100.20"
|
||||
assert log.integrity_hash
|
||||
|
||||
|
||||
def test_login_unknown_user_still_audited(client, db):
|
||||
response = client.post(
|
||||
"/api/v1/auth/login",
|
||||
data={"username": "nobody", "password": "password"},
|
||||
)
|
||||
assert response.status_code == 400
|
||||
log = db.query(AuditLog).filter(AuditLog.action == "LOGIN_FAILED").first()
|
||||
assert log is not None
|
||||
assert log.user_id is None
|
||||
@@ -0,0 +1,72 @@
|
||||
"""Tests for data classification fields and admin updates."""
|
||||
|
||||
from app.models.enums import TestState
|
||||
from app.models.test import Test
|
||||
from app.models.technique import Technique
|
||||
|
||||
|
||||
def _seed_technique(db) -> Technique:
|
||||
technique = Technique(
|
||||
mitre_id="T9999",
|
||||
name="Test Technique",
|
||||
tactic="test",
|
||||
platforms=["linux"],
|
||||
)
|
||||
db.add(technique)
|
||||
db.commit()
|
||||
db.refresh(technique)
|
||||
return technique
|
||||
|
||||
|
||||
def test_new_test_defaults_to_internal(db, red_lead_user):
|
||||
technique = _seed_technique(db)
|
||||
test = Test(
|
||||
technique_id=technique.id,
|
||||
name="Classification test",
|
||||
created_by=red_lead_user.id,
|
||||
)
|
||||
db.add(test)
|
||||
db.commit()
|
||||
db.refresh(test)
|
||||
assert test.data_classification == "internal"
|
||||
|
||||
|
||||
def test_admin_can_update_classification(client, db, admin_user, admin_token, red_lead_user):
|
||||
technique = _seed_technique(db)
|
||||
test = Test(
|
||||
technique_id=technique.id,
|
||||
name="Classify me",
|
||||
created_by=red_lead_user.id,
|
||||
state=TestState.draft,
|
||||
)
|
||||
db.add(test)
|
||||
db.commit()
|
||||
|
||||
response = client.patch(
|
||||
f"/api/v1/tests/{test.id}/classification",
|
||||
json={"data_classification": "sensitive"},
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["data_classification"] == "sensitive"
|
||||
|
||||
db.refresh(test)
|
||||
assert test.data_classification == "sensitive"
|
||||
|
||||
|
||||
def test_non_admin_cannot_update_classification(client, db, admin_user, red_lead_token, red_lead_user):
|
||||
technique = _seed_technique(db)
|
||||
test = Test(
|
||||
technique_id=technique.id,
|
||||
name="Protected",
|
||||
created_by=red_lead_user.id,
|
||||
)
|
||||
db.add(test)
|
||||
db.commit()
|
||||
|
||||
response = client.patch(
|
||||
f"/api/v1/tests/{test.id}/classification",
|
||||
json={"data_classification": "restricted"},
|
||||
headers={"Authorization": f"Bearer {red_lead_token}"},
|
||||
)
|
||||
assert response.status_code == 403
|
||||
@@ -0,0 +1,25 @@
|
||||
"""Smoke tests for extended rate-limit decorators (SEC-003)."""
|
||||
|
||||
import inspect
|
||||
|
||||
from app.routers import evidence, professional_reports, system, tests
|
||||
|
||||
|
||||
def test_sync_mitre_has_hourly_limit():
|
||||
source = inspect.getsource(system.trigger_mitre_sync)
|
||||
assert "2/hour" in source
|
||||
|
||||
|
||||
def test_create_test_has_per_minute_limit():
|
||||
source = inspect.getsource(tests.create_test)
|
||||
assert "30/minute" in source
|
||||
|
||||
|
||||
def test_upload_evidence_has_per_minute_limit():
|
||||
source = inspect.getsource(evidence.upload_evidence)
|
||||
assert "10/minute" in source
|
||||
|
||||
|
||||
def test_report_endpoints_have_per_minute_limit():
|
||||
source = inspect.getsource(professional_reports.generate_coverage_report)
|
||||
assert "5/minute" in source
|
||||
@@ -42,6 +42,14 @@ class TestUsernameValidation:
|
||||
with pytest.raises(ValidationError, match="3-50 characters"):
|
||||
UserCreate(username="john@doe", password="SecurePass123!@#")
|
||||
|
||||
def test_reserved_username_system(self):
|
||||
with pytest.raises(ValidationError):
|
||||
UserCreate(username="system", password="SecurePass123!@#")
|
||||
|
||||
def test_invalid_username_path_chars(self):
|
||||
with pytest.raises(ValidationError):
|
||||
UserCreate(username="../admin", password="SecurePass123!@#")
|
||||
|
||||
def test_reserved_username_admin(self):
|
||||
with pytest.raises(ValidationError, match="reserved"):
|
||||
UserCreate(username="admin", password="SecurePass123!@#")
|
||||
|
||||
@@ -0,0 +1,59 @@
|
||||
"""API-level validation tests for user creation (SEC-004, SEC-007)."""
|
||||
|
||||
|
||||
def test_create_user_weak_password_rejected(client, admin_user, admin_token):
|
||||
response = client.post(
|
||||
"/api/v1/users",
|
||||
json={
|
||||
"username": "newuser",
|
||||
"password": "123",
|
||||
"email": "new@test.com",
|
||||
"role": "viewer",
|
||||
},
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
assert "password" in response.text.lower()
|
||||
|
||||
|
||||
def test_create_user_reserved_username(client, admin_user, admin_token):
|
||||
response = client.post(
|
||||
"/api/v1/users",
|
||||
json={
|
||||
"username": "system",
|
||||
"password": "SecurePass123!@#",
|
||||
"email": "sys@test.com",
|
||||
"role": "viewer",
|
||||
},
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_create_user_invalid_username_chars(client, admin_user, admin_token):
|
||||
response = client.post(
|
||||
"/api/v1/users",
|
||||
json={
|
||||
"username": "../admin",
|
||||
"password": "SecurePass123!@#",
|
||||
"email": "bad@test.com",
|
||||
"role": "viewer",
|
||||
},
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_create_user_valid_password_accepted(client, admin_user, admin_token):
|
||||
response = client.post(
|
||||
"/api/v1/users",
|
||||
json={
|
||||
"username": "validuser99",
|
||||
"password": "ValidPass123!@#",
|
||||
"email": "valid@test.com",
|
||||
"role": "viewer",
|
||||
},
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
)
|
||||
assert response.status_code == 201
|
||||
assert response.json()["username"] == "validuser99"
|
||||
Reference in New Issue
Block a user