feat(domain): add domain layer foundation -- enums, value objects, TechniqueEntity, repository ports

This commit is contained in:
2026-02-18 19:10:31 +01:00
parent e651ef8a8c
commit 5c55e7c17f
14 changed files with 761 additions and 28 deletions

View File

@@ -0,0 +1,3 @@
from app.domain.entities.technique import TechniqueEntity
__all__ = ["TechniqueEntity"]

View File

@@ -0,0 +1,159 @@
"""TechniqueEntity — pure domain object for a MITRE ATT&CK technique.
Owns the status recalculation logic that was previously in
``status_service.py``. Has **no** dependency on FastAPI, SQLAlchemy,
or any infrastructure concern.
Usage::
entity = TechniqueEntity.from_orm(technique_orm_model)
entity.recalculate_status(test_states_and_results)
entity.mark_reviewed()
entity.apply_to(technique_orm_model)
"""
from __future__ import annotations
import uuid
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any
from app.domain.enums import TechniqueStatus, TestResult, TestState
from app.domain.value_objects.mitre_id import MitreId
@dataclass(frozen=True)
class _TestSnapshot:
"""Minimal read-only view of a test for status calculation."""
state: TestState
detection_result: str | None
@dataclass
class TechniqueEntity:
"""Pure domain representation of a MITRE ATT&CK technique."""
id: uuid.UUID
mitre_id: str
name: str
tactic: str | None = None
description: str | None = None
platforms: list[str] = field(default_factory=list)
is_subtechnique: bool = False
parent_mitre_id: str | None = None
status_global: TechniqueStatus = TechniqueStatus.not_evaluated
review_required: bool = False
last_review_date: datetime | None = None
# -- Factory -----------------------------------------------------------
@classmethod
def create(
cls,
*,
mitre_id: str,
name: str,
tactic: str | None = None,
description: str | None = None,
platforms: list[str] | None = None,
) -> TechniqueEntity:
"""Create a new technique, validating the MITRE ID format."""
validated_id = MitreId(mitre_id)
return cls(
id=uuid.uuid4(),
mitre_id=validated_id.value,
name=name,
tactic=tactic,
description=description,
platforms=platforms or [],
is_subtechnique=validated_id.is_subtechnique,
parent_mitre_id=validated_id.parent_id,
status_global=TechniqueStatus.not_evaluated,
)
@classmethod
def from_orm(cls, model: Any) -> TechniqueEntity:
"""Build a TechniqueEntity from a SQLAlchemy Technique model."""
raw_status = model.status_global
status = (
raw_status
if isinstance(raw_status, TechniqueStatus)
else TechniqueStatus(raw_status)
)
return cls(
id=model.id,
mitre_id=model.mitre_id,
name=model.name,
tactic=model.tactic,
description=model.description,
platforms=model.platforms or [],
is_subtechnique=model.is_subtechnique or False,
parent_mitre_id=model.parent_mitre_id,
status_global=status,
review_required=model.review_required or False,
last_review_date=model.last_review_date,
)
def apply_to(self, model: Any) -> None:
"""Copy mutable fields back onto the ORM model."""
model.status_global = self.status_global
model.review_required = self.review_required
model.last_review_date = self.last_review_date
# -- Business logic ----------------------------------------------------
def recalculate_status(
self,
test_snapshots: list[tuple[str, str | None]],
) -> TechniqueStatus:
"""Recompute ``status_global`` from a list of (state, detection_result) pairs.
Rules (v2):
1. No tests -> not_evaluated
2. All validated -> inspect detection results:
- All detected -> validated
- Any partially_detected -> partial
- Otherwise -> not_covered
3. Some validated, others in progress -> partial
4. All in intermediate states -> in_progress
Returns the new status (also set on the entity).
"""
tests = [
_TestSnapshot(
state=s if isinstance(s, TestState) else TestState(s),
detection_result=dr,
)
for s, dr in test_snapshots
]
if not tests:
self.status_global = TechniqueStatus.not_evaluated
elif all(t.state == TestState.validated for t in tests):
results = [t.detection_result for t in tests if t.detection_result]
if results and all(r == TestResult.detected or r == "detected" for r in results):
self.status_global = TechniqueStatus.validated
elif any(
r == TestResult.partially_detected or r == "partially_detected"
for r in results
):
self.status_global = TechniqueStatus.partial
else:
self.status_global = TechniqueStatus.not_covered
elif any(t.state == TestState.validated for t in tests):
self.status_global = TechniqueStatus.partial
else:
self.status_global = TechniqueStatus.in_progress
return self.status_global
def mark_reviewed(self) -> None:
"""Mark the technique as reviewed, clearing the review flag."""
self.review_required = False
self.last_review_date = datetime.utcnow()
def flag_for_review(self) -> None:
"""Flag the technique as needing review."""
self.review_required = True

