diff --git a/backend/app/domain/unit_of_work.py b/backend/app/domain/unit_of_work.py new file mode 100644 index 0000000..b93f2b9 --- /dev/null +++ b/backend/app/domain/unit_of_work.py @@ -0,0 +1,47 @@ +"""Unit of Work — wraps a SQLAlchemy session for explicit transaction control. + +Usage in routers:: + + with UnitOfWork(db) as uow: + service_a(db, ...) + service_b(db, ...) + uow.commit() # single commit for the entire operation + +If an exception propagates, ``__exit__`` issues a rollback automatically. +Services should **never** call ``db.commit()``; they use ``db.add()`` / +``db.flush()`` to stage work and let the caller decide when to commit. +""" + +from __future__ import annotations + +from sqlalchemy.orm import Session + + +class UnitOfWork: + """Lightweight transaction wrapper around an existing SQLAlchemy session.""" + + def __init__(self, session: Session) -> None: + self._session = session + + # -- context manager ----------------------------------------------------- + + def __enter__(self) -> "UnitOfWork": + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + if exc_type is not None: + self.rollback() + + # -- public API ---------------------------------------------------------- + + def commit(self) -> None: + """Flush pending changes and commit the transaction.""" + self._session.commit() + + def rollback(self) -> None: + """Roll back the current transaction.""" + self._session.rollback() + + def flush(self) -> None: + """Flush pending changes without committing (useful for getting IDs).""" + self._session.flush() diff --git a/backend/app/routers/notifications.py b/backend/app/routers/notifications.py index 5f25cd6..fa7d6b2 100644 --- a/backend/app/routers/notifications.py +++ b/backend/app/routers/notifications.py @@ -15,6 +15,7 @@ from sqlalchemy.orm import Session from app.database import get_db from app.dependencies.auth import get_current_user +from app.domain.unit_of_work import UnitOfWork from app.models.notification import Notification from app.models.user import User from app.schemas.notification import NotificationOut, UnreadCountOut @@ -78,12 +79,14 @@ def read_notification( current_user: User = Depends(get_current_user), ): """Mark a single notification as read.""" - success = mark_as_read(db, notification_id, current_user.id) - if not success: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Notification not found", - ) + with UnitOfWork(db) as uow: + success = mark_as_read(db, notification_id, current_user.id) + if not success: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Notification not found", + ) + uow.commit() notif = db.query(Notification).filter(Notification.id == notification_id).first() return notif @@ -99,5 +102,7 @@ def read_all_notifications( current_user: User = Depends(get_current_user), ): """Mark all notifications for the current user as read.""" - count = mark_all_as_read(db, current_user.id) + with UnitOfWork(db) as uow: + count = mark_all_as_read(db, current_user.id) + uow.commit() return {"detail": f"Marked {count} notifications as read"} diff --git a/backend/app/routers/tests.py b/backend/app/routers/tests.py index 4ad1ca5..df96902 100644 --- a/backend/app/routers/tests.py +++ b/backend/app/routers/tests.py @@ -43,6 +43,7 @@ from app.schemas.test import ( TestRemediationUpdate, ) from app.schemas.test_template import TestTemplateInstantiate +from app.domain.unit_of_work import UnitOfWork from app.services.audit_service import log_action from app.services.status_service import recalculate_technique_status from app.services.test_workflow_service import ( @@ -434,7 +435,9 @@ def start_execution( ): """Move a test from ``draft`` to ``red_executing``.""" test = _get_test_or_404(db, test_id) - test = wf_start_execution(db, test, current_user) + with UnitOfWork(db) as uow: + test = wf_start_execution(db, test, current_user) + uow.commit() db.refresh(test) return test @@ -452,7 +455,9 @@ def submit_red( ): """Red Team finalises — move from ``red_executing`` to ``blue_evaluating``.""" test = _get_test_or_404(db, test_id) - test = wf_submit_red(db, test, current_user) + with UnitOfWork(db) as uow: + test = wf_submit_red(db, test, current_user) + uow.commit() db.refresh(test) return test @@ -470,7 +475,9 @@ def submit_blue( ): """Blue Team finalises — move from ``blue_evaluating`` to ``in_review``.""" test = _get_test_or_404(db, test_id) - test = wf_submit_blue(db, test, current_user) + with UnitOfWork(db) as uow: + test = wf_submit_blue(db, test, current_user) + uow.commit() db.refresh(test) return test @@ -488,7 +495,9 @@ def pause_timer( ): """Pause the running timer for the current phase (red_executing or blue_evaluating).""" test = _get_test_or_404(db, test_id) - test = wf_pause_timer(db, test, current_user) + with UnitOfWork(db) as uow: + test = wf_pause_timer(db, test, current_user) + uow.commit() db.refresh(test) return test @@ -506,7 +515,9 @@ def resume_timer( ): """Resume the paused timer for the current phase.""" test = _get_test_or_404(db, test_id) - test = wf_resume_timer(db, test, current_user) + with UnitOfWork(db) as uow: + test = wf_resume_timer(db, test, current_user) + uow.commit() db.refresh(test) return test @@ -525,16 +536,15 @@ def validate_red( ): """Red Lead approves or rejects the red side of a test.""" test = _get_test_with_technique(db, test_id) - test = wf_validate_red( - db, test, current_user, - validation_status=payload.red_validation_status, - notes=payload.red_validation_notes, - ) - - # Recalculate technique status if test reached a terminal state - if test.state in (TestState.validated, TestState.rejected): - recalculate_technique_status(db, test.technique) - + with UnitOfWork(db) as uow: + test = wf_validate_red( + db, test, current_user, + validation_status=payload.red_validation_status, + notes=payload.red_validation_notes, + ) + if test.state in (TestState.validated, TestState.rejected): + recalculate_technique_status(db, test.technique) + uow.commit() db.refresh(test) return test @@ -553,16 +563,15 @@ def validate_blue( ): """Blue Lead approves or rejects the blue side of a test.""" test = _get_test_with_technique(db, test_id) - test = wf_validate_blue( - db, test, current_user, - validation_status=payload.blue_validation_status, - notes=payload.blue_validation_notes, - ) - - # Recalculate technique status if test reached a terminal state - if test.state in (TestState.validated, TestState.rejected): - recalculate_technique_status(db, test.technique) - + with UnitOfWork(db) as uow: + test = wf_validate_blue( + db, test, current_user, + validation_status=payload.blue_validation_status, + notes=payload.blue_validation_notes, + ) + if test.state in (TestState.validated, TestState.rejected): + recalculate_technique_status(db, test.technique) + uow.commit() db.refresh(test) return test @@ -580,7 +589,9 @@ def reopen( ): """Reopen a rejected test, moving it back to ``draft``.""" test = _get_test_or_404(db, test_id) - test = wf_reopen(db, test, current_user) + with UnitOfWork(db) as uow: + test = wf_reopen(db, test, current_user) + uow.commit() db.refresh(test) return test @@ -610,25 +621,23 @@ def update_remediation( for field, value in update_data.items(): setattr(test, field, value) - db.commit() + with UnitOfWork(db) as uow: + log_action( + db, + user_id=current_user.id, + action="update_remediation", + entity_type="test", + entity_id=test.id, + details={"updated_fields": list(update_data.keys())}, + ) + + new_status = update_data.get("remediation_status") + if new_status == "completed" and old_remediation_status != "completed": + wf_handle_remediation(db, test, current_user) + + uow.commit() + db.refresh(test) - - log_action( - db, - user_id=current_user.id, - action="update_remediation", - entity_type="test", - entity_id=test.id, - details={"updated_fields": list(update_data.keys())}, - ) - - # Auto-create retest when remediation is marked completed - new_status = update_data.get("remediation_status") - if new_status == "completed" and old_remediation_status != "completed": - retest = wf_handle_remediation(db, test, current_user) - if retest: - db.refresh(test) - return test diff --git a/backend/app/services/notification_service.py b/backend/app/services/notification_service.py index df7c83c..86d523f 100644 --- a/backend/app/services/notification_service.py +++ b/backend/app/services/notification_service.py @@ -2,6 +2,9 @@ Provides helpers for generating notifications automatically when test state changes occur, plus CRUD for the notifications API. + +Functions in this module stage changes via ``db.add()`` / ``db.flush()`` +but do **not** commit. The caller is responsible for committing. """ import uuid @@ -38,8 +41,7 @@ def create_notification( entity_id=entity_id, ) db.add(notif) - db.commit() - db.refresh(notif) + db.flush() return notif @@ -53,7 +55,6 @@ def mark_as_read(db: Session, notification_id: uuid.UUID, user_id: uuid.UUID) -> if notif is None: return False notif.read = True - db.commit() return True @@ -64,7 +65,6 @@ def mark_all_as_read(db: Session, user_id: uuid.UUID) -> int: .filter(Notification.user_id == user_id, Notification.read == False) # noqa: E712 .update({"read": True}) ) - db.commit() return count @@ -88,7 +88,6 @@ def cleanup_old_notifications(db: Session, days: int = 90) -> int: ) .delete() ) - db.commit() return count diff --git a/backend/app/services/status_service.py b/backend/app/services/status_service.py index 50962c2..6fe251d 100644 --- a/backend/app/services/status_service.py +++ b/backend/app/services/status_service.py @@ -4,6 +4,9 @@ based on the state and result of its associated tests. V2 rules account for dual Red/Blue validation and use ``detection_result`` (filled by Blue Team) instead of the legacy ``result`` field. + +This function mutates the technique but does **not** commit. +The caller is responsible for committing the session. """ from sqlalchemy.orm import Session @@ -42,5 +45,3 @@ def recalculate_technique_status(db: Session, technique: Technique) -> None: technique.status_global = TechniqueStatus.partial else: technique.status_global = TechniqueStatus.in_progress - - db.commit() diff --git a/backend/app/services/test_workflow_service.py b/backend/app/services/test_workflow_service.py index cd89eb9..4e3ead6 100644 --- a/backend/app/services/test_workflow_service.py +++ b/backend/app/services/test_workflow_service.py @@ -7,8 +7,9 @@ for each step in the test lifecycle: ↓ rejected → draft -Every public function validates the transition, mutates the test, writes an -audit-log entry, and commits the session. +Every public function validates the transition, mutates the test, and writes +an audit-log entry. The caller (router) is responsible for committing the +session via the Unit of Work pattern. """ import logging @@ -122,7 +123,6 @@ def start_execution(db: Session, test: Test, user: User) -> Test: ) test.execution_date = now test.red_started_at = now - db.commit() return test @@ -161,7 +161,6 @@ def submit_red_evidence(db: Session, test: Test, user: User) -> Test: # Start Blue Team timer test.blue_started_at = now test.blue_paused_seconds = 0 - db.commit() return test @@ -196,7 +195,6 @@ def submit_blue_evidence(db: Session, test: Test, user: User) -> Test: description=f"Blue Team evaluation: {test.name}", ) - db.commit() return test @@ -222,7 +220,6 @@ def pause_timer(db: Session, test: Test, user: User) -> Test: entity_id=test.id, details={"state": test.state.value}, ) - db.commit() return test @@ -252,7 +249,6 @@ def resume_timer(db: Session, test: Test, user: User) -> Test: entity_id=test.id, details={"paused_seconds": paused_seconds, "state": test.state.value}, ) - db.commit() return test @@ -421,14 +417,12 @@ def check_dual_validation(db: Session, test: Test) -> Test: if red_status == "rejected" or blue_status == "rejected": test.state = TestState.rejected - db.commit() try: notify_test_state_change(db, test, "rejected") except Exception as e: logger.warning("Notification failed for test %s (rejected): %s", test.id, e, exc_info=True) elif red_status == "approved" and blue_status == "approved": test.state = TestState.validated - db.commit() # Invalidate cached scores — a validation changes org-level numbers try: from app.services.score_cache import invalidate @@ -439,10 +433,6 @@ def check_dual_validation(db: Session, test: Test) -> Test: notify_test_state_change(db, test, "validated") except Exception as e: logger.warning("Notification failed for test %s (validated): %s", test.id, e, exc_info=True) - else: - # One side hasn't voted yet — stay in_review, just flush - db.commit() - return test @@ -533,8 +523,7 @@ def handle_remediation_completed(db: Session, test: Test, user: User) -> Test | entity_id=retest.id, ) - db.commit() - db.refresh(retest) + db.flush() return retest @@ -599,5 +588,4 @@ def reopen_test(db: Session, test: Test, user: User) -> Test: test.red_paused_seconds = 0 test.blue_paused_seconds = 0 - db.commit() return test