"""Detection rule data service. Extracts query and business logic from the detection_rules router so that the router remains a thin HTTP adapter. This module is framework-agnostic: no FastAPI imports. """ # Enable future language features for compatibility from __future__ import annotations # Import datetime from datetime from datetime import datetime # Import Any from typing from typing import Any # Import UUID from uuid from uuid import UUID # Import Session from sqlalchemy.orm from sqlalchemy.orm import Session # Import EntityNotFoundError from app.domain.errors from app.domain.errors import EntityNotFoundError # Import DetectionRule from app.models.detection_rule from app.models.detection_rule import DetectionRule # Import Technique from app.models.technique from app.models.technique import Technique # Import Test from app.models.test from app.models.test import Test # Import TestDetectionResult from app.models.test_detection_result from app.models.test_detection_result import TestDetectionResult # Import TestTemplate from app.models.test_template from app.models.test_template import TestTemplate # Import TestTemplateDetectionRule from app.models.test_template_detection_rule from app.models.test_template_detection_rule import TestTemplateDetectionRule # Import escape_like from app.utils from app.utils import escape_like # ── Public service functions ────────────────────────────────────────── def list_rules( # Entry: db db: Session, *, # Entry: technique technique: str | None = None, # Entry: source source: str | None = None, # Entry: severity severity: str | None = None, # Entry: search search: str | None = None, # Entry: offset offset: int = 0, # Entry: limit limit: int = 50, ) -> dict[str, Any]: """List detection rules with optional filters and pagination.""" # Assign query = db.query(DetectionRule).filter(DetectionRule.is_active == True) query = db.query(DetectionRule).filter(DetectionRule.is_active == True) # Check: technique if technique: # Assign query = query.filter(DetectionRule.mitre_technique_id == technique) query = query.filter(DetectionRule.mitre_technique_id == technique) # Check: source if source: # Assign query = query.filter(DetectionRule.source == source) query = query.filter(DetectionRule.source == source) # Check: severity if severity: # Assign query = query.filter(DetectionRule.severity == severity) query = query.filter(DetectionRule.severity == severity) # Check: search if search: # Assign pattern = f"%{escape_like(search)}%" pattern = f"%{escape_like(search)}%" # Assign query = query.filter( query = query.filter( DetectionRule.title.ilike(pattern) | DetectionRule.description.ilike(pattern) ) # Assign total = query.count() total = query.count() # Assign items = ( items = ( query.order_by(DetectionRule.mitre_technique_id, DetectionRule.title) # Chain .offset() call .offset(offset) # Chain .limit() call .limit(limit) # Chain .all() call .all() ) # Return { return { # Literal argument value "total": total, # Literal argument value "offset": offset, # Literal argument value "limit": limit, # Literal argument value "items": [ { # Literal argument value "id": str(r.id), # Literal argument value "mitre_technique_id": r.mitre_technique_id, # Literal argument value "title": r.title, # Literal argument value "description": r.description, # Literal argument value "source": r.source, # Literal argument value "source_url": r.source_url, # Literal argument value "rule_format": r.rule_format, # Literal argument value "severity": r.severity, # Literal argument value "platforms": r.platforms or [], # Literal argument value "log_sources": r.log_sources, # Literal argument value "is_active": r.is_active, } for r in items ], } # Define function get_rules_for_template def get_rules_for_template(db: Session, template_id: str) -> dict[str, Any]: """Get detection rules associated with a test template. Raises EntityNotFoundError if the template does not exist. """ # Assign template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first() template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first() # Check: not template if not template: # Raise EntityNotFoundError raise EntityNotFoundError("Test template", template_id) # Assign associations = ( associations = ( db.query(TestTemplateDetectionRule) # Chain .filter() call .filter(TestTemplateDetectionRule.test_template_id == template_id) # Chain .all() call .all() ) # Assign rules = [] rules = [] # Iterate over associations for assoc in associations: # Assign r = assoc.detection_rule r = assoc.detection_rule # Call rules.append() rules.append({ # Literal argument value "id": str(r.id), # Literal argument value "mitre_technique_id": r.mitre_technique_id, # Literal argument value "title": r.title, # Literal argument value "description": r.description, # Literal argument value "source": r.source, # Literal argument value "source_url": r.source_url, # Literal argument value "rule_content": r.rule_content, # Literal argument value "rule_format": r.rule_format, # Literal argument value "severity": r.severity, # Literal argument value "platforms": r.platforms or [], # Literal argument value "log_sources": r.log_sources, # Literal argument value "is_primary": assoc.is_primary, }) # Return { return { # Literal argument value "template_id": str(template.id), # Literal argument value "template_name": template.name, # Literal argument value "mitre_technique_id": template.mitre_technique_id, # Literal argument value "rules": rules, # Literal argument value "total": len(rules), } # Define function auto_associate_rules def auto_associate_rules(db: Session) -> dict[str, Any]: """Auto-associate test templates with detection rules by MITRE technique ID. For each active template, finds all active detection rules for the same technique and creates associations. Rules with severity high/critical are marked as primary. Performs commit internally. """ # Assign templates = db.query(TestTemplate).filter(TestTemplate.is_active == True).all() templates = db.query(TestTemplate).filter(TestTemplate.is_active == True).all() # Assign rules = db.query(DetectionRule).filter(DetectionRule.is_active == True).all() rules = db.query(DetectionRule).filter(DetectionRule.is_active == True).all() # Assign rules_by_technique = {} rules_by_technique: dict[str, list] = {} # Iterate over rules for rule in rules: # Assign tid = rule.mitre_technique_id tid = rule.mitre_technique_id # Check: tid not in rules_by_technique if tid not in rules_by_technique: # Assign rules_by_technique[tid] = [] rules_by_technique[tid] = [] # rules_by_technique[tid].append(rule) rules_by_technique[tid].append(rule) # Assign created = 0 created = 0 # Assign skipped = 0 skipped = 0 # Assign high_severities = {"high", "critical"} high_severities = {"high", "critical"} # Iterate over templates for template in templates: # Assign matching_rules = rules_by_technique.get(template.mitre_technique_id, []) matching_rules = rules_by_technique.get(template.mitre_technique_id, []) # Iterate over matching_rules for rule in matching_rules: # Assign existing = ( existing = ( db.query(TestTemplateDetectionRule) # Chain .filter() call .filter( TestTemplateDetectionRule.test_template_id == template.id, TestTemplateDetectionRule.detection_rule_id == rule.id, ) # Chain .first() call .first() ) # Check: existing if existing: # Assign skipped = 1 skipped += 1 # Skip to the next loop iteration continue # Assign is_primary = (rule.severity or "").lower() in high_severities is_primary = (rule.severity or "").lower() in high_severities # Assign assoc = TestTemplateDetectionRule( assoc = TestTemplateDetectionRule( # Keyword argument: test_template_id test_template_id=template.id, # Keyword argument: detection_rule_id detection_rule_id=rule.id, # Keyword argument: is_primary is_primary=is_primary, ) # Stage new record(s) for database insertion db.add(assoc) # Assign created = 1 created += 1 # Commit all pending changes to the database db.commit() # Assign total = db.query(TestTemplateDetectionRule).count() total = db.query(TestTemplateDetectionRule).count() # Return { return { # Literal argument value "created": created, # Literal argument value "skipped": skipped, # Literal argument value "total_associations": total, } # Define function get_rules_for_test def get_rules_for_test(db: Session, test_id: str) -> dict[str, Any]: """Get detection rules relevant to a test, along with their evaluation results. Finds rules by matching the test's technique to detection rules. Raises EntityNotFoundError if the test or its technique does not exist. """ # Assign test = db.query(Test).filter(Test.id == test_id).first() test = db.query(Test).filter(Test.id == test_id).first() # Check: not test if not test: # Raise EntityNotFoundError raise EntityNotFoundError("Test", str(test_id)) # Assign technique = db.query(Technique).filter(Technique.id == test.technique_id).first() technique = db.query(Technique).filter(Technique.id == test.technique_id).first() # Check: not technique if not technique: # Raise EntityNotFoundError raise EntityNotFoundError("Technique", str(test.technique_id)) # Assign rules = ( rules = ( db.query(DetectionRule) # Chain .filter() call .filter( DetectionRule.mitre_technique_id == technique.mitre_id, DetectionRule.is_active == True, ) # Chain .order_by() call .order_by(DetectionRule.severity.desc(), DetectionRule.title) # Chain .all() call .all() ) # Assign existing_results = ( existing_results = ( db.query(TestDetectionResult) # Chain .filter() call .filter(TestDetectionResult.test_id == test_id) # Chain .all() call .all() ) # Assign results_map = {str(r.detection_rule_id): r for r in existing_results} results_map = {str(r.detection_rule_id): r for r in existing_results} # Assign items = [] items = [] # Assign triggered_count = 0 triggered_count = 0 # Assign evaluated_count = 0 evaluated_count = 0 # Iterate over rules for rule in rules: # Assign result = results_map.get(str(rule.id)) result = results_map.get(str(rule.id)) # Assign triggered = result.triggered if result else None triggered = result.triggered if result else None # Assign notes = result.notes if result else None notes = result.notes if result else None # Assign evaluated_at = result.evaluated_at.isoformat() if result and result.evaluated_at e... evaluated_at = result.evaluated_at.isoformat() if result and result.evaluated_at else None # Check: triggered is not None if triggered is not None: # Assign evaluated_count = 1 evaluated_count += 1 # Check: triggered if triggered: # Assign triggered_count = 1 triggered_count += 1 # Call items.append() items.append({ # Literal argument value "id": str(rule.id), # Literal argument value "mitre_technique_id": rule.mitre_technique_id, # Literal argument value "title": rule.title, # Literal argument value "description": rule.description, # Literal argument value "source": rule.source, # Literal argument value "source_url": rule.source_url, # Literal argument value "rule_content": rule.rule_content, # Literal argument value "rule_format": rule.rule_format, # Literal argument value "severity": rule.severity, # Literal argument value "platforms": rule.platforms or [], # Literal argument value "log_sources": rule.log_sources, # Literal argument value "triggered": triggered, # Literal argument value "notes": notes, # Literal argument value "evaluated_at": evaluated_at, # Literal argument value "result_id": str(result.id) if result else None, }) # Return { return { # Literal argument value "test_id": str(test.id), # Literal argument value "mitre_technique_id": technique.mitre_id, # Literal argument value "rules": items, # Literal argument value "total": len(items), # Literal argument value "evaluated": evaluated_count, # Literal argument value "triggered": triggered_count, # Literal argument value "detection_rate": round(triggered_count / evaluated_count * 100, 1) if evaluated_count > 0 else 0, } # Define function evaluate_rule def evaluate_rule( # Entry: db db: Session, *, # Entry: test_id test_id: UUID, # Entry: detection_rule_id detection_rule_id: UUID, # Entry: triggered triggered: bool | None, # Entry: notes notes: str | None, # Entry: evaluator_id evaluator_id: UUID, ) -> dict[str, Any]: """Save or update the evaluation result for a detection rule on a test. Raises EntityNotFoundError if the test or detection rule does not exist. """ # Assign test = db.query(Test).filter(Test.id == test_id).first() test = db.query(Test).filter(Test.id == test_id).first() # Check: not test if not test: # Raise EntityNotFoundError raise EntityNotFoundError("Test", str(test_id)) # Assign rule = db.query(DetectionRule).filter(DetectionRule.id == detection_rule_i... rule = db.query(DetectionRule).filter(DetectionRule.id == detection_rule_id).first() # Check: not rule if not rule: # Raise EntityNotFoundError raise EntityNotFoundError("Detection rule", str(detection_rule_id)) # Assign existing = ( existing = ( db.query(TestDetectionResult) # Chain .filter() call .filter( TestDetectionResult.test_id == test_id, TestDetectionResult.detection_rule_id == detection_rule_id, ) # Chain .first() call .first() ) # Check: existing if existing: # Assign existing.triggered = triggered existing.triggered = triggered # Assign existing.notes = notes existing.notes = notes # Assign existing.evaluated_by = evaluator_id existing.evaluated_by = evaluator_id # Assign existing.evaluated_at = datetime.utcnow() existing.evaluated_at = datetime.utcnow() # Commit all pending changes to the database db.commit() # Reload ORM object attributes from the database db.refresh(existing) # Return { return { # Literal argument value "id": str(existing.id), # Literal argument value "triggered": existing.triggered, # Literal argument value "notes": existing.notes, # Literal argument value "evaluated_at": existing.evaluated_at.isoformat() if existing.evaluated_at else None, } # Fallback: handle remaining cases else: # Assign result = TestDetectionResult( result = TestDetectionResult( # Keyword argument: test_id test_id=test_id, # Keyword argument: detection_rule_id detection_rule_id=detection_rule_id, # Keyword argument: triggered triggered=triggered, # Keyword argument: notes notes=notes, # Keyword argument: evaluated_by evaluated_by=evaluator_id, # Keyword argument: evaluated_at evaluated_at=datetime.utcnow(), ) # Stage new record(s) for database insertion db.add(result) # Commit all pending changes to the database db.commit() # Reload ORM object attributes from the database db.refresh(result) # Return { return { # Literal argument value "id": str(result.id), # Literal argument value "triggered": result.triggered, # Literal argument value "notes": result.notes, # Literal argument value "evaluated_at": result.evaluated_at.isoformat() if result.evaluated_at else None, }