feat(audit): enhanced audit trail with IP, user-agent and integrity hash [FASE-3.1]
This commit is contained in:
26
backend/app/middleware/request_context.py
Normal file
26
backend/app/middleware/request_context.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""Request context middleware — captures client IP and User-Agent per request."""
|
||||
|
||||
from contextvars import ContextVar
|
||||
|
||||
from fastapi import Request
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
request_ip: ContextVar[str] = ContextVar("request_ip", default="")
|
||||
request_user_agent: ContextVar[str] = ContextVar("request_user_agent", default="")
|
||||
|
||||
|
||||
def resolve_client_ip(request: Request) -> str:
|
||||
"""Extract the client IP, honouring ``X-Forwarded-For`` when present."""
|
||||
forwarded = request.headers.get("X-Forwarded-For")
|
||||
if forwarded:
|
||||
return forwarded.split(",")[0].strip()
|
||||
if request.client:
|
||||
return request.client.host
|
||||
return "unknown"
|
||||
|
||||
|
||||
class RequestContextMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
request_ip.set(resolve_client_ip(request))
|
||||
request_user_agent.set(request.headers.get("User-Agent", ""))
|
||||
return await call_next(request)
|
||||
@@ -22,6 +22,10 @@ class AuditLog(Base):
|
||||
entity_id = Column(String, nullable=True)
|
||||
timestamp = Column(DateTime(timezone=True), server_default=func.now())
|
||||
details = Column(JSONB, nullable=True)
|
||||
ip_address = Column(String(45), nullable=True)
|
||||
user_agent = Column(String(500), nullable=True)
|
||||
integrity_hash = Column(String(64), nullable=True)
|
||||
session_id = Column(String(100), nullable=True)
|
||||
|
||||
# Relationships
|
||||
user = relationship("User")
|
||||
|
||||
@@ -1,32 +1,65 @@
|
||||
"""Audit logging with request context and integrity hashing."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.middleware.request_context import request_ip, request_user_agent
|
||||
from app.models.audit import AuditLog
|
||||
|
||||
|
||||
def _integrity_payload(entry: AuditLog) -> str:
|
||||
ts = entry.timestamp
|
||||
if ts is None:
|
||||
ts = datetime.now(timezone.utc)
|
||||
user_part = str(entry.user_id) if entry.user_id else ""
|
||||
entity_type = entry.entity_type or ""
|
||||
entity_id = entry.entity_id or ""
|
||||
return f"{user_part}:{entry.action}:{entity_type}:{entity_id}:{ts.isoformat()}"
|
||||
|
||||
|
||||
def compute_integrity_hash(entry: AuditLog) -> str:
|
||||
"""Return the SHA-256 hex digest for an audit log entry."""
|
||||
return hashlib.sha256(_integrity_payload(entry).encode()).hexdigest()
|
||||
|
||||
|
||||
def verify_audit_integrity(entry: AuditLog) -> bool:
|
||||
"""Return whether the stored hash matches the entry's current fields."""
|
||||
if not entry.integrity_hash:
|
||||
return False
|
||||
return entry.integrity_hash == compute_integrity_hash(entry)
|
||||
|
||||
|
||||
def log_action(
|
||||
db: Session,
|
||||
user_id,
|
||||
action: str,
|
||||
entity_type: str = None,
|
||||
entity_id: str = None,
|
||||
details: dict = None
|
||||
):
|
||||
"""
|
||||
Log an action to the audit log.
|
||||
entity_type: str | None = None,
|
||||
entity_id: str | None = None,
|
||||
details: dict | None = None,
|
||||
*,
|
||||
ip_address: str | None = None,
|
||||
user_agent: str | None = None,
|
||||
session_id: str | None = None,
|
||||
) -> AuditLog:
|
||||
"""Record an audit event. Does not commit — the caller owns the transaction."""
|
||||
ip = ip_address if ip_address is not None else request_ip.get("")
|
||||
ua = user_agent if user_agent is not None else request_user_agent.get("")
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_id: UUID of the user performing the action (can be None for system actions)
|
||||
action: Description of the action (e.g., "create_test", "validate_technique")
|
||||
entity_type: Type of entity affected (e.g., "technique", "test", "user")
|
||||
entity_id: ID of the entity affected
|
||||
details: Additional details as a dictionary (stored as JSONB)
|
||||
"""
|
||||
log = AuditLog(
|
||||
entry = AuditLog(
|
||||
user_id=user_id,
|
||||
action=action,
|
||||
entity_type=entity_type,
|
||||
entity_id=str(entity_id) if entity_id else None,
|
||||
details=details,
|
||||
ip_address=ip or None,
|
||||
user_agent=ua or None,
|
||||
session_id=session_id,
|
||||
)
|
||||
db.add(log)
|
||||
db.add(entry)
|
||||
db.flush()
|
||||
entry.integrity_hash = compute_integrity_hash(entry)
|
||||
return entry
|
||||
|
||||
84
backend/tests/test_audit_trail.py
Normal file
84
backend/tests/test_audit_trail.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""Tests for enhanced audit trail (IP, user-agent, integrity hash)."""
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from app.middleware.request_context import request_ip, request_user_agent
|
||||
from app.models.audit import AuditLog
|
||||
from app.services.audit_service import (
|
||||
compute_integrity_hash,
|
||||
log_action,
|
||||
verify_audit_integrity,
|
||||
)
|
||||
from app.jobs.retention_job import apply_retention_policies
|
||||
|
||||
|
||||
class TestAuditIntegrity:
|
||||
def test_integrity_hash_set_on_log(self, db):
|
||||
entry = log_action(
|
||||
db,
|
||||
user_id=None,
|
||||
action="test_action",
|
||||
entity_type="test",
|
||||
entity_id="abc",
|
||||
ip_address="10.0.0.1",
|
||||
user_agent="pytest",
|
||||
)
|
||||
db.commit()
|
||||
db.refresh(entry)
|
||||
assert entry.integrity_hash
|
||||
assert len(entry.integrity_hash) == 64
|
||||
assert verify_audit_integrity(entry)
|
||||
|
||||
def test_tampered_entry_fails_integrity(self, db):
|
||||
entry = log_action(
|
||||
db,
|
||||
user_id=None,
|
||||
action="test_action",
|
||||
entity_type="test",
|
||||
entity_id="abc",
|
||||
)
|
||||
db.commit()
|
||||
entry.entity_id = "tampered"
|
||||
assert not verify_audit_integrity(entry)
|
||||
|
||||
def test_recomputed_hash_matches_stored(self, db):
|
||||
entry = log_action(db, None, "update", "user", "1")
|
||||
db.commit()
|
||||
assert entry.integrity_hash == compute_integrity_hash(entry)
|
||||
|
||||
|
||||
class TestRequestContext:
|
||||
def test_context_vars_used_by_log_action(self, db):
|
||||
token_ip = request_ip.set("203.0.113.42")
|
||||
token_ua = request_user_agent.set("AegisTestClient/1.0")
|
||||
try:
|
||||
entry = log_action(db, None, "ctx_action", "system", None)
|
||||
db.commit()
|
||||
assert entry.ip_address == "203.0.113.42"
|
||||
assert entry.user_agent == "AegisTestClient/1.0"
|
||||
finally:
|
||||
request_ip.reset(token_ip)
|
||||
request_user_agent.reset(token_ua)
|
||||
|
||||
|
||||
class TestRetentionJob:
|
||||
def test_deletes_old_audit_logs(self, db):
|
||||
old = AuditLog(
|
||||
action="old",
|
||||
entity_type="system",
|
||||
timestamp=datetime.now(timezone.utc) - timedelta(days=800),
|
||||
)
|
||||
recent = AuditLog(
|
||||
action="recent",
|
||||
entity_type="system",
|
||||
timestamp=datetime.now(timezone.utc) - timedelta(days=1),
|
||||
)
|
||||
db.add_all([old, recent])
|
||||
db.commit()
|
||||
|
||||
summary = apply_retention_policies(db)
|
||||
assert summary["audit_logs_deleted"] >= 1
|
||||
remaining = db.query(AuditLog).all()
|
||||
assert all(log.action != "old" for log in remaining)
|
||||
Reference in New Issue
Block a user