Compare commits

...

3 Commits

Author SHA1 Message Date
kitos bfce1a8a0e refactor(core): introduce Unit of Work and remove commits from services
Aegis CI / lint-and-test (push) Has been cancelled
- Add UnitOfWork context manager in domain/unit_of_work.py with commit/rollback/flush API and auto-rollback on exception

- Remove all db.commit() from test_workflow_service (8 calls), notification_service (4 calls), status_service (1 call)

- Services now only stage changes via db.add/db.flush; caller owns the transaction boundary

- Update routers/tests.py: wrap 9 workflow endpoints in UnitOfWork context managers

- Update routers/notifications.py: wrap mark_as_read and mark_all_as_read in UnitOfWork
2026-02-18 12:51:55 +01:00
kitos 98e8ca1eef perf(snapshot): remove N+1 queries in snapshot generation
- Replace per-technique calculate_technique_score loop with bulk_technique_scores() from scoring_service

- Snapshot creation now runs ~10 fixed queries instead of N*5+N*5 (was ~2000+ for 200 techniques)
2026-02-18 12:22:24 +01:00
kitos f0f59facdb perf(scoring): eliminate N+1 in organization score calculation
- Add bulk_technique_scores() that pre-fetches all scoring data in 5 aggregated GROUP BY queries instead of N*5 per-technique queries

- Rewrite calculate_organization_score to use bulk data (N*5+5 queries -> 10 fixed queries)

- Rewrite calculate_tactic_score and calculate_actor_coverage_score to use bulk data

