feat(domain): add domain layer foundation -- enums, value objects, TechniqueEntity, repository ports
This commit is contained in:
3
backend/app/domain/entities/__init__.py
Normal file
3
backend/app/domain/entities/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from app.domain.entities.technique import TechniqueEntity
|
||||||
|
|
||||||
|
__all__ = ["TechniqueEntity"]
|
||||||
159
backend/app/domain/entities/technique.py
Normal file
159
backend/app/domain/entities/technique.py
Normal file
@@ -0,0 +1,159 @@
|
|||||||
|
"""TechniqueEntity — pure domain object for a MITRE ATT&CK technique.
|
||||||
|
|
||||||
|
Owns the status recalculation logic that was previously in
|
||||||
|
``status_service.py``. Has **no** dependency on FastAPI, SQLAlchemy,
|
||||||
|
or any infrastructure concern.
|
||||||
|
|
||||||
|
Usage::
|
||||||
|
|
||||||
|
entity = TechniqueEntity.from_orm(technique_orm_model)
|
||||||
|
entity.recalculate_status(test_states_and_results)
|
||||||
|
entity.mark_reviewed()
|
||||||
|
entity.apply_to(technique_orm_model)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.domain.enums import TechniqueStatus, TestResult, TestState
|
||||||
|
from app.domain.value_objects.mitre_id import MitreId
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class _TestSnapshot:
|
||||||
|
"""Minimal read-only view of a test for status calculation."""
|
||||||
|
|
||||||
|
state: TestState
|
||||||
|
detection_result: str | None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TechniqueEntity:
|
||||||
|
"""Pure domain representation of a MITRE ATT&CK technique."""
|
||||||
|
|
||||||
|
id: uuid.UUID
|
||||||
|
mitre_id: str
|
||||||
|
name: str
|
||||||
|
tactic: str | None = None
|
||||||
|
description: str | None = None
|
||||||
|
platforms: list[str] = field(default_factory=list)
|
||||||
|
is_subtechnique: bool = False
|
||||||
|
parent_mitre_id: str | None = None
|
||||||
|
status_global: TechniqueStatus = TechniqueStatus.not_evaluated
|
||||||
|
review_required: bool = False
|
||||||
|
last_review_date: datetime | None = None
|
||||||
|
|
||||||
|
# -- Factory -----------------------------------------------------------
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(
|
||||||
|
cls,
|
||||||
|
*,
|
||||||
|
mitre_id: str,
|
||||||
|
name: str,
|
||||||
|
tactic: str | None = None,
|
||||||
|
description: str | None = None,
|
||||||
|
platforms: list[str] | None = None,
|
||||||
|
) -> TechniqueEntity:
|
||||||
|
"""Create a new technique, validating the MITRE ID format."""
|
||||||
|
validated_id = MitreId(mitre_id)
|
||||||
|
return cls(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
mitre_id=validated_id.value,
|
||||||
|
name=name,
|
||||||
|
tactic=tactic,
|
||||||
|
description=description,
|
||||||
|
platforms=platforms or [],
|
||||||
|
is_subtechnique=validated_id.is_subtechnique,
|
||||||
|
parent_mitre_id=validated_id.parent_id,
|
||||||
|
status_global=TechniqueStatus.not_evaluated,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_orm(cls, model: Any) -> TechniqueEntity:
|
||||||
|
"""Build a TechniqueEntity from a SQLAlchemy Technique model."""
|
||||||
|
raw_status = model.status_global
|
||||||
|
status = (
|
||||||
|
raw_status
|
||||||
|
if isinstance(raw_status, TechniqueStatus)
|
||||||
|
else TechniqueStatus(raw_status)
|
||||||
|
)
|
||||||
|
return cls(
|
||||||
|
id=model.id,
|
||||||
|
mitre_id=model.mitre_id,
|
||||||
|
name=model.name,
|
||||||
|
tactic=model.tactic,
|
||||||
|
description=model.description,
|
||||||
|
platforms=model.platforms or [],
|
||||||
|
is_subtechnique=model.is_subtechnique or False,
|
||||||
|
parent_mitre_id=model.parent_mitre_id,
|
||||||
|
status_global=status,
|
||||||
|
review_required=model.review_required or False,
|
||||||
|
last_review_date=model.last_review_date,
|
||||||
|
)
|
||||||
|
|
||||||
|
def apply_to(self, model: Any) -> None:
|
||||||
|
"""Copy mutable fields back onto the ORM model."""
|
||||||
|
model.status_global = self.status_global
|
||||||
|
model.review_required = self.review_required
|
||||||
|
model.last_review_date = self.last_review_date
|
||||||
|
|
||||||
|
# -- Business logic ----------------------------------------------------
|
||||||
|
|
||||||
|
def recalculate_status(
|
||||||
|
self,
|
||||||
|
test_snapshots: list[tuple[str, str | None]],
|
||||||
|
) -> TechniqueStatus:
|
||||||
|
"""Recompute ``status_global`` from a list of (state, detection_result) pairs.
|
||||||
|
|
||||||
|
Rules (v2):
|
||||||
|
1. No tests -> not_evaluated
|
||||||
|
2. All validated -> inspect detection results:
|
||||||
|
- All detected -> validated
|
||||||
|
- Any partially_detected -> partial
|
||||||
|
- Otherwise -> not_covered
|
||||||
|
3. Some validated, others in progress -> partial
|
||||||
|
4. All in intermediate states -> in_progress
|
||||||
|
|
||||||
|
Returns the new status (also set on the entity).
|
||||||
|
"""
|
||||||
|
tests = [
|
||||||
|
_TestSnapshot(
|
||||||
|
state=s if isinstance(s, TestState) else TestState(s),
|
||||||
|
detection_result=dr,
|
||||||
|
)
|
||||||
|
for s, dr in test_snapshots
|
||||||
|
]
|
||||||
|
|
||||||
|
if not tests:
|
||||||
|
self.status_global = TechniqueStatus.not_evaluated
|
||||||
|
elif all(t.state == TestState.validated for t in tests):
|
||||||
|
results = [t.detection_result for t in tests if t.detection_result]
|
||||||
|
if results and all(r == TestResult.detected or r == "detected" for r in results):
|
||||||
|
self.status_global = TechniqueStatus.validated
|
||||||
|
elif any(
|
||||||
|
r == TestResult.partially_detected or r == "partially_detected"
|
||||||
|
for r in results
|
||||||
|
):
|
||||||
|
self.status_global = TechniqueStatus.partial
|
||||||
|
else:
|
||||||
|
self.status_global = TechniqueStatus.not_covered
|
||||||
|
elif any(t.state == TestState.validated for t in tests):
|
||||||
|
self.status_global = TechniqueStatus.partial
|
||||||
|
else:
|
||||||
|
self.status_global = TechniqueStatus.in_progress
|
||||||
|
|
||||||
|
return self.status_global
|
||||||
|
|
||||||
|
def mark_reviewed(self) -> None:
|
||||||
|
"""Mark the technique as reviewed, clearing the review flag."""
|
||||||
|
self.review_required = False
|
||||||
|
self.last_review_date = datetime.utcnow()
|
||||||
|
|
||||||
|
def flag_for_review(self) -> None:
|
||||||
|
"""Flag the technique as needing review."""
|
||||||
|
self.review_required = True
|
||||||
37
backend/app/domain/enums.py
Normal file
37
backend/app/domain/enums.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
"""Canonical domain enums for Aegis.
|
||||||
|
|
||||||
|
These enums represent core domain concepts and are the single source of
|
||||||
|
truth. ``models/enums.py`` re-exports them so that existing ORM code
|
||||||
|
continues to work without changes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import enum
|
||||||
|
|
||||||
|
|
||||||
|
class TechniqueStatus(str, enum.Enum):
|
||||||
|
not_evaluated = "not_evaluated"
|
||||||
|
in_progress = "in_progress"
|
||||||
|
validated = "validated"
|
||||||
|
partial = "partial"
|
||||||
|
not_covered = "not_covered"
|
||||||
|
review_required = "review_required"
|
||||||
|
|
||||||
|
|
||||||
|
class TestState(str, enum.Enum):
|
||||||
|
draft = "draft"
|
||||||
|
red_executing = "red_executing"
|
||||||
|
blue_evaluating = "blue_evaluating"
|
||||||
|
in_review = "in_review"
|
||||||
|
validated = "validated"
|
||||||
|
rejected = "rejected"
|
||||||
|
|
||||||
|
|
||||||
|
class TeamSide(str, enum.Enum):
|
||||||
|
red = "red"
|
||||||
|
blue = "blue"
|
||||||
|
|
||||||
|
|
||||||
|
class TestResult(str, enum.Enum):
|
||||||
|
detected = "detected"
|
||||||
|
not_detected = "not_detected"
|
||||||
|
partially_detected = "partially_detected"
|
||||||
0
backend/app/domain/ports/__init__.py
Normal file
0
backend/app/domain/ports/__init__.py
Normal file
4
backend/app/domain/ports/repositories/__init__.py
Normal file
4
backend/app/domain/ports/repositories/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
from app.domain.ports.repositories.technique_repository import TechniqueRepository
|
||||||
|
from app.domain.ports.repositories.test_repository import TestRepository
|
||||||
|
|
||||||
|
__all__ = ["TechniqueRepository", "TestRepository"]
|
||||||
@@ -0,0 +1,57 @@
|
|||||||
|
"""Port defining how the application accesses technique data.
|
||||||
|
|
||||||
|
This is a domain contract — implementations live in infrastructure/.
|
||||||
|
The domain layer NEVER imports the implementation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from typing import NamedTuple, Protocol, runtime_checkable
|
||||||
|
|
||||||
|
from app.domain.entities.technique import TechniqueEntity
|
||||||
|
from app.domain.enums import TechniqueStatus
|
||||||
|
|
||||||
|
|
||||||
|
class TechniqueWithCounts(NamedTuple):
|
||||||
|
"""Pre-aggregated technique data for heatmap/scoring."""
|
||||||
|
|
||||||
|
entity: TechniqueEntity
|
||||||
|
test_count: int
|
||||||
|
validated_test_count: int
|
||||||
|
detection_rule_count: int
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class TechniqueRepository(Protocol):
|
||||||
|
"""Data access contract for techniques (one per aggregate root)."""
|
||||||
|
|
||||||
|
# -- Single-entity access ----------------------------------------------
|
||||||
|
|
||||||
|
def find_by_id(self, technique_id: uuid.UUID) -> TechniqueEntity | None: ...
|
||||||
|
|
||||||
|
def find_by_mitre_id(self, mitre_id: str) -> TechniqueEntity | None: ...
|
||||||
|
|
||||||
|
# -- List access -------------------------------------------------------
|
||||||
|
|
||||||
|
def list_all(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
tactic: str | None = None,
|
||||||
|
status: TechniqueStatus | None = None,
|
||||||
|
review_required: bool | None = None,
|
||||||
|
) -> list[TechniqueEntity]: ...
|
||||||
|
|
||||||
|
def list_by_ids(self, ids: list[uuid.UUID]) -> list[TechniqueEntity]: ...
|
||||||
|
|
||||||
|
# -- Batch queries (scoring/heatmap performance) -----------------------
|
||||||
|
|
||||||
|
def count_by_status(self) -> dict[TechniqueStatus, int]: ...
|
||||||
|
|
||||||
|
def find_all_with_test_counts(self) -> list[TechniqueWithCounts]: ...
|
||||||
|
|
||||||
|
# -- Mutations ---------------------------------------------------------
|
||||||
|
|
||||||
|
def save(self, technique: TechniqueEntity) -> TechniqueEntity: ...
|
||||||
|
|
||||||
|
def exists_by_mitre_id(self, mitre_id: str) -> bool: ...
|
||||||
52
backend/app/domain/ports/repositories/test_repository.py
Normal file
52
backend/app/domain/ports/repositories/test_repository.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
"""Port defining how the application accesses test data.
|
||||||
|
|
||||||
|
This is a domain contract — implementations live in infrastructure/.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from typing import Protocol, runtime_checkable
|
||||||
|
|
||||||
|
from app.domain.enums import TestState
|
||||||
|
|
||||||
|
|
||||||
|
class TestRepository(Protocol):
|
||||||
|
"""Data access contract for tests."""
|
||||||
|
|
||||||
|
# -- Single-entity access ----------------------------------------------
|
||||||
|
|
||||||
|
def find_by_id(self, test_id: uuid.UUID) -> object | None:
|
||||||
|
"""Return a Test ORM model by primary key, or None.
|
||||||
|
|
||||||
|
Returns the ORM model directly (not a domain entity) because
|
||||||
|
the TestEntity is constructed at the service layer via
|
||||||
|
``TestEntity.from_orm()``.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
# -- List access -------------------------------------------------------
|
||||||
|
|
||||||
|
def list_by_technique(self, technique_id: uuid.UUID) -> list[object]: ...
|
||||||
|
|
||||||
|
def list_by_state(self, state: TestState) -> list[object]: ...
|
||||||
|
|
||||||
|
def count_by_technique_and_state(
|
||||||
|
self,
|
||||||
|
technique_id: uuid.UUID,
|
||||||
|
) -> dict[TestState, int]:
|
||||||
|
"""Return test counts grouped by state for a single technique."""
|
||||||
|
...
|
||||||
|
|
||||||
|
# -- Batch queries -----------------------------------------------------
|
||||||
|
|
||||||
|
def get_states_and_results_for_technique(
|
||||||
|
self,
|
||||||
|
technique_id: uuid.UUID,
|
||||||
|
) -> list[tuple[str, str | None]]:
|
||||||
|
"""Return (state, detection_result) pairs for all tests of a technique.
|
||||||
|
|
||||||
|
Used by TechniqueEntity.recalculate_status() without loading full
|
||||||
|
test models.
|
||||||
|
"""
|
||||||
|
...
|
||||||
4
backend/app/domain/value_objects/__init__.py
Normal file
4
backend/app/domain/value_objects/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
from app.domain.value_objects.mitre_id import MitreId
|
||||||
|
from app.domain.value_objects.scoring_weights import ScoringWeights
|
||||||
|
|
||||||
|
__all__ = ["MitreId", "ScoringWeights"]
|
||||||
51
backend/app/domain/value_objects/mitre_id.py
Normal file
51
backend/app/domain/value_objects/mitre_id.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
"""MitreId — validated MITRE ATT&CK technique identifier.
|
||||||
|
|
||||||
|
Immutable value object that ensures the identifier follows the ATT&CK
|
||||||
|
format: ``T`` followed by 4 digits, optionally a dot and 3 more digits
|
||||||
|
for sub-techniques (e.g. ``T1059``, ``T1059.001``).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
_MITRE_ID_RE = re.compile(r"^T\d{4}(\.\d{3})?$")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class MitreId:
|
||||||
|
"""Validated MITRE ATT&CK technique identifier."""
|
||||||
|
|
||||||
|
value: str
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
if not _MITRE_ID_RE.match(self.value):
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid MITRE ATT&CK ID '{self.value}'. "
|
||||||
|
"Expected format: T1234 or T1234.001"
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_subtechnique(self) -> bool:
|
||||||
|
return "." in self.value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parent_id(self) -> str | None:
|
||||||
|
"""Return the parent technique ID (e.g. T1059 for T1059.001)."""
|
||||||
|
if not self.is_subtechnique:
|
||||||
|
return None
|
||||||
|
return self.value.split(".")[0]
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return self.value
|
||||||
|
|
||||||
|
def __eq__(self, other: object) -> bool:
|
||||||
|
if isinstance(other, MitreId):
|
||||||
|
return self.value == other.value
|
||||||
|
if isinstance(other, str):
|
||||||
|
return self.value == other
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
|
def __hash__(self) -> int:
|
||||||
|
return hash(self.value)
|
||||||
48
backend/app/domain/value_objects/scoring_weights.py
Normal file
48
backend/app/domain/value_objects/scoring_weights.py
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
"""ScoringWeights — validated immutable weight set for the scoring engine.
|
||||||
|
|
||||||
|
Enforces that all five weights are non-negative and sum to exactly 100.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class ScoringWeights:
|
||||||
|
"""Five scoring dimension weights that must sum to 100."""
|
||||||
|
|
||||||
|
tests: float
|
||||||
|
detection_rules: float
|
||||||
|
d3fend: float
|
||||||
|
freshness: float
|
||||||
|
platform_diversity: float
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
fields = [
|
||||||
|
self.tests,
|
||||||
|
self.detection_rules,
|
||||||
|
self.d3fend,
|
||||||
|
self.freshness,
|
||||||
|
self.platform_diversity,
|
||||||
|
]
|
||||||
|
for f in fields:
|
||||||
|
if f < 0:
|
||||||
|
raise ValueError("Scoring weights must be non-negative")
|
||||||
|
|
||||||
|
total = sum(fields)
|
||||||
|
if abs(total - 100) > 0.01:
|
||||||
|
raise ValueError(
|
||||||
|
f"Scoring weights must sum to 100, got {total}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def default(cls) -> ScoringWeights:
|
||||||
|
"""Return the default weight distribution."""
|
||||||
|
return cls(
|
||||||
|
tests=40.0,
|
||||||
|
detection_rules=25.0,
|
||||||
|
d3fend=15.0,
|
||||||
|
freshness=10.0,
|
||||||
|
platform_diversity=10.0,
|
||||||
|
)
|
||||||
@@ -1,30 +1,13 @@
|
|||||||
import enum
|
"""ORM-level re-exports of the canonical domain enums.
|
||||||
|
|
||||||
|
The single source of truth lives in ``app.domain.enums``. This module
|
||||||
|
re-exports every enum so that existing model and router code keeps
|
||||||
|
working with ``from app.models.enums import ...``.
|
||||||
|
"""
|
||||||
|
|
||||||
class TechniqueStatus(str, enum.Enum):
|
from app.domain.enums import ( # noqa: F401
|
||||||
not_evaluated = "not_evaluated"
|
TeamSide,
|
||||||
in_progress = "in_progress"
|
TechniqueStatus,
|
||||||
validated = "validated"
|
TestResult,
|
||||||
partial = "partial"
|
TestState,
|
||||||
not_covered = "not_covered"
|
)
|
||||||
review_required = "review_required"
|
|
||||||
|
|
||||||
|
|
||||||
class TestState(str, enum.Enum):
|
|
||||||
draft = "draft"
|
|
||||||
red_executing = "red_executing" # Red Team documenting attack
|
|
||||||
blue_evaluating = "blue_evaluating" # Blue Team evaluating detection
|
|
||||||
in_review = "in_review"
|
|
||||||
validated = "validated"
|
|
||||||
rejected = "rejected"
|
|
||||||
|
|
||||||
|
|
||||||
class TeamSide(str, enum.Enum):
|
|
||||||
red = "red"
|
|
||||||
blue = "blue"
|
|
||||||
|
|
||||||
|
|
||||||
class TestResult(str, enum.Enum):
|
|
||||||
detected = "detected"
|
|
||||||
not_detected = "not_detected"
|
|
||||||
partially_detected = "partially_detected"
|
|
||||||
|
|||||||
53
backend/tests/test_domain_enums.py
Normal file
53
backend/tests/test_domain_enums.py
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
"""Tests verifying domain enums are canonical and properly re-exported."""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
backend_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||||
|
if backend_dir not in sys.path:
|
||||||
|
sys.path.insert(0, backend_dir)
|
||||||
|
|
||||||
|
from app.domain.enums import TechniqueStatus, TestState, TeamSide, TestResult
|
||||||
|
|
||||||
|
|
||||||
|
def test_technique_status_values():
|
||||||
|
assert TechniqueStatus.not_evaluated == "not_evaluated"
|
||||||
|
assert TechniqueStatus.validated == "validated"
|
||||||
|
assert TechniqueStatus.partial == "partial"
|
||||||
|
assert TechniqueStatus.in_progress == "in_progress"
|
||||||
|
assert TechniqueStatus.not_covered == "not_covered"
|
||||||
|
assert TechniqueStatus.review_required == "review_required"
|
||||||
|
|
||||||
|
|
||||||
|
def test_test_state_values():
|
||||||
|
assert TestState.draft == "draft"
|
||||||
|
assert TestState.red_executing == "red_executing"
|
||||||
|
assert TestState.blue_evaluating == "blue_evaluating"
|
||||||
|
assert TestState.in_review == "in_review"
|
||||||
|
assert TestState.validated == "validated"
|
||||||
|
assert TestState.rejected == "rejected"
|
||||||
|
|
||||||
|
|
||||||
|
def test_team_side_values():
|
||||||
|
assert TeamSide.red == "red"
|
||||||
|
assert TeamSide.blue == "blue"
|
||||||
|
|
||||||
|
|
||||||
|
def test_test_result_values():
|
||||||
|
assert TestResult.detected == "detected"
|
||||||
|
assert TestResult.not_detected == "not_detected"
|
||||||
|
assert TestResult.partially_detected == "partially_detected"
|
||||||
|
|
||||||
|
|
||||||
|
def test_models_enums_reexport_is_same_class():
|
||||||
|
"""Verify models/enums.py re-exports the exact same class objects."""
|
||||||
|
from app.models.enums import (
|
||||||
|
TechniqueStatus as MS,
|
||||||
|
TestState as MTS,
|
||||||
|
TeamSide as MTeam,
|
||||||
|
TestResult as MTR,
|
||||||
|
)
|
||||||
|
assert MS is TechniqueStatus
|
||||||
|
assert MTS is TestState
|
||||||
|
assert MTeam is TeamSide
|
||||||
|
assert MTR is TestResult
|
||||||
168
backend/tests/test_technique_entity.py
Normal file
168
backend/tests/test_technique_entity.py
Normal file
@@ -0,0 +1,168 @@
|
|||||||
|
"""Tests for TechniqueEntity — pure domain logic, no DB."""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
from datetime import datetime
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
backend_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||||
|
if backend_dir not in sys.path:
|
||||||
|
sys.path.insert(0, backend_dir)
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.domain.entities.technique import TechniqueEntity
|
||||||
|
from app.domain.enums import TechniqueStatus
|
||||||
|
|
||||||
|
|
||||||
|
# ── Helpers ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _entity(**overrides) -> TechniqueEntity:
|
||||||
|
defaults = dict(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
mitre_id="T1059",
|
||||||
|
name="Command and Scripting Interpreter",
|
||||||
|
tactic="execution",
|
||||||
|
status_global=TechniqueStatus.not_evaluated,
|
||||||
|
)
|
||||||
|
defaults.update(overrides)
|
||||||
|
return TechniqueEntity(**defaults)
|
||||||
|
|
||||||
|
|
||||||
|
def _fake_orm(**overrides) -> MagicMock:
|
||||||
|
m = MagicMock()
|
||||||
|
m.id = uuid.uuid4()
|
||||||
|
m.mitre_id = "T1059"
|
||||||
|
m.name = "Command and Scripting Interpreter"
|
||||||
|
m.tactic = "execution"
|
||||||
|
m.description = None
|
||||||
|
m.platforms = ["windows", "linux"]
|
||||||
|
m.is_subtechnique = False
|
||||||
|
m.parent_mitre_id = None
|
||||||
|
m.status_global = "not_evaluated"
|
||||||
|
m.review_required = False
|
||||||
|
m.last_review_date = None
|
||||||
|
for k, v in overrides.items():
|
||||||
|
setattr(m, k, v)
|
||||||
|
return m
|
||||||
|
|
||||||
|
|
||||||
|
# ── 1. Factory: create() ────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestCreate:
|
||||||
|
|
||||||
|
def test_valid_technique(self):
|
||||||
|
e = TechniqueEntity.create(
|
||||||
|
mitre_id="T1059",
|
||||||
|
name="Command and Scripting Interpreter",
|
||||||
|
tactic="execution",
|
||||||
|
)
|
||||||
|
assert e.mitre_id == "T1059"
|
||||||
|
assert e.name == "Command and Scripting Interpreter"
|
||||||
|
assert e.status_global == TechniqueStatus.not_evaluated
|
||||||
|
assert not e.is_subtechnique
|
||||||
|
|
||||||
|
def test_valid_subtechnique(self):
|
||||||
|
e = TechniqueEntity.create(
|
||||||
|
mitre_id="T1059.001",
|
||||||
|
name="PowerShell",
|
||||||
|
tactic="execution",
|
||||||
|
)
|
||||||
|
assert e.is_subtechnique
|
||||||
|
assert e.parent_mitre_id == "T1059"
|
||||||
|
|
||||||
|
def test_invalid_mitre_id_raises(self):
|
||||||
|
with pytest.raises(ValueError, match="Invalid MITRE"):
|
||||||
|
TechniqueEntity.create(mitre_id="INVALID", name="Bad", tactic="x")
|
||||||
|
|
||||||
|
def test_platforms_default_to_empty(self):
|
||||||
|
e = TechniqueEntity.create(mitre_id="T1059", name="Test")
|
||||||
|
assert e.platforms == []
|
||||||
|
|
||||||
|
|
||||||
|
# ── 2. from_orm / apply_to ──────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestOrmRoundTrip:
|
||||||
|
|
||||||
|
def test_from_orm_basic(self):
|
||||||
|
orm = _fake_orm()
|
||||||
|
e = TechniqueEntity.from_orm(orm)
|
||||||
|
assert e.mitre_id == "T1059"
|
||||||
|
assert e.status_global == TechniqueStatus.not_evaluated
|
||||||
|
|
||||||
|
def test_from_orm_coerces_string_status(self):
|
||||||
|
orm = _fake_orm(status_global="validated")
|
||||||
|
e = TechniqueEntity.from_orm(orm)
|
||||||
|
assert e.status_global == TechniqueStatus.validated
|
||||||
|
|
||||||
|
def test_apply_to_updates_model(self):
|
||||||
|
orm = _fake_orm()
|
||||||
|
e = TechniqueEntity.from_orm(orm)
|
||||||
|
e.status_global = TechniqueStatus.validated
|
||||||
|
e.review_required = True
|
||||||
|
e.apply_to(orm)
|
||||||
|
assert orm.status_global == TechniqueStatus.validated
|
||||||
|
assert orm.review_required is True
|
||||||
|
|
||||||
|
|
||||||
|
# ── 3. recalculate_status ───────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestRecalculateStatus:
|
||||||
|
|
||||||
|
def test_no_tests_gives_not_evaluated(self):
|
||||||
|
e = _entity()
|
||||||
|
result = e.recalculate_status([])
|
||||||
|
assert result == TechniqueStatus.not_evaluated
|
||||||
|
assert e.status_global == TechniqueStatus.not_evaluated
|
||||||
|
|
||||||
|
def test_all_validated_all_detected_gives_validated(self):
|
||||||
|
e = _entity()
|
||||||
|
tests = [("validated", "detected"), ("validated", "detected")]
|
||||||
|
assert e.recalculate_status(tests) == TechniqueStatus.validated
|
||||||
|
|
||||||
|
def test_all_validated_some_partially_gives_partial(self):
|
||||||
|
e = _entity()
|
||||||
|
tests = [("validated", "detected"), ("validated", "partially_detected")]
|
||||||
|
assert e.recalculate_status(tests) == TechniqueStatus.partial
|
||||||
|
|
||||||
|
def test_all_validated_none_detected_gives_not_covered(self):
|
||||||
|
e = _entity()
|
||||||
|
tests = [("validated", "not_detected"), ("validated", "not_detected")]
|
||||||
|
assert e.recalculate_status(tests) == TechniqueStatus.not_covered
|
||||||
|
|
||||||
|
def test_all_validated_no_results_gives_not_covered(self):
|
||||||
|
e = _entity()
|
||||||
|
tests = [("validated", None)]
|
||||||
|
assert e.recalculate_status(tests) == TechniqueStatus.not_covered
|
||||||
|
|
||||||
|
def test_mixed_validated_and_in_progress_gives_partial(self):
|
||||||
|
e = _entity()
|
||||||
|
tests = [("validated", "detected"), ("draft", None)]
|
||||||
|
assert e.recalculate_status(tests) == TechniqueStatus.partial
|
||||||
|
|
||||||
|
def test_all_in_progress_gives_in_progress(self):
|
||||||
|
e = _entity()
|
||||||
|
tests = [("draft", None), ("red_executing", None)]
|
||||||
|
assert e.recalculate_status(tests) == TechniqueStatus.in_progress
|
||||||
|
|
||||||
|
|
||||||
|
# ── 4. mark_reviewed / flag_for_review ──────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestReviewCycle:
|
||||||
|
|
||||||
|
def test_mark_reviewed_clears_flag(self):
|
||||||
|
e = _entity(review_required=True)
|
||||||
|
e.mark_reviewed()
|
||||||
|
assert e.review_required is False
|
||||||
|
assert e.last_review_date is not None
|
||||||
|
|
||||||
|
def test_flag_for_review(self):
|
||||||
|
e = _entity(review_required=False)
|
||||||
|
e.flag_for_review()
|
||||||
|
assert e.review_required is True
|
||||||
114
backend/tests/test_value_objects.py
Normal file
114
backend/tests/test_value_objects.py
Normal file
@@ -0,0 +1,114 @@
|
|||||||
|
"""Tests for domain value objects: MitreId and ScoringWeights."""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
backend_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||||
|
if backend_dir not in sys.path:
|
||||||
|
sys.path.insert(0, backend_dir)
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.domain.value_objects.mitre_id import MitreId
|
||||||
|
from app.domain.value_objects.scoring_weights import ScoringWeights
|
||||||
|
|
||||||
|
|
||||||
|
# ── MitreId ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestMitreId:
|
||||||
|
|
||||||
|
def test_valid_technique(self):
|
||||||
|
mid = MitreId("T1059")
|
||||||
|
assert mid.value == "T1059"
|
||||||
|
assert str(mid) == "T1059"
|
||||||
|
assert not mid.is_subtechnique
|
||||||
|
assert mid.parent_id is None
|
||||||
|
|
||||||
|
def test_valid_subtechnique(self):
|
||||||
|
mid = MitreId("T1059.001")
|
||||||
|
assert mid.value == "T1059.001"
|
||||||
|
assert mid.is_subtechnique
|
||||||
|
assert mid.parent_id == "T1059"
|
||||||
|
|
||||||
|
def test_invalid_empty_string(self):
|
||||||
|
with pytest.raises(ValueError, match="Invalid MITRE"):
|
||||||
|
MitreId("")
|
||||||
|
|
||||||
|
def test_invalid_no_prefix(self):
|
||||||
|
with pytest.raises(ValueError, match="Invalid MITRE"):
|
||||||
|
MitreId("1059")
|
||||||
|
|
||||||
|
def test_invalid_wrong_prefix(self):
|
||||||
|
with pytest.raises(ValueError, match="Invalid MITRE"):
|
||||||
|
MitreId("A1059")
|
||||||
|
|
||||||
|
def test_invalid_too_few_digits(self):
|
||||||
|
with pytest.raises(ValueError, match="Invalid MITRE"):
|
||||||
|
MitreId("T105")
|
||||||
|
|
||||||
|
def test_invalid_subtechnique_format(self):
|
||||||
|
with pytest.raises(ValueError, match="Invalid MITRE"):
|
||||||
|
MitreId("T1059.01") # needs 3 digits after dot
|
||||||
|
|
||||||
|
def test_invalid_trailing_garbage(self):
|
||||||
|
with pytest.raises(ValueError, match="Invalid MITRE"):
|
||||||
|
MitreId("T1059.001.002")
|
||||||
|
|
||||||
|
def test_equality_with_same_mitre_id(self):
|
||||||
|
assert MitreId("T1059") == MitreId("T1059")
|
||||||
|
|
||||||
|
def test_equality_with_string(self):
|
||||||
|
assert MitreId("T1059") == "T1059"
|
||||||
|
|
||||||
|
def test_inequality(self):
|
||||||
|
assert MitreId("T1059") != MitreId("T1060")
|
||||||
|
|
||||||
|
def test_hashable(self):
|
||||||
|
s = {MitreId("T1059"), MitreId("T1059"), MitreId("T1060")}
|
||||||
|
assert len(s) == 2
|
||||||
|
|
||||||
|
def test_immutable(self):
|
||||||
|
mid = MitreId("T1059")
|
||||||
|
with pytest.raises(AttributeError):
|
||||||
|
mid.value = "T1060"
|
||||||
|
|
||||||
|
|
||||||
|
# ── ScoringWeights ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestScoringWeights:
|
||||||
|
|
||||||
|
def test_valid_default(self):
|
||||||
|
w = ScoringWeights.default()
|
||||||
|
assert w.tests == 40.0
|
||||||
|
assert w.detection_rules == 25.0
|
||||||
|
assert w.d3fend == 15.0
|
||||||
|
assert w.freshness == 10.0
|
||||||
|
assert w.platform_diversity == 10.0
|
||||||
|
|
||||||
|
def test_valid_custom(self):
|
||||||
|
w = ScoringWeights(
|
||||||
|
tests=50, detection_rules=20, d3fend=10,
|
||||||
|
freshness=10, platform_diversity=10,
|
||||||
|
)
|
||||||
|
assert w.tests == 50
|
||||||
|
|
||||||
|
def test_invalid_sum_not_100(self):
|
||||||
|
with pytest.raises(ValueError, match="sum to 100"):
|
||||||
|
ScoringWeights(
|
||||||
|
tests=50, detection_rules=20, d3fend=10,
|
||||||
|
freshness=10, platform_diversity=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_invalid_negative_weight(self):
|
||||||
|
with pytest.raises(ValueError, match="non-negative"):
|
||||||
|
ScoringWeights(
|
||||||
|
tests=-10, detection_rules=40, d3fend=30,
|
||||||
|
freshness=20, platform_diversity=20,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_immutable(self):
|
||||||
|
w = ScoringWeights.default()
|
||||||
|
with pytest.raises(AttributeError):
|
||||||
|
w.tests = 50
|
||||||
Reference in New Issue
Block a user