refactor(detection-rules): extract query/business logic to detection_rule_service, router is thin HTTP adapter

This commit is contained in:
2026-02-19 17:39:31 +01:00
parent d305db8794
commit 560fc0c9f0
7 changed files with 5853 additions and 282 deletions

View File

@@ -1,31 +1,32 @@
"""Detection rules endpoints — listing, filtering, and template association.
Thin HTTP adapter: delegates all query and business logic to detection_rule_service.
Provides endpoints for browsing detection rules, querying rules by technique,
and managing the template ↔ detection rule associations.
"""
import logging
import uuid
from typing import Optional
from datetime import datetime
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi import APIRouter, Depends, Query
from pydantic import BaseModel
from sqlalchemy import func
from sqlalchemy.orm import Session
from app.database import get_db
from app.dependencies.auth import get_current_user, require_role, require_any_role
from app.models.user import User
from app.models.detection_rule import DetectionRule
from app.models.test_template import TestTemplate
from app.models.test_template_detection_rule import TestTemplateDetectionRule
from app.models.test_detection_result import TestDetectionResult
from app.services.detection_rule_service import (
list_rules,
get_rules_for_template,
auto_associate_rules,
get_rules_for_test,
evaluate_rule,
)
# ---------------------------------------------------------------------------
# Pydantic schemas for request validation
# ---------------------------------------------------------------------------
# ── Pydantic schemas for request validation ────────────────────────────
class DetectionRuleEvaluate(BaseModel):
"""Payload for evaluating a detection rule against a test."""
@@ -34,14 +35,12 @@ class DetectionRuleEvaluate(BaseModel):
triggered: Optional[bool] = None
notes: Optional[str] = None
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/detection-rules", tags=["detection-rules"])
# ---------------------------------------------------------------------------
# GET /detection-rules — List with filters
# ---------------------------------------------------------------------------
# ── GET /detection-rules — List with filters ───────────────────────────
@router.get("")
def list_detection_rules(
@@ -55,54 +54,19 @@ def list_detection_rules(
current_user: User = Depends(get_current_user),
):
"""List detection rules with optional filters and pagination."""
query = db.query(DetectionRule).filter(DetectionRule.is_active == True) # noqa: E712
if technique:
query = query.filter(DetectionRule.mitre_technique_id == technique)
if source:
query = query.filter(DetectionRule.source == source)
if severity:
query = query.filter(DetectionRule.severity == severity)
if search:
from app.utils import escape_like
pattern = f"%{escape_like(search)}%"
query = query.filter(
DetectionRule.title.ilike(pattern)
| DetectionRule.description.ilike(pattern)
)
total = query.count()
items = query.order_by(DetectionRule.mitre_technique_id, DetectionRule.title).offset(offset).limit(limit).all()
return {
"total": total,
"offset": offset,
"limit": limit,
"items": [
{
"id": str(r.id),
"mitre_technique_id": r.mitre_technique_id,
"title": r.title,
"description": r.description,
"source": r.source,
"source_url": r.source_url,
"rule_format": r.rule_format,
"severity": r.severity,
"platforms": r.platforms or [],
"log_sources": r.log_sources,
"is_active": r.is_active,
}
for r in items
],
}
return list_rules(
db,
technique=technique,
source=source,
severity=severity,
search=search,
offset=offset,
limit=limit,
)
# ---------------------------------------------------------------------------
# GET /test-templates/{id}/detection-rules — Rules for a template
# ---------------------------------------------------------------------------
# ── GET /detection-rules/for-template/{template_id} ────────────────────
@router.get("/for-template/{template_id}")
def get_detection_rules_for_template(
@@ -111,46 +75,11 @@ def get_detection_rules_for_template(
current_user: User = Depends(get_current_user),
):
"""Get detection rules associated with a test template."""
template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first()
if not template:
raise HTTPException(status_code=404, detail="Test template not found")
associations = (
db.query(TestTemplateDetectionRule)
.filter(TestTemplateDetectionRule.test_template_id == template_id)
.all()
)
rules = []
for assoc in associations:
r = assoc.detection_rule
rules.append({
"id": str(r.id),
"mitre_technique_id": r.mitre_technique_id,
"title": r.title,
"description": r.description,
"source": r.source,
"source_url": r.source_url,
"rule_content": r.rule_content,
"rule_format": r.rule_format,
"severity": r.severity,
"platforms": r.platforms or [],
"log_sources": r.log_sources,
"is_primary": assoc.is_primary,
})
return {
"template_id": str(template.id),
"template_name": template.name,
"mitre_technique_id": template.mitre_technique_id,
"rules": rules,
"total": len(rules),
}
return get_rules_for_template(db, template_id)
# ---------------------------------------------------------------------------
# POST /detection-rules/auto-associate — Auto-link templates ↔ rules
# ---------------------------------------------------------------------------
# ── POST /detection-rules/auto-associate ────────────────────────────────
@router.post("/auto-associate")
def auto_associate_detection_rules(
@@ -163,60 +92,11 @@ def auto_associate_detection_rules(
technique and create associations. Rules with severity >= high are marked
as primary.
"""
templates = db.query(TestTemplate).filter(TestTemplate.is_active == True).all() # noqa: E712
rules = db.query(DetectionRule).filter(DetectionRule.is_active == True).all() # noqa: E712
# Index rules by technique
rules_by_technique: dict[str, list] = {}
for rule in rules:
tid = rule.mitre_technique_id
if tid not in rules_by_technique:
rules_by_technique[tid] = []
rules_by_technique[tid].append(rule)
created = 0
skipped = 0
high_severities = {"high", "critical"}
for template in templates:
matching_rules = rules_by_technique.get(template.mitre_technique_id, [])
for rule in matching_rules:
# Check if association already exists
existing = (
db.query(TestTemplateDetectionRule)
.filter(
TestTemplateDetectionRule.test_template_id == template.id,
TestTemplateDetectionRule.detection_rule_id == rule.id,
)
.first()
)
if existing:
skipped += 1
continue
is_primary = (rule.severity or "").lower() in high_severities
assoc = TestTemplateDetectionRule(
test_template_id=template.id,
detection_rule_id=rule.id,
is_primary=is_primary,
)
db.add(assoc)
created += 1
db.commit()
total = db.query(TestTemplateDetectionRule).count()
return {
"created": created,
"skipped": skipped,
"total_associations": total,
}
return auto_associate_rules(db)
# ---------------------------------------------------------------------------
# GET /detection-rules/for-test/{test_id} — Rules + results for a test
# ---------------------------------------------------------------------------
# ── GET /detection-rules/for-test/{test_id} ──────────────────────────────
@router.get("/for-test/{test_id}")
def get_detection_rules_for_test(
@@ -229,83 +109,11 @@ def get_detection_rules_for_test(
Finds rules by matching the test's technique_id to detection rules,
and returns any existing evaluation results.
"""
from app.models.test import Test
from app.models.technique import Technique
test = db.query(Test).filter(Test.id == test_id).first()
if not test:
raise HTTPException(status_code=404, detail="Test not found")
technique = db.query(Technique).filter(Technique.id == test.technique_id).first()
if not technique:
raise HTTPException(status_code=404, detail="Technique not found")
# Get detection rules for this technique
rules = (
db.query(DetectionRule)
.filter(
DetectionRule.mitre_technique_id == technique.mitre_id,
DetectionRule.is_active == True, # noqa: E712
)
.order_by(DetectionRule.severity.desc(), DetectionRule.title)
.all()
)
# Get existing results for this test
existing_results = (
db.query(TestDetectionResult)
.filter(TestDetectionResult.test_id == test_id)
.all()
)
results_map = {str(r.detection_rule_id): r for r in existing_results}
items = []
triggered_count = 0
evaluated_count = 0
for rule in rules:
result = results_map.get(str(rule.id))
triggered = result.triggered if result else None
notes = result.notes if result else None
evaluated_at = result.evaluated_at.isoformat() if result and result.evaluated_at else None
if triggered is not None:
evaluated_count += 1
if triggered:
triggered_count += 1
items.append({
"id": str(rule.id),
"mitre_technique_id": rule.mitre_technique_id,
"title": rule.title,
"description": rule.description,
"source": rule.source,
"source_url": rule.source_url,
"rule_content": rule.rule_content,
"rule_format": rule.rule_format,
"severity": rule.severity,
"platforms": rule.platforms or [],
"log_sources": rule.log_sources,
"triggered": triggered,
"notes": notes,
"evaluated_at": evaluated_at,
"result_id": str(result.id) if result else None,
})
return {
"test_id": str(test.id),
"mitre_technique_id": technique.mitre_id,
"rules": items,
"total": len(items),
"evaluated": evaluated_count,
"triggered": triggered_count,
"detection_rate": round(triggered_count / evaluated_count * 100, 1) if evaluated_count > 0 else 0,
}
return get_rules_for_test(db, test_id)
# ---------------------------------------------------------------------------
# POST /detection-rules/evaluate — Save detection result for a rule
# ---------------------------------------------------------------------------
# ── POST /detection-rules/evaluate ──────────────────────────────────────
@router.post("/evaluate")
def evaluate_detection_rule(
@@ -314,60 +122,11 @@ def evaluate_detection_rule(
current_user: User = Depends(require_any_role("blue_tech", "blue_lead")),
):
"""Save or update the evaluation result for a detection rule on a test."""
test_id = payload.test_id
detection_rule_id = payload.detection_rule_id
triggered = payload.triggered
notes = payload.notes
# Check test exists
from app.models.test import Test
test = db.query(Test).filter(Test.id == test_id).first()
if not test:
raise HTTPException(status_code=404, detail="Test not found")
# Check rule exists
rule = db.query(DetectionRule).filter(DetectionRule.id == detection_rule_id).first()
if not rule:
raise HTTPException(status_code=404, detail="Detection rule not found")
# Upsert result
existing = (
db.query(TestDetectionResult)
.filter(
TestDetectionResult.test_id == test_id,
TestDetectionResult.detection_rule_id == detection_rule_id,
)
.first()
return evaluate_rule(
db,
test_id=payload.test_id,
detection_rule_id=payload.detection_rule_id,
triggered=payload.triggered,
notes=payload.notes,
evaluator_id=current_user.id,
)
if existing:
existing.triggered = triggered
existing.notes = notes
existing.evaluated_by = current_user.id
existing.evaluated_at = datetime.utcnow()
db.commit()
db.refresh(existing)
return {
"id": str(existing.id),
"triggered": existing.triggered,
"notes": existing.notes,
"evaluated_at": existing.evaluated_at.isoformat() if existing.evaluated_at else None,
}
else:
result = TestDetectionResult(
test_id=test_id,
detection_rule_id=detection_rule_id,
triggered=triggered,
notes=notes,
evaluated_by=current_user.id,
evaluated_at=datetime.utcnow(),
)
db.add(result)
db.commit()
db.refresh(result)
return {
"id": str(result.id),
"triggered": result.triggered,
"notes": result.notes,
"evaluated_at": result.evaluated_at.isoformat() if result.evaluated_at else None,
}

View File

@@ -0,0 +1,319 @@
"""Detection rule data service.
Extracts query and business logic from the detection_rules router so
that the router remains a thin HTTP adapter.
This module is framework-agnostic: no FastAPI imports.
"""
from __future__ import annotations
from datetime import datetime
from typing import Any
from sqlalchemy.orm import Session
from app.domain.errors import EntityNotFoundError
from app.models.detection_rule import DetectionRule
from app.models.test import Test
from app.models.test_template import TestTemplate
from app.models.test_template_detection_rule import TestTemplateDetectionRule
from app.models.test_detection_result import TestDetectionResult
from app.models.technique import Technique
from app.utils import escape_like
# ── Public service functions ──────────────────────────────────────────
def list_rules(
db: Session,
*,
technique: str | None = None,
source: str | None = None,
severity: str | None = None,
search: str | None = None,
offset: int = 0,
limit: int = 50,
) -> dict[str, Any]:
"""List detection rules with optional filters and pagination."""
query = db.query(DetectionRule).filter(DetectionRule.is_active == True)
if technique:
query = query.filter(DetectionRule.mitre_technique_id == technique)
if source:
query = query.filter(DetectionRule.source == source)
if severity:
query = query.filter(DetectionRule.severity == severity)
if search:
pattern = f"%{escape_like(search)}%"
query = query.filter(
DetectionRule.title.ilike(pattern)
| DetectionRule.description.ilike(pattern)
)
total = query.count()
items = (
query.order_by(DetectionRule.mitre_technique_id, DetectionRule.title)
.offset(offset)
.limit(limit)
.all()
)
return {
"total": total,
"offset": offset,
"limit": limit,
"items": [
{
"id": str(r.id),
"mitre_technique_id": r.mitre_technique_id,
"title": r.title,
"description": r.description,
"source": r.source,
"source_url": r.source_url,
"rule_format": r.rule_format,
"severity": r.severity,
"platforms": r.platforms or [],
"log_sources": r.log_sources,
"is_active": r.is_active,
}
for r in items
],
}
def get_rules_for_template(db: Session, template_id: str) -> dict[str, Any]:
"""Get detection rules associated with a test template.
Raises EntityNotFoundError if the template does not exist.
"""
template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first()
if not template:
raise EntityNotFoundError("Test template", template_id)
associations = (
db.query(TestTemplateDetectionRule)
.filter(TestTemplateDetectionRule.test_template_id == template_id)
.all()
)
rules = []
for assoc in associations:
r = assoc.detection_rule
rules.append({
"id": str(r.id),
"mitre_technique_id": r.mitre_technique_id,
"title": r.title,
"description": r.description,
"source": r.source,
"source_url": r.source_url,
"rule_content": r.rule_content,
"rule_format": r.rule_format,
"severity": r.severity,
"platforms": r.platforms or [],
"log_sources": r.log_sources,
"is_primary": assoc.is_primary,
})
return {
"template_id": str(template.id),
"template_name": template.name,
"mitre_technique_id": template.mitre_technique_id,
"rules": rules,
"total": len(rules),
}
def auto_associate_rules(db: Session) -> dict[str, Any]:
"""Auto-associate test templates with detection rules by MITRE technique ID.
For each active template, finds all active detection rules for the same
technique and creates associations. Rules with severity high/critical
are marked as primary. Performs commit internally.
"""
templates = db.query(TestTemplate).filter(TestTemplate.is_active == True).all()
rules = db.query(DetectionRule).filter(DetectionRule.is_active == True).all()
rules_by_technique: dict[str, list] = {}
for rule in rules:
tid = rule.mitre_technique_id
if tid not in rules_by_technique:
rules_by_technique[tid] = []
rules_by_technique[tid].append(rule)
created = 0
skipped = 0
high_severities = {"high", "critical"}
for template in templates:
matching_rules = rules_by_technique.get(template.mitre_technique_id, [])
for rule in matching_rules:
existing = (
db.query(TestTemplateDetectionRule)
.filter(
TestTemplateDetectionRule.test_template_id == template.id,
TestTemplateDetectionRule.detection_rule_id == rule.id,
)
.first()
)
if existing:
skipped += 1
continue
is_primary = (rule.severity or "").lower() in high_severities
assoc = TestTemplateDetectionRule(
test_template_id=template.id,
detection_rule_id=rule.id,
is_primary=is_primary,
)
db.add(assoc)
created += 1
db.commit()
total = db.query(TestTemplateDetectionRule).count()
return {
"created": created,
"skipped": skipped,
"total_associations": total,
}
def get_rules_for_test(db: Session, test_id: str) -> dict[str, Any]:
"""Get detection rules relevant to a test, along with their evaluation results.
Finds rules by matching the test's technique to detection rules.
Raises EntityNotFoundError if the test or its technique does not exist.
"""
test = db.query(Test).filter(Test.id == test_id).first()
if not test:
raise EntityNotFoundError("Test", str(test_id))
technique = db.query(Technique).filter(Technique.id == test.technique_id).first()
if not technique:
raise EntityNotFoundError("Technique", str(test.technique_id))
rules = (
db.query(DetectionRule)
.filter(
DetectionRule.mitre_technique_id == technique.mitre_id,
DetectionRule.is_active == True,
)
.order_by(DetectionRule.severity.desc(), DetectionRule.title)
.all()
)
existing_results = (
db.query(TestDetectionResult)
.filter(TestDetectionResult.test_id == test_id)
.all()
)
results_map = {str(r.detection_rule_id): r for r in existing_results}
items = []
triggered_count = 0
evaluated_count = 0
for rule in rules:
result = results_map.get(str(rule.id))
triggered = result.triggered if result else None
notes = result.notes if result else None
evaluated_at = result.evaluated_at.isoformat() if result and result.evaluated_at else None
if triggered is not None:
evaluated_count += 1
if triggered:
triggered_count += 1
items.append({
"id": str(rule.id),
"mitre_technique_id": rule.mitre_technique_id,
"title": rule.title,
"description": rule.description,
"source": rule.source,
"source_url": rule.source_url,
"rule_content": rule.rule_content,
"rule_format": rule.rule_format,
"severity": rule.severity,
"platforms": rule.platforms or [],
"log_sources": rule.log_sources,
"triggered": triggered,
"notes": notes,
"evaluated_at": evaluated_at,
"result_id": str(result.id) if result else None,
})
return {
"test_id": str(test.id),
"mitre_technique_id": technique.mitre_id,
"rules": items,
"total": len(items),
"evaluated": evaluated_count,
"triggered": triggered_count,
"detection_rate": round(triggered_count / evaluated_count * 100, 1) if evaluated_count > 0 else 0,
}
def evaluate_rule(
db: Session,
*,
test_id: Any,
detection_rule_id: Any,
triggered: bool | None,
notes: str | None,
evaluator_id: Any,
) -> dict[str, Any]:
"""Save or update the evaluation result for a detection rule on a test.
Raises EntityNotFoundError if the test or detection rule does not exist.
"""
test = db.query(Test).filter(Test.id == test_id).first()
if not test:
raise EntityNotFoundError("Test", str(test_id))
rule = db.query(DetectionRule).filter(DetectionRule.id == detection_rule_id).first()
if not rule:
raise EntityNotFoundError("Detection rule", str(detection_rule_id))
existing = (
db.query(TestDetectionResult)
.filter(
TestDetectionResult.test_id == test_id,
TestDetectionResult.detection_rule_id == detection_rule_id,
)
.first()
)
if existing:
existing.triggered = triggered
existing.notes = notes
existing.evaluated_by = evaluator_id
existing.evaluated_at = datetime.utcnow()
db.commit()
db.refresh(existing)
return {
"id": str(existing.id),
"triggered": existing.triggered,
"notes": existing.notes,
"evaluated_at": existing.evaluated_at.isoformat() if existing.evaluated_at else None,
}
else:
result = TestDetectionResult(
test_id=test_id,
detection_rule_id=detection_rule_id,
triggered=triggered,
notes=notes,
evaluated_by=evaluator_id,
evaluated_at=datetime.utcnow(),
)
db.add(result)
db.commit()
db.refresh(result)
return {
"id": str(result.id),
"triggered": result.triggered,
"notes": result.notes,
"evaluated_at": result.evaluated_at.isoformat() if result.evaluated_at else None,
}