- Preserve calculate_technique_score single-technique API for router-level calls
2026-02-18 12:18:48 +01:00
8 changed files with 406 additions and 169 deletions
+47
View File
@@ -0,0 +1,47 @@
"""Unit of Work — wraps a SQLAlchemy session for explicit transaction control.
Usage in routers::
with UnitOfWork(db) as uow:
service_a(db, ...)
service_b(db, ...)
uow.commit() # single commit for the entire operation
If an exception propagates, ``__exit__`` issues a rollback automatically.
Services should **never** call ``db.commit()``; they use ``db.add()`` /
``db.flush()`` to stage work and let the caller decide when to commit.
"""
from __future__ import annotations
from sqlalchemy.orm import Session
class UnitOfWork:
"""Lightweight transaction wrapper around an existing SQLAlchemy session."""
def __init__(self, session: Session) -> None:
self._session = session
# -- context manager -----------------------------------------------------
def __enter__(self) -> "UnitOfWork":
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
if exc_type is not None:
self.rollback()
# -- public API ----------------------------------------------------------
def commit(self) -> None:
"""Flush pending changes and commit the transaction."""
self._session.commit()
def rollback(self) -> None:
"""Roll back the current transaction."""
self._session.rollback()
def flush(self) -> None:
"""Flush pending changes without committing (useful for getting IDs)."""
self._session.flush()
+12 -7
View File
@@ -15,6 +15,7 @@ from sqlalchemy.orm import Session
from app.database import get_db from app.database import get_db
from app.dependencies.auth import get_current_user from app.dependencies.auth import get_current_user
from app.domain.unit_of_work import UnitOfWork
from app.models.notification import Notification from app.models.notification import Notification
from app.models.user import User from app.models.user import User
from app.schemas.notification import NotificationOut, UnreadCountOut from app.schemas.notification import NotificationOut, UnreadCountOut
@@ -78,12 +79,14 @@ def read_notification(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ):
"""Mark a single notification as read.""" """Mark a single notification as read."""
success = mark_as_read(db, notification_id, current_user.id) with UnitOfWork(db) as uow:
if not success: success = mark_as_read(db, notification_id, current_user.id)
raise HTTPException( if not success:
status_code=status.HTTP_404_NOT_FOUND, raise HTTPException(
detail="Notification not found", status_code=status.HTTP_404_NOT_FOUND,
) detail="Notification not found",
)
uow.commit()
notif = db.query(Notification).filter(Notification.id == notification_id).first() notif = db.query(Notification).filter(Notification.id == notification_id).first()
return notif return notif
@@ -99,5 +102,7 @@ def read_all_notifications(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ):
"""Mark all notifications for the current user as read.""" """Mark all notifications for the current user as read."""
count = mark_all_as_read(db, current_user.id) with UnitOfWork(db) as uow:
count = mark_all_as_read(db, current_user.id)
uow.commit()
return {"detail": f"Marked {count} notifications as read"} return {"detail": f"Marked {count} notifications as read"}
+53 -44
View File
@@ -43,6 +43,7 @@ from app.schemas.test import (
TestRemediationUpdate, TestRemediationUpdate,
) )
from app.schemas.test_template import TestTemplateInstantiate from app.schemas.test_template import TestTemplateInstantiate
from app.domain.unit_of_work import UnitOfWork
from app.services.audit_service import log_action from app.services.audit_service import log_action
from app.services.status_service import recalculate_technique_status from app.services.status_service import recalculate_technique_status
from app.services.test_workflow_service import ( from app.services.test_workflow_service import (
@@ -434,7 +435,9 @@ def start_execution(
): ):
"""Move a test from ``draft`` to ``red_executing``.""" """Move a test from ``draft`` to ``red_executing``."""
test = _get_test_or_404(db, test_id) test = _get_test_or_404(db, test_id)
test = wf_start_execution(db, test, current_user) with UnitOfWork(db) as uow:
test = wf_start_execution(db, test, current_user)
uow.commit()
db.refresh(test) db.refresh(test)
return test return test
@@ -452,7 +455,9 @@ def submit_red(
): ):
"""Red Team finalises — move from ``red_executing`` to ``blue_evaluating``.""" """Red Team finalises — move from ``red_executing`` to ``blue_evaluating``."""
test = _get_test_or_404(db, test_id) test = _get_test_or_404(db, test_id)
test = wf_submit_red(db, test, current_user) with UnitOfWork(db) as uow:
test = wf_submit_red(db, test, current_user)
uow.commit()
db.refresh(test) db.refresh(test)
return test return test
@@ -470,7 +475,9 @@ def submit_blue(
): ):
"""Blue Team finalises — move from ``blue_evaluating`` to ``in_review``.""" """Blue Team finalises — move from ``blue_evaluating`` to ``in_review``."""
test = _get_test_or_404(db, test_id) test = _get_test_or_404(db, test_id)
test = wf_submit_blue(db, test, current_user) with UnitOfWork(db) as uow:
test = wf_submit_blue(db, test, current_user)
uow.commit()
db.refresh(test) db.refresh(test)
return test return test
@@ -488,7 +495,9 @@ def pause_timer(
): ):
"""Pause the running timer for the current phase (red_executing or blue_evaluating).""" """Pause the running timer for the current phase (red_executing or blue_evaluating)."""
test = _get_test_or_404(db, test_id) test = _get_test_or_404(db, test_id)
test = wf_pause_timer(db, test, current_user) with UnitOfWork(db) as uow:
test = wf_pause_timer(db, test, current_user)
uow.commit()
db.refresh(test) db.refresh(test)
return test return test
@@ -506,7 +515,9 @@ def resume_timer(
): ):
"""Resume the paused timer for the current phase.""" """Resume the paused timer for the current phase."""
test = _get_test_or_404(db, test_id) test = _get_test_or_404(db, test_id)
test = wf_resume_timer(db, test, current_user) with UnitOfWork(db) as uow:
test = wf_resume_timer(db, test, current_user)
uow.commit()
db.refresh(test) db.refresh(test)
return test return test
@@ -525,16 +536,15 @@ def validate_red(
): ):
"""Red Lead approves or rejects the red side of a test.""" """Red Lead approves or rejects the red side of a test."""
test = _get_test_with_technique(db, test_id) test = _get_test_with_technique(db, test_id)
test = wf_validate_red( with UnitOfWork(db) as uow:
db, test, current_user, test = wf_validate_red(
validation_status=payload.red_validation_status, db, test, current_user,
notes=payload.red_validation_notes, validation_status=payload.red_validation_status,
) notes=payload.red_validation_notes,
)
# Recalculate technique status if test reached a terminal state if test.state in (TestState.validated, TestState.rejected):
if test.state in (TestState.validated, TestState.rejected): recalculate_technique_status(db, test.technique)
recalculate_technique_status(db, test.technique) uow.commit()
db.refresh(test) db.refresh(test)
return test return test
@@ -553,16 +563,15 @@ def validate_blue(
): ):
"""Blue Lead approves or rejects the blue side of a test.""" """Blue Lead approves or rejects the blue side of a test."""
test = _get_test_with_technique(db, test_id) test = _get_test_with_technique(db, test_id)
test = wf_validate_blue( with UnitOfWork(db) as uow:
db, test, current_user, test = wf_validate_blue(
validation_status=payload.blue_validation_status, db, test, current_user,
notes=payload.blue_validation_notes, validation_status=payload.blue_validation_status,
) notes=payload.blue_validation_notes,
)
# Recalculate technique status if test reached a terminal state if test.state in (TestState.validated, TestState.rejected):
if test.state in (TestState.validated, TestState.rejected): recalculate_technique_status(db, test.technique)
recalculate_technique_status(db, test.technique) uow.commit()
db.refresh(test) db.refresh(test)
return test return test
@@ -580,7 +589,9 @@ def reopen(
): ):
"""Reopen a rejected test, moving it back to ``draft``.""" """Reopen a rejected test, moving it back to ``draft``."""
test = _get_test_or_404(db, test_id) test = _get_test_or_404(db, test_id)
test = wf_reopen(db, test, current_user) with UnitOfWork(db) as uow:
test = wf_reopen(db, test, current_user)
uow.commit()
db.refresh(test) db.refresh(test)
return test return test
@@ -610,25 +621,23 @@ def update_remediation(
for field, value in update_data.items(): for field, value in update_data.items():
setattr(test, field, value) setattr(test, field, value)
db.commit() with UnitOfWork(db) as uow:
log_action(
db,
user_id=current_user.id,
action="update_remediation",
entity_type="test",
entity_id=test.id,
details={"updated_fields": list(update_data.keys())},
)
new_status = update_data.get("remediation_status")
if new_status == "completed" and old_remediation_status != "completed":
wf_handle_remediation(db, test, current_user)
uow.commit()
db.refresh(test) db.refresh(test)
log_action(
db,
user_id=current_user.id,
action="update_remediation",
entity_type="test",
entity_id=test.id,
details={"updated_fields": list(update_data.keys())},
)
# Auto-create retest when remediation is marked completed
new_status = update_data.get("remediation_status")
if new_status == "completed" and old_remediation_status != "completed":
retest = wf_handle_remediation(db, test, current_user)
if retest:
db.refresh(test)
return test return test
+4 -5
View File
@@ -2,6 +2,9 @@
Provides helpers for generating notifications automatically when test Provides helpers for generating notifications automatically when test
state changes occur, plus CRUD for the notifications API. state changes occur, plus CRUD for the notifications API.
Functions in this module stage changes via ``db.add()`` / ``db.flush()``
but do **not** commit. The caller is responsible for committing.
""" """
import uuid import uuid
@@ -38,8 +41,7 @@ def create_notification(
entity_id=entity_id, entity_id=entity_id,
) )
db.add(notif) db.add(notif)
db.commit() db.flush()
db.refresh(notif)
return notif return notif
@@ -53,7 +55,6 @@ def mark_as_read(db: Session, notification_id: uuid.UUID, user_id: uuid.UUID) ->
if notif is None: if notif is None:
return False return False
notif.read = True notif.read = True
db.commit()
return True return True
@@ -64,7 +65,6 @@ def mark_all_as_read(db: Session, user_id: uuid.UUID) -> int:
.filter(Notification.user_id == user_id, Notification.read == False) # noqa: E712 .filter(Notification.user_id == user_id, Notification.read == False) # noqa: E712
.update({"read": True}) .update({"read": True})
) )
db.commit()
return count return count
@@ -88,7 +88,6 @@ def cleanup_old_notifications(db: Session, days: int = 90) -> int:
) )
.delete() .delete()
) )
db.commit()
return count return count
+267 -82
View File
@@ -2,12 +2,16 @@
Uses configurable weights from Settings to compute coverage scores with Uses configurable weights from Settings to compute coverage scores with
detailed breakdowns. detailed breakdowns.
Bulk helpers (``bulk_technique_scores``) pre-fetch all scoring data in a
fixed number of aggregated queries so that organisation-wide calculations
never produce N+1 traffic.
""" """
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Optional from typing import Optional
from sqlalchemy import func from sqlalchemy import case, func
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.config import settings from app.config import settings
@@ -20,7 +24,219 @@ from app.models.threat_actor import ThreatActor, ThreatActorTechnique
from app.models.enums import TestState, TestResult from app.models.enums import TestState, TestResult
# ── Technique-level scoring ────────────────────────────────────────── # ── Bulk scoring helpers (5 queries for ALL techniques) ───────────────
def _build_empty_stats():
return {
"validated": 0,
"detected": 0,
"platforms": set(),
"latest_validated_at": None,
}
def bulk_technique_scores(db: Session) -> dict:
"""Pre-fetch all scoring data and compute per-technique scores in memory.
Executes exactly 5 queries regardless of technique count:
Q1 — Test aggregates per technique (validated / detected / platforms / freshness)
Q2 — Detection rules per mitre_id
Q3 — Triggered rules per mitre_id
Q4 — D3FEND mapping counts per technique
Q5 — All techniques
Returns ``{technique_id: {"total_score": float, "breakdown": dict}}``.
"""
w_tests = settings.SCORING_WEIGHT_TESTS
w_detection = settings.SCORING_WEIGHT_DETECTION_RULES
w_d3fend = settings.SCORING_WEIGHT_D3FEND
w_freshness = settings.SCORING_WEIGHT_FRESHNESS
w_diversity = settings.SCORING_WEIGHT_PLATFORM_DIVERSITY
# Q1: test stats grouped by technique_id
test_rows = (
db.query(
Test.technique_id,
func.count(Test.id).label("validated_count"),
func.count(
case((Test.detection_result == TestResult.detected, Test.id))
).label("detected_count"),
func.max(Test.red_validated_at).label("latest_validated_at"),
func.count(func.distinct(Test.platform)).label("platform_count"),
)
.filter(Test.state == TestState.validated)
.group_by(Test.technique_id)
.all()
)
test_stats: dict = {}
for row in test_rows:
test_stats[row.technique_id] = {
"validated": row.validated_count,
"detected": row.detected_count,
"latest_validated_at": row.latest_validated_at,
"platform_count": row.platform_count,
}
# Q2: active detection rules per mitre_id
rule_rows = (
db.query(
DetectionRule.mitre_technique_id,
func.count(DetectionRule.id).label("total"),
)
.filter(DetectionRule.is_active == True) # noqa: E712
.group_by(DetectionRule.mitre_technique_id)
.all()
)
rules_by_mitre: dict[str, int] = {r.mitre_technique_id: r.total for r in rule_rows}
# Q3: triggered rules per mitre_id
triggered_rows = (
db.query(
DetectionRule.mitre_technique_id,
func.count(TestDetectionResult.id).label("triggered"),
)
.join(DetectionRule, DetectionRule.id == TestDetectionResult.detection_rule_id)
.filter(TestDetectionResult.triggered == True) # noqa: E712
.group_by(DetectionRule.mitre_technique_id)
.all()
)
triggered_by_mitre: dict[str, int] = {
r.mitre_technique_id: r.triggered for r in triggered_rows
}
# Q4: D3FEND mapping counts per technique
d3fend_rows = (
db.query(
DefensiveTechniqueMapping.attack_technique_id,
func.count(DefensiveTechniqueMapping.id).label("total"),
)
.group_by(DefensiveTechniqueMapping.attack_technique_id)
.all()
)
d3fend_by_tech: dict = {r.attack_technique_id: r.total for r in d3fend_rows}
# Q5: all techniques
techniques = db.query(Technique).all()
now = datetime.utcnow()
results: dict = {}
for tech in techniques:
ts = test_stats.get(tech.id, {})
validated = ts.get("validated", 0)
detected = ts.get("detected", 0)
latest_at = ts.get("latest_validated_at")
plat_count = ts.get("platform_count", 0)
breakdown = {}
# 1. Tests validated with detection
if validated > 0:
test_ratio = detected / validated
test_score = round(test_ratio * w_tests, 1)
else:
test_ratio = 0
test_score = 0
breakdown["tests_validated"] = {
"score": test_score,
"max": w_tests,
"detail": (
f"{detected}/{validated} tests detected"
if validated else "No validated tests"
),
}
# 2. Detection rules
total_rules = rules_by_mitre.get(tech.mitre_id, 0)
triggered_rules = triggered_by_mitre.get(tech.mitre_id, 0)
if total_rules > 0:
detection_ratio = min(triggered_rules / total_rules, 1.0)
detection_score = round(detection_ratio * w_detection, 1)
else:
detection_ratio = 0
detection_score = 0
breakdown["detection_rules"] = {
"score": detection_score,
"max": w_detection,
"detail": (
f"{triggered_rules}/{total_rules} rules triggered"
if total_rules > 0 else "No detection rules available"
),
}
# 3. D3FEND coverage
total_cm = d3fend_by_tech.get(tech.id, 0)
if total_cm > 0 and detected > 0:
verified_cm = min(detected, total_cm)
d3fend_score = round((verified_cm / total_cm) * w_d3fend, 1)
else:
verified_cm = 0
d3fend_score = 0
breakdown["d3fend_coverage"] = {
"score": d3fend_score,
"max": w_d3fend,
"detail": (
f"{verified_cm}/{total_cm} countermeasures"
if total_cm > 0 else "No D3FEND mappings"
),
}
# 4. Freshness
if latest_at:
days_ago = (now - latest_at).days
if days_ago < 90:
freshness_pct = 1.0
elif days_ago < 180:
freshness_pct = 0.5
else:
freshness_pct = 0.0
freshness_score = round(freshness_pct * w_freshness, 1)
freshness_detail = f"Last test {days_ago} days ago"
else:
freshness_score = 0
freshness_detail = "No validated tests"
breakdown["freshness"] = {
"score": freshness_score,
"max": w_freshness,
"detail": freshness_detail,
}
# 5. Platform diversity
available = tech.platforms or []
total_platforms = len(available) if available else 3
if total_platforms > 0 and plat_count > 0:
diversity_score = round(
min(plat_count / total_platforms, 1.0) * w_diversity, 1,
)
else:
diversity_score = 0
breakdown["platform_diversity"] = {
"score": diversity_score,
"max": w_diversity,
"detail": (
f"{plat_count}/{total_platforms} platforms covered"
if plat_count > 0 else "No platforms tested"
),
}
total = min(
test_score + detection_score + d3fend_score
+ freshness_score + diversity_score,
100,
)
results[tech.id] = {
"total_score": round(total, 1),
"breakdown": breakdown,
"mitre_id": tech.mitre_id,
"tactic": tech.tactic,
}
return results
# ── Technique-level scoring (single technique — preserved API) ────────
def calculate_technique_score(technique: Technique, db: Session) -> dict: def calculate_technique_score(technique: Technique, db: Session) -> dict:
@@ -73,7 +289,7 @@ def calculate_technique_score(technique: Technique, db: Session) -> dict:
db.query(func.count(DetectionRule.id)) db.query(func.count(DetectionRule.id))
.filter( .filter(
DetectionRule.mitre_technique_id == technique.mitre_id, DetectionRule.mitre_technique_id == technique.mitre_id,
DetectionRule.is_active == True, DetectionRule.is_active == True, # noqa: E712
) )
.scalar() .scalar()
) or 0 ) or 0
@@ -88,7 +304,7 @@ def calculate_technique_score(technique: Technique, db: Session) -> dict:
) )
.filter( .filter(
DetectionRule.mitre_technique_id == technique.mitre_id, DetectionRule.mitre_technique_id == technique.mitre_id,
TestDetectionResult.triggered == True, TestDetectionResult.triggered == True, # noqa: E712
) )
.scalar() .scalar()
) or 0 ) or 0
@@ -114,11 +330,8 @@ def calculate_technique_score(technique: Technique, db: Session) -> dict:
.scalar() .scalar()
) or 0 ) or 0
# Consider a countermeasure "verified" if we have validated tests
# with detection for the technique (simplified heuristic)
verified_countermeasures = 0 verified_countermeasures = 0
if total_countermeasures > 0 and len(detected_tests) > 0: if total_countermeasures > 0 and len(detected_tests) > 0:
# Rough heuristic: each detected test validates ~1 countermeasure
verified_countermeasures = min(len(detected_tests), total_countermeasures) verified_countermeasures = min(len(detected_tests), total_countermeasures)
d3fend_ratio = verified_countermeasures / total_countermeasures d3fend_ratio = verified_countermeasures / total_countermeasures
d3fend_score = round(d3fend_ratio * w_d3fend, 1) d3fend_score = round(d3fend_ratio * w_d3fend, 1)
@@ -135,7 +348,6 @@ def calculate_technique_score(technique: Technique, db: Session) -> dict:
} }
# ── 4. Freshness ────────────────────────────────────────────── # ── 4. Freshness ──────────────────────────────────────────────
# Most recent validated test date
most_recent_test = ( most_recent_test = (
db.query(func.max(Test.red_validated_at)) db.query(func.max(Test.red_validated_at))
.filter( .filter(
@@ -169,7 +381,7 @@ def calculate_technique_score(technique: Technique, db: Session) -> dict:
# ── 5. Platform diversity ───────────────────────────────────── # ── 5. Platform diversity ─────────────────────────────────────
available_platforms = technique.platforms or [] available_platforms = technique.platforms or []
total_platforms = len(available_platforms) if available_platforms else 3 # default 3 total_platforms = len(available_platforms) if available_platforms else 3
tested_platforms = set() tested_platforms = set()
for t in validated_tests: for t in validated_tests:
@@ -208,30 +420,19 @@ def calculate_technique_score(technique: Technique, db: Session) -> dict:
def calculate_tactic_score(tactic: str, db: Session) -> dict: def calculate_tactic_score(tactic: str, db: Session) -> dict:
"""Calculate average score for all techniques in a tactic.""" """Calculate average score for all techniques in a tactic."""
techniques = ( scores_map = bulk_technique_scores(db)
db.query(Technique)
.filter(Technique.tactic.ilike(f"%{tactic}%"))
.all()
)
if not techniques: matching = [
return { v["total_score"]
"tactic": tactic, for v in scores_map.values()
"average_score": 0, if v.get("tactic") and tactic.lower() in v["tactic"].lower()
"techniques_count": 0, ]
"techniques_scored": 0,
}
scores = []
for tech in techniques:
result = calculate_technique_score(tech, db)
scores.append(result["total_score"])
return { return {
"tactic": tactic, "tactic": tactic,
"average_score": round(sum(scores) / len(scores), 1) if scores else 0, "average_score": round(sum(matching) / len(matching), 1) if matching else 0,
"techniques_count": len(techniques), "techniques_count": len(matching),
"techniques_scored": len([s for s in scores if s > 0]), "techniques_scored": len([s for s in matching if s > 0]),
} }
@@ -244,14 +445,13 @@ def calculate_actor_coverage_score(actor_id: str, db: Session) -> dict:
if not actor: if not actor:
return {"total_score": 0, "techniques_count": 0, "techniques_covered": 0} return {"total_score": 0, "techniques_count": 0, "techniques_covered": 0}
# Get all techniques used by this actor
actor_techniques = ( actor_techniques = (
db.query(ThreatActorTechnique) db.query(ThreatActorTechnique)
.filter(ThreatActorTechnique.threat_actor_id == actor.id) .filter(ThreatActorTechnique.threat_actor_id == actor.id)
.all() .all()
) )
technique_ids = [at.technique_id for at in actor_techniques] technique_ids = {at.technique_id for at in actor_techniques}
if not technique_ids: if not technique_ids:
return { return {
"actor_id": str(actor.id), "actor_id": str(actor.id),
@@ -262,23 +462,21 @@ def calculate_actor_coverage_score(actor_id: str, db: Session) -> dict:
"techniques_detail": [], "techniques_detail": [],
} }
techniques = ( scores_map = bulk_technique_scores(db)
db.query(Technique)
.filter(Technique.id.in_(technique_ids))
.all()
)
scores = [] scores = []
details = [] details = []
for tech in techniques: for tid in technique_ids:
result = calculate_technique_score(tech, db) entry = scores_map.get(tid)
score = result["total_score"] if not entry:
continue
score = entry["total_score"]
scores.append(score) scores.append(score)
details.append({ details.append({
"mitre_id": tech.mitre_id, "mitre_id": entry["mitre_id"],
"name": tech.name, "name": entry.get("name", ""),
"score": score, "score": score,
"breakdown": result["breakdown"], "breakdown": entry["breakdown"],
}) })
avg_score = round(sum(scores) / len(scores), 1) if scores else 0 avg_score = round(sum(scores) / len(scores), 1) if scores else 0
@@ -287,7 +485,7 @@ def calculate_actor_coverage_score(actor_id: str, db: Session) -> dict:
"actor_id": str(actor.id), "actor_id": str(actor.id),
"actor_name": actor.name, "actor_name": actor.name,
"total_score": avg_score, "total_score": avg_score,
"techniques_count": len(techniques), "techniques_count": len(technique_ids),
"techniques_covered": len([s for s in scores if s > 50]), "techniques_covered": len([s for s in scores if s > 50]),
"techniques_detail": details, "techniques_detail": details,
} }
@@ -297,10 +495,13 @@ def calculate_actor_coverage_score(actor_id: str, db: Session) -> dict:
def calculate_organization_score(db: Session) -> dict: def calculate_organization_score(db: Session) -> dict:
"""Calculate the overall organization security score.""" """Calculate the overall organization security score.
# All techniques
all_techniques = db.query(Technique).all() Uses ``bulk_technique_scores`` to compute all technique scores in
total_count = len(all_techniques) 5 aggregated queries instead of N*5.
"""
scores_map = bulk_technique_scores(db)
total_count = len(scores_map)
if total_count == 0: if total_count == 0:
return { return {
@@ -313,27 +514,16 @@ def calculate_organization_score(db: Session) -> dict:
"techniques_total": 0, "techniques_total": 0,
} }
# Calculate scores for all techniques (with caching for performance) all_scores = [v["total_score"] for v in scores_map.values()]
all_scores = []
evaluated_count = 0
for tech in all_techniques:
result = calculate_technique_score(tech, db)
score = result["total_score"]
all_scores.append(score)
if score > 0:
evaluated_count += 1
# Total coverage: average of all evaluated techniques
evaluated_scores = [s for s in all_scores if s > 0] evaluated_scores = [s for s in all_scores if s > 0]
evaluated_count = len(evaluated_scores)
total_coverage = ( total_coverage = (
round(sum(evaluated_scores) / len(evaluated_scores), 1) round(sum(evaluated_scores) / len(evaluated_scores), 1)
if evaluated_scores if evaluated_scores else 0
else 0
) )
# Critical coverage: techniques with high-severity templates # Critical coverage: techniques with high/critical severity templates
# (simplified: techniques that have tests are "critical")
from app.models.test_template import TestTemplate from app.models.test_template import TestTemplate
critical_mitre_ids = set( critical_mitre_ids = set(
@@ -344,38 +534,35 @@ def calculate_organization_score(db: Session) -> dict:
.all() .all()
) )
critical_techniques = [ critical_scores = [
t for t in all_techniques if t.mitre_id in critical_mitre_ids v["total_score"]
for v in scores_map.values()
if v.get("mitre_id") in critical_mitre_ids
] ]
if critical_techniques: critical_coverage = (
critical_scores = [] round(sum(critical_scores) / len(critical_scores), 1)
for tech in critical_techniques: if critical_scores else 0
result = calculate_technique_score(tech, db) )
critical_scores.append(result["total_score"])
critical_coverage = round(sum(critical_scores) / len(critical_scores), 1)
else:
critical_coverage = 0
# Detection maturity: based on detection rule coverage # Detection maturity (2 scalar queries — already efficient)
total_rules = ( total_rules = (
db.query(func.count(DetectionRule.id)) db.query(func.count(DetectionRule.id))
.filter(DetectionRule.is_active == True) .filter(DetectionRule.is_active == True) # noqa: E712
.scalar() .scalar()
) or 0 ) or 0
triggered_total = ( triggered_total = (
db.query(func.count(TestDetectionResult.id)) db.query(func.count(TestDetectionResult.id))
.filter(TestDetectionResult.triggered == True) .filter(TestDetectionResult.triggered == True) # noqa: E712
.scalar() .scalar()
) or 0 ) or 0
detection_maturity = ( detection_maturity = (
round((triggered_total / total_rules) * 100, 1) round((triggered_total / total_rules) * 100, 1)
if total_rules > 0 if total_rules > 0 else 0
else 0
) )
detection_maturity = min(detection_maturity, 100) detection_maturity = min(detection_maturity, 100)
# Response readiness: based on remediation completion # Response readiness (2 scalar queries — already efficient)
remediation_total = ( remediation_total = (
db.query(func.count(Test.id)) db.query(func.count(Test.id))
.filter(Test.remediation_status.isnot(None)) .filter(Test.remediation_status.isnot(None))
@@ -389,11 +576,9 @@ def calculate_organization_score(db: Session) -> dict:
response_readiness = ( response_readiness = (
round((remediation_completed / remediation_total) * 100, 1) round((remediation_completed / remediation_total) * 100, 1)
if remediation_total > 0 if remediation_total > 0 else 0
else 0
) )
# Overall score: weighted average of sub-scores
overall = round( overall = round(
total_coverage * 0.4 total_coverage * 0.4
+ critical_coverage * 0.25 + critical_coverage * 0.25
+16 -13
View File
@@ -2,6 +2,9 @@
Provides point-in-time coverage captures with normalised per-technique Provides point-in-time coverage captures with normalised per-technique
storage and temporal comparison between any two snapshots. storage and temporal comparison between any two snapshots.
Uses ``bulk_technique_scores`` so that snapshot creation runs in a fixed
number of SQL queries regardless of technique count.
""" """
import logging import logging
@@ -14,7 +17,10 @@ from sqlalchemy.orm import Session
from app.models.technique import Technique from app.models.technique import Technique
from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState
from app.models.enums import TechniqueStatus from app.models.enums import TechniqueStatus
from app.services.scoring_service import calculate_technique_score, calculate_organization_score from app.services.scoring_service import (
bulk_technique_scores,
calculate_organization_score,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -31,14 +37,16 @@ def create_snapshot(
) -> CoverageSnapshot: ) -> CoverageSnapshot:
"""Capture the current coverage state into a new snapshot. """Capture the current coverage state into a new snapshot.
1. Fetch every technique with its status and score. 1. Bulk-fetch all technique scores in 5 aggregated queries.
2. Compute aggregate counts. 2. Walk the already-loaded techniques to count statuses.
3. Persist a ``CoverageSnapshot`` with normalised 3. Compute the org score from the same bulk data.
4. Persist a ``CoverageSnapshot`` with normalised
``SnapshotTechniqueState`` rows. ``SnapshotTechniqueState`` rows.
""" """
scores_map = bulk_technique_scores(db)
techniques = db.query(Technique).all() techniques = db.query(Technique).all()
# Aggregate counters
validated_count = 0 validated_count = 0
partial_count = 0 partial_count = 0
not_covered_count = 0 not_covered_count = 0
@@ -54,7 +62,6 @@ def create_snapshot(
else (tech.status_global or "not_evaluated") else (tech.status_global or "not_evaluated")
) )
# Count by status
if status_value == "validated": if status_value == "validated":
validated_count += 1 validated_count += 1
elif status_value == "partial": elif status_value == "partial":
@@ -66,20 +73,17 @@ def create_snapshot(
else: else:
not_evaluated_count += 1 not_evaluated_count += 1
# Compute technique score entry = scores_map.get(tech.id, {})
score_data = calculate_technique_score(tech, db)
technique_rows.append({ technique_rows.append({
"technique_id": tech.id, "technique_id": tech.id,
"mitre_id": tech.mitre_id, "mitre_id": tech.mitre_id,
"status": status_value, "status": status_value,
"score": score_data["total_score"], "score": entry.get("total_score", 0),
}) })
# Organization score
org_data = calculate_organization_score(db) org_data = calculate_organization_score(db)
org_score = org_data.get("overall_score", 0) org_score = org_data.get("overall_score", 0)
# Create the snapshot
snapshot = CoverageSnapshot( snapshot = CoverageSnapshot(
name=name, name=name,
organization_score=org_score, organization_score=org_score,
@@ -92,9 +96,8 @@ def create_snapshot(
created_by=user_id, created_by=user_id,
) )
db.add(snapshot) db.add(snapshot)
db.flush() # get snapshot.id db.flush()
# Create normalised technique state rows
for row in technique_rows: for row in technique_rows:
state = SnapshotTechniqueState( state = SnapshotTechniqueState(
snapshot_id=snapshot.id, snapshot_id=snapshot.id,
+3 -2
View File
@@ -4,6 +4,9 @@ based on the state and result of its associated tests.
V2 rules account for dual Red/Blue validation and use V2 rules account for dual Red/Blue validation and use
``detection_result`` (filled by Blue Team) instead of the legacy ``detection_result`` (filled by Blue Team) instead of the legacy
``result`` field. ``result`` field.
This function mutates the technique but does **not** commit.
The caller is responsible for committing the session.
""" """
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -42,5 +45,3 @@ def recalculate_technique_status(db: Session, technique: Technique) -> None:
technique.status_global = TechniqueStatus.partial technique.status_global = TechniqueStatus.partial
else: else:
technique.status_global = TechniqueStatus.in_progress technique.status_global = TechniqueStatus.in_progress
db.commit()
+4 -16
View File
@@ -7,8 +7,9 @@ for each step in the test lifecycle:
rejected → draft rejected → draft
Every public function validates the transition, mutates the test, writes an Every public function validates the transition, mutates the test, and writes
audit-log entry, and commits the session. an audit-log entry. The caller (router) is responsible for committing the
session via the Unit of Work pattern.
""" """
import logging import logging
@@ -122,7 +123,6 @@ def start_execution(db: Session, test: Test, user: User) -> Test:
) )
test.execution_date = now test.execution_date = now
test.red_started_at = now test.red_started_at = now
db.commit()
return test return test
@@ -161,7 +161,6 @@ def submit_red_evidence(db: Session, test: Test, user: User) -> Test:
# Start Blue Team timer # Start Blue Team timer
test.blue_started_at = now test.blue_started_at = now
test.blue_paused_seconds = 0 test.blue_paused_seconds = 0
db.commit()
return test return test
@@ -196,7 +195,6 @@ def submit_blue_evidence(db: Session, test: Test, user: User) -> Test:
description=f"Blue Team evaluation: {test.name}", description=f"Blue Team evaluation: {test.name}",
) )
db.commit()
return test return test
@@ -222,7 +220,6 @@ def pause_timer(db: Session, test: Test, user: User) -> Test:
entity_id=test.id, entity_id=test.id,
details={"state": test.state.value}, details={"state": test.state.value},
) )
db.commit()
return test return test
@@ -252,7 +249,6 @@ def resume_timer(db: Session, test: Test, user: User) -> Test:
entity_id=test.id, entity_id=test.id,
details={"paused_seconds": paused_seconds, "state": test.state.value}, details={"paused_seconds": paused_seconds, "state": test.state.value},
) )
db.commit()
return test return test
@@ -421,14 +417,12 @@ def check_dual_validation(db: Session, test: Test) -> Test:
if red_status == "rejected" or blue_status == "rejected": if red_status == "rejected" or blue_status == "rejected":
test.state = TestState.rejected test.state = TestState.rejected
db.commit()
try: try:
notify_test_state_change(db, test, "rejected") notify_test_state_change(db, test, "rejected")
except Exception as e: except Exception as e:
logger.warning("Notification failed for test %s (rejected): %s", test.id, e, exc_info=True) logger.warning("Notification failed for test %s (rejected): %s", test.id, e, exc_info=True)
elif red_status == "approved" and blue_status == "approved": elif red_status == "approved" and blue_status == "approved":
test.state = TestState.validated test.state = TestState.validated
db.commit()
# Invalidate cached scores — a validation changes org-level numbers # Invalidate cached scores — a validation changes org-level numbers
try: try:
from app.services.score_cache import invalidate from app.services.score_cache import invalidate
@@ -439,10 +433,6 @@ def check_dual_validation(db: Session, test: Test) -> Test:
notify_test_state_change(db, test, "validated") notify_test_state_change(db, test, "validated")
except Exception as e: except Exception as e:
logger.warning("Notification failed for test %s (validated): %s", test.id, e, exc_info=True) logger.warning("Notification failed for test %s (validated): %s", test.id, e, exc_info=True)
else:
# One side hasn't voted yet — stay in_review, just flush
db.commit()
return test return test
@@ -533,8 +523,7 @@ def handle_remediation_completed(db: Session, test: Test, user: User) -> Test |
entity_id=retest.id, entity_id=retest.id,
) )
db.commit() db.flush()
db.refresh(retest)
return retest return retest
@@ -599,5 +588,4 @@ def reopen_test(db: Session, test: Test, user: User) -> Test:
test.red_paused_seconds = 0 test.red_paused_seconds = 0
test.blue_paused_seconds = 0 test.blue_paused_seconds = 0
db.commit()
return test return test