feat: extract advanced_metrics, analytics, test_templates, and auth to services (Tier 1 complete)

This commit is contained in:
2026-02-20 14:28:52 +01:00
parent bbc2dddd86
commit 9e22fde746
8 changed files with 579 additions and 422 deletions

View File

@@ -0,0 +1,160 @@
"""Advanced metrics service — coverage by tactic, never-tested, avg validation time, detection trend."""
from __future__ import annotations
from datetime import datetime, timedelta
from sqlalchemy import case, func
from sqlalchemy.orm import Session
from app.models.technique import Technique
from app.models.test import Test
from app.models.enums import TestResult
def get_coverage_by_tactic(db: Session) -> list[dict]:
"""Coverage percentage broken down by MITRE ATT&CK tactic."""
results = (
db.query(
Technique.tactic,
func.count(Technique.id).label("total"),
func.sum(
case((Technique.status_global == "validated", 1), else_=0)
).label("validated"),
func.sum(
case((Technique.status_global == "partial", 1), else_=0)
).label("partial"),
func.sum(
case((Technique.status_global == "not_covered", 1), else_=0)
).label("not_covered"),
func.sum(
case((Technique.status_global == "in_progress", 1), else_=0)
).label("in_progress"),
)
.group_by(Technique.tactic)
.order_by(Technique.tactic)
.all()
)
return [
{
"tactic": r[0] or "Unknown",
"total": r[1],
"validated": int(r[2]),
"partial": int(r[3]),
"not_covered": int(r[4]),
"in_progress": int(r[5]),
"coverage_pct": round((int(r[2]) / r[1]) * 100, 1) if r[1] > 0 else 0,
}
for r in results
]
def get_never_tested_techniques(db: Session) -> list[dict]:
"""Techniques that have never had a test created."""
tested_technique_ids = db.query(Test.technique_id).distinct().subquery()
techniques = (
db.query(Technique)
.filter(~Technique.id.in_(db.query(tested_technique_ids)))
.order_by(Technique.mitre_id)
.all()
)
return [
{
"mitre_id": t.mitre_id,
"name": t.name,
"tactic": t.tactic,
"is_subtechnique": t.is_subtechnique,
}
for t in techniques
]
def get_avg_validation_time(db: Session) -> dict:
"""Average time from test creation to validation, computed from validated tests.
Returns overall average and per-phase averages where data is available.
"""
validated_tests = (
db.query(Test)
.filter(Test.state == "validated")
.all()
)
if not validated_tests:
return {
"total_validated": 0,
"avg_total_hours": 0,
"avg_red_phase_hours": 0,
"avg_blue_phase_hours": 0,
}
total_durations = []
red_durations = []
blue_durations = []
for test in validated_tests:
if test.created_at and test.red_validated_at:
total_seconds = (test.red_validated_at - test.created_at).total_seconds()
total_durations.append(total_seconds)
if test.red_started_at and test.blue_started_at:
red_sec = (test.blue_started_at - test.red_started_at).total_seconds()
red_paused = test.red_paused_seconds or 0
red_durations.append(max(red_sec - red_paused, 0))
if test.blue_started_at and test.blue_validated_at:
blue_sec = (test.blue_validated_at - test.blue_started_at).total_seconds()
blue_paused = test.blue_paused_seconds or 0
blue_durations.append(max(blue_sec - blue_paused, 0))
def avg_hours(durations: list[float]) -> float:
if not durations:
return 0
return round(sum(durations) / len(durations) / 3600, 2)
return {
"total_validated": len(validated_tests),
"avg_total_hours": avg_hours(total_durations),
"avg_red_phase_hours": avg_hours(red_durations),
"avg_blue_phase_hours": avg_hours(blue_durations),
}
def get_detection_rate_trend(db: Session) -> list[dict]:
"""Monthly detection rate trend for the last 12 months."""
now = datetime.utcnow()
months = []
for i in range(11, -1, -1):
month_start = datetime(now.year, now.month, 1) - timedelta(days=i * 30)
month_end = month_start + timedelta(days=30)
validated = (
db.query(func.count(Test.id))
.filter(
Test.state == "validated",
Test.created_at >= month_start,
Test.created_at < month_end,
)
.scalar() or 0
)
detected = (
db.query(func.count(Test.id))
.filter(
Test.state == "validated",
Test.detection_result == TestResult.detected,
Test.created_at >= month_start,
Test.created_at < month_end,
)
.scalar() or 0
)
months.append({
"month": month_start.strftime("%Y-%m"),
"validated": validated,
"detected": detected,
"detection_rate": round((detected / validated) * 100, 1) if validated > 0 else 0,
})
return months

