238 lines
7.8 KiB
Python
238 lines
7.8 KiB
Python
"""Integration tests for repository implementations.
|
|
|
|
Uses the SQLite-based ``db`` fixture from conftest.py.
|
|
"""
|
|
|
|
import uuid
|
|
|
|
import pytest
|
|
|
|
from app.domain.entities.technique import TechniqueEntity
|
|
from app.domain.enums import TechniqueStatus, TestState
|
|
from app.infrastructure.persistence.repositories.sa_technique_repository import (
|
|
SATechniqueRepository,
|
|
)
|
|
from app.infrastructure.persistence.repositories.sa_test_repository import (
|
|
SATestRepository,
|
|
)
|
|
from app.models.technique import Technique
|
|
from app.models.test import Test
|
|
|
|
|
|
# ── Helpers ──────────────────────────────────────────────────────────
|
|
|
|
|
|
def _create_technique(db, mitre_id="T1059", name="Test Technique", **kwargs):
|
|
tech = Technique(
|
|
id=uuid.uuid4(),
|
|
mitre_id=mitre_id,
|
|
name=name,
|
|
tactic=kwargs.get("tactic", "execution"),
|
|
status_global=kwargs.get("status_global", TechniqueStatus.not_evaluated),
|
|
)
|
|
db.add(tech)
|
|
db.flush()
|
|
return tech
|
|
|
|
|
|
def _create_test(db, technique, state=TestState.draft, **kwargs):
|
|
t = Test(
|
|
id=uuid.uuid4(),
|
|
technique_id=technique.id,
|
|
name=kwargs.get("name", "Test 1"),
|
|
state=state,
|
|
detection_result=kwargs.get("detection_result"),
|
|
)
|
|
db.add(t)
|
|
db.flush()
|
|
return t
|
|
|
|
|
|
# ── SATechniqueRepository ───────────────────────────────────────────
|
|
|
|
|
|
class TestSATechniqueRepository:
|
|
|
|
def test_find_by_id_returns_entity(self, db):
|
|
tech = _create_technique(db)
|
|
db.commit()
|
|
repo = SATechniqueRepository(db)
|
|
|
|
result = repo.find_by_id(tech.id)
|
|
assert result is not None
|
|
assert result.mitre_id == "T1059"
|
|
assert isinstance(result, TechniqueEntity)
|
|
|
|
def test_find_by_id_not_found(self, db):
|
|
repo = SATechniqueRepository(db)
|
|
assert repo.find_by_id(uuid.uuid4()) is None
|
|
|
|
def test_find_by_mitre_id(self, db):
|
|
_create_technique(db, mitre_id="T1548")
|
|
db.commit()
|
|
repo = SATechniqueRepository(db)
|
|
|
|
result = repo.find_by_mitre_id("T1548")
|
|
assert result is not None
|
|
assert result.mitre_id == "T1548"
|
|
|
|
def test_find_by_mitre_id_not_found(self, db):
|
|
repo = SATechniqueRepository(db)
|
|
assert repo.find_by_mitre_id("T9999") is None
|
|
|
|
def test_list_all_no_filters(self, db):
|
|
_create_technique(db, mitre_id="T1059", name="A")
|
|
_create_technique(db, mitre_id="T1060", name="B")
|
|
db.commit()
|
|
repo = SATechniqueRepository(db)
|
|
|
|
results = repo.list_all()
|
|
assert len(results) == 2
|
|
assert results[0].mitre_id == "T1059"
|
|
assert results[1].mitre_id == "T1060"
|
|
|
|
def test_list_all_filter_by_tactic(self, db):
|
|
_create_technique(db, mitre_id="T1059", tactic="execution")
|
|
_create_technique(db, mitre_id="T1060", tactic="persistence")
|
|
db.commit()
|
|
repo = SATechniqueRepository(db)
|
|
|
|
results = repo.list_all(tactic="execution")
|
|
assert len(results) == 1
|
|
assert results[0].mitre_id == "T1059"
|
|
|
|
def test_list_all_filter_by_status(self, db):
|
|
_create_technique(db, mitre_id="T1059", status_global=TechniqueStatus.validated)
|
|
_create_technique(db, mitre_id="T1060", status_global=TechniqueStatus.not_evaluated)
|
|
db.commit()
|
|
repo = SATechniqueRepository(db)
|
|
|
|
results = repo.list_all(status=TechniqueStatus.validated)
|
|
assert len(results) == 1
|
|
assert results[0].mitre_id == "T1059"
|
|
|
|
def test_list_by_ids(self, db):
|
|
t1 = _create_technique(db, mitre_id="T1059")
|
|
_create_technique(db, mitre_id="T1060")
|
|
db.commit()
|
|
repo = SATechniqueRepository(db)
|
|
|
|
results = repo.list_by_ids([t1.id])
|
|
assert len(results) == 1
|
|
assert results[0].mitre_id == "T1059"
|
|
|
|
def test_list_by_ids_empty(self, db):
|
|
repo = SATechniqueRepository(db)
|
|
assert repo.list_by_ids([]) == []
|
|
|
|
def test_count_by_status(self, db):
|
|
_create_technique(db, mitre_id="T1059", status_global=TechniqueStatus.validated)
|
|
_create_technique(db, mitre_id="T1060", status_global=TechniqueStatus.validated)
|
|
_create_technique(db, mitre_id="T1061", status_global=TechniqueStatus.not_evaluated)
|
|
db.commit()
|
|
repo = SATechniqueRepository(db)
|
|
|
|
counts = repo.count_by_status()
|
|
assert counts[TechniqueStatus.validated] == 2
|
|
assert counts[TechniqueStatus.not_evaluated] == 1
|
|
assert counts[TechniqueStatus.partial] == 0
|
|
|
|
def test_exists_by_mitre_id(self, db):
|
|
_create_technique(db, mitre_id="T1059")
|
|
db.commit()
|
|
repo = SATechniqueRepository(db)
|
|
|
|
assert repo.exists_by_mitre_id("T1059") is True
|
|
assert repo.exists_by_mitre_id("T9999") is False
|
|
|
|
def test_save_new_technique(self, db):
|
|
repo = SATechniqueRepository(db)
|
|
entity = TechniqueEntity.create(
|
|
mitre_id="T2000",
|
|
name="New Technique",
|
|
tactic="discovery",
|
|
)
|
|
saved = repo.save(entity)
|
|
db.commit()
|
|
|
|
assert saved.mitre_id == "T2000"
|
|
assert repo.find_by_mitre_id("T2000") is not None
|
|
|
|
def test_save_updates_existing(self, db):
|
|
tech = _create_technique(db, mitre_id="T1059")
|
|
db.commit()
|
|
repo = SATechniqueRepository(db)
|
|
|
|
entity = repo.find_by_id(tech.id)
|
|
entity.status_global = TechniqueStatus.validated
|
|
saved = repo.save(entity)
|
|
db.commit()
|
|
|
|
reloaded = repo.find_by_id(tech.id)
|
|
assert reloaded.status_global == TechniqueStatus.validated
|
|
|
|
|
|
# ── SATestRepository ─────────────────────────────────────────────────
|
|
|
|
|
|
class TestSATestRepository:
|
|
|
|
def test_find_by_id(self, db):
|
|
tech = _create_technique(db)
|
|
test = _create_test(db, tech)
|
|
db.commit()
|
|
repo = SATestRepository(db)
|
|
|
|
result = repo.find_by_id(test.id)
|
|
assert result is not None
|
|
assert result.id == test.id
|
|
|
|
def test_find_by_id_not_found(self, db):
|
|
repo = SATestRepository(db)
|
|
assert repo.find_by_id(uuid.uuid4()) is None
|
|
|
|
def test_list_by_technique(self, db):
|
|
tech = _create_technique(db)
|
|
_create_test(db, tech, name="Test A")
|
|
_create_test(db, tech, name="Test B")
|
|
db.commit()
|
|
repo = SATestRepository(db)
|
|
|
|
results = repo.list_by_technique(tech.id)
|
|
assert len(results) == 2
|
|
|
|
def test_list_by_state(self, db):
|
|
tech = _create_technique(db)
|
|
_create_test(db, tech, state=TestState.draft)
|
|
_create_test(db, tech, state=TestState.validated)
|
|
db.commit()
|
|
repo = SATestRepository(db)
|
|
|
|
drafts = repo.list_by_state(TestState.draft)
|
|
assert len(drafts) == 1
|
|
|
|
def test_count_by_technique_and_state(self, db):
|
|
tech = _create_technique(db)
|
|
_create_test(db, tech, state=TestState.draft)
|
|
_create_test(db, tech, state=TestState.draft)
|
|
_create_test(db, tech, state=TestState.validated)
|
|
db.commit()
|
|
repo = SATestRepository(db)
|
|
|
|
counts = repo.count_by_technique_and_state(tech.id)
|
|
assert counts.get(TestState.draft) == 2
|
|
assert counts.get(TestState.validated) == 1
|
|
|
|
def test_get_states_and_results(self, db):
|
|
tech = _create_technique(db)
|
|
_create_test(db, tech, state=TestState.validated, detection_result="detected")
|
|
_create_test(db, tech, state=TestState.draft)
|
|
db.commit()
|
|
repo = SATestRepository(db)
|
|
|
|
pairs = repo.get_states_and_results_for_technique(tech.id)
|
|
assert len(pairs) == 2
|
|
states = [p[0] for p in pairs]
|
|
assert "validated" in states
|
|
assert "draft" in states
|