"""Unit of Work — wraps a SQLAlchemy session for explicit transaction control. Usage in routers:: with UnitOfWork(db) as uow: service_a(db, ...) service_b(db, ...) uow.commit() # single commit for the entire operation If an exception propagates, ``__exit__`` issues a rollback automatically. Services should **never** call ``db.commit()``; they use ``db.add()`` / ``db.flush()`` to stage work and let the caller decide when to commit. **Documented exceptions** (services that may commit internally): - Import services (atomic_import, sigma_import, etc.) — self-contained sync ops. - Background jobs (campaign_scheduler, intel_service, stale_detection, mitre_sync) — self-contained operations. - Self-contained batch ops (e.g. detection_rule_service.auto_associate_rules, snapshot_service.create_snapshot, campaign_service.generate_campaign_from_*, osint_enrichment_service.enrich_technique_with_cves). """ # Enable future language features for compatibility from __future__ import annotations # Import TracebackType from types from types import TracebackType # Import Session from sqlalchemy.orm from sqlalchemy.orm import Session # Define class UnitOfWork class UnitOfWork: """Lightweight transaction wrapper around an existing SQLAlchemy session.""" # Define function __init__ def __init__(self, session: Session) -> None: """Wrap an existing SQLAlchemy session in a Unit of Work. Args: session (Session): The active SQLAlchemy session to manage. Returns: None """ # Assign self._session = session self._session = session # -- context manager ----------------------------------------------------- def __enter__(self) -> "UnitOfWork": """Enter the runtime context, returning this UnitOfWork instance. Returns: UnitOfWork: The UnitOfWork itself, for use in ``with`` statements. """ # Return self return self # Define function __exit__ def __exit__( self, # Entry: exc_type exc_type: type[BaseException] | None, # Entry: exc_val exc_val: BaseException | None, # Entry: exc_tb exc_tb: TracebackType | None, ) -> None: """Exit the runtime context, rolling back if an exception propagated. Args: exc_type (type[BaseException] | None): Exception class, if raised. exc_val (BaseException | None): Exception instance, if raised. exc_tb (TracebackType | None): Traceback object, if an exception was raised. Returns: None """ # Check: exc_type is not None if exc_type is not None: # Call self.rollback() self.rollback() # -- public API ---------------------------------------------------------- def commit(self) -> None: """Flush pending changes and commit the transaction.""" # Call self._session.commit() self._session.commit() # Define function rollback def rollback(self) -> None: """Roll back the current transaction.""" # Call self._session.rollback() self._session.rollback() # Define function flush def flush(self) -> None: """Flush pending changes without committing (useful for getting IDs).""" # Call self._session.flush() self._session.flush()