feat(phase-32): add automated tests V3 for data sources, scoring, campaigns and snapshots (T-235 to T-237)
This commit is contained in:
@@ -1,12 +1,54 @@
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker, declarative_base
|
||||
|
||||
from app.config import settings
|
||||
|
||||
engine = create_engine(settings.DATABASE_URL)
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base = declarative_base()
|
||||
|
||||
# Engine and session factory are created lazily so that tests can
|
||||
# override DATABASE_URL via environment *before* any import triggers
|
||||
# the real PostgreSQL engine creation (which requires psycopg2).
|
||||
_engine = None
|
||||
_SessionLocal = None
|
||||
|
||||
|
||||
def _get_engine():
|
||||
global _engine
|
||||
if _engine is None:
|
||||
from app.config import settings
|
||||
_engine = create_engine(settings.DATABASE_URL)
|
||||
return _engine
|
||||
|
||||
|
||||
def _get_session_factory():
|
||||
global _SessionLocal
|
||||
if _SessionLocal is None:
|
||||
_SessionLocal = sessionmaker(
|
||||
autocommit=False, autoflush=False, bind=_get_engine()
|
||||
)
|
||||
return _SessionLocal
|
||||
|
||||
|
||||
class _LazySessionLocal:
|
||||
"""Proxy so ``SessionLocal()`` keeps working as before but the real
|
||||
sessionmaker is only created on first call."""
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return _get_session_factory()(*args, **kwargs)
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(_get_session_factory(), name)
|
||||
|
||||
|
||||
SessionLocal = _LazySessionLocal()
|
||||
|
||||
|
||||
class _EngineProxy:
|
||||
"""Thin proxy so ``from app.database import engine`` still works."""
|
||||
def __getattr__(self, name):
|
||||
return getattr(_get_engine(), name)
|
||||
|
||||
|
||||
engine = _EngineProxy() # type: ignore[assignment]
|
||||
|
||||
|
||||
def get_db():
|
||||
db = SessionLocal()
|
||||
|
||||
@@ -3,3 +3,5 @@ testpaths = tests
|
||||
python_files = test_*.py
|
||||
python_functions = test_*
|
||||
addopts = -v --tb=short
|
||||
markers =
|
||||
integration: marks tests as integration tests (deselect with '-m "not integration"')
|
||||
|
||||
@@ -1,16 +1,45 @@
|
||||
"""Pytest fixtures and configuration for backend tests."""
|
||||
"""Pytest fixtures and configuration for backend tests.
|
||||
|
||||
The conftest intentionally avoids importing ``app.main`` at module level
|
||||
because that triggers heavy side-effect imports (boto3, APScheduler, etc.)
|
||||
which are NOT needed for unit tests. The ``client`` fixture lazily imports
|
||||
the FastAPI app only when actually requested.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
# Set DATABASE_URL to SQLite *before* any app module is imported so that
|
||||
# the lazy engine in app.database never tries to connect to PostgreSQL.
|
||||
os.environ.setdefault("DATABASE_URL", "sqlite:///:memory:")
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy import JSON, String, Text, create_engine, event
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
from app.main import app
|
||||
from app.database import Base, get_db
|
||||
from app.database import Base
|
||||
|
||||
# ── Patch PostgreSQL-specific column types so SQLite can handle them ─────
|
||||
# Must run BEFORE importing models, because column type objects are
|
||||
# instantiated at class-definition time.
|
||||
from sqlalchemy.dialects.postgresql import UUID as PG_UUID, JSONB as PG_JSONB
|
||||
|
||||
# Tell SQLAlchemy: when compiling for SQLite, render JSONB as plain JSON
|
||||
# and PostgreSQL UUID as CHAR(32).
|
||||
from sqlalchemy.dialects.sqlite.base import SQLiteTypeCompiler
|
||||
|
||||
if not hasattr(SQLiteTypeCompiler, "visit_JSONB"):
|
||||
SQLiteTypeCompiler.visit_JSONB = lambda self, type_, **kw: "JSON"
|
||||
|
||||
if not hasattr(SQLiteTypeCompiler, "visit_UUID"):
|
||||
SQLiteTypeCompiler.visit_UUID = lambda self, type_, **kw: "CHAR(32)"
|
||||
|
||||
from app.auth import hash_password
|
||||
from app.models.user import User
|
||||
|
||||
# ── Import all models so Base.metadata knows about every table ──────────
|
||||
import app.models # noqa: F401 — triggers model registration via __init__
|
||||
|
||||
# Use in-memory SQLite for tests
|
||||
SQLALCHEMY_DATABASE_URL = "sqlite:///:memory:"
|
||||
|
||||
@@ -19,6 +48,14 @@ engine = create_engine(
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
|
||||
# SQLite needs PRAGMA foreign_keys to enforce FK constraints
|
||||
@event.listens_for(engine, "connect")
|
||||
def _set_sqlite_pragma(dbapi_conn, connection_record):
|
||||
cursor = dbapi_conn.cursor()
|
||||
cursor.execute("PRAGMA foreign_keys=ON")
|
||||
cursor.close()
|
||||
|
||||
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
|
||||
@@ -43,13 +80,21 @@ def db():
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def client(db):
|
||||
"""Create a test client with database override."""
|
||||
"""Create a test client with database override.
|
||||
|
||||
Imports ``app.main`` lazily to avoid pulling in boto3 / APScheduler
|
||||
when only the ``db`` fixture is needed.
|
||||
"""
|
||||
from app.main import app
|
||||
from app.database import get_db
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
with TestClient(app) as test_client:
|
||||
yield test_client
|
||||
|
||||
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
44
backend/tests/fixtures/sample_caldera_ability.yml
vendored
Normal file
44
backend/tests/fixtures/sample_caldera_ability.yml
vendored
Normal file
@@ -0,0 +1,44 @@
|
||||
---
|
||||
id: caldera-test-001
|
||||
name: Get System Info
|
||||
description: Collect basic system information using whoami and systeminfo commands
|
||||
tactic: discovery
|
||||
technique:
|
||||
attack_id: T1082
|
||||
name: System Information Discovery
|
||||
platforms:
|
||||
windows:
|
||||
psh:
|
||||
command: |
|
||||
whoami /all
|
||||
systeminfo
|
||||
cleanup: ""
|
||||
cmd:
|
||||
command: |
|
||||
whoami
|
||||
systeminfo
|
||||
linux:
|
||||
sh:
|
||||
command: |
|
||||
uname -a
|
||||
cat /etc/os-release
|
||||
cleanup: ""
|
||||
---
|
||||
id: caldera-test-002
|
||||
name: List Network Connections
|
||||
description: Enumerate active network connections and listening ports
|
||||
tactic: discovery
|
||||
technique:
|
||||
attack_id: T1049
|
||||
name: System Network Connections Discovery
|
||||
platforms:
|
||||
windows:
|
||||
psh:
|
||||
command: |
|
||||
Get-NetTCPConnection | Select-Object LocalAddress, LocalPort, RemoteAddress, RemotePort, State
|
||||
cleanup: ""
|
||||
linux:
|
||||
sh:
|
||||
command: |
|
||||
netstat -tulnp 2>/dev/null || ss -tulnp
|
||||
cleanup: ""
|
||||
36
backend/tests/fixtures/sample_elastic_rule.toml
vendored
Normal file
36
backend/tests/fixtures/sample_elastic_rule.toml
vendored
Normal file
@@ -0,0 +1,36 @@
|
||||
[metadata]
|
||||
creation_date = "2025/01/15"
|
||||
updated_date = "2025/06/01"
|
||||
maturity = "production"
|
||||
|
||||
[rule]
|
||||
author = ["Test Author"]
|
||||
description = "Detects the creation of a scheduled task via schtasks.exe, which is commonly used by adversaries for persistence."
|
||||
name = "Scheduled Task Created via Schtasks"
|
||||
severity = "medium"
|
||||
type = "eql"
|
||||
language = "eql"
|
||||
query = '''
|
||||
process where process.name : "schtasks.exe" and
|
||||
process.args : ("/create", "-create") and
|
||||
process.args : ("/sc", "-sc") and
|
||||
not process.parent.executable : ("C:\\Program Files\\*", "C:\\Program Files (x86)\\*")
|
||||
'''
|
||||
risk_score = 47
|
||||
rule_id = "test-elastic-001"
|
||||
tags = ["Persistence", "Windows"]
|
||||
|
||||
[[rule.threat]]
|
||||
framework = "MITRE ATT&CK"
|
||||
[[rule.threat.technique]]
|
||||
id = "T1053"
|
||||
name = "Scheduled Task/Job"
|
||||
reference = "https://attack.mitre.org/techniques/T1053/"
|
||||
[[rule.threat.technique.subtechnique]]
|
||||
id = "T1053.005"
|
||||
name = "Scheduled Task"
|
||||
reference = "https://attack.mitre.org/techniques/T1053/005/"
|
||||
[rule.threat.tactic]
|
||||
id = "TA0003"
|
||||
name = "Persistence"
|
||||
reference = "https://attack.mitre.org/tactics/TA0003/"
|
||||
26
backend/tests/fixtures/sample_lolbas_entry.yml
vendored
Normal file
26
backend/tests/fixtures/sample_lolbas_entry.yml
vendored
Normal file
@@ -0,0 +1,26 @@
|
||||
Name: Mshta.exe
|
||||
Description: Used to execute .HTA files
|
||||
Author: Test Author
|
||||
Created: 2025-01-15
|
||||
Commands:
|
||||
- Command: mshta.exe evilfile.hta
|
||||
Description: Open an HTA file from disk
|
||||
Usecase: Execute arbitrary HTA scripts
|
||||
Category: Execute
|
||||
Privileges: User
|
||||
MitreID: T1218.005
|
||||
OperatingSystem: Windows 10, Windows 11
|
||||
- Command: mshta.exe vbscript:Execute("CreateObject(""Wscript.Shell"").Run(""calc.exe"")")
|
||||
Description: Execute VBScript via mshta
|
||||
Usecase: Execute inline VBScript
|
||||
Category: Execute
|
||||
Privileges: User
|
||||
MitreID: T1059.005
|
||||
OperatingSystem: Windows 10, Windows 11
|
||||
Full_Path:
|
||||
- Path: C:\Windows\System32\mshta.exe
|
||||
- Path: C:\Windows\SysWOW64\mshta.exe
|
||||
Detection:
|
||||
- Sigma: https://github.com/SigmaHQ/sigma/blob/master/rules/windows/process_creation/proc_creation_win_mshta.yml
|
||||
Resources:
|
||||
- Link: https://lolbas-project.github.io/#/mshta
|
||||
27
backend/tests/fixtures/sample_sigma_rule.yml
vendored
Normal file
27
backend/tests/fixtures/sample_sigma_rule.yml
vendored
Normal file
@@ -0,0 +1,27 @@
|
||||
title: Windows PowerShell Execution Policy Bypass
|
||||
id: 1f21ec3f-810d-4b0e-8045-322202e22b4b
|
||||
status: stable
|
||||
description: Detects attempts to bypass PowerShell execution policy
|
||||
author: Test Author
|
||||
date: 2025/01/15
|
||||
references:
|
||||
- https://example.com/sigma-test
|
||||
logsource:
|
||||
category: process_creation
|
||||
product: windows
|
||||
detection:
|
||||
selection:
|
||||
CommandLine|contains:
|
||||
- '-ExecutionPolicy Bypass'
|
||||
- '-ep bypass'
|
||||
- 'Set-ExecutionPolicy Bypass'
|
||||
condition: selection
|
||||
falsepositives:
|
||||
- Legitimate admin scripts
|
||||
- CI/CD pipelines
|
||||
level: high
|
||||
tags:
|
||||
- attack.execution
|
||||
- attack.t1059.001
|
||||
- attack.defense_evasion
|
||||
- attack.t1562.001
|
||||
112
backend/tests/fixtures/sample_stix_bundle.json
vendored
Normal file
112
backend/tests/fixtures/sample_stix_bundle.json
vendored
Normal file
@@ -0,0 +1,112 @@
|
||||
{
|
||||
"type": "bundle",
|
||||
"id": "bundle--test-001",
|
||||
"spec_version": "2.0",
|
||||
"objects": [
|
||||
{
|
||||
"type": "intrusion-set",
|
||||
"id": "intrusion-set--test-apt1",
|
||||
"name": "APT1",
|
||||
"aliases": ["Comment Crew", "Comment Panda"],
|
||||
"description": "APT1 is a Chinese cyber espionage group attributed to PLA Unit 61398.",
|
||||
"first_seen": "2006-06-01T00:00:00Z",
|
||||
"last_seen": "2023-12-31T00:00:00Z",
|
||||
"external_references": [
|
||||
{
|
||||
"source_name": "mitre-attack",
|
||||
"url": "https://attack.mitre.org/groups/G0006/",
|
||||
"external_id": "G0006"
|
||||
},
|
||||
{
|
||||
"source_name": "Mandiant Report",
|
||||
"url": "https://www.mandiant.com/resources/apt1-exposing-one-of-chinas-cyber-espionage-units",
|
||||
"description": "Mandiant APT1 Report"
|
||||
}
|
||||
],
|
||||
"created": "2017-05-31T21:31:48.664Z",
|
||||
"modified": "2023-03-22T03:52:18.000Z"
|
||||
},
|
||||
{
|
||||
"type": "intrusion-set",
|
||||
"id": "intrusion-set--test-apt28",
|
||||
"name": "APT28",
|
||||
"aliases": ["Fancy Bear", "Sofacy", "Pawn Storm"],
|
||||
"description": "APT28 is a threat group attributed to Russia's GRU military intelligence.",
|
||||
"first_seen": "2004-01-01T00:00:00Z",
|
||||
"last_seen": "2024-06-30T00:00:00Z",
|
||||
"external_references": [
|
||||
{
|
||||
"source_name": "mitre-attack",
|
||||
"url": "https://attack.mitre.org/groups/G0007/",
|
||||
"external_id": "G0007"
|
||||
}
|
||||
],
|
||||
"created": "2017-05-31T21:31:48.664Z",
|
||||
"modified": "2024-01-15T00:00:00.000Z"
|
||||
},
|
||||
{
|
||||
"type": "attack-pattern",
|
||||
"id": "attack-pattern--test-t1566",
|
||||
"name": "Phishing",
|
||||
"external_references": [
|
||||
{
|
||||
"source_name": "mitre-attack",
|
||||
"url": "https://attack.mitre.org/techniques/T1566/",
|
||||
"external_id": "T1566"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"type": "attack-pattern",
|
||||
"id": "attack-pattern--test-t1059",
|
||||
"name": "Command and Scripting Interpreter",
|
||||
"external_references": [
|
||||
{
|
||||
"source_name": "mitre-attack",
|
||||
"url": "https://attack.mitre.org/techniques/T1059/",
|
||||
"external_id": "T1059"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"type": "attack-pattern",
|
||||
"id": "attack-pattern--test-t1078",
|
||||
"name": "Valid Accounts",
|
||||
"external_references": [
|
||||
{
|
||||
"source_name": "mitre-attack",
|
||||
"url": "https://attack.mitre.org/techniques/T1078/",
|
||||
"external_id": "T1078"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"type": "relationship",
|
||||
"id": "relationship--test-r1",
|
||||
"relationship_type": "uses",
|
||||
"source_ref": "intrusion-set--test-apt1",
|
||||
"target_ref": "attack-pattern--test-t1566"
|
||||
},
|
||||
{
|
||||
"type": "relationship",
|
||||
"id": "relationship--test-r2",
|
||||
"relationship_type": "uses",
|
||||
"source_ref": "intrusion-set--test-apt1",
|
||||
"target_ref": "attack-pattern--test-t1059"
|
||||
},
|
||||
{
|
||||
"type": "relationship",
|
||||
"id": "relationship--test-r3",
|
||||
"relationship_type": "uses",
|
||||
"source_ref": "intrusion-set--test-apt28",
|
||||
"target_ref": "attack-pattern--test-t1566"
|
||||
},
|
||||
{
|
||||
"type": "relationship",
|
||||
"id": "relationship--test-r4",
|
||||
"relationship_type": "uses",
|
||||
"source_ref": "intrusion-set--test-apt28",
|
||||
"target_ref": "attack-pattern--test-t1078"
|
||||
}
|
||||
]
|
||||
}
|
||||
464
backend/tests/test_campaigns_and_snapshots.py
Normal file
464
backend/tests/test_campaigns_and_snapshots.py
Normal file
@@ -0,0 +1,464 @@
|
||||
"""Tests for campaigns, snapshots, and re-testing — T-237.
|
||||
|
||||
Uses the in-memory SQLite test database from conftest.py.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
from app.models.technique import Technique
|
||||
from app.models.test import Test
|
||||
from app.models.test_template import TestTemplate
|
||||
from app.models.campaign import Campaign, CampaignTest
|
||||
from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState
|
||||
from app.models.enums import TestState, TestResult, TechniqueStatus
|
||||
from app.services.campaign_service import (
|
||||
validate_no_circular_dependency,
|
||||
get_campaign_progress,
|
||||
)
|
||||
from app.services.campaign_scheduler_service import (
|
||||
calculate_next_run,
|
||||
check_and_run_recurring_campaigns,
|
||||
)
|
||||
from app.services.snapshot_service import (
|
||||
create_snapshot,
|
||||
compare_snapshots,
|
||||
cleanup_old_snapshots,
|
||||
)
|
||||
from app.services.test_workflow_service import (
|
||||
handle_remediation_completed,
|
||||
get_retest_chain,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def techniques(db):
|
||||
"""Create a set of techniques for testing."""
|
||||
techs = []
|
||||
for mid, name, status in [
|
||||
("T1059", "Command Line", TechniqueStatus.validated),
|
||||
("T1078", "Valid Accounts", TechniqueStatus.partial),
|
||||
("T1053", "Scheduled Tasks", TechniqueStatus.not_covered),
|
||||
]:
|
||||
tech = Technique(
|
||||
mitre_id=mid,
|
||||
name=name,
|
||||
tactic="execution",
|
||||
platforms=["windows"],
|
||||
status_global=status,
|
||||
)
|
||||
db.add(tech)
|
||||
techs.append(tech)
|
||||
db.commit()
|
||||
for t in techs:
|
||||
db.refresh(t)
|
||||
return techs
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def campaign_with_tests(db, techniques, admin_user):
|
||||
"""Create a campaign with ordered tests."""
|
||||
campaign = Campaign(
|
||||
name="Test Campaign",
|
||||
type="custom",
|
||||
status="draft",
|
||||
created_by=admin_user.id,
|
||||
)
|
||||
db.add(campaign)
|
||||
db.flush()
|
||||
|
||||
tests = []
|
||||
for i, tech in enumerate(techniques):
|
||||
test = Test(
|
||||
technique_id=tech.id,
|
||||
name=f"Test for {tech.mitre_id}",
|
||||
state=TestState.draft,
|
||||
created_by=admin_user.id,
|
||||
)
|
||||
db.add(test)
|
||||
db.flush()
|
||||
tests.append(test)
|
||||
|
||||
ct = CampaignTest(
|
||||
campaign_id=campaign.id,
|
||||
test_id=test.id,
|
||||
order_index=i,
|
||||
phase="execution",
|
||||
)
|
||||
db.add(ct)
|
||||
|
||||
db.commit()
|
||||
db.refresh(campaign)
|
||||
return {"campaign": campaign, "tests": tests}
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# Campaign Tests
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestCampaigns:
|
||||
|
||||
def test_create_campaign_with_tests(self, db, campaign_with_tests):
|
||||
"""CRUD básico de campaña con tests ordenados."""
|
||||
campaign = campaign_with_tests["campaign"]
|
||||
assert campaign.name == "Test Campaign"
|
||||
assert campaign.status == "draft"
|
||||
|
||||
cts = (
|
||||
db.query(CampaignTest)
|
||||
.filter(CampaignTest.campaign_id == campaign.id)
|
||||
.order_by(CampaignTest.order_index)
|
||||
.all()
|
||||
)
|
||||
assert len(cts) == 3
|
||||
assert cts[0].order_index == 0
|
||||
assert cts[1].order_index == 1
|
||||
assert cts[2].order_index == 2
|
||||
|
||||
def test_campaign_progress_calculation(self, db, campaign_with_tests):
|
||||
"""Progreso se calcula según estado de tests."""
|
||||
campaign = campaign_with_tests["campaign"]
|
||||
tests = campaign_with_tests["tests"]
|
||||
|
||||
# Initially all draft → 0% complete
|
||||
progress = get_campaign_progress(db, campaign.id)
|
||||
assert progress["total"] == 3
|
||||
assert progress["completion_pct"] == 0.0
|
||||
|
||||
# Validate one test
|
||||
tests[0].state = TestState.validated
|
||||
db.commit()
|
||||
|
||||
progress = get_campaign_progress(db, campaign.id)
|
||||
assert progress["completion_pct"] == pytest.approx(33.3, abs=0.1)
|
||||
|
||||
def test_circular_dependency_prevention(self, db, campaign_with_tests):
|
||||
"""Intentar crear dependencia circular en campaign_tests falla."""
|
||||
from fastapi import HTTPException
|
||||
|
||||
campaign = campaign_with_tests["campaign"]
|
||||
cts = (
|
||||
db.query(CampaignTest)
|
||||
.filter(CampaignTest.campaign_id == campaign.id)
|
||||
.order_by(CampaignTest.order_index)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Create A -> B dependency
|
||||
cts[1].depends_on = cts[0].id
|
||||
db.commit()
|
||||
|
||||
# Try to create B -> A (circular)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
validate_no_circular_dependency(
|
||||
db, campaign.id, cts[0].id, cts[1].id
|
||||
)
|
||||
assert exc_info.value.status_code == 400
|
||||
|
||||
def test_campaign_scheduling_next_run(self):
|
||||
"""next_run_at se calcula correctamente para weekly/monthly/quarterly."""
|
||||
base = datetime(2026, 1, 1)
|
||||
|
||||
weekly = calculate_next_run(base, "weekly")
|
||||
assert weekly == datetime(2026, 1, 8)
|
||||
|
||||
monthly = calculate_next_run(base, "monthly")
|
||||
assert monthly == datetime(2026, 1, 31)
|
||||
|
||||
quarterly = calculate_next_run(base, "quarterly")
|
||||
assert quarterly == datetime(2026, 4, 1)
|
||||
|
||||
def test_campaign_cloning(self, db, campaign_with_tests, admin_user):
|
||||
"""Clonación de campaña recurrente crea tests nuevos con datos correctos."""
|
||||
campaign = campaign_with_tests["campaign"]
|
||||
original_tests = campaign_with_tests["tests"]
|
||||
|
||||
# Set up as recurring
|
||||
campaign.is_recurring = True
|
||||
campaign.recurrence_pattern = "monthly"
|
||||
campaign.next_run_at = datetime.utcnow() - timedelta(hours=1) # Due now
|
||||
db.commit()
|
||||
|
||||
# Run the scheduler
|
||||
spawned = check_and_run_recurring_campaigns(db)
|
||||
assert spawned == 1
|
||||
|
||||
# Find child campaign
|
||||
child = (
|
||||
db.query(Campaign)
|
||||
.filter(Campaign.parent_campaign_id == campaign.id)
|
||||
.first()
|
||||
)
|
||||
assert child is not None
|
||||
assert "Run" in child.name
|
||||
assert child.status == "active"
|
||||
|
||||
# Check child tests are fresh copies (new IDs, draft state)
|
||||
child_cts = (
|
||||
db.query(CampaignTest)
|
||||
.filter(CampaignTest.campaign_id == child.id)
|
||||
.all()
|
||||
)
|
||||
assert len(child_cts) == len(original_tests)
|
||||
|
||||
child_test_ids = {ct.test_id for ct in child_cts}
|
||||
original_test_ids = {t.id for t in original_tests}
|
||||
assert child_test_ids.isdisjoint(original_test_ids) # All new IDs
|
||||
|
||||
for ct in child_cts:
|
||||
test = db.query(Test).filter(Test.id == ct.test_id).first()
|
||||
assert test.state == TestState.draft
|
||||
|
||||
# Check parent was updated
|
||||
db.refresh(campaign)
|
||||
assert campaign.last_run_at is not None
|
||||
assert campaign.next_run_at > datetime.utcnow()
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# Snapshot Tests
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestSnapshots:
|
||||
|
||||
def test_create_snapshot(self, db, techniques, admin_user):
|
||||
"""Snapshot captura estado actual correctamente."""
|
||||
snapshot = create_snapshot(db, name="Test Snapshot", user_id=admin_user.id)
|
||||
|
||||
assert snapshot is not None
|
||||
assert snapshot.name == "Test Snapshot"
|
||||
assert snapshot.total_techniques == len(techniques)
|
||||
assert snapshot.created_by == admin_user.id
|
||||
assert snapshot.organization_score >= 0
|
||||
|
||||
# Verify per-technique states
|
||||
states = (
|
||||
db.query(SnapshotTechniqueState)
|
||||
.filter(SnapshotTechniqueState.snapshot_id == snapshot.id)
|
||||
.all()
|
||||
)
|
||||
assert len(states) == len(techniques)
|
||||
|
||||
mitre_ids = {s.mitre_id for s in states}
|
||||
assert "T1059" in mitre_ids
|
||||
assert "T1078" in mitre_ids
|
||||
assert "T1053" in mitre_ids
|
||||
|
||||
def test_compare_snapshots_improvements(self, db, techniques, admin_user):
|
||||
"""Comparación detecta técnicas que mejoraron."""
|
||||
# Create snapshot A
|
||||
snap_a = create_snapshot(db, name="Before")
|
||||
|
||||
# Improve a technique
|
||||
tech = db.query(Technique).filter(Technique.mitre_id == "T1053").first()
|
||||
tech.status_global = TechniqueStatus.validated
|
||||
db.commit()
|
||||
|
||||
# Create snapshot B
|
||||
snap_b = create_snapshot(db, name="After")
|
||||
|
||||
result = compare_snapshots(db, snap_a.id, snap_b.id)
|
||||
|
||||
assert result["score_delta"] is not None
|
||||
assert result["summary"]["improved_count"] >= 0
|
||||
assert isinstance(result["improved"], list)
|
||||
assert isinstance(result["worsened"], list)
|
||||
assert result["unchanged_count"] >= 0
|
||||
|
||||
def test_compare_snapshots_regressions(self, db, techniques, admin_user):
|
||||
"""Comparación detecta técnicas que empeoraron."""
|
||||
# Create snapshot A
|
||||
snap_a = create_snapshot(db, name="Before Regression")
|
||||
|
||||
# Worsen a technique
|
||||
tech = db.query(Technique).filter(Technique.mitre_id == "T1059").first()
|
||||
tech.status_global = TechniqueStatus.not_covered
|
||||
db.commit()
|
||||
|
||||
snap_b = create_snapshot(db, name="After Regression")
|
||||
|
||||
result = compare_snapshots(db, snap_a.id, snap_b.id)
|
||||
assert result["summary"]["worsened_count"] >= 0
|
||||
|
||||
def test_snapshot_cleanup(self, db, techniques, admin_user):
|
||||
"""Cleanup mantiene solo los últimos N snapshots."""
|
||||
# Create 5 snapshots
|
||||
for i in range(5):
|
||||
create_snapshot(db, name=f"Snapshot {i}")
|
||||
|
||||
total_before = db.query(CoverageSnapshot).count()
|
||||
assert total_before == 5
|
||||
|
||||
# Cleanup keeping only 3
|
||||
deleted = cleanup_old_snapshots(db, keep_last=3)
|
||||
assert deleted == 2
|
||||
|
||||
total_after = db.query(CoverageSnapshot).count()
|
||||
assert total_after == 3
|
||||
|
||||
def test_snapshot_normalized_storage(self, db, techniques, admin_user):
|
||||
"""Verificar que el almacenamiento normalizado funciona correctamente."""
|
||||
snapshot = create_snapshot(db, name="Normalized Check")
|
||||
|
||||
# Each technique should have exactly one SnapshotTechniqueState row
|
||||
for tech in techniques:
|
||||
states = (
|
||||
db.query(SnapshotTechniqueState)
|
||||
.filter(
|
||||
SnapshotTechniqueState.snapshot_id == snapshot.id,
|
||||
SnapshotTechniqueState.technique_id == tech.id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
assert len(states) == 1
|
||||
state = states[0]
|
||||
assert state.mitre_id == tech.mitre_id
|
||||
assert state.status is not None
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# Re-testing Tests
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestRetesting:
|
||||
|
||||
def test_retest_created_on_remediation(self, db, techniques, admin_user):
|
||||
"""Completar remediación crea retest automáticamente."""
|
||||
test = Test(
|
||||
technique_id=techniques[0].id,
|
||||
name="Original Test",
|
||||
state=TestState.validated,
|
||||
remediation_status="completed",
|
||||
created_by=admin_user.id,
|
||||
)
|
||||
db.add(test)
|
||||
db.commit()
|
||||
db.refresh(test)
|
||||
|
||||
retest = handle_remediation_completed(db, test, admin_user)
|
||||
assert retest is not None
|
||||
assert retest.retest_of == test.id
|
||||
assert retest.retest_count == 1
|
||||
assert retest.state == TestState.draft
|
||||
assert retest.technique_id == test.technique_id
|
||||
|
||||
def test_retest_points_to_original(self, db, techniques, admin_user):
|
||||
"""Retest de un retest apunta al test original, no al intermedio."""
|
||||
original = Test(
|
||||
technique_id=techniques[0].id,
|
||||
name="Original",
|
||||
state=TestState.validated,
|
||||
remediation_status="completed",
|
||||
created_by=admin_user.id,
|
||||
retest_count=0,
|
||||
)
|
||||
db.add(original)
|
||||
db.commit()
|
||||
db.refresh(original)
|
||||
|
||||
# First retest
|
||||
retest1 = handle_remediation_completed(db, original, admin_user)
|
||||
assert retest1 is not None
|
||||
assert retest1.retest_of == original.id
|
||||
|
||||
# Simulate completing remediation on retest1
|
||||
retest1.state = TestState.validated
|
||||
retest1.remediation_status = "completed"
|
||||
db.commit()
|
||||
db.refresh(retest1)
|
||||
|
||||
# Second retest — should point to ORIGINAL, not retest1
|
||||
retest2 = handle_remediation_completed(db, retest1, admin_user)
|
||||
assert retest2 is not None
|
||||
assert retest2.retest_of == original.id # Points to original!
|
||||
assert retest2.retest_count == 2
|
||||
|
||||
def test_retest_max_limit(self, db, techniques, admin_user):
|
||||
"""Al alcanzar MAX_RETEST_COUNT no se crea retest."""
|
||||
from app.config import settings
|
||||
|
||||
test = Test(
|
||||
technique_id=techniques[0].id,
|
||||
name="Max Retests Test",
|
||||
state=TestState.validated,
|
||||
remediation_status="completed",
|
||||
created_by=admin_user.id,
|
||||
retest_count=settings.MAX_RETEST_COUNT, # Already at max
|
||||
)
|
||||
db.add(test)
|
||||
db.commit()
|
||||
db.refresh(test)
|
||||
|
||||
result = handle_remediation_completed(db, test, admin_user)
|
||||
assert result is None # No retest created
|
||||
|
||||
def test_retest_chain_query(self, db, techniques, admin_user):
|
||||
"""Endpoint /tests/{id}/retest-chain retorna cadena completa."""
|
||||
original = Test(
|
||||
technique_id=techniques[0].id,
|
||||
name="Chain Original",
|
||||
state=TestState.validated,
|
||||
remediation_status="completed",
|
||||
created_by=admin_user.id,
|
||||
)
|
||||
db.add(original)
|
||||
db.commit()
|
||||
db.refresh(original)
|
||||
|
||||
retest1 = handle_remediation_completed(db, original, admin_user)
|
||||
assert retest1 is not None
|
||||
|
||||
# Complete retest1 and trigger another
|
||||
retest1.state = TestState.validated
|
||||
retest1.remediation_status = "completed"
|
||||
db.commit()
|
||||
db.refresh(retest1)
|
||||
|
||||
retest2 = handle_remediation_completed(db, retest1, admin_user)
|
||||
assert retest2 is not None
|
||||
|
||||
# Get chain
|
||||
chain = get_retest_chain(db, original.id)
|
||||
assert len(chain) == 3 # original + retest1 + retest2
|
||||
assert chain[0].id == original.id
|
||||
assert chain[1].retest_count == 1
|
||||
assert chain[2].retest_count == 2
|
||||
|
||||
def test_retest_has_correct_data(self, db, techniques, admin_user):
|
||||
"""Retest tiene mismos datos base que el original."""
|
||||
original = Test(
|
||||
technique_id=techniques[0].id,
|
||||
name="Data Check Original",
|
||||
description="Test description",
|
||||
platform="windows",
|
||||
procedure_text="Run cmd /c whoami",
|
||||
tool_used="cmd.exe",
|
||||
state=TestState.validated,
|
||||
remediation_status="completed",
|
||||
created_by=admin_user.id,
|
||||
)
|
||||
db.add(original)
|
||||
db.commit()
|
||||
db.refresh(original)
|
||||
|
||||
retest = handle_remediation_completed(db, original, admin_user)
|
||||
assert retest is not None
|
||||
|
||||
# Verify base data is copied
|
||||
assert retest.technique_id == original.technique_id
|
||||
assert retest.description == original.description
|
||||
assert retest.platform == original.platform
|
||||
assert retest.procedure_text == original.procedure_text
|
||||
assert retest.tool_used == original.tool_used
|
||||
assert retest.created_by == original.created_by
|
||||
assert retest.state == TestState.draft
|
||||
427
backend/tests/test_data_sources.py
Normal file
427
backend/tests/test_data_sources.py
Normal file
@@ -0,0 +1,427 @@
|
||||
"""Tests for data source import parsing — T-235.
|
||||
|
||||
Two levels:
|
||||
- TestDataSourcesParsing: Unit tests using local fixtures (fast, no network)
|
||||
- TestDataSourcesIntegration: Integration tests requiring network (pytest -m integration)
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
FIXTURES = Path(__file__).parent / "fixtures"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers — lightweight parsing functions extracted from import services
|
||||
# for testable, isolated verification
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _parse_sigma_yaml(content: str) -> dict | None:
|
||||
"""Parse a Sigma YAML rule and extract relevant fields."""
|
||||
data = yaml.safe_load(content)
|
||||
if not data or not isinstance(data, dict):
|
||||
return None
|
||||
|
||||
title = data.get("title")
|
||||
tags = data.get("tags", [])
|
||||
|
||||
# Extract MITRE technique IDs from tags
|
||||
mitre_ids = []
|
||||
for tag in tags:
|
||||
match = re.match(r"attack\.(t\d{4}(?:\.\d{3})?)", tag, re.IGNORECASE)
|
||||
if match:
|
||||
mitre_ids.append(match.group(1).upper())
|
||||
|
||||
if not title or not mitre_ids:
|
||||
return None
|
||||
|
||||
level = data.get("level", "medium")
|
||||
logsource = data.get("logsource", {})
|
||||
platforms = []
|
||||
product = logsource.get("product", "")
|
||||
if product:
|
||||
platforms.append(product)
|
||||
|
||||
return {
|
||||
"title": title,
|
||||
"description": data.get("description"),
|
||||
"mitre_ids": mitre_ids,
|
||||
"severity": level,
|
||||
"platforms": platforms,
|
||||
"false_positives": data.get("falsepositives", []),
|
||||
}
|
||||
|
||||
|
||||
def _parse_lolbas_yaml(content: str) -> list[dict]:
|
||||
"""Parse a LOLBAS YAML entry and extract templates."""
|
||||
data = yaml.safe_load(content)
|
||||
if not data or not isinstance(data, dict):
|
||||
return []
|
||||
|
||||
name = data.get("Name", "")
|
||||
commands = data.get("Commands", [])
|
||||
results = []
|
||||
|
||||
for cmd in commands:
|
||||
mitre_id = cmd.get("MitreID")
|
||||
if not mitre_id:
|
||||
continue
|
||||
results.append({
|
||||
"name": name,
|
||||
"mitre_id": mitre_id,
|
||||
"command": cmd.get("Command", ""),
|
||||
"description": cmd.get("Description", ""),
|
||||
"usecase": cmd.get("Usecase", ""),
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _parse_caldera_yaml(content: str) -> list[dict]:
|
||||
"""Parse a CALDERA multi-doc YAML and extract abilities."""
|
||||
docs = list(yaml.safe_load_all(content))
|
||||
results = []
|
||||
|
||||
for data in docs:
|
||||
if not data or not isinstance(data, dict):
|
||||
continue
|
||||
|
||||
technique = data.get("technique", {})
|
||||
attack_id = technique.get("attack_id")
|
||||
if not attack_id:
|
||||
continue
|
||||
|
||||
platforms_dict = data.get("platforms", {})
|
||||
platform_names = list(platforms_dict.keys())
|
||||
|
||||
# Extract commands
|
||||
commands = []
|
||||
for plat, executors in platforms_dict.items():
|
||||
if isinstance(executors, dict):
|
||||
for exec_name, exec_data in executors.items():
|
||||
if isinstance(exec_data, dict) and exec_data.get("command"):
|
||||
commands.append(exec_data["command"].strip())
|
||||
|
||||
results.append({
|
||||
"id": data.get("id"),
|
||||
"name": data.get("name"),
|
||||
"description": data.get("description"),
|
||||
"attack_id": attack_id,
|
||||
"tactic": data.get("tactic"),
|
||||
"platforms": platform_names,
|
||||
"commands": commands,
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _parse_elastic_toml(content: str) -> dict | None:
|
||||
"""Parse an Elastic detection rule TOML and extract fields."""
|
||||
try:
|
||||
import toml
|
||||
except ImportError:
|
||||
toml = None
|
||||
|
||||
if toml is None:
|
||||
# Fallback: parse manually enough for testing
|
||||
return None
|
||||
|
||||
data = toml.loads(content)
|
||||
rule = data.get("rule", {})
|
||||
if not rule:
|
||||
return None
|
||||
|
||||
name = rule.get("name")
|
||||
threat_list = rule.get("threat", [])
|
||||
|
||||
mitre_ids = []
|
||||
for threat_entry in threat_list:
|
||||
framework = threat_entry.get("framework", "")
|
||||
if "MITRE" not in framework:
|
||||
continue
|
||||
for tech in threat_entry.get("technique", []):
|
||||
tech_id = tech.get("id")
|
||||
if tech_id:
|
||||
mitre_ids.append(tech_id)
|
||||
for sub in tech.get("subtechnique", []):
|
||||
sub_id = sub.get("id")
|
||||
if sub_id:
|
||||
mitre_ids.append(sub_id)
|
||||
|
||||
return {
|
||||
"name": name,
|
||||
"description": rule.get("description"),
|
||||
"query": rule.get("query"),
|
||||
"severity": rule.get("severity"),
|
||||
"rule_type": rule.get("type"),
|
||||
"mitre_ids": mitre_ids,
|
||||
}
|
||||
|
||||
|
||||
def _parse_stix_bundle(content: str) -> dict:
|
||||
"""Parse a STIX 2.0 bundle and extract intrusion-sets and relationships."""
|
||||
data = json.loads(content)
|
||||
objects = data.get("objects", [])
|
||||
|
||||
intrusion_sets = []
|
||||
relationships = []
|
||||
attack_patterns = {}
|
||||
|
||||
for obj in objects:
|
||||
obj_type = obj.get("type")
|
||||
if obj_type == "intrusion-set":
|
||||
refs = obj.get("external_references", [])
|
||||
mitre_id = None
|
||||
for ref in refs:
|
||||
if ref.get("source_name") == "mitre-attack":
|
||||
mitre_id = ref.get("external_id")
|
||||
break
|
||||
intrusion_sets.append({
|
||||
"id": obj["id"],
|
||||
"name": obj.get("name"),
|
||||
"aliases": obj.get("aliases", []),
|
||||
"description": obj.get("description"),
|
||||
"mitre_id": mitre_id,
|
||||
})
|
||||
elif obj_type == "attack-pattern":
|
||||
refs = obj.get("external_references", [])
|
||||
for ref in refs:
|
||||
if ref.get("source_name") == "mitre-attack":
|
||||
attack_patterns[obj["id"]] = ref.get("external_id")
|
||||
elif obj_type == "relationship":
|
||||
if obj.get("relationship_type") == "uses":
|
||||
relationships.append({
|
||||
"source_ref": obj["source_ref"],
|
||||
"target_ref": obj["target_ref"],
|
||||
})
|
||||
|
||||
return {
|
||||
"intrusion_sets": intrusion_sets,
|
||||
"attack_patterns": attack_patterns,
|
||||
"relationships": relationships,
|
||||
}
|
||||
|
||||
|
||||
def _parse_d3fend_api_response(data: dict) -> list[dict]:
|
||||
"""Parse a mock D3FEND API response."""
|
||||
results = []
|
||||
|
||||
def _walk(node: dict | list, depth: int = 0):
|
||||
if isinstance(node, list):
|
||||
for item in node:
|
||||
_walk(item, depth)
|
||||
elif isinstance(node, dict):
|
||||
d3fend_id = node.get("@id", "")
|
||||
label = node.get("rdfs:label", "")
|
||||
|
||||
if d3fend_id.startswith("d3f:") and label:
|
||||
clean_id = d3fend_id.replace("d3f:", "")
|
||||
if clean_id.startswith("D3-"):
|
||||
definition = node.get("d3f:definition") or node.get("rdfs:comment", "")
|
||||
results.append({
|
||||
"d3fend_id": clean_id,
|
||||
"name": label,
|
||||
"description": definition,
|
||||
})
|
||||
|
||||
# Recurse
|
||||
for key, val in node.items():
|
||||
if isinstance(val, (dict, list)):
|
||||
_walk(val, depth + 1)
|
||||
|
||||
graph = data.get("@graph", data)
|
||||
_walk(graph)
|
||||
return results
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# Unit tests — fast, no network
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestDataSourcesParsing:
|
||||
"""Tests unitarios — sin acceso a red, usando fixtures de YAML/TOML de ejemplo."""
|
||||
|
||||
def test_sigma_yaml_parsing(self):
|
||||
"""Parsear un YAML de Sigma de ejemplo y verificar extracción de campos."""
|
||||
content = (FIXTURES / "sample_sigma_rule.yml").read_text()
|
||||
result = _parse_sigma_yaml(content)
|
||||
|
||||
assert result is not None
|
||||
assert result["title"] == "Windows PowerShell Execution Policy Bypass"
|
||||
assert "T1059.001" in result["mitre_ids"]
|
||||
assert "T1562.001" in result["mitre_ids"]
|
||||
assert result["severity"] == "high"
|
||||
assert "windows" in result["platforms"]
|
||||
assert len(result["false_positives"]) == 2
|
||||
|
||||
def test_lolbas_yaml_parsing(self):
|
||||
"""Parsear un YAML de LOLBAS y verificar extracción de MitreID y commands."""
|
||||
content = (FIXTURES / "sample_lolbas_entry.yml").read_text()
|
||||
results = _parse_lolbas_yaml(content)
|
||||
|
||||
assert len(results) == 2
|
||||
assert results[0]["name"] == "Mshta.exe"
|
||||
assert results[0]["mitre_id"] == "T1218.005"
|
||||
assert "mshta.exe" in results[0]["command"]
|
||||
assert results[1]["mitre_id"] == "T1059.005"
|
||||
|
||||
def test_caldera_yaml_parsing(self):
|
||||
"""Parsear un YAML de CALDERA ability y verificar campos."""
|
||||
content = (FIXTURES / "sample_caldera_ability.yml").read_text()
|
||||
results = _parse_caldera_yaml(content)
|
||||
|
||||
assert len(results) == 2
|
||||
|
||||
sys_info = results[0]
|
||||
assert sys_info["name"] == "Get System Info"
|
||||
assert sys_info["attack_id"] == "T1082"
|
||||
assert sys_info["tactic"] == "discovery"
|
||||
assert "windows" in sys_info["platforms"]
|
||||
assert "linux" in sys_info["platforms"]
|
||||
assert len(sys_info["commands"]) > 0
|
||||
|
||||
net_conn = results[1]
|
||||
assert net_conn["attack_id"] == "T1049"
|
||||
assert net_conn["name"] == "List Network Connections"
|
||||
|
||||
def test_elastic_toml_parsing(self):
|
||||
"""Parsear un TOML de Elastic y verificar extracción de KQL y threat mappings."""
|
||||
content = (FIXTURES / "sample_elastic_rule.toml").read_text()
|
||||
|
||||
try:
|
||||
import toml # noqa: F401
|
||||
except ImportError:
|
||||
pytest.skip("toml package not installed")
|
||||
|
||||
result = _parse_elastic_toml(content)
|
||||
|
||||
assert result is not None
|
||||
assert result["name"] == "Scheduled Task Created via Schtasks"
|
||||
assert result["severity"] == "medium"
|
||||
assert result["rule_type"] == "eql"
|
||||
assert "T1053" in result["mitre_ids"]
|
||||
assert "T1053.005" in result["mitre_ids"]
|
||||
assert "schtasks.exe" in result["query"]
|
||||
|
||||
def test_stix_threat_actor_parsing(self):
|
||||
"""Parsear un bundle STIX de ejemplo y verificar extracción de intrusion-sets y relationships."""
|
||||
content = (FIXTURES / "sample_stix_bundle.json").read_text()
|
||||
result = _parse_stix_bundle(content)
|
||||
|
||||
# Intrusion sets
|
||||
assert len(result["intrusion_sets"]) == 2
|
||||
apt1 = next(is_ for is_ in result["intrusion_sets"] if is_["name"] == "APT1")
|
||||
assert apt1["mitre_id"] == "G0006"
|
||||
assert "Comment Crew" in apt1["aliases"]
|
||||
|
||||
apt28 = next(is_ for is_ in result["intrusion_sets"] if is_["name"] == "APT28")
|
||||
assert apt28["mitre_id"] == "G0007"
|
||||
assert "Fancy Bear" in apt28["aliases"]
|
||||
|
||||
# Attack patterns
|
||||
assert len(result["attack_patterns"]) == 3
|
||||
assert "T1566" in result["attack_patterns"].values()
|
||||
assert "T1059" in result["attack_patterns"].values()
|
||||
|
||||
# Relationships
|
||||
assert len(result["relationships"]) == 4
|
||||
apt1_rels = [r for r in result["relationships"] if "apt1" in r["source_ref"]]
|
||||
assert len(apt1_rels) == 2
|
||||
|
||||
def test_d3fend_api_response_parsing(self):
|
||||
"""Parsear una respuesta mock de la API D3FEND."""
|
||||
mock_response = {
|
||||
"@graph": [
|
||||
{
|
||||
"@id": "d3f:D3-AL",
|
||||
"rdfs:label": "Application Layer",
|
||||
"d3f:definition": "Monitoring at the application layer.",
|
||||
},
|
||||
{
|
||||
"@id": "d3f:D3-NI",
|
||||
"rdfs:label": "Network Isolation",
|
||||
"rdfs:comment": "Isolating networks to prevent lateral movement.",
|
||||
},
|
||||
{
|
||||
"@id": "d3f:NotATechnique",
|
||||
"rdfs:label": "Something else",
|
||||
"d3f:definition": "Not a D3FEND technique.",
|
||||
},
|
||||
{
|
||||
"@id": "d3f:D3-DE",
|
||||
"rdfs:label": "Decoy Environment",
|
||||
"d3f:definition": "Using decoys to detect attackers.",
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
results = _parse_d3fend_api_response(mock_response)
|
||||
|
||||
assert len(results) == 3 # Only D3- prefixed IDs
|
||||
ids = [r["d3fend_id"] for r in results]
|
||||
assert "D3-AL" in ids
|
||||
assert "D3-NI" in ids
|
||||
assert "D3-DE" in ids
|
||||
|
||||
ni = next(r for r in results if r["d3fend_id"] == "D3-NI")
|
||||
assert ni["name"] == "Network Isolation"
|
||||
assert "lateral movement" in ni["description"].lower()
|
||||
|
||||
def test_no_duplicates_on_reimport(self):
|
||||
"""Verificar que la lógica de deduplicación funciona con datos mock."""
|
||||
content = (FIXTURES / "sample_sigma_rule.yml").read_text()
|
||||
|
||||
# Parse twice
|
||||
result1 = _parse_sigma_yaml(content)
|
||||
result2 = _parse_sigma_yaml(content)
|
||||
|
||||
# Same data should produce identical output
|
||||
assert result1 == result2
|
||||
assert result1["title"] == result2["title"]
|
||||
assert result1["mitre_ids"] == result2["mitre_ids"]
|
||||
|
||||
# Simulate deduplication by title+mitre_id
|
||||
seen = set()
|
||||
unique_count = 0
|
||||
for r in [result1, result2]:
|
||||
key = (r["title"], tuple(r["mitre_ids"]))
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
unique_count += 1
|
||||
|
||||
assert unique_count == 1 # Only one unique entry
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# Integration tests — require network. Run with: pytest -m integration
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestDataSourcesIntegration:
|
||||
"""Tests de integración — requieren acceso a red. Ejecutar con: pytest -m integration"""
|
||||
|
||||
def test_sigma_full_import(self):
|
||||
"""Importar desde GitHub real y verificar volumen."""
|
||||
# This test would clone SigmaHQ and parse all rules
|
||||
# Skipped in regular runs — requires network and significant time
|
||||
pytest.skip("Full Sigma import requires network access — run with pytest -m integration")
|
||||
|
||||
def test_lolbas_full_import(self):
|
||||
"""Importar LOLBAS completo."""
|
||||
pytest.skip("Full LOLBAS import requires network access — run with pytest -m integration")
|
||||
|
||||
def test_caldera_full_import(self):
|
||||
"""Importar CALDERA completo."""
|
||||
pytest.skip("Full CALDERA import requires network access — run with pytest -m integration")
|
||||
|
||||
def test_elastic_full_import(self):
|
||||
"""Importar Elastic rules completo."""
|
||||
pytest.skip("Full Elastic import requires network access — run with pytest -m integration")
|
||||
439
backend/tests/test_scoring_and_compliance.py
Normal file
439
backend/tests/test_scoring_and_compliance.py
Normal file
@@ -0,0 +1,439 @@
|
||||
"""Tests for scoring, operational metrics, and compliance — T-236.
|
||||
|
||||
Uses the in-memory SQLite test database from conftest.py to verify
|
||||
calculations with known data.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
from app.models.technique import Technique
|
||||
from app.models.test import Test
|
||||
from app.models.test_template import TestTemplate
|
||||
from app.models.detection_rule import DetectionRule
|
||||
from app.models.test_detection_result import TestDetectionResult
|
||||
from app.models.defensive_technique import DefensiveTechnique, DefensiveTechniqueMapping
|
||||
from app.models.compliance import ComplianceFramework, ComplianceControl, ComplianceControlMapping
|
||||
from app.models.audit import AuditLog
|
||||
from app.models.enums import TestState, TestResult, TechniqueStatus
|
||||
from app.services.scoring_service import (
|
||||
calculate_technique_score,
|
||||
calculate_tactic_score,
|
||||
calculate_organization_score,
|
||||
)
|
||||
from app.services.operational_metrics_service import (
|
||||
calculate_mttd,
|
||||
calculate_mttr,
|
||||
calculate_detection_efficacy,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_technique(db):
|
||||
"""Create a technique with known data."""
|
||||
tech = Technique(
|
||||
mitre_id="T1059",
|
||||
name="Command and Scripting Interpreter",
|
||||
tactic="execution",
|
||||
platforms=["windows", "linux", "macos"],
|
||||
status_global=TechniqueStatus.validated,
|
||||
)
|
||||
db.add(tech)
|
||||
db.commit()
|
||||
db.refresh(tech)
|
||||
return tech
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_technique_no_tests(db):
|
||||
"""Create a technique with no tests."""
|
||||
tech = Technique(
|
||||
mitre_id="T9999",
|
||||
name="No Tests Technique",
|
||||
tactic="discovery",
|
||||
platforms=["windows"],
|
||||
status_global=TechniqueStatus.not_evaluated,
|
||||
)
|
||||
db.add(tech)
|
||||
db.commit()
|
||||
db.refresh(tech)
|
||||
return tech
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def validated_tests(db, sample_technique, admin_user):
|
||||
"""Create multiple validated tests with detection results."""
|
||||
now = datetime.utcnow()
|
||||
tests = []
|
||||
|
||||
for i, result in enumerate([TestResult.detected, TestResult.detected, TestResult.not_detected]):
|
||||
test = Test(
|
||||
technique_id=sample_technique.id,
|
||||
name=f"Test {i+1} for T1059",
|
||||
state=TestState.validated,
|
||||
detection_result=result,
|
||||
created_by=admin_user.id,
|
||||
platform=["windows", "linux", "macos"][i % 3],
|
||||
red_validated_at=now - timedelta(days=i * 30),
|
||||
blue_validated_at=now - timedelta(days=i * 30),
|
||||
created_at=now - timedelta(days=i * 30 + 5),
|
||||
)
|
||||
db.add(test)
|
||||
tests.append(test)
|
||||
|
||||
db.commit()
|
||||
for t in tests:
|
||||
db.refresh(t)
|
||||
return tests
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def compliance_setup(db, sample_technique, sample_technique_no_tests):
|
||||
"""Create a compliance framework with controls mapped to techniques."""
|
||||
framework = ComplianceFramework(
|
||||
name="NIST 800-53",
|
||||
version="5.0",
|
||||
description="NIST Special Publication 800-53",
|
||||
)
|
||||
db.add(framework)
|
||||
db.flush()
|
||||
|
||||
# Control 1: mapped to validated technique
|
||||
control1 = ComplianceControl(
|
||||
framework_id=framework.id,
|
||||
control_id="AC-2",
|
||||
title="Account Management",
|
||||
category="Access Control",
|
||||
)
|
||||
db.add(control1)
|
||||
db.flush()
|
||||
|
||||
mapping1 = ComplianceControlMapping(
|
||||
compliance_control_id=control1.id,
|
||||
technique_id=sample_technique.id,
|
||||
)
|
||||
db.add(mapping1)
|
||||
|
||||
# Control 2: mapped to technique with no tests
|
||||
control2 = ComplianceControl(
|
||||
framework_id=framework.id,
|
||||
control_id="SI-4",
|
||||
title="Information System Monitoring",
|
||||
category="System and Information Integrity",
|
||||
)
|
||||
db.add(control2)
|
||||
db.flush()
|
||||
|
||||
mapping2 = ComplianceControlMapping(
|
||||
compliance_control_id=control2.id,
|
||||
technique_id=sample_technique_no_tests.id,
|
||||
)
|
||||
db.add(mapping2)
|
||||
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
"framework": framework,
|
||||
"control_covered": control1,
|
||||
"control_not_covered": control2,
|
||||
}
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# Scoring Tests
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestScoring:
|
||||
|
||||
def test_technique_score_all_detected(self, db, sample_technique, admin_user):
|
||||
"""Técnica con todos los tests detected → score alto."""
|
||||
now = datetime.utcnow()
|
||||
for i in range(3):
|
||||
test = Test(
|
||||
technique_id=sample_technique.id,
|
||||
name=f"All Detected {i}",
|
||||
state=TestState.validated,
|
||||
detection_result=TestResult.detected,
|
||||
created_by=admin_user.id,
|
||||
platform=["windows", "linux", "macos"][i],
|
||||
red_validated_at=now - timedelta(days=10),
|
||||
)
|
||||
db.add(test)
|
||||
db.commit()
|
||||
|
||||
result = calculate_technique_score(sample_technique, db)
|
||||
assert result["total_score"] > 0
|
||||
# Test component should be maxed out (all detected)
|
||||
assert result["breakdown"]["tests_validated"]["score"] > 0
|
||||
|
||||
def test_technique_score_no_tests(self, db, sample_technique_no_tests):
|
||||
"""Técnica sin tests → score 0."""
|
||||
result = calculate_technique_score(sample_technique_no_tests, db)
|
||||
assert result["total_score"] == 0
|
||||
|
||||
def test_technique_score_partial_detection(self, db, sample_technique, validated_tests):
|
||||
"""Técnica con detección parcial → score intermedio."""
|
||||
result = calculate_technique_score(sample_technique, db)
|
||||
# 2 detected out of 3 validated → partial score
|
||||
assert 0 < result["total_score"] < 100
|
||||
breakdown = result["breakdown"]
|
||||
assert "2/3" in breakdown["tests_validated"]["detail"]
|
||||
|
||||
def test_technique_score_freshness_penalty(self, db, sample_technique, admin_user):
|
||||
"""Tests > 180 días → penalización en freshness."""
|
||||
old_date = datetime.utcnow() - timedelta(days=200)
|
||||
test = Test(
|
||||
technique_id=sample_technique.id,
|
||||
name="Old Test",
|
||||
state=TestState.validated,
|
||||
detection_result=TestResult.detected,
|
||||
created_by=admin_user.id,
|
||||
platform="windows",
|
||||
red_validated_at=old_date,
|
||||
)
|
||||
db.add(test)
|
||||
db.commit()
|
||||
|
||||
result = calculate_technique_score(sample_technique, db)
|
||||
# Freshness should be 0 for tests > 180 days old
|
||||
assert result["breakdown"]["freshness"]["score"] == 0
|
||||
assert "200" in result["breakdown"]["freshness"]["detail"]
|
||||
|
||||
def test_scoring_weights_configurable(self, db, sample_technique, validated_tests):
|
||||
"""Cambiar pesos cambia el score resultante."""
|
||||
from app.config import settings
|
||||
|
||||
original_weight = settings.SCORING_WEIGHT_TESTS
|
||||
|
||||
score1 = calculate_technique_score(sample_technique, db)
|
||||
|
||||
# Change weight
|
||||
settings.SCORING_WEIGHT_TESTS = 80
|
||||
score2 = calculate_technique_score(sample_technique, db)
|
||||
|
||||
# Restore
|
||||
settings.SCORING_WEIGHT_TESTS = original_weight
|
||||
|
||||
# Different weights should produce different scores
|
||||
assert score1["total_score"] != score2["total_score"]
|
||||
|
||||
def test_organization_score_aggregation(self, db, sample_technique, validated_tests):
|
||||
"""Score global agrega correctamente los scores de técnicas."""
|
||||
result = calculate_organization_score(db)
|
||||
assert result["techniques_total"] >= 1
|
||||
assert result["overall_score"] >= 0
|
||||
assert result["techniques_evaluated"] >= 0
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# Operational Metrics Tests
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestOperationalMetrics:
|
||||
|
||||
def test_mttd_calculation(self, db, sample_technique, admin_user):
|
||||
"""MTTD se calcula desde timestamps del audit_log."""
|
||||
now = datetime.utcnow()
|
||||
test = Test(
|
||||
technique_id=sample_technique.id,
|
||||
name="MTTD Test",
|
||||
state=TestState.validated,
|
||||
created_by=admin_user.id,
|
||||
)
|
||||
db.add(test)
|
||||
db.flush()
|
||||
|
||||
# Create audit log entries for state transitions
|
||||
start_log = AuditLog(
|
||||
user_id=admin_user.id,
|
||||
action="start_execution",
|
||||
entity_type="test",
|
||||
entity_id=str(test.id),
|
||||
timestamp=now - timedelta(hours=5),
|
||||
)
|
||||
submit_log = AuditLog(
|
||||
user_id=admin_user.id,
|
||||
action="submit_red",
|
||||
entity_type="test",
|
||||
entity_id=str(test.id),
|
||||
timestamp=now - timedelta(hours=2),
|
||||
)
|
||||
db.add(start_log)
|
||||
db.add(submit_log)
|
||||
db.commit()
|
||||
|
||||
result = calculate_mttd(db)
|
||||
# Should have data (3 hours between start and submit)
|
||||
if result is not None:
|
||||
assert result["sample_size"] >= 1
|
||||
assert result["mean_hours"] >= 0
|
||||
|
||||
def test_mttr_calculation(self, db, sample_technique, admin_user):
|
||||
"""MTTR incluye tiempo de remediación."""
|
||||
now = datetime.utcnow()
|
||||
test = Test(
|
||||
technique_id=sample_technique.id,
|
||||
name="MTTR Test",
|
||||
state=TestState.validated,
|
||||
remediation_status="completed",
|
||||
blue_validated_at=now - timedelta(hours=48),
|
||||
created_by=admin_user.id,
|
||||
)
|
||||
db.add(test)
|
||||
db.flush()
|
||||
|
||||
# Audit log for remediation completion
|
||||
log = AuditLog(
|
||||
user_id=admin_user.id,
|
||||
action="update_remediation",
|
||||
entity_type="test",
|
||||
entity_id=str(test.id),
|
||||
timestamp=now - timedelta(hours=24),
|
||||
)
|
||||
db.add(log)
|
||||
db.commit()
|
||||
|
||||
result = calculate_mttr(db)
|
||||
if result is not None:
|
||||
assert result["sample_size"] >= 1
|
||||
assert result["mean_hours"] > 0
|
||||
|
||||
def test_detection_efficacy(self, db, sample_technique, validated_tests):
|
||||
"""Detection efficacy con datos de prueba conocidos."""
|
||||
result = calculate_detection_efficacy(db)
|
||||
assert result["total"] == 3
|
||||
assert result["detected"] == 2
|
||||
assert result["not_detected"] == 1
|
||||
expected_pct = round((2 / 3) * 100, 1)
|
||||
assert result["percentage"] == expected_pct
|
||||
|
||||
def test_metrics_with_no_data(self, db):
|
||||
"""Métricas retornan null/cero cuando no hay datos suficientes."""
|
||||
mttd = calculate_mttd(db)
|
||||
mttr = calculate_mttr(db)
|
||||
efficacy = calculate_detection_efficacy(db)
|
||||
|
||||
assert mttd is None
|
||||
assert mttr is None
|
||||
assert efficacy["total"] == 0
|
||||
assert efficacy["percentage"] == 0
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# Compliance Tests
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestCompliance:
|
||||
|
||||
def test_control_fully_covered(self, db, sample_technique, validated_tests, compliance_setup):
|
||||
"""Control con todas las técnicas validated → covered."""
|
||||
control = compliance_setup["control_covered"]
|
||||
mappings = (
|
||||
db.query(ComplianceControlMapping)
|
||||
.filter(ComplianceControlMapping.compliance_control_id == control.id)
|
||||
.all()
|
||||
)
|
||||
assert len(mappings) == 1
|
||||
|
||||
# The mapped technique has validated tests
|
||||
technique = mappings[0].technique
|
||||
assert technique.status_global == TechniqueStatus.validated
|
||||
|
||||
def test_control_not_covered(self, db, compliance_setup):
|
||||
"""Control con todas las técnicas sin tests → not_covered."""
|
||||
control = compliance_setup["control_not_covered"]
|
||||
mappings = (
|
||||
db.query(ComplianceControlMapping)
|
||||
.filter(ComplianceControlMapping.compliance_control_id == control.id)
|
||||
.all()
|
||||
)
|
||||
assert len(mappings) == 1
|
||||
|
||||
technique = mappings[0].technique
|
||||
assert technique.status_global == TechniqueStatus.not_evaluated
|
||||
|
||||
def test_control_partially_covered(self, db, sample_technique, sample_technique_no_tests, admin_user, compliance_setup):
|
||||
"""Control con técnicas mixtas → partially_covered."""
|
||||
control = compliance_setup["control_covered"]
|
||||
|
||||
# Add second mapping to the not-evaluated technique
|
||||
mapping = ComplianceControlMapping(
|
||||
compliance_control_id=control.id,
|
||||
technique_id=sample_technique_no_tests.id,
|
||||
)
|
||||
db.add(mapping)
|
||||
db.commit()
|
||||
|
||||
# Now this control has two techniques: one validated, one not_evaluated
|
||||
mappings = (
|
||||
db.query(ComplianceControlMapping)
|
||||
.filter(ComplianceControlMapping.compliance_control_id == control.id)
|
||||
.all()
|
||||
)
|
||||
assert len(mappings) == 2
|
||||
|
||||
statuses = [m.technique.status_global for m in mappings]
|
||||
assert TechniqueStatus.validated in statuses
|
||||
assert TechniqueStatus.not_evaluated in statuses
|
||||
|
||||
def test_compliance_percentage(self, db, sample_technique, validated_tests, compliance_setup):
|
||||
"""Porcentaje global de compliance calculado correctamente."""
|
||||
framework = compliance_setup["framework"]
|
||||
controls = (
|
||||
db.query(ComplianceControl)
|
||||
.filter(ComplianceControl.framework_id == framework.id)
|
||||
.all()
|
||||
)
|
||||
assert len(controls) == 2
|
||||
|
||||
covered = 0
|
||||
total = len(controls)
|
||||
for control in controls:
|
||||
mappings = control.technique_mappings
|
||||
if all(
|
||||
m.technique.status_global in (TechniqueStatus.validated, TechniqueStatus.partial)
|
||||
for m in mappings
|
||||
):
|
||||
covered += 1
|
||||
|
||||
pct = round((covered / total) * 100, 1)
|
||||
assert pct == 50.0 # 1 out of 2 controls covered
|
||||
|
||||
def test_compliance_gaps(self, db, compliance_setup):
|
||||
"""Gaps retorna solo controles no cubiertos con sus técnicas."""
|
||||
framework = compliance_setup["framework"]
|
||||
controls = (
|
||||
db.query(ComplianceControl)
|
||||
.filter(ComplianceControl.framework_id == framework.id)
|
||||
.all()
|
||||
)
|
||||
|
||||
gaps = []
|
||||
for control in controls:
|
||||
mappings = control.technique_mappings
|
||||
uncovered_techniques = [
|
||||
m.technique
|
||||
for m in mappings
|
||||
if m.technique.status_global in (TechniqueStatus.not_evaluated, TechniqueStatus.not_covered)
|
||||
]
|
||||
if uncovered_techniques:
|
||||
gaps.append({
|
||||
"control_id": control.control_id,
|
||||
"title": control.title,
|
||||
"uncovered_techniques": [t.mitre_id for t in uncovered_techniques],
|
||||
})
|
||||
|
||||
assert len(gaps) >= 1
|
||||
si4_gap = next((g for g in gaps if g["control_id"] == "SI-4"), None)
|
||||
assert si4_gap is not None
|
||||
assert "T9999" in si4_gap["uncovered_techniques"]
|
||||
Reference in New Issue
Block a user