Files
Aegis/backend/app/services/test_crud_service.py
kitos c467459b51
Some checks failed
Aegis CI / lint-and-test (push) Has been cancelled
fix(campaigns): filter existing-test picker to draft + not in any campaign
Backend: add not_in_any_campaign filter to list_tests (subquery on
CampaignTest) and expose it as a query param on GET /tests.
Frontend: the 'Existing Test' tab now requests only
  state=draft & not_in_any_campaign=true
so tests already linked to any campaign or not in draft state
are never shown.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-29 09:55:02 +02:00

293 lines
9.1 KiB
Python

"""Test CRUD service — list, create, update, and query logic for security tests.
Framework-agnostic; uses domain exceptions from app.domain.errors.
The router is responsible for HTTP concerns, auth, audit logging, and commit.
"""
import uuid
from datetime import datetime
from typing import Any
from sqlalchemy.orm import Session, joinedload
from app.domain.errors import (
BusinessRuleViolation,
EntityNotFoundError,
PermissionViolation,
)
from app.models.enums import TestState
from app.models.technique import Technique
from app.models.test import Test
from app.models.test_template import TestTemplate
from app.models.campaign import CampaignTest
from app.models.audit import AuditLog
from app.utils import escape_like
def list_tests(
db: Session,
*,
state: str | None = None,
technique_id: uuid.UUID | None = None,
platform: str | None = None,
created_by: uuid.UUID | None = None,
pending_validation_side: str | None = None,
not_in_any_campaign: bool = False,
offset: int = 0,
limit: int = 50,
) -> list[Test]:
"""Return a paginated list of tests with optional filters."""
query = db.query(Test).options(joinedload(Test.technique))
if state:
query = query.filter(Test.state == state)
if technique_id:
query = query.filter(Test.technique_id == technique_id)
if platform:
query = query.filter(Test.platform.ilike(f"%{escape_like(platform)}%"))
if created_by:
query = query.filter(Test.created_by == created_by)
if pending_validation_side == "red":
query = query.filter(
Test.state == TestState.in_review,
Test.red_validation_status.in_(["pending", None]),
)
elif pending_validation_side == "blue":
query = query.filter(
Test.state == TestState.in_review,
Test.blue_validation_status.in_(["pending", None]),
)
if not_in_any_campaign:
linked = db.query(CampaignTest.test_id).distinct().subquery()
query = query.filter(~Test.id.in_(linked))
return query.order_by(Test.created_at.desc()).offset(offset).limit(limit).all()
def create_test(
db: Session,
*,
technique_id: uuid.UUID,
creator_id: uuid.UUID,
**fields: Any,
) -> Test:
"""Create a new test linked to an existing technique.
Raises EntityNotFoundError if the technique does not exist.
Does not commit; caller uses UnitOfWork.
"""
technique = db.query(Technique).filter(Technique.id == technique_id).first()
if technique is None:
raise EntityNotFoundError("Technique", str(technique_id))
test = Test(
technique_id=technique_id,
created_by=creator_id,
state=TestState.draft,
created_at=datetime.utcnow(), # explicit — DB column has no server default
**fields,
)
db.add(test)
db.flush()
return test
def create_test_from_template(
db: Session,
*,
template_id: uuid.UUID,
technique_id_or_mitre: str,
creator_id: uuid.UUID,
# Optional user-edited overrides (take priority over template values)
name_override: str | None = None,
description_override: str | None = None,
platform_override: str | None = None,
procedure_text_override: str | None = None,
tool_used_override: str | None = None,
) -> Test:
"""Instantiate a Test from a TestTemplate.
technique_id_or_mitre can be a UUID string or MITRE ID (e.g. T1059.001).
Override fields, when provided, take precedence over the template's values.
Raises EntityNotFoundError if template or technique not found.
Does not commit; caller uses UnitOfWork.
"""
template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first()
if template is None:
raise EntityNotFoundError("TestTemplate", str(template_id))
technique = None
try:
technique_uuid = uuid.UUID(technique_id_or_mitre)
technique = db.query(Technique).filter(Technique.id == technique_uuid).first()
except ValueError:
pass
if technique is None:
technique = db.query(Technique).filter(
Technique.mitre_id == technique_id_or_mitre
).first()
if technique is None:
raise EntityNotFoundError("Technique", technique_id_or_mitre)
test = Test(
technique_id=technique.id,
name=name_override if name_override is not None else template.name,
description=description_override if description_override is not None else template.description,
platform=platform_override if platform_override is not None else template.platform,
procedure_text=procedure_text_override if procedure_text_override is not None else template.attack_procedure,
tool_used=tool_used_override if tool_used_override is not None else template.tool_suggested,
remediation_steps=template.suggested_remediation,
created_by=creator_id,
state=TestState.draft,
created_at=datetime.utcnow(), # explicit — DB column has no server default
)
db.add(test)
db.flush()
return test
def get_test_detail(db: Session, test_id: uuid.UUID) -> Test:
"""Fetch a test with evidences and technique eager-loaded.
Raises EntityNotFoundError if the test does not exist.
"""
test = (
db.query(Test)
.options(joinedload(Test.evidences), joinedload(Test.technique))
.filter(Test.id == test_id)
.first()
)
if test is None:
raise EntityNotFoundError("Test", str(test_id))
return test
def get_test_or_raise(db: Session, test_id: uuid.UUID) -> Test:
"""Fetch a test by ID. Raises EntityNotFoundError if not found."""
test = db.query(Test).filter(Test.id == test_id).first()
if test is None:
raise EntityNotFoundError("Test", str(test_id))
return test
def get_test_with_technique(db: Session, test_id: uuid.UUID) -> Test:
"""Fetch a test with technique joined. Raises EntityNotFoundError if not found."""
test = (
db.query(Test)
.options(joinedload(Test.technique))
.filter(Test.id == test_id)
.first()
)
if test is None:
raise EntityNotFoundError("Test", str(test_id))
return test
def update_test(
db: Session,
test_id: uuid.UUID,
*,
updater_id: uuid.UUID,
updater_role: str,
**fields: Any,
) -> Test:
"""Update general test fields (draft or rejected only).
Raises PermissionViolation if not creator or admin.
Raises BusinessRuleViolation if state is not draft or rejected.
Raises EntityNotFoundError if test not found.
Does not commit; caller uses UnitOfWork.
"""
test = get_test_or_raise(db, test_id)
if updater_role != "admin" and test.created_by != updater_id:
raise PermissionViolation(
"Only the test creator or an admin can update this test"
)
if test.state not in (TestState.draft, TestState.rejected):
raise BusinessRuleViolation(
f"Cannot update a test in '{test.state.value}' state (must be draft or rejected)"
)
for field, value in fields.items():
setattr(test, field, value)
db.flush()
return test
def update_test_red(db: Session, test_id: uuid.UUID, **fields: Any) -> Test:
"""Update Red Team fields (draft or red_executing only).
Raises BusinessRuleViolation if state not in (draft, red_executing).
Raises EntityNotFoundError if test not found.
Does not commit; caller uses UnitOfWork.
"""
test = get_test_or_raise(db, test_id)
if test.state not in (TestState.draft, TestState.red_executing):
raise BusinessRuleViolation(
f"Cannot update red fields in '{test.state.value}' state "
"(must be draft or red_executing)"
)
for field, value in fields.items():
setattr(test, field, value)
db.flush()
return test
def update_test_blue(db: Session, test_id: uuid.UUID, **fields: Any) -> Test:
"""Update Blue Team fields (blue_evaluating only).
Raises BusinessRuleViolation if state is not blue_evaluating.
Raises EntityNotFoundError if test not found.
Does not commit; caller uses UnitOfWork.
"""
test = get_test_or_raise(db, test_id)
if test.state != TestState.blue_evaluating:
raise BusinessRuleViolation(
f"Cannot update blue fields in '{test.state.value}' state "
"(must be blue_evaluating)"
)
for field, value in fields.items():
setattr(test, field, value)
db.flush()
return test
def get_test_timeline(db: Session, test_id: uuid.UUID) -> list[dict[str, Any]]:
"""Return chronological audit-log history for a test.
Raises EntityNotFoundError if the test does not exist.
"""
get_test_or_raise(db, test_id)
logs = (
db.query(AuditLog)
.filter(
AuditLog.entity_type == "test",
AuditLog.entity_id == str(test_id),
)
.order_by(AuditLog.timestamp.asc())
.all()
)
return [
{
"id": str(log.id),
"action": log.action,
"user_id": str(log.user_id) if log.user_id else None,
"timestamp": log.timestamp.isoformat() if log.timestamp else None,
"details": log.details,
}
for log in logs
]