From 633c8e46ad632c19f4e89c3149b046f332905459 Mon Sep 17 00:00:00 2001 From: Kitos Date: Wed, 18 Feb 2026 13:54:01 +0100 Subject: [PATCH] refactor(workflow): delegate transition_state to TestEntity transition_state() now hydrates a TestEntity from the ORM model and delegates state validation to entity.transition_to(). The entity is authoritative for which transitions are valid; VALID_TRANSITIONS and can_transition() are kept for backward compatibility. Also adds public transition_to() method to TestEntity as the stable API surface for callers that need a single validated transition without lifecycle side-effects. --- backend/app/domain/test_entity.py | 331 ++++++++++++++++++ backend/app/services/test_workflow_service.py | 23 +- 2 files changed, 341 insertions(+), 13 deletions(-) create mode 100644 backend/app/domain/test_entity.py diff --git a/backend/app/domain/test_entity.py b/backend/app/domain/test_entity.py new file mode 100644 index 0000000..29eea1d --- /dev/null +++ b/backend/app/domain/test_entity.py @@ -0,0 +1,331 @@ +"""TestEntity — pure domain object for the test lifecycle state machine. + +This entity owns ALL state-transition logic and business rules for a +security test. It has **no** dependency on FastAPI, SQLAlchemy, or any +infrastructure concern. + +Usage:: + + entity = TestEntity.from_orm(test_orm_model) + entity.start_execution() # draft → red_executing + entity.submit_red_evidence() # red_executing → blue_evaluating + entity.pause_timer() + entity.resume_timer() + entity.submit_blue_evidence() # blue_evaluating → in_review + entity.validate_red("approved") + entity.validate_blue("approved") # triggers dual-validation → validated + entity.reopen() # rejected → draft + +After mutations, the service layer copies ``entity.changes`` back onto +the ORM model and persists via Unit of Work. +""" + +from __future__ import annotations + +import enum +import uuid +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any + +from app.domain.errors import BusinessRuleViolation, InvalidStateTransition + + +# ── Value objects ──────────────────────────────────────────────────── + + +class TestState(str, enum.Enum): + draft = "draft" + red_executing = "red_executing" + blue_evaluating = "blue_evaluating" + in_review = "in_review" + validated = "validated" + rejected = "rejected" + + +VALID_TRANSITIONS: dict[TestState, list[TestState]] = { + TestState.draft: [TestState.red_executing], + TestState.red_executing: [TestState.blue_evaluating], + TestState.blue_evaluating: [TestState.in_review], + TestState.in_review: [TestState.validated, TestState.rejected], + TestState.rejected: [TestState.draft], + TestState.validated: [], +} + +_PAUSABLE_STATES = frozenset({TestState.red_executing, TestState.blue_evaluating}) + + +# ── Domain events (lightweight records of what happened) ───────────── + + +@dataclass(frozen=True) +class DomainEvent: + name: str + payload: dict[str, Any] = field(default_factory=dict) + + +# ── Entity ─────────────────────────────────────────────────────────── + + +@dataclass +class TestEntity: + """Pure domain representation of a security test.""" + + id: uuid.UUID + state: TestState + + # Red validation + red_validation_status: str | None = None + red_validated_by: uuid.UUID | None = None + red_validated_at: datetime | None = None + red_validation_notes: str | None = None + + # Blue validation + blue_validation_status: str | None = None + blue_validated_by: uuid.UUID | None = None + blue_validated_at: datetime | None = None + blue_validation_notes: str | None = None + + # Phase timing + execution_date: datetime | None = None + red_started_at: datetime | None = None + blue_started_at: datetime | None = None + paused_at: datetime | None = None + red_paused_seconds: int = 0 + blue_paused_seconds: int = 0 + + # Internal bookkeeping (not persisted as-is) + _events: list[DomainEvent] = field(default_factory=list, repr=False) + + # -- Factory -------------------------------------------------------- + + @classmethod + def from_orm(cls, model: Any) -> TestEntity: + """Build a TestEntity from a SQLAlchemy ``Test`` model instance.""" + raw_state = model.state + state = raw_state if isinstance(raw_state, TestState) else TestState(raw_state) + return cls( + id=model.id, + state=state, + red_validation_status=model.red_validation_status, + red_validated_by=model.red_validated_by, + red_validated_at=model.red_validated_at, + red_validation_notes=model.red_validation_notes, + blue_validation_status=model.blue_validation_status, + blue_validated_by=model.blue_validated_by, + blue_validated_at=model.blue_validated_at, + blue_validation_notes=model.blue_validation_notes, + execution_date=model.execution_date, + red_started_at=model.red_started_at, + blue_started_at=model.blue_started_at, + paused_at=model.paused_at, + red_paused_seconds=model.red_paused_seconds or 0, + blue_paused_seconds=model.blue_paused_seconds or 0, + ) + + def apply_to(self, model: Any) -> None: + """Copy the entity's mutable fields back onto the ORM model.""" + model.state = self.state + model.red_validation_status = self.red_validation_status + model.red_validated_by = self.red_validated_by + model.red_validated_at = self.red_validated_at + model.red_validation_notes = self.red_validation_notes + model.blue_validation_status = self.blue_validation_status + model.blue_validated_by = self.blue_validated_by + model.blue_validated_at = self.blue_validated_at + model.blue_validation_notes = self.blue_validation_notes + model.execution_date = self.execution_date + model.red_started_at = self.red_started_at + model.blue_started_at = self.blue_started_at + model.paused_at = self.paused_at + model.red_paused_seconds = self.red_paused_seconds + model.blue_paused_seconds = self.blue_paused_seconds + + # -- Query helpers -------------------------------------------------- + + @property + def events(self) -> list[DomainEvent]: + return list(self._events) + + def can_transition(self, target: TestState) -> bool: + return target in VALID_TRANSITIONS.get(self.state, []) + + @property + def is_terminal(self) -> bool: + return self.state == TestState.validated + + # -- Core transition ------------------------------------------------ + + def transition_to(self, target: TestState | str) -> str: + """Validate and apply a state transition. + + Accepts either a :class:`TestState` member or its string value + (so callers using ``models.enums.TestState`` work transparently). + + Returns the *previous* state value as a plain string. + + Raises :class:`InvalidStateTransition` when the move is illegal. + """ + resolved = target if isinstance(target, TestState) else TestState(str(target)) + return self._transition(resolved) + + def _transition(self, target: TestState) -> str: + """Internal: validate and apply; return previous state value.""" + if not self.can_transition(target): + valid = [s.value for s in VALID_TRANSITIONS.get(self.state, [])] + raise InvalidStateTransition( + current_state=self.state.value, + target_state=target.value, + valid_transitions=valid, + ) + previous = self.state.value + self.state = target + self._events.append(DomainEvent( + "state_changed", + {"previous": previous, "new": target.value}, + )) + return previous + + # -- Lifecycle commands -------------------------------------------- + + def start_execution(self) -> None: + """``draft`` -> ``red_executing``.""" + self._transition(TestState.red_executing) + now = datetime.utcnow() + self.execution_date = now + self.red_started_at = now + self._events.append(DomainEvent("execution_started")) + + def submit_red_evidence(self) -> int: + """``red_executing`` -> ``blue_evaluating``. + + Auto-resumes if paused. Returns paused seconds accumulated + during this phase (for worklog calculation). + """ + paused_extra = self._auto_resume() + self._transition(TestState.blue_evaluating) + total_paused = self.red_paused_seconds + paused_extra + self.blue_started_at = datetime.utcnow() + self.blue_paused_seconds = 0 + self._events.append(DomainEvent( + "red_evidence_submitted", + {"red_paused_seconds": total_paused}, + )) + return total_paused + + def submit_blue_evidence(self) -> int: + """``blue_evaluating`` -> ``in_review``. + + Auto-resumes if paused. Returns paused seconds accumulated + during this phase (for worklog calculation). + """ + paused_extra = self._auto_resume() + self._transition(TestState.in_review) + total_paused = self.blue_paused_seconds + paused_extra + self._events.append(DomainEvent( + "blue_evidence_submitted", + {"blue_paused_seconds": total_paused}, + )) + return total_paused + + def pause_timer(self) -> None: + """Pause the active phase timer.""" + if self.state not in _PAUSABLE_STATES: + raise BusinessRuleViolation( + f"Cannot pause timer in '{self.state.value}' state" + ) + if self.paused_at is not None: + raise BusinessRuleViolation("Timer is already paused") + self.paused_at = datetime.utcnow() + self._events.append(DomainEvent("timer_paused")) + + def resume_timer(self) -> int: + """Resume a paused timer. Returns seconds that were paused.""" + if self.paused_at is None: + raise BusinessRuleViolation("Timer is not paused") + now = datetime.utcnow() + paused_seconds = max(int((now - self.paused_at).total_seconds()), 0) + if self.state == TestState.red_executing: + self.red_paused_seconds += paused_seconds + elif self.state == TestState.blue_evaluating: + self.blue_paused_seconds += paused_seconds + self.paused_at = None + self._events.append(DomainEvent("timer_resumed", {"paused_seconds": paused_seconds})) + return paused_seconds + + def validate_red(self, status: str, *, by: uuid.UUID, notes: str | None = None) -> None: + """Record Red Lead's validation decision.""" + self._assert_in_review("red") + self._assert_valid_vote(status) + now = datetime.utcnow() + self.red_validation_status = status + self.red_validated_by = by + self.red_validated_at = now + self.red_validation_notes = notes + self._events.append(DomainEvent("red_validated", {"status": status})) + self._check_dual_validation() + + def validate_blue(self, status: str, *, by: uuid.UUID, notes: str | None = None) -> None: + """Record Blue Lead's validation decision.""" + self._assert_in_review("blue") + self._assert_valid_vote(status) + now = datetime.utcnow() + self.blue_validation_status = status + self.blue_validated_by = by + self.blue_validated_at = now + self.blue_validation_notes = notes + self._events.append(DomainEvent("blue_validated", {"status": status})) + self._check_dual_validation() + + def reopen(self) -> None: + """``rejected`` -> ``draft``, clearing all validation/timing fields.""" + self._transition(TestState.draft) + self.red_validation_status = None + self.red_validated_by = None + self.red_validated_at = None + self.red_validation_notes = None + self.blue_validation_status = None + self.blue_validated_by = None + self.blue_validated_at = None + self.blue_validation_notes = None + self.red_started_at = None + self.blue_started_at = None + self.paused_at = None + self.red_paused_seconds = 0 + self.blue_paused_seconds = 0 + self._events.append(DomainEvent("test_reopened")) + + # -- Private ------------------------------------------------------- + + def _auto_resume(self) -> int: + """If paused, accumulate pause time and clear. Returns extra seconds.""" + if self.paused_at is None: + return 0 + now = datetime.utcnow() + extra = max(int((now - self.paused_at).total_seconds()), 0) + self.paused_at = None + return extra + + def _assert_in_review(self, side: str) -> None: + if self.state != TestState.in_review: + raise BusinessRuleViolation( + f"Cannot validate {side} side while test is in " + f"'{self.state.value}' state (must be in_review)" + ) + + @staticmethod + def _assert_valid_vote(status: str) -> None: + if status not in ("approved", "rejected"): + raise BusinessRuleViolation( + "validation_status must be 'approved' or 'rejected'" + ) + + def _check_dual_validation(self) -> None: + """If both leads have voted, advance to validated or rejected.""" + r, b = self.red_validation_status, self.blue_validation_status + if r == "rejected" or b == "rejected": + self.state = TestState.rejected + self._events.append(DomainEvent("dual_validation_rejected")) + elif r == "approved" and b == "approved": + self.state = TestState.validated + self._events.append(DomainEvent("dual_validation_approved")) diff --git a/backend/app/services/test_workflow_service.py b/backend/app/services/test_workflow_service.py index 4e3ead6..b1053e1 100644 --- a/backend/app/services/test_workflow_service.py +++ b/backend/app/services/test_workflow_service.py @@ -19,6 +19,7 @@ from sqlalchemy.orm import Session from app.config import settings from app.domain.exceptions import InvalidOperationError, InvalidTransitionError +from app.domain.test_entity import TestEntity from app.models.enums import TestState from app.models.test import Test from app.models.user import User @@ -61,21 +62,18 @@ def transition_state( action_name: str = "transition_state", extra_details: dict | None = None, ) -> Test: - """Validate and perform a state transition, log it, and commit. + """Validate and perform a state transition, log it, and flush. - Raises :class:`InvalidTransitionError` when the transition is invalid. + Delegates validation to :class:`TestEntity` which raises + :class:`InvalidStateTransition` (aliased as ``InvalidTransitionError``) + when the transition is illegal. The entity is authoritative for which + transitions are valid; the module-level ``VALID_TRANSITIONS`` dict is + kept temporarily for backward compatibility of ``can_transition()``. """ - if not can_transition(test, target_state): - current = test.state if isinstance(test.state, TestState) else TestState(test.state) - valid = [s.value for s in VALID_TRANSITIONS.get(current, [])] - raise InvalidTransitionError( - current_state=current.value, - target_state=target_state.value, - valid_transitions=valid, - ) + entity = TestEntity.from_orm(test) + previous_state = entity.transition_to(target_state) - previous_state = test.state.value if isinstance(test.state, TestState) else test.state - test.state = target_state + test.state = entity.state db.flush() details: dict = { @@ -96,7 +94,6 @@ def transition_state( details=details, ) - # Dispatch in-app notifications for the new state try: notify_test_state_change(db, test, target_state.value) except Exception as e: