feat(audit): enhanced audit trail with IP, user-agent and integrity hash [FASE-3.1]
This commit is contained in:
26
backend/app/middleware/request_context.py
Normal file
26
backend/app/middleware/request_context.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
"""Request context middleware — captures client IP and User-Agent per request."""
|
||||||
|
|
||||||
|
from contextvars import ContextVar
|
||||||
|
|
||||||
|
from fastapi import Request
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
|
||||||
|
request_ip: ContextVar[str] = ContextVar("request_ip", default="")
|
||||||
|
request_user_agent: ContextVar[str] = ContextVar("request_user_agent", default="")
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_client_ip(request: Request) -> str:
|
||||||
|
"""Extract the client IP, honouring ``X-Forwarded-For`` when present."""
|
||||||
|
forwarded = request.headers.get("X-Forwarded-For")
|
||||||
|
if forwarded:
|
||||||
|
return forwarded.split(",")[0].strip()
|
||||||
|
if request.client:
|
||||||
|
return request.client.host
|
||||||
|
return "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
class RequestContextMiddleware(BaseHTTPMiddleware):
|
||||||
|
async def dispatch(self, request: Request, call_next):
|
||||||
|
request_ip.set(resolve_client_ip(request))
|
||||||
|
request_user_agent.set(request.headers.get("User-Agent", ""))
|
||||||
|
return await call_next(request)
|
||||||
@@ -22,6 +22,10 @@ class AuditLog(Base):
|
|||||||
entity_id = Column(String, nullable=True)
|
entity_id = Column(String, nullable=True)
|
||||||
timestamp = Column(DateTime(timezone=True), server_default=func.now())
|
timestamp = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
details = Column(JSONB, nullable=True)
|
details = Column(JSONB, nullable=True)
|
||||||
|
ip_address = Column(String(45), nullable=True)
|
||||||
|
user_agent = Column(String(500), nullable=True)
|
||||||
|
integrity_hash = Column(String(64), nullable=True)
|
||||||
|
session_id = Column(String(100), nullable=True)
|
||||||
|
|
||||||
# Relationships
|
# Relationships
|
||||||
user = relationship("User")
|
user = relationship("User")
|
||||||
|
|||||||
@@ -1,32 +1,65 @@
|
|||||||
|
"""Audit logging with request context and integrity hashing."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.middleware.request_context import request_ip, request_user_agent
|
||||||
from app.models.audit import AuditLog
|
from app.models.audit import AuditLog
|
||||||
|
|
||||||
|
|
||||||
|
def _integrity_payload(entry: AuditLog) -> str:
|
||||||
|
ts = entry.timestamp
|
||||||
|
if ts is None:
|
||||||
|
ts = datetime.now(timezone.utc)
|
||||||
|
user_part = str(entry.user_id) if entry.user_id else ""
|
||||||
|
entity_type = entry.entity_type or ""
|
||||||
|
entity_id = entry.entity_id or ""
|
||||||
|
return f"{user_part}:{entry.action}:{entity_type}:{entity_id}:{ts.isoformat()}"
|
||||||
|
|
||||||
|
|
||||||
|
def compute_integrity_hash(entry: AuditLog) -> str:
|
||||||
|
"""Return the SHA-256 hex digest for an audit log entry."""
|
||||||
|
return hashlib.sha256(_integrity_payload(entry).encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def verify_audit_integrity(entry: AuditLog) -> bool:
|
||||||
|
"""Return whether the stored hash matches the entry's current fields."""
|
||||||
|
if not entry.integrity_hash:
|
||||||
|
return False
|
||||||
|
return entry.integrity_hash == compute_integrity_hash(entry)
|
||||||
|
|
||||||
|
|
||||||
def log_action(
|
def log_action(
|
||||||
db: Session,
|
db: Session,
|
||||||
user_id,
|
user_id,
|
||||||
action: str,
|
action: str,
|
||||||
entity_type: str = None,
|
entity_type: str | None = None,
|
||||||
entity_id: str = None,
|
entity_id: str | None = None,
|
||||||
details: dict = None
|
details: dict | None = None,
|
||||||
):
|
*,
|
||||||
"""
|
ip_address: str | None = None,
|
||||||
Log an action to the audit log.
|
user_agent: str | None = None,
|
||||||
|
session_id: str | None = None,
|
||||||
|
) -> AuditLog:
|
||||||
|
"""Record an audit event. Does not commit — the caller owns the transaction."""
|
||||||
|
ip = ip_address if ip_address is not None else request_ip.get("")
|
||||||
|
ua = user_agent if user_agent is not None else request_user_agent.get("")
|
||||||
|
|
||||||
Args:
|
entry = AuditLog(
|
||||||
db: Database session
|
|
||||||
user_id: UUID of the user performing the action (can be None for system actions)
|
|
||||||
action: Description of the action (e.g., "create_test", "validate_technique")
|
|
||||||
entity_type: Type of entity affected (e.g., "technique", "test", "user")
|
|
||||||
entity_id: ID of the entity affected
|
|
||||||
details: Additional details as a dictionary (stored as JSONB)
|
|
||||||
"""
|
|
||||||
log = AuditLog(
|
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
action=action,
|
action=action,
|
||||||
entity_type=entity_type,
|
entity_type=entity_type,
|
||||||
entity_id=str(entity_id) if entity_id else None,
|
entity_id=str(entity_id) if entity_id else None,
|
||||||
details=details,
|
details=details,
|
||||||
|
ip_address=ip or None,
|
||||||
|
user_agent=ua or None,
|
||||||
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
db.add(log)
|
db.add(entry)
|
||||||
|
db.flush()
|
||||||
|
entry.integrity_hash = compute_integrity_hash(entry)
|
||||||
|
return entry
|
||||||
|
|||||||
84
backend/tests/test_audit_trail.py
Normal file
84
backend/tests/test_audit_trail.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
"""Tests for enhanced audit trail (IP, user-agent, integrity hash)."""
|
||||||
|
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.middleware.request_context import request_ip, request_user_agent
|
||||||
|
from app.models.audit import AuditLog
|
||||||
|
from app.services.audit_service import (
|
||||||
|
compute_integrity_hash,
|
||||||
|
log_action,
|
||||||
|
verify_audit_integrity,
|
||||||
|
)
|
||||||
|
from app.jobs.retention_job import apply_retention_policies
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuditIntegrity:
|
||||||
|
def test_integrity_hash_set_on_log(self, db):
|
||||||
|
entry = log_action(
|
||||||
|
db,
|
||||||
|
user_id=None,
|
||||||
|
action="test_action",
|
||||||
|
entity_type="test",
|
||||||
|
entity_id="abc",
|
||||||
|
ip_address="10.0.0.1",
|
||||||
|
user_agent="pytest",
|
||||||
|
)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(entry)
|
||||||
|
assert entry.integrity_hash
|
||||||
|
assert len(entry.integrity_hash) == 64
|
||||||
|
assert verify_audit_integrity(entry)
|
||||||
|
|
||||||
|
def test_tampered_entry_fails_integrity(self, db):
|
||||||
|
entry = log_action(
|
||||||
|
db,
|
||||||
|
user_id=None,
|
||||||
|
action="test_action",
|
||||||
|
entity_type="test",
|
||||||
|
entity_id="abc",
|
||||||
|
)
|
||||||
|
db.commit()
|
||||||
|
entry.entity_id = "tampered"
|
||||||
|
assert not verify_audit_integrity(entry)
|
||||||
|
|
||||||
|
def test_recomputed_hash_matches_stored(self, db):
|
||||||
|
entry = log_action(db, None, "update", "user", "1")
|
||||||
|
db.commit()
|
||||||
|
assert entry.integrity_hash == compute_integrity_hash(entry)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRequestContext:
|
||||||
|
def test_context_vars_used_by_log_action(self, db):
|
||||||
|
token_ip = request_ip.set("203.0.113.42")
|
||||||
|
token_ua = request_user_agent.set("AegisTestClient/1.0")
|
||||||
|
try:
|
||||||
|
entry = log_action(db, None, "ctx_action", "system", None)
|
||||||
|
db.commit()
|
||||||
|
assert entry.ip_address == "203.0.113.42"
|
||||||
|
assert entry.user_agent == "AegisTestClient/1.0"
|
||||||
|
finally:
|
||||||
|
request_ip.reset(token_ip)
|
||||||
|
request_user_agent.reset(token_ua)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRetentionJob:
|
||||||
|
def test_deletes_old_audit_logs(self, db):
|
||||||
|
old = AuditLog(
|
||||||
|
action="old",
|
||||||
|
entity_type="system",
|
||||||
|
timestamp=datetime.now(timezone.utc) - timedelta(days=800),
|
||||||
|
)
|
||||||
|
recent = AuditLog(
|
||||||
|
action="recent",
|
||||||
|
entity_type="system",
|
||||||
|
timestamp=datetime.now(timezone.utc) - timedelta(days=1),
|
||||||
|
)
|
||||||
|
db.add_all([old, recent])
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
summary = apply_retention_policies(db)
|
||||||
|
assert summary["audit_logs_deleted"] >= 1
|
||||||
|
remaining = db.query(AuditLog).all()
|
||||||
|
assert all(log.action != "old" for log in remaining)
|
||||||
Reference in New Issue
Block a user