Files
Aegis/backend/tests/test_scoring_and_compliance.py

478 lines
18 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Tests for scoring, operational metrics, and compliance — T-236.
Uses the in-memory SQLite test database from conftest.py to verify
calculations with known data.
"""
import uuid
from datetime import datetime, timedelta
import pytest
from app.models.technique import Technique
from app.models.test import Test
from app.models.test_template import TestTemplate
from app.models.detection_rule import DetectionRule
from app.models.test_detection_result import TestDetectionResult
from app.models.defensive_technique import DefensiveTechnique, DefensiveTechniqueMapping
from app.models.compliance import ComplianceFramework, ComplianceControl, ComplianceControlMapping
from app.models.audit import AuditLog
from app.models.enums import TestState, TestResult, TechniqueStatus
from app.models.scoring_config import ScoringConfig
from app.services.scoring_config_service import get_scoring_weights, update_scoring_weights
from app.services.scoring_service import (
_recency_factor,
calculate_technique_score,
calculate_tactic_score,
calculate_organization_score,
)
from app.services.operational_metrics_service import (
calculate_mttd,
calculate_mttr,
calculate_detection_efficacy,
)
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def sample_technique(db):
"""Create a technique with known data."""
tech = Technique(
mitre_id="T1059",
name="Command and Scripting Interpreter",
tactic="execution",
platforms=["windows", "linux", "macos"],
status_global=TechniqueStatus.validated,
)
db.add(tech)
db.commit()
db.refresh(tech)
return tech
@pytest.fixture
def sample_technique_no_tests(db):
"""Create a technique with no tests."""
tech = Technique(
mitre_id="T9999",
name="No Tests Technique",
tactic="discovery",
platforms=["windows"],
status_global=TechniqueStatus.not_evaluated,
)
db.add(tech)
db.commit()
db.refresh(tech)
return tech
@pytest.fixture
def validated_tests(db, sample_technique, admin_user):
"""Create multiple validated tests with detection results."""
now = datetime.utcnow()
tests = []
for i, result in enumerate([TestResult.detected, TestResult.detected, TestResult.not_detected]):
test = Test(
technique_id=sample_technique.id,
name=f"Test {i+1} for T1059",
state=TestState.validated,
detection_result=result,
created_by=admin_user.id,
platform=["windows", "linux", "macos"][i % 3],
red_validated_at=now - timedelta(days=i * 30),
blue_validated_at=now - timedelta(days=i * 30),
created_at=now - timedelta(days=i * 30 + 5),
)
db.add(test)
tests.append(test)
db.commit()
for t in tests:
db.refresh(t)
return tests
@pytest.fixture
def compliance_setup(db, sample_technique, sample_technique_no_tests):
"""Create a compliance framework with controls mapped to techniques."""
framework = ComplianceFramework(
name="NIST 800-53",
version="5.0",
description="NIST Special Publication 800-53",
)
db.add(framework)
db.flush()
# Control 1: mapped to validated technique
control1 = ComplianceControl(
framework_id=framework.id,
control_id="AC-2",
title="Account Management",
category="Access Control",
)
db.add(control1)
db.flush()
mapping1 = ComplianceControlMapping(
compliance_control_id=control1.id,
technique_id=sample_technique.id,
)
db.add(mapping1)
# Control 2: mapped to technique with no tests
control2 = ComplianceControl(
framework_id=framework.id,
control_id="SI-4",
title="Information System Monitoring",
category="System and Information Integrity",
)
db.add(control2)
db.flush()
mapping2 = ComplianceControlMapping(
compliance_control_id=control2.id,
technique_id=sample_technique_no_tests.id,
)
db.add(mapping2)
db.commit()
return {
"framework": framework,
"control_covered": control1,
"control_not_covered": control2,
}
# ═══════════════════════════════════════════════════════════════════════
# Scoring Tests
# ═══════════════════════════════════════════════════════════════════════
class TestScoring:
def test_technique_score_all_detected(self, db, sample_technique, admin_user):
"""Técnica con todos los tests detected → score alto."""
now = datetime.utcnow()
for i in range(3):
test = Test(
technique_id=sample_technique.id,
name=f"All Detected {i}",
state=TestState.validated,
detection_result=TestResult.detected,
created_by=admin_user.id,
platform=["windows", "linux", "macos"][i],
red_validated_at=now - timedelta(days=10),
)
db.add(test)
db.commit()
result = calculate_technique_score(sample_technique, db)
assert result["total_score"] > 0
# Test component should be maxed out (all detected)
assert result["breakdown"]["tests_validated"]["score"] > 0
def test_technique_score_no_tests(self, db, sample_technique_no_tests):
"""Técnica sin tests → solo puede puntuar severidad por defecto."""
result = calculate_technique_score(sample_technique_no_tests, db)
assert result["breakdown"]["tests_validated"]["score"] == 0
assert result["breakdown"]["recency"]["score"] == 0
assert result["total_score"] == result["breakdown"]["severity"]["score"]
def test_technique_score_partial_detection(self, db, sample_technique, validated_tests):
"""Técnica con detección parcial → score intermedio."""
result = calculate_technique_score(sample_technique, db)
# 2 detected out of 3 validated → partial score
assert 0 < result["total_score"] < 100
breakdown = result["breakdown"]
assert "2/3" in breakdown["tests_validated"]["detail"]
def test_technique_score_recency_penalty(self, db, sample_technique, admin_user):
"""Tests > 180 días → factor de recencia reducido."""
old_date = datetime.utcnow() - timedelta(days=200)
test = Test(
technique_id=sample_technique.id,
name="Old Test",
state=TestState.validated,
detection_result=TestResult.detected,
created_by=admin_user.id,
platform="windows",
red_validated_at=old_date,
blue_validated_at=old_date,
)
db.add(test)
db.commit()
result = calculate_technique_score(sample_technique, db)
assert result["breakdown"]["recency"]["score"] == 5.0 # 0.5 * 10 (181365 días)
assert "200" in result["breakdown"]["recency"]["detail"]
def test_recency_recent_scores_higher_than_old(self, db, sample_technique, admin_user):
"""Mismo resultado de detección: test reciente puntúa más que uno de hace 1 año."""
now = datetime.utcnow()
for days_ago in (1, 400):
db.add(
Test(
technique_id=sample_technique.id,
name=f"Recency {days_ago}",
state=TestState.validated,
detection_result=TestResult.detected,
created_by=admin_user.id,
platform="windows",
red_validated_at=now - timedelta(days=days_ago),
blue_validated_at=now - timedelta(days=days_ago),
)
)
db.commit()
result = calculate_technique_score(sample_technique, db)
assert _recency_factor(now - timedelta(days=1)) == 1.0
assert _recency_factor(now - timedelta(days=400)) == 0.2
assert result["breakdown"]["recency"]["score"] == 10.0
def test_scoring_weights_persist_in_database(self, db, sample_technique, validated_tests):
"""Cambiar pesos en BD se refleja en el breakdown."""
update_scoring_weights(db, tests=50, detection_rules=20, d3fend=10, recency=10, severity=10)
db.commit()
score = calculate_technique_score(sample_technique, db)
assert score["breakdown"]["tests_validated"]["max"] == 50
row = db.query(ScoringConfig).first()
assert row is not None
assert row.weight_tests == 50
update_scoring_weights(db, tests=40, detection_rules=25, d3fend=15, recency=10, severity=10)
db.commit()
assert get_scoring_weights(db).tests == 40
def test_scoring_weights_configurable(self, db, sample_technique, validated_tests):
"""Scoring weights are reflected in the breakdown max values."""
score = calculate_technique_score(sample_technique, db)
breakdown = score["breakdown"]
total_max = sum(
v["max"] for v in breakdown.values() if isinstance(v, dict) and "max" in v
)
assert total_max == 100, f"Weights should sum to 100, got {total_max}"
assert score["total_score"] >= 0
assert score["total_score"] <= 100
def test_organization_score_aggregation(self, db, sample_technique, validated_tests):
"""Score global agrega correctamente los scores de técnicas."""
result = calculate_organization_score(db)
assert result["techniques_total"] >= 1
assert result["overall_score"] >= 0
assert result["techniques_evaluated"] >= 0
# ═══════════════════════════════════════════════════════════════════════
# Operational Metrics Tests
# ═══════════════════════════════════════════════════════════════════════
class TestOperationalMetrics:
def test_mttd_calculation(self, db, sample_technique, admin_user):
"""MTTD se calcula desde timestamps del audit_log."""
now = datetime.utcnow()
test = Test(
technique_id=sample_technique.id,
name="MTTD Test",
state=TestState.validated,
created_by=admin_user.id,
)
db.add(test)
db.flush()
# Create audit log entries for state transitions
start_log = AuditLog(
user_id=admin_user.id,
action="start_execution",
entity_type="test",
entity_id=str(test.id),
timestamp=now - timedelta(hours=5),
)
submit_log = AuditLog(
user_id=admin_user.id,
action="submit_red",
entity_type="test",
entity_id=str(test.id),
timestamp=now - timedelta(hours=2),
)
db.add(start_log)
db.add(submit_log)
db.commit()
result = calculate_mttd(db)
# Should have data (3 hours between start and submit)
if result is not None:
assert result["sample_size"] >= 1
assert result["mean_hours"] >= 0
def test_mttr_calculation(self, db, sample_technique, admin_user):
"""MTTR incluye tiempo de remediación."""
now = datetime.utcnow()
test = Test(
technique_id=sample_technique.id,
name="MTTR Test",
state=TestState.validated,
remediation_status="completed",
blue_validated_at=now - timedelta(hours=48),
created_by=admin_user.id,
)
db.add(test)
db.flush()
# Audit log for remediation completion
log = AuditLog(
user_id=admin_user.id,
action="update_remediation",
entity_type="test",
entity_id=str(test.id),
timestamp=now - timedelta(hours=24),
)
db.add(log)
db.commit()
result = calculate_mttr(db)
if result is not None:
assert result["sample_size"] >= 1
assert result["mean_hours"] > 0
def test_detection_efficacy(self, db, sample_technique, validated_tests):
"""Detection efficacy con datos de prueba conocidos."""
result = calculate_detection_efficacy(db)
assert result["total"] == 3
assert result["detected"] == 2
assert result["not_detected"] == 1
expected_pct = round((2 / 3) * 100, 1)
assert result["percentage"] == expected_pct
def test_metrics_with_no_data(self, db):
"""Métricas retornan null/cero cuando no hay datos suficientes."""
mttd = calculate_mttd(db)
mttr = calculate_mttr(db)
efficacy = calculate_detection_efficacy(db)
assert mttd is None
assert mttr is None
assert efficacy["total"] == 0
assert efficacy["percentage"] == 0
# ═══════════════════════════════════════════════════════════════════════
# Compliance Tests
# ═══════════════════════════════════════════════════════════════════════
class TestCompliance:
def test_control_fully_covered(self, db, sample_technique, validated_tests, compliance_setup):
"""Control con todas las técnicas validated → covered."""
control = compliance_setup["control_covered"]
mappings = (
db.query(ComplianceControlMapping)
.filter(ComplianceControlMapping.compliance_control_id == control.id)
.all()
)
assert len(mappings) == 1
# The mapped technique has validated tests
technique = mappings[0].technique
assert technique.status_global == TechniqueStatus.validated
def test_control_not_covered(self, db, compliance_setup):
"""Control con todas las técnicas sin tests → not_covered."""
control = compliance_setup["control_not_covered"]
mappings = (
db.query(ComplianceControlMapping)
.filter(ComplianceControlMapping.compliance_control_id == control.id)
.all()
)
assert len(mappings) == 1
technique = mappings[0].technique
assert technique.status_global == TechniqueStatus.not_evaluated
def test_control_partially_covered(self, db, sample_technique, sample_technique_no_tests, admin_user, compliance_setup):
"""Control con técnicas mixtas → partially_covered."""
control = compliance_setup["control_covered"]
# Add second mapping to the not-evaluated technique
mapping = ComplianceControlMapping(
compliance_control_id=control.id,
technique_id=sample_technique_no_tests.id,
)
db.add(mapping)
db.commit()
# Now this control has two techniques: one validated, one not_evaluated
mappings = (
db.query(ComplianceControlMapping)
.filter(ComplianceControlMapping.compliance_control_id == control.id)
.all()
)
assert len(mappings) == 2
statuses = [m.technique.status_global for m in mappings]
assert TechniqueStatus.validated in statuses
assert TechniqueStatus.not_evaluated in statuses
def test_compliance_percentage(self, db, sample_technique, validated_tests, compliance_setup):
"""Porcentaje global de compliance calculado correctamente."""
framework = compliance_setup["framework"]
controls = (
db.query(ComplianceControl)
.filter(ComplianceControl.framework_id == framework.id)
.all()
)
assert len(controls) == 2
covered = 0
total = len(controls)
for control in controls:
mappings = control.technique_mappings
if all(
m.technique.status_global in (TechniqueStatus.validated, TechniqueStatus.partial)
for m in mappings
):
covered += 1
pct = round((covered / total) * 100, 1)
assert pct == 50.0 # 1 out of 2 controls covered
def test_compliance_gaps(self, db, compliance_setup):
"""Gaps retorna solo controles no cubiertos con sus técnicas."""
framework = compliance_setup["framework"]
controls = (
db.query(ComplianceControl)
.filter(ComplianceControl.framework_id == framework.id)
.all()
)
gaps = []
for control in controls:
mappings = control.technique_mappings
uncovered_techniques = [
m.technique
for m in mappings
if m.technique.status_global in (TechniqueStatus.not_evaluated, TechniqueStatus.not_covered)
]
if uncovered_techniques:
gaps.append({
"control_id": control.control_id,
"title": control.title,
"uncovered_techniques": [t.mitre_id for t in uncovered_techniques],
})
assert len(gaps) >= 1
si4_gap = next((g for g in gaps if g["control_id"] == "SI-4"), None)
assert si4_gap is not None
assert "T9999" in si4_gap["uncovered_techniques"]