456 lines
14 KiB
Python
456 lines
14 KiB
Python
"""Test workflow service — state-machine transitions for the Red/Blue validation flow.
|
|
|
|
Controls which state transitions are valid and exposes high-level helpers
|
|
for each step in the test lifecycle:
|
|
|
|
draft → red_executing → blue_evaluating → in_review → validated / rejected
|
|
↓
|
|
rejected → draft
|
|
|
|
Every public function validates the transition, mutates the test, writes an
|
|
audit-log entry, and commits the session.
|
|
"""
|
|
|
|
from datetime import datetime
|
|
|
|
from fastapi import HTTPException, status
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.config import settings
|
|
from app.models.enums import TestState
|
|
from app.models.test import Test
|
|
from app.models.user import User
|
|
from app.services.audit_service import log_action
|
|
from app.services.notification_service import notify_test_state_change, create_notification
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Valid transition map
|
|
# ---------------------------------------------------------------------------
|
|
|
|
VALID_TRANSITIONS: dict[TestState, list[TestState]] = {
|
|
TestState.draft: [TestState.red_executing],
|
|
TestState.red_executing: [TestState.blue_evaluating],
|
|
TestState.blue_evaluating: [TestState.in_review],
|
|
TestState.in_review: [TestState.validated, TestState.rejected],
|
|
TestState.rejected: [TestState.draft],
|
|
TestState.validated: [], # terminal state
|
|
}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Core helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def can_transition(test: Test, target_state: TestState) -> bool:
|
|
"""Return *True* if moving *test* to *target_state* is allowed."""
|
|
current = test.state if isinstance(test.state, TestState) else TestState(test.state)
|
|
return target_state in VALID_TRANSITIONS.get(current, [])
|
|
|
|
|
|
def transition_state(
|
|
db: Session,
|
|
test: Test,
|
|
target_state: TestState,
|
|
user: User,
|
|
*,
|
|
action_name: str = "transition_state",
|
|
extra_details: dict | None = None,
|
|
) -> Test:
|
|
"""Validate and perform a state transition, log it, and commit.
|
|
|
|
Raises :class:`~fastapi.HTTPException` 400 when the transition is invalid.
|
|
"""
|
|
if not can_transition(test, target_state):
|
|
current = test.state if isinstance(test.state, TestState) else TestState(test.state)
|
|
valid = [s.value for s in VALID_TRANSITIONS.get(current, [])]
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail={
|
|
"message": (
|
|
f"Cannot transition from '{current.value}' to '{target_state.value}'. "
|
|
f"Valid transitions: {valid}"
|
|
),
|
|
"code": "INVALID_TRANSITION",
|
|
"current_state": current.value,
|
|
"target_state": target_state.value,
|
|
"valid_transitions": valid,
|
|
},
|
|
)
|
|
|
|
previous_state = test.state.value if isinstance(test.state, TestState) else test.state
|
|
test.state = target_state
|
|
db.flush()
|
|
|
|
details: dict = {
|
|
"previous_state": previous_state,
|
|
"new_state": target_state.value,
|
|
"test_name": test.name,
|
|
"technique_id": str(test.technique_id),
|
|
}
|
|
if extra_details:
|
|
details.update(extra_details)
|
|
|
|
log_action(
|
|
db,
|
|
user_id=user.id,
|
|
action=action_name,
|
|
entity_type="test",
|
|
entity_id=test.id,
|
|
details=details,
|
|
)
|
|
|
|
# Dispatch in-app notifications for the new state
|
|
try:
|
|
notify_test_state_change(db, test, target_state.value)
|
|
except Exception:
|
|
pass # Notifications are best-effort — don't block the workflow
|
|
|
|
return test
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Lifecycle convenience functions
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def start_execution(db: Session, test: Test, user: User) -> Test:
|
|
"""Move from ``draft`` → ``red_executing``.
|
|
|
|
Typically called by a **red_tech** when they begin the attack.
|
|
"""
|
|
test = transition_state(
|
|
db, test, TestState.red_executing, user,
|
|
action_name="start_execution",
|
|
)
|
|
test.execution_date = datetime.utcnow()
|
|
db.commit()
|
|
return test
|
|
|
|
|
|
def submit_red_evidence(db: Session, test: Test, user: User) -> Test:
|
|
"""Move from ``red_executing`` → ``blue_evaluating``.
|
|
|
|
Called by **red_tech** once they have finished documenting the attack.
|
|
"""
|
|
test = transition_state(
|
|
db, test, TestState.blue_evaluating, user,
|
|
action_name="submit_red_evidence",
|
|
)
|
|
db.commit()
|
|
return test
|
|
|
|
|
|
def submit_blue_evidence(db: Session, test: Test, user: User) -> Test:
|
|
"""Move from ``blue_evaluating`` → ``in_review``.
|
|
|
|
Called by **blue_tech** once they have finished documenting detection.
|
|
"""
|
|
test = transition_state(
|
|
db, test, TestState.in_review, user,
|
|
action_name="submit_blue_evidence",
|
|
)
|
|
db.commit()
|
|
return test
|
|
|
|
|
|
def validate_as_red_lead(
|
|
db: Session,
|
|
test: Test,
|
|
user: User,
|
|
validation_status: str,
|
|
notes: str | None = None,
|
|
) -> Test:
|
|
"""Record Red Lead's validation decision.
|
|
|
|
*validation_status* must be ``"approved"`` or ``"rejected"``.
|
|
After recording the decision, :func:`check_dual_validation` is called
|
|
to potentially advance the test to ``validated`` or ``rejected``.
|
|
"""
|
|
current = test.state.value if isinstance(test.state, TestState) else test.state
|
|
if test.state not in (TestState.in_review,):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail={
|
|
"message": f"Cannot validate red side while test is in '{current}' state (must be in_review)",
|
|
"code": "INVALID_STATE",
|
|
"current_state": current,
|
|
},
|
|
)
|
|
|
|
if validation_status not in ("approved", "rejected"):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail={
|
|
"message": "validation_status must be 'approved' or 'rejected'",
|
|
"code": "INVALID_VALIDATION_STATUS",
|
|
},
|
|
)
|
|
|
|
now = datetime.utcnow()
|
|
test.red_validation_status = validation_status
|
|
test.red_validated_by = user.id
|
|
test.red_validated_at = now
|
|
test.red_validation_notes = notes
|
|
|
|
log_action(
|
|
db,
|
|
user_id=user.id,
|
|
action="validate_as_red_lead",
|
|
entity_type="test",
|
|
entity_id=test.id,
|
|
details={
|
|
"validation_status": validation_status,
|
|
"notes": notes,
|
|
"technique_id": str(test.technique_id),
|
|
},
|
|
)
|
|
|
|
check_dual_validation(db, test)
|
|
return test
|
|
|
|
|
|
def validate_as_blue_lead(
|
|
db: Session,
|
|
test: Test,
|
|
user: User,
|
|
validation_status: str,
|
|
notes: str | None = None,
|
|
) -> Test:
|
|
"""Record Blue Lead's validation decision.
|
|
|
|
*validation_status* must be ``"approved"`` or ``"rejected"``.
|
|
After recording the decision, :func:`check_dual_validation` is called
|
|
to potentially advance the test to ``validated`` or ``rejected``.
|
|
"""
|
|
current = test.state.value if isinstance(test.state, TestState) else test.state
|
|
if test.state not in (TestState.in_review,):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail={
|
|
"message": f"Cannot validate blue side while test is in '{current}' state (must be in_review)",
|
|
"code": "INVALID_STATE",
|
|
"current_state": current,
|
|
},
|
|
)
|
|
|
|
if validation_status not in ("approved", "rejected"):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail={
|
|
"message": "validation_status must be 'approved' or 'rejected'",
|
|
"code": "INVALID_VALIDATION_STATUS",
|
|
},
|
|
)
|
|
|
|
now = datetime.utcnow()
|
|
test.blue_validation_status = validation_status
|
|
test.blue_validated_by = user.id
|
|
test.blue_validated_at = now
|
|
test.blue_validation_notes = notes
|
|
|
|
log_action(
|
|
db,
|
|
user_id=user.id,
|
|
action="validate_as_blue_lead",
|
|
entity_type="test",
|
|
entity_id=test.id,
|
|
details={
|
|
"validation_status": validation_status,
|
|
"notes": notes,
|
|
"technique_id": str(test.technique_id),
|
|
},
|
|
)
|
|
|
|
check_dual_validation(db, test)
|
|
return test
|
|
|
|
|
|
def check_dual_validation(db: Session, test: Test) -> Test:
|
|
"""Evaluate both leads' decisions and advance the test if both have voted.
|
|
|
|
- Both **approved** → ``validated``
|
|
- Either **rejected** → ``rejected``
|
|
- Otherwise no state change (waiting for the other lead).
|
|
|
|
Commits only when the state actually changes.
|
|
"""
|
|
red_status = test.red_validation_status
|
|
blue_status = test.blue_validation_status
|
|
|
|
if red_status == "rejected" or blue_status == "rejected":
|
|
test.state = TestState.rejected
|
|
db.commit()
|
|
try:
|
|
notify_test_state_change(db, test, "rejected")
|
|
except Exception:
|
|
pass
|
|
elif red_status == "approved" and blue_status == "approved":
|
|
test.state = TestState.validated
|
|
db.commit()
|
|
# Invalidate cached scores — a validation changes org-level numbers
|
|
try:
|
|
from app.services.score_cache import invalidate
|
|
invalidate()
|
|
except Exception:
|
|
pass
|
|
try:
|
|
notify_test_state_change(db, test, "validated")
|
|
except Exception:
|
|
pass
|
|
else:
|
|
# One side hasn't voted yet — stay in_review, just flush
|
|
db.commit()
|
|
|
|
return test
|
|
|
|
|
|
def handle_remediation_completed(db: Session, test: Test, user: User) -> Test | None:
|
|
"""Create a re-test when remediation is completed.
|
|
|
|
When a test's remediation_status changes to 'completed', this function
|
|
creates a new test (retest) with the same base data to verify that the
|
|
fix was effective.
|
|
|
|
Prevents infinite loops by enforcing ``MAX_RETEST_COUNT``.
|
|
|
|
Returns the new retest or *None* if the limit was reached.
|
|
"""
|
|
# Always reference the original test, not an intermediate retest
|
|
original_test_id = test.retest_of or test.id
|
|
|
|
if test.retest_count >= settings.MAX_RETEST_COUNT:
|
|
# Max retests reached — notify and bail out
|
|
if test.created_by:
|
|
create_notification(
|
|
db,
|
|
user_id=test.created_by,
|
|
type="max_retests_reached",
|
|
title="Maximum retests reached",
|
|
message=(
|
|
f'Test "{test.name}" has reached the maximum of '
|
|
f'{settings.MAX_RETEST_COUNT} retests. Manual review required.'
|
|
),
|
|
entity_type="test",
|
|
entity_id=test.id,
|
|
)
|
|
|
|
log_action(
|
|
db,
|
|
user_id=user.id,
|
|
action="max_retests_reached",
|
|
entity_type="test",
|
|
entity_id=test.id,
|
|
details={
|
|
"retest_count": test.retest_count,
|
|
"max_allowed": settings.MAX_RETEST_COUNT,
|
|
"original_test_id": str(original_test_id),
|
|
},
|
|
)
|
|
return None
|
|
|
|
retest = Test(
|
|
technique_id=test.technique_id,
|
|
name=f"[Retest #{test.retest_count + 1}] {test.name.replace(f'[Retest #{test.retest_count}] ', '')}",
|
|
description=test.description,
|
|
platform=test.platform,
|
|
procedure_text=test.procedure_text,
|
|
tool_used=test.tool_used,
|
|
state=TestState.draft,
|
|
created_by=test.created_by,
|
|
retest_of=original_test_id,
|
|
retest_count=test.retest_count + 1,
|
|
)
|
|
db.add(retest)
|
|
db.flush()
|
|
|
|
log_action(
|
|
db,
|
|
user_id=user.id,
|
|
action="create_retest",
|
|
entity_type="test",
|
|
entity_id=retest.id,
|
|
details={
|
|
"original_test_id": str(original_test_id),
|
|
"retest_number": retest.retest_count,
|
|
"source_test_id": str(test.id),
|
|
},
|
|
)
|
|
|
|
# Notify the test creator and any red_tech users
|
|
if test.created_by:
|
|
create_notification(
|
|
db,
|
|
user_id=test.created_by,
|
|
type="retest_created",
|
|
title="Re-test created",
|
|
message=(
|
|
f'A re-test has been automatically created for "{test.name}" '
|
|
f'after remediation was completed.'
|
|
),
|
|
entity_type="test",
|
|
entity_id=retest.id,
|
|
)
|
|
|
|
db.commit()
|
|
db.refresh(retest)
|
|
return retest
|
|
|
|
|
|
def get_retest_chain(db: Session, test_id) -> list[Test]:
|
|
"""Return the full chain of retests for a given test.
|
|
|
|
Includes the original test and all subsequent retests, ordered
|
|
by retest_count.
|
|
"""
|
|
import uuid as _uuid
|
|
|
|
tid = _uuid.UUID(str(test_id)) if not isinstance(test_id, _uuid.UUID) else test_id
|
|
|
|
# Find the original test first
|
|
test = db.query(Test).filter(Test.id == tid).first()
|
|
if not test:
|
|
return []
|
|
|
|
original_id = test.retest_of or test.id
|
|
|
|
# Get original
|
|
original = db.query(Test).filter(Test.id == original_id).first()
|
|
if not original:
|
|
return [test]
|
|
|
|
# Get all retests of the original
|
|
retests = (
|
|
db.query(Test)
|
|
.filter(Test.retest_of == original_id)
|
|
.order_by(Test.retest_count)
|
|
.all()
|
|
)
|
|
|
|
return [original] + retests
|
|
|
|
|
|
def reopen_test(db: Session, test: Test, user: User) -> Test:
|
|
"""Move a ``rejected`` test back to ``draft``, clearing validation fields.
|
|
|
|
This allows the teams to redo the test cycle.
|
|
"""
|
|
test = transition_state(
|
|
db, test, TestState.draft, user,
|
|
action_name="reopen_test",
|
|
)
|
|
|
|
# Clear dual-validation fields
|
|
test.red_validation_status = None
|
|
test.red_validated_by = None
|
|
test.red_validated_at = None
|
|
test.red_validation_notes = None
|
|
|
|
test.blue_validation_status = None
|
|
test.blue_validated_by = None
|
|
test.blue_validated_at = None
|
|
test.blue_validation_notes = None
|
|
|
|
db.commit()
|
|
return test
|