refactor(tests): extract CRUD/query logic to test_crud_service, router delegates to service with domain exceptions
This commit is contained in:
@@ -22,15 +22,11 @@ import uuid
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||||
from sqlalchemy.orm import Session, joinedload
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
from app.dependencies.auth import get_current_user, require_any_role
|
from app.dependencies.auth import get_current_user, require_any_role
|
||||||
from app.models.audit import AuditLog
|
from app.models.enums import TestState
|
||||||
from app.models.enums import TestState, TeamSide
|
|
||||||
from app.models.technique import Technique
|
|
||||||
from app.models.test import Test
|
|
||||||
from app.models.test_template import TestTemplate
|
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.schemas.test import (
|
from app.schemas.test import (
|
||||||
TestCreate,
|
TestCreate,
|
||||||
@@ -46,6 +42,18 @@ from app.schemas.test_template import TestTemplateInstantiate
|
|||||||
from app.domain.unit_of_work import UnitOfWork
|
from app.domain.unit_of_work import UnitOfWork
|
||||||
from app.services.audit_service import log_action
|
from app.services.audit_service import log_action
|
||||||
from app.services.status_service import recalculate_technique_status
|
from app.services.status_service import recalculate_technique_status
|
||||||
|
from app.services.test_crud_service import (
|
||||||
|
create_test as crud_create_test,
|
||||||
|
create_test_from_template as crud_create_from_template,
|
||||||
|
get_test_detail as crud_get_test_detail,
|
||||||
|
get_test_or_raise as crud_get_test_or_raise,
|
||||||
|
get_test_timeline as crud_get_test_timeline,
|
||||||
|
get_test_with_technique as crud_get_test_with_technique,
|
||||||
|
list_tests as crud_list_tests,
|
||||||
|
update_test as crud_update_test,
|
||||||
|
update_test_blue as crud_update_test_blue,
|
||||||
|
update_test_red as crud_update_test_red,
|
||||||
|
)
|
||||||
from app.services.test_workflow_service import (
|
from app.services.test_workflow_service import (
|
||||||
start_execution as wf_start_execution,
|
start_execution as wf_start_execution,
|
||||||
submit_red_evidence as wf_submit_red,
|
submit_red_evidence as wf_submit_red,
|
||||||
@@ -62,29 +70,6 @@ from app.services.test_workflow_service import (
|
|||||||
router = APIRouter(prefix="/tests", tags=["tests"])
|
router = APIRouter(prefix="/tests", tags=["tests"])
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Helpers
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
def _get_test_or_404(db: Session, test_id: uuid.UUID) -> Test:
|
|
||||||
test = db.query(Test).filter(Test.id == test_id).first()
|
|
||||||
if test is None:
|
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Test not found")
|
|
||||||
return test
|
|
||||||
|
|
||||||
|
|
||||||
def _get_test_with_technique(db: Session, test_id: uuid.UUID) -> Test:
|
|
||||||
test = (
|
|
||||||
db.query(Test)
|
|
||||||
.options(joinedload(Test.technique))
|
|
||||||
.filter(Test.id == test_id)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
if test is None:
|
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Test not found")
|
|
||||||
return test
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# GET /tests — list with filters
|
# GET /tests — list with filters
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -105,30 +90,16 @@ def list_tests(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Return a paginated list of tests, optionally filtered by state, technique, platform or creator."""
|
"""Return a paginated list of tests, optionally filtered by state, technique, platform or creator."""
|
||||||
query = db.query(Test).options(joinedload(Test.technique))
|
return crud_list_tests(
|
||||||
|
db,
|
||||||
if state:
|
state=state,
|
||||||
query = query.filter(Test.state == state)
|
technique_id=technique_id,
|
||||||
if technique_id:
|
platform=platform,
|
||||||
query = query.filter(Test.technique_id == technique_id)
|
created_by=created_by,
|
||||||
if platform:
|
pending_validation_side=pending_validation_side,
|
||||||
from app.utils import escape_like
|
offset=offset,
|
||||||
query = query.filter(Test.platform.ilike(f"%{escape_like(platform)}%"))
|
limit=limit,
|
||||||
if created_by:
|
|
||||||
query = query.filter(Test.created_by == created_by)
|
|
||||||
if pending_validation_side == "red":
|
|
||||||
query = query.filter(
|
|
||||||
Test.state == TestState.in_review,
|
|
||||||
Test.red_validation_status.in_(["pending", None]),
|
|
||||||
)
|
)
|
||||||
elif pending_validation_side == "blue":
|
|
||||||
query = query.filter(
|
|
||||||
Test.state == TestState.in_review,
|
|
||||||
Test.blue_validation_status.in_(["pending", None]),
|
|
||||||
)
|
|
||||||
|
|
||||||
tests = query.order_by(Test.created_at.desc()).offset(offset).limit(limit).all()
|
|
||||||
return tests
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -150,20 +121,14 @@ def create_test(
|
|||||||
|
|
||||||
``created_by`` is set automatically and ``state`` defaults to *draft*.
|
``created_by`` is set automatically and ``state`` defaults to *draft*.
|
||||||
"""
|
"""
|
||||||
technique = db.query(Technique).filter(Technique.id == payload.technique_id).first()
|
with UnitOfWork(db) as uow:
|
||||||
if technique is None:
|
test = crud_create_test(
|
||||||
raise HTTPException(
|
db,
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
technique_id=payload.technique_id,
|
||||||
detail=f"Technique with id '{payload.technique_id}' not found",
|
creator_id=current_user.id,
|
||||||
|
**payload.model_dump(exclude={"technique_id"}),
|
||||||
)
|
)
|
||||||
|
uow.commit()
|
||||||
test = Test(
|
|
||||||
**payload.model_dump(),
|
|
||||||
created_by=current_user.id,
|
|
||||||
state=TestState.draft,
|
|
||||||
)
|
|
||||||
db.add(test)
|
|
||||||
db.commit()
|
|
||||||
db.refresh(test)
|
db.refresh(test)
|
||||||
|
|
||||||
log_action(
|
log_action(
|
||||||
@@ -197,43 +162,14 @@ def create_test_from_template(
|
|||||||
|
|
||||||
The template's fields are copied into the new test as starting data.
|
The template's fields are copied into the new test as starting data.
|
||||||
"""
|
"""
|
||||||
template = db.query(TestTemplate).filter(TestTemplate.id == payload.template_id).first()
|
with UnitOfWork(db) as uow:
|
||||||
if template is None:
|
test = crud_create_from_template(
|
||||||
raise HTTPException(
|
db,
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
template_id=payload.template_id,
|
||||||
detail=f"TestTemplate with id '{payload.template_id}' not found",
|
technique_id_or_mitre=payload.technique_id,
|
||||||
|
creator_id=current_user.id,
|
||||||
)
|
)
|
||||||
|
uow.commit()
|
||||||
# Resolve technique_id: accept both UUID and MITRE ID (e.g. "T1059.001")
|
|
||||||
technique = None
|
|
||||||
try:
|
|
||||||
technique_uuid = uuid.UUID(payload.technique_id)
|
|
||||||
technique = db.query(Technique).filter(Technique.id == technique_uuid).first()
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if technique is None:
|
|
||||||
technique = db.query(Technique).filter(Technique.mitre_id == payload.technique_id).first()
|
|
||||||
|
|
||||||
if technique is None:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail=f"Technique '{payload.technique_id}' not found",
|
|
||||||
)
|
|
||||||
|
|
||||||
test = Test(
|
|
||||||
technique_id=technique.id,
|
|
||||||
name=template.name,
|
|
||||||
description=template.description,
|
|
||||||
platform=template.platform,
|
|
||||||
procedure_text=template.attack_procedure,
|
|
||||||
tool_used=template.tool_suggested,
|
|
||||||
remediation_steps=template.suggested_remediation,
|
|
||||||
created_by=current_user.id,
|
|
||||||
state=TestState.draft,
|
|
||||||
)
|
|
||||||
db.add(test)
|
|
||||||
db.commit()
|
|
||||||
db.refresh(test)
|
db.refresh(test)
|
||||||
|
|
||||||
log_action(
|
log_action(
|
||||||
@@ -244,7 +180,7 @@ def create_test_from_template(
|
|||||||
entity_id=test.id,
|
entity_id=test.id,
|
||||||
details={
|
details={
|
||||||
"name": test.name,
|
"name": test.name,
|
||||||
"template_id": str(template.id),
|
"template_id": str(payload.template_id),
|
||||||
"technique_id": str(test.technique_id),
|
"technique_id": str(test.technique_id),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@@ -264,20 +200,7 @@ def get_test(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Return full details for a single test, including its evidences."""
|
"""Return full details for a single test, including its evidences."""
|
||||||
test = (
|
return crud_get_test_detail(db, test_id)
|
||||||
db.query(Test)
|
|
||||||
.options(joinedload(Test.evidences))
|
|
||||||
.filter(Test.id == test_id)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
if test is None:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail="Test not found",
|
|
||||||
)
|
|
||||||
|
|
||||||
return test
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -297,29 +220,16 @@ def update_test(
|
|||||||
Only leads or admins can update general test fields.
|
Only leads or admins can update general test fields.
|
||||||
The test must be in ``draft`` or ``rejected`` state.
|
The test must be in ``draft`` or ``rejected`` state.
|
||||||
"""
|
"""
|
||||||
test = _get_test_or_404(db, test_id)
|
|
||||||
|
|
||||||
if current_user.role != "admin" and test.created_by != current_user.id:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
detail={"message": "Only the test creator or an admin can update this test", "code": "FORBIDDEN"},
|
|
||||||
)
|
|
||||||
|
|
||||||
if test.state not in (TestState.draft, TestState.rejected):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail={
|
|
||||||
"message": f"Cannot update a test in '{test.state.value}' state (must be draft or rejected)",
|
|
||||||
"code": "INVALID_STATE",
|
|
||||||
"current_state": test.state.value,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
update_data = payload.model_dump(exclude_unset=True)
|
update_data = payload.model_dump(exclude_unset=True)
|
||||||
for field, value in update_data.items():
|
with UnitOfWork(db) as uow:
|
||||||
setattr(test, field, value)
|
test = crud_update_test(
|
||||||
|
db,
|
||||||
db.commit()
|
test_id,
|
||||||
|
updater_id=current_user.id,
|
||||||
|
updater_role=current_user.role,
|
||||||
|
**update_data,
|
||||||
|
)
|
||||||
|
uow.commit()
|
||||||
db.refresh(test)
|
db.refresh(test)
|
||||||
|
|
||||||
log_action(
|
log_action(
|
||||||
@@ -347,23 +257,10 @@ def update_test_red(
|
|||||||
current_user: User = Depends(require_any_role("red_tech", "red_lead")),
|
current_user: User = Depends(require_any_role("red_tech", "red_lead")),
|
||||||
):
|
):
|
||||||
"""Red Team updates their fields (allowed in ``draft`` and ``red_executing``)."""
|
"""Red Team updates their fields (allowed in ``draft`` and ``red_executing``)."""
|
||||||
test = _get_test_or_404(db, test_id)
|
|
||||||
|
|
||||||
if test.state not in (TestState.draft, TestState.red_executing):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail={
|
|
||||||
"message": f"Cannot update red fields in '{test.state.value}' state (must be draft or red_executing)",
|
|
||||||
"code": "INVALID_STATE",
|
|
||||||
"current_state": test.state.value,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
update_data = payload.model_dump(exclude_unset=True)
|
update_data = payload.model_dump(exclude_unset=True)
|
||||||
for field, value in update_data.items():
|
with UnitOfWork(db) as uow:
|
||||||
setattr(test, field, value)
|
test = crud_update_test_red(db, test_id, **update_data)
|
||||||
|
uow.commit()
|
||||||
db.commit()
|
|
||||||
db.refresh(test)
|
db.refresh(test)
|
||||||
|
|
||||||
log_action(
|
log_action(
|
||||||
@@ -391,23 +288,10 @@ def update_test_blue(
|
|||||||
current_user: User = Depends(require_any_role("blue_tech", "blue_lead")),
|
current_user: User = Depends(require_any_role("blue_tech", "blue_lead")),
|
||||||
):
|
):
|
||||||
"""Blue Team updates their fields (allowed only in ``blue_evaluating``)."""
|
"""Blue Team updates their fields (allowed only in ``blue_evaluating``)."""
|
||||||
test = _get_test_or_404(db, test_id)
|
|
||||||
|
|
||||||
if test.state != TestState.blue_evaluating:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail={
|
|
||||||
"message": f"Cannot update blue fields in '{test.state.value}' state (must be blue_evaluating)",
|
|
||||||
"code": "INVALID_STATE",
|
|
||||||
"current_state": test.state.value,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
update_data = payload.model_dump(exclude_unset=True)
|
update_data = payload.model_dump(exclude_unset=True)
|
||||||
for field, value in update_data.items():
|
with UnitOfWork(db) as uow:
|
||||||
setattr(test, field, value)
|
test = crud_update_test_blue(db, test_id, **update_data)
|
||||||
|
uow.commit()
|
||||||
db.commit()
|
|
||||||
db.refresh(test)
|
db.refresh(test)
|
||||||
|
|
||||||
log_action(
|
log_action(
|
||||||
@@ -434,7 +318,7 @@ def start_execution(
|
|||||||
current_user: User = Depends(require_any_role("red_tech", "red_lead")),
|
current_user: User = Depends(require_any_role("red_tech", "red_lead")),
|
||||||
):
|
):
|
||||||
"""Move a test from ``draft`` to ``red_executing``."""
|
"""Move a test from ``draft`` to ``red_executing``."""
|
||||||
test = _get_test_or_404(db, test_id)
|
test = crud_get_test_or_raise(db, test_id)
|
||||||
with UnitOfWork(db) as uow:
|
with UnitOfWork(db) as uow:
|
||||||
test = wf_start_execution(db, test, current_user)
|
test = wf_start_execution(db, test, current_user)
|
||||||
uow.commit()
|
uow.commit()
|
||||||
@@ -454,7 +338,7 @@ def submit_red(
|
|||||||
current_user: User = Depends(require_any_role("red_tech", "red_lead")),
|
current_user: User = Depends(require_any_role("red_tech", "red_lead")),
|
||||||
):
|
):
|
||||||
"""Red Team finalises — move from ``red_executing`` to ``blue_evaluating``."""
|
"""Red Team finalises — move from ``red_executing`` to ``blue_evaluating``."""
|
||||||
test = _get_test_or_404(db, test_id)
|
test = crud_get_test_or_raise(db, test_id)
|
||||||
with UnitOfWork(db) as uow:
|
with UnitOfWork(db) as uow:
|
||||||
test = wf_submit_red(db, test, current_user)
|
test = wf_submit_red(db, test, current_user)
|
||||||
uow.commit()
|
uow.commit()
|
||||||
@@ -474,7 +358,7 @@ def submit_blue(
|
|||||||
current_user: User = Depends(require_any_role("blue_tech", "blue_lead")),
|
current_user: User = Depends(require_any_role("blue_tech", "blue_lead")),
|
||||||
):
|
):
|
||||||
"""Blue Team finalises — move from ``blue_evaluating`` to ``in_review``."""
|
"""Blue Team finalises — move from ``blue_evaluating`` to ``in_review``."""
|
||||||
test = _get_test_or_404(db, test_id)
|
test = crud_get_test_or_raise(db, test_id)
|
||||||
with UnitOfWork(db) as uow:
|
with UnitOfWork(db) as uow:
|
||||||
test = wf_submit_blue(db, test, current_user)
|
test = wf_submit_blue(db, test, current_user)
|
||||||
uow.commit()
|
uow.commit()
|
||||||
@@ -494,7 +378,7 @@ def pause_timer(
|
|||||||
current_user: User = Depends(require_any_role("red_tech", "blue_tech", "red_lead", "blue_lead")),
|
current_user: User = Depends(require_any_role("red_tech", "blue_tech", "red_lead", "blue_lead")),
|
||||||
):
|
):
|
||||||
"""Pause the running timer for the current phase (red_executing or blue_evaluating)."""
|
"""Pause the running timer for the current phase (red_executing or blue_evaluating)."""
|
||||||
test = _get_test_or_404(db, test_id)
|
test = crud_get_test_or_raise(db, test_id)
|
||||||
with UnitOfWork(db) as uow:
|
with UnitOfWork(db) as uow:
|
||||||
test = wf_pause_timer(db, test, current_user)
|
test = wf_pause_timer(db, test, current_user)
|
||||||
uow.commit()
|
uow.commit()
|
||||||
@@ -514,7 +398,7 @@ def resume_timer(
|
|||||||
current_user: User = Depends(require_any_role("red_tech", "blue_tech", "red_lead", "blue_lead")),
|
current_user: User = Depends(require_any_role("red_tech", "blue_tech", "red_lead", "blue_lead")),
|
||||||
):
|
):
|
||||||
"""Resume the paused timer for the current phase."""
|
"""Resume the paused timer for the current phase."""
|
||||||
test = _get_test_or_404(db, test_id)
|
test = crud_get_test_or_raise(db, test_id)
|
||||||
with UnitOfWork(db) as uow:
|
with UnitOfWork(db) as uow:
|
||||||
test = wf_resume_timer(db, test, current_user)
|
test = wf_resume_timer(db, test, current_user)
|
||||||
uow.commit()
|
uow.commit()
|
||||||
@@ -535,7 +419,7 @@ def validate_red(
|
|||||||
current_user: User = Depends(require_any_role("red_lead")),
|
current_user: User = Depends(require_any_role("red_lead")),
|
||||||
):
|
):
|
||||||
"""Red Lead approves or rejects the red side of a test."""
|
"""Red Lead approves or rejects the red side of a test."""
|
||||||
test = _get_test_with_technique(db, test_id)
|
test = crud_get_test_with_technique(db, test_id)
|
||||||
with UnitOfWork(db) as uow:
|
with UnitOfWork(db) as uow:
|
||||||
test = wf_validate_red(
|
test = wf_validate_red(
|
||||||
db, test, current_user,
|
db, test, current_user,
|
||||||
@@ -562,7 +446,7 @@ def validate_blue(
|
|||||||
current_user: User = Depends(require_any_role("blue_lead")),
|
current_user: User = Depends(require_any_role("blue_lead")),
|
||||||
):
|
):
|
||||||
"""Blue Lead approves or rejects the blue side of a test."""
|
"""Blue Lead approves or rejects the blue side of a test."""
|
||||||
test = _get_test_with_technique(db, test_id)
|
test = crud_get_test_with_technique(db, test_id)
|
||||||
with UnitOfWork(db) as uow:
|
with UnitOfWork(db) as uow:
|
||||||
test = wf_validate_blue(
|
test = wf_validate_blue(
|
||||||
db, test, current_user,
|
db, test, current_user,
|
||||||
@@ -588,7 +472,7 @@ def reopen(
|
|||||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||||
):
|
):
|
||||||
"""Reopen a rejected test, moving it back to ``draft``."""
|
"""Reopen a rejected test, moving it back to ``draft``."""
|
||||||
test = _get_test_or_404(db, test_id)
|
test = crud_get_test_or_raise(db, test_id)
|
||||||
with UnitOfWork(db) as uow:
|
with UnitOfWork(db) as uow:
|
||||||
test = wf_reopen(db, test, current_user)
|
test = wf_reopen(db, test, current_user)
|
||||||
uow.commit()
|
uow.commit()
|
||||||
@@ -613,7 +497,7 @@ def update_remediation(
|
|||||||
When ``remediation_status`` transitions to ``'completed'``, an automatic
|
When ``remediation_status`` transitions to ``'completed'``, an automatic
|
||||||
re-test is created (subject to ``MAX_RETEST_COUNT``).
|
re-test is created (subject to ``MAX_RETEST_COUNT``).
|
||||||
"""
|
"""
|
||||||
test = _get_test_or_404(db, test_id)
|
test = crud_get_test_or_raise(db, test_id)
|
||||||
|
|
||||||
old_remediation_status = test.remediation_status
|
old_remediation_status = test.remediation_status
|
||||||
|
|
||||||
@@ -653,29 +537,7 @@ def get_test_timeline(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Return the chronological audit-log history for a test."""
|
"""Return the chronological audit-log history for a test."""
|
||||||
# Verify the test exists
|
return crud_get_test_timeline(db, test_id)
|
||||||
_get_test_or_404(db, test_id)
|
|
||||||
|
|
||||||
logs = (
|
|
||||||
db.query(AuditLog)
|
|
||||||
.filter(
|
|
||||||
AuditLog.entity_type == "test",
|
|
||||||
AuditLog.entity_id == str(test_id),
|
|
||||||
)
|
|
||||||
.order_by(AuditLog.timestamp.asc())
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
"id": str(log.id),
|
|
||||||
"action": log.action,
|
|
||||||
"user_id": str(log.user_id) if log.user_id else None,
|
|
||||||
"timestamp": log.timestamp.isoformat() if log.timestamp else None,
|
|
||||||
"details": log.details,
|
|
||||||
}
|
|
||||||
for log in logs
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
277
backend/app/services/test_crud_service.py
Normal file
277
backend/app/services/test_crud_service.py
Normal file
@@ -0,0 +1,277 @@
|
|||||||
|
"""Test CRUD service — list, create, update, and query logic for security tests.
|
||||||
|
|
||||||
|
Framework-agnostic; uses domain exceptions from app.domain.errors.
|
||||||
|
The router is responsible for HTTP concerns, auth, audit logging, and commit.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session, joinedload
|
||||||
|
|
||||||
|
from app.domain.errors import (
|
||||||
|
BusinessRuleViolation,
|
||||||
|
EntityNotFoundError,
|
||||||
|
PermissionViolation,
|
||||||
|
)
|
||||||
|
from app.models.enums import TestState
|
||||||
|
from app.models.technique import Technique
|
||||||
|
from app.models.test import Test
|
||||||
|
from app.models.test_template import TestTemplate
|
||||||
|
from app.models.audit import AuditLog
|
||||||
|
from app.utils import escape_like
|
||||||
|
|
||||||
|
|
||||||
|
def list_tests(
|
||||||
|
db: Session,
|
||||||
|
*,
|
||||||
|
state: str | None = None,
|
||||||
|
technique_id: uuid.UUID | None = None,
|
||||||
|
platform: str | None = None,
|
||||||
|
created_by: uuid.UUID | None = None,
|
||||||
|
pending_validation_side: str | None = None,
|
||||||
|
offset: int = 0,
|
||||||
|
limit: int = 50,
|
||||||
|
) -> list[Test]:
|
||||||
|
"""Return a paginated list of tests with optional filters."""
|
||||||
|
query = db.query(Test).options(joinedload(Test.technique))
|
||||||
|
|
||||||
|
if state:
|
||||||
|
query = query.filter(Test.state == state)
|
||||||
|
if technique_id:
|
||||||
|
query = query.filter(Test.technique_id == technique_id)
|
||||||
|
if platform:
|
||||||
|
query = query.filter(Test.platform.ilike(f"%{escape_like(platform)}%"))
|
||||||
|
if created_by:
|
||||||
|
query = query.filter(Test.created_by == created_by)
|
||||||
|
if pending_validation_side == "red":
|
||||||
|
query = query.filter(
|
||||||
|
Test.state == TestState.in_review,
|
||||||
|
Test.red_validation_status.in_(["pending", None]),
|
||||||
|
)
|
||||||
|
elif pending_validation_side == "blue":
|
||||||
|
query = query.filter(
|
||||||
|
Test.state == TestState.in_review,
|
||||||
|
Test.blue_validation_status.in_(["pending", None]),
|
||||||
|
)
|
||||||
|
|
||||||
|
return query.order_by(Test.created_at.desc()).offset(offset).limit(limit).all()
|
||||||
|
|
||||||
|
|
||||||
|
def create_test(
|
||||||
|
db: Session,
|
||||||
|
*,
|
||||||
|
technique_id: uuid.UUID,
|
||||||
|
creator_id: uuid.UUID,
|
||||||
|
**fields: Any,
|
||||||
|
) -> Test:
|
||||||
|
"""Create a new test linked to an existing technique.
|
||||||
|
|
||||||
|
Raises EntityNotFoundError if the technique does not exist.
|
||||||
|
Does not commit; caller uses UnitOfWork.
|
||||||
|
"""
|
||||||
|
technique = db.query(Technique).filter(Technique.id == technique_id).first()
|
||||||
|
if technique is None:
|
||||||
|
raise EntityNotFoundError("Technique", str(technique_id))
|
||||||
|
|
||||||
|
test = Test(
|
||||||
|
technique_id=technique_id,
|
||||||
|
created_by=creator_id,
|
||||||
|
state=TestState.draft,
|
||||||
|
**fields,
|
||||||
|
)
|
||||||
|
db.add(test)
|
||||||
|
db.flush()
|
||||||
|
return test
|
||||||
|
|
||||||
|
|
||||||
|
def create_test_from_template(
|
||||||
|
db: Session,
|
||||||
|
*,
|
||||||
|
template_id: uuid.UUID,
|
||||||
|
technique_id_or_mitre: str,
|
||||||
|
creator_id: uuid.UUID,
|
||||||
|
) -> Test:
|
||||||
|
"""Instantiate a Test from a TestTemplate.
|
||||||
|
|
||||||
|
technique_id_or_mitre can be a UUID string or MITRE ID (e.g. T1059.001).
|
||||||
|
Raises EntityNotFoundError if template or technique not found.
|
||||||
|
Does not commit; caller uses UnitOfWork.
|
||||||
|
"""
|
||||||
|
template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first()
|
||||||
|
if template is None:
|
||||||
|
raise EntityNotFoundError("TestTemplate", str(template_id))
|
||||||
|
|
||||||
|
technique = None
|
||||||
|
try:
|
||||||
|
technique_uuid = uuid.UUID(technique_id_or_mitre)
|
||||||
|
technique = db.query(Technique).filter(Technique.id == technique_uuid).first()
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if technique is None:
|
||||||
|
technique = db.query(Technique).filter(
|
||||||
|
Technique.mitre_id == technique_id_or_mitre
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if technique is None:
|
||||||
|
raise EntityNotFoundError("Technique", technique_id_or_mitre)
|
||||||
|
|
||||||
|
test = Test(
|
||||||
|
technique_id=technique.id,
|
||||||
|
name=template.name,
|
||||||
|
description=template.description,
|
||||||
|
platform=template.platform,
|
||||||
|
procedure_text=template.attack_procedure,
|
||||||
|
tool_used=template.tool_suggested,
|
||||||
|
remediation_steps=template.suggested_remediation,
|
||||||
|
created_by=creator_id,
|
||||||
|
state=TestState.draft,
|
||||||
|
)
|
||||||
|
db.add(test)
|
||||||
|
db.flush()
|
||||||
|
return test
|
||||||
|
|
||||||
|
|
||||||
|
def get_test_detail(db: Session, test_id: uuid.UUID) -> Test:
|
||||||
|
"""Fetch a test with evidences eager-loaded.
|
||||||
|
|
||||||
|
Raises EntityNotFoundError if the test does not exist.
|
||||||
|
"""
|
||||||
|
test = (
|
||||||
|
db.query(Test)
|
||||||
|
.options(joinedload(Test.evidences))
|
||||||
|
.filter(Test.id == test_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if test is None:
|
||||||
|
raise EntityNotFoundError("Test", str(test_id))
|
||||||
|
return test
|
||||||
|
|
||||||
|
|
||||||
|
def get_test_or_raise(db: Session, test_id: uuid.UUID) -> Test:
|
||||||
|
"""Fetch a test by ID. Raises EntityNotFoundError if not found."""
|
||||||
|
test = db.query(Test).filter(Test.id == test_id).first()
|
||||||
|
if test is None:
|
||||||
|
raise EntityNotFoundError("Test", str(test_id))
|
||||||
|
return test
|
||||||
|
|
||||||
|
|
||||||
|
def get_test_with_technique(db: Session, test_id: uuid.UUID) -> Test:
|
||||||
|
"""Fetch a test with technique joined. Raises EntityNotFoundError if not found."""
|
||||||
|
test = (
|
||||||
|
db.query(Test)
|
||||||
|
.options(joinedload(Test.technique))
|
||||||
|
.filter(Test.id == test_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if test is None:
|
||||||
|
raise EntityNotFoundError("Test", str(test_id))
|
||||||
|
return test
|
||||||
|
|
||||||
|
|
||||||
|
def update_test(
|
||||||
|
db: Session,
|
||||||
|
test_id: uuid.UUID,
|
||||||
|
*,
|
||||||
|
updater_id: uuid.UUID,
|
||||||
|
updater_role: str,
|
||||||
|
**fields: Any,
|
||||||
|
) -> Test:
|
||||||
|
"""Update general test fields (draft or rejected only).
|
||||||
|
|
||||||
|
Raises PermissionViolation if not creator or admin.
|
||||||
|
Raises BusinessRuleViolation if state is not draft or rejected.
|
||||||
|
Raises EntityNotFoundError if test not found.
|
||||||
|
Does not commit; caller uses UnitOfWork.
|
||||||
|
"""
|
||||||
|
test = get_test_or_raise(db, test_id)
|
||||||
|
|
||||||
|
if updater_role != "admin" and test.created_by != updater_id:
|
||||||
|
raise PermissionViolation(
|
||||||
|
"Only the test creator or an admin can update this test"
|
||||||
|
)
|
||||||
|
|
||||||
|
if test.state not in (TestState.draft, TestState.rejected):
|
||||||
|
raise BusinessRuleViolation(
|
||||||
|
f"Cannot update a test in '{test.state.value}' state (must be draft or rejected)"
|
||||||
|
)
|
||||||
|
|
||||||
|
for field, value in fields.items():
|
||||||
|
setattr(test, field, value)
|
||||||
|
|
||||||
|
db.flush()
|
||||||
|
return test
|
||||||
|
|
||||||
|
|
||||||
|
def update_test_red(db: Session, test_id: uuid.UUID, **fields: Any) -> Test:
|
||||||
|
"""Update Red Team fields (draft or red_executing only).
|
||||||
|
|
||||||
|
Raises BusinessRuleViolation if state not in (draft, red_executing).
|
||||||
|
Raises EntityNotFoundError if test not found.
|
||||||
|
Does not commit; caller uses UnitOfWork.
|
||||||
|
"""
|
||||||
|
test = get_test_or_raise(db, test_id)
|
||||||
|
|
||||||
|
if test.state not in (TestState.draft, TestState.red_executing):
|
||||||
|
raise BusinessRuleViolation(
|
||||||
|
f"Cannot update red fields in '{test.state.value}' state "
|
||||||
|
"(must be draft or red_executing)"
|
||||||
|
)
|
||||||
|
|
||||||
|
for field, value in fields.items():
|
||||||
|
setattr(test, field, value)
|
||||||
|
|
||||||
|
db.flush()
|
||||||
|
return test
|
||||||
|
|
||||||
|
|
||||||
|
def update_test_blue(db: Session, test_id: uuid.UUID, **fields: Any) -> Test:
|
||||||
|
"""Update Blue Team fields (blue_evaluating only).
|
||||||
|
|
||||||
|
Raises BusinessRuleViolation if state is not blue_evaluating.
|
||||||
|
Raises EntityNotFoundError if test not found.
|
||||||
|
Does not commit; caller uses UnitOfWork.
|
||||||
|
"""
|
||||||
|
test = get_test_or_raise(db, test_id)
|
||||||
|
|
||||||
|
if test.state != TestState.blue_evaluating:
|
||||||
|
raise BusinessRuleViolation(
|
||||||
|
f"Cannot update blue fields in '{test.state.value}' state "
|
||||||
|
"(must be blue_evaluating)"
|
||||||
|
)
|
||||||
|
|
||||||
|
for field, value in fields.items():
|
||||||
|
setattr(test, field, value)
|
||||||
|
|
||||||
|
db.flush()
|
||||||
|
return test
|
||||||
|
|
||||||
|
|
||||||
|
def get_test_timeline(db: Session, test_id: uuid.UUID) -> list[dict[str, Any]]:
|
||||||
|
"""Return chronological audit-log history for a test.
|
||||||
|
|
||||||
|
Raises EntityNotFoundError if the test does not exist.
|
||||||
|
"""
|
||||||
|
get_test_or_raise(db, test_id)
|
||||||
|
|
||||||
|
logs = (
|
||||||
|
db.query(AuditLog)
|
||||||
|
.filter(
|
||||||
|
AuditLog.entity_type == "test",
|
||||||
|
AuditLog.entity_id == str(test_id),
|
||||||
|
)
|
||||||
|
.order_by(AuditLog.timestamp.asc())
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"id": str(log.id),
|
||||||
|
"action": log.action,
|
||||||
|
"user_id": str(log.user_id) if log.user_id else None,
|
||||||
|
"timestamp": log.timestamp.isoformat() if log.timestamp else None,
|
||||||
|
"details": log.details,
|
||||||
|
}
|
||||||
|
for log in logs
|
||||||
|
]
|
||||||
@@ -101,7 +101,7 @@ from app.routers.test_templates import (
|
|||||||
toggle_template_active,
|
toggle_template_active,
|
||||||
template_stats,
|
template_stats,
|
||||||
)
|
)
|
||||||
from app.routers.tests import create_test_from_template
|
from app.services.test_crud_service import create_test_from_template as crud_create_from_template
|
||||||
from app.schemas.test_template import TestTemplateCreate
|
from app.schemas.test_template import TestTemplateCreate
|
||||||
|
|
||||||
|
|
||||||
@@ -174,7 +174,8 @@ def test_get_templates_by_technique():
|
|||||||
|
|
||||||
def test_instantiate_template():
|
def test_instantiate_template():
|
||||||
"""POST /tests/from-template creates a test pre-filled from template data."""
|
"""POST /tests/from-template creates a test pre-filled from template data."""
|
||||||
source = inspect.getsource(create_test_from_template)
|
# Template field copying lives in the service; router delegates to it
|
||||||
|
source = inspect.getsource(crud_create_from_template)
|
||||||
|
|
||||||
# Verify it reads from template and copies fields
|
# Verify it reads from template and copies fields
|
||||||
assert "template" in source, "Must reference template"
|
assert "template" in source, "Must reference template"
|
||||||
|
|||||||
@@ -477,15 +477,15 @@ def test_evidence_team_separation():
|
|||||||
|
|
||||||
|
|
||||||
def test_red_edit_allowed_in_draft_and_red_executing():
|
def test_red_edit_allowed_in_draft_and_red_executing():
|
||||||
"""Verify the red update router checks that state is draft or red_executing."""
|
"""Verify the red update checks that state is draft or red_executing."""
|
||||||
from app.routers.tests import update_test_red
|
from app.services.test_crud_service import update_test_red
|
||||||
import inspect
|
import inspect
|
||||||
source = inspect.getsource(update_test_red)
|
source = inspect.getsource(update_test_red)
|
||||||
|
|
||||||
# The function must guard against states other than draft/red_executing
|
# The service must guard against states other than draft/red_executing
|
||||||
assert "draft" in source, "Red update must allow draft state"
|
assert "draft" in source, "Red update must allow draft state"
|
||||||
assert "red_executing" in source, "Red update must allow red_executing state"
|
assert "red_executing" in source, "Red update must allow red_executing state"
|
||||||
assert "400" in source or "HTTP_400_BAD_REQUEST" in source, "Red update must return 400 for invalid state"
|
assert "BusinessRuleViolation" in source, "Must raise domain exception for invalid state (mapped to 400)"
|
||||||
|
|
||||||
|
|
||||||
# ===========================================================================
|
# ===========================================================================
|
||||||
|
|||||||
Reference in New Issue
Block a user