feat(infra): add repository implementations, mappers, FastAPI wiring, and technique indexes
This commit is contained in:
237
backend/tests/test_repositories.py
Normal file
237
backend/tests/test_repositories.py
Normal file
@@ -0,0 +1,237 @@
|
||||
"""Integration tests for repository implementations.
|
||||
|
||||
Uses the SQLite-based ``db`` fixture from conftest.py.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
from app.domain.entities.technique import TechniqueEntity
|
||||
from app.domain.enums import TechniqueStatus, TestState
|
||||
from app.infrastructure.persistence.repositories.sa_technique_repository import (
|
||||
SATechniqueRepository,
|
||||
)
|
||||
from app.infrastructure.persistence.repositories.sa_test_repository import (
|
||||
SATestRepository,
|
||||
)
|
||||
from app.models.technique import Technique
|
||||
from app.models.test import Test
|
||||
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _create_technique(db, mitre_id="T1059", name="Test Technique", **kwargs):
|
||||
tech = Technique(
|
||||
id=uuid.uuid4(),
|
||||
mitre_id=mitre_id,
|
||||
name=name,
|
||||
tactic=kwargs.get("tactic", "execution"),
|
||||
status_global=kwargs.get("status_global", TechniqueStatus.not_evaluated),
|
||||
)
|
||||
db.add(tech)
|
||||
db.flush()
|
||||
return tech
|
||||
|
||||
|
||||
def _create_test(db, technique, state=TestState.draft, **kwargs):
|
||||
t = Test(
|
||||
id=uuid.uuid4(),
|
||||
technique_id=technique.id,
|
||||
name=kwargs.get("name", "Test 1"),
|
||||
state=state,
|
||||
detection_result=kwargs.get("detection_result"),
|
||||
)
|
||||
db.add(t)
|
||||
db.flush()
|
||||
return t
|
||||
|
||||
|
||||
# ── SATechniqueRepository ───────────────────────────────────────────
|
||||
|
||||
|
||||
class TestSATechniqueRepository:
|
||||
|
||||
def test_find_by_id_returns_entity(self, db):
|
||||
tech = _create_technique(db)
|
||||
db.commit()
|
||||
repo = SATechniqueRepository(db)
|
||||
|
||||
result = repo.find_by_id(tech.id)
|
||||
assert result is not None
|
||||
assert result.mitre_id == "T1059"
|
||||
assert isinstance(result, TechniqueEntity)
|
||||
|
||||
def test_find_by_id_not_found(self, db):
|
||||
repo = SATechniqueRepository(db)
|
||||
assert repo.find_by_id(uuid.uuid4()) is None
|
||||
|
||||
def test_find_by_mitre_id(self, db):
|
||||
_create_technique(db, mitre_id="T1548")
|
||||
db.commit()
|
||||
repo = SATechniqueRepository(db)
|
||||
|
||||
result = repo.find_by_mitre_id("T1548")
|
||||
assert result is not None
|
||||
assert result.mitre_id == "T1548"
|
||||
|
||||
def test_find_by_mitre_id_not_found(self, db):
|
||||
repo = SATechniqueRepository(db)
|
||||
assert repo.find_by_mitre_id("T9999") is None
|
||||
|
||||
def test_list_all_no_filters(self, db):
|
||||
_create_technique(db, mitre_id="T1059", name="A")
|
||||
_create_technique(db, mitre_id="T1060", name="B")
|
||||
db.commit()
|
||||
repo = SATechniqueRepository(db)
|
||||
|
||||
results = repo.list_all()
|
||||
assert len(results) == 2
|
||||
assert results[0].mitre_id == "T1059"
|
||||
assert results[1].mitre_id == "T1060"
|
||||
|
||||
def test_list_all_filter_by_tactic(self, db):
|
||||
_create_technique(db, mitre_id="T1059", tactic="execution")
|
||||
_create_technique(db, mitre_id="T1060", tactic="persistence")
|
||||
db.commit()
|
||||
repo = SATechniqueRepository(db)
|
||||
|
||||
results = repo.list_all(tactic="execution")
|
||||
assert len(results) == 1
|
||||
assert results[0].mitre_id == "T1059"
|
||||
|
||||
def test_list_all_filter_by_status(self, db):
|
||||
_create_technique(db, mitre_id="T1059", status_global=TechniqueStatus.validated)
|
||||
_create_technique(db, mitre_id="T1060", status_global=TechniqueStatus.not_evaluated)
|
||||
db.commit()
|
||||
repo = SATechniqueRepository(db)
|
||||
|
||||
results = repo.list_all(status=TechniqueStatus.validated)
|
||||
assert len(results) == 1
|
||||
assert results[0].mitre_id == "T1059"
|
||||
|
||||
def test_list_by_ids(self, db):
|
||||
t1 = _create_technique(db, mitre_id="T1059")
|
||||
_create_technique(db, mitre_id="T1060")
|
||||
db.commit()
|
||||
repo = SATechniqueRepository(db)
|
||||
|
||||
results = repo.list_by_ids([t1.id])
|
||||
assert len(results) == 1
|
||||
assert results[0].mitre_id == "T1059"
|
||||
|
||||
def test_list_by_ids_empty(self, db):
|
||||
repo = SATechniqueRepository(db)
|
||||
assert repo.list_by_ids([]) == []
|
||||
|
||||
def test_count_by_status(self, db):
|
||||
_create_technique(db, mitre_id="T1059", status_global=TechniqueStatus.validated)
|
||||
_create_technique(db, mitre_id="T1060", status_global=TechniqueStatus.validated)
|
||||
_create_technique(db, mitre_id="T1061", status_global=TechniqueStatus.not_evaluated)
|
||||
db.commit()
|
||||
repo = SATechniqueRepository(db)
|
||||
|
||||
counts = repo.count_by_status()
|
||||
assert counts[TechniqueStatus.validated] == 2
|
||||
assert counts[TechniqueStatus.not_evaluated] == 1
|
||||
assert counts[TechniqueStatus.partial] == 0
|
||||
|
||||
def test_exists_by_mitre_id(self, db):
|
||||
_create_technique(db, mitre_id="T1059")
|
||||
db.commit()
|
||||
repo = SATechniqueRepository(db)
|
||||
|
||||
assert repo.exists_by_mitre_id("T1059") is True
|
||||
assert repo.exists_by_mitre_id("T9999") is False
|
||||
|
||||
def test_save_new_technique(self, db):
|
||||
repo = SATechniqueRepository(db)
|
||||
entity = TechniqueEntity.create(
|
||||
mitre_id="T2000",
|
||||
name="New Technique",
|
||||
tactic="discovery",
|
||||
)
|
||||
saved = repo.save(entity)
|
||||
db.commit()
|
||||
|
||||
assert saved.mitre_id == "T2000"
|
||||
assert repo.find_by_mitre_id("T2000") is not None
|
||||
|
||||
def test_save_updates_existing(self, db):
|
||||
tech = _create_technique(db, mitre_id="T1059")
|
||||
db.commit()
|
||||
repo = SATechniqueRepository(db)
|
||||
|
||||
entity = repo.find_by_id(tech.id)
|
||||
entity.status_global = TechniqueStatus.validated
|
||||
saved = repo.save(entity)
|
||||
db.commit()
|
||||
|
||||
reloaded = repo.find_by_id(tech.id)
|
||||
assert reloaded.status_global == TechniqueStatus.validated
|
||||
|
||||
|
||||
# ── SATestRepository ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestSATestRepository:
|
||||
|
||||
def test_find_by_id(self, db):
|
||||
tech = _create_technique(db)
|
||||
test = _create_test(db, tech)
|
||||
db.commit()
|
||||
repo = SATestRepository(db)
|
||||
|
||||
result = repo.find_by_id(test.id)
|
||||
assert result is not None
|
||||
assert result.id == test.id
|
||||
|
||||
def test_find_by_id_not_found(self, db):
|
||||
repo = SATestRepository(db)
|
||||
assert repo.find_by_id(uuid.uuid4()) is None
|
||||
|
||||
def test_list_by_technique(self, db):
|
||||
tech = _create_technique(db)
|
||||
_create_test(db, tech, name="Test A")
|
||||
_create_test(db, tech, name="Test B")
|
||||
db.commit()
|
||||
repo = SATestRepository(db)
|
||||
|
||||
results = repo.list_by_technique(tech.id)
|
||||
assert len(results) == 2
|
||||
|
||||
def test_list_by_state(self, db):
|
||||
tech = _create_technique(db)
|
||||
_create_test(db, tech, state=TestState.draft)
|
||||
_create_test(db, tech, state=TestState.validated)
|
||||
db.commit()
|
||||
repo = SATestRepository(db)
|
||||
|
||||
drafts = repo.list_by_state(TestState.draft)
|
||||
assert len(drafts) == 1
|
||||
|
||||
def test_count_by_technique_and_state(self, db):
|
||||
tech = _create_technique(db)
|
||||
_create_test(db, tech, state=TestState.draft)
|
||||
_create_test(db, tech, state=TestState.draft)
|
||||
_create_test(db, tech, state=TestState.validated)
|
||||
db.commit()
|
||||
repo = SATestRepository(db)
|
||||
|
||||
counts = repo.count_by_technique_and_state(tech.id)
|
||||
assert counts.get(TestState.draft) == 2
|
||||
assert counts.get(TestState.validated) == 1
|
||||
|
||||
def test_get_states_and_results(self, db):
|
||||
tech = _create_technique(db)
|
||||
_create_test(db, tech, state=TestState.validated, detection_result="detected")
|
||||
_create_test(db, tech, state=TestState.draft)
|
||||
db.commit()
|
||||
repo = SATestRepository(db)
|
||||
|
||||
pairs = repo.get_states_and_results_for_technique(tech.id)
|
||||
assert len(pairs) == 2
|
||||
states = [p[0] for p in pairs]
|
||||
assert "validated" in states
|
||||
assert "draft" in states
|
||||
Reference in New Issue
Block a user