Files
Aegis/backend/app/services/detection_rule_service.py

320 lines
10 KiB
Python

"""Detection rule data service.
Extracts query and business logic from the detection_rules router so
that the router remains a thin HTTP adapter.
This module is framework-agnostic: no FastAPI imports.
"""
from __future__ import annotations
from datetime import datetime
from typing import Any
from sqlalchemy.orm import Session
from app.domain.errors import EntityNotFoundError
from app.models.detection_rule import DetectionRule
from app.models.test import Test
from app.models.test_template import TestTemplate
from app.models.test_template_detection_rule import TestTemplateDetectionRule
from app.models.test_detection_result import TestDetectionResult
from app.models.technique import Technique
from app.utils import escape_like
# ── Public service functions ──────────────────────────────────────────
def list_rules(
db: Session,
*,
technique: str | None = None,
source: str | None = None,
severity: str | None = None,
search: str | None = None,
offset: int = 0,
limit: int = 50,
) -> dict[str, Any]:
"""List detection rules with optional filters and pagination."""
query = db.query(DetectionRule).filter(DetectionRule.is_active == True)
if technique:
query = query.filter(DetectionRule.mitre_technique_id == technique)
if source:
query = query.filter(DetectionRule.source == source)
if severity:
query = query.filter(DetectionRule.severity == severity)
if search:
pattern = f"%{escape_like(search)}%"
query = query.filter(
DetectionRule.title.ilike(pattern)
| DetectionRule.description.ilike(pattern)
)
total = query.count()
items = (
query.order_by(DetectionRule.mitre_technique_id, DetectionRule.title)
.offset(offset)
.limit(limit)
.all()
)
return {
"total": total,
"offset": offset,
"limit": limit,
"items": [
{
"id": str(r.id),
"mitre_technique_id": r.mitre_technique_id,
"title": r.title,
"description": r.description,
"source": r.source,
"source_url": r.source_url,
"rule_format": r.rule_format,
"severity": r.severity,
"platforms": r.platforms or [],
"log_sources": r.log_sources,
"is_active": r.is_active,
}
for r in items
],
}
def get_rules_for_template(db: Session, template_id: str) -> dict[str, Any]:
"""Get detection rules associated with a test template.
Raises EntityNotFoundError if the template does not exist.
"""
template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first()
if not template:
raise EntityNotFoundError("Test template", template_id)
associations = (
db.query(TestTemplateDetectionRule)
.filter(TestTemplateDetectionRule.test_template_id == template_id)
.all()
)
rules = []
for assoc in associations:
r = assoc.detection_rule
rules.append({
"id": str(r.id),
"mitre_technique_id": r.mitre_technique_id,
"title": r.title,
"description": r.description,
"source": r.source,
"source_url": r.source_url,
"rule_content": r.rule_content,
"rule_format": r.rule_format,
"severity": r.severity,
"platforms": r.platforms or [],
"log_sources": r.log_sources,
"is_primary": assoc.is_primary,
})
return {
"template_id": str(template.id),
"template_name": template.name,
"mitre_technique_id": template.mitre_technique_id,
"rules": rules,
"total": len(rules),
}
def auto_associate_rules(db: Session) -> dict[str, Any]:
"""Auto-associate test templates with detection rules by MITRE technique ID.
For each active template, finds all active detection rules for the same
technique and creates associations. Rules with severity high/critical
are marked as primary. Performs commit internally.
"""
templates = db.query(TestTemplate).filter(TestTemplate.is_active == True).all()
rules = db.query(DetectionRule).filter(DetectionRule.is_active == True).all()
rules_by_technique: dict[str, list] = {}
for rule in rules:
tid = rule.mitre_technique_id
if tid not in rules_by_technique:
rules_by_technique[tid] = []
rules_by_technique[tid].append(rule)
created = 0
skipped = 0
high_severities = {"high", "critical"}
for template in templates:
matching_rules = rules_by_technique.get(template.mitre_technique_id, [])
for rule in matching_rules:
existing = (
db.query(TestTemplateDetectionRule)
.filter(
TestTemplateDetectionRule.test_template_id == template.id,
TestTemplateDetectionRule.detection_rule_id == rule.id,
)
.first()
)
if existing:
skipped += 1
continue
is_primary = (rule.severity or "").lower() in high_severities
assoc = TestTemplateDetectionRule(
test_template_id=template.id,
detection_rule_id=rule.id,
is_primary=is_primary,
)
db.add(assoc)
created += 1
db.commit()
total = db.query(TestTemplateDetectionRule).count()
return {
"created": created,
"skipped": skipped,
"total_associations": total,
}
def get_rules_for_test(db: Session, test_id: str) -> dict[str, Any]:
"""Get detection rules relevant to a test, along with their evaluation results.
Finds rules by matching the test's technique to detection rules.
Raises EntityNotFoundError if the test or its technique does not exist.
"""
test = db.query(Test).filter(Test.id == test_id).first()
if not test:
raise EntityNotFoundError("Test", str(test_id))
technique = db.query(Technique).filter(Technique.id == test.technique_id).first()
if not technique:
raise EntityNotFoundError("Technique", str(test.technique_id))
rules = (
db.query(DetectionRule)
.filter(
DetectionRule.mitre_technique_id == technique.mitre_id,
DetectionRule.is_active == True,
)
.order_by(DetectionRule.severity.desc(), DetectionRule.title)
.all()
)
existing_results = (
db.query(TestDetectionResult)
.filter(TestDetectionResult.test_id == test_id)
.all()
)
results_map = {str(r.detection_rule_id): r for r in existing_results}
items = []
triggered_count = 0
evaluated_count = 0
for rule in rules:
result = results_map.get(str(rule.id))
triggered = result.triggered if result else None
notes = result.notes if result else None
evaluated_at = result.evaluated_at.isoformat() if result and result.evaluated_at else None
if triggered is not None:
evaluated_count += 1
if triggered:
triggered_count += 1
items.append({
"id": str(rule.id),
"mitre_technique_id": rule.mitre_technique_id,
"title": rule.title,
"description": rule.description,
"source": rule.source,
"source_url": rule.source_url,
"rule_content": rule.rule_content,
"rule_format": rule.rule_format,
"severity": rule.severity,
"platforms": rule.platforms or [],
"log_sources": rule.log_sources,
"triggered": triggered,
"notes": notes,
"evaluated_at": evaluated_at,
"result_id": str(result.id) if result else None,
})
return {
"test_id": str(test.id),
"mitre_technique_id": technique.mitre_id,
"rules": items,
"total": len(items),
"evaluated": evaluated_count,
"triggered": triggered_count,
"detection_rate": round(triggered_count / evaluated_count * 100, 1) if evaluated_count > 0 else 0,
}
def evaluate_rule(
db: Session,
*,
test_id: Any,
detection_rule_id: Any,
triggered: bool | None,
notes: str | None,
evaluator_id: Any,
) -> dict[str, Any]:
"""Save or update the evaluation result for a detection rule on a test.
Raises EntityNotFoundError if the test or detection rule does not exist.
"""
test = db.query(Test).filter(Test.id == test_id).first()
if not test:
raise EntityNotFoundError("Test", str(test_id))
rule = db.query(DetectionRule).filter(DetectionRule.id == detection_rule_id).first()
if not rule:
raise EntityNotFoundError("Detection rule", str(detection_rule_id))
existing = (
db.query(TestDetectionResult)
.filter(
TestDetectionResult.test_id == test_id,
TestDetectionResult.detection_rule_id == detection_rule_id,
)
.first()
)
if existing:
existing.triggered = triggered
existing.notes = notes
existing.evaluated_by = evaluator_id
existing.evaluated_at = datetime.utcnow()
db.commit()
db.refresh(existing)
return {
"id": str(existing.id),
"triggered": existing.triggered,
"notes": existing.notes,
"evaluated_at": existing.evaluated_at.isoformat() if existing.evaluated_at else None,
}
else:
result = TestDetectionResult(
test_id=test_id,
detection_rule_id=detection_rule_id,
triggered=triggered,
notes=notes,
evaluated_by=evaluator_id,
evaluated_at=datetime.utcnow(),
)
db.add(result)
db.commit()
db.refresh(result)
return {
"id": str(result.id),
"triggered": result.triggered,
"notes": result.notes,
"evaluated_at": result.evaluated_at.isoformat() if result.evaluated_at else None,
}