refactor(core): introduce Unit of Work and remove commits from services
Some checks failed
Aegis CI / lint-and-test (push) Has been cancelled
Some checks failed
Aegis CI / lint-and-test (push) Has been cancelled
- Add UnitOfWork context manager in domain/unit_of_work.py with commit/rollback/flush API and auto-rollback on exception - Remove all db.commit() from test_workflow_service (8 calls), notification_service (4 calls), status_service (1 call) - Services now only stage changes via db.add/db.flush; caller owns the transaction boundary - Update routers/tests.py: wrap 9 workflow endpoints in UnitOfWork context managers - Update routers/notifications.py: wrap mark_as_read and mark_all_as_read in UnitOfWork
This commit is contained in:
@@ -2,6 +2,9 @@
|
||||
|
||||
Provides helpers for generating notifications automatically when test
|
||||
state changes occur, plus CRUD for the notifications API.
|
||||
|
||||
Functions in this module stage changes via ``db.add()`` / ``db.flush()``
|
||||
but do **not** commit. The caller is responsible for committing.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
@@ -38,8 +41,7 @@ def create_notification(
|
||||
entity_id=entity_id,
|
||||
)
|
||||
db.add(notif)
|
||||
db.commit()
|
||||
db.refresh(notif)
|
||||
db.flush()
|
||||
return notif
|
||||
|
||||
|
||||
@@ -53,7 +55,6 @@ def mark_as_read(db: Session, notification_id: uuid.UUID, user_id: uuid.UUID) ->
|
||||
if notif is None:
|
||||
return False
|
||||
notif.read = True
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
|
||||
@@ -64,7 +65,6 @@ def mark_all_as_read(db: Session, user_id: uuid.UUID) -> int:
|
||||
.filter(Notification.user_id == user_id, Notification.read == False) # noqa: E712
|
||||
.update({"read": True})
|
||||
)
|
||||
db.commit()
|
||||
return count
|
||||
|
||||
|
||||
@@ -88,7 +88,6 @@ def cleanup_old_notifications(db: Session, days: int = 90) -> int:
|
||||
)
|
||||
.delete()
|
||||
)
|
||||
db.commit()
|
||||
return count
|
||||
|
||||
|
||||
|
||||
@@ -4,6 +4,9 @@ based on the state and result of its associated tests.
|
||||
V2 rules account for dual Red/Blue validation and use
|
||||
``detection_result`` (filled by Blue Team) instead of the legacy
|
||||
``result`` field.
|
||||
|
||||
This function mutates the technique but does **not** commit.
|
||||
The caller is responsible for committing the session.
|
||||
"""
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -42,5 +45,3 @@ def recalculate_technique_status(db: Session, technique: Technique) -> None:
|
||||
technique.status_global = TechniqueStatus.partial
|
||||
else:
|
||||
technique.status_global = TechniqueStatus.in_progress
|
||||
|
||||
db.commit()
|
||||
|
||||
@@ -7,8 +7,9 @@ for each step in the test lifecycle:
|
||||
↓
|
||||
rejected → draft
|
||||
|
||||
Every public function validates the transition, mutates the test, writes an
|
||||
audit-log entry, and commits the session.
|
||||
Every public function validates the transition, mutates the test, and writes
|
||||
an audit-log entry. The caller (router) is responsible for committing the
|
||||
session via the Unit of Work pattern.
|
||||
"""
|
||||
|
||||
import logging
|
||||
@@ -122,7 +123,6 @@ def start_execution(db: Session, test: Test, user: User) -> Test:
|
||||
)
|
||||
test.execution_date = now
|
||||
test.red_started_at = now
|
||||
db.commit()
|
||||
return test
|
||||
|
||||
|
||||
@@ -161,7 +161,6 @@ def submit_red_evidence(db: Session, test: Test, user: User) -> Test:
|
||||
# Start Blue Team timer
|
||||
test.blue_started_at = now
|
||||
test.blue_paused_seconds = 0
|
||||
db.commit()
|
||||
return test
|
||||
|
||||
|
||||
@@ -196,7 +195,6 @@ def submit_blue_evidence(db: Session, test: Test, user: User) -> Test:
|
||||
description=f"Blue Team evaluation: {test.name}",
|
||||
)
|
||||
|
||||
db.commit()
|
||||
return test
|
||||
|
||||
|
||||
@@ -222,7 +220,6 @@ def pause_timer(db: Session, test: Test, user: User) -> Test:
|
||||
entity_id=test.id,
|
||||
details={"state": test.state.value},
|
||||
)
|
||||
db.commit()
|
||||
return test
|
||||
|
||||
|
||||
@@ -252,7 +249,6 @@ def resume_timer(db: Session, test: Test, user: User) -> Test:
|
||||
entity_id=test.id,
|
||||
details={"paused_seconds": paused_seconds, "state": test.state.value},
|
||||
)
|
||||
db.commit()
|
||||
return test
|
||||
|
||||
|
||||
@@ -421,14 +417,12 @@ def check_dual_validation(db: Session, test: Test) -> Test:
|
||||
|
||||
if red_status == "rejected" or blue_status == "rejected":
|
||||
test.state = TestState.rejected
|
||||
db.commit()
|
||||
try:
|
||||
notify_test_state_change(db, test, "rejected")
|
||||
except Exception as e:
|
||||
logger.warning("Notification failed for test %s (rejected): %s", test.id, e, exc_info=True)
|
||||
elif red_status == "approved" and blue_status == "approved":
|
||||
test.state = TestState.validated
|
||||
db.commit()
|
||||
# Invalidate cached scores — a validation changes org-level numbers
|
||||
try:
|
||||
from app.services.score_cache import invalidate
|
||||
@@ -439,10 +433,6 @@ def check_dual_validation(db: Session, test: Test) -> Test:
|
||||
notify_test_state_change(db, test, "validated")
|
||||
except Exception as e:
|
||||
logger.warning("Notification failed for test %s (validated): %s", test.id, e, exc_info=True)
|
||||
else:
|
||||
# One side hasn't voted yet — stay in_review, just flush
|
||||
db.commit()
|
||||
|
||||
return test
|
||||
|
||||
|
||||
@@ -533,8 +523,7 @@ def handle_remediation_completed(db: Session, test: Test, user: User) -> Test |
|
||||
entity_id=retest.id,
|
||||
)
|
||||
|
||||
db.commit()
|
||||
db.refresh(retest)
|
||||
db.flush()
|
||||
return retest
|
||||
|
||||
|
||||
@@ -599,5 +588,4 @@ def reopen_test(db: Session, test: Test, user: User) -> Test:
|
||||
test.red_paused_seconds = 0
|
||||
test.blue_paused_seconds = 0
|
||||
|
||||
db.commit()
|
||||
return test
|
||||
|
||||
Reference in New Issue
Block a user