Compare commits
5 Commits
a8a24b5429
...
bdeeed54e1
| Author | SHA1 | Date | |
|---|---|---|---|
| bdeeed54e1 | |||
| 3e854b7b79 | |||
| 5b29c2fc56 | |||
| 6b076f52b2 | |||
| c0aff4cbeb |
@@ -0,0 +1,58 @@
|
|||||||
|
"""Phase 3: audit trail columns and data classification fields.
|
||||||
|
|
||||||
|
Revision ID: b029phase3
|
||||||
|
Revises: b028phase0
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision: str = "b029phase3"
|
||||||
|
down_revision: Union[str, None] = "b028phase0"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def _column_names(table: str) -> set[str]:
|
||||||
|
bind = op.get_bind()
|
||||||
|
insp = sa.inspect(bind)
|
||||||
|
return {c["name"] for c in insp.get_columns(table)}
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
audit_cols = _column_names("audit_logs")
|
||||||
|
if "ip_address" not in audit_cols:
|
||||||
|
op.add_column("audit_logs", sa.Column("ip_address", sa.String(45), nullable=True))
|
||||||
|
if "user_agent" not in audit_cols:
|
||||||
|
op.add_column("audit_logs", sa.Column("user_agent", sa.String(500), nullable=True))
|
||||||
|
if "integrity_hash" not in audit_cols:
|
||||||
|
op.add_column("audit_logs", sa.Column("integrity_hash", sa.String(64), nullable=True))
|
||||||
|
if "session_id" not in audit_cols:
|
||||||
|
op.add_column("audit_logs", sa.Column("session_id", sa.String(100), nullable=True))
|
||||||
|
|
||||||
|
for table in ("tests", "evidences", "campaigns"):
|
||||||
|
cols = _column_names(table)
|
||||||
|
if "data_classification" not in cols:
|
||||||
|
op.add_column(
|
||||||
|
table,
|
||||||
|
sa.Column(
|
||||||
|
"data_classification",
|
||||||
|
sa.String(20),
|
||||||
|
nullable=False,
|
||||||
|
server_default="internal",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
for table in ("campaigns", "evidences", "tests"):
|
||||||
|
cols = _column_names(table)
|
||||||
|
if "data_classification" in cols:
|
||||||
|
op.drop_column(table, "data_classification")
|
||||||
|
|
||||||
|
audit_cols = _column_names("audit_logs")
|
||||||
|
for col in ("session_id", "integrity_hash", "user_agent", "ip_address"):
|
||||||
|
if col in audit_cols:
|
||||||
|
op.drop_column("audit_logs", col)
|
||||||
@@ -35,3 +35,10 @@ class TestResult(str, enum.Enum):
|
|||||||
detected = "detected"
|
detected = "detected"
|
||||||
not_detected = "not_detected"
|
not_detected = "not_detected"
|
||||||
partially_detected = "partially_detected"
|
partially_detected = "partially_detected"
|
||||||
|
|
||||||
|
|
||||||
|
class DataClassification(str, enum.Enum):
|
||||||
|
public = "public"
|
||||||
|
internal = "internal"
|
||||||
|
sensitive = "sensitive"
|
||||||
|
restricted = "restricted"
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ from app.services.campaign_scheduler_service import check_and_run_recurring_camp
|
|||||||
from app.jobs.jira_sync_job import sync_all_jira_links
|
from app.jobs.jira_sync_job import sync_all_jira_links
|
||||||
from app.services.osint_enrichment_service import enrich_all_techniques
|
from app.services.osint_enrichment_service import enrich_all_techniques
|
||||||
from app.services.stale_detection_service import detect_stale_coverage
|
from app.services.stale_detection_service import detect_stale_coverage
|
||||||
|
from app.jobs.retention_job import run_retention_job
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -217,10 +218,19 @@ def start_scheduler() -> None:
|
|||||||
name="Stale coverage detection (daily)",
|
name="Stale coverage detection (daily)",
|
||||||
replace_existing=True,
|
replace_existing=True,
|
||||||
)
|
)
|
||||||
|
scheduler.add_job(
|
||||||
|
run_retention_job,
|
||||||
|
trigger="interval",
|
||||||
|
hours=24,
|
||||||
|
id="retention_policies",
|
||||||
|
name="Data retention policies (daily)",
|
||||||
|
replace_existing=True,
|
||||||
|
)
|
||||||
scheduler.start()
|
scheduler.start()
|
||||||
logger.info(
|
logger.info(
|
||||||
"Background scheduler started — mitre_sync (24h), intel_scan (7d), "
|
"Background scheduler started — mitre_sync (24h), intel_scan (7d), "
|
||||||
"notification_cleanup (24h), weekly_snapshot (Sundays 00:00), "
|
"notification_cleanup (24h), weekly_snapshot (Sundays 00:00), "
|
||||||
"recurring_campaigns (daily), jira_sync (1h), "
|
"recurring_campaigns (daily), jira_sync (1h), "
|
||||||
"osint_enrichment (weekly), stale_detection (daily)"
|
"osint_enrichment (weekly), stale_detection (daily), "
|
||||||
|
"retention_policies (daily)"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -0,0 +1,53 @@
|
|||||||
|
"""Data retention policies — scheduled cleanup of aged records."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.database import SessionLocal
|
||||||
|
from app.models.audit import AuditLog
|
||||||
|
from app.services.notification_service import cleanup_old_notifications
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
AUDIT_LOG_RETENTION_DAYS = 730
|
||||||
|
|
||||||
|
|
||||||
|
def apply_retention_policies(db: Session) -> dict[str, int]:
|
||||||
|
"""Apply retention rules. Commits the session before returning."""
|
||||||
|
cutoff = datetime.now(timezone.utc) - timedelta(days=AUDIT_LOG_RETENTION_DAYS)
|
||||||
|
deleted_audit = (
|
||||||
|
db.query(AuditLog)
|
||||||
|
.filter(AuditLog.timestamp < cutoff)
|
||||||
|
.delete(synchronize_session=False)
|
||||||
|
)
|
||||||
|
if deleted_audit:
|
||||||
|
logger.info(
|
||||||
|
"Retention: deleted %d audit logs older than %d days",
|
||||||
|
deleted_audit,
|
||||||
|
AUDIT_LOG_RETENTION_DAYS,
|
||||||
|
)
|
||||||
|
|
||||||
|
deleted_notifications = cleanup_old_notifications(db, days=90)
|
||||||
|
db.commit()
|
||||||
|
return {
|
||||||
|
"audit_logs_deleted": deleted_audit,
|
||||||
|
"notifications_deleted": deleted_notifications,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def run_retention_job() -> None:
|
||||||
|
"""Entry point for the daily retention scheduler job."""
|
||||||
|
logger.info("Scheduled retention job starting...")
|
||||||
|
db = SessionLocal()
|
||||||
|
try:
|
||||||
|
summary = apply_retention_policies(db)
|
||||||
|
logger.info("Retention job finished — %s", summary)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Retention job failed")
|
||||||
|
db.rollback()
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
@@ -0,0 +1,6 @@
|
|||||||
|
"""Shared SlowAPI rate limiter for all routers."""
|
||||||
|
|
||||||
|
from slowapi import Limiter
|
||||||
|
from slowapi.util import get_remote_address
|
||||||
|
|
||||||
|
limiter = Limiter(key_func=get_remote_address)
|
||||||
+19
-5
@@ -6,9 +6,8 @@ from fastapi import FastAPI, Request, status
|
|||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
from slowapi import Limiter, _rate_limit_exceeded_handler
|
from slowapi import _rate_limit_exceeded_handler
|
||||||
from slowapi.errors import RateLimitExceeded
|
from slowapi.errors import RateLimitExceeded
|
||||||
from slowapi.util import get_remote_address
|
|
||||||
from sqlalchemy.exc import SQLAlchemyError
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
|
|
||||||
from app.routers import auth as auth_router
|
from app.routers import auth as auth_router
|
||||||
@@ -40,6 +39,8 @@ from app.routers import advanced_metrics as advanced_metrics_router
|
|||||||
from app.routers import osint as osint_router
|
from app.routers import osint as osint_router
|
||||||
from app.domain.errors import DomainError
|
from app.domain.errors import DomainError
|
||||||
from app.middleware.error_handler import domain_exception_handler
|
from app.middleware.error_handler import domain_exception_handler
|
||||||
|
from app.middleware.request_context import RequestContextMiddleware
|
||||||
|
from app.limiter import limiter
|
||||||
from app.storage import ensure_bucket_exists
|
from app.storage import ensure_bucket_exists
|
||||||
from app.jobs.mitre_sync_job import start_scheduler, scheduler
|
from app.jobs.mitre_sync_job import start_scheduler, scheduler
|
||||||
|
|
||||||
@@ -71,10 +72,11 @@ app = FastAPI(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# ── Rate Limiter ──────────────────────────────────────────────────────────
|
# ── Rate Limiter ──────────────────────────────────────────────────────────
|
||||||
limiter = Limiter(key_func=get_remote_address)
|
|
||||||
app.state.limiter = limiter
|
app.state.limiter = limiter
|
||||||
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
||||||
|
|
||||||
|
app.add_middleware(RequestContextMiddleware)
|
||||||
|
|
||||||
# ── Domain exception → HTTP mapping ──────────────────────────────────────
|
# ── Domain exception → HTTP mapping ──────────────────────────────────────
|
||||||
app.add_exception_handler(DomainError, domain_exception_handler)
|
app.add_exception_handler(DomainError, domain_exception_handler)
|
||||||
|
|
||||||
@@ -136,15 +138,27 @@ def health():
|
|||||||
# ── Exception Handlers ────────────────────────────────────────────────────
|
# ── Exception Handlers ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _serialize_validation_errors(exc: RequestValidationError) -> list[dict]:
|
||||||
|
"""Return validation errors safe for JSON (no raw exception objects)."""
|
||||||
|
serialized: list[dict] = []
|
||||||
|
for err in exc.errors():
|
||||||
|
item = dict(err)
|
||||||
|
ctx = item.get("ctx")
|
||||||
|
if isinstance(ctx, dict):
|
||||||
|
item["ctx"] = {key: str(value) for key, value in ctx.items()}
|
||||||
|
serialized.append(item)
|
||||||
|
return serialized
|
||||||
|
|
||||||
|
|
||||||
@app.exception_handler(RequestValidationError)
|
@app.exception_handler(RequestValidationError)
|
||||||
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
||||||
"""Handle validation errors with consistent format."""
|
"""Handle validation errors with consistent format."""
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||||
content={
|
content={
|
||||||
"detail": "Validation error",
|
"detail": "Validation error",
|
||||||
"code": "VALIDATION_ERROR",
|
"code": "VALIDATION_ERROR",
|
||||||
"errors": exc.errors(),
|
"errors": _serialize_validation_errors(exc),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,26 @@
|
|||||||
|
"""Request context middleware — captures client IP and User-Agent per request."""
|
||||||
|
|
||||||
|
from contextvars import ContextVar
|
||||||
|
|
||||||
|
from fastapi import Request
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
|
||||||
|
request_ip: ContextVar[str] = ContextVar("request_ip", default="")
|
||||||
|
request_user_agent: ContextVar[str] = ContextVar("request_user_agent", default="")
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_client_ip(request: Request) -> str:
|
||||||
|
"""Extract the client IP, honouring ``X-Forwarded-For`` when present."""
|
||||||
|
forwarded = request.headers.get("X-Forwarded-For")
|
||||||
|
if forwarded:
|
||||||
|
return forwarded.split(",")[0].strip()
|
||||||
|
if request.client:
|
||||||
|
return request.client.host
|
||||||
|
return "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
class RequestContextMiddleware(BaseHTTPMiddleware):
|
||||||
|
async def dispatch(self, request: Request, call_next):
|
||||||
|
request_ip.set(resolve_client_ip(request))
|
||||||
|
request_user_agent.set(request.headers.get("User-Agent", ""))
|
||||||
|
return await call_next(request)
|
||||||
@@ -22,6 +22,10 @@ class AuditLog(Base):
|
|||||||
entity_id = Column(String, nullable=True)
|
entity_id = Column(String, nullable=True)
|
||||||
timestamp = Column(DateTime(timezone=True), server_default=func.now())
|
timestamp = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
details = Column(JSONB, nullable=True)
|
details = Column(JSONB, nullable=True)
|
||||||
|
ip_address = Column(String(45), nullable=True)
|
||||||
|
user_agent = Column(String(500), nullable=True)
|
||||||
|
integrity_hash = Column(String(64), nullable=True)
|
||||||
|
session_id = Column(String(100), nullable=True)
|
||||||
|
|
||||||
# Relationships
|
# Relationships
|
||||||
user = relationship("User")
|
user = relationship("User")
|
||||||
|
|||||||
@@ -53,6 +53,7 @@ class Campaign(Base):
|
|||||||
target_platform = Column(String, nullable=True)
|
target_platform = Column(String, nullable=True)
|
||||||
tags = Column(JSONB, nullable=True, default=[])
|
tags = Column(JSONB, nullable=True, default=[])
|
||||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
data_classification = Column(String(20), nullable=False, server_default="internal")
|
||||||
|
|
||||||
# Recurring scheduling fields
|
# Recurring scheduling fields
|
||||||
is_recurring = Column(Boolean, default=False)
|
is_recurring = Column(Boolean, default=False)
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ working with ``from app.models.enums import ...``.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from app.domain.enums import ( # noqa: F401
|
from app.domain.enums import ( # noqa: F401
|
||||||
|
DataClassification,
|
||||||
TeamSide,
|
TeamSide,
|
||||||
TechniqueStatus,
|
TechniqueStatus,
|
||||||
TestResult,
|
TestResult,
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ class Evidence(Base):
|
|||||||
uploaded_at = Column(DateTime(timezone=True), server_default=func.now())
|
uploaded_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
team = Column(Enum(TeamSide, name="teamside"), nullable=False, default=TeamSide.red)
|
team = Column(Enum(TeamSide, name="teamside"), nullable=False, default=TeamSide.red)
|
||||||
notes = Column(Text, nullable=True)
|
notes = Column(Text, nullable=True)
|
||||||
|
data_classification = Column(String(20), nullable=False, server_default="internal")
|
||||||
|
|
||||||
# Relationships
|
# Relationships
|
||||||
test = relationship("Test", back_populates="evidences")
|
test = relationship("Test", back_populates="evidences")
|
||||||
|
|||||||
@@ -62,6 +62,7 @@ class Test(Base):
|
|||||||
# ── Re-test fields ────────────────────────────────────────────
|
# ── Re-test fields ────────────────────────────────────────────
|
||||||
retest_of = Column(UUID(as_uuid=True), ForeignKey("tests.id"), nullable=True)
|
retest_of = Column(UUID(as_uuid=True), ForeignKey("tests.id"), nullable=True)
|
||||||
retest_count = Column(Integer, default=0)
|
retest_count = Column(Integer, default=0)
|
||||||
|
data_classification = Column(String(20), nullable=False, server_default="internal")
|
||||||
|
|
||||||
# ── Relationships ───────────────────────────────────────────────
|
# ── Relationships ───────────────────────────────────────────────
|
||||||
technique = relationship("Technique", back_populates="tests")
|
technique = relationship("Technique", back_populates="tests")
|
||||||
|
|||||||
+51
-56
@@ -11,39 +11,33 @@ import os
|
|||||||
|
|
||||||
from fastapi import APIRouter, Cookie, Depends, Request, Response
|
from fastapi import APIRouter, Cookie, Depends, Request, Response
|
||||||
from fastapi.security import OAuth2PasswordRequestForm
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
from slowapi import Limiter
|
|
||||||
from slowapi.util import get_remote_address
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from jose import jwt, JWTError
|
from jose import jwt, JWTError
|
||||||
|
|
||||||
from app.auth import create_access_token, blacklist_token
|
from app.auth import create_access_token, blacklist_token, verify_password
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
from app.dependencies.auth import get_current_user
|
from app.dependencies.auth import get_current_user
|
||||||
|
from app.domain.errors import BusinessRuleViolation, PermissionViolation
|
||||||
from app.domain.unit_of_work import UnitOfWork
|
from app.domain.unit_of_work import UnitOfWork
|
||||||
|
from app.limiter import limiter
|
||||||
|
from app.middleware.request_context import resolve_client_ip
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.services.auth_service import authenticate_user, change_password as auth_change_password
|
from app.services.auth_service import (
|
||||||
|
_DUMMY_HASH,
|
||||||
|
change_password as auth_change_password,
|
||||||
|
)
|
||||||
|
from app.services.audit_service import log_action
|
||||||
from app.schemas.auth import TokenResponse, UserOut
|
from app.schemas.auth import TokenResponse, UserOut
|
||||||
from app.schemas.user import PasswordChange
|
from app.schemas.user import PasswordChange
|
||||||
|
|
||||||
# Rate limiter instance (shares backend state via app.state.limiter)
|
|
||||||
limiter = Limiter(key_func=get_remote_address)
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||||
|
|
||||||
# Detect whether we're behind HTTPS (production) so the cookie can be Secure
|
|
||||||
_IS_HTTPS = os.environ.get("AEGIS_ENV", "").lower() == "production"
|
_IS_HTTPS = os.environ.get("AEGIS_ENV", "").lower() == "production"
|
||||||
|
|
||||||
# Cookie name used to transport the JWT
|
|
||||||
_COOKIE_NAME = "aegis_token"
|
_COOKIE_NAME = "aegis_token"
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# POST /auth/login
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/login", response_model=TokenResponse)
|
@router.post("/login", response_model=TokenResponse)
|
||||||
@limiter.limit("5/minute")
|
@limiter.limit("5/minute")
|
||||||
def login(
|
def login(
|
||||||
@@ -54,19 +48,49 @@ def login(
|
|||||||
):
|
):
|
||||||
"""Authenticate a user and return a JWT access token.
|
"""Authenticate a user and return a JWT access token.
|
||||||
|
|
||||||
Rate-limited to **5 attempts per minute per IP** to prevent brute-force
|
Rate-limited to **5 attempts per minute per IP**. Failed and successful
|
||||||
attacks. The token is set as an HttpOnly cookie **and** returned in the
|
logins are recorded in the audit log (SEC-009).
|
||||||
JSON body for API/Swagger compatibility.
|
|
||||||
"""
|
"""
|
||||||
user = authenticate_user(
|
user = db.query(User).filter(User.username == form_data.username).first()
|
||||||
db,
|
target_hash = user.hashed_password if user else _DUMMY_HASH
|
||||||
username=form_data.username,
|
password_valid = verify_password(form_data.password, target_hash)
|
||||||
password=form_data.password,
|
ip = resolve_client_ip(request)
|
||||||
)
|
|
||||||
|
if user is None or not password_valid:
|
||||||
|
with UnitOfWork(db) as uow:
|
||||||
|
log_action(
|
||||||
|
db,
|
||||||
|
user.id if user else None,
|
||||||
|
"LOGIN_FAILED",
|
||||||
|
"auth",
|
||||||
|
None,
|
||||||
|
details={
|
||||||
|
"username": form_data.username,
|
||||||
|
"ip": ip,
|
||||||
|
"reason": "invalid_credentials",
|
||||||
|
},
|
||||||
|
ip_address=ip,
|
||||||
|
)
|
||||||
|
uow.commit()
|
||||||
|
raise BusinessRuleViolation("Incorrect username or password")
|
||||||
|
|
||||||
|
if not user.is_active:
|
||||||
|
raise PermissionViolation("Account is disabled. Contact an administrator.")
|
||||||
|
|
||||||
access_token = create_access_token(data={"sub": user.username})
|
access_token = create_access_token(data={"sub": user.username})
|
||||||
|
|
||||||
# Set HttpOnly cookie — inaccessible from JS
|
with UnitOfWork(db) as uow:
|
||||||
|
log_action(
|
||||||
|
db,
|
||||||
|
user.id,
|
||||||
|
"LOGIN_SUCCESS",
|
||||||
|
"auth",
|
||||||
|
str(user.id),
|
||||||
|
details={"username": user.username, "ip": ip},
|
||||||
|
ip_address=ip,
|
||||||
|
)
|
||||||
|
uow.commit()
|
||||||
|
|
||||||
response.set_cookie(
|
response.set_cookie(
|
||||||
key=_COOKIE_NAME,
|
key=_COOKIE_NAME,
|
||||||
value=access_token,
|
value=access_token,
|
||||||
@@ -80,27 +104,13 @@ def login(
|
|||||||
return TokenResponse(access_token=access_token)
|
return TokenResponse(access_token=access_token)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# POST /auth/logout
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/logout")
|
@router.post("/logout")
|
||||||
def logout(
|
def logout(
|
||||||
request: Request,
|
request: Request,
|
||||||
response: Response,
|
response: Response,
|
||||||
aegis_token: str | None = Cookie(None),
|
aegis_token: str | None = Cookie(None),
|
||||||
):
|
):
|
||||||
"""Clear the authentication cookie and revoke the current token.
|
"""Clear the authentication cookie and revoke the current token."""
|
||||||
|
|
||||||
The token's ``jti`` is added to the Redis blacklist so it cannot
|
|
||||||
be reused even if the cookie has already been copied elsewhere.
|
|
||||||
The blacklist entry auto-expires when the token's ``exp`` is reached.
|
|
||||||
|
|
||||||
When both HttpOnly cookie and ``Authorization: Bearer`` are present
|
|
||||||
(typical for API clients), **both** are revoked so the session cannot
|
|
||||||
survive on whichever credential the next request prefers.
|
|
||||||
"""
|
|
||||||
bearer = (
|
bearer = (
|
||||||
request.headers.get("Authorization")
|
request.headers.get("Authorization")
|
||||||
or request.headers.get("authorization")
|
or request.headers.get("authorization")
|
||||||
@@ -124,7 +134,7 @@ def logout(
|
|||||||
if jti:
|
if jti:
|
||||||
blacklist_token(jti, float(exp))
|
blacklist_token(jti, float(exp))
|
||||||
except JWTError:
|
except JWTError:
|
||||||
pass # token already invalid — nothing to revoke for this raw value
|
pass
|
||||||
|
|
||||||
response.delete_cookie(
|
response.delete_cookie(
|
||||||
key=_COOKIE_NAME,
|
key=_COOKIE_NAME,
|
||||||
@@ -136,34 +146,19 @@ def logout(
|
|||||||
return {"detail": "Logged out"}
|
return {"detail": "Logged out"}
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# GET /auth/me
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/me", response_model=UserOut)
|
@router.get("/me", response_model=UserOut)
|
||||||
def read_current_user(current_user: User = Depends(get_current_user)):
|
def read_current_user(current_user: User = Depends(get_current_user)):
|
||||||
"""Return the profile of the currently authenticated user."""
|
"""Return the profile of the currently authenticated user."""
|
||||||
return current_user
|
return current_user
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# POST /auth/change-password
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/change-password")
|
@router.post("/change-password")
|
||||||
def change_password(
|
def change_password(
|
||||||
body: PasswordChange,
|
body: PasswordChange,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Change the current user's password.
|
"""Change the current user's password."""
|
||||||
|
|
||||||
Requires the current password for verification. On success the
|
|
||||||
``must_change_password`` flag is cleared so the user can proceed
|
|
||||||
normally.
|
|
||||||
"""
|
|
||||||
auth_change_password(
|
auth_change_password(
|
||||||
db,
|
db,
|
||||||
current_user,
|
current_user,
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ import os
|
|||||||
import uuid as _uuid
|
import uuid as _uuid
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, File, Form, Query, UploadFile, status
|
from fastapi import APIRouter, Depends, File, Form, Query, Request, UploadFile, status
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
@@ -44,6 +44,7 @@ from app.services.evidence_service import (
|
|||||||
validate_file,
|
validate_file,
|
||||||
validate_upload_permission,
|
validate_upload_permission,
|
||||||
)
|
)
|
||||||
|
from app.limiter import limiter
|
||||||
from app.storage import get_presigned_url, upload_file
|
from app.storage import get_presigned_url, upload_file
|
||||||
|
|
||||||
router = APIRouter(tags=["evidence"])
|
router = APIRouter(tags=["evidence"])
|
||||||
@@ -78,7 +79,9 @@ def _evidence_to_out(evidence: Evidence) -> EvidenceOut:
|
|||||||
response_model=EvidenceOut,
|
response_model=EvidenceOut,
|
||||||
status_code=status.HTTP_201_CREATED,
|
status_code=status.HTTP_201_CREATED,
|
||||||
)
|
)
|
||||||
|
@limiter.limit("10/minute")
|
||||||
async def upload_evidence(
|
async def upload_evidence(
|
||||||
|
request: Request,
|
||||||
test_id: _uuid.UUID,
|
test_id: _uuid.UUID,
|
||||||
file: UploadFile = File(...),
|
file: UploadFile = File(...),
|
||||||
team: TeamSide = Form(TeamSide.red),
|
team: TeamSide = Form(TeamSide.red),
|
||||||
|
|||||||
@@ -2,13 +2,14 @@
|
|||||||
|
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Query
|
from fastapi import APIRouter, Depends, Query, Request
|
||||||
from fastapi.responses import FileResponse
|
from fastapi.responses import FileResponse
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
from app.dependencies.auth import get_current_user, require_any_role
|
from app.dependencies.auth import get_current_user, require_any_role
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
|
from app.limiter import limiter
|
||||||
from app.services import report_generation_service
|
from app.services import report_generation_service
|
||||||
|
|
||||||
router = APIRouter(prefix="/reports/generate", tags=["professional-reports"])
|
router = APIRouter(prefix="/reports/generate", tags=["professional-reports"])
|
||||||
@@ -21,7 +22,9 @@ _MEDIA_TYPES = {
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/purple-campaign/{campaign_id}")
|
@router.get("/purple-campaign/{campaign_id}")
|
||||||
|
@limiter.limit("5/minute")
|
||||||
def generate_purple_report(
|
def generate_purple_report(
|
||||||
|
request: Request,
|
||||||
campaign_id: UUID,
|
campaign_id: UUID,
|
||||||
format: str = Query("pdf", pattern="^(pdf|docx|html)$"),
|
format: str = Query("pdf", pattern="^(pdf|docx|html)$"),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
@@ -39,7 +42,9 @@ def generate_purple_report(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/coverage-summary")
|
@router.get("/coverage-summary")
|
||||||
|
@limiter.limit("5/minute")
|
||||||
def generate_coverage_report(
|
def generate_coverage_report(
|
||||||
|
request: Request,
|
||||||
format: str = Query("pdf", pattern="^(pdf|docx|html)$"),
|
format: str = Query("pdf", pattern="^(pdf|docx|html)$"),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
user: User = Depends(require_any_role("red_lead", "blue_lead", "viewer")),
|
user: User = Depends(require_any_role("red_lead", "blue_lead", "viewer")),
|
||||||
@@ -56,7 +61,9 @@ def generate_coverage_report(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/executive-summary")
|
@router.get("/executive-summary")
|
||||||
|
@limiter.limit("5/minute")
|
||||||
def generate_executive_report(
|
def generate_executive_report(
|
||||||
|
request: Request,
|
||||||
format: str = Query("pdf", pattern="^(pdf|docx|html)$"),
|
format: str = Query("pdf", pattern="^(pdf|docx|html)$"),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
user: User = Depends(require_any_role("red_lead", "blue_lead", "viewer")),
|
user: User = Depends(require_any_role("red_lead", "blue_lead", "viewer")),
|
||||||
@@ -73,7 +80,9 @@ def generate_executive_report(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/quarterly-summary")
|
@router.get("/quarterly-summary")
|
||||||
|
@limiter.limit("5/minute")
|
||||||
def generate_quarterly_report(
|
def generate_quarterly_report(
|
||||||
|
request: Request,
|
||||||
format: str = Query("pdf", pattern="^(pdf|docx|html)$"),
|
format: str = Query("pdf", pattern="^(pdf|docx|html)$"),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
user: User = Depends(require_any_role("red_lead", "blue_lead", "viewer")),
|
user: User = Depends(require_any_role("red_lead", "blue_lead", "viewer")),
|
||||||
@@ -90,7 +99,9 @@ def generate_quarterly_report(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/technique/{technique_id}")
|
@router.get("/technique/{technique_id}")
|
||||||
|
@limiter.limit("5/minute")
|
||||||
def generate_technique_report(
|
def generate_technique_report(
|
||||||
|
request: Request,
|
||||||
technique_id: UUID,
|
technique_id: UUID,
|
||||||
format: str = Query("pdf", pattern="^(pdf|docx|html)$"),
|
format: str = Query("pdf", pattern="^(pdf|docx|html)$"),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ scheduler health introspection.
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends
|
from fastapi import APIRouter, Depends, Request
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
@@ -17,6 +17,7 @@ from app.services.mitre_sync_service import sync_mitre
|
|||||||
from app.services.intel_service import scan_intel
|
from app.services.intel_service import scan_intel
|
||||||
from app.services.atomic_import_service import import_atomic_red_team
|
from app.services.atomic_import_service import import_atomic_red_team
|
||||||
from app.jobs.mitre_sync_job import scheduler
|
from app.jobs.mitre_sync_job import scheduler
|
||||||
|
from app.limiter import limiter
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -24,7 +25,9 @@ router = APIRouter(prefix="/system", tags=["system"])
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/sync-mitre")
|
@router.post("/sync-mitre")
|
||||||
|
@limiter.limit("2/hour")
|
||||||
def trigger_mitre_sync(
|
def trigger_mitre_sync(
|
||||||
|
request: Request,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(require_role("admin")),
|
current_user: User = Depends(require_role("admin")),
|
||||||
):
|
):
|
||||||
@@ -63,7 +66,9 @@ def trigger_intel_scan(
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/import-atomic-tests")
|
@router.post("/import-atomic-tests")
|
||||||
|
@limiter.limit("2/hour")
|
||||||
def trigger_atomic_import(
|
def trigger_atomic_import(
|
||||||
|
request: Request,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(require_role("admin")),
|
current_user: User = Depends(require_role("admin")),
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -21,11 +21,13 @@ GET /tests/{id}/timeline — audit-log history for this test
|
|||||||
import uuid
|
import uuid
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
from app.dependencies.auth import get_current_user, require_any_role
|
from app.dependencies.auth import get_current_user, require_any_role, require_role
|
||||||
|
from app.domain.enums import DataClassification
|
||||||
|
from app.limiter import limiter
|
||||||
from app.models.enums import TestState
|
from app.models.enums import TestState
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.schemas.test import (
|
from app.schemas.test import (
|
||||||
@@ -37,6 +39,7 @@ from app.schemas.test import (
|
|||||||
TestRedValidate,
|
TestRedValidate,
|
||||||
TestBlueValidate,
|
TestBlueValidate,
|
||||||
TestRemediationUpdate,
|
TestRemediationUpdate,
|
||||||
|
TestClassificationUpdate,
|
||||||
)
|
)
|
||||||
from app.schemas.test_template import TestTemplateInstantiate
|
from app.schemas.test_template import TestTemplateInstantiate
|
||||||
from app.domain.unit_of_work import UnitOfWork
|
from app.domain.unit_of_work import UnitOfWork
|
||||||
@@ -112,7 +115,9 @@ def list_tests(
|
|||||||
response_model=TestOut,
|
response_model=TestOut,
|
||||||
status_code=status.HTTP_201_CREATED,
|
status_code=status.HTTP_201_CREATED,
|
||||||
)
|
)
|
||||||
|
@limiter.limit("30/minute")
|
||||||
def create_test(
|
def create_test(
|
||||||
|
request: Request,
|
||||||
payload: TestCreate,
|
payload: TestCreate,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||||
@@ -152,7 +157,9 @@ def create_test(
|
|||||||
response_model=TestOut,
|
response_model=TestOut,
|
||||||
status_code=status.HTTP_201_CREATED,
|
status_code=status.HTTP_201_CREATED,
|
||||||
)
|
)
|
||||||
|
@limiter.limit("30/minute")
|
||||||
def create_test_from_template(
|
def create_test_from_template(
|
||||||
|
request: Request,
|
||||||
payload: TestTemplateInstantiate,
|
payload: TestTemplateInstantiate,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||||
@@ -241,6 +248,36 @@ def update_test(
|
|||||||
return test
|
return test
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# PATCH /tests/{id}/classification — admin data classification
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@router.patch("/{test_id}/classification", response_model=TestOut)
|
||||||
|
def update_test_classification(
|
||||||
|
test_id: uuid.UUID,
|
||||||
|
payload: TestClassificationUpdate,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(require_role("admin")),
|
||||||
|
):
|
||||||
|
"""Update the data classification label for a test (admin only)."""
|
||||||
|
with UnitOfWork(db) as uow:
|
||||||
|
test = crud_get_test_or_raise(db, test_id)
|
||||||
|
test.data_classification = payload.data_classification.value
|
||||||
|
db.flush()
|
||||||
|
log_action(
|
||||||
|
db,
|
||||||
|
user_id=current_user.id,
|
||||||
|
action="update_test_classification",
|
||||||
|
entity_type="test",
|
||||||
|
entity_id=test.id,
|
||||||
|
details={"data_classification": payload.data_classification.value},
|
||||||
|
)
|
||||||
|
uow.commit()
|
||||||
|
db.refresh(test)
|
||||||
|
return test
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# PATCH /tests/{id}/red — Red Team update (draft, red_executing)
|
# PATCH /tests/{id}/red — Red Team update (draft, red_executing)
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from datetime import datetime
|
|||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
|
from app.domain.enums import DataClassification
|
||||||
from app.models.enums import TestResult, TestState
|
from app.models.enums import TestResult, TestState
|
||||||
|
|
||||||
|
|
||||||
@@ -25,6 +26,12 @@ class TestCreate(BaseModel):
|
|||||||
# ── Update (general) ───────────────────────────────────────────────
|
# ── Update (general) ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestClassificationUpdate(BaseModel):
|
||||||
|
"""Admin-only payload for changing data classification."""
|
||||||
|
|
||||||
|
data_classification: DataClassification
|
||||||
|
|
||||||
|
|
||||||
class TestUpdate(BaseModel):
|
class TestUpdate(BaseModel):
|
||||||
"""Payload for partially updating an existing test.
|
"""Payload for partially updating an existing test.
|
||||||
Every field is optional so callers send only what changed."""
|
Every field is optional so callers send only what changed."""
|
||||||
@@ -152,6 +159,7 @@ class TestOut(BaseModel):
|
|||||||
# Re-test fields
|
# Re-test fields
|
||||||
retest_of: uuid.UUID | None = None
|
retest_of: uuid.UUID | None = None
|
||||||
retest_count: int = 0
|
retest_count: int = 0
|
||||||
|
data_classification: str = "internal"
|
||||||
|
|
||||||
# Technique info (populated when joined)
|
# Technique info (populated when joined)
|
||||||
technique_mitre_id: str | None = None
|
technique_mitre_id: str | None = None
|
||||||
|
|||||||
@@ -1,32 +1,65 @@
|
|||||||
|
"""Audit logging with request context and integrity hashing."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.middleware.request_context import request_ip, request_user_agent
|
||||||
from app.models.audit import AuditLog
|
from app.models.audit import AuditLog
|
||||||
|
|
||||||
|
|
||||||
|
def _integrity_payload(entry: AuditLog) -> str:
|
||||||
|
ts = entry.timestamp
|
||||||
|
if ts is None:
|
||||||
|
ts = datetime.now(timezone.utc)
|
||||||
|
user_part = str(entry.user_id) if entry.user_id else ""
|
||||||
|
entity_type = entry.entity_type or ""
|
||||||
|
entity_id = entry.entity_id or ""
|
||||||
|
return f"{user_part}:{entry.action}:{entity_type}:{entity_id}:{ts.isoformat()}"
|
||||||
|
|
||||||
|
|
||||||
|
def compute_integrity_hash(entry: AuditLog) -> str:
|
||||||
|
"""Return the SHA-256 hex digest for an audit log entry."""
|
||||||
|
return hashlib.sha256(_integrity_payload(entry).encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def verify_audit_integrity(entry: AuditLog) -> bool:
|
||||||
|
"""Return whether the stored hash matches the entry's current fields."""
|
||||||
|
if not entry.integrity_hash:
|
||||||
|
return False
|
||||||
|
return entry.integrity_hash == compute_integrity_hash(entry)
|
||||||
|
|
||||||
|
|
||||||
def log_action(
|
def log_action(
|
||||||
db: Session,
|
db: Session,
|
||||||
user_id,
|
user_id,
|
||||||
action: str,
|
action: str,
|
||||||
entity_type: str = None,
|
entity_type: str | None = None,
|
||||||
entity_id: str = None,
|
entity_id: str | None = None,
|
||||||
details: dict = None
|
details: dict | None = None,
|
||||||
):
|
*,
|
||||||
"""
|
ip_address: str | None = None,
|
||||||
Log an action to the audit log.
|
user_agent: str | None = None,
|
||||||
|
session_id: str | None = None,
|
||||||
|
) -> AuditLog:
|
||||||
|
"""Record an audit event. Does not commit — the caller owns the transaction."""
|
||||||
|
ip = ip_address if ip_address is not None else request_ip.get("")
|
||||||
|
ua = user_agent if user_agent is not None else request_user_agent.get("")
|
||||||
|
|
||||||
Args:
|
entry = AuditLog(
|
||||||
db: Database session
|
|
||||||
user_id: UUID of the user performing the action (can be None for system actions)
|
|
||||||
action: Description of the action (e.g., "create_test", "validate_technique")
|
|
||||||
entity_type: Type of entity affected (e.g., "technique", "test", "user")
|
|
||||||
entity_id: ID of the entity affected
|
|
||||||
details: Additional details as a dictionary (stored as JSONB)
|
|
||||||
"""
|
|
||||||
log = AuditLog(
|
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
action=action,
|
action=action,
|
||||||
entity_type=entity_type,
|
entity_type=entity_type,
|
||||||
entity_id=str(entity_id) if entity_id else None,
|
entity_id=str(entity_id) if entity_id else None,
|
||||||
details=details,
|
details=details,
|
||||||
|
ip_address=ip or None,
|
||||||
|
user_agent=ua or None,
|
||||||
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
db.add(log)
|
db.add(entry)
|
||||||
|
db.flush()
|
||||||
|
entry.integrity_hash = compute_integrity_hash(entry)
|
||||||
|
return entry
|
||||||
|
|||||||
@@ -121,10 +121,8 @@ def client(db, monkeypatch):
|
|||||||
app.dependency_overrides[get_db] = override_get_db
|
app.dependency_overrides[get_db] = override_get_db
|
||||||
Base.metadata.create_all(bind=engine)
|
Base.metadata.create_all(bind=engine)
|
||||||
|
|
||||||
if hasattr(app.state, "limiter"):
|
from app.limiter import limiter
|
||||||
app.state.limiter.enabled = False
|
limiter.enabled = False
|
||||||
from app.routers.auth import limiter as auth_limiter
|
|
||||||
auth_limiter.enabled = False
|
|
||||||
|
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
with TestClient(app) as test_client:
|
with TestClient(app) as test_client:
|
||||||
|
|||||||
@@ -0,0 +1,84 @@
|
|||||||
|
"""Tests for enhanced audit trail (IP, user-agent, integrity hash)."""
|
||||||
|
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.middleware.request_context import request_ip, request_user_agent
|
||||||
|
from app.models.audit import AuditLog
|
||||||
|
from app.services.audit_service import (
|
||||||
|
compute_integrity_hash,
|
||||||
|
log_action,
|
||||||
|
verify_audit_integrity,
|
||||||
|
)
|
||||||
|
from app.jobs.retention_job import apply_retention_policies
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuditIntegrity:
|
||||||
|
def test_integrity_hash_set_on_log(self, db):
|
||||||
|
entry = log_action(
|
||||||
|
db,
|
||||||
|
user_id=None,
|
||||||
|
action="test_action",
|
||||||
|
entity_type="test",
|
||||||
|
entity_id="abc",
|
||||||
|
ip_address="10.0.0.1",
|
||||||
|
user_agent="pytest",
|
||||||
|
)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(entry)
|
||||||
|
assert entry.integrity_hash
|
||||||
|
assert len(entry.integrity_hash) == 64
|
||||||
|
assert verify_audit_integrity(entry)
|
||||||
|
|
||||||
|
def test_tampered_entry_fails_integrity(self, db):
|
||||||
|
entry = log_action(
|
||||||
|
db,
|
||||||
|
user_id=None,
|
||||||
|
action="test_action",
|
||||||
|
entity_type="test",
|
||||||
|
entity_id="abc",
|
||||||
|
)
|
||||||
|
db.commit()
|
||||||
|
entry.entity_id = "tampered"
|
||||||
|
assert not verify_audit_integrity(entry)
|
||||||
|
|
||||||
|
def test_recomputed_hash_matches_stored(self, db):
|
||||||
|
entry = log_action(db, None, "update", "user", "1")
|
||||||
|
db.commit()
|
||||||
|
assert entry.integrity_hash == compute_integrity_hash(entry)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRequestContext:
|
||||||
|
def test_context_vars_used_by_log_action(self, db):
|
||||||
|
token_ip = request_ip.set("203.0.113.42")
|
||||||
|
token_ua = request_user_agent.set("AegisTestClient/1.0")
|
||||||
|
try:
|
||||||
|
entry = log_action(db, None, "ctx_action", "system", None)
|
||||||
|
db.commit()
|
||||||
|
assert entry.ip_address == "203.0.113.42"
|
||||||
|
assert entry.user_agent == "AegisTestClient/1.0"
|
||||||
|
finally:
|
||||||
|
request_ip.reset(token_ip)
|
||||||
|
request_user_agent.reset(token_ua)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRetentionJob:
|
||||||
|
def test_deletes_old_audit_logs(self, db):
|
||||||
|
old = AuditLog(
|
||||||
|
action="old",
|
||||||
|
entity_type="system",
|
||||||
|
timestamp=datetime.now(timezone.utc) - timedelta(days=800),
|
||||||
|
)
|
||||||
|
recent = AuditLog(
|
||||||
|
action="recent",
|
||||||
|
entity_type="system",
|
||||||
|
timestamp=datetime.now(timezone.utc) - timedelta(days=1),
|
||||||
|
)
|
||||||
|
db.add_all([old, recent])
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
summary = apply_retention_policies(db)
|
||||||
|
assert summary["audit_logs_deleted"] >= 1
|
||||||
|
remaining = db.query(AuditLog).all()
|
||||||
|
assert all(log.action != "old" for log in remaining)
|
||||||
@@ -0,0 +1,58 @@
|
|||||||
|
"""Tests for login attempt auditing (SEC-009)."""
|
||||||
|
|
||||||
|
from app.models.audit import AuditLog
|
||||||
|
|
||||||
|
|
||||||
|
def test_login_failed_creates_audit_entry(client, admin_user, db):
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/auth/login",
|
||||||
|
data={"username": "admin", "password": "wrong"},
|
||||||
|
headers={"X-Forwarded-For": "198.51.100.10", "User-Agent": "LoginAuditTest/1.0"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 400
|
||||||
|
|
||||||
|
log = (
|
||||||
|
db.query(AuditLog)
|
||||||
|
.filter(AuditLog.action == "LOGIN_FAILED")
|
||||||
|
.order_by(AuditLog.timestamp.desc())
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
assert log is not None
|
||||||
|
assert log.entity_type == "auth"
|
||||||
|
assert log.details["username"] == "admin"
|
||||||
|
assert log.details["reason"] == "invalid_credentials"
|
||||||
|
assert log.ip_address == "198.51.100.10"
|
||||||
|
assert log.user_agent == "LoginAuditTest/1.0"
|
||||||
|
assert log.integrity_hash
|
||||||
|
|
||||||
|
|
||||||
|
def test_login_success_creates_audit_entry(client, admin_user, db):
|
||||||
|
client.cookies.clear()
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/auth/login",
|
||||||
|
data={"username": "admin", "password": "admin123"},
|
||||||
|
headers={"X-Forwarded-For": "198.51.100.20"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
log = (
|
||||||
|
db.query(AuditLog)
|
||||||
|
.filter(AuditLog.action == "LOGIN_SUCCESS")
|
||||||
|
.order_by(AuditLog.timestamp.desc())
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
assert log is not None
|
||||||
|
assert log.user_id == admin_user.id
|
||||||
|
assert log.ip_address == "198.51.100.20"
|
||||||
|
assert log.integrity_hash
|
||||||
|
|
||||||
|
|
||||||
|
def test_login_unknown_user_still_audited(client, db):
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/auth/login",
|
||||||
|
data={"username": "nobody", "password": "password"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 400
|
||||||
|
log = db.query(AuditLog).filter(AuditLog.action == "LOGIN_FAILED").first()
|
||||||
|
assert log is not None
|
||||||
|
assert log.user_id is None
|
||||||
@@ -0,0 +1,72 @@
|
|||||||
|
"""Tests for data classification fields and admin updates."""
|
||||||
|
|
||||||
|
from app.models.enums import TestState
|
||||||
|
from app.models.test import Test
|
||||||
|
from app.models.technique import Technique
|
||||||
|
|
||||||
|
|
||||||
|
def _seed_technique(db) -> Technique:
|
||||||
|
technique = Technique(
|
||||||
|
mitre_id="T9999",
|
||||||
|
name="Test Technique",
|
||||||
|
tactic="test",
|
||||||
|
platforms=["linux"],
|
||||||
|
)
|
||||||
|
db.add(technique)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(technique)
|
||||||
|
return technique
|
||||||
|
|
||||||
|
|
||||||
|
def test_new_test_defaults_to_internal(db, red_lead_user):
|
||||||
|
technique = _seed_technique(db)
|
||||||
|
test = Test(
|
||||||
|
technique_id=technique.id,
|
||||||
|
name="Classification test",
|
||||||
|
created_by=red_lead_user.id,
|
||||||
|
)
|
||||||
|
db.add(test)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(test)
|
||||||
|
assert test.data_classification == "internal"
|
||||||
|
|
||||||
|
|
||||||
|
def test_admin_can_update_classification(client, db, admin_user, admin_token, red_lead_user):
|
||||||
|
technique = _seed_technique(db)
|
||||||
|
test = Test(
|
||||||
|
technique_id=technique.id,
|
||||||
|
name="Classify me",
|
||||||
|
created_by=red_lead_user.id,
|
||||||
|
state=TestState.draft,
|
||||||
|
)
|
||||||
|
db.add(test)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
response = client.patch(
|
||||||
|
f"/api/v1/tests/{test.id}/classification",
|
||||||
|
json={"data_classification": "sensitive"},
|
||||||
|
headers={"Authorization": f"Bearer {admin_token}"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["data_classification"] == "sensitive"
|
||||||
|
|
||||||
|
db.refresh(test)
|
||||||
|
assert test.data_classification == "sensitive"
|
||||||
|
|
||||||
|
|
||||||
|
def test_non_admin_cannot_update_classification(client, db, admin_user, red_lead_token, red_lead_user):
|
||||||
|
technique = _seed_technique(db)
|
||||||
|
test = Test(
|
||||||
|
technique_id=technique.id,
|
||||||
|
name="Protected",
|
||||||
|
created_by=red_lead_user.id,
|
||||||
|
)
|
||||||
|
db.add(test)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
response = client.patch(
|
||||||
|
f"/api/v1/tests/{test.id}/classification",
|
||||||
|
json={"data_classification": "restricted"},
|
||||||
|
headers={"Authorization": f"Bearer {red_lead_token}"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 403
|
||||||
@@ -0,0 +1,25 @@
|
|||||||
|
"""Smoke tests for extended rate-limit decorators (SEC-003)."""
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
from app.routers import evidence, professional_reports, system, tests
|
||||||
|
|
||||||
|
|
||||||
|
def test_sync_mitre_has_hourly_limit():
|
||||||
|
source = inspect.getsource(system.trigger_mitre_sync)
|
||||||
|
assert "2/hour" in source
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_test_has_per_minute_limit():
|
||||||
|
source = inspect.getsource(tests.create_test)
|
||||||
|
assert "30/minute" in source
|
||||||
|
|
||||||
|
|
||||||
|
def test_upload_evidence_has_per_minute_limit():
|
||||||
|
source = inspect.getsource(evidence.upload_evidence)
|
||||||
|
assert "10/minute" in source
|
||||||
|
|
||||||
|
|
||||||
|
def test_report_endpoints_have_per_minute_limit():
|
||||||
|
source = inspect.getsource(professional_reports.generate_coverage_report)
|
||||||
|
assert "5/minute" in source
|
||||||
@@ -42,6 +42,14 @@ class TestUsernameValidation:
|
|||||||
with pytest.raises(ValidationError, match="3-50 characters"):
|
with pytest.raises(ValidationError, match="3-50 characters"):
|
||||||
UserCreate(username="john@doe", password="SecurePass123!@#")
|
UserCreate(username="john@doe", password="SecurePass123!@#")
|
||||||
|
|
||||||
|
def test_reserved_username_system(self):
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
UserCreate(username="system", password="SecurePass123!@#")
|
||||||
|
|
||||||
|
def test_invalid_username_path_chars(self):
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
UserCreate(username="../admin", password="SecurePass123!@#")
|
||||||
|
|
||||||
def test_reserved_username_admin(self):
|
def test_reserved_username_admin(self):
|
||||||
with pytest.raises(ValidationError, match="reserved"):
|
with pytest.raises(ValidationError, match="reserved"):
|
||||||
UserCreate(username="admin", password="SecurePass123!@#")
|
UserCreate(username="admin", password="SecurePass123!@#")
|
||||||
|
|||||||
@@ -0,0 +1,59 @@
|
|||||||
|
"""API-level validation tests for user creation (SEC-004, SEC-007)."""
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_user_weak_password_rejected(client, admin_user, admin_token):
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/users",
|
||||||
|
json={
|
||||||
|
"username": "newuser",
|
||||||
|
"password": "123",
|
||||||
|
"email": "new@test.com",
|
||||||
|
"role": "viewer",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {admin_token}"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 422
|
||||||
|
assert "password" in response.text.lower()
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_user_reserved_username(client, admin_user, admin_token):
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/users",
|
||||||
|
json={
|
||||||
|
"username": "system",
|
||||||
|
"password": "SecurePass123!@#",
|
||||||
|
"email": "sys@test.com",
|
||||||
|
"role": "viewer",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {admin_token}"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_user_invalid_username_chars(client, admin_user, admin_token):
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/users",
|
||||||
|
json={
|
||||||
|
"username": "../admin",
|
||||||
|
"password": "SecurePass123!@#",
|
||||||
|
"email": "bad@test.com",
|
||||||
|
"role": "viewer",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {admin_token}"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_user_valid_password_accepted(client, admin_user, admin_token):
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/users",
|
||||||
|
json={
|
||||||
|
"username": "validuser99",
|
||||||
|
"password": "ValidPass123!@#",
|
||||||
|
"email": "valid@test.com",
|
||||||
|
"role": "viewer",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {admin_token}"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 201
|
||||||
|
assert response.json()["username"] == "validuser99"
|
||||||
Reference in New Issue
Block a user