feat(infra): add repository implementations, mappers, FastAPI wiring, and technique indexes
This commit is contained in:
38
backend/alembic/versions/b026_add_technique_query_indexes.py
Normal file
38
backend/alembic/versions/b026_add_technique_query_indexes.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
"""add_technique_query_indexes
|
||||||
|
|
||||||
|
Add indexes on techniques table for common query patterns
|
||||||
|
(filter by tactic, filter by status_global) used in heatmap, scoring,
|
||||||
|
and list-all-techniques operations.
|
||||||
|
|
||||||
|
These may already exist if the ORM model auto-created them; the
|
||||||
|
``if_not_exists`` flag makes this migration safe to run regardless.
|
||||||
|
|
||||||
|
Revision ID: b026techidx
|
||||||
|
Revises: b025uqtdr
|
||||||
|
Create Date: 2026-02-18 18:00:00.000000
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision: str = "b026techidx"
|
||||||
|
down_revision: Union[str, None] = "b025uqtdr"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.execute(
|
||||||
|
"CREATE INDEX IF NOT EXISTS ix_techniques_tactic "
|
||||||
|
"ON techniques (tactic)"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE INDEX IF NOT EXISTS ix_techniques_status_global "
|
||||||
|
"ON techniques (status_global)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index("ix_techniques_status_global", table_name="techniques")
|
||||||
|
op.drop_index("ix_techniques_tactic", table_name="techniques")
|
||||||
30
backend/app/dependencies/repositories.py
Normal file
30
backend/app/dependencies/repositories.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
"""FastAPI dependency providers for repositories.
|
||||||
|
|
||||||
|
Wiring lives ONLY in the presentation layer — use cases and services
|
||||||
|
never know which concrete repository implementation they receive.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from fastapi import Depends
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.database import get_db
|
||||||
|
from app.infrastructure.persistence.repositories.sa_technique_repository import (
|
||||||
|
SATechniqueRepository,
|
||||||
|
)
|
||||||
|
from app.infrastructure.persistence.repositories.sa_test_repository import (
|
||||||
|
SATestRepository,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_technique_repository(
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
) -> SATechniqueRepository:
|
||||||
|
"""Provide a TechniqueRepository backed by the current DB session."""
|
||||||
|
return SATechniqueRepository(db)
|
||||||
|
|
||||||
|
|
||||||
|
def get_test_repository(
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
) -> SATestRepository:
|
||||||
|
"""Provide a TestRepository backed by the current DB session."""
|
||||||
|
return SATestRepository(db)
|
||||||
0
backend/app/infrastructure/persistence/__init__.py
Normal file
0
backend/app/infrastructure/persistence/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
"""Technique ORM model <-> domain entity mapper."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from app.domain.entities.technique import TechniqueEntity
|
||||||
|
from app.domain.enums import TechniqueStatus
|
||||||
|
|
||||||
|
|
||||||
|
class TechniqueMapper:
|
||||||
|
"""Converts between SQLAlchemy Technique model and TechniqueEntity."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def to_entity(model: object) -> TechniqueEntity:
|
||||||
|
"""Convert an ORM Technique model to a domain TechniqueEntity."""
|
||||||
|
return TechniqueEntity.from_orm(model)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def to_model_updates(entity: TechniqueEntity, model: object) -> None:
|
||||||
|
"""Apply entity changes back onto an existing ORM model."""
|
||||||
|
entity.apply_to(model)
|
||||||
@@ -0,0 +1,8 @@
|
|||||||
|
from app.infrastructure.persistence.repositories.sa_technique_repository import (
|
||||||
|
SATechniqueRepository,
|
||||||
|
)
|
||||||
|
from app.infrastructure.persistence.repositories.sa_test_repository import (
|
||||||
|
SATestRepository,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = ["SATechniqueRepository", "SATestRepository"]
|
||||||
@@ -0,0 +1,199 @@
|
|||||||
|
"""SQLAlchemy implementation of TechniqueRepository.
|
||||||
|
|
||||||
|
Receives a Session from the caller — does NOT create its own.
|
||||||
|
Does NOT call commit() — the Unit of Work owns that.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from sqlalchemy import func
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.domain.entities.technique import TechniqueEntity
|
||||||
|
from app.domain.enums import TechniqueStatus, TestState
|
||||||
|
from app.domain.ports.repositories.technique_repository import TechniqueWithCounts
|
||||||
|
from app.infrastructure.persistence.mappers.technique_mapper import TechniqueMapper
|
||||||
|
from app.models.detection_rule import DetectionRule
|
||||||
|
from app.models.technique import Technique
|
||||||
|
from app.models.test import Test
|
||||||
|
|
||||||
|
|
||||||
|
class SATechniqueRepository:
|
||||||
|
"""Concrete repository backed by SQLAlchemy."""
|
||||||
|
|
||||||
|
def __init__(self, session: Session) -> None:
|
||||||
|
self._session = session
|
||||||
|
|
||||||
|
# -- Single-entity access ----------------------------------------------
|
||||||
|
|
||||||
|
def find_by_id(self, technique_id: uuid.UUID) -> TechniqueEntity | None:
|
||||||
|
model = (
|
||||||
|
self._session.query(Technique)
|
||||||
|
.filter(Technique.id == technique_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
return TechniqueMapper.to_entity(model) if model else None
|
||||||
|
|
||||||
|
def find_by_mitre_id(self, mitre_id: str) -> TechniqueEntity | None:
|
||||||
|
model = (
|
||||||
|
self._session.query(Technique)
|
||||||
|
.filter(Technique.mitre_id == mitre_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
return TechniqueMapper.to_entity(model) if model else None
|
||||||
|
|
||||||
|
# -- List access -------------------------------------------------------
|
||||||
|
|
||||||
|
def list_all(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
tactic: str | None = None,
|
||||||
|
status: TechniqueStatus | None = None,
|
||||||
|
review_required: bool | None = None,
|
||||||
|
) -> list[TechniqueEntity]:
|
||||||
|
query = self._session.query(Technique)
|
||||||
|
if tactic is not None:
|
||||||
|
query = query.filter(Technique.tactic == tactic)
|
||||||
|
if status is not None:
|
||||||
|
query = query.filter(Technique.status_global == status)
|
||||||
|
if review_required is not None:
|
||||||
|
query = query.filter(Technique.review_required == review_required)
|
||||||
|
models = query.order_by(Technique.mitre_id).all()
|
||||||
|
return [TechniqueMapper.to_entity(m) for m in models]
|
||||||
|
|
||||||
|
def list_by_ids(self, ids: list[uuid.UUID]) -> list[TechniqueEntity]:
|
||||||
|
if not ids:
|
||||||
|
return []
|
||||||
|
models = (
|
||||||
|
self._session.query(Technique)
|
||||||
|
.filter(Technique.id.in_(ids))
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
return [TechniqueMapper.to_entity(m) for m in models]
|
||||||
|
|
||||||
|
# -- Batch queries (for scoring/heatmap) -------------------------------
|
||||||
|
|
||||||
|
def count_by_status(self) -> dict[TechniqueStatus, int]:
|
||||||
|
rows = (
|
||||||
|
self._session.query(
|
||||||
|
Technique.status_global,
|
||||||
|
func.count(Technique.id),
|
||||||
|
)
|
||||||
|
.group_by(Technique.status_global)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
result = {s: 0 for s in TechniqueStatus}
|
||||||
|
for status_val, count in rows:
|
||||||
|
key = (
|
||||||
|
status_val
|
||||||
|
if isinstance(status_val, TechniqueStatus)
|
||||||
|
else TechniqueStatus(status_val)
|
||||||
|
)
|
||||||
|
result[key] = count
|
||||||
|
return result
|
||||||
|
|
||||||
|
def find_all_with_test_counts(self) -> list[TechniqueWithCounts]:
|
||||||
|
"""Single query replacing the N+1 pattern.
|
||||||
|
|
||||||
|
Returns all techniques with pre-aggregated test and detection
|
||||||
|
rule counts via subqueries.
|
||||||
|
"""
|
||||||
|
test_count_sq = (
|
||||||
|
self._session.query(
|
||||||
|
Test.technique_id,
|
||||||
|
func.count(Test.id).label("test_count"),
|
||||||
|
func.sum(
|
||||||
|
func.cast(Test.state == TestState.validated, self._int_type())
|
||||||
|
).label("validated_count"),
|
||||||
|
)
|
||||||
|
.group_by(Test.technique_id)
|
||||||
|
.subquery()
|
||||||
|
)
|
||||||
|
rule_count_sq = (
|
||||||
|
self._session.query(
|
||||||
|
DetectionRule.mitre_technique_id,
|
||||||
|
func.count(DetectionRule.id).label("rule_count"),
|
||||||
|
)
|
||||||
|
.group_by(DetectionRule.mitre_technique_id)
|
||||||
|
.subquery()
|
||||||
|
)
|
||||||
|
|
||||||
|
rows = (
|
||||||
|
self._session.query(
|
||||||
|
Technique,
|
||||||
|
func.coalesce(test_count_sq.c.test_count, 0),
|
||||||
|
func.coalesce(test_count_sq.c.validated_count, 0),
|
||||||
|
func.coalesce(rule_count_sq.c.rule_count, 0),
|
||||||
|
)
|
||||||
|
.outerjoin(test_count_sq, Technique.id == test_count_sq.c.technique_id)
|
||||||
|
.outerjoin(
|
||||||
|
rule_count_sq,
|
||||||
|
Technique.mitre_id == rule_count_sq.c.mitre_technique_id,
|
||||||
|
)
|
||||||
|
.order_by(Technique.mitre_id)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
return [
|
||||||
|
TechniqueWithCounts(
|
||||||
|
entity=TechniqueMapper.to_entity(tech),
|
||||||
|
test_count=int(tc),
|
||||||
|
validated_test_count=int(vtc),
|
||||||
|
detection_rule_count=int(rc),
|
||||||
|
)
|
||||||
|
for tech, tc, vtc, rc in rows
|
||||||
|
]
|
||||||
|
|
||||||
|
# -- Mutations ---------------------------------------------------------
|
||||||
|
|
||||||
|
def save(self, technique: TechniqueEntity) -> TechniqueEntity:
|
||||||
|
existing = (
|
||||||
|
self._session.query(Technique)
|
||||||
|
.filter(Technique.id == technique.id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if existing:
|
||||||
|
technique.apply_to(existing)
|
||||||
|
existing.mitre_id = technique.mitre_id
|
||||||
|
existing.name = technique.name
|
||||||
|
existing.tactic = technique.tactic
|
||||||
|
existing.description = technique.description
|
||||||
|
existing.platforms = technique.platforms
|
||||||
|
existing.is_subtechnique = technique.is_subtechnique
|
||||||
|
existing.parent_mitre_id = technique.parent_mitre_id
|
||||||
|
self._session.flush()
|
||||||
|
return TechniqueMapper.to_entity(existing)
|
||||||
|
else:
|
||||||
|
model = Technique(
|
||||||
|
id=technique.id,
|
||||||
|
mitre_id=technique.mitre_id,
|
||||||
|
name=technique.name,
|
||||||
|
tactic=technique.tactic,
|
||||||
|
description=technique.description,
|
||||||
|
platforms=technique.platforms,
|
||||||
|
is_subtechnique=technique.is_subtechnique,
|
||||||
|
parent_mitre_id=technique.parent_mitre_id,
|
||||||
|
status_global=technique.status_global,
|
||||||
|
review_required=technique.review_required,
|
||||||
|
last_review_date=technique.last_review_date,
|
||||||
|
)
|
||||||
|
self._session.add(model)
|
||||||
|
self._session.flush()
|
||||||
|
return TechniqueMapper.to_entity(model)
|
||||||
|
|
||||||
|
def exists_by_mitre_id(self, mitre_id: str) -> bool:
|
||||||
|
return (
|
||||||
|
self._session.query(Technique.id)
|
||||||
|
.filter(Technique.mitre_id == mitre_id)
|
||||||
|
.first()
|
||||||
|
) is not None
|
||||||
|
|
||||||
|
# -- Internal ----------------------------------------------------------
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _int_type():
|
||||||
|
"""Return an Integer type for CAST expressions (SQLite-compatible)."""
|
||||||
|
from sqlalchemy import Integer
|
||||||
|
return Integer
|
||||||
@@ -0,0 +1,86 @@
|
|||||||
|
"""SQLAlchemy implementation of TestRepository."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from sqlalchemy import func
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.domain.enums import TestState
|
||||||
|
from app.models.test import Test
|
||||||
|
|
||||||
|
|
||||||
|
class SATestRepository:
|
||||||
|
"""Concrete test repository backed by SQLAlchemy."""
|
||||||
|
|
||||||
|
def __init__(self, session: Session) -> None:
|
||||||
|
self._session = session
|
||||||
|
|
||||||
|
def find_by_id(self, test_id: uuid.UUID) -> Test | None:
|
||||||
|
return (
|
||||||
|
self._session.query(Test)
|
||||||
|
.filter(Test.id == test_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
def list_by_technique(self, technique_id: uuid.UUID) -> list[Test]:
|
||||||
|
return (
|
||||||
|
self._session.query(Test)
|
||||||
|
.filter(Test.technique_id == technique_id)
|
||||||
|
.order_by(Test.created_at)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
def list_by_state(self, state: TestState) -> list[Test]:
|
||||||
|
return (
|
||||||
|
self._session.query(Test)
|
||||||
|
.filter(Test.state == state)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
def count_by_technique_and_state(
|
||||||
|
self,
|
||||||
|
technique_id: uuid.UUID,
|
||||||
|
) -> dict[TestState, int]:
|
||||||
|
rows = (
|
||||||
|
self._session.query(Test.state, func.count(Test.id))
|
||||||
|
.filter(Test.technique_id == technique_id)
|
||||||
|
.group_by(Test.state)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
result: dict[TestState, int] = {}
|
||||||
|
for state_val, count in rows:
|
||||||
|
key = (
|
||||||
|
state_val
|
||||||
|
if isinstance(state_val, TestState)
|
||||||
|
else TestState(state_val)
|
||||||
|
)
|
||||||
|
result[key] = count
|
||||||
|
return result
|
||||||
|
|
||||||
|
def get_states_and_results_for_technique(
|
||||||
|
self,
|
||||||
|
technique_id: uuid.UUID,
|
||||||
|
) -> list[tuple[str, str | None]]:
|
||||||
|
"""Return lightweight (state, detection_result) pairs.
|
||||||
|
|
||||||
|
Used by TechniqueEntity.recalculate_status() without loading
|
||||||
|
full Test models.
|
||||||
|
"""
|
||||||
|
rows = (
|
||||||
|
self._session.query(Test.state, Test.detection_result)
|
||||||
|
.filter(Test.technique_id == technique_id)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
return [
|
||||||
|
(
|
||||||
|
r.state.value if hasattr(r.state, "value") else str(r.state),
|
||||||
|
(
|
||||||
|
r.detection_result.value
|
||||||
|
if hasattr(r.detection_result, "value")
|
||||||
|
else r.detection_result
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for r in rows
|
||||||
|
]
|
||||||
237
backend/tests/test_repositories.py
Normal file
237
backend/tests/test_repositories.py
Normal file
@@ -0,0 +1,237 @@
|
|||||||
|
"""Integration tests for repository implementations.
|
||||||
|
|
||||||
|
Uses the SQLite-based ``db`` fixture from conftest.py.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.domain.entities.technique import TechniqueEntity
|
||||||
|
from app.domain.enums import TechniqueStatus, TestState
|
||||||
|
from app.infrastructure.persistence.repositories.sa_technique_repository import (
|
||||||
|
SATechniqueRepository,
|
||||||
|
)
|
||||||
|
from app.infrastructure.persistence.repositories.sa_test_repository import (
|
||||||
|
SATestRepository,
|
||||||
|
)
|
||||||
|
from app.models.technique import Technique
|
||||||
|
from app.models.test import Test
|
||||||
|
|
||||||
|
|
||||||
|
# ── Helpers ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _create_technique(db, mitre_id="T1059", name="Test Technique", **kwargs):
|
||||||
|
tech = Technique(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
mitre_id=mitre_id,
|
||||||
|
name=name,
|
||||||
|
tactic=kwargs.get("tactic", "execution"),
|
||||||
|
status_global=kwargs.get("status_global", TechniqueStatus.not_evaluated),
|
||||||
|
)
|
||||||
|
db.add(tech)
|
||||||
|
db.flush()
|
||||||
|
return tech
|
||||||
|
|
||||||
|
|
||||||
|
def _create_test(db, technique, state=TestState.draft, **kwargs):
|
||||||
|
t = Test(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
technique_id=technique.id,
|
||||||
|
name=kwargs.get("name", "Test 1"),
|
||||||
|
state=state,
|
||||||
|
detection_result=kwargs.get("detection_result"),
|
||||||
|
)
|
||||||
|
db.add(t)
|
||||||
|
db.flush()
|
||||||
|
return t
|
||||||
|
|
||||||
|
|
||||||
|
# ── SATechniqueRepository ───────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestSATechniqueRepository:
|
||||||
|
|
||||||
|
def test_find_by_id_returns_entity(self, db):
|
||||||
|
tech = _create_technique(db)
|
||||||
|
db.commit()
|
||||||
|
repo = SATechniqueRepository(db)
|
||||||
|
|
||||||
|
result = repo.find_by_id(tech.id)
|
||||||
|
assert result is not None
|
||||||
|
assert result.mitre_id == "T1059"
|
||||||
|
assert isinstance(result, TechniqueEntity)
|
||||||
|
|
||||||
|
def test_find_by_id_not_found(self, db):
|
||||||
|
repo = SATechniqueRepository(db)
|
||||||
|
assert repo.find_by_id(uuid.uuid4()) is None
|
||||||
|
|
||||||
|
def test_find_by_mitre_id(self, db):
|
||||||
|
_create_technique(db, mitre_id="T1548")
|
||||||
|
db.commit()
|
||||||
|
repo = SATechniqueRepository(db)
|
||||||
|
|
||||||
|
result = repo.find_by_mitre_id("T1548")
|
||||||
|
assert result is not None
|
||||||
|
assert result.mitre_id == "T1548"
|
||||||
|
|
||||||
|
def test_find_by_mitre_id_not_found(self, db):
|
||||||
|
repo = SATechniqueRepository(db)
|
||||||
|
assert repo.find_by_mitre_id("T9999") is None
|
||||||
|
|
||||||
|
def test_list_all_no_filters(self, db):
|
||||||
|
_create_technique(db, mitre_id="T1059", name="A")
|
||||||
|
_create_technique(db, mitre_id="T1060", name="B")
|
||||||
|
db.commit()
|
||||||
|
repo = SATechniqueRepository(db)
|
||||||
|
|
||||||
|
results = repo.list_all()
|
||||||
|
assert len(results) == 2
|
||||||
|
assert results[0].mitre_id == "T1059"
|
||||||
|
assert results[1].mitre_id == "T1060"
|
||||||
|
|
||||||
|
def test_list_all_filter_by_tactic(self, db):
|
||||||
|
_create_technique(db, mitre_id="T1059", tactic="execution")
|
||||||
|
_create_technique(db, mitre_id="T1060", tactic="persistence")
|
||||||
|
db.commit()
|
||||||
|
repo = SATechniqueRepository(db)
|
||||||
|
|
||||||
|
results = repo.list_all(tactic="execution")
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0].mitre_id == "T1059"
|
||||||
|
|
||||||
|
def test_list_all_filter_by_status(self, db):
|
||||||
|
_create_technique(db, mitre_id="T1059", status_global=TechniqueStatus.validated)
|
||||||
|
_create_technique(db, mitre_id="T1060", status_global=TechniqueStatus.not_evaluated)
|
||||||
|
db.commit()
|
||||||
|
repo = SATechniqueRepository(db)
|
||||||
|
|
||||||
|
results = repo.list_all(status=TechniqueStatus.validated)
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0].mitre_id == "T1059"
|
||||||
|
|
||||||
|
def test_list_by_ids(self, db):
|
||||||
|
t1 = _create_technique(db, mitre_id="T1059")
|
||||||
|
_create_technique(db, mitre_id="T1060")
|
||||||
|
db.commit()
|
||||||
|
repo = SATechniqueRepository(db)
|
||||||
|
|
||||||
|
results = repo.list_by_ids([t1.id])
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0].mitre_id == "T1059"
|
||||||
|
|
||||||
|
def test_list_by_ids_empty(self, db):
|
||||||
|
repo = SATechniqueRepository(db)
|
||||||
|
assert repo.list_by_ids([]) == []
|
||||||
|
|
||||||
|
def test_count_by_status(self, db):
|
||||||
|
_create_technique(db, mitre_id="T1059", status_global=TechniqueStatus.validated)
|
||||||
|
_create_technique(db, mitre_id="T1060", status_global=TechniqueStatus.validated)
|
||||||
|
_create_technique(db, mitre_id="T1061", status_global=TechniqueStatus.not_evaluated)
|
||||||
|
db.commit()
|
||||||
|
repo = SATechniqueRepository(db)
|
||||||
|
|
||||||
|
counts = repo.count_by_status()
|
||||||
|
assert counts[TechniqueStatus.validated] == 2
|
||||||
|
assert counts[TechniqueStatus.not_evaluated] == 1
|
||||||
|
assert counts[TechniqueStatus.partial] == 0
|
||||||
|
|
||||||
|
def test_exists_by_mitre_id(self, db):
|
||||||
|
_create_technique(db, mitre_id="T1059")
|
||||||
|
db.commit()
|
||||||
|
repo = SATechniqueRepository(db)
|
||||||
|
|
||||||
|
assert repo.exists_by_mitre_id("T1059") is True
|
||||||
|
assert repo.exists_by_mitre_id("T9999") is False
|
||||||
|
|
||||||
|
def test_save_new_technique(self, db):
|
||||||
|
repo = SATechniqueRepository(db)
|
||||||
|
entity = TechniqueEntity.create(
|
||||||
|
mitre_id="T2000",
|
||||||
|
name="New Technique",
|
||||||
|
tactic="discovery",
|
||||||
|
)
|
||||||
|
saved = repo.save(entity)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
assert saved.mitre_id == "T2000"
|
||||||
|
assert repo.find_by_mitre_id("T2000") is not None
|
||||||
|
|
||||||
|
def test_save_updates_existing(self, db):
|
||||||
|
tech = _create_technique(db, mitre_id="T1059")
|
||||||
|
db.commit()
|
||||||
|
repo = SATechniqueRepository(db)
|
||||||
|
|
||||||
|
entity = repo.find_by_id(tech.id)
|
||||||
|
entity.status_global = TechniqueStatus.validated
|
||||||
|
saved = repo.save(entity)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
reloaded = repo.find_by_id(tech.id)
|
||||||
|
assert reloaded.status_global == TechniqueStatus.validated
|
||||||
|
|
||||||
|
|
||||||
|
# ── SATestRepository ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestSATestRepository:
|
||||||
|
|
||||||
|
def test_find_by_id(self, db):
|
||||||
|
tech = _create_technique(db)
|
||||||
|
test = _create_test(db, tech)
|
||||||
|
db.commit()
|
||||||
|
repo = SATestRepository(db)
|
||||||
|
|
||||||
|
result = repo.find_by_id(test.id)
|
||||||
|
assert result is not None
|
||||||
|
assert result.id == test.id
|
||||||
|
|
||||||
|
def test_find_by_id_not_found(self, db):
|
||||||
|
repo = SATestRepository(db)
|
||||||
|
assert repo.find_by_id(uuid.uuid4()) is None
|
||||||
|
|
||||||
|
def test_list_by_technique(self, db):
|
||||||
|
tech = _create_technique(db)
|
||||||
|
_create_test(db, tech, name="Test A")
|
||||||
|
_create_test(db, tech, name="Test B")
|
||||||
|
db.commit()
|
||||||
|
repo = SATestRepository(db)
|
||||||
|
|
||||||
|
results = repo.list_by_technique(tech.id)
|
||||||
|
assert len(results) == 2
|
||||||
|
|
||||||
|
def test_list_by_state(self, db):
|
||||||
|
tech = _create_technique(db)
|
||||||
|
_create_test(db, tech, state=TestState.draft)
|
||||||
|
_create_test(db, tech, state=TestState.validated)
|
||||||
|
db.commit()
|
||||||
|
repo = SATestRepository(db)
|
||||||
|
|
||||||
|
drafts = repo.list_by_state(TestState.draft)
|
||||||
|
assert len(drafts) == 1
|
||||||
|
|
||||||
|
def test_count_by_technique_and_state(self, db):
|
||||||
|
tech = _create_technique(db)
|
||||||
|
_create_test(db, tech, state=TestState.draft)
|
||||||
|
_create_test(db, tech, state=TestState.draft)
|
||||||
|
_create_test(db, tech, state=TestState.validated)
|
||||||
|
db.commit()
|
||||||
|
repo = SATestRepository(db)
|
||||||
|
|
||||||
|
counts = repo.count_by_technique_and_state(tech.id)
|
||||||
|
assert counts.get(TestState.draft) == 2
|
||||||
|
assert counts.get(TestState.validated) == 1
|
||||||
|
|
||||||
|
def test_get_states_and_results(self, db):
|
||||||
|
tech = _create_technique(db)
|
||||||
|
_create_test(db, tech, state=TestState.validated, detection_result="detected")
|
||||||
|
_create_test(db, tech, state=TestState.draft)
|
||||||
|
db.commit()
|
||||||
|
repo = SATestRepository(db)
|
||||||
|
|
||||||
|
pairs = repo.get_states_and_results_for_technique(tech.id)
|
||||||
|
assert len(pairs) == 2
|
||||||
|
states = [p[0] for p in pairs]
|
||||||
|
assert "validated" in states
|
||||||
|
assert "draft" in states
|
||||||
Reference in New Issue
Block a user