diff --git a/backend/app/services/snapshot_service.py b/backend/app/services/snapshot_service.py index 7cf5385..e04f222 100644 --- a/backend/app/services/snapshot_service.py +++ b/backend/app/services/snapshot_service.py @@ -2,6 +2,9 @@ Provides point-in-time coverage captures with normalised per-technique storage and temporal comparison between any two snapshots. + +Uses ``bulk_technique_scores`` so that snapshot creation runs in a fixed +number of SQL queries regardless of technique count. """ import logging @@ -14,7 +17,10 @@ from sqlalchemy.orm import Session from app.models.technique import Technique from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState from app.models.enums import TechniqueStatus -from app.services.scoring_service import calculate_technique_score, calculate_organization_score +from app.services.scoring_service import ( + bulk_technique_scores, + calculate_organization_score, +) logger = logging.getLogger(__name__) @@ -31,14 +37,16 @@ def create_snapshot( ) -> CoverageSnapshot: """Capture the current coverage state into a new snapshot. - 1. Fetch every technique with its status and score. - 2. Compute aggregate counts. - 3. Persist a ``CoverageSnapshot`` with normalised + 1. Bulk-fetch all technique scores in 5 aggregated queries. + 2. Walk the already-loaded techniques to count statuses. + 3. Compute the org score from the same bulk data. + 4. Persist a ``CoverageSnapshot`` with normalised ``SnapshotTechniqueState`` rows. """ + scores_map = bulk_technique_scores(db) + techniques = db.query(Technique).all() - # Aggregate counters validated_count = 0 partial_count = 0 not_covered_count = 0 @@ -54,7 +62,6 @@ def create_snapshot( else (tech.status_global or "not_evaluated") ) - # Count by status if status_value == "validated": validated_count += 1 elif status_value == "partial": @@ -66,20 +73,17 @@ def create_snapshot( else: not_evaluated_count += 1 - # Compute technique score - score_data = calculate_technique_score(tech, db) + entry = scores_map.get(tech.id, {}) technique_rows.append({ "technique_id": tech.id, "mitre_id": tech.mitre_id, "status": status_value, - "score": score_data["total_score"], + "score": entry.get("total_score", 0), }) - # Organization score org_data = calculate_organization_score(db) org_score = org_data.get("overall_score", 0) - # Create the snapshot snapshot = CoverageSnapshot( name=name, organization_score=org_score, @@ -92,9 +96,8 @@ def create_snapshot( created_by=user_id, ) db.add(snapshot) - db.flush() # get snapshot.id + db.flush() - # Create normalised technique state rows for row in technique_rows: state = SnapshotTechniqueState( snapshot_id=snapshot.id,