View File

@@ -0,0 +1,107 @@
"""Analytics service — flat JSON optimized for PowerBI / BI tools."""
from __future__ import annotations
from sqlalchemy import func
from sqlalchemy.orm import Session
from app.models.coverage_snapshot import CoverageSnapshot
from app.models.technique import Technique
from app.models.test import Test
from app.models.user import User
def get_coverage_analytics(db: Session) -> list[dict]:
"""Coverage per technique — flat format for BI dashboards."""
techniques = db.query(Technique).all()
return [
{
"mitre_id": t.mitre_id,
"name": t.name,
"tactic": t.tactic,
"status": t.status_global.value if t.status_global else "not_evaluated",
"is_subtechnique": t.is_subtechnique,
"test_count": len(t.tests) if t.tests else 0,
"review_required": t.review_required,
"last_review_date": (
t.last_review_date.isoformat() if t.last_review_date else None
),
}
for t in techniques
]
def get_tests_analytics(
db: Session,
*,
date_from: str | None = None,
date_to: str | None = None,
) -> list[dict]:
"""All tests with timestamps — flat format for BI dashboards."""
query = db.query(Test)
if date_from:
query = query.filter(Test.created_at >= date_from)
if date_to:
query = query.filter(Test.created_at <= date_to)
tests = query.all()
return [
{
"id": str(t.id),
"technique_id": str(t.technique_id),
"name": t.name,
"state": t.state.value if t.state else None,
"result": t.result.value if t.result else None,
"detection_result": (
t.detection_result.value if t.detection_result else None
),
"created_at": t.created_at.isoformat() if t.created_at else None,
"execution_date": (
t.execution_date.isoformat() if t.execution_date else None
),
"platform": t.platform,
"tool_used": t.tool_used,
"attack_success": t.attack_success,
"remediation_status": t.remediation_status,
}
for t in tests
]
def get_trends_analytics(db: Session) -> list[dict]:
"""Historical coverage snapshots for trend visualization."""
snapshots = (
db.query(CoverageSnapshot)
.order_by(CoverageSnapshot.created_at)
.all()
)
return [
{
"date": s.created_at.isoformat() if s.created_at else None,
"name": s.name,
"total_techniques": s.total_techniques,
"validated_count": s.validated_count,
"partial_count": s.partial_count,
"not_covered_count": s.not_covered_count,
"organization_score": s.organization_score,
}
for s in snapshots
]
def get_operators_analytics(db: Session) -> list[dict]:
"""Per-operator metrics — for workload management dashboards."""
results = (
db.query(
User.username,
User.role,
func.count(Test.id).label("test_count"),
)
.outerjoin(Test, Test.created_by == User.id)
.group_by(User.id, User.username, User.role)
.all()
)
return [
{"username": r[0], "role": r[1], "test_count": r[2]}
for r in results
]

View File

@@ -0,0 +1,45 @@
"""Authentication service — credential validation and password management."""
from __future__ import annotations
from sqlalchemy.orm import Session
from app.auth import hash_password, verify_password
from app.domain.errors import BusinessRuleViolation, PermissionViolation
from app.models.user import User
_DUMMY_HASH = "$2b$12$LJ3m4ys3Lg3dMO/NpNmOaeVwFpWJMxlB2FLmEAo9fZr.S8H1vC4Wy"
def authenticate_user(db: Session, *, username: str, password: str) -> User:
"""Validate credentials and return the User.
Raises BusinessRuleViolation for invalid credentials.
Raises PermissionViolation for disabled account.
Uses constant-time comparison to prevent timing attacks.
"""
user = db.query(User).filter(User.username == username).first()
hashed = user.hashed_password if user else _DUMMY_HASH
password_valid = verify_password(password, hashed)
if user is None or not password_valid:
raise BusinessRuleViolation("Incorrect username or password")
if not user.is_active:
raise PermissionViolation("Account is disabled. Contact an administrator.")
return user
def change_password(
db: Session,
user: User,
*,
current_password: str,
new_password: str,
) -> None:
"""Change a user's password. Does NOT commit.
Raises BusinessRuleViolation if current password is wrong.
"""
if not verify_password(current_password, user.hashed_password):
raise BusinessRuleViolation("Current password is incorrect")
user.hashed_password = hash_password(new_password)
user.must_change_password = False

