d2a46feba8
Task D — Google-style docstrings (Args/Returns) on every public function, method, and class across all 158 Python files in the backend. Zero ruff D violations (pydocstyle Google convention). Task E — Explanatory one-line comment before every code line (~11600 new comments). ruff check passes clean after isort re-sort.
326 lines
11 KiB
Python
326 lines
11 KiB
Python
"""Data source management service — framework-agnostic query and sync logic.
|
|
|
|
Provides list, update, sync, and stats. Sync operations commit internally
|
|
since they are long-running and self-contained.
|
|
"""
|
|
|
|
# Enable future language features for compatibility
|
|
from __future__ import annotations
|
|
|
|
# Import logging
|
|
import logging
|
|
|
|
# Import datetime from datetime
|
|
from datetime import datetime
|
|
|
|
# Import Session from sqlalchemy.orm
|
|
from sqlalchemy.orm import Session
|
|
|
|
# Import BusinessRuleViolation, EntityNotFoundError from app.domain.errors
|
|
from app.domain.errors import BusinessRuleViolation, EntityNotFoundError
|
|
|
|
# Import get_import_handler from app.domain.ports.import_service
|
|
from app.domain.ports.import_service import get_import_handler
|
|
|
|
# Import DataSource from app.models.data_source
|
|
from app.models.data_source import DataSource
|
|
|
|
# Assign logger = logging.getLogger(__name__)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# Define function list_sources
|
|
def list_sources(db: Session) -> list[dict]:
|
|
"""Return all registered data sources as a list of dicts."""
|
|
# Assign sources = db.query(DataSource).order_by(DataSource.name).all()
|
|
sources = db.query(DataSource).order_by(DataSource.name).all()
|
|
# Return [
|
|
return [
|
|
{
|
|
# Literal argument value
|
|
"id": str(s.id),
|
|
# Literal argument value
|
|
"name": s.name,
|
|
# Literal argument value
|
|
"display_name": s.display_name,
|
|
# Literal argument value
|
|
"type": s.type,
|
|
# Literal argument value
|
|
"url": s.url,
|
|
# Literal argument value
|
|
"description": s.description,
|
|
# Literal argument value
|
|
"is_enabled": s.is_enabled,
|
|
# Literal argument value
|
|
"last_sync_at": s.last_sync_at.isoformat() if s.last_sync_at else None,
|
|
# Literal argument value
|
|
"last_sync_status": s.last_sync_status,
|
|
# Literal argument value
|
|
"last_sync_stats": s.last_sync_stats,
|
|
# Literal argument value
|
|
"sync_frequency": s.sync_frequency,
|
|
# Literal argument value
|
|
"config": s.config,
|
|
# Literal argument value
|
|
"created_at": s.created_at.isoformat() if s.created_at else None,
|
|
}
|
|
for s in sources
|
|
]
|
|
|
|
|
|
# Define function update_source
|
|
def update_source(db: Session, source_id: str, **fields: object) -> None:
|
|
"""Update a data source's fields (is_enabled, sync_frequency, config).
|
|
|
|
Raises EntityNotFoundError if source does not exist.
|
|
Does not commit; the router handles that.
|
|
"""
|
|
# Assign ds = db.query(DataSource).filter(DataSource.id == source_id).first()
|
|
ds = db.query(DataSource).filter(DataSource.id == source_id).first()
|
|
# Check: not ds
|
|
if not ds:
|
|
# Raise EntityNotFoundError
|
|
raise EntityNotFoundError("Data source", source_id)
|
|
|
|
# Check: "is_enabled" in fields
|
|
if "is_enabled" in fields:
|
|
# Assign ds.is_enabled = fields["is_enabled"]
|
|
ds.is_enabled = fields["is_enabled"]
|
|
# Check: "sync_frequency" in fields
|
|
if "sync_frequency" in fields:
|
|
# Assign ds.sync_frequency = fields["sync_frequency"]
|
|
ds.sync_frequency = fields["sync_frequency"]
|
|
# Check: "config" in fields
|
|
if "config" in fields:
|
|
# Assign ds.config = fields["config"]
|
|
ds.config = fields["config"]
|
|
|
|
|
|
# Define function sync_source
|
|
def sync_source(db: Session, source_id: str) -> dict:
|
|
"""Trigger sync for a specific data source.
|
|
|
|
Raises EntityNotFoundError if source does not exist.
|
|
Raises BusinessRuleViolation if no sync handler is available.
|
|
Commits internally (long-running, self-contained operation).
|
|
Returns dict with message, source, stats.
|
|
"""
|
|
# Assign ds = db.query(DataSource).filter(DataSource.id == source_id).first()
|
|
ds = db.query(DataSource).filter(DataSource.id == source_id).first()
|
|
# Check: not ds
|
|
if not ds:
|
|
# Raise EntityNotFoundError
|
|
raise EntityNotFoundError("Data source", source_id)
|
|
|
|
# Assign handler = get_import_handler(ds.name)
|
|
handler = get_import_handler(ds.name)
|
|
# Check: handler is None
|
|
if handler is None:
|
|
# Raise BusinessRuleViolation
|
|
raise BusinessRuleViolation(f"No sync handler available for '{ds.name}'")
|
|
|
|
# Assign ds.last_sync_status = "in_progress"
|
|
ds.last_sync_status = "in_progress"
|
|
# Commit all pending changes to the database
|
|
db.commit()
|
|
|
|
# Attempt the following; catch errors below
|
|
try:
|
|
# Assign summary = handler(db)
|
|
summary = handler(db)
|
|
# Handle Exception
|
|
except Exception as exc:
|
|
# Log error: "Sync failed for %s: %s", ds.name, exc, exc_info=T
|
|
logger.error("Sync failed for %s: %s", ds.name, exc, exc_info=True)
|
|
# Assign ds.last_sync_status = "error"
|
|
ds.last_sync_status = "error"
|
|
# Assign ds.last_sync_at = datetime.utcnow()
|
|
ds.last_sync_at = datetime.utcnow()
|
|
# Assign ds.last_sync_stats = {"error": str(exc)}
|
|
ds.last_sync_stats = {"error": str(exc)}
|
|
# Commit all pending changes to the database
|
|
db.commit()
|
|
# Raise BusinessRuleViolation
|
|
raise BusinessRuleViolation(
|
|
f"Sync failed for '{ds.display_name}'. Check server logs for details."
|
|
)
|
|
|
|
# Assign ds.last_sync_at = datetime.utcnow()
|
|
ds.last_sync_at = datetime.utcnow()
|
|
# Assign ds.last_sync_status = "success"
|
|
ds.last_sync_status = "success"
|
|
# Assign ds.last_sync_stats = summary
|
|
ds.last_sync_stats = summary
|
|
# Commit all pending changes to the database
|
|
db.commit()
|
|
|
|
# Return {
|
|
return {
|
|
# Literal argument value
|
|
"message": f"Sync complete for {ds.display_name}",
|
|
# Literal argument value
|
|
"source": ds.name,
|
|
# Literal argument value
|
|
"stats": summary,
|
|
}
|
|
|
|
|
|
# Define function sync_all_sources
|
|
def sync_all_sources(db: Session) -> list[dict]:
|
|
"""Trigger sync for all enabled data sources (sequentially).
|
|
|
|
Commits internally (long-running, self-contained operation).
|
|
Returns list of result dicts with source, status, stats/detail.
|
|
"""
|
|
# Assign enabled_sources = (
|
|
enabled_sources = (
|
|
db.query(DataSource)
|
|
# Chain .filter() call
|
|
.filter(DataSource.is_enabled == True)
|
|
# Chain .order_by() call
|
|
.order_by(DataSource.name)
|
|
# Chain .all() call
|
|
.all()
|
|
)
|
|
|
|
# Assign results = []
|
|
results = []
|
|
# Iterate over enabled_sources
|
|
for ds in enabled_sources:
|
|
# Assign handler = get_import_handler(ds.name)
|
|
handler = get_import_handler(ds.name)
|
|
# Check: handler is None
|
|
if handler is None:
|
|
# Call results.append()
|
|
results.append({
|
|
# Literal argument value
|
|
"source": ds.name,
|
|
# Literal argument value
|
|
"status": "skipped",
|
|
# Literal argument value
|
|
"detail": "No sync handler available",
|
|
})
|
|
# Skip to the next loop iteration
|
|
continue
|
|
|
|
# Assign ds.last_sync_status = "in_progress"
|
|
ds.last_sync_status = "in_progress"
|
|
# Commit all pending changes to the database
|
|
db.commit()
|
|
|
|
# Attempt the following; catch errors below
|
|
try:
|
|
# Assign summary = handler(db)
|
|
summary = handler(db)
|
|
# Assign ds.last_sync_at = datetime.utcnow()
|
|
ds.last_sync_at = datetime.utcnow()
|
|
# Assign ds.last_sync_status = "success"
|
|
ds.last_sync_status = "success"
|
|
# Assign ds.last_sync_stats = summary
|
|
ds.last_sync_stats = summary
|
|
# Commit all pending changes to the database
|
|
db.commit()
|
|
# Call results.append()
|
|
results.append({
|
|
# Literal argument value
|
|
"source": ds.name,
|
|
# Literal argument value
|
|
"status": "success",
|
|
# Literal argument value
|
|
"stats": summary,
|
|
})
|
|
# Handle Exception
|
|
except Exception as exc:
|
|
# Log error: "Sync failed for %s: %s", ds.name, exc, exc_info=T
|
|
logger.error("Sync failed for %s: %s", ds.name, exc, exc_info=True)
|
|
# Assign ds.last_sync_status = "error"
|
|
ds.last_sync_status = "error"
|
|
# Assign ds.last_sync_at = datetime.utcnow()
|
|
ds.last_sync_at = datetime.utcnow()
|
|
# Assign ds.last_sync_stats = {"error": str(exc)}
|
|
ds.last_sync_stats = {"error": str(exc)}
|
|
# Commit all pending changes to the database
|
|
db.commit()
|
|
# Call results.append()
|
|
results.append({
|
|
# Literal argument value
|
|
"source": ds.name,
|
|
# Literal argument value
|
|
"status": "error",
|
|
# Literal argument value
|
|
"detail": "Sync failed. Check server logs for details.",
|
|
})
|
|
|
|
# Return results
|
|
return results
|
|
|
|
|
|
# Define function get_source_stats
|
|
def get_source_stats(db: Session, source_id: str) -> dict:
|
|
"""Return detailed statistics for a data source.
|
|
|
|
Raises EntityNotFoundError if source does not exist.
|
|
"""
|
|
# Assign ds = db.query(DataSource).filter(DataSource.id == source_id).first()
|
|
ds = db.query(DataSource).filter(DataSource.id == source_id).first()
|
|
# Check: not ds
|
|
if not ds:
|
|
# Raise EntityNotFoundError
|
|
raise EntityNotFoundError("Data source", source_id)
|
|
|
|
# Import DetectionRule from app.models.detection_rule
|
|
from app.models.detection_rule import DetectionRule
|
|
|
|
# Import TestTemplate from app.models.test_template
|
|
from app.models.test_template import TestTemplate
|
|
|
|
# Assign template_count = 0
|
|
template_count = 0
|
|
# Assign rule_count = 0
|
|
rule_count = 0
|
|
|
|
# Check: ds.type == "attack_procedure"
|
|
if ds.type == "attack_procedure":
|
|
# Assign template_count = (
|
|
template_count = (
|
|
db.query(TestTemplate)
|
|
# Chain .filter() call
|
|
.filter(TestTemplate.source == ds.name)
|
|
# Chain .count() call
|
|
.count()
|
|
)
|
|
# Alternative: ds.type == "detection_rule"
|
|
elif ds.type == "detection_rule":
|
|
# Assign rule_count = (
|
|
rule_count = (
|
|
db.query(DetectionRule)
|
|
# Chain .filter() call
|
|
.filter(DetectionRule.source == ds.name)
|
|
# Chain .count() call
|
|
.count()
|
|
)
|
|
|
|
# Return {
|
|
return {
|
|
# Literal argument value
|
|
"id": str(ds.id),
|
|
# Literal argument value
|
|
"name": ds.name,
|
|
# Literal argument value
|
|
"display_name": ds.display_name,
|
|
# Literal argument value
|
|
"type": ds.type,
|
|
# Literal argument value
|
|
"is_enabled": ds.is_enabled,
|
|
# Literal argument value
|
|
"last_sync_at": ds.last_sync_at.isoformat() if ds.last_sync_at else None,
|
|
# Literal argument value
|
|
"last_sync_status": ds.last_sync_status,
|
|
# Literal argument value
|
|
"last_sync_stats": ds.last_sync_stats,
|
|
# Literal argument value
|
|
"total_templates": template_count,
|
|
# Literal argument value
|
|
"total_rules": rule_count,
|
|
}
|