View File

@@ -0,0 +1,37 @@
"""Canonical domain enums for Aegis.
These enums represent core domain concepts and are the single source of
truth. ``models/enums.py`` re-exports them so that existing ORM code
continues to work without changes.
"""
import enum
class TechniqueStatus(str, enum.Enum):
not_evaluated = "not_evaluated"
in_progress = "in_progress"
validated = "validated"
partial = "partial"
not_covered = "not_covered"
review_required = "review_required"
class TestState(str, enum.Enum):
draft = "draft"
red_executing = "red_executing"
blue_evaluating = "blue_evaluating"
in_review = "in_review"
validated = "validated"
rejected = "rejected"
class TeamSide(str, enum.Enum):
red = "red"
blue = "blue"
class TestResult(str, enum.Enum):
detected = "detected"
not_detected = "not_detected"
partially_detected = "partially_detected"

View File

View File

@@ -0,0 +1,4 @@
from app.domain.ports.repositories.technique_repository import TechniqueRepository
from app.domain.ports.repositories.test_repository import TestRepository
__all__ = ["TechniqueRepository", "TestRepository"]

View File

@@ -0,0 +1,57 @@
"""Port defining how the application accesses technique data.
This is a domain contract — implementations live in infrastructure/.
The domain layer NEVER imports the implementation.
"""
from __future__ import annotations
import uuid
from typing import NamedTuple, Protocol, runtime_checkable
from app.domain.entities.technique import TechniqueEntity
from app.domain.enums import TechniqueStatus
class TechniqueWithCounts(NamedTuple):
"""Pre-aggregated technique data for heatmap/scoring."""
entity: TechniqueEntity
test_count: int
validated_test_count: int
detection_rule_count: int
@runtime_checkable
class TechniqueRepository(Protocol):
"""Data access contract for techniques (one per aggregate root)."""
# -- Single-entity access ----------------------------------------------
def find_by_id(self, technique_id: uuid.UUID) -> TechniqueEntity | None: ...
def find_by_mitre_id(self, mitre_id: str) -> TechniqueEntity | None: ...
# -- List access -------------------------------------------------------
def list_all(
self,
*,
tactic: str | None = None,
status: TechniqueStatus | None = None,
review_required: bool | None = None,
) -> list[TechniqueEntity]: ...
def list_by_ids(self, ids: list[uuid.UUID]) -> list[TechniqueEntity]: ...
# -- Batch queries (scoring/heatmap performance) -----------------------
def count_by_status(self) -> dict[TechniqueStatus, int]: ...
def find_all_with_test_counts(self) -> list[TechniqueWithCounts]: ...
# -- Mutations ---------------------------------------------------------
def save(self, technique: TechniqueEntity) -> TechniqueEntity: ...
def exists_by_mitre_id(self, mitre_id: str) -> bool: ...

View File

