refactor(core): introduce Unit of Work and remove commits from services
Some checks failed
Aegis CI / lint-and-test (push) Has been cancelled

- Add UnitOfWork context manager in domain/unit_of_work.py with commit/rollback/flush API and auto-rollback on exception

- Remove all db.commit() from test_workflow_service (8 calls), notification_service (4 calls), status_service (1 call)

- Services now only stage changes via db.add/db.flush; caller owns the transaction boundary

- Update routers/tests.py: wrap 9 workflow endpoints in UnitOfWork context managers

- Update routers/notifications.py: wrap mark_as_read and mark_all_as_read in UnitOfWork
This commit is contained in:
2026-02-18 12:51:55 +01:00
parent 98e8ca1eef
commit bfce1a8a0e
6 changed files with 123 additions and 74 deletions

View File

@@ -0,0 +1,47 @@
"""Unit of Work — wraps a SQLAlchemy session for explicit transaction control.
Usage in routers::
with UnitOfWork(db) as uow:
service_a(db, ...)
service_b(db, ...)
uow.commit() # single commit for the entire operation
If an exception propagates, ``__exit__`` issues a rollback automatically.
Services should **never** call ``db.commit()``; they use ``db.add()`` /
``db.flush()`` to stage work and let the caller decide when to commit.
"""
from __future__ import annotations
from sqlalchemy.orm import Session
class UnitOfWork:
"""Lightweight transaction wrapper around an existing SQLAlchemy session."""
def __init__(self, session: Session) -> None:
self._session = session
# -- context manager -----------------------------------------------------
def __enter__(self) -> "UnitOfWork":
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
if exc_type is not None:
self.rollback()
# -- public API ----------------------------------------------------------
def commit(self) -> None:
"""Flush pending changes and commit the transaction."""
self._session.commit()
def rollback(self) -> None:
"""Roll back the current transaction."""
self._session.rollback()
def flush(self) -> None:
"""Flush pending changes without committing (useful for getting IDs)."""
self._session.flush()

View File

@@ -15,6 +15,7 @@ from sqlalchemy.orm import Session
from app.database import get_db from app.database import get_db
from app.dependencies.auth import get_current_user from app.dependencies.auth import get_current_user
from app.domain.unit_of_work import UnitOfWork
from app.models.notification import Notification from app.models.notification import Notification
from app.models.user import User from app.models.user import User
from app.schemas.notification import NotificationOut, UnreadCountOut from app.schemas.notification import NotificationOut, UnreadCountOut
@@ -78,12 +79,14 @@ def read_notification(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ):
"""Mark a single notification as read.""" """Mark a single notification as read."""
success = mark_as_read(db, notification_id, current_user.id) with UnitOfWork(db) as uow:
if not success: success = mark_as_read(db, notification_id, current_user.id)
raise HTTPException( if not success:
status_code=status.HTTP_404_NOT_FOUND, raise HTTPException(
detail="Notification not found", status_code=status.HTTP_404_NOT_FOUND,
) detail="Notification not found",
)
uow.commit()
notif = db.query(Notification).filter(Notification.id == notification_id).first() notif = db.query(Notification).filter(Notification.id == notification_id).first()
return notif return notif
@@ -99,5 +102,7 @@ def read_all_notifications(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ):
"""Mark all notifications for the current user as read.""" """Mark all notifications for the current user as read."""
count = mark_all_as_read(db, current_user.id) with UnitOfWork(db) as uow:
count = mark_all_as_read(db, current_user.id)
uow.commit()
return {"detail": f"Marked {count} notifications as read"} return {"detail": f"Marked {count} notifications as read"}

View File

