refactor: remove db.commit() from audit_service.log_action, all callers use UoW

This commit is contained in:
2026-02-20 15:33:23 +01:00
parent 0c526c48f9
commit a9255e15ce
19 changed files with 345 additions and 337 deletions

View File

@@ -30,8 +30,9 @@ from app.services.campaign_crud_service import (
serialize_campaign,
update_campaign as crud_update,
)
from app.services.notification_service import notify_role
from app.domain.unit_of_work import UnitOfWork
from app.services.audit_service import log_action
from app.services.notification_service import notify_role
logger = logging.getLogger(__name__)
@@ -108,27 +109,27 @@ def create_campaign(
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
"""Create a new campaign."""
result = crud_create(
db,
creator_id=current_user.id,
name=payload.name,
description=payload.description,
type=payload.type,
threat_actor_id=payload.threat_actor_id,
target_platform=payload.target_platform,
tags=payload.tags,
scheduled_at=payload.scheduled_at,
)
log_action(
db,
user_id=current_user.id,
action="create_campaign",
entity_type="campaign",
entity_id=result["id"],
details={"name": payload.name, "type": payload.type},
)
db.commit()
with UnitOfWork(db) as uow:
result = crud_create(
db,
creator_id=current_user.id,
name=payload.name,
description=payload.description,
type=payload.type,
threat_actor_id=payload.threat_actor_id,
target_platform=payload.target_platform,
tags=payload.tags,
scheduled_at=payload.scheduled_at,
)
log_action(
db,
user_id=current_user.id,
action="create_campaign",
entity_type="campaign",
entity_id=result["id"],
details={"name": payload.name, "type": payload.type},
)
uow.commit()
return result
@@ -160,23 +161,23 @@ def update_campaign(
):
"""Update a campaign. Only allowed in draft or active state."""
update_data = payload.model_dump(exclude_unset=True)
result = crud_update(
db,
campaign_id,
updater_id=current_user.id,
updater_role=current_user.role,
**update_data,
)
log_action(
db,
user_id=current_user.id,
action="update_campaign",
entity_type="campaign",
entity_id=campaign_id,
details={"updated_fields": list(update_data.keys())},
)
db.commit()
with UnitOfWork(db) as uow:
result = crud_update(
db,
campaign_id,
updater_id=current_user.id,
updater_role=current_user.role,
**update_data,
)
log_action(
db,
user_id=current_user.id,
action="update_campaign",
entity_type="campaign",
entity_id=campaign_id,
details={"updated_fields": list(update_data.keys())},
)
uow.commit()
return result
@@ -193,15 +194,16 @@ def add_test_to_campaign(
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
"""Add a test to a campaign with optional ordering and dependency."""
result = crud_add_test(
db,
campaign_id,
test_id=payload.test_id,
order_index=payload.order_index,
depends_on=payload.depends_on,
phase=payload.phase,
)
db.commit()
with UnitOfWork(db) as uow:
result = crud_add_test(
db,
campaign_id,
test_id=payload.test_id,
order_index=payload.order_index,
depends_on=payload.depends_on,
phase=payload.phase,
)
uow.commit()
return result
@@ -217,8 +219,9 @@ def remove_test_from_campaign(
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
"""Remove a test from a campaign."""
crud_remove_test(db, campaign_id, campaign_test_id)
db.commit()
with UnitOfWork(db) as uow:
crud_remove_test(db, campaign_id, campaign_test_id)
uow.commit()
return {"detail": "Test removed from campaign"}
@@ -233,29 +236,28 @@ def activate_campaign(
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
"""Activate a campaign, moving it from draft to active."""
campaign = crud_activate(db, campaign_id)
db.commit()
with UnitOfWork(db) as uow:
campaign = crud_activate(db, campaign_id)
notify_role(
db,
role="red_tech",
type="campaign_activated",
title="Campaign activated",
message=f'Campaign "{campaign.name}" has been activated.',
entity_type="campaign",
entity_id=campaign.id,
)
log_action(
db,
user_id=current_user.id,
action="activate_campaign",
entity_type="campaign",
entity_id=campaign.id,
details={"name": campaign.name},
)
uow.commit()
db.refresh(campaign)
notify_role(
db,
role="red_tech",
type="campaign_activated",
title="Campaign activated",
message=f'Campaign "{campaign.name}" has been activated.',
entity_type="campaign",
entity_id=campaign.id,
)
log_action(
db,
user_id=current_user.id,
action="activate_campaign",
entity_type="campaign",
entity_id=campaign.id,
details={"name": campaign.name},
)
return serialize_campaign(db, campaign)
@@ -270,19 +272,19 @@ def complete_campaign(
current_user: User = Depends(require_any_role("red_lead", "admin")),
):
"""Mark a campaign as completed."""
campaign = crud_complete(db, campaign_id)
db.commit()
with UnitOfWork(db) as uow:
campaign = crud_complete(db, campaign_id)
log_action(
db,
user_id=current_user.id,
action="complete_campaign",
entity_type="campaign",
entity_id=campaign.id,
details={"name": campaign.name},
)
uow.commit()
db.refresh(campaign)
log_action(
db,
user_id=current_user.id,
action="complete_campaign",
entity_type="campaign",
entity_id=campaign.id,
details={"name": campaign.name},
)
return serialize_campaign(db, campaign)
@@ -321,14 +323,16 @@ def generate_campaign_from_actor(
current_user,
)
log_action(
db,
user_id=current_user.id,
action="generate_campaign",
entity_type="campaign",
entity_id=campaign.id,
details={"actor_id": actor_id, "campaign_name": campaign.name},
)
with UnitOfWork(db) as uow:
log_action(
db,
user_id=current_user.id,
action="generate_campaign",
entity_type="campaign",
entity_id=campaign.id,
details={"actor_id": actor_id, "campaign_name": campaign.name},
)
uow.commit()
return serialize_campaign(db, campaign)
@@ -348,31 +352,31 @@ def schedule_campaign(
Only the campaign creator or admin can change scheduling.
"""
campaign = crud_schedule(
db,
campaign_id,
owner_id=current_user.id,
owner_role=current_user.role,
is_recurring=payload.is_recurring,
recurrence_pattern=payload.recurrence_pattern,
next_run_at=payload.next_run_at,
)
db.commit()
with UnitOfWork(db) as uow:
campaign = crud_schedule(
db,
campaign_id,
owner_id=current_user.id,
owner_role=current_user.role,
is_recurring=payload.is_recurring,
recurrence_pattern=payload.recurrence_pattern,
next_run_at=payload.next_run_at,
)
log_action(
db,
user_id=current_user.id,
action="schedule_campaign",
entity_type="campaign",
entity_id=campaign.id,
details={
"is_recurring": campaign.is_recurring,
"recurrence_pattern": campaign.recurrence_pattern,
"next_run_at": campaign.next_run_at.isoformat() if campaign.next_run_at else None,
},
)
uow.commit()
db.refresh(campaign)
log_action(
db,
user_id=current_user.id,
action="schedule_campaign",
entity_type="campaign",
entity_id=campaign.id,
details={
"is_recurring": campaign.is_recurring,
"recurrence_pattern": campaign.recurrence_pattern,
"next_run_at": campaign.next_run_at.isoformat() if campaign.next_run_at else None,
},
)
return serialize_campaign(db, campaign)

View File

@@ -67,19 +67,18 @@ def update_data_source(
**Requires** the ``admin`` role.
"""
update_data = body.model_dump(exclude_unset=True)
update_source(db, source_id, **update_data)
with UnitOfWork(db) as uow:
update_source(db, source_id, **update_data)
log_action(
db,
user_id=current_user.id,
action="update_data_source",
entity_type="data_source",
entity_id=source_id,
details={"updates": update_data},
)
uow.commit()
log_action(
db,
user_id=current_user.id,
action="update_data_source",
entity_type="data_source",
entity_id=source_id,
details={"updates": update_data},
)
return {"message": "Data source updated", "id": source_id}
@@ -107,14 +106,16 @@ def sync_all_data_sources(
"""
results = sync_all_sources(db)
log_action(
db,
user_id=current_user.id,
action="sync_all_data_sources",
entity_type="data_source",
entity_id=None,
details={"results": results},
)
with UnitOfWork(db) as uow:
log_action(
db,
user_id=current_user.id,
action="sync_all_data_sources",
entity_type="data_source",
entity_id=None,
details={"results": results},
)
uow.commit()
return {"message": "Sync all complete", "results": results}

View File

@@ -28,6 +28,7 @@ from fastapi import APIRouter, Depends, File, Form, Query, UploadFile, status
from sqlalchemy.orm import Session
from app.database import get_db
from app.domain.unit_of_work import UnitOfWork
from app.dependencies.auth import get_current_user
from app.models.enums import TeamSide
from app.models.evidence import Evidence
@@ -107,35 +108,35 @@ async def upload_evidence(
# 5. Upload to MinIO
upload_file(content, key)
# 6. Persist metadata
evidence = Evidence(
test_id=test_id,
file_name=safe_name,
file_path=key,
sha256_hash=sha256,
uploaded_by=current_user.id,
team=team,
notes=notes,
)
db.add(evidence)
db.commit()
# 6. Persist metadata and audit
with UnitOfWork(db) as uow:
evidence = Evidence(
test_id=test_id,
file_name=safe_name,
file_path=key,
sha256_hash=sha256,
uploaded_by=current_user.id,
team=team,
notes=notes,
)
db.add(evidence)
db.flush() # Get evidence.id for audit
log_action(
db,
user_id=current_user.id,
action="upload_evidence",
entity_type="evidence",
entity_id=evidence.id,
details={
"file_name": safe_name,
"sha256": sha256,
"test_id": str(test_id),
"team": team.value,
},
)
uow.commit()
db.refresh(evidence)
# 7. Audit
log_action(
db,
user_id=current_user.id,
action="upload_evidence",
entity_type="evidence",
entity_id=evidence.id,
details={
"file_name": safe_name,
"sha256": sha256,
"test_id": str(test_id),
"team": team.value,
},
)
return _evidence_to_out(evidence)
@@ -195,21 +196,20 @@ def delete_evidence(
test = get_test_or_raise(db, evidence.test_id)
validate_delete_permission(test, evidence, current_user.role, current_user.id)
# Audit before deletion
log_action(
db,
user_id=current_user.id,
action="delete_evidence",
entity_type="evidence",
entity_id=evidence.id,
details={
"file_name": evidence.file_name,
"test_id": str(evidence.test_id),
"team": evidence.team.value if evidence.team else None,
},
)
db.delete(evidence)
db.commit()
with UnitOfWork(db) as uow:
log_action(
db,
user_id=current_user.id,
action="delete_evidence",
entity_type="evidence",
entity_id=evidence.id,
details={
"file_name": evidence.file_name,
"test_id": str(evidence.test_id),
"team": evidence.team.value if evidence.team else None,
},
)
db.delete(evidence)
uow.commit()
return {"detail": "Evidence deleted"}

View File

@@ -9,6 +9,7 @@ from sqlalchemy.orm import Session
from app.database import get_db
from app.dependencies.auth import get_current_user, require_role
from app.domain.unit_of_work import UnitOfWork
from app.models.jira_link import JiraLinkEntityType
from app.models.user import User
from app.schemas.jira_schema import (
@@ -40,29 +41,30 @@ def create_link(
user: User = Depends(get_current_user),
):
"""Associate an Aegis entity with a Jira issue."""
link = jira_service.create_link(
db,
entity_type=body.entity_type,
entity_id=body.entity_id,
jira_issue_key=body.jira_issue_key,
sync_direction=body.sync_direction,
created_by=user.id,
)
db.commit()
with UnitOfWork(db) as uow:
link = jira_service.create_link(
db,
entity_type=body.entity_type,
entity_id=body.entity_id,
jira_issue_key=body.jira_issue_key,
sync_direction=body.sync_direction,
created_by=user.id,
)
audit_service.log_action(
db,
user_id=user.id,
action="jira_link_created",
entity_type="jira_link",
entity_id=str(link.id),
details={
"linked_entity_type": body.entity_type.value,
"linked_entity_id": str(body.entity_id),
"jira_issue_key": body.jira_issue_key,
},
)
uow.commit()
db.refresh(link)
audit_service.log_action(
db,
user_id=user.id,
action="jira_link_created",
entity_type="jira_link",
entity_id=str(link.id),
details={
"linked_entity_type": body.entity_type.value,
"linked_entity_id": str(body.entity_id),
"jira_issue_key": body.jira_issue_key,
},
)
return link
@@ -88,9 +90,10 @@ def sync_link(
user: User = Depends(require_role("admin")),
):
"""Force bidirectional sync for a specific Jira link."""
link = jira_service.get_link_or_raise(db, link_id)
jira_service.sync_jira_to_aegis(db, link)
db.commit()
with UnitOfWork(db) as uow:
link = jira_service.get_link_or_raise(db, link_id)
jira_service.sync_jira_to_aegis(db, link)
uow.commit()
return {"message": "Sync completed", "jira_status": link.jira_status}
@@ -101,16 +104,17 @@ def delete_link(
user: User = Depends(get_current_user),
):
"""Remove a Jira link."""
link = jira_service.delete_link(db, link_id)
db.commit()
audit_service.log_action(
db,
user_id=user.id,
action="jira_link_deleted",
entity_type="jira_link",
entity_id=str(link_id),
details={"jira_issue_key": link.jira_issue_key},
)
with UnitOfWork(db) as uow:
link = jira_service.delete_link(db, link_id)
audit_service.log_action(
db,
user_id=user.id,
action="jira_link_deleted",
entity_type="jira_link",
entity_id=str(link_id),
details={"jira_issue_key": link.jira_issue_key},
)
uow.commit()
@router.post("/create-issue")
@@ -121,11 +125,12 @@ def create_issue_from_entity(
user: User = Depends(get_current_user),
):
"""Auto-create a Jira issue from an Aegis entity and link them."""
result = jira_service.create_issue_and_link(
db,
entity_type=entity_type,
entity_id=entity_id,
created_by=user.id,
)
db.commit()
with UnitOfWork(db) as uow:
result = jira_service.create_issue_and_link(
db,
entity_type=entity_type,
entity_id=entity_id,
created_by=user.id,
)
uow.commit()
return result

View File

@@ -68,14 +68,16 @@ def create_snapshot_endpoint(
"""Create a manual coverage snapshot with an optional name."""
snapshot = create_snapshot(db, name=payload.name, user_id=current_user.id)
log_action(
db,
user_id=current_user.id,
action="create_snapshot",
entity_type="snapshot",
entity_id=snapshot.id,
details={"name": snapshot.name, "score": snapshot.organization_score},
)
with UnitOfWork(db) as uow:
log_action(
db,
user_id=current_user.id,
action="create_snapshot",
entity_type="snapshot",
entity_id=snapshot.id,
details={"name": snapshot.name, "score": snapshot.organization_score},
)
uow.commit()
return serialize_snapshot_summary(snapshot)
@@ -128,16 +130,15 @@ def delete_snapshot_endpoint(
"""Delete a snapshot (admin only)."""
snapshot = get_snapshot_or_raise(db, snapshot_id)
log_action(
db,
user_id=current_user.id,
action="delete_snapshot",
entity_type="snapshot",
entity_id=snapshot.id,
details={"name": snapshot.name},
)
with UnitOfWork(db) as uow:
log_action(
db,
user_id=current_user.id,
action="delete_snapshot",
entity_type="snapshot",
entity_id=snapshot.id,
details={"name": snapshot.name},
)
delete_snapshot(db, snapshot_id)
uow.commit()

View File

@@ -99,17 +99,16 @@ def create_technique(
with UnitOfWork(db) as uow:
saved = repo.save(entity)
log_action(
db,
user_id=current_user.id,
action="create_technique",
entity_type="technique",
entity_id=saved.id,
details={"mitre_id": saved.mitre_id, "name": saved.name},
)
uow.commit()
log_action(
db,
user_id=current_user.id,
action="create_technique",
entity_type="technique",
entity_id=saved.id,
details={"mitre_id": saved.mitre_id, "name": saved.name},
)
return saved
@@ -137,17 +136,16 @@ def update_technique(
with UnitOfWork(db) as uow:
saved = repo.save(entity)
log_action(
db,
user_id=current_user.id,
action="update_technique",
entity_type="technique",
entity_id=saved.id,
details={"mitre_id": mitre_id, "updated_fields": list(update_data.keys())},
)
uow.commit()
log_action(
db,
user_id=current_user.id,
action="update_technique",
entity_type="technique",
entity_id=saved.id,
details={"mitre_id": mitre_id, "updated_fields": list(update_data.keys())},
)
return saved
@@ -176,15 +174,14 @@ def review_technique(
with UnitOfWork(db) as uow:
saved = repo.save(entity)
log_action(
db,
user_id=current_user.id,
action="review_technique",
entity_type="technique",
entity_id=saved.id,
details={"mitre_id": mitre_id},
)
uow.commit()
log_action(
db,
user_id=current_user.id,
action="review_technique",
entity_type="technique",
entity_id=saved.id,
details={"mitre_id": mitre_id},
)
return saved

View File

@@ -128,18 +128,17 @@ def create_test(
creator_id=current_user.id,
**payload.model_dump(exclude={"technique_id"}),
)
log_action(
db,
user_id=current_user.id,
action="create_test",
entity_type="test",
entity_id=test.id,
details={"name": test.name, "technique_id": str(test.technique_id)},
)
uow.commit()
db.refresh(test)
log_action(
db,
user_id=current_user.id,
action="create_test",
entity_type="test",
entity_id=test.id,
details={"name": test.name, "technique_id": str(test.technique_id)},
)
return test
@@ -169,22 +168,21 @@ def create_test_from_template(
technique_id_or_mitre=payload.technique_id,
creator_id=current_user.id,
)
log_action(
db,
user_id=current_user.id,
action="create_test_from_template",
entity_type="test",
entity_id=test.id,
details={
"name": test.name,
"template_id": str(payload.template_id),
"technique_id": str(test.technique_id),
},
)
uow.commit()
db.refresh(test)
log_action(
db,
user_id=current_user.id,
action="create_test_from_template",
entity_type="test",
entity_id=test.id,
details={
"name": test.name,
"template_id": str(payload.template_id),
"technique_id": str(test.technique_id),
},
)
return test
@@ -229,18 +227,17 @@ def update_test(
updater_role=current_user.role,
**update_data,
)
log_action(
db,
user_id=current_user.id,
action="update_test",
entity_type="test",
entity_id=test.id,
details={"updated_fields": list(update_data.keys())},
)
uow.commit()
db.refresh(test)
log_action(
db,
user_id=current_user.id,
action="update_test",
entity_type="test",
entity_id=test.id,
details={"updated_fields": list(update_data.keys())},
)
return test
@@ -260,18 +257,17 @@ def update_test_red(
update_data = payload.model_dump(exclude_unset=True)
with UnitOfWork(db) as uow:
test = crud_update_test_red(db, test_id, **update_data)
log_action(
db,
user_id=current_user.id,
action="update_test_red",
entity_type="test",
entity_id=test.id,
details={"updated_fields": list(update_data.keys())},
)
uow.commit()
db.refresh(test)
log_action(
db,
user_id=current_user.id,
action="update_test_red",
entity_type="test",
entity_id=test.id,
details={"updated_fields": list(update_data.keys())},
)
return test
@@ -291,18 +287,17 @@ def update_test_blue(
update_data = payload.model_dump(exclude_unset=True)
with UnitOfWork(db) as uow:
test = crud_update_test_blue(db, test_id, **update_data)
log_action(
db,
user_id=current_user.id,
action="update_test_blue",
entity_type="test",
entity_id=test.id,
details={"updated_fields": list(update_data.keys())},
)
uow.commit()
db.refresh(test)
log_action(
db,
user_id=current_user.id,
action="update_test_blue",
entity_type="test",
entity_id=test.id,
details={"updated_fields": list(update_data.keys())},
)
return test

View File

@@ -47,26 +47,25 @@ def create_user_route(
current_user: User = Depends(require_role("admin")),
):
"""Create a new user. **Requires admin role.**"""
user = create_user(
db,
username=payload.username,
email=payload.email,
password=payload.password,
role=payload.role,
)
with UnitOfWork(db) as uow:
user = create_user(
db,
username=payload.username,
email=payload.email,
password=payload.password,
role=payload.role,
)
log_action(
db,
user_id=current_user.id,
action="create_user",
entity_type="user",
entity_id=user.id,
details={"username": user.username, "role": user.role},
)
uow.commit()
db.refresh(user)
log_action(
db,
user_id=current_user.id,
action="create_user",
entity_type="user",
entity_id=user.id,
details={"username": user.username, "role": user.role},
)
return user
@@ -99,18 +98,17 @@ def update_user_route(
):
"""Update one or more fields of an existing user. **Requires admin role.**"""
update_data = payload.model_dump(exclude_unset=True)
user = update_user(db, user_id, **update_data)
with UnitOfWork(db) as uow:
user = update_user(db, user_id, **update_data)
log_action(
db,
user_id=current_user.id,
action="update_user",
entity_type="user",
entity_id=user.id,
details={"updated_fields": list(update_data.keys())},
)
uow.commit()
db.refresh(user)
log_action(
db,
user_id=current_user.id,
action="update_user",
entity_type="user",
entity_id=user.id,
details={"updated_fields": list(update_data.keys())},
)
return user