refactor: remove db.commit() from audit_service.log_action, all callers use UoW

This commit is contained in:
2026-02-20 15:33:23 +01:00
parent 0c526c48f9
commit a9255e15ce
19 changed files with 345 additions and 337 deletions

View File

@@ -12,8 +12,6 @@ Services should **never** call ``db.commit()``; they use ``db.add()`` /
``db.flush()`` to stage work and let the caller decide when to commit.
**Documented exceptions** (services that may commit internally):
- ``audit_service.log_action`` — called from 15+ routers; commits to ensure
audit records persist even when callers do not.
- Import services (atomic_import, sigma_import, etc.) — self-contained sync ops.
- Background jobs (campaign_scheduler, intel_service, stale_detection,
mitre_sync) — self-contained operations.

View File

@@ -30,8 +30,9 @@ from app.services.campaign_crud_service import (
serialize_campaign,
update_campaign as crud_update,
)
from app.services.notification_service import notify_role
from app.domain.unit_of_work import UnitOfWork
from app.services.audit_service import log_action
from app.services.notification_service import notify_role
logger = logging.getLogger(__name__)
@@ -108,6 +109,7 @@ def create_campaign(
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
"""Create a new campaign."""
with UnitOfWork(db) as uow:
result = crud_create(
db,
creator_id=current_user.id,
@@ -119,7 +121,6 @@ def create_campaign(
tags=payload.tags,
scheduled_at=payload.scheduled_at,
)
log_action(
db,
user_id=current_user.id,
@@ -128,7 +129,7 @@ def create_campaign(
entity_id=result["id"],
details={"name": payload.name, "type": payload.type},
)
db.commit()
uow.commit()
return result
@@ -160,6 +161,7 @@ def update_campaign(
):
"""Update a campaign. Only allowed in draft or active state."""
update_data = payload.model_dump(exclude_unset=True)
with UnitOfWork(db) as uow:
result = crud_update(
db,
campaign_id,
@@ -167,7 +169,6 @@ def update_campaign(
updater_role=current_user.role,
**update_data,
)
log_action(
db,
user_id=current_user.id,
@@ -176,7 +177,7 @@ def update_campaign(
entity_id=campaign_id,
details={"updated_fields": list(update_data.keys())},
)
db.commit()
uow.commit()
return result
@@ -193,6 +194,7 @@ def add_test_to_campaign(
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
"""Add a test to a campaign with optional ordering and dependency."""
with UnitOfWork(db) as uow:
result = crud_add_test(
db,
campaign_id,
@@ -201,7 +203,7 @@ def add_test_to_campaign(
depends_on=payload.depends_on,
phase=payload.phase,
)
db.commit()
uow.commit()
return result
@@ -217,8 +219,9 @@ def remove_test_from_campaign(
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
"""Remove a test from a campaign."""
with UnitOfWork(db) as uow:
crud_remove_test(db, campaign_id, campaign_test_id)
db.commit()
uow.commit()
return {"detail": "Test removed from campaign"}
@@ -233,10 +236,8 @@ def activate_campaign(
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
"""Activate a campaign, moving it from draft to active."""
with UnitOfWork(db) as uow:
campaign = crud_activate(db, campaign_id)
db.commit()
db.refresh(campaign)
notify_role(
db,
role="red_tech",
@@ -246,7 +247,6 @@ def activate_campaign(
entity_type="campaign",
entity_id=campaign.id,
)
log_action(
db,
user_id=current_user.id,
@@ -255,6 +255,8 @@ def activate_campaign(
entity_id=campaign.id,
details={"name": campaign.name},
)
uow.commit()
db.refresh(campaign)
return serialize_campaign(db, campaign)
@@ -270,10 +272,8 @@ def complete_campaign(
current_user: User = Depends(require_any_role("red_lead", "admin")),
):
"""Mark a campaign as completed."""
with UnitOfWork(db) as uow:
campaign = crud_complete(db, campaign_id)
db.commit()
db.refresh(campaign)
log_action(
db,
user_id=current_user.id,
@@ -282,6 +282,8 @@ def complete_campaign(
entity_id=campaign.id,
details={"name": campaign.name},
)
uow.commit()
db.refresh(campaign)
return serialize_campaign(db, campaign)
@@ -321,6 +323,7 @@ def generate_campaign_from_actor(
current_user,
)
with UnitOfWork(db) as uow:
log_action(
db,
user_id=current_user.id,
@@ -329,6 +332,7 @@ def generate_campaign_from_actor(
entity_id=campaign.id,
details={"actor_id": actor_id, "campaign_name": campaign.name},
)
uow.commit()
return serialize_campaign(db, campaign)
@@ -348,6 +352,7 @@ def schedule_campaign(
Only the campaign creator or admin can change scheduling.
"""
with UnitOfWork(db) as uow:
campaign = crud_schedule(
db,
campaign_id,
@@ -357,9 +362,6 @@ def schedule_campaign(
recurrence_pattern=payload.recurrence_pattern,
next_run_at=payload.next_run_at,
)
db.commit()
db.refresh(campaign)
log_action(
db,
user_id=current_user.id,
@@ -372,6 +374,8 @@ def schedule_campaign(
"next_run_at": campaign.next_run_at.isoformat() if campaign.next_run_at else None,
},
)
uow.commit()
db.refresh(campaign)
return serialize_campaign(db, campaign)

View File

@@ -67,10 +67,8 @@ def update_data_source(
**Requires** the ``admin`` role.
"""
update_data = body.model_dump(exclude_unset=True)
update_source(db, source_id, **update_data)
with UnitOfWork(db) as uow:
uow.commit()
update_source(db, source_id, **update_data)
log_action(
db,
user_id=current_user.id,
@@ -79,6 +77,7 @@ def update_data_source(
entity_id=source_id,
details={"updates": update_data},
)
uow.commit()
return {"message": "Data source updated", "id": source_id}
@@ -107,6 +106,7 @@ def sync_all_data_sources(
"""
results = sync_all_sources(db)
with UnitOfWork(db) as uow:
log_action(
db,
user_id=current_user.id,
@@ -115,6 +115,7 @@ def sync_all_data_sources(
entity_id=None,
details={"results": results},
)
uow.commit()
return {"message": "Sync all complete", "results": results}

View File

@@ -28,6 +28,7 @@ from fastapi import APIRouter, Depends, File, Form, Query, UploadFile, status
from sqlalchemy.orm import Session
from app.database import get_db
from app.domain.unit_of_work import UnitOfWork
from app.dependencies.auth import get_current_user
from app.models.enums import TeamSide
from app.models.evidence import Evidence
@@ -107,7 +108,8 @@ async def upload_evidence(
# 5. Upload to MinIO
upload_file(content, key)
# 6. Persist metadata
# 6. Persist metadata and audit
with UnitOfWork(db) as uow:
evidence = Evidence(
test_id=test_id,
file_name=safe_name,
@@ -118,10 +120,7 @@ async def upload_evidence(
notes=notes,
)
db.add(evidence)
db.commit()
db.refresh(evidence)
# 7. Audit
db.flush() # Get evidence.id for audit
log_action(
db,
user_id=current_user.id,
@@ -135,6 +134,8 @@ async def upload_evidence(
"team": team.value,
},
)
uow.commit()
db.refresh(evidence)
return _evidence_to_out(evidence)
@@ -195,7 +196,7 @@ def delete_evidence(
test = get_test_or_raise(db, evidence.test_id)
validate_delete_permission(test, evidence, current_user.role, current_user.id)
# Audit before deletion
with UnitOfWork(db) as uow:
log_action(
db,
user_id=current_user.id,
@@ -208,8 +209,7 @@ def delete_evidence(
"team": evidence.team.value if evidence.team else None,
},
)
db.delete(evidence)
db.commit()
uow.commit()
return {"detail": "Evidence deleted"}

View File

@@ -9,6 +9,7 @@ from sqlalchemy.orm import Session
from app.database import get_db
from app.dependencies.auth import get_current_user, require_role
from app.domain.unit_of_work import UnitOfWork
from app.models.jira_link import JiraLinkEntityType
from app.models.user import User
from app.schemas.jira_schema import (
@@ -40,6 +41,7 @@ def create_link(
user: User = Depends(get_current_user),
):
"""Associate an Aegis entity with a Jira issue."""
with UnitOfWork(db) as uow:
link = jira_service.create_link(
db,
entity_type=body.entity_type,
@@ -48,9 +50,6 @@ def create_link(
sync_direction=body.sync_direction,
created_by=user.id,
)
db.commit()
db.refresh(link)
audit_service.log_action(
db,
user_id=user.id,
@@ -63,6 +62,9 @@ def create_link(
"jira_issue_key": body.jira_issue_key,
},
)
uow.commit()
db.refresh(link)
return link
@@ -88,9 +90,10 @@ def sync_link(
user: User = Depends(require_role("admin")),
):
"""Force bidirectional sync for a specific Jira link."""
with UnitOfWork(db) as uow:
link = jira_service.get_link_or_raise(db, link_id)
jira_service.sync_jira_to_aegis(db, link)
db.commit()
uow.commit()
return {"message": "Sync completed", "jira_status": link.jira_status}
@@ -101,8 +104,8 @@ def delete_link(
user: User = Depends(get_current_user),
):
"""Remove a Jira link."""
with UnitOfWork(db) as uow:
link = jira_service.delete_link(db, link_id)
db.commit()
audit_service.log_action(
db,
user_id=user.id,
@@ -111,6 +114,7 @@ def delete_link(
entity_id=str(link_id),
details={"jira_issue_key": link.jira_issue_key},
)
uow.commit()
@router.post("/create-issue")
@@ -121,11 +125,12 @@ def create_issue_from_entity(
user: User = Depends(get_current_user),
):
"""Auto-create a Jira issue from an Aegis entity and link them."""
with UnitOfWork(db) as uow:
result = jira_service.create_issue_and_link(
db,
entity_type=entity_type,
entity_id=entity_id,
created_by=user.id,
)
db.commit()
uow.commit()
return result

View File

@@ -68,6 +68,7 @@ def create_snapshot_endpoint(
"""Create a manual coverage snapshot with an optional name."""
snapshot = create_snapshot(db, name=payload.name, user_id=current_user.id)
with UnitOfWork(db) as uow:
log_action(
db,
user_id=current_user.id,
@@ -76,6 +77,7 @@ def create_snapshot_endpoint(
entity_id=snapshot.id,
details={"name": snapshot.name, "score": snapshot.organization_score},
)
uow.commit()
return serialize_snapshot_summary(snapshot)
@@ -128,6 +130,7 @@ def delete_snapshot_endpoint(
"""Delete a snapshot (admin only)."""
snapshot = get_snapshot_or_raise(db, snapshot_id)
with UnitOfWork(db) as uow:
log_action(
db,
user_id=current_user.id,
@@ -136,8 +139,6 @@ def delete_snapshot_endpoint(
entity_id=snapshot.id,
details={"name": snapshot.name},
)
with UnitOfWork(db) as uow:
delete_snapshot(db, snapshot_id)
uow.commit()

View File

@@ -99,8 +99,6 @@ def create_technique(
with UnitOfWork(db) as uow:
saved = repo.save(entity)
uow.commit()
log_action(
db,
user_id=current_user.id,
@@ -109,6 +107,7 @@ def create_technique(
entity_id=saved.id,
details={"mitre_id": saved.mitre_id, "name": saved.name},
)
uow.commit()
return saved
@@ -137,8 +136,6 @@ def update_technique(
with UnitOfWork(db) as uow:
saved = repo.save(entity)
uow.commit()
log_action(
db,
user_id=current_user.id,
@@ -147,6 +144,7 @@ def update_technique(
entity_id=saved.id,
details={"mitre_id": mitre_id, "updated_fields": list(update_data.keys())},
)
uow.commit()
return saved
@@ -176,8 +174,6 @@ def review_technique(
with UnitOfWork(db) as uow:
saved = repo.save(entity)
uow.commit()
log_action(
db,
user_id=current_user.id,
@@ -186,5 +182,6 @@ def review_technique(
entity_id=saved.id,
details={"mitre_id": mitre_id},
)
uow.commit()
return saved

View File

@@ -128,9 +128,6 @@ def create_test(
creator_id=current_user.id,
**payload.model_dump(exclude={"technique_id"}),
)
uow.commit()
db.refresh(test)
log_action(
db,
user_id=current_user.id,
@@ -139,6 +136,8 @@ def create_test(
entity_id=test.id,
details={"name": test.name, "technique_id": str(test.technique_id)},
)
uow.commit()
db.refresh(test)
return test
@@ -169,9 +168,6 @@ def create_test_from_template(
technique_id_or_mitre=payload.technique_id,
creator_id=current_user.id,
)
uow.commit()
db.refresh(test)
log_action(
db,
user_id=current_user.id,
@@ -184,6 +180,8 @@ def create_test_from_template(
"technique_id": str(test.technique_id),
},
)
uow.commit()
db.refresh(test)
return test
@@ -229,9 +227,6 @@ def update_test(
updater_role=current_user.role,
**update_data,
)
uow.commit()
db.refresh(test)
log_action(
db,
user_id=current_user.id,
@@ -240,6 +235,8 @@ def update_test(
entity_id=test.id,
details={"updated_fields": list(update_data.keys())},
)
uow.commit()
db.refresh(test)
return test
@@ -260,9 +257,6 @@ def update_test_red(
update_data = payload.model_dump(exclude_unset=True)
with UnitOfWork(db) as uow:
test = crud_update_test_red(db, test_id, **update_data)
uow.commit()
db.refresh(test)
log_action(
db,
user_id=current_user.id,
@@ -271,6 +265,8 @@ def update_test_red(
entity_id=test.id,
details={"updated_fields": list(update_data.keys())},
)
uow.commit()
db.refresh(test)
return test
@@ -291,9 +287,6 @@ def update_test_blue(
update_data = payload.model_dump(exclude_unset=True)
with UnitOfWork(db) as uow:
test = crud_update_test_blue(db, test_id, **update_data)
uow.commit()
db.refresh(test)
log_action(
db,
user_id=current_user.id,
@@ -302,6 +295,8 @@ def update_test_blue(
entity_id=test.id,
details={"updated_fields": list(update_data.keys())},
)
uow.commit()
db.refresh(test)
return test

View File

@@ -47,6 +47,7 @@ def create_user_route(
current_user: User = Depends(require_role("admin")),
):
"""Create a new user. **Requires admin role.**"""
with UnitOfWork(db) as uow:
user = create_user(
db,
username=payload.username,
@@ -54,10 +55,6 @@ def create_user_route(
password=payload.password,
role=payload.role,
)
with UnitOfWork(db) as uow:
uow.commit()
db.refresh(user)
log_action(
db,
user_id=current_user.id,
@@ -66,6 +63,8 @@ def create_user_route(
entity_id=user.id,
details={"username": user.username, "role": user.role},
)
uow.commit()
db.refresh(user)
return user
@@ -99,11 +98,8 @@ def update_user_route(
):
"""Update one or more fields of an existing user. **Requires admin role.**"""
update_data = payload.model_dump(exclude_unset=True)
user = update_user(db, user_id, **update_data)
with UnitOfWork(db) as uow:
uow.commit()
db.refresh(user)
user = update_user(db, user_id, **update_data)
log_action(
db,
user_id=current_user.id,
@@ -112,5 +108,7 @@ def update_user_route(
entity_id=user.id,
details={"updated_fields": list(update_data.keys())},
)
uow.commit()
db.refresh(user)
return user

View File

@@ -267,5 +267,6 @@ def import_atomic_red_team(db: Session) -> dict:
entity_id=None,
details=summary,
)
db.commit()
return summary

View File

@@ -30,4 +30,3 @@ def log_action(
details=details,
)
db.add(log)
db.commit()

View File

@@ -278,4 +278,5 @@ def sync(db: Session) -> dict:
logger.info("CALDERA import complete — %s", summary)
log_action(db, user_id=None, action="import_caldera",
entity_type="test_template", entity_id=None, details=summary)
db.commit()
return summary

View File

@@ -157,6 +157,7 @@ def check_and_run_recurring_campaigns(db: Session) -> int:
"pattern": campaign.recurrence_pattern,
},
)
db.commit()
# Notify
if campaign.created_by:

View File

@@ -358,4 +358,5 @@ def sync(db: Session) -> dict:
logger.info("Elastic import complete — %s", summary)
log_action(db, user_id=None, action="import_elastic_rules",
entity_type="detection_rule", entity_id=None, details=summary)
db.commit()
return summary

View File

@@ -250,5 +250,6 @@ def scan_intel(db: Session) -> dict:
entity_id=None,
details=summary,
)
db.commit()
return summary

View File

@@ -352,6 +352,7 @@ def sync(db: Session) -> dict:
logger.info("LOLBAS import complete — %s", summary)
log_action(db, user_id=None, action="import_lolbas",
entity_type="test_template", entity_id=None, details=summary)
db.commit()
return summary
@@ -381,4 +382,5 @@ def sync_gtfobins(db: Session) -> dict:
logger.info("GTFOBins import complete — %s", summary)
log_action(db, user_id=None, action="import_gtfobins",
entity_type="test_template", entity_id=None, details=summary)
db.commit()
return summary

View File

@@ -243,5 +243,6 @@ def sync_mitre(db: Session) -> dict:
entity_id=None,
details=summary,
)
db.commit()
return summary

View File

@@ -344,5 +344,6 @@ def sync(db: Session) -> dict:
entity_id=None,
details=summary,
)
db.commit()
return summary

View File

@@ -369,5 +369,6 @@ def sync(db: Session) -> dict:
entity_id=None,
details=summary,
)
db.commit()
return summary