feat: extract advanced_metrics, analytics, test_templates, and auth to services (Tier 1 complete)

This commit is contained in:
2026-02-20 14:28:52 +01:00
parent bbc2dddd86
commit 9e22fde746
8 changed files with 579 additions and 422 deletions

View File

@@ -1,17 +1,12 @@
"""Advanced metrics endpoints — coverage by tactic, never-tested, avg validation time.""" """Advanced metrics endpoints — coverage by tactic, never-tested, avg validation time."""
from datetime import datetime
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
from sqlalchemy import func, case
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.database import get_db from app.database import get_db
from app.dependencies.auth import get_current_user from app.dependencies.auth import get_current_user
from app.models.audit import AuditLog
from app.models.technique import Technique
from app.models.test import Test
from app.models.user import User from app.models.user import User
from app.services import advanced_metrics_service
router = APIRouter(prefix="/metrics/advanced", tags=["advanced-metrics"]) router = APIRouter(prefix="/metrics/advanced", tags=["advanced-metrics"])
@@ -22,39 +17,7 @@ def coverage_by_tactic(
user: User = Depends(get_current_user), user: User = Depends(get_current_user),
): ):
"""Coverage percentage broken down by MITRE ATT&CK tactic.""" """Coverage percentage broken down by MITRE ATT&CK tactic."""
results = ( return advanced_metrics_service.get_coverage_by_tactic(db)
db.query(
Technique.tactic,
func.count(Technique.id).label("total"),
func.sum(
case((Technique.status_global == "validated", 1), else_=0)
).label("validated"),
func.sum(
case((Technique.status_global == "partial", 1), else_=0)
).label("partial"),
func.sum(
case((Technique.status_global == "not_covered", 1), else_=0)
).label("not_covered"),
func.sum(
case((Technique.status_global == "in_progress", 1), else_=0)
).label("in_progress"),
)
.group_by(Technique.tactic)
.order_by(Technique.tactic)
.all()
)
return [
{
"tactic": r[0] or "Unknown",
"total": r[1],
"validated": int(r[2]),
"partial": int(r[3]),
"not_covered": int(r[4]),
"in_progress": int(r[5]),
"coverage_pct": round((int(r[2]) / r[1]) * 100, 1) if r[1] > 0 else 0,
}
for r in results
]
@router.get("/never-tested") @router.get("/never-tested")
@@ -63,24 +26,7 @@ def never_tested_techniques(
user: User = Depends(get_current_user), user: User = Depends(get_current_user),
): ):
"""Techniques that have never had a test created.""" """Techniques that have never had a test created."""
tested_technique_ids = ( return advanced_metrics_service.get_never_tested_techniques(db)
db.query(Test.technique_id).distinct().subquery()
)
techniques = (
db.query(Technique)
.filter(~Technique.id.in_(db.query(tested_technique_ids)))
.order_by(Technique.mitre_id)
.all()
)
return [
{
"mitre_id": t.mitre_id,
"name": t.name,
"tactic": t.tactic,
"is_subtechnique": t.is_subtechnique,
}
for t in techniques
]
@router.get("/avg-validation-time") @router.get("/avg-validation-time")
@@ -92,50 +38,7 @@ def avg_validation_time(
Returns overall average and per-phase averages where data is available. Returns overall average and per-phase averages where data is available.
""" """
validated_tests = ( return advanced_metrics_service.get_avg_validation_time(db)
db.query(Test)
.filter(Test.state == "validated")
.all()
)
if not validated_tests:
return {
"total_validated": 0,
"avg_total_hours": 0,
"avg_red_phase_hours": 0,
"avg_blue_phase_hours": 0,
}
total_durations = []
red_durations = []
blue_durations = []
for test in validated_tests:
if test.created_at and test.red_validated_at:
total_seconds = (test.red_validated_at - test.created_at).total_seconds()
total_durations.append(total_seconds)
if test.red_started_at and test.blue_started_at:
red_sec = (test.blue_started_at - test.red_started_at).total_seconds()
red_paused = test.red_paused_seconds or 0
red_durations.append(max(red_sec - red_paused, 0))
if test.blue_started_at and test.blue_validated_at:
blue_sec = (test.blue_validated_at - test.blue_started_at).total_seconds()
blue_paused = test.blue_paused_seconds or 0
blue_durations.append(max(blue_sec - blue_paused, 0))
def avg_hours(durations: list[float]) -> float:
if not durations:
return 0
return round(sum(durations) / len(durations) / 3600, 2)
return {
"total_validated": len(validated_tests),
"avg_total_hours": avg_hours(total_durations),
"avg_red_phase_hours": avg_hours(red_durations),
"avg_blue_phase_hours": avg_hours(blue_durations),
}
@router.get("/detection-rate-trend") @router.get("/detection-rate-trend")
@@ -144,41 +47,4 @@ def detection_rate_trend(
user: User = Depends(get_current_user), user: User = Depends(get_current_user),
): ):
"""Monthly detection rate trend for the last 12 months.""" """Monthly detection rate trend for the last 12 months."""
from datetime import timedelta return advanced_metrics_service.get_detection_rate_trend(db)
now = datetime.utcnow()
months = []
for i in range(11, -1, -1):
month_start = datetime(now.year, now.month, 1) - timedelta(days=i * 30)
month_end = month_start + timedelta(days=30)
validated = (
db.query(func.count(Test.id))
.filter(
Test.state == "validated",
Test.created_at >= month_start,
Test.created_at < month_end,
)
.scalar() or 0
)
detected = (
db.query(func.count(Test.id))
.filter(
Test.state == "validated",
Test.detection_result == "detected",
Test.created_at >= month_start,
Test.created_at < month_end,
)
.scalar() or 0
)
months.append({
"month": month_start.strftime("%Y-%m"),
"validated": validated,
"detected": detected,
"detection_rate": round((detected / validated) * 100, 1) if validated > 0 else 0,
})
return months