View File

@@ -0,0 +1,150 @@
"""Test template service — framework-agnostic CRUD and queries."""
from __future__ import annotations
import uuid
from sqlalchemy import func, or_
from sqlalchemy.orm import Session
from app.domain.errors import EntityNotFoundError
from app.models.test_template import TestTemplate
from app.utils import escape_like
def list_templates(
db: Session,
*,
source: str | None = None,
platform: str | None = None,
severity: str | None = None,
mitre_technique_id: str | None = None,
search: str | None = None,
is_active: bool | None = None,
offset: int = 0,
limit: int = 50,
) -> list:
"""Return paginated, filterable list of test templates."""
query = db.query(TestTemplate)
if is_active is not None:
query = query.filter(TestTemplate.is_active == is_active)
if source:
query = query.filter(TestTemplate.source == source)
if platform:
query = query.filter(TestTemplate.platform.ilike(f"%{escape_like(platform)}%"))
if severity:
query = query.filter(TestTemplate.severity == severity)
if mitre_technique_id:
query = query.filter(TestTemplate.mitre_technique_id == mitre_technique_id)
if search:
pattern = f"%{escape_like(search)}%"
query = query.filter(
or_(
TestTemplate.name.ilike(pattern),
TestTemplate.description.ilike(pattern),
)
)
templates = (
query
.order_by(TestTemplate.mitre_technique_id, TestTemplate.name)
.offset(offset)
.limit(limit)
.all()
)
return templates
def get_template_stats(db: Session) -> dict:
"""Return catalog statistics: totals by source, platform, active/inactive."""
total = db.query(func.count(TestTemplate.id)).scalar() or 0
active = (
db.query(func.count(TestTemplate.id))
.filter(TestTemplate.is_active == True) # noqa: E712
.scalar()
) or 0
inactive = total - active
source_rows = (
db.query(TestTemplate.source, func.count(TestTemplate.id))
.filter(TestTemplate.is_active == True) # noqa: E712
.group_by(TestTemplate.source)
.all()
)
by_source = {source: cnt for source, cnt in source_rows}
platform_rows = (
db.query(TestTemplate.platform, func.count(TestTemplate.id))
.filter(TestTemplate.is_active == True) # noqa: E712
.group_by(TestTemplate.platform)
.all()
)
by_platform = {(platform or "unspecified"): cnt for platform, cnt in platform_rows}
return {
"total": total,
"active": active,
"inactive": inactive,
"by_source": by_source,
"by_platform": by_platform,
}
def bulk_activate(db: Session, *, activate: bool) -> int:
"""Set all templates to active or inactive. Returns count of affected. Does NOT commit."""
count = (
db.query(TestTemplate)
.filter(TestTemplate.is_active != activate)
.update({TestTemplate.is_active: activate})
)
return count
def get_templates_by_technique(db: Session, mitre_id: str) -> list:
"""Return all active templates mapped to a specific MITRE technique."""
return (
db.query(TestTemplate)
.filter(
TestTemplate.mitre_technique_id == mitre_id,
TestTemplate.is_active == True, # noqa: E712
)
.order_by(TestTemplate.name)
.all()
)
def get_template_or_raise(db: Session, template_id: uuid.UUID) -> TestTemplate:
"""Return a template by ID. Raises EntityNotFoundError if not found."""
template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first()
if template is None:
raise EntityNotFoundError("Test template", str(template_id))
return template
def create_template(db: Session, **fields: object) -> TestTemplate:
"""Create a test template from keyword args (e.g. payload.model_dump()). Does NOT commit."""
template = TestTemplate(**fields)
db.add(template)
return template
def update_template(db: Session, template_id: uuid.UUID, **fields: object) -> TestTemplate:
"""Update an existing template. Raises EntityNotFoundError if not found. Does NOT commit."""
template = get_template_or_raise(db, template_id)
for field, value in fields.items():
if hasattr(template, field):
setattr(template, field, value)
return template
def toggle_template_active(db: Session, template_id: uuid.UUID) -> TestTemplate:
"""Toggle template active/inactive. Does NOT commit."""
template = get_template_or_raise(db, template_id)
template.is_active = not template.is_active
return template
def soft_delete_template(db: Session, template_id: uuid.UUID) -> None:
"""Soft-delete a template by setting is_active=False. Does NOT commit."""
template = get_template_or_raise(db, template_id)
template.is_active = False