"""Pydantic schemas for Test endpoints.""" import uuid from datetime import datetime from pydantic import BaseModel, ConfigDict, model_validator from app.domain.enums import DataClassification from app.models.enums import TestResult, TestState from app.schemas.evidence import EvidenceOut # ── Create ────────────────────────────────────────────────────────── class TestCreate(BaseModel): """Payload for creating a new test.""" technique_id: uuid.UUID name: str description: str | None = None platform: str | None = None procedure_text: str | None = None tool_used: str | None = None # ── Update (general) ─────────────────────────────────────────────── class TestClassificationUpdate(BaseModel): """Admin-only payload for changing data classification.""" data_classification: DataClassification class TestUpdate(BaseModel): """Payload for partially updating an existing test. Every field is optional so callers send only what changed.""" name: str | None = None description: str | None = None platform: str | None = None procedure_text: str | None = None tool_used: str | None = None result: TestResult | None = None # ── Red Team update ──────────────────────────────────────────────── class TestRedUpdate(BaseModel): """Fields that Red Team fills in during the red_executing phase.""" name: str | None = None description: str | None = None procedure_text: str | None = None tool_used: str | None = None attack_success: bool | None = None red_summary: str | None = None # ── Blue Team update ─────────────────────────────────────────────── class TestBlueUpdate(BaseModel): """Fields that Blue Team fills in during the blue_evaluating phase.""" detection_result: TestResult | None = None blue_summary: str | None = None # ── Red Lead validation ──────────────────────────────────────────── class TestRedValidate(BaseModel): """Payload sent by Red Lead to approve/reject the red side.""" red_validation_status: str # "approved" or "rejected" red_validation_notes: str | None = None # ── Blue Lead validation ─────────────────────────────────────────── class TestBlueValidate(BaseModel): """Payload sent by Blue Lead to approve/reject the blue side.""" blue_validation_status: str # "approved" or "rejected" blue_validation_notes: str | None = None # ── Remediation update ──────────────────────────────────────────── class TestRemediationUpdate(BaseModel): """Payload for updating remediation fields.""" remediation_steps: str | None = None remediation_status: str | None = None # pending / in_progress / completed / not_applicable remediation_assignee: uuid.UUID | None = None # ── Legacy validate (kept for backwards compat) ──────────────────── class TestValidate(BaseModel): """Payload sent by a reviewer to validate / reject a test.""" result: TestResult comments: str | None = None # ── Read (full) ───────────────────────────────────────────────────── class TestOut(BaseModel): """Complete representation returned by the API.""" id: uuid.UUID technique_id: uuid.UUID name: str description: str | None = None platform: str | None = None procedure_text: str | None = None tool_used: str | None = None execution_date: datetime | None = None created_by: uuid.UUID | None = None result: TestResult | None = None state: TestState = TestState.draft created_at: datetime | None = None # Red Team fields red_summary: str | None = None attack_success: bool | None = None red_validated_by: uuid.UUID | None = None red_validated_at: datetime | None = None red_validation_status: str | None = None red_validation_notes: str | None = None # Blue Team fields blue_summary: str | None = None detection_result: TestResult | None = None blue_validated_by: uuid.UUID | None = None blue_validated_at: datetime | None = None blue_validation_status: str | None = None blue_validation_notes: str | None = None # Phase timing fields (for Tempo worklogs) red_started_at: datetime | None = None blue_started_at: datetime | None = None blue_work_started_at: datetime | None = None paused_at: datetime | None = None red_paused_seconds: int = 0 blue_paused_seconds: int = 0 # Remediation fields remediation_steps: str | None = None remediation_status: str | None = None remediation_assignee: uuid.UUID | None = None # Re-test fields retest_of: uuid.UUID | None = None retest_count: int = 0 data_classification: str = "internal" # Technique info (populated when joined) technique_mitre_id: str | None = None technique_name: str | None = None # Evidences split by team (populated from the ORM relationship) red_evidences: list[EvidenceOut] = [] blue_evidences: list[EvidenceOut] = [] model_config = ConfigDict(from_attributes=True) @model_validator(mode="before") @classmethod def _populate_derived_fields(cls, obj): """Populate technique and evidence fields from ORM relationships. Uses ``@model_validator(mode='before')`` so it is called by Pydantic's internal Rust validation pipeline, including FastAPI's TypeAdapter path. A plain ``model_validate`` classmethod override is **not** invoked by FastAPI's response serialisation in Pydantic v2 — only registered validators are guaranteed to run. Evidences are only processed when the relationship was **explicitly loaded** (via joinedload or prior access). Accessing ``obj.evidences`` blindly on a session-expired ORM object triggers a lazy query that fails on mutation endpoints that do not joinload the relationship. We inspect ``obj.__dict__`` directly — SQLAlchemy stores loaded relationships there; if the key is absent the relationship is unloaded and we leave the lists empty (the frontend invalidates and refetches the detail endpoint, which *does* joinload). """ if not hasattr(obj, "__dict__"): return obj # Technique info (lazy-load is fine here: session is still open on GET) try: if hasattr(obj, "technique") and obj.technique is not None: obj.__dict__["technique_mitre_id"] = obj.technique.mitre_id obj.__dict__["technique_name"] = obj.technique.name except Exception: pass # DetachedInstanceError or similar — leave technique fields None # Only split evidences when they are already in memory (loaded via joinedload) raw_evs = obj.__dict__.get("evidences") if raw_evs is not None: red_evs: list[EvidenceOut] = [] blue_evs: list[EvidenceOut] = [] for ev in raw_evs: ev_out = EvidenceOut( id=ev.id, test_id=ev.test_id, file_name=ev.file_name, sha256_hash=ev.sha256_hash, uploaded_by=ev.uploaded_by, uploaded_at=ev.uploaded_at, team=ev.team, notes=ev.notes, download_url=f"/api/v1/evidence/{ev.id}/file", ) if ev.team and ev.team.value == "blue": blue_evs.append(ev_out) else: red_evs.append(ev_out) obj.__dict__["red_evidences"] = red_evs obj.__dict__["blue_evidences"] = blue_evs return obj