refactor(core): introduce Unit of Work and remove commits from services
Some checks failed
Aegis CI / lint-and-test (push) Has been cancelled
Some checks failed
Aegis CI / lint-and-test (push) Has been cancelled
- Add UnitOfWork context manager in domain/unit_of_work.py with commit/rollback/flush API and auto-rollback on exception - Remove all db.commit() from test_workflow_service (8 calls), notification_service (4 calls), status_service (1 call) - Services now only stage changes via db.add/db.flush; caller owns the transaction boundary - Update routers/tests.py: wrap 9 workflow endpoints in UnitOfWork context managers - Update routers/notifications.py: wrap mark_as_read and mark_all_as_read in UnitOfWork
This commit is contained in:
47
backend/app/domain/unit_of_work.py
Normal file
47
backend/app/domain/unit_of_work.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
"""Unit of Work — wraps a SQLAlchemy session for explicit transaction control.
|
||||||
|
|
||||||
|
Usage in routers::
|
||||||
|
|
||||||
|
with UnitOfWork(db) as uow:
|
||||||
|
service_a(db, ...)
|
||||||
|
service_b(db, ...)
|
||||||
|
uow.commit() # single commit for the entire operation
|
||||||
|
|
||||||
|
If an exception propagates, ``__exit__`` issues a rollback automatically.
|
||||||
|
Services should **never** call ``db.commit()``; they use ``db.add()`` /
|
||||||
|
``db.flush()`` to stage work and let the caller decide when to commit.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
|
||||||
|
class UnitOfWork:
|
||||||
|
"""Lightweight transaction wrapper around an existing SQLAlchemy session."""
|
||||||
|
|
||||||
|
def __init__(self, session: Session) -> None:
|
||||||
|
self._session = session
|
||||||
|
|
||||||
|
# -- context manager -----------------------------------------------------
|
||||||
|
|
||||||
|
def __enter__(self) -> "UnitOfWork":
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
||||||
|
if exc_type is not None:
|
||||||
|
self.rollback()
|
||||||
|
|
||||||
|
# -- public API ----------------------------------------------------------
|
||||||
|
|
||||||
|
def commit(self) -> None:
|
||||||
|
"""Flush pending changes and commit the transaction."""
|
||||||
|
self._session.commit()
|
||||||
|
|
||||||
|
def rollback(self) -> None:
|
||||||
|
"""Roll back the current transaction."""
|
||||||
|
self._session.rollback()
|
||||||
|
|
||||||
|
def flush(self) -> None:
|
||||||
|
"""Flush pending changes without committing (useful for getting IDs)."""
|
||||||
|
self._session.flush()
|
||||||
@@ -15,6 +15,7 @@ from sqlalchemy.orm import Session
|
|||||||
|
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
from app.dependencies.auth import get_current_user
|
from app.dependencies.auth import get_current_user
|
||||||
|
from app.domain.unit_of_work import UnitOfWork
|
||||||
from app.models.notification import Notification
|
from app.models.notification import Notification
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.schemas.notification import NotificationOut, UnreadCountOut
|
from app.schemas.notification import NotificationOut, UnreadCountOut
|
||||||
@@ -78,12 +79,14 @@ def read_notification(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Mark a single notification as read."""
|
"""Mark a single notification as read."""
|
||||||
success = mark_as_read(db, notification_id, current_user.id)
|
with UnitOfWork(db) as uow:
|
||||||
if not success:
|
success = mark_as_read(db, notification_id, current_user.id)
|
||||||
raise HTTPException(
|
if not success:
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
raise HTTPException(
|
||||||
detail="Notification not found",
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
)
|
detail="Notification not found",
|
||||||
|
)
|
||||||
|
uow.commit()
|
||||||
notif = db.query(Notification).filter(Notification.id == notification_id).first()
|
notif = db.query(Notification).filter(Notification.id == notification_id).first()
|
||||||
return notif
|
return notif
|
||||||
|
|
||||||
@@ -99,5 +102,7 @@ def read_all_notifications(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Mark all notifications for the current user as read."""
|
"""Mark all notifications for the current user as read."""
|
||||||
count = mark_all_as_read(db, current_user.id)
|
with UnitOfWork(db) as uow:
|
||||||
|
count = mark_all_as_read(db, current_user.id)
|
||||||
|
uow.commit()
|
||||||
return {"detail": f"Marked {count} notifications as read"}
|
return {"detail": f"Marked {count} notifications as read"}
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ from app.schemas.test import (
|
|||||||
TestRemediationUpdate,
|
TestRemediationUpdate,
|
||||||
)
|
)
|
||||||
from app.schemas.test_template import TestTemplateInstantiate
|
from app.schemas.test_template import TestTemplateInstantiate
|
||||||
|
from app.domain.unit_of_work import UnitOfWork
|
||||||
from app.services.audit_service import log_action
|
from app.services.audit_service import log_action
|
||||||
from app.services.status_service import recalculate_technique_status
|
from app.services.status_service import recalculate_technique_status
|
||||||
from app.services.test_workflow_service import (
|
from app.services.test_workflow_service import (
|
||||||
@@ -434,7 +435,9 @@ def start_execution(
|
|||||||
):
|
):
|
||||||
"""Move a test from ``draft`` to ``red_executing``."""
|
"""Move a test from ``draft`` to ``red_executing``."""
|
||||||
test = _get_test_or_404(db, test_id)
|
test = _get_test_or_404(db, test_id)
|
||||||
test = wf_start_execution(db, test, current_user)
|
with UnitOfWork(db) as uow:
|
||||||
|
test = wf_start_execution(db, test, current_user)
|
||||||
|
uow.commit()
|
||||||
db.refresh(test)
|
db.refresh(test)
|
||||||
return test
|
return test
|
||||||
|
|
||||||
@@ -452,7 +455,9 @@ def submit_red(
|
|||||||
):
|
):
|
||||||
"""Red Team finalises — move from ``red_executing`` to ``blue_evaluating``."""
|
"""Red Team finalises — move from ``red_executing`` to ``blue_evaluating``."""
|
||||||
test = _get_test_or_404(db, test_id)
|
test = _get_test_or_404(db, test_id)
|
||||||
test = wf_submit_red(db, test, current_user)
|
with UnitOfWork(db) as uow:
|
||||||
|
test = wf_submit_red(db, test, current_user)
|
||||||
|
uow.commit()
|
||||||
db.refresh(test)
|
db.refresh(test)
|
||||||
return test
|
return test
|
||||||
|
|
||||||
@@ -470,7 +475,9 @@ def submit_blue(
|
|||||||
):
|
):
|
||||||
"""Blue Team finalises — move from ``blue_evaluating`` to ``in_review``."""
|
"""Blue Team finalises — move from ``blue_evaluating`` to ``in_review``."""
|
||||||
test = _get_test_or_404(db, test_id)
|
test = _get_test_or_404(db, test_id)
|
||||||
test = wf_submit_blue(db, test, current_user)
|
with UnitOfWork(db) as uow:
|
||||||
|
test = wf_submit_blue(db, test, current_user)
|
||||||
|
uow.commit()
|
||||||
db.refresh(test)
|
db.refresh(test)
|
||||||
return test
|
return test
|
||||||
|
|
||||||
@@ -488,7 +495,9 @@ def pause_timer(
|
|||||||
):
|
):
|
||||||
"""Pause the running timer for the current phase (red_executing or blue_evaluating)."""
|
"""Pause the running timer for the current phase (red_executing or blue_evaluating)."""
|
||||||
test = _get_test_or_404(db, test_id)
|
test = _get_test_or_404(db, test_id)
|
||||||
test = wf_pause_timer(db, test, current_user)
|
with UnitOfWork(db) as uow:
|
||||||
|
test = wf_pause_timer(db, test, current_user)
|
||||||
|
uow.commit()
|
||||||
db.refresh(test)
|
db.refresh(test)
|
||||||
return test
|
return test
|
||||||
|
|
||||||
@@ -506,7 +515,9 @@ def resume_timer(
|
|||||||
):
|
):
|
||||||
"""Resume the paused timer for the current phase."""
|
"""Resume the paused timer for the current phase."""
|
||||||
test = _get_test_or_404(db, test_id)
|
test = _get_test_or_404(db, test_id)
|
||||||
test = wf_resume_timer(db, test, current_user)
|
with UnitOfWork(db) as uow:
|
||||||
|
test = wf_resume_timer(db, test, current_user)
|
||||||
|
uow.commit()
|
||||||
db.refresh(test)
|
db.refresh(test)
|
||||||
return test
|
return test
|
||||||
|
|
||||||
@@ -525,16 +536,15 @@ def validate_red(
|
|||||||
):
|
):
|
||||||
"""Red Lead approves or rejects the red side of a test."""
|
"""Red Lead approves or rejects the red side of a test."""
|
||||||
test = _get_test_with_technique(db, test_id)
|
test = _get_test_with_technique(db, test_id)
|
||||||
test = wf_validate_red(
|
with UnitOfWork(db) as uow:
|
||||||
db, test, current_user,
|
test = wf_validate_red(
|
||||||
validation_status=payload.red_validation_status,
|
db, test, current_user,
|
||||||
notes=payload.red_validation_notes,
|
validation_status=payload.red_validation_status,
|
||||||
)
|
notes=payload.red_validation_notes,
|
||||||
|
)
|
||||||
# Recalculate technique status if test reached a terminal state
|
if test.state in (TestState.validated, TestState.rejected):
|
||||||
if test.state in (TestState.validated, TestState.rejected):
|
recalculate_technique_status(db, test.technique)
|
||||||
recalculate_technique_status(db, test.technique)
|
uow.commit()
|
||||||
|
|
||||||
db.refresh(test)
|
db.refresh(test)
|
||||||
return test
|
return test
|
||||||
|
|
||||||
@@ -553,16 +563,15 @@ def validate_blue(
|
|||||||
):
|
):
|
||||||
"""Blue Lead approves or rejects the blue side of a test."""
|
"""Blue Lead approves or rejects the blue side of a test."""
|
||||||
test = _get_test_with_technique(db, test_id)
|
test = _get_test_with_technique(db, test_id)
|
||||||
test = wf_validate_blue(
|
with UnitOfWork(db) as uow:
|
||||||
db, test, current_user,
|
test = wf_validate_blue(
|
||||||
validation_status=payload.blue_validation_status,
|
db, test, current_user,
|
||||||
notes=payload.blue_validation_notes,
|
validation_status=payload.blue_validation_status,
|
||||||
)
|
notes=payload.blue_validation_notes,
|
||||||
|
)
|
||||||
# Recalculate technique status if test reached a terminal state
|
if test.state in (TestState.validated, TestState.rejected):
|
||||||
if test.state in (TestState.validated, TestState.rejected):
|
recalculate_technique_status(db, test.technique)
|
||||||
recalculate_technique_status(db, test.technique)
|
uow.commit()
|
||||||
|
|
||||||
db.refresh(test)
|
db.refresh(test)
|
||||||
return test
|
return test
|
||||||
|
|
||||||
@@ -580,7 +589,9 @@ def reopen(
|
|||||||
):
|
):
|
||||||
"""Reopen a rejected test, moving it back to ``draft``."""
|
"""Reopen a rejected test, moving it back to ``draft``."""
|
||||||
test = _get_test_or_404(db, test_id)
|
test = _get_test_or_404(db, test_id)
|
||||||
test = wf_reopen(db, test, current_user)
|
with UnitOfWork(db) as uow:
|
||||||
|
test = wf_reopen(db, test, current_user)
|
||||||
|
uow.commit()
|
||||||
db.refresh(test)
|
db.refresh(test)
|
||||||
return test
|
return test
|
||||||
|
|
||||||
@@ -610,25 +621,23 @@ def update_remediation(
|
|||||||
for field, value in update_data.items():
|
for field, value in update_data.items():
|
||||||
setattr(test, field, value)
|
setattr(test, field, value)
|
||||||
|
|
||||||
db.commit()
|
with UnitOfWork(db) as uow:
|
||||||
|
log_action(
|
||||||
|
db,
|
||||||
|
user_id=current_user.id,
|
||||||
|
action="update_remediation",
|
||||||
|
entity_type="test",
|
||||||
|
entity_id=test.id,
|
||||||
|
details={"updated_fields": list(update_data.keys())},
|
||||||
|
)
|
||||||
|
|
||||||
|
new_status = update_data.get("remediation_status")
|
||||||
|
if new_status == "completed" and old_remediation_status != "completed":
|
||||||
|
wf_handle_remediation(db, test, current_user)
|
||||||
|
|
||||||
|
uow.commit()
|
||||||
|
|
||||||
db.refresh(test)
|
db.refresh(test)
|
||||||
|
|
||||||
log_action(
|
|
||||||
db,
|
|
||||||
user_id=current_user.id,
|
|
||||||
action="update_remediation",
|
|
||||||
entity_type="test",
|
|
||||||
entity_id=test.id,
|
|
||||||
details={"updated_fields": list(update_data.keys())},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Auto-create retest when remediation is marked completed
|
|
||||||
new_status = update_data.get("remediation_status")
|
|
||||||
if new_status == "completed" and old_remediation_status != "completed":
|
|
||||||
retest = wf_handle_remediation(db, test, current_user)
|
|
||||||
if retest:
|
|
||||||
db.refresh(test)
|
|
||||||
|
|
||||||
return test
|
return test
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,9 @@
|
|||||||
|
|
||||||
Provides helpers for generating notifications automatically when test
|
Provides helpers for generating notifications automatically when test
|
||||||
state changes occur, plus CRUD for the notifications API.
|
state changes occur, plus CRUD for the notifications API.
|
||||||
|
|
||||||
|
Functions in this module stage changes via ``db.add()`` / ``db.flush()``
|
||||||
|
but do **not** commit. The caller is responsible for committing.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
@@ -38,8 +41,7 @@ def create_notification(
|
|||||||
entity_id=entity_id,
|
entity_id=entity_id,
|
||||||
)
|
)
|
||||||
db.add(notif)
|
db.add(notif)
|
||||||
db.commit()
|
db.flush()
|
||||||
db.refresh(notif)
|
|
||||||
return notif
|
return notif
|
||||||
|
|
||||||
|
|
||||||
@@ -53,7 +55,6 @@ def mark_as_read(db: Session, notification_id: uuid.UUID, user_id: uuid.UUID) ->
|
|||||||
if notif is None:
|
if notif is None:
|
||||||
return False
|
return False
|
||||||
notif.read = True
|
notif.read = True
|
||||||
db.commit()
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
@@ -64,7 +65,6 @@ def mark_all_as_read(db: Session, user_id: uuid.UUID) -> int:
|
|||||||
.filter(Notification.user_id == user_id, Notification.read == False) # noqa: E712
|
.filter(Notification.user_id == user_id, Notification.read == False) # noqa: E712
|
||||||
.update({"read": True})
|
.update({"read": True})
|
||||||
)
|
)
|
||||||
db.commit()
|
|
||||||
return count
|
return count
|
||||||
|
|
||||||
|
|
||||||
@@ -88,7 +88,6 @@ def cleanup_old_notifications(db: Session, days: int = 90) -> int:
|
|||||||
)
|
)
|
||||||
.delete()
|
.delete()
|
||||||
)
|
)
|
||||||
db.commit()
|
|
||||||
return count
|
return count
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,9 @@ based on the state and result of its associated tests.
|
|||||||
V2 rules account for dual Red/Blue validation and use
|
V2 rules account for dual Red/Blue validation and use
|
||||||
``detection_result`` (filled by Blue Team) instead of the legacy
|
``detection_result`` (filled by Blue Team) instead of the legacy
|
||||||
``result`` field.
|
``result`` field.
|
||||||
|
|
||||||
|
This function mutates the technique but does **not** commit.
|
||||||
|
The caller is responsible for committing the session.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
@@ -42,5 +45,3 @@ def recalculate_technique_status(db: Session, technique: Technique) -> None:
|
|||||||
technique.status_global = TechniqueStatus.partial
|
technique.status_global = TechniqueStatus.partial
|
||||||
else:
|
else:
|
||||||
technique.status_global = TechniqueStatus.in_progress
|
technique.status_global = TechniqueStatus.in_progress
|
||||||
|
|
||||||
db.commit()
|
|
||||||
|
|||||||
@@ -7,8 +7,9 @@ for each step in the test lifecycle:
|
|||||||
↓
|
↓
|
||||||
rejected → draft
|
rejected → draft
|
||||||
|
|
||||||
Every public function validates the transition, mutates the test, writes an
|
Every public function validates the transition, mutates the test, and writes
|
||||||
audit-log entry, and commits the session.
|
an audit-log entry. The caller (router) is responsible for committing the
|
||||||
|
session via the Unit of Work pattern.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
@@ -122,7 +123,6 @@ def start_execution(db: Session, test: Test, user: User) -> Test:
|
|||||||
)
|
)
|
||||||
test.execution_date = now
|
test.execution_date = now
|
||||||
test.red_started_at = now
|
test.red_started_at = now
|
||||||
db.commit()
|
|
||||||
return test
|
return test
|
||||||
|
|
||||||
|
|
||||||
@@ -161,7 +161,6 @@ def submit_red_evidence(db: Session, test: Test, user: User) -> Test:
|
|||||||
# Start Blue Team timer
|
# Start Blue Team timer
|
||||||
test.blue_started_at = now
|
test.blue_started_at = now
|
||||||
test.blue_paused_seconds = 0
|
test.blue_paused_seconds = 0
|
||||||
db.commit()
|
|
||||||
return test
|
return test
|
||||||
|
|
||||||
|
|
||||||
@@ -196,7 +195,6 @@ def submit_blue_evidence(db: Session, test: Test, user: User) -> Test:
|
|||||||
description=f"Blue Team evaluation: {test.name}",
|
description=f"Blue Team evaluation: {test.name}",
|
||||||
)
|
)
|
||||||
|
|
||||||
db.commit()
|
|
||||||
return test
|
return test
|
||||||
|
|
||||||
|
|
||||||
@@ -222,7 +220,6 @@ def pause_timer(db: Session, test: Test, user: User) -> Test:
|
|||||||
entity_id=test.id,
|
entity_id=test.id,
|
||||||
details={"state": test.state.value},
|
details={"state": test.state.value},
|
||||||
)
|
)
|
||||||
db.commit()
|
|
||||||
return test
|
return test
|
||||||
|
|
||||||
|
|
||||||
@@ -252,7 +249,6 @@ def resume_timer(db: Session, test: Test, user: User) -> Test:
|
|||||||
entity_id=test.id,
|
entity_id=test.id,
|
||||||
details={"paused_seconds": paused_seconds, "state": test.state.value},
|
details={"paused_seconds": paused_seconds, "state": test.state.value},
|
||||||
)
|
)
|
||||||
db.commit()
|
|
||||||
return test
|
return test
|
||||||
|
|
||||||
|
|
||||||
@@ -421,14 +417,12 @@ def check_dual_validation(db: Session, test: Test) -> Test:
|
|||||||
|
|
||||||
if red_status == "rejected" or blue_status == "rejected":
|
if red_status == "rejected" or blue_status == "rejected":
|
||||||
test.state = TestState.rejected
|
test.state = TestState.rejected
|
||||||
db.commit()
|
|
||||||
try:
|
try:
|
||||||
notify_test_state_change(db, test, "rejected")
|
notify_test_state_change(db, test, "rejected")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Notification failed for test %s (rejected): %s", test.id, e, exc_info=True)
|
logger.warning("Notification failed for test %s (rejected): %s", test.id, e, exc_info=True)
|
||||||
elif red_status == "approved" and blue_status == "approved":
|
elif red_status == "approved" and blue_status == "approved":
|
||||||
test.state = TestState.validated
|
test.state = TestState.validated
|
||||||
db.commit()
|
|
||||||
# Invalidate cached scores — a validation changes org-level numbers
|
# Invalidate cached scores — a validation changes org-level numbers
|
||||||
try:
|
try:
|
||||||
from app.services.score_cache import invalidate
|
from app.services.score_cache import invalidate
|
||||||
@@ -439,10 +433,6 @@ def check_dual_validation(db: Session, test: Test) -> Test:
|
|||||||
notify_test_state_change(db, test, "validated")
|
notify_test_state_change(db, test, "validated")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Notification failed for test %s (validated): %s", test.id, e, exc_info=True)
|
logger.warning("Notification failed for test %s (validated): %s", test.id, e, exc_info=True)
|
||||||
else:
|
|
||||||
# One side hasn't voted yet — stay in_review, just flush
|
|
||||||
db.commit()
|
|
||||||
|
|
||||||
return test
|
return test
|
||||||
|
|
||||||
|
|
||||||
@@ -533,8 +523,7 @@ def handle_remediation_completed(db: Session, test: Test, user: User) -> Test |
|
|||||||
entity_id=retest.id,
|
entity_id=retest.id,
|
||||||
)
|
)
|
||||||
|
|
||||||
db.commit()
|
db.flush()
|
||||||
db.refresh(retest)
|
|
||||||
return retest
|
return retest
|
||||||
|
|
||||||
|
|
||||||
@@ -599,5 +588,4 @@ def reopen_test(db: Session, test: Test, user: User) -> Test:
|
|||||||
test.red_paused_seconds = 0
|
test.red_paused_seconds = 0
|
||||||
test.blue_paused_seconds = 0
|
test.blue_paused_seconds = 0
|
||||||
|
|
||||||
db.commit()
|
|
||||||
return test
|
return test
|
||||||
|
|||||||
Reference in New Issue
Block a user