"""Test CRUD service — list, create, update, and query logic for security tests. Framework-agnostic; uses domain exceptions from app.domain.errors. The router is responsible for HTTP concerns, auth, audit logging, and commit. """ # Import uuid import uuid from datetime import datetime from typing import Any # Import Session, joinedload from sqlalchemy.orm from sqlalchemy.orm import Session, joinedload # Import from app.domain.errors from app.domain.errors import ( BusinessRuleViolation, EntityNotFoundError, PermissionViolation, ) from app.models.enums import TestState from app.models.technique import Technique from app.models.test import Test from app.models.test_template import TestTemplate from app.models.campaign import Campaign, CampaignTest from app.models.audit import AuditLog from app.utils import escape_like # Define function list_tests def list_tests( # Entry: db db: Session, *, # Entry: state state: str | None = None, # Entry: technique_id technique_id: uuid.UUID | None = None, # Entry: platform platform: str | None = None, # Entry: created_by created_by: uuid.UUID | None = None, # Entry: pending_validation_side pending_validation_side: str | None = None, not_in_any_campaign: bool = False, offset: int = 0, # Entry: limit limit: int = 50, ) -> list[Test]: """Return a paginated list of tests with optional filters. Tests that belong to a campaign still in 'draft' status AND with a start_date in the future are always excluded — they should not appear in the team's queue until the campaign is activated on its start date. """ query = db.query(Test).options(joinedload(Test.technique)) # Check: state if state: # Assign query = query.filter(Test.state == state) query = query.filter(Test.state == state) # Check: technique_id if technique_id: # Assign query = query.filter(Test.technique_id == technique_id) query = query.filter(Test.technique_id == technique_id) # Check: platform if platform: # Assign query = query.filter(Test.platform.ilike(f"%{escape_like(platform)}%")) query = query.filter(Test.platform.ilike(f"%{escape_like(platform)}%")) # Check: created_by if created_by: # Assign query = query.filter(Test.created_by == created_by) query = query.filter(Test.created_by == created_by) # Check: pending_validation_side == "red" if pending_validation_side == "red": # Assign query = query.filter( query = query.filter( Test.state == TestState.in_review, Test.red_validation_status.in_(["pending", None]), ) # Alternative: pending_validation_side == "blue" elif pending_validation_side == "blue": # Assign query = query.filter( query = query.filter( Test.state == TestState.in_review, Test.blue_validation_status.in_(["pending", None]), ) if not_in_any_campaign: linked = db.query(CampaignTest.test_id).distinct().subquery() query = query.filter(~Test.id.in_(linked)) # Always hide tests from scheduled campaigns that haven't started yet. # A "scheduled-but-not-yet-active" campaign = draft status + start_date in future. now = datetime.utcnow() future_draft_tests = ( db.query(CampaignTest.test_id) .join(Campaign, Campaign.id == CampaignTest.campaign_id) .filter( Campaign.status == "draft", Campaign.start_date.isnot(None), Campaign.start_date > now, ) .distinct() .subquery() ) query = query.filter(~Test.id.in_(future_draft_tests)) # Return query.order_by(Test.created_at.desc()).offset(offset).limit(limit).... return query.order_by(Test.created_at.desc()).offset(offset).limit(limit).all() # Define function create_test def create_test( # Entry: db db: Session, *, # Entry: technique_id technique_id: uuid.UUID, # Entry: creator_id creator_id: uuid.UUID, **fields: object, ) -> Test: """Create a new test linked to an existing technique. Raises EntityNotFoundError if the technique does not exist. Does not commit; caller uses UnitOfWork. Args: db (Session): Active SQLAlchemy database session. technique_id (uuid.UUID): UUID of the technique this test covers. creator_id (uuid.UUID): UUID of the user creating the test. **fields (object): Additional keyword arguments set as attributes on the new test (e.g. ``name``, ``platform``, ``description``). Returns: Test: The newly created test ORM object, flushed but not committed. """ # Assign technique = db.query(Technique).filter(Technique.id == technique_id).first() technique = db.query(Technique).filter(Technique.id == technique_id).first() # Check: technique is None if technique is None: # Raise EntityNotFoundError raise EntityNotFoundError("Technique", str(technique_id)) # Assign test = Test( test = Test( # Keyword argument: technique_id technique_id=technique_id, # Keyword argument: created_by created_by=creator_id, # Keyword argument: state state=TestState.draft, created_at=datetime.utcnow(), # explicit — DB column has no server default **fields, ) # Stage new record(s) for database insertion db.add(test) # Flush changes to DB without committing the transaction db.flush() # Return test return test # Define function create_test_from_template def create_test_from_template( # Entry: db db: Session, *, # Entry: template_id template_id: uuid.UUID, # Entry: technique_id_or_mitre technique_id_or_mitre: str, # Entry: creator_id creator_id: uuid.UUID, # Optional user-edited overrides (take priority over template values) name_override: str | None = None, description_override: str | None = None, platform_override: str | None = None, procedure_text_override: str | None = None, tool_used_override: str | None = None, ) -> Test: """Instantiate a Test from a TestTemplate. technique_id_or_mitre can be a UUID string or MITRE ID (e.g. T1059.001). Override fields, when provided, take precedence over the template's values. Raises EntityNotFoundError if template or technique not found. Does not commit; caller uses UnitOfWork. Args: db (Session): Active SQLAlchemy database session. template_id (uuid.UUID): UUID of the template to instantiate. technique_id_or_mitre (str): UUID string or MITRE technique ID (e.g. ``"T1059.001"``) identifying the target technique. creator_id (uuid.UUID): UUID of the user creating the test. Returns: Test: The newly created test populated from template fields, flushed but not committed. """ # Assign template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first() template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first() # Check: template is None if template is None: # Raise EntityNotFoundError raise EntityNotFoundError("TestTemplate", str(template_id)) # Assign technique = None technique = None # Attempt the following; catch errors below try: # Assign technique_uuid = uuid.UUID(technique_id_or_mitre) technique_uuid = uuid.UUID(technique_id_or_mitre) # Assign technique = db.query(Technique).filter(Technique.id == technique_uuid).first() technique = db.query(Technique).filter(Technique.id == technique_uuid).first() # Handle ValueError except ValueError: # Intentional no-op placeholder pass # Check: technique is None if technique is None: # Assign technique = db.query(Technique).filter( technique = db.query(Technique).filter( Technique.mitre_id == technique_id_or_mitre ).first() # Check: technique is None if technique is None: # Raise EntityNotFoundError raise EntityNotFoundError("Technique", technique_id_or_mitre) # Assign test = Test( test = Test( # Keyword argument: technique_id technique_id=technique.id, name=name_override if name_override is not None else template.name, description=description_override if description_override is not None else template.description, platform=platform_override if platform_override is not None else template.platform, procedure_text=procedure_text_override if procedure_text_override is not None else template.attack_procedure, tool_used=tool_used_override if tool_used_override is not None else template.tool_suggested, remediation_steps=template.suggested_remediation, # Keyword argument: created_by created_by=creator_id, # Keyword argument: state state=TestState.draft, created_at=datetime.utcnow(), # explicit — DB column has no server default ) # Stage new record(s) for database insertion db.add(test) # Flush changes to DB without committing the transaction db.flush() # Return test return test # Define function get_test_detail def get_test_detail(db: Session, test_id: uuid.UUID) -> Test: """Fetch a test with evidences and technique eager-loaded. Raises EntityNotFoundError if the test does not exist. Args: db (Session): Active SQLAlchemy database session. test_id (uuid.UUID): UUID of the test to retrieve. Returns: Test: The test ORM object with ``evidences`` relationship loaded. """ # Assign test = ( test = ( db.query(Test) .options(joinedload(Test.evidences), joinedload(Test.technique)) .filter(Test.id == test_id) # Chain .first() call .first() ) # Check: test is None if test is None: # Raise EntityNotFoundError raise EntityNotFoundError("Test", str(test_id)) # Return test return test # Define function get_test_or_raise def get_test_or_raise(db: Session, test_id: uuid.UUID) -> Test: """Fetch a test by ID. Raises EntityNotFoundError if not found. Args: db (Session): Active SQLAlchemy database session. test_id (uuid.UUID): UUID of the test to retrieve. Returns: Test: The matching test ORM object. """ # Assign test = db.query(Test).filter(Test.id == test_id).first() test = db.query(Test).filter(Test.id == test_id).first() # Check: test is None if test is None: # Raise EntityNotFoundError raise EntityNotFoundError("Test", str(test_id)) # Return test return test # Define function get_test_with_technique def get_test_with_technique(db: Session, test_id: uuid.UUID) -> Test: """Fetch a test with technique joined. Raises EntityNotFoundError if not found. Args: db (Session): Active SQLAlchemy database session. test_id (uuid.UUID): UUID of the test to retrieve. Returns: Test: The test ORM object with ``technique`` relationship loaded. """ # Assign test = ( test = ( db.query(Test) # Chain .options() call .options(joinedload(Test.technique)) # Chain .filter() call .filter(Test.id == test_id) # Chain .first() call .first() ) # Check: test is None if test is None: # Raise EntityNotFoundError raise EntityNotFoundError("Test", str(test_id)) # Return test return test # Define function update_test def update_test( # Entry: db db: Session, # Entry: test_id test_id: uuid.UUID, *, # Entry: updater_id updater_id: uuid.UUID, # Entry: updater_role updater_role: str, **fields: object, ) -> Test: """Update general test fields (draft or rejected only). Raises PermissionViolation if not creator or admin. Raises BusinessRuleViolation if state is not draft or rejected. Raises EntityNotFoundError if test not found. Does not commit; caller uses UnitOfWork. Args: db (Session): Active SQLAlchemy database session. test_id (uuid.UUID): UUID of the test to update. updater_id (uuid.UUID): UUID of the user performing the update. updater_role (str): Role of the updater; ``"admin"`` bypasses the creator-only restriction. **fields (object): Keyword arguments mapped directly onto test attributes. Returns: Test: The updated test ORM object, flushed but not committed. """ # Assign test = get_test_or_raise(db, test_id) test = get_test_or_raise(db, test_id) # Check: updater_role != "admin" and test.created_by != updater_id if updater_role != "admin" and test.created_by != updater_id: # Raise PermissionViolation raise PermissionViolation( # Literal argument value "Only the test creator or an admin can update this test" ) # Check: test.state not in (TestState.draft, TestState.rejected) if test.state not in (TestState.draft, TestState.rejected): # Raise BusinessRuleViolation raise BusinessRuleViolation( f"Cannot update a test in '{test.state.value}' state (must be draft or rejected)" ) # Iterate over fields.items() for field, value in fields.items(): # Call setattr() setattr(test, field, value) # Flush changes to DB without committing the transaction db.flush() # Return test return test # Define function update_test_red def update_test_red(db: Session, test_id: uuid.UUID, **fields: object) -> Test: """Update Red Team fields (draft or red_executing only). Raises BusinessRuleViolation if state not in (draft, red_executing). Raises EntityNotFoundError if test not found. Does not commit; caller uses UnitOfWork. Args: db (Session): Active SQLAlchemy database session. test_id (uuid.UUID): UUID of the test to update. **fields (object): Red-team field names and their new values. Returns: Test: The updated test ORM object, flushed but not committed. """ # Assign test = get_test_or_raise(db, test_id) test = get_test_or_raise(db, test_id) # Check: test.state not in (TestState.draft, TestState.red_executing) if test.state not in (TestState.draft, TestState.red_executing): # Raise BusinessRuleViolation raise BusinessRuleViolation( f"Cannot update red fields in '{test.state.value}' state " # Literal argument value "(must be draft or red_executing)" ) # Iterate over fields.items() for field, value in fields.items(): # Call setattr() setattr(test, field, value) # Flush changes to DB without committing the transaction db.flush() # Return test return test # Define function update_test_blue def update_test_blue(db: Session, test_id: uuid.UUID, **fields: object) -> Test: """Update Blue Team fields (blue_evaluating only). Raises BusinessRuleViolation if state is not blue_evaluating. Raises EntityNotFoundError if test not found. Does not commit; caller uses UnitOfWork. Args: db (Session): Active SQLAlchemy database session. test_id (uuid.UUID): UUID of the test to update. **fields (object): Blue-team field names and their new values. Returns: Test: The updated test ORM object, flushed but not committed. """ # Assign test = get_test_or_raise(db, test_id) test = get_test_or_raise(db, test_id) # Check: test.state != TestState.blue_evaluating if test.state != TestState.blue_evaluating: # Raise BusinessRuleViolation raise BusinessRuleViolation( f"Cannot update blue fields in '{test.state.value}' state " # Literal argument value "(must be blue_evaluating)" ) # Iterate over fields.items() for field, value in fields.items(): # Call setattr() setattr(test, field, value) # Flush changes to DB without committing the transaction db.flush() # Return test return test # Define function get_test_timeline def get_test_timeline(db: Session, test_id: uuid.UUID) -> list[dict[str, Any]]: """Return chronological audit-log history for a test. Raises EntityNotFoundError if the test does not exist. Args: db (Session): Active SQLAlchemy database session. test_id (uuid.UUID): UUID of the test whose history is requested. Returns: list[dict[str, Any]]: Audit-log entries ordered by timestamp ascending, each containing ``id``, ``action``, ``user_id``, ``timestamp``, and ``details``. """ # Call get_test_or_raise() get_test_or_raise(db, test_id) # Assign logs = ( logs = ( db.query(AuditLog) # Chain .filter() call .filter( AuditLog.entity_type == "test", AuditLog.entity_id == str(test_id), ) # Chain .order_by() call .order_by(AuditLog.timestamp.asc()) # Chain .all() call .all() ) # Return [ return [ { # Literal argument value "id": str(log.id), # Literal argument value "action": log.action, # Literal argument value "user_id": str(log.user_id) if log.user_id else None, # Literal argument value "timestamp": log.timestamp.isoformat() if log.timestamp else None, # Literal argument value "details": log.details, } for log in logs ]