perf(snapshot): remove N+1 queries in snapshot generation

- Replace per-technique calculate_technique_score loop with bulk_technique_scores() from scoring_service

- Snapshot creation now runs ~10 fixed queries instead of N*5+N*5 (was ~2000+ for 200 techniques)
This commit is contained in:
2026-02-18 12:22:24 +01:00
parent f0f59facdb
commit 98e8ca1eef

View File

@@ -2,6 +2,9 @@
Provides point-in-time coverage captures with normalised per-technique Provides point-in-time coverage captures with normalised per-technique
storage and temporal comparison between any two snapshots. storage and temporal comparison between any two snapshots.
Uses ``bulk_technique_scores`` so that snapshot creation runs in a fixed
number of SQL queries regardless of technique count.
""" """
import logging import logging
@@ -14,7 +17,10 @@ from sqlalchemy.orm import Session
from app.models.technique import Technique from app.models.technique import Technique
from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState
from app.models.enums import TechniqueStatus from app.models.enums import TechniqueStatus
from app.services.scoring_service import calculate_technique_score, calculate_organization_score from app.services.scoring_service import (
bulk_technique_scores,
calculate_organization_score,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -31,14 +37,16 @@ def create_snapshot(
) -> CoverageSnapshot: ) -> CoverageSnapshot:
"""Capture the current coverage state into a new snapshot. """Capture the current coverage state into a new snapshot.
1. Fetch every technique with its status and score. 1. Bulk-fetch all technique scores in 5 aggregated queries.
2. Compute aggregate counts. 2. Walk the already-loaded techniques to count statuses.
3. Persist a ``CoverageSnapshot`` with normalised 3. Compute the org score from the same bulk data.
4. Persist a ``CoverageSnapshot`` with normalised
``SnapshotTechniqueState`` rows. ``SnapshotTechniqueState`` rows.
""" """
scores_map = bulk_technique_scores(db)
techniques = db.query(Technique).all() techniques = db.query(Technique).all()
# Aggregate counters
validated_count = 0 validated_count = 0
partial_count = 0 partial_count = 0
not_covered_count = 0 not_covered_count = 0
@@ -54,7 +62,6 @@ def create_snapshot(
else (tech.status_global or "not_evaluated") else (tech.status_global or "not_evaluated")
) )
# Count by status
if status_value == "validated": if status_value == "validated":
validated_count += 1 validated_count += 1
elif status_value == "partial": elif status_value == "partial":
@@ -66,20 +73,17 @@ def create_snapshot(
else: else:
not_evaluated_count += 1 not_evaluated_count += 1
# Compute technique score entry = scores_map.get(tech.id, {})
score_data = calculate_technique_score(tech, db)
technique_rows.append({ technique_rows.append({
"technique_id": tech.id, "technique_id": tech.id,
"mitre_id": tech.mitre_id, "mitre_id": tech.mitre_id,
"status": status_value, "status": status_value,
"score": score_data["total_score"], "score": entry.get("total_score", 0),
}) })
# Organization score
org_data = calculate_organization_score(db) org_data = calculate_organization_score(db)
org_score = org_data.get("overall_score", 0) org_score = org_data.get("overall_score", 0)
# Create the snapshot
snapshot = CoverageSnapshot( snapshot = CoverageSnapshot(
name=name, name=name,
organization_score=org_score, organization_score=org_score,
@@ -92,9 +96,8 @@ def create_snapshot(
created_by=user_id, created_by=user_id,
) )
db.add(snapshot) db.add(snapshot)
db.flush() # get snapshot.id db.flush()
# Create normalised technique state rows
for row in technique_rows: for row in technique_rows:
state = SnapshotTechniqueState( state = SnapshotTechniqueState(
snapshot_id=snapshot.id, snapshot_id=snapshot.id,