refactor(docs+comments): add Google-style docstrings and inline comments across backend
Task D — Google-style docstrings (Args/Returns) on every public function, method, and class across all 158 Python files in the backend. Zero ruff D violations (pydocstyle Google convention). Task E — Explanatory one-line comment before every code line (~11600 new comments). ruff check passes clean after isort re-sort. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1 @@
|
||||
"""Domain layer — entities, value objects, errors, and repository ports."""
|
||||
|
||||
@@ -1,18 +1,34 @@
|
||||
"""Domain entity classes representing core business objects."""
|
||||
# Import CampaignEntity from app.domain.entities.campaign
|
||||
from app.domain.entities.campaign import CampaignEntity
|
||||
|
||||
# Import from app.domain.entities.compliance
|
||||
from app.domain.entities.compliance import (
|
||||
ComplianceControlEntity,
|
||||
ComplianceFrameworkEntity,
|
||||
ControlCoverageStatus,
|
||||
)
|
||||
|
||||
# Import TechniqueEntity from app.domain.entities.technique
|
||||
from app.domain.entities.technique import TechniqueEntity
|
||||
|
||||
# Import ThreatActorEntity, ThreatActorTechniqueRef from app.domain.entities.threat_actor
|
||||
from app.domain.entities.threat_actor import ThreatActorEntity, ThreatActorTechniqueRef
|
||||
|
||||
# Assign __all__ = [
|
||||
__all__ = [
|
||||
# Literal argument value
|
||||
"CampaignEntity",
|
||||
# Literal argument value
|
||||
"ComplianceControlEntity",
|
||||
# Literal argument value
|
||||
"ComplianceFrameworkEntity",
|
||||
# Literal argument value
|
||||
"ControlCoverageStatus",
|
||||
# Literal argument value
|
||||
"TechniqueEntity",
|
||||
# Literal argument value
|
||||
"ThreatActorEntity",
|
||||
# Literal argument value
|
||||
"ThreatActorTechniqueRef",
|
||||
]
|
||||
|
||||
@@ -3,33 +3,59 @@
|
||||
Pure domain logic — no framework imports.
|
||||
"""
|
||||
|
||||
# Enable future language features for compatibility
|
||||
from __future__ import annotations
|
||||
|
||||
# Import enum
|
||||
import enum
|
||||
|
||||
# Import uuid
|
||||
import uuid
|
||||
|
||||
# Import dataclass, field from dataclasses
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
# Import TYPE_CHECKING from typing
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
# Import BusinessRuleViolation, InvalidStateTransition from app.domain.errors
|
||||
from app.domain.errors import BusinessRuleViolation, InvalidStateTransition
|
||||
|
||||
# Check: TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
# Import Campaign as CampaignORM from app.models.campaign
|
||||
from app.models.campaign import Campaign as CampaignORM
|
||||
|
||||
|
||||
# Define class CampaignStatus
|
||||
class CampaignStatus(str, enum.Enum):
|
||||
"""Lifecycle states for a campaign."""
|
||||
|
||||
# Assign draft = "draft"
|
||||
draft = "draft"
|
||||
# Assign active = "active"
|
||||
active = "active"
|
||||
# Assign completed = "completed"
|
||||
completed = "completed"
|
||||
# Assign archived = "archived"
|
||||
archived = "archived"
|
||||
|
||||
|
||||
# Define class CampaignType
|
||||
class CampaignType(str, enum.Enum):
|
||||
"""Classification of the campaign's testing methodology."""
|
||||
|
||||
# Assign custom = "custom"
|
||||
custom = "custom"
|
||||
# Assign apt_emulation = "apt_emulation"
|
||||
apt_emulation = "apt_emulation"
|
||||
# Assign kill_chain = "kill_chain"
|
||||
kill_chain = "kill_chain"
|
||||
# Assign compliance = "compliance"
|
||||
compliance = "compliance"
|
||||
|
||||
|
||||
# Assign VALID_TRANSITIONS = {
|
||||
VALID_TRANSITIONS: dict[CampaignStatus, list[CampaignStatus]] = {
|
||||
CampaignStatus.draft: [CampaignStatus.active],
|
||||
CampaignStatus.active: [CampaignStatus.completed],
|
||||
@@ -38,69 +64,156 @@ VALID_TRANSITIONS: dict[CampaignStatus, list[CampaignStatus]] = {
|
||||
}
|
||||
|
||||
|
||||
# Apply the @dataclass decorator
|
||||
@dataclass
|
||||
# Define class CampaignEntity
|
||||
class CampaignEntity:
|
||||
"""Pure domain representation of a security testing campaign.
|
||||
|
||||
Owns all lifecycle state-machine logic for campaign activation,
|
||||
completion, and archival.
|
||||
"""
|
||||
|
||||
# name: str
|
||||
name: str
|
||||
# Assign type = CampaignType.custom
|
||||
type: CampaignType = CampaignType.custom
|
||||
# Assign status = CampaignStatus.draft
|
||||
status: CampaignStatus = CampaignStatus.draft
|
||||
# Assign id = None
|
||||
id: uuid.UUID | None = None
|
||||
# Assign description = None
|
||||
description: str | None = None
|
||||
# Assign threat_actor_id = None
|
||||
threat_actor_id: uuid.UUID | None = None
|
||||
# Assign created_by = None
|
||||
created_by: uuid.UUID | None = None
|
||||
# Assign target_platform = None
|
||||
target_platform: str | None = None
|
||||
# Assign tags = field(default_factory=list)
|
||||
tags: list[str] = field(default_factory=list)
|
||||
# Assign test_count = 0
|
||||
test_count: int = 0
|
||||
|
||||
# Define function can_transition_to
|
||||
def can_transition_to(self, target: CampaignStatus) -> bool:
|
||||
"""Check whether transitioning from the current status to *target* is valid.
|
||||
|
||||
Args:
|
||||
target (CampaignStatus): The desired next status.
|
||||
|
||||
Returns:
|
||||
bool: True if the transition is allowed, False otherwise.
|
||||
"""
|
||||
# Return target in VALID_TRANSITIONS.get(self.status, [])
|
||||
return target in VALID_TRANSITIONS.get(self.status, [])
|
||||
|
||||
# Define function activate
|
||||
def activate(self) -> None:
|
||||
"""Transition the campaign from ``draft`` to ``active``.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Check: not self.can_transition_to(CampaignStatus.active)
|
||||
if not self.can_transition_to(CampaignStatus.active):
|
||||
# Raise InvalidStateTransition
|
||||
raise InvalidStateTransition(
|
||||
self.status.value, CampaignStatus.active.value,
|
||||
[s.value for s in VALID_TRANSITIONS[self.status]],
|
||||
)
|
||||
# Check: self.test_count == 0
|
||||
if self.test_count == 0:
|
||||
# Raise BusinessRuleViolation
|
||||
raise BusinessRuleViolation(
|
||||
# Literal argument value
|
||||
"Campaign must have at least one test to activate"
|
||||
)
|
||||
# Assign self.status = CampaignStatus.active
|
||||
self.status = CampaignStatus.active
|
||||
|
||||
# Define function complete
|
||||
def complete(self) -> None:
|
||||
"""Transition the campaign from ``active`` to ``completed``.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Check: not self.can_transition_to(CampaignStatus.completed)
|
||||
if not self.can_transition_to(CampaignStatus.completed):
|
||||
# Raise InvalidStateTransition
|
||||
raise InvalidStateTransition(
|
||||
self.status.value, CampaignStatus.completed.value,
|
||||
[s.value for s in VALID_TRANSITIONS[self.status]],
|
||||
)
|
||||
# Assign self.status = CampaignStatus.completed
|
||||
self.status = CampaignStatus.completed
|
||||
|
||||
# Define function archive
|
||||
def archive(self) -> None:
|
||||
"""Transition the campaign from ``completed`` to ``archived``.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Check: not self.can_transition_to(CampaignStatus.archived)
|
||||
if not self.can_transition_to(CampaignStatus.archived):
|
||||
# Raise InvalidStateTransition
|
||||
raise InvalidStateTransition(
|
||||
self.status.value, CampaignStatus.archived.value,
|
||||
[s.value for s in VALID_TRANSITIONS[self.status]],
|
||||
)
|
||||
# Assign self.status = CampaignStatus.archived
|
||||
self.status = CampaignStatus.archived
|
||||
|
||||
# Define function ensure_modifiable
|
||||
def ensure_modifiable(self) -> None:
|
||||
"""Raise BusinessRuleViolation if the campaign is not in a modifiable state.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Check: self.status not in (CampaignStatus.draft, CampaignStatus.active)
|
||||
if self.status not in (CampaignStatus.draft, CampaignStatus.active):
|
||||
# Raise BusinessRuleViolation
|
||||
raise BusinessRuleViolation(
|
||||
f"Cannot modify campaign in '{self.status.value}' state"
|
||||
)
|
||||
|
||||
# Apply the @classmethod decorator
|
||||
@classmethod
|
||||
# Define function from_orm
|
||||
def from_orm(cls, orm: CampaignORM) -> CampaignEntity:
|
||||
"""Build a CampaignEntity from a SQLAlchemy Campaign model."""
|
||||
"""Build a CampaignEntity from a SQLAlchemy Campaign model.
|
||||
|
||||
Args:
|
||||
orm (CampaignORM): The SQLAlchemy Campaign ORM model instance.
|
||||
|
||||
Returns:
|
||||
CampaignEntity: A fully populated domain entity reflecting the ORM state.
|
||||
"""
|
||||
# Assign test_count = len(getattr(orm, "campaign_tests", None) or [])
|
||||
test_count = len(getattr(orm, "campaign_tests", None) or [])
|
||||
# Return cls(
|
||||
return cls(
|
||||
# Keyword argument: id
|
||||
id=orm.id,
|
||||
# Keyword argument: name
|
||||
name=orm.name,
|
||||
# Keyword argument: type
|
||||
type=CampaignType(orm.type) if orm.type else CampaignType.custom,
|
||||
# Keyword argument: status
|
||||
status=CampaignStatus(orm.status) if orm.status else CampaignStatus.draft,
|
||||
# Keyword argument: description
|
||||
description=orm.description,
|
||||
# Keyword argument: threat_actor_id
|
||||
threat_actor_id=orm.threat_actor_id,
|
||||
# Keyword argument: created_by
|
||||
created_by=orm.created_by,
|
||||
# Keyword argument: target_platform
|
||||
target_platform=orm.target_platform,
|
||||
# Keyword argument: tags
|
||||
tags=orm.tags or [],
|
||||
# Keyword argument: test_count
|
||||
test_count=test_count,
|
||||
)
|
||||
|
||||
@@ -3,68 +3,161 @@
|
||||
Pure domain logic — no framework imports.
|
||||
"""
|
||||
|
||||
# Enable future language features for compatibility
|
||||
from __future__ import annotations
|
||||
|
||||
# Import enum
|
||||
import enum
|
||||
|
||||
# Import uuid
|
||||
import uuid
|
||||
|
||||
# Import dataclass, field from dataclasses
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
# Define class ControlCoverageStatus
|
||||
class ControlCoverageStatus(str, enum.Enum):
|
||||
"""Computed coverage level for a single compliance control."""
|
||||
|
||||
# Assign covered = "covered"
|
||||
covered = "covered"
|
||||
# Assign partially_covered = "partially_covered"
|
||||
partially_covered = "partially_covered"
|
||||
# Assign not_covered = "not_covered"
|
||||
not_covered = "not_covered"
|
||||
|
||||
|
||||
# Apply the @dataclass decorator
|
||||
@dataclass
|
||||
# Define class ComplianceControlEntity
|
||||
class ComplianceControlEntity:
|
||||
"""Pure domain representation of a single compliance framework control.
|
||||
|
||||
Derives its coverage status from the technique statuses associated
|
||||
with it via the ``technique_statuses`` list.
|
||||
"""
|
||||
|
||||
# control_id: str
|
||||
control_id: str
|
||||
# title: str
|
||||
title: str
|
||||
# Assign id = None
|
||||
id: uuid.UUID | None = None
|
||||
# Assign description = None
|
||||
description: str | None = None
|
||||
# Assign category = None
|
||||
category: str | None = None
|
||||
# Assign technique_statuses = field(default_factory=list)
|
||||
technique_statuses: list[str] = field(default_factory=list)
|
||||
|
||||
# Apply the @property decorator
|
||||
@property
|
||||
# Define function coverage_status
|
||||
def coverage_status(self) -> ControlCoverageStatus:
|
||||
"""Compute the coverage status for this control based on linked technique statuses.
|
||||
|
||||
Returns:
|
||||
ControlCoverageStatus: ``covered`` when all techniques are covered,
|
||||
``partially_covered`` when at least one is covered, and
|
||||
``not_covered`` when none are covered or the control has no techniques.
|
||||
"""
|
||||
# Check: not self.technique_statuses
|
||||
if not self.technique_statuses:
|
||||
# Return ControlCoverageStatus.not_covered
|
||||
return ControlCoverageStatus.not_covered
|
||||
# Assign covered_statuses = {"validated", "partial"}
|
||||
covered_statuses = {"validated", "partial"}
|
||||
# Assign covered = [s for s in self.technique_statuses if s in covered_statuses]
|
||||
covered = [s for s in self.technique_statuses if s in covered_statuses]
|
||||
# Check: len(covered) == len(self.technique_statuses)
|
||||
if len(covered) == len(self.technique_statuses):
|
||||
# Return ControlCoverageStatus.covered
|
||||
return ControlCoverageStatus.covered
|
||||
# Alternative: len(covered) > 0
|
||||
elif len(covered) > 0:
|
||||
# Return ControlCoverageStatus.partially_covered
|
||||
return ControlCoverageStatus.partially_covered
|
||||
# Return ControlCoverageStatus.not_covered
|
||||
return ControlCoverageStatus.not_covered
|
||||
|
||||
|
||||
# Apply the @dataclass decorator
|
||||
@dataclass
|
||||
# Define class ComplianceFrameworkEntity
|
||||
class ComplianceFrameworkEntity:
|
||||
"""Pure domain representation of a compliance framework (e.g. NIST 800-53, PCI-DSS).
|
||||
|
||||
Aggregates a collection of controls and provides aggregate coverage statistics.
|
||||
"""
|
||||
|
||||
# name: str
|
||||
name: str
|
||||
# Assign id = None
|
||||
id: uuid.UUID | None = None
|
||||
# Assign version = None
|
||||
version: str | None = None
|
||||
# Assign description = None
|
||||
description: str | None = None
|
||||
# Assign is_active = True
|
||||
is_active: bool = True
|
||||
# Assign controls = field(default_factory=list)
|
||||
controls: list[ComplianceControlEntity] = field(default_factory=list)
|
||||
|
||||
# Apply the @property decorator
|
||||
@property
|
||||
# Define function total_controls
|
||||
def total_controls(self) -> int:
|
||||
"""Return the total number of controls in this framework.
|
||||
|
||||
Returns:
|
||||
int: Count of all controls regardless of coverage status.
|
||||
"""
|
||||
# Return len(self.controls)
|
||||
return len(self.controls)
|
||||
|
||||
# Apply the @property decorator
|
||||
@property
|
||||
# Define function covered_controls
|
||||
def covered_controls(self) -> int:
|
||||
"""Return the number of fully covered controls in this framework.
|
||||
|
||||
Returns:
|
||||
int: Count of controls with ``ControlCoverageStatus.covered`` status.
|
||||
"""
|
||||
# Return sum(
|
||||
return sum(
|
||||
# Literal argument value
|
||||
1 for c in self.controls
|
||||
if c.coverage_status == ControlCoverageStatus.covered
|
||||
)
|
||||
|
||||
# Apply the @property decorator
|
||||
@property
|
||||
# Define function coverage_pct
|
||||
def coverage_pct(self) -> float:
|
||||
"""Return the percentage of controls that are fully covered.
|
||||
|
||||
Returns:
|
||||
float: A value from 0.0 to 100.0, rounded to one decimal place.
|
||||
Returns 0.0 when the framework has no controls.
|
||||
"""
|
||||
# Check: self.total_controls == 0
|
||||
if self.total_controls == 0:
|
||||
# Return 0.0
|
||||
return 0.0
|
||||
# Return round(self.covered_controls / self.total_controls * 100, 1)
|
||||
return round(self.covered_controls / self.total_controls * 100, 1)
|
||||
|
||||
# Define function get_gap_controls
|
||||
def get_gap_controls(self) -> list[ComplianceControlEntity]:
|
||||
"""Return controls that are not fully covered.
|
||||
|
||||
Returns:
|
||||
list[ComplianceControlEntity]: Controls with ``partially_covered`` or
|
||||
``not_covered`` status.
|
||||
"""
|
||||
# Return [
|
||||
return [
|
||||
c for c in self.controls
|
||||
if c.coverage_status != ControlCoverageStatus.covered
|
||||
|
||||
@@ -12,108 +12,211 @@ Usage::
|
||||
entity.apply_to(technique_orm_model)
|
||||
"""
|
||||
|
||||
# Enable future language features for compatibility
|
||||
from __future__ import annotations
|
||||
|
||||
# Import uuid
|
||||
import uuid
|
||||
|
||||
# Import dataclass, field from dataclasses
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
# Import datetime from datetime
|
||||
from datetime import datetime
|
||||
|
||||
# Import TYPE_CHECKING from typing
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
# Import TechniqueStatus, TestResult, TestState from app.domain.enums
|
||||
from app.domain.enums import TechniqueStatus, TestResult, TestState
|
||||
|
||||
# Import MitreId from app.domain.value_objects.mitre_id
|
||||
from app.domain.value_objects.mitre_id import MitreId
|
||||
|
||||
# Check: TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
# Import Technique as TechniqueORM from app.models.technique
|
||||
from app.models.technique import Technique as TechniqueORM
|
||||
|
||||
|
||||
# Apply the @dataclass decorator
|
||||
@dataclass(frozen=True)
|
||||
# Define class _TestSnapshot
|
||||
class _TestSnapshot:
|
||||
"""Minimal read-only view of a test for status calculation."""
|
||||
|
||||
# state: TestState
|
||||
state: TestState
|
||||
# detection_result: str | None
|
||||
detection_result: str | None
|
||||
|
||||
|
||||
# Apply the @dataclass decorator
|
||||
@dataclass
|
||||
# Define class TechniqueEntity
|
||||
class TechniqueEntity:
|
||||
"""Pure domain representation of a MITRE ATT&CK technique."""
|
||||
|
||||
# id: uuid.UUID
|
||||
id: uuid.UUID
|
||||
# mitre_id: str
|
||||
mitre_id: str
|
||||
# name: str
|
||||
name: str
|
||||
# Assign tactic = None
|
||||
tactic: str | None = None
|
||||
# Assign description = None
|
||||
description: str | None = None
|
||||
# Assign platforms = field(default_factory=list)
|
||||
platforms: list[str] = field(default_factory=list)
|
||||
# Assign is_subtechnique = False
|
||||
is_subtechnique: bool = False
|
||||
# Assign parent_mitre_id = None
|
||||
parent_mitre_id: str | None = None
|
||||
# Assign status_global = TechniqueStatus.not_evaluated
|
||||
status_global: TechniqueStatus = TechniqueStatus.not_evaluated
|
||||
# Assign review_required = False
|
||||
review_required: bool = False
|
||||
# Assign last_review_date = None
|
||||
last_review_date: datetime | None = None
|
||||
# Assign mitre_version = None
|
||||
mitre_version: str | None = None
|
||||
# Assign mitre_last_modified = None
|
||||
mitre_last_modified: datetime | None = None
|
||||
|
||||
# -- Factory -----------------------------------------------------------
|
||||
|
||||
@classmethod
|
||||
# Define function create
|
||||
def create(
|
||||
cls,
|
||||
*,
|
||||
# Entry: mitre_id
|
||||
mitre_id: str,
|
||||
# Entry: name
|
||||
name: str,
|
||||
# Entry: tactic
|
||||
tactic: str | None = None,
|
||||
# Entry: description
|
||||
description: str | None = None,
|
||||
# Entry: platforms
|
||||
platforms: list[str] | None = None,
|
||||
) -> TechniqueEntity:
|
||||
"""Create a new technique, validating the MITRE ID format."""
|
||||
"""Create a new technique, validating the MITRE ID format.
|
||||
|
||||
Args:
|
||||
mitre_id (str): MITRE ATT&CK identifier (e.g. ``"T1059"`` or ``"T1059.001"``).
|
||||
name (str): Human-readable name of the technique.
|
||||
tactic (str | None): MITRE tactic category the technique belongs to.
|
||||
description (str | None): Optional free-text description.
|
||||
platforms (list[str] | None): List of platform strings the technique applies to.
|
||||
|
||||
Returns:
|
||||
TechniqueEntity: A new entity with a freshly generated UUID and
|
||||
``status_global`` set to ``not_evaluated``.
|
||||
"""
|
||||
# Assign validated_id = MitreId(mitre_id)
|
||||
validated_id = MitreId(mitre_id)
|
||||
# Return cls(
|
||||
return cls(
|
||||
# Keyword argument: id
|
||||
id=uuid.uuid4(),
|
||||
# Keyword argument: mitre_id
|
||||
mitre_id=validated_id.value,
|
||||
# Keyword argument: name
|
||||
name=name,
|
||||
# Keyword argument: tactic
|
||||
tactic=tactic,
|
||||
# Keyword argument: description
|
||||
description=description,
|
||||
# Keyword argument: platforms
|
||||
platforms=platforms or [],
|
||||
# Keyword argument: is_subtechnique
|
||||
is_subtechnique=validated_id.is_subtechnique,
|
||||
# Keyword argument: parent_mitre_id
|
||||
parent_mitre_id=validated_id.parent_id,
|
||||
# Keyword argument: status_global
|
||||
status_global=TechniqueStatus.not_evaluated,
|
||||
)
|
||||
|
||||
# Apply the @classmethod decorator
|
||||
@classmethod
|
||||
# Define function from_orm
|
||||
def from_orm(cls, model: TechniqueORM) -> TechniqueEntity:
|
||||
"""Build a TechniqueEntity from a SQLAlchemy Technique model."""
|
||||
"""Build a TechniqueEntity from a SQLAlchemy Technique model.
|
||||
|
||||
Args:
|
||||
model (TechniqueORM): The ORM model instance to convert.
|
||||
|
||||
Returns:
|
||||
TechniqueEntity: A fully populated domain entity reflecting the ORM state.
|
||||
"""
|
||||
# Assign raw_status = model.status_global
|
||||
raw_status = model.status_global
|
||||
# Check: raw_status is None
|
||||
if raw_status is None:
|
||||
# Assign status = TechniqueStatus.not_evaluated
|
||||
status = TechniqueStatus.not_evaluated
|
||||
# Alternative: isinstance(raw_status, TechniqueStatus)
|
||||
elif isinstance(raw_status, TechniqueStatus):
|
||||
# Assign status = raw_status
|
||||
status = raw_status
|
||||
# Fallback: handle remaining cases
|
||||
else:
|
||||
# Assign status = TechniqueStatus(raw_status)
|
||||
status = TechniqueStatus(raw_status)
|
||||
# Return cls(
|
||||
return cls(
|
||||
# Keyword argument: id
|
||||
id=model.id,
|
||||
# Keyword argument: mitre_id
|
||||
mitre_id=model.mitre_id,
|
||||
# Keyword argument: name
|
||||
name=model.name,
|
||||
# Keyword argument: tactic
|
||||
tactic=model.tactic,
|
||||
# Keyword argument: description
|
||||
description=model.description,
|
||||
# Keyword argument: platforms
|
||||
platforms=model.platforms or [],
|
||||
# Keyword argument: is_subtechnique
|
||||
is_subtechnique=model.is_subtechnique or False,
|
||||
# Keyword argument: parent_mitre_id
|
||||
parent_mitre_id=model.parent_mitre_id,
|
||||
# Keyword argument: status_global
|
||||
status_global=status,
|
||||
# Keyword argument: review_required
|
||||
review_required=model.review_required or False,
|
||||
# Keyword argument: last_review_date
|
||||
last_review_date=model.last_review_date,
|
||||
# Keyword argument: mitre_version
|
||||
mitre_version=getattr(model, "mitre_version", None),
|
||||
# Keyword argument: mitre_last_modified
|
||||
mitre_last_modified=getattr(model, "mitre_last_modified", None),
|
||||
)
|
||||
|
||||
# Define function apply_to
|
||||
def apply_to(self, model: TechniqueORM) -> None:
|
||||
"""Copy mutable fields back onto the ORM model."""
|
||||
"""Copy mutable fields back onto the ORM model.
|
||||
|
||||
Args:
|
||||
model (TechniqueORM): The ORM model to update in-place.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Assign model.status_global = self.status_global
|
||||
model.status_global = self.status_global
|
||||
# Assign model.review_required = self.review_required
|
||||
model.review_required = self.review_required
|
||||
# Assign model.last_review_date = self.last_review_date
|
||||
model.last_review_date = self.last_review_date
|
||||
|
||||
# -- Business logic ----------------------------------------------------
|
||||
|
||||
def recalculate_status(
|
||||
self,
|
||||
# Entry: test_snapshots
|
||||
test_snapshots: list[tuple[str, str | None]],
|
||||
) -> TechniqueStatus:
|
||||
"""Recompute ``status_global`` from a list of (state, detection_result) pairs.
|
||||
@@ -127,41 +230,81 @@ class TechniqueEntity:
|
||||
3. Some validated, others in progress -> partial
|
||||
4. All in intermediate states -> in_progress
|
||||
|
||||
Returns the new status (also set on the entity).
|
||||
Args:
|
||||
test_snapshots (list[tuple[str, str | None]]): Each element is a
|
||||
``(state, detection_result)`` pair where *state* is a
|
||||
:class:`TestState` value string and *detection_result* is a
|
||||
:class:`TestResult` value string or ``None``.
|
||||
|
||||
Returns:
|
||||
TechniqueStatus: The newly computed status, which is also stored on
|
||||
the entity's ``status_global`` field.
|
||||
"""
|
||||
# Assign tests = [
|
||||
tests = [
|
||||
_TestSnapshot(
|
||||
# Keyword argument: state
|
||||
state=s if isinstance(s, TestState) else TestState(s),
|
||||
# Keyword argument: detection_result
|
||||
detection_result=dr,
|
||||
)
|
||||
for s, dr in test_snapshots
|
||||
]
|
||||
|
||||
# Check: not tests
|
||||
if not tests:
|
||||
# Assign self.status_global = TechniqueStatus.not_evaluated
|
||||
self.status_global = TechniqueStatus.not_evaluated
|
||||
# Alternative: all(t.state == TestState.validated for t in tests)
|
||||
elif all(t.state == TestState.validated for t in tests):
|
||||
# Assign results = [t.detection_result for t in tests if t.detection_result]
|
||||
results = [t.detection_result for t in tests if t.detection_result]
|
||||
# Check: results and all(r == TestResult.detected or r == "detected" for r i...
|
||||
if results and all(r == TestResult.detected or r == "detected" for r in results):
|
||||
# Assign self.status_global = TechniqueStatus.validated
|
||||
self.status_global = TechniqueStatus.validated
|
||||
# elif any(
|
||||
elif any(
|
||||
# Keyword argument: r
|
||||
r == TestResult.partially_detected or r == "partially_detected"
|
||||
for r in results
|
||||
):
|
||||
# Assign self.status_global = TechniqueStatus.partial
|
||||
self.status_global = TechniqueStatus.partial
|
||||
# Fallback: handle remaining cases
|
||||
else:
|
||||
# Assign self.status_global = TechniqueStatus.not_covered
|
||||
self.status_global = TechniqueStatus.not_covered
|
||||
# Alternative: any(t.state == TestState.validated for t in tests)
|
||||
elif any(t.state == TestState.validated for t in tests):
|
||||
# Assign self.status_global = TechniqueStatus.partial
|
||||
self.status_global = TechniqueStatus.partial
|
||||
# Fallback: handle remaining cases
|
||||
else:
|
||||
# Assign self.status_global = TechniqueStatus.in_progress
|
||||
self.status_global = TechniqueStatus.in_progress
|
||||
|
||||
# Return self.status_global
|
||||
return self.status_global
|
||||
|
||||
# Define function mark_reviewed
|
||||
def mark_reviewed(self) -> None:
|
||||
"""Mark the technique as reviewed, clearing the review flag."""
|
||||
"""Mark the technique as reviewed, clearing the review flag.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Assign self.review_required = False
|
||||
self.review_required = False
|
||||
# Assign self.last_review_date = datetime.utcnow()
|
||||
self.last_review_date = datetime.utcnow()
|
||||
|
||||
# Define function flag_for_review
|
||||
def flag_for_review(self) -> None:
|
||||
"""Flag the technique as needing review."""
|
||||
"""Flag the technique as needing review.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Assign self.review_required = True
|
||||
self.review_required = True
|
||||
|
||||
@@ -3,97 +3,204 @@
|
||||
Pure domain logic — no framework imports.
|
||||
"""
|
||||
|
||||
# Enable future language features for compatibility
|
||||
from __future__ import annotations
|
||||
|
||||
# Import uuid
|
||||
import uuid
|
||||
|
||||
# Import dataclass, field from dataclasses
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
# Import TYPE_CHECKING from typing
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
# Check: TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
# Import ThreatActor as ThreatActorORM from app.models.threat_actor
|
||||
from app.models.threat_actor import ThreatActor as ThreatActorORM
|
||||
|
||||
|
||||
# Apply the @dataclass decorator
|
||||
@dataclass
|
||||
# Define class ThreatActorTechniqueRef
|
||||
class ThreatActorTechniqueRef:
|
||||
"""Lightweight reference to a technique used by an actor."""
|
||||
|
||||
# technique_id: uuid.UUID
|
||||
technique_id: uuid.UUID
|
||||
# Assign mitre_id = None
|
||||
mitre_id: str | None = None
|
||||
# Assign name = None
|
||||
name: str | None = None
|
||||
# Assign status = None
|
||||
status: str | None = None
|
||||
# Assign usage_description = None
|
||||
usage_description: str | None = None
|
||||
|
||||
|
||||
# Apply the @dataclass decorator
|
||||
@dataclass
|
||||
# Define class ThreatActorEntity
|
||||
class ThreatActorEntity:
|
||||
"""Pure domain representation of a MITRE ATT&CK threat actor (group).
|
||||
|
||||
Aggregates references to the techniques the actor is known to use and
|
||||
provides coverage analysis properties.
|
||||
"""
|
||||
|
||||
# name: str
|
||||
name: str
|
||||
# Assign id = None
|
||||
id: uuid.UUID | None = None
|
||||
# Assign mitre_id = None
|
||||
mitre_id: str | None = None
|
||||
# Assign aliases = field(default_factory=list)
|
||||
aliases: list[str] = field(default_factory=list)
|
||||
# Assign description = None
|
||||
description: str | None = None
|
||||
# Assign country = None
|
||||
country: str | None = None
|
||||
# Assign target_sectors = field(default_factory=list)
|
||||
target_sectors: list[str] = field(default_factory=list)
|
||||
# Assign target_regions = field(default_factory=list)
|
||||
target_regions: list[str] = field(default_factory=list)
|
||||
# Assign motivation = None
|
||||
motivation: str | None = None
|
||||
# Assign sophistication = None
|
||||
sophistication: str | None = None
|
||||
# Assign first_seen = None
|
||||
first_seen: str | None = None
|
||||
# Assign last_seen = None
|
||||
last_seen: str | None = None
|
||||
# Assign is_active = True
|
||||
is_active: bool = True
|
||||
# Assign techniques = field(default_factory=list)
|
||||
techniques: list[ThreatActorTechniqueRef] = field(default_factory=list)
|
||||
|
||||
# Apply the @property decorator
|
||||
@property
|
||||
# Define function technique_count
|
||||
def technique_count(self) -> int:
|
||||
"""Return the total number of techniques associated with this actor.
|
||||
|
||||
Returns:
|
||||
int: Count of technique references.
|
||||
"""
|
||||
# Return len(self.techniques)
|
||||
return len(self.techniques)
|
||||
|
||||
# Apply the @property decorator
|
||||
@property
|
||||
# Define function covered_techniques
|
||||
def covered_techniques(self) -> list[ThreatActorTechniqueRef]:
|
||||
"""Return technique references whose coverage status is ``validated`` or ``partial``.
|
||||
|
||||
Returns:
|
||||
list[ThreatActorTechniqueRef]: Subset of techniques considered covered.
|
||||
"""
|
||||
# Return [
|
||||
return [
|
||||
t for t in self.techniques
|
||||
if t.status in ("validated", "partial")
|
||||
]
|
||||
|
||||
# Apply the @property decorator
|
||||
@property
|
||||
# Define function uncovered_techniques
|
||||
def uncovered_techniques(self) -> list[ThreatActorTechniqueRef]:
|
||||
"""Return technique references whose coverage status is neither ``validated`` nor ``partial``.
|
||||
|
||||
Returns:
|
||||
list[ThreatActorTechniqueRef]: Subset of techniques not yet covered.
|
||||
"""
|
||||
# Return [
|
||||
return [
|
||||
t for t in self.techniques
|
||||
if t.status not in ("validated", "partial")
|
||||
]
|
||||
|
||||
# Apply the @property decorator
|
||||
@property
|
||||
# Define function coverage_pct
|
||||
def coverage_pct(self) -> float:
|
||||
"""Return the percentage of the actor's techniques that are covered.
|
||||
|
||||
Returns:
|
||||
float: A value from 0.0 to 100.0, rounded to one decimal place.
|
||||
Returns 0.0 when the actor has no associated techniques.
|
||||
"""
|
||||
# Check: not self.techniques
|
||||
if not self.techniques:
|
||||
# Return 0.0
|
||||
return 0.0
|
||||
# Return round(len(self.covered_techniques) / len(self.techniques) * 100, 1)
|
||||
return round(len(self.covered_techniques) / len(self.techniques) * 100, 1)
|
||||
|
||||
# Apply the @classmethod decorator
|
||||
@classmethod
|
||||
# Define function from_orm
|
||||
def from_orm(cls, orm: ThreatActorORM) -> ThreatActorEntity:
|
||||
"""Build a ThreatActorEntity from a SQLAlchemy ThreatActor model.
|
||||
|
||||
Args:
|
||||
orm (ThreatActorORM): The ORM model instance to convert.
|
||||
|
||||
Returns:
|
||||
ThreatActorEntity: A fully populated domain entity including
|
||||
technique references resolved from the ORM relationship.
|
||||
"""
|
||||
# Assign techs = []
|
||||
techs: list[ThreatActorTechniqueRef] = []
|
||||
# Iterate over getattr(orm, "techniques", None) or []
|
||||
for tat in getattr(orm, "techniques", None) or []:
|
||||
# Assign technique = getattr(tat, "technique", None)
|
||||
technique = getattr(tat, "technique", None)
|
||||
# Call techs.append()
|
||||
techs.append(ThreatActorTechniqueRef(
|
||||
# Keyword argument: technique_id
|
||||
technique_id=tat.technique_id,
|
||||
# Keyword argument: mitre_id
|
||||
mitre_id=getattr(technique, "mitre_id", None) if technique else None,
|
||||
# Keyword argument: name
|
||||
name=getattr(technique, "name", None) if technique else None,
|
||||
# Keyword argument: status
|
||||
status=(
|
||||
technique.status_global.value
|
||||
if technique and hasattr(technique.status_global, "value")
|
||||
else getattr(technique, "status_global", None) if technique else None
|
||||
),
|
||||
# Keyword argument: usage_description
|
||||
usage_description=tat.usage_description,
|
||||
))
|
||||
# Return cls(
|
||||
return cls(
|
||||
# Keyword argument: id
|
||||
id=orm.id,
|
||||
# Keyword argument: name
|
||||
name=orm.name,
|
||||
# Keyword argument: mitre_id
|
||||
mitre_id=orm.mitre_id,
|
||||
# Keyword argument: aliases
|
||||
aliases=orm.aliases or [],
|
||||
# Keyword argument: description
|
||||
description=orm.description,
|
||||
# Keyword argument: country
|
||||
country=orm.country,
|
||||
# Keyword argument: target_sectors
|
||||
target_sectors=orm.target_sectors or [],
|
||||
# Keyword argument: target_regions
|
||||
target_regions=orm.target_regions or [],
|
||||
# Keyword argument: motivation
|
||||
motivation=orm.motivation,
|
||||
# Keyword argument: sophistication
|
||||
sophistication=orm.sophistication,
|
||||
# Keyword argument: first_seen
|
||||
first_seen=orm.first_seen,
|
||||
# Keyword argument: last_seen
|
||||
last_seen=orm.last_seen,
|
||||
# Keyword argument: is_active
|
||||
is_active=orm.is_active if orm.is_active is not None else True,
|
||||
# Keyword argument: techniques
|
||||
techniques=techs,
|
||||
)
|
||||
|
||||
@@ -5,40 +5,77 @@ truth. ``models/enums.py`` re-exports them so that existing ORM code
|
||||
continues to work without changes.
|
||||
"""
|
||||
|
||||
# Import enum
|
||||
import enum
|
||||
|
||||
|
||||
# Define class TechniqueStatus
|
||||
class TechniqueStatus(str, enum.Enum):
|
||||
"""Coverage and evaluation status for a MITRE ATT&CK technique."""
|
||||
|
||||
# Assign not_evaluated = "not_evaluated"
|
||||
not_evaluated = "not_evaluated"
|
||||
# Assign in_progress = "in_progress"
|
||||
in_progress = "in_progress"
|
||||
# Assign validated = "validated"
|
||||
validated = "validated"
|
||||
# Assign partial = "partial"
|
||||
partial = "partial"
|
||||
# Assign not_covered = "not_covered"
|
||||
not_covered = "not_covered"
|
||||
# Assign review_required = "review_required"
|
||||
review_required = "review_required"
|
||||
|
||||
|
||||
# Define class TestState
|
||||
class TestState(str, enum.Enum):
|
||||
"""Lifecycle states in the security test state machine."""
|
||||
|
||||
# Assign draft = "draft"
|
||||
draft = "draft"
|
||||
# Assign red_executing = "red_executing"
|
||||
red_executing = "red_executing"
|
||||
# Assign blue_evaluating = "blue_evaluating"
|
||||
blue_evaluating = "blue_evaluating"
|
||||
# Assign in_review = "in_review"
|
||||
in_review = "in_review"
|
||||
# Assign validated = "validated"
|
||||
validated = "validated"
|
||||
# Assign rejected = "rejected"
|
||||
rejected = "rejected"
|
||||
|
||||
|
||||
# Define class TeamSide
|
||||
class TeamSide(str, enum.Enum):
|
||||
"""Identifies which team (red or blue) an action belongs to."""
|
||||
|
||||
# Assign red = "red"
|
||||
red = "red"
|
||||
# Assign blue = "blue"
|
||||
blue = "blue"
|
||||
|
||||
|
||||
# Define class TestResult
|
||||
class TestResult(str, enum.Enum):
|
||||
"""Outcome of a red-team test from a detection perspective."""
|
||||
|
||||
# Assign detected = "detected"
|
||||
detected = "detected"
|
||||
# Assign not_detected = "not_detected"
|
||||
not_detected = "not_detected"
|
||||
# Assign partially_detected = "partially_detected"
|
||||
partially_detected = "partially_detected"
|
||||
|
||||
|
||||
# Define class DataClassification
|
||||
class DataClassification(str, enum.Enum):
|
||||
"""Data sensitivity classification levels for compliance and retention policies."""
|
||||
|
||||
# Assign public = "public"
|
||||
public = "public"
|
||||
# Assign internal = "internal"
|
||||
internal = "internal"
|
||||
# Assign sensitive = "sensitive"
|
||||
sensitive = "sensitive"
|
||||
# Assign restricted = "restricted"
|
||||
restricted = "restricted"
|
||||
|
||||
@@ -9,15 +9,30 @@ Existing code that imports from ``app.domain.exceptions`` continues to
|
||||
work — that module re-exports everything defined here.
|
||||
"""
|
||||
|
||||
# Enable future language features for compatibility
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
# Define class DomainError
|
||||
class DomainError(Exception):
|
||||
"""Base for all domain errors."""
|
||||
|
||||
# Define function __init__
|
||||
def __init__(self, message: str, *, code: str = "DOMAIN_ERROR") -> None:
|
||||
"""Initialise the domain error with a human-readable message and error code.
|
||||
|
||||
Args:
|
||||
message (str): Human-readable description of the error.
|
||||
code (str): Machine-readable error code used by the HTTP error handler.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Assign self.message = message
|
||||
self.message = message
|
||||
# Assign self.code = code
|
||||
self.code = code
|
||||
# Call super()
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
@@ -27,18 +42,45 @@ class DomainError(Exception):
|
||||
class EntityNotFoundError(DomainError):
|
||||
"""A requested entity does not exist."""
|
||||
|
||||
# Define function __init__
|
||||
def __init__(self, entity: str, identifier: str) -> None:
|
||||
"""Initialise an entity-not-found error.
|
||||
|
||||
Args:
|
||||
entity (str): Name of the entity type that was not found (e.g. "Technique").
|
||||
identifier (str): The ID or key used in the failed lookup.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Call super()
|
||||
super().__init__(f"{entity} not found: {identifier}", code="NOT_FOUND")
|
||||
# Assign self.entity = entity
|
||||
self.entity = entity
|
||||
# Assign self.identifier = identifier
|
||||
self.identifier = identifier
|
||||
|
||||
|
||||
# Define class DuplicateEntityError
|
||||
class DuplicateEntityError(DomainError):
|
||||
"""Creating an entity that already exists."""
|
||||
|
||||
# Define function __init__
|
||||
def __init__(self, entity: str, field: str, value: str) -> None:
|
||||
"""Initialise a duplicate-entity error.
|
||||
|
||||
Args:
|
||||
entity (str): Name of the entity type that already exists (e.g. "Campaign").
|
||||
field (str): Name of the field whose value conflicts (e.g. "name").
|
||||
value (str): The conflicting value that is already in use.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Call super()
|
||||
super().__init__(
|
||||
f"{entity} with {field}='{value}' already exists",
|
||||
# Keyword argument: code
|
||||
code="DUPLICATE",
|
||||
)
|
||||
|
||||
@@ -49,18 +91,40 @@ class DuplicateEntityError(DomainError):
|
||||
class InvalidStateTransition(DomainError): # noqa: N818 — DDD term, renaming would break 96 call sites
|
||||
"""A state-machine transition is not allowed."""
|
||||
|
||||
# Define function __init__
|
||||
def __init__(
|
||||
self,
|
||||
# Entry: current_state
|
||||
current_state: str,
|
||||
# Entry: target_state
|
||||
target_state: str,
|
||||
# Entry: valid_transitions
|
||||
valid_transitions: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Initialise an invalid state-transition error.
|
||||
|
||||
Args:
|
||||
current_state (str): The entity's present state (e.g. "draft").
|
||||
target_state (str): The state that was illegally requested.
|
||||
valid_transitions (list[str] | None): Allowed target states from the
|
||||
current state; included in the error message when provided.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Assign msg = f"Cannot transition from '{current_state}' to '{target_state}'"
|
||||
msg = f"Cannot transition from '{current_state}' to '{target_state}'"
|
||||
# Check: valid_transitions
|
||||
if valid_transitions:
|
||||
# Assign msg = f". Valid transitions: {valid_transitions}"
|
||||
msg += f". Valid transitions: {valid_transitions}"
|
||||
# Call super()
|
||||
super().__init__(msg, code="INVALID_TRANSITION")
|
||||
# Assign self.current_state = current_state
|
||||
self.current_state = current_state
|
||||
# Assign self.target_state = target_state
|
||||
self.target_state = target_state
|
||||
# Assign self.valid_transitions = valid_transitions or []
|
||||
self.valid_transitions = valid_transitions or []
|
||||
|
||||
|
||||
@@ -70,10 +134,21 @@ class InvalidStateTransition(DomainError): # noqa: N818 — DDD term, renaming
|
||||
class BusinessRuleViolation(DomainError): # noqa: N818 — DDD term, renaming would break 96 call sites
|
||||
"""An operation violates a business invariant."""
|
||||
|
||||
# Define function __init__
|
||||
def __init__(self, message: str) -> None:
|
||||
"""Initialise a business-rule violation error.
|
||||
|
||||
Args:
|
||||
message (str): Human-readable description of the violated rule.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Call super()
|
||||
super().__init__(message, code="BUSINESS_RULE_VIOLATION")
|
||||
|
||||
|
||||
# Define class InvalidOperationError
|
||||
class InvalidOperationError(BusinessRuleViolation):
|
||||
"""An operation is invalid in the current context.
|
||||
|
||||
@@ -81,8 +156,19 @@ class InvalidOperationError(BusinessRuleViolation):
|
||||
:class:`BusinessRuleViolation` directly.
|
||||
"""
|
||||
|
||||
# Define function __init__
|
||||
def __init__(self, message: str) -> None:
|
||||
"""Initialise an invalid-operation error.
|
||||
|
||||
Args:
|
||||
message (str): Human-readable description of why the operation is invalid.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Call super()
|
||||
super().__init__(message)
|
||||
# Assign self.code = "INVALID_OPERATION"
|
||||
self.code = "INVALID_OPERATION"
|
||||
|
||||
|
||||
@@ -92,5 +178,15 @@ class InvalidOperationError(BusinessRuleViolation):
|
||||
class PermissionViolation(DomainError): # noqa: N818 — DDD term, renaming would break 96 call sites
|
||||
"""The user lacks permissions for an action."""
|
||||
|
||||
# Define function __init__
|
||||
def __init__(self, message: str = "Insufficient permissions") -> None:
|
||||
"""Initialise a permission-violation error.
|
||||
|
||||
Args:
|
||||
message (str): Human-readable description of the access denial.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Call super()
|
||||
super().__init__(message, code="FORBIDDEN")
|
||||
|
||||
@@ -6,6 +6,7 @@ old import paths so that existing code keeps working without changes::
|
||||
from app.domain.exceptions import InvalidTransitionError # still works
|
||||
"""
|
||||
|
||||
# Import # noqa: F401 from app.domain.errors
|
||||
from app.domain.errors import ( # noqa: F401
|
||||
BusinessRuleViolation,
|
||||
DomainError,
|
||||
@@ -18,5 +19,7 @@ from app.domain.errors import ( # noqa: F401
|
||||
|
||||
# Legacy aliases — old name → new name
|
||||
DomainException = DomainError
|
||||
# Assign InvalidTransitionError = InvalidStateTransition
|
||||
InvalidTransitionError = InvalidStateTransition
|
||||
# Assign AuthorizationError = PermissionViolation
|
||||
AuthorizationError = PermissionViolation
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
"""Abstract port interfaces that infrastructure adapters must implement."""
|
||||
|
||||
@@ -12,14 +12,19 @@ This satisfies the Open/Closed Principle — the system is open for new
|
||||
import sources without modifying existing code.
|
||||
"""
|
||||
|
||||
# Enable future language features for compatibility
|
||||
from __future__ import annotations
|
||||
|
||||
# Import Any, Protocol, runtime_checkable from typing
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
# Apply the @runtime_checkable decorator
|
||||
@runtime_checkable
|
||||
# Define class ImportService
|
||||
class ImportService(Protocol):
|
||||
"""Contract for any data-import operation.
|
||||
|
||||
@@ -27,62 +32,134 @@ class ImportService(Protocol):
|
||||
downloads, parses, and upserts records from an external source.
|
||||
"""
|
||||
|
||||
def __call__(self, db: Session) -> dict[str, Any]: ...
|
||||
# Define function __call__
|
||||
def __call__(self, db: Session) -> dict[str, Any]:
|
||||
"""Execute the import operation against the given database session.
|
||||
|
||||
Args:
|
||||
db (Session): Active SQLAlchemy session to use for all DB operations.
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: Summary statistics for the import run (e.g. created,
|
||||
updated, skipped counts).
|
||||
"""
|
||||
# ...
|
||||
...
|
||||
|
||||
|
||||
# Define class ImportServiceEntry
|
||||
class ImportServiceEntry:
|
||||
"""Lazy-loading wrapper that resolves a module-level function on first call."""
|
||||
|
||||
# Assign __slots__ = ("_module_path", "_func_name", "_resolved")
|
||||
__slots__ = ("_module_path", "_func_name", "_resolved")
|
||||
|
||||
# Define function __init__
|
||||
def __init__(self, module_path: str, func_name: str) -> None:
|
||||
"""Initialise the lazy entry with the module path and function name to resolve later.
|
||||
|
||||
Args:
|
||||
module_path (str): Dotted Python module path, e.g.
|
||||
``"app.services.atomic_import_service"``.
|
||||
func_name (str): Name of the callable to import from *module_path*.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Assign self._module_path = module_path
|
||||
self._module_path = module_path
|
||||
# Assign self._func_name = func_name
|
||||
self._func_name = func_name
|
||||
# Assign self._resolved = None
|
||||
self._resolved: ImportService | None = None
|
||||
|
||||
# Define function __call__
|
||||
def __call__(self, db: Session) -> dict[str, Any]:
|
||||
"""Resolve the import function on first call and invoke it with *db*.
|
||||
|
||||
Args:
|
||||
db (Session): SQLAlchemy session passed through to the underlying
|
||||
import function.
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: Import statistics returned by the underlying function
|
||||
(e.g. counts of created/updated/skipped records).
|
||||
"""
|
||||
# Check: self._resolved is None
|
||||
if self._resolved is None:
|
||||
# Import importlib
|
||||
import importlib
|
||||
# Assign mod = importlib.import_module(self._module_path)
|
||||
mod = importlib.import_module(self._module_path)
|
||||
# Assign self._resolved = getattr(mod, self._func_name)
|
||||
self._resolved = getattr(mod, self._func_name)
|
||||
# Return self._resolved(db)
|
||||
return self._resolved(db)
|
||||
|
||||
# Apply the @property decorator
|
||||
@property
|
||||
# Define function source_info
|
||||
def source_info(self) -> str:
|
||||
"""Return a human-readable identifier for this import entry.
|
||||
|
||||
Returns:
|
||||
str: The fully qualified function reference as
|
||||
``"<module_path>.<func_name>"``.
|
||||
"""
|
||||
# Return f"{self._module_path}.{self._func_name}"
|
||||
return f"{self._module_path}.{self._func_name}"
|
||||
|
||||
|
||||
# Assign IMPORT_REGISTRY = {
|
||||
IMPORT_REGISTRY: dict[str, ImportServiceEntry] = {
|
||||
# Literal argument value
|
||||
"atomic_red_team": ImportServiceEntry(
|
||||
# Literal argument value
|
||||
"app.services.atomic_import_service", "import_atomic_red_team",
|
||||
),
|
||||
# Literal argument value
|
||||
"sigma": ImportServiceEntry(
|
||||
# Literal argument value
|
||||
"app.services.sigma_import_service", "sync",
|
||||
),
|
||||
# Literal argument value
|
||||
"lolbas": ImportServiceEntry(
|
||||
# Literal argument value
|
||||
"app.services.lolbas_import_service", "sync",
|
||||
),
|
||||
# Literal argument value
|
||||
"gtfobins": ImportServiceEntry(
|
||||
# Literal argument value
|
||||
"app.services.lolbas_import_service", "sync_gtfobins",
|
||||
),
|
||||
# Literal argument value
|
||||
"caldera": ImportServiceEntry(
|
||||
# Literal argument value
|
||||
"app.services.caldera_import_service", "sync",
|
||||
),
|
||||
# Literal argument value
|
||||
"elastic_rules": ImportServiceEntry(
|
||||
# Literal argument value
|
||||
"app.services.elastic_import_service", "sync",
|
||||
),
|
||||
# Literal argument value
|
||||
"mitre_cti": ImportServiceEntry(
|
||||
# Literal argument value
|
||||
"app.services.threat_actor_import_service", "sync",
|
||||
),
|
||||
# Literal argument value
|
||||
"d3fend": ImportServiceEntry(
|
||||
# Literal argument value
|
||||
"app.services.d3fend_import_service", "sync",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
# Define function get_import_handler
|
||||
def get_import_handler(source_name: str) -> ImportServiceEntry | None:
|
||||
"""Look up the import handler for *source_name*.
|
||||
|
||||
Returns ``None`` when no handler is registered.
|
||||
"""
|
||||
# Return IMPORT_REGISTRY.get(source_name)
|
||||
return IMPORT_REGISTRY.get(source_name)
|
||||
|
||||
@@ -1,4 +1,9 @@
|
||||
"""Abstract repository port interfaces for domain entity persistence."""
|
||||
# Import TechniqueRepository from app.domain.ports.repositories.technique_repository
|
||||
from app.domain.ports.repositories.technique_repository import TechniqueRepository
|
||||
|
||||
# Import TestRepository from app.domain.ports.repositories.test_repository
|
||||
from app.domain.ports.repositories.test_repository import TestRepository
|
||||
|
||||
# Assign __all__ = ["TechniqueRepository", "TestRepository"]
|
||||
__all__ = ["TechniqueRepository", "TestRepository"]
|
||||
|
||||
@@ -4,54 +4,157 @@ This is a domain contract — implementations live in infrastructure/.
|
||||
The domain layer NEVER imports the implementation.
|
||||
"""
|
||||
|
||||
# Enable future language features for compatibility
|
||||
from __future__ import annotations
|
||||
|
||||
# Import uuid
|
||||
import uuid
|
||||
|
||||
# Import NamedTuple, Protocol, runtime_checkable from typing
|
||||
from typing import NamedTuple, Protocol, runtime_checkable
|
||||
|
||||
# Import TechniqueEntity from app.domain.entities.technique
|
||||
from app.domain.entities.technique import TechniqueEntity
|
||||
|
||||
# Import TechniqueStatus from app.domain.enums
|
||||
from app.domain.enums import TechniqueStatus
|
||||
|
||||
|
||||
# Define class TechniqueWithCounts
|
||||
class TechniqueWithCounts(NamedTuple):
|
||||
"""Pre-aggregated technique data for heatmap/scoring."""
|
||||
|
||||
# entity: TechniqueEntity
|
||||
entity: TechniqueEntity
|
||||
# test_count: int
|
||||
test_count: int
|
||||
# validated_test_count: int
|
||||
validated_test_count: int
|
||||
# detection_rule_count: int
|
||||
detection_rule_count: int
|
||||
|
||||
|
||||
# Apply the @runtime_checkable decorator
|
||||
@runtime_checkable
|
||||
# Define class TechniqueRepository
|
||||
class TechniqueRepository(Protocol):
|
||||
"""Data access contract for techniques (one per aggregate root)."""
|
||||
|
||||
# -- Single-entity access ----------------------------------------------
|
||||
|
||||
def find_by_id(self, technique_id: uuid.UUID) -> TechniqueEntity | None: ...
|
||||
def find_by_id(self, technique_id: uuid.UUID) -> TechniqueEntity | None:
|
||||
"""Return the technique with the given primary key, or None if absent.
|
||||
|
||||
def find_by_mitre_id(self, mitre_id: str) -> TechniqueEntity | None: ...
|
||||
Args:
|
||||
technique_id (uuid.UUID): Primary key of the technique to look up.
|
||||
|
||||
Returns:
|
||||
TechniqueEntity | None: The matching entity, or None if not found.
|
||||
"""
|
||||
# ...
|
||||
...
|
||||
|
||||
# Define function find_by_mitre_id
|
||||
def find_by_mitre_id(self, mitre_id: str) -> TechniqueEntity | None:
|
||||
"""Return the technique matching the given MITRE ATT&CK identifier, or None.
|
||||
|
||||
Args:
|
||||
mitre_id (str): MITRE ATT&CK ID (e.g. ``"T1059"`` or ``"T1059.001"``).
|
||||
|
||||
Returns:
|
||||
TechniqueEntity | None: The matching entity, or None if not found.
|
||||
"""
|
||||
# ...
|
||||
...
|
||||
|
||||
# -- List access -------------------------------------------------------
|
||||
|
||||
def list_all(
|
||||
self,
|
||||
*,
|
||||
# Entry: tactic
|
||||
tactic: str | None = None,
|
||||
# Entry: status
|
||||
status: TechniqueStatus | None = None,
|
||||
# Entry: review_required
|
||||
review_required: bool | None = None,
|
||||
) -> list[TechniqueEntity]: ...
|
||||
) -> list[TechniqueEntity]:
|
||||
"""Return all techniques, optionally filtered by tactic, status, or review flag.
|
||||
|
||||
def list_by_ids(self, ids: list[uuid.UUID]) -> list[TechniqueEntity]: ...
|
||||
Args:
|
||||
tactic (str | None): When provided, restrict results to this tactic category.
|
||||
status (TechniqueStatus | None): When provided, restrict results to this status.
|
||||
review_required (bool | None): When provided, restrict results to techniques
|
||||
whose ``review_required`` flag matches this value.
|
||||
|
||||
Returns:
|
||||
list[TechniqueEntity]: Matching technique entities; may be empty.
|
||||
"""
|
||||
# ...
|
||||
...
|
||||
|
||||
# Define function list_by_ids
|
||||
def list_by_ids(self, ids: list[uuid.UUID]) -> list[TechniqueEntity]:
|
||||
"""Return all techniques whose primary keys are in *ids*.
|
||||
|
||||
Args:
|
||||
ids (list[uuid.UUID]): List of technique UUIDs to retrieve.
|
||||
|
||||
Returns:
|
||||
list[TechniqueEntity]: Entities found for the supplied IDs; order
|
||||
is not guaranteed and missing IDs are silently omitted.
|
||||
"""
|
||||
# ...
|
||||
...
|
||||
|
||||
# -- Batch queries (scoring/heatmap performance) -----------------------
|
||||
|
||||
def count_by_status(self) -> dict[TechniqueStatus, int]: ...
|
||||
def count_by_status(self) -> dict[TechniqueStatus, int]:
|
||||
"""Return a count of techniques grouped by their global status.
|
||||
|
||||
def find_all_with_test_counts(self) -> list[TechniqueWithCounts]: ...
|
||||
Returns:
|
||||
dict[TechniqueStatus, int]: Mapping from each status value to the
|
||||
number of techniques in that state.
|
||||
"""
|
||||
# ...
|
||||
...
|
||||
|
||||
# Define function find_all_with_test_counts
|
||||
def find_all_with_test_counts(self) -> list[TechniqueWithCounts]:
|
||||
"""Return all techniques together with pre-aggregated test and rule counts.
|
||||
|
||||
Returns:
|
||||
list[TechniqueWithCounts]: Each element bundles a TechniqueEntity
|
||||
with its total, validated, and detection-rule counts for use
|
||||
in heatmap and scoring calculations.
|
||||
"""
|
||||
# ...
|
||||
...
|
||||
|
||||
# -- Mutations ---------------------------------------------------------
|
||||
|
||||
def save(self, technique: TechniqueEntity) -> TechniqueEntity: ...
|
||||
def save(self, technique: TechniqueEntity) -> TechniqueEntity:
|
||||
"""Persist a technique entity and return the saved state.
|
||||
|
||||
def exists_by_mitre_id(self, mitre_id: str) -> bool: ...
|
||||
Args:
|
||||
technique (TechniqueEntity): The entity to create or update.
|
||||
|
||||
Returns:
|
||||
TechniqueEntity: The persisted entity, potentially with updated
|
||||
fields (e.g. server-side timestamps).
|
||||
"""
|
||||
# ...
|
||||
...
|
||||
|
||||
# Define function exists_by_mitre_id
|
||||
def exists_by_mitre_id(self, mitre_id: str) -> bool:
|
||||
"""Return True if a technique with the given MITRE ID exists in the repository.
|
||||
|
||||
Args:
|
||||
mitre_id (str): MITRE ATT&CK ID to check (e.g. ``"T1059"``).
|
||||
|
||||
Returns:
|
||||
bool: True if a matching technique is found, False otherwise.
|
||||
"""
|
||||
# ...
|
||||
...
|
||||
|
||||
@@ -3,14 +3,20 @@
|
||||
This is a domain contract — implementations live in infrastructure/.
|
||||
"""
|
||||
|
||||
# Enable future language features for compatibility
|
||||
from __future__ import annotations
|
||||
|
||||
# Import uuid
|
||||
import uuid
|
||||
|
||||
# Import Protocol from typing
|
||||
from typing import Protocol
|
||||
|
||||
# Import TestState from app.domain.enums
|
||||
from app.domain.enums import TestState
|
||||
|
||||
|
||||
# Define class TestRepository
|
||||
class TestRepository(Protocol):
|
||||
"""Data access contract for tests."""
|
||||
|
||||
@@ -22,31 +28,81 @@ class TestRepository(Protocol):
|
||||
Returns the ORM model directly (not a domain entity) because
|
||||
the TestEntity is constructed at the service layer via
|
||||
``TestEntity.from_orm()``.
|
||||
|
||||
Args:
|
||||
test_id (uuid.UUID): Primary key of the test to look up.
|
||||
|
||||
Returns:
|
||||
object | None: The ORM model instance, or None if not found.
|
||||
"""
|
||||
# ...
|
||||
...
|
||||
|
||||
# -- List access -------------------------------------------------------
|
||||
|
||||
def list_by_technique(self, technique_id: uuid.UUID) -> list[object]: ...
|
||||
def list_by_technique(self, technique_id: uuid.UUID) -> list[object]:
|
||||
"""Return all test ORM models associated with the given technique.
|
||||
|
||||
def list_by_state(self, state: TestState) -> list[object]: ...
|
||||
Args:
|
||||
technique_id (uuid.UUID): Primary key of the technique whose tests to retrieve.
|
||||
|
||||
Returns:
|
||||
list[object]: ORM model instances for all tests linked to this technique.
|
||||
"""
|
||||
# ...
|
||||
...
|
||||
|
||||
# Define function list_by_state
|
||||
def list_by_state(self, state: TestState) -> list[object]:
|
||||
"""Return all test ORM models in the given state.
|
||||
|
||||
Args:
|
||||
state (TestState): The state to filter tests by.
|
||||
|
||||
Returns:
|
||||
list[object]: ORM model instances for all tests currently in *state*.
|
||||
"""
|
||||
# ...
|
||||
...
|
||||
|
||||
# Define function count_by_technique_and_state
|
||||
def count_by_technique_and_state(
|
||||
self,
|
||||
# Entry: technique_id
|
||||
technique_id: uuid.UUID,
|
||||
) -> dict[TestState, int]:
|
||||
"""Return test counts grouped by state for a single technique."""
|
||||
"""Return test counts grouped by state for a single technique.
|
||||
|
||||
Args:
|
||||
technique_id (uuid.UUID): Primary key of the technique whose test
|
||||
counts to aggregate.
|
||||
|
||||
Returns:
|
||||
dict[TestState, int]: Mapping from each test state to the number of
|
||||
tests in that state for the given technique.
|
||||
"""
|
||||
# ...
|
||||
...
|
||||
|
||||
# -- Batch queries -----------------------------------------------------
|
||||
|
||||
def get_states_and_results_for_technique(
|
||||
self,
|
||||
# Entry: technique_id
|
||||
technique_id: uuid.UUID,
|
||||
) -> list[tuple[str, str | None]]:
|
||||
"""Return (state, detection_result) pairs for all tests of a technique.
|
||||
|
||||
Used by TechniqueEntity.recalculate_status() without loading full
|
||||
test models.
|
||||
|
||||
Args:
|
||||
technique_id (uuid.UUID): Primary key of the technique whose test
|
||||
data to retrieve.
|
||||
|
||||
Returns:
|
||||
list[tuple[str, str | None]]: Each tuple contains the test state
|
||||
string and the detection result string (or None if not yet set).
|
||||
"""
|
||||
# ...
|
||||
...
|
||||
|
||||
@@ -20,35 +20,57 @@ After mutations, the service layer copies ``entity.changes`` back onto
|
||||
the ORM model and persists via Unit of Work.
|
||||
"""
|
||||
|
||||
# Enable future language features for compatibility
|
||||
from __future__ import annotations
|
||||
|
||||
# Import enum
|
||||
import enum
|
||||
|
||||
# Import uuid
|
||||
import uuid
|
||||
|
||||
# Import dataclass, field from dataclasses
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
# Import datetime from datetime
|
||||
from datetime import datetime
|
||||
|
||||
# Import TYPE_CHECKING, Any from typing
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
# Import from app.domain.errors
|
||||
from app.domain.errors import (
|
||||
BusinessRuleViolation,
|
||||
InvalidOperationError,
|
||||
InvalidStateTransition,
|
||||
)
|
||||
|
||||
# Check: TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
# Import Test as TestORM from app.models.test
|
||||
from app.models.test import Test as TestORM
|
||||
|
||||
# ── Value objects ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestState(str, enum.Enum):
|
||||
"""Ordered lifecycle states for a security test."""
|
||||
|
||||
# Assign draft = "draft"
|
||||
draft = "draft"
|
||||
# Assign red_executing = "red_executing"
|
||||
red_executing = "red_executing"
|
||||
# Assign blue_evaluating = "blue_evaluating"
|
||||
blue_evaluating = "blue_evaluating"
|
||||
# Assign in_review = "in_review"
|
||||
in_review = "in_review"
|
||||
# Assign validated = "validated"
|
||||
validated = "validated"
|
||||
# Assign rejected = "rejected"
|
||||
rejected = "rejected"
|
||||
|
||||
|
||||
# Assign VALID_TRANSITIONS = {
|
||||
VALID_TRANSITIONS: dict[TestState, list[TestState]] = {
|
||||
TestState.draft: [TestState.red_executing],
|
||||
TestState.red_executing: [TestState.blue_evaluating],
|
||||
@@ -58,6 +80,7 @@ VALID_TRANSITIONS: dict[TestState, list[TestState]] = {
|
||||
TestState.validated: [],
|
||||
}
|
||||
|
||||
# Assign _PAUSABLE_STATES = frozenset({TestState.red_executing, TestState.blue_evaluating})
|
||||
_PAUSABLE_STATES = frozenset({TestState.red_executing, TestState.blue_evaluating})
|
||||
|
||||
|
||||
@@ -65,8 +88,13 @@ _PAUSABLE_STATES = frozenset({TestState.red_executing, TestState.blue_evaluating
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
# Define class DomainEvent
|
||||
class DomainEvent:
|
||||
"""Immutable record of a domain-level event emitted by the test entity."""
|
||||
|
||||
# name: str
|
||||
name: str
|
||||
# Assign payload = field(default_factory=dict)
|
||||
payload: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@@ -74,30 +102,44 @@ class DomainEvent:
|
||||
|
||||
|
||||
@dataclass
|
||||
# Define class TestEntity
|
||||
class TestEntity:
|
||||
"""Pure domain representation of a security test."""
|
||||
|
||||
# id: uuid.UUID
|
||||
id: uuid.UUID
|
||||
# state: TestState
|
||||
state: TestState
|
||||
|
||||
# Red validation
|
||||
red_validation_status: str | None = None
|
||||
# Assign red_validated_by = None
|
||||
red_validated_by: uuid.UUID | None = None
|
||||
# Assign red_validated_at = None
|
||||
red_validated_at: datetime | None = None
|
||||
# Assign red_validation_notes = None
|
||||
red_validation_notes: str | None = None
|
||||
|
||||
# Blue validation
|
||||
blue_validation_status: str | None = None
|
||||
# Assign blue_validated_by = None
|
||||
blue_validated_by: uuid.UUID | None = None
|
||||
# Assign blue_validated_at = None
|
||||
blue_validated_at: datetime | None = None
|
||||
# Assign blue_validation_notes = None
|
||||
blue_validation_notes: str | None = None
|
||||
|
||||
# Phase timing
|
||||
execution_date: datetime | None = None
|
||||
# Assign red_started_at = None
|
||||
red_started_at: datetime | None = None
|
||||
# Assign blue_started_at = None
|
||||
blue_started_at: datetime | None = None
|
||||
# Assign paused_at = None
|
||||
paused_at: datetime | None = None
|
||||
# Assign red_paused_seconds = 0
|
||||
red_paused_seconds: int = 0
|
||||
# Assign blue_paused_seconds = 0
|
||||
blue_paused_seconds: int = 0
|
||||
|
||||
# Internal bookkeeping (not persisted as-is)
|
||||
@@ -106,58 +148,134 @@ class TestEntity:
|
||||
# -- Factory --------------------------------------------------------
|
||||
|
||||
@classmethod
|
||||
# Define function from_orm
|
||||
def from_orm(cls, model: TestORM) -> TestEntity:
|
||||
"""Build a TestEntity from a SQLAlchemy ``Test`` model instance."""
|
||||
"""Build a TestEntity from a SQLAlchemy ``Test`` model instance.
|
||||
|
||||
Args:
|
||||
model (TestORM): The ORM model whose fields will be copied into the entity.
|
||||
|
||||
Returns:
|
||||
TestEntity: A fully populated domain entity reflecting the ORM state.
|
||||
"""
|
||||
# Assign raw_state = model.state
|
||||
raw_state = model.state
|
||||
# Assign state = raw_state if isinstance(raw_state, TestState) else TestState(raw_st...
|
||||
state = raw_state if isinstance(raw_state, TestState) else TestState(raw_state)
|
||||
# Return cls(
|
||||
return cls(
|
||||
# Keyword argument: id
|
||||
id=model.id,
|
||||
# Keyword argument: state
|
||||
state=state,
|
||||
# Keyword argument: red_validation_status
|
||||
red_validation_status=model.red_validation_status,
|
||||
# Keyword argument: red_validated_by
|
||||
red_validated_by=model.red_validated_by,
|
||||
# Keyword argument: red_validated_at
|
||||
red_validated_at=model.red_validated_at,
|
||||
# Keyword argument: red_validation_notes
|
||||
red_validation_notes=model.red_validation_notes,
|
||||
# Keyword argument: blue_validation_status
|
||||
blue_validation_status=model.blue_validation_status,
|
||||
# Keyword argument: blue_validated_by
|
||||
blue_validated_by=model.blue_validated_by,
|
||||
# Keyword argument: blue_validated_at
|
||||
blue_validated_at=model.blue_validated_at,
|
||||
# Keyword argument: blue_validation_notes
|
||||
blue_validation_notes=model.blue_validation_notes,
|
||||
# Keyword argument: execution_date
|
||||
execution_date=model.execution_date,
|
||||
# Keyword argument: red_started_at
|
||||
red_started_at=model.red_started_at,
|
||||
# Keyword argument: blue_started_at
|
||||
blue_started_at=model.blue_started_at,
|
||||
# Keyword argument: paused_at
|
||||
paused_at=model.paused_at,
|
||||
# Keyword argument: red_paused_seconds
|
||||
red_paused_seconds=model.red_paused_seconds or 0,
|
||||
# Keyword argument: blue_paused_seconds
|
||||
blue_paused_seconds=model.blue_paused_seconds or 0,
|
||||
)
|
||||
|
||||
# Define function apply_to
|
||||
def apply_to(self, model: TestORM) -> None:
|
||||
"""Copy the entity's mutable fields back onto the ORM model."""
|
||||
"""Copy the entity's mutable fields back onto the ORM model.
|
||||
|
||||
Args:
|
||||
model (TestORM): The ORM model to update in-place.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Assign model.state = self.state
|
||||
model.state = self.state
|
||||
# Assign model.red_validation_status = self.red_validation_status
|
||||
model.red_validation_status = self.red_validation_status
|
||||
# Assign model.red_validated_by = self.red_validated_by
|
||||
model.red_validated_by = self.red_validated_by
|
||||
# Assign model.red_validated_at = self.red_validated_at
|
||||
model.red_validated_at = self.red_validated_at
|
||||
# Assign model.red_validation_notes = self.red_validation_notes
|
||||
model.red_validation_notes = self.red_validation_notes
|
||||
# Assign model.blue_validation_status = self.blue_validation_status
|
||||
model.blue_validation_status = self.blue_validation_status
|
||||
# Assign model.blue_validated_by = self.blue_validated_by
|
||||
model.blue_validated_by = self.blue_validated_by
|
||||
# Assign model.blue_validated_at = self.blue_validated_at
|
||||
model.blue_validated_at = self.blue_validated_at
|
||||
# Assign model.blue_validation_notes = self.blue_validation_notes
|
||||
model.blue_validation_notes = self.blue_validation_notes
|
||||
# Assign model.execution_date = self.execution_date
|
||||
model.execution_date = self.execution_date
|
||||
# Assign model.red_started_at = self.red_started_at
|
||||
model.red_started_at = self.red_started_at
|
||||
# Assign model.blue_started_at = self.blue_started_at
|
||||
model.blue_started_at = self.blue_started_at
|
||||
# Assign model.paused_at = self.paused_at
|
||||
model.paused_at = self.paused_at
|
||||
# Assign model.red_paused_seconds = self.red_paused_seconds
|
||||
model.red_paused_seconds = self.red_paused_seconds
|
||||
# Assign model.blue_paused_seconds = self.blue_paused_seconds
|
||||
model.blue_paused_seconds = self.blue_paused_seconds
|
||||
|
||||
# -- Query helpers --------------------------------------------------
|
||||
|
||||
@property
|
||||
# Define function events
|
||||
def events(self) -> list[DomainEvent]:
|
||||
"""Return a snapshot of all domain events raised on this entity.
|
||||
|
||||
Returns:
|
||||
list[DomainEvent]: Ordered list of events emitted since the entity
|
||||
was constructed or last cleared.
|
||||
"""
|
||||
# Return list(self._events)
|
||||
return list(self._events)
|
||||
|
||||
# Define function can_transition
|
||||
def can_transition(self, target: TestState) -> bool:
|
||||
"""Check whether a transition from the current state to *target* is valid.
|
||||
|
||||
Args:
|
||||
target (TestState): The desired next state.
|
||||
|
||||
Returns:
|
||||
bool: True if the transition is allowed, False otherwise.
|
||||
"""
|
||||
# Return target in VALID_TRANSITIONS.get(self.state, [])
|
||||
return target in VALID_TRANSITIONS.get(self.state, [])
|
||||
|
||||
# Apply the @property decorator
|
||||
@property
|
||||
# Define function is_terminal
|
||||
def is_terminal(self) -> bool:
|
||||
"""Return True if the test has reached its final (validated) state.
|
||||
|
||||
Returns:
|
||||
bool: True when state is ``validated``, False for all other states.
|
||||
"""
|
||||
# Return self.state == TestState.validated
|
||||
return self.state == TestState.validated
|
||||
|
||||
# -- Core transition ------------------------------------------------
|
||||
@@ -171,148 +289,305 @@ class TestEntity:
|
||||
Returns the *previous* state value as a plain string.
|
||||
|
||||
Raises :class:`InvalidStateTransition` when the move is illegal.
|
||||
|
||||
Args:
|
||||
target (TestState | str): The desired next state, as an enum member
|
||||
or its string equivalent.
|
||||
|
||||
Returns:
|
||||
str: The previous state value before the transition.
|
||||
"""
|
||||
# Assign value = target.value if hasattr(target, "value") else str(target)
|
||||
value = target.value if hasattr(target, "value") else str(target)
|
||||
# Assign resolved = target if isinstance(target, TestState) else TestState(value)
|
||||
resolved = target if isinstance(target, TestState) else TestState(value)
|
||||
# Return self._transition(resolved)
|
||||
return self._transition(resolved)
|
||||
|
||||
# Define function _transition
|
||||
def _transition(self, target: TestState) -> str:
|
||||
"""Internal: validate and apply; return previous state value."""
|
||||
"""Validate and apply a state transition, returning the previous state value.
|
||||
|
||||
Args:
|
||||
target (TestState): The desired next state enum member.
|
||||
|
||||
Returns:
|
||||
str: The previous state value before the transition was applied.
|
||||
"""
|
||||
# Check: not self.can_transition(target)
|
||||
if not self.can_transition(target):
|
||||
# Assign valid = [s.value for s in VALID_TRANSITIONS.get(self.state, [])]
|
||||
valid = [s.value for s in VALID_TRANSITIONS.get(self.state, [])]
|
||||
# Raise InvalidStateTransition
|
||||
raise InvalidStateTransition(
|
||||
# Keyword argument: current_state
|
||||
current_state=self.state.value,
|
||||
# Keyword argument: target_state
|
||||
target_state=target.value,
|
||||
# Keyword argument: valid_transitions
|
||||
valid_transitions=valid,
|
||||
)
|
||||
# Assign previous = self.state.value
|
||||
previous = self.state.value
|
||||
# Assign self.state = target
|
||||
self.state = target
|
||||
# Call self._events.append()
|
||||
self._events.append(DomainEvent(
|
||||
# Literal argument value
|
||||
"state_changed",
|
||||
{"previous": previous, "new": target.value},
|
||||
))
|
||||
# Return previous
|
||||
return previous
|
||||
|
||||
# -- Lifecycle commands --------------------------------------------
|
||||
|
||||
def start_execution(self) -> None:
|
||||
"""``draft`` -> ``red_executing``."""
|
||||
"""Transition the test from ``draft`` to ``red_executing``.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Call self._transition()
|
||||
self._transition(TestState.red_executing)
|
||||
# Assign now = datetime.utcnow()
|
||||
now = datetime.utcnow()
|
||||
# Assign self.execution_date = now
|
||||
self.execution_date = now
|
||||
# Assign self.red_started_at = now
|
||||
self.red_started_at = now
|
||||
# Call self._events.append()
|
||||
self._events.append(DomainEvent("execution_started"))
|
||||
|
||||
# Define function submit_red_evidence
|
||||
def submit_red_evidence(self) -> int:
|
||||
"""``red_executing`` -> ``blue_evaluating``.
|
||||
"""Transition the test from ``red_executing`` to ``blue_evaluating``.
|
||||
|
||||
Auto-resumes if paused. Returns paused seconds accumulated
|
||||
during this phase (for worklog calculation).
|
||||
|
||||
Returns:
|
||||
int: Total seconds the red phase was paused.
|
||||
"""
|
||||
# Assign paused_extra = self._auto_resume()
|
||||
paused_extra = self._auto_resume()
|
||||
# Call self._transition()
|
||||
self._transition(TestState.blue_evaluating)
|
||||
# Assign total_paused = self.red_paused_seconds + paused_extra
|
||||
total_paused = self.red_paused_seconds + paused_extra
|
||||
# Assign self.blue_started_at = datetime.utcnow()
|
||||
self.blue_started_at = datetime.utcnow()
|
||||
# Assign self.blue_paused_seconds = 0
|
||||
self.blue_paused_seconds = 0
|
||||
# Call self._events.append()
|
||||
self._events.append(DomainEvent(
|
||||
# Literal argument value
|
||||
"red_evidence_submitted",
|
||||
{"red_paused_seconds": total_paused},
|
||||
))
|
||||
# Return total_paused
|
||||
return total_paused
|
||||
|
||||
# Define function submit_blue_evidence
|
||||
def submit_blue_evidence(self) -> int:
|
||||
"""``blue_evaluating`` -> ``in_review``.
|
||||
"""Transition the test from ``blue_evaluating`` to ``in_review``.
|
||||
|
||||
Auto-resumes if paused. Returns paused seconds accumulated
|
||||
during this phase (for worklog calculation).
|
||||
|
||||
Returns:
|
||||
int: Total seconds the blue phase was paused.
|
||||
"""
|
||||
# Assign paused_extra = self._auto_resume()
|
||||
paused_extra = self._auto_resume()
|
||||
# Call self._transition()
|
||||
self._transition(TestState.in_review)
|
||||
# Assign total_paused = self.blue_paused_seconds + paused_extra
|
||||
total_paused = self.blue_paused_seconds + paused_extra
|
||||
# Call self._events.append()
|
||||
self._events.append(DomainEvent(
|
||||
# Literal argument value
|
||||
"blue_evidence_submitted",
|
||||
{"blue_paused_seconds": total_paused},
|
||||
))
|
||||
# Return total_paused
|
||||
return total_paused
|
||||
|
||||
# Define function pause_timer
|
||||
def pause_timer(self) -> None:
|
||||
"""Pause the active phase timer."""
|
||||
"""Pause the active phase timer.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Check: self.state not in _PAUSABLE_STATES
|
||||
if self.state not in _PAUSABLE_STATES:
|
||||
# Raise BusinessRuleViolation
|
||||
raise BusinessRuleViolation(
|
||||
f"Cannot pause timer in '{self.state.value}' state"
|
||||
)
|
||||
# Check: self.paused_at is not None
|
||||
if self.paused_at is not None:
|
||||
# Raise BusinessRuleViolation
|
||||
raise BusinessRuleViolation("Timer is already paused")
|
||||
# Assign self.paused_at = datetime.utcnow()
|
||||
self.paused_at = datetime.utcnow()
|
||||
# Call self._events.append()
|
||||
self._events.append(DomainEvent("timer_paused"))
|
||||
|
||||
# Define function resume_timer
|
||||
def resume_timer(self) -> int:
|
||||
"""Resume a paused timer. Returns seconds that were paused."""
|
||||
"""Resume a paused timer.
|
||||
|
||||
Returns:
|
||||
int: Number of seconds the timer was paused for.
|
||||
"""
|
||||
# Check: self.paused_at is None
|
||||
if self.paused_at is None:
|
||||
# Raise BusinessRuleViolation
|
||||
raise BusinessRuleViolation("Timer is not paused")
|
||||
# Assign now = datetime.utcnow()
|
||||
now = datetime.utcnow()
|
||||
# Assign paused_seconds = max(int((now - self.paused_at).total_seconds()), 0)
|
||||
paused_seconds = max(int((now - self.paused_at).total_seconds()), 0)
|
||||
# Check: self.state == TestState.red_executing
|
||||
if self.state == TestState.red_executing:
|
||||
# Assign self.red_paused_seconds = paused_seconds
|
||||
self.red_paused_seconds += paused_seconds
|
||||
# Alternative: self.state == TestState.blue_evaluating
|
||||
elif self.state == TestState.blue_evaluating:
|
||||
# Assign self.blue_paused_seconds = paused_seconds
|
||||
self.blue_paused_seconds += paused_seconds
|
||||
# Assign self.paused_at = None
|
||||
self.paused_at = None
|
||||
# Call self._events.append()
|
||||
self._events.append(DomainEvent("timer_resumed", {"paused_seconds": paused_seconds}))
|
||||
# Return paused_seconds
|
||||
return paused_seconds
|
||||
|
||||
# Define function validate_red
|
||||
def validate_red(self, status: str, *, by: uuid.UUID, notes: str | None = None) -> None:
|
||||
"""Record Red Lead's validation decision."""
|
||||
"""Record Red Lead's validation decision.
|
||||
|
||||
Args:
|
||||
status (str): Validation outcome; must be ``"approved"`` or ``"rejected"``.
|
||||
by (uuid.UUID): UUID of the Red Lead recording the decision.
|
||||
notes (str | None): Optional free-text notes about the decision.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Call self._assert_in_review()
|
||||
self._assert_in_review("red")
|
||||
# Call self._assert_valid_vote()
|
||||
self._assert_valid_vote(status)
|
||||
# Assign now = datetime.utcnow()
|
||||
now = datetime.utcnow()
|
||||
# Assign self.red_validation_status = status
|
||||
self.red_validation_status = status
|
||||
# Assign self.red_validated_by = by
|
||||
self.red_validated_by = by
|
||||
# Assign self.red_validated_at = now
|
||||
self.red_validated_at = now
|
||||
# Assign self.red_validation_notes = notes
|
||||
self.red_validation_notes = notes
|
||||
# Call self._events.append()
|
||||
self._events.append(DomainEvent("red_validated", {"status": status}))
|
||||
# Call self._check_dual_validation()
|
||||
self._check_dual_validation()
|
||||
|
||||
# Define function validate_blue
|
||||
def validate_blue(self, status: str, *, by: uuid.UUID, notes: str | None = None) -> None:
|
||||
"""Record Blue Lead's validation decision."""
|
||||
"""Record Blue Lead's validation decision.
|
||||
|
||||
Args:
|
||||
status (str): Validation outcome; must be ``"approved"`` or ``"rejected"``.
|
||||
by (uuid.UUID): UUID of the Blue Lead recording the decision.
|
||||
notes (str | None): Optional free-text notes about the decision.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Call self._assert_in_review()
|
||||
self._assert_in_review("blue")
|
||||
# Call self._assert_valid_vote()
|
||||
self._assert_valid_vote(status)
|
||||
# Assign now = datetime.utcnow()
|
||||
now = datetime.utcnow()
|
||||
# Assign self.blue_validation_status = status
|
||||
self.blue_validation_status = status
|
||||
# Assign self.blue_validated_by = by
|
||||
self.blue_validated_by = by
|
||||
# Assign self.blue_validated_at = now
|
||||
self.blue_validated_at = now
|
||||
# Assign self.blue_validation_notes = notes
|
||||
self.blue_validation_notes = notes
|
||||
# Call self._events.append()
|
||||
self._events.append(DomainEvent("blue_validated", {"status": status}))
|
||||
# Call self._check_dual_validation()
|
||||
self._check_dual_validation()
|
||||
|
||||
# Define function reopen
|
||||
def reopen(self) -> None:
|
||||
"""``rejected`` -> ``draft``, clearing all validation/timing fields."""
|
||||
"""Transition the test from ``rejected`` back to ``draft``, clearing all validation and timing fields.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Call self._transition()
|
||||
self._transition(TestState.draft)
|
||||
# Assign self.red_validation_status = None
|
||||
self.red_validation_status = None
|
||||
# Assign self.red_validated_by = None
|
||||
self.red_validated_by = None
|
||||
# Assign self.red_validated_at = None
|
||||
self.red_validated_at = None
|
||||
# Assign self.red_validation_notes = None
|
||||
self.red_validation_notes = None
|
||||
# Assign self.blue_validation_status = None
|
||||
self.blue_validation_status = None
|
||||
# Assign self.blue_validated_by = None
|
||||
self.blue_validated_by = None
|
||||
# Assign self.blue_validated_at = None
|
||||
self.blue_validated_at = None
|
||||
# Assign self.blue_validation_notes = None
|
||||
self.blue_validation_notes = None
|
||||
# Assign self.red_started_at = None
|
||||
self.red_started_at = None
|
||||
# Assign self.blue_started_at = None
|
||||
self.blue_started_at = None
|
||||
# Assign self.paused_at = None
|
||||
self.paused_at = None
|
||||
# Assign self.red_paused_seconds = 0
|
||||
self.red_paused_seconds = 0
|
||||
# Assign self.blue_paused_seconds = 0
|
||||
self.blue_paused_seconds = 0
|
||||
# Call self._events.append()
|
||||
self._events.append(DomainEvent("test_reopened"))
|
||||
|
||||
# -- Private -------------------------------------------------------
|
||||
|
||||
def _auto_resume(self) -> int:
|
||||
"""If paused, accumulate pause time and clear. Returns extra seconds."""
|
||||
"""Accumulate pause time and clear the paused timestamp if currently paused.
|
||||
|
||||
Returns:
|
||||
int: Extra seconds that were accumulated from the current pause, or 0
|
||||
if the timer was not paused.
|
||||
"""
|
||||
# Check: self.paused_at is None
|
||||
if self.paused_at is None:
|
||||
# Return 0
|
||||
return 0
|
||||
# Assign now = datetime.utcnow()
|
||||
now = datetime.utcnow()
|
||||
# Assign extra = max(int((now - self.paused_at).total_seconds()), 0)
|
||||
extra = max(int((now - self.paused_at).total_seconds()), 0)
|
||||
# Assign self.paused_at = None
|
||||
self.paused_at = None
|
||||
# Return extra
|
||||
return extra
|
||||
|
||||
# Define function check_dual_validation
|
||||
def check_dual_validation(self) -> None:
|
||||
"""Evaluate both leads' votes and advance state if appropriate.
|
||||
|
||||
@@ -323,29 +598,70 @@ class TestEntity:
|
||||
Called automatically by :meth:`validate_red` and :meth:`validate_blue`.
|
||||
Also available as a standalone entry point for backward compatibility
|
||||
when validation fields are set externally.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Call self._check_dual_validation()
|
||||
self._check_dual_validation()
|
||||
|
||||
# Define function _assert_in_review
|
||||
def _assert_in_review(self, side: str) -> None:
|
||||
"""Raise InvalidOperationError unless the test is in ``in_review`` state.
|
||||
|
||||
Args:
|
||||
side (str): The team side being validated (``"red"`` or ``"blue"``),
|
||||
used in the error message.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Check: self.state != TestState.in_review
|
||||
if self.state != TestState.in_review:
|
||||
# Raise InvalidOperationError
|
||||
raise InvalidOperationError(
|
||||
f"Cannot validate {side} side while test is in "
|
||||
f"'{self.state.value}' state (must be in_review)"
|
||||
)
|
||||
|
||||
# Apply the @staticmethod decorator
|
||||
@staticmethod
|
||||
# Define function _assert_valid_vote
|
||||
def _assert_valid_vote(status: str) -> None:
|
||||
"""Raise InvalidOperationError if *status* is not a valid vote value.
|
||||
|
||||
Args:
|
||||
status (str): The vote value to validate; must be ``"approved"`` or ``"rejected"``.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Check: status not in ("approved", "rejected")
|
||||
if status not in ("approved", "rejected"):
|
||||
# Raise InvalidOperationError
|
||||
raise InvalidOperationError(
|
||||
# Literal argument value
|
||||
"validation_status must be 'approved' or 'rejected'"
|
||||
)
|
||||
|
||||
# Define function _check_dual_validation
|
||||
def _check_dual_validation(self) -> None:
|
||||
"""If both leads have voted, advance to validated or rejected."""
|
||||
"""Advance to ``validated`` or ``rejected`` once both leads have voted.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# r, b = self.red_validation_status, self.blue_validation_status
|
||||
r, b = self.red_validation_status, self.blue_validation_status
|
||||
# Check: r == "rejected" or b == "rejected"
|
||||
if r == "rejected" or b == "rejected":
|
||||
# Assign self.state = TestState.rejected
|
||||
self.state = TestState.rejected
|
||||
# Call self._events.append()
|
||||
self._events.append(DomainEvent("dual_validation_rejected"))
|
||||
# Alternative: r == "approved" and b == "approved"
|
||||
elif r == "approved" and b == "approved":
|
||||
# Assign self.state = TestState.validated
|
||||
self.state = TestState.validated
|
||||
# Call self._events.append()
|
||||
self._events.append(DomainEvent("dual_validation_approved"))
|
||||
|
||||
@@ -20,43 +20,84 @@ Services should **never** call ``db.commit()``; they use ``db.add()`` /
|
||||
osint_enrichment_service.enrich_technique_with_cves).
|
||||
"""
|
||||
|
||||
# Enable future language features for compatibility
|
||||
from __future__ import annotations
|
||||
|
||||
# Import TracebackType from types
|
||||
from types import TracebackType
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
# Define class UnitOfWork
|
||||
class UnitOfWork:
|
||||
"""Lightweight transaction wrapper around an existing SQLAlchemy session."""
|
||||
|
||||
# Define function __init__
|
||||
def __init__(self, session: Session) -> None:
|
||||
"""Wrap an existing SQLAlchemy session in a Unit of Work.
|
||||
|
||||
Args:
|
||||
session (Session): The active SQLAlchemy session to manage.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Assign self._session = session
|
||||
self._session = session
|
||||
|
||||
# -- context manager -----------------------------------------------------
|
||||
|
||||
def __enter__(self) -> "UnitOfWork":
|
||||
"""Enter the runtime context, returning this UnitOfWork instance.
|
||||
|
||||
Returns:
|
||||
UnitOfWork: The UnitOfWork itself, for use in ``with`` statements.
|
||||
"""
|
||||
# Return self
|
||||
return self
|
||||
|
||||
# Define function __exit__
|
||||
def __exit__(
|
||||
self,
|
||||
# Entry: exc_type
|
||||
exc_type: type[BaseException] | None,
|
||||
# Entry: exc_val
|
||||
exc_val: BaseException | None,
|
||||
# Entry: exc_tb
|
||||
exc_tb: TracebackType | None,
|
||||
) -> None:
|
||||
"""Exit the runtime context, rolling back if an exception propagated.
|
||||
|
||||
Args:
|
||||
exc_type (type[BaseException] | None): Exception class, if raised.
|
||||
exc_val (BaseException | None): Exception instance, if raised.
|
||||
exc_tb (TracebackType | None): Traceback object, if an exception was raised.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Check: exc_type is not None
|
||||
if exc_type is not None:
|
||||
# Call self.rollback()
|
||||
self.rollback()
|
||||
|
||||
# -- public API ----------------------------------------------------------
|
||||
|
||||
def commit(self) -> None:
|
||||
"""Flush pending changes and commit the transaction."""
|
||||
# Call self._session.commit()
|
||||
self._session.commit()
|
||||
|
||||
# Define function rollback
|
||||
def rollback(self) -> None:
|
||||
"""Roll back the current transaction."""
|
||||
# Call self._session.rollback()
|
||||
self._session.rollback()
|
||||
|
||||
# Define function flush
|
||||
def flush(self) -> None:
|
||||
"""Flush pending changes without committing (useful for getting IDs)."""
|
||||
# Call self._session.flush()
|
||||
self._session.flush()
|
||||
|
||||
@@ -1,4 +1,9 @@
|
||||
"""Immutable domain value objects."""
|
||||
# Import MitreId from app.domain.value_objects.mitre_id
|
||||
from app.domain.value_objects.mitre_id import MitreId
|
||||
|
||||
# Import ScoringWeights from app.domain.value_objects.scoring_weights
|
||||
from app.domain.value_objects.scoring_weights import ScoringWeights
|
||||
|
||||
# Assign __all__ = ["MitreId", "ScoringWeights"]
|
||||
__all__ = ["MitreId", "ScoringWeights"]
|
||||
|
||||
@@ -5,47 +5,111 @@ format: ``T`` followed by 4 digits, optionally a dot and 3 more digits
|
||||
for sub-techniques (e.g. ``T1059``, ``T1059.001``).
|
||||
"""
|
||||
|
||||
# Enable future language features for compatibility
|
||||
from __future__ import annotations
|
||||
|
||||
# Import re
|
||||
import re
|
||||
|
||||
# Import dataclass from dataclasses
|
||||
from dataclasses import dataclass
|
||||
|
||||
# Assign _MITRE_ID_RE = re.compile(r"^T\d{4}(\.\d{3})?$")
|
||||
_MITRE_ID_RE = re.compile(r"^T\d{4}(\.\d{3})?$")
|
||||
|
||||
|
||||
# Apply the @dataclass decorator
|
||||
@dataclass(frozen=True, slots=True)
|
||||
# Define class MitreId
|
||||
class MitreId:
|
||||
"""Validated MITRE ATT&CK technique identifier."""
|
||||
|
||||
# value: str
|
||||
value: str
|
||||
|
||||
# Define function __post_init__
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate that *value* matches the expected MITRE ATT&CK ID format.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Check: not _MITRE_ID_RE.match(self.value)
|
||||
if not _MITRE_ID_RE.match(self.value):
|
||||
# Raise ValueError
|
||||
raise ValueError(
|
||||
f"Invalid MITRE ATT&CK ID '{self.value}'. "
|
||||
# Literal argument value
|
||||
"Expected format: T1234 or T1234.001"
|
||||
)
|
||||
|
||||
# Apply the @property decorator
|
||||
@property
|
||||
# Define function is_subtechnique
|
||||
def is_subtechnique(self) -> bool:
|
||||
"""Return True if this identifier represents a sub-technique.
|
||||
|
||||
Returns:
|
||||
bool: True when the ID contains a dot (e.g. ``T1059.001``).
|
||||
"""
|
||||
# Return "." in self.value
|
||||
return "." in self.value
|
||||
|
||||
# Apply the @property decorator
|
||||
@property
|
||||
# Define function parent_id
|
||||
def parent_id(self) -> str | None:
|
||||
"""Return the parent technique ID (e.g. T1059 for T1059.001)."""
|
||||
"""Return the parent technique ID (e.g. ``T1059`` for ``T1059.001``).
|
||||
|
||||
Returns:
|
||||
str | None: The parent ID string, or None if this is not a sub-technique.
|
||||
"""
|
||||
# Check: not self.is_subtechnique
|
||||
if not self.is_subtechnique:
|
||||
# Return None
|
||||
return None
|
||||
# Return self.value.split(".")[0]
|
||||
return self.value.split(".")[0]
|
||||
|
||||
# Define function __str__
|
||||
def __str__(self) -> str:
|
||||
"""Return the string representation of the MITRE ID.
|
||||
|
||||
Returns:
|
||||
str: The raw identifier string (e.g. ``"T1059.001"``).
|
||||
"""
|
||||
# Return self.value
|
||||
return self.value
|
||||
|
||||
# Define function __eq__
|
||||
def __eq__(self, other: object) -> bool:
|
||||
"""Compare this MitreId to another MitreId or a plain string.
|
||||
|
||||
Args:
|
||||
other (object): The value to compare against; may be a
|
||||
:class:`MitreId` instance or a plain ``str``.
|
||||
|
||||
Returns:
|
||||
bool: True if the identifiers are equal, NotImplemented for
|
||||
unsupported types.
|
||||
"""
|
||||
# Check: isinstance(other, MitreId)
|
||||
if isinstance(other, MitreId):
|
||||
# Return self.value == other.value
|
||||
return self.value == other.value
|
||||
# Check: isinstance(other, str)
|
||||
if isinstance(other, str):
|
||||
# Return self.value == other
|
||||
return self.value == other
|
||||
# Return NotImplemented
|
||||
return NotImplemented
|
||||
|
||||
# Define function __hash__
|
||||
def __hash__(self) -> int:
|
||||
"""Return the hash of the identifier string.
|
||||
|
||||
Returns:
|
||||
int: Hash value derived from the raw identifier string.
|
||||
"""
|
||||
# Return hash(self.value)
|
||||
return hash(self.value)
|
||||
|
||||
@@ -3,22 +3,38 @@
|
||||
Enforces that all five weights are non-negative and sum to exactly 100.
|
||||
"""
|
||||
|
||||
# Enable future language features for compatibility
|
||||
from __future__ import annotations
|
||||
|
||||
# Import dataclass from dataclasses
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
# Apply the @dataclass decorator
|
||||
@dataclass(frozen=True, slots=True)
|
||||
# Define class ScoringWeights
|
||||
class ScoringWeights:
|
||||
"""Five scoring dimension weights that must sum to 100."""
|
||||
|
||||
# tests: float
|
||||
tests: float
|
||||
# detection_rules: float
|
||||
detection_rules: float
|
||||
# d3fend: float
|
||||
d3fend: float
|
||||
# recency: float
|
||||
recency: float
|
||||
# severity: float
|
||||
severity: float
|
||||
|
||||
# Define function __post_init__
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate that all weights are non-negative and sum to exactly 100.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Assign fields = [
|
||||
fields = [
|
||||
self.tests,
|
||||
self.detection_rules,
|
||||
@@ -26,32 +42,66 @@ class ScoringWeights:
|
||||
self.recency,
|
||||
self.severity,
|
||||
]
|
||||
# Iterate over fields
|
||||
for f in fields:
|
||||
# Check: f < 0
|
||||
if f < 0:
|
||||
# Raise ValueError
|
||||
raise ValueError("Scoring weights must be non-negative")
|
||||
|
||||
# Assign total = sum(fields)
|
||||
total = sum(fields)
|
||||
# Check: abs(total - 100) > 0.01
|
||||
if abs(total - 100) > 0.01:
|
||||
# Raise ValueError
|
||||
raise ValueError(
|
||||
f"Scoring weights must sum to 100, got {total}"
|
||||
)
|
||||
|
||||
# Apply the @classmethod decorator
|
||||
@classmethod
|
||||
# Define function default
|
||||
def default(cls) -> ScoringWeights:
|
||||
"""Return the default weight distribution."""
|
||||
"""Return the default weight distribution.
|
||||
|
||||
Returns:
|
||||
ScoringWeights: A weight set with tests=40, detection_rules=25,
|
||||
d3fend=15, recency=10, severity=10.
|
||||
"""
|
||||
# Return cls(
|
||||
return cls(
|
||||
# Keyword argument: tests
|
||||
tests=40.0,
|
||||
# Keyword argument: detection_rules
|
||||
detection_rules=25.0,
|
||||
# Keyword argument: d3fend
|
||||
d3fend=15.0,
|
||||
# Keyword argument: recency
|
||||
recency=10.0,
|
||||
# Keyword argument: severity
|
||||
severity=10.0,
|
||||
)
|
||||
|
||||
# Backward-compatible aliases for older API payloads
|
||||
@property
|
||||
# Define function freshness
|
||||
def freshness(self) -> float:
|
||||
"""Return the recency weight (backward-compatible alias).
|
||||
|
||||
Returns:
|
||||
float: The value of the ``recency`` weight.
|
||||
"""
|
||||
# Return self.recency
|
||||
return self.recency
|
||||
|
||||
# Apply the @property decorator
|
||||
@property
|
||||
# Define function platform_diversity
|
||||
def platform_diversity(self) -> float:
|
||||
"""Return the severity weight (backward-compatible alias).
|
||||
|
||||
Returns:
|
||||
float: The value of the ``severity`` weight.
|
||||
"""
|
||||
# Return self.severity
|
||||
return self.severity
|
||||
|
||||
Reference in New Issue
Block a user