feat(infra): add repository implementations, mappers, FastAPI wiring, and technique indexes

This commit is contained in:
2026-02-18 19:10:50 +01:00
parent 5c55e7c17f
commit 1521005b62
9 changed files with 618 additions and 0 deletions

View File

@@ -0,0 +1,8 @@
from app.infrastructure.persistence.repositories.sa_technique_repository import (
SATechniqueRepository,
)
from app.infrastructure.persistence.repositories.sa_test_repository import (
SATestRepository,
)
__all__ = ["SATechniqueRepository", "SATestRepository"]

View File

@@ -0,0 +1,199 @@
"""SQLAlchemy implementation of TechniqueRepository.
Receives a Session from the caller — does NOT create its own.
Does NOT call commit() — the Unit of Work owns that.
"""
from __future__ import annotations
import uuid
from sqlalchemy import func
from sqlalchemy.orm import Session
from app.domain.entities.technique import TechniqueEntity
from app.domain.enums import TechniqueStatus, TestState
from app.domain.ports.repositories.technique_repository import TechniqueWithCounts
from app.infrastructure.persistence.mappers.technique_mapper import TechniqueMapper
from app.models.detection_rule import DetectionRule
from app.models.technique import Technique
from app.models.test import Test
class SATechniqueRepository:
"""Concrete repository backed by SQLAlchemy."""
def __init__(self, session: Session) -> None:
self._session = session
# -- Single-entity access ----------------------------------------------
def find_by_id(self, technique_id: uuid.UUID) -> TechniqueEntity | None:
model = (
self._session.query(Technique)
.filter(Technique.id == technique_id)
.first()
)
return TechniqueMapper.to_entity(model) if model else None
def find_by_mitre_id(self, mitre_id: str) -> TechniqueEntity | None:
model = (
self._session.query(Technique)
.filter(Technique.mitre_id == mitre_id)
.first()
)
return TechniqueMapper.to_entity(model) if model else None
# -- List access -------------------------------------------------------
def list_all(
self,
*,
tactic: str | None = None,
status: TechniqueStatus | None = None,
review_required: bool | None = None,
) -> list[TechniqueEntity]:
query = self._session.query(Technique)
if tactic is not None:
query = query.filter(Technique.tactic == tactic)
if status is not None:
query = query.filter(Technique.status_global == status)
if review_required is not None:
query = query.filter(Technique.review_required == review_required)
models = query.order_by(Technique.mitre_id).all()
return [TechniqueMapper.to_entity(m) for m in models]
def list_by_ids(self, ids: list[uuid.UUID]) -> list[TechniqueEntity]:
if not ids:
return []
models = (
self._session.query(Technique)
.filter(Technique.id.in_(ids))
.all()
)
return [TechniqueMapper.to_entity(m) for m in models]
# -- Batch queries (for scoring/heatmap) -------------------------------
def count_by_status(self) -> dict[TechniqueStatus, int]:
rows = (
self._session.query(
Technique.status_global,
func.count(Technique.id),
)
.group_by(Technique.status_global)
.all()
)
result = {s: 0 for s in TechniqueStatus}
for status_val, count in rows:
key = (
status_val
if isinstance(status_val, TechniqueStatus)
else TechniqueStatus(status_val)
)
result[key] = count
return result
def find_all_with_test_counts(self) -> list[TechniqueWithCounts]:
"""Single query replacing the N+1 pattern.
Returns all techniques with pre-aggregated test and detection
rule counts via subqueries.
"""
test_count_sq = (
self._session.query(
Test.technique_id,
func.count(Test.id).label("test_count"),
func.sum(
func.cast(Test.state == TestState.validated, self._int_type())
).label("validated_count"),
)
.group_by(Test.technique_id)
.subquery()
)
rule_count_sq = (
self._session.query(
DetectionRule.mitre_technique_id,
func.count(DetectionRule.id).label("rule_count"),
)
.group_by(DetectionRule.mitre_technique_id)
.subquery()
)
rows = (
self._session.query(
Technique,
func.coalesce(test_count_sq.c.test_count, 0),
func.coalesce(test_count_sq.c.validated_count, 0),
func.coalesce(rule_count_sq.c.rule_count, 0),
)
.outerjoin(test_count_sq, Technique.id == test_count_sq.c.technique_id)
.outerjoin(
rule_count_sq,
Technique.mitre_id == rule_count_sq.c.mitre_technique_id,
)
.order_by(Technique.mitre_id)
.all()
)
return [
TechniqueWithCounts(
entity=TechniqueMapper.to_entity(tech),
test_count=int(tc),
validated_test_count=int(vtc),
detection_rule_count=int(rc),
)
for tech, tc, vtc, rc in rows
]
# -- Mutations ---------------------------------------------------------
def save(self, technique: TechniqueEntity) -> TechniqueEntity:
existing = (
self._session.query(Technique)
.filter(Technique.id == technique.id)
.first()
)
if existing:
technique.apply_to(existing)
existing.mitre_id = technique.mitre_id
existing.name = technique.name
existing.tactic = technique.tactic
existing.description = technique.description
existing.platforms = technique.platforms
existing.is_subtechnique = technique.is_subtechnique
existing.parent_mitre_id = technique.parent_mitre_id
self._session.flush()
return TechniqueMapper.to_entity(existing)
else:
model = Technique(
id=technique.id,
mitre_id=technique.mitre_id,
name=technique.name,
tactic=technique.tactic,
description=technique.description,
platforms=technique.platforms,
is_subtechnique=technique.is_subtechnique,
parent_mitre_id=technique.parent_mitre_id,
status_global=technique.status_global,
review_required=technique.review_required,
last_review_date=technique.last_review_date,
)
self._session.add(model)
self._session.flush()
return TechniqueMapper.to_entity(model)
def exists_by_mitre_id(self, mitre_id: str) -> bool:
return (
self._session.query(Technique.id)
.filter(Technique.mitre_id == mitre_id)
.first()
) is not None
# -- Internal ----------------------------------------------------------
@staticmethod
def _int_type():
"""Return an Integer type for CAST expressions (SQLite-compatible)."""
from sqlalchemy import Integer
return Integer

