diff --git a/backend/alembic/versions/b026_add_technique_query_indexes.py b/backend/alembic/versions/b026_add_technique_query_indexes.py new file mode 100644 index 0000000..e62dc89 --- /dev/null +++ b/backend/alembic/versions/b026_add_technique_query_indexes.py @@ -0,0 +1,38 @@ +"""add_technique_query_indexes + +Add indexes on techniques table for common query patterns +(filter by tactic, filter by status_global) used in heatmap, scoring, +and list-all-techniques operations. + +These may already exist if the ORM model auto-created them; the +``if_not_exists`` flag makes this migration safe to run regardless. + +Revision ID: b026techidx +Revises: b025uqtdr +Create Date: 2026-02-18 18:00:00.000000 + +""" +from typing import Sequence, Union + +from alembic import op + +revision: str = "b026techidx" +down_revision: Union[str, None] = "b025uqtdr" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.execute( + "CREATE INDEX IF NOT EXISTS ix_techniques_tactic " + "ON techniques (tactic)" + ) + op.execute( + "CREATE INDEX IF NOT EXISTS ix_techniques_status_global " + "ON techniques (status_global)" + ) + + +def downgrade() -> None: + op.drop_index("ix_techniques_status_global", table_name="techniques") + op.drop_index("ix_techniques_tactic", table_name="techniques") diff --git a/backend/app/dependencies/repositories.py b/backend/app/dependencies/repositories.py new file mode 100644 index 0000000..eae5d62 --- /dev/null +++ b/backend/app/dependencies/repositories.py @@ -0,0 +1,30 @@ +"""FastAPI dependency providers for repositories. + +Wiring lives ONLY in the presentation layer — use cases and services +never know which concrete repository implementation they receive. +""" + +from fastapi import Depends +from sqlalchemy.orm import Session + +from app.database import get_db +from app.infrastructure.persistence.repositories.sa_technique_repository import ( + SATechniqueRepository, +) +from app.infrastructure.persistence.repositories.sa_test_repository import ( + SATestRepository, +) + + +def get_technique_repository( + db: Session = Depends(get_db), +) -> SATechniqueRepository: + """Provide a TechniqueRepository backed by the current DB session.""" + return SATechniqueRepository(db) + + +def get_test_repository( + db: Session = Depends(get_db), +) -> SATestRepository: + """Provide a TestRepository backed by the current DB session.""" + return SATestRepository(db) diff --git a/backend/app/infrastructure/persistence/__init__.py b/backend/app/infrastructure/persistence/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/app/infrastructure/persistence/mappers/__init__.py b/backend/app/infrastructure/persistence/mappers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/app/infrastructure/persistence/mappers/technique_mapper.py b/backend/app/infrastructure/persistence/mappers/technique_mapper.py new file mode 100644 index 0000000..74cd588 --- /dev/null +++ b/backend/app/infrastructure/persistence/mappers/technique_mapper.py @@ -0,0 +1,20 @@ +"""Technique ORM model <-> domain entity mapper.""" + +from __future__ import annotations + +from app.domain.entities.technique import TechniqueEntity +from app.domain.enums import TechniqueStatus + + +class TechniqueMapper: + """Converts between SQLAlchemy Technique model and TechniqueEntity.""" + + @staticmethod + def to_entity(model: object) -> TechniqueEntity: + """Convert an ORM Technique model to a domain TechniqueEntity.""" + return TechniqueEntity.from_orm(model) + + @staticmethod + def to_model_updates(entity: TechniqueEntity, model: object) -> None: + """Apply entity changes back onto an existing ORM model.""" + entity.apply_to(model) diff --git a/backend/app/infrastructure/persistence/repositories/__init__.py b/backend/app/infrastructure/persistence/repositories/__init__.py new file mode 100644 index 0000000..d6c0338 --- /dev/null +++ b/backend/app/infrastructure/persistence/repositories/__init__.py @@ -0,0 +1,8 @@ +from app.infrastructure.persistence.repositories.sa_technique_repository import ( + SATechniqueRepository, +) +from app.infrastructure.persistence.repositories.sa_test_repository import ( + SATestRepository, +) + +__all__ = ["SATechniqueRepository", "SATestRepository"] diff --git a/backend/app/infrastructure/persistence/repositories/sa_technique_repository.py b/backend/app/infrastructure/persistence/repositories/sa_technique_repository.py new file mode 100644 index 0000000..84fb08e --- /dev/null +++ b/backend/app/infrastructure/persistence/repositories/sa_technique_repository.py @@ -0,0 +1,199 @@ +"""SQLAlchemy implementation of TechniqueRepository. + +Receives a Session from the caller — does NOT create its own. +Does NOT call commit() — the Unit of Work owns that. +""" + +from __future__ import annotations + +import uuid + +from sqlalchemy import func +from sqlalchemy.orm import Session + +from app.domain.entities.technique import TechniqueEntity +from app.domain.enums import TechniqueStatus, TestState +from app.domain.ports.repositories.technique_repository import TechniqueWithCounts +from app.infrastructure.persistence.mappers.technique_mapper import TechniqueMapper +from app.models.detection_rule import DetectionRule +from app.models.technique import Technique +from app.models.test import Test + + +class SATechniqueRepository: + """Concrete repository backed by SQLAlchemy.""" + + def __init__(self, session: Session) -> None: + self._session = session + + # -- Single-entity access ---------------------------------------------- + + def find_by_id(self, technique_id: uuid.UUID) -> TechniqueEntity | None: + model = ( + self._session.query(Technique) + .filter(Technique.id == technique_id) + .first() + ) + return TechniqueMapper.to_entity(model) if model else None + + def find_by_mitre_id(self, mitre_id: str) -> TechniqueEntity | None: + model = ( + self._session.query(Technique) + .filter(Technique.mitre_id == mitre_id) + .first() + ) + return TechniqueMapper.to_entity(model) if model else None + + # -- List access ------------------------------------------------------- + + def list_all( + self, + *, + tactic: str | None = None, + status: TechniqueStatus | None = None, + review_required: bool | None = None, + ) -> list[TechniqueEntity]: + query = self._session.query(Technique) + if tactic is not None: + query = query.filter(Technique.tactic == tactic) + if status is not None: + query = query.filter(Technique.status_global == status) + if review_required is not None: + query = query.filter(Technique.review_required == review_required) + models = query.order_by(Technique.mitre_id).all() + return [TechniqueMapper.to_entity(m) for m in models] + + def list_by_ids(self, ids: list[uuid.UUID]) -> list[TechniqueEntity]: + if not ids: + return [] + models = ( + self._session.query(Technique) + .filter(Technique.id.in_(ids)) + .all() + ) + return [TechniqueMapper.to_entity(m) for m in models] + + # -- Batch queries (for scoring/heatmap) ------------------------------- + + def count_by_status(self) -> dict[TechniqueStatus, int]: + rows = ( + self._session.query( + Technique.status_global, + func.count(Technique.id), + ) + .group_by(Technique.status_global) + .all() + ) + result = {s: 0 for s in TechniqueStatus} + for status_val, count in rows: + key = ( + status_val + if isinstance(status_val, TechniqueStatus) + else TechniqueStatus(status_val) + ) + result[key] = count + return result + + def find_all_with_test_counts(self) -> list[TechniqueWithCounts]: + """Single query replacing the N+1 pattern. + + Returns all techniques with pre-aggregated test and detection + rule counts via subqueries. + """ + test_count_sq = ( + self._session.query( + Test.technique_id, + func.count(Test.id).label("test_count"), + func.sum( + func.cast(Test.state == TestState.validated, self._int_type()) + ).label("validated_count"), + ) + .group_by(Test.technique_id) + .subquery() + ) + rule_count_sq = ( + self._session.query( + DetectionRule.mitre_technique_id, + func.count(DetectionRule.id).label("rule_count"), + ) + .group_by(DetectionRule.mitre_technique_id) + .subquery() + ) + + rows = ( + self._session.query( + Technique, + func.coalesce(test_count_sq.c.test_count, 0), + func.coalesce(test_count_sq.c.validated_count, 0), + func.coalesce(rule_count_sq.c.rule_count, 0), + ) + .outerjoin(test_count_sq, Technique.id == test_count_sq.c.technique_id) + .outerjoin( + rule_count_sq, + Technique.mitre_id == rule_count_sq.c.mitre_technique_id, + ) + .order_by(Technique.mitre_id) + .all() + ) + + return [ + TechniqueWithCounts( + entity=TechniqueMapper.to_entity(tech), + test_count=int(tc), + validated_test_count=int(vtc), + detection_rule_count=int(rc), + ) + for tech, tc, vtc, rc in rows + ] + + # -- Mutations --------------------------------------------------------- + + def save(self, technique: TechniqueEntity) -> TechniqueEntity: + existing = ( + self._session.query(Technique) + .filter(Technique.id == technique.id) + .first() + ) + if existing: + technique.apply_to(existing) + existing.mitre_id = technique.mitre_id + existing.name = technique.name + existing.tactic = technique.tactic + existing.description = technique.description + existing.platforms = technique.platforms + existing.is_subtechnique = technique.is_subtechnique + existing.parent_mitre_id = technique.parent_mitre_id + self._session.flush() + return TechniqueMapper.to_entity(existing) + else: + model = Technique( + id=technique.id, + mitre_id=technique.mitre_id, + name=technique.name, + tactic=technique.tactic, + description=technique.description, + platforms=technique.platforms, + is_subtechnique=technique.is_subtechnique, + parent_mitre_id=technique.parent_mitre_id, + status_global=technique.status_global, + review_required=technique.review_required, + last_review_date=technique.last_review_date, + ) + self._session.add(model) + self._session.flush() + return TechniqueMapper.to_entity(model) + + def exists_by_mitre_id(self, mitre_id: str) -> bool: + return ( + self._session.query(Technique.id) + .filter(Technique.mitre_id == mitre_id) + .first() + ) is not None + + # -- Internal ---------------------------------------------------------- + + @staticmethod + def _int_type(): + """Return an Integer type for CAST expressions (SQLite-compatible).""" + from sqlalchemy import Integer + return Integer diff --git a/backend/app/infrastructure/persistence/repositories/sa_test_repository.py b/backend/app/infrastructure/persistence/repositories/sa_test_repository.py new file mode 100644 index 0000000..0a893f8 --- /dev/null +++ b/backend/app/infrastructure/persistence/repositories/sa_test_repository.py @@ -0,0 +1,86 @@ +"""SQLAlchemy implementation of TestRepository.""" + +from __future__ import annotations + +import uuid + +from sqlalchemy import func +from sqlalchemy.orm import Session + +from app.domain.enums import TestState +from app.models.test import Test + + +class SATestRepository: + """Concrete test repository backed by SQLAlchemy.""" + + def __init__(self, session: Session) -> None: + self._session = session + + def find_by_id(self, test_id: uuid.UUID) -> Test | None: + return ( + self._session.query(Test) + .filter(Test.id == test_id) + .first() + ) + + def list_by_technique(self, technique_id: uuid.UUID) -> list[Test]: + return ( + self._session.query(Test) + .filter(Test.technique_id == technique_id) + .order_by(Test.created_at) + .all() + ) + + def list_by_state(self, state: TestState) -> list[Test]: + return ( + self._session.query(Test) + .filter(Test.state == state) + .all() + ) + + def count_by_technique_and_state( + self, + technique_id: uuid.UUID, + ) -> dict[TestState, int]: + rows = ( + self._session.query(Test.state, func.count(Test.id)) + .filter(Test.technique_id == technique_id) + .group_by(Test.state) + .all() + ) + result: dict[TestState, int] = {} + for state_val, count in rows: + key = ( + state_val + if isinstance(state_val, TestState) + else TestState(state_val) + ) + result[key] = count + return result + + def get_states_and_results_for_technique( + self, + technique_id: uuid.UUID, + ) -> list[tuple[str, str | None]]: + """Return lightweight (state, detection_result) pairs. + + Used by TechniqueEntity.recalculate_status() without loading + full Test models. + """ + rows = ( + self._session.query(Test.state, Test.detection_result) + .filter(Test.technique_id == technique_id) + .all() + ) + return [ + ( + r.state.value if hasattr(r.state, "value") else str(r.state), + ( + r.detection_result.value + if hasattr(r.detection_result, "value") + else r.detection_result + ), + ) + for r in rows + ] diff --git a/backend/tests/test_repositories.py b/backend/tests/test_repositories.py new file mode 100644 index 0000000..f42cacf --- /dev/null +++ b/backend/tests/test_repositories.py @@ -0,0 +1,237 @@ +"""Integration tests for repository implementations. + +Uses the SQLite-based ``db`` fixture from conftest.py. +""" + +import uuid + +import pytest + +from app.domain.entities.technique import TechniqueEntity +from app.domain.enums import TechniqueStatus, TestState +from app.infrastructure.persistence.repositories.sa_technique_repository import ( + SATechniqueRepository, +) +from app.infrastructure.persistence.repositories.sa_test_repository import ( + SATestRepository, +) +from app.models.technique import Technique +from app.models.test import Test + + +# ── Helpers ────────────────────────────────────────────────────────── + + +def _create_technique(db, mitre_id="T1059", name="Test Technique", **kwargs): + tech = Technique( + id=uuid.uuid4(), + mitre_id=mitre_id, + name=name, + tactic=kwargs.get("tactic", "execution"), + status_global=kwargs.get("status_global", TechniqueStatus.not_evaluated), + ) + db.add(tech) + db.flush() + return tech + + +def _create_test(db, technique, state=TestState.draft, **kwargs): + t = Test( + id=uuid.uuid4(), + technique_id=technique.id, + name=kwargs.get("name", "Test 1"), + state=state, + detection_result=kwargs.get("detection_result"), + ) + db.add(t) + db.flush() + return t + + +# ── SATechniqueRepository ─────────────────────────────────────────── + + +class TestSATechniqueRepository: + + def test_find_by_id_returns_entity(self, db): + tech = _create_technique(db) + db.commit() + repo = SATechniqueRepository(db) + + result = repo.find_by_id(tech.id) + assert result is not None + assert result.mitre_id == "T1059" + assert isinstance(result, TechniqueEntity) + + def test_find_by_id_not_found(self, db): + repo = SATechniqueRepository(db) + assert repo.find_by_id(uuid.uuid4()) is None + + def test_find_by_mitre_id(self, db): + _create_technique(db, mitre_id="T1548") + db.commit() + repo = SATechniqueRepository(db) + + result = repo.find_by_mitre_id("T1548") + assert result is not None + assert result.mitre_id == "T1548" + + def test_find_by_mitre_id_not_found(self, db): + repo = SATechniqueRepository(db) + assert repo.find_by_mitre_id("T9999") is None + + def test_list_all_no_filters(self, db): + _create_technique(db, mitre_id="T1059", name="A") + _create_technique(db, mitre_id="T1060", name="B") + db.commit() + repo = SATechniqueRepository(db) + + results = repo.list_all() + assert len(results) == 2 + assert results[0].mitre_id == "T1059" + assert results[1].mitre_id == "T1060" + + def test_list_all_filter_by_tactic(self, db): + _create_technique(db, mitre_id="T1059", tactic="execution") + _create_technique(db, mitre_id="T1060", tactic="persistence") + db.commit() + repo = SATechniqueRepository(db) + + results = repo.list_all(tactic="execution") + assert len(results) == 1 + assert results[0].mitre_id == "T1059" + + def test_list_all_filter_by_status(self, db): + _create_technique(db, mitre_id="T1059", status_global=TechniqueStatus.validated) + _create_technique(db, mitre_id="T1060", status_global=TechniqueStatus.not_evaluated) + db.commit() + repo = SATechniqueRepository(db) + + results = repo.list_all(status=TechniqueStatus.validated) + assert len(results) == 1 + assert results[0].mitre_id == "T1059" + + def test_list_by_ids(self, db): + t1 = _create_technique(db, mitre_id="T1059") + _create_technique(db, mitre_id="T1060") + db.commit() + repo = SATechniqueRepository(db) + + results = repo.list_by_ids([t1.id]) + assert len(results) == 1 + assert results[0].mitre_id == "T1059" + + def test_list_by_ids_empty(self, db): + repo = SATechniqueRepository(db) + assert repo.list_by_ids([]) == [] + + def test_count_by_status(self, db): + _create_technique(db, mitre_id="T1059", status_global=TechniqueStatus.validated) + _create_technique(db, mitre_id="T1060", status_global=TechniqueStatus.validated) + _create_technique(db, mitre_id="T1061", status_global=TechniqueStatus.not_evaluated) + db.commit() + repo = SATechniqueRepository(db) + + counts = repo.count_by_status() + assert counts[TechniqueStatus.validated] == 2 + assert counts[TechniqueStatus.not_evaluated] == 1 + assert counts[TechniqueStatus.partial] == 0 + + def test_exists_by_mitre_id(self, db): + _create_technique(db, mitre_id="T1059") + db.commit() + repo = SATechniqueRepository(db) + + assert repo.exists_by_mitre_id("T1059") is True + assert repo.exists_by_mitre_id("T9999") is False + + def test_save_new_technique(self, db): + repo = SATechniqueRepository(db) + entity = TechniqueEntity.create( + mitre_id="T2000", + name="New Technique", + tactic="discovery", + ) + saved = repo.save(entity) + db.commit() + + assert saved.mitre_id == "T2000" + assert repo.find_by_mitre_id("T2000") is not None + + def test_save_updates_existing(self, db): + tech = _create_technique(db, mitre_id="T1059") + db.commit() + repo = SATechniqueRepository(db) + + entity = repo.find_by_id(tech.id) + entity.status_global = TechniqueStatus.validated + saved = repo.save(entity) + db.commit() + + reloaded = repo.find_by_id(tech.id) + assert reloaded.status_global == TechniqueStatus.validated + + +# ── SATestRepository ───────────────────────────────────────────────── + + +class TestSATestRepository: + + def test_find_by_id(self, db): + tech = _create_technique(db) + test = _create_test(db, tech) + db.commit() + repo = SATestRepository(db) + + result = repo.find_by_id(test.id) + assert result is not None + assert result.id == test.id + + def test_find_by_id_not_found(self, db): + repo = SATestRepository(db) + assert repo.find_by_id(uuid.uuid4()) is None + + def test_list_by_technique(self, db): + tech = _create_technique(db) + _create_test(db, tech, name="Test A") + _create_test(db, tech, name="Test B") + db.commit() + repo = SATestRepository(db) + + results = repo.list_by_technique(tech.id) + assert len(results) == 2 + + def test_list_by_state(self, db): + tech = _create_technique(db) + _create_test(db, tech, state=TestState.draft) + _create_test(db, tech, state=TestState.validated) + db.commit() + repo = SATestRepository(db) + + drafts = repo.list_by_state(TestState.draft) + assert len(drafts) == 1 + + def test_count_by_technique_and_state(self, db): + tech = _create_technique(db) + _create_test(db, tech, state=TestState.draft) + _create_test(db, tech, state=TestState.draft) + _create_test(db, tech, state=TestState.validated) + db.commit() + repo = SATestRepository(db) + + counts = repo.count_by_technique_and_state(tech.id) + assert counts.get(TestState.draft) == 2 + assert counts.get(TestState.validated) == 1 + + def test_get_states_and_results(self, db): + tech = _create_technique(db) + _create_test(db, tech, state=TestState.validated, detection_result="detected") + _create_test(db, tech, state=TestState.draft) + db.commit() + repo = SATestRepository(db) + + pairs = repo.get_states_and_results_for_technique(tech.id) + assert len(pairs) == 2 + states = [p[0] for p in pairs] + assert "validated" in states + assert "draft" in states