Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| bfce1a8a0e | |||
| 98e8ca1eef | |||
| f0f59facdb |
@@ -0,0 +1,47 @@
|
|||||||
|
"""Unit of Work — wraps a SQLAlchemy session for explicit transaction control.
|
||||||
|
|
||||||
|
Usage in routers::
|
||||||
|
|
||||||
|
with UnitOfWork(db) as uow:
|
||||||
|
service_a(db, ...)
|
||||||
|
service_b(db, ...)
|
||||||
|
uow.commit() # single commit for the entire operation
|
||||||
|
|
||||||
|
If an exception propagates, ``__exit__`` issues a rollback automatically.
|
||||||
|
Services should **never** call ``db.commit()``; they use ``db.add()`` /
|
||||||
|
``db.flush()`` to stage work and let the caller decide when to commit.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
|
||||||
|
class UnitOfWork:
|
||||||
|
"""Lightweight transaction wrapper around an existing SQLAlchemy session."""
|
||||||
|
|
||||||
|
def __init__(self, session: Session) -> None:
|
||||||
|
self._session = session
|
||||||
|
|
||||||
|
# -- context manager -----------------------------------------------------
|
||||||
|
|
||||||
|
def __enter__(self) -> "UnitOfWork":
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
||||||
|
if exc_type is not None:
|
||||||
|
self.rollback()
|
||||||
|
|
||||||
|
# -- public API ----------------------------------------------------------
|
||||||
|
|
||||||
|
def commit(self) -> None:
|
||||||
|
"""Flush pending changes and commit the transaction."""
|
||||||
|
self._session.commit()
|
||||||
|
|
||||||
|
def rollback(self) -> None:
|
||||||
|
"""Roll back the current transaction."""
|
||||||
|
self._session.rollback()
|
||||||
|
|
||||||
|
def flush(self) -> None:
|
||||||
|
"""Flush pending changes without committing (useful for getting IDs)."""
|
||||||
|
self._session.flush()
|
||||||
@@ -15,6 +15,7 @@ from sqlalchemy.orm import Session
|
|||||||
|
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
from app.dependencies.auth import get_current_user
|
from app.dependencies.auth import get_current_user
|
||||||
|
from app.domain.unit_of_work import UnitOfWork
|
||||||
from app.models.notification import Notification
|
from app.models.notification import Notification
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.schemas.notification import NotificationOut, UnreadCountOut
|
from app.schemas.notification import NotificationOut, UnreadCountOut
|
||||||
@@ -78,12 +79,14 @@ def read_notification(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Mark a single notification as read."""
|
"""Mark a single notification as read."""
|
||||||
success = mark_as_read(db, notification_id, current_user.id)
|
with UnitOfWork(db) as uow:
|
||||||
if not success:
|
success = mark_as_read(db, notification_id, current_user.id)
|
||||||
raise HTTPException(
|
if not success:
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
raise HTTPException(
|
||||||
detail="Notification not found",
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
)
|
detail="Notification not found",
|
||||||
|
)
|
||||||
|
uow.commit()
|
||||||
notif = db.query(Notification).filter(Notification.id == notification_id).first()
|
notif = db.query(Notification).filter(Notification.id == notification_id).first()
|
||||||
return notif
|
return notif
|
||||||
|
|
||||||
@@ -99,5 +102,7 @@ def read_all_notifications(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Mark all notifications for the current user as read."""
|
"""Mark all notifications for the current user as read."""
|
||||||
count = mark_all_as_read(db, current_user.id)
|
with UnitOfWork(db) as uow:
|
||||||
|
count = mark_all_as_read(db, current_user.id)
|
||||||
|
uow.commit()
|
||||||
return {"detail": f"Marked {count} notifications as read"}
|
return {"detail": f"Marked {count} notifications as read"}
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ from app.schemas.test import (
|
|||||||
TestRemediationUpdate,
|
TestRemediationUpdate,
|
||||||
)
|
)
|
||||||
from app.schemas.test_template import TestTemplateInstantiate
|
from app.schemas.test_template import TestTemplateInstantiate
|
||||||
|
from app.domain.unit_of_work import UnitOfWork
|
||||||
from app.services.audit_service import log_action
|
from app.services.audit_service import log_action
|
||||||
from app.services.status_service import recalculate_technique_status
|
from app.services.status_service import recalculate_technique_status
|
||||||
from app.services.test_workflow_service import (
|
from app.services.test_workflow_service import (
|
||||||
@@ -434,7 +435,9 @@ def start_execution(
|
|||||||
):
|
):
|
||||||
"""Move a test from ``draft`` to ``red_executing``."""
|
"""Move a test from ``draft`` to ``red_executing``."""
|
||||||
test = _get_test_or_404(db, test_id)
|
test = _get_test_or_404(db, test_id)
|
||||||
test = wf_start_execution(db, test, current_user)
|
with UnitOfWork(db) as uow:
|
||||||
|
test = wf_start_execution(db, test, current_user)
|
||||||
|
uow.commit()
|
||||||
db.refresh(test)
|
db.refresh(test)
|
||||||
return test
|
return test
|
||||||
|
|
||||||
@@ -452,7 +455,9 @@ def submit_red(
|
|||||||
):
|
):
|
||||||
"""Red Team finalises — move from ``red_executing`` to ``blue_evaluating``."""
|
"""Red Team finalises — move from ``red_executing`` to ``blue_evaluating``."""
|
||||||
test = _get_test_or_404(db, test_id)
|
test = _get_test_or_404(db, test_id)
|
||||||
test = wf_submit_red(db, test, current_user)
|
with UnitOfWork(db) as uow:
|
||||||
|
test = wf_submit_red(db, test, current_user)
|
||||||
|
uow.commit()
|
||||||
db.refresh(test)
|
db.refresh(test)
|
||||||
return test
|
return test
|
||||||
|
|
||||||
@@ -470,7 +475,9 @@ def submit_blue(
|
|||||||
):
|
):
|
||||||
"""Blue Team finalises — move from ``blue_evaluating`` to ``in_review``."""
|
"""Blue Team finalises — move from ``blue_evaluating`` to ``in_review``."""
|
||||||
test = _get_test_or_404(db, test_id)
|
test = _get_test_or_404(db, test_id)
|
||||||
test = wf_submit_blue(db, test, current_user)
|
with UnitOfWork(db) as uow:
|
||||||
|
test = wf_submit_blue(db, test, current_user)
|
||||||
|
uow.commit()
|
||||||
db.refresh(test)
|
db.refresh(test)
|
||||||
return test
|
return test
|
||||||
|
|
||||||
@@ -488,7 +495,9 @@ def pause_timer(
|
|||||||
):
|
):
|
||||||
"""Pause the running timer for the current phase (red_executing or blue_evaluating)."""
|
"""Pause the running timer for the current phase (red_executing or blue_evaluating)."""
|
||||||
test = _get_test_or_404(db, test_id)
|
test = _get_test_or_404(db, test_id)
|
||||||
test = wf_pause_timer(db, test, current_user)
|
with UnitOfWork(db) as uow:
|
||||||
|
test = wf_pause_timer(db, test, current_user)
|
||||||
|
uow.commit()
|
||||||
db.refresh(test)
|
db.refresh(test)
|
||||||
return test
|
return test
|
||||||
|
|
||||||
@@ -506,7 +515,9 @@ def resume_timer(
|
|||||||
):
|
):
|
||||||
"""Resume the paused timer for the current phase."""
|
"""Resume the paused timer for the current phase."""
|
||||||
test = _get_test_or_404(db, test_id)
|
test = _get_test_or_404(db, test_id)
|
||||||
test = wf_resume_timer(db, test, current_user)
|
with UnitOfWork(db) as uow:
|
||||||
|
test = wf_resume_timer(db, test, current_user)
|
||||||
|
uow.commit()
|
||||||
db.refresh(test)
|
db.refresh(test)
|
||||||
return test
|
return test
|
||||||
|
|
||||||
@@ -525,16 +536,15 @@ def validate_red(
|
|||||||
):
|
):
|
||||||
"""Red Lead approves or rejects the red side of a test."""
|
"""Red Lead approves or rejects the red side of a test."""
|
||||||
test = _get_test_with_technique(db, test_id)
|
test = _get_test_with_technique(db, test_id)
|
||||||
test = wf_validate_red(
|
with UnitOfWork(db) as uow:
|
||||||
db, test, current_user,
|
test = wf_validate_red(
|
||||||
validation_status=payload.red_validation_status,
|
db, test, current_user,
|
||||||
notes=payload.red_validation_notes,
|
validation_status=payload.red_validation_status,
|
||||||
)
|
notes=payload.red_validation_notes,
|
||||||
|
)
|
||||||
# Recalculate technique status if test reached a terminal state
|
if test.state in (TestState.validated, TestState.rejected):
|
||||||
if test.state in (TestState.validated, TestState.rejected):
|
recalculate_technique_status(db, test.technique)
|
||||||
recalculate_technique_status(db, test.technique)
|
uow.commit()
|
||||||
|
|
||||||
db.refresh(test)
|
db.refresh(test)
|
||||||
return test
|
return test
|
||||||
|
|
||||||
@@ -553,16 +563,15 @@ def validate_blue(
|
|||||||
):
|
):
|
||||||
"""Blue Lead approves or rejects the blue side of a test."""
|
"""Blue Lead approves or rejects the blue side of a test."""
|
||||||
test = _get_test_with_technique(db, test_id)
|
test = _get_test_with_technique(db, test_id)
|
||||||
test = wf_validate_blue(
|
with UnitOfWork(db) as uow:
|
||||||
db, test, current_user,
|
test = wf_validate_blue(
|
||||||
validation_status=payload.blue_validation_status,
|
db, test, current_user,
|
||||||
notes=payload.blue_validation_notes,
|
validation_status=payload.blue_validation_status,
|
||||||
)
|
notes=payload.blue_validation_notes,
|
||||||
|
)
|
||||||
# Recalculate technique status if test reached a terminal state
|
if test.state in (TestState.validated, TestState.rejected):
|
||||||
if test.state in (TestState.validated, TestState.rejected):
|
recalculate_technique_status(db, test.technique)
|
||||||
recalculate_technique_status(db, test.technique)
|
uow.commit()
|
||||||
|
|
||||||
db.refresh(test)
|
db.refresh(test)
|
||||||
return test
|
return test
|
||||||
|
|
||||||
@@ -580,7 +589,9 @@ def reopen(
|
|||||||
):
|
):
|
||||||
"""Reopen a rejected test, moving it back to ``draft``."""
|
"""Reopen a rejected test, moving it back to ``draft``."""
|
||||||
test = _get_test_or_404(db, test_id)
|
test = _get_test_or_404(db, test_id)
|
||||||
test = wf_reopen(db, test, current_user)
|
with UnitOfWork(db) as uow:
|
||||||
|
test = wf_reopen(db, test, current_user)
|
||||||
|
uow.commit()
|
||||||
db.refresh(test)
|
db.refresh(test)
|
||||||
return test
|
return test
|
||||||
|
|
||||||
@@ -610,25 +621,23 @@ def update_remediation(
|
|||||||
for field, value in update_data.items():
|
for field, value in update_data.items():
|
||||||
setattr(test, field, value)
|
setattr(test, field, value)
|
||||||
|
|
||||||
db.commit()
|
with UnitOfWork(db) as uow:
|
||||||
|
log_action(
|
||||||
|
db,
|
||||||
|
user_id=current_user.id,
|
||||||
|
action="update_remediation",
|
||||||
|
entity_type="test",
|
||||||
|
entity_id=test.id,
|
||||||
|
details={"updated_fields": list(update_data.keys())},
|
||||||
|
)
|
||||||
|
|
||||||
|
new_status = update_data.get("remediation_status")
|
||||||
|
if new_status == "completed" and old_remediation_status != "completed":
|
||||||
|
wf_handle_remediation(db, test, current_user)
|
||||||
|
|
||||||
|
uow.commit()
|
||||||
|
|
||||||
db.refresh(test)
|
db.refresh(test)
|
||||||
|
|
||||||
log_action(
|
|
||||||
db,
|
|
||||||
user_id=current_user.id,
|
|
||||||
action="update_remediation",
|
|
||||||
entity_type="test",
|
|
||||||
entity_id=test.id,
|
|
||||||
details={"updated_fields": list(update_data.keys())},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Auto-create retest when remediation is marked completed
|
|
||||||
new_status = update_data.get("remediation_status")
|
|
||||||
if new_status == "completed" and old_remediation_status != "completed":
|
|
||||||
retest = wf_handle_remediation(db, test, current_user)
|
|
||||||
if retest:
|
|
||||||
db.refresh(test)
|
|
||||||
|
|
||||||
return test
|
return test
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,9 @@
|
|||||||
|
|
||||||
Provides helpers for generating notifications automatically when test
|
Provides helpers for generating notifications automatically when test
|
||||||
state changes occur, plus CRUD for the notifications API.
|
state changes occur, plus CRUD for the notifications API.
|
||||||
|
|
||||||
|
Functions in this module stage changes via ``db.add()`` / ``db.flush()``
|
||||||
|
but do **not** commit. The caller is responsible for committing.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
@@ -38,8 +41,7 @@ def create_notification(
|
|||||||
entity_id=entity_id,
|
entity_id=entity_id,
|
||||||
)
|
)
|
||||||
db.add(notif)
|
db.add(notif)
|
||||||
db.commit()
|
db.flush()
|
||||||
db.refresh(notif)
|
|
||||||
return notif
|
return notif
|
||||||
|
|
||||||
|
|
||||||
@@ -53,7 +55,6 @@ def mark_as_read(db: Session, notification_id: uuid.UUID, user_id: uuid.UUID) ->
|
|||||||
if notif is None:
|
if notif is None:
|
||||||
return False
|
return False
|
||||||
notif.read = True
|
notif.read = True
|
||||||
db.commit()
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
@@ -64,7 +65,6 @@ def mark_all_as_read(db: Session, user_id: uuid.UUID) -> int:
|
|||||||
.filter(Notification.user_id == user_id, Notification.read == False) # noqa: E712
|
.filter(Notification.user_id == user_id, Notification.read == False) # noqa: E712
|
||||||
.update({"read": True})
|
.update({"read": True})
|
||||||
)
|
)
|
||||||
db.commit()
|
|
||||||
return count
|
return count
|
||||||
|
|
||||||
|
|
||||||
@@ -88,7 +88,6 @@ def cleanup_old_notifications(db: Session, days: int = 90) -> int:
|
|||||||
)
|
)
|
||||||
.delete()
|
.delete()
|
||||||
)
|
)
|
||||||
db.commit()
|
|
||||||
return count
|
return count
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,12 +2,16 @@
|
|||||||
|
|
||||||
Uses configurable weights from Settings to compute coverage scores with
|
Uses configurable weights from Settings to compute coverage scores with
|
||||||
detailed breakdowns.
|
detailed breakdowns.
|
||||||
|
|
||||||
|
Bulk helpers (``bulk_technique_scores``) pre-fetch all scoring data in a
|
||||||
|
fixed number of aggregated queries so that organisation-wide calculations
|
||||||
|
never produce N+1 traffic.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from sqlalchemy import func
|
from sqlalchemy import case, func
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
@@ -20,7 +24,219 @@ from app.models.threat_actor import ThreatActor, ThreatActorTechnique
|
|||||||
from app.models.enums import TestState, TestResult
|
from app.models.enums import TestState, TestResult
|
||||||
|
|
||||||
|
|
||||||
# ── Technique-level scoring ──────────────────────────────────────────
|
# ── Bulk scoring helpers (5 queries for ALL techniques) ───────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _build_empty_stats():
|
||||||
|
return {
|
||||||
|
"validated": 0,
|
||||||
|
"detected": 0,
|
||||||
|
"platforms": set(),
|
||||||
|
"latest_validated_at": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def bulk_technique_scores(db: Session) -> dict:
|
||||||
|
"""Pre-fetch all scoring data and compute per-technique scores in memory.
|
||||||
|
|
||||||
|
Executes exactly 5 queries regardless of technique count:
|
||||||
|
Q1 — Test aggregates per technique (validated / detected / platforms / freshness)
|
||||||
|
Q2 — Detection rules per mitre_id
|
||||||
|
Q3 — Triggered rules per mitre_id
|
||||||
|
Q4 — D3FEND mapping counts per technique
|
||||||
|
Q5 — All techniques
|
||||||
|
|
||||||
|
Returns ``{technique_id: {"total_score": float, "breakdown": dict}}``.
|
||||||
|
"""
|
||||||
|
w_tests = settings.SCORING_WEIGHT_TESTS
|
||||||
|
w_detection = settings.SCORING_WEIGHT_DETECTION_RULES
|
||||||
|
w_d3fend = settings.SCORING_WEIGHT_D3FEND
|
||||||
|
w_freshness = settings.SCORING_WEIGHT_FRESHNESS
|
||||||
|
w_diversity = settings.SCORING_WEIGHT_PLATFORM_DIVERSITY
|
||||||
|
|
||||||
|
# Q1: test stats grouped by technique_id
|
||||||
|
test_rows = (
|
||||||
|
db.query(
|
||||||
|
Test.technique_id,
|
||||||
|
func.count(Test.id).label("validated_count"),
|
||||||
|
func.count(
|
||||||
|
case((Test.detection_result == TestResult.detected, Test.id))
|
||||||
|
).label("detected_count"),
|
||||||
|
func.max(Test.red_validated_at).label("latest_validated_at"),
|
||||||
|
func.count(func.distinct(Test.platform)).label("platform_count"),
|
||||||
|
)
|
||||||
|
.filter(Test.state == TestState.validated)
|
||||||
|
.group_by(Test.technique_id)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
test_stats: dict = {}
|
||||||
|
for row in test_rows:
|
||||||
|
test_stats[row.technique_id] = {
|
||||||
|
"validated": row.validated_count,
|
||||||
|
"detected": row.detected_count,
|
||||||
|
"latest_validated_at": row.latest_validated_at,
|
||||||
|
"platform_count": row.platform_count,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Q2: active detection rules per mitre_id
|
||||||
|
rule_rows = (
|
||||||
|
db.query(
|
||||||
|
DetectionRule.mitre_technique_id,
|
||||||
|
func.count(DetectionRule.id).label("total"),
|
||||||
|
)
|
||||||
|
.filter(DetectionRule.is_active == True) # noqa: E712
|
||||||
|
.group_by(DetectionRule.mitre_technique_id)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
rules_by_mitre: dict[str, int] = {r.mitre_technique_id: r.total for r in rule_rows}
|
||||||
|
|
||||||
|
# Q3: triggered rules per mitre_id
|
||||||
|
triggered_rows = (
|
||||||
|
db.query(
|
||||||
|
DetectionRule.mitre_technique_id,
|
||||||
|
func.count(TestDetectionResult.id).label("triggered"),
|
||||||
|
)
|
||||||
|
.join(DetectionRule, DetectionRule.id == TestDetectionResult.detection_rule_id)
|
||||||
|
.filter(TestDetectionResult.triggered == True) # noqa: E712
|
||||||
|
.group_by(DetectionRule.mitre_technique_id)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
triggered_by_mitre: dict[str, int] = {
|
||||||
|
r.mitre_technique_id: r.triggered for r in triggered_rows
|
||||||
|
}
|
||||||
|
|
||||||
|
# Q4: D3FEND mapping counts per technique
|
||||||
|
d3fend_rows = (
|
||||||
|
db.query(
|
||||||
|
DefensiveTechniqueMapping.attack_technique_id,
|
||||||
|
func.count(DefensiveTechniqueMapping.id).label("total"),
|
||||||
|
)
|
||||||
|
.group_by(DefensiveTechniqueMapping.attack_technique_id)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
d3fend_by_tech: dict = {r.attack_technique_id: r.total for r in d3fend_rows}
|
||||||
|
|
||||||
|
# Q5: all techniques
|
||||||
|
techniques = db.query(Technique).all()
|
||||||
|
|
||||||
|
now = datetime.utcnow()
|
||||||
|
results: dict = {}
|
||||||
|
|
||||||
|
for tech in techniques:
|
||||||
|
ts = test_stats.get(tech.id, {})
|
||||||
|
validated = ts.get("validated", 0)
|
||||||
|
detected = ts.get("detected", 0)
|
||||||
|
latest_at = ts.get("latest_validated_at")
|
||||||
|
plat_count = ts.get("platform_count", 0)
|
||||||
|
|
||||||
|
breakdown = {}
|
||||||
|
|
||||||
|
# 1. Tests validated with detection
|
||||||
|
if validated > 0:
|
||||||
|
test_ratio = detected / validated
|
||||||
|
test_score = round(test_ratio * w_tests, 1)
|
||||||
|
else:
|
||||||
|
test_ratio = 0
|
||||||
|
test_score = 0
|
||||||
|
breakdown["tests_validated"] = {
|
||||||
|
"score": test_score,
|
||||||
|
"max": w_tests,
|
||||||
|
"detail": (
|
||||||
|
f"{detected}/{validated} tests detected"
|
||||||
|
if validated else "No validated tests"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
# 2. Detection rules
|
||||||
|
total_rules = rules_by_mitre.get(tech.mitre_id, 0)
|
||||||
|
triggered_rules = triggered_by_mitre.get(tech.mitre_id, 0)
|
||||||
|
if total_rules > 0:
|
||||||
|
detection_ratio = min(triggered_rules / total_rules, 1.0)
|
||||||
|
detection_score = round(detection_ratio * w_detection, 1)
|
||||||
|
else:
|
||||||
|
detection_ratio = 0
|
||||||
|
detection_score = 0
|
||||||
|
breakdown["detection_rules"] = {
|
||||||
|
"score": detection_score,
|
||||||
|
"max": w_detection,
|
||||||
|
"detail": (
|
||||||
|
f"{triggered_rules}/{total_rules} rules triggered"
|
||||||
|
if total_rules > 0 else "No detection rules available"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
# 3. D3FEND coverage
|
||||||
|
total_cm = d3fend_by_tech.get(tech.id, 0)
|
||||||
|
if total_cm > 0 and detected > 0:
|
||||||
|
verified_cm = min(detected, total_cm)
|
||||||
|
d3fend_score = round((verified_cm / total_cm) * w_d3fend, 1)
|
||||||
|
else:
|
||||||
|
verified_cm = 0
|
||||||
|
d3fend_score = 0
|
||||||
|
breakdown["d3fend_coverage"] = {
|
||||||
|
"score": d3fend_score,
|
||||||
|
"max": w_d3fend,
|
||||||
|
"detail": (
|
||||||
|
f"{verified_cm}/{total_cm} countermeasures"
|
||||||
|
if total_cm > 0 else "No D3FEND mappings"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
# 4. Freshness
|
||||||
|
if latest_at:
|
||||||
|
days_ago = (now - latest_at).days
|
||||||
|
if days_ago < 90:
|
||||||
|
freshness_pct = 1.0
|
||||||
|
elif days_ago < 180:
|
||||||
|
freshness_pct = 0.5
|
||||||
|
else:
|
||||||
|
freshness_pct = 0.0
|
||||||
|
freshness_score = round(freshness_pct * w_freshness, 1)
|
||||||
|
freshness_detail = f"Last test {days_ago} days ago"
|
||||||
|
else:
|
||||||
|
freshness_score = 0
|
||||||
|
freshness_detail = "No validated tests"
|
||||||
|
breakdown["freshness"] = {
|
||||||
|
"score": freshness_score,
|
||||||
|
"max": w_freshness,
|
||||||
|
"detail": freshness_detail,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 5. Platform diversity
|
||||||
|
available = tech.platforms or []
|
||||||
|
total_platforms = len(available) if available else 3
|
||||||
|
if total_platforms > 0 and plat_count > 0:
|
||||||
|
diversity_score = round(
|
||||||
|
min(plat_count / total_platforms, 1.0) * w_diversity, 1,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
diversity_score = 0
|
||||||
|
breakdown["platform_diversity"] = {
|
||||||
|
"score": diversity_score,
|
||||||
|
"max": w_diversity,
|
||||||
|
"detail": (
|
||||||
|
f"{plat_count}/{total_platforms} platforms covered"
|
||||||
|
if plat_count > 0 else "No platforms tested"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
total = min(
|
||||||
|
test_score + detection_score + d3fend_score
|
||||||
|
+ freshness_score + diversity_score,
|
||||||
|
100,
|
||||||
|
)
|
||||||
|
results[tech.id] = {
|
||||||
|
"total_score": round(total, 1),
|
||||||
|
"breakdown": breakdown,
|
||||||
|
"mitre_id": tech.mitre_id,
|
||||||
|
"tactic": tech.tactic,
|
||||||
|
}
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
# ── Technique-level scoring (single technique — preserved API) ────────
|
||||||
|
|
||||||
|
|
||||||
def calculate_technique_score(technique: Technique, db: Session) -> dict:
|
def calculate_technique_score(technique: Technique, db: Session) -> dict:
|
||||||
@@ -73,7 +289,7 @@ def calculate_technique_score(technique: Technique, db: Session) -> dict:
|
|||||||
db.query(func.count(DetectionRule.id))
|
db.query(func.count(DetectionRule.id))
|
||||||
.filter(
|
.filter(
|
||||||
DetectionRule.mitre_technique_id == technique.mitre_id,
|
DetectionRule.mitre_technique_id == technique.mitre_id,
|
||||||
DetectionRule.is_active == True,
|
DetectionRule.is_active == True, # noqa: E712
|
||||||
)
|
)
|
||||||
.scalar()
|
.scalar()
|
||||||
) or 0
|
) or 0
|
||||||
@@ -88,7 +304,7 @@ def calculate_technique_score(technique: Technique, db: Session) -> dict:
|
|||||||
)
|
)
|
||||||
.filter(
|
.filter(
|
||||||
DetectionRule.mitre_technique_id == technique.mitre_id,
|
DetectionRule.mitre_technique_id == technique.mitre_id,
|
||||||
TestDetectionResult.triggered == True,
|
TestDetectionResult.triggered == True, # noqa: E712
|
||||||
)
|
)
|
||||||
.scalar()
|
.scalar()
|
||||||
) or 0
|
) or 0
|
||||||
@@ -114,11 +330,8 @@ def calculate_technique_score(technique: Technique, db: Session) -> dict:
|
|||||||
.scalar()
|
.scalar()
|
||||||
) or 0
|
) or 0
|
||||||
|
|
||||||
# Consider a countermeasure "verified" if we have validated tests
|
|
||||||
# with detection for the technique (simplified heuristic)
|
|
||||||
verified_countermeasures = 0
|
verified_countermeasures = 0
|
||||||
if total_countermeasures > 0 and len(detected_tests) > 0:
|
if total_countermeasures > 0 and len(detected_tests) > 0:
|
||||||
# Rough heuristic: each detected test validates ~1 countermeasure
|
|
||||||
verified_countermeasures = min(len(detected_tests), total_countermeasures)
|
verified_countermeasures = min(len(detected_tests), total_countermeasures)
|
||||||
d3fend_ratio = verified_countermeasures / total_countermeasures
|
d3fend_ratio = verified_countermeasures / total_countermeasures
|
||||||
d3fend_score = round(d3fend_ratio * w_d3fend, 1)
|
d3fend_score = round(d3fend_ratio * w_d3fend, 1)
|
||||||
@@ -135,7 +348,6 @@ def calculate_technique_score(technique: Technique, db: Session) -> dict:
|
|||||||
}
|
}
|
||||||
|
|
||||||
# ── 4. Freshness ──────────────────────────────────────────────
|
# ── 4. Freshness ──────────────────────────────────────────────
|
||||||
# Most recent validated test date
|
|
||||||
most_recent_test = (
|
most_recent_test = (
|
||||||
db.query(func.max(Test.red_validated_at))
|
db.query(func.max(Test.red_validated_at))
|
||||||
.filter(
|
.filter(
|
||||||
@@ -169,7 +381,7 @@ def calculate_technique_score(technique: Technique, db: Session) -> dict:
|
|||||||
|
|
||||||
# ── 5. Platform diversity ─────────────────────────────────────
|
# ── 5. Platform diversity ─────────────────────────────────────
|
||||||
available_platforms = technique.platforms or []
|
available_platforms = technique.platforms or []
|
||||||
total_platforms = len(available_platforms) if available_platforms else 3 # default 3
|
total_platforms = len(available_platforms) if available_platforms else 3
|
||||||
|
|
||||||
tested_platforms = set()
|
tested_platforms = set()
|
||||||
for t in validated_tests:
|
for t in validated_tests:
|
||||||
@@ -208,30 +420,19 @@ def calculate_technique_score(technique: Technique, db: Session) -> dict:
|
|||||||
|
|
||||||
def calculate_tactic_score(tactic: str, db: Session) -> dict:
|
def calculate_tactic_score(tactic: str, db: Session) -> dict:
|
||||||
"""Calculate average score for all techniques in a tactic."""
|
"""Calculate average score for all techniques in a tactic."""
|
||||||
techniques = (
|
scores_map = bulk_technique_scores(db)
|
||||||
db.query(Technique)
|
|
||||||
.filter(Technique.tactic.ilike(f"%{tactic}%"))
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
|
|
||||||
if not techniques:
|
matching = [
|
||||||
return {
|
v["total_score"]
|
||||||
"tactic": tactic,
|
for v in scores_map.values()
|
||||||
"average_score": 0,
|
if v.get("tactic") and tactic.lower() in v["tactic"].lower()
|
||||||
"techniques_count": 0,
|
]
|
||||||
"techniques_scored": 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
scores = []
|
|
||||||
for tech in techniques:
|
|
||||||
result = calculate_technique_score(tech, db)
|
|
||||||
scores.append(result["total_score"])
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"tactic": tactic,
|
"tactic": tactic,
|
||||||
"average_score": round(sum(scores) / len(scores), 1) if scores else 0,
|
"average_score": round(sum(matching) / len(matching), 1) if matching else 0,
|
||||||
"techniques_count": len(techniques),
|
"techniques_count": len(matching),
|
||||||
"techniques_scored": len([s for s in scores if s > 0]),
|
"techniques_scored": len([s for s in matching if s > 0]),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -244,14 +445,13 @@ def calculate_actor_coverage_score(actor_id: str, db: Session) -> dict:
|
|||||||
if not actor:
|
if not actor:
|
||||||
return {"total_score": 0, "techniques_count": 0, "techniques_covered": 0}
|
return {"total_score": 0, "techniques_count": 0, "techniques_covered": 0}
|
||||||
|
|
||||||
# Get all techniques used by this actor
|
|
||||||
actor_techniques = (
|
actor_techniques = (
|
||||||
db.query(ThreatActorTechnique)
|
db.query(ThreatActorTechnique)
|
||||||
.filter(ThreatActorTechnique.threat_actor_id == actor.id)
|
.filter(ThreatActorTechnique.threat_actor_id == actor.id)
|
||||||
.all()
|
.all()
|
||||||
)
|
)
|
||||||
|
|
||||||
technique_ids = [at.technique_id for at in actor_techniques]
|
technique_ids = {at.technique_id for at in actor_techniques}
|
||||||
if not technique_ids:
|
if not technique_ids:
|
||||||
return {
|
return {
|
||||||
"actor_id": str(actor.id),
|
"actor_id": str(actor.id),
|
||||||
@@ -262,23 +462,21 @@ def calculate_actor_coverage_score(actor_id: str, db: Session) -> dict:
|
|||||||
"techniques_detail": [],
|
"techniques_detail": [],
|
||||||
}
|
}
|
||||||
|
|
||||||
techniques = (
|
scores_map = bulk_technique_scores(db)
|
||||||
db.query(Technique)
|
|
||||||
.filter(Technique.id.in_(technique_ids))
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
|
|
||||||
scores = []
|
scores = []
|
||||||
details = []
|
details = []
|
||||||
for tech in techniques:
|
for tid in technique_ids:
|
||||||
result = calculate_technique_score(tech, db)
|
entry = scores_map.get(tid)
|
||||||
score = result["total_score"]
|
if not entry:
|
||||||
|
continue
|
||||||
|
score = entry["total_score"]
|
||||||
scores.append(score)
|
scores.append(score)
|
||||||
details.append({
|
details.append({
|
||||||
"mitre_id": tech.mitre_id,
|
"mitre_id": entry["mitre_id"],
|
||||||
"name": tech.name,
|
"name": entry.get("name", ""),
|
||||||
"score": score,
|
"score": score,
|
||||||
"breakdown": result["breakdown"],
|
"breakdown": entry["breakdown"],
|
||||||
})
|
})
|
||||||
|
|
||||||
avg_score = round(sum(scores) / len(scores), 1) if scores else 0
|
avg_score = round(sum(scores) / len(scores), 1) if scores else 0
|
||||||
@@ -287,7 +485,7 @@ def calculate_actor_coverage_score(actor_id: str, db: Session) -> dict:
|
|||||||
"actor_id": str(actor.id),
|
"actor_id": str(actor.id),
|
||||||
"actor_name": actor.name,
|
"actor_name": actor.name,
|
||||||
"total_score": avg_score,
|
"total_score": avg_score,
|
||||||
"techniques_count": len(techniques),
|
"techniques_count": len(technique_ids),
|
||||||
"techniques_covered": len([s for s in scores if s > 50]),
|
"techniques_covered": len([s for s in scores if s > 50]),
|
||||||
"techniques_detail": details,
|
"techniques_detail": details,
|
||||||
}
|
}
|
||||||
@@ -297,10 +495,13 @@ def calculate_actor_coverage_score(actor_id: str, db: Session) -> dict:
|
|||||||
|
|
||||||
|
|
||||||
def calculate_organization_score(db: Session) -> dict:
|
def calculate_organization_score(db: Session) -> dict:
|
||||||
"""Calculate the overall organization security score."""
|
"""Calculate the overall organization security score.
|
||||||
# All techniques
|
|
||||||
all_techniques = db.query(Technique).all()
|
Uses ``bulk_technique_scores`` to compute all technique scores in
|
||||||
total_count = len(all_techniques)
|
5 aggregated queries instead of N*5.
|
||||||
|
"""
|
||||||
|
scores_map = bulk_technique_scores(db)
|
||||||
|
total_count = len(scores_map)
|
||||||
|
|
||||||
if total_count == 0:
|
if total_count == 0:
|
||||||
return {
|
return {
|
||||||
@@ -313,27 +514,16 @@ def calculate_organization_score(db: Session) -> dict:
|
|||||||
"techniques_total": 0,
|
"techniques_total": 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Calculate scores for all techniques (with caching for performance)
|
all_scores = [v["total_score"] for v in scores_map.values()]
|
||||||
all_scores = []
|
|
||||||
evaluated_count = 0
|
|
||||||
|
|
||||||
for tech in all_techniques:
|
|
||||||
result = calculate_technique_score(tech, db)
|
|
||||||
score = result["total_score"]
|
|
||||||
all_scores.append(score)
|
|
||||||
if score > 0:
|
|
||||||
evaluated_count += 1
|
|
||||||
|
|
||||||
# Total coverage: average of all evaluated techniques
|
|
||||||
evaluated_scores = [s for s in all_scores if s > 0]
|
evaluated_scores = [s for s in all_scores if s > 0]
|
||||||
|
evaluated_count = len(evaluated_scores)
|
||||||
|
|
||||||
total_coverage = (
|
total_coverage = (
|
||||||
round(sum(evaluated_scores) / len(evaluated_scores), 1)
|
round(sum(evaluated_scores) / len(evaluated_scores), 1)
|
||||||
if evaluated_scores
|
if evaluated_scores else 0
|
||||||
else 0
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Critical coverage: techniques with high-severity templates
|
# Critical coverage: techniques with high/critical severity templates
|
||||||
# (simplified: techniques that have tests are "critical")
|
|
||||||
from app.models.test_template import TestTemplate
|
from app.models.test_template import TestTemplate
|
||||||
|
|
||||||
critical_mitre_ids = set(
|
critical_mitre_ids = set(
|
||||||
@@ -344,38 +534,35 @@ def calculate_organization_score(db: Session) -> dict:
|
|||||||
.all()
|
.all()
|
||||||
)
|
)
|
||||||
|
|
||||||
critical_techniques = [
|
critical_scores = [
|
||||||
t for t in all_techniques if t.mitre_id in critical_mitre_ids
|
v["total_score"]
|
||||||
|
for v in scores_map.values()
|
||||||
|
if v.get("mitre_id") in critical_mitre_ids
|
||||||
]
|
]
|
||||||
if critical_techniques:
|
critical_coverage = (
|
||||||
critical_scores = []
|
round(sum(critical_scores) / len(critical_scores), 1)
|
||||||
for tech in critical_techniques:
|
if critical_scores else 0
|
||||||
result = calculate_technique_score(tech, db)
|
)
|
||||||
critical_scores.append(result["total_score"])
|
|
||||||
critical_coverage = round(sum(critical_scores) / len(critical_scores), 1)
|
|
||||||
else:
|
|
||||||
critical_coverage = 0
|
|
||||||
|
|
||||||
# Detection maturity: based on detection rule coverage
|
# Detection maturity (2 scalar queries — already efficient)
|
||||||
total_rules = (
|
total_rules = (
|
||||||
db.query(func.count(DetectionRule.id))
|
db.query(func.count(DetectionRule.id))
|
||||||
.filter(DetectionRule.is_active == True)
|
.filter(DetectionRule.is_active == True) # noqa: E712
|
||||||
.scalar()
|
.scalar()
|
||||||
) or 0
|
) or 0
|
||||||
triggered_total = (
|
triggered_total = (
|
||||||
db.query(func.count(TestDetectionResult.id))
|
db.query(func.count(TestDetectionResult.id))
|
||||||
.filter(TestDetectionResult.triggered == True)
|
.filter(TestDetectionResult.triggered == True) # noqa: E712
|
||||||
.scalar()
|
.scalar()
|
||||||
) or 0
|
) or 0
|
||||||
|
|
||||||
detection_maturity = (
|
detection_maturity = (
|
||||||
round((triggered_total / total_rules) * 100, 1)
|
round((triggered_total / total_rules) * 100, 1)
|
||||||
if total_rules > 0
|
if total_rules > 0 else 0
|
||||||
else 0
|
|
||||||
)
|
)
|
||||||
detection_maturity = min(detection_maturity, 100)
|
detection_maturity = min(detection_maturity, 100)
|
||||||
|
|
||||||
# Response readiness: based on remediation completion
|
# Response readiness (2 scalar queries — already efficient)
|
||||||
remediation_total = (
|
remediation_total = (
|
||||||
db.query(func.count(Test.id))
|
db.query(func.count(Test.id))
|
||||||
.filter(Test.remediation_status.isnot(None))
|
.filter(Test.remediation_status.isnot(None))
|
||||||
@@ -389,11 +576,9 @@ def calculate_organization_score(db: Session) -> dict:
|
|||||||
|
|
||||||
response_readiness = (
|
response_readiness = (
|
||||||
round((remediation_completed / remediation_total) * 100, 1)
|
round((remediation_completed / remediation_total) * 100, 1)
|
||||||
if remediation_total > 0
|
if remediation_total > 0 else 0
|
||||||
else 0
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Overall score: weighted average of sub-scores
|
|
||||||
overall = round(
|
overall = round(
|
||||||
total_coverage * 0.4
|
total_coverage * 0.4
|
||||||
+ critical_coverage * 0.25
|
+ critical_coverage * 0.25
|
||||||
|
|||||||
@@ -2,6 +2,9 @@
|
|||||||
|
|
||||||
Provides point-in-time coverage captures with normalised per-technique
|
Provides point-in-time coverage captures with normalised per-technique
|
||||||
storage and temporal comparison between any two snapshots.
|
storage and temporal comparison between any two snapshots.
|
||||||
|
|
||||||
|
Uses ``bulk_technique_scores`` so that snapshot creation runs in a fixed
|
||||||
|
number of SQL queries regardless of technique count.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
@@ -14,7 +17,10 @@ from sqlalchemy.orm import Session
|
|||||||
from app.models.technique import Technique
|
from app.models.technique import Technique
|
||||||
from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState
|
from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState
|
||||||
from app.models.enums import TechniqueStatus
|
from app.models.enums import TechniqueStatus
|
||||||
from app.services.scoring_service import calculate_technique_score, calculate_organization_score
|
from app.services.scoring_service import (
|
||||||
|
bulk_technique_scores,
|
||||||
|
calculate_organization_score,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -31,14 +37,16 @@ def create_snapshot(
|
|||||||
) -> CoverageSnapshot:
|
) -> CoverageSnapshot:
|
||||||
"""Capture the current coverage state into a new snapshot.
|
"""Capture the current coverage state into a new snapshot.
|
||||||
|
|
||||||
1. Fetch every technique with its status and score.
|
1. Bulk-fetch all technique scores in 5 aggregated queries.
|
||||||
2. Compute aggregate counts.
|
2. Walk the already-loaded techniques to count statuses.
|
||||||
3. Persist a ``CoverageSnapshot`` with normalised
|
3. Compute the org score from the same bulk data.
|
||||||
|
4. Persist a ``CoverageSnapshot`` with normalised
|
||||||
``SnapshotTechniqueState`` rows.
|
``SnapshotTechniqueState`` rows.
|
||||||
"""
|
"""
|
||||||
|
scores_map = bulk_technique_scores(db)
|
||||||
|
|
||||||
techniques = db.query(Technique).all()
|
techniques = db.query(Technique).all()
|
||||||
|
|
||||||
# Aggregate counters
|
|
||||||
validated_count = 0
|
validated_count = 0
|
||||||
partial_count = 0
|
partial_count = 0
|
||||||
not_covered_count = 0
|
not_covered_count = 0
|
||||||
@@ -54,7 +62,6 @@ def create_snapshot(
|
|||||||
else (tech.status_global or "not_evaluated")
|
else (tech.status_global or "not_evaluated")
|
||||||
)
|
)
|
||||||
|
|
||||||
# Count by status
|
|
||||||
if status_value == "validated":
|
if status_value == "validated":
|
||||||
validated_count += 1
|
validated_count += 1
|
||||||
elif status_value == "partial":
|
elif status_value == "partial":
|
||||||
@@ -66,20 +73,17 @@ def create_snapshot(
|
|||||||
else:
|
else:
|
||||||
not_evaluated_count += 1
|
not_evaluated_count += 1
|
||||||
|
|
||||||
# Compute technique score
|
entry = scores_map.get(tech.id, {})
|
||||||
score_data = calculate_technique_score(tech, db)
|
|
||||||
technique_rows.append({
|
technique_rows.append({
|
||||||
"technique_id": tech.id,
|
"technique_id": tech.id,
|
||||||
"mitre_id": tech.mitre_id,
|
"mitre_id": tech.mitre_id,
|
||||||
"status": status_value,
|
"status": status_value,
|
||||||
"score": score_data["total_score"],
|
"score": entry.get("total_score", 0),
|
||||||
})
|
})
|
||||||
|
|
||||||
# Organization score
|
|
||||||
org_data = calculate_organization_score(db)
|
org_data = calculate_organization_score(db)
|
||||||
org_score = org_data.get("overall_score", 0)
|
org_score = org_data.get("overall_score", 0)
|
||||||
|
|
||||||
# Create the snapshot
|
|
||||||
snapshot = CoverageSnapshot(
|
snapshot = CoverageSnapshot(
|
||||||
name=name,
|
name=name,
|
||||||
organization_score=org_score,
|
organization_score=org_score,
|
||||||
@@ -92,9 +96,8 @@ def create_snapshot(
|
|||||||
created_by=user_id,
|
created_by=user_id,
|
||||||
)
|
)
|
||||||
db.add(snapshot)
|
db.add(snapshot)
|
||||||
db.flush() # get snapshot.id
|
db.flush()
|
||||||
|
|
||||||
# Create normalised technique state rows
|
|
||||||
for row in technique_rows:
|
for row in technique_rows:
|
||||||
state = SnapshotTechniqueState(
|
state = SnapshotTechniqueState(
|
||||||
snapshot_id=snapshot.id,
|
snapshot_id=snapshot.id,
|
||||||
|
|||||||
@@ -4,6 +4,9 @@ based on the state and result of its associated tests.
|
|||||||
V2 rules account for dual Red/Blue validation and use
|
V2 rules account for dual Red/Blue validation and use
|
||||||
``detection_result`` (filled by Blue Team) instead of the legacy
|
``detection_result`` (filled by Blue Team) instead of the legacy
|
||||||
``result`` field.
|
``result`` field.
|
||||||
|
|
||||||
|
This function mutates the technique but does **not** commit.
|
||||||
|
The caller is responsible for committing the session.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
@@ -42,5 +45,3 @@ def recalculate_technique_status(db: Session, technique: Technique) -> None:
|
|||||||
technique.status_global = TechniqueStatus.partial
|
technique.status_global = TechniqueStatus.partial
|
||||||
else:
|
else:
|
||||||
technique.status_global = TechniqueStatus.in_progress
|
technique.status_global = TechniqueStatus.in_progress
|
||||||
|
|
||||||
db.commit()
|
|
||||||
|
|||||||
@@ -7,8 +7,9 @@ for each step in the test lifecycle:
|
|||||||
↓
|
↓
|
||||||
rejected → draft
|
rejected → draft
|
||||||
|
|
||||||
Every public function validates the transition, mutates the test, writes an
|
Every public function validates the transition, mutates the test, and writes
|
||||||
audit-log entry, and commits the session.
|
an audit-log entry. The caller (router) is responsible for committing the
|
||||||
|
session via the Unit of Work pattern.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
@@ -122,7 +123,6 @@ def start_execution(db: Session, test: Test, user: User) -> Test:
|
|||||||
)
|
)
|
||||||
test.execution_date = now
|
test.execution_date = now
|
||||||
test.red_started_at = now
|
test.red_started_at = now
|
||||||
db.commit()
|
|
||||||
return test
|
return test
|
||||||
|
|
||||||
|
|
||||||
@@ -161,7 +161,6 @@ def submit_red_evidence(db: Session, test: Test, user: User) -> Test:
|
|||||||
# Start Blue Team timer
|
# Start Blue Team timer
|
||||||
test.blue_started_at = now
|
test.blue_started_at = now
|
||||||
test.blue_paused_seconds = 0
|
test.blue_paused_seconds = 0
|
||||||
db.commit()
|
|
||||||
return test
|
return test
|
||||||
|
|
||||||
|
|
||||||
@@ -196,7 +195,6 @@ def submit_blue_evidence(db: Session, test: Test, user: User) -> Test:
|
|||||||
description=f"Blue Team evaluation: {test.name}",
|
description=f"Blue Team evaluation: {test.name}",
|
||||||
)
|
)
|
||||||
|
|
||||||
db.commit()
|
|
||||||
return test
|
return test
|
||||||
|
|
||||||
|
|
||||||
@@ -222,7 +220,6 @@ def pause_timer(db: Session, test: Test, user: User) -> Test:
|
|||||||
entity_id=test.id,
|
entity_id=test.id,
|
||||||
details={"state": test.state.value},
|
details={"state": test.state.value},
|
||||||
)
|
)
|
||||||
db.commit()
|
|
||||||
return test
|
return test
|
||||||
|
|
||||||
|
|
||||||
@@ -252,7 +249,6 @@ def resume_timer(db: Session, test: Test, user: User) -> Test:
|
|||||||
entity_id=test.id,
|
entity_id=test.id,
|
||||||
details={"paused_seconds": paused_seconds, "state": test.state.value},
|
details={"paused_seconds": paused_seconds, "state": test.state.value},
|
||||||
)
|
)
|
||||||
db.commit()
|
|
||||||
return test
|
return test
|
||||||
|
|
||||||
|
|
||||||
@@ -421,14 +417,12 @@ def check_dual_validation(db: Session, test: Test) -> Test:
|
|||||||
|
|
||||||
if red_status == "rejected" or blue_status == "rejected":
|
if red_status == "rejected" or blue_status == "rejected":
|
||||||
test.state = TestState.rejected
|
test.state = TestState.rejected
|
||||||
db.commit()
|
|
||||||
try:
|
try:
|
||||||
notify_test_state_change(db, test, "rejected")
|
notify_test_state_change(db, test, "rejected")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Notification failed for test %s (rejected): %s", test.id, e, exc_info=True)
|
logger.warning("Notification failed for test %s (rejected): %s", test.id, e, exc_info=True)
|
||||||
elif red_status == "approved" and blue_status == "approved":
|
elif red_status == "approved" and blue_status == "approved":
|
||||||
test.state = TestState.validated
|
test.state = TestState.validated
|
||||||
db.commit()
|
|
||||||
# Invalidate cached scores — a validation changes org-level numbers
|
# Invalidate cached scores — a validation changes org-level numbers
|
||||||
try:
|
try:
|
||||||
from app.services.score_cache import invalidate
|
from app.services.score_cache import invalidate
|
||||||
@@ -439,10 +433,6 @@ def check_dual_validation(db: Session, test: Test) -> Test:
|
|||||||
notify_test_state_change(db, test, "validated")
|
notify_test_state_change(db, test, "validated")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Notification failed for test %s (validated): %s", test.id, e, exc_info=True)
|
logger.warning("Notification failed for test %s (validated): %s", test.id, e, exc_info=True)
|
||||||
else:
|
|
||||||
# One side hasn't voted yet — stay in_review, just flush
|
|
||||||
db.commit()
|
|
||||||
|
|
||||||
return test
|
return test
|
||||||
|
|
||||||
|
|
||||||
@@ -533,8 +523,7 @@ def handle_remediation_completed(db: Session, test: Test, user: User) -> Test |
|
|||||||
entity_id=retest.id,
|
entity_id=retest.id,
|
||||||
)
|
)
|
||||||
|
|
||||||
db.commit()
|
db.flush()
|
||||||
db.refresh(retest)
|
|
||||||
return retest
|
return retest
|
||||||
|
|
||||||
|
|
||||||
@@ -599,5 +588,4 @@ def reopen_test(db: Session, test: Test, user: User) -> Test:
|
|||||||
test.red_paused_seconds = 0
|
test.red_paused_seconds = 0
|
||||||
test.blue_paused_seconds = 0
|
test.blue_paused_seconds = 0
|
||||||
|
|
||||||
db.commit()
|
|
||||||
return test
|
return test
|
||||||
|
|||||||
Reference in New Issue
Block a user