From 20738d11b3f6564f78eb96fb34051230a5edee26 Mon Sep 17 00:00:00 2001 From: Kitos Date: Thu, 19 Feb 2026 18:35:09 +0100 Subject: [PATCH] refactor(tests): extract CRUD/query logic to test_crud_service, router delegates to service with domain exceptions --- backend/app/routers/tests.py | 268 +++++---------------- backend/app/services/test_crud_service.py | 277 ++++++++++++++++++++++ backend/tests/test_templates_crud.py | 5 +- backend/tests/test_workflow.py | 8 +- 4 files changed, 349 insertions(+), 209 deletions(-) create mode 100644 backend/app/services/test_crud_service.py diff --git a/backend/app/routers/tests.py b/backend/app/routers/tests.py index df96902..1bf800b 100644 --- a/backend/app/routers/tests.py +++ b/backend/app/routers/tests.py @@ -22,15 +22,11 @@ import uuid from typing import Optional from fastapi import APIRouter, Depends, HTTPException, Query, status -from sqlalchemy.orm import Session, joinedload +from sqlalchemy.orm import Session from app.database import get_db from app.dependencies.auth import get_current_user, require_any_role -from app.models.audit import AuditLog -from app.models.enums import TestState, TeamSide -from app.models.technique import Technique -from app.models.test import Test -from app.models.test_template import TestTemplate +from app.models.enums import TestState from app.models.user import User from app.schemas.test import ( TestCreate, @@ -46,6 +42,18 @@ from app.schemas.test_template import TestTemplateInstantiate from app.domain.unit_of_work import UnitOfWork from app.services.audit_service import log_action from app.services.status_service import recalculate_technique_status +from app.services.test_crud_service import ( + create_test as crud_create_test, + create_test_from_template as crud_create_from_template, + get_test_detail as crud_get_test_detail, + get_test_or_raise as crud_get_test_or_raise, + get_test_timeline as crud_get_test_timeline, + get_test_with_technique as crud_get_test_with_technique, + list_tests as crud_list_tests, + update_test as crud_update_test, + update_test_blue as crud_update_test_blue, + update_test_red as crud_update_test_red, +) from app.services.test_workflow_service import ( start_execution as wf_start_execution, submit_red_evidence as wf_submit_red, @@ -62,29 +70,6 @@ from app.services.test_workflow_service import ( router = APIRouter(prefix="/tests", tags=["tests"]) -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - -def _get_test_or_404(db: Session, test_id: uuid.UUID) -> Test: - test = db.query(Test).filter(Test.id == test_id).first() - if test is None: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Test not found") - return test - - -def _get_test_with_technique(db: Session, test_id: uuid.UUID) -> Test: - test = ( - db.query(Test) - .options(joinedload(Test.technique)) - .filter(Test.id == test_id) - .first() - ) - if test is None: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Test not found") - return test - - # --------------------------------------------------------------------------- # GET /tests — list with filters # --------------------------------------------------------------------------- @@ -105,30 +90,16 @@ def list_tests( current_user: User = Depends(get_current_user), ): """Return a paginated list of tests, optionally filtered by state, technique, platform or creator.""" - query = db.query(Test).options(joinedload(Test.technique)) - - if state: - query = query.filter(Test.state == state) - if technique_id: - query = query.filter(Test.technique_id == technique_id) - if platform: - from app.utils import escape_like - query = query.filter(Test.platform.ilike(f"%{escape_like(platform)}%")) - if created_by: - query = query.filter(Test.created_by == created_by) - if pending_validation_side == "red": - query = query.filter( - Test.state == TestState.in_review, - Test.red_validation_status.in_(["pending", None]), - ) - elif pending_validation_side == "blue": - query = query.filter( - Test.state == TestState.in_review, - Test.blue_validation_status.in_(["pending", None]), - ) - - tests = query.order_by(Test.created_at.desc()).offset(offset).limit(limit).all() - return tests + return crud_list_tests( + db, + state=state, + technique_id=technique_id, + platform=platform, + created_by=created_by, + pending_validation_side=pending_validation_side, + offset=offset, + limit=limit, + ) # --------------------------------------------------------------------------- @@ -150,20 +121,14 @@ def create_test( ``created_by`` is set automatically and ``state`` defaults to *draft*. """ - technique = db.query(Technique).filter(Technique.id == payload.technique_id).first() - if technique is None: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Technique with id '{payload.technique_id}' not found", + with UnitOfWork(db) as uow: + test = crud_create_test( + db, + technique_id=payload.technique_id, + creator_id=current_user.id, + **payload.model_dump(exclude={"technique_id"}), ) - - test = Test( - **payload.model_dump(), - created_by=current_user.id, - state=TestState.draft, - ) - db.add(test) - db.commit() + uow.commit() db.refresh(test) log_action( @@ -197,43 +162,14 @@ def create_test_from_template( The template's fields are copied into the new test as starting data. """ - template = db.query(TestTemplate).filter(TestTemplate.id == payload.template_id).first() - if template is None: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"TestTemplate with id '{payload.template_id}' not found", + with UnitOfWork(db) as uow: + test = crud_create_from_template( + db, + template_id=payload.template_id, + technique_id_or_mitre=payload.technique_id, + creator_id=current_user.id, ) - - # Resolve technique_id: accept both UUID and MITRE ID (e.g. "T1059.001") - technique = None - try: - technique_uuid = uuid.UUID(payload.technique_id) - technique = db.query(Technique).filter(Technique.id == technique_uuid).first() - except ValueError: - pass - - if technique is None: - technique = db.query(Technique).filter(Technique.mitre_id == payload.technique_id).first() - - if technique is None: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Technique '{payload.technique_id}' not found", - ) - - test = Test( - technique_id=technique.id, - name=template.name, - description=template.description, - platform=template.platform, - procedure_text=template.attack_procedure, - tool_used=template.tool_suggested, - remediation_steps=template.suggested_remediation, - created_by=current_user.id, - state=TestState.draft, - ) - db.add(test) - db.commit() + uow.commit() db.refresh(test) log_action( @@ -244,7 +180,7 @@ def create_test_from_template( entity_id=test.id, details={ "name": test.name, - "template_id": str(template.id), + "template_id": str(payload.template_id), "technique_id": str(test.technique_id), }, ) @@ -264,20 +200,7 @@ def get_test( current_user: User = Depends(get_current_user), ): """Return full details for a single test, including its evidences.""" - test = ( - db.query(Test) - .options(joinedload(Test.evidences)) - .filter(Test.id == test_id) - .first() - ) - - if test is None: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Test not found", - ) - - return test + return crud_get_test_detail(db, test_id) # --------------------------------------------------------------------------- @@ -297,29 +220,16 @@ def update_test( Only leads or admins can update general test fields. The test must be in ``draft`` or ``rejected`` state. """ - test = _get_test_or_404(db, test_id) - - if current_user.role != "admin" and test.created_by != current_user.id: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail={"message": "Only the test creator or an admin can update this test", "code": "FORBIDDEN"}, - ) - - if test.state not in (TestState.draft, TestState.rejected): - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail={ - "message": f"Cannot update a test in '{test.state.value}' state (must be draft or rejected)", - "code": "INVALID_STATE", - "current_state": test.state.value, - }, - ) - update_data = payload.model_dump(exclude_unset=True) - for field, value in update_data.items(): - setattr(test, field, value) - - db.commit() + with UnitOfWork(db) as uow: + test = crud_update_test( + db, + test_id, + updater_id=current_user.id, + updater_role=current_user.role, + **update_data, + ) + uow.commit() db.refresh(test) log_action( @@ -347,23 +257,10 @@ def update_test_red( current_user: User = Depends(require_any_role("red_tech", "red_lead")), ): """Red Team updates their fields (allowed in ``draft`` and ``red_executing``).""" - test = _get_test_or_404(db, test_id) - - if test.state not in (TestState.draft, TestState.red_executing): - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail={ - "message": f"Cannot update red fields in '{test.state.value}' state (must be draft or red_executing)", - "code": "INVALID_STATE", - "current_state": test.state.value, - }, - ) - update_data = payload.model_dump(exclude_unset=True) - for field, value in update_data.items(): - setattr(test, field, value) - - db.commit() + with UnitOfWork(db) as uow: + test = crud_update_test_red(db, test_id, **update_data) + uow.commit() db.refresh(test) log_action( @@ -391,23 +288,10 @@ def update_test_blue( current_user: User = Depends(require_any_role("blue_tech", "blue_lead")), ): """Blue Team updates their fields (allowed only in ``blue_evaluating``).""" - test = _get_test_or_404(db, test_id) - - if test.state != TestState.blue_evaluating: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail={ - "message": f"Cannot update blue fields in '{test.state.value}' state (must be blue_evaluating)", - "code": "INVALID_STATE", - "current_state": test.state.value, - }, - ) - update_data = payload.model_dump(exclude_unset=True) - for field, value in update_data.items(): - setattr(test, field, value) - - db.commit() + with UnitOfWork(db) as uow: + test = crud_update_test_blue(db, test_id, **update_data) + uow.commit() db.refresh(test) log_action( @@ -434,7 +318,7 @@ def start_execution( current_user: User = Depends(require_any_role("red_tech", "red_lead")), ): """Move a test from ``draft`` to ``red_executing``.""" - test = _get_test_or_404(db, test_id) + test = crud_get_test_or_raise(db, test_id) with UnitOfWork(db) as uow: test = wf_start_execution(db, test, current_user) uow.commit() @@ -454,7 +338,7 @@ def submit_red( current_user: User = Depends(require_any_role("red_tech", "red_lead")), ): """Red Team finalises — move from ``red_executing`` to ``blue_evaluating``.""" - test = _get_test_or_404(db, test_id) + test = crud_get_test_or_raise(db, test_id) with UnitOfWork(db) as uow: test = wf_submit_red(db, test, current_user) uow.commit() @@ -474,7 +358,7 @@ def submit_blue( current_user: User = Depends(require_any_role("blue_tech", "blue_lead")), ): """Blue Team finalises — move from ``blue_evaluating`` to ``in_review``.""" - test = _get_test_or_404(db, test_id) + test = crud_get_test_or_raise(db, test_id) with UnitOfWork(db) as uow: test = wf_submit_blue(db, test, current_user) uow.commit() @@ -494,7 +378,7 @@ def pause_timer( current_user: User = Depends(require_any_role("red_tech", "blue_tech", "red_lead", "blue_lead")), ): """Pause the running timer for the current phase (red_executing or blue_evaluating).""" - test = _get_test_or_404(db, test_id) + test = crud_get_test_or_raise(db, test_id) with UnitOfWork(db) as uow: test = wf_pause_timer(db, test, current_user) uow.commit() @@ -514,7 +398,7 @@ def resume_timer( current_user: User = Depends(require_any_role("red_tech", "blue_tech", "red_lead", "blue_lead")), ): """Resume the paused timer for the current phase.""" - test = _get_test_or_404(db, test_id) + test = crud_get_test_or_raise(db, test_id) with UnitOfWork(db) as uow: test = wf_resume_timer(db, test, current_user) uow.commit() @@ -535,7 +419,7 @@ def validate_red( current_user: User = Depends(require_any_role("red_lead")), ): """Red Lead approves or rejects the red side of a test.""" - test = _get_test_with_technique(db, test_id) + test = crud_get_test_with_technique(db, test_id) with UnitOfWork(db) as uow: test = wf_validate_red( db, test, current_user, @@ -562,7 +446,7 @@ def validate_blue( current_user: User = Depends(require_any_role("blue_lead")), ): """Blue Lead approves or rejects the blue side of a test.""" - test = _get_test_with_technique(db, test_id) + test = crud_get_test_with_technique(db, test_id) with UnitOfWork(db) as uow: test = wf_validate_blue( db, test, current_user, @@ -588,7 +472,7 @@ def reopen( current_user: User = Depends(require_any_role("red_lead", "blue_lead")), ): """Reopen a rejected test, moving it back to ``draft``.""" - test = _get_test_or_404(db, test_id) + test = crud_get_test_or_raise(db, test_id) with UnitOfWork(db) as uow: test = wf_reopen(db, test, current_user) uow.commit() @@ -613,7 +497,7 @@ def update_remediation( When ``remediation_status`` transitions to ``'completed'``, an automatic re-test is created (subject to ``MAX_RETEST_COUNT``). """ - test = _get_test_or_404(db, test_id) + test = crud_get_test_or_raise(db, test_id) old_remediation_status = test.remediation_status @@ -653,29 +537,7 @@ def get_test_timeline( current_user: User = Depends(get_current_user), ): """Return the chronological audit-log history for a test.""" - # Verify the test exists - _get_test_or_404(db, test_id) - - logs = ( - db.query(AuditLog) - .filter( - AuditLog.entity_type == "test", - AuditLog.entity_id == str(test_id), - ) - .order_by(AuditLog.timestamp.asc()) - .all() - ) - - return [ - { - "id": str(log.id), - "action": log.action, - "user_id": str(log.user_id) if log.user_id else None, - "timestamp": log.timestamp.isoformat() if log.timestamp else None, - "details": log.details, - } - for log in logs - ] + return crud_get_test_timeline(db, test_id) # --------------------------------------------------------------------------- diff --git a/backend/app/services/test_crud_service.py b/backend/app/services/test_crud_service.py new file mode 100644 index 0000000..a5be075 --- /dev/null +++ b/backend/app/services/test_crud_service.py @@ -0,0 +1,277 @@ +"""Test CRUD service — list, create, update, and query logic for security tests. + +Framework-agnostic; uses domain exceptions from app.domain.errors. +The router is responsible for HTTP concerns, auth, audit logging, and commit. +""" + +import uuid +from typing import Any + +from sqlalchemy.orm import Session, joinedload + +from app.domain.errors import ( + BusinessRuleViolation, + EntityNotFoundError, + PermissionViolation, +) +from app.models.enums import TestState +from app.models.technique import Technique +from app.models.test import Test +from app.models.test_template import TestTemplate +from app.models.audit import AuditLog +from app.utils import escape_like + + +def list_tests( + db: Session, + *, + state: str | None = None, + technique_id: uuid.UUID | None = None, + platform: str | None = None, + created_by: uuid.UUID | None = None, + pending_validation_side: str | None = None, + offset: int = 0, + limit: int = 50, +) -> list[Test]: + """Return a paginated list of tests with optional filters.""" + query = db.query(Test).options(joinedload(Test.technique)) + + if state: + query = query.filter(Test.state == state) + if technique_id: + query = query.filter(Test.technique_id == technique_id) + if platform: + query = query.filter(Test.platform.ilike(f"%{escape_like(platform)}%")) + if created_by: + query = query.filter(Test.created_by == created_by) + if pending_validation_side == "red": + query = query.filter( + Test.state == TestState.in_review, + Test.red_validation_status.in_(["pending", None]), + ) + elif pending_validation_side == "blue": + query = query.filter( + Test.state == TestState.in_review, + Test.blue_validation_status.in_(["pending", None]), + ) + + return query.order_by(Test.created_at.desc()).offset(offset).limit(limit).all() + + +def create_test( + db: Session, + *, + technique_id: uuid.UUID, + creator_id: uuid.UUID, + **fields: Any, +) -> Test: + """Create a new test linked to an existing technique. + + Raises EntityNotFoundError if the technique does not exist. + Does not commit; caller uses UnitOfWork. + """ + technique = db.query(Technique).filter(Technique.id == technique_id).first() + if technique is None: + raise EntityNotFoundError("Technique", str(technique_id)) + + test = Test( + technique_id=technique_id, + created_by=creator_id, + state=TestState.draft, + **fields, + ) + db.add(test) + db.flush() + return test + + +def create_test_from_template( + db: Session, + *, + template_id: uuid.UUID, + technique_id_or_mitre: str, + creator_id: uuid.UUID, +) -> Test: + """Instantiate a Test from a TestTemplate. + + technique_id_or_mitre can be a UUID string or MITRE ID (e.g. T1059.001). + Raises EntityNotFoundError if template or technique not found. + Does not commit; caller uses UnitOfWork. + """ + template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first() + if template is None: + raise EntityNotFoundError("TestTemplate", str(template_id)) + + technique = None + try: + technique_uuid = uuid.UUID(technique_id_or_mitre) + technique = db.query(Technique).filter(Technique.id == technique_uuid).first() + except ValueError: + pass + + if technique is None: + technique = db.query(Technique).filter( + Technique.mitre_id == technique_id_or_mitre + ).first() + + if technique is None: + raise EntityNotFoundError("Technique", technique_id_or_mitre) + + test = Test( + technique_id=technique.id, + name=template.name, + description=template.description, + platform=template.platform, + procedure_text=template.attack_procedure, + tool_used=template.tool_suggested, + remediation_steps=template.suggested_remediation, + created_by=creator_id, + state=TestState.draft, + ) + db.add(test) + db.flush() + return test + + +def get_test_detail(db: Session, test_id: uuid.UUID) -> Test: + """Fetch a test with evidences eager-loaded. + + Raises EntityNotFoundError if the test does not exist. + """ + test = ( + db.query(Test) + .options(joinedload(Test.evidences)) + .filter(Test.id == test_id) + .first() + ) + if test is None: + raise EntityNotFoundError("Test", str(test_id)) + return test + + +def get_test_or_raise(db: Session, test_id: uuid.UUID) -> Test: + """Fetch a test by ID. Raises EntityNotFoundError if not found.""" + test = db.query(Test).filter(Test.id == test_id).first() + if test is None: + raise EntityNotFoundError("Test", str(test_id)) + return test + + +def get_test_with_technique(db: Session, test_id: uuid.UUID) -> Test: + """Fetch a test with technique joined. Raises EntityNotFoundError if not found.""" + test = ( + db.query(Test) + .options(joinedload(Test.technique)) + .filter(Test.id == test_id) + .first() + ) + if test is None: + raise EntityNotFoundError("Test", str(test_id)) + return test + + +def update_test( + db: Session, + test_id: uuid.UUID, + *, + updater_id: uuid.UUID, + updater_role: str, + **fields: Any, +) -> Test: + """Update general test fields (draft or rejected only). + + Raises PermissionViolation if not creator or admin. + Raises BusinessRuleViolation if state is not draft or rejected. + Raises EntityNotFoundError if test not found. + Does not commit; caller uses UnitOfWork. + """ + test = get_test_or_raise(db, test_id) + + if updater_role != "admin" and test.created_by != updater_id: + raise PermissionViolation( + "Only the test creator or an admin can update this test" + ) + + if test.state not in (TestState.draft, TestState.rejected): + raise BusinessRuleViolation( + f"Cannot update a test in '{test.state.value}' state (must be draft or rejected)" + ) + + for field, value in fields.items(): + setattr(test, field, value) + + db.flush() + return test + + +def update_test_red(db: Session, test_id: uuid.UUID, **fields: Any) -> Test: + """Update Red Team fields (draft or red_executing only). + + Raises BusinessRuleViolation if state not in (draft, red_executing). + Raises EntityNotFoundError if test not found. + Does not commit; caller uses UnitOfWork. + """ + test = get_test_or_raise(db, test_id) + + if test.state not in (TestState.draft, TestState.red_executing): + raise BusinessRuleViolation( + f"Cannot update red fields in '{test.state.value}' state " + "(must be draft or red_executing)" + ) + + for field, value in fields.items(): + setattr(test, field, value) + + db.flush() + return test + + +def update_test_blue(db: Session, test_id: uuid.UUID, **fields: Any) -> Test: + """Update Blue Team fields (blue_evaluating only). + + Raises BusinessRuleViolation if state is not blue_evaluating. + Raises EntityNotFoundError if test not found. + Does not commit; caller uses UnitOfWork. + """ + test = get_test_or_raise(db, test_id) + + if test.state != TestState.blue_evaluating: + raise BusinessRuleViolation( + f"Cannot update blue fields in '{test.state.value}' state " + "(must be blue_evaluating)" + ) + + for field, value in fields.items(): + setattr(test, field, value) + + db.flush() + return test + + +def get_test_timeline(db: Session, test_id: uuid.UUID) -> list[dict[str, Any]]: + """Return chronological audit-log history for a test. + + Raises EntityNotFoundError if the test does not exist. + """ + get_test_or_raise(db, test_id) + + logs = ( + db.query(AuditLog) + .filter( + AuditLog.entity_type == "test", + AuditLog.entity_id == str(test_id), + ) + .order_by(AuditLog.timestamp.asc()) + .all() + ) + + return [ + { + "id": str(log.id), + "action": log.action, + "user_id": str(log.user_id) if log.user_id else None, + "timestamp": log.timestamp.isoformat() if log.timestamp else None, + "details": log.details, + } + for log in logs + ] diff --git a/backend/tests/test_templates_crud.py b/backend/tests/test_templates_crud.py index 7b11b07..2169208 100644 --- a/backend/tests/test_templates_crud.py +++ b/backend/tests/test_templates_crud.py @@ -101,7 +101,7 @@ from app.routers.test_templates import ( toggle_template_active, template_stats, ) -from app.routers.tests import create_test_from_template +from app.services.test_crud_service import create_test_from_template as crud_create_from_template from app.schemas.test_template import TestTemplateCreate @@ -174,7 +174,8 @@ def test_get_templates_by_technique(): def test_instantiate_template(): """POST /tests/from-template creates a test pre-filled from template data.""" - source = inspect.getsource(create_test_from_template) + # Template field copying lives in the service; router delegates to it + source = inspect.getsource(crud_create_from_template) # Verify it reads from template and copies fields assert "template" in source, "Must reference template" diff --git a/backend/tests/test_workflow.py b/backend/tests/test_workflow.py index 19dd96c..0b61848 100644 --- a/backend/tests/test_workflow.py +++ b/backend/tests/test_workflow.py @@ -477,15 +477,15 @@ def test_evidence_team_separation(): def test_red_edit_allowed_in_draft_and_red_executing(): - """Verify the red update router checks that state is draft or red_executing.""" - from app.routers.tests import update_test_red + """Verify the red update checks that state is draft or red_executing.""" + from app.services.test_crud_service import update_test_red import inspect source = inspect.getsource(update_test_red) - # The function must guard against states other than draft/red_executing + # The service must guard against states other than draft/red_executing assert "draft" in source, "Red update must allow draft state" assert "red_executing" in source, "Red update must allow red_executing state" - assert "400" in source or "HTTP_400_BAD_REQUEST" in source, "Red update must return 400 for invalid state" + assert "BusinessRuleViolation" in source, "Must raise domain exception for invalid state (mapped to 400)" # ===========================================================================