refactor(tests): extract CRUD/query logic to test_crud_service, router delegates to service with domain exceptions

This commit is contained in:
2026-02-19 18:35:09 +01:00
parent 4e3787d091
commit 20738d11b3
4 changed files with 349 additions and 209 deletions

View File

@@ -22,15 +22,11 @@ import uuid
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy.orm import Session, joinedload
from sqlalchemy.orm import Session
from app.database import get_db
from app.dependencies.auth import get_current_user, require_any_role
from app.models.audit import AuditLog
from app.models.enums import TestState, TeamSide
from app.models.technique import Technique
from app.models.test import Test
from app.models.test_template import TestTemplate
from app.models.enums import TestState
from app.models.user import User
from app.schemas.test import (
TestCreate,
@@ -46,6 +42,18 @@ from app.schemas.test_template import TestTemplateInstantiate
from app.domain.unit_of_work import UnitOfWork
from app.services.audit_service import log_action
from app.services.status_service import recalculate_technique_status
from app.services.test_crud_service import (
create_test as crud_create_test,
create_test_from_template as crud_create_from_template,
get_test_detail as crud_get_test_detail,
get_test_or_raise as crud_get_test_or_raise,
get_test_timeline as crud_get_test_timeline,
get_test_with_technique as crud_get_test_with_technique,
list_tests as crud_list_tests,
update_test as crud_update_test,
update_test_blue as crud_update_test_blue,
update_test_red as crud_update_test_red,
)
from app.services.test_workflow_service import (
start_execution as wf_start_execution,
submit_red_evidence as wf_submit_red,
@@ -62,29 +70,6 @@ from app.services.test_workflow_service import (
router = APIRouter(prefix="/tests", tags=["tests"])
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _get_test_or_404(db: Session, test_id: uuid.UUID) -> Test:
test = db.query(Test).filter(Test.id == test_id).first()
if test is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Test not found")
return test
def _get_test_with_technique(db: Session, test_id: uuid.UUID) -> Test:
test = (
db.query(Test)
.options(joinedload(Test.technique))
.filter(Test.id == test_id)
.first()
)
if test is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Test not found")
return test
# ---------------------------------------------------------------------------
# GET /tests — list with filters
# ---------------------------------------------------------------------------
@@ -105,30 +90,16 @@ def list_tests(
current_user: User = Depends(get_current_user),
):
"""Return a paginated list of tests, optionally filtered by state, technique, platform or creator."""
query = db.query(Test).options(joinedload(Test.technique))
if state:
query = query.filter(Test.state == state)
if technique_id:
query = query.filter(Test.technique_id == technique_id)
if platform:
from app.utils import escape_like
query = query.filter(Test.platform.ilike(f"%{escape_like(platform)}%"))
if created_by:
query = query.filter(Test.created_by == created_by)
if pending_validation_side == "red":
query = query.filter(
Test.state == TestState.in_review,
Test.red_validation_status.in_(["pending", None]),
)
elif pending_validation_side == "blue":
query = query.filter(
Test.state == TestState.in_review,
Test.blue_validation_status.in_(["pending", None]),
)
tests = query.order_by(Test.created_at.desc()).offset(offset).limit(limit).all()
return tests
return crud_list_tests(
db,
state=state,
technique_id=technique_id,
platform=platform,
created_by=created_by,
pending_validation_side=pending_validation_side,
offset=offset,
limit=limit,
)
# ---------------------------------------------------------------------------
@@ -150,20 +121,14 @@ def create_test(
``created_by`` is set automatically and ``state`` defaults to *draft*.
"""
technique = db.query(Technique).filter(Technique.id == payload.technique_id).first()
if technique is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Technique with id '{payload.technique_id}' not found",
with UnitOfWork(db) as uow:
test = crud_create_test(
db,
technique_id=payload.technique_id,
creator_id=current_user.id,
**payload.model_dump(exclude={"technique_id"}),
)
test = Test(
**payload.model_dump(),
created_by=current_user.id,
state=TestState.draft,
)
db.add(test)
db.commit()
uow.commit()
db.refresh(test)
log_action(
@@ -197,43 +162,14 @@ def create_test_from_template(
The template's fields are copied into the new test as starting data.
"""
template = db.query(TestTemplate).filter(TestTemplate.id == payload.template_id).first()
if template is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"TestTemplate with id '{payload.template_id}' not found",
with UnitOfWork(db) as uow:
test = crud_create_from_template(
db,
template_id=payload.template_id,
technique_id_or_mitre=payload.technique_id,
creator_id=current_user.id,
)
# Resolve technique_id: accept both UUID and MITRE ID (e.g. "T1059.001")
technique = None
try:
technique_uuid = uuid.UUID(payload.technique_id)
technique = db.query(Technique).filter(Technique.id == technique_uuid).first()
except ValueError:
pass
if technique is None:
technique = db.query(Technique).filter(Technique.mitre_id == payload.technique_id).first()
if technique is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Technique '{payload.technique_id}' not found",
)
test = Test(
technique_id=technique.id,
name=template.name,
description=template.description,
platform=template.platform,
procedure_text=template.attack_procedure,
tool_used=template.tool_suggested,
remediation_steps=template.suggested_remediation,
created_by=current_user.id,
state=TestState.draft,
)
db.add(test)
db.commit()
uow.commit()
db.refresh(test)
log_action(
@@ -244,7 +180,7 @@ def create_test_from_template(
entity_id=test.id,
details={
"name": test.name,
"template_id": str(template.id),
"template_id": str(payload.template_id),
"technique_id": str(test.technique_id),
},
)
@@ -264,20 +200,7 @@ def get_test(
current_user: User = Depends(get_current_user),
):
"""Return full details for a single test, including its evidences."""
test = (
db.query(Test)
.options(joinedload(Test.evidences))
.filter(Test.id == test_id)
.first()
)
if test is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Test not found",
)
return test
return crud_get_test_detail(db, test_id)
# ---------------------------------------------------------------------------
@@ -297,29 +220,16 @@ def update_test(
Only leads or admins can update general test fields.
The test must be in ``draft`` or ``rejected`` state.
"""
test = _get_test_or_404(db, test_id)
if current_user.role != "admin" and test.created_by != current_user.id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail={"message": "Only the test creator or an admin can update this test", "code": "FORBIDDEN"},
)
if test.state not in (TestState.draft, TestState.rejected):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"message": f"Cannot update a test in '{test.state.value}' state (must be draft or rejected)",
"code": "INVALID_STATE",
"current_state": test.state.value,
},
)
update_data = payload.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(test, field, value)
db.commit()
with UnitOfWork(db) as uow:
test = crud_update_test(
db,
test_id,
updater_id=current_user.id,
updater_role=current_user.role,
**update_data,
)
uow.commit()
db.refresh(test)
log_action(
@@ -347,23 +257,10 @@ def update_test_red(
current_user: User = Depends(require_any_role("red_tech", "red_lead")),
):
"""Red Team updates their fields (allowed in ``draft`` and ``red_executing``)."""
test = _get_test_or_404(db, test_id)
if test.state not in (TestState.draft, TestState.red_executing):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"message": f"Cannot update red fields in '{test.state.value}' state (must be draft or red_executing)",
"code": "INVALID_STATE",
"current_state": test.state.value,
},
)
update_data = payload.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(test, field, value)
db.commit()
with UnitOfWork(db) as uow:
test = crud_update_test_red(db, test_id, **update_data)
uow.commit()
db.refresh(test)
log_action(
@@ -391,23 +288,10 @@ def update_test_blue(
current_user: User = Depends(require_any_role("blue_tech", "blue_lead")),
):
"""Blue Team updates their fields (allowed only in ``blue_evaluating``)."""
test = _get_test_or_404(db, test_id)
if test.state != TestState.blue_evaluating:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"message": f"Cannot update blue fields in '{test.state.value}' state (must be blue_evaluating)",
"code": "INVALID_STATE",
"current_state": test.state.value,
},
)
update_data = payload.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(test, field, value)
db.commit()
with UnitOfWork(db) as uow:
test = crud_update_test_blue(db, test_id, **update_data)
uow.commit()
db.refresh(test)
log_action(
@@ -434,7 +318,7 @@ def start_execution(
current_user: User = Depends(require_any_role("red_tech", "red_lead")),
):
"""Move a test from ``draft`` to ``red_executing``."""
test = _get_test_or_404(db, test_id)
test = crud_get_test_or_raise(db, test_id)
with UnitOfWork(db) as uow:
test = wf_start_execution(db, test, current_user)
uow.commit()
@@ -454,7 +338,7 @@ def submit_red(
current_user: User = Depends(require_any_role("red_tech", "red_lead")),
):
"""Red Team finalises — move from ``red_executing`` to ``blue_evaluating``."""
test = _get_test_or_404(db, test_id)
test = crud_get_test_or_raise(db, test_id)
with UnitOfWork(db) as uow:
test = wf_submit_red(db, test, current_user)
uow.commit()
@@ -474,7 +358,7 @@ def submit_blue(
current_user: User = Depends(require_any_role("blue_tech", "blue_lead")),
):
"""Blue Team finalises — move from ``blue_evaluating`` to ``in_review``."""
test = _get_test_or_404(db, test_id)
test = crud_get_test_or_raise(db, test_id)
with UnitOfWork(db) as uow:
test = wf_submit_blue(db, test, current_user)
uow.commit()
@@ -494,7 +378,7 @@ def pause_timer(
current_user: User = Depends(require_any_role("red_tech", "blue_tech", "red_lead", "blue_lead")),
):
"""Pause the running timer for the current phase (red_executing or blue_evaluating)."""
test = _get_test_or_404(db, test_id)
test = crud_get_test_or_raise(db, test_id)
with UnitOfWork(db) as uow:
test = wf_pause_timer(db, test, current_user)
uow.commit()
@@ -514,7 +398,7 @@ def resume_timer(
current_user: User = Depends(require_any_role("red_tech", "blue_tech", "red_lead", "blue_lead")),
):
"""Resume the paused timer for the current phase."""
test = _get_test_or_404(db, test_id)
test = crud_get_test_or_raise(db, test_id)
with UnitOfWork(db) as uow:
test = wf_resume_timer(db, test, current_user)
uow.commit()
@@ -535,7 +419,7 @@ def validate_red(
current_user: User = Depends(require_any_role("red_lead")),
):
"""Red Lead approves or rejects the red side of a test."""
test = _get_test_with_technique(db, test_id)
test = crud_get_test_with_technique(db, test_id)
with UnitOfWork(db) as uow:
test = wf_validate_red(
db, test, current_user,
@@ -562,7 +446,7 @@ def validate_blue(
current_user: User = Depends(require_any_role("blue_lead")),
):
"""Blue Lead approves or rejects the blue side of a test."""
test = _get_test_with_technique(db, test_id)
test = crud_get_test_with_technique(db, test_id)
with UnitOfWork(db) as uow:
test = wf_validate_blue(
db, test, current_user,
@@ -588,7 +472,7 @@ def reopen(
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
"""Reopen a rejected test, moving it back to ``draft``."""
test = _get_test_or_404(db, test_id)
test = crud_get_test_or_raise(db, test_id)
with UnitOfWork(db) as uow:
test = wf_reopen(db, test, current_user)
uow.commit()
@@ -613,7 +497,7 @@ def update_remediation(
When ``remediation_status`` transitions to ``'completed'``, an automatic
re-test is created (subject to ``MAX_RETEST_COUNT``).
"""
test = _get_test_or_404(db, test_id)
test = crud_get_test_or_raise(db, test_id)
old_remediation_status = test.remediation_status
@@ -653,29 +537,7 @@ def get_test_timeline(
current_user: User = Depends(get_current_user),
):
"""Return the chronological audit-log history for a test."""
# Verify the test exists
_get_test_or_404(db, test_id)
logs = (
db.query(AuditLog)
.filter(
AuditLog.entity_type == "test",
AuditLog.entity_id == str(test_id),
)
.order_by(AuditLog.timestamp.asc())
.all()
)
return [
{
"id": str(log.id),
"action": log.action,
"user_id": str(log.user_id) if log.user_id else None,
"timestamp": log.timestamp.isoformat() if log.timestamp else None,
"details": log.details,
}
for log in logs
]
return crud_get_test_timeline(db, test_id)
# ---------------------------------------------------------------------------