@@ -43,6 +43,7 @@ from app.schemas.test import (
TestRemediationUpdate, TestRemediationUpdate,
) )
from app.schemas.test_template import TestTemplateInstantiate from app.schemas.test_template import TestTemplateInstantiate
from app.domain.unit_of_work import UnitOfWork
from app.services.audit_service import log_action from app.services.audit_service import log_action
from app.services.status_service import recalculate_technique_status from app.services.status_service import recalculate_technique_status
from app.services.test_workflow_service import ( from app.services.test_workflow_service import (
@@ -434,7 +435,9 @@ def start_execution(
): ):
"""Move a test from ``draft`` to ``red_executing``.""" """Move a test from ``draft`` to ``red_executing``."""
test = _get_test_or_404(db, test_id) test = _get_test_or_404(db, test_id)
test = wf_start_execution(db, test, current_user) with UnitOfWork(db) as uow:
test = wf_start_execution(db, test, current_user)
uow.commit()
db.refresh(test) db.refresh(test)
return test return test
@@ -452,7 +455,9 @@ def submit_red(
): ):
"""Red Team finalises — move from ``red_executing`` to ``blue_evaluating``.""" """Red Team finalises — move from ``red_executing`` to ``blue_evaluating``."""
test = _get_test_or_404(db, test_id) test = _get_test_or_404(db, test_id)
test = wf_submit_red(db, test, current_user) with UnitOfWork(db) as uow:
test = wf_submit_red(db, test, current_user)
uow.commit()
db.refresh(test) db.refresh(test)
return test return test
@@ -470,7 +475,9 @@ def submit_blue(
): ):
"""Blue Team finalises — move from ``blue_evaluating`` to ``in_review``.""" """Blue Team finalises — move from ``blue_evaluating`` to ``in_review``."""
test = _get_test_or_404(db, test_id) test = _get_test_or_404(db, test_id)
test = wf_submit_blue(db, test, current_user) with UnitOfWork(db) as uow:
test = wf_submit_blue(db, test, current_user)
uow.commit()
db.refresh(test) db.refresh(test)
return test return test
@@ -488,7 +495,9 @@ def pause_timer(
): ):
"""Pause the running timer for the current phase (red_executing or blue_evaluating).""" """Pause the running timer for the current phase (red_executing or blue_evaluating)."""
test = _get_test_or_404(db, test_id) test = _get_test_or_404(db, test_id)
test = wf_pause_timer(db, test, current_user) with UnitOfWork(db) as uow:
test = wf_pause_timer(db, test, current_user)
uow.commit()
db.refresh(test) db.refresh(test)
return test return test
@@ -506,7 +515,9 @@ def resume_timer(
): ):
"""Resume the paused timer for the current phase.""" """Resume the paused timer for the current phase."""
test = _get_test_or_404(db, test_id) test = _get_test_or_404(db, test_id)
test = wf_resume_timer(db, test, current_user) with UnitOfWork(db) as uow:
test = wf_resume_timer(db, test, current_user)
uow.commit()
db.refresh(test) db.refresh(test)
return test return test
@@ -525,16 +536,15 @@ def validate_red(
): ):
"""Red Lead approves or rejects the red side of a test.""" """Red Lead approves or rejects the red side of a test."""
test = _get_test_with_technique(db, test_id) test = _get_test_with_technique(db, test_id)
test = wf_validate_red( with UnitOfWork(db) as uow:
db, test, current_user, test = wf_validate_red(
validation_status=payload.red_validation_status, db, test, current_user,
notes=payload.red_validation_notes, validation_status=payload.red_validation_status,
) notes=payload.red_validation_notes,
)
# Recalculate technique status if test reached a terminal state if test.state in (TestState.validated, TestState.rejected):
if test.state in (TestState.validated, TestState.rejected): recalculate_technique_status(db, test.technique)
recalculate_technique_status(db, test.technique) uow.commit()
db.refresh(test) db.refresh(test)
return test return test
@@ -553,16 +563,15 @@ def validate_blue(
): ):
"""Blue Lead approves or rejects the blue side of a test.""" """Blue Lead approves or rejects the blue side of a test."""
test = _get_test_with_technique(db, test_id) test = _get_test_with_technique(db, test_id)
test = wf_validate_blue( with UnitOfWork(db) as uow:
db, test, current_user, test = wf_validate_blue(
validation_status=payload.blue_validation_status, db, test, current_user,
notes=payload.blue_validation_notes, validation_status=payload.blue_validation_status,
) notes=payload.blue_validation_notes,
)
# Recalculate technique status if test reached a terminal state if test.state in (TestState.validated, TestState.rejected):
if test.state in (TestState.validated, TestState.rejected): recalculate_technique_status(db, test.technique)
recalculate_technique_status(db, test.technique) uow.commit()
db.refresh(test) db.refresh(test)
return test return test
@@ -580,7 +589,9 @@ def reopen(
): ):
"""Reopen a rejected test, moving it back to ``draft``.""" """Reopen a rejected test, moving it back to ``draft``."""
test = _get_test_or_404(db, test_id) test = _get_test_or_404(db, test_id)
test = wf_reopen(db, test, current_user) with UnitOfWork(db) as uow:
test = wf_reopen(db, test, current_user)
uow.commit()
db.refresh(test) db.refresh(test)
return test return test
@@ -610,25 +621,23 @@ def update_remediation(
for field, value in update_data.items(): for field, value in update_data.items():
setattr(test, field, value) setattr(test, field, value)
db.commit() with UnitOfWork(db) as uow:
log_action(
db,
user_id=current_user.id,
action="update_remediation",
entity_type="test",
entity_id=test.id,
details={"updated_fields": list(update_data.keys())},
)
new_status = update_data.get("remediation_status")
if new_status == "completed" and old_remediation_status != "completed":
wf_handle_remediation(db, test, current_user)
uow.commit()
db.refresh(test) db.refresh(test)
log_action(
db,
user_id=current_user.id,
action="update_remediation",
entity_type="test",
entity_id=test.id,
details={"updated_fields": list(update_data.keys())},
)
# Auto-create retest when remediation is marked completed
new_status = update_data.get("remediation_status")
if new_status == "completed" and old_remediation_status != "completed":
retest = wf_handle_remediation(db, test, current_user)
if retest:
db.refresh(test)
return test return test

View File

@@ -2,6 +2,9 @@
Provides helpers for generating notifications automatically when test Provides helpers for generating notifications automatically when test
state changes occur, plus CRUD for the notifications API. state changes occur, plus CRUD for the notifications API.
Functions in this module stage changes via ``db.add()`` / ``db.flush()``
but do **not** commit. The caller is responsible for committing.
""" """
import uuid import uuid
@@ -38,8 +41,7 @@ def create_notification(
entity_id=entity_id, entity_id=entity_id,
) )
db.add(notif) db.add(notif)
db.commit() db.flush()
db.refresh(notif)
return notif return notif
@@ -53,7 +55,6 @@ def mark_as_read(db: Session, notification_id: uuid.UUID, user_id: uuid.UUID) ->
if notif is None: if notif is None:
return False return False
notif.read = True notif.read = True
db.commit()
return True return True
@@ -64,7 +65,6 @@ def mark_all_as_read(db: Session, user_id: uuid.UUID) -> int:
.filter(Notification.user_id == user_id, Notification.read == False) # noqa: E712 .filter(Notification.user_id == user_id, Notification.read == False) # noqa: E712
.update({"read": True}) .update({"read": True})
) )
db.commit()
return count return count
@@ -88,7 +88,6 @@ def cleanup_old_notifications(db: Session, days: int = 90) -> int:
) )
.delete() .delete()
) )
db.commit()
return count return count

