feat(infra): add repository implementations, mappers, FastAPI wiring, and technique indexes

This commit is contained in:
2026-02-18 19:10:50 +01:00
parent 5c55e7c17f
commit 1521005b62
9 changed files with 618 additions and 0 deletions

View File

@@ -0,0 +1,237 @@
"""Integration tests for repository implementations.
Uses the SQLite-based ``db`` fixture from conftest.py.
"""
import uuid
import pytest
from app.domain.entities.technique import TechniqueEntity
from app.domain.enums import TechniqueStatus, TestState
from app.infrastructure.persistence.repositories.sa_technique_repository import (
SATechniqueRepository,
)
from app.infrastructure.persistence.repositories.sa_test_repository import (
SATestRepository,
)
from app.models.technique import Technique
from app.models.test import Test
# ── Helpers ──────────────────────────────────────────────────────────
def _create_technique(db, mitre_id="T1059", name="Test Technique", **kwargs):
tech = Technique(
id=uuid.uuid4(),
mitre_id=mitre_id,
name=name,
tactic=kwargs.get("tactic", "execution"),
status_global=kwargs.get("status_global", TechniqueStatus.not_evaluated),
)
db.add(tech)
db.flush()
return tech
def _create_test(db, technique, state=TestState.draft, **kwargs):
t = Test(
id=uuid.uuid4(),
technique_id=technique.id,
name=kwargs.get("name", "Test 1"),
state=state,
detection_result=kwargs.get("detection_result"),
)
db.add(t)
db.flush()
return t
# ── SATechniqueRepository ───────────────────────────────────────────
class TestSATechniqueRepository:
def test_find_by_id_returns_entity(self, db):
tech = _create_technique(db)
db.commit()
repo = SATechniqueRepository(db)
result = repo.find_by_id(tech.id)
assert result is not None
assert result.mitre_id == "T1059"
assert isinstance(result, TechniqueEntity)
def test_find_by_id_not_found(self, db):
repo = SATechniqueRepository(db)
assert repo.find_by_id(uuid.uuid4()) is None
def test_find_by_mitre_id(self, db):
_create_technique(db, mitre_id="T1548")
db.commit()
repo = SATechniqueRepository(db)
result = repo.find_by_mitre_id("T1548")
assert result is not None
assert result.mitre_id == "T1548"
def test_find_by_mitre_id_not_found(self, db):
repo = SATechniqueRepository(db)
assert repo.find_by_mitre_id("T9999") is None
def test_list_all_no_filters(self, db):
_create_technique(db, mitre_id="T1059", name="A")
_create_technique(db, mitre_id="T1060", name="B")
db.commit()
repo = SATechniqueRepository(db)
results = repo.list_all()
assert len(results) == 2
assert results[0].mitre_id == "T1059"
assert results[1].mitre_id == "T1060"
def test_list_all_filter_by_tactic(self, db):
_create_technique(db, mitre_id="T1059", tactic="execution")
_create_technique(db, mitre_id="T1060", tactic="persistence")
db.commit()
repo = SATechniqueRepository(db)
results = repo.list_all(tactic="execution")
assert len(results) == 1
assert results[0].mitre_id == "T1059"
def test_list_all_filter_by_status(self, db):
_create_technique(db, mitre_id="T1059", status_global=TechniqueStatus.validated)
_create_technique(db, mitre_id="T1060", status_global=TechniqueStatus.not_evaluated)
db.commit()
repo = SATechniqueRepository(db)
results = repo.list_all(status=TechniqueStatus.validated)
assert len(results) == 1
assert results[0].mitre_id == "T1059"
def test_list_by_ids(self, db):
t1 = _create_technique(db, mitre_id="T1059")
_create_technique(db, mitre_id="T1060")
db.commit()
repo = SATechniqueRepository(db)
results = repo.list_by_ids([t1.id])
assert len(results) == 1
assert results[0].mitre_id == "T1059"
def test_list_by_ids_empty(self, db):
repo = SATechniqueRepository(db)
assert repo.list_by_ids([]) == []
def test_count_by_status(self, db):
_create_technique(db, mitre_id="T1059", status_global=TechniqueStatus.validated)
_create_technique(db, mitre_id="T1060", status_global=TechniqueStatus.validated)
_create_technique(db, mitre_id="T1061", status_global=TechniqueStatus.not_evaluated)
db.commit()
repo = SATechniqueRepository(db)
counts = repo.count_by_status()
assert counts[TechniqueStatus.validated] == 2
assert counts[TechniqueStatus.not_evaluated] == 1
assert counts[TechniqueStatus.partial] == 0
def test_exists_by_mitre_id(self, db):
_create_technique(db, mitre_id="T1059")
db.commit()
repo = SATechniqueRepository(db)
assert repo.exists_by_mitre_id("T1059") is True
assert repo.exists_by_mitre_id("T9999") is False
def test_save_new_technique(self, db):
repo = SATechniqueRepository(db)
entity = TechniqueEntity.create(
mitre_id="T2000",
name="New Technique",
tactic="discovery",
)
saved = repo.save(entity)
db.commit()
assert saved.mitre_id == "T2000"
assert repo.find_by_mitre_id("T2000") is not None
def test_save_updates_existing(self, db):
tech = _create_technique(db, mitre_id="T1059")
db.commit()
repo = SATechniqueRepository(db)
entity = repo.find_by_id(tech.id)
entity.status_global = TechniqueStatus.validated
saved = repo.save(entity)
db.commit()
reloaded = repo.find_by_id(tech.id)
assert reloaded.status_global == TechniqueStatus.validated
# ── SATestRepository ─────────────────────────────────────────────────
class TestSATestRepository:
def test_find_by_id(self, db):
tech = _create_technique(db)
test = _create_test(db, tech)
db.commit()
repo = SATestRepository(db)
result = repo.find_by_id(test.id)
assert result is not None
assert result.id == test.id
def test_find_by_id_not_found(self, db):
repo = SATestRepository(db)
assert repo.find_by_id(uuid.uuid4()) is None
def test_list_by_technique(self, db):
tech = _create_technique(db)
_create_test(db, tech, name="Test A")
_create_test(db, tech, name="Test B")
db.commit()
repo = SATestRepository(db)
results = repo.list_by_technique(tech.id)
assert len(results) == 2
def test_list_by_state(self, db):
tech = _create_technique(db)
_create_test(db, tech, state=TestState.draft)
_create_test(db, tech, state=TestState.validated)
db.commit()
repo = SATestRepository(db)
drafts = repo.list_by_state(TestState.draft)
assert len(drafts) == 1
def test_count_by_technique_and_state(self, db):
tech = _create_technique(db)
_create_test(db, tech, state=TestState.draft)
_create_test(db, tech, state=TestState.draft)
_create_test(db, tech, state=TestState.validated)
db.commit()
repo = SATestRepository(db)
counts = repo.count_by_technique_and_state(tech.id)
assert counts.get(TestState.draft) == 2
assert counts.get(TestState.validated) == 1
def test_get_states_and_results(self, db):
tech = _create_technique(db)
_create_test(db, tech, state=TestState.validated, detection_result="detected")
_create_test(db, tech, state=TestState.draft)
db.commit()
repo = SATestRepository(db)
pairs = repo.get_states_and_results_for_technique(tech.id)
assert len(pairs) == 2
states = [p[0] for p in pairs]
assert "validated" in states
assert "draft" in states