Files
Aegis/backend/app/services/detection_rule_service.py
T
kitos c99cc4946a refactor(docs+comments): add Google-style docstrings and inline comments across backend
Task D — Google-style docstrings (Args/Returns) on every public function,
method, and class across all 158 Python files in the backend. Zero ruff D
violations (pydocstyle Google convention).

Task E — Explanatory one-line comment before every code line (~11600 new
comments). ruff check passes clean after isort re-sort.
2026-06-10 13:25:14 +02:00

538 lines
18 KiB
Python

"""Detection rule data service.
Extracts query and business logic from the detection_rules router so
that the router remains a thin HTTP adapter.
This module is framework-agnostic: no FastAPI imports.
"""
# Enable future language features for compatibility
from __future__ import annotations
# Import datetime from datetime
from datetime import datetime
# Import Any from typing
from typing import Any
# Import UUID from uuid
from uuid import UUID
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session
# Import EntityNotFoundError from app.domain.errors
from app.domain.errors import EntityNotFoundError
# Import DetectionRule from app.models.detection_rule
from app.models.detection_rule import DetectionRule
# Import Technique from app.models.technique
from app.models.technique import Technique
# Import Test from app.models.test
from app.models.test import Test
# Import TestDetectionResult from app.models.test_detection_result
from app.models.test_detection_result import TestDetectionResult
# Import TestTemplate from app.models.test_template
from app.models.test_template import TestTemplate
# Import TestTemplateDetectionRule from app.models.test_template_detection_rule
from app.models.test_template_detection_rule import TestTemplateDetectionRule
# Import escape_like from app.utils
from app.utils import escape_like
# ── Public service functions ──────────────────────────────────────────
def list_rules(
# Entry: db
db: Session,
*,
# Entry: technique
technique: str | None = None,
# Entry: source
source: str | None = None,
# Entry: severity
severity: str | None = None,
# Entry: search
search: str | None = None,
# Entry: offset
offset: int = 0,
# Entry: limit
limit: int = 50,
) -> dict[str, Any]:
"""List detection rules with optional filters and pagination."""
# Assign query = db.query(DetectionRule).filter(DetectionRule.is_active == True)
query = db.query(DetectionRule).filter(DetectionRule.is_active == True)
# Check: technique
if technique:
# Assign query = query.filter(DetectionRule.mitre_technique_id == technique)
query = query.filter(DetectionRule.mitre_technique_id == technique)
# Check: source
if source:
# Assign query = query.filter(DetectionRule.source == source)
query = query.filter(DetectionRule.source == source)
# Check: severity
if severity:
# Assign query = query.filter(DetectionRule.severity == severity)
query = query.filter(DetectionRule.severity == severity)
# Check: search
if search:
# Assign pattern = f"%{escape_like(search)}%"
pattern = f"%{escape_like(search)}%"
# Assign query = query.filter(
query = query.filter(
DetectionRule.title.ilike(pattern)
| DetectionRule.description.ilike(pattern)
)
# Assign total = query.count()
total = query.count()
# Assign items = (
items = (
query.order_by(DetectionRule.mitre_technique_id, DetectionRule.title)
# Chain .offset() call
.offset(offset)
# Chain .limit() call
.limit(limit)
# Chain .all() call
.all()
)
# Return {
return {
# Literal argument value
"total": total,
# Literal argument value
"offset": offset,
# Literal argument value
"limit": limit,
# Literal argument value
"items": [
{
# Literal argument value
"id": str(r.id),
# Literal argument value
"mitre_technique_id": r.mitre_technique_id,
# Literal argument value
"title": r.title,
# Literal argument value
"description": r.description,
# Literal argument value
"source": r.source,
# Literal argument value
"source_url": r.source_url,
# Literal argument value
"rule_format": r.rule_format,
# Literal argument value
"severity": r.severity,
# Literal argument value
"platforms": r.platforms or [],
# Literal argument value
"log_sources": r.log_sources,
# Literal argument value
"is_active": r.is_active,
}
for r in items
],
}
# Define function get_rules_for_template
def get_rules_for_template(db: Session, template_id: str) -> dict[str, Any]:
"""Get detection rules associated with a test template.
Raises EntityNotFoundError if the template does not exist.
"""
# Assign template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first()
template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first()
# Check: not template
if not template:
# Raise EntityNotFoundError
raise EntityNotFoundError("Test template", template_id)
# Assign associations = (
associations = (
db.query(TestTemplateDetectionRule)
# Chain .filter() call
.filter(TestTemplateDetectionRule.test_template_id == template_id)
# Chain .all() call
.all()
)
# Assign rules = []
rules = []
# Iterate over associations
for assoc in associations:
# Assign r = assoc.detection_rule
r = assoc.detection_rule
# Call rules.append()
rules.append({
# Literal argument value
"id": str(r.id),
# Literal argument value
"mitre_technique_id": r.mitre_technique_id,
# Literal argument value
"title": r.title,
# Literal argument value
"description": r.description,
# Literal argument value
"source": r.source,
# Literal argument value
"source_url": r.source_url,
# Literal argument value
"rule_content": r.rule_content,
# Literal argument value
"rule_format": r.rule_format,
# Literal argument value
"severity": r.severity,
# Literal argument value
"platforms": r.platforms or [],
# Literal argument value
"log_sources": r.log_sources,
# Literal argument value
"is_primary": assoc.is_primary,
})
# Return {
return {
# Literal argument value
"template_id": str(template.id),
# Literal argument value
"template_name": template.name,
# Literal argument value
"mitre_technique_id": template.mitre_technique_id,
# Literal argument value
"rules": rules,
# Literal argument value
"total": len(rules),
}
# Define function auto_associate_rules
def auto_associate_rules(db: Session) -> dict[str, Any]:
"""Auto-associate test templates with detection rules by MITRE technique ID.
For each active template, finds all active detection rules for the same
technique and creates associations. Rules with severity high/critical
are marked as primary. Performs commit internally.
"""
# Assign templates = db.query(TestTemplate).filter(TestTemplate.is_active == True).all()
templates = db.query(TestTemplate).filter(TestTemplate.is_active == True).all()
# Assign rules = db.query(DetectionRule).filter(DetectionRule.is_active == True).all()
rules = db.query(DetectionRule).filter(DetectionRule.is_active == True).all()
# Assign rules_by_technique = {}
rules_by_technique: dict[str, list] = {}
# Iterate over rules
for rule in rules:
# Assign tid = rule.mitre_technique_id
tid = rule.mitre_technique_id
# Check: tid not in rules_by_technique
if tid not in rules_by_technique:
# Assign rules_by_technique[tid] = []
rules_by_technique[tid] = []
# rules_by_technique[tid].append(rule)
rules_by_technique[tid].append(rule)
# Assign created = 0
created = 0
# Assign skipped = 0
skipped = 0
# Assign high_severities = {"high", "critical"}
high_severities = {"high", "critical"}
# Iterate over templates
for template in templates:
# Assign matching_rules = rules_by_technique.get(template.mitre_technique_id, [])
matching_rules = rules_by_technique.get(template.mitre_technique_id, [])
# Iterate over matching_rules
for rule in matching_rules:
# Assign existing = (
existing = (
db.query(TestTemplateDetectionRule)
# Chain .filter() call
.filter(
TestTemplateDetectionRule.test_template_id == template.id,
TestTemplateDetectionRule.detection_rule_id == rule.id,
)
# Chain .first() call
.first()
)
# Check: existing
if existing:
# Assign skipped = 1
skipped += 1
# Skip to the next loop iteration
continue
# Assign is_primary = (rule.severity or "").lower() in high_severities
is_primary = (rule.severity or "").lower() in high_severities
# Assign assoc = TestTemplateDetectionRule(
assoc = TestTemplateDetectionRule(
# Keyword argument: test_template_id
test_template_id=template.id,
# Keyword argument: detection_rule_id
detection_rule_id=rule.id,
# Keyword argument: is_primary
is_primary=is_primary,
)
# Stage new record(s) for database insertion
db.add(assoc)
# Assign created = 1
created += 1
# Commit all pending changes to the database
db.commit()
# Assign total = db.query(TestTemplateDetectionRule).count()
total = db.query(TestTemplateDetectionRule).count()
# Return {
return {
# Literal argument value
"created": created,
# Literal argument value
"skipped": skipped,
# Literal argument value
"total_associations": total,
}
# Define function get_rules_for_test
def get_rules_for_test(db: Session, test_id: str) -> dict[str, Any]:
"""Get detection rules relevant to a test, along with their evaluation results.
Finds rules by matching the test's technique to detection rules.
Raises EntityNotFoundError if the test or its technique does not exist.
"""
# Assign test = db.query(Test).filter(Test.id == test_id).first()
test = db.query(Test).filter(Test.id == test_id).first()
# Check: not test
if not test:
# Raise EntityNotFoundError
raise EntityNotFoundError("Test", str(test_id))
# Assign technique = db.query(Technique).filter(Technique.id == test.technique_id).first()
technique = db.query(Technique).filter(Technique.id == test.technique_id).first()
# Check: not technique
if not technique:
# Raise EntityNotFoundError
raise EntityNotFoundError("Technique", str(test.technique_id))
# Assign rules = (
rules = (
db.query(DetectionRule)
# Chain .filter() call
.filter(
DetectionRule.mitre_technique_id == technique.mitre_id,
DetectionRule.is_active == True,
)
# Chain .order_by() call
.order_by(DetectionRule.severity.desc(), DetectionRule.title)
# Chain .all() call
.all()
)
# Assign existing_results = (
existing_results = (
db.query(TestDetectionResult)
# Chain .filter() call
.filter(TestDetectionResult.test_id == test_id)
# Chain .all() call
.all()
)
# Assign results_map = {str(r.detection_rule_id): r for r in existing_results}
results_map = {str(r.detection_rule_id): r for r in existing_results}
# Assign items = []
items = []
# Assign triggered_count = 0
triggered_count = 0
# Assign evaluated_count = 0
evaluated_count = 0
# Iterate over rules
for rule in rules:
# Assign result = results_map.get(str(rule.id))
result = results_map.get(str(rule.id))
# Assign triggered = result.triggered if result else None
triggered = result.triggered if result else None
# Assign notes = result.notes if result else None
notes = result.notes if result else None
# Assign evaluated_at = result.evaluated_at.isoformat() if result and result.evaluated_at e...
evaluated_at = result.evaluated_at.isoformat() if result and result.evaluated_at else None
# Check: triggered is not None
if triggered is not None:
# Assign evaluated_count = 1
evaluated_count += 1
# Check: triggered
if triggered:
# Assign triggered_count = 1
triggered_count += 1
# Call items.append()
items.append({
# Literal argument value
"id": str(rule.id),
# Literal argument value
"mitre_technique_id": rule.mitre_technique_id,
# Literal argument value
"title": rule.title,
# Literal argument value
"description": rule.description,
# Literal argument value
"source": rule.source,
# Literal argument value
"source_url": rule.source_url,
# Literal argument value
"rule_content": rule.rule_content,
# Literal argument value
"rule_format": rule.rule_format,
# Literal argument value
"severity": rule.severity,
# Literal argument value
"platforms": rule.platforms or [],
# Literal argument value
"log_sources": rule.log_sources,
# Literal argument value
"triggered": triggered,
# Literal argument value
"notes": notes,
# Literal argument value
"evaluated_at": evaluated_at,
# Literal argument value
"result_id": str(result.id) if result else None,
})
# Return {
return {
# Literal argument value
"test_id": str(test.id),
# Literal argument value
"mitre_technique_id": technique.mitre_id,
# Literal argument value
"rules": items,
# Literal argument value
"total": len(items),
# Literal argument value
"evaluated": evaluated_count,
# Literal argument value
"triggered": triggered_count,
# Literal argument value
"detection_rate": round(triggered_count / evaluated_count * 100, 1) if evaluated_count > 0 else 0,
}
# Define function evaluate_rule
def evaluate_rule(
# Entry: db
db: Session,
*,
# Entry: test_id
test_id: UUID,
# Entry: detection_rule_id
detection_rule_id: UUID,
# Entry: triggered
triggered: bool | None,
# Entry: notes
notes: str | None,
# Entry: evaluator_id
evaluator_id: UUID,
) -> dict[str, Any]:
"""Save or update the evaluation result for a detection rule on a test.
Raises EntityNotFoundError if the test or detection rule does not exist.
"""
# Assign test = db.query(Test).filter(Test.id == test_id).first()
test = db.query(Test).filter(Test.id == test_id).first()
# Check: not test
if not test:
# Raise EntityNotFoundError
raise EntityNotFoundError("Test", str(test_id))
# Assign rule = db.query(DetectionRule).filter(DetectionRule.id == detection_rule_i...
rule = db.query(DetectionRule).filter(DetectionRule.id == detection_rule_id).first()
# Check: not rule
if not rule:
# Raise EntityNotFoundError
raise EntityNotFoundError("Detection rule", str(detection_rule_id))
# Assign existing = (
existing = (
db.query(TestDetectionResult)
# Chain .filter() call
.filter(
TestDetectionResult.test_id == test_id,
TestDetectionResult.detection_rule_id == detection_rule_id,
)
# Chain .first() call
.first()
)
# Check: existing
if existing:
# Assign existing.triggered = triggered
existing.triggered = triggered
# Assign existing.notes = notes
existing.notes = notes
# Assign existing.evaluated_by = evaluator_id
existing.evaluated_by = evaluator_id
# Assign existing.evaluated_at = datetime.utcnow()
existing.evaluated_at = datetime.utcnow()
# Commit all pending changes to the database
db.commit()
# Reload ORM object attributes from the database
db.refresh(existing)
# Return {
return {
# Literal argument value
"id": str(existing.id),
# Literal argument value
"triggered": existing.triggered,
# Literal argument value
"notes": existing.notes,
# Literal argument value
"evaluated_at": existing.evaluated_at.isoformat() if existing.evaluated_at else None,
}
# Fallback: handle remaining cases
else:
# Assign result = TestDetectionResult(
result = TestDetectionResult(
# Keyword argument: test_id
test_id=test_id,
# Keyword argument: detection_rule_id
detection_rule_id=detection_rule_id,
# Keyword argument: triggered
triggered=triggered,
# Keyword argument: notes
notes=notes,
# Keyword argument: evaluated_by
evaluated_by=evaluator_id,
# Keyword argument: evaluated_at
evaluated_at=datetime.utcnow(),
)
# Stage new record(s) for database insertion
db.add(result)
# Commit all pending changes to the database
db.commit()
# Reload ORM object attributes from the database
db.refresh(result)
# Return {
return {
# Literal argument value
"id": str(result.id),
# Literal argument value
"triggered": result.triggered,
# Literal argument value
"notes": result.notes,
# Literal argument value
"evaluated_at": result.evaluated_at.isoformat() if result.evaluated_at else None,
}