refactor(tests): extract CRUD/query logic to test_crud_service, router delegates to service with domain exceptions
This commit is contained in:
277
backend/app/services/test_crud_service.py
Normal file
277
backend/app/services/test_crud_service.py
Normal file
@@ -0,0 +1,277 @@
|
||||
"""Test CRUD service — list, create, update, and query logic for security tests.
|
||||
|
||||
Framework-agnostic; uses domain exceptions from app.domain.errors.
|
||||
The router is responsible for HTTP concerns, auth, audit logging, and commit.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from app.domain.errors import (
|
||||
BusinessRuleViolation,
|
||||
EntityNotFoundError,
|
||||
PermissionViolation,
|
||||
)
|
||||
from app.models.enums import TestState
|
||||
from app.models.technique import Technique
|
||||
from app.models.test import Test
|
||||
from app.models.test_template import TestTemplate
|
||||
from app.models.audit import AuditLog
|
||||
from app.utils import escape_like
|
||||
|
||||
|
||||
def list_tests(
|
||||
db: Session,
|
||||
*,
|
||||
state: str | None = None,
|
||||
technique_id: uuid.UUID | None = None,
|
||||
platform: str | None = None,
|
||||
created_by: uuid.UUID | None = None,
|
||||
pending_validation_side: str | None = None,
|
||||
offset: int = 0,
|
||||
limit: int = 50,
|
||||
) -> list[Test]:
|
||||
"""Return a paginated list of tests with optional filters."""
|
||||
query = db.query(Test).options(joinedload(Test.technique))
|
||||
|
||||
if state:
|
||||
query = query.filter(Test.state == state)
|
||||
if technique_id:
|
||||
query = query.filter(Test.technique_id == technique_id)
|
||||
if platform:
|
||||
query = query.filter(Test.platform.ilike(f"%{escape_like(platform)}%"))
|
||||
if created_by:
|
||||
query = query.filter(Test.created_by == created_by)
|
||||
if pending_validation_side == "red":
|
||||
query = query.filter(
|
||||
Test.state == TestState.in_review,
|
||||
Test.red_validation_status.in_(["pending", None]),
|
||||
)
|
||||
elif pending_validation_side == "blue":
|
||||
query = query.filter(
|
||||
Test.state == TestState.in_review,
|
||||
Test.blue_validation_status.in_(["pending", None]),
|
||||
)
|
||||
|
||||
return query.order_by(Test.created_at.desc()).offset(offset).limit(limit).all()
|
||||
|
||||
|
||||
def create_test(
|
||||
db: Session,
|
||||
*,
|
||||
technique_id: uuid.UUID,
|
||||
creator_id: uuid.UUID,
|
||||
**fields: Any,
|
||||
) -> Test:
|
||||
"""Create a new test linked to an existing technique.
|
||||
|
||||
Raises EntityNotFoundError if the technique does not exist.
|
||||
Does not commit; caller uses UnitOfWork.
|
||||
"""
|
||||
technique = db.query(Technique).filter(Technique.id == technique_id).first()
|
||||
if technique is None:
|
||||
raise EntityNotFoundError("Technique", str(technique_id))
|
||||
|
||||
test = Test(
|
||||
technique_id=technique_id,
|
||||
created_by=creator_id,
|
||||
state=TestState.draft,
|
||||
**fields,
|
||||
)
|
||||
db.add(test)
|
||||
db.flush()
|
||||
return test
|
||||
|
||||
|
||||
def create_test_from_template(
|
||||
db: Session,
|
||||
*,
|
||||
template_id: uuid.UUID,
|
||||
technique_id_or_mitre: str,
|
||||
creator_id: uuid.UUID,
|
||||
) -> Test:
|
||||
"""Instantiate a Test from a TestTemplate.
|
||||
|
||||
technique_id_or_mitre can be a UUID string or MITRE ID (e.g. T1059.001).
|
||||
Raises EntityNotFoundError if template or technique not found.
|
||||
Does not commit; caller uses UnitOfWork.
|
||||
"""
|
||||
template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first()
|
||||
if template is None:
|
||||
raise EntityNotFoundError("TestTemplate", str(template_id))
|
||||
|
||||
technique = None
|
||||
try:
|
||||
technique_uuid = uuid.UUID(technique_id_or_mitre)
|
||||
technique = db.query(Technique).filter(Technique.id == technique_uuid).first()
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if technique is None:
|
||||
technique = db.query(Technique).filter(
|
||||
Technique.mitre_id == technique_id_or_mitre
|
||||
).first()
|
||||
|
||||
if technique is None:
|
||||
raise EntityNotFoundError("Technique", technique_id_or_mitre)
|
||||
|
||||
test = Test(
|
||||
technique_id=technique.id,
|
||||
name=template.name,
|
||||
description=template.description,
|
||||
platform=template.platform,
|
||||
procedure_text=template.attack_procedure,
|
||||
tool_used=template.tool_suggested,
|
||||
remediation_steps=template.suggested_remediation,
|
||||
created_by=creator_id,
|
||||
state=TestState.draft,
|
||||
)
|
||||
db.add(test)
|
||||
db.flush()
|
||||
return test
|
||||
|
||||
|
||||
def get_test_detail(db: Session, test_id: uuid.UUID) -> Test:
|
||||
"""Fetch a test with evidences eager-loaded.
|
||||
|
||||
Raises EntityNotFoundError if the test does not exist.
|
||||
"""
|
||||
test = (
|
||||
db.query(Test)
|
||||
.options(joinedload(Test.evidences))
|
||||
.filter(Test.id == test_id)
|
||||
.first()
|
||||
)
|
||||
if test is None:
|
||||
raise EntityNotFoundError("Test", str(test_id))
|
||||
return test
|
||||
|
||||
|
||||
def get_test_or_raise(db: Session, test_id: uuid.UUID) -> Test:
|
||||
"""Fetch a test by ID. Raises EntityNotFoundError if not found."""
|
||||
test = db.query(Test).filter(Test.id == test_id).first()
|
||||
if test is None:
|
||||
raise EntityNotFoundError("Test", str(test_id))
|
||||
return test
|
||||
|
||||
|
||||
def get_test_with_technique(db: Session, test_id: uuid.UUID) -> Test:
|
||||
"""Fetch a test with technique joined. Raises EntityNotFoundError if not found."""
|
||||
test = (
|
||||
db.query(Test)
|
||||
.options(joinedload(Test.technique))
|
||||
.filter(Test.id == test_id)
|
||||
.first()
|
||||
)
|
||||
if test is None:
|
||||
raise EntityNotFoundError("Test", str(test_id))
|
||||
return test
|
||||
|
||||
|
||||
def update_test(
|
||||
db: Session,
|
||||
test_id: uuid.UUID,
|
||||
*,
|
||||
updater_id: uuid.UUID,
|
||||
updater_role: str,
|
||||
**fields: Any,
|
||||
) -> Test:
|
||||
"""Update general test fields (draft or rejected only).
|
||||
|
||||
Raises PermissionViolation if not creator or admin.
|
||||
Raises BusinessRuleViolation if state is not draft or rejected.
|
||||
Raises EntityNotFoundError if test not found.
|
||||
Does not commit; caller uses UnitOfWork.
|
||||
"""
|
||||
test = get_test_or_raise(db, test_id)
|
||||
|
||||
if updater_role != "admin" and test.created_by != updater_id:
|
||||
raise PermissionViolation(
|
||||
"Only the test creator or an admin can update this test"
|
||||
)
|
||||
|
||||
if test.state not in (TestState.draft, TestState.rejected):
|
||||
raise BusinessRuleViolation(
|
||||
f"Cannot update a test in '{test.state.value}' state (must be draft or rejected)"
|
||||
)
|
||||
|
||||
for field, value in fields.items():
|
||||
setattr(test, field, value)
|
||||
|
||||
db.flush()
|
||||
return test
|
||||
|
||||
|
||||
def update_test_red(db: Session, test_id: uuid.UUID, **fields: Any) -> Test:
|
||||
"""Update Red Team fields (draft or red_executing only).
|
||||
|
||||
Raises BusinessRuleViolation if state not in (draft, red_executing).
|
||||
Raises EntityNotFoundError if test not found.
|
||||
Does not commit; caller uses UnitOfWork.
|
||||
"""
|
||||
test = get_test_or_raise(db, test_id)
|
||||
|
||||
if test.state not in (TestState.draft, TestState.red_executing):
|
||||
raise BusinessRuleViolation(
|
||||
f"Cannot update red fields in '{test.state.value}' state "
|
||||
"(must be draft or red_executing)"
|
||||
)
|
||||
|
||||
for field, value in fields.items():
|
||||
setattr(test, field, value)
|
||||
|
||||
db.flush()
|
||||
return test
|
||||
|
||||
|
||||
def update_test_blue(db: Session, test_id: uuid.UUID, **fields: Any) -> Test:
|
||||
"""Update Blue Team fields (blue_evaluating only).
|
||||
|
||||
Raises BusinessRuleViolation if state is not blue_evaluating.
|
||||
Raises EntityNotFoundError if test not found.
|
||||
Does not commit; caller uses UnitOfWork.
|
||||
"""
|
||||
test = get_test_or_raise(db, test_id)
|
||||
|
||||
if test.state != TestState.blue_evaluating:
|
||||
raise BusinessRuleViolation(
|
||||
f"Cannot update blue fields in '{test.state.value}' state "
|
||||
"(must be blue_evaluating)"
|
||||
)
|
||||
|
||||
for field, value in fields.items():
|
||||
setattr(test, field, value)
|
||||
|
||||
db.flush()
|
||||
return test
|
||||
|
||||
|
||||
def get_test_timeline(db: Session, test_id: uuid.UUID) -> list[dict[str, Any]]:
|
||||
"""Return chronological audit-log history for a test.
|
||||
|
||||
Raises EntityNotFoundError if the test does not exist.
|
||||
"""
|
||||
get_test_or_raise(db, test_id)
|
||||
|
||||
logs = (
|
||||
db.query(AuditLog)
|
||||
.filter(
|
||||
AuditLog.entity_type == "test",
|
||||
AuditLog.entity_id == str(test_id),
|
||||
)
|
||||
.order_by(AuditLog.timestamp.asc())
|
||||
.all()
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
"id": str(log.id),
|
||||
"action": log.action,
|
||||
"user_id": str(log.user_id) if log.user_id else None,
|
||||
"timestamp": log.timestamp.isoformat() if log.timestamp else None,
|
||||
"details": log.details,
|
||||
}
|
||||
for log in logs
|
||||
]
|
||||
Reference in New Issue
Block a user