View File

@@ -4,6 +4,9 @@ based on the state and result of its associated tests.
V2 rules account for dual Red/Blue validation and use V2 rules account for dual Red/Blue validation and use
``detection_result`` (filled by Blue Team) instead of the legacy ``detection_result`` (filled by Blue Team) instead of the legacy
``result`` field. ``result`` field.
This function mutates the technique but does **not** commit.
The caller is responsible for committing the session.
""" """
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -42,5 +45,3 @@ def recalculate_technique_status(db: Session, technique: Technique) -> None:
technique.status_global = TechniqueStatus.partial technique.status_global = TechniqueStatus.partial
else: else:
technique.status_global = TechniqueStatus.in_progress technique.status_global = TechniqueStatus.in_progress
db.commit()

View File

@@ -7,8 +7,9 @@ for each step in the test lifecycle:
rejected → draft rejected → draft
Every public function validates the transition, mutates the test, writes an Every public function validates the transition, mutates the test, and writes
audit-log entry, and commits the session. an audit-log entry. The caller (router) is responsible for committing the
session via the Unit of Work pattern.
""" """
import logging import logging
@@ -122,7 +123,6 @@ def start_execution(db: Session, test: Test, user: User) -> Test:
) )
test.execution_date = now test.execution_date = now
test.red_started_at = now test.red_started_at = now
db.commit()
return test return test
@@ -161,7 +161,6 @@ def submit_red_evidence(db: Session, test: Test, user: User) -> Test:
# Start Blue Team timer # Start Blue Team timer
test.blue_started_at = now test.blue_started_at = now
test.blue_paused_seconds = 0 test.blue_paused_seconds = 0
db.commit()
return test return test
@@ -196,7 +195,6 @@ def submit_blue_evidence(db: Session, test: Test, user: User) -> Test:
description=f"Blue Team evaluation: {test.name}", description=f"Blue Team evaluation: {test.name}",
) )
db.commit()
return test return test
@@ -222,7 +220,6 @@ def pause_timer(db: Session, test: Test, user: User) -> Test:
entity_id=test.id, entity_id=test.id,
details={"state": test.state.value}, details={"state": test.state.value},
) )
db.commit()
return test return test
@@ -252,7 +249,6 @@ def resume_timer(db: Session, test: Test, user: User) -> Test:
entity_id=test.id, entity_id=test.id,
details={"paused_seconds": paused_seconds, "state": test.state.value}, details={"paused_seconds": paused_seconds, "state": test.state.value},
) )
db.commit()
return test return test
@@ -421,14 +417,12 @@ def check_dual_validation(db: Session, test: Test) -> Test:
if red_status == "rejected" or blue_status == "rejected": if red_status == "rejected" or blue_status == "rejected":
test.state = TestState.rejected test.state = TestState.rejected
db.commit()
try: try:
notify_test_state_change(db, test, "rejected") notify_test_state_change(db, test, "rejected")
except Exception as e: except Exception as e:
logger.warning("Notification failed for test %s (rejected): %s", test.id, e, exc_info=True) logger.warning("Notification failed for test %s (rejected): %s", test.id, e, exc_info=True)
elif red_status == "approved" and blue_status == "approved": elif red_status == "approved" and blue_status == "approved":
test.state = TestState.validated test.state = TestState.validated
db.commit()
# Invalidate cached scores — a validation changes org-level numbers # Invalidate cached scores — a validation changes org-level numbers
try: try:
from app.services.score_cache import invalidate from app.services.score_cache import invalidate
@@ -439,10 +433,6 @@ def check_dual_validation(db: Session, test: Test) -> Test:
notify_test_state_change(db, test, "validated") notify_test_state_change(db, test, "validated")
except Exception as e: except Exception as e:
logger.warning("Notification failed for test %s (validated): %s", test.id, e, exc_info=True) logger.warning("Notification failed for test %s (validated): %s", test.id, e, exc_info=True)
else:
# One side hasn't voted yet — stay in_review, just flush
db.commit()
return test return test
@@ -533,8 +523,7 @@ def handle_remediation_completed(db: Session, test: Test, user: User) -> Test |
entity_id=retest.id, entity_id=retest.id,
) )
db.commit() db.flush()
db.refresh(retest)
return retest return retest
@@ -599,5 +588,4 @@ def reopen_test(db: Session, test: Test, user: User) -> Test:
test.red_paused_seconds = 0 test.red_paused_seconds = 0
test.blue_paused_seconds = 0 test.blue_paused_seconds = 0
db.commit()
return test return test