"""CRUD router for security Tests.""" import uuid from datetime import datetime from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy.orm import Session, joinedload from app.database import get_db from app.dependencies.auth import get_current_user, require_role, require_any_role from app.models.enums import TestState from app.models.technique import Technique from app.models.test import Test from app.models.user import User from app.schemas.test import TestCreate, TestOut, TestUpdate, TestValidate from app.services.audit_service import log_action from app.services.status_service import recalculate_technique_status router = APIRouter(prefix="/tests", tags=["tests"]) # --------------------------------------------------------------------------- # POST /tests — create (red_tech or admin) # --------------------------------------------------------------------------- @router.post( "", response_model=TestOut, status_code=status.HTTP_201_CREATED, ) def create_test( payload: TestCreate, db: Session = Depends(get_db), current_user: User = Depends(require_any_role("red_tech")), ): """Create a new test linked to an existing technique. The ``created_by`` field is set automatically to the current user and ``state`` defaults to *draft*. """ # Verify the parent technique exists technique = db.query(Technique).filter(Technique.id == payload.technique_id).first() if technique is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"Technique with id '{payload.technique_id}' not found", ) test = Test( **payload.model_dump(), created_by=current_user.id, state=TestState.draft, ) db.add(test) db.commit() db.refresh(test) log_action( db, user_id=current_user.id, action="create_test", entity_type="test", entity_id=test.id, details={"name": test.name, "technique_id": str(test.technique_id)}, ) return test # --------------------------------------------------------------------------- # GET /tests/{id} — detail (with evidences) # --------------------------------------------------------------------------- @router.get("/{test_id}", response_model=TestOut) def get_test( test_id: uuid.UUID, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ): """Return full details for a single test, including its evidences.""" test = ( db.query(Test) .options(joinedload(Test.evidences)) .filter(Test.id == test_id) .first() ) if test is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Test not found", ) return test # --------------------------------------------------------------------------- # PATCH /tests/{id} — update (creator or admin, only in draft/rejected) # --------------------------------------------------------------------------- @router.patch("/{test_id}", response_model=TestOut) def update_test( test_id: uuid.UUID, payload: TestUpdate, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ): """Update one or more fields of an existing test. Only the original creator or an admin can update. The test must be in ``draft`` or ``rejected`` state. """ test = db.query(Test).filter(Test.id == test_id).first() if test is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Test not found", ) # Ownership / admin check if current_user.role != "admin" and test.created_by != current_user.id: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Not enough permissions", ) # State guard if test.state not in (TestState.draft, TestState.rejected): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Cannot update a test in '{test.state.value}' state (must be draft or rejected)", ) update_data = payload.model_dump(exclude_unset=True) for field, value in update_data.items(): setattr(test, field, value) db.commit() db.refresh(test) log_action( db, user_id=current_user.id, action="update_test", entity_type="test", entity_id=test.id, details={"updated_fields": list(update_data.keys())}, ) return test # --------------------------------------------------------------------------- # POST /tests/{id}/validate — validate (leads + admin) # --------------------------------------------------------------------------- @router.post("/{test_id}/validate", response_model=TestOut) def validate_test( test_id: uuid.UUID, payload: TestValidate, db: Session = Depends(get_db), current_user: User = Depends(require_any_role("red_lead", "blue_lead")), ): """Mark a test as validated. Sets ``state`` to *validated*, records ``validated_by`` / ``validated_at``, stores the ``result``, and recalculates the parent technique's global status. """ test = ( db.query(Test) .options(joinedload(Test.technique)) .filter(Test.id == test_id) .first() ) if test is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Test not found", ) test.state = TestState.validated test.result = payload.result test.validated_by = current_user.id test.validated_at = datetime.utcnow() db.commit() db.refresh(test) # Recalculate the parent technique's global status technique = test.technique recalculate_technique_status(db, technique) log_action( db, user_id=current_user.id, action="validate_test", entity_type="test", entity_id=test.id, details={ "result": payload.result.value, "technique_id": str(test.technique_id), }, ) return test # --------------------------------------------------------------------------- # POST /tests/{id}/reject — reject (leads + admin) # --------------------------------------------------------------------------- @router.post("/{test_id}/reject", response_model=TestOut) def reject_test( test_id: uuid.UUID, db: Session = Depends(get_db), current_user: User = Depends(require_any_role("red_lead", "blue_lead")), ): """Reject a test, setting its state to *rejected*.""" test = db.query(Test).filter(Test.id == test_id).first() if test is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Test not found", ) test.state = TestState.rejected db.commit() db.refresh(test) log_action( db, user_id=current_user.id, action="reject_test", entity_type="test", entity_id=test.id, details={"technique_id": str(test.technique_id)}, ) return test