Files
Aegis/backend/app/services/test_crud_service.py
kitos 4c230caa32
Some checks failed
Aegis CI / lint-and-test (push) Has been cancelled
fix(campaigns): start_date modal + hide future-campaign tests from queue
Backend: activate endpoint returns 409 with structured warning when
start_date is in the future; accepts force=true to bypass.
test_crud_service: always excludes tests from draft campaigns with future
start_date so they do not appear in the team queue prematurely.

Frontend: catches 409 on activate and shows amber confirmation modal
with Keep scheduled / Activate now anyway options.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-06-04 14:05:58 +02:00

314 lines
9.9 KiB
Python

"""Test CRUD service — list, create, update, and query logic for security tests.
Framework-agnostic; uses domain exceptions from app.domain.errors.
The router is responsible for HTTP concerns, auth, audit logging, and commit.
"""
import uuid
from datetime import datetime
from typing import Any
from sqlalchemy.orm import Session, joinedload
from app.domain.errors import (
BusinessRuleViolation,
EntityNotFoundError,
PermissionViolation,
)
from app.models.enums import TestState
from app.models.technique import Technique
from app.models.test import Test
from app.models.test_template import TestTemplate
from app.models.campaign import Campaign, CampaignTest
from app.models.audit import AuditLog
from app.utils import escape_like
def list_tests(
db: Session,
*,
state: str | None = None,
technique_id: uuid.UUID | None = None,
platform: str | None = None,
created_by: uuid.UUID | None = None,
pending_validation_side: str | None = None,
not_in_any_campaign: bool = False,
offset: int = 0,
limit: int = 50,
) -> list[Test]:
"""Return a paginated list of tests with optional filters.
Tests that belong to a campaign still in 'draft' status AND with a
start_date in the future are always excluded — they should not appear
in the team's queue until the campaign is activated on its start date.
"""
query = db.query(Test).options(joinedload(Test.technique))
if state:
query = query.filter(Test.state == state)
if technique_id:
query = query.filter(Test.technique_id == technique_id)
if platform:
query = query.filter(Test.platform.ilike(f"%{escape_like(platform)}%"))
if created_by:
query = query.filter(Test.created_by == created_by)
if pending_validation_side == "red":
query = query.filter(
Test.state == TestState.in_review,
Test.red_validation_status.in_(["pending", None]),
)
elif pending_validation_side == "blue":
query = query.filter(
Test.state == TestState.in_review,
Test.blue_validation_status.in_(["pending", None]),
)
if not_in_any_campaign:
linked = db.query(CampaignTest.test_id).distinct().subquery()
query = query.filter(~Test.id.in_(linked))
# Always hide tests from scheduled campaigns that haven't started yet.
# A "scheduled-but-not-yet-active" campaign = draft status + start_date in future.
now = datetime.utcnow()
future_draft_tests = (
db.query(CampaignTest.test_id)
.join(Campaign, Campaign.id == CampaignTest.campaign_id)
.filter(
Campaign.status == "draft",
Campaign.start_date.isnot(None),
Campaign.start_date > now,
)
.distinct()
.subquery()
)
query = query.filter(~Test.id.in_(future_draft_tests))
return query.order_by(Test.created_at.desc()).offset(offset).limit(limit).all()
def create_test(
db: Session,
*,
technique_id: uuid.UUID,
creator_id: uuid.UUID,
**fields: Any,
) -> Test:
"""Create a new test linked to an existing technique.
Raises EntityNotFoundError if the technique does not exist.
Does not commit; caller uses UnitOfWork.
"""
technique = db.query(Technique).filter(Technique.id == technique_id).first()
if technique is None:
raise EntityNotFoundError("Technique", str(technique_id))
test = Test(
technique_id=technique_id,
created_by=creator_id,
state=TestState.draft,
created_at=datetime.utcnow(), # explicit — DB column has no server default
**fields,
)
db.add(test)
db.flush()
return test
def create_test_from_template(
db: Session,
*,
template_id: uuid.UUID,
technique_id_or_mitre: str,
creator_id: uuid.UUID,
# Optional user-edited overrides (take priority over template values)
name_override: str | None = None,
description_override: str | None = None,
platform_override: str | None = None,
procedure_text_override: str | None = None,
tool_used_override: str | None = None,
) -> Test:
"""Instantiate a Test from a TestTemplate.
technique_id_or_mitre can be a UUID string or MITRE ID (e.g. T1059.001).
Override fields, when provided, take precedence over the template's values.
Raises EntityNotFoundError if template or technique not found.
Does not commit; caller uses UnitOfWork.
"""
template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first()
if template is None:
raise EntityNotFoundError("TestTemplate", str(template_id))
technique = None
try:
technique_uuid = uuid.UUID(technique_id_or_mitre)
technique = db.query(Technique).filter(Technique.id == technique_uuid).first()
except ValueError:
pass
if technique is None:
technique = db.query(Technique).filter(
Technique.mitre_id == technique_id_or_mitre
).first()
if technique is None:
raise EntityNotFoundError("Technique", technique_id_or_mitre)
test = Test(
technique_id=technique.id,
name=name_override if name_override is not None else template.name,
description=description_override if description_override is not None else template.description,
platform=platform_override if platform_override is not None else template.platform,
procedure_text=procedure_text_override if procedure_text_override is not None else template.attack_procedure,
tool_used=tool_used_override if tool_used_override is not None else template.tool_suggested,
remediation_steps=template.suggested_remediation,
created_by=creator_id,
state=TestState.draft,
created_at=datetime.utcnow(), # explicit — DB column has no server default
)
db.add(test)
db.flush()
return test
def get_test_detail(db: Session, test_id: uuid.UUID) -> Test:
"""Fetch a test with evidences and technique eager-loaded.
Raises EntityNotFoundError if the test does not exist.
"""
test = (
db.query(Test)
.options(joinedload(Test.evidences), joinedload(Test.technique))
.filter(Test.id == test_id)
.first()
)
if test is None:
raise EntityNotFoundError("Test", str(test_id))
return test
def get_test_or_raise(db: Session, test_id: uuid.UUID) -> Test:
"""Fetch a test by ID. Raises EntityNotFoundError if not found."""
test = db.query(Test).filter(Test.id == test_id).first()
if test is None:
raise EntityNotFoundError("Test", str(test_id))
return test
def get_test_with_technique(db: Session, test_id: uuid.UUID) -> Test:
"""Fetch a test with technique joined. Raises EntityNotFoundError if not found."""
test = (
db.query(Test)
.options(joinedload(Test.technique))
.filter(Test.id == test_id)
.first()
)
if test is None:
raise EntityNotFoundError("Test", str(test_id))
return test
def update_test(
db: Session,
test_id: uuid.UUID,
*,
updater_id: uuid.UUID,
updater_role: str,
**fields: Any,
) -> Test:
"""Update general test fields (draft or rejected only).
Raises PermissionViolation if not creator or admin.
Raises BusinessRuleViolation if state is not draft or rejected.
Raises EntityNotFoundError if test not found.
Does not commit; caller uses UnitOfWork.
"""
test = get_test_or_raise(db, test_id)
if updater_role != "admin" and test.created_by != updater_id:
raise PermissionViolation(
"Only the test creator or an admin can update this test"
)
if test.state not in (TestState.draft, TestState.rejected):
raise BusinessRuleViolation(
f"Cannot update a test in '{test.state.value}' state (must be draft or rejected)"
)
for field, value in fields.items():
setattr(test, field, value)
db.flush()
return test
def update_test_red(db: Session, test_id: uuid.UUID, **fields: Any) -> Test:
"""Update Red Team fields (draft or red_executing only).
Raises BusinessRuleViolation if state not in (draft, red_executing).
Raises EntityNotFoundError if test not found.
Does not commit; caller uses UnitOfWork.
"""
test = get_test_or_raise(db, test_id)
if test.state not in (TestState.draft, TestState.red_executing):
raise BusinessRuleViolation(
f"Cannot update red fields in '{test.state.value}' state "
"(must be draft or red_executing)"
)
for field, value in fields.items():
setattr(test, field, value)
db.flush()
return test
def update_test_blue(db: Session, test_id: uuid.UUID, **fields: Any) -> Test:
"""Update Blue Team fields (blue_evaluating only).
Raises BusinessRuleViolation if state is not blue_evaluating.
Raises EntityNotFoundError if test not found.
Does not commit; caller uses UnitOfWork.
"""
test = get_test_or_raise(db, test_id)
if test.state != TestState.blue_evaluating:
raise BusinessRuleViolation(
f"Cannot update blue fields in '{test.state.value}' state "
"(must be blue_evaluating)"
)
for field, value in fields.items():
setattr(test, field, value)
db.flush()
return test
def get_test_timeline(db: Session, test_id: uuid.UUID) -> list[dict[str, Any]]:
"""Return chronological audit-log history for a test.
Raises EntityNotFoundError if the test does not exist.
"""
get_test_or_raise(db, test_id)
logs = (
db.query(AuditLog)
.filter(
AuditLog.entity_type == "test",
AuditLog.entity_id == str(test_id),
)
.order_by(AuditLog.timestamp.asc())
.all()
)
return [
{
"id": str(log.id),
"action": log.action,
"user_id": str(log.user_id) if log.user_id else None,
"timestamp": log.timestamp.isoformat() if log.timestamp else None,
"details": log.details,
}
for log in logs
]