"""Integration tests for repository implementations. Uses the SQLite-based ``db`` fixture from conftest.py. """ import uuid import pytest from app.domain.entities.technique import TechniqueEntity from app.domain.enums import TechniqueStatus, TestState from app.infrastructure.persistence.repositories.sa_technique_repository import ( SATechniqueRepository, ) from app.infrastructure.persistence.repositories.sa_test_repository import ( SATestRepository, ) from app.models.technique import Technique from app.models.test import Test # ── Helpers ────────────────────────────────────────────────────────── def _create_technique(db, mitre_id="T1059", name="Test Technique", **kwargs): tech = Technique( id=uuid.uuid4(), mitre_id=mitre_id, name=name, tactic=kwargs.get("tactic", "execution"), status_global=kwargs.get("status_global", TechniqueStatus.not_evaluated), ) db.add(tech) db.flush() return tech def _create_test(db, technique, state=TestState.draft, **kwargs): t = Test( id=uuid.uuid4(), technique_id=technique.id, name=kwargs.get("name", "Test 1"), state=state, detection_result=kwargs.get("detection_result"), ) db.add(t) db.flush() return t # ── SATechniqueRepository ─────────────────────────────────────────── class TestSATechniqueRepository: def test_find_by_id_returns_entity(self, db): tech = _create_technique(db) db.commit() repo = SATechniqueRepository(db) result = repo.find_by_id(tech.id) assert result is not None assert result.mitre_id == "T1059" assert isinstance(result, TechniqueEntity) def test_find_by_id_not_found(self, db): repo = SATechniqueRepository(db) assert repo.find_by_id(uuid.uuid4()) is None def test_find_by_mitre_id(self, db): _create_technique(db, mitre_id="T1548") db.commit() repo = SATechniqueRepository(db) result = repo.find_by_mitre_id("T1548") assert result is not None assert result.mitre_id == "T1548" def test_find_by_mitre_id_not_found(self, db): repo = SATechniqueRepository(db) assert repo.find_by_mitre_id("T9999") is None def test_list_all_no_filters(self, db): _create_technique(db, mitre_id="T1059", name="A") _create_technique(db, mitre_id="T1060", name="B") db.commit() repo = SATechniqueRepository(db) results = repo.list_all() assert len(results) == 2 assert results[0].mitre_id == "T1059" assert results[1].mitre_id == "T1060" def test_list_all_filter_by_tactic(self, db): _create_technique(db, mitre_id="T1059", tactic="execution") _create_technique(db, mitre_id="T1060", tactic="persistence") db.commit() repo = SATechniqueRepository(db) results = repo.list_all(tactic="execution") assert len(results) == 1 assert results[0].mitre_id == "T1059" def test_list_all_filter_by_status(self, db): _create_technique(db, mitre_id="T1059", status_global=TechniqueStatus.validated) _create_technique(db, mitre_id="T1060", status_global=TechniqueStatus.not_evaluated) db.commit() repo = SATechniqueRepository(db) results = repo.list_all(status=TechniqueStatus.validated) assert len(results) == 1 assert results[0].mitre_id == "T1059" def test_list_by_ids(self, db): t1 = _create_technique(db, mitre_id="T1059") _create_technique(db, mitre_id="T1060") db.commit() repo = SATechniqueRepository(db) results = repo.list_by_ids([t1.id]) assert len(results) == 1 assert results[0].mitre_id == "T1059" def test_list_by_ids_empty(self, db): repo = SATechniqueRepository(db) assert repo.list_by_ids([]) == [] def test_count_by_status(self, db): _create_technique(db, mitre_id="T1059", status_global=TechniqueStatus.validated) _create_technique(db, mitre_id="T1060", status_global=TechniqueStatus.validated) _create_technique(db, mitre_id="T1061", status_global=TechniqueStatus.not_evaluated) db.commit() repo = SATechniqueRepository(db) counts = repo.count_by_status() assert counts[TechniqueStatus.validated] == 2 assert counts[TechniqueStatus.not_evaluated] == 1 assert counts[TechniqueStatus.partial] == 0 def test_exists_by_mitre_id(self, db): _create_technique(db, mitre_id="T1059") db.commit() repo = SATechniqueRepository(db) assert repo.exists_by_mitre_id("T1059") is True assert repo.exists_by_mitre_id("T9999") is False def test_save_new_technique(self, db): repo = SATechniqueRepository(db) entity = TechniqueEntity.create( mitre_id="T2000", name="New Technique", tactic="discovery", ) saved = repo.save(entity) db.commit() assert saved.mitre_id == "T2000" assert repo.find_by_mitre_id("T2000") is not None def test_save_updates_existing(self, db): tech = _create_technique(db, mitre_id="T1059") db.commit() repo = SATechniqueRepository(db) entity = repo.find_by_id(tech.id) entity.status_global = TechniqueStatus.validated saved = repo.save(entity) db.commit() reloaded = repo.find_by_id(tech.id) assert reloaded.status_global == TechniqueStatus.validated # ── SATestRepository ───────────────────────────────────────────────── class TestSATestRepository: def test_find_by_id(self, db): tech = _create_technique(db) test = _create_test(db, tech) db.commit() repo = SATestRepository(db) result = repo.find_by_id(test.id) assert result is not None assert result.id == test.id def test_find_by_id_not_found(self, db): repo = SATestRepository(db) assert repo.find_by_id(uuid.uuid4()) is None def test_list_by_technique(self, db): tech = _create_technique(db) _create_test(db, tech, name="Test A") _create_test(db, tech, name="Test B") db.commit() repo = SATestRepository(db) results = repo.list_by_technique(tech.id) assert len(results) == 2 def test_list_by_state(self, db): tech = _create_technique(db) _create_test(db, tech, state=TestState.draft) _create_test(db, tech, state=TestState.validated) db.commit() repo = SATestRepository(db) drafts = repo.list_by_state(TestState.draft) assert len(drafts) == 1 def test_count_by_technique_and_state(self, db): tech = _create_technique(db) _create_test(db, tech, state=TestState.draft) _create_test(db, tech, state=TestState.draft) _create_test(db, tech, state=TestState.validated) db.commit() repo = SATestRepository(db) counts = repo.count_by_technique_and_state(tech.id) assert counts.get(TestState.draft) == 2 assert counts.get(TestState.validated) == 1 def test_get_states_and_results(self, db): tech = _create_technique(db) _create_test(db, tech, state=TestState.validated, detection_result="detected") _create_test(db, tech, state=TestState.draft) db.commit() repo = SATestRepository(db) pairs = repo.get_states_and_results_for_technique(tech.id) assert len(pairs) == 2 states = [p[0] for p in pairs] assert "validated" in states assert "draft" in states