feat(infra): add repository implementations, mappers, FastAPI wiring, and technique indexes
This commit is contained in:
30
backend/app/dependencies/repositories.py
Normal file
30
backend/app/dependencies/repositories.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""FastAPI dependency providers for repositories.
|
||||
|
||||
Wiring lives ONLY in the presentation layer — use cases and services
|
||||
never know which concrete repository implementation they receive.
|
||||
"""
|
||||
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.infrastructure.persistence.repositories.sa_technique_repository import (
|
||||
SATechniqueRepository,
|
||||
)
|
||||
from app.infrastructure.persistence.repositories.sa_test_repository import (
|
||||
SATestRepository,
|
||||
)
|
||||
|
||||
|
||||
def get_technique_repository(
|
||||
db: Session = Depends(get_db),
|
||||
) -> SATechniqueRepository:
|
||||
"""Provide a TechniqueRepository backed by the current DB session."""
|
||||
return SATechniqueRepository(db)
|
||||
|
||||
|
||||
def get_test_repository(
|
||||
db: Session = Depends(get_db),
|
||||
) -> SATestRepository:
|
||||
"""Provide a TestRepository backed by the current DB session."""
|
||||
return SATestRepository(db)
|
||||
0
backend/app/infrastructure/persistence/__init__.py
Normal file
0
backend/app/infrastructure/persistence/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""Technique ORM model <-> domain entity mapper."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.domain.entities.technique import TechniqueEntity
|
||||
from app.domain.enums import TechniqueStatus
|
||||
|
||||
|
||||
class TechniqueMapper:
|
||||
"""Converts between SQLAlchemy Technique model and TechniqueEntity."""
|
||||
|
||||
@staticmethod
|
||||
def to_entity(model: object) -> TechniqueEntity:
|
||||
"""Convert an ORM Technique model to a domain TechniqueEntity."""
|
||||
return TechniqueEntity.from_orm(model)
|
||||
|
||||
@staticmethod
|
||||
def to_model_updates(entity: TechniqueEntity, model: object) -> None:
|
||||
"""Apply entity changes back onto an existing ORM model."""
|
||||
entity.apply_to(model)
|
||||
@@ -0,0 +1,8 @@
|
||||
from app.infrastructure.persistence.repositories.sa_technique_repository import (
|
||||
SATechniqueRepository,
|
||||
)
|
||||
from app.infrastructure.persistence.repositories.sa_test_repository import (
|
||||
SATestRepository,
|
||||
)
|
||||
|
||||
__all__ = ["SATechniqueRepository", "SATestRepository"]
|
||||
@@ -0,0 +1,199 @@
|
||||
"""SQLAlchemy implementation of TechniqueRepository.
|
||||
|
||||
Receives a Session from the caller — does NOT create its own.
|
||||
Does NOT call commit() — the Unit of Work owns that.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.domain.entities.technique import TechniqueEntity
|
||||
from app.domain.enums import TechniqueStatus, TestState
|
||||
from app.domain.ports.repositories.technique_repository import TechniqueWithCounts
|
||||
from app.infrastructure.persistence.mappers.technique_mapper import TechniqueMapper
|
||||
from app.models.detection_rule import DetectionRule
|
||||
from app.models.technique import Technique
|
||||
from app.models.test import Test
|
||||
|
||||
|
||||
class SATechniqueRepository:
|
||||
"""Concrete repository backed by SQLAlchemy."""
|
||||
|
||||
def __init__(self, session: Session) -> None:
|
||||
self._session = session
|
||||
|
||||
# -- Single-entity access ----------------------------------------------
|
||||
|
||||
def find_by_id(self, technique_id: uuid.UUID) -> TechniqueEntity | None:
|
||||
model = (
|
||||
self._session.query(Technique)
|
||||
.filter(Technique.id == technique_id)
|
||||
.first()
|
||||
)
|
||||
return TechniqueMapper.to_entity(model) if model else None
|
||||
|
||||
def find_by_mitre_id(self, mitre_id: str) -> TechniqueEntity | None:
|
||||
model = (
|
||||
self._session.query(Technique)
|
||||
.filter(Technique.mitre_id == mitre_id)
|
||||
.first()
|
||||
)
|
||||
return TechniqueMapper.to_entity(model) if model else None
|
||||
|
||||
# -- List access -------------------------------------------------------
|
||||
|
||||
def list_all(
|
||||
self,
|
||||
*,
|
||||
tactic: str | None = None,
|
||||
status: TechniqueStatus | None = None,
|
||||
review_required: bool | None = None,
|
||||
) -> list[TechniqueEntity]:
|
||||
query = self._session.query(Technique)
|
||||
if tactic is not None:
|
||||
query = query.filter(Technique.tactic == tactic)
|
||||
if status is not None:
|
||||
query = query.filter(Technique.status_global == status)
|
||||
if review_required is not None:
|
||||
query = query.filter(Technique.review_required == review_required)
|
||||
models = query.order_by(Technique.mitre_id).all()
|
||||
return [TechniqueMapper.to_entity(m) for m in models]
|
||||
|
||||
def list_by_ids(self, ids: list[uuid.UUID]) -> list[TechniqueEntity]:
|
||||
if not ids:
|
||||
return []
|
||||
models = (
|
||||
self._session.query(Technique)
|
||||
.filter(Technique.id.in_(ids))
|
||||
.all()
|
||||
)
|
||||
return [TechniqueMapper.to_entity(m) for m in models]
|
||||
|
||||
# -- Batch queries (for scoring/heatmap) -------------------------------
|
||||
|
||||
def count_by_status(self) -> dict[TechniqueStatus, int]:
|
||||
rows = (
|
||||
self._session.query(
|
||||
Technique.status_global,
|
||||
func.count(Technique.id),
|
||||
)
|
||||
.group_by(Technique.status_global)
|
||||
.all()
|
||||
)
|
||||
result = {s: 0 for s in TechniqueStatus}
|
||||
for status_val, count in rows:
|
||||
key = (
|
||||
status_val
|
||||
if isinstance(status_val, TechniqueStatus)
|
||||
else TechniqueStatus(status_val)
|
||||
)
|
||||
result[key] = count
|
||||
return result
|
||||
|
||||
def find_all_with_test_counts(self) -> list[TechniqueWithCounts]:
|
||||
"""Single query replacing the N+1 pattern.
|
||||
|
||||
Returns all techniques with pre-aggregated test and detection
|
||||
rule counts via subqueries.
|
||||
"""
|
||||
test_count_sq = (
|
||||
self._session.query(
|
||||
Test.technique_id,
|
||||
func.count(Test.id).label("test_count"),
|
||||
func.sum(
|
||||
func.cast(Test.state == TestState.validated, self._int_type())
|
||||
).label("validated_count"),
|
||||
)
|
||||
.group_by(Test.technique_id)
|
||||
.subquery()
|
||||
)
|
||||
rule_count_sq = (
|
||||
self._session.query(
|
||||
DetectionRule.mitre_technique_id,
|
||||
func.count(DetectionRule.id).label("rule_count"),
|
||||
)
|
||||
.group_by(DetectionRule.mitre_technique_id)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
rows = (
|
||||
self._session.query(
|
||||
Technique,
|
||||
func.coalesce(test_count_sq.c.test_count, 0),
|
||||
func.coalesce(test_count_sq.c.validated_count, 0),
|
||||
func.coalesce(rule_count_sq.c.rule_count, 0),
|
||||
)
|
||||
.outerjoin(test_count_sq, Technique.id == test_count_sq.c.technique_id)
|
||||
.outerjoin(
|
||||
rule_count_sq,
|
||||
Technique.mitre_id == rule_count_sq.c.mitre_technique_id,
|
||||
)
|
||||
.order_by(Technique.mitre_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
return [
|
||||
TechniqueWithCounts(
|
||||
entity=TechniqueMapper.to_entity(tech),
|
||||
test_count=int(tc),
|
||||
validated_test_count=int(vtc),
|
||||
detection_rule_count=int(rc),
|
||||
)
|
||||
for tech, tc, vtc, rc in rows
|
||||
]
|
||||
|
||||
# -- Mutations ---------------------------------------------------------
|
||||
|
||||
def save(self, technique: TechniqueEntity) -> TechniqueEntity:
|
||||
existing = (
|
||||
self._session.query(Technique)
|
||||
.filter(Technique.id == technique.id)
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
technique.apply_to(existing)
|
||||
existing.mitre_id = technique.mitre_id
|
||||
existing.name = technique.name
|
||||
existing.tactic = technique.tactic
|
||||
existing.description = technique.description
|
||||
existing.platforms = technique.platforms
|
||||
existing.is_subtechnique = technique.is_subtechnique
|
||||
existing.parent_mitre_id = technique.parent_mitre_id
|
||||
self._session.flush()
|
||||
return TechniqueMapper.to_entity(existing)
|
||||
else:
|
||||
model = Technique(
|
||||
id=technique.id,
|
||||
mitre_id=technique.mitre_id,
|
||||
name=technique.name,
|
||||
tactic=technique.tactic,
|
||||
description=technique.description,
|
||||
platforms=technique.platforms,
|
||||
is_subtechnique=technique.is_subtechnique,
|
||||
parent_mitre_id=technique.parent_mitre_id,
|
||||
status_global=technique.status_global,
|
||||
review_required=technique.review_required,
|
||||
last_review_date=technique.last_review_date,
|
||||
)
|
||||
self._session.add(model)
|
||||
self._session.flush()
|
||||
return TechniqueMapper.to_entity(model)
|
||||
|
||||
def exists_by_mitre_id(self, mitre_id: str) -> bool:
|
||||
return (
|
||||
self._session.query(Technique.id)
|
||||
.filter(Technique.mitre_id == mitre_id)
|
||||
.first()
|
||||
) is not None
|
||||
|
||||
# -- Internal ----------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _int_type():
|
||||
"""Return an Integer type for CAST expressions (SQLite-compatible)."""
|
||||
from sqlalchemy import Integer
|
||||
return Integer
|
||||
@@ -0,0 +1,86 @@
|
||||
"""SQLAlchemy implementation of TestRepository."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.domain.enums import TestState
|
||||
from app.models.test import Test
|
||||
|
||||
|
||||
class SATestRepository:
|
||||
"""Concrete test repository backed by SQLAlchemy."""
|
||||
|
||||
def __init__(self, session: Session) -> None:
|
||||
self._session = session
|
||||
|
||||
def find_by_id(self, test_id: uuid.UUID) -> Test | None:
|
||||
return (
|
||||
self._session.query(Test)
|
||||
.filter(Test.id == test_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
def list_by_technique(self, technique_id: uuid.UUID) -> list[Test]:
|
||||
return (
|
||||
self._session.query(Test)
|
||||
.filter(Test.technique_id == technique_id)
|
||||
.order_by(Test.created_at)
|
||||
.all()
|
||||
)
|
||||
|
||||
def list_by_state(self, state: TestState) -> list[Test]:
|
||||
return (
|
||||
self._session.query(Test)
|
||||
.filter(Test.state == state)
|
||||
.all()
|
||||
)
|
||||
|
||||
def count_by_technique_and_state(
|
||||
self,
|
||||
technique_id: uuid.UUID,
|
||||
) -> dict[TestState, int]:
|
||||
rows = (
|
||||
self._session.query(Test.state, func.count(Test.id))
|
||||
.filter(Test.technique_id == technique_id)
|
||||
.group_by(Test.state)
|
||||
.all()
|
||||
)
|
||||
result: dict[TestState, int] = {}
|
||||
for state_val, count in rows:
|
||||
key = (
|
||||
state_val
|
||||
if isinstance(state_val, TestState)
|
||||
else TestState(state_val)
|
||||
)
|
||||
result[key] = count
|
||||
return result
|
||||
|
||||
def get_states_and_results_for_technique(
|
||||
self,
|
||||
technique_id: uuid.UUID,
|
||||
) -> list[tuple[str, str | None]]:
|
||||
"""Return lightweight (state, detection_result) pairs.
|
||||
|
||||
Used by TechniqueEntity.recalculate_status() without loading
|
||||
full Test models.
|
||||
"""
|
||||
rows = (
|
||||
self._session.query(Test.state, Test.detection_result)
|
||||
.filter(Test.technique_id == technique_id)
|
||||
.all()
|
||||
)
|
||||
return [
|
||||
(
|
||||
r.state.value if hasattr(r.state, "value") else str(r.state),
|
||||
(
|
||||
r.detection_result.value
|
||||
if hasattr(r.detection_result, "value")
|
||||
else r.detection_result
|
||||
),
|
||||
)
|
||||
for r in rows
|
||||
]
|
||||
Reference in New Issue
Block a user