View File

@@ -0,0 +1,86 @@
"""SQLAlchemy implementation of TestRepository."""
from __future__ import annotations
import uuid
from sqlalchemy import func
from sqlalchemy.orm import Session
from app.domain.enums import TestState
from app.models.test import Test
class SATestRepository:
"""Concrete test repository backed by SQLAlchemy."""
def __init__(self, session: Session) -> None:
self._session = session
def find_by_id(self, test_id: uuid.UUID) -> Test | None:
return (
self._session.query(Test)
.filter(Test.id == test_id)
.first()
)
def list_by_technique(self, technique_id: uuid.UUID) -> list[Test]:
return (
self._session.query(Test)
.filter(Test.technique_id == technique_id)
.order_by(Test.created_at)
.all()
)
def list_by_state(self, state: TestState) -> list[Test]:
return (
self._session.query(Test)
.filter(Test.state == state)
.all()
)
def count_by_technique_and_state(
self,
technique_id: uuid.UUID,
) -> dict[TestState, int]:
rows = (
self._session.query(Test.state, func.count(Test.id))
.filter(Test.technique_id == technique_id)
.group_by(Test.state)
.all()
)
result: dict[TestState, int] = {}
for state_val, count in rows:
key = (
state_val
if isinstance(state_val, TestState)
else TestState(state_val)
)
result[key] = count
return result
def get_states_and_results_for_technique(
self,
technique_id: uuid.UUID,
) -> list[tuple[str, str | None]]:
"""Return lightweight (state, detection_result) pairs.
Used by TechniqueEntity.recalculate_status() without loading
full Test models.
"""
rows = (
self._session.query(Test.state, Test.detection_result)
.filter(Test.technique_id == technique_id)
.all()
)
return [
(
r.state.value if hasattr(r.state, "value") else str(r.state),
(
r.detection_result.value
if hasattr(r.detection_result, "value")
else r.detection_result
),
)
for r in rows
]