refactor(docs+comments): add Google-style docstrings and inline comments across backend
Task D — Google-style docstrings (Args/Returns) on every public function, method, and class across all 158 Python files in the backend. Zero ruff D violations (pydocstyle Google convention). Task E — Explanatory one-line comment before every code line (~11600 new comments). ruff check passes clean after isort re-sort.
This commit is contained in:
@@ -0,0 +1 @@
|
||||
"""SQLAlchemy-based persistence adapters for the domain repository ports."""
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
"""ORM-to-domain entity mapper functions."""
|
||||
|
||||
@@ -1,19 +1,28 @@
|
||||
"""Technique ORM model <-> domain entity mapper."""
|
||||
|
||||
# Enable future language features for compatibility
|
||||
from __future__ import annotations
|
||||
|
||||
# Import TechniqueEntity from app.domain.entities.technique
|
||||
from app.domain.entities.technique import TechniqueEntity
|
||||
|
||||
|
||||
# Define class TechniqueMapper
|
||||
class TechniqueMapper:
|
||||
"""Converts between SQLAlchemy Technique model and TechniqueEntity."""
|
||||
|
||||
# Apply the @staticmethod decorator
|
||||
@staticmethod
|
||||
# Define function to_entity
|
||||
def to_entity(model: object) -> TechniqueEntity:
|
||||
"""Convert an ORM Technique model to a domain TechniqueEntity."""
|
||||
# Return TechniqueEntity.from_orm(model)
|
||||
return TechniqueEntity.from_orm(model)
|
||||
|
||||
# Apply the @staticmethod decorator
|
||||
@staticmethod
|
||||
# Define function to_model_updates
|
||||
def to_model_updates(entity: TechniqueEntity, model: object) -> None:
|
||||
"""Apply entity changes back onto an existing ORM model."""
|
||||
# Call entity.apply_to()
|
||||
entity.apply_to(model)
|
||||
|
||||
@@ -1,8 +1,13 @@
|
||||
"""Concrete SQLAlchemy repository implementations."""
|
||||
# Import from app.infrastructure.persistence.repositories.sa_technique_repository
|
||||
from app.infrastructure.persistence.repositories.sa_technique_repository import (
|
||||
SATechniqueRepository,
|
||||
)
|
||||
|
||||
# Import from app.infrastructure.persistence.repositories.sa_test_repository
|
||||
from app.infrastructure.persistence.repositories.sa_test_repository import (
|
||||
SATestRepository,
|
||||
)
|
||||
|
||||
# Assign __all__ = ["SATechniqueRepository", "SATestRepository"]
|
||||
__all__ = ["SATechniqueRepository", "SATestRepository"]
|
||||
|
||||
@@ -4,44 +4,95 @@ Receives a Session from the caller — does NOT create its own.
|
||||
Does NOT call commit() — the Unit of Work owns that.
|
||||
"""
|
||||
|
||||
# Enable future language features for compatibility
|
||||
from __future__ import annotations
|
||||
|
||||
# Import uuid
|
||||
import uuid
|
||||
|
||||
# Import func from sqlalchemy
|
||||
from sqlalchemy import func
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import TechniqueEntity from app.domain.entities.technique
|
||||
from app.domain.entities.technique import TechniqueEntity
|
||||
|
||||
# Import TechniqueStatus, TestState from app.domain.enums
|
||||
from app.domain.enums import TechniqueStatus, TestState
|
||||
|
||||
# Import TechniqueWithCounts from app.domain.ports.repositories.technique_repository
|
||||
from app.domain.ports.repositories.technique_repository import TechniqueWithCounts
|
||||
|
||||
# Import TechniqueMapper from app.infrastructure.persistence.mappers.technique_mapper
|
||||
from app.infrastructure.persistence.mappers.technique_mapper import TechniqueMapper
|
||||
|
||||
# Import DetectionRule from app.models.detection_rule
|
||||
from app.models.detection_rule import DetectionRule
|
||||
|
||||
# Import Technique from app.models.technique
|
||||
from app.models.technique import Technique
|
||||
|
||||
# Import Test from app.models.test
|
||||
from app.models.test import Test
|
||||
|
||||
|
||||
# Define class SATechniqueRepository
|
||||
class SATechniqueRepository:
|
||||
"""Concrete repository backed by SQLAlchemy."""
|
||||
|
||||
# Define function __init__
|
||||
def __init__(self, session: Session) -> None:
|
||||
"""Initialise the repository with a caller-provided session.
|
||||
|
||||
Args:
|
||||
session (Session): The SQLAlchemy session to use for all queries.
|
||||
"""
|
||||
# Assign self._session = session
|
||||
self._session = session
|
||||
|
||||
# -- Single-entity access ----------------------------------------------
|
||||
|
||||
def find_by_id(self, technique_id: uuid.UUID) -> TechniqueEntity | None:
|
||||
"""Return a single technique by its primary key.
|
||||
|
||||
Args:
|
||||
technique_id (uuid.UUID): The UUID primary key of the technique.
|
||||
|
||||
Returns:
|
||||
TechniqueEntity | None: The matching entity, or ``None`` if not found.
|
||||
"""
|
||||
# Assign model = (
|
||||
model = (
|
||||
self._session.query(Technique)
|
||||
# Chain .filter() call
|
||||
.filter(Technique.id == technique_id)
|
||||
# Chain .first() call
|
||||
.first()
|
||||
)
|
||||
# Return TechniqueMapper.to_entity(model) if model else None
|
||||
return TechniqueMapper.to_entity(model) if model else None
|
||||
|
||||
# Define function find_by_mitre_id
|
||||
def find_by_mitre_id(self, mitre_id: str) -> TechniqueEntity | None:
|
||||
"""Return a single technique by its MITRE ATT&CK ID (e.g. ``T1059.001``).
|
||||
|
||||
Args:
|
||||
mitre_id (str): The MITRE ATT&CK identifier string.
|
||||
|
||||
Returns:
|
||||
TechniqueEntity | None: The matching entity, or ``None`` if not found.
|
||||
"""
|
||||
# Assign model = (
|
||||
model = (
|
||||
self._session.query(Technique)
|
||||
# Chain .filter() call
|
||||
.filter(Technique.mitre_id == mitre_id)
|
||||
# Chain .first() call
|
||||
.first()
|
||||
)
|
||||
# Return TechniqueMapper.to_entity(model) if model else None
|
||||
return TechniqueMapper.to_entity(model) if model else None
|
||||
|
||||
# -- List access -------------------------------------------------------
|
||||
@@ -49,57 +100,111 @@ class SATechniqueRepository:
|
||||
def list_all(
|
||||
self,
|
||||
*,
|
||||
# Entry: tactic
|
||||
tactic: str | None = None,
|
||||
# Entry: status
|
||||
status: TechniqueStatus | None = None,
|
||||
# Entry: review_required
|
||||
review_required: bool | None = None,
|
||||
) -> list[TechniqueEntity]:
|
||||
"""Return all techniques, optionally filtered by tactic, status, or review flag.
|
||||
|
||||
Args:
|
||||
tactic (str | None): Filter to techniques belonging to this tactic name.
|
||||
status (TechniqueStatus | None): Filter to techniques with this coverage status.
|
||||
review_required (bool | None): Filter to techniques where ``review_required`` matches.
|
||||
|
||||
Returns:
|
||||
list[TechniqueEntity]: Ordered list of matching technique entities.
|
||||
"""
|
||||
# Assign query = self._session.query(Technique)
|
||||
query = self._session.query(Technique)
|
||||
# Check: tactic is not None
|
||||
if tactic is not None:
|
||||
# Assign query = query.filter(Technique.tactic == tactic)
|
||||
query = query.filter(Technique.tactic == tactic)
|
||||
# Check: status is not None
|
||||
if status is not None:
|
||||
# Assign query = query.filter(Technique.status_global == status)
|
||||
query = query.filter(Technique.status_global == status)
|
||||
# Check: review_required is not None
|
||||
if review_required is not None:
|
||||
# Assign query = query.filter(Technique.review_required == review_required)
|
||||
query = query.filter(Technique.review_required == review_required)
|
||||
# Assign models = query.order_by(Technique.mitre_id).all()
|
||||
models = query.order_by(Technique.mitre_id).all()
|
||||
# Return [TechniqueMapper.to_entity(m) for m in models]
|
||||
return [TechniqueMapper.to_entity(m) for m in models]
|
||||
|
||||
# Define function list_by_ids
|
||||
def list_by_ids(self, ids: list[uuid.UUID]) -> list[TechniqueEntity]:
|
||||
"""Return techniques matching the provided list of UUIDs.
|
||||
|
||||
Args:
|
||||
ids (list[uuid.UUID]): UUIDs of the techniques to retrieve.
|
||||
|
||||
Returns:
|
||||
list[TechniqueEntity]: Technique entities corresponding to the given IDs.
|
||||
"""
|
||||
# Check: not ids
|
||||
if not ids:
|
||||
# Return []
|
||||
return []
|
||||
# Assign models = (
|
||||
models = (
|
||||
self._session.query(Technique)
|
||||
# Chain .filter() call
|
||||
.filter(Technique.id.in_(ids))
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
# Return [TechniqueMapper.to_entity(m) for m in models]
|
||||
return [TechniqueMapper.to_entity(m) for m in models]
|
||||
|
||||
# -- Batch queries (for scoring/heatmap) -------------------------------
|
||||
|
||||
def count_by_status(self) -> dict[TechniqueStatus, int]:
|
||||
"""Return a count of techniques grouped by their coverage status.
|
||||
|
||||
Returns:
|
||||
dict[TechniqueStatus, int]: Mapping of each status value to its technique count.
|
||||
"""
|
||||
# Assign rows = (
|
||||
rows = (
|
||||
self._session.query(
|
||||
Technique.status_global,
|
||||
func.count(Technique.id),
|
||||
)
|
||||
# Chain .group_by() call
|
||||
.group_by(Technique.status_global)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
# Assign result = {s: 0 for s in TechniqueStatus}
|
||||
result = {s: 0 for s in TechniqueStatus}
|
||||
# Iterate over rows
|
||||
for status_val, count in rows:
|
||||
# Assign key = (
|
||||
key = (
|
||||
status_val
|
||||
if isinstance(status_val, TechniqueStatus)
|
||||
else TechniqueStatus(status_val)
|
||||
)
|
||||
# Assign result[key] = count
|
||||
result[key] = count
|
||||
# Return result
|
||||
return result
|
||||
|
||||
# Define function find_all_with_test_counts
|
||||
def find_all_with_test_counts(self) -> list[TechniqueWithCounts]:
|
||||
"""Single query replacing the N+1 pattern.
|
||||
"""Return all techniques with pre-aggregated test and detection rule counts.
|
||||
|
||||
Returns all techniques with pre-aggregated test and detection
|
||||
rule counts via subqueries.
|
||||
Uses a single query with subqueries to avoid the N+1 pattern.
|
||||
|
||||
Returns:
|
||||
list[TechniqueWithCounts]: All techniques with their associated counts.
|
||||
"""
|
||||
# Assign test_count_sq = (
|
||||
test_count_sq = (
|
||||
self._session.query(
|
||||
Test.technique_id,
|
||||
@@ -108,18 +213,24 @@ class SATechniqueRepository:
|
||||
func.cast(Test.state == TestState.validated, self._int_type())
|
||||
).label("validated_count"),
|
||||
)
|
||||
# Chain .group_by() call
|
||||
.group_by(Test.technique_id)
|
||||
# Chain .subquery() call
|
||||
.subquery()
|
||||
)
|
||||
# Assign rule_count_sq = (
|
||||
rule_count_sq = (
|
||||
self._session.query(
|
||||
DetectionRule.mitre_technique_id,
|
||||
func.count(DetectionRule.id).label("rule_count"),
|
||||
)
|
||||
# Chain .group_by() call
|
||||
.group_by(DetectionRule.mitre_technique_id)
|
||||
# Chain .subquery() call
|
||||
.subquery()
|
||||
)
|
||||
|
||||
# Assign rows = (
|
||||
rows = (
|
||||
self._session.query(
|
||||
Technique,
|
||||
@@ -127,20 +238,29 @@ class SATechniqueRepository:
|
||||
func.coalesce(test_count_sq.c.validated_count, 0),
|
||||
func.coalesce(rule_count_sq.c.rule_count, 0),
|
||||
)
|
||||
# Chain .outerjoin() call
|
||||
.outerjoin(test_count_sq, Technique.id == test_count_sq.c.technique_id)
|
||||
# Chain .outerjoin() call
|
||||
.outerjoin(
|
||||
rule_count_sq,
|
||||
Technique.mitre_id == rule_count_sq.c.mitre_technique_id,
|
||||
)
|
||||
# Chain .order_by() call
|
||||
.order_by(Technique.mitre_id)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
|
||||
# Return [
|
||||
return [
|
||||
TechniqueWithCounts(
|
||||
# Keyword argument: entity
|
||||
entity=TechniqueMapper.to_entity(tech),
|
||||
# Keyword argument: test_count
|
||||
test_count=int(tc),
|
||||
# Keyword argument: validated_test_count
|
||||
validated_test_count=int(vtc),
|
||||
# Keyword argument: detection_rule_count
|
||||
detection_rule_count=int(rc),
|
||||
)
|
||||
for tech, tc, vtc, rc in rows
|
||||
@@ -149,55 +269,112 @@ class SATechniqueRepository:
|
||||
# -- Mutations ---------------------------------------------------------
|
||||
|
||||
def save(self, technique: TechniqueEntity) -> TechniqueEntity:
|
||||
"""Persist a technique entity, inserting or updating as needed.
|
||||
|
||||
Args:
|
||||
technique (TechniqueEntity): The domain entity to persist.
|
||||
|
||||
Returns:
|
||||
TechniqueEntity: The persisted entity reflecting the current DB state.
|
||||
"""
|
||||
# Assign existing = (
|
||||
existing = (
|
||||
self._session.query(Technique)
|
||||
# Chain .filter() call
|
||||
.filter(Technique.id == technique.id)
|
||||
# Chain .first() call
|
||||
.first()
|
||||
)
|
||||
# Check: existing
|
||||
if existing:
|
||||
# Call technique.apply_to()
|
||||
technique.apply_to(existing)
|
||||
# Assign existing.mitre_id = technique.mitre_id
|
||||
existing.mitre_id = technique.mitre_id
|
||||
# Assign existing.name = technique.name
|
||||
existing.name = technique.name
|
||||
# Assign existing.tactic = technique.tactic
|
||||
existing.tactic = technique.tactic
|
||||
# Assign existing.description = technique.description
|
||||
existing.description = technique.description
|
||||
# Assign existing.platforms = technique.platforms
|
||||
existing.platforms = technique.platforms
|
||||
# Assign existing.is_subtechnique = technique.is_subtechnique
|
||||
existing.is_subtechnique = technique.is_subtechnique
|
||||
# Assign existing.parent_mitre_id = technique.parent_mitre_id
|
||||
existing.parent_mitre_id = technique.parent_mitre_id
|
||||
# Assign existing.mitre_version = technique.mitre_version
|
||||
existing.mitre_version = technique.mitre_version
|
||||
# Assign existing.mitre_last_modified = technique.mitre_last_modified
|
||||
existing.mitre_last_modified = technique.mitre_last_modified
|
||||
# Call self._session.flush()
|
||||
self._session.flush()
|
||||
# Return TechniqueMapper.to_entity(existing)
|
||||
return TechniqueMapper.to_entity(existing)
|
||||
# Fallback: handle remaining cases
|
||||
else:
|
||||
# Assign model = Technique(
|
||||
model = Technique(
|
||||
# Keyword argument: id
|
||||
id=technique.id,
|
||||
# Keyword argument: mitre_id
|
||||
mitre_id=technique.mitre_id,
|
||||
# Keyword argument: name
|
||||
name=technique.name,
|
||||
# Keyword argument: tactic
|
||||
tactic=technique.tactic,
|
||||
# Keyword argument: description
|
||||
description=technique.description,
|
||||
# Keyword argument: platforms
|
||||
platforms=technique.platforms,
|
||||
# Keyword argument: is_subtechnique
|
||||
is_subtechnique=technique.is_subtechnique,
|
||||
# Keyword argument: parent_mitre_id
|
||||
parent_mitre_id=technique.parent_mitre_id,
|
||||
# Keyword argument: status_global
|
||||
status_global=technique.status_global,
|
||||
# Keyword argument: review_required
|
||||
review_required=technique.review_required,
|
||||
# Keyword argument: last_review_date
|
||||
last_review_date=technique.last_review_date,
|
||||
# Keyword argument: mitre_version
|
||||
mitre_version=technique.mitre_version,
|
||||
# Keyword argument: mitre_last_modified
|
||||
mitre_last_modified=technique.mitre_last_modified,
|
||||
)
|
||||
# Call self._session.add()
|
||||
self._session.add(model)
|
||||
# Call self._session.flush()
|
||||
self._session.flush()
|
||||
# Return TechniqueMapper.to_entity(model)
|
||||
return TechniqueMapper.to_entity(model)
|
||||
|
||||
# Define function exists_by_mitre_id
|
||||
def exists_by_mitre_id(self, mitre_id: str) -> bool:
|
||||
"""Check whether a technique with the given MITRE ID already exists.
|
||||
|
||||
Args:
|
||||
mitre_id (str): The MITRE ATT&CK identifier to look up.
|
||||
|
||||
Returns:
|
||||
bool: ``True`` if the technique exists, ``False`` otherwise.
|
||||
"""
|
||||
# Return (
|
||||
return (
|
||||
self._session.query(Technique.id)
|
||||
# Chain .filter() call
|
||||
.filter(Technique.mitre_id == mitre_id)
|
||||
# Chain .first() call
|
||||
.first()
|
||||
) is not None
|
||||
|
||||
# -- Internal ----------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
# Define function _int_type
|
||||
def _int_type() -> type:
|
||||
"""Return an Integer type for CAST expressions (SQLite-compatible)."""
|
||||
# Import Integer from sqlalchemy
|
||||
from sqlalchemy import Integer
|
||||
# Return Integer
|
||||
return Integer
|
||||
|
||||
@@ -1,78 +1,163 @@
|
||||
"""SQLAlchemy implementation of TestRepository."""
|
||||
|
||||
# Enable future language features for compatibility
|
||||
from __future__ import annotations
|
||||
|
||||
# Import uuid
|
||||
import uuid
|
||||
|
||||
# Import func from sqlalchemy
|
||||
from sqlalchemy import func
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import TestState from app.domain.enums
|
||||
from app.domain.enums import TestState
|
||||
|
||||
# Import Test from app.models.test
|
||||
from app.models.test import Test
|
||||
|
||||
|
||||
# Define class SATestRepository
|
||||
class SATestRepository:
|
||||
"""Concrete test repository backed by SQLAlchemy."""
|
||||
|
||||
# Define function __init__
|
||||
def __init__(self, session: Session) -> None:
|
||||
"""Initialise the repository with a caller-provided session.
|
||||
|
||||
Args:
|
||||
session (Session): The SQLAlchemy session to use for all queries.
|
||||
"""
|
||||
# Assign self._session = session
|
||||
self._session = session
|
||||
|
||||
# Define function find_by_id
|
||||
def find_by_id(self, test_id: uuid.UUID) -> Test | None:
|
||||
"""Return a single test by its primary key.
|
||||
|
||||
Args:
|
||||
test_id (uuid.UUID): The UUID primary key of the test.
|
||||
|
||||
Returns:
|
||||
Test | None: The ORM model instance, or ``None`` if not found.
|
||||
"""
|
||||
# Return (
|
||||
return (
|
||||
self._session.query(Test)
|
||||
# Chain .filter() call
|
||||
.filter(Test.id == test_id)
|
||||
# Chain .first() call
|
||||
.first()
|
||||
)
|
||||
|
||||
# Define function list_by_technique
|
||||
def list_by_technique(self, technique_id: uuid.UUID) -> list[Test]:
|
||||
"""Return all tests for a given technique, ordered by creation date.
|
||||
|
||||
Args:
|
||||
technique_id (uuid.UUID): The UUID of the parent technique.
|
||||
|
||||
Returns:
|
||||
list[Test]: ORM model instances ordered by ``created_at`` ascending.
|
||||
"""
|
||||
# Return (
|
||||
return (
|
||||
self._session.query(Test)
|
||||
# Chain .filter() call
|
||||
.filter(Test.technique_id == technique_id)
|
||||
# Chain .order_by() call
|
||||
.order_by(Test.created_at)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
|
||||
# Define function list_by_state
|
||||
def list_by_state(self, state: TestState) -> list[Test]:
|
||||
"""Return all tests that are currently in the given workflow state.
|
||||
|
||||
Args:
|
||||
state (TestState): The workflow state to filter on.
|
||||
|
||||
Returns:
|
||||
list[Test]: All ORM model instances with the specified state.
|
||||
"""
|
||||
# Return (
|
||||
return (
|
||||
self._session.query(Test)
|
||||
# Chain .filter() call
|
||||
.filter(Test.state == state)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
|
||||
# Define function count_by_technique_and_state
|
||||
def count_by_technique_and_state(
|
||||
self,
|
||||
# Entry: technique_id
|
||||
technique_id: uuid.UUID,
|
||||
) -> dict[TestState, int]:
|
||||
"""Return per-state test counts for a specific technique.
|
||||
|
||||
Args:
|
||||
technique_id (uuid.UUID): The UUID of the technique to aggregate for.
|
||||
|
||||
Returns:
|
||||
dict[TestState, int]: Mapping of each state to the number of tests in that state.
|
||||
"""
|
||||
# Assign rows = (
|
||||
rows = (
|
||||
self._session.query(Test.state, func.count(Test.id))
|
||||
# Chain .filter() call
|
||||
.filter(Test.technique_id == technique_id)
|
||||
# Chain .group_by() call
|
||||
.group_by(Test.state)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
# Assign result = {}
|
||||
result: dict[TestState, int] = {}
|
||||
# Iterate over rows
|
||||
for state_val, count in rows:
|
||||
# Assign key = (
|
||||
key = (
|
||||
state_val
|
||||
if isinstance(state_val, TestState)
|
||||
else TestState(state_val)
|
||||
)
|
||||
# Assign result[key] = count
|
||||
result[key] = count
|
||||
# Return result
|
||||
return result
|
||||
|
||||
# Define function get_states_and_results_for_technique
|
||||
def get_states_and_results_for_technique(
|
||||
self,
|
||||
# Entry: technique_id
|
||||
technique_id: uuid.UUID,
|
||||
) -> list[tuple[str, str | None]]:
|
||||
"""Return lightweight (state, detection_result) pairs.
|
||||
"""Return lightweight ``(state, detection_result)`` pairs for a technique.
|
||||
|
||||
Used by TechniqueEntity.recalculate_status() without loading
|
||||
full Test models.
|
||||
Used by ``TechniqueEntity.recalculate_status()`` to avoid loading full
|
||||
``Test`` models.
|
||||
|
||||
Args:
|
||||
technique_id (uuid.UUID): The UUID of the technique to query.
|
||||
|
||||
Returns:
|
||||
list[tuple[str, str | None]]: Each tuple contains the state string
|
||||
and the detection result string (or ``None``).
|
||||
"""
|
||||
# Assign rows = (
|
||||
rows = (
|
||||
self._session.query(Test.state, Test.detection_result)
|
||||
# Chain .filter() call
|
||||
.filter(Test.technique_id == technique_id)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
# Return [
|
||||
return [
|
||||
(
|
||||
r.state.value if hasattr(r.state, "value") else str(r.state),
|
||||
|
||||
Reference in New Issue
Block a user