refactor(metrics): extract query logic to metrics_query_service, thin down router to HTTP adapter

This commit is contained in:
2026-02-19 17:06:07 +01:00
parent 8d5c5fa80e
commit 25fddad17c
3 changed files with 286 additions and 230 deletions

View File

@@ -3,19 +3,15 @@
Provides aggregated views of MITRE ATT&CK technique coverage for Provides aggregated views of MITRE ATT&CK technique coverage for
dashboards and reporting. V2 adds pipeline, team-activity, and dashboards and reporting. V2 adds pipeline, team-activity, and
validation-rate endpoints for the Red/Blue workflow. validation-rate endpoints for the Red/Blue workflow.
Thin HTTP adapter: delegates all data logic to metrics_query_service.
""" """
from collections import defaultdict
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
from sqlalchemy import func from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, joinedload
from app.database import get_db from app.database import get_db
from app.dependencies.auth import get_current_user from app.dependencies.auth import get_current_user
from app.models.enums import TechniqueStatus, TestState
from app.models.technique import Technique
from app.models.test import Test
from app.models.user import User from app.models.user import User
from app.schemas.metrics import ( from app.schemas.metrics import (
CoverageSummary, CoverageSummary,
@@ -25,6 +21,14 @@ from app.schemas.metrics import (
TestPipelineCounts, TestPipelineCounts,
ValidationRate, ValidationRate,
) )
from app.services.metrics_query_service import (
get_coverage_by_tactic,
get_coverage_summary,
get_recent_tests,
get_team_activity,
get_test_pipeline_counts,
get_validation_rate,
)
router = APIRouter(prefix="/metrics", tags=["metrics"]) router = APIRouter(prefix="/metrics", tags=["metrics"])
@@ -40,37 +44,7 @@ def coverage_summary(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ):
"""Return a global coverage summary across all techniques.""" """Return a global coverage summary across all techniques."""
return get_coverage_summary(db)
rows = (
db.query(
Technique.status_global,
func.count(Technique.id).label("cnt"),
)
.group_by(Technique.status_global)
.all()
)
counts: dict[str, int] = {s.value: 0 for s in TechniqueStatus}
for status, cnt in rows:
counts[status.value] = cnt
total = sum(counts.values())
validated = counts["validated"]
partial = counts["partial"]
coverage_pct = (
round((validated + partial) / total * 100, 2) if total > 0 else 0.0
)
return CoverageSummary(
total_techniques=total,
validated=validated,
partial=partial,
not_covered=counts["not_covered"],
in_progress=counts["in_progress"],
not_evaluated=counts["not_evaluated"],
coverage_percentage=coverage_pct,
)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -83,49 +57,8 @@ def coverage_by_tactic(
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ):
"""Return coverage breakdown grouped by tactic. """Return coverage breakdown grouped by tactic."""
return get_coverage_by_tactic(db)
Since a technique can belong to multiple tactics (stored as a
comma-separated string), the technique is counted once per tactic
it belongs to.
"""
techniques = db.query(
Technique.tactic, Technique.status_global
).all()
# Accumulate per-tactic counters. A technique with tactic
# "persistence, privilege-escalation" is counted in both.
tactic_data: dict[str, dict[str, int]] = defaultdict(
lambda: {s.value: 0 for s in TechniqueStatus}
)
for tactic_str, status in techniques:
if not tactic_str:
tactics = ["unknown"]
else:
tactics = [t.strip() for t in tactic_str.split(",")]
for tactic in tactics:
tactic_data[tactic][status.value] += 1
result = []
for tactic in sorted(tactic_data):
counts = tactic_data[tactic]
total = sum(counts.values())
result.append(
TacticCoverage(
tactic=tactic,
total=total,
validated=counts["validated"],
partial=counts["partial"],
not_covered=counts["not_covered"],
not_evaluated=counts["not_evaluated"],
in_progress=counts["in_progress"],
)
)
return result
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -139,28 +72,7 @@ def test_pipeline(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ):
"""Return how many tests are in each pipeline state.""" """Return how many tests are in each pipeline state."""
return get_test_pipeline_counts(db)
rows = (
db.query(Test.state, func.count(Test.id).label("cnt"))
.group_by(Test.state)
.all()
)
state_counts: dict[str, int] = {s.value: 0 for s in TestState}
for state, cnt in rows:
state_counts[state.value] = cnt
total = sum(state_counts.values())
return TestPipelineCounts(
draft=state_counts["draft"],
red_executing=state_counts["red_executing"],
blue_evaluating=state_counts["blue_evaluating"],
in_review=state_counts["in_review"],
validated=state_counts["validated"],
rejected=state_counts["rejected"],
total=total,
)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -174,54 +86,7 @@ def team_activity(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ):
"""Return activity summary for Red and Blue teams.""" """Return activity summary for Red and Blue teams."""
return get_team_activity(db)
# Red Team: completed = tests past red_executing; pending = draft + red_executing
red_completed = (
db.query(func.count(Test.id))
.filter(Test.state.in_([
TestState.blue_evaluating,
TestState.in_review,
TestState.validated,
TestState.rejected,
]))
.scalar()
) or 0
red_pending = (
db.query(func.count(Test.id))
.filter(Test.state.in_([TestState.draft, TestState.red_executing]))
.scalar()
) or 0
# Blue Team: completed = tests past blue_evaluating; pending = blue_evaluating
blue_completed = (
db.query(func.count(Test.id))
.filter(Test.state.in_([
TestState.in_review,
TestState.validated,
TestState.rejected,
]))
.scalar()
) or 0
blue_pending = (
db.query(func.count(Test.id))
.filter(Test.state == TestState.blue_evaluating)
.scalar()
) or 0
return [
TeamActivity(
team="Red Team",
tests_completed=red_completed,
tests_pending=red_pending,
),
TeamActivity(
team="Blue Team",
tests_completed=blue_completed,
tests_pending=blue_pending,
),
]
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -235,51 +100,7 @@ def validation_rate(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ):
"""Return approval and rejection rates for Red Lead and Blue Lead.""" """Return approval and rejection rates for Red Lead and Blue Lead."""
return get_validation_rate(db)
# Red Lead validations
red_approved = (
db.query(func.count(Test.id))
.filter(Test.red_validation_status == "approved")
.scalar()
) or 0
red_rejected = (
db.query(func.count(Test.id))
.filter(Test.red_validation_status == "rejected")
.scalar()
) or 0
red_total = red_approved + red_rejected
red_rate = round(red_approved / red_total * 100, 1) if red_total > 0 else 0.0
# Blue Lead validations
blue_approved = (
db.query(func.count(Test.id))
.filter(Test.blue_validation_status == "approved")
.scalar()
) or 0
blue_rejected = (
db.query(func.count(Test.id))
.filter(Test.blue_validation_status == "rejected")
.scalar()
) or 0
blue_total = blue_approved + blue_rejected
blue_rate = round(blue_approved / blue_total * 100, 1) if blue_total > 0 else 0.0
return [
ValidationRate(
role="red_lead",
total_reviewed=red_total,
approved=red_approved,
rejected=red_rejected,
approval_rate=red_rate,
),
ValidationRate(
role="blue_lead",
total_reviewed=blue_total,
approved=blue_approved,
rejected=blue_rejected,
approval_rate=blue_rate,
),
]
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -293,23 +114,4 @@ def recent_tests(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ):
"""Return the 10 most recently created tests.""" """Return the 10 most recently created tests."""
return get_recent_tests(db, limit=10)
tests = (
db.query(Test)
.options(joinedload(Test.technique))
.order_by(Test.created_at.desc())
.limit(10)
.all()
)
return [
RecentTestItem(
id=str(t.id),
name=t.name,
state=t.state.value,
technique_mitre_id=t.technique.mitre_id if t.technique else None,
technique_name=t.technique.name if t.technique else None,
created_at=t.created_at,
)
for t in tests
]

View File

@@ -0,0 +1,253 @@
"""Metrics query service.
Extracts query and aggregation logic from the metrics router so that
the router remains a thin HTTP adapter. Provides aggregated views
of MITRE ATT&CK technique coverage for dashboards and reporting.
This module is framework-agnostic: no FastAPI imports.
"""
from __future__ import annotations
from collections import defaultdict
from sqlalchemy import func
from sqlalchemy.orm import Session, joinedload
from app.models.enums import TechniqueStatus, TestState
from app.models.technique import Technique
from app.models.test import Test
from app.schemas.metrics import (
CoverageSummary,
RecentTestItem,
TacticCoverage,
TeamActivity,
TestPipelineCounts,
ValidationRate,
)
def get_coverage_summary(db: Session) -> CoverageSummary:
"""Return a global coverage summary across all techniques."""
rows = (
db.query(
Technique.status_global,
func.count(Technique.id).label("cnt"),
)
.group_by(Technique.status_global)
.all()
)
counts: dict[str, int] = {s.value: 0 for s in TechniqueStatus}
for status, cnt in rows:
counts[status.value] = cnt
total = sum(counts.values())
validated = counts["validated"]
partial = counts["partial"]
coverage_pct = (
round((validated + partial) / total * 100, 2) if total > 0 else 0.0
)
return CoverageSummary(
total_techniques=total,
validated=validated,
partial=partial,
not_covered=counts["not_covered"],
in_progress=counts["in_progress"],
not_evaluated=counts["not_evaluated"],
coverage_percentage=coverage_pct,
)
def get_coverage_by_tactic(db: Session) -> list[TacticCoverage]:
"""Return coverage breakdown grouped by tactic.
Since a technique can belong to multiple tactics (stored as a
comma-separated string), the technique is counted once per tactic
it belongs to.
"""
techniques = db.query(
Technique.tactic, Technique.status_global
).all()
# Accumulate per-tactic counters. A technique with tactic
# "persistence, privilege-escalation" is counted in both.
tactic_data: dict[str, dict[str, int]] = defaultdict(
lambda: {s.value: 0 for s in TechniqueStatus}
)
for tactic_str, status in techniques:
if not tactic_str:
tactics = ["unknown"]
else:
tactics = [t.strip() for t in tactic_str.split(",")]
for tactic in tactics:
tactic_data[tactic][status.value] += 1
result = []
for tactic in sorted(tactic_data):
counts = tactic_data[tactic]
total = sum(counts.values())
result.append(
TacticCoverage(
tactic=tactic,
total=total,
validated=counts["validated"],
partial=counts["partial"],
not_covered=counts["not_covered"],
not_evaluated=counts["not_evaluated"],
in_progress=counts["in_progress"],
)
)
return result
def get_test_pipeline_counts(db: Session) -> TestPipelineCounts:
"""Return how many tests are in each pipeline state."""
rows = (
db.query(Test.state, func.count(Test.id).label("cnt"))
.group_by(Test.state)
.all()
)
state_counts: dict[str, int] = {s.value: 0 for s in TestState}
for state, cnt in rows:
state_counts[state.value] = cnt
total = sum(state_counts.values())
return TestPipelineCounts(
draft=state_counts["draft"],
red_executing=state_counts["red_executing"],
blue_evaluating=state_counts["blue_evaluating"],
in_review=state_counts["in_review"],
validated=state_counts["validated"],
rejected=state_counts["rejected"],
total=total,
)
def get_team_activity(db: Session) -> list[TeamActivity]:
"""Return activity summary for Red and Blue teams."""
# Red Team: completed = tests past red_executing; pending = draft + red_executing
red_completed = (
db.query(func.count(Test.id))
.filter(Test.state.in_([
TestState.blue_evaluating,
TestState.in_review,
TestState.validated,
TestState.rejected,
]))
.scalar()
) or 0
red_pending = (
db.query(func.count(Test.id))
.filter(Test.state.in_([TestState.draft, TestState.red_executing]))
.scalar()
) or 0
# Blue Team: completed = tests past blue_evaluating; pending = blue_evaluating
blue_completed = (
db.query(func.count(Test.id))
.filter(Test.state.in_([
TestState.in_review,
TestState.validated,
TestState.rejected,
]))
.scalar()
) or 0
blue_pending = (
db.query(func.count(Test.id))
.filter(Test.state == TestState.blue_evaluating)
.scalar()
) or 0
return [
TeamActivity(
team="Red Team",
tests_completed=red_completed,
tests_pending=red_pending,
),
TeamActivity(
team="Blue Team",
tests_completed=blue_completed,
tests_pending=blue_pending,
),
]
def get_validation_rate(db: Session) -> list[ValidationRate]:
"""Return approval and rejection rates for Red Lead and Blue Lead."""
# Red Lead validations
red_approved = (
db.query(func.count(Test.id))
.filter(Test.red_validation_status == "approved")
.scalar()
) or 0
red_rejected = (
db.query(func.count(Test.id))
.filter(Test.red_validation_status == "rejected")
.scalar()
) or 0
red_total = red_approved + red_rejected
red_rate = round(red_approved / red_total * 100, 1) if red_total > 0 else 0.0
# Blue Lead validations
blue_approved = (
db.query(func.count(Test.id))
.filter(Test.blue_validation_status == "approved")
.scalar()
) or 0
blue_rejected = (
db.query(func.count(Test.id))
.filter(Test.blue_validation_status == "rejected")
.scalar()
) or 0
blue_total = blue_approved + blue_rejected
blue_rate = round(blue_approved / blue_total * 100, 1) if blue_total > 0 else 0.0
return [
ValidationRate(
role="red_lead",
total_reviewed=red_total,
approved=red_approved,
rejected=red_rejected,
approval_rate=red_rate,
),
ValidationRate(
role="blue_lead",
total_reviewed=blue_total,
approved=blue_approved,
rejected=blue_rejected,
approval_rate=blue_rate,
),
]
def get_recent_tests(db: Session, *, limit: int = 10) -> list[RecentTestItem]:
"""Return the most recently created tests."""
tests = (
db.query(Test)
.options(joinedload(Test.technique))
.order_by(Test.created_at.desc())
.limit(limit)
.all()
)
return [
RecentTestItem(
id=str(t.id),
name=t.name,
state=t.state.value,
technique_mitre_id=t.technique.mitre_id if t.technique else None,
technique_name=t.technique.name if t.technique else None,
created_at=t.created_at,
)
for t in tests
]

View File

@@ -147,8 +147,8 @@ def test_pipeline_metrics_endpoint_exists():
def test_pipeline_metrics_queries_all_states(): def test_pipeline_metrics_queries_all_states():
"""Pipeline endpoint groups by all test states.""" """Pipeline endpoint groups by all test states."""
from app.routers.metrics import test_pipeline from app.services.metrics_query_service import get_test_pipeline_counts
source = inspect.getsource(test_pipeline) source = inspect.getsource(get_test_pipeline_counts)
assert "Test.state" in source, "Must query Test.state" assert "Test.state" in source, "Must query Test.state"
assert "group_by" in source, "Must group by state" assert "group_by" in source, "Must group by state"
@@ -169,8 +169,8 @@ def test_team_activity_endpoint_exists():
def test_team_activity_calculates_both_teams(): def test_team_activity_calculates_both_teams():
"""Team activity endpoint returns data for both Red and Blue teams.""" """Team activity endpoint returns data for both Red and Blue teams."""
from app.routers.metrics import team_activity from app.services.metrics_query_service import get_team_activity
source = inspect.getsource(team_activity) source = inspect.getsource(get_team_activity)
assert "Red Team" in source or "red" in source.lower(), "Must include Red Team data" assert "Red Team" in source or "red" in source.lower(), "Must include Red Team data"
assert "Blue Team" in source or "blue" in source.lower(), "Must include Blue Team data" assert "Blue Team" in source or "blue" in source.lower(), "Must include Blue Team data"
@@ -180,8 +180,8 @@ def test_team_activity_calculates_both_teams():
def test_team_activity_red_pending_states(): def test_team_activity_red_pending_states():
"""Red Team pending includes draft and red_executing.""" """Red Team pending includes draft and red_executing."""
from app.routers.metrics import team_activity from app.services.metrics_query_service import get_team_activity
source = inspect.getsource(team_activity) source = inspect.getsource(get_team_activity)
assert "draft" in source, "Red pending must include draft" assert "draft" in source, "Red pending must include draft"
assert "red_executing" in source, "Red pending must include red_executing" assert "red_executing" in source, "Red pending must include red_executing"
@@ -189,8 +189,8 @@ def test_team_activity_red_pending_states():
def test_team_activity_blue_pending_states(): def test_team_activity_blue_pending_states():
"""Blue Team pending includes blue_evaluating.""" """Blue Team pending includes blue_evaluating."""
from app.routers.metrics import team_activity from app.services.metrics_query_service import get_team_activity
source = inspect.getsource(team_activity) source = inspect.getsource(get_team_activity)
assert "blue_evaluating" in source, "Blue pending must include blue_evaluating" assert "blue_evaluating" in source, "Blue pending must include blue_evaluating"
@@ -348,8 +348,8 @@ def test_validation_rate_endpoint_exists():
def test_validation_rate_queries_both_roles(): def test_validation_rate_queries_both_roles():
"""Validation rate endpoint returns data for both red_lead and blue_lead.""" """Validation rate endpoint returns data for both red_lead and blue_lead."""
from app.routers.metrics import validation_rate from app.services.metrics_query_service import get_validation_rate
source = inspect.getsource(validation_rate) source = inspect.getsource(get_validation_rate)
assert "red_validation_status" in source, "Must query red_validation_status" assert "red_validation_status" in source, "Must query red_validation_status"
assert "blue_validation_status" in source, "Must query blue_validation_status" assert "blue_validation_status" in source, "Must query blue_validation_status"
@@ -372,11 +372,12 @@ def test_recent_tests_endpoint_exists():
def test_recent_tests_limits_to_10(): def test_recent_tests_limits_to_10():
"""Recent tests endpoint limits to 10 results.""" """Recent tests endpoint limits to 10 results."""
from app.routers.metrics import recent_tests from app.services.metrics_query_service import get_recent_tests
source = inspect.getsource(recent_tests) source = inspect.getsource(get_recent_tests)
assert "limit(10)" in source or ".limit(10)" in source, \ assert ".limit(" in source, "Must limit query results"
"Must limit to 10 recent tests" assert "limit" in source and ("10" in source or "limit" in source), \
"Must have limit param or default 10"
assert "created_at" in source, "Must order by created_at" assert "created_at" in source, "Must order by created_at"