@@ -0,0 +1,52 @@
"""Port defining how the application accesses test data.
This is a domain contract — implementations live in infrastructure/.
"""
from __future__ import annotations
import uuid
from typing import Protocol, runtime_checkable
from app.domain.enums import TestState
class TestRepository(Protocol):
"""Data access contract for tests."""
# -- Single-entity access ----------------------------------------------
def find_by_id(self, test_id: uuid.UUID) -> object | None:
"""Return a Test ORM model by primary key, or None.
Returns the ORM model directly (not a domain entity) because
the TestEntity is constructed at the service layer via
``TestEntity.from_orm()``.
"""
...
# -- List access -------------------------------------------------------
def list_by_technique(self, technique_id: uuid.UUID) -> list[object]: ...
def list_by_state(self, state: TestState) -> list[object]: ...
def count_by_technique_and_state(
self,
technique_id: uuid.UUID,
) -> dict[TestState, int]:
"""Return test counts grouped by state for a single technique."""
...
# -- Batch queries -----------------------------------------------------
def get_states_and_results_for_technique(
self,
technique_id: uuid.UUID,
) -> list[tuple[str, str | None]]:
"""Return (state, detection_result) pairs for all tests of a technique.
Used by TechniqueEntity.recalculate_status() without loading full
test models.
"""
...

View File

@@ -0,0 +1,4 @@
from app.domain.value_objects.mitre_id import MitreId
from app.domain.value_objects.scoring_weights import ScoringWeights
__all__ = ["MitreId", "ScoringWeights"]

View File

@@ -0,0 +1,51 @@
"""MitreId — validated MITRE ATT&CK technique identifier.
Immutable value object that ensures the identifier follows the ATT&CK
format: ``T`` followed by 4 digits, optionally a dot and 3 more digits
for sub-techniques (e.g. ``T1059``, ``T1059.001``).
"""
from __future__ import annotations
import re
from dataclasses import dataclass
_MITRE_ID_RE = re.compile(r"^T\d{4}(\.\d{3})?$")
@dataclass(frozen=True, slots=True)
class MitreId:
"""Validated MITRE ATT&CK technique identifier."""
value: str
def __post_init__(self) -> None:
if not _MITRE_ID_RE.match(self.value):
raise ValueError(
f"Invalid MITRE ATT&CK ID '{self.value}'. "
"Expected format: T1234 or T1234.001"
)
@property
def is_subtechnique(self) -> bool:
return "." in self.value
@property
def parent_id(self) -> str | None:
"""Return the parent technique ID (e.g. T1059 for T1059.001)."""
if not self.is_subtechnique:
return None
return self.value.split(".")[0]
def __str__(self) -> str:
return self.value
def __eq__(self, other: object) -> bool:
if isinstance(other, MitreId):
return self.value == other.value
if isinstance(other, str):
return self.value == other
return NotImplemented
def __hash__(self) -> int:
return hash(self.value)

View File

@@ -0,0 +1,48 @@
"""ScoringWeights — validated immutable weight set for the scoring engine.
Enforces that all five weights are non-negative and sum to exactly 100.
"""
from __future__ import annotations
from dataclasses import dataclass
@dataclass(frozen=True, slots=True)
class ScoringWeights:
"""Five scoring dimension weights that must sum to 100."""
tests: float
detection_rules: float
d3fend: float
freshness: float
platform_diversity: float
def __post_init__(self) -> None:
fields = [
self.tests,
self.detection_rules,
self.d3fend,
self.freshness,
self.platform_diversity,
]
for f in fields:
if f < 0:
raise ValueError("Scoring weights must be non-negative")
total = sum(fields)
if abs(total - 100) > 0.01:
raise ValueError(
f"Scoring weights must sum to 100, got {total}"
)
@classmethod
def default(cls) -> ScoringWeights:
"""Return the default weight distribution."""
return cls(
tests=40.0,
detection_rules=25.0,
d3fend=15.0,
freshness=10.0,
platform_diversity=10.0,
)