feat(phase-11): implement Red/Blue business logic services (T-106, T-107, T-108)

T-106: Create test_workflow_service.py with state-machine transitions for the complete test lifecycle (draft -> red_executing -> blue_evaluating -> in_review -> validated/rejected), dual validation by Red/Blue leads, and reopen capability with field cleanup.

T-107: Update status_service.py to use detection_result from Blue Team instead of legacy result field, and differentiate between partial progress (some validated) vs all-in-progress states.

T-108: Create atomic_import_service.py that downloads the Atomic Red Team repo as a ZIP (avoiding API rate limits), parses all atomics YAML files, and creates idempotent TestTemplate records mapped to MITRE techniques.

Includes validation tests for all three tasks (19 checks total).
This commit is contained in:
2026-02-09 09:58:54 +01:00
parent 086cc5c8bc
commit 7af6be10be
23 changed files with 2053 additions and 45 deletions

View File

@@ -2,12 +2,14 @@
from app.models.user import User
from app.models.technique import Technique
from app.models.test import Test
from app.models.test_template import TestTemplate
from app.models.evidence import Evidence
from app.models.intel import IntelItem
from app.models.audit import AuditLog
from app.models.enums import TechniqueStatus, TestState, TestResult
from app.models.enums import TechniqueStatus, TestState, TestResult, TeamSide
__all__ = [
"User", "Technique", "Test", "Evidence", "IntelItem", "AuditLog",
"TechniqueStatus", "TestState", "TestResult"
"User", "Technique", "Test", "TestTemplate", "Evidence",
"IntelItem", "AuditLog",
"TechniqueStatus", "TestState", "TestResult", "TeamSide",
]

View File

@@ -12,11 +12,18 @@ class TechniqueStatus(str, enum.Enum):
class TestState(str, enum.Enum):
draft = "draft"
red_executing = "red_executing" # Red Team documenting attack
blue_evaluating = "blue_evaluating" # Blue Team evaluating detection
in_review = "in_review"
validated = "validated"
rejected = "rejected"
class TeamSide(str, enum.Enum):
red = "red"
blue = "blue"
class TestResult(str, enum.Enum):
detected = "detected"
not_detected = "not_detected"

View File

