From c0aff4cbebf27baaad469003fde2f50e75d2b95c Mon Sep 17 00:00:00 2001 From: Kitos Date: Mon, 18 May 2026 14:16:18 +0200 Subject: [PATCH] feat(audit): enhanced audit trail with IP, user-agent and integrity hash [FASE-3.1] --- backend/app/middleware/request_context.py | 26 +++++++ backend/app/models/audit.py | 4 ++ backend/app/services/audit_service.py | 67 +++++++++++++----- backend/tests/test_audit_trail.py | 84 +++++++++++++++++++++++ 4 files changed, 164 insertions(+), 17 deletions(-) create mode 100644 backend/app/middleware/request_context.py create mode 100644 backend/tests/test_audit_trail.py diff --git a/backend/app/middleware/request_context.py b/backend/app/middleware/request_context.py new file mode 100644 index 0000000..f49ef57 --- /dev/null +++ b/backend/app/middleware/request_context.py @@ -0,0 +1,26 @@ +"""Request context middleware — captures client IP and User-Agent per request.""" + +from contextvars import ContextVar + +from fastapi import Request +from starlette.middleware.base import BaseHTTPMiddleware + +request_ip: ContextVar[str] = ContextVar("request_ip", default="") +request_user_agent: ContextVar[str] = ContextVar("request_user_agent", default="") + + +def resolve_client_ip(request: Request) -> str: + """Extract the client IP, honouring ``X-Forwarded-For`` when present.""" + forwarded = request.headers.get("X-Forwarded-For") + if forwarded: + return forwarded.split(",")[0].strip() + if request.client: + return request.client.host + return "unknown" + + +class RequestContextMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + request_ip.set(resolve_client_ip(request)) + request_user_agent.set(request.headers.get("User-Agent", "")) + return await call_next(request) diff --git a/backend/app/models/audit.py b/backend/app/models/audit.py index 90aed66..dda16b5 100644 --- a/backend/app/models/audit.py +++ b/backend/app/models/audit.py @@ -22,6 +22,10 @@ class AuditLog(Base): entity_id = Column(String, nullable=True) timestamp = Column(DateTime(timezone=True), server_default=func.now()) details = Column(JSONB, nullable=True) + ip_address = Column(String(45), nullable=True) + user_agent = Column(String(500), nullable=True) + integrity_hash = Column(String(64), nullable=True) + session_id = Column(String(100), nullable=True) # Relationships user = relationship("User") diff --git a/backend/app/services/audit_service.py b/backend/app/services/audit_service.py index 3ee58f5..bf1c3a6 100644 --- a/backend/app/services/audit_service.py +++ b/backend/app/services/audit_service.py @@ -1,32 +1,65 @@ +"""Audit logging with request context and integrity hashing.""" + +from __future__ import annotations + +import hashlib +from datetime import datetime, timezone + from sqlalchemy.orm import Session +from app.middleware.request_context import request_ip, request_user_agent from app.models.audit import AuditLog +def _integrity_payload(entry: AuditLog) -> str: + ts = entry.timestamp + if ts is None: + ts = datetime.now(timezone.utc) + user_part = str(entry.user_id) if entry.user_id else "" + entity_type = entry.entity_type or "" + entity_id = entry.entity_id or "" + return f"{user_part}:{entry.action}:{entity_type}:{entity_id}:{ts.isoformat()}" + + +def compute_integrity_hash(entry: AuditLog) -> str: + """Return the SHA-256 hex digest for an audit log entry.""" + return hashlib.sha256(_integrity_payload(entry).encode()).hexdigest() + + +def verify_audit_integrity(entry: AuditLog) -> bool: + """Return whether the stored hash matches the entry's current fields.""" + if not entry.integrity_hash: + return False + return entry.integrity_hash == compute_integrity_hash(entry) + + def log_action( db: Session, user_id, action: str, - entity_type: str = None, - entity_id: str = None, - details: dict = None -): - """ - Log an action to the audit log. - - Args: - db: Database session - user_id: UUID of the user performing the action (can be None for system actions) - action: Description of the action (e.g., "create_test", "validate_technique") - entity_type: Type of entity affected (e.g., "technique", "test", "user") - entity_id: ID of the entity affected - details: Additional details as a dictionary (stored as JSONB) - """ - log = AuditLog( + entity_type: str | None = None, + entity_id: str | None = None, + details: dict | None = None, + *, + ip_address: str | None = None, + user_agent: str | None = None, + session_id: str | None = None, +) -> AuditLog: + """Record an audit event. Does not commit — the caller owns the transaction.""" + ip = ip_address if ip_address is not None else request_ip.get("") + ua = user_agent if user_agent is not None else request_user_agent.get("") + + entry = AuditLog( user_id=user_id, action=action, entity_type=entity_type, entity_id=str(entity_id) if entity_id else None, details=details, + ip_address=ip or None, + user_agent=ua or None, + session_id=session_id, ) - db.add(log) + db.add(entry) + db.flush() + entry.integrity_hash = compute_integrity_hash(entry) + return entry diff --git a/backend/tests/test_audit_trail.py b/backend/tests/test_audit_trail.py new file mode 100644 index 0000000..2003183 --- /dev/null +++ b/backend/tests/test_audit_trail.py @@ -0,0 +1,84 @@ +"""Tests for enhanced audit trail (IP, user-agent, integrity hash).""" + +from datetime import datetime, timedelta, timezone + +import pytest + +from app.middleware.request_context import request_ip, request_user_agent +from app.models.audit import AuditLog +from app.services.audit_service import ( + compute_integrity_hash, + log_action, + verify_audit_integrity, +) +from app.jobs.retention_job import apply_retention_policies + + +class TestAuditIntegrity: + def test_integrity_hash_set_on_log(self, db): + entry = log_action( + db, + user_id=None, + action="test_action", + entity_type="test", + entity_id="abc", + ip_address="10.0.0.1", + user_agent="pytest", + ) + db.commit() + db.refresh(entry) + assert entry.integrity_hash + assert len(entry.integrity_hash) == 64 + assert verify_audit_integrity(entry) + + def test_tampered_entry_fails_integrity(self, db): + entry = log_action( + db, + user_id=None, + action="test_action", + entity_type="test", + entity_id="abc", + ) + db.commit() + entry.entity_id = "tampered" + assert not verify_audit_integrity(entry) + + def test_recomputed_hash_matches_stored(self, db): + entry = log_action(db, None, "update", "user", "1") + db.commit() + assert entry.integrity_hash == compute_integrity_hash(entry) + + +class TestRequestContext: + def test_context_vars_used_by_log_action(self, db): + token_ip = request_ip.set("203.0.113.42") + token_ua = request_user_agent.set("AegisTestClient/1.0") + try: + entry = log_action(db, None, "ctx_action", "system", None) + db.commit() + assert entry.ip_address == "203.0.113.42" + assert entry.user_agent == "AegisTestClient/1.0" + finally: + request_ip.reset(token_ip) + request_user_agent.reset(token_ua) + + +class TestRetentionJob: + def test_deletes_old_audit_logs(self, db): + old = AuditLog( + action="old", + entity_type="system", + timestamp=datetime.now(timezone.utc) - timedelta(days=800), + ) + recent = AuditLog( + action="recent", + entity_type="system", + timestamp=datetime.now(timezone.utc) - timedelta(days=1), + ) + db.add_all([old, recent]) + db.commit() + + summary = apply_retention_policies(db) + assert summary["audit_logs_deleted"] >= 1 + remaining = db.query(AuditLog).all() + assert all(log.action != "old" for log in remaining)