Files
Aegis/backend/app/schemas/test.py
kitos 0830b36cd6
Some checks failed
Aegis CI / lint-and-test (push) Has been cancelled
fix(schemas): avoid lazy-load in TestOut.model_validate
Accessing obj.evidences on a session-expired ORM object (mutation endpoints
do commit+refresh without joinload) triggers a lazy query that fails or
returns stale data. Use obj.__dict__.get('evidences') instead — SQLAlchemy
stores joinloaded relationships in __dict__; absent means not loaded.

Mutation endpoints (submit-red, submit-blue, etc.) return empty evidence
lists, which is fine: the frontend invalidates and refetches GET /tests/{id},
which uses joinedload and correctly populates red_evidences / blue_evidences.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-28 12:06:34 +02:00

217 lines
7.7 KiB
Python

"""Pydantic schemas for Test endpoints."""
import uuid
from datetime import datetime
from pydantic import BaseModel, ConfigDict
from app.domain.enums import DataClassification
from app.models.enums import TestResult, TestState
from app.schemas.evidence import EvidenceOut
# ── Create ──────────────────────────────────────────────────────────
class TestCreate(BaseModel):
"""Payload for creating a new test."""
technique_id: uuid.UUID
name: str
description: str | None = None
platform: str | None = None
procedure_text: str | None = None
tool_used: str | None = None
# ── Update (general) ───────────────────────────────────────────────
class TestClassificationUpdate(BaseModel):
"""Admin-only payload for changing data classification."""
data_classification: DataClassification
class TestUpdate(BaseModel):
"""Payload for partially updating an existing test.
Every field is optional so callers send only what changed."""
name: str | None = None
description: str | None = None
platform: str | None = None
procedure_text: str | None = None
tool_used: str | None = None
result: TestResult | None = None
# ── Red Team update ────────────────────────────────────────────────
class TestRedUpdate(BaseModel):
"""Fields that Red Team fills in during the red_executing phase."""
name: str | None = None
description: str | None = None
procedure_text: str | None = None
tool_used: str | None = None
attack_success: bool | None = None
red_summary: str | None = None
# ── Blue Team update ───────────────────────────────────────────────
class TestBlueUpdate(BaseModel):
"""Fields that Blue Team fills in during the blue_evaluating phase."""
detection_result: TestResult | None = None
blue_summary: str | None = None
# ── Red Lead validation ────────────────────────────────────────────
class TestRedValidate(BaseModel):
"""Payload sent by Red Lead to approve/reject the red side."""
red_validation_status: str # "approved" or "rejected"
red_validation_notes: str | None = None
# ── Blue Lead validation ───────────────────────────────────────────
class TestBlueValidate(BaseModel):
"""Payload sent by Blue Lead to approve/reject the blue side."""
blue_validation_status: str # "approved" or "rejected"
blue_validation_notes: str | None = None
# ── Remediation update ────────────────────────────────────────────
class TestRemediationUpdate(BaseModel):
"""Payload for updating remediation fields."""
remediation_steps: str | None = None
remediation_status: str | None = None # pending / in_progress / completed / not_applicable
remediation_assignee: uuid.UUID | None = None
# ── Legacy validate (kept for backwards compat) ────────────────────
class TestValidate(BaseModel):
"""Payload sent by a reviewer to validate / reject a test."""
result: TestResult
comments: str | None = None
# ── Read (full) ─────────────────────────────────────────────────────
class TestOut(BaseModel):
"""Complete representation returned by the API."""
id: uuid.UUID
technique_id: uuid.UUID
name: str
description: str | None = None
platform: str | None = None
procedure_text: str | None = None
tool_used: str | None = None
execution_date: datetime | None = None
created_by: uuid.UUID | None = None
result: TestResult | None = None
state: TestState = TestState.draft
created_at: datetime | None = None
# Red Team fields
red_summary: str | None = None
attack_success: bool | None = None
red_validated_by: uuid.UUID | None = None
red_validated_at: datetime | None = None
red_validation_status: str | None = None
red_validation_notes: str | None = None
# Blue Team fields
blue_summary: str | None = None
detection_result: TestResult | None = None
blue_validated_by: uuid.UUID | None = None
blue_validated_at: datetime | None = None
blue_validation_status: str | None = None
blue_validation_notes: str | None = None
# Phase timing fields (for Tempo worklogs)
red_started_at: datetime | None = None
blue_started_at: datetime | None = None
blue_work_started_at: datetime | None = None
paused_at: datetime | None = None
red_paused_seconds: int = 0
blue_paused_seconds: int = 0
# Remediation fields
remediation_steps: str | None = None
remediation_status: str | None = None
remediation_assignee: uuid.UUID | None = None
# Re-test fields
retest_of: uuid.UUID | None = None
retest_count: int = 0
data_classification: str = "internal"
# Technique info (populated when joined)
technique_mitre_id: str | None = None
technique_name: str | None = None
# Evidences split by team (populated from the ORM relationship)
red_evidences: list[EvidenceOut] = []
blue_evidences: list[EvidenceOut] = []
model_config = ConfigDict(from_attributes=True)
@classmethod
def model_validate(cls, obj, **kwargs):
"""Override to populate technique and evidence fields from ORM relationships.
Evidences are only processed when the relationship was **explicitly loaded**
(via joinedload or prior access). Accessing ``obj.evidences`` blindly on a
session-expired ORM object triggers a lazy query that fails on mutation
endpoints that do not joinload the relationship. We inspect ``obj.__dict__``
directly — SQLAlchemy stores loaded relationships there; if the key is absent
the relationship is unloaded and we leave the lists empty (the frontend
invalidates and refetches the detail endpoint, which *does* joinload).
"""
if hasattr(obj, "technique") and obj.technique is not None:
obj.__dict__["technique_mitre_id"] = obj.technique.mitre_id
obj.__dict__["technique_name"] = obj.technique.name
# Only split evidences when they are already in memory (loaded via joinedload)
raw_evs = obj.__dict__.get("evidences") if hasattr(obj, "__dict__") else None
if raw_evs is not None:
red_evs: list[EvidenceOut] = []
blue_evs: list[EvidenceOut] = []
for ev in raw_evs:
ev_out = EvidenceOut(
id=ev.id,
test_id=ev.test_id,
file_name=ev.file_name,
sha256_hash=ev.sha256_hash,
uploaded_by=ev.uploaded_by,
uploaded_at=ev.uploaded_at,
team=ev.team,
notes=ev.notes,
download_url=f"/api/v1/evidence/{ev.id}/file",
)
if ev.team and ev.team.value == "blue":
blue_evs.append(ev_out)
else:
red_evs.append(ev_out)
obj.__dict__["red_evidences"] = red_evs
obj.__dict__["blue_evidences"] = blue_evs
return super().model_validate(obj, **kwargs)