View File

@@ -5,15 +5,12 @@ directly from URL. All endpoints require authentication.
""" """
from fastapi import APIRouter, Depends, Query from fastapi import APIRouter, Depends, Query
from sqlalchemy import func
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.database import get_db from app.database import get_db
from app.dependencies.auth import get_current_user, require_any_role from app.dependencies.auth import get_current_user, require_any_role
from app.models.coverage_snapshot import CoverageSnapshot
from app.models.technique import Technique
from app.models.test import Test
from app.models.user import User from app.models.user import User
from app.services import analytics_service
router = APIRouter(prefix="/analytics", tags=["analytics"]) router = APIRouter(prefix="/analytics", tags=["analytics"])
@@ -24,22 +21,7 @@ def analytics_coverage(
user: User = Depends(get_current_user), user: User = Depends(get_current_user),
): ):
"""Coverage per technique — flat format for BI dashboards.""" """Coverage per technique — flat format for BI dashboards."""
techniques = db.query(Technique).all() return analytics_service.get_coverage_analytics(db)
return [
{
"mitre_id": t.mitre_id,
"name": t.name,
"tactic": t.tactic,
"status": t.status_global.value if t.status_global else "not_evaluated",
"is_subtechnique": t.is_subtechnique,
"test_count": len(t.tests) if t.tests else 0,
"review_required": t.review_required,
"last_review_date": (
t.last_review_date.isoformat() if t.last_review_date else None
),
}
for t in techniques
]
@router.get("/tests") @router.get("/tests")
@@ -50,34 +32,9 @@ def analytics_tests(
user: User = Depends(get_current_user), user: User = Depends(get_current_user),
): ):
"""All tests with timestamps — flat format for BI dashboards.""" """All tests with timestamps — flat format for BI dashboards."""
query = db.query(Test) return analytics_service.get_tests_analytics(
if date_from: db, date_from=date_from, date_to=date_to
query = query.filter(Test.created_at >= date_from) )
if date_to:
query = query.filter(Test.created_at <= date_to)
tests = query.all()
return [
{
"id": str(t.id),
"technique_id": str(t.technique_id),
"name": t.name,
"state": t.state.value if t.state else None,
"result": t.result.value if t.result else None,
"detection_result": (
t.detection_result.value if t.detection_result else None
),
"created_at": t.created_at.isoformat() if t.created_at else None,
"execution_date": (
t.execution_date.isoformat() if t.execution_date else None
),
"platform": t.platform,
"tool_used": t.tool_used,
"attack_success": t.attack_success,
"remediation_status": t.remediation_status,
}
for t in tests
]
@router.get("/trends") @router.get("/trends")
@@ -86,23 +43,7 @@ def analytics_trends(
user: User = Depends(get_current_user), user: User = Depends(get_current_user),
): ):
"""Historical coverage snapshots for trend visualization.""" """Historical coverage snapshots for trend visualization."""
snapshots = ( return analytics_service.get_trends_analytics(db)
db.query(CoverageSnapshot)
.order_by(CoverageSnapshot.created_at)
.all()
)
return [
{
"date": s.created_at.isoformat() if s.created_at else None,
"name": s.name,
"total_techniques": s.total_techniques,
"validated_count": s.validated_count,
"partial_count": s.partial_count,
"not_covered_count": s.not_covered_count,
"organization_score": s.organization_score,
}
for s in snapshots
]
@router.get("/operators") @router.get("/operators")
@@ -111,17 +52,4 @@ def analytics_operators(
user: User = Depends(require_any_role("red_lead", "blue_lead")), user: User = Depends(require_any_role("red_lead", "blue_lead")),
): ):
"""Per-operator metrics — for workload management dashboards.""" """Per-operator metrics — for workload management dashboards."""
results = ( return analytics_service.get_operators_analytics(db)
db.query(
User.username,
User.role,
func.count(Test.id).label("test_count"),
)
.outerjoin(Test, Test.created_by == User.id)
.group_by(User.id, User.username, User.role)
.all()
)
return [
{"username": r[0], "role": r[1], "test_count": r[2]}
for r in results
]

