perf(snapshot): remove N+1 queries in snapshot generation
- Replace per-technique calculate_technique_score loop with bulk_technique_scores() from scoring_service - Snapshot creation now runs ~10 fixed queries instead of N*5+N*5 (was ~2000+ for 200 techniques)
This commit is contained in:
@@ -2,6 +2,9 @@
|
||||
|
||||
Provides point-in-time coverage captures with normalised per-technique
|
||||
storage and temporal comparison between any two snapshots.
|
||||
|
||||
Uses ``bulk_technique_scores`` so that snapshot creation runs in a fixed
|
||||
number of SQL queries regardless of technique count.
|
||||
"""
|
||||
|
||||
import logging
|
||||
@@ -14,7 +17,10 @@ from sqlalchemy.orm import Session
|
||||
from app.models.technique import Technique
|
||||
from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState
|
||||
from app.models.enums import TechniqueStatus
|
||||
from app.services.scoring_service import calculate_technique_score, calculate_organization_score
|
||||
from app.services.scoring_service import (
|
||||
bulk_technique_scores,
|
||||
calculate_organization_score,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -31,14 +37,16 @@ def create_snapshot(
|
||||
) -> CoverageSnapshot:
|
||||
"""Capture the current coverage state into a new snapshot.
|
||||
|
||||
1. Fetch every technique with its status and score.
|
||||
2. Compute aggregate counts.
|
||||
3. Persist a ``CoverageSnapshot`` with normalised
|
||||
1. Bulk-fetch all technique scores in 5 aggregated queries.
|
||||
2. Walk the already-loaded techniques to count statuses.
|
||||
3. Compute the org score from the same bulk data.
|
||||
4. Persist a ``CoverageSnapshot`` with normalised
|
||||
``SnapshotTechniqueState`` rows.
|
||||
"""
|
||||
scores_map = bulk_technique_scores(db)
|
||||
|
||||
techniques = db.query(Technique).all()
|
||||
|
||||
# Aggregate counters
|
||||
validated_count = 0
|
||||
partial_count = 0
|
||||
not_covered_count = 0
|
||||
@@ -54,7 +62,6 @@ def create_snapshot(
|
||||
else (tech.status_global or "not_evaluated")
|
||||
)
|
||||
|
||||
# Count by status
|
||||
if status_value == "validated":
|
||||
validated_count += 1
|
||||
elif status_value == "partial":
|
||||
@@ -66,20 +73,17 @@ def create_snapshot(
|
||||
else:
|
||||
not_evaluated_count += 1
|
||||
|
||||
# Compute technique score
|
||||
score_data = calculate_technique_score(tech, db)
|
||||
entry = scores_map.get(tech.id, {})
|
||||
technique_rows.append({
|
||||
"technique_id": tech.id,
|
||||
"mitre_id": tech.mitre_id,
|
||||
"status": status_value,
|
||||
"score": score_data["total_score"],
|
||||
"score": entry.get("total_score", 0),
|
||||
})
|
||||
|
||||
# Organization score
|
||||
org_data = calculate_organization_score(db)
|
||||
org_score = org_data.get("overall_score", 0)
|
||||
|
||||
# Create the snapshot
|
||||
snapshot = CoverageSnapshot(
|
||||
name=name,
|
||||
organization_score=org_score,
|
||||
@@ -92,9 +96,8 @@ def create_snapshot(
|
||||
created_by=user_id,
|
||||
)
|
||||
db.add(snapshot)
|
||||
db.flush() # get snapshot.id
|
||||
db.flush()
|
||||
|
||||
# Create normalised technique state rows
|
||||
for row in technique_rows:
|
||||
state = SnapshotTechniqueState(
|
||||
snapshot_id=snapshot.id,
|
||||
|
||||
Reference in New Issue
Block a user