refactor(detection-rules): extract query/business logic to detection_rule_service, router is thin HTTP adapter

This commit is contained in:
2026-02-19 17:39:31 +01:00
parent d305db8794
commit 560fc0c9f0
7 changed files with 5853 additions and 282 deletions

View File

@@ -1,31 +1,32 @@
"""Detection rules endpoints — listing, filtering, and template association.
Thin HTTP adapter: delegates all query and business logic to detection_rule_service.
Provides endpoints for browsing detection rules, querying rules by technique,
and managing the template ↔ detection rule associations.
"""
import logging
import uuid
from typing import Optional
from datetime import datetime
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi import APIRouter, Depends, Query
from pydantic import BaseModel
from sqlalchemy import func
from sqlalchemy.orm import Session
from app.database import get_db
from app.dependencies.auth import get_current_user, require_role, require_any_role
from app.models.user import User
from app.models.detection_rule import DetectionRule
from app.models.test_template import TestTemplate
from app.models.test_template_detection_rule import TestTemplateDetectionRule
from app.models.test_detection_result import TestDetectionResult
from app.services.detection_rule_service import (
list_rules,
get_rules_for_template,
auto_associate_rules,
get_rules_for_test,
evaluate_rule,
)
# ---------------------------------------------------------------------------
# Pydantic schemas for request validation
# ---------------------------------------------------------------------------
# ── Pydantic schemas for request validation ────────────────────────────
class DetectionRuleEvaluate(BaseModel):
"""Payload for evaluating a detection rule against a test."""
@@ -34,14 +35,12 @@ class DetectionRuleEvaluate(BaseModel):
triggered: Optional[bool] = None
notes: Optional[str] = None
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/detection-rules", tags=["detection-rules"])
# ---------------------------------------------------------------------------
# GET /detection-rules — List with filters
# ---------------------------------------------------------------------------
# ── GET /detection-rules — List with filters ───────────────────────────
@router.get("")
def list_detection_rules(
@@ -55,54 +54,19 @@ def list_detection_rules(
current_user: User = Depends(get_current_user),
):
"""List detection rules with optional filters and pagination."""
query = db.query(DetectionRule).filter(DetectionRule.is_active == True) # noqa: E712
if technique:
query = query.filter(DetectionRule.mitre_technique_id == technique)
if source:
query = query.filter(DetectionRule.source == source)
if severity:
query = query.filter(DetectionRule.severity == severity)
if search:
from app.utils import escape_like
pattern = f"%{escape_like(search)}%"
query = query.filter(
DetectionRule.title.ilike(pattern)
| DetectionRule.description.ilike(pattern)
)
total = query.count()
items = query.order_by(DetectionRule.mitre_technique_id, DetectionRule.title).offset(offset).limit(limit).all()
return {
"total": total,
"offset": offset,
"limit": limit,
"items": [
{
"id": str(r.id),
"mitre_technique_id": r.mitre_technique_id,
"title": r.title,
"description": r.description,
"source": r.source,
"source_url": r.source_url,
"rule_format": r.rule_format,
"severity": r.severity,
"platforms": r.platforms or [],
"log_sources": r.log_sources,
"is_active": r.is_active,
}
for r in items
],
}
return list_rules(
db,
technique=technique,
source=source,
severity=severity,
search=search,
offset=offset,
limit=limit,
)
# ---------------------------------------------------------------------------
# GET /test-templates/{id}/detection-rules — Rules for a template
# ---------------------------------------------------------------------------
# ── GET /detection-rules/for-template/{template_id} ────────────────────
@router.get("/for-template/{template_id}")
def get_detection_rules_for_template(
@@ -111,46 +75,11 @@ def get_detection_rules_for_template(
current_user: User = Depends(get_current_user),
):
"""Get detection rules associated with a test template."""
template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first()
if not template:
raise HTTPException(status_code=404, detail="Test template not found")
associations = (
db.query(TestTemplateDetectionRule)
.filter(TestTemplateDetectionRule.test_template_id == template_id)
.all()
)
rules = []
for assoc in associations:
r = assoc.detection_rule
rules.append({
"id": str(r.id),
"mitre_technique_id": r.mitre_technique_id,
"title": r.title,
"description": r.description,
"source": r.source,
"source_url": r.source_url,
"rule_content": r.rule_content,
"rule_format": r.rule_format,
"severity": r.severity,
"platforms": r.platforms or [],
"log_sources": r.log_sources,
"is_primary": assoc.is_primary,
})
return {
"template_id": str(template.id),
"template_name": template.name,
"mitre_technique_id": template.mitre_technique_id,
"rules": rules,
"total": len(rules),
}
return get_rules_for_template(db, template_id)
# ---------------------------------------------------------------------------
# POST /detection-rules/auto-associate — Auto-link templates ↔ rules
# ---------------------------------------------------------------------------
# ── POST /detection-rules/auto-associate ────────────────────────────────
@router.post("/auto-associate")
def auto_associate_detection_rules(
@@ -163,60 +92,11 @@ def auto_associate_detection_rules(
technique and create associations. Rules with severity >= high are marked
as primary.
"""
templates = db.query(TestTemplate).filter(TestTemplate.is_active == True).all() # noqa: E712
rules = db.query(DetectionRule).filter(DetectionRule.is_active == True).all() # noqa: E712
# Index rules by technique
rules_by_technique: dict[str, list] = {}
for rule in rules:
tid = rule.mitre_technique_id
if tid not in rules_by_technique:
rules_by_technique[tid] = []
rules_by_technique[tid].append(rule)
created = 0
skipped = 0
high_severities = {"high", "critical"}
for template in templates:
matching_rules = rules_by_technique.get(template.mitre_technique_id, [])
for rule in matching_rules:
# Check if association already exists
existing = (
db.query(TestTemplateDetectionRule)
.filter(
TestTemplateDetectionRule.test_template_id == template.id,
TestTemplateDetectionRule.detection_rule_id == rule.id,
)
.first()
)
if existing:
skipped += 1
continue
is_primary = (rule.severity or "").lower() in high_severities
assoc = TestTemplateDetectionRule(
test_template_id=template.id,
detection_rule_id=rule.id,
is_primary=is_primary,
)
db.add(assoc)
created += 1
db.commit()
total = db.query(TestTemplateDetectionRule).count()
return {
"created": created,
"skipped": skipped,
"total_associations": total,
}
return auto_associate_rules(db)
# ---------------------------------------------------------------------------
# GET /detection-rules/for-test/{test_id} — Rules + results for a test
# ---------------------------------------------------------------------------
# ── GET /detection-rules/for-test/{test_id} ──────────────────────────────
@router.get("/for-test/{test_id}")
def get_detection_rules_for_test(
@@ -229,83 +109,11 @@ def get_detection_rules_for_test(
Finds rules by matching the test's technique_id to detection rules,
and returns any existing evaluation results.
"""
from app.models.test import Test
from app.models.technique import Technique
test = db.query(Test).filter(Test.id == test_id).first()
if not test:
raise HTTPException(status_code=404, detail="Test not found")
technique = db.query(Technique).filter(Technique.id == test.technique_id).first()
if not technique:
raise HTTPException(status_code=404, detail="Technique not found")
# Get detection rules for this technique
rules = (
db.query(DetectionRule)
.filter(
DetectionRule.mitre_technique_id == technique.mitre_id,
DetectionRule.is_active == True, # noqa: E712
)
.order_by(DetectionRule.severity.desc(), DetectionRule.title)
.all()
)
# Get existing results for this test
existing_results = (
db.query(TestDetectionResult)
.filter(TestDetectionResult.test_id == test_id)
.all()
)
results_map = {str(r.detection_rule_id): r for r in existing_results}
items = []
triggered_count = 0
evaluated_count = 0
for rule in rules:
result = results_map.get(str(rule.id))
triggered = result.triggered if result else None
notes = result.notes if result else None
evaluated_at = result.evaluated_at.isoformat() if result and result.evaluated_at else None
if triggered is not None:
evaluated_count += 1
if triggered:
triggered_count += 1
items.append({
"id": str(rule.id),
"mitre_technique_id": rule.mitre_technique_id,
"title": rule.title,
"description": rule.description,
"source": rule.source,
"source_url": rule.source_url,
"rule_content": rule.rule_content,
"rule_format": rule.rule_format,
"severity": rule.severity,
"platforms": rule.platforms or [],
"log_sources": rule.log_sources,
"triggered": triggered,
"notes": notes,
"evaluated_at": evaluated_at,
"result_id": str(result.id) if result else None,
})
return {
"test_id": str(test.id),
"mitre_technique_id": technique.mitre_id,
"rules": items,
"total": len(items),
"evaluated": evaluated_count,
"triggered": triggered_count,
"detection_rate": round(triggered_count / evaluated_count * 100, 1) if evaluated_count > 0 else 0,
}
return get_rules_for_test(db, test_id)
# ---------------------------------------------------------------------------
# POST /detection-rules/evaluate — Save detection result for a rule
# ---------------------------------------------------------------------------
# ── POST /detection-rules/evaluate ──────────────────────────────────────
@router.post("/evaluate")
def evaluate_detection_rule(
@@ -314,60 +122,11 @@ def evaluate_detection_rule(
current_user: User = Depends(require_any_role("blue_tech", "blue_lead")),
):
"""Save or update the evaluation result for a detection rule on a test."""
test_id = payload.test_id
detection_rule_id = payload.detection_rule_id
triggered = payload.triggered
notes = payload.notes
# Check test exists
from app.models.test import Test
test = db.query(Test).filter(Test.id == test_id).first()
if not test:
raise HTTPException(status_code=404, detail="Test not found")
# Check rule exists
rule = db.query(DetectionRule).filter(DetectionRule.id == detection_rule_id).first()
if not rule:
raise HTTPException(status_code=404, detail="Detection rule not found")
# Upsert result
existing = (
db.query(TestDetectionResult)
.filter(
TestDetectionResult.test_id == test_id,
TestDetectionResult.detection_rule_id == detection_rule_id,
)
.first()
return evaluate_rule(
db,
test_id=payload.test_id,
detection_rule_id=payload.detection_rule_id,
triggered=payload.triggered,
notes=payload.notes,
evaluator_id=current_user.id,
)
if existing:
existing.triggered = triggered
existing.notes = notes
existing.evaluated_by = current_user.id
existing.evaluated_at = datetime.utcnow()
db.commit()
db.refresh(existing)
return {
"id": str(existing.id),
"triggered": existing.triggered,
"notes": existing.notes,
"evaluated_at": existing.evaluated_at.isoformat() if existing.evaluated_at else None,
}
else:
result = TestDetectionResult(
test_id=test_id,
detection_rule_id=detection_rule_id,
triggered=triggered,
notes=notes,
evaluated_by=current_user.id,
evaluated_at=datetime.utcnow(),
)
db.add(result)
db.commit()
db.refresh(result)
return {
"id": str(result.id),
"triggered": result.triggered,
"notes": result.notes,
"evaluated_at": result.evaluated_at.isoformat() if result.evaluated_at else None,
}