"""Test template service — framework-agnostic CRUD and queries.""" # Enable future language features for compatibility from __future__ import annotations # Import uuid import uuid # Import func, or_ from sqlalchemy from sqlalchemy import func, or_ # Import Session from sqlalchemy.orm from sqlalchemy.orm import Session # Import EntityNotFoundError from app.domain.errors from app.domain.errors import EntityNotFoundError # Import TestTemplate from app.models.test_template from app.models.test_template import TestTemplate # Import escape_like from app.utils from app.utils import escape_like # Define function list_templates def list_templates( # Entry: db db: Session, *, # Entry: source source: str | None = None, # Entry: platform platform: str | None = None, # Entry: severity severity: str | None = None, # Entry: mitre_technique_id mitre_technique_id: str | None = None, # Entry: search search: str | None = None, # Entry: is_active is_active: bool | None = None, # Entry: offset offset: int = 0, # Entry: limit limit: int = 50, ) -> list: """Return paginated, filterable list of test templates.""" # Assign query = db.query(TestTemplate) query = db.query(TestTemplate) # Check: is_active is not None if is_active is not None: # Assign query = query.filter(TestTemplate.is_active == is_active) query = query.filter(TestTemplate.is_active == is_active) # Check: source if source: # Assign query = query.filter(TestTemplate.source == source) query = query.filter(TestTemplate.source == source) # Check: platform if platform: # Assign query = query.filter(TestTemplate.platform.ilike(f"%{escape_like(platform)}... query = query.filter(TestTemplate.platform.ilike(f"%{escape_like(platform)}%")) # Check: severity if severity: # Assign query = query.filter(TestTemplate.severity == severity) query = query.filter(TestTemplate.severity == severity) # Check: mitre_technique_id if mitre_technique_id: # Assign query = query.filter(TestTemplate.mitre_technique_id == mitre_technique_id) query = query.filter(TestTemplate.mitre_technique_id == mitre_technique_id) # Check: search if search: # Assign pattern = f"%{escape_like(search)}%" pattern = f"%{escape_like(search)}%" # Assign query = query.filter( query = query.filter( or_( TestTemplate.name.ilike(pattern), TestTemplate.description.ilike(pattern), ) ) # Assign templates = ( templates = ( query # Chain .order_by() call .order_by(TestTemplate.mitre_technique_id, TestTemplate.name) # Chain .offset() call .offset(offset) # Chain .limit() call .limit(limit) # Chain .all() call .all() ) # Return templates return templates # Define function get_template_stats def get_template_stats(db: Session) -> dict: """Return catalog statistics: totals by source, platform, active/inactive.""" # Assign total = db.query(func.count(TestTemplate.id)).scalar() or 0 total = db.query(func.count(TestTemplate.id)).scalar() or 0 # Assign active = ( active = ( db.query(func.count(TestTemplate.id)) # Chain .filter() call .filter(TestTemplate.is_active == True) # noqa: E712 # Chain .scalar() call .scalar() ) or 0 # Assign inactive = total - active inactive = total - active # Assign source_rows = ( source_rows = ( db.query(TestTemplate.source, func.count(TestTemplate.id)) # Chain .filter() call .filter(TestTemplate.is_active == True) # noqa: E712 # Chain .group_by() call .group_by(TestTemplate.source) # Chain .all() call .all() ) # Assign by_source = {source: cnt for source, cnt in source_rows} by_source = {source: cnt for source, cnt in source_rows} # Assign platform_rows = ( platform_rows = ( db.query(TestTemplate.platform, func.count(TestTemplate.id)) # Chain .filter() call .filter(TestTemplate.is_active == True) # noqa: E712 # Chain .group_by() call .group_by(TestTemplate.platform) # Chain .all() call .all() ) # Assign by_platform = {(platform or "unspecified"): cnt for platform, cnt in platform_rows} by_platform = {(platform or "unspecified"): cnt for platform, cnt in platform_rows} # Return { return { # Literal argument value "total": total, # Literal argument value "active": active, # Literal argument value "inactive": inactive, # Literal argument value "by_source": by_source, # Literal argument value "by_platform": by_platform, } # Define function bulk_activate def bulk_activate(db: Session, *, activate: bool) -> int: """Set all templates to active or inactive. Returns count of affected. Does NOT commit.""" # Assign count = ( count = ( db.query(TestTemplate) # Chain .filter() call .filter(TestTemplate.is_active != activate) # Chain .update() call .update({TestTemplate.is_active: activate}) ) # Return count return count # Define function get_templates_by_technique def get_templates_by_technique(db: Session, mitre_id: str) -> list: """Return all active templates mapped to a specific MITRE technique.""" # Return ( return ( db.query(TestTemplate) # Chain .filter() call .filter( TestTemplate.mitre_technique_id == mitre_id, TestTemplate.is_active == True, # noqa: E712 ) # Chain .order_by() call .order_by(TestTemplate.name) # Chain .all() call .all() ) # Define function get_template_or_raise def get_template_or_raise(db: Session, template_id: uuid.UUID) -> TestTemplate: """Return a template by ID. Raises EntityNotFoundError if not found.""" # Assign template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first() template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first() # Check: template is None if template is None: # Raise EntityNotFoundError raise EntityNotFoundError("Test template", str(template_id)) # Return template return template # Define function create_template def create_template(db: Session, **fields: object) -> TestTemplate: """Create a test template from keyword args (e.g. payload.model_dump()). Does NOT commit.""" # Assign template = TestTemplate(**fields) template = TestTemplate(**fields) # Stage new record(s) for database insertion db.add(template) # Return template return template # Define function update_template def update_template(db: Session, template_id: uuid.UUID, **fields: object) -> TestTemplate: """Update an existing template. Raises EntityNotFoundError if not found. Does NOT commit.""" # Assign template = get_template_or_raise(db, template_id) template = get_template_or_raise(db, template_id) # Iterate over fields.items() for field, value in fields.items(): # Check: hasattr(template, field) if hasattr(template, field): # Call setattr() setattr(template, field, value) # Return template return template # Define function toggle_template_active def toggle_template_active(db: Session, template_id: uuid.UUID) -> TestTemplate: """Toggle template active/inactive. Does NOT commit.""" # Assign template = get_template_or_raise(db, template_id) template = get_template_or_raise(db, template_id) # Assign template.is_active = not template.is_active template.is_active = not template.is_active # Return template return template # Define function soft_delete_template def soft_delete_template(db: Session, template_id: uuid.UUID) -> None: """Soft-delete a template by setting is_active=False. Does NOT commit.""" # Assign template = get_template_or_raise(db, template_id) template = get_template_or_raise(db, template_id) # Assign template.is_active = False template.is_active = False