Files
Aegis/backend/app/services/test_crud_service.py
T
kitos 394d5d9056 refactor(types): add comprehensive type annotations across backend Python codebase
Enable ANN rules in ruff.toml (flake8-annotations) and resolve all 221 violations:

ANN201/ANN202 — return types on 168 public/private functions:
- All 28 FastAPI routers: endpoints annotated with dict/list/specific schema/
  StreamingResponse/FileResponse/JSONResponse as appropriate
- main.py: lifespan→AsyncGenerator[None,None], exception handlers→JSONResponse
- database.py: get_db→Generator[Session,None,None], proxy methods→correct types
- middleware/request_context.py: dispatch→Response with Callable call_next type

ANN001/ANN002/ANN003 — 32 missing argument types:
- seed_demo.py: all db parameters typed as Session
- domain/unit_of_work.py: __aexit__ exc_type/exc_val/exc_tb typed with TracebackType
- services: audit_service user_id→UUID|None, heatmap_service query/model/builder,
  notification_service test→Test, tempo_service test→Test/user→User,
  test_workflow_service test_id→UUID, campaign_crud **fields→object,
  test_crud **fields→object (4 sites)

ANN401 — 16 Any usages resolved:
- Domain entities (campaign/technique/threat_actor/test_entity): replaced Any with
  actual ORM types via TYPE_CHECKING guards to avoid circular imports
- detection_rule_service: test_id/detection_rule_id/evaluator_id→UUID
- score_cache: kept Any with # noqa: ANN401 (genuinely generic cache)
- jira_service/tempo_service: kept Any with # noqa: ANN401 (lazy optional deps)
- d3fend_import_service: _to_str(v: Any) kept with # noqa: ANN401

