106 lines
3.4 KiB
Python
106 lines
3.4 KiB
Python
"""Scoring configuration persistence service.
|
|
|
|
Reads and writes scoring weights from the ``scoring_config`` table.
|
|
Falls back to environment-variable defaults (from ``Settings``) when
|
|
no row has been persisted yet.
|
|
|
|
This module is framework-agnostic: no FastAPI imports.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Any
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.config import settings
|
|
from app.domain.value_objects.scoring_weights import ScoringWeights
|
|
from app.models.scoring_config import ScoringConfig
|
|
|
|
|
|
def get_scoring_weights(db: Session) -> ScoringWeights:
|
|
"""Return the active scoring weights.
|
|
|
|
Reads the single ``scoring_config`` row. If the table is empty
|
|
(first run or migration just applied), falls back to the values
|
|
from the environment / ``Settings``.
|
|
"""
|
|
row = db.query(ScoringConfig).first()
|
|
if row is not None:
|
|
return ScoringWeights(
|
|
tests=row.weight_tests,
|
|
detection_rules=row.weight_detection_rules,
|
|
d3fend=row.weight_d3fend,
|
|
freshness=row.weight_freshness,
|
|
platform_diversity=row.weight_platform_diversity,
|
|
)
|
|
|
|
return ScoringWeights(
|
|
tests=float(settings.SCORING_WEIGHT_TESTS),
|
|
detection_rules=float(settings.SCORING_WEIGHT_DETECTION_RULES),
|
|
d3fend=float(settings.SCORING_WEIGHT_D3FEND),
|
|
freshness=float(settings.SCORING_WEIGHT_FRESHNESS),
|
|
platform_diversity=float(settings.SCORING_WEIGHT_PLATFORM_DIVERSITY),
|
|
)
|
|
|
|
|
|
def update_scoring_weights(
|
|
db: Session,
|
|
*,
|
|
tests: float | None = None,
|
|
detection_rules: float | None = None,
|
|
d3fend: float | None = None,
|
|
freshness: float | None = None,
|
|
platform_diversity: float | None = None,
|
|
) -> dict[str, Any]:
|
|
"""Upsert scoring weights into the database.
|
|
|
|
Only provided fields are overwritten; ``None`` values keep the
|
|
current (or default) value. Validates via ``ScoringWeights``
|
|
before persisting.
|
|
|
|
Returns a dict with ``weights`` and ``total``.
|
|
"""
|
|
current = get_scoring_weights(db)
|
|
|
|
new = ScoringWeights(
|
|
tests=tests if tests is not None else current.tests,
|
|
detection_rules=detection_rules if detection_rules is not None else current.detection_rules,
|
|
d3fend=d3fend if d3fend is not None else current.d3fend,
|
|
freshness=freshness if freshness is not None else current.freshness,
|
|
platform_diversity=platform_diversity if platform_diversity is not None else current.platform_diversity,
|
|
)
|
|
|
|
row = db.query(ScoringConfig).first()
|
|
if row is None:
|
|
row = ScoringConfig()
|
|
db.add(row)
|
|
|
|
row.weight_tests = new.tests
|
|
row.weight_detection_rules = new.detection_rules
|
|
row.weight_d3fend = new.d3fend
|
|
row.weight_freshness = new.freshness
|
|
row.weight_platform_diversity = new.platform_diversity
|
|
|
|
# Does not commit; caller (router) uses UnitOfWork.
|
|
return _weights_dict(new)
|
|
|
|
|
|
def get_weights_dict(db: Session) -> dict[str, Any]:
|
|
"""Return current weights as a serialisable dict."""
|
|
return _weights_dict(get_scoring_weights(db))
|
|
|
|
|
|
def _weights_dict(w: ScoringWeights) -> dict[str, Any]:
|
|
weights = {
|
|
"tests": w.tests,
|
|
"detection_rules": w.detection_rules,
|
|
"d3fend": w.d3fend,
|
|
"freshness": w.freshness,
|
|
"platform_diversity": w.platform_diversity,
|
|
}
|
|
return {
|
|
"weights": weights,
|
|
"total": sum(weights.values()),
|
|
}
|