refactor(techniques): wire TechniqueRepository into techniques router replacing direct db.query() with repo pattern, domain exceptions, and UnitOfWork

This commit is contained in:
2026-02-19 15:13:52 +01:00
parent 0b65f51d1c
commit 2b6d9090c9
4 changed files with 97 additions and 79 deletions

View File

@@ -46,6 +46,8 @@ class TechniqueEntity:
status_global: TechniqueStatus = TechniqueStatus.not_evaluated
review_required: bool = False
last_review_date: datetime | None = None
mitre_version: str | None = None
mitre_last_modified: datetime | None = None
# -- Factory -----------------------------------------------------------
@@ -94,6 +96,8 @@ class TechniqueEntity:
status_global=status,
review_required=model.review_required or False,
last_review_date=model.last_review_date,
mitre_version=getattr(model, "mitre_version", None),
mitre_last_modified=getattr(model, "mitre_last_modified", None),
)
def apply_to(self, model: Any) -> None:

View File

@@ -163,6 +163,8 @@ class SATechniqueRepository:
existing.platforms = technique.platforms
existing.is_subtechnique = technique.is_subtechnique
existing.parent_mitre_id = technique.parent_mitre_id
existing.mitre_version = technique.mitre_version
existing.mitre_last_modified = technique.mitre_last_modified
self._session.flush()
return TechniqueMapper.to_entity(existing)
else:
@@ -178,6 +180,8 @@ class SATechniqueRepository:
status_global=technique.status_global,
review_required=technique.review_required,
last_review_date=technique.last_review_date,
mitre_version=technique.mitre_version,
mitre_last_modified=technique.mitre_last_modified,
)
self._session.add(model)
self._session.flush()

View File

@@ -1,13 +1,23 @@
"""CRUD router for MITRE ATT&CK Techniques."""
"""CRUD router for MITRE ATT&CK Techniques.
from datetime import datetime
Uses the TechniqueRepository for data access and domain exceptions
for error signaling. The error_handler middleware maps domain
exceptions to HTTP responses automatically.
"""
from fastapi import APIRouter, Depends, HTTPException, Query, status
from fastapi import APIRouter, Depends, Query, status
from sqlalchemy.orm import Session, joinedload
from app.database import get_db
from app.dependencies.auth import get_current_user, require_role, require_any_role
from app.models.enums import TechniqueStatus
from app.dependencies.repositories import get_technique_repository
from app.domain.entities.technique import TechniqueEntity
from app.domain.errors import DuplicateEntityError, EntityNotFoundError
from app.domain.enums import TechniqueStatus
from app.domain.unit_of_work import UnitOfWork
from app.infrastructure.persistence.repositories.sa_technique_repository import (
SATechniqueRepository,
)
from app.models.technique import Technique
from app.models.user import User
from app.schemas.technique import (
@@ -34,24 +44,19 @@ def list_techniques(
None, alias="status", description="Filter by global status"
),
review_required: bool | None = Query(None, description="Filter by review flag"),
db: Session = Depends(get_db),
repo: SATechniqueRepository = Depends(get_technique_repository),
current_user: User = Depends(get_current_user),
):
"""Return a lightweight list of techniques, optionally filtered."""
query = db.query(Technique)
if tactic is not None:
query = query.filter(Technique.tactic == tactic)
if status_global is not None:
query = query.filter(Technique.status_global == status_global)
if review_required is not None:
query = query.filter(Technique.review_required == review_required)
return query.order_by(Technique.mitre_id).all()
return repo.list_all(
tactic=tactic,
status=status_global,
review_required=review_required,
)
# ---------------------------------------------------------------------------
# GET /techniques/{mitre_id} — detail (with tests)
# GET /techniques/{mitre_id} — detail (with tests + D3FEND)
# ---------------------------------------------------------------------------
@@ -70,12 +75,8 @@ def get_technique(
)
if technique is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Technique {mitre_id} not found",
)
raise EntityNotFoundError("Technique", mitre_id)
# Build response dict manually to include D3FEND defenses
defenses = get_defenses_for_technique(db, technique.id)
return {
@@ -120,34 +121,35 @@ def get_technique(
def create_technique(
payload: TechniqueCreate,
db: Session = Depends(get_db),
repo: SATechniqueRepository = Depends(get_technique_repository),
current_user: User = Depends(require_role("admin")),
):
"""Create a new technique manually."""
# Ensure mitre_id is unique
existing = (
db.query(Technique).filter(Technique.mitre_id == payload.mitre_id).first()
)
if existing is not None:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Technique with mitre_id '{payload.mitre_id}' already exists",
)
if repo.exists_by_mitre_id(payload.mitre_id):
raise DuplicateEntityError("Technique", "mitre_id", payload.mitre_id)
technique = Technique(**payload.model_dump())
db.add(technique)
db.commit()
db.refresh(technique)
entity = TechniqueEntity.create(
mitre_id=payload.mitre_id,
name=payload.name,
description=payload.description,
tactic=payload.tactic,
platforms=payload.platforms,
)
with UnitOfWork(db) as uow:
saved = repo.save(entity)
uow.commit()
log_action(
db,
user_id=current_user.id,
action="create_technique",
entity_type="technique",
entity_id=technique.id,
details={"mitre_id": technique.mitre_id, "name": technique.name},
entity_id=saved.id,
details={"mitre_id": saved.mitre_id, "name": saved.name},
)
return technique
return saved
# ---------------------------------------------------------------------------
@@ -160,36 +162,32 @@ def update_technique(
mitre_id: str,
payload: TechniqueUpdate,
db: Session = Depends(get_db),
repo: SATechniqueRepository = Depends(get_technique_repository),
current_user: User = Depends(require_role("admin")),
):
"""Update one or more fields of an existing technique."""
technique = (
db.query(Technique).filter(Technique.mitre_id == mitre_id).first()
)
if technique is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Technique {mitre_id} not found",
)
entity = repo.find_by_mitre_id(mitre_id)
if entity is None:
raise EntityNotFoundError("Technique", mitre_id)
update_data = payload.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(technique, field, value)
setattr(entity, field, value)
db.commit()
db.refresh(technique)
with UnitOfWork(db) as uow:
saved = repo.save(entity)
uow.commit()
log_action(
db,
user_id=current_user.id,
action="update_technique",
entity_type="technique",
entity_id=technique.id,
entity_id=saved.id,
details={"mitre_id": mitre_id, "updated_fields": list(update_data.keys())},
)
return technique
return saved
# ---------------------------------------------------------------------------
@@ -201,6 +199,7 @@ def update_technique(
def review_technique(
mitre_id: str,
db: Session = Depends(get_db),
repo: SATechniqueRepository = Depends(get_technique_repository),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
"""Mark a technique as reviewed.
@@ -208,29 +207,23 @@ def review_technique(
Sets ``review_required`` to *False* and records the current timestamp
in ``last_review_date``.
"""
technique = (
db.query(Technique).filter(Technique.mitre_id == mitre_id).first()
)
entity = repo.find_by_mitre_id(mitre_id)
if entity is None:
raise EntityNotFoundError("Technique", mitre_id)
if technique is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Technique {mitre_id} not found",
)
entity.mark_reviewed()
technique.review_required = False
technique.last_review_date = datetime.utcnow()
db.commit()
db.refresh(technique)
with UnitOfWork(db) as uow:
saved = repo.save(entity)
uow.commit()
log_action(
db,
user_id=current_user.id,
action="review_technique",
entity_type="technique",
entity_id=technique.id,
entity_id=saved.id,
details={"mitre_id": mitre_id},
)
return technique
return saved