View File

@@ -9,7 +9,7 @@ cannot use cookies (e.g. Swagger UI).
import os import os
from fastapi import APIRouter, Cookie, Depends, HTTPException, Request, Response, status from fastapi import APIRouter, Cookie, Depends, Request, Response
from fastapi.security import OAuth2PasswordRequestForm from fastapi.security import OAuth2PasswordRequestForm
from slowapi import Limiter from slowapi import Limiter
from slowapi.util import get_remote_address from slowapi.util import get_remote_address
@@ -17,11 +17,13 @@ from sqlalchemy.orm import Session
from jose import jwt, JWTError from jose import jwt, JWTError
from app.auth import verify_password, hash_password, create_access_token, blacklist_token from app.auth import create_access_token, blacklist_token
from app.config import settings from app.config import settings
from app.database import get_db from app.database import get_db
from app.dependencies.auth import get_current_user from app.dependencies.auth import get_current_user
from app.domain.unit_of_work import UnitOfWork
from app.models.user import User from app.models.user import User
from app.services.auth_service import authenticate_user, change_password as auth_change_password
from app.schemas.auth import TokenResponse, UserOut from app.schemas.auth import TokenResponse, UserOut
from app.schemas.user import PasswordChange from app.schemas.user import PasswordChange
@@ -56,25 +58,11 @@ def login(
attacks. The token is set as an HttpOnly cookie **and** returned in the attacks. The token is set as an HttpOnly cookie **and** returned in the
JSON body for API/Swagger compatibility. JSON body for API/Swagger compatibility.
""" """
user = db.query(User).filter(User.username == form_data.username).first() user = authenticate_user(
db,
# Constant-time comparison: always run bcrypt verify to prevent username=form_data.username,
# timing-based user enumeration (SEC-005). password=form_data.password,
_DUMMY_HASH = "$2b$12$LJ3m4ys3Lg3dMO/NpNmOaeVwFpWJMxlB2FLmEAo9fZr.S8H1vC4Wy" )
hashed = user.hashed_password if user else _DUMMY_HASH
password_valid = verify_password(form_data.password, hashed)
if user is None or not password_valid:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Incorrect username or password",
)
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Account is disabled. Contact an administrator.",
)
access_token = create_access_token(data={"sub": user.username}) access_token = create_access_token(data={"sub": user.username})
@@ -163,14 +151,13 @@ def change_password(
``must_change_password`` flag is cleared so the user can proceed ``must_change_password`` flag is cleared so the user can proceed
normally. normally.
""" """
if not verify_password(body.current_password, current_user.hashed_password): auth_change_password(
raise HTTPException( db,
status_code=status.HTTP_400_BAD_REQUEST, current_user,
detail="Current password is incorrect", current_password=body.current_password,
) new_password=body.new_password,
)
current_user.hashed_password = hash_password(body.new_password) with UnitOfWork(db) as uow:
current_user.must_change_password = False uow.commit()
db.commit()
return {"detail": "Password changed successfully"} return {"detail": "Password changed successfully"}

View File

@@ -25,13 +25,12 @@ Filters (GET /test-templates)
import uuid import uuid
from typing import Optional from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query, status from fastapi import APIRouter, Depends, Query, status
from sqlalchemy import func, or_
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.database import get_db from app.database import get_db
from app.dependencies.auth import get_current_user, require_role, require_any_role from app.dependencies.auth import get_current_user, require_any_role
from app.models.test_template import TestTemplate from app.domain.unit_of_work import UnitOfWork
from app.models.user import User from app.models.user import User
from app.schemas.test_template import ( from app.schemas.test_template import (
TestTemplateCreate, TestTemplateCreate,
@@ -39,6 +38,17 @@ from app.schemas.test_template import (
TestTemplateSummary, TestTemplateSummary,
) )
from app.services.audit_service import log_action from app.services.audit_service import log_action
from app.services.test_template_service import (
bulk_activate,
create_template as create_template_svc,
get_template_or_raise,
get_template_stats,
get_templates_by_technique as templates_by_technique,
list_templates,
soft_delete_template,
toggle_template_active as toggle_template_active_svc,
update_template as update_template_svc,
)
router = APIRouter(prefix="/test-templates", tags=["test-templates"]) router = APIRouter(prefix="/test-templates", tags=["test-templates"])
@@ -49,7 +59,7 @@ router = APIRouter(prefix="/test-templates", tags=["test-templates"])
@router.get("", response_model=list[TestTemplateSummary]) @router.get("", response_model=list[TestTemplateSummary])
def list_templates( def _list_templates_handler(
source: Optional[str] = Query(None, description="Filter by source (atomic_red_team, mitre, custom)"), source: Optional[str] = Query(None, description="Filter by source (atomic_red_team, mitre, custom)"),
platform: Optional[str] = Query(None, description="Filter by platform (windows, linux, macos)"), platform: Optional[str] = Query(None, description="Filter by platform (windows, linux, macos)"),
severity: Optional[str] = Query(None, description="Filter by severity (low, medium, high, critical)"), severity: Optional[str] = Query(None, description="Filter by severity (low, medium, high, critical)"),
@@ -62,37 +72,17 @@ def list_templates(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ):
"""Return a paginated, filterable list of test templates.""" """Return a paginated, filterable list of test templates."""
query = db.query(TestTemplate) return list_templates(
if is_active is not None: db,
query = query.filter(TestTemplate.is_active == is_active) # noqa: E712 source=source,
platform=platform,
if source: severity=severity,
query = query.filter(TestTemplate.source == source) mitre_technique_id=mitre_technique_id,
if platform: search=search,
from app.utils import escape_like is_active=is_active,
query = query.filter(TestTemplate.platform.ilike(f"%{escape_like(platform)}%")) offset=offset,
if severity: limit=limit,
query = query.filter(TestTemplate.severity == severity)
if mitre_technique_id:
query = query.filter(TestTemplate.mitre_technique_id == mitre_technique_id)
if search:
from app.utils import escape_like
pattern = f"%{escape_like(search)}%"
query = query.filter(
or_(
TestTemplate.name.ilike(pattern),
TestTemplate.description.ilike(pattern),
)
)
templates = (
query
.order_by(TestTemplate.mitre_technique_id, TestTemplate.name)
.offset(offset)
.limit(limit)
.all()
) )
return templates
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -105,41 +95,8 @@ def template_stats(
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")), current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
): ):
"""Return catalog statistics: totals by source, platform, active/inactive.""" """Return catalog statistics: active, by_source, by_platform."""
return get_template_stats(db)
total = db.query(func.count(TestTemplate.id)).scalar() or 0
active = (
db.query(func.count(TestTemplate.id))
.filter(TestTemplate.is_active == True) # noqa: E712
.scalar()
) or 0
inactive = total - active
# By source
source_rows = (
db.query(TestTemplate.source, func.count(TestTemplate.id))
.filter(TestTemplate.is_active == True) # noqa: E712
.group_by(TestTemplate.source)
.all()
)
by_source = {source: cnt for source, cnt in source_rows}
# By platform
platform_rows = (
db.query(TestTemplate.platform, func.count(TestTemplate.id))
.filter(TestTemplate.is_active == True) # noqa: E712
.group_by(TestTemplate.platform)
.all()
)
by_platform = {(platform or "unspecified"): cnt for platform, cnt in platform_rows}
return {
"total": total,
"active": active,
"inactive": inactive,
"by_source": by_source,
"by_platform": by_platform,
}
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -154,21 +111,17 @@ def bulk_activate_templates(
current_user: User = Depends(require_any_role("red_lead", "blue_lead")), current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
): ):
"""Set all templates to active or inactive.""" """Set all templates to active or inactive."""
count = ( count = bulk_activate(db, activate=activate)
db.query(TestTemplate) with UnitOfWork(db) as uow:
.filter(TestTemplate.is_active != activate) log_action(
.update({TestTemplate.is_active: activate}) db,
) user_id=current_user.id,
db.commit() action="bulk_activate_templates" if activate else "bulk_deactivate_templates",
entity_type="test_template",
log_action( entity_id=None,
db, details={"affected": count, "is_active": activate},
user_id=current_user.id, )
action="bulk_activate_templates" if activate else "bulk_deactivate_templates", uow.commit()
entity_type="test_template",
entity_id=None,
details={"affected": count, "is_active": activate},
)
return { return {
"detail": f"{'Activated' if activate else 'Deactivated'} {count} templates", "detail": f"{'Activated' if activate else 'Deactivated'} {count} templates",
@@ -183,22 +136,13 @@ def bulk_activate_templates(
@router.get("/by-technique/{mitre_id}", response_model=list[TestTemplateSummary]) @router.get("/by-technique/{mitre_id}", response_model=list[TestTemplateSummary])
def templates_by_technique( def _templates_by_technique_handler(
mitre_id: str, mitre_id: str,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ):
"""Return all active templates mapped to a specific MITRE technique.""" """Return all active templates mapped to a specific MITRE technique."""
templates = ( return templates_by_technique(db, mitre_id)
db.query(TestTemplate)
.filter(
TestTemplate.mitre_technique_id == mitre_id,
TestTemplate.is_active == True, # noqa: E712
)
.order_by(TestTemplate.name)
.all()
)
return templates
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -213,13 +157,7 @@ def get_template(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ):
"""Return full details for a single test template.""" """Return full details for a single test template."""
template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first() return get_template_or_raise(db, template_id)
if template is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Test template not found",
)
return template
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -238,24 +176,23 @@ def create_template(
current_user: User = Depends(require_any_role("red_lead", "blue_lead")), current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
): ):
"""Create a custom test template.""" """Create a custom test template."""
template = TestTemplate(**payload.model_dump()) template = create_template_svc(db, **payload.model_dump())
db.add(template) with UnitOfWork(db) as uow:
db.commit() log_action(
db,
user_id=current_user.id,
action="create_test_template",
entity_type="test_template",
entity_id=template.id,
details={
"name": template.name,
"source": template.source,
"mitre_technique_id": template.mitre_technique_id,
},
)
uow.commit()
db.refresh(template) db.refresh(template)
log_action(
db,
user_id=current_user.id,
action="create_test_template",
entity_type="test_template",
entity_id=template.id,
details={
"name": template.name,
"source": template.source,
"mitre_technique_id": template.mitre_technique_id,
},
)
return template return template
@@ -272,29 +209,19 @@ def update_template(
current_user: User = Depends(require_any_role("red_lead", "blue_lead")), current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
): ):
"""Update fields of an existing test template.""" """Update fields of an existing test template."""
template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first() template = update_template_svc(db, template_id, **payload.model_dump(exclude_unset=True))
if template is None: with UnitOfWork(db) as uow:
raise HTTPException( log_action(
status_code=status.HTTP_404_NOT_FOUND, db,
detail="Test template not found", user_id=current_user.id,
action="update_test_template",
entity_type="test_template",
entity_id=template.id,
details={"updated_fields": list(payload.model_dump(exclude_unset=True).keys())},
) )
uow.commit()
update_data = payload.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(template, field, value)
db.commit()
db.refresh(template) db.refresh(template)
log_action(
db,
user_id=current_user.id,
action="update_test_template",
entity_type="test_template",
entity_id=template.id,
details={"updated_fields": list(update_data.keys())},
)
return template return template
@@ -309,27 +236,20 @@ def toggle_template_active(
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")), current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
): ):
"""Toggle a template between active and inactive.""" """Toggle a template between active and inactive (is_active = not is_active)."""
template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first() template = toggle_template_active_svc(db, template_id)
if template is None: with UnitOfWork(db) as uow:
raise HTTPException( log_action(
status_code=status.HTTP_404_NOT_FOUND, db,
detail="Test template not found", user_id=current_user.id,
action="toggle_test_template",
entity_type="test_template",
entity_id=template.id,
details={"name": template.name, "is_active": template.is_active},
) )
uow.commit()
template.is_active = not template.is_active
db.commit()
db.refresh(template) db.refresh(template)
log_action(
db,
user_id=current_user.id,
action="toggle_test_template",
entity_type="test_template",
entity_id=template.id,
details={"name": template.name, "is_active": template.is_active},
)
return template return template
@@ -345,23 +265,17 @@ def delete_template(
current_user: User = Depends(require_any_role("red_lead", "blue_lead")), current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
): ):
"""Soft-delete a test template by setting ``is_active=False``.""" """Soft-delete a test template by setting ``is_active=False``."""
template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first() template = get_template_or_raise(db, template_id)
if template is None: soft_delete_template(db, template_id)
raise HTTPException( with UnitOfWork(db) as uow:
status_code=status.HTTP_404_NOT_FOUND, log_action(
detail="Test template not found", db,
user_id=current_user.id,
action="delete_test_template",
entity_type="test_template",
entity_id=template.id,
details={"name": template.name},
) )
uow.commit()
template.is_active = False
db.commit()
log_action(
db,
user_id=current_user.id,
action="delete_test_template",
entity_type="test_template",
entity_id=template.id,
details={"name": template.name},
)
return {"detail": "Test template deactivated"} return {"detail": "Test template deactivated"}

View File

@@ -0,0 +1,160 @@
"""Advanced metrics service — coverage by tactic, never-tested, avg validation time, detection trend."""
from __future__ import annotations
from datetime import datetime, timedelta
from sqlalchemy import case, func
from sqlalchemy.orm import Session
from app.models.technique import Technique
from app.models.test import Test
from app.models.enums import TestResult
def get_coverage_by_tactic(db: Session) -> list[dict]:
"""Coverage percentage broken down by MITRE ATT&CK tactic."""
results = (
db.query(
Technique.tactic,
func.count(Technique.id).label("total"),
func.sum(
case((Technique.status_global == "validated", 1), else_=0)
).label("validated"),
func.sum(
case((Technique.status_global == "partial", 1), else_=0)
).label("partial"),
func.sum(
case((Technique.status_global == "not_covered", 1), else_=0)
).label("not_covered"),
func.sum(
case((Technique.status_global == "in_progress", 1), else_=0)
).label("in_progress"),
)
.group_by(Technique.tactic)
.order_by(Technique.tactic)
.all()
)
return [
{
"tactic": r[0] or "Unknown",
"total": r[1],
"validated": int(r[2]),
"partial": int(r[3]),
"not_covered": int(r[4]),
"in_progress": int(r[5]),
"coverage_pct": round((int(r[2]) / r[1]) * 100, 1) if r[1] > 0 else 0,
}
for r in results
]
def get_never_tested_techniques(db: Session) -> list[dict]:
"""Techniques that have never had a test created."""
tested_technique_ids = db.query(Test.technique_id).distinct().subquery()
techniques = (
db.query(Technique)
.filter(~Technique.id.in_(db.query(tested_technique_ids)))
.order_by(Technique.mitre_id)
.all()
)
return [
{
"mitre_id": t.mitre_id,
"name": t.name,
"tactic": t.tactic,
"is_subtechnique": t.is_subtechnique,
}
for t in techniques
]
def get_avg_validation_time(db: Session) -> dict:
"""Average time from test creation to validation, computed from validated tests.
Returns overall average and per-phase averages where data is available.
"""
validated_tests = (
db.query(Test)
.filter(Test.state == "validated")
.all()
)
if not validated_tests:
return {
"total_validated": 0,
"avg_total_hours": 0,
"avg_red_phase_hours": 0,
"avg_blue_phase_hours": 0,
}
total_durations = []
red_durations = []
blue_durations = []
for test in validated_tests:
if test.created_at and test.red_validated_at:
total_seconds = (test.red_validated_at - test.created_at).total_seconds()
total_durations.append(total_seconds)
if test.red_started_at and test.blue_started_at:
red_sec = (test.blue_started_at - test.red_started_at).total_seconds()
red_paused = test.red_paused_seconds or 0
red_durations.append(max(red_sec - red_paused, 0))
if test.blue_started_at and test.blue_validated_at:
blue_sec = (test.blue_validated_at - test.blue_started_at).total_seconds()
blue_paused = test.blue_paused_seconds or 0
blue_durations.append(max(blue_sec - blue_paused, 0))
def avg_hours(durations: list[float]) -> float:
if not durations:
return 0
return round(sum(durations) / len(durations) / 3600, 2)
return {
"total_validated": len(validated_tests),
"avg_total_hours": avg_hours(total_durations),
"avg_red_phase_hours": avg_hours(red_durations),
"avg_blue_phase_hours": avg_hours(blue_durations),
}
def get_detection_rate_trend(db: Session) -> list[dict]:
"""Monthly detection rate trend for the last 12 months."""
now = datetime.utcnow()
months = []
for i in range(11, -1, -1):
month_start = datetime(now.year, now.month, 1) - timedelta(days=i * 30)
month_end = month_start + timedelta(days=30)
validated = (
db.query(func.count(Test.id))
.filter(
Test.state == "validated",
Test.created_at >= month_start,
Test.created_at < month_end,
)
.scalar() or 0
)
detected = (
db.query(func.count(Test.id))
.filter(
Test.state == "validated",
Test.detection_result == TestResult.detected,
Test.created_at >= month_start,
Test.created_at < month_end,
)
.scalar() or 0
)
months.append({
"month": month_start.strftime("%Y-%m"),
"validated": validated,
"detected": detected,
"detection_rate": round((detected / validated) * 100, 1) if validated > 0 else 0,
})
return months

View File

@@ -0,0 +1,107 @@
"""Analytics service — flat JSON optimized for PowerBI / BI tools."""
from __future__ import annotations
from sqlalchemy import func
from sqlalchemy.orm import Session
from app.models.coverage_snapshot import CoverageSnapshot
from app.models.technique import Technique
from app.models.test import Test
from app.models.user import User
def get_coverage_analytics(db: Session) -> list[dict]:
"""Coverage per technique — flat format for BI dashboards."""
techniques = db.query(Technique).all()
return [
{
"mitre_id": t.mitre_id,
"name": t.name,
"tactic": t.tactic,
"status": t.status_global.value if t.status_global else "not_evaluated",
"is_subtechnique": t.is_subtechnique,
"test_count": len(t.tests) if t.tests else 0,
"review_required": t.review_required,
"last_review_date": (
t.last_review_date.isoformat() if t.last_review_date else None
),
}
for t in techniques
]
def get_tests_analytics(
db: Session,
*,
date_from: str | None = None,
date_to: str | None = None,
) -> list[dict]:
"""All tests with timestamps — flat format for BI dashboards."""
query = db.query(Test)
if date_from:
query = query.filter(Test.created_at >= date_from)
if date_to:
query = query.filter(Test.created_at <= date_to)
tests = query.all()
return [
{
"id": str(t.id),
"technique_id": str(t.technique_id),
"name": t.name,
"state": t.state.value if t.state else None,
"result": t.result.value if t.result else None,
"detection_result": (
t.detection_result.value if t.detection_result else None
),
"created_at": t.created_at.isoformat() if t.created_at else None,
"execution_date": (
t.execution_date.isoformat() if t.execution_date else None
),
"platform": t.platform,
"tool_used": t.tool_used,
"attack_success": t.attack_success,
"remediation_status": t.remediation_status,
}
for t in tests
]
def get_trends_analytics(db: Session) -> list[dict]:
"""Historical coverage snapshots for trend visualization."""
snapshots = (
db.query(CoverageSnapshot)
.order_by(CoverageSnapshot.created_at)
.all()
)
return [
{
"date": s.created_at.isoformat() if s.created_at else None,
"name": s.name,
"total_techniques": s.total_techniques,
"validated_count": s.validated_count,
"partial_count": s.partial_count,
"not_covered_count": s.not_covered_count,
"organization_score": s.organization_score,
}
for s in snapshots
]
def get_operators_analytics(db: Session) -> list[dict]:
"""Per-operator metrics — for workload management dashboards."""
results = (
db.query(
User.username,
User.role,
func.count(Test.id).label("test_count"),
)
.outerjoin(Test, Test.created_by == User.id)
.group_by(User.id, User.username, User.role)
.all()
)
return [
{"username": r[0], "role": r[1], "test_count": r[2]}
for r in results
]

View File

@@ -0,0 +1,45 @@
"""Authentication service — credential validation and password management."""
from __future__ import annotations
from sqlalchemy.orm import Session
from app.auth import hash_password, verify_password
from app.domain.errors import BusinessRuleViolation, PermissionViolation
from app.models.user import User
_DUMMY_HASH = "$2b$12$LJ3m4ys3Lg3dMO/NpNmOaeVwFpWJMxlB2FLmEAo9fZr.S8H1vC4Wy"
def authenticate_user(db: Session, *, username: str, password: str) -> User:
"""Validate credentials and return the User.
Raises BusinessRuleViolation for invalid credentials.
Raises PermissionViolation for disabled account.
Uses constant-time comparison to prevent timing attacks.
"""
user = db.query(User).filter(User.username == username).first()
hashed = user.hashed_password if user else _DUMMY_HASH
password_valid = verify_password(password, hashed)
if user is None or not password_valid:
raise BusinessRuleViolation("Incorrect username or password")
if not user.is_active:
raise PermissionViolation("Account is disabled. Contact an administrator.")
return user
def change_password(
db: Session,
user: User,
*,
current_password: str,
new_password: str,
) -> None:
"""Change a user's password. Does NOT commit.
Raises BusinessRuleViolation if current password is wrong.
"""
if not verify_password(current_password, user.hashed_password):
raise BusinessRuleViolation("Current password is incorrect")
user.hashed_password = hash_password(new_password)
user.must_change_password = False

View File

@@ -0,0 +1,150 @@
"""Test template service — framework-agnostic CRUD and queries."""
from __future__ import annotations
import uuid
from sqlalchemy import func, or_
from sqlalchemy.orm import Session
from app.domain.errors import EntityNotFoundError
from app.models.test_template import TestTemplate
from app.utils import escape_like
def list_templates(
db: Session,
*,
source: str | None = None,
platform: str | None = None,
severity: str | None = None,
mitre_technique_id: str | None = None,
search: str | None = None,
is_active: bool | None = None,
offset: int = 0,
limit: int = 50,
) -> list:
"""Return paginated, filterable list of test templates."""
query = db.query(TestTemplate)
if is_active is not None:
query = query.filter(TestTemplate.is_active == is_active)
if source:
query = query.filter(TestTemplate.source == source)
if platform:
query = query.filter(TestTemplate.platform.ilike(f"%{escape_like(platform)}%"))
if severity:
query = query.filter(TestTemplate.severity == severity)
if mitre_technique_id:
query = query.filter(TestTemplate.mitre_technique_id == mitre_technique_id)
if search:
pattern = f"%{escape_like(search)}%"
query = query.filter(
or_(
TestTemplate.name.ilike(pattern),
TestTemplate.description.ilike(pattern),
)
)
templates = (
query
.order_by(TestTemplate.mitre_technique_id, TestTemplate.name)
.offset(offset)
.limit(limit)
.all()
)
return templates
def get_template_stats(db: Session) -> dict:
"""Return catalog statistics: totals by source, platform, active/inactive."""
total = db.query(func.count(TestTemplate.id)).scalar() or 0
active = (
db.query(func.count(TestTemplate.id))
.filter(TestTemplate.is_active == True) # noqa: E712
.scalar()
) or 0
inactive = total - active
source_rows = (
db.query(TestTemplate.source, func.count(TestTemplate.id))
.filter(TestTemplate.is_active == True) # noqa: E712
.group_by(TestTemplate.source)
.all()
)
by_source = {source: cnt for source, cnt in source_rows}
platform_rows = (
db.query(TestTemplate.platform, func.count(TestTemplate.id))
.filter(TestTemplate.is_active == True) # noqa: E712
.group_by(TestTemplate.platform)
.all()
)
by_platform = {(platform or "unspecified"): cnt for platform, cnt in platform_rows}
return {
"total": total,
"active": active,
"inactive": inactive,
"by_source": by_source,
"by_platform": by_platform,
}
def bulk_activate(db: Session, *, activate: bool) -> int:
"""Set all templates to active or inactive. Returns count of affected. Does NOT commit."""
count = (
db.query(TestTemplate)
.filter(TestTemplate.is_active != activate)
.update({TestTemplate.is_active: activate})
)
return count
def get_templates_by_technique(db: Session, mitre_id: str) -> list:
"""Return all active templates mapped to a specific MITRE technique."""
return (
db.query(TestTemplate)
.filter(
TestTemplate.mitre_technique_id == mitre_id,
TestTemplate.is_active == True, # noqa: E712
)
.order_by(TestTemplate.name)
.all()
)
def get_template_or_raise(db: Session, template_id: uuid.UUID) -> TestTemplate:
"""Return a template by ID. Raises EntityNotFoundError if not found."""
template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first()
if template is None:
raise EntityNotFoundError("Test template", str(template_id))
return template
def create_template(db: Session, **fields: object) -> TestTemplate:
"""Create a test template from keyword args (e.g. payload.model_dump()). Does NOT commit."""
template = TestTemplate(**fields)
db.add(template)
return template
def update_template(db: Session, template_id: uuid.UUID, **fields: object) -> TestTemplate:
"""Update an existing template. Raises EntityNotFoundError if not found. Does NOT commit."""
template = get_template_or_raise(db, template_id)
for field, value in fields.items():
if hasattr(template, field):
setattr(template, field, value)
return template
def toggle_template_active(db: Session, template_id: uuid.UUID) -> TestTemplate:
"""Toggle template active/inactive. Does NOT commit."""
template = get_template_or_raise(db, template_id)
template.is_active = not template.is_active
return template
def soft_delete_template(db: Session, template_id: uuid.UUID) -> None:
"""Soft-delete a template by setting is_active=False. Does NOT commit."""
template = get_template_or_raise(db, template_id)
template.is_active = False