ANN204/ANN205/ANN206 — special/static/class methods:
- database.py proxy __call__/__getattr__: *args: object/**kwargs: object
- schemas/test.py model_validate: obj→object, **kwargs→object
- sa_technique_repository._int_type→type

All 439 unit tests pass. ruff check app/ → All checks passed!

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-06-09 17:04:51 +02:00

278 lines
8.0 KiB
Python

"""Test CRUD service — list, create, update, and query logic for security tests.
Framework-agnostic; uses domain exceptions from app.domain.errors.
The router is responsible for HTTP concerns, auth, audit logging, and commit.
"""
import uuid
from typing import Any
from sqlalchemy.orm import Session, joinedload
from app.domain.errors import (
BusinessRuleViolation,
EntityNotFoundError,
PermissionViolation,
)
from app.models.audit import AuditLog
from app.models.enums import TestState
from app.models.technique import Technique
from app.models.test import Test
from app.models.test_template import TestTemplate
from app.utils import escape_like
def list_tests(
db: Session,
*,
state: str | None = None,
technique_id: uuid.UUID | None = None,
platform: str | None = None,
created_by: uuid.UUID | None = None,
pending_validation_side: str | None = None,
offset: int = 0,
limit: int = 50,
) -> list[Test]:
"""Return a paginated list of tests with optional filters."""
query = db.query(Test).options(joinedload(Test.technique))
if state:
query = query.filter(Test.state == state)
if technique_id:
query = query.filter(Test.technique_id == technique_id)
if platform:
query = query.filter(Test.platform.ilike(f"%{escape_like(platform)}%"))
if created_by:
query = query.filter(Test.created_by == created_by)
if pending_validation_side == "red":
query = query.filter(
Test.state == TestState.in_review,
Test.red_validation_status.in_(["pending", None]),
)
elif pending_validation_side == "blue":
query = query.filter(
Test.state == TestState.in_review,
Test.blue_validation_status.in_(["pending", None]),
)
return query.order_by(Test.created_at.desc()).offset(offset).limit(limit).all()
def create_test(
db: Session,
*,
technique_id: uuid.UUID,
creator_id: uuid.UUID,
**fields: object,
) -> Test:
"""Create a new test linked to an existing technique.
Raises EntityNotFoundError if the technique does not exist.
Does not commit; caller uses UnitOfWork.
"""
technique = db.query(Technique).filter(Technique.id == technique_id).first()
if technique is None:
raise EntityNotFoundError("Technique", str(technique_id))
test = Test(
technique_id=technique_id,
created_by=creator_id,
state=TestState.draft,
**fields,
)
db.add(test)
db.flush()
return test
def create_test_from_template(
db: Session,
*,
template_id: uuid.UUID,
technique_id_or_mitre: str,
creator_id: uuid.UUID,
) -> Test:
"""Instantiate a Test from a TestTemplate.
technique_id_or_mitre can be a UUID string or MITRE ID (e.g. T1059.001).
Raises EntityNotFoundError if template or technique not found.
Does not commit; caller uses UnitOfWork.
"""
template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first()
if template is None:
raise EntityNotFoundError("TestTemplate", str(template_id))
technique = None
try:
technique_uuid = uuid.UUID(technique_id_or_mitre)
technique = db.query(Technique).filter(Technique.id == technique_uuid).first()
except ValueError:
pass
if technique is None:
technique = db.query(Technique).filter(
Technique.mitre_id == technique_id_or_mitre
).first()
if technique is None:
raise EntityNotFoundError("Technique", technique_id_or_mitre)
test = Test(
technique_id=technique.id,
name=template.name,
description=template.description,
platform=template.platform,
procedure_text=template.attack_procedure,
tool_used=template.tool_suggested,
remediation_steps=template.suggested_remediation,
created_by=creator_id,
state=TestState.draft,
)
db.add(test)
db.flush()
return test
def get_test_detail(db: Session, test_id: uuid.UUID) -> Test:
"""Fetch a test with evidences eager-loaded.
Raises EntityNotFoundError if the test does not exist.
"""
test = (
db.query(Test)
.options(joinedload(Test.evidences))
.filter(Test.id == test_id)
.first()
)
if test is None:
raise EntityNotFoundError("Test", str(test_id))
return test
def get_test_or_raise(db: Session, test_id: uuid.UUID) -> Test:
"""Fetch a test by ID. Raises EntityNotFoundError if not found."""
test = db.query(Test).filter(Test.id == test_id).first()
if test is None:
raise EntityNotFoundError("Test", str(test_id))
return test
def get_test_with_technique(db: Session, test_id: uuid.UUID) -> Test:
"""Fetch a test with technique joined. Raises EntityNotFoundError if not found."""
test = (
db.query(Test)
.options(joinedload(Test.technique))
.filter(Test.id == test_id)
.first()
)
if test is None:
raise EntityNotFoundError("Test", str(test_id))
return test
def update_test(
db: Session,
test_id: uuid.UUID,
*,
updater_id: uuid.UUID,
updater_role: str,
**fields: object,
) -> Test:
"""Update general test fields (draft or rejected only).
Raises PermissionViolation if not creator or admin.
Raises BusinessRuleViolation if state is not draft or rejected.
Raises EntityNotFoundError if test not found.
Does not commit; caller uses UnitOfWork.
"""
test = get_test_or_raise(db, test_id)
if updater_role != "admin" and test.created_by != updater_id:
raise PermissionViolation(
"Only the test creator or an admin can update this test"
)
if test.state not in (TestState.draft, TestState.rejected):
raise BusinessRuleViolation(
f"Cannot update a test in '{test.state.value}' state (must be draft or rejected)"
)
for field, value in fields.items():
setattr(test, field, value)
db.flush()
return test
def update_test_red(db: Session, test_id: uuid.UUID, **fields: object) -> Test:
"""Update Red Team fields (draft or red_executing only).
Raises BusinessRuleViolation if state not in (draft, red_executing).
Raises EntityNotFoundError if test not found.
Does not commit; caller uses UnitOfWork.
"""
test = get_test_or_raise(db, test_id)
if test.state not in (TestState.draft, TestState.red_executing):
raise BusinessRuleViolation(
f"Cannot update red fields in '{test.state.value}' state "
"(must be draft or red_executing)"
)
for field, value in fields.items():
setattr(test, field, value)
db.flush()
return test
def update_test_blue(db: Session, test_id: uuid.UUID, **fields: object) -> Test:
"""Update Blue Team fields (blue_evaluating only).
Raises BusinessRuleViolation if state is not blue_evaluating.
Raises EntityNotFoundError if test not found.
Does not commit; caller uses UnitOfWork.
"""
test = get_test_or_raise(db, test_id)
if test.state != TestState.blue_evaluating:
raise BusinessRuleViolation(
f"Cannot update blue fields in '{test.state.value}' state "
"(must be blue_evaluating)"
)
for field, value in fields.items():
setattr(test, field, value)
db.flush()
return test
def get_test_timeline(db: Session, test_id: uuid.UUID) -> list[dict[str, Any]]:
"""Return chronological audit-log history for a test.
Raises EntityNotFoundError if the test does not exist.
"""
get_test_or_raise(db, test_id)
logs = (
db.query(AuditLog)
.filter(
AuditLog.entity_type == "test",
AuditLog.entity_id == str(test_id),
)
.order_by(AuditLog.timestamp.asc())
.all()
)
return [
{
"id": str(log.id),
"action": log.action,
"user_id": str(log.user_id) if log.user_id else None,
"timestamp": log.timestamp.isoformat() if log.timestamp else None,
"details": log.details,
}
for log in logs
]