"""Test CRUD service — list, create, update, and query logic for security tests. Framework-agnostic; uses domain exceptions from app.domain.errors. The router is responsible for HTTP concerns, auth, audit logging, and commit. """ # Import uuid import uuid # Import Any from typing from typing import Any # Import Session, joinedload from sqlalchemy.orm from sqlalchemy.orm import Session, joinedload # Import from app.domain.errors from app.domain.errors import ( BusinessRuleViolation, EntityNotFoundError, PermissionViolation, ) # Import AuditLog from app.models.audit from app.models.audit import AuditLog # Import TestState from app.models.enums from app.models.enums import TestState # Import Technique from app.models.technique from app.models.technique import Technique # Import Test from app.models.test from app.models.test import Test # Import TestTemplate from app.models.test_template from app.models.test_template import TestTemplate # Import escape_like from app.utils from app.utils import escape_like # Define function list_tests def list_tests( # Entry: db db: Session, *, # Entry: state state: str | None = None, # Entry: technique_id technique_id: uuid.UUID | None = None, # Entry: platform platform: str | None = None, # Entry: created_by created_by: uuid.UUID | None = None, # Entry: pending_validation_side pending_validation_side: str | None = None, # Entry: offset offset: int = 0, # Entry: limit limit: int = 50, ) -> list[Test]: """Return a paginated list of tests with optional filters. Args: db (Session): Active SQLAlchemy database session. state (str | None): Filter by test state string value. technique_id (uuid.UUID | None): Filter by linked technique UUID. platform (str | None): Case-insensitive substring filter on the ``platform`` field. created_by (uuid.UUID | None): Filter by creator user UUID. pending_validation_side (str | None): When ``"red"`` or ``"blue"``, returns only ``in_review`` tests awaiting that side's vote. offset (int): Number of records to skip for pagination. limit (int): Maximum number of records to return. Returns: list[Test]: Matching test records ordered by creation date descending. """ # Assign query = db.query(Test).options(joinedload(Test.technique)) query = db.query(Test).options(joinedload(Test.technique)) # Check: state if state: # Assign query = query.filter(Test.state == state) query = query.filter(Test.state == state) # Check: technique_id if technique_id: # Assign query = query.filter(Test.technique_id == technique_id) query = query.filter(Test.technique_id == technique_id) # Check: platform if platform: # Assign query = query.filter(Test.platform.ilike(f"%{escape_like(platform)}%")) query = query.filter(Test.platform.ilike(f"%{escape_like(platform)}%")) # Check: created_by if created_by: # Assign query = query.filter(Test.created_by == created_by) query = query.filter(Test.created_by == created_by) # Check: pending_validation_side == "red" if pending_validation_side == "red": # Assign query = query.filter( query = query.filter( Test.state == TestState.in_review, Test.red_validation_status.in_(["pending", None]), ) # Alternative: pending_validation_side == "blue" elif pending_validation_side == "blue": # Assign query = query.filter( query = query.filter( Test.state == TestState.in_review, Test.blue_validation_status.in_(["pending", None]), ) # Return query.order_by(Test.created_at.desc()).offset(offset).limit(limit).... return query.order_by(Test.created_at.desc()).offset(offset).limit(limit).all() # Define function create_test def create_test( # Entry: db db: Session, *, # Entry: technique_id technique_id: uuid.UUID, # Entry: creator_id creator_id: uuid.UUID, **fields: object, ) -> Test: """Create a new test linked to an existing technique. Raises EntityNotFoundError if the technique does not exist. Does not commit; caller uses UnitOfWork. Args: db (Session): Active SQLAlchemy database session. technique_id (uuid.UUID): UUID of the technique this test covers. creator_id (uuid.UUID): UUID of the user creating the test. **fields (object): Additional keyword arguments set as attributes on the new test (e.g. ``name``, ``platform``, ``description``). Returns: Test: The newly created test ORM object, flushed but not committed. """ # Assign technique = db.query(Technique).filter(Technique.id == technique_id).first() technique = db.query(Technique).filter(Technique.id == technique_id).first() # Check: technique is None if technique is None: # Raise EntityNotFoundError raise EntityNotFoundError("Technique", str(technique_id)) # Assign test = Test( test = Test( # Keyword argument: technique_id technique_id=technique_id, # Keyword argument: created_by created_by=creator_id, # Keyword argument: state state=TestState.draft, **fields, ) # Stage new record(s) for database insertion db.add(test) # Flush changes to DB without committing the transaction db.flush() # Return test return test # Define function create_test_from_template def create_test_from_template( # Entry: db db: Session, *, # Entry: template_id template_id: uuid.UUID, # Entry: technique_id_or_mitre technique_id_or_mitre: str, # Entry: creator_id creator_id: uuid.UUID, ) -> Test: """Instantiate a Test from a TestTemplate. technique_id_or_mitre can be a UUID string or MITRE ID (e.g. T1059.001). Raises EntityNotFoundError if template or technique not found. Does not commit; caller uses UnitOfWork. Args: db (Session): Active SQLAlchemy database session. template_id (uuid.UUID): UUID of the template to instantiate. technique_id_or_mitre (str): UUID string or MITRE technique ID (e.g. ``"T1059.001"``) identifying the target technique. creator_id (uuid.UUID): UUID of the user creating the test. Returns: Test: The newly created test populated from template fields, flushed but not committed. """ # Assign template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first() template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first() # Check: template is None if template is None: # Raise EntityNotFoundError raise EntityNotFoundError("TestTemplate", str(template_id)) # Assign technique = None technique = None # Attempt the following; catch errors below try: # Assign technique_uuid = uuid.UUID(technique_id_or_mitre) technique_uuid = uuid.UUID(technique_id_or_mitre) # Assign technique = db.query(Technique).filter(Technique.id == technique_uuid).first() technique = db.query(Technique).filter(Technique.id == technique_uuid).first() # Handle ValueError except ValueError: # Intentional no-op placeholder pass # Check: technique is None if technique is None: # Assign technique = db.query(Technique).filter( technique = db.query(Technique).filter( Technique.mitre_id == technique_id_or_mitre ).first() # Check: technique is None if technique is None: # Raise EntityNotFoundError raise EntityNotFoundError("Technique", technique_id_or_mitre) # Assign test = Test( test = Test( # Keyword argument: technique_id technique_id=technique.id, # Keyword argument: name name=template.name, # Keyword argument: description description=template.description, # Keyword argument: platform platform=template.platform, # Keyword argument: procedure_text procedure_text=template.attack_procedure, # Keyword argument: tool_used tool_used=template.tool_suggested, # Keyword argument: remediation_steps remediation_steps=template.suggested_remediation, # Keyword argument: created_by created_by=creator_id, # Keyword argument: state state=TestState.draft, ) # Stage new record(s) for database insertion db.add(test) # Flush changes to DB without committing the transaction db.flush() # Return test return test # Define function get_test_detail def get_test_detail(db: Session, test_id: uuid.UUID) -> Test: """Fetch a test with evidences eager-loaded. Raises EntityNotFoundError if the test does not exist. Args: db (Session): Active SQLAlchemy database session. test_id (uuid.UUID): UUID of the test to retrieve. Returns: Test: The test ORM object with ``evidences`` relationship loaded. """ # Assign test = ( test = ( db.query(Test) # Chain .options() call .options(joinedload(Test.evidences)) # Chain .filter() call .filter(Test.id == test_id) # Chain .first() call .first() ) # Check: test is None if test is None: # Raise EntityNotFoundError raise EntityNotFoundError("Test", str(test_id)) # Return test return test # Define function get_test_or_raise def get_test_or_raise(db: Session, test_id: uuid.UUID) -> Test: """Fetch a test by ID. Raises EntityNotFoundError if not found. Args: db (Session): Active SQLAlchemy database session. test_id (uuid.UUID): UUID of the test to retrieve. Returns: Test: The matching test ORM object. """ # Assign test = db.query(Test).filter(Test.id == test_id).first() test = db.query(Test).filter(Test.id == test_id).first() # Check: test is None if test is None: # Raise EntityNotFoundError raise EntityNotFoundError("Test", str(test_id)) # Return test return test # Define function get_test_with_technique def get_test_with_technique(db: Session, test_id: uuid.UUID) -> Test: """Fetch a test with technique joined. Raises EntityNotFoundError if not found. Args: db (Session): Active SQLAlchemy database session. test_id (uuid.UUID): UUID of the test to retrieve. Returns: Test: The test ORM object with ``technique`` relationship loaded. """ # Assign test = ( test = ( db.query(Test) # Chain .options() call .options(joinedload(Test.technique)) # Chain .filter() call .filter(Test.id == test_id) # Chain .first() call .first() ) # Check: test is None if test is None: # Raise EntityNotFoundError raise EntityNotFoundError("Test", str(test_id)) # Return test return test # Define function update_test def update_test( # Entry: db db: Session, # Entry: test_id test_id: uuid.UUID, *, # Entry: updater_id updater_id: uuid.UUID, # Entry: updater_role updater_role: str, **fields: object, ) -> Test: """Update general test fields (draft or rejected only). Raises PermissionViolation if not creator or admin. Raises BusinessRuleViolation if state is not draft or rejected. Raises EntityNotFoundError if test not found. Does not commit; caller uses UnitOfWork. Args: db (Session): Active SQLAlchemy database session. test_id (uuid.UUID): UUID of the test to update. updater_id (uuid.UUID): UUID of the user performing the update. updater_role (str): Role of the updater; ``"admin"`` bypasses the creator-only restriction. **fields (object): Keyword arguments mapped directly onto test attributes. Returns: Test: The updated test ORM object, flushed but not committed. """ # Assign test = get_test_or_raise(db, test_id) test = get_test_or_raise(db, test_id) # Check: updater_role != "admin" and test.created_by != updater_id if updater_role != "admin" and test.created_by != updater_id: # Raise PermissionViolation raise PermissionViolation( # Literal argument value "Only the test creator or an admin can update this test" ) # Check: test.state not in (TestState.draft, TestState.rejected) if test.state not in (TestState.draft, TestState.rejected): # Raise BusinessRuleViolation raise BusinessRuleViolation( f"Cannot update a test in '{test.state.value}' state (must be draft or rejected)" ) # Iterate over fields.items() for field, value in fields.items(): # Call setattr() setattr(test, field, value) # Flush changes to DB without committing the transaction db.flush() # Return test return test # Define function update_test_red def update_test_red(db: Session, test_id: uuid.UUID, **fields: object) -> Test: """Update Red Team fields (draft or red_executing only). Raises BusinessRuleViolation if state not in (draft, red_executing). Raises EntityNotFoundError if test not found. Does not commit; caller uses UnitOfWork. Args: db (Session): Active SQLAlchemy database session. test_id (uuid.UUID): UUID of the test to update. **fields (object): Red-team field names and their new values. Returns: Test: The updated test ORM object, flushed but not committed. """ # Assign test = get_test_or_raise(db, test_id) test = get_test_or_raise(db, test_id) # Check: test.state not in (TestState.draft, TestState.red_executing) if test.state not in (TestState.draft, TestState.red_executing): # Raise BusinessRuleViolation raise BusinessRuleViolation( f"Cannot update red fields in '{test.state.value}' state " # Literal argument value "(must be draft or red_executing)" ) # Iterate over fields.items() for field, value in fields.items(): # Call setattr() setattr(test, field, value) # Flush changes to DB without committing the transaction db.flush() # Return test return test # Define function update_test_blue def update_test_blue(db: Session, test_id: uuid.UUID, **fields: object) -> Test: """Update Blue Team fields (blue_evaluating only). Raises BusinessRuleViolation if state is not blue_evaluating. Raises EntityNotFoundError if test not found. Does not commit; caller uses UnitOfWork. Args: db (Session): Active SQLAlchemy database session. test_id (uuid.UUID): UUID of the test to update. **fields (object): Blue-team field names and their new values. Returns: Test: The updated test ORM object, flushed but not committed. """ # Assign test = get_test_or_raise(db, test_id) test = get_test_or_raise(db, test_id) # Check: test.state != TestState.blue_evaluating if test.state != TestState.blue_evaluating: # Raise BusinessRuleViolation raise BusinessRuleViolation( f"Cannot update blue fields in '{test.state.value}' state " # Literal argument value "(must be blue_evaluating)" ) # Iterate over fields.items() for field, value in fields.items(): # Call setattr() setattr(test, field, value) # Flush changes to DB without committing the transaction db.flush() # Return test return test # Define function get_test_timeline def get_test_timeline(db: Session, test_id: uuid.UUID) -> list[dict[str, Any]]: """Return chronological audit-log history for a test. Raises EntityNotFoundError if the test does not exist. Args: db (Session): Active SQLAlchemy database session. test_id (uuid.UUID): UUID of the test whose history is requested. Returns: list[dict[str, Any]]: Audit-log entries ordered by timestamp ascending, each containing ``id``, ``action``, ``user_id``, ``timestamp``, and ``details``. """ # Call get_test_or_raise() get_test_or_raise(db, test_id) # Assign logs = ( logs = ( db.query(AuditLog) # Chain .filter() call .filter( AuditLog.entity_type == "test", AuditLog.entity_id == str(test_id), ) # Chain .order_by() call .order_by(AuditLog.timestamp.asc()) # Chain .all() call .all() ) # Return [ return [ { # Literal argument value "id": str(log.id), # Literal argument value "action": log.action, # Literal argument value "user_id": str(log.user_id) if log.user_id else None, # Literal argument value "timestamp": log.timestamp.isoformat() if log.timestamp else None, # Literal argument value "details": log.details, } for log in logs ]