@@ -1,11 +1,12 @@
import uuid
from datetime import datetime
from sqlalchemy import Column, String, DateTime, ForeignKey
from sqlalchemy import Column, String, Text, DateTime, ForeignKey, Enum
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship
from app.database import Base
from app.models.enums import TeamSide
class Evidence(Base):
@@ -14,6 +15,9 @@ class Evidence(Base):
Files are stored in MinIO, and this model tracks the file location,
integrity hash, and upload metadata.
The ``team`` field distinguishes whether this evidence was uploaded by
Red Team (attack evidence) or Blue Team (detection evidence).
"""
__tablename__ = "evidences"
@@ -24,6 +28,8 @@ class Evidence(Base):
sha256_hash = Column(String, nullable=False)
uploaded_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
uploaded_at = Column(DateTime, default=datetime.utcnow)
team = Column(Enum(TeamSide, name="teamside"), nullable=False, default=TeamSide.red)
notes = Column(Text, nullable=True)
# Relationships
test = relationship("Test", back_populates="evidences")

View File

@@ -1,7 +1,7 @@
import uuid
from datetime import datetime
from sqlalchemy import Column, String, Text, DateTime, ForeignKey, Enum
from sqlalchemy import Column, String, Text, Boolean, DateTime, ForeignKey, Enum
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship
@@ -12,12 +12,14 @@ from app.models.enums import TestState, TestResult
class Test(Base):
"""
Test model representing a security test for a MITRE ATT&CK technique.
Each test documents an attempt to validate coverage of a specific technique,
including the procedure, tools used, and outcome.
including the procedure, tools used, and outcome. V2 introduces dual
validation: Red Lead and Blue Lead must each approve independently.
"""
__tablename__ = "tests"
# ── Core fields ─────────────────────────────────────────────────
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
technique_id = Column(UUID(as_uuid=True), ForeignKey("techniques.id"), nullable=False)
name = Column(String, nullable=False)
@@ -29,12 +31,27 @@ class Test(Base):
created_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
result = Column(Enum(TestResult, name="testresult"), nullable=True)
state = Column(Enum(TestState, name="teststate"), default=TestState.draft)
validated_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
validated_at = Column(DateTime, nullable=True)
created_at = Column(DateTime, default=datetime.utcnow)
# Relationships
# ── Red Team fields ─────────────────────────────────────────────
red_summary = Column(Text, nullable=True)
attack_success = Column(Boolean, nullable=True)
red_validated_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
red_validated_at = Column(DateTime, nullable=True)
red_validation_status = Column(String, nullable=True) # pending / approved / rejected
red_validation_notes = Column(Text, nullable=True)
# ── Blue Team fields ────────────────────────────────────────────
blue_summary = Column(Text, nullable=True)
detection_result = Column(Enum(TestResult, name="testresult"), nullable=True)
blue_validated_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
blue_validated_at = Column(DateTime, nullable=True)
blue_validation_status = Column(String, nullable=True) # pending / approved / rejected
blue_validation_notes = Column(Text, nullable=True)
# ── Relationships ───────────────────────────────────────────────
technique = relationship("Technique", back_populates="tests")
evidences = relationship("Evidence", back_populates="test")
creator = relationship("User", foreign_keys=[created_by])
validator = relationship("User", foreign_keys=[validated_by])
red_validator = relationship("User", foreign_keys=[red_validated_by])
blue_validator = relationship("User", foreign_keys=[blue_validated_by])

View File

@@ -0,0 +1,45 @@
"""TestTemplate model — predefined test catalog entries."""
import uuid
from datetime import datetime
from sqlalchemy import Column, String, Text, Boolean, DateTime, Index
from sqlalchemy.dialects.postgresql import UUID
from app.database import Base
class TestTemplate(Base):
"""
Predefined test template mapped to a MITRE ATT&CK technique.
Templates come from several sources:
- **atomic_red_team**: Atomic Red Team by Red Canary
- **mitre**: MITRE ATT&CK procedure examples
- **custom**: Manually created by teams
Users can instantiate a real Test from a template.
"""
__tablename__ = "test_templates"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
mitre_technique_id = Column(String, nullable=False) # e.g. "T1059.001"
name = Column(String, nullable=False)
description = Column(Text, nullable=True)
source = Column(String, nullable=False) # atomic_red_team / mitre / custom
source_url = Column(String, nullable=True)
attack_procedure = Column(Text, nullable=True) # Suggested attack procedure
expected_detection = Column(Text, nullable=True) # What blue team should detect
platform = Column(String, nullable=True) # windows / linux / macos
tool_suggested = Column(String, nullable=True)
severity = Column(String, nullable=True) # low / medium / high / critical
atomic_test_id = Column(String, nullable=True) # ID in Atomic Red Team repo
is_active = Column(Boolean, default=True)
created_at = Column(DateTime, default=datetime.utcnow)
__table_args__ = (
Index('ix_test_templates_mitre_technique_id', 'mitre_technique_id'),
Index('ix_test_templates_source', 'source'),
Index('ix_test_templates_platform', 'platform'),
Index('ix_test_templates_severity', 'severity'),
)

View File

@@ -128,5 +128,7 @@ def _evidence_to_out(evidence: Evidence) -> EvidenceOut:
sha256_hash=evidence.sha256_hash,
uploaded_by=evidence.uploaded_by,
uploaded_at=evidence.uploaded_at,
team=evidence.team,
notes=evidence.notes,
download_url=get_presigned_url(evidence.file_path),
)

View File

@@ -166,10 +166,11 @@ def validate_test(
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
"""Mark a test as validated.
"""Validate the red or blue side of a test (dual validation).
Sets ``state`` to *validated*, records ``validated_by`` / ``validated_at``,
stores the ``result``, and recalculates the parent technique's global status.
Red Lead approves/rejects the red side; Blue Lead approves/rejects the
blue side. When *both* sides are approved the test state moves to
``validated``. If either side is rejected the state moves to ``rejected``.
"""
test = (
db.query(Test)
@@ -184,10 +185,39 @@ def validate_test(
detail="Test not found",
)
test.state = TestState.validated
now = datetime.utcnow()
if current_user.role in ("red_lead", "admin"):
test.red_validation_status = payload.result.value
test.red_validated_by = current_user.id
test.red_validated_at = now
side = "red"
elif current_user.role == "blue_lead":
test.blue_validation_status = payload.result.value
test.blue_validated_by = current_user.id
test.blue_validated_at = now
side = "blue"
else:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not enough permissions to validate",
)
# Store the overall result from the payload
test.result = payload.result
test.validated_by = current_user.id
test.validated_at = datetime.utcnow()
# Determine aggregate state
red_ok = test.red_validation_status == "approved"
blue_ok = test.blue_validation_status == "approved"
red_rej = test.red_validation_status == "rejected"
blue_rej = test.blue_validation_status == "rejected"
if red_ok and blue_ok:
test.state = TestState.validated
elif red_rej or blue_rej:
test.state = TestState.rejected
else:
test.state = TestState.in_review
db.commit()
db.refresh(test)
@@ -203,6 +233,7 @@ def validate_test(
entity_type="test",
entity_id=test.id,
details={
"side": side,
"result": payload.result.value,
"technique_id": str(test.technique_id),
},

View File

@@ -14,9 +14,20 @@ from app.schemas.test import (
TestOut,
TestUpdate,
TestValidate,
TestRedUpdate,
TestBlueUpdate,
TestRedValidate,
TestBlueValidate,
)
from app.schemas.evidence import EvidenceOut
from app.schemas.evidence import EvidenceOut, EvidenceUpload
from app.schemas.test_template import (
TestTemplateOut,
TestTemplateCreate,
TestTemplateSummary,
TestTemplateInstantiate,
)
__all__ = [
# Auth
@@ -33,6 +44,16 @@ __all__ = [
"TestOut",
"TestUpdate",
"TestValidate",
"TestRedUpdate",
"TestBlueUpdate",
"TestRedValidate",
"TestBlueValidate",
# Evidence
"EvidenceOut",
"EvidenceUpload",
# Test Template
"TestTemplateOut",
"TestTemplateCreate",
"TestTemplateSummary",
"TestTemplateInstantiate",
]

View File

@@ -5,6 +5,8 @@ from datetime import datetime
from pydantic import BaseModel, ConfigDict
from app.models.enums import TeamSide
class EvidenceOut(BaseModel):
"""Representation of an evidence record returned by the API.
@@ -18,6 +20,15 @@ class EvidenceOut(BaseModel):
sha256_hash: str
uploaded_by: uuid.UUID | None = None
uploaded_at: datetime | None = None
team: TeamSide = TeamSide.red
notes: str | None = None
download_url: str | None = None
model_config = ConfigDict(from_attributes=True)
class EvidenceUpload(BaseModel):
"""Metadata sent alongside an evidence file upload."""
team: TeamSide
notes: str | None = None

View File

@@ -10,6 +10,7 @@ from app.models.enums import TestResult, TestState
# ── Create ──────────────────────────────────────────────────────────
class TestCreate(BaseModel):
"""Payload for creating a new test."""
@@ -21,7 +22,8 @@ class TestCreate(BaseModel):
tool_used: str | None = None
# ── Update ──────────────────────────────────────────────────────────
# ── Update (general) ───────────────────────────────────────────────
class TestUpdate(BaseModel):
"""Payload for partially updating an existing test.
@@ -35,8 +37,63 @@ class TestUpdate(BaseModel):
result: TestResult | None = None
# ── Red Team update ────────────────────────────────────────────────
class TestRedUpdate(BaseModel):
"""Fields that Red Team fills in during the red_executing phase."""
name: str | None = None
description: str | None = None
procedure_text: str | None = None
tool_used: str | None = None
attack_success: bool | None = None
red_summary: str | None = None
# ── Blue Team update ───────────────────────────────────────────────
class TestBlueUpdate(BaseModel):
"""Fields that Blue Team fills in during the blue_evaluating phase."""
detection_result: TestResult | None = None
blue_summary: str | None = None
# ── Red Lead validation ────────────────────────────────────────────
class TestRedValidate(BaseModel):
"""Payload sent by Red Lead to approve/reject the red side."""
red_validation_status: str # "approved" or "rejected"
red_validation_notes: str | None = None
# ── Blue Lead validation ───────────────────────────────────────────
class TestBlueValidate(BaseModel):
"""Payload sent by Blue Lead to approve/reject the blue side."""
blue_validation_status: str # "approved" or "rejected"
blue_validation_notes: str | None = None
# ── Legacy validate (kept for backwards compat) ────────────────────
class TestValidate(BaseModel):
"""Payload sent by a reviewer to validate / reject a test."""
result: TestResult
comments: str | None = None
# ── Read (full) ─────────────────────────────────────────────────────
class TestOut(BaseModel):
"""Complete representation returned by the API."""
@@ -51,17 +108,22 @@ class TestOut(BaseModel):
created_by: uuid.UUID | None = None
result: TestResult | None = None
state: TestState = TestState.draft
validated_by: uuid.UUID | None = None
validated_at: datetime | None = None
created_at: datetime | None = None
# Red Team fields
red_summary: str | None = None
attack_success: bool | None = None
red_validated_by: uuid.UUID | None = None
red_validated_at: datetime | None = None
red_validation_status: str | None = None
red_validation_notes: str | None = None
# Blue Team fields
blue_summary: str | None = None
detection_result: TestResult | None = None
blue_validated_by: uuid.UUID | None = None
blue_validated_at: datetime | None = None
blue_validation_status: str | None = None
blue_validation_notes: str | None = None
model_config = ConfigDict(from_attributes=True)
# ── Validate ────────────────────────────────────────────────────────
class TestValidate(BaseModel):
"""Payload sent by a reviewer to validate / reject a test."""
result: TestResult
comments: str | None = None

View File

@@ -0,0 +1,75 @@
"""Pydantic schemas for TestTemplate endpoints."""
import uuid
from datetime import datetime
from pydantic import BaseModel, ConfigDict
# ── Full output ─────────────────────────────────────────────────────
class TestTemplateOut(BaseModel):
"""Complete representation of a test template."""
id: uuid.UUID
mitre_technique_id: str
name: str
description: str | None = None
source: str
source_url: str | None = None
attack_procedure: str | None = None
expected_detection: str | None = None
platform: str | None = None
tool_suggested: str | None = None
severity: str | None = None
atomic_test_id: str | None = None
is_active: bool = True
created_at: datetime | None = None
model_config = ConfigDict(from_attributes=True)
# ── Create ──────────────────────────────────────────────────────────
class TestTemplateCreate(BaseModel):
"""Payload for creating a custom test template."""
mitre_technique_id: str
name: str
description: str | None = None
source: str = "custom"
source_url: str | None = None
attack_procedure: str | None = None
expected_detection: str | None = None
platform: str | None = None
tool_suggested: str | None = None
severity: str | None = None
atomic_test_id: str | None = None
# ── Summary (for listings) ─────────────────────────────────────────
class TestTemplateSummary(BaseModel):
"""Lightweight representation for listing templates."""
id: uuid.UUID
mitre_technique_id: str
name: str
source: str
platform: str | None = None
severity: str | None = None
model_config = ConfigDict(from_attributes=True)
# ── Instantiate (create a real Test from a template) ────────────────
class TestTemplateInstantiate(BaseModel):
"""Payload to create a real test from an existing template."""
template_id: uuid.UUID
technique_id: uuid.UUID

View File

@@ -0,0 +1,231 @@
"""Atomic Red Team import service.
Downloads the Atomic Red Team repository ZIP from GitHub, parses every
``atomics/T*/T*.yaml`` file, and upserts :class:`TestTemplate` records
into the database.
Strategy
--------
The GitHub REST API without authentication only allows 60 req/hour.
Since the Atomic Red Team repo contains 1 500+ YAML files we avoid
per-file requests entirely. Instead we:
1. Download the full repo as a ZIP archive (~40 MB).
2. Extract in a temporary directory.
3. Walk ``atomics/T*/T*.yaml`` files parsing them with PyYAML.
4. Create / update ``TestTemplate`` rows keyed by ``atomic_test_id``.
5. Clean up the temporary directory.
Idempotency
-----------
Running the import twice does **not** create duplicates. Existing
templates are identified by their ``atomic_test_id`` and simply skipped.
"""
import io
import logging
import os
import shutil
import tempfile
import zipfile
from pathlib import Path
import requests as _requests
import yaml
from sqlalchemy.orm import Session
from app.models.test_template import TestTemplate
from app.services.audit_service import log_action
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
ATOMIC_RT_ZIP_URL = (
"https://github.com/redcanaryco/atomic-red-team"
"/archive/refs/heads/master.zip"
)
# Request timeout for the ZIP download (seconds)
_DOWNLOAD_TIMEOUT = 300
# Top-level directory name inside the ZIP
_ZIP_ROOT_PREFIX = "atomic-red-team-master"
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
def _download_zip(url: str = ATOMIC_RT_ZIP_URL) -> bytes:
"""Download the Atomic Red Team ZIP and return its raw bytes."""
logger.info("Downloading Atomic Red Team ZIP from %s", url)
resp = _requests.get(url, timeout=_DOWNLOAD_TIMEOUT, stream=True)
resp.raise_for_status()
content = resp.content
logger.info("Downloaded %.1f MB", len(content) / (1024 * 1024))
return content
def _extract_zip(zip_bytes: bytes, dest: str) -> Path:
"""Extract *zip_bytes* into *dest* and return the path to the atomics/ dir."""
with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
zf.extractall(dest)
atomics_dir = Path(dest) / _ZIP_ROOT_PREFIX / "atomics"
if not atomics_dir.is_dir():
raise FileNotFoundError(
f"Expected atomics directory not found at {atomics_dir}"
)
return atomics_dir
def _parse_yaml_files(atomics_dir: Path) -> list[dict]:
"""Walk the atomics directory and parse all technique YAML files.
Returns a flat list of dicts, each representing a single atomic test
with the following keys::
technique_id, index, name, description, platforms,
executor_type, command, source_url
"""
results: list[dict] = []
yaml_files = sorted(atomics_dir.glob("T*/T*.yaml"))
logger.info("Found %d YAML files to parse", len(yaml_files))
for yaml_path in yaml_files:
technique_id = yaml_path.stem # e.g. "T1059.001"
try:
with open(yaml_path, "r", encoding="utf-8") as fh:
data = yaml.safe_load(fh)
except Exception as exc:
logger.warning("Failed to parse %s: %s", yaml_path, exc)
continue
if not data or "atomic_tests" not in data:
continue
for idx, test in enumerate(data["atomic_tests"]):
name = test.get("name", "").strip()
description = test.get("description", "").strip()
platforms = test.get("supported_platforms", [])
executor = test.get("executor", {})
executor_type = executor.get("name", "") if isinstance(executor, dict) else ""
command = executor.get("command", "") if isinstance(executor, dict) else ""
# Build an atomic_test_id in the format "T1059.001-0"
atomic_test_id = f"{technique_id}-{idx}"
source_url = (
f"https://github.com/redcanaryco/atomic-red-team/blob/master"
f"/atomics/{technique_id}/{technique_id}.yaml"
)
results.append({
"technique_id": technique_id,
"index": idx,
"atomic_test_id": atomic_test_id,
"name": name,
"description": description,
"platforms": ", ".join(platforms) if isinstance(platforms, list) else str(platforms),
"executor_type": executor_type,
"command": command[:4000] if command else None, # cap at 4k chars
"source_url": source_url,
})
logger.info("Parsed %d atomic tests total", len(results))
return results
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def import_atomic_red_team(db: Session) -> dict:
"""Download and import Atomic Red Team tests as TestTemplates.
Parameters
----------
db : Session
Active SQLAlchemy database session.
Returns
-------
dict
Summary with keys ``created``, ``skipped_existing``,
``yaml_files_parsed``, ``total_tests_parsed``.
"""
tmp_dir = tempfile.mkdtemp(prefix="aegis_atomic_")
try:
zip_bytes = _download_zip()
atomics_dir = _extract_zip(zip_bytes, tmp_dir)
parsed_tests = _parse_yaml_files(atomics_dir)
finally:
# Always clean up
shutil.rmtree(tmp_dir, ignore_errors=True)
logger.info("Cleaned up temp directory %s", tmp_dir)
# Pre-load existing atomic_test_ids for dedup
existing_ids: set[str] = {
row[0]
for row in db.query(TestTemplate.atomic_test_id)
.filter(TestTemplate.atomic_test_id.isnot(None))
.all()
}
created = 0
skipped = 0
for item in parsed_tests:
if item["atomic_test_id"] in existing_ids:
skipped += 1
continue
template = TestTemplate(
mitre_technique_id=item["technique_id"],
name=item["name"][:500] if item["name"] else f"Atomic Test {item['atomic_test_id']}",
description=item["description"][:2000] if item["description"] else None,
source="atomic_red_team",
source_url=item["source_url"],
attack_procedure=item["command"],
platform=item["platforms"],
tool_suggested=item["executor_type"] if item["executor_type"] else None,
atomic_test_id=item["atomic_test_id"],
is_active=True,
)
db.add(template)
existing_ids.add(item["atomic_test_id"])
created += 1
db.commit()
# Count distinct YAML files by technique_id
yaml_files_count = len({t["technique_id"] for t in parsed_tests})
summary = {
"created": created,
"skipped_existing": skipped,
"yaml_files_parsed": yaml_files_count,
"total_tests_parsed": len(parsed_tests),
}
logger.info(
"Atomic Red Team import complete — created=%d, skipped=%d, "
"yaml_files=%d, total_tests=%d",
created, skipped, yaml_files_count, len(parsed_tests),
)
# Audit log (system action)
log_action(
db,
user_id=None,
action="import_atomic_red_team",
entity_type="test_template",
entity_id=None,
details=summary,
)
return summary

View File

@@ -1,36 +1,46 @@
"""Service for recalculating the global status of a Technique
based on the state and result of its associated tests."""
based on the state and result of its associated tests.
V2 rules account for dual Red/Blue validation and use
``detection_result`` (filled by Blue Team) instead of the legacy
``result`` field.
"""
from sqlalchemy.orm import Session
from app.models.enums import TechniqueStatus
from app.models.enums import TechniqueStatus, TestState
from app.models.technique import Technique
def recalculate_technique_status(db: Session, technique: Technique) -> None:
"""Recompute ``technique.status_global`` from its tests and commit.
Rules
-----
- No tests → ``not_evaluated``
- Any test not yet ``validated`` → ``in_progress``
- All validated and all ``detected`` → ``validated``
- All validated and any ``partially_detected`` → ``partial``
- Otherwise → ``not_covered``
Rules (v2)
----------
1. No tests → ``not_evaluated``
2. All tests ``validated`` → look at detection results:
- All ``detected`` → ``validated``
- Any ``partially_detected`` → ``partial``
- Otherwise → ``not_covered``
3. Some tests ``validated``, others still in progress → ``partial``
4. All tests in intermediate states (no validated) → ``in_progress``
"""
tests = technique.tests
if not tests:
technique.status_global = TechniqueStatus.not_evaluated
elif any(t.state != "validated" for t in tests):
technique.status_global = TechniqueStatus.in_progress
else:
results = [t.result for t in tests]
if all(r == "detected" for r in results):
elif all(t.state == TestState.validated for t in tests):
# All validated — inspect detection results
results = [t.detection_result for t in tests if t.detection_result]
if results and all(str(r) == "detected" or r == "detected" for r in results):
technique.status_global = TechniqueStatus.validated
elif any(r == "partially_detected" for r in results):
elif any(str(r) == "partially_detected" or r == "partially_detected" for r in results):
technique.status_global = TechniqueStatus.partial
else:
technique.status_global = TechniqueStatus.not_covered
elif any(t.state == TestState.validated for t in tests):
technique.status_global = TechniqueStatus.partial
else:
technique.status_global = TechniqueStatus.in_progress
db.commit()

View File

@@ -0,0 +1,285 @@
"""Test workflow service — state-machine transitions for the Red/Blue validation flow.
Controls which state transitions are valid and exposes high-level helpers
for each step in the test lifecycle:
draft → red_executing → blue_evaluating → in_review → validated / rejected
rejected → draft
Every public function validates the transition, mutates the test, writes an
audit-log entry, and commits the session.
"""
from datetime import datetime
from fastapi import HTTPException, status
from sqlalchemy.orm import Session
from app.models.enums import TestState
from app.models.test import Test
from app.models.user import User
from app.services.audit_service import log_action
# ---------------------------------------------------------------------------
# Valid transition map
# ---------------------------------------------------------------------------
VALID_TRANSITIONS: dict[TestState, list[TestState]] = {
TestState.draft: [TestState.red_executing],
TestState.red_executing: [TestState.blue_evaluating],
TestState.blue_evaluating: [TestState.in_review],
TestState.in_review: [TestState.validated, TestState.rejected],
TestState.rejected: [TestState.draft],
TestState.validated: [], # terminal state
}
# ---------------------------------------------------------------------------
# Core helpers
# ---------------------------------------------------------------------------
def can_transition(test: Test, target_state: TestState) -> bool:
"""Return *True* if moving *test* to *target_state* is allowed."""
current = test.state if isinstance(test.state, TestState) else TestState(test.state)
return target_state in VALID_TRANSITIONS.get(current, [])
def transition_state(
db: Session,
test: Test,
target_state: TestState,
user: User,
*,
action_name: str = "transition_state",
extra_details: dict | None = None,
) -> Test:
"""Validate and perform a state transition, log it, and commit.
Raises :class:`~fastapi.HTTPException` 400 when the transition is invalid.
"""
if not can_transition(test, target_state):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=(
f"Invalid transition: cannot move from "
f"'{test.state.value if isinstance(test.state, TestState) else test.state}' "
f"to '{target_state.value}'"
),
)
previous_state = test.state.value if isinstance(test.state, TestState) else test.state
test.state = target_state
db.flush()
details: dict = {
"previous_state": previous_state,
"new_state": target_state.value,
"test_name": test.name,
"technique_id": str(test.technique_id),
}
if extra_details:
details.update(extra_details)
log_action(
db,
user_id=user.id,
action=action_name,
entity_type="test",
entity_id=test.id,
details=details,
)
return test
# ---------------------------------------------------------------------------
# Lifecycle convenience functions
# ---------------------------------------------------------------------------
def start_execution(db: Session, test: Test, user: User) -> Test:
"""Move from ``draft`` → ``red_executing``.
Typically called by a **red_tech** when they begin the attack.
"""
test = transition_state(
db, test, TestState.red_executing, user,
action_name="start_execution",
)
test.execution_date = datetime.utcnow()
db.commit()
return test
def submit_red_evidence(db: Session, test: Test, user: User) -> Test:
"""Move from ``red_executing`` → ``blue_evaluating``.
Called by **red_tech** once they have finished documenting the attack.
"""
test = transition_state(
db, test, TestState.blue_evaluating, user,
action_name="submit_red_evidence",
)
db.commit()
return test
def submit_blue_evidence(db: Session, test: Test, user: User) -> Test:
"""Move from ``blue_evaluating`` → ``in_review``.
Called by **blue_tech** once they have finished documenting detection.
"""
test = transition_state(
db, test, TestState.in_review, user,
action_name="submit_blue_evidence",
)
db.commit()
return test
def validate_as_red_lead(
db: Session,
test: Test,
user: User,
validation_status: str,
notes: str | None = None,
) -> Test:
"""Record Red Lead's validation decision.
*validation_status* must be ``"approved"`` or ``"rejected"``.
After recording the decision, :func:`check_dual_validation` is called
to potentially advance the test to ``validated`` or ``rejected``.
"""
if test.state not in (TestState.in_review,):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Cannot validate red side while test is in '{test.state.value if isinstance(test.state, TestState) else test.state}' state (must be in_review)",
)
if validation_status not in ("approved", "rejected"):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="validation_status must be 'approved' or 'rejected'",
)
now = datetime.utcnow()
test.red_validation_status = validation_status
test.red_validated_by = user.id
test.red_validated_at = now
test.red_validation_notes = notes
log_action(
db,
user_id=user.id,
action="validate_as_red_lead",
entity_type="test",
entity_id=test.id,
details={
"validation_status": validation_status,
"notes": notes,
"technique_id": str(test.technique_id),
},
)
check_dual_validation(db, test)
return test
def validate_as_blue_lead(
db: Session,
test: Test,
user: User,
validation_status: str,
notes: str | None = None,
) -> Test:
"""Record Blue Lead's validation decision.
*validation_status* must be ``"approved"`` or ``"rejected"``.
After recording the decision, :func:`check_dual_validation` is called
to potentially advance the test to ``validated`` or ``rejected``.
"""
if test.state not in (TestState.in_review,):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Cannot validate blue side while test is in '{test.state.value if isinstance(test.state, TestState) else test.state}' state (must be in_review)",
)
if validation_status not in ("approved", "rejected"):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="validation_status must be 'approved' or 'rejected'",
)
now = datetime.utcnow()
test.blue_validation_status = validation_status
test.blue_validated_by = user.id
test.blue_validated_at = now
test.blue_validation_notes = notes
log_action(
db,
user_id=user.id,
action="validate_as_blue_lead",
entity_type="test",
entity_id=test.id,
details={
"validation_status": validation_status,
"notes": notes,
"technique_id": str(test.technique_id),
},
)
check_dual_validation(db, test)
return test
def check_dual_validation(db: Session, test: Test) -> Test:
"""Evaluate both leads' decisions and advance the test if both have voted.
- Both **approved** → ``validated``
- Either **rejected** → ``rejected``
- Otherwise no state change (waiting for the other lead).
Commits only when the state actually changes.
"""
red_status = test.red_validation_status
blue_status = test.blue_validation_status
if red_status == "rejected" or blue_status == "rejected":
test.state = TestState.rejected
db.commit()
elif red_status == "approved" and blue_status == "approved":
test.state = TestState.validated
db.commit()
else:
# One side hasn't voted yet — stay in_review, just flush
db.commit()
return test
def reopen_test(db: Session, test: Test, user: User) -> Test:
"""Move a ``rejected`` test back to ``draft``, clearing validation fields.
This allows the teams to redo the test cycle.
"""
test = transition_state(
db, test, TestState.draft, user,
action_name="reopen_test",
)
# Clear dual-validation fields
test.red_validation_status = None
test.red_validated_by = None
test.red_validated_at = None
test.red_validation_notes = None
test.blue_validation_status = None
test.blue_validated_by = None
test.blue_validated_at = None
test.blue_validation_notes = None
db.commit()
return test