"""Test CRUD service — list, create, update, and query logic for security tests. Framework-agnostic; uses domain exceptions from app.domain.errors. The router is responsible for HTTP concerns, auth, audit logging, and commit. """ import uuid from datetime import datetime from typing import Any from sqlalchemy.orm import Session, joinedload from app.domain.errors import ( BusinessRuleViolation, EntityNotFoundError, PermissionViolation, ) from app.models.enums import TestState from app.models.technique import Technique from app.models.test import Test from app.models.test_template import TestTemplate from app.models.audit import AuditLog from app.utils import escape_like def list_tests( db: Session, *, state: str | None = None, technique_id: uuid.UUID | None = None, platform: str | None = None, created_by: uuid.UUID | None = None, pending_validation_side: str | None = None, offset: int = 0, limit: int = 50, ) -> list[Test]: """Return a paginated list of tests with optional filters.""" query = db.query(Test).options(joinedload(Test.technique)) if state: query = query.filter(Test.state == state) if technique_id: query = query.filter(Test.technique_id == technique_id) if platform: query = query.filter(Test.platform.ilike(f"%{escape_like(platform)}%")) if created_by: query = query.filter(Test.created_by == created_by) if pending_validation_side == "red": query = query.filter( Test.state == TestState.in_review, Test.red_validation_status.in_(["pending", None]), ) elif pending_validation_side == "blue": query = query.filter( Test.state == TestState.in_review, Test.blue_validation_status.in_(["pending", None]), ) return query.order_by(Test.created_at.desc()).offset(offset).limit(limit).all() def create_test( db: Session, *, technique_id: uuid.UUID, creator_id: uuid.UUID, **fields: Any, ) -> Test: """Create a new test linked to an existing technique. Raises EntityNotFoundError if the technique does not exist. Does not commit; caller uses UnitOfWork. """ technique = db.query(Technique).filter(Technique.id == technique_id).first() if technique is None: raise EntityNotFoundError("Technique", str(technique_id)) test = Test( technique_id=technique_id, created_by=creator_id, state=TestState.draft, created_at=datetime.utcnow(), # explicit — DB column has no server default **fields, ) db.add(test) db.flush() return test def create_test_from_template( db: Session, *, template_id: uuid.UUID, technique_id_or_mitre: str, creator_id: uuid.UUID, # Optional user-edited overrides (take priority over template values) name_override: str | None = None, description_override: str | None = None, platform_override: str | None = None, procedure_text_override: str | None = None, tool_used_override: str | None = None, ) -> Test: """Instantiate a Test from a TestTemplate. technique_id_or_mitre can be a UUID string or MITRE ID (e.g. T1059.001). Override fields, when provided, take precedence over the template's values. Raises EntityNotFoundError if template or technique not found. Does not commit; caller uses UnitOfWork. """ template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first() if template is None: raise EntityNotFoundError("TestTemplate", str(template_id)) technique = None try: technique_uuid = uuid.UUID(technique_id_or_mitre) technique = db.query(Technique).filter(Technique.id == technique_uuid).first() except ValueError: pass if technique is None: technique = db.query(Technique).filter( Technique.mitre_id == technique_id_or_mitre ).first() if technique is None: raise EntityNotFoundError("Technique", technique_id_or_mitre) test = Test( technique_id=technique.id, name=name_override if name_override is not None else template.name, description=description_override if description_override is not None else template.description, platform=platform_override if platform_override is not None else template.platform, procedure_text=procedure_text_override if procedure_text_override is not None else template.attack_procedure, tool_used=tool_used_override if tool_used_override is not None else template.tool_suggested, remediation_steps=template.suggested_remediation, created_by=creator_id, state=TestState.draft, created_at=datetime.utcnow(), # explicit — DB column has no server default ) db.add(test) db.flush() return test def get_test_detail(db: Session, test_id: uuid.UUID) -> Test: """Fetch a test with evidences and technique eager-loaded. Raises EntityNotFoundError if the test does not exist. """ test = ( db.query(Test) .options(joinedload(Test.evidences), joinedload(Test.technique)) .filter(Test.id == test_id) .first() ) if test is None: raise EntityNotFoundError("Test", str(test_id)) return test def get_test_or_raise(db: Session, test_id: uuid.UUID) -> Test: """Fetch a test by ID. Raises EntityNotFoundError if not found.""" test = db.query(Test).filter(Test.id == test_id).first() if test is None: raise EntityNotFoundError("Test", str(test_id)) return test def get_test_with_technique(db: Session, test_id: uuid.UUID) -> Test: """Fetch a test with technique joined. Raises EntityNotFoundError if not found.""" test = ( db.query(Test) .options(joinedload(Test.technique)) .filter(Test.id == test_id) .first() ) if test is None: raise EntityNotFoundError("Test", str(test_id)) return test def update_test( db: Session, test_id: uuid.UUID, *, updater_id: uuid.UUID, updater_role: str, **fields: Any, ) -> Test: """Update general test fields (draft or rejected only). Raises PermissionViolation if not creator or admin. Raises BusinessRuleViolation if state is not draft or rejected. Raises EntityNotFoundError if test not found. Does not commit; caller uses UnitOfWork. """ test = get_test_or_raise(db, test_id) if updater_role != "admin" and test.created_by != updater_id: raise PermissionViolation( "Only the test creator or an admin can update this test" ) if test.state not in (TestState.draft, TestState.rejected): raise BusinessRuleViolation( f"Cannot update a test in '{test.state.value}' state (must be draft or rejected)" ) for field, value in fields.items(): setattr(test, field, value) db.flush() return test def update_test_red(db: Session, test_id: uuid.UUID, **fields: Any) -> Test: """Update Red Team fields (draft or red_executing only). Raises BusinessRuleViolation if state not in (draft, red_executing). Raises EntityNotFoundError if test not found. Does not commit; caller uses UnitOfWork. """ test = get_test_or_raise(db, test_id) if test.state not in (TestState.draft, TestState.red_executing): raise BusinessRuleViolation( f"Cannot update red fields in '{test.state.value}' state " "(must be draft or red_executing)" ) for field, value in fields.items(): setattr(test, field, value) db.flush() return test def update_test_blue(db: Session, test_id: uuid.UUID, **fields: Any) -> Test: """Update Blue Team fields (blue_evaluating only). Raises BusinessRuleViolation if state is not blue_evaluating. Raises EntityNotFoundError if test not found. Does not commit; caller uses UnitOfWork. """ test = get_test_or_raise(db, test_id) if test.state != TestState.blue_evaluating: raise BusinessRuleViolation( f"Cannot update blue fields in '{test.state.value}' state " "(must be blue_evaluating)" ) for field, value in fields.items(): setattr(test, field, value) db.flush() return test def get_test_timeline(db: Session, test_id: uuid.UUID) -> list[dict[str, Any]]: """Return chronological audit-log history for a test. Raises EntityNotFoundError if the test does not exist. """ get_test_or_raise(db, test_id) logs = ( db.query(AuditLog) .filter( AuditLog.entity_type == "test", AuditLog.entity_id == str(test_id), ) .order_by(AuditLog.timestamp.asc()) .all() ) return [ { "id": str(log.id), "action": log.action, "user_id": str(log.user_id) if log.user_id else None, "timestamp": log.timestamp.isoformat() if log.timestamp else None, "details": log.details, } for log in logs ]