From 5c55e7c17fae820c8a80aa276b51f5f9aa450f09 Mon Sep 17 00:00:00 2001 From: Kitos Date: Wed, 18 Feb 2026 19:10:31 +0100 Subject: [PATCH] feat(domain): add domain layer foundation -- enums, value objects, TechniqueEntity, repository ports --- backend/app/domain/entities/__init__.py | 3 + backend/app/domain/entities/technique.py | 159 +++++++++++++++++ backend/app/domain/enums.py | 37 ++++ backend/app/domain/ports/__init__.py | 0 .../app/domain/ports/repositories/__init__.py | 4 + .../repositories/technique_repository.py | 57 ++++++ .../ports/repositories/test_repository.py | 52 ++++++ backend/app/domain/value_objects/__init__.py | 4 + backend/app/domain/value_objects/mitre_id.py | 51 ++++++ .../domain/value_objects/scoring_weights.py | 48 +++++ backend/app/models/enums.py | 39 ++-- backend/tests/test_domain_enums.py | 53 ++++++ backend/tests/test_technique_entity.py | 168 ++++++++++++++++++ backend/tests/test_value_objects.py | 114 ++++++++++++ 14 files changed, 761 insertions(+), 28 deletions(-) create mode 100644 backend/app/domain/entities/__init__.py create mode 100644 backend/app/domain/entities/technique.py create mode 100644 backend/app/domain/enums.py create mode 100644 backend/app/domain/ports/__init__.py create mode 100644 backend/app/domain/ports/repositories/__init__.py create mode 100644 backend/app/domain/ports/repositories/technique_repository.py create mode 100644 backend/app/domain/ports/repositories/test_repository.py create mode 100644 backend/app/domain/value_objects/__init__.py create mode 100644 backend/app/domain/value_objects/mitre_id.py create mode 100644 backend/app/domain/value_objects/scoring_weights.py create mode 100644 backend/tests/test_domain_enums.py create mode 100644 backend/tests/test_technique_entity.py create mode 100644 backend/tests/test_value_objects.py diff --git a/backend/app/domain/entities/__init__.py b/backend/app/domain/entities/__init__.py new file mode 100644 index 0000000..79c7772 --- /dev/null +++ b/backend/app/domain/entities/__init__.py @@ -0,0 +1,3 @@ +from app.domain.entities.technique import TechniqueEntity + +__all__ = ["TechniqueEntity"] diff --git a/backend/app/domain/entities/technique.py b/backend/app/domain/entities/technique.py new file mode 100644 index 0000000..8c7571c --- /dev/null +++ b/backend/app/domain/entities/technique.py @@ -0,0 +1,159 @@ +"""TechniqueEntity — pure domain object for a MITRE ATT&CK technique. + +Owns the status recalculation logic that was previously in +``status_service.py``. Has **no** dependency on FastAPI, SQLAlchemy, +or any infrastructure concern. + +Usage:: + + entity = TechniqueEntity.from_orm(technique_orm_model) + entity.recalculate_status(test_states_and_results) + entity.mark_reviewed() + entity.apply_to(technique_orm_model) +""" + +from __future__ import annotations + +import uuid +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any + +from app.domain.enums import TechniqueStatus, TestResult, TestState +from app.domain.value_objects.mitre_id import MitreId + + +@dataclass(frozen=True) +class _TestSnapshot: + """Minimal read-only view of a test for status calculation.""" + + state: TestState + detection_result: str | None + + +@dataclass +class TechniqueEntity: + """Pure domain representation of a MITRE ATT&CK technique.""" + + id: uuid.UUID + mitre_id: str + name: str + tactic: str | None = None + description: str | None = None + platforms: list[str] = field(default_factory=list) + is_subtechnique: bool = False + parent_mitre_id: str | None = None + status_global: TechniqueStatus = TechniqueStatus.not_evaluated + review_required: bool = False + last_review_date: datetime | None = None + + # -- Factory ----------------------------------------------------------- + + @classmethod + def create( + cls, + *, + mitre_id: str, + name: str, + tactic: str | None = None, + description: str | None = None, + platforms: list[str] | None = None, + ) -> TechniqueEntity: + """Create a new technique, validating the MITRE ID format.""" + validated_id = MitreId(mitre_id) + return cls( + id=uuid.uuid4(), + mitre_id=validated_id.value, + name=name, + tactic=tactic, + description=description, + platforms=platforms or [], + is_subtechnique=validated_id.is_subtechnique, + parent_mitre_id=validated_id.parent_id, + status_global=TechniqueStatus.not_evaluated, + ) + + @classmethod + def from_orm(cls, model: Any) -> TechniqueEntity: + """Build a TechniqueEntity from a SQLAlchemy Technique model.""" + raw_status = model.status_global + status = ( + raw_status + if isinstance(raw_status, TechniqueStatus) + else TechniqueStatus(raw_status) + ) + return cls( + id=model.id, + mitre_id=model.mitre_id, + name=model.name, + tactic=model.tactic, + description=model.description, + platforms=model.platforms or [], + is_subtechnique=model.is_subtechnique or False, + parent_mitre_id=model.parent_mitre_id, + status_global=status, + review_required=model.review_required or False, + last_review_date=model.last_review_date, + ) + + def apply_to(self, model: Any) -> None: + """Copy mutable fields back onto the ORM model.""" + model.status_global = self.status_global + model.review_required = self.review_required + model.last_review_date = self.last_review_date + + # -- Business logic ---------------------------------------------------- + + def recalculate_status( + self, + test_snapshots: list[tuple[str, str | None]], + ) -> TechniqueStatus: + """Recompute ``status_global`` from a list of (state, detection_result) pairs. + + Rules (v2): + 1. No tests -> not_evaluated + 2. All validated -> inspect detection results: + - All detected -> validated + - Any partially_detected -> partial + - Otherwise -> not_covered + 3. Some validated, others in progress -> partial + 4. All in intermediate states -> in_progress + + Returns the new status (also set on the entity). + """ + tests = [ + _TestSnapshot( + state=s if isinstance(s, TestState) else TestState(s), + detection_result=dr, + ) + for s, dr in test_snapshots + ] + + if not tests: + self.status_global = TechniqueStatus.not_evaluated + elif all(t.state == TestState.validated for t in tests): + results = [t.detection_result for t in tests if t.detection_result] + if results and all(r == TestResult.detected or r == "detected" for r in results): + self.status_global = TechniqueStatus.validated + elif any( + r == TestResult.partially_detected or r == "partially_detected" + for r in results + ): + self.status_global = TechniqueStatus.partial + else: + self.status_global = TechniqueStatus.not_covered + elif any(t.state == TestState.validated for t in tests): + self.status_global = TechniqueStatus.partial + else: + self.status_global = TechniqueStatus.in_progress + + return self.status_global + + def mark_reviewed(self) -> None: + """Mark the technique as reviewed, clearing the review flag.""" + self.review_required = False + self.last_review_date = datetime.utcnow() + + def flag_for_review(self) -> None: + """Flag the technique as needing review.""" + self.review_required = True diff --git a/backend/app/domain/enums.py b/backend/app/domain/enums.py new file mode 100644 index 0000000..8e7afa2 --- /dev/null +++ b/backend/app/domain/enums.py @@ -0,0 +1,37 @@ +"""Canonical domain enums for Aegis. + +These enums represent core domain concepts and are the single source of +truth. ``models/enums.py`` re-exports them so that existing ORM code +continues to work without changes. +""" + +import enum + + +class TechniqueStatus(str, enum.Enum): + not_evaluated = "not_evaluated" + in_progress = "in_progress" + validated = "validated" + partial = "partial" + not_covered = "not_covered" + review_required = "review_required" + + +class TestState(str, enum.Enum): + draft = "draft" + red_executing = "red_executing" + blue_evaluating = "blue_evaluating" + in_review = "in_review" + validated = "validated" + rejected = "rejected" + + +class TeamSide(str, enum.Enum): + red = "red" + blue = "blue" + + +class TestResult(str, enum.Enum): + detected = "detected" + not_detected = "not_detected" + partially_detected = "partially_detected" diff --git a/backend/app/domain/ports/__init__.py b/backend/app/domain/ports/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/app/domain/ports/repositories/__init__.py b/backend/app/domain/ports/repositories/__init__.py new file mode 100644 index 0000000..1260672 --- /dev/null +++ b/backend/app/domain/ports/repositories/__init__.py @@ -0,0 +1,4 @@ +from app.domain.ports.repositories.technique_repository import TechniqueRepository +from app.domain.ports.repositories.test_repository import TestRepository + +__all__ = ["TechniqueRepository", "TestRepository"] diff --git a/backend/app/domain/ports/repositories/technique_repository.py b/backend/app/domain/ports/repositories/technique_repository.py new file mode 100644 index 0000000..b5a45e7 --- /dev/null +++ b/backend/app/domain/ports/repositories/technique_repository.py @@ -0,0 +1,57 @@ +"""Port defining how the application accesses technique data. + +This is a domain contract — implementations live in infrastructure/. +The domain layer NEVER imports the implementation. +""" + +from __future__ import annotations + +import uuid +from typing import NamedTuple, Protocol, runtime_checkable + +from app.domain.entities.technique import TechniqueEntity +from app.domain.enums import TechniqueStatus + + +class TechniqueWithCounts(NamedTuple): + """Pre-aggregated technique data for heatmap/scoring.""" + + entity: TechniqueEntity + test_count: int + validated_test_count: int + detection_rule_count: int + + +@runtime_checkable +class TechniqueRepository(Protocol): + """Data access contract for techniques (one per aggregate root).""" + + # -- Single-entity access ---------------------------------------------- + + def find_by_id(self, technique_id: uuid.UUID) -> TechniqueEntity | None: ... + + def find_by_mitre_id(self, mitre_id: str) -> TechniqueEntity | None: ... + + # -- List access ------------------------------------------------------- + + def list_all( + self, + *, + tactic: str | None = None, + status: TechniqueStatus | None = None, + review_required: bool | None = None, + ) -> list[TechniqueEntity]: ... + + def list_by_ids(self, ids: list[uuid.UUID]) -> list[TechniqueEntity]: ... + + # -- Batch queries (scoring/heatmap performance) ----------------------- + + def count_by_status(self) -> dict[TechniqueStatus, int]: ... + + def find_all_with_test_counts(self) -> list[TechniqueWithCounts]: ... + + # -- Mutations --------------------------------------------------------- + + def save(self, technique: TechniqueEntity) -> TechniqueEntity: ... + + def exists_by_mitre_id(self, mitre_id: str) -> bool: ... diff --git a/backend/app/domain/ports/repositories/test_repository.py b/backend/app/domain/ports/repositories/test_repository.py new file mode 100644 index 0000000..79b6a26 --- /dev/null +++ b/backend/app/domain/ports/repositories/test_repository.py @@ -0,0 +1,52 @@ +"""Port defining how the application accesses test data. + +This is a domain contract — implementations live in infrastructure/. +""" + +from __future__ import annotations + +import uuid +from typing import Protocol, runtime_checkable + +from app.domain.enums import TestState + + +class TestRepository(Protocol): + """Data access contract for tests.""" + + # -- Single-entity access ---------------------------------------------- + + def find_by_id(self, test_id: uuid.UUID) -> object | None: + """Return a Test ORM model by primary key, or None. + + Returns the ORM model directly (not a domain entity) because + the TestEntity is constructed at the service layer via + ``TestEntity.from_orm()``. + """ + ... + + # -- List access ------------------------------------------------------- + + def list_by_technique(self, technique_id: uuid.UUID) -> list[object]: ... + + def list_by_state(self, state: TestState) -> list[object]: ... + + def count_by_technique_and_state( + self, + technique_id: uuid.UUID, + ) -> dict[TestState, int]: + """Return test counts grouped by state for a single technique.""" + ... + + # -- Batch queries ----------------------------------------------------- + + def get_states_and_results_for_technique( + self, + technique_id: uuid.UUID, + ) -> list[tuple[str, str | None]]: + """Return (state, detection_result) pairs for all tests of a technique. + + Used by TechniqueEntity.recalculate_status() without loading full + test models. + """ + ... diff --git a/backend/app/domain/value_objects/__init__.py b/backend/app/domain/value_objects/__init__.py new file mode 100644 index 0000000..bc332a6 --- /dev/null +++ b/backend/app/domain/value_objects/__init__.py @@ -0,0 +1,4 @@ +from app.domain.value_objects.mitre_id import MitreId +from app.domain.value_objects.scoring_weights import ScoringWeights + +__all__ = ["MitreId", "ScoringWeights"] diff --git a/backend/app/domain/value_objects/mitre_id.py b/backend/app/domain/value_objects/mitre_id.py new file mode 100644 index 0000000..092a5a3 --- /dev/null +++ b/backend/app/domain/value_objects/mitre_id.py @@ -0,0 +1,51 @@ +"""MitreId — validated MITRE ATT&CK technique identifier. + +Immutable value object that ensures the identifier follows the ATT&CK +format: ``T`` followed by 4 digits, optionally a dot and 3 more digits +for sub-techniques (e.g. ``T1059``, ``T1059.001``). +""" + +from __future__ import annotations + +import re +from dataclasses import dataclass + +_MITRE_ID_RE = re.compile(r"^T\d{4}(\.\d{3})?$") + + +@dataclass(frozen=True, slots=True) +class MitreId: + """Validated MITRE ATT&CK technique identifier.""" + + value: str + + def __post_init__(self) -> None: + if not _MITRE_ID_RE.match(self.value): + raise ValueError( + f"Invalid MITRE ATT&CK ID '{self.value}'. " + "Expected format: T1234 or T1234.001" + ) + + @property + def is_subtechnique(self) -> bool: + return "." in self.value + + @property + def parent_id(self) -> str | None: + """Return the parent technique ID (e.g. T1059 for T1059.001).""" + if not self.is_subtechnique: + return None + return self.value.split(".")[0] + + def __str__(self) -> str: + return self.value + + def __eq__(self, other: object) -> bool: + if isinstance(other, MitreId): + return self.value == other.value + if isinstance(other, str): + return self.value == other + return NotImplemented + + def __hash__(self) -> int: + return hash(self.value) diff --git a/backend/app/domain/value_objects/scoring_weights.py b/backend/app/domain/value_objects/scoring_weights.py new file mode 100644 index 0000000..8d35016 --- /dev/null +++ b/backend/app/domain/value_objects/scoring_weights.py @@ -0,0 +1,48 @@ +"""ScoringWeights — validated immutable weight set for the scoring engine. + +Enforces that all five weights are non-negative and sum to exactly 100. +""" + +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass(frozen=True, slots=True) +class ScoringWeights: + """Five scoring dimension weights that must sum to 100.""" + + tests: float + detection_rules: float + d3fend: float + freshness: float + platform_diversity: float + + def __post_init__(self) -> None: + fields = [ + self.tests, + self.detection_rules, + self.d3fend, + self.freshness, + self.platform_diversity, + ] + for f in fields: + if f < 0: + raise ValueError("Scoring weights must be non-negative") + + total = sum(fields) + if abs(total - 100) > 0.01: + raise ValueError( + f"Scoring weights must sum to 100, got {total}" + ) + + @classmethod + def default(cls) -> ScoringWeights: + """Return the default weight distribution.""" + return cls( + tests=40.0, + detection_rules=25.0, + d3fend=15.0, + freshness=10.0, + platform_diversity=10.0, + ) diff --git a/backend/app/models/enums.py b/backend/app/models/enums.py index 8df0125..6e1cb32 100644 --- a/backend/app/models/enums.py +++ b/backend/app/models/enums.py @@ -1,30 +1,13 @@ -import enum +"""ORM-level re-exports of the canonical domain enums. +The single source of truth lives in ``app.domain.enums``. This module +re-exports every enum so that existing model and router code keeps +working with ``from app.models.enums import ...``. +""" -class TechniqueStatus(str, enum.Enum): - not_evaluated = "not_evaluated" - in_progress = "in_progress" - validated = "validated" - partial = "partial" - not_covered = "not_covered" - review_required = "review_required" - - -class TestState(str, enum.Enum): - draft = "draft" - red_executing = "red_executing" # Red Team documenting attack - blue_evaluating = "blue_evaluating" # Blue Team evaluating detection - in_review = "in_review" - validated = "validated" - rejected = "rejected" - - -class TeamSide(str, enum.Enum): - red = "red" - blue = "blue" - - -class TestResult(str, enum.Enum): - detected = "detected" - not_detected = "not_detected" - partially_detected = "partially_detected" +from app.domain.enums import ( # noqa: F401 + TeamSide, + TechniqueStatus, + TestResult, + TestState, +) diff --git a/backend/tests/test_domain_enums.py b/backend/tests/test_domain_enums.py new file mode 100644 index 0000000..7254469 --- /dev/null +++ b/backend/tests/test_domain_enums.py @@ -0,0 +1,53 @@ +"""Tests verifying domain enums are canonical and properly re-exported.""" + +import sys +import os + +backend_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +if backend_dir not in sys.path: + sys.path.insert(0, backend_dir) + +from app.domain.enums import TechniqueStatus, TestState, TeamSide, TestResult + + +def test_technique_status_values(): + assert TechniqueStatus.not_evaluated == "not_evaluated" + assert TechniqueStatus.validated == "validated" + assert TechniqueStatus.partial == "partial" + assert TechniqueStatus.in_progress == "in_progress" + assert TechniqueStatus.not_covered == "not_covered" + assert TechniqueStatus.review_required == "review_required" + + +def test_test_state_values(): + assert TestState.draft == "draft" + assert TestState.red_executing == "red_executing" + assert TestState.blue_evaluating == "blue_evaluating" + assert TestState.in_review == "in_review" + assert TestState.validated == "validated" + assert TestState.rejected == "rejected" + + +def test_team_side_values(): + assert TeamSide.red == "red" + assert TeamSide.blue == "blue" + + +def test_test_result_values(): + assert TestResult.detected == "detected" + assert TestResult.not_detected == "not_detected" + assert TestResult.partially_detected == "partially_detected" + + +def test_models_enums_reexport_is_same_class(): + """Verify models/enums.py re-exports the exact same class objects.""" + from app.models.enums import ( + TechniqueStatus as MS, + TestState as MTS, + TeamSide as MTeam, + TestResult as MTR, + ) + assert MS is TechniqueStatus + assert MTS is TestState + assert MTeam is TeamSide + assert MTR is TestResult diff --git a/backend/tests/test_technique_entity.py b/backend/tests/test_technique_entity.py new file mode 100644 index 0000000..3b749d7 --- /dev/null +++ b/backend/tests/test_technique_entity.py @@ -0,0 +1,168 @@ +"""Tests for TechniqueEntity — pure domain logic, no DB.""" + +import uuid +import sys +import os +from datetime import datetime +from unittest.mock import MagicMock + +backend_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +if backend_dir not in sys.path: + sys.path.insert(0, backend_dir) + +import pytest + +from app.domain.entities.technique import TechniqueEntity +from app.domain.enums import TechniqueStatus + + +# ── Helpers ────────────────────────────────────────────────────────── + + +def _entity(**overrides) -> TechniqueEntity: + defaults = dict( + id=uuid.uuid4(), + mitre_id="T1059", + name="Command and Scripting Interpreter", + tactic="execution", + status_global=TechniqueStatus.not_evaluated, + ) + defaults.update(overrides) + return TechniqueEntity(**defaults) + + +def _fake_orm(**overrides) -> MagicMock: + m = MagicMock() + m.id = uuid.uuid4() + m.mitre_id = "T1059" + m.name = "Command and Scripting Interpreter" + m.tactic = "execution" + m.description = None + m.platforms = ["windows", "linux"] + m.is_subtechnique = False + m.parent_mitre_id = None + m.status_global = "not_evaluated" + m.review_required = False + m.last_review_date = None + for k, v in overrides.items(): + setattr(m, k, v) + return m + + +# ── 1. Factory: create() ──────────────────────────────────────────── + + +class TestCreate: + + def test_valid_technique(self): + e = TechniqueEntity.create( + mitre_id="T1059", + name="Command and Scripting Interpreter", + tactic="execution", + ) + assert e.mitre_id == "T1059" + assert e.name == "Command and Scripting Interpreter" + assert e.status_global == TechniqueStatus.not_evaluated + assert not e.is_subtechnique + + def test_valid_subtechnique(self): + e = TechniqueEntity.create( + mitre_id="T1059.001", + name="PowerShell", + tactic="execution", + ) + assert e.is_subtechnique + assert e.parent_mitre_id == "T1059" + + def test_invalid_mitre_id_raises(self): + with pytest.raises(ValueError, match="Invalid MITRE"): + TechniqueEntity.create(mitre_id="INVALID", name="Bad", tactic="x") + + def test_platforms_default_to_empty(self): + e = TechniqueEntity.create(mitre_id="T1059", name="Test") + assert e.platforms == [] + + +# ── 2. from_orm / apply_to ────────────────────────────────────────── + + +class TestOrmRoundTrip: + + def test_from_orm_basic(self): + orm = _fake_orm() + e = TechniqueEntity.from_orm(orm) + assert e.mitre_id == "T1059" + assert e.status_global == TechniqueStatus.not_evaluated + + def test_from_orm_coerces_string_status(self): + orm = _fake_orm(status_global="validated") + e = TechniqueEntity.from_orm(orm) + assert e.status_global == TechniqueStatus.validated + + def test_apply_to_updates_model(self): + orm = _fake_orm() + e = TechniqueEntity.from_orm(orm) + e.status_global = TechniqueStatus.validated + e.review_required = True + e.apply_to(orm) + assert orm.status_global == TechniqueStatus.validated + assert orm.review_required is True + + +# ── 3. recalculate_status ─────────────────────────────────────────── + + +class TestRecalculateStatus: + + def test_no_tests_gives_not_evaluated(self): + e = _entity() + result = e.recalculate_status([]) + assert result == TechniqueStatus.not_evaluated + assert e.status_global == TechniqueStatus.not_evaluated + + def test_all_validated_all_detected_gives_validated(self): + e = _entity() + tests = [("validated", "detected"), ("validated", "detected")] + assert e.recalculate_status(tests) == TechniqueStatus.validated + + def test_all_validated_some_partially_gives_partial(self): + e = _entity() + tests = [("validated", "detected"), ("validated", "partially_detected")] + assert e.recalculate_status(tests) == TechniqueStatus.partial + + def test_all_validated_none_detected_gives_not_covered(self): + e = _entity() + tests = [("validated", "not_detected"), ("validated", "not_detected")] + assert e.recalculate_status(tests) == TechniqueStatus.not_covered + + def test_all_validated_no_results_gives_not_covered(self): + e = _entity() + tests = [("validated", None)] + assert e.recalculate_status(tests) == TechniqueStatus.not_covered + + def test_mixed_validated_and_in_progress_gives_partial(self): + e = _entity() + tests = [("validated", "detected"), ("draft", None)] + assert e.recalculate_status(tests) == TechniqueStatus.partial + + def test_all_in_progress_gives_in_progress(self): + e = _entity() + tests = [("draft", None), ("red_executing", None)] + assert e.recalculate_status(tests) == TechniqueStatus.in_progress + + +# ── 4. mark_reviewed / flag_for_review ────────────────────────────── + + +class TestReviewCycle: + + def test_mark_reviewed_clears_flag(self): + e = _entity(review_required=True) + e.mark_reviewed() + assert e.review_required is False + assert e.last_review_date is not None + + def test_flag_for_review(self): + e = _entity(review_required=False) + e.flag_for_review() + assert e.review_required is True diff --git a/backend/tests/test_value_objects.py b/backend/tests/test_value_objects.py new file mode 100644 index 0000000..19a4460 --- /dev/null +++ b/backend/tests/test_value_objects.py @@ -0,0 +1,114 @@ +"""Tests for domain value objects: MitreId and ScoringWeights.""" + +import sys +import os + +backend_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +if backend_dir not in sys.path: + sys.path.insert(0, backend_dir) + +import pytest + +from app.domain.value_objects.mitre_id import MitreId +from app.domain.value_objects.scoring_weights import ScoringWeights + + +# ── MitreId ────────────────────────────────────────────────────────── + + +class TestMitreId: + + def test_valid_technique(self): + mid = MitreId("T1059") + assert mid.value == "T1059" + assert str(mid) == "T1059" + assert not mid.is_subtechnique + assert mid.parent_id is None + + def test_valid_subtechnique(self): + mid = MitreId("T1059.001") + assert mid.value == "T1059.001" + assert mid.is_subtechnique + assert mid.parent_id == "T1059" + + def test_invalid_empty_string(self): + with pytest.raises(ValueError, match="Invalid MITRE"): + MitreId("") + + def test_invalid_no_prefix(self): + with pytest.raises(ValueError, match="Invalid MITRE"): + MitreId("1059") + + def test_invalid_wrong_prefix(self): + with pytest.raises(ValueError, match="Invalid MITRE"): + MitreId("A1059") + + def test_invalid_too_few_digits(self): + with pytest.raises(ValueError, match="Invalid MITRE"): + MitreId("T105") + + def test_invalid_subtechnique_format(self): + with pytest.raises(ValueError, match="Invalid MITRE"): + MitreId("T1059.01") # needs 3 digits after dot + + def test_invalid_trailing_garbage(self): + with pytest.raises(ValueError, match="Invalid MITRE"): + MitreId("T1059.001.002") + + def test_equality_with_same_mitre_id(self): + assert MitreId("T1059") == MitreId("T1059") + + def test_equality_with_string(self): + assert MitreId("T1059") == "T1059" + + def test_inequality(self): + assert MitreId("T1059") != MitreId("T1060") + + def test_hashable(self): + s = {MitreId("T1059"), MitreId("T1059"), MitreId("T1060")} + assert len(s) == 2 + + def test_immutable(self): + mid = MitreId("T1059") + with pytest.raises(AttributeError): + mid.value = "T1060" + + +# ── ScoringWeights ─────────────────────────────────────────────────── + + +class TestScoringWeights: + + def test_valid_default(self): + w = ScoringWeights.default() + assert w.tests == 40.0 + assert w.detection_rules == 25.0 + assert w.d3fend == 15.0 + assert w.freshness == 10.0 + assert w.platform_diversity == 10.0 + + def test_valid_custom(self): + w = ScoringWeights( + tests=50, detection_rules=20, d3fend=10, + freshness=10, platform_diversity=10, + ) + assert w.tests == 50 + + def test_invalid_sum_not_100(self): + with pytest.raises(ValueError, match="sum to 100"): + ScoringWeights( + tests=50, detection_rules=20, d3fend=10, + freshness=10, platform_diversity=5, + ) + + def test_invalid_negative_weight(self): + with pytest.raises(ValueError, match="non-negative"): + ScoringWeights( + tests=-10, detection_rules=40, d3fend=30, + freshness=20, platform_diversity=20, + ) + + def test_immutable(self): + w = ScoringWeights.default() + with pytest.raises(AttributeError): + w.tests = 50