refactor(tests): extract CRUD/query logic to test_crud_service, router delegates to service with domain exceptions

This commit is contained in:
2026-02-19 18:35:09 +01:00
parent 4e3787d091
commit 20738d11b3
4 changed files with 349 additions and 209 deletions

View File

@@ -22,15 +22,11 @@ import uuid
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy.orm import Session, joinedload
from sqlalchemy.orm import Session
from app.database import get_db
from app.dependencies.auth import get_current_user, require_any_role
from app.models.audit import AuditLog
from app.models.enums import TestState, TeamSide
from app.models.technique import Technique
from app.models.test import Test
from app.models.test_template import TestTemplate
from app.models.enums import TestState
from app.models.user import User
from app.schemas.test import (
TestCreate,
@@ -46,6 +42,18 @@ from app.schemas.test_template import TestTemplateInstantiate
from app.domain.unit_of_work import UnitOfWork
from app.services.audit_service import log_action
from app.services.status_service import recalculate_technique_status
from app.services.test_crud_service import (
create_test as crud_create_test,
create_test_from_template as crud_create_from_template,
get_test_detail as crud_get_test_detail,
get_test_or_raise as crud_get_test_or_raise,
get_test_timeline as crud_get_test_timeline,
get_test_with_technique as crud_get_test_with_technique,
list_tests as crud_list_tests,
update_test as crud_update_test,
update_test_blue as crud_update_test_blue,
update_test_red as crud_update_test_red,
)
from app.services.test_workflow_service import (
start_execution as wf_start_execution,
submit_red_evidence as wf_submit_red,
@@ -62,29 +70,6 @@ from app.services.test_workflow_service import (
router = APIRouter(prefix="/tests", tags=["tests"])
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _get_test_or_404(db: Session, test_id: uuid.UUID) -> Test:
test = db.query(Test).filter(Test.id == test_id).first()
if test is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Test not found")
return test
def _get_test_with_technique(db: Session, test_id: uuid.UUID) -> Test:
test = (
db.query(Test)
.options(joinedload(Test.technique))
.filter(Test.id == test_id)
.first()
)
if test is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Test not found")
return test
# ---------------------------------------------------------------------------
# GET /tests — list with filters
# ---------------------------------------------------------------------------
@@ -105,30 +90,16 @@ def list_tests(
current_user: User = Depends(get_current_user),
):
"""Return a paginated list of tests, optionally filtered by state, technique, platform or creator."""
query = db.query(Test).options(joinedload(Test.technique))
if state:
query = query.filter(Test.state == state)
if technique_id:
query = query.filter(Test.technique_id == technique_id)
if platform:
from app.utils import escape_like
query = query.filter(Test.platform.ilike(f"%{escape_like(platform)}%"))
if created_by:
query = query.filter(Test.created_by == created_by)
if pending_validation_side == "red":
query = query.filter(
Test.state == TestState.in_review,
Test.red_validation_status.in_(["pending", None]),
)
elif pending_validation_side == "blue":
query = query.filter(
Test.state == TestState.in_review,
Test.blue_validation_status.in_(["pending", None]),
)
tests = query.order_by(Test.created_at.desc()).offset(offset).limit(limit).all()
return tests
return crud_list_tests(
db,
state=state,
technique_id=technique_id,
platform=platform,
created_by=created_by,
pending_validation_side=pending_validation_side,
offset=offset,
limit=limit,
)
# ---------------------------------------------------------------------------
@@ -150,20 +121,14 @@ def create_test(
``created_by`` is set automatically and ``state`` defaults to *draft*.
"""
technique = db.query(Technique).filter(Technique.id == payload.technique_id).first()
if technique is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Technique with id '{payload.technique_id}' not found",
with UnitOfWork(db) as uow:
test = crud_create_test(
db,
technique_id=payload.technique_id,
creator_id=current_user.id,
**payload.model_dump(exclude={"technique_id"}),
)
test = Test(
**payload.model_dump(),
created_by=current_user.id,
state=TestState.draft,
)
db.add(test)
db.commit()
uow.commit()
db.refresh(test)
log_action(
@@ -197,43 +162,14 @@ def create_test_from_template(
The template's fields are copied into the new test as starting data.
"""
template = db.query(TestTemplate).filter(TestTemplate.id == payload.template_id).first()
if template is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"TestTemplate with id '{payload.template_id}' not found",
with UnitOfWork(db) as uow:
test = crud_create_from_template(
db,
template_id=payload.template_id,
technique_id_or_mitre=payload.technique_id,
creator_id=current_user.id,
)
# Resolve technique_id: accept both UUID and MITRE ID (e.g. "T1059.001")
technique = None
try:
technique_uuid = uuid.UUID(payload.technique_id)
technique = db.query(Technique).filter(Technique.id == technique_uuid).first()
except ValueError:
pass
if technique is None:
technique = db.query(Technique).filter(Technique.mitre_id == payload.technique_id).first()
if technique is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Technique '{payload.technique_id}' not found",
)
test = Test(
technique_id=technique.id,
name=template.name,
description=template.description,
platform=template.platform,
procedure_text=template.attack_procedure,
tool_used=template.tool_suggested,
remediation_steps=template.suggested_remediation,
created_by=current_user.id,
state=TestState.draft,
)
db.add(test)
db.commit()
uow.commit()
db.refresh(test)
log_action(
@@ -244,7 +180,7 @@ def create_test_from_template(
entity_id=test.id,
details={
"name": test.name,
"template_id": str(template.id),
"template_id": str(payload.template_id),
"technique_id": str(test.technique_id),
},
)
@@ -264,20 +200,7 @@ def get_test(
current_user: User = Depends(get_current_user),
):
"""Return full details for a single test, including its evidences."""
test = (
db.query(Test)
.options(joinedload(Test.evidences))
.filter(Test.id == test_id)
.first()
)
if test is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Test not found",
)
return test
return crud_get_test_detail(db, test_id)
# ---------------------------------------------------------------------------
@@ -297,29 +220,16 @@ def update_test(
Only leads or admins can update general test fields.
The test must be in ``draft`` or ``rejected`` state.
"""
test = _get_test_or_404(db, test_id)
if current_user.role != "admin" and test.created_by != current_user.id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail={"message": "Only the test creator or an admin can update this test", "code": "FORBIDDEN"},
)
if test.state not in (TestState.draft, TestState.rejected):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"message": f"Cannot update a test in '{test.state.value}' state (must be draft or rejected)",
"code": "INVALID_STATE",
"current_state": test.state.value,
},
)
update_data = payload.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(test, field, value)
db.commit()
with UnitOfWork(db) as uow:
test = crud_update_test(
db,
test_id,
updater_id=current_user.id,
updater_role=current_user.role,
**update_data,
)
uow.commit()
db.refresh(test)
log_action(
@@ -347,23 +257,10 @@ def update_test_red(
current_user: User = Depends(require_any_role("red_tech", "red_lead")),
):
"""Red Team updates their fields (allowed in ``draft`` and ``red_executing``)."""
test = _get_test_or_404(db, test_id)
if test.state not in (TestState.draft, TestState.red_executing):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"message": f"Cannot update red fields in '{test.state.value}' state (must be draft or red_executing)",
"code": "INVALID_STATE",
"current_state": test.state.value,
},
)
update_data = payload.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(test, field, value)
db.commit()
with UnitOfWork(db) as uow:
test = crud_update_test_red(db, test_id, **update_data)
uow.commit()
db.refresh(test)
log_action(
@@ -391,23 +288,10 @@ def update_test_blue(
current_user: User = Depends(require_any_role("blue_tech", "blue_lead")),
):
"""Blue Team updates their fields (allowed only in ``blue_evaluating``)."""
test = _get_test_or_404(db, test_id)
if test.state != TestState.blue_evaluating:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"message": f"Cannot update blue fields in '{test.state.value}' state (must be blue_evaluating)",
"code": "INVALID_STATE",
"current_state": test.state.value,
},
)
update_data = payload.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(test, field, value)
db.commit()
with UnitOfWork(db) as uow:
test = crud_update_test_blue(db, test_id, **update_data)
uow.commit()
db.refresh(test)
log_action(
@@ -434,7 +318,7 @@ def start_execution(
current_user: User = Depends(require_any_role("red_tech", "red_lead")),
):
"""Move a test from ``draft`` to ``red_executing``."""
test = _get_test_or_404(db, test_id)
test = crud_get_test_or_raise(db, test_id)
with UnitOfWork(db) as uow:
test = wf_start_execution(db, test, current_user)
uow.commit()
@@ -454,7 +338,7 @@ def submit_red(
current_user: User = Depends(require_any_role("red_tech", "red_lead")),
):
"""Red Team finalises — move from ``red_executing`` to ``blue_evaluating``."""
test = _get_test_or_404(db, test_id)
test = crud_get_test_or_raise(db, test_id)
with UnitOfWork(db) as uow:
test = wf_submit_red(db, test, current_user)
uow.commit()
@@ -474,7 +358,7 @@ def submit_blue(
current_user: User = Depends(require_any_role("blue_tech", "blue_lead")),
):
"""Blue Team finalises — move from ``blue_evaluating`` to ``in_review``."""
test = _get_test_or_404(db, test_id)
test = crud_get_test_or_raise(db, test_id)
with UnitOfWork(db) as uow:
test = wf_submit_blue(db, test, current_user)
uow.commit()
@@ -494,7 +378,7 @@ def pause_timer(
current_user: User = Depends(require_any_role("red_tech", "blue_tech", "red_lead", "blue_lead")),
):
"""Pause the running timer for the current phase (red_executing or blue_evaluating)."""
test = _get_test_or_404(db, test_id)
test = crud_get_test_or_raise(db, test_id)
with UnitOfWork(db) as uow:
test = wf_pause_timer(db, test, current_user)
uow.commit()
@@ -514,7 +398,7 @@ def resume_timer(
current_user: User = Depends(require_any_role("red_tech", "blue_tech", "red_lead", "blue_lead")),
):
"""Resume the paused timer for the current phase."""
test = _get_test_or_404(db, test_id)
test = crud_get_test_or_raise(db, test_id)
with UnitOfWork(db) as uow:
test = wf_resume_timer(db, test, current_user)
uow.commit()
@@ -535,7 +419,7 @@ def validate_red(
current_user: User = Depends(require_any_role("red_lead")),
):
"""Red Lead approves or rejects the red side of a test."""
test = _get_test_with_technique(db, test_id)
test = crud_get_test_with_technique(db, test_id)
with UnitOfWork(db) as uow:
test = wf_validate_red(
db, test, current_user,
@@ -562,7 +446,7 @@ def validate_blue(
current_user: User = Depends(require_any_role("blue_lead")),
):
"""Blue Lead approves or rejects the blue side of a test."""
test = _get_test_with_technique(db, test_id)
test = crud_get_test_with_technique(db, test_id)
with UnitOfWork(db) as uow:
test = wf_validate_blue(
db, test, current_user,
@@ -588,7 +472,7 @@ def reopen(
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
"""Reopen a rejected test, moving it back to ``draft``."""
test = _get_test_or_404(db, test_id)
test = crud_get_test_or_raise(db, test_id)
with UnitOfWork(db) as uow:
test = wf_reopen(db, test, current_user)
uow.commit()
@@ -613,7 +497,7 @@ def update_remediation(
When ``remediation_status`` transitions to ``'completed'``, an automatic
re-test is created (subject to ``MAX_RETEST_COUNT``).
"""
test = _get_test_or_404(db, test_id)
test = crud_get_test_or_raise(db, test_id)
old_remediation_status = test.remediation_status
@@ -653,29 +537,7 @@ def get_test_timeline(
current_user: User = Depends(get_current_user),
):
"""Return the chronological audit-log history for a test."""
# Verify the test exists
_get_test_or_404(db, test_id)
logs = (
db.query(AuditLog)
.filter(
AuditLog.entity_type == "test",
AuditLog.entity_id == str(test_id),
)
.order_by(AuditLog.timestamp.asc())
.all()
)
return [
{
"id": str(log.id),
"action": log.action,
"user_id": str(log.user_id) if log.user_id else None,
"timestamp": log.timestamp.isoformat() if log.timestamp else None,
"details": log.details,
}
for log in logs
]
return crud_get_test_timeline(db, test_id)
# ---------------------------------------------------------------------------

View File

@@ -0,0 +1,277 @@
"""Test CRUD service — list, create, update, and query logic for security tests.
Framework-agnostic; uses domain exceptions from app.domain.errors.
The router is responsible for HTTP concerns, auth, audit logging, and commit.
"""
import uuid
from typing import Any
from sqlalchemy.orm import Session, joinedload
from app.domain.errors import (
BusinessRuleViolation,
EntityNotFoundError,
PermissionViolation,
)
from app.models.enums import TestState
from app.models.technique import Technique
from app.models.test import Test
from app.models.test_template import TestTemplate
from app.models.audit import AuditLog
from app.utils import escape_like
def list_tests(
db: Session,
*,
state: str | None = None,
technique_id: uuid.UUID | None = None,
platform: str | None = None,
created_by: uuid.UUID | None = None,
pending_validation_side: str | None = None,
offset: int = 0,
limit: int = 50,
) -> list[Test]:
"""Return a paginated list of tests with optional filters."""
query = db.query(Test).options(joinedload(Test.technique))
if state:
query = query.filter(Test.state == state)
if technique_id:
query = query.filter(Test.technique_id == technique_id)
if platform:
query = query.filter(Test.platform.ilike(f"%{escape_like(platform)}%"))
if created_by:
query = query.filter(Test.created_by == created_by)
if pending_validation_side == "red":
query = query.filter(
Test.state == TestState.in_review,
Test.red_validation_status.in_(["pending", None]),
)
elif pending_validation_side == "blue":
query = query.filter(
Test.state == TestState.in_review,
Test.blue_validation_status.in_(["pending", None]),
)
return query.order_by(Test.created_at.desc()).offset(offset).limit(limit).all()
def create_test(
db: Session,
*,
technique_id: uuid.UUID,
creator_id: uuid.UUID,
**fields: Any,
) -> Test:
"""Create a new test linked to an existing technique.
Raises EntityNotFoundError if the technique does not exist.
Does not commit; caller uses UnitOfWork.
"""
technique = db.query(Technique).filter(Technique.id == technique_id).first()
if technique is None:
raise EntityNotFoundError("Technique", str(technique_id))
test = Test(
technique_id=technique_id,
created_by=creator_id,
state=TestState.draft,
**fields,
)
db.add(test)
db.flush()
return test
def create_test_from_template(
db: Session,
*,
template_id: uuid.UUID,
technique_id_or_mitre: str,
creator_id: uuid.UUID,
) -> Test:
"""Instantiate a Test from a TestTemplate.
technique_id_or_mitre can be a UUID string or MITRE ID (e.g. T1059.001).
Raises EntityNotFoundError if template or technique not found.
Does not commit; caller uses UnitOfWork.
"""
template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first()
if template is None:
raise EntityNotFoundError("TestTemplate", str(template_id))
technique = None
try:
technique_uuid = uuid.UUID(technique_id_or_mitre)
technique = db.query(Technique).filter(Technique.id == technique_uuid).first()
except ValueError:
pass
if technique is None:
technique = db.query(Technique).filter(
Technique.mitre_id == technique_id_or_mitre
).first()
if technique is None:
raise EntityNotFoundError("Technique", technique_id_or_mitre)
test = Test(
technique_id=technique.id,
name=template.name,
description=template.description,
platform=template.platform,
procedure_text=template.attack_procedure,
tool_used=template.tool_suggested,
remediation_steps=template.suggested_remediation,
created_by=creator_id,
state=TestState.draft,
)
db.add(test)
db.flush()
return test
def get_test_detail(db: Session, test_id: uuid.UUID) -> Test:
"""Fetch a test with evidences eager-loaded.
Raises EntityNotFoundError if the test does not exist.
"""
test = (
db.query(Test)
.options(joinedload(Test.evidences))
.filter(Test.id == test_id)
.first()
)
if test is None:
raise EntityNotFoundError("Test", str(test_id))
return test
def get_test_or_raise(db: Session, test_id: uuid.UUID) -> Test:
"""Fetch a test by ID. Raises EntityNotFoundError if not found."""
test = db.query(Test).filter(Test.id == test_id).first()
if test is None:
raise EntityNotFoundError("Test", str(test_id))
return test
def get_test_with_technique(db: Session, test_id: uuid.UUID) -> Test:
"""Fetch a test with technique joined. Raises EntityNotFoundError if not found."""
test = (
db.query(Test)
.options(joinedload(Test.technique))
.filter(Test.id == test_id)
.first()
)
if test is None:
raise EntityNotFoundError("Test", str(test_id))
return test
def update_test(
db: Session,
test_id: uuid.UUID,
*,
updater_id: uuid.UUID,
updater_role: str,
**fields: Any,
) -> Test:
"""Update general test fields (draft or rejected only).
Raises PermissionViolation if not creator or admin.
Raises BusinessRuleViolation if state is not draft or rejected.
Raises EntityNotFoundError if test not found.
Does not commit; caller uses UnitOfWork.
"""
test = get_test_or_raise(db, test_id)
if updater_role != "admin" and test.created_by != updater_id:
raise PermissionViolation(
"Only the test creator or an admin can update this test"
)
if test.state not in (TestState.draft, TestState.rejected):
raise BusinessRuleViolation(
f"Cannot update a test in '{test.state.value}' state (must be draft or rejected)"
)
for field, value in fields.items():
setattr(test, field, value)
db.flush()
return test
def update_test_red(db: Session, test_id: uuid.UUID, **fields: Any) -> Test:
"""Update Red Team fields (draft or red_executing only).
Raises BusinessRuleViolation if state not in (draft, red_executing).
Raises EntityNotFoundError if test not found.
Does not commit; caller uses UnitOfWork.
"""
test = get_test_or_raise(db, test_id)
if test.state not in (TestState.draft, TestState.red_executing):
raise BusinessRuleViolation(
f"Cannot update red fields in '{test.state.value}' state "
"(must be draft or red_executing)"
)
for field, value in fields.items():
setattr(test, field, value)
db.flush()
return test
def update_test_blue(db: Session, test_id: uuid.UUID, **fields: Any) -> Test:
"""Update Blue Team fields (blue_evaluating only).
Raises BusinessRuleViolation if state is not blue_evaluating.
Raises EntityNotFoundError if test not found.
Does not commit; caller uses UnitOfWork.
"""
test = get_test_or_raise(db, test_id)
if test.state != TestState.blue_evaluating:
raise BusinessRuleViolation(
f"Cannot update blue fields in '{test.state.value}' state "
"(must be blue_evaluating)"
)
for field, value in fields.items():
setattr(test, field, value)
db.flush()
return test
def get_test_timeline(db: Session, test_id: uuid.UUID) -> list[dict[str, Any]]:
"""Return chronological audit-log history for a test.
Raises EntityNotFoundError if the test does not exist.
"""
get_test_or_raise(db, test_id)
logs = (
db.query(AuditLog)
.filter(
AuditLog.entity_type == "test",
AuditLog.entity_id == str(test_id),
)
.order_by(AuditLog.timestamp.asc())
.all()
)
return [
{
"id": str(log.id),
"action": log.action,
"user_id": str(log.user_id) if log.user_id else None,
"timestamp": log.timestamp.isoformat() if log.timestamp else None,
"details": log.details,
}
for log in logs
]

View File

@@ -101,7 +101,7 @@ from app.routers.test_templates import (
toggle_template_active,
template_stats,
)
from app.routers.tests import create_test_from_template
from app.services.test_crud_service import create_test_from_template as crud_create_from_template
from app.schemas.test_template import TestTemplateCreate
@@ -174,7 +174,8 @@ def test_get_templates_by_technique():
def test_instantiate_template():
"""POST /tests/from-template creates a test pre-filled from template data."""
source = inspect.getsource(create_test_from_template)
# Template field copying lives in the service; router delegates to it
source = inspect.getsource(crud_create_from_template)
# Verify it reads from template and copies fields
assert "template" in source, "Must reference template"

View File

@@ -477,15 +477,15 @@ def test_evidence_team_separation():
def test_red_edit_allowed_in_draft_and_red_executing():
"""Verify the red update router checks that state is draft or red_executing."""
from app.routers.tests import update_test_red
"""Verify the red update checks that state is draft or red_executing."""
from app.services.test_crud_service import update_test_red
import inspect
source = inspect.getsource(update_test_red)
# The function must guard against states other than draft/red_executing
# The service must guard against states other than draft/red_executing
assert "draft" in source, "Red update must allow draft state"
assert "red_executing" in source, "Red update must allow red_executing state"
assert "400" in source or "HTTP_400_BAD_REQUEST" in source, "Red update must return 400 for invalid state"
assert "BusinessRuleViolation" in source, "Must raise domain exception for invalid state (mapped to 400)"
# ===========================================================================