Files
Aegis/backend/app/services/test_crud_service.py
T

536 lines
18 KiB
Python

"""Test CRUD service — list, create, update, and query logic for security tests.
Framework-agnostic; uses domain exceptions from app.domain.errors.
The router is responsible for HTTP concerns, auth, audit logging, and commit.
"""
# Import uuid
import uuid
from datetime import datetime
from typing import Any
# Import Session, joinedload from sqlalchemy.orm
from sqlalchemy.orm import Session, joinedload
# Import from app.domain.errors
from app.domain.errors import (
BusinessRuleViolation,
EntityNotFoundError,
PermissionViolation,
)
from app.models.enums import TestState
from app.models.technique import Technique
from app.models.test import Test
from app.models.test_template import TestTemplate
from app.models.campaign import Campaign, CampaignTest
from app.models.audit import AuditLog
# Import TestState from app.models.enums
from app.models.enums import TestState
# Import Technique from app.models.technique
from app.models.technique import Technique
# Import Test from app.models.test
from app.models.test import Test
# Import TestTemplate from app.models.test_template
from app.models.test_template import TestTemplate
# Import escape_like from app.utils
from app.utils import escape_like
# Define function list_tests
def list_tests(
# Entry: db
db: Session,
*,
# Entry: state
state: str | None = None,
# Entry: technique_id
technique_id: uuid.UUID | None = None,
# Entry: platform
platform: str | None = None,
# Entry: created_by
created_by: uuid.UUID | None = None,
# Entry: pending_validation_side
pending_validation_side: str | None = None,
not_in_any_campaign: bool = False,
offset: int = 0,
# Entry: limit
limit: int = 50,
) -> list[Test]:
"""Return a paginated list of tests with optional filters.
Tests that belong to a campaign still in 'draft' status AND with a
start_date in the future are always excluded — they should not appear
in the team's queue until the campaign is activated on its start date.
"""
query = db.query(Test).options(joinedload(Test.technique))
# Check: state
if state:
# Assign query = query.filter(Test.state == state)
query = query.filter(Test.state == state)
# Check: technique_id
if technique_id:
# Assign query = query.filter(Test.technique_id == technique_id)
query = query.filter(Test.technique_id == technique_id)
# Check: platform
if platform:
# Assign query = query.filter(Test.platform.ilike(f"%{escape_like(platform)}%"))
query = query.filter(Test.platform.ilike(f"%{escape_like(platform)}%"))
# Check: created_by
if created_by:
# Assign query = query.filter(Test.created_by == created_by)
query = query.filter(Test.created_by == created_by)
# Check: pending_validation_side == "red"
if pending_validation_side == "red":
# Assign query = query.filter(
query = query.filter(
Test.state == TestState.in_review,
Test.red_validation_status.in_(["pending", None]),
)
# Alternative: pending_validation_side == "blue"
elif pending_validation_side == "blue":
# Assign query = query.filter(
query = query.filter(
Test.state == TestState.in_review,
Test.blue_validation_status.in_(["pending", None]),
)
if not_in_any_campaign:
linked = db.query(CampaignTest.test_id).distinct().subquery()
query = query.filter(~Test.id.in_(linked))
# Always hide tests from scheduled campaigns that haven't started yet.
# A "scheduled-but-not-yet-active" campaign = draft status + start_date in future.
now = datetime.utcnow()
future_draft_tests = (
db.query(CampaignTest.test_id)
.join(Campaign, Campaign.id == CampaignTest.campaign_id)
.filter(
Campaign.status == "draft",
Campaign.start_date.isnot(None),
Campaign.start_date > now,
)
.distinct()
.subquery()
)
query = query.filter(~Test.id.in_(future_draft_tests))
# Return query.order_by(Test.created_at.desc()).offset(offset).limit(limit)....
return query.order_by(Test.created_at.desc()).offset(offset).limit(limit).all()
# Define function create_test
def create_test(
# Entry: db
db: Session,
*,
# Entry: technique_id
technique_id: uuid.UUID,
# Entry: creator_id
creator_id: uuid.UUID,
**fields: object,
) -> Test:
"""Create a new test linked to an existing technique.
Raises EntityNotFoundError if the technique does not exist.
Does not commit; caller uses UnitOfWork.
Args:
db (Session): Active SQLAlchemy database session.
technique_id (uuid.UUID): UUID of the technique this test covers.
creator_id (uuid.UUID): UUID of the user creating the test.
**fields (object): Additional keyword arguments set as attributes on
the new test (e.g. ``name``, ``platform``, ``description``).
Returns:
Test: The newly created test ORM object, flushed but not committed.
"""
# Assign technique = db.query(Technique).filter(Technique.id == technique_id).first()
technique = db.query(Technique).filter(Technique.id == technique_id).first()
# Check: technique is None
if technique is None:
# Raise EntityNotFoundError
raise EntityNotFoundError("Technique", str(technique_id))
# Assign test = Test(
test = Test(
# Keyword argument: technique_id
technique_id=technique_id,
# Keyword argument: created_by
created_by=creator_id,
# Keyword argument: state
state=TestState.draft,
created_at=datetime.utcnow(), # explicit — DB column has no server default
**fields,
)
# Stage new record(s) for database insertion
db.add(test)
# Flush changes to DB without committing the transaction
db.flush()
# Return test
return test
# Define function create_test_from_template
def create_test_from_template(
# Entry: db
db: Session,
*,
# Entry: template_id
template_id: uuid.UUID,
# Entry: technique_id_or_mitre
technique_id_or_mitre: str,
# Entry: creator_id
creator_id: uuid.UUID,
# Optional user-edited overrides (take priority over template values)
name_override: str | None = None,
description_override: str | None = None,
platform_override: str | None = None,
procedure_text_override: str | None = None,
tool_used_override: str | None = None,
) -> Test:
"""Instantiate a Test from a TestTemplate.
technique_id_or_mitre can be a UUID string or MITRE ID (e.g. T1059.001).
Override fields, when provided, take precedence over the template's values.
Raises EntityNotFoundError if template or technique not found.
Does not commit; caller uses UnitOfWork.
Args:
db (Session): Active SQLAlchemy database session.
template_id (uuid.UUID): UUID of the template to instantiate.
technique_id_or_mitre (str): UUID string or MITRE technique ID
(e.g. ``"T1059.001"``) identifying the target technique.
creator_id (uuid.UUID): UUID of the user creating the test.
Returns:
Test: The newly created test populated from template fields, flushed
but not committed.
"""
# Assign template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first()
template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first()
# Check: template is None
if template is None:
# Raise EntityNotFoundError
raise EntityNotFoundError("TestTemplate", str(template_id))
# Assign technique = None
technique = None
# Attempt the following; catch errors below
try:
# Assign technique_uuid = uuid.UUID(technique_id_or_mitre)
technique_uuid = uuid.UUID(technique_id_or_mitre)
# Assign technique = db.query(Technique).filter(Technique.id == technique_uuid).first()
technique = db.query(Technique).filter(Technique.id == technique_uuid).first()
# Handle ValueError
except ValueError:
# Intentional no-op placeholder
pass
# Check: technique is None
if technique is None:
# Assign technique = db.query(Technique).filter(
technique = db.query(Technique).filter(
Technique.mitre_id == technique_id_or_mitre
).first()
# Check: technique is None
if technique is None:
# Raise EntityNotFoundError
raise EntityNotFoundError("Technique", technique_id_or_mitre)
# Assign test = Test(
test = Test(
# Keyword argument: technique_id
technique_id=technique.id,
name=name_override if name_override is not None else template.name,
description=description_override if description_override is not None else template.description,
platform=platform_override if platform_override is not None else template.platform,
procedure_text=procedure_text_override if procedure_text_override is not None else template.attack_procedure,
tool_used=tool_used_override if tool_used_override is not None else template.tool_suggested,
remediation_steps=template.suggested_remediation,
# Keyword argument: created_by
created_by=creator_id,
# Keyword argument: state
state=TestState.draft,
created_at=datetime.utcnow(), # explicit — DB column has no server default
)
# Stage new record(s) for database insertion
db.add(test)
# Flush changes to DB without committing the transaction
db.flush()
# Return test
return test
# Define function get_test_detail
def get_test_detail(db: Session, test_id: uuid.UUID) -> Test:
"""Fetch a test with evidences and technique eager-loaded.
Raises EntityNotFoundError if the test does not exist.
Args:
db (Session): Active SQLAlchemy database session.
test_id (uuid.UUID): UUID of the test to retrieve.
Returns:
Test: The test ORM object with ``evidences`` relationship loaded.
"""
# Assign test = (
test = (
db.query(Test)
.options(joinedload(Test.evidences), joinedload(Test.technique))
.filter(Test.id == test_id)
# Chain .first() call
.first()
)
# Check: test is None
if test is None:
# Raise EntityNotFoundError
raise EntityNotFoundError("Test", str(test_id))
# Return test
return test
# Define function get_test_or_raise
def get_test_or_raise(db: Session, test_id: uuid.UUID) -> Test:
"""Fetch a test by ID. Raises EntityNotFoundError if not found.
Args:
db (Session): Active SQLAlchemy database session.
test_id (uuid.UUID): UUID of the test to retrieve.
Returns:
Test: The matching test ORM object.
"""
# Assign test = db.query(Test).filter(Test.id == test_id).first()
test = db.query(Test).filter(Test.id == test_id).first()
# Check: test is None
if test is None:
# Raise EntityNotFoundError
raise EntityNotFoundError("Test", str(test_id))
# Return test
return test
# Define function get_test_with_technique
def get_test_with_technique(db: Session, test_id: uuid.UUID) -> Test:
"""Fetch a test with technique joined. Raises EntityNotFoundError if not found.
Args:
db (Session): Active SQLAlchemy database session.
test_id (uuid.UUID): UUID of the test to retrieve.
Returns:
Test: The test ORM object with ``technique`` relationship loaded.
"""
# Assign test = (
test = (
db.query(Test)
# Chain .options() call
.options(joinedload(Test.technique))
# Chain .filter() call
.filter(Test.id == test_id)
# Chain .first() call
.first()
)
# Check: test is None
if test is None:
# Raise EntityNotFoundError
raise EntityNotFoundError("Test", str(test_id))
# Return test
return test
# Define function update_test
def update_test(
# Entry: db
db: Session,
# Entry: test_id
test_id: uuid.UUID,
*,
# Entry: updater_id
updater_id: uuid.UUID,
# Entry: updater_role
updater_role: str,
**fields: object,
) -> Test:
"""Update general test fields (draft or rejected only).
Raises PermissionViolation if not creator or admin.
Raises BusinessRuleViolation if state is not draft or rejected.
Raises EntityNotFoundError if test not found.
Does not commit; caller uses UnitOfWork.
Args:
db (Session): Active SQLAlchemy database session.
test_id (uuid.UUID): UUID of the test to update.
updater_id (uuid.UUID): UUID of the user performing the update.
updater_role (str): Role of the updater; ``"admin"`` bypasses the
creator-only restriction.
**fields (object): Keyword arguments mapped directly onto test
attributes.
Returns:
Test: The updated test ORM object, flushed but not committed.
"""
# Assign test = get_test_or_raise(db, test_id)
test = get_test_or_raise(db, test_id)
# Check: updater_role != "admin" and test.created_by != updater_id
if updater_role != "admin" and test.created_by != updater_id:
# Raise PermissionViolation
raise PermissionViolation(
# Literal argument value
"Only the test creator or an admin can update this test"
)
# Check: test.state not in (TestState.draft, TestState.rejected)
if test.state not in (TestState.draft, TestState.rejected):
# Raise BusinessRuleViolation
raise BusinessRuleViolation(
f"Cannot update a test in '{test.state.value}' state (must be draft or rejected)"
)
# Iterate over fields.items()
for field, value in fields.items():
# Call setattr()
setattr(test, field, value)
# Flush changes to DB without committing the transaction
db.flush()
# Return test
return test
# Define function update_test_red
def update_test_red(db: Session, test_id: uuid.UUID, **fields: object) -> Test:
"""Update Red Team fields (draft or red_executing only).
Raises BusinessRuleViolation if state not in (draft, red_executing).
Raises EntityNotFoundError if test not found.
Does not commit; caller uses UnitOfWork.
Args:
db (Session): Active SQLAlchemy database session.
test_id (uuid.UUID): UUID of the test to update.
**fields (object): Red-team field names and their new values.
Returns:
Test: The updated test ORM object, flushed but not committed.
"""
# Assign test = get_test_or_raise(db, test_id)
test = get_test_or_raise(db, test_id)
# Check: test.state not in (TestState.draft, TestState.red_executing)
if test.state not in (TestState.draft, TestState.red_executing):
# Raise BusinessRuleViolation
raise BusinessRuleViolation(
f"Cannot update red fields in '{test.state.value}' state "
# Literal argument value
"(must be draft or red_executing)"
)
# Iterate over fields.items()
for field, value in fields.items():
# Call setattr()
setattr(test, field, value)
# Flush changes to DB without committing the transaction
db.flush()
# Return test
return test
# Define function update_test_blue
def update_test_blue(db: Session, test_id: uuid.UUID, **fields: object) -> Test:
"""Update Blue Team fields (blue_evaluating only).
Raises BusinessRuleViolation if state is not blue_evaluating.
Raises EntityNotFoundError if test not found.
Does not commit; caller uses UnitOfWork.
Args:
db (Session): Active SQLAlchemy database session.
test_id (uuid.UUID): UUID of the test to update.
**fields (object): Blue-team field names and their new values.
Returns:
Test: The updated test ORM object, flushed but not committed.
"""
# Assign test = get_test_or_raise(db, test_id)
test = get_test_or_raise(db, test_id)
# Check: test.state != TestState.blue_evaluating
if test.state != TestState.blue_evaluating:
# Raise BusinessRuleViolation
raise BusinessRuleViolation(
f"Cannot update blue fields in '{test.state.value}' state "
# Literal argument value
"(must be blue_evaluating)"
)
# Iterate over fields.items()
for field, value in fields.items():
# Call setattr()
setattr(test, field, value)
# Flush changes to DB without committing the transaction
db.flush()
# Return test
return test
# Define function get_test_timeline
def get_test_timeline(db: Session, test_id: uuid.UUID) -> list[dict[str, Any]]:
"""Return chronological audit-log history for a test.
Raises EntityNotFoundError if the test does not exist.
Args:
db (Session): Active SQLAlchemy database session.
test_id (uuid.UUID): UUID of the test whose history is requested.
Returns:
list[dict[str, Any]]: Audit-log entries ordered by timestamp ascending,
each containing ``id``, ``action``, ``user_id``, ``timestamp``,
and ``details``.
"""
# Call get_test_or_raise()
get_test_or_raise(db, test_id)
# Assign logs = (
logs = (
db.query(AuditLog)
# Chain .filter() call
.filter(
AuditLog.entity_type == "test",
AuditLog.entity_id == str(test_id),
)
# Chain .order_by() call
.order_by(AuditLog.timestamp.asc())
# Chain .all() call
.all()
)
# Return [
return [
{
# Literal argument value
"id": str(log.id),
# Literal argument value
"action": log.action,
# Literal argument value
"user_id": str(log.user_id) if log.user_id else None,
# Literal argument value
"timestamp": log.timestamp.isoformat() if log.timestamp else None,
# Literal argument value
"details": log.details,
}
for log in logs
]