Compare commits
181 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 102370897f | |||
| a43d73cac8 | |||
| 5f8a196df3 | |||
| 7703c36ed7 | |||
| 72983a022b | |||
| 802e8f862b | |||
| 93b4a700e6 | |||
| cbaa0deedd | |||
| a4cdc06534 | |||
| e3e79be35a | |||
| cfc48ccd2b | |||
| 421b786953 | |||
| 14a56a6057 | |||
| 20cdb70f57 | |||
| 05898e2cee | |||
| 15eda30b75 | |||
| 019924f78c | |||
| 910c198545 | |||
| 131817cc81 | |||
| baac07d43a | |||
| 4c230caa32 | |||
| f8418bc7ea | |||
| 498536f3f1 | |||
| bea5a8e781 | |||
| c62dafbc1f | |||
| 3db9809be5 | |||
| 7c6aaeda30 | |||
| 1dcff4ad20 | |||
| 0b82d96bcc | |||
| 460faf9935 | |||
| 02ff89401c | |||
| 4e20bfa835 | |||
| 46ff79e695 | |||
| 9f86c205be | |||
| 61e6037e97 | |||
| 2de95a3082 | |||
| 74ca8dc53a | |||
| ad5cd26363 | |||
| fc3b413a83 | |||
| 9f1c4c28c9 | |||
| ea8c48755f | |||
| 5684484fdf | |||
| 06e8effaa4 | |||
| 56d49f6de7 | |||
| 688e843e03 | |||
| e03a222ab0 | |||
| f53500bcb5 | |||
| 9e36b683fa | |||
| b33562a34e | |||
| 757d99d22a | |||
| d896f2761d | |||
| 2bbc65993c | |||
| 46722aec19 | |||
| eee0560aeb | |||
| 922fb251da | |||
| b4a264f2bd | |||
| 2b41b191bd | |||
| a518c06653 | |||
| 61e705ece4 | |||
| 2bfcc7e58c | |||
| 7e4a44bbde | |||
| ba75baeb7d | |||
| 71141d9901 | |||
| 646ac7146e | |||
| 0d4c105aa3 | |||
| a566834e08 | |||
| 51c506a86d | |||
| b98a539d93 | |||
| 65c34c3374 | |||
| 2f1ef7545d | |||
| b39a4fec14 | |||
| 07c6164ceb | |||
| f590a00006 | |||
| 8a542f912d | |||
| e49eca0b24 | |||
| 7d856bef43 | |||
| 70b5c833d4 | |||
| 6c8a1317fd | |||
| 9310652944 | |||
| 193c48d031 | |||
| 416b31a5b6 | |||
| 843b545df3 | |||
| 2238ca671b | |||
| e9aa473a6b | |||
| bd0493aade | |||
| d7d11dfdf5 | |||
| 1b513b050e | |||
| 727b8af7fd | |||
| c467459b51 | |||
| b19ecc0d5f | |||
| 2910aea6b2 | |||
| 20075305a5 | |||
| 4881825fea | |||
| de093778f6 | |||
| 34340a67eb | |||
| db208b9f5c | |||
| a8542512b4 | |||
| 1120d8f2ce | |||
| 2eed763f9e | |||
| 2865846db2 | |||
| 8b035b5c5c | |||
| b248c2816e | |||
| fa8e7f311b | |||
| 2371318e9e | |||
| 8024f32954 | |||
| 45b13bccde | |||
| 2e5b47a4a2 | |||
| 664210be3d | |||
| d3baa9c032 | |||
| 986e91a88a | |||
| cf5332f522 | |||
| 2ee74bf6c9 | |||
| 0830b36cd6 | |||
| e623a0887d | |||
| 0955f35015 | |||
| 7111debd8f | |||
| c886b6e8bb | |||
| d8a0b0c449 | |||
| 27184627f8 | |||
| 323964ed9d | |||
| eeee17d260 | |||
| 43c8b241dc | |||
| 398e279116 | |||
| 0e6cec4d07 | |||
| 44ef4129a5 | |||
| bd0586d296 | |||
| 84a6590e17 | |||
| 69d92f500a | |||
| 2337abe55e | |||
| 4a64ac1c8b | |||
| f17f0a8c10 | |||
| 5f6a098e6b | |||
| a04d5308ab | |||
| 48a936d426 | |||
| 513a7b488b | |||
| fd4a625760 | |||
| 217c4c88b2 | |||
| f316a249cc | |||
| 2675a4b7c2 | |||
| c780ad1e78 | |||
| 8bed3abc08 | |||
| c45eed2801 | |||
| cba9bfbab9 | |||
| 43ef4ea6a0 | |||
| 6f4901b611 | |||
| f36c633d16 | |||
| fc16675cf2 | |||
| d05aa94a01 | |||
| 97349a1d13 | |||
| cfbf6a6ede | |||
| d4b147da7c | |||
| d81fc04b8f | |||
| ab591d30c4 | |||
| 41a0c536bb | |||
| 7fae4783a2 | |||
| 084ea4c0b2 | |||
| 362a17aa1b | |||
| 0febbc67f1 | |||
| 852adb6e4d | |||
| 4fba4152d9 | |||
| 9546ef8bc8 | |||
| e550ebb30f | |||
| 5e18db48d3 | |||
| 4f5370db89 | |||
| 080ce56de7 | |||
| 4ece2293ec | |||
| f97b9e96b7 | |||
| 36fe4aa250 | |||
| a8b4518485 | |||
| 89a951c2a2 | |||
| 9a020f97ef | |||
| 1fe150963c | |||
| 0e1b8e2b39 | |||
| 93ebcf2b86 | |||
| c1e06d4c0a | |||
| d6df7fdc09 | |||
| 7312f9664b | |||
| 63da22b77e | |||
| fd476ce460 | |||
| 60183f704c | |||
| 2495423790 |
@@ -7,6 +7,10 @@ RUN apt-get update && apt-get install -y \
|
||||
gcc \
|
||||
libpq-dev \
|
||||
curl \
|
||||
pkg-config \
|
||||
libxml2-dev \
|
||||
libxmlsec1-dev \
|
||||
libxmlsec1-openssl \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy requirements first for better caching
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
"""Phase 6.1: webhook_configs table.
|
||||
|
||||
Revision ID: b031phase6
|
||||
Revises: b030phase5
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from alembic import op
|
||||
|
||||
revision: str = "b031phase6"
|
||||
down_revision: Union[str, None] = "b030phase5"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"webhook_configs",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column("name", sa.String(200), nullable=False),
|
||||
sa.Column("url", sa.Text, nullable=False),
|
||||
sa.Column("secret", sa.String(256), nullable=True),
|
||||
sa.Column("events", postgresql.JSONB, nullable=False, server_default="[]"),
|
||||
sa.Column("is_active", sa.Boolean, nullable=False, server_default="true"),
|
||||
sa.Column("created_by", postgresql.UUID(as_uuid=True), sa.ForeignKey("users.id", ondelete="SET NULL"), nullable=True),
|
||||
sa.Column("last_triggered_at", sa.DateTime, nullable=True),
|
||||
sa.Column("failure_count", sa.Integer, nullable=False, server_default="0"),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
)
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("webhook_configs")
|
||||
@@ -0,0 +1,41 @@
|
||||
"""Phase 7.2: user notification_preferences and jira_account_id columns.
|
||||
|
||||
Revision ID: b032phase7
|
||||
Revises: b031phase6
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from alembic import op
|
||||
|
||||
revision: str = "b032phase7"
|
||||
down_revision: Union[str, None] = "b031phase6"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
_DEFAULT_PREFS = '{"email_on_test_validated": true, "email_on_campaign_completed": true, "email_on_new_mitre_techniques": false, "in_app_all": true}'
|
||||
|
||||
def _column_names(table: str) -> set[str]:
|
||||
bind = op.get_bind()
|
||||
insp = sa.inspect(bind)
|
||||
return {c["name"] for c in insp.get_columns(table)}
|
||||
|
||||
def upgrade() -> None:
|
||||
user_cols = _column_names("users")
|
||||
if "notification_preferences" not in user_cols:
|
||||
op.add_column(
|
||||
"users",
|
||||
sa.Column("notification_preferences", postgresql.JSONB, nullable=True, server_default=_DEFAULT_PREFS),
|
||||
)
|
||||
if "jira_account_id" not in user_cols:
|
||||
op.add_column(
|
||||
"users",
|
||||
sa.Column("jira_account_id", sa.String(100), nullable=True),
|
||||
)
|
||||
|
||||
def downgrade() -> None:
|
||||
user_cols = _column_names("users")
|
||||
if "jira_account_id" in user_cols:
|
||||
op.drop_column("users", "jira_account_id")
|
||||
if "notification_preferences" in user_cols:
|
||||
op.drop_column("users", "notification_preferences")
|
||||
@@ -0,0 +1,43 @@
|
||||
"""Phase 8: system_configs table for runtime configuration.
|
||||
|
||||
Revision ID: b033syscfg
|
||||
Revises: b032phase7
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from alembic import op
|
||||
|
||||
revision: str = "b033syscfg"
|
||||
down_revision: Union[str, None] = "b032phase7"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def _table_exists(name: str) -> bool:
|
||||
bind = op.get_bind()
|
||||
insp = sa.inspect(bind)
|
||||
return name in insp.get_table_names()
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
if not _table_exists("system_configs"):
|
||||
op.create_table(
|
||||
"system_configs",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column("key", sa.String(200), unique=True, nullable=False),
|
||||
sa.Column("value", sa.Text, nullable=True),
|
||||
sa.Column("description", sa.String(500), nullable=True),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
)
|
||||
op.create_index("ix_system_configs_key", "system_configs", ["key"])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
if _table_exists("system_configs"):
|
||||
op.drop_index("ix_system_configs_key", table_name="system_configs")
|
||||
op.drop_table("system_configs")
|
||||
@@ -0,0 +1,174 @@
|
||||
"""Phase 8: Detection Lifecycle Management tables.
|
||||
|
||||
Revision ID: b034dlm
|
||||
Revises: b033syscfg
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from alembic import op
|
||||
|
||||
revision: str = "b034dlm"
|
||||
down_revision: Union[str, None] = "b033syscfg"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def _table_exists(name: str) -> bool:
|
||||
bind = op.get_bind()
|
||||
insp = sa.inspect(bind)
|
||||
return name in insp.get_table_names()
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
if not _table_exists("detection_assets"):
|
||||
op.create_table(
|
||||
"detection_assets",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column("name", sa.String(500), nullable=False),
|
||||
sa.Column("description", sa.Text),
|
||||
sa.Column("asset_type", sa.String(50), nullable=False),
|
||||
sa.Column("platform", sa.String(100)),
|
||||
sa.Column("rule_content", sa.Text),
|
||||
sa.Column("rule_language", sa.String(50)),
|
||||
sa.Column("rule_repository_url", sa.Text),
|
||||
sa.Column("rule_file_path", sa.String(500)),
|
||||
sa.Column("rule_version", sa.String(50)),
|
||||
sa.Column("rule_hash", sa.String(64)),
|
||||
sa.Column("last_rule_change_at", sa.DateTime),
|
||||
sa.Column("log_source_name", sa.String(200)),
|
||||
sa.Column("log_source_version", sa.String(50)),
|
||||
sa.Column("log_source_config", postgresql.JSONB, server_default="{}"),
|
||||
sa.Column("infrastructure_hash", sa.String(64)),
|
||||
sa.Column("infrastructure_details", postgresql.JSONB, server_default="{}"),
|
||||
sa.Column("health_status", sa.String(20), server_default="untested", nullable=False),
|
||||
sa.Column("last_alert_at", sa.DateTime),
|
||||
sa.Column("alert_count_30d", sa.Integer, server_default="0"),
|
||||
sa.Column("false_positive_rate", sa.Float),
|
||||
sa.Column("expected_alert_frequency", sa.String(50)),
|
||||
sa.Column("owner_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("users.id", ondelete="SET NULL")),
|
||||
sa.Column("backup_owner_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("users.id", ondelete="SET NULL")),
|
||||
sa.Column("team", sa.String(100)),
|
||||
sa.Column("is_active", sa.Boolean, server_default="true", nullable=False),
|
||||
sa.Column("tags", postgresql.JSONB, server_default="[]"),
|
||||
sa.Column("asset_metadata", postgresql.JSONB, server_default="{}"),
|
||||
sa.Column("created_by", postgresql.UUID(as_uuid=True), sa.ForeignKey("users.id", ondelete="SET NULL")),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()")),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()")),
|
||||
)
|
||||
op.create_index("ix_detection_assets_platform", "detection_assets", ["platform"])
|
||||
op.create_index("ix_detection_assets_health_status", "detection_assets", ["health_status"])
|
||||
op.create_index("ix_detection_assets_owner_id", "detection_assets", ["owner_id"])
|
||||
|
||||
if not _table_exists("detection_technique_mappings"):
|
||||
op.create_table(
|
||||
"detection_technique_mappings",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column("detection_asset_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("detection_assets.id", ondelete="CASCADE"), nullable=False),
|
||||
sa.Column("technique_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("techniques.id", ondelete="CASCADE"), nullable=False),
|
||||
sa.Column("coverage_type", sa.String(50), server_default="detect"),
|
||||
sa.Column("confidence_level", sa.String(20), server_default="medium"),
|
||||
sa.Column("notes", sa.Text),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()")),
|
||||
)
|
||||
op.create_index("ix_detection_technique_mappings_technique_id", "detection_technique_mappings", ["technique_id"])
|
||||
op.create_index("ix_detection_technique_mappings_asset_id", "detection_technique_mappings", ["detection_asset_id"])
|
||||
|
||||
if not _table_exists("detection_validations"):
|
||||
op.create_table(
|
||||
"detection_validations",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column("detection_asset_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("detection_assets.id", ondelete="CASCADE"), nullable=False),
|
||||
sa.Column("technique_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("techniques.id", ondelete="SET NULL")),
|
||||
sa.Column("test_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("tests.id", ondelete="SET NULL")),
|
||||
sa.Column("validated_at", sa.DateTime),
|
||||
sa.Column("expires_at", sa.DateTime, nullable=False),
|
||||
sa.Column("is_valid", sa.Boolean, server_default="true", nullable=False),
|
||||
sa.Column("validation_result", sa.String(50)),
|
||||
sa.Column("validation_method", sa.String(100)),
|
||||
sa.Column("rule_hash_at_validation", sa.String(64)),
|
||||
sa.Column("log_source_version_at_validation", sa.String(50)),
|
||||
sa.Column("infrastructure_hash_at_validation", sa.String(64)),
|
||||
sa.Column("environment_snapshot", postgresql.JSONB, server_default="{}"),
|
||||
sa.Column("invalidated_at", sa.DateTime),
|
||||
sa.Column("invalidation_reason", sa.String(50)),
|
||||
sa.Column("invalidation_details", sa.Text),
|
||||
sa.Column("invalidated_by", postgresql.UUID(as_uuid=True), sa.ForeignKey("users.id", ondelete="SET NULL")),
|
||||
sa.Column("validated_by", postgresql.UUID(as_uuid=True), sa.ForeignKey("users.id", ondelete="SET NULL"), nullable=False),
|
||||
sa.Column("integrity_hash", sa.String(64)),
|
||||
sa.Column("notes", sa.Text),
|
||||
sa.Column("evidence_ids", postgresql.JSONB, server_default="[]"),
|
||||
)
|
||||
op.create_index("ix_detection_validations_asset_id_valid", "detection_validations", ["detection_asset_id", "is_valid"])
|
||||
op.create_index("ix_detection_validations_expires_at", "detection_validations", ["expires_at"])
|
||||
|
||||
if not _table_exists("technique_confidence_scores"):
|
||||
op.create_table(
|
||||
"technique_confidence_scores",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column("technique_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("techniques.id", ondelete="CASCADE"), nullable=False, unique=True),
|
||||
sa.Column("confidence_level", sa.String(20), server_default="unknown"),
|
||||
sa.Column("confidence_score", sa.Float, server_default="0.0"),
|
||||
sa.Column("detection_count", sa.Integer, server_default="0"),
|
||||
sa.Column("valid_detection_count", sa.Integer, server_default="0"),
|
||||
sa.Column("last_validated_at", sa.DateTime),
|
||||
sa.Column("next_validation_due", sa.DateTime),
|
||||
sa.Column("last_recalculated_at", sa.DateTime),
|
||||
sa.Column("recency_factor", sa.Float, server_default="0.0"),
|
||||
sa.Column("coverage_factor", sa.Float, server_default="0.0"),
|
||||
sa.Column("health_factor", sa.Float, server_default="0.0"),
|
||||
sa.Column("diversity_factor", sa.Float, server_default="0.0"),
|
||||
sa.Column("score_breakdown", postgresql.JSONB, server_default="{}"),
|
||||
sa.Column("risk_factors", postgresql.JSONB, server_default="[]"),
|
||||
sa.Column("updated_at", sa.DateTime),
|
||||
)
|
||||
op.create_index("ix_technique_confidence_scores_technique_id", "technique_confidence_scores", ["technique_id"])
|
||||
op.create_index("ix_technique_confidence_scores_confidence_level", "technique_confidence_scores", ["confidence_level"])
|
||||
|
||||
if not _table_exists("infrastructure_change_logs"):
|
||||
op.create_table(
|
||||
"infrastructure_change_logs",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column("change_type", sa.String(100), nullable=False),
|
||||
sa.Column("description", sa.Text, nullable=False),
|
||||
sa.Column("affected_platforms", postgresql.JSONB, server_default="[]"),
|
||||
sa.Column("affected_log_sources", postgresql.JSONB, server_default="[]"),
|
||||
sa.Column("change_date", sa.DateTime),
|
||||
sa.Column("reported_by", postgresql.UUID(as_uuid=True), sa.ForeignKey("users.id", ondelete="SET NULL")),
|
||||
sa.Column("auto_invalidate", sa.Boolean, server_default="true"),
|
||||
sa.Column("invalidated_count", sa.Integer, server_default="0"),
|
||||
sa.Column("change_metadata", postgresql.JSONB, server_default="{}"),
|
||||
sa.Column("created_at", sa.DateTime),
|
||||
)
|
||||
op.create_index("ix_infrastructure_change_logs_change_date", "infrastructure_change_logs", ["change_date"])
|
||||
|
||||
if not _table_exists("decay_policies"):
|
||||
op.create_table(
|
||||
"decay_policies",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column("name", sa.String(200), nullable=False),
|
||||
sa.Column("description", sa.Text),
|
||||
sa.Column("applies_to_platform", sa.String(100)),
|
||||
sa.Column("applies_to_asset_type", sa.String(50)),
|
||||
sa.Column("applies_to_tactic", sa.String(100)),
|
||||
sa.Column("fresh_days", sa.Integer, server_default="90"),
|
||||
sa.Column("aging_days", sa.Integer, server_default="180"),
|
||||
sa.Column("stale_days", sa.Integer, server_default="365"),
|
||||
sa.Column("default_validity_days", sa.Integer, server_default="180"),
|
||||
sa.Column("silent_threshold_days", sa.Integer, server_default="30"),
|
||||
sa.Column("noisy_threshold_daily", sa.Integer, server_default="100"),
|
||||
sa.Column("recency_weight", sa.Float, server_default="0.3"),
|
||||
sa.Column("coverage_weight", sa.Float, server_default="0.3"),
|
||||
sa.Column("health_weight", sa.Float, server_default="0.25"),
|
||||
sa.Column("diversity_weight", sa.Float, server_default="0.15"),
|
||||
sa.Column("is_default", sa.Boolean, server_default="false"),
|
||||
sa.Column("is_active", sa.Boolean, server_default="true"),
|
||||
sa.Column("created_at", sa.DateTime),
|
||||
sa.Column("updated_at", sa.DateTime),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
for table in ["decay_policies", "infrastructure_change_logs", "technique_confidence_scores", "detection_validations", "detection_technique_mappings", "detection_assets"]:
|
||||
if _table_exists(table):
|
||||
op.drop_table(table)
|
||||
@@ -0,0 +1,118 @@
|
||||
"""Phase 9: Ownership & Revalidation Queue
|
||||
|
||||
Revision ID: b035ownerq
|
||||
Revises: b034dlm
|
||||
Create Date: 2026-05-19
|
||||
|
||||
Uses raw SQL for all DDL to avoid SQLAlchemy before_create hook issues
|
||||
with existing enum types.
|
||||
"""
|
||||
|
||||
from typing import Union
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision: str = "b035ownerq"
|
||||
down_revision: Union[str, None] = "b034dlm"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# ── Enums (idempotent) ────────────────────────────────────────────────────
|
||||
conn.execute(sa.text("""
|
||||
DO $$ BEGIN
|
||||
CREATE TYPE queue_priority AS ENUM ('critical', 'high', 'medium', 'low');
|
||||
EXCEPTION WHEN duplicate_object THEN NULL;
|
||||
END $$
|
||||
"""))
|
||||
conn.execute(sa.text("""
|
||||
DO $$ BEGIN
|
||||
CREATE TYPE queue_status AS ENUM ('pending', 'in_progress', 'completed', 'dismissed');
|
||||
EXCEPTION WHEN duplicate_object THEN NULL;
|
||||
END $$
|
||||
"""))
|
||||
conn.execute(sa.text("""
|
||||
DO $$ BEGIN
|
||||
CREATE TYPE queue_reason AS ENUM (
|
||||
'validation_expired', 'infra_change', 'osint_alert',
|
||||
'mitre_update', 'rule_modified', 'low_confidence', 'manual');
|
||||
EXCEPTION WHEN duplicate_object THEN NULL;
|
||||
END $$
|
||||
"""))
|
||||
|
||||
# ── technique_ownerships ──────────────────────────────────────────────────
|
||||
conn.execute(sa.text("""
|
||||
CREATE TABLE IF NOT EXISTS technique_ownerships (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
technique_id UUID NOT NULL UNIQUE
|
||||
REFERENCES techniques(id) ON DELETE CASCADE,
|
||||
owner_id UUID
|
||||
REFERENCES users(id) ON DELETE SET NULL,
|
||||
backup_owner_id UUID
|
||||
REFERENCES users(id) ON DELETE SET NULL,
|
||||
team VARCHAR(200),
|
||||
notes TEXT,
|
||||
assigned_at TIMESTAMP,
|
||||
assigned_by UUID
|
||||
REFERENCES users(id) ON DELETE SET NULL,
|
||||
created_at TIMESTAMP DEFAULT now(),
|
||||
updated_at TIMESTAMP DEFAULT now()
|
||||
)
|
||||
"""))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_techown_owner_id ON technique_ownerships (owner_id)"
|
||||
))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_techown_technique_id ON technique_ownerships (technique_id)"
|
||||
))
|
||||
|
||||
# ── revalidation_queue_items ──────────────────────────────────────────────
|
||||
conn.execute(sa.text("""
|
||||
CREATE TABLE IF NOT EXISTS revalidation_queue_items (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
technique_id UUID
|
||||
REFERENCES techniques(id) ON DELETE CASCADE,
|
||||
detection_asset_id UUID
|
||||
REFERENCES detection_assets(id) ON DELETE CASCADE,
|
||||
priority queue_priority NOT NULL DEFAULT 'medium',
|
||||
reason queue_reason NOT NULL,
|
||||
reason_detail TEXT,
|
||||
status queue_status NOT NULL DEFAULT 'pending',
|
||||
assigned_to UUID
|
||||
REFERENCES users(id) ON DELETE SET NULL,
|
||||
due_date TIMESTAMP,
|
||||
created_at TIMESTAMP DEFAULT now(),
|
||||
completed_at TIMESTAMP,
|
||||
dismissed_at TIMESTAMP,
|
||||
completed_by UUID
|
||||
REFERENCES users(id) ON DELETE SET NULL,
|
||||
extra JSONB
|
||||
)
|
||||
"""))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_rqueue_status ON revalidation_queue_items (status)"
|
||||
))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_rqueue_priority ON revalidation_queue_items (priority)"
|
||||
))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_rqueue_assigned_to ON revalidation_queue_items (assigned_to)"
|
||||
))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_rqueue_technique_id ON revalidation_queue_items (technique_id)"
|
||||
))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_rqueue_asset_id ON revalidation_queue_items (detection_asset_id)"
|
||||
))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
conn.execute(sa.text("DROP TABLE IF EXISTS revalidation_queue_items"))
|
||||
conn.execute(sa.text("DROP TABLE IF EXISTS technique_ownerships"))
|
||||
conn.execute(sa.text("DROP TYPE IF EXISTS queue_reason"))
|
||||
conn.execute(sa.text("DROP TYPE IF EXISTS queue_status"))
|
||||
conn.execute(sa.text("DROP TYPE IF EXISTS queue_priority"))
|
||||
@@ -0,0 +1,184 @@
|
||||
"""Phase 10: Attack Paths & Advanced Purple Team
|
||||
|
||||
Revision ID: b036atk
|
||||
Revises: b035ownerq
|
||||
Create Date: 2026-05-19
|
||||
|
||||
Uses raw SQL to avoid SQLAlchemy DDL hook issues with enum types.
|
||||
"""
|
||||
|
||||
from typing import Union
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision: str = "b036atk"
|
||||
down_revision: Union[str, None] = "b035ownerq"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# ── Enums ─────────────────────────────────────────────────────────────────
|
||||
conn.execute(sa.text("""
|
||||
DO $$ BEGIN
|
||||
CREATE TYPE execution_status AS ENUM
|
||||
('planned','in_progress','completed','aborted');
|
||||
EXCEPTION WHEN duplicate_object THEN NULL; END $$
|
||||
"""))
|
||||
conn.execute(sa.text("""
|
||||
DO $$ BEGIN
|
||||
CREATE TYPE step_result_status AS ENUM
|
||||
('pending','executing','detected','not_detected','skipped');
|
||||
EXCEPTION WHEN duplicate_object THEN NULL; END $$
|
||||
"""))
|
||||
conn.execute(sa.text("""
|
||||
DO $$ BEGIN
|
||||
CREATE TYPE timeline_actor_side AS ENUM ('red','blue','system');
|
||||
EXCEPTION WHEN duplicate_object THEN NULL; END $$
|
||||
"""))
|
||||
conn.execute(sa.text("""
|
||||
DO $$ BEGIN
|
||||
CREATE TYPE timeline_entry_type AS ENUM
|
||||
('action','detection','note','phase_transition','flag');
|
||||
EXCEPTION WHEN duplicate_object THEN NULL; END $$
|
||||
"""))
|
||||
|
||||
# ── attack_paths ──────────────────────────────────────────────────────────
|
||||
conn.execute(sa.text("""
|
||||
CREATE TABLE IF NOT EXISTS attack_paths (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
name VARCHAR(300) NOT NULL,
|
||||
description TEXT,
|
||||
objective TEXT,
|
||||
is_template BOOLEAN DEFAULT FALSE,
|
||||
threat_actor_id UUID REFERENCES threat_actors(id) ON DELETE SET NULL,
|
||||
created_by UUID REFERENCES users(id) ON DELETE SET NULL,
|
||||
tags JSONB,
|
||||
is_active BOOLEAN DEFAULT TRUE,
|
||||
created_at TIMESTAMP DEFAULT now(),
|
||||
updated_at TIMESTAMP DEFAULT now()
|
||||
)
|
||||
"""))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_attack_paths_created_by ON attack_paths (created_by)"
|
||||
))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_attack_paths_is_template ON attack_paths (is_template)"
|
||||
))
|
||||
|
||||
# ── attack_path_steps ─────────────────────────────────────────────────────
|
||||
conn.execute(sa.text("""
|
||||
CREATE TABLE IF NOT EXISTS attack_path_steps (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
attack_path_id UUID NOT NULL REFERENCES attack_paths(id) ON DELETE CASCADE,
|
||||
order_index INTEGER NOT NULL DEFAULT 0,
|
||||
kill_chain_phase VARCHAR(60),
|
||||
technique_id UUID REFERENCES techniques(id) ON DELETE SET NULL,
|
||||
test_id UUID REFERENCES tests(id) ON DELETE SET NULL,
|
||||
name VARCHAR(300),
|
||||
description TEXT,
|
||||
expected_detection BOOLEAN DEFAULT TRUE,
|
||||
notes TEXT
|
||||
)
|
||||
"""))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_ap_steps_path_id ON attack_path_steps (attack_path_id)"
|
||||
))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_ap_steps_technique_id ON attack_path_steps (technique_id)"
|
||||
))
|
||||
|
||||
# ── attack_path_executions ────────────────────────────────────────────────
|
||||
conn.execute(sa.text("""
|
||||
CREATE TABLE IF NOT EXISTS attack_path_executions (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
attack_path_id UUID NOT NULL REFERENCES attack_paths(id) ON DELETE CASCADE,
|
||||
status execution_status NOT NULL DEFAULT 'planned',
|
||||
environment VARCHAR(100),
|
||||
red_team_lead UUID REFERENCES users(id) ON DELETE SET NULL,
|
||||
blue_team_lead UUID REFERENCES users(id) ON DELETE SET NULL,
|
||||
started_by UUID REFERENCES users(id) ON DELETE SET NULL,
|
||||
started_at TIMESTAMP,
|
||||
completed_at TIMESTAMP,
|
||||
notes TEXT,
|
||||
created_at TIMESTAMP DEFAULT now(),
|
||||
-- kill-chain metrics (populated on completion)
|
||||
total_steps INTEGER,
|
||||
detected_steps INTEGER,
|
||||
not_detected_steps INTEGER,
|
||||
skipped_steps INTEGER,
|
||||
detection_rate FLOAT,
|
||||
mttd_seconds FLOAT,
|
||||
furthest_undetected_step INTEGER
|
||||
)
|
||||
"""))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_ap_exec_path_id ON attack_path_executions (attack_path_id)"
|
||||
))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_ap_exec_status ON attack_path_executions (status)"
|
||||
))
|
||||
|
||||
# ── attack_path_step_results ──────────────────────────────────────────────
|
||||
conn.execute(sa.text("""
|
||||
CREATE TABLE IF NOT EXISTS attack_path_step_results (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
execution_id UUID NOT NULL
|
||||
REFERENCES attack_path_executions(id) ON DELETE CASCADE,
|
||||
step_id UUID NOT NULL
|
||||
REFERENCES attack_path_steps(id) ON DELETE CASCADE,
|
||||
step_order INTEGER NOT NULL DEFAULT 0,
|
||||
status step_result_status NOT NULL DEFAULT 'pending',
|
||||
executed_by UUID REFERENCES users(id) ON DELETE SET NULL,
|
||||
executed_at TIMESTAMP,
|
||||
detected_at TIMESTAMP,
|
||||
time_to_detect_seconds FLOAT,
|
||||
detection_asset_id UUID
|
||||
REFERENCES detection_assets(id) ON DELETE SET NULL,
|
||||
notes TEXT,
|
||||
evidence_ids JSONB
|
||||
)
|
||||
"""))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_ap_stepres_exec ON attack_path_step_results (execution_id)"
|
||||
))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_ap_stepres_step ON attack_path_step_results (step_id)"
|
||||
))
|
||||
|
||||
# ── attack_path_timeline ──────────────────────────────────────────────────
|
||||
conn.execute(sa.text("""
|
||||
CREATE TABLE IF NOT EXISTS attack_path_timeline (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
execution_id UUID NOT NULL
|
||||
REFERENCES attack_path_executions(id) ON DELETE CASCADE,
|
||||
step_id UUID REFERENCES attack_path_steps(id) ON DELETE SET NULL,
|
||||
timestamp TIMESTAMP NOT NULL DEFAULT now(),
|
||||
actor_side timeline_actor_side NOT NULL,
|
||||
actor_id UUID REFERENCES users(id) ON DELETE SET NULL,
|
||||
entry_type timeline_entry_type NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
extra JSONB
|
||||
)
|
||||
"""))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_timeline_execution_id ON attack_path_timeline (execution_id)"
|
||||
))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_timeline_timestamp ON attack_path_timeline (timestamp)"
|
||||
))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
conn.execute(sa.text("DROP TABLE IF EXISTS attack_path_timeline"))
|
||||
conn.execute(sa.text("DROP TABLE IF EXISTS attack_path_step_results"))
|
||||
conn.execute(sa.text("DROP TABLE IF EXISTS attack_path_executions"))
|
||||
conn.execute(sa.text("DROP TABLE IF EXISTS attack_path_steps"))
|
||||
conn.execute(sa.text("DROP TABLE IF EXISTS attack_paths"))
|
||||
conn.execute(sa.text("DROP TYPE IF EXISTS timeline_entry_type"))
|
||||
conn.execute(sa.text("DROP TYPE IF EXISTS timeline_actor_side"))
|
||||
conn.execute(sa.text("DROP TYPE IF EXISTS step_result_status"))
|
||||
conn.execute(sa.text("DROP TYPE IF EXISTS execution_status"))
|
||||
@@ -0,0 +1,106 @@
|
||||
"""Phase 11: Knowledge Management — Playbooks + Lessons Learned
|
||||
|
||||
Revision ID: b037know
|
||||
Revises: b036atk
|
||||
Create Date: 2026-05-20
|
||||
|
||||
Uses raw SQL to bypass SQLAlchemy DDL hooks (no enum types — string columns
|
||||
with Pydantic-layer validation instead, so no PostgreSQL enums needed).
|
||||
"""
|
||||
|
||||
from typing import Union
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision: str = "b037know"
|
||||
down_revision: Union[str, None] = "b036atk"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# ── playbooks ──────────────────────────────────────────────────────────────
|
||||
conn.execute(sa.text("""
|
||||
CREATE TABLE IF NOT EXISTS playbooks (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
technique_id UUID NOT NULL REFERENCES techniques(id) ON DELETE CASCADE,
|
||||
playbook_type VARCHAR(32) NOT NULL,
|
||||
title VARCHAR(255) NOT NULL,
|
||||
content TEXT NOT NULL DEFAULT '',
|
||||
version INTEGER NOT NULL DEFAULT 1,
|
||||
tools JSONB,
|
||||
prerequisites JSONB,
|
||||
created_by UUID REFERENCES users(id) ON DELETE SET NULL,
|
||||
updated_by UUID REFERENCES users(id) ON DELETE SET NULL,
|
||||
created_at TIMESTAMP DEFAULT now(),
|
||||
updated_at TIMESTAMP DEFAULT now(),
|
||||
is_active BOOLEAN DEFAULT TRUE,
|
||||
CONSTRAINT uq_playbook_technique_type UNIQUE (technique_id, playbook_type)
|
||||
)
|
||||
"""))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_playbooks_technique_id ON playbooks (technique_id)"
|
||||
))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_playbooks_type ON playbooks (playbook_type)"
|
||||
))
|
||||
|
||||
# ── playbook_versions ──────────────────────────────────────────────────────
|
||||
conn.execute(sa.text("""
|
||||
CREATE TABLE IF NOT EXISTS playbook_versions (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
playbook_id UUID NOT NULL REFERENCES playbooks(id) ON DELETE CASCADE,
|
||||
version INTEGER NOT NULL,
|
||||
title VARCHAR(255) NOT NULL,
|
||||
content TEXT NOT NULL DEFAULT '',
|
||||
tools JSONB,
|
||||
prerequisites JSONB,
|
||||
changed_by UUID REFERENCES users(id) ON DELETE SET NULL,
|
||||
change_note VARCHAR(500),
|
||||
created_at TIMESTAMP DEFAULT now()
|
||||
)
|
||||
"""))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_pb_versions_playbook_id ON playbook_versions (playbook_id)"
|
||||
))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_pb_versions_version ON playbook_versions (playbook_id, version)"
|
||||
))
|
||||
|
||||
# ── lessons_learned ────────────────────────────────────────────────────────
|
||||
conn.execute(sa.text("""
|
||||
CREATE TABLE IF NOT EXISTS lessons_learned (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
title VARCHAR(255) NOT NULL,
|
||||
what_happened TEXT NOT NULL DEFAULT '',
|
||||
root_cause TEXT NOT NULL DEFAULT '',
|
||||
fix_applied TEXT,
|
||||
severity VARCHAR(16) NOT NULL DEFAULT 'medium',
|
||||
entity_type VARCHAR(32) NOT NULL DEFAULT 'manual',
|
||||
entity_id UUID,
|
||||
technique_ids JSONB,
|
||||
tags JSONB,
|
||||
created_by UUID REFERENCES users(id) ON DELETE SET NULL,
|
||||
created_at TIMESTAMP DEFAULT now(),
|
||||
updated_at TIMESTAMP DEFAULT now(),
|
||||
is_active BOOLEAN DEFAULT TRUE
|
||||
)
|
||||
"""))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_ll_entity ON lessons_learned (entity_type, entity_id)"
|
||||
))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_ll_severity ON lessons_learned (severity)"
|
||||
))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_ll_created_by ON lessons_learned (created_by)"
|
||||
))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
conn.execute(sa.text("DROP TABLE IF EXISTS lessons_learned"))
|
||||
conn.execute(sa.text("DROP TABLE IF EXISTS playbook_versions"))
|
||||
conn.execute(sa.text("DROP TABLE IF EXISTS playbooks"))
|
||||
@@ -0,0 +1,62 @@
|
||||
"""Phase 12: Risk Intelligence — technique_risk_profiles table
|
||||
|
||||
Revision ID: b038risk
|
||||
Revises: b037know
|
||||
Create Date: 2026-05-20
|
||||
|
||||
Uses raw SQL to bypass SQLAlchemy DDL hooks.
|
||||
"""
|
||||
|
||||
from typing import Union
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision: str = "b038risk"
|
||||
down_revision: Union[str, None] = "b037know"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
conn.execute(sa.text("""
|
||||
CREATE TABLE IF NOT EXISTS technique_risk_profiles (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
technique_id UUID NOT NULL REFERENCES techniques(id) ON DELETE CASCADE,
|
||||
risk_score FLOAT NOT NULL DEFAULT 0.0,
|
||||
likelihood FLOAT NOT NULL DEFAULT 0.0,
|
||||
impact FLOAT NOT NULL DEFAULT 0.0,
|
||||
risk_level VARCHAR(16) NOT NULL DEFAULT 'info',
|
||||
detection_gap FLOAT NOT NULL DEFAULT 1.0,
|
||||
threat_actor_count INTEGER NOT NULL DEFAULT 0,
|
||||
osint_signal_count INTEGER NOT NULL DEFAULT 0,
|
||||
test_fail_count INTEGER NOT NULL DEFAULT 0,
|
||||
test_total_count INTEGER NOT NULL DEFAULT 0,
|
||||
test_failure_rate FLOAT NOT NULL DEFAULT 0.0,
|
||||
confidence_level FLOAT NOT NULL DEFAULT 0.0,
|
||||
scoring_breakdown JSONB,
|
||||
recommendations JSONB,
|
||||
computed_at TIMESTAMP DEFAULT now(),
|
||||
is_stale BOOLEAN DEFAULT TRUE,
|
||||
CONSTRAINT uq_risk_profile_technique UNIQUE (technique_id)
|
||||
)
|
||||
"""))
|
||||
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_risk_profiles_risk_score "
|
||||
"ON technique_risk_profiles (risk_score)"
|
||||
))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_risk_profiles_risk_level "
|
||||
"ON technique_risk_profiles (risk_level)"
|
||||
))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_risk_profiles_stale "
|
||||
"ON technique_risk_profiles (is_stale)"
|
||||
))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
conn.execute(sa.text("DROP TABLE IF EXISTS technique_risk_profiles"))
|
||||
@@ -0,0 +1,77 @@
|
||||
"""Phase 13: Executive Dashboard — posture_snapshots table.
|
||||
|
||||
Revision ID: b039exec
|
||||
Revises: b038risk
|
||||
Create Date: 2026-05-20
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "b039exec"
|
||||
down_revision = "b038risk"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
conn.execute(sa.text("""
|
||||
CREATE TABLE IF NOT EXISTS posture_snapshots (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
snapshot_date DATE NOT NULL,
|
||||
|
||||
-- Coverage
|
||||
total_techniques INTEGER NOT NULL DEFAULT 0,
|
||||
validated_count INTEGER NOT NULL DEFAULT 0,
|
||||
partial_count INTEGER NOT NULL DEFAULT 0,
|
||||
not_covered_count INTEGER NOT NULL DEFAULT 0,
|
||||
coverage_pct FLOAT NOT NULL DEFAULT 0.0,
|
||||
|
||||
-- Risk
|
||||
avg_risk_score FLOAT NOT NULL DEFAULT 0.0,
|
||||
critical_count INTEGER NOT NULL DEFAULT 0,
|
||||
high_count INTEGER NOT NULL DEFAULT 0,
|
||||
medium_count INTEGER NOT NULL DEFAULT 0,
|
||||
low_count INTEGER NOT NULL DEFAULT 0,
|
||||
|
||||
-- Operations
|
||||
open_queue_items INTEGER NOT NULL DEFAULT 0,
|
||||
orphan_techniques INTEGER NOT NULL DEFAULT 0,
|
||||
|
||||
-- Knowledge
|
||||
playbook_count INTEGER NOT NULL DEFAULT 0,
|
||||
lesson_count INTEGER NOT NULL DEFAULT 0,
|
||||
|
||||
-- MTTD
|
||||
mttd_avg_seconds FLOAT,
|
||||
executions_30d INTEGER NOT NULL DEFAULT 0,
|
||||
detection_rate_30d FLOAT,
|
||||
|
||||
-- Meta
|
||||
created_by UUID REFERENCES users(id) ON DELETE SET NULL,
|
||||
created_at TIMESTAMP WITHOUT TIME ZONE DEFAULT now(),
|
||||
extra JSONB
|
||||
)
|
||||
"""))
|
||||
|
||||
# Unique constraint: one snapshot per calendar day
|
||||
conn.execute(sa.text("""
|
||||
DO $$ BEGIN
|
||||
ALTER TABLE posture_snapshots
|
||||
ADD CONSTRAINT uq_posture_snapshot_date UNIQUE (snapshot_date);
|
||||
EXCEPTION WHEN duplicate_table THEN NULL;
|
||||
WHEN duplicate_object THEN NULL;
|
||||
END $$
|
||||
"""))
|
||||
|
||||
# Index for date-range trend queries
|
||||
conn.execute(sa.text("""
|
||||
CREATE INDEX IF NOT EXISTS ix_posture_snapshots_date
|
||||
ON posture_snapshots (snapshot_date)
|
||||
"""))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
conn.execute(sa.text("DROP TABLE IF EXISTS posture_snapshots CASCADE"))
|
||||
@@ -0,0 +1,75 @@
|
||||
"""Phase 14: Enterprise Readiness — api_keys and sso_configs tables.
|
||||
|
||||
Revision ID: b040ent
|
||||
Revises: b039exec
|
||||
Create Date: 2026-05-20
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "b040ent"
|
||||
down_revision = "b039exec"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# ── api_keys ──────────────────────────────────────────────────────────────
|
||||
conn.execute(sa.text("""
|
||||
CREATE TABLE IF NOT EXISTS api_keys (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
name VARCHAR(200) NOT NULL,
|
||||
description TEXT,
|
||||
key_prefix VARCHAR(13) NOT NULL,
|
||||
key_hash VARCHAR(64) NOT NULL UNIQUE,
|
||||
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
scopes JSONB NOT NULL DEFAULT '["read"]',
|
||||
last_used_at TIMESTAMP WITHOUT TIME ZONE,
|
||||
expires_at TIMESTAMP WITHOUT TIME ZONE,
|
||||
is_active BOOLEAN NOT NULL DEFAULT TRUE,
|
||||
created_at TIMESTAMP WITHOUT TIME ZONE DEFAULT now()
|
||||
)
|
||||
"""))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_api_keys_user_id ON api_keys (user_id)"
|
||||
))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_api_keys_key_hash ON api_keys (key_hash)"
|
||||
))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_api_keys_active ON api_keys (is_active)"
|
||||
))
|
||||
|
||||
# ── sso_configs ───────────────────────────────────────────────────────────
|
||||
conn.execute(sa.text("""
|
||||
CREATE TABLE IF NOT EXISTS sso_configs (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
is_enabled BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
provider_name VARCHAR(200),
|
||||
sp_entity_id VARCHAR(500),
|
||||
sp_acs_url VARCHAR(500),
|
||||
sp_slo_url VARCHAR(500),
|
||||
sp_certificate TEXT,
|
||||
sp_private_key TEXT,
|
||||
idp_entity_id VARCHAR(500),
|
||||
idp_sso_url VARCHAR(500),
|
||||
idp_slo_url VARCHAR(500),
|
||||
idp_certificate TEXT,
|
||||
attr_email VARCHAR(200) DEFAULT 'email',
|
||||
attr_username VARCHAR(200) DEFAULT 'username',
|
||||
attr_role VARCHAR(200) DEFAULT 'role',
|
||||
default_role VARCHAR(50) DEFAULT 'viewer',
|
||||
auto_provision BOOLEAN NOT NULL DEFAULT TRUE,
|
||||
created_at TIMESTAMP WITHOUT TIME ZONE DEFAULT now(),
|
||||
updated_at TIMESTAMP WITHOUT TIME ZONE DEFAULT now()
|
||||
)
|
||||
"""))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
conn.execute(sa.text("DROP TABLE IF EXISTS api_keys CASCADE"))
|
||||
conn.execute(sa.text("DROP TABLE IF EXISTS sso_configs CASCADE"))
|
||||
@@ -0,0 +1,82 @@
|
||||
"""Phase 13: Operational Alerts — alert_rules and alert_instances tables.
|
||||
|
||||
Revision ID: b041alerts
|
||||
Revises: b040ent
|
||||
Create Date: 2026-05-21
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "b041alerts"
|
||||
down_revision = "b040ent"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# ── alert_rules ───────────────────────────────────────────────────────────
|
||||
conn.execute(sa.text("""
|
||||
CREATE TABLE IF NOT EXISTS alert_rules (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
name VARCHAR(300) NOT NULL,
|
||||
description TEXT,
|
||||
rule_type VARCHAR(50) NOT NULL,
|
||||
severity VARCHAR(20) NOT NULL DEFAULT 'medium',
|
||||
is_enabled BOOLEAN NOT NULL DEFAULT TRUE,
|
||||
is_system BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
config JSONB NOT NULL DEFAULT '{}',
|
||||
notify_in_app BOOLEAN NOT NULL DEFAULT TRUE,
|
||||
notify_webhook BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
webhook_id UUID REFERENCES webhook_configs(id) ON DELETE SET NULL,
|
||||
cooldown_hours INTEGER NOT NULL DEFAULT 24,
|
||||
created_by UUID REFERENCES users(id) ON DELETE SET NULL,
|
||||
created_at TIMESTAMP WITHOUT TIME ZONE DEFAULT now(),
|
||||
last_fired_at TIMESTAMP WITHOUT TIME ZONE
|
||||
)
|
||||
"""))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_alert_rules_type ON alert_rules (rule_type)"
|
||||
))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_alert_rules_enabled ON alert_rules (is_enabled)"
|
||||
))
|
||||
|
||||
# ── alert_instances ───────────────────────────────────────────────────────
|
||||
conn.execute(sa.text("""
|
||||
CREATE TABLE IF NOT EXISTS alert_instances (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
rule_id UUID REFERENCES alert_rules(id) ON DELETE SET NULL,
|
||||
rule_name VARCHAR(300) NOT NULL,
|
||||
rule_type VARCHAR(50) NOT NULL,
|
||||
severity VARCHAR(20) NOT NULL,
|
||||
title VARCHAR(500) NOT NULL,
|
||||
message TEXT NOT NULL,
|
||||
details JSONB,
|
||||
status VARCHAR(20) NOT NULL DEFAULT 'open',
|
||||
acknowledged_by UUID REFERENCES users(id) ON DELETE SET NULL,
|
||||
acknowledged_at TIMESTAMP WITHOUT TIME ZONE,
|
||||
resolved_at TIMESTAMP WITHOUT TIME ZONE,
|
||||
created_at TIMESTAMP WITHOUT TIME ZONE DEFAULT now()
|
||||
)
|
||||
"""))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_alert_instances_rule_id ON alert_instances (rule_id)"
|
||||
))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_alert_instances_status ON alert_instances (status)"
|
||||
))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_alert_instances_severity ON alert_instances (severity)"
|
||||
))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_alert_instances_created ON alert_instances (created_at)"
|
||||
))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
conn.execute(sa.text("DROP TABLE IF EXISTS alert_instances CASCADE"))
|
||||
conn.execute(sa.text("DROP TABLE IF EXISTS alert_rules CASCADE"))
|
||||
@@ -0,0 +1,25 @@
|
||||
"""Add jira_api_token to users table.
|
||||
|
||||
Revision ID: b042
|
||||
Revises: b041_operational_alerts
|
||||
Create Date: 2026-05-26
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "b042"
|
||||
down_revision = "b041alerts"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"users",
|
||||
sa.Column("jira_api_token", sa.String(500), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("users", "jira_api_token")
|
||||
@@ -0,0 +1,28 @@
|
||||
"""Add jira_email to users table.
|
||||
|
||||
Allows each user to specify a separate email for Jira authentication,
|
||||
independent of their Aegis account email.
|
||||
|
||||
Revision ID: b043
|
||||
Revises: b042
|
||||
Create Date: 2026-05-26
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "b043"
|
||||
down_revision = "b042"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"users",
|
||||
sa.Column("jira_email", sa.String(255), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("users", "jira_email")
|
||||
@@ -0,0 +1,25 @@
|
||||
"""add tempo_api_token to users
|
||||
|
||||
Revision ID: b044
|
||||
Revises: b043
|
||||
Create Date: 2026-05-27
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "b044"
|
||||
down_revision = "b043"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
op.add_column(
|
||||
"users",
|
||||
sa.Column("tempo_api_token", sa.String(500), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_column("users", "tempo_api_token")
|
||||
@@ -0,0 +1,16 @@
|
||||
"""Add blue_work_started_at to tests table."""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "b045"
|
||||
down_revision = "b044"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
op.add_column("tests", sa.Column("blue_work_started_at", sa.DateTime(), nullable=True))
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_column("tests", "blue_work_started_at")
|
||||
@@ -0,0 +1,22 @@
|
||||
"""Add 'disputed' value to teststate enum.
|
||||
|
||||
Revision ID: b046
|
||||
Revises: b045
|
||||
Create Date: 2026-06-03
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision = "b046"
|
||||
down_revision = "b045"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute("ALTER TYPE teststate ADD VALUE IF NOT EXISTS 'disputed'")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# PostgreSQL does not support removing enum values; downgrade is a no-op.
|
||||
pass
|
||||
@@ -0,0 +1,27 @@
|
||||
"""Add start_date to campaigns.
|
||||
|
||||
Revision ID: b047
|
||||
Revises: b046
|
||||
Create Date: 2026-06-03
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "b047"
|
||||
down_revision = "b046"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"campaigns",
|
||||
sa.Column("start_date", sa.DateTime(), nullable=True),
|
||||
)
|
||||
op.create_index("ix_campaigns_start_date", "campaigns", ["start_date"])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_campaigns_start_date", table_name="campaigns")
|
||||
op.drop_column("campaigns", "start_date")
|
||||
@@ -0,0 +1,39 @@
|
||||
"""Add evaluation_imports table.
|
||||
|
||||
Revision ID: b048
|
||||
Revises: b047
|
||||
Create Date: 2026-06-05
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
revision = "b048"
|
||||
down_revision = "b047"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"evaluation_imports",
|
||||
sa.Column("id", UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column("adversary_name", sa.String, nullable=False),
|
||||
sa.Column("adversary_display", sa.String, nullable=False),
|
||||
sa.Column("eval_round", sa.Integer, nullable=False),
|
||||
sa.Column("imported_at", sa.DateTime, nullable=False),
|
||||
sa.Column("imported_by", UUID(as_uuid=True), sa.ForeignKey("users.id"), nullable=True),
|
||||
sa.Column("tests_created", sa.Integer, default=0),
|
||||
sa.Column("techniques_covered", sa.Integer, default=0),
|
||||
sa.Column("status", sa.String, default="completed"),
|
||||
sa.Column("notes", sa.Text, nullable=True),
|
||||
)
|
||||
op.create_index("ix_evaluation_imports_adversary", "evaluation_imports", ["adversary_name"])
|
||||
op.create_index("ix_evaluation_imports_round", "evaluation_imports", ["eval_round"])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_evaluation_imports_round", table_name="evaluation_imports")
|
||||
op.drop_index("ix_evaluation_imports_adversary", table_name="evaluation_imports")
|
||||
op.drop_table("evaluation_imports")
|
||||
+23
-3
@@ -20,7 +20,7 @@ class Settings(BaseSettings):
|
||||
# so tokens survive restarts.
|
||||
SECRET_KEY: str = ""
|
||||
ALGORITHM: str = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 15 # short-lived for security; configurable via env
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 480 # 8 hours — /auth/refresh extends active sessions
|
||||
|
||||
# ── Redis ─────────────────────────────────────────────────────────
|
||||
REDIS_URL: str = "redis://redis:6379/0"
|
||||
@@ -36,6 +36,10 @@ class Settings(BaseSettings):
|
||||
|
||||
# ── MinIO / S3 ───────────────────────────────────────────────────
|
||||
MINIO_ENDPOINT: str = "minio:9000"
|
||||
# Public hostname used in presigned URLs returned to browsers.
|
||||
# In production set this to <server-ip>:9000 (or a public FQDN) so
|
||||
# the browser can reach MinIO directly. Defaults to MINIO_ENDPOINT.
|
||||
MINIO_PUBLIC_ENDPOINT: str = ""
|
||||
MINIO_ACCESS_KEY: str = "minioadmin"
|
||||
MINIO_SECRET_KEY: str = "minioadmin"
|
||||
MINIO_BUCKET: str = "evidence"
|
||||
@@ -51,14 +55,20 @@ class Settings(BaseSettings):
|
||||
JIRA_API_TOKEN: str = ""
|
||||
JIRA_IS_CLOUD: bool = True
|
||||
JIRA_DEFAULT_PROJECT: str = ""
|
||||
JIRA_ISSUE_TYPE_TEST: str = "Task"
|
||||
JIRA_ISSUE_TYPE_CAMPAIGN: str = "Epic"
|
||||
JIRA_ISSUE_TYPE_TEST: str = "Task" # tests (campaign or standalone)
|
||||
JIRA_ISSUE_TYPE_CAMPAIGN: str = "Epic" # campaigns (under Initiative)
|
||||
# Jira custom field ID for "Start date" — Jira Cloud team-managed: customfield_10015
|
||||
# Override with the correct field ID for your Jira instance if different.
|
||||
JIRA_START_DATE_FIELD: str = "customfield_10015"
|
||||
|
||||
# ── Tempo Integration ─────────────────────────────────────────────
|
||||
TEMPO_ENABLED: bool = False
|
||||
TEMPO_API_TOKEN: str = ""
|
||||
TEMPO_API_VERSION: int = 4
|
||||
TEMPO_DEFAULT_WORK_TYPE: str = "Red Team"
|
||||
# Tempo API base URL — use https://api.eu.tempo.io/4 for EU workspaces.
|
||||
# Can also be set via system_configs key "tempo.base_url" at runtime.
|
||||
TEMPO_BASE_URL: str = "" # empty → falls back to https://api.tempo.io/4
|
||||
|
||||
# ── OSINT / Intelligence ────────────────────────────────────────
|
||||
NVD_API_KEY: str = "" # optional; increases NVD rate limit from 5/30s to 50/30s
|
||||
@@ -70,6 +80,16 @@ class Settings(BaseSettings):
|
||||
COMPANY_NAME: str = "Organization"
|
||||
COMPANY_LOGO_PATH: str = "app/templates/reports/assets/logo.png"
|
||||
|
||||
# ── Email / SMTP ──────────────────────────────────────────────────
|
||||
SMTP_ENABLED: bool = False
|
||||
SMTP_HOST: str = ""
|
||||
SMTP_PORT: int = 587
|
||||
SMTP_USERNAME: str = ""
|
||||
SMTP_PASSWORD: str = ""
|
||||
SMTP_FROM_EMAIL: str = "aegis@company.com"
|
||||
SMTP_USE_TLS: bool = True
|
||||
PLATFORM_URL: str = "http://localhost:5173" # base URL for links in emails
|
||||
|
||||
# ── Scoring weights (must sum to 100) ────────────────────────────
|
||||
SCORING_WEIGHT_TESTS: int = 40
|
||||
SCORING_WEIGHT_DETECTION_RULES: int = 25
|
||||
|
||||
@@ -4,6 +4,7 @@ Authentication and RBAC dependencies for FastAPI.
|
||||
Provides:
|
||||
- ``get_current_user``: decodes JWT from HttpOnly cookie (preferred) or
|
||||
Authorization header (fallback), fetches user from DB, raises 401 on failure.
|
||||
Also accepts Aegis API keys (``aegis_…`` prefix) as Bearer tokens.
|
||||
- ``require_role``: factory that returns a dependency enforcing a specific role
|
||||
(admins always pass).
|
||||
"""
|
||||
@@ -19,6 +20,7 @@ from app import auth as auth_lib
|
||||
from app.config import settings
|
||||
from app.database import get_db
|
||||
from app.models.user import User
|
||||
from app.models.api_key import KEY_PREFIX
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OAuth2 scheme (reads Authorization header — used as fallback / Swagger UI)
|
||||
@@ -68,6 +70,15 @@ async def get_current_user(
|
||||
if token is None:
|
||||
raise credentials_exception
|
||||
|
||||
# ── API Key path (Bearer token starts with "aegis_") ──────────────────
|
||||
if token.startswith(KEY_PREFIX):
|
||||
from app.services.api_key_service import authenticate_raw_key
|
||||
user = authenticate_raw_key(db, token)
|
||||
if user is None:
|
||||
raise credentials_exception
|
||||
return user
|
||||
|
||||
# ── JWT path ──────────────────────────────────────────────────────────
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
@@ -112,11 +123,27 @@ async def require_password_changed(
|
||||
return current_user
|
||||
|
||||
|
||||
def _check_api_key_scope(user: User, required_scope: str) -> None:
|
||||
"""Raise 403 if the request was authenticated via an API key that lacks *required_scope*.
|
||||
|
||||
When authenticated via JWT (browser session), ``_api_key_scopes`` is not set
|
||||
and the check is skipped — full access is granted based on role alone.
|
||||
"""
|
||||
key_scopes = getattr(user, "_api_key_scopes", None)
|
||||
if key_scopes is not None and required_scope not in key_scopes:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"API key scope '{required_scope}' required for this operation",
|
||||
)
|
||||
|
||||
|
||||
def require_role(required_role: str):
|
||||
"""Return a FastAPI dependency that enforces *required_role*.
|
||||
|
||||
The dependency allows the request to proceed when
|
||||
``user.role == required_role`` **or** ``user.role == "admin"``.
|
||||
Also enforces API key scopes: admin-role endpoints require the ``admin``
|
||||
scope; all other role-restricted endpoints require ``write``.
|
||||
Otherwise it raises :class:`~fastapi.HTTPException` **403**.
|
||||
"""
|
||||
|
||||
@@ -128,6 +155,8 @@ def require_role(required_role: str):
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Not enough permissions",
|
||||
)
|
||||
scope = "admin" if required_role == "admin" else "write"
|
||||
_check_api_key_scope(current_user, scope)
|
||||
return current_user
|
||||
|
||||
return role_checker
|
||||
@@ -136,7 +165,11 @@ def require_role(required_role: str):
|
||||
def require_any_role(*roles: str):
|
||||
"""Return a FastAPI dependency that enforces **any** of the given *roles*.
|
||||
|
||||
Admins always pass. Usage example::
|
||||
Admins always pass. Also enforces API key scopes: if the only accepted
|
||||
role is ``admin``, the key must carry the ``admin`` scope; otherwise the
|
||||
``write`` scope is required.
|
||||
|
||||
Usage example::
|
||||
|
||||
@router.patch("/resource", dependencies=[Depends(require_any_role("red_lead", "blue_lead"))])
|
||||
"""
|
||||
@@ -149,6 +182,27 @@ def require_any_role(*roles: str):
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Not enough permissions",
|
||||
)
|
||||
scope = "admin" if set(roles) == {"admin"} else "write"
|
||||
_check_api_key_scope(current_user, scope)
|
||||
return current_user
|
||||
|
||||
return role_checker
|
||||
|
||||
|
||||
def require_scope(scope: str):
|
||||
"""Return a dependency that enforces the API key carries *scope*.
|
||||
|
||||
JWT-authenticated requests (browser sessions) bypass this check entirely.
|
||||
Use on mutation endpoints that don't already use ``require_role`` /
|
||||
``require_any_role``::
|
||||
|
||||
@router.post("/resource", dependencies=[Depends(require_scope("write"))])
|
||||
"""
|
||||
|
||||
async def scope_checker(
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> User:
|
||||
_check_api_key_scope(current_user, scope)
|
||||
return current_user
|
||||
|
||||
return scope_checker
|
||||
|
||||
@@ -115,17 +115,26 @@ class TechniqueEntity:
|
||||
) -> TechniqueStatus:
|
||||
"""Recompute ``status_global`` from a list of (state, detection_result) pairs.
|
||||
|
||||
Rules (v2):
|
||||
Rules (v3):
|
||||
1. No tests -> not_evaluated
|
||||
2. All validated -> inspect detection results:
|
||||
- All detected -> validated
|
||||
- Any partially_detected -> partial
|
||||
- Otherwise -> not_covered
|
||||
3. Some validated, others in progress -> partial
|
||||
4. All in intermediate states -> in_progress
|
||||
2. All tests validated -> inspect detection results:
|
||||
a. All detected AND ≥ 2 validated tests -> validated
|
||||
b. All detected but only 1 validated test -> partial
|
||||
(single test is not enough evidence for full coverage)
|
||||
c. Any partially_detected -> partial
|
||||
d. Otherwise (no detected results) -> not_covered
|
||||
3. Some validated, others in intermediate states -> partial
|
||||
4. All tests in intermediate states (draft/executing/evaluating/review/rejected)
|
||||
-> in_progress
|
||||
|
||||
Minimum validated count for "validated": 2 tests.
|
||||
With only 1 validated+detected test the technique is "partial" to
|
||||
signal that more testing is recommended.
|
||||
|
||||
Returns the new status (also set on the entity).
|
||||
"""
|
||||
_MIN_VALIDATED_FOR_FULL = 2 # require ≥ N validated tests for "validated"
|
||||
|
||||
tests = [
|
||||
_TestSnapshot(
|
||||
state=s if isinstance(s, TestState) else TestState(s),
|
||||
@@ -137,9 +146,14 @@ class TechniqueEntity:
|
||||
if not tests:
|
||||
self.status_global = TechniqueStatus.not_evaluated
|
||||
elif all(t.state == TestState.validated for t in tests):
|
||||
validated_count = len(tests)
|
||||
results = [t.detection_result for t in tests if t.detection_result]
|
||||
if results and all(r == TestResult.detected or r == "detected" for r in results):
|
||||
self.status_global = TechniqueStatus.validated
|
||||
# Need at least _MIN_VALIDATED_FOR_FULL tests for "validated"
|
||||
if validated_count >= _MIN_VALIDATED_FOR_FULL:
|
||||
self.status_global = TechniqueStatus.validated
|
||||
else:
|
||||
self.status_global = TechniqueStatus.partial
|
||||
elif any(
|
||||
r == TestResult.partially_detected or r == "partially_detected"
|
||||
for r in results
|
||||
|
||||
@@ -24,6 +24,7 @@ class TestState(str, enum.Enum):
|
||||
in_review = "in_review"
|
||||
validated = "validated"
|
||||
rejected = "rejected"
|
||||
disputed = "disputed" # one lead approved, the other rejected
|
||||
|
||||
|
||||
class TeamSide(str, enum.Enum):
|
||||
|
||||
@@ -45,13 +45,15 @@ class TestState(str, enum.Enum):
|
||||
in_review = "in_review"
|
||||
validated = "validated"
|
||||
rejected = "rejected"
|
||||
disputed = "disputed" # one lead approved, the other rejected
|
||||
|
||||
|
||||
VALID_TRANSITIONS: dict[TestState, list[TestState]] = {
|
||||
TestState.draft: [TestState.red_executing],
|
||||
TestState.red_executing: [TestState.blue_evaluating],
|
||||
TestState.blue_evaluating: [TestState.in_review],
|
||||
TestState.in_review: [TestState.validated, TestState.rejected],
|
||||
TestState.in_review: [TestState.validated, TestState.rejected, TestState.disputed],
|
||||
TestState.disputed: [TestState.validated, TestState.rejected],
|
||||
TestState.rejected: [TestState.draft],
|
||||
TestState.validated: [],
|
||||
}
|
||||
@@ -314,21 +316,21 @@ class TestEntity:
|
||||
def check_dual_validation(self) -> None:
|
||||
"""Evaluate both leads' votes and advance state if appropriate.
|
||||
|
||||
- Both **approved** -> ``validated``
|
||||
- Either **rejected** -> ``rejected``
|
||||
- Otherwise no change (waiting for the other lead).
|
||||
Rules (v2 — consensus required):
|
||||
- Both **approved** -> ``validated``
|
||||
- Both **rejected** -> ``rejected``
|
||||
- One approved + one rejected -> ``disputed`` (conflict, needs discussion)
|
||||
- Otherwise (one or both still pending) -> no change
|
||||
|
||||
Called automatically by :meth:`validate_red` and :meth:`validate_blue`.
|
||||
Also available as a standalone entry point for backward compatibility
|
||||
when validation fields are set externally.
|
||||
"""
|
||||
self._check_dual_validation()
|
||||
|
||||
def _assert_in_review(self, side: str) -> None:
|
||||
if self.state != TestState.in_review:
|
||||
if self.state not in (TestState.in_review, TestState.disputed):
|
||||
raise InvalidOperationError(
|
||||
f"Cannot validate {side} side while test is in "
|
||||
f"'{self.state.value}' state (must be in_review)"
|
||||
f"'{self.state.value}' state (must be in_review or disputed)"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -339,11 +341,19 @@ class TestEntity:
|
||||
)
|
||||
|
||||
def _check_dual_validation(self) -> None:
|
||||
"""If both leads have voted, advance to validated or rejected."""
|
||||
"""Advance the test state once both leads have voted."""
|
||||
r, b = self.red_validation_status, self.blue_validation_status
|
||||
if r == "rejected" or b == "rejected":
|
||||
self.state = TestState.rejected
|
||||
self._events.append(DomainEvent("dual_validation_rejected"))
|
||||
elif r == "approved" and b == "approved":
|
||||
|
||||
if r == "approved" and b == "approved":
|
||||
self.state = TestState.validated
|
||||
self._events.append(DomainEvent("dual_validation_approved"))
|
||||
|
||||
elif r == "rejected" and b == "rejected":
|
||||
# Full consensus to reject
|
||||
self.state = TestState.rejected
|
||||
self._events.append(DomainEvent("dual_validation_rejected"))
|
||||
|
||||
elif (r == "approved" and b == "rejected") or (r == "rejected" and b == "approved"):
|
||||
# Conflict: one approves, one rejects → needs discussion
|
||||
self.state = TestState.disputed
|
||||
self._events.append(DomainEvent("dual_validation_disputed"))
|
||||
|
||||
@@ -11,6 +11,7 @@ sessions.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from apscheduler.schedulers.background import BackgroundScheduler
|
||||
|
||||
@@ -41,11 +42,13 @@ scheduler = BackgroundScheduler()
|
||||
|
||||
def _run_mitre_sync() -> None:
|
||||
"""Execute a MITRE sync inside its own DB session."""
|
||||
from app.services.webhook_service import dispatch_webhook
|
||||
logger.info("Scheduled MITRE sync job starting...")
|
||||
db = SessionLocal()
|
||||
try:
|
||||
summary = sync_mitre(db)
|
||||
logger.info("Scheduled MITRE sync job finished — %s", summary)
|
||||
dispatch_webhook("mitre.synced", {"created": summary.get("created", 0), "updated": summary.get("updated", 0)})
|
||||
except Exception:
|
||||
logger.exception("Scheduled MITRE sync job failed")
|
||||
finally:
|
||||
@@ -98,6 +101,96 @@ def _run_recurring_campaigns() -> None:
|
||||
db.close()
|
||||
|
||||
|
||||
def _run_scheduled_campaign_activation() -> None:
|
||||
"""Auto-activate campaigns whose start_date has arrived.
|
||||
|
||||
Finds all campaigns in 'draft' state with a start_date <= now,
|
||||
activates them, creates Jira tickets, and notifies the red_tech team.
|
||||
Runs every hour so campaigns activate within ~1 hour of their scheduled time.
|
||||
"""
|
||||
logger.info("Scheduled campaign auto-activation check starting...")
|
||||
db = SessionLocal()
|
||||
try:
|
||||
from datetime import datetime as _dt
|
||||
from app.models.campaign import Campaign
|
||||
from app.models.user import User
|
||||
from app.services.campaign_crud_service import activate_campaign as _activate
|
||||
from app.services.notification_service import notify_role
|
||||
from app.services.audit_service import log_action
|
||||
|
||||
now = _dt.utcnow()
|
||||
due_campaigns = (
|
||||
db.query(Campaign)
|
||||
.filter(
|
||||
Campaign.status == "draft",
|
||||
Campaign.start_date != None, # noqa: E711
|
||||
Campaign.start_date <= now,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
activated = 0
|
||||
for campaign in due_campaigns:
|
||||
try:
|
||||
_activate(db, str(campaign.id))
|
||||
notify_role(
|
||||
db,
|
||||
role="red_tech",
|
||||
type="campaign_activated",
|
||||
title="Campaign auto-activated",
|
||||
message=f'Campaign "{campaign.name}" has been automatically activated on its scheduled start date.',
|
||||
entity_type="campaign",
|
||||
entity_id=campaign.id,
|
||||
)
|
||||
log_action(
|
||||
db,
|
||||
user_id=None,
|
||||
action="auto_activate_campaign",
|
||||
entity_type="campaign",
|
||||
entity_id=campaign.id,
|
||||
details={"name": campaign.name, "start_date": str(campaign.start_date)},
|
||||
)
|
||||
|
||||
# Create Jira tickets non-fatally
|
||||
try:
|
||||
from app.services.jira_service import (
|
||||
auto_create_campaign_issue,
|
||||
auto_create_test_issue,
|
||||
get_campaign_jira_key,
|
||||
get_test_jira_key,
|
||||
)
|
||||
# Use first admin user as actor for Jira auth
|
||||
admin_user = db.query(User).filter(User.role == "admin").first()
|
||||
if admin_user:
|
||||
db.refresh(campaign)
|
||||
campaign_jira_key = get_campaign_jira_key(db, str(campaign.id))
|
||||
if not campaign_jira_key:
|
||||
campaign_jira_key = auto_create_campaign_issue(db, campaign, admin_user)
|
||||
if campaign_jira_key:
|
||||
for ct in campaign.campaign_tests:
|
||||
if ct.test and not get_test_jira_key(db, ct.test.id):
|
||||
auto_create_test_issue(
|
||||
db, ct.test, admin_user,
|
||||
parent_ticket_override=campaign_jira_key,
|
||||
campaign_start_date=campaign.start_date,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Jira auto-create failed for auto-activated campaign %s", campaign.id)
|
||||
|
||||
db.commit()
|
||||
activated += 1
|
||||
logger.info("Auto-activated campaign %s (%s)", campaign.id, campaign.name)
|
||||
except Exception:
|
||||
logger.exception("Failed to auto-activate campaign %s", campaign.id)
|
||||
db.rollback()
|
||||
|
||||
logger.info("Campaign auto-activation check finished — activated %d campaigns", activated)
|
||||
except Exception:
|
||||
logger.exception("Campaign auto-activation job failed")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def _run_intel_scan() -> None:
|
||||
"""Execute an intel scan inside its own DB session."""
|
||||
logger.info("Scheduled intel scan job starting...")
|
||||
@@ -111,6 +204,83 @@ def _run_intel_scan() -> None:
|
||||
db.close()
|
||||
|
||||
|
||||
def _run_evaluation_round_check() -> None:
|
||||
"""Weekly job: check if a new ATT&CK Evaluation round is available.
|
||||
|
||||
If a new round is found it is imported automatically and an admin
|
||||
notification is created so the team knows new baseline data is available.
|
||||
"""
|
||||
logger.info("ATT&CK Evaluations new-round check starting...")
|
||||
db = SessionLocal()
|
||||
try:
|
||||
from app.services.attck_evaluations_service import check_for_new_round, import_evaluation_round
|
||||
from app.models.user import User as UserModel
|
||||
|
||||
result = check_for_new_round(db)
|
||||
if result.get("error"):
|
||||
logger.warning("ATT&CK Evaluations check failed: %s", result["error"])
|
||||
return
|
||||
|
||||
if not result.get("new_round_available"):
|
||||
logger.info(
|
||||
"ATT&CK Evaluations check — latest round '%s' already imported",
|
||||
result.get("latest_round", {}).get("display_name", "?"),
|
||||
)
|
||||
return
|
||||
|
||||
latest = result["latest_round"]
|
||||
logger.info(
|
||||
"New ATT&CK Evaluation round detected: %s (round %d) — starting auto-import",
|
||||
latest["display_name"], latest["eval_round"],
|
||||
)
|
||||
|
||||
# Use the first admin user as the importer (system action)
|
||||
admin = db.query(UserModel).filter(UserModel.role == "admin").first()
|
||||
if not admin:
|
||||
logger.warning("ATT&CK Evaluations auto-import: no admin user found — skipping")
|
||||
return
|
||||
|
||||
summary = import_evaluation_round(
|
||||
db,
|
||||
latest["name"],
|
||||
latest["display_name"],
|
||||
latest["eval_round"],
|
||||
admin,
|
||||
)
|
||||
logger.info(
|
||||
"ATT&CK Evaluations auto-import complete — round %d (%s): %d tests created",
|
||||
latest["eval_round"], latest["display_name"], summary["created"],
|
||||
)
|
||||
|
||||
# Notify all admins
|
||||
try:
|
||||
from app.services.notification_service import create_notification
|
||||
admins = db.query(UserModel).filter(UserModel.role == "admin").all()
|
||||
for adm in admins:
|
||||
create_notification(
|
||||
db,
|
||||
user_id=adm.id,
|
||||
title="New ATT&CK Evaluation round imported",
|
||||
message=(
|
||||
f"Round {latest['eval_round']} — {latest['display_name']} — "
|
||||
f"has been automatically imported. "
|
||||
f"{summary['created']} tests created in In Review state. "
|
||||
f"Blue Leads must validate each result before it counts as coverage."
|
||||
),
|
||||
notification_type="eval_import",
|
||||
entity_type="evaluation",
|
||||
entity_id=None,
|
||||
)
|
||||
db.commit()
|
||||
except Exception:
|
||||
logger.warning("Failed to send eval import notifications", exc_info=True)
|
||||
|
||||
except Exception:
|
||||
logger.exception("ATT&CK Evaluations round check job failed")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def _run_osint_enrichment() -> None:
|
||||
"""Execute weekly OSINT enrichment inside its own DB session."""
|
||||
logger.info("Scheduled OSINT enrichment job starting...")
|
||||
@@ -124,6 +294,61 @@ def _run_osint_enrichment() -> None:
|
||||
db.close()
|
||||
|
||||
|
||||
_FREQUENCY_INTERVALS: dict[str, timedelta] = {
|
||||
"daily": timedelta(days=1),
|
||||
"weekly": timedelta(weeks=1),
|
||||
"monthly": timedelta(days=30),
|
||||
}
|
||||
|
||||
|
||||
def _run_data_sources_sync() -> None:
|
||||
"""Check all enabled data sources and sync those that are overdue."""
|
||||
logger.info("Scheduled data sources sync check starting...")
|
||||
db = SessionLocal()
|
||||
try:
|
||||
from app.models.data_source import DataSource
|
||||
from app.services.data_source_service import sync_source
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
sources = (
|
||||
db.query(DataSource)
|
||||
.filter(DataSource.is_enabled == True) # noqa: E712
|
||||
.all()
|
||||
)
|
||||
synced = 0
|
||||
for ds in sources:
|
||||
freq = ds.sync_frequency
|
||||
if not freq or freq == "manual":
|
||||
continue
|
||||
interval = _FREQUENCY_INTERVALS.get(freq)
|
||||
if interval is None:
|
||||
continue
|
||||
last = ds.last_sync_at
|
||||
if last is None:
|
||||
# Never synced — run it now
|
||||
overdue = True
|
||||
else:
|
||||
# Make last timezone-aware if needed
|
||||
if last.tzinfo is None:
|
||||
last = last.replace(tzinfo=timezone.utc)
|
||||
overdue = now - last >= interval
|
||||
if overdue:
|
||||
logger.info(
|
||||
"Data source '%s' is overdue (freq=%s, last=%s) — syncing",
|
||||
ds.name, freq, last,
|
||||
)
|
||||
try:
|
||||
sync_source(db, str(ds.id))
|
||||
synced += 1
|
||||
except Exception:
|
||||
logger.exception("Failed to sync data source '%s'", ds.name)
|
||||
logger.info("Data sources sync check finished — %d source(s) synced", synced)
|
||||
except Exception:
|
||||
logger.exception("Data sources sync check failed")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def _run_stale_detection() -> None:
|
||||
"""Execute daily stale coverage detection inside its own DB session."""
|
||||
logger.info("Scheduled stale coverage detection starting...")
|
||||
@@ -137,6 +362,53 @@ def _run_stale_detection() -> None:
|
||||
db.close()
|
||||
|
||||
|
||||
def _run_decay_engine() -> None:
|
||||
"""Execute the decay engine inside its own DB session."""
|
||||
logger.info("Scheduled decay engine job starting...")
|
||||
db = SessionLocal()
|
||||
try:
|
||||
from app.services.decay_engine_service import run_decay_engine
|
||||
results = run_decay_engine(db)
|
||||
logger.info("Decay engine job finished — %s", results)
|
||||
except Exception:
|
||||
logger.exception("Decay engine job failed")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def _run_queue_generation() -> None:
|
||||
"""Generate revalidation queue items for analysts — runs after decay engine."""
|
||||
logger.info("Scheduled revalidation queue generation starting...")
|
||||
db = SessionLocal()
|
||||
try:
|
||||
from app.services.revalidation_queue_service import generate_queue_items
|
||||
results = generate_queue_items(db)
|
||||
logger.info("Queue generation finished — %s", results)
|
||||
except Exception:
|
||||
logger.exception("Queue generation job failed")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def _run_alert_evaluation() -> None:
|
||||
"""Evaluate all enabled operational alert rules (hourly)."""
|
||||
logger.info("Scheduled alert evaluation job starting...")
|
||||
db = SessionLocal()
|
||||
try:
|
||||
from app.services.operational_alert_service import evaluate_all_rules
|
||||
result = evaluate_all_rules(db)
|
||||
logger.info(
|
||||
"Alert evaluation finished — %d rules, %d alerts fired in %.3fs",
|
||||
result["rules_evaluated"],
|
||||
result["alerts_fired"],
|
||||
result["duration_seconds"],
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Alert evaluation job failed")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scheduler bootstrap
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -186,6 +458,14 @@ def start_scheduler() -> None:
|
||||
name="Weekly coverage snapshot (Sundays 00:00)",
|
||||
replace_existing=True,
|
||||
)
|
||||
scheduler.add_job(
|
||||
_run_scheduled_campaign_activation,
|
||||
trigger="interval",
|
||||
hours=1,
|
||||
id="scheduled_campaign_activation",
|
||||
name="Auto-activate campaigns on start_date (hourly)",
|
||||
replace_existing=True,
|
||||
)
|
||||
scheduler.add_job(
|
||||
_run_recurring_campaigns,
|
||||
trigger="interval",
|
||||
@@ -226,11 +506,56 @@ def start_scheduler() -> None:
|
||||
name="Data retention policies (daily)",
|
||||
replace_existing=True,
|
||||
)
|
||||
scheduler.add_job(
|
||||
_run_data_sources_sync,
|
||||
trigger="interval",
|
||||
hours=6,
|
||||
id="data_sources_sync",
|
||||
name="Data sources auto-sync (every 6h)",
|
||||
replace_existing=True,
|
||||
)
|
||||
scheduler.add_job(
|
||||
_run_decay_engine,
|
||||
trigger="cron",
|
||||
hour=2,
|
||||
minute=0,
|
||||
id="decay_engine",
|
||||
name="Detection decay engine (daily 02:00)",
|
||||
replace_existing=True,
|
||||
)
|
||||
scheduler.add_job(
|
||||
_run_queue_generation,
|
||||
trigger="cron",
|
||||
hour=2,
|
||||
minute=30,
|
||||
id="queue_generation",
|
||||
name="Revalidation queue generation (daily 02:30)",
|
||||
replace_existing=True,
|
||||
)
|
||||
scheduler.add_job(
|
||||
_run_alert_evaluation,
|
||||
trigger="interval",
|
||||
hours=1,
|
||||
id="alert_evaluation",
|
||||
name="Operational alert evaluation (hourly)",
|
||||
replace_existing=True,
|
||||
)
|
||||
scheduler.add_job(
|
||||
_run_evaluation_round_check,
|
||||
trigger="cron",
|
||||
day_of_week="mon",
|
||||
hour=6,
|
||||
minute=0,
|
||||
id="attck_evaluation_check",
|
||||
name="ATT&CK Evaluations new-round check (Mondays 06:00)",
|
||||
replace_existing=True,
|
||||
)
|
||||
scheduler.start()
|
||||
logger.info(
|
||||
"Background scheduler started — mitre_sync (24h), intel_scan (7d), "
|
||||
"notification_cleanup (24h), weekly_snapshot (Sundays 00:00), "
|
||||
"recurring_campaigns (daily), jira_sync (1h), "
|
||||
"osint_enrichment (weekly), stale_detection (daily), "
|
||||
"retention_policies (daily)"
|
||||
"retention_policies (daily), data_sources_sync (6h), "
|
||||
"alert_evaluation (1h), attck_evaluation_check (Mondays 06:00)"
|
||||
)
|
||||
|
||||
@@ -37,6 +37,18 @@ from app.routers import professional_reports as professional_reports_router
|
||||
from app.routers import analytics as analytics_router
|
||||
from app.routers import advanced_metrics as advanced_metrics_router
|
||||
from app.routers import osint as osint_router
|
||||
from app.routers import webhooks as webhooks_router
|
||||
from app.routers import detection_lifecycle as detection_lifecycle_router
|
||||
from app.routers import intel as intel_router
|
||||
from app.routers import admin_config as admin_config_router
|
||||
from app.routers import ownership as ownership_router
|
||||
from app.routers import attack_paths as attack_paths_router
|
||||
from app.routers import knowledge as knowledge_router
|
||||
from app.routers import risk_intelligence as risk_router
|
||||
from app.routers import executive_dashboard as dashboard_router
|
||||
from app.routers import api_keys as api_keys_router
|
||||
from app.routers import sso as sso_router
|
||||
from app.routers import operational_alerts as alerts_router
|
||||
from app.domain.errors import DomainError
|
||||
from app.middleware.error_handler import domain_exception_handler
|
||||
from app.middleware.request_context import RequestContextMiddleware
|
||||
@@ -57,6 +69,25 @@ async def lifespan(app: FastAPI):
|
||||
"""Startup / shutdown logic."""
|
||||
ensure_bucket_exists()
|
||||
start_scheduler()
|
||||
# Seed decay policies
|
||||
from app.database import SessionLocal
|
||||
from app.seed_decay_policies import seed_decay_policies
|
||||
db = SessionLocal()
|
||||
try:
|
||||
seed_decay_policies(db)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
db.close()
|
||||
# Seed operational alert system rules
|
||||
db2 = SessionLocal()
|
||||
try:
|
||||
from app.services.operational_alert_service import seed_system_rules
|
||||
seed_system_rules(db2)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
db2.close()
|
||||
yield
|
||||
# Graceful shutdown of the background scheduler
|
||||
scheduler.shutdown(wait=False)
|
||||
@@ -77,6 +108,24 @@ app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
||||
|
||||
app.add_middleware(RequestContextMiddleware)
|
||||
|
||||
|
||||
# ── No-cache middleware for all /api/ responses ───────────────────────────
|
||||
# Prevents Cloudflare and browser caches from storing API responses,
|
||||
# which would cause stale/empty data to be served after backend restarts.
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import Response as StarletteResponse
|
||||
|
||||
class NoCacheAPIMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
response = await call_next(request)
|
||||
if request.url.path.startswith("/api/"):
|
||||
response.headers["Cache-Control"] = "no-store, no-cache, must-revalidate"
|
||||
response.headers["Pragma"] = "no-cache"
|
||||
return response
|
||||
|
||||
app.add_middleware(NoCacheAPIMiddleware)
|
||||
|
||||
|
||||
# ── Domain exception → HTTP mapping ──────────────────────────────────────
|
||||
app.add_exception_handler(DomainError, domain_exception_handler)
|
||||
|
||||
@@ -116,6 +165,8 @@ app.include_router(heatmap_router.router, prefix="/api/v1")
|
||||
app.include_router(scores_router.router, prefix="/api/v1")
|
||||
app.include_router(operational_metrics_router.router, prefix="/api/v1")
|
||||
app.include_router(compliance_router.router, prefix="/api/v1")
|
||||
app.include_router(intel_router.router, prefix="/api/v1")
|
||||
app.include_router(admin_config_router.router, prefix="/api/v1")
|
||||
app.include_router(snapshots_router.router, prefix="/api/v1")
|
||||
app.include_router(jira_router.router, prefix="/api/v1")
|
||||
app.include_router(worklogs_router.router, prefix="/api/v1")
|
||||
@@ -123,6 +174,16 @@ app.include_router(professional_reports_router.router, prefix="/api/v1")
|
||||
app.include_router(analytics_router.router, prefix="/api/v1")
|
||||
app.include_router(advanced_metrics_router.router, prefix="/api/v1")
|
||||
app.include_router(osint_router.router, prefix="/api/v1")
|
||||
app.include_router(webhooks_router.router, prefix="/api/v1")
|
||||
app.include_router(detection_lifecycle_router.router, prefix="/api/v1")
|
||||
app.include_router(ownership_router.router, prefix="/api/v1")
|
||||
app.include_router(attack_paths_router.router, prefix="/api/v1")
|
||||
app.include_router(knowledge_router.router, prefix="/api/v1")
|
||||
app.include_router(risk_router.router, prefix="/api/v1")
|
||||
app.include_router(dashboard_router.router, prefix="/api/v1")
|
||||
app.include_router(api_keys_router.router, prefix="/api/v1")
|
||||
app.include_router(sso_router.router, prefix="/api/v1")
|
||||
app.include_router(alerts_router.router, prefix="/api/v1")
|
||||
|
||||
|
||||
@app.get("/health", include_in_schema=False)
|
||||
|
||||
@@ -21,6 +21,29 @@ from app.models.worklog import Worklog
|
||||
from app.models.osint_item import OsintItem
|
||||
from app.models.scoring_config import ScoringConfig
|
||||
from app.models.enums import TechniqueStatus, TestState, TestResult, TeamSide
|
||||
from app.models.webhook_config import WebhookConfig
|
||||
from app.models.system_config import SystemConfig
|
||||
from app.models.detection_lifecycle import (
|
||||
DetectionAsset, DetectionTechniqueMapping, DetectionValidation,
|
||||
TechniqueConfidenceScore, InfrastructureChangeLog,
|
||||
DetectionConfidence, DetectionHealthStatus, InvalidationReason,
|
||||
)
|
||||
from app.models.decay_policy import DecayPolicy
|
||||
from app.models.ownership_queue import (
|
||||
TechniqueOwnership, RevalidationQueueItem,
|
||||
QueuePriority, QueueStatus, QueueReason,
|
||||
)
|
||||
from app.models.attack_path import (
|
||||
AttackPath, AttackPathStep, AttackPathExecution,
|
||||
AttackPathStepResult, TimelineEntry,
|
||||
ExecutionStatus, StepResultStatus, TimelineActorSide, TimelineEntryType,
|
||||
)
|
||||
from app.models.knowledge import Playbook, PlaybookVersion, LessonLearned
|
||||
from app.models.risk_intelligence import TechniqueRiskProfile
|
||||
from app.models.executive_dashboard import PostureSnapshot
|
||||
from app.models.api_key import ApiKey
|
||||
from app.models.sso_config import SsoConfig
|
||||
from app.models.operational_alert import AlertRule, AlertInstance
|
||||
|
||||
__all__ = [
|
||||
"User", "Technique", "Test", "TestTemplate", "Evidence",
|
||||
@@ -34,4 +57,19 @@ __all__ = [
|
||||
"JiraLink", "JiraLinkEntityType", "JiraSyncDirection",
|
||||
"Worklog", "OsintItem", "ScoringConfig",
|
||||
"TechniqueStatus", "TestState", "TestResult", "TeamSide",
|
||||
"WebhookConfig", "SystemConfig",
|
||||
"DetectionAsset", "DetectionTechniqueMapping", "DetectionValidation",
|
||||
"TechniqueConfidenceScore", "InfrastructureChangeLog", "DecayPolicy",
|
||||
"TechniqueOwnership", "RevalidationQueueItem",
|
||||
"QueuePriority", "QueueStatus", "QueueReason",
|
||||
"AttackPath", "AttackPathStep", "AttackPathExecution",
|
||||
"AttackPathStepResult", "TimelineEntry",
|
||||
"ExecutionStatus", "StepResultStatus", "TimelineActorSide", "TimelineEntryType",
|
||||
"Playbook", "PlaybookVersion", "LessonLearned",
|
||||
"TechniqueRiskProfile",
|
||||
"PostureSnapshot",
|
||||
"ApiKey",
|
||||
"SsoConfig",
|
||||
"AlertRule",
|
||||
"AlertInstance",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,81 @@
|
||||
"""Phase 14: API Key model for programmatic access."""
|
||||
|
||||
import hashlib
|
||||
import secrets
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Index, String, Text
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.database import Base
|
||||
|
||||
# ── Key generation constants ──────────────────────────────────────────────────
|
||||
KEY_PREFIX = "aegis_"
|
||||
KEY_BYTES = 32 # 32 random bytes → 64 hex chars → 70-char key total
|
||||
DISPLAY_LEN = 12 # chars stored as prefix for UI display
|
||||
|
||||
|
||||
def generate_raw_key() -> str:
|
||||
"""Generate a fresh raw API key (must be shown to user only once)."""
|
||||
return KEY_PREFIX + secrets.token_hex(KEY_BYTES)
|
||||
|
||||
|
||||
def hash_key(raw_key: str) -> str:
|
||||
"""SHA-256 hash of a raw API key for secure storage."""
|
||||
return hashlib.sha256(raw_key.encode()).hexdigest()
|
||||
|
||||
|
||||
def key_prefix_display(raw_key: str) -> str:
|
||||
"""First DISPLAY_LEN characters of the raw key (safe for UI)."""
|
||||
return raw_key[:DISPLAY_LEN]
|
||||
|
||||
|
||||
# ── Valid scopes ──────────────────────────────────────────────────────────────
|
||||
VALID_SCOPES = {"read", "write", "admin"}
|
||||
|
||||
|
||||
class ApiKey(Base):
|
||||
"""
|
||||
Scoped API key for programmatic / BI / SOAR access.
|
||||
|
||||
The full raw key is **never stored** — only a SHA-256 hash.
|
||||
The first 12 characters (``key_prefix``) are retained for display.
|
||||
"""
|
||||
|
||||
__tablename__ = "api_keys"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
name = Column(String(200), nullable=False)
|
||||
description = Column(Text, nullable=True)
|
||||
|
||||
# Display only — never use for auth
|
||||
key_prefix = Column(String(DISPLAY_LEN + 1), nullable=False)
|
||||
|
||||
# Auth token — SHA-256 of the full raw key
|
||||
key_hash = Column(String(64), nullable=False, unique=True)
|
||||
|
||||
# Owner
|
||||
user_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Permissions
|
||||
scopes = Column(JSONB, nullable=False, default=["read"]) # ["read","write","admin"]
|
||||
|
||||
# Lifecycle
|
||||
last_used_at = Column(DateTime, nullable=True)
|
||||
expires_at = Column(DateTime, nullable=True)
|
||||
is_active = Column(Boolean, nullable=False, default=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
user = relationship("User", foreign_keys=[user_id])
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_api_keys_user_id", "user_id"),
|
||||
Index("ix_api_keys_key_hash", "key_hash"),
|
||||
Index("ix_api_keys_active", "is_active"),
|
||||
)
|
||||
@@ -0,0 +1,253 @@
|
||||
"""Phase 10: Attack Paths & Advanced Purple Team models."""
|
||||
|
||||
import enum
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import (
|
||||
Boolean, Column, DateTime, Enum, Float, ForeignKey,
|
||||
Index, Integer, String, Text,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class ExecutionStatus(str, enum.Enum):
|
||||
planned = "planned"
|
||||
in_progress = "in_progress"
|
||||
completed = "completed"
|
||||
aborted = "aborted"
|
||||
|
||||
|
||||
class StepResultStatus(str, enum.Enum):
|
||||
pending = "pending"
|
||||
executing = "executing"
|
||||
detected = "detected"
|
||||
not_detected = "not_detected"
|
||||
skipped = "skipped"
|
||||
|
||||
|
||||
class TimelineActorSide(str, enum.Enum):
|
||||
red = "red"
|
||||
blue = "blue"
|
||||
system = "system"
|
||||
|
||||
|
||||
class TimelineEntryType(str, enum.Enum):
|
||||
action = "action"
|
||||
detection = "detection"
|
||||
note = "note"
|
||||
phase_transition = "phase_transition"
|
||||
flag = "flag"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class AttackPath(Base):
|
||||
"""
|
||||
A reusable attack scenario composed of ordered kill-chain steps.
|
||||
Can be a template (shared) or a one-off scenario.
|
||||
"""
|
||||
__tablename__ = "attack_paths"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
name = Column(String(300), nullable=False)
|
||||
description = Column(Text, nullable=True)
|
||||
objective = Column(Text, nullable=True) # what the attacker aims to achieve
|
||||
is_template = Column(Boolean, default=False) # reusable template flag
|
||||
threat_actor_id = Column(
|
||||
UUID(as_uuid=True), ForeignKey("threat_actors.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
created_by = Column(
|
||||
UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
tags = Column(JSONB, nullable=True, default=list)
|
||||
is_active = Column(Boolean, default=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
steps = relationship(
|
||||
"AttackPathStep", back_populates="attack_path",
|
||||
cascade="all, delete-orphan",
|
||||
order_by="AttackPathStep.order_index",
|
||||
)
|
||||
executions = relationship("AttackPathExecution", back_populates="attack_path")
|
||||
creator = relationship("User", foreign_keys=[created_by])
|
||||
threat_actor = relationship("ThreatActor", foreign_keys=[threat_actor_id])
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_attack_paths_created_by", "created_by"),
|
||||
Index("ix_attack_paths_is_template", "is_template"),
|
||||
)
|
||||
|
||||
|
||||
class AttackPathStep(Base):
|
||||
"""One step in an attack path — maps to a kill-chain phase + technique."""
|
||||
|
||||
__tablename__ = "attack_path_steps"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
attack_path_id = Column(
|
||||
UUID(as_uuid=True), ForeignKey("attack_paths.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
order_index = Column(Integer, nullable=False, default=0)
|
||||
kill_chain_phase = Column(String(60), nullable=True) # initial_access, execution, …
|
||||
technique_id = Column(
|
||||
UUID(as_uuid=True), ForeignKey("techniques.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
test_id = Column(
|
||||
UUID(as_uuid=True), ForeignKey("tests.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
name = Column(String(300), nullable=True) # human label for the step
|
||||
description = Column(Text, nullable=True)
|
||||
expected_detection = Column(Boolean, default=True) # do we expect blue to detect this?
|
||||
notes = Column(Text, nullable=True)
|
||||
|
||||
attack_path = relationship("AttackPath", back_populates="steps")
|
||||
technique = relationship("Technique", foreign_keys=[technique_id])
|
||||
test = relationship("Test", foreign_keys=[test_id])
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_ap_steps_path_id", "attack_path_id"),
|
||||
Index("ix_ap_steps_technique_id", "technique_id"),
|
||||
)
|
||||
|
||||
|
||||
class AttackPathExecution(Base):
|
||||
"""
|
||||
A single run of an attack path.
|
||||
Tracks Red/Blue participants, timing, and aggregated kill-chain metrics.
|
||||
"""
|
||||
__tablename__ = "attack_path_executions"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
attack_path_id = Column(
|
||||
UUID(as_uuid=True), ForeignKey("attack_paths.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
status = Column(
|
||||
Enum(ExecutionStatus, name="execution_status"), nullable=False,
|
||||
default=ExecutionStatus.planned,
|
||||
)
|
||||
environment = Column(String(100), nullable=True) # prod, staging, lab
|
||||
red_team_lead = Column(
|
||||
UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
blue_team_lead = Column(
|
||||
UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
started_by = Column(
|
||||
UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
started_at = Column(DateTime, nullable=True)
|
||||
completed_at = Column(DateTime, nullable=True)
|
||||
notes = Column(Text, nullable=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
# ── Computed kill-chain metrics (written on complete) ─────────────────
|
||||
total_steps = Column(Integer, nullable=True)
|
||||
detected_steps = Column(Integer, nullable=True)
|
||||
not_detected_steps = Column(Integer, nullable=True)
|
||||
skipped_steps = Column(Integer, nullable=True)
|
||||
detection_rate = Column(Float, nullable=True) # 0.0–1.0
|
||||
mttd_seconds = Column(Float, nullable=True) # mean time to detect (avg across detected)
|
||||
furthest_undetected_step = Column(Integer, nullable=True) # order_index of deepest undetected step
|
||||
|
||||
attack_path = relationship("AttackPath", back_populates="executions")
|
||||
step_results = relationship(
|
||||
"AttackPathStepResult", back_populates="execution",
|
||||
cascade="all, delete-orphan",
|
||||
order_by="AttackPathStepResult.step_order",
|
||||
)
|
||||
timeline = relationship(
|
||||
"TimelineEntry", back_populates="execution",
|
||||
cascade="all, delete-orphan",
|
||||
order_by="TimelineEntry.timestamp",
|
||||
)
|
||||
red_lead_user = relationship("User", foreign_keys=[red_team_lead])
|
||||
blue_lead_user = relationship("User", foreign_keys=[blue_team_lead])
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_ap_exec_path_id", "attack_path_id"),
|
||||
Index("ix_ap_exec_status", "status"),
|
||||
)
|
||||
|
||||
|
||||
class AttackPathStepResult(Base):
|
||||
"""Result of executing one step in an attack path execution."""
|
||||
|
||||
__tablename__ = "attack_path_step_results"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
execution_id = Column(
|
||||
UUID(as_uuid=True), ForeignKey("attack_path_executions.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
step_id = Column(
|
||||
UUID(as_uuid=True), ForeignKey("attack_path_steps.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
step_order = Column(Integer, nullable=False, default=0) # denormalized for sorting
|
||||
status = Column(
|
||||
Enum(StepResultStatus, name="step_result_status"), nullable=False,
|
||||
default=StepResultStatus.pending,
|
||||
)
|
||||
executed_by = Column(
|
||||
UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
executed_at = Column(DateTime, nullable=True)
|
||||
detected_at = Column(DateTime, nullable=True)
|
||||
time_to_detect_seconds = Column(Float, nullable=True)
|
||||
detection_asset_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("detection_assets.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
notes = Column(Text, nullable=True)
|
||||
evidence_ids = Column(JSONB, nullable=True, default=list)
|
||||
|
||||
execution = relationship("AttackPathExecution", back_populates="step_results")
|
||||
step = relationship("AttackPathStep")
|
||||
detection_asset = relationship("DetectionAsset", foreign_keys=[detection_asset_id])
|
||||
executor = relationship("User", foreign_keys=[executed_by])
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_ap_stepres_execution_id", "execution_id"),
|
||||
Index("ix_ap_stepres_step_id", "step_id"),
|
||||
)
|
||||
|
||||
|
||||
class TimelineEntry(Base):
|
||||
"""Timestamped Red/Blue action during an execution — used for MTTD/MTTR."""
|
||||
|
||||
__tablename__ = "attack_path_timeline"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
execution_id = Column(
|
||||
UUID(as_uuid=True), ForeignKey("attack_path_executions.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
step_id = Column(
|
||||
UUID(as_uuid=True), ForeignKey("attack_path_steps.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
timestamp = Column(DateTime, nullable=False, default=datetime.utcnow)
|
||||
actor_side = Column(
|
||||
Enum(TimelineActorSide, name="timeline_actor_side"), nullable=False,
|
||||
)
|
||||
actor_id = Column(
|
||||
UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
entry_type = Column(
|
||||
Enum(TimelineEntryType, name="timeline_entry_type"), nullable=False,
|
||||
)
|
||||
content = Column(Text, nullable=False)
|
||||
extra = Column(JSONB, nullable=True)
|
||||
|
||||
execution = relationship("AttackPathExecution", back_populates="timeline")
|
||||
actor = relationship("User", foreign_keys=[actor_id])
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_timeline_execution_id", "execution_id"),
|
||||
Index("ix_timeline_timestamp", "timestamp"),
|
||||
)
|
||||
@@ -48,6 +48,7 @@ class Campaign(Base):
|
||||
ForeignKey("users.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
start_date = Column(DateTime, nullable=True) # campaign won't activate before this date
|
||||
scheduled_at = Column(DateTime, nullable=True)
|
||||
completed_at = Column(DateTime, nullable=True)
|
||||
target_platform = Column(String, nullable=True)
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
"""Decay Policy model — configurable detection validity rules."""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from sqlalchemy import Column, String, Integer, Float, Boolean, DateTime, Text
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class DecayPolicy(Base):
|
||||
__tablename__ = "decay_policies"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
name = Column(String(200), nullable=False)
|
||||
description = Column(Text)
|
||||
applies_to_platform = Column(String(100))
|
||||
applies_to_asset_type = Column(String(50))
|
||||
applies_to_tactic = Column(String(100))
|
||||
fresh_days = Column(Integer, default=90, server_default='90')
|
||||
aging_days = Column(Integer, default=180, server_default='180')
|
||||
stale_days = Column(Integer, default=365, server_default='365')
|
||||
default_validity_days = Column(Integer, default=180, server_default='180')
|
||||
silent_threshold_days = Column(Integer, default=30, server_default='30')
|
||||
noisy_threshold_daily = Column(Integer, default=100, server_default='100')
|
||||
recency_weight = Column(Float, default=0.3, server_default='0.3')
|
||||
coverage_weight = Column(Float, default=0.3, server_default='0.3')
|
||||
health_weight = Column(Float, default=0.25, server_default='0.25')
|
||||
diversity_weight = Column(Float, default=0.15, server_default='0.15')
|
||||
is_default = Column(Boolean, default=False, server_default='false')
|
||||
is_active = Column(Boolean, default=True, server_default='true')
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow)
|
||||
@@ -0,0 +1,168 @@
|
||||
"""Detection Lifecycle Management models."""
|
||||
|
||||
import uuid
|
||||
import enum
|
||||
from datetime import datetime
|
||||
from sqlalchemy import (
|
||||
Column, String, Integer, Float, Boolean, DateTime,
|
||||
ForeignKey, Text, Enum as SQLEnum
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
from sqlalchemy.orm import relationship
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class DetectionConfidence(str, enum.Enum):
|
||||
fresh = "fresh"
|
||||
aging = "aging"
|
||||
stale = "stale"
|
||||
broken = "broken"
|
||||
unknown = "unknown"
|
||||
|
||||
|
||||
class DetectionHealthStatus(str, enum.Enum):
|
||||
healthy = "healthy"
|
||||
silent = "silent"
|
||||
noisy = "noisy"
|
||||
orphan = "orphan"
|
||||
deprecated = "deprecated"
|
||||
untested = "untested"
|
||||
|
||||
|
||||
class InvalidationReason(str, enum.Enum):
|
||||
time_decay = "time_decay"
|
||||
mitre_update = "mitre_update"
|
||||
log_source_change = "log_source_change"
|
||||
siem_update = "siem_update"
|
||||
edr_update = "edr_update"
|
||||
infrastructure_change = "infrastructure_change"
|
||||
parser_change = "parser_change"
|
||||
manual = "manual"
|
||||
rule_modified = "rule_modified"
|
||||
|
||||
|
||||
class DetectionAsset(Base):
|
||||
__tablename__ = "detection_assets"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
name = Column(String(500), nullable=False)
|
||||
description = Column(Text)
|
||||
asset_type = Column(String(50), nullable=False)
|
||||
platform = Column(String(100))
|
||||
rule_content = Column(Text)
|
||||
rule_language = Column(String(50))
|
||||
rule_repository_url = Column(Text)
|
||||
rule_file_path = Column(String(500))
|
||||
rule_version = Column(String(50))
|
||||
rule_hash = Column(String(64))
|
||||
last_rule_change_at = Column(DateTime)
|
||||
log_source_name = Column(String(200))
|
||||
log_source_version = Column(String(50))
|
||||
log_source_config = Column(JSONB, server_default='{}')
|
||||
infrastructure_hash = Column(String(64))
|
||||
infrastructure_details = Column(JSONB, server_default='{}')
|
||||
health_status = Column(
|
||||
SQLEnum(DetectionHealthStatus, name="detectionhealthstatus"),
|
||||
default=DetectionHealthStatus.untested,
|
||||
nullable=False,
|
||||
server_default="untested",
|
||||
)
|
||||
last_alert_at = Column(DateTime)
|
||||
alert_count_30d = Column(Integer, default=0, server_default='0')
|
||||
false_positive_rate = Column(Float)
|
||||
expected_alert_frequency = Column(String(50))
|
||||
owner_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True)
|
||||
backup_owner_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True)
|
||||
team = Column(String(100))
|
||||
is_active = Column(Boolean, default=True, nullable=False, server_default='true')
|
||||
tags = Column(JSONB, server_default='[]')
|
||||
asset_metadata = Column(JSONB, server_default='{}')
|
||||
created_by = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True)
|
||||
created_at = Column(DateTime(timezone=True), server_default='now()')
|
||||
updated_at = Column(DateTime(timezone=True), server_default='now()')
|
||||
|
||||
technique_mappings = relationship("DetectionTechniqueMapping", back_populates="detection_asset", cascade="all, delete-orphan")
|
||||
validations = relationship("DetectionValidation", back_populates="detection_asset", cascade="all, delete-orphan")
|
||||
|
||||
|
||||
class DetectionTechniqueMapping(Base):
|
||||
__tablename__ = "detection_technique_mappings"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
detection_asset_id = Column(UUID(as_uuid=True), ForeignKey("detection_assets.id", ondelete="CASCADE"), nullable=False)
|
||||
technique_id = Column(UUID(as_uuid=True), ForeignKey("techniques.id", ondelete="CASCADE"), nullable=False)
|
||||
coverage_type = Column(String(50), default="detect", server_default="detect")
|
||||
confidence_level = Column(String(20), default="medium", server_default="medium")
|
||||
notes = Column(Text)
|
||||
created_at = Column(DateTime(timezone=True), server_default='now()')
|
||||
|
||||
detection_asset = relationship("DetectionAsset", back_populates="technique_mappings")
|
||||
|
||||
|
||||
class DetectionValidation(Base):
|
||||
__tablename__ = "detection_validations"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
detection_asset_id = Column(UUID(as_uuid=True), ForeignKey("detection_assets.id", ondelete="CASCADE"), nullable=False)
|
||||
technique_id = Column(UUID(as_uuid=True), ForeignKey("techniques.id", ondelete="SET NULL"), nullable=True)
|
||||
test_id = Column(UUID(as_uuid=True), ForeignKey("tests.id", ondelete="SET NULL"), nullable=True)
|
||||
validated_at = Column(DateTime, default=datetime.utcnow)
|
||||
expires_at = Column(DateTime, nullable=False)
|
||||
is_valid = Column(Boolean, default=True, nullable=False, server_default='true')
|
||||
validation_result = Column(String(50))
|
||||
validation_method = Column(String(100))
|
||||
rule_hash_at_validation = Column(String(64))
|
||||
log_source_version_at_validation = Column(String(50))
|
||||
infrastructure_hash_at_validation = Column(String(64))
|
||||
environment_snapshot = Column(JSONB, server_default='{}')
|
||||
invalidated_at = Column(DateTime)
|
||||
invalidation_reason = Column(SQLEnum(InvalidationReason, name="invalidationreason"), nullable=True)
|
||||
invalidation_details = Column(Text)
|
||||
invalidated_by = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True)
|
||||
validated_by = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=False)
|
||||
integrity_hash = Column(String(64))
|
||||
notes = Column(Text)
|
||||
evidence_ids = Column(JSONB, server_default='[]')
|
||||
|
||||
detection_asset = relationship("DetectionAsset", back_populates="validations")
|
||||
|
||||
|
||||
class TechniqueConfidenceScore(Base):
|
||||
__tablename__ = "technique_confidence_scores"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
technique_id = Column(UUID(as_uuid=True), ForeignKey("techniques.id", ondelete="CASCADE"), nullable=False, unique=True)
|
||||
confidence_level = Column(
|
||||
SQLEnum(DetectionConfidence, name="detectionconfidence"),
|
||||
default=DetectionConfidence.unknown,
|
||||
server_default="unknown",
|
||||
)
|
||||
confidence_score = Column(Float, default=0.0, server_default='0.0')
|
||||
detection_count = Column(Integer, default=0, server_default='0')
|
||||
valid_detection_count = Column(Integer, default=0, server_default='0')
|
||||
last_validated_at = Column(DateTime)
|
||||
next_validation_due = Column(DateTime)
|
||||
last_recalculated_at = Column(DateTime, default=datetime.utcnow)
|
||||
recency_factor = Column(Float, default=0.0, server_default='0.0')
|
||||
coverage_factor = Column(Float, default=0.0, server_default='0.0')
|
||||
health_factor = Column(Float, default=0.0, server_default='0.0')
|
||||
diversity_factor = Column(Float, default=0.0, server_default='0.0')
|
||||
score_breakdown = Column(JSONB, server_default='{}')
|
||||
risk_factors = Column(JSONB, server_default='[]')
|
||||
updated_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
|
||||
class InfrastructureChangeLog(Base):
|
||||
__tablename__ = "infrastructure_change_logs"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
change_type = Column(String(100), nullable=False)
|
||||
description = Column(Text, nullable=False)
|
||||
affected_platforms = Column(JSONB, server_default='[]')
|
||||
affected_log_sources = Column(JSONB, server_default='[]')
|
||||
change_date = Column(DateTime, default=datetime.utcnow)
|
||||
reported_by = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True)
|
||||
auto_invalidate = Column(Boolean, default=True, server_default='true')
|
||||
invalidated_count = Column(Integer, default=0, server_default='0')
|
||||
change_metadata = Column(JSONB, server_default='{}')
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
@@ -0,0 +1,34 @@
|
||||
"""SQLAlchemy model for tracking imported ATT&CK Evaluation rounds."""
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Column, String, Integer, DateTime, Text, ForeignKey, Index
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class EvaluationImport(Base):
|
||||
"""Tracks which ATT&CK Evaluation rounds have been imported into the platform.
|
||||
|
||||
Each row represents one vendor+adversary combination that has been processed
|
||||
and turned into Test records. Used to avoid duplicate imports and to show
|
||||
the admin panel which rounds are available vs imported.
|
||||
"""
|
||||
__tablename__ = "evaluation_imports"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
adversary_name = Column(String, nullable=False) # "apt29", "turla"
|
||||
adversary_display = Column(String, nullable=False) # "APT29", "Turla"
|
||||
eval_round = Column(Integer, nullable=False) # 1, 2, 3 …
|
||||
imported_at = Column(DateTime, nullable=False, default=datetime.utcnow)
|
||||
imported_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
|
||||
tests_created = Column(Integer, default=0)
|
||||
techniques_covered = Column(Integer, default=0)
|
||||
status = Column(String, default="completed") # "completed" | "failed"
|
||||
notes = Column(Text, nullable=True)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_evaluation_imports_adversary", "adversary_name"),
|
||||
Index("ix_evaluation_imports_round", "eval_round"),
|
||||
)
|
||||
@@ -0,0 +1,68 @@
|
||||
"""Phase 13: Executive Dashboard — PostureSnapshot model."""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import (
|
||||
Boolean, Column, Date, DateTime, Float, ForeignKey,
|
||||
Index, Integer, UniqueConstraint,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class PostureSnapshot(Base):
|
||||
"""
|
||||
Daily point-in-time capture of the organisation's security posture.
|
||||
|
||||
Aggregates data from all phases (coverage, risk, ownership, knowledge,
|
||||
attack-paths) into a single row that can be trended over time.
|
||||
"""
|
||||
|
||||
__tablename__ = "posture_snapshots"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
snapshot_date = Column(Date, nullable=False) # one per calendar day
|
||||
|
||||
# ── Coverage ──────────────────────────────────────────────────────────────
|
||||
total_techniques = Column(Integer, nullable=False, default=0)
|
||||
validated_count = Column(Integer, nullable=False, default=0)
|
||||
partial_count = Column(Integer, nullable=False, default=0)
|
||||
not_covered_count = Column(Integer, nullable=False, default=0)
|
||||
coverage_pct = Column(Float, nullable=False, default=0.0) # 0–100
|
||||
|
||||
# ── Risk ─────────────────────────────────────────────────────────────────
|
||||
avg_risk_score = Column(Float, nullable=False, default=0.0)
|
||||
critical_count = Column(Integer, nullable=False, default=0)
|
||||
high_count = Column(Integer, nullable=False, default=0)
|
||||
medium_count = Column(Integer, nullable=False, default=0)
|
||||
low_count = Column(Integer, nullable=False, default=0)
|
||||
|
||||
# ── Operations ────────────────────────────────────────────────────────────
|
||||
open_queue_items = Column(Integer, nullable=False, default=0)
|
||||
orphan_techniques = Column(Integer, nullable=False, default=0)
|
||||
|
||||
# ── Knowledge ─────────────────────────────────────────────────────────────
|
||||
playbook_count = Column(Integer, nullable=False, default=0)
|
||||
lesson_count = Column(Integer, nullable=False, default=0)
|
||||
|
||||
# ── MTTD (from attack-path executions completed in last 30 d) ────────────
|
||||
mttd_avg_seconds = Column(Float, nullable=True) # None if no data
|
||||
executions_30d = Column(Integer, nullable=False, default=0)
|
||||
detection_rate_30d = Column(Float, nullable=True) # avg across executions
|
||||
|
||||
# ── Meta ─────────────────────────────────────────────────────────────────
|
||||
created_by = Column(
|
||||
UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
extra = Column(JSONB, nullable=True) # full breakdown / by-tactic
|
||||
|
||||
creator = relationship("User", foreign_keys=[created_by])
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("snapshot_date", name="uq_posture_snapshot_date"),
|
||||
Index("ix_posture_snapshots_date", "snapshot_date"),
|
||||
)
|
||||
@@ -0,0 +1,129 @@
|
||||
"""Phase 11: Knowledge Management models — Playbooks + Lessons Learned."""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import (
|
||||
Boolean, Column, DateTime, ForeignKey,
|
||||
Index, Integer, String, Text, UniqueConstraint,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
# ── Playbooks ──────────────────────────────────────────────────────────────────
|
||||
|
||||
class Playbook(Base):
|
||||
"""
|
||||
Structured runbook for a specific technique and playbook type.
|
||||
|
||||
playbook_type: attack | detect | investigate | respond | hunt
|
||||
One playbook per (technique, type). Edits increment ``version``
|
||||
and save a snapshot to ``PlaybookVersion``.
|
||||
"""
|
||||
__tablename__ = "playbooks"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
technique_id = Column(
|
||||
UUID(as_uuid=True), ForeignKey("techniques.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
playbook_type = Column(String(32), nullable=False) # attack/detect/investigate/respond/hunt
|
||||
title = Column(String(255), nullable=False)
|
||||
content = Column(Text, nullable=False, default="")
|
||||
version = Column(Integer, default=1, nullable=False)
|
||||
tools = Column(JSONB, default=list) # list of tool name strings
|
||||
prerequisites = Column(JSONB, default=list) # list of prerequisite strings
|
||||
created_by = Column(
|
||||
UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
updated_by = Column(
|
||||
UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
is_active = Column(Boolean, default=True)
|
||||
|
||||
# Relationships
|
||||
technique = relationship("Technique", foreign_keys=[technique_id])
|
||||
creator = relationship("User", foreign_keys=[created_by])
|
||||
updater = relationship("User", foreign_keys=[updated_by])
|
||||
versions = relationship(
|
||||
"PlaybookVersion", back_populates="playbook",
|
||||
cascade="all, delete-orphan",
|
||||
order_by="PlaybookVersion.version.desc()",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("technique_id", "playbook_type", name="uq_playbook_technique_type"),
|
||||
Index("ix_playbooks_technique_id", "technique_id"),
|
||||
Index("ix_playbooks_type", "playbook_type"),
|
||||
)
|
||||
|
||||
|
||||
class PlaybookVersion(Base):
|
||||
"""Immutable snapshot of a playbook at a given version number."""
|
||||
|
||||
__tablename__ = "playbook_versions"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
playbook_id = Column(
|
||||
UUID(as_uuid=True), ForeignKey("playbooks.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
version = Column(Integer, nullable=False)
|
||||
title = Column(String(255), nullable=False)
|
||||
content = Column(Text, nullable=False, default="")
|
||||
tools = Column(JSONB, default=list)
|
||||
prerequisites = Column(JSONB, default=list)
|
||||
changed_by = Column(
|
||||
UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
change_note = Column(String(500), nullable=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
playbook = relationship("Playbook", back_populates="versions")
|
||||
changer = relationship("User", foreign_keys=[changed_by])
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_pb_versions_playbook_id", "playbook_id"),
|
||||
Index("ix_pb_versions_version", "playbook_id", "version"),
|
||||
)
|
||||
|
||||
|
||||
# ── Lessons Learned ────────────────────────────────────────────────────────────
|
||||
|
||||
class LessonLearned(Base):
|
||||
"""
|
||||
Immutable post-mortem record linked to a test, campaign, attack-path or
|
||||
created manually.
|
||||
|
||||
severity: critical | high | medium | low | info
|
||||
entity_type: test | campaign | attack_path | manual
|
||||
"""
|
||||
__tablename__ = "lessons_learned"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
title = Column(String(255), nullable=False)
|
||||
what_happened = Column(Text, nullable=False, default="")
|
||||
root_cause = Column(Text, nullable=False, default="")
|
||||
fix_applied = Column(Text, nullable=True)
|
||||
severity = Column(String(16), nullable=False, default="medium")
|
||||
entity_type = Column(String(32), nullable=False, default="manual")
|
||||
entity_id = Column(UUID(as_uuid=True), nullable=True)
|
||||
technique_ids = Column(JSONB, default=list) # list of UUID strings
|
||||
tags = Column(JSONB, default=list)
|
||||
created_by = Column(
|
||||
UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
is_active = Column(Boolean, default=True) # soft-delete (admin only)
|
||||
|
||||
creator = relationship("User", foreign_keys=[created_by])
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_ll_entity", "entity_type", "entity_id"),
|
||||
Index("ix_ll_severity", "severity"),
|
||||
Index("ix_ll_created_by", "created_by"),
|
||||
)
|
||||
@@ -0,0 +1,144 @@
|
||||
"""Phase 13: Operational Alerts — AlertRule and AlertInstance models."""
|
||||
|
||||
import enum
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import (
|
||||
Boolean, Column, DateTime, ForeignKey,
|
||||
Index, Integer, String, Text,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
# ── Enumerations ──────────────────────────────────────────────────────────────
|
||||
|
||||
class AlertSeverity(str, enum.Enum):
|
||||
critical = "critical"
|
||||
high = "high"
|
||||
medium = "medium"
|
||||
low = "low"
|
||||
info = "info"
|
||||
|
||||
|
||||
class AlertStatus(str, enum.Enum):
|
||||
open = "open"
|
||||
acknowledged = "acknowledged"
|
||||
resolved = "resolved"
|
||||
dismissed = "dismissed"
|
||||
|
||||
|
||||
class AlertRuleType(str, enum.Enum):
|
||||
high_risk = "high_risk" # risk_score >= threshold
|
||||
stale_technique = "stale_technique" # not validated in N days
|
||||
coverage_regression = "coverage_regression" # coverage_pct dropped
|
||||
low_coverage = "low_coverage" # coverage below min
|
||||
expiry_wave = "expiry_wave" # many pending queue items
|
||||
new_technique = "new_technique" # new MITRE techniques added
|
||||
orphan_spike = "orphan_spike" # many unowned techniques
|
||||
custom = "custom" # future extension placeholder
|
||||
|
||||
|
||||
# ── AlertRule ─────────────────────────────────────────────────────────────────
|
||||
|
||||
class AlertRule(Base):
|
||||
"""
|
||||
Defines a condition that, when satisfied, fires an AlertInstance.
|
||||
|
||||
System rules (is_system=True) are seeded at startup and cannot be deleted.
|
||||
Custom rules (is_system=False) can be created by admins.
|
||||
"""
|
||||
|
||||
__tablename__ = "alert_rules"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
name = Column(String(300), nullable=False)
|
||||
description = Column(Text, nullable=True)
|
||||
rule_type = Column(String(50), nullable=False)
|
||||
severity = Column(String(20), nullable=False, default=AlertSeverity.medium.value)
|
||||
is_enabled = Column(Boolean, nullable=False, default=True)
|
||||
is_system = Column(Boolean, nullable=False, default=False) # seeded, not deletable
|
||||
|
||||
# Rule-specific thresholds/config (varies by rule_type)
|
||||
config = Column(JSONB, nullable=False, default={})
|
||||
|
||||
# Delivery
|
||||
notify_in_app = Column(Boolean, nullable=False, default=True)
|
||||
notify_webhook = Column(Boolean, nullable=False, default=False)
|
||||
webhook_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("webhook_configs.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
# Cooldown — don't re-fire within N hours of last firing
|
||||
cooldown_hours = Column(Integer, nullable=False, default=24)
|
||||
|
||||
# Meta
|
||||
created_by = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("users.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
last_fired_at = Column(DateTime, nullable=True)
|
||||
|
||||
creator = relationship("User", foreign_keys=[created_by])
|
||||
instances = relationship("AlertInstance", back_populates="rule",
|
||||
cascade="all, delete-orphan")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_alert_rules_type", "rule_type"),
|
||||
Index("ix_alert_rules_enabled", "is_enabled"),
|
||||
)
|
||||
|
||||
|
||||
# ── AlertInstance ─────────────────────────────────────────────────────────────
|
||||
|
||||
class AlertInstance(Base):
|
||||
"""
|
||||
A single firing of an AlertRule.
|
||||
|
||||
Transitions: open → acknowledged → resolved
|
||||
open → dismissed
|
||||
"""
|
||||
|
||||
__tablename__ = "alert_instances"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
rule_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("alert_rules.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
# Denormalised fields kept for history even after rule deletion
|
||||
rule_name = Column(String(300), nullable=False)
|
||||
rule_type = Column(String(50), nullable=False)
|
||||
severity = Column(String(20), nullable=False)
|
||||
|
||||
title = Column(String(500), nullable=False)
|
||||
message = Column(Text, nullable=False)
|
||||
details = Column(JSONB, nullable=True) # structured context
|
||||
|
||||
status = Column(String(20), nullable=False, default=AlertStatus.open.value)
|
||||
acknowledged_by = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("users.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
acknowledged_at = Column(DateTime, nullable=True)
|
||||
resolved_at = Column(DateTime, nullable=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
rule = relationship("AlertRule", back_populates="instances")
|
||||
acknowledger = relationship("User", foreign_keys=[acknowledged_by])
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_alert_instances_rule_id", "rule_id"),
|
||||
Index("ix_alert_instances_status", "status"),
|
||||
Index("ix_alert_instances_severity", "severity"),
|
||||
Index("ix_alert_instances_created", "created_at"),
|
||||
)
|
||||
@@ -0,0 +1,136 @@
|
||||
"""Phase 9: Ownership & Revalidation Queue models."""
|
||||
|
||||
import enum
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Column, DateTime, Enum, ForeignKey, Index, String, Text
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class QueuePriority(str, enum.Enum):
|
||||
critical = "critical"
|
||||
high = "high"
|
||||
medium = "medium"
|
||||
low = "low"
|
||||
|
||||
|
||||
class QueueStatus(str, enum.Enum):
|
||||
pending = "pending"
|
||||
in_progress = "in_progress"
|
||||
completed = "completed"
|
||||
dismissed = "dismissed"
|
||||
|
||||
|
||||
class QueueReason(str, enum.Enum):
|
||||
validation_expired = "validation_expired"
|
||||
infra_change = "infra_change"
|
||||
osint_alert = "osint_alert"
|
||||
mitre_update = "mitre_update"
|
||||
rule_modified = "rule_modified"
|
||||
low_confidence = "low_confidence"
|
||||
manual = "manual"
|
||||
|
||||
|
||||
class TechniqueOwnership(Base):
|
||||
"""Ownership assignment for a MITRE technique."""
|
||||
|
||||
__tablename__ = "technique_ownerships"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
technique_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("techniques.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
unique=True,
|
||||
)
|
||||
owner_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("users.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
backup_owner_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("users.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
team = Column(String(200), nullable=True)
|
||||
notes = Column(Text, nullable=True)
|
||||
assigned_at = Column(DateTime, nullable=True)
|
||||
assigned_by = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("users.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
technique = relationship("Technique", foreign_keys=[technique_id])
|
||||
owner = relationship("User", foreign_keys=[owner_id])
|
||||
backup_owner = relationship("User", foreign_keys=[backup_owner_id])
|
||||
|
||||
|
||||
class RevalidationQueueItem(Base):
|
||||
"""A prioritised work item for the analyst's daily queue."""
|
||||
|
||||
__tablename__ = "revalidation_queue_items"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
technique_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("techniques.id", ondelete="CASCADE"),
|
||||
nullable=True,
|
||||
)
|
||||
detection_asset_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("detection_assets.id", ondelete="CASCADE"),
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
priority = Column(
|
||||
Enum(QueuePriority, name="queue_priority"),
|
||||
nullable=False,
|
||||
default=QueuePriority.medium,
|
||||
)
|
||||
reason = Column(
|
||||
Enum(QueueReason, name="queue_reason"),
|
||||
nullable=False,
|
||||
)
|
||||
reason_detail = Column(Text, nullable=True)
|
||||
status = Column(
|
||||
Enum(QueueStatus, name="queue_status"),
|
||||
nullable=False,
|
||||
default=QueueStatus.pending,
|
||||
)
|
||||
|
||||
assigned_to = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("users.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
due_date = Column(DateTime, nullable=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
completed_at = Column(DateTime, nullable=True)
|
||||
dismissed_at = Column(DateTime, nullable=True)
|
||||
completed_by = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("users.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
extra = Column(JSONB, nullable=True) # arbitrary metadata
|
||||
|
||||
technique = relationship("Technique", foreign_keys=[technique_id])
|
||||
detection_asset = relationship("DetectionAsset", foreign_keys=[detection_asset_id])
|
||||
assignee = relationship("User", foreign_keys=[assigned_to])
|
||||
|
||||
|
||||
# Indexes
|
||||
Index("ix_rqueue_status", RevalidationQueueItem.status)
|
||||
Index("ix_rqueue_priority", RevalidationQueueItem.priority)
|
||||
Index("ix_rqueue_assigned_to", RevalidationQueueItem.assigned_to)
|
||||
Index("ix_rqueue_technique_id", RevalidationQueueItem.technique_id)
|
||||
Index("ix_rqueue_asset_id", RevalidationQueueItem.detection_asset_id)
|
||||
Index("ix_techown_owner_id", TechniqueOwnership.owner_id)
|
||||
@@ -0,0 +1,69 @@
|
||||
"""Phase 12: Risk Intelligence model — per-technique risk scoring."""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import (
|
||||
Boolean, Column, DateTime, Float, ForeignKey,
|
||||
Index, Integer, String, UniqueConstraint,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class TechniqueRiskProfile(Base):
|
||||
"""
|
||||
Aggregated risk profile for one technique.
|
||||
|
||||
Combines four weighted factors:
|
||||
• detection_gap (35 %) — 0=fully covered → 1=no coverage
|
||||
• threat_actor_rel (30 %) — normalised actor count
|
||||
• osint_signals (20 %) — normalised recent OSINT items (30 d)
|
||||
• test_failure_rate (15 %) — proportion of tests where blue didn't detect
|
||||
|
||||
risk_score = weighted sum × 100 → 0–100
|
||||
risk_level: critical ≥75 | high ≥50 | medium ≥25 | low ≥10 | info
|
||||
"""
|
||||
|
||||
__tablename__ = "technique_risk_profiles"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
technique_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("techniques.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# ── Computed scores ───────────────────────────────────────────────────────
|
||||
risk_score = Column(Float, nullable=False, default=0.0) # 0–100
|
||||
likelihood = Column(Float, nullable=False, default=0.0) # 0–100
|
||||
impact = Column(Float, nullable=False, default=0.0) # 0–100
|
||||
risk_level = Column(String(16), nullable=False, default="info")
|
||||
|
||||
# ── Raw factor values ─────────────────────────────────────────────────────
|
||||
detection_gap = Column(Float, nullable=False, default=1.0) # 0–1
|
||||
threat_actor_count = Column(Integer, nullable=False, default=0)
|
||||
osint_signal_count = Column(Integer, nullable=False, default=0) # last 30 d
|
||||
test_fail_count = Column(Integer, nullable=False, default=0)
|
||||
test_total_count = Column(Integer, nullable=False, default=0)
|
||||
test_failure_rate = Column(Float, nullable=False, default=0.0) # 0–1
|
||||
confidence_level = Column(Float, nullable=False, default=0.0) # DLC 0–1
|
||||
|
||||
# ── Rich detail ──────────────────────────────────────────────────────────
|
||||
scoring_breakdown = Column(JSONB, nullable=True) # per-factor contributions
|
||||
recommendations = Column(JSONB, nullable=True) # list[str]
|
||||
|
||||
# ── Meta ─────────────────────────────────────────────────────────────────
|
||||
computed_at = Column(DateTime, default=datetime.utcnow)
|
||||
is_stale = Column(Boolean, default=True)
|
||||
|
||||
technique = relationship("Technique", foreign_keys=[technique_id])
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("technique_id", name="uq_risk_profile_technique"),
|
||||
Index("ix_risk_profiles_risk_score", "risk_score"),
|
||||
Index("ix_risk_profiles_risk_level", "risk_level"),
|
||||
Index("ix_risk_profiles_stale", "is_stale"),
|
||||
)
|
||||
@@ -0,0 +1,49 @@
|
||||
"""Phase 14: SSO / SAML 2.0 configuration model."""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Boolean, Column, DateTime, String, Text
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class SsoConfig(Base):
|
||||
"""
|
||||
SAML 2.0 Identity Provider configuration.
|
||||
|
||||
Exactly one row is expected (use upsert). The SP metadata endpoint
|
||||
reads from this row to generate XML for IdP registration.
|
||||
"""
|
||||
|
||||
__tablename__ = "sso_configs"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
is_enabled = Column(Boolean, nullable=False, default=False)
|
||||
provider_name = Column(String(200), nullable=True) # e.g., "Okta", "Azure AD"
|
||||
|
||||
# ── Service Provider (Aegis) settings ────────────────────────────────────
|
||||
sp_entity_id = Column(String(500), nullable=True) # e.g., https://aegis.co/api/v1/sso/metadata
|
||||
sp_acs_url = Column(String(500), nullable=True) # Assertion Consumer Service URL
|
||||
sp_slo_url = Column(String(500), nullable=True) # Single Logout URL (optional)
|
||||
sp_certificate = Column(Text, nullable=True) # SP public cert for signed requests
|
||||
sp_private_key = Column(Text, nullable=True) # SP private key (stored encrypted in future)
|
||||
|
||||
# ── Identity Provider settings ────────────────────────────────────────────
|
||||
idp_entity_id = Column(String(500), nullable=True)
|
||||
idp_sso_url = Column(String(500), nullable=True) # IdP redirect/POST binding URL
|
||||
idp_slo_url = Column(String(500), nullable=True) # IdP SLO URL
|
||||
idp_certificate = Column(Text, nullable=True) # IdP X.509 cert for response validation
|
||||
|
||||
# ── Attribute mapping ─────────────────────────────────────────────────────
|
||||
# SAML attribute name → Aegis field
|
||||
attr_email = Column(String(200), nullable=True, default="email")
|
||||
attr_username = Column(String(200), nullable=True, default="username")
|
||||
attr_role = Column(String(200), nullable=True, default="role")
|
||||
default_role = Column(String(50), nullable=True, default="viewer")
|
||||
auto_provision = Column(Boolean, nullable=False, default=True) # create user on first login
|
||||
|
||||
# ── Meta ─────────────────────────────────────────────────────────────────
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
@@ -0,0 +1,26 @@
|
||||
"""SystemConfig model — runtime key-value configuration store."""
|
||||
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import Column, String, Text, DateTime, func
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class SystemConfig(Base):
|
||||
"""Generic key-value store for runtime system configuration.
|
||||
|
||||
Currently used for:
|
||||
- SMTP email settings (overrides .env values when present)
|
||||
|
||||
Keys are namespaced by convention: ``smtp.host``, ``smtp.port``, etc.
|
||||
"""
|
||||
|
||||
__tablename__ = "system_configs"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
key = Column(String(200), unique=True, nullable=False, index=True)
|
||||
value = Column(Text, nullable=True)
|
||||
description = Column(String(500), nullable=True)
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
|
||||
@@ -50,6 +50,7 @@ class Test(Base):
|
||||
# ── Phase timing fields (for automatic Tempo worklogs) ──────────
|
||||
red_started_at = Column(DateTime, nullable=True)
|
||||
blue_started_at = Column(DateTime, nullable=True)
|
||||
blue_work_started_at = Column(DateTime, nullable=True) # when blue tech picks up (Tempo start)
|
||||
paused_at = Column(DateTime, nullable=True)
|
||||
red_paused_seconds = Column(Integer, default=0)
|
||||
blue_paused_seconds = Column(Integer, default=0)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import uuid
|
||||
from sqlalchemy import Column, String, Boolean, DateTime, func
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
|
||||
from app.database import Base
|
||||
|
||||
@@ -28,3 +28,8 @@ class User(Base):
|
||||
must_change_password = Column(Boolean, default=True)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
last_login = Column(DateTime, nullable=True)
|
||||
notification_preferences = Column(JSONB, nullable=True, server_default='{"email_on_test_validated": true, "email_on_campaign_completed": true, "email_on_new_mitre_techniques": false, "in_app_all": true}')
|
||||
jira_account_id = Column(String(100), nullable=True)
|
||||
jira_api_token = Column(String(500), nullable=True) # personal Atlassian token
|
||||
jira_email = Column(String(255), nullable=True) # Atlassian email (overrides account email)
|
||||
tempo_api_token = Column(String(500), nullable=True) # personal Tempo API token
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
"""WebhookConfig model — outbound HTTP notification endpoints."""
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from sqlalchemy import Column, String, Boolean, DateTime, Integer, Text, ForeignKey, func
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
from app.database import Base
|
||||
|
||||
class WebhookConfig(Base):
|
||||
__tablename__ = "webhook_configs"
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
name = Column(String(200), nullable=False)
|
||||
url = Column(Text, nullable=False)
|
||||
secret = Column(String(256), nullable=True) # HMAC signature key
|
||||
events = Column(JSONB, nullable=False, server_default="[]") # list of event types
|
||||
is_active = Column(Boolean, default=True, nullable=False)
|
||||
created_by = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True)
|
||||
last_triggered_at = Column(DateTime, nullable=True)
|
||||
failure_count = Column(Integer, default=0, nullable=False)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
@@ -0,0 +1,340 @@
|
||||
"""Admin configuration export/import — single-file migration bundle.
|
||||
|
||||
GET /admin/export-config — download JSON bundle (admin only)
|
||||
POST /admin/import-config — upload JSON bundle and restore (admin only)
|
||||
|
||||
What is exported (and what is NOT):
|
||||
✓ system_configs — email / jira settings (passwords REDACTED)
|
||||
✓ webhook_configs — notification webhooks (secrets REDACTED)
|
||||
✓ sso_configs — SAML/SSO config (private keys REDACTED)
|
||||
✓ scoring_config — technique scoring weights
|
||||
✓ test_templates — CUSTOM templates only (source='custom')
|
||||
✓ users — username / email / role (no passwords / tokens)
|
||||
✗ atomic/sigma/elastic templates, techniques, tests, campaigns, reports
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.auth import hash_password
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import get_current_user, require_role
|
||||
from app.models.scoring_config import ScoringConfig
|
||||
from app.models.sso_config import SsoConfig
|
||||
from app.models.system_config import SystemConfig
|
||||
from app.models.test_template import TestTemplate
|
||||
from app.models.user import User
|
||||
from app.models.webhook_config import WebhookConfig
|
||||
|
||||
router = APIRouter(prefix="/admin", tags=["admin"])
|
||||
|
||||
# Keys whose values contain secrets and must be redacted in the export
|
||||
_REDACTED_KEYS = {
|
||||
"smtp.password",
|
||||
"jira.api_token",
|
||||
"jira.password",
|
||||
"tempo.api_token",
|
||||
}
|
||||
|
||||
_EXPORT_VERSION = "1.0"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _redact(key: str, value: Any) -> Any:
|
||||
if key in _REDACTED_KEYS:
|
||||
return "[REDACTED]"
|
||||
return value
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /admin/export-config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get("/export-config")
|
||||
def export_config(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Export all platform configuration as a downloadable JSON bundle."""
|
||||
|
||||
# ── 1. system_configs ────────────────────────────────────────────
|
||||
system_configs = [
|
||||
{
|
||||
"key": r.key,
|
||||
"value": _redact(r.key, r.value),
|
||||
"description": r.description,
|
||||
}
|
||||
for r in db.query(SystemConfig).order_by(SystemConfig.key).all()
|
||||
]
|
||||
|
||||
# ── 2. webhook_configs ───────────────────────────────────────────
|
||||
webhooks = [
|
||||
{
|
||||
"name": w.name,
|
||||
"url": w.url,
|
||||
"secret": "[REDACTED]" if w.secret else None,
|
||||
"events": w.events or [],
|
||||
"is_active": w.is_active,
|
||||
}
|
||||
for w in db.query(WebhookConfig).order_by(WebhookConfig.name).all()
|
||||
]
|
||||
|
||||
# ── 3. SSO config (single row) ───────────────────────────────────
|
||||
sso_row = db.query(SsoConfig).first()
|
||||
sso = None
|
||||
if sso_row:
|
||||
sso = {
|
||||
"is_enabled": sso_row.is_enabled,
|
||||
"provider_name": sso_row.provider_name,
|
||||
"sp_entity_id": sso_row.sp_entity_id,
|
||||
"sp_acs_url": sso_row.sp_acs_url,
|
||||
"sp_slo_url": sso_row.sp_slo_url,
|
||||
"sp_certificate": sso_row.sp_certificate,
|
||||
"sp_private_key": "[REDACTED]", # never export private keys
|
||||
"idp_entity_id": sso_row.idp_entity_id,
|
||||
"idp_sso_url": getattr(sso_row, "idp_sso_url", None),
|
||||
"idp_slo_url": getattr(sso_row, "idp_slo_url", None),
|
||||
"idp_certificate": getattr(sso_row, "idp_certificate", None),
|
||||
"attr_email": getattr(sso_row, "attr_email", None),
|
||||
"attr_username": getattr(sso_row, "attr_username", None),
|
||||
"attr_role": getattr(sso_row, "attr_role", None),
|
||||
"default_role": getattr(sso_row, "default_role", None),
|
||||
"auto_provision": getattr(sso_row, "auto_provision", False),
|
||||
}
|
||||
|
||||
# ── 4. Scoring config (single row) ──────────────────────────────
|
||||
sc = db.query(ScoringConfig).first()
|
||||
scoring = None
|
||||
if sc:
|
||||
scoring = {
|
||||
"weight_tests": sc.weight_tests,
|
||||
"weight_detection_rules": sc.weight_detection_rules,
|
||||
"weight_d3fend": sc.weight_d3fend,
|
||||
"weight_recency": sc.weight_recency,
|
||||
"weight_severity": sc.weight_severity,
|
||||
}
|
||||
|
||||
# ── 5. Custom test templates only ───────────────────────────────
|
||||
templates = [
|
||||
{
|
||||
"mitre_technique_id": t.mitre_technique_id,
|
||||
"name": t.name,
|
||||
"description": t.description,
|
||||
"source": t.source,
|
||||
"source_url": t.source_url,
|
||||
"attack_procedure": t.attack_procedure,
|
||||
"expected_detection": t.expected_detection,
|
||||
"platform": t.platform,
|
||||
"tool_suggested": t.tool_suggested,
|
||||
"severity": t.severity,
|
||||
"suggested_remediation": t.suggested_remediation,
|
||||
"is_active": t.is_active,
|
||||
}
|
||||
for t in db.query(TestTemplate).filter(TestTemplate.source == "custom").all()
|
||||
]
|
||||
|
||||
# ── 6. Users (sanitized — no passwords/tokens) ───────────────────
|
||||
users = [
|
||||
{
|
||||
"username": u.username,
|
||||
"email": u.email if hasattr(u, "email") else None,
|
||||
"role": u.role,
|
||||
"is_active": u.is_active,
|
||||
"must_change_password": True, # force password reset on new instance
|
||||
}
|
||||
for u in db.query(User).order_by(User.username).all()
|
||||
]
|
||||
|
||||
bundle = {
|
||||
"_meta": {
|
||||
"version": _EXPORT_VERSION,
|
||||
"exported_at": datetime.utcnow().isoformat() + "Z",
|
||||
"exported_by": current_user.username,
|
||||
"note": (
|
||||
"Sensitive values (passwords, API tokens, private keys) are REDACTED. "
|
||||
"Re-enter them manually after import. "
|
||||
"User passwords are NOT exported — users must reset passwords on first login."
|
||||
),
|
||||
},
|
||||
"system_configs": system_configs,
|
||||
"webhooks": webhooks,
|
||||
"sso": sso,
|
||||
"scoring": scoring,
|
||||
"custom_templates": templates,
|
||||
"users": users,
|
||||
}
|
||||
|
||||
filename = f"aegis-config-{datetime.utcnow().strftime('%Y%m%d-%H%M%S')}.json"
|
||||
return JSONResponse(
|
||||
content=bundle,
|
||||
headers={
|
||||
"Content-Disposition": f'attachment; filename="{filename}"',
|
||||
"X-Export-Version": _EXPORT_VERSION,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /admin/import-config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.post("/import-config")
|
||||
async def import_config(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Restore platform configuration from a previously exported JSON bundle.
|
||||
|
||||
Idempotent: safe to run multiple times. Existing records are updated,
|
||||
missing ones are created. REDACTED values are skipped (left as-is).
|
||||
User passwords are set to a random temp value with must_change_password=True.
|
||||
"""
|
||||
try:
|
||||
bundle = await request.json()
|
||||
except Exception:
|
||||
raise HTTPException(status_code=400, detail="Invalid JSON body")
|
||||
|
||||
meta = bundle.get("_meta", {})
|
||||
version = meta.get("version", "unknown")
|
||||
summary: dict[str, int] = {
|
||||
"system_configs": 0,
|
||||
"webhooks": 0,
|
||||
"custom_templates": 0,
|
||||
"users_created": 0,
|
||||
"users_updated": 0,
|
||||
}
|
||||
|
||||
# ── 1. system_configs ────────────────────────────────────────────
|
||||
for item in bundle.get("system_configs", []):
|
||||
key = item.get("key")
|
||||
value = item.get("value")
|
||||
if not key or value == "[REDACTED]":
|
||||
continue
|
||||
row = db.query(SystemConfig).filter(SystemConfig.key == key).first()
|
||||
if row:
|
||||
row.value = value
|
||||
row.description = item.get("description") or row.description
|
||||
else:
|
||||
db.add(SystemConfig(key=key, value=value, description=item.get("description")))
|
||||
summary["system_configs"] += 1
|
||||
|
||||
# ── 2. webhooks ──────────────────────────────────────────────────
|
||||
for item in bundle.get("webhooks", []):
|
||||
name = item.get("name")
|
||||
url = item.get("url")
|
||||
if not name or not url:
|
||||
continue
|
||||
existing = db.query(WebhookConfig).filter(WebhookConfig.name == name).first()
|
||||
if existing:
|
||||
existing.url = url
|
||||
existing.events = item.get("events", [])
|
||||
existing.is_active = item.get("is_active", True)
|
||||
existing.failure_count = 0
|
||||
else:
|
||||
db.add(WebhookConfig(
|
||||
name=name,
|
||||
url=url,
|
||||
secret=None, # never restore secrets
|
||||
events=item.get("events", []),
|
||||
is_active=item.get("is_active", True),
|
||||
created_by=current_user.id,
|
||||
failure_count=0,
|
||||
))
|
||||
summary["webhooks"] += 1
|
||||
|
||||
# ── 3. SSO config ────────────────────────────────────────────────
|
||||
sso_data = bundle.get("sso")
|
||||
if sso_data:
|
||||
sso_row = db.query(SsoConfig).first()
|
||||
if sso_row:
|
||||
for field, val in sso_data.items():
|
||||
if val == "[REDACTED]":
|
||||
continue
|
||||
if hasattr(sso_row, field):
|
||||
setattr(sso_row, field, val)
|
||||
else:
|
||||
clean = {k: v for k, v in sso_data.items() if v != "[REDACTED]"}
|
||||
clean.pop("sp_private_key", None)
|
||||
db.add(SsoConfig(**clean))
|
||||
|
||||
# ── 4. Scoring config ────────────────────────────────────────────
|
||||
scoring_data = bundle.get("scoring")
|
||||
if scoring_data:
|
||||
sc = db.query(ScoringConfig).first()
|
||||
if sc:
|
||||
for field, val in scoring_data.items():
|
||||
if hasattr(sc, field) and val is not None:
|
||||
setattr(sc, field, val)
|
||||
else:
|
||||
db.add(ScoringConfig(**scoring_data))
|
||||
|
||||
# ── 5. Custom templates ──────────────────────────────────────────
|
||||
for item in bundle.get("custom_templates", []):
|
||||
name = item.get("name")
|
||||
mitre_id = item.get("mitre_technique_id")
|
||||
if not name or not mitre_id:
|
||||
continue
|
||||
existing = (
|
||||
db.query(TestTemplate)
|
||||
.filter(TestTemplate.name == name, TestTemplate.source == "custom")
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
for field, val in item.items():
|
||||
if hasattr(existing, field):
|
||||
setattr(existing, field, val)
|
||||
else:
|
||||
db.add(TestTemplate(**{k: v for k, v in item.items()
|
||||
if k not in ("id", "created_at")}))
|
||||
summary["custom_templates"] += 1
|
||||
|
||||
# ── 6. Users ─────────────────────────────────────────────────────
|
||||
import secrets as _secrets
|
||||
for item in bundle.get("users", []):
|
||||
username = item.get("username")
|
||||
if not username:
|
||||
continue
|
||||
existing = db.query(User).filter(User.username == username).first()
|
||||
if existing:
|
||||
existing.role = item.get("role", existing.role)
|
||||
existing.is_active = item.get("is_active", existing.is_active)
|
||||
summary["users_updated"] += 1
|
||||
else:
|
||||
# Create with random temp password — user must reset on login
|
||||
temp_pw = _secrets.token_urlsafe(16) + "Aa1!"
|
||||
new_user = User(
|
||||
username=username,
|
||||
hashed_password=hash_password(temp_pw),
|
||||
role=item.get("role", "viewer"),
|
||||
is_active=item.get("is_active", True),
|
||||
must_change_password=True,
|
||||
)
|
||||
if item.get("email") and hasattr(User, "email"):
|
||||
new_user.email = item["email"]
|
||||
db.add(new_user)
|
||||
summary["users_created"] += 1
|
||||
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
"status": "ok",
|
||||
"imported_from_version": version,
|
||||
"summary": summary,
|
||||
"warnings": [
|
||||
"REDACTED values were skipped — re-enter passwords/tokens manually.",
|
||||
"All imported users have must_change_password=True.",
|
||||
"SSO private key was not restored — re-upload it manually.",
|
||||
],
|
||||
}
|
||||
@@ -0,0 +1,104 @@
|
||||
"""Phase 14: API Key management router."""
|
||||
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import get_current_user, require_any_role
|
||||
from app.models.user import User
|
||||
from app.schemas.api_key_schema import (
|
||||
ApiKeyCreate, ApiKeyCreated, ApiKeyOut, ApiKeyUpdate,
|
||||
)
|
||||
import app.services.api_key_service as svc
|
||||
|
||||
router = APIRouter(prefix="/api-keys", tags=["API Keys"])
|
||||
|
||||
|
||||
@router.post("", response_model=ApiKeyCreated, status_code=201)
|
||||
def create_key(
|
||||
body: ApiKeyCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Create a scoped API key.
|
||||
|
||||
The ``raw_key`` field in the response is shown **exactly once** and
|
||||
cannot be retrieved later. Store it securely.
|
||||
"""
|
||||
key, raw_key = svc.create_api_key(
|
||||
db,
|
||||
user_id = current_user.id,
|
||||
name = body.name,
|
||||
scopes = body.scopes,
|
||||
description = body.description,
|
||||
expires_at = body.expires_at,
|
||||
)
|
||||
out = ApiKeyOut.model_validate(key)
|
||||
return ApiKeyCreated(**out.model_dump(), raw_key=raw_key)
|
||||
|
||||
|
||||
@router.get("", response_model=List[ApiKeyOut])
|
||||
def list_keys(
|
||||
include_inactive: bool = Query(False),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""List API keys owned by the current user."""
|
||||
# Admins can see all keys; others only see their own
|
||||
user_id = None if current_user.role == "admin" else current_user.id
|
||||
return svc.list_api_keys(db, user_id=user_id, include_inactive=include_inactive)
|
||||
|
||||
|
||||
@router.get("/{key_id}", response_model=ApiKeyOut)
|
||||
def get_key(
|
||||
key_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Get a single API key (owner or admin)."""
|
||||
user_id = None if current_user.role == "admin" else current_user.id
|
||||
return svc.get_api_key(db, key_id, user_id=user_id)
|
||||
|
||||
|
||||
@router.patch("/{key_id}", response_model=ApiKeyOut)
|
||||
def update_key(
|
||||
key_id: UUID,
|
||||
body: ApiKeyUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Update name, description, scopes, expiry, or active status."""
|
||||
user_id = None if current_user.role == "admin" else current_user.id
|
||||
return svc.update_api_key(
|
||||
db, key_id, user_id,
|
||||
name = body.name,
|
||||
description = body.description,
|
||||
scopes = body.scopes,
|
||||
expires_at = body.expires_at,
|
||||
is_active = body.is_active,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{key_id}/revoke", response_model=ApiKeyOut)
|
||||
def revoke_key(
|
||||
key_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Revoke an API key (soft-delete — sets is_active=False)."""
|
||||
user_id = None if current_user.role == "admin" else current_user.id
|
||||
return svc.revoke_api_key(db, key_id, user_id=user_id)
|
||||
|
||||
|
||||
@router.delete("/{key_id}", status_code=204)
|
||||
def delete_key(
|
||||
key_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_any_role("admin")),
|
||||
):
|
||||
"""Permanently delete an API key (admin only)."""
|
||||
svc.delete_api_key(db, key_id)
|
||||
@@ -0,0 +1,250 @@
|
||||
"""Phase 10: Attack Paths & Advanced Purple Team router."""
|
||||
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import get_current_user, require_any_role
|
||||
from app.schemas.attack_path_schema import (
|
||||
AttackPathCreate, AttackPathUpdate, AttackPathOut,
|
||||
AttackPathStepCreate, AttackPathStepUpdate, AttackPathStepOut,
|
||||
ExecutionCreate, ExecutionOut,
|
||||
StepExecuteRequest, StepResultOut,
|
||||
TimelineEntryCreate, TimelineEntryOut,
|
||||
KillChainMetrics,
|
||||
)
|
||||
from app.services import attack_path_service as svc
|
||||
|
||||
router = APIRouter(prefix="/attack-paths", tags=["attack-paths"])
|
||||
|
||||
|
||||
# ── Attack Paths CRUD ─────────────────────────────────────────────────────────
|
||||
|
||||
@router.post("", response_model=AttackPathOut, status_code=201)
|
||||
def create_attack_path(
|
||||
body: AttackPathCreate,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
return svc.create_attack_path(db, body.model_dump(), user.id)
|
||||
|
||||
|
||||
@router.get("", response_model=list[AttackPathOut])
|
||||
def list_attack_paths(
|
||||
is_template: Optional[bool] = None,
|
||||
technique_id: Optional[UUID] = None,
|
||||
is_active: Optional[bool] = True,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
paths = svc.list_attack_paths(db, is_template=is_template,
|
||||
technique_id=technique_id, is_active=is_active)
|
||||
# Inject step_count
|
||||
result = []
|
||||
for p in paths:
|
||||
d = AttackPathOut.model_validate(p)
|
||||
d.step_count = len(p.steps)
|
||||
result.append(d)
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/{path_id}", response_model=AttackPathOut)
|
||||
def get_attack_path(
|
||||
path_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
p = svc.get_attack_path(db, path_id)
|
||||
d = AttackPathOut.model_validate(p)
|
||||
d.step_count = len(p.steps)
|
||||
return d
|
||||
|
||||
|
||||
@router.patch("/{path_id}", response_model=AttackPathOut)
|
||||
def update_attack_path(
|
||||
path_id: UUID,
|
||||
body: AttackPathUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
return svc.update_attack_path(db, path_id, body.model_dump(exclude_unset=True), user.id)
|
||||
|
||||
|
||||
@router.delete("/{path_id}", status_code=204)
|
||||
def delete_attack_path(
|
||||
path_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(require_any_role("admin", "red_lead", "blue_lead")),
|
||||
):
|
||||
svc.delete_attack_path(db, path_id, user.id)
|
||||
|
||||
|
||||
# ── Steps ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/{path_id}/steps", response_model=list[AttackPathStepOut])
|
||||
def list_steps(
|
||||
path_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
path = svc.get_attack_path(db, path_id)
|
||||
return path.steps
|
||||
|
||||
|
||||
@router.post("/{path_id}/steps", response_model=AttackPathStepOut, status_code=201)
|
||||
def add_step(
|
||||
path_id: UUID,
|
||||
body: AttackPathStepCreate,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
return svc.add_step(db, path_id, body.model_dump(), user.id)
|
||||
|
||||
|
||||
@router.patch("/{path_id}/steps/{step_id}", response_model=AttackPathStepOut)
|
||||
def update_step(
|
||||
path_id: UUID,
|
||||
step_id: UUID,
|
||||
body: AttackPathStepUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
return svc.update_step(db, step_id, body.model_dump(exclude_unset=True), user.id)
|
||||
|
||||
|
||||
@router.delete("/{path_id}/steps/{step_id}", status_code=204)
|
||||
def delete_step(
|
||||
path_id: UUID,
|
||||
step_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
svc.delete_step(db, step_id, user.id)
|
||||
|
||||
|
||||
@router.post("/{path_id}/steps/reorder", response_model=list[AttackPathStepOut])
|
||||
def reorder_steps(
|
||||
path_id: UUID,
|
||||
step_ids: list[UUID],
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
"""Pass an ordered list of step UUIDs to reorder the steps."""
|
||||
return svc.reorder_steps(db, path_id, step_ids, user.id)
|
||||
|
||||
|
||||
# ── Executions ────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.post("/{path_id}/executions", response_model=ExecutionOut, status_code=201)
|
||||
def create_execution(
|
||||
path_id: UUID,
|
||||
body: ExecutionCreate,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
return svc.create_execution(db, path_id, body.model_dump(), user.id)
|
||||
|
||||
|
||||
@router.get("/{path_id}/executions", response_model=list[ExecutionOut])
|
||||
def list_executions(
|
||||
path_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
return svc.list_executions(db, path_id)
|
||||
|
||||
|
||||
@router.get("/executions/{execution_id}", response_model=ExecutionOut)
|
||||
def get_execution(
|
||||
execution_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
return svc.get_execution(db, execution_id)
|
||||
|
||||
|
||||
@router.post("/executions/{execution_id}/start", response_model=ExecutionOut)
|
||||
def start_execution(
|
||||
execution_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
return svc.start_execution(db, execution_id, user.id)
|
||||
|
||||
|
||||
@router.post("/executions/{execution_id}/steps/{step_id}", response_model=StepResultOut)
|
||||
def execute_step(
|
||||
execution_id: UUID,
|
||||
step_id: UUID,
|
||||
body: StepExecuteRequest,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
"""Record the result of one step (detected / not_detected / skipped)."""
|
||||
return svc.execute_step(db, execution_id, step_id, body.model_dump(), user.id)
|
||||
|
||||
|
||||
@router.get("/executions/{execution_id}/steps", response_model=list[StepResultOut])
|
||||
def list_step_results(
|
||||
execution_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
ex = svc.get_execution(db, execution_id)
|
||||
return ex.step_results
|
||||
|
||||
|
||||
@router.post("/executions/{execution_id}/complete", response_model=ExecutionOut)
|
||||
def complete_execution(
|
||||
execution_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
"""Mark execution as complete and compute kill-chain metrics."""
|
||||
return svc.complete_execution(db, execution_id, user.id)
|
||||
|
||||
|
||||
@router.post("/executions/{execution_id}/abort", response_model=ExecutionOut)
|
||||
def abort_execution(
|
||||
execution_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(require_any_role("admin", "red_lead", "blue_lead")),
|
||||
):
|
||||
return svc.abort_execution(db, execution_id, user.id)
|
||||
|
||||
|
||||
# ── Timeline ──────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.post("/executions/{execution_id}/timeline",
|
||||
response_model=TimelineEntryOut, status_code=201)
|
||||
def add_timeline_entry(
|
||||
execution_id: UUID,
|
||||
body: TimelineEntryCreate,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
return svc.add_timeline_entry(db, execution_id, body.model_dump(), user.id)
|
||||
|
||||
|
||||
@router.get("/executions/{execution_id}/timeline", response_model=list[TimelineEntryOut])
|
||||
def get_timeline(
|
||||
execution_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
return svc.get_timeline(db, execution_id)
|
||||
|
||||
|
||||
# ── Kill-Chain Metrics ────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/executions/{execution_id}/metrics")
|
||||
def get_metrics(
|
||||
execution_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
"""Return full kill-chain metrics for a completed (or partial) execution."""
|
||||
return svc.get_kill_chain_metrics(db, execution_id)
|
||||
@@ -34,7 +34,16 @@ from app.schemas.user import PasswordChange
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
_IS_HTTPS = os.environ.get("AEGIS_ENV", "").lower() == "production"
|
||||
# SECURE_COOKIES desacopla la seguridad de la cookie del entorno de ejecucion.
|
||||
# Por defecto activo en produccion; ponlo en "false" para servidores HTTP.
|
||||
_aegis_env = os.environ.get("AEGIS_ENV", "development").lower()
|
||||
_secure_cookie_env = os.environ.get("SECURE_COOKIES", "auto").lower()
|
||||
if _secure_cookie_env == "false":
|
||||
_IS_HTTPS = False
|
||||
elif _secure_cookie_env == "true":
|
||||
_IS_HTTPS = True
|
||||
else: # "auto" — activo solo si AEGIS_ENV=production
|
||||
_IS_HTTPS = _aegis_env == "production"
|
||||
_COOKIE_NAME = "aegis_token"
|
||||
|
||||
|
||||
@@ -146,6 +155,57 @@ def logout(
|
||||
return {"detail": "Logged out"}
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=TokenResponse)
|
||||
def refresh_token(
|
||||
response: Response,
|
||||
aegis_token: str | None = Cookie(None),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Issue a new access token if the current one is valid.
|
||||
|
||||
Called automatically by the frontend when it detects an expired
|
||||
session while the user is actively using the app. If the current
|
||||
cookie token is still valid (not blacklisted, not expired), a fresh
|
||||
token is issued and the cookie is renewed — keeping the session alive
|
||||
without requiring re-authentication.
|
||||
"""
|
||||
if not aegis_token:
|
||||
raise PermissionViolation("No active session")
|
||||
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
aegis_token,
|
||||
settings.SECRET_KEY,
|
||||
algorithms=[settings.ALGORITHM],
|
||||
)
|
||||
except JWTError:
|
||||
raise PermissionViolation("Session expired — please log in again")
|
||||
|
||||
username: str | None = payload.get("sub")
|
||||
if not username:
|
||||
raise PermissionViolation("Invalid session")
|
||||
|
||||
user = db.query(User).filter(User.username == username).first()
|
||||
if user is None or not user.is_active:
|
||||
raise PermissionViolation("Account not found or disabled")
|
||||
|
||||
if getattr(user, "must_change_password", False):
|
||||
raise PermissionViolation("Password change required before refreshing session")
|
||||
|
||||
# Issue a fresh token with a new expiry
|
||||
new_token = create_access_token(data={"sub": user.username})
|
||||
response.set_cookie(
|
||||
key=_COOKIE_NAME,
|
||||
value=new_token,
|
||||
httponly=True,
|
||||
secure=_IS_HTTPS,
|
||||
samesite="strict",
|
||||
max_age=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
||||
path="/",
|
||||
)
|
||||
return TokenResponse(access_token=new_token)
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserOut)
|
||||
def read_current_user(current_user: User = Depends(get_current_user)):
|
||||
"""Return the profile of the currently authenticated user."""
|
||||
|
||||
@@ -6,6 +6,7 @@ test ordering, progress tracking, and threat actor integration.
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
@@ -15,12 +16,15 @@ from pydantic import BaseModel, Field
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import get_current_user, require_any_role
|
||||
from app.models.user import User
|
||||
from app.models.campaign import Campaign, CampaignTest
|
||||
from app.models.test import Test
|
||||
from app.services.campaign_service import generate_campaign_from_threat_actor
|
||||
from app.services.campaign_crud_service import (
|
||||
add_test_to_campaign as crud_add_test,
|
||||
activate_campaign as crud_activate,
|
||||
complete_campaign as crud_complete,
|
||||
create_campaign as crud_create,
|
||||
delete_campaign as crud_delete,
|
||||
get_campaign_detail as crud_get_detail,
|
||||
get_campaign_history as crud_get_history,
|
||||
get_campaign_progress_data as crud_get_progress,
|
||||
@@ -33,6 +37,7 @@ from app.services.campaign_crud_service import (
|
||||
from app.domain.unit_of_work import UnitOfWork
|
||||
from app.services.audit_service import log_action
|
||||
from app.services.notification_service import notify_role
|
||||
from app.services.webhook_service import dispatch_webhook
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -49,6 +54,7 @@ class CampaignCreate(BaseModel):
|
||||
target_platform: Optional[str] = None
|
||||
tags: Optional[list[str]] = Field(default_factory=list)
|
||||
scheduled_at: Optional[str] = None
|
||||
start_date: Optional[str] = None # ISO date — campaign won't activate before this
|
||||
|
||||
class CampaignUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
@@ -57,6 +63,7 @@ class CampaignUpdate(BaseModel):
|
||||
target_platform: Optional[str] = None
|
||||
tags: Optional[list[str]] = None
|
||||
scheduled_at: Optional[str] = None
|
||||
start_date: Optional[str] = None # ISO date — can be updated while still in draft
|
||||
|
||||
class AddTestPayload(BaseModel):
|
||||
test_id: str
|
||||
@@ -120,13 +127,15 @@ def create_campaign(
|
||||
target_platform=payload.target_platform,
|
||||
tags=payload.tags,
|
||||
scheduled_at=payload.scheduled_at,
|
||||
start_date=payload.start_date,
|
||||
)
|
||||
campaign_id = result["id"]
|
||||
log_action(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
action="create_campaign",
|
||||
entity_type="campaign",
|
||||
entity_id=result["id"],
|
||||
entity_id=campaign_id,
|
||||
details={"name": payload.name, "type": payload.type},
|
||||
)
|
||||
uow.commit()
|
||||
@@ -182,6 +191,37 @@ def update_campaign(
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DELETE /campaigns/{id} — Delete campaign
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.delete("/{campaign_id}", status_code=204)
|
||||
def delete_campaign(
|
||||
campaign_id: str,
|
||||
delete_tests: bool = Query(False, description="Also delete associated tests"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Delete a campaign. Only draft campaigns can be deleted (admins can delete any)."""
|
||||
with UnitOfWork(db) as uow:
|
||||
crud_delete(
|
||||
db,
|
||||
campaign_id,
|
||||
deleter_id=current_user.id,
|
||||
deleter_role=current_user.role,
|
||||
delete_tests=delete_tests,
|
||||
)
|
||||
log_action(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
action="delete_campaign",
|
||||
entity_type="campaign",
|
||||
entity_id=campaign_id,
|
||||
details={"delete_tests": delete_tests},
|
||||
)
|
||||
uow.commit()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /campaigns/{id}/tests — Add test to campaign
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -204,6 +244,7 @@ def add_test_to_campaign(
|
||||
phase=payload.phase,
|
||||
)
|
||||
uow.commit()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -232,10 +273,35 @@ def remove_test_from_campaign(
|
||||
@router.post("/{campaign_id}/activate")
|
||||
def activate_campaign(
|
||||
campaign_id: str,
|
||||
force: bool = Query(False, description="Activate even if start_date is in the future"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||
):
|
||||
"""Activate a campaign, moving it from draft to active."""
|
||||
"""Activate a campaign, moving it from draft to active.
|
||||
|
||||
If the campaign has a start_date in the future and force=False, returns a 409
|
||||
with a warning so the frontend can show a confirmation modal. If force=True,
|
||||
activates immediately regardless of start_date.
|
||||
"""
|
||||
from fastapi import HTTPException
|
||||
campaign_obj = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||
if campaign_obj and campaign_obj.start_date and not force:
|
||||
now = datetime.utcnow()
|
||||
if campaign_obj.start_date > now:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail={
|
||||
"code": "start_date_in_future",
|
||||
"start_date": campaign_obj.start_date.strftime("%Y-%m-%d"),
|
||||
"message": (
|
||||
f"This campaign is scheduled to start on "
|
||||
f"{campaign_obj.start_date.strftime('%d %b %Y')}. "
|
||||
f"It will activate automatically on that date. "
|
||||
f"Do you want to activate it now anyway?"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
with UnitOfWork(db) as uow:
|
||||
campaign = crud_activate(db, campaign_id)
|
||||
notify_role(
|
||||
@@ -258,6 +324,33 @@ def activate_campaign(
|
||||
uow.commit()
|
||||
db.refresh(campaign)
|
||||
|
||||
# Create Jira tickets for campaign and tests at activation time (non-fatal).
|
||||
# Campaign ticket is created here if it doesn't already exist (deferred from creation).
|
||||
try:
|
||||
from app.services.jira_service import (
|
||||
auto_create_campaign_issue,
|
||||
auto_create_test_issue,
|
||||
get_campaign_jira_key,
|
||||
get_test_jira_key,
|
||||
)
|
||||
campaign_jira_key = get_campaign_jira_key(db, campaign_id)
|
||||
if not campaign_jira_key:
|
||||
campaign_jira_key = auto_create_campaign_issue(db, campaign, current_user)
|
||||
if campaign_jira_key:
|
||||
for ct in campaign.campaign_tests:
|
||||
if ct.test and not get_test_jira_key(db, ct.test.id):
|
||||
auto_create_test_issue(
|
||||
db, ct.test, current_user,
|
||||
parent_ticket_override=campaign_jira_key,
|
||||
campaign_start_date=campaign.start_date,
|
||||
)
|
||||
db.commit()
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Jira ticket creation failed during activation of campaign %s",
|
||||
campaign_id,
|
||||
)
|
||||
|
||||
return serialize_campaign(db, campaign)
|
||||
|
||||
|
||||
@@ -284,6 +377,7 @@ def complete_campaign(
|
||||
)
|
||||
uow.commit()
|
||||
db.refresh(campaign)
|
||||
dispatch_webhook("campaign.completed", {"campaign_id": str(campaign.id), "name": campaign.name})
|
||||
|
||||
return serialize_campaign(db, campaign)
|
||||
|
||||
@@ -306,9 +400,14 @@ def get_campaign_progress_endpoint(
|
||||
# POST /campaigns/from-threat-actor/{actor_id} — Auto-generate campaign
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class GenerateFromActorPayload(BaseModel):
|
||||
start_date: Optional[str] = None # ISO date YYYY-MM-DD
|
||||
|
||||
|
||||
@router.post("/from-threat-actor/{actor_id}", status_code=201)
|
||||
def generate_campaign_from_actor(
|
||||
actor_id: str,
|
||||
payload: GenerateFromActorPayload = GenerateFromActorPayload(),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||
):
|
||||
@@ -317,10 +416,14 @@ def generate_campaign_from_actor(
|
||||
Creates tests from the best available templates and orders them
|
||||
by kill chain phase.
|
||||
"""
|
||||
start_date_parsed = (
|
||||
datetime.fromisoformat(payload.start_date) if payload.start_date else None
|
||||
)
|
||||
campaign = generate_campaign_from_threat_actor(
|
||||
db,
|
||||
uuid.UUID(actor_id),
|
||||
current_user,
|
||||
start_date=start_date_parsed,
|
||||
)
|
||||
|
||||
with UnitOfWork(db) as uow:
|
||||
@@ -392,3 +495,97 @@ def get_campaign_history(
|
||||
):
|
||||
"""List all child campaigns (execution history) of a recurring campaign."""
|
||||
return crud_get_history(db, campaign_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /campaigns/{id}/timing-summary — Aggregated timing across campaign tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _seconds_between(start: datetime | None, end: datetime | None) -> int:
|
||||
"""Return elapsed seconds between two datetimes; 0 if either is None."""
|
||||
if not start or not end:
|
||||
return 0
|
||||
diff = (end - start).total_seconds()
|
||||
return max(0, int(diff))
|
||||
|
||||
|
||||
@router.get("/{campaign_id}/timing-summary")
|
||||
def get_campaign_timing_summary(
|
||||
campaign_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Return aggregated Red/Blue timing metrics for all tests in a campaign.
|
||||
|
||||
For each test we calculate:
|
||||
- red_execution_secs : red_started_at → blue_started_at (minus red_paused_seconds)
|
||||
- blue_queue_secs : blue_started_at → blue_work_started_at (waiting for Blue pick-up)
|
||||
- blue_evaluation_secs: blue_work_started_at → first validation timestamp (minus blue_paused_seconds)
|
||||
- total_secs : sum of the three phases
|
||||
|
||||
Returns totals + per-test breakdown.
|
||||
"""
|
||||
# Load campaign
|
||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||
if not campaign:
|
||||
from fastapi import HTTPException
|
||||
raise HTTPException(status_code=404, detail="Campaign not found")
|
||||
|
||||
# Load all tests for this campaign
|
||||
test_ids = [
|
||||
ct.test_id
|
||||
for ct in db.query(CampaignTest).filter(CampaignTest.campaign_id == campaign.id).all()
|
||||
]
|
||||
tests = db.query(Test).filter(Test.id.in_(test_ids)).all() if test_ids else []
|
||||
|
||||
breakdown = []
|
||||
total_red = 0
|
||||
total_queue = 0
|
||||
total_blue = 0
|
||||
|
||||
for t in tests:
|
||||
# Red execution: from start-execution to submit-to-blue, minus paused time
|
||||
red_secs = max(
|
||||
0,
|
||||
_seconds_between(t.red_started_at, t.blue_started_at) - (t.red_paused_seconds or 0),
|
||||
)
|
||||
|
||||
# Blue queue: from receiving the test to actually starting evaluation
|
||||
queue_secs = _seconds_between(t.blue_started_at, t.blue_work_started_at)
|
||||
|
||||
# Blue evaluation: from starting evaluation to first validation, minus paused time
|
||||
eval_end = t.red_validated_at or t.blue_validated_at
|
||||
blue_secs = max(
|
||||
0,
|
||||
_seconds_between(t.blue_work_started_at, eval_end) - (t.blue_paused_seconds or 0),
|
||||
)
|
||||
|
||||
total_red += red_secs
|
||||
total_queue += queue_secs
|
||||
total_blue += blue_secs
|
||||
|
||||
breakdown.append({
|
||||
"test_id": str(t.id),
|
||||
"test_name": t.name,
|
||||
"state": t.state.value if t.state else None,
|
||||
"red_execution_secs": red_secs,
|
||||
"blue_queue_secs": queue_secs,
|
||||
"blue_evaluation_secs": blue_secs,
|
||||
"total_secs": red_secs + queue_secs + blue_secs,
|
||||
"has_timing": bool(t.red_started_at),
|
||||
})
|
||||
|
||||
total_secs = total_red + total_queue + total_blue
|
||||
|
||||
return {
|
||||
"campaign_id": campaign_id,
|
||||
"campaign_name": campaign.name,
|
||||
"tests_total": len(tests),
|
||||
"tests_with_timing": sum(1 for b in breakdown if b["has_timing"]),
|
||||
"red_execution_secs": total_red,
|
||||
"blue_queue_secs": total_queue,
|
||||
"blue_evaluation_secs": total_blue,
|
||||
"total_secs": total_secs,
|
||||
"breakdown": sorted(breakdown, key=lambda x: -(x["total_secs"])),
|
||||
}
|
||||
|
||||
@@ -22,6 +22,9 @@ from app.services.compliance_service import (
|
||||
from app.services.compliance_import_service import (
|
||||
import_nist_800_53_mappings,
|
||||
import_cis_controls_v8_mappings,
|
||||
import_dora_mappings,
|
||||
import_iso_27001_mappings,
|
||||
import_iso_42001_mappings,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/compliance", tags=["compliance"])
|
||||
@@ -119,3 +122,33 @@ def import_cis(
|
||||
"""Import CIS Controls v8 mappings (admin only)."""
|
||||
result = import_cis_controls_v8_mappings(db)
|
||||
return result
|
||||
|
||||
|
||||
@router.post("/import/dora")
|
||||
def import_dora(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Import DORA (EU 2022/2554) compliance mappings (admin only)."""
|
||||
result = import_dora_mappings(db)
|
||||
return result
|
||||
|
||||
|
||||
@router.post("/import/iso-27001")
|
||||
def import_iso27001(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Import ISO/IEC 27001:2022 Annex A compliance mappings (admin only)."""
|
||||
result = import_iso_27001_mappings(db)
|
||||
return result
|
||||
|
||||
|
||||
@router.post("/import/iso-42001")
|
||||
def import_iso42001(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Import ISO/IEC 42001:2023 AI Management System compliance mappings (admin only)."""
|
||||
result = import_iso_42001_mappings(db)
|
||||
return result
|
||||
|
||||
@@ -0,0 +1,319 @@
|
||||
"""Detection Lifecycle Management router."""
|
||||
|
||||
import hashlib
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import get_current_user, require_any_role
|
||||
from app.domain.exceptions import EntityNotFoundError
|
||||
from app.models.detection_lifecycle import (
|
||||
DetectionAsset, DetectionTechniqueMapping, DetectionValidation,
|
||||
TechniqueConfidenceScore, InfrastructureChangeLog,
|
||||
)
|
||||
from app.schemas.detection_lifecycle_schema import (
|
||||
DetectionAssetCreate, DetectionAssetUpdate, DetectionAssetOut,
|
||||
DetectionValidationCreate, DetectionValidationOut,
|
||||
TechniqueConfidenceOut,
|
||||
InfrastructureChangeCreate, InfrastructureChangeOut,
|
||||
)
|
||||
from app.services import detection_asset_service, decay_engine_service, audit_service
|
||||
|
||||
router = APIRouter(prefix="/detection-lifecycle", tags=["detection-lifecycle"])
|
||||
|
||||
|
||||
def _now() -> datetime:
|
||||
return datetime.utcnow()
|
||||
|
||||
|
||||
# ── Detection Assets ─────────────────────────────────────────────────────────
|
||||
|
||||
@router.post("/assets", response_model=DetectionAssetOut, status_code=201)
|
||||
def create_asset(body: DetectionAssetCreate, db: Session = Depends(get_db), user=Depends(get_current_user)):
|
||||
asset = detection_asset_service.create_detection_asset(db, body.model_dump(), user.id)
|
||||
return asset
|
||||
|
||||
|
||||
@router.get("/assets", response_model=list[DetectionAssetOut])
|
||||
def list_assets(
|
||||
platform: Optional[str] = None,
|
||||
asset_type: Optional[str] = None,
|
||||
health_status: Optional[str] = None,
|
||||
technique_id: Optional[UUID] = None,
|
||||
is_active: Optional[bool] = True,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
return detection_asset_service.list_assets(db, platform=platform, asset_type=asset_type, health_status=health_status, technique_id=technique_id, is_active=is_active)
|
||||
|
||||
|
||||
@router.get("/assets/{asset_id}", response_model=DetectionAssetOut)
|
||||
def get_asset(asset_id: UUID, db: Session = Depends(get_db), user=Depends(get_current_user)):
|
||||
return detection_asset_service.get_asset_with_details(db, asset_id)
|
||||
|
||||
|
||||
@router.patch("/assets/{asset_id}", response_model=DetectionAssetOut)
|
||||
def update_asset(asset_id: UUID, body: DetectionAssetUpdate, db: Session = Depends(get_db), user=Depends(get_current_user)):
|
||||
return detection_asset_service.update_detection_asset(db, asset_id, body.model_dump(exclude_unset=True), user.id)
|
||||
|
||||
|
||||
@router.delete("/assets/{asset_id}", status_code=204)
|
||||
def delete_asset(asset_id: UUID, db: Session = Depends(get_db), user=Depends(require_any_role("red_lead", "blue_lead"))):
|
||||
asset = db.query(DetectionAsset).filter(DetectionAsset.id == asset_id).first()
|
||||
if not asset:
|
||||
raise EntityNotFoundError("DetectionAsset", str(asset_id))
|
||||
asset.is_active = False
|
||||
db.commit()
|
||||
|
||||
|
||||
# ── Technique Mappings ───────────────────────────────────────────────────────
|
||||
|
||||
@router.post("/assets/{asset_id}/techniques/{technique_id}")
|
||||
def map_technique(
|
||||
asset_id: UUID, technique_id: UUID,
|
||||
coverage_type: str = Query("detect"),
|
||||
confidence_level: str = Query("medium"),
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
# Validate asset exists
|
||||
asset = db.query(DetectionAsset).filter(DetectionAsset.id == asset_id).first()
|
||||
if not asset:
|
||||
raise EntityNotFoundError("DetectionAsset", str(asset_id))
|
||||
|
||||
# Prevent duplicate mappings
|
||||
existing = db.query(DetectionTechniqueMapping).filter(
|
||||
DetectionTechniqueMapping.detection_asset_id == asset_id,
|
||||
DetectionTechniqueMapping.technique_id == technique_id,
|
||||
).first()
|
||||
if existing:
|
||||
# Update coverage/confidence on existing mapping instead of duplicating
|
||||
existing.coverage_type = coverage_type
|
||||
existing.confidence_level = confidence_level
|
||||
db.commit()
|
||||
return {"message": "Technique mapping updated", "mapping_id": str(existing.id)}
|
||||
|
||||
mapping = DetectionTechniqueMapping(
|
||||
detection_asset_id=asset_id, technique_id=technique_id,
|
||||
coverage_type=coverage_type, confidence_level=confidence_level,
|
||||
)
|
||||
db.add(mapping)
|
||||
db.commit()
|
||||
return {"message": "Technique mapped", "mapping_id": str(mapping.id)}
|
||||
|
||||
|
||||
@router.get("/techniques/{technique_id}/detections")
|
||||
def get_technique_detections(technique_id: UUID, db: Session = Depends(get_db), user=Depends(get_current_user)):
|
||||
return detection_asset_service.get_technique_detection_summary(db, technique_id)
|
||||
|
||||
|
||||
# ── Validations ──────────────────────────────────────────────────────────────
|
||||
|
||||
@router.post("/validations", response_model=DetectionValidationOut, status_code=201)
|
||||
def create_validation(body: DetectionValidationCreate, db: Session = Depends(get_db), user=Depends(get_current_user)):
|
||||
asset = db.query(DetectionAsset).filter(DetectionAsset.id == body.detection_asset_id).first()
|
||||
if not asset:
|
||||
raise EntityNotFoundError("DetectionAsset", str(body.detection_asset_id))
|
||||
|
||||
now = _now()
|
||||
validation = DetectionValidation(
|
||||
detection_asset_id=body.detection_asset_id,
|
||||
technique_id=body.technique_id,
|
||||
test_id=body.test_id,
|
||||
validation_result=body.validation_result,
|
||||
validation_method=body.validation_method,
|
||||
notes=body.notes,
|
||||
evidence_ids=[str(e) for e in (body.evidence_ids or [])],
|
||||
validated_by=user.id,
|
||||
validated_at=now,
|
||||
expires_at=now + timedelta(days=body.validity_days),
|
||||
rule_hash_at_validation=asset.rule_hash,
|
||||
log_source_version_at_validation=asset.log_source_version,
|
||||
infrastructure_hash_at_validation=asset.infrastructure_hash,
|
||||
)
|
||||
data = f"{validation.detection_asset_id}:{validation.validated_by}:{validation.validation_result}:{validation.validated_at}"
|
||||
validation.integrity_hash = hashlib.sha256(data.encode()).hexdigest()
|
||||
|
||||
db.add(validation)
|
||||
db.commit()
|
||||
db.refresh(validation)
|
||||
|
||||
if body.technique_id:
|
||||
decay_engine_service.calculate_confidence_for_technique(db, body.technique_id)
|
||||
|
||||
audit_service.log_action(db, user.id, "DETECTION_VALIDATED", "detection_validation", str(validation.id),
|
||||
details={"asset_id": str(body.detection_asset_id), "result": body.validation_result, "validity_days": body.validity_days})
|
||||
|
||||
return validation
|
||||
|
||||
|
||||
@router.get("/validations", response_model=list[DetectionValidationOut])
|
||||
def list_validations(
|
||||
asset_id: Optional[UUID] = None,
|
||||
technique_id: Optional[UUID] = None,
|
||||
is_valid: Optional[bool] = None,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
query = db.query(DetectionValidation)
|
||||
if asset_id:
|
||||
query = query.filter(DetectionValidation.detection_asset_id == asset_id)
|
||||
if technique_id:
|
||||
query = query.filter(DetectionValidation.technique_id == technique_id)
|
||||
if is_valid is not None:
|
||||
query = query.filter(DetectionValidation.is_valid == is_valid)
|
||||
return query.order_by(DetectionValidation.validated_at.desc()).all()
|
||||
|
||||
|
||||
@router.post("/validations/{validation_id}/invalidate")
|
||||
def invalidate_validation(
|
||||
validation_id: UUID,
|
||||
reason: str = Query(...),
|
||||
details: Optional[str] = None,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(require_any_role("admin", "blue_lead")),
|
||||
):
|
||||
validation = db.query(DetectionValidation).filter(DetectionValidation.id == validation_id).first()
|
||||
if not validation:
|
||||
raise EntityNotFoundError("DetectionValidation", str(validation_id))
|
||||
|
||||
from app.models.detection_lifecycle import InvalidationReason
|
||||
try:
|
||||
reason_enum = InvalidationReason(reason)
|
||||
except ValueError:
|
||||
reason_enum = InvalidationReason.manual
|
||||
|
||||
validation.is_valid = False
|
||||
validation.invalidated_at = _now()
|
||||
validation.invalidation_reason = reason_enum
|
||||
validation.invalidation_details = details
|
||||
validation.invalidated_by = user.id
|
||||
db.commit()
|
||||
return {"message": "Validation invalidated"}
|
||||
|
||||
|
||||
# ── Confidence Scores ────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/confidence", response_model=list[TechniqueConfidenceOut])
|
||||
def list_confidence_scores(
|
||||
confidence_level: Optional[str] = None,
|
||||
min_score: Optional[float] = None,
|
||||
max_score: Optional[float] = None,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
query = db.query(TechniqueConfidenceScore)
|
||||
if confidence_level:
|
||||
query = query.filter(TechniqueConfidenceScore.confidence_level == confidence_level)
|
||||
if min_score is not None:
|
||||
query = query.filter(TechniqueConfidenceScore.confidence_score >= min_score)
|
||||
if max_score is not None:
|
||||
query = query.filter(TechniqueConfidenceScore.confidence_score <= max_score)
|
||||
return query.order_by(TechniqueConfidenceScore.confidence_score.asc()).all()
|
||||
|
||||
|
||||
@router.get("/confidence/{technique_id}", response_model=TechniqueConfidenceOut)
|
||||
def get_technique_confidence(
|
||||
technique_id: UUID,
|
||||
recalculate: bool = Query(False),
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
if recalculate:
|
||||
return decay_engine_service.calculate_confidence_for_technique(db, technique_id)
|
||||
score = db.query(TechniqueConfidenceScore).filter(TechniqueConfidenceScore.technique_id == technique_id).first()
|
||||
if not score:
|
||||
return decay_engine_service.calculate_confidence_for_technique(db, technique_id)
|
||||
return score
|
||||
|
||||
|
||||
# ── Infrastructure Changes ───────────────────────────────────────────────────
|
||||
|
||||
@router.post("/infrastructure-changes", response_model=InfrastructureChangeOut, status_code=201)
|
||||
def report_infrastructure_change(
|
||||
body: InfrastructureChangeCreate,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(require_any_role("admin", "blue_lead")),
|
||||
):
|
||||
change = InfrastructureChangeLog(
|
||||
change_type=body.change_type,
|
||||
description=body.description,
|
||||
affected_platforms=body.affected_platforms,
|
||||
affected_log_sources=body.affected_log_sources,
|
||||
change_date=body.change_date or _now(),
|
||||
auto_invalidate=body.auto_invalidate,
|
||||
reported_by=user.id,
|
||||
)
|
||||
db.add(change)
|
||||
db.commit()
|
||||
db.refresh(change)
|
||||
|
||||
if change.auto_invalidate:
|
||||
decay_engine_service.process_infrastructure_change(db, change.id)
|
||||
db.refresh(change)
|
||||
|
||||
audit_service.log_action(db, user.id, "INFRASTRUCTURE_CHANGE_REPORTED", "infrastructure_change", str(change.id),
|
||||
details={"type": body.change_type, "invalidated_count": change.invalidated_count})
|
||||
|
||||
return change
|
||||
|
||||
|
||||
@router.get("/infrastructure-changes", response_model=list[InfrastructureChangeOut])
|
||||
def list_infrastructure_changes(
|
||||
days: int = Query(90, ge=1, le=730),
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
cutoff = _now() - timedelta(days=days)
|
||||
return db.query(InfrastructureChangeLog).filter(InfrastructureChangeLog.change_date >= cutoff).order_by(InfrastructureChangeLog.change_date.desc()).all()
|
||||
|
||||
|
||||
# ── Decay Engine Control ─────────────────────────────────────────────────────
|
||||
|
||||
@router.post("/decay-engine/run")
|
||||
def trigger_decay_engine(db: Session = Depends(get_db), user=Depends(require_any_role("admin"))):
|
||||
results = decay_engine_service.run_decay_engine(db)
|
||||
return {"message": "Decay engine completed", "results": results}
|
||||
|
||||
|
||||
# ── Dashboard ────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/dashboard")
|
||||
def lifecycle_dashboard(db: Session = Depends(get_db), user=Depends(get_current_user)):
|
||||
now = _now()
|
||||
|
||||
health_dist = dict(
|
||||
db.query(DetectionAsset.health_status, func.count(DetectionAsset.id))
|
||||
.filter(DetectionAsset.is_active == True)
|
||||
.group_by(DetectionAsset.health_status)
|
||||
.all()
|
||||
)
|
||||
confidence_dist = dict(
|
||||
db.query(TechniqueConfidenceScore.confidence_level, func.count(TechniqueConfidenceScore.id))
|
||||
.group_by(TechniqueConfidenceScore.confidence_level)
|
||||
.all()
|
||||
)
|
||||
expiring_soon = db.query(func.count(DetectionValidation.id)).filter(
|
||||
DetectionValidation.is_valid == True,
|
||||
DetectionValidation.expires_at <= (now + timedelta(days=7)),
|
||||
).scalar() or 0
|
||||
|
||||
total_assets = db.query(func.count(DetectionAsset.id)).filter(DetectionAsset.is_active == True).scalar() or 0
|
||||
total_valid = db.query(func.count(DetectionValidation.id)).filter(DetectionValidation.is_valid == True).scalar() or 0
|
||||
recent_changes = db.query(func.count(InfrastructureChangeLog.id)).filter(
|
||||
InfrastructureChangeLog.change_date >= (now - timedelta(days=30))
|
||||
).scalar() or 0
|
||||
|
||||
return {
|
||||
"total_detection_assets": total_assets,
|
||||
"total_valid_validations": total_valid,
|
||||
"health_distribution": {k.value if hasattr(k, "value") else str(k): v for k, v in health_dist.items()},
|
||||
"confidence_distribution": {k.value if hasattr(k, "value") else str(k): v for k, v in confidence_dist.items()},
|
||||
"validations_expiring_7d": expiring_soon,
|
||||
"infrastructure_changes_30d": recent_changes,
|
||||
}
|
||||
@@ -4,7 +4,8 @@ Endpoints
|
||||
---------
|
||||
POST /tests/{test_id}/evidence — upload evidence (with team=red/blue)
|
||||
GET /tests/{test_id}/evidence — list evidences (filterable by team)
|
||||
GET /evidence/{id} — presigned download URL
|
||||
GET /evidence/{id} — metadata + download_url
|
||||
GET /evidence/{id}/file — proxy download (streams file through backend)
|
||||
DELETE /evidence/{id} — delete evidence (only in editable states)
|
||||
|
||||
Access Control
|
||||
@@ -20,11 +21,14 @@ Access Control
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import uuid as _uuid
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, File, Form, Query, Request, UploadFile, status
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
@@ -45,7 +49,9 @@ from app.services.evidence_service import (
|
||||
validate_upload_permission,
|
||||
)
|
||||
from app.limiter import limiter
|
||||
from app.storage import get_presigned_url, upload_file
|
||||
from app.storage import download_file, upload_file
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(tags=["evidence"])
|
||||
|
||||
@@ -55,7 +61,11 @@ router = APIRouter(tags=["evidence"])
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _evidence_to_out(evidence: Evidence) -> EvidenceOut:
|
||||
"""Convert an ORM ``Evidence`` to the API schema, injecting a presigned URL."""
|
||||
"""Convert an ORM ``Evidence`` to the API schema.
|
||||
|
||||
``download_url`` points to the backend proxy endpoint so the browser
|
||||
never needs direct access to MinIO.
|
||||
"""
|
||||
return EvidenceOut(
|
||||
id=evidence.id,
|
||||
test_id=evidence.test_id,
|
||||
@@ -65,7 +75,7 @@ def _evidence_to_out(evidence: Evidence) -> EvidenceOut:
|
||||
uploaded_at=evidence.uploaded_at,
|
||||
team=evidence.team,
|
||||
notes=evidence.notes,
|
||||
download_url=get_presigned_url(evidence.file_path),
|
||||
download_url=f"/api/v1/evidence/{evidence.id}/file",
|
||||
)
|
||||
|
||||
|
||||
@@ -119,6 +129,7 @@ async def upload_evidence(
|
||||
file_path=key,
|
||||
sha256_hash=sha256,
|
||||
uploaded_by=current_user.id,
|
||||
uploaded_at=datetime.utcnow(), # set explicitly — DB column has no server default
|
||||
team=team,
|
||||
notes=notes,
|
||||
)
|
||||
@@ -140,9 +151,43 @@ async def upload_evidence(
|
||||
uow.commit()
|
||||
db.refresh(evidence)
|
||||
|
||||
# 7. Attach to Jira ticket if one exists (non-fatal)
|
||||
_attach_evidence_to_jira(db, test_id, content, safe_name, current_user)
|
||||
|
||||
return _evidence_to_out(evidence)
|
||||
|
||||
|
||||
def _attach_evidence_to_jira(
|
||||
db,
|
||||
test_id: _uuid.UUID,
|
||||
content: bytes,
|
||||
file_name: str,
|
||||
actor,
|
||||
) -> None:
|
||||
"""Attach uploaded evidence to the linked Jira ticket (non-fatal)."""
|
||||
try:
|
||||
from app.services.jira_service import get_test_jira_key, get_user_jira_client, has_jira_configured
|
||||
if not has_jira_configured(actor, db):
|
||||
return
|
||||
issue_key = get_test_jira_key(db, test_id)
|
||||
if not issue_key:
|
||||
return
|
||||
import io
|
||||
jira = get_user_jira_client(actor, db)
|
||||
buf = io.BytesIO(content)
|
||||
buf.name = file_name # requests uses .name as the multipart filename
|
||||
jira.add_attachment_object(issue_key, buf)
|
||||
import logging
|
||||
logging.getLogger(__name__).info(
|
||||
"Attached evidence '%s' to Jira ticket %s", file_name, issue_key
|
||||
)
|
||||
except Exception as exc:
|
||||
import logging
|
||||
logging.getLogger(__name__).warning(
|
||||
"Failed to attach evidence '%s' to Jira: %s", file_name, exc, exc_info=True
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /tests/{test_id}/evidence — list (with optional team filter)
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -162,7 +207,7 @@ def list_evidence(
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /evidence/{id} — presigned download URL
|
||||
# GET /evidence/{id} — metadata + proxy download URL
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@@ -172,11 +217,48 @@ def get_evidence(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Return evidence metadata together with a presigned download URL."""
|
||||
"""Return evidence metadata. ``download_url`` is a backend proxy URL."""
|
||||
evidence = get_evidence_or_raise(db, evidence_id)
|
||||
return _evidence_to_out(evidence)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /evidence/{id}/file — proxy download (streams file via backend)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get("/evidence/{evidence_id}/file")
|
||||
def download_evidence_file(
|
||||
evidence_id: _uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Stream the evidence file through the backend.
|
||||
|
||||
The browser calls this endpoint (authenticated via JWT cookie/header).
|
||||
The backend fetches the file from MinIO internally and streams it back,
|
||||
so MinIO never needs to be publicly accessible.
|
||||
"""
|
||||
import mimetypes
|
||||
|
||||
evidence = get_evidence_or_raise(db, evidence_id)
|
||||
content = download_file(evidence.file_path)
|
||||
|
||||
mime_type, _ = mimetypes.guess_type(evidence.file_name)
|
||||
if not mime_type:
|
||||
mime_type = "application/octet-stream"
|
||||
|
||||
safe_name = evidence.file_name.replace('"', '\\"')
|
||||
return StreamingResponse(
|
||||
iter([content]),
|
||||
media_type=mime_type,
|
||||
headers={
|
||||
"Content-Disposition": f'inline; filename="{safe_name}"',
|
||||
"Content-Length": str(len(content)),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DELETE /evidence/{id} — delete evidence (editable states only)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -0,0 +1,124 @@
|
||||
"""Phase 13: Executive Dashboard router."""
|
||||
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import get_current_user, require_any_role
|
||||
from app.schemas.executive_dashboard_schema import (
|
||||
PostureSnapshotOut,
|
||||
ExecutiveSummary,
|
||||
KpiBlock,
|
||||
CoverageByTactic,
|
||||
PostureHistoryEntry,
|
||||
ActivityEntry,
|
||||
)
|
||||
import app.services.executive_dashboard_service as svc
|
||||
|
||||
router = APIRouter(prefix="/dashboard", tags=["Executive Dashboard"])
|
||||
|
||||
|
||||
@router.get("/executive", response_model=ExecutiveSummary)
|
||||
def executive_view(
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Full executive view — snapshot, 30-day trends, top risks,
|
||||
coverage by tactic, and recent activity feed.
|
||||
"""
|
||||
data = svc.get_executive_summary(db)
|
||||
snap = data["snapshot"]
|
||||
return ExecutiveSummary(
|
||||
snapshot=PostureSnapshotOut.model_validate(snap),
|
||||
coverage_trend=data["coverage_trend"],
|
||||
risk_trend=data["risk_trend"],
|
||||
top_risks=data["top_risks"],
|
||||
coverage_by_tactic=data["coverage_by_tactic"],
|
||||
recent_activity=data["recent_activity"],
|
||||
)
|
||||
|
||||
|
||||
@router.get("/kpis", response_model=KpiBlock)
|
||||
def kpis(
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
"""Compact KPI block — live aggregation without persisting a snapshot."""
|
||||
live = svc.get_live_kpis(db)
|
||||
|
||||
# Try to find today's snapshot id; fall back to None
|
||||
from datetime import date
|
||||
from app.models.executive_dashboard import PostureSnapshot
|
||||
today_snap = db.query(PostureSnapshot).filter(
|
||||
PostureSnapshot.snapshot_date == date.today()
|
||||
).first()
|
||||
|
||||
return KpiBlock(
|
||||
coverage_pct=live["coverage_pct"],
|
||||
avg_risk_score=live["avg_risk_score"],
|
||||
critical_count=live["critical_count"],
|
||||
open_queue_items=live["open_queue_items"],
|
||||
orphan_techniques=live["orphan_techniques"],
|
||||
mttd_avg_seconds=live.get("mttd_avg_seconds"),
|
||||
detection_rate_30d=live.get("detection_rate_30d"),
|
||||
playbook_count=live["playbook_count"],
|
||||
lesson_count=live["lesson_count"],
|
||||
snapshot_date=live["snapshot_date"],
|
||||
snapshot_id=today_snap.id if today_snap else None,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/coverage-by-tactic", response_model=List[CoverageByTactic])
|
||||
def coverage_by_tactic(
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
"""Per-tactic validated / partial / not_covered breakdown."""
|
||||
return svc.get_coverage_by_tactic(db)
|
||||
|
||||
|
||||
@router.get("/posture-history", response_model=List[PostureHistoryEntry])
|
||||
def posture_history(
|
||||
days: int = Query(30, ge=1, le=365),
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
"""Historical posture snapshots for trend charts (default last 30 days)."""
|
||||
snaps = svc.get_posture_history(db, days=days)
|
||||
return [
|
||||
PostureHistoryEntry(
|
||||
snapshot_date=s.snapshot_date,
|
||||
coverage_pct=s.coverage_pct,
|
||||
avg_risk_score=s.avg_risk_score,
|
||||
critical_count=s.critical_count,
|
||||
open_queue_items=s.open_queue_items,
|
||||
)
|
||||
for s in snaps
|
||||
]
|
||||
|
||||
|
||||
@router.post("/posture-snapshot", response_model=PostureSnapshotOut, status_code=201)
|
||||
def create_posture_snapshot(
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(require_any_role("admin", "red_lead", "blue_lead")),
|
||||
):
|
||||
"""
|
||||
Take (or refresh) today's posture snapshot — admin / leads only.
|
||||
Aggregates live data from all phases into a single PostureSnapshot row.
|
||||
"""
|
||||
snap = svc.take_posture_snapshot(db, created_by=user.id)
|
||||
return PostureSnapshotOut.model_validate(snap)
|
||||
|
||||
|
||||
@router.get("/activity", response_model=List[ActivityEntry])
|
||||
def recent_activity(
|
||||
limit: int = Query(20, ge=1, le=100),
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
"""Recent activity feed — tests, attack-path executions, OSINT signals."""
|
||||
return svc.get_recent_activity(db, limit=limit)
|
||||
@@ -0,0 +1,54 @@
|
||||
"""Intel items endpoints — list and manage threat intelligence items."""
|
||||
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import get_current_user
|
||||
from app.models.intel import IntelItem
|
||||
from app.models.user import User
|
||||
|
||||
router = APIRouter(prefix="/intel", tags=["intel"])
|
||||
|
||||
|
||||
class IntelItemOut(BaseModel):
|
||||
id: uuid.UUID
|
||||
technique_id: Optional[uuid.UUID] = None
|
||||
url: str
|
||||
title: Optional[str] = None
|
||||
source: Optional[str] = None
|
||||
detected_at: Optional[str] = None
|
||||
reviewed: bool
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
@router.get("/items", response_model=list[IntelItemOut])
|
||||
def list_intel_items(
|
||||
technique_id: Optional[uuid.UUID] = Query(None, description="Filter by technique"),
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""List threat intelligence items, optionally filtered by technique."""
|
||||
query = db.query(IntelItem).order_by(IntelItem.detected_at.desc())
|
||||
if technique_id:
|
||||
query = query.filter(IntelItem.technique_id == technique_id)
|
||||
items = query.limit(limit).all()
|
||||
return [
|
||||
IntelItemOut(
|
||||
id=item.id,
|
||||
technique_id=item.technique_id,
|
||||
url=item.url,
|
||||
title=item.title,
|
||||
source=item.source,
|
||||
detected_at=item.detected_at.isoformat() if item.detected_at else None,
|
||||
reviewed=item.reviewed,
|
||||
)
|
||||
for item in items
|
||||
]
|
||||
@@ -72,14 +72,16 @@ def create_link(
|
||||
def list_links(
|
||||
entity_type: Optional[JiraLinkEntityType] = None,
|
||||
entity_id: Optional[UUID] = None,
|
||||
entity_ids: Optional[list[UUID]] = Query(default=None, description="Filter by multiple entity IDs"),
|
||||
db: Session = Depends(get_db),
|
||||
user: User = Depends(get_current_user),
|
||||
):
|
||||
"""List Jira links, optionally filtered by entity."""
|
||||
"""List Jira links, optionally filtered by entity or a list of entity IDs."""
|
||||
return jira_service.list_links(
|
||||
db,
|
||||
entity_type=entity_type,
|
||||
entity_id=entity_id,
|
||||
entity_ids=entity_ids,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,206 @@
|
||||
"""Phase 11: Knowledge Management router — Playbooks + Lessons Learned."""
|
||||
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import get_current_user, require_any_role
|
||||
from app.schemas.knowledge_schema import (
|
||||
PlaybookCreate, PlaybookUpdate, PlaybookOut, PlaybookVersionOut,
|
||||
LessonLearnedCreate, LessonLearnedUpdate, LessonLearnedOut,
|
||||
)
|
||||
from app.services import playbook_service as pb_svc
|
||||
from app.services import lesson_learned_service as ll_svc
|
||||
|
||||
router = APIRouter(prefix="/knowledge", tags=["knowledge"])
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════════════════════════════
|
||||
# Playbooks
|
||||
# ══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
@router.get("/playbooks", response_model=List[PlaybookOut])
|
||||
def list_playbooks(
|
||||
technique_id: Optional[UUID] = None,
|
||||
playbook_type: Optional[str] = None,
|
||||
include_inactive: bool = False,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
return pb_svc.list_playbooks(
|
||||
db,
|
||||
technique_id=technique_id,
|
||||
playbook_type=playbook_type,
|
||||
include_inactive=include_inactive,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/playbooks", response_model=PlaybookOut, status_code=201)
|
||||
def create_playbook(
|
||||
body: PlaybookCreate,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(require_any_role("admin", "red_lead", "blue_lead")),
|
||||
):
|
||||
return pb_svc.create_playbook(db, body.model_dump(), user.id)
|
||||
|
||||
|
||||
@router.get("/playbooks/{playbook_id}", response_model=PlaybookOut)
|
||||
def get_playbook(
|
||||
playbook_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
return pb_svc.get_playbook(db, playbook_id)
|
||||
|
||||
|
||||
@router.patch("/playbooks/{playbook_id}", response_model=PlaybookOut)
|
||||
def update_playbook(
|
||||
playbook_id: UUID,
|
||||
body: PlaybookUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(require_any_role("admin", "red_lead", "blue_lead")),
|
||||
):
|
||||
return pb_svc.update_playbook(db, playbook_id, body.model_dump(exclude_unset=True), user.id)
|
||||
|
||||
|
||||
@router.delete("/playbooks/{playbook_id}", status_code=204)
|
||||
def delete_playbook(
|
||||
playbook_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(require_any_role("admin", "red_lead", "blue_lead")),
|
||||
):
|
||||
pb_svc.delete_playbook(db, playbook_id, user.id)
|
||||
|
||||
|
||||
# ── Versions ──────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/playbooks/{playbook_id}/versions", response_model=List[PlaybookVersionOut])
|
||||
def list_versions(
|
||||
playbook_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
return pb_svc.get_playbook_versions(db, playbook_id)
|
||||
|
||||
|
||||
@router.post("/playbooks/{playbook_id}/restore/{version}", response_model=PlaybookOut)
|
||||
def restore_version(
|
||||
playbook_id: UUID,
|
||||
version: int,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(require_any_role("admin", "red_lead", "blue_lead")),
|
||||
):
|
||||
"""Roll the playbook back to a specific historical version."""
|
||||
return pb_svc.restore_version(db, playbook_id, version, user.id)
|
||||
|
||||
|
||||
# ── By technique (convenience) ────────────────────────────────────────────────
|
||||
|
||||
@router.get(
|
||||
"/techniques/{technique_id}/playbooks",
|
||||
response_model=List[PlaybookOut],
|
||||
)
|
||||
def playbooks_for_technique(
|
||||
technique_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
"""List all active playbooks for a specific technique."""
|
||||
return pb_svc.list_playbooks(db, technique_id=technique_id)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/techniques/{technique_id}/playbooks/{playbook_type}",
|
||||
response_model=PlaybookOut,
|
||||
)
|
||||
def get_playbook_by_technique_type(
|
||||
technique_id: UUID,
|
||||
playbook_type: str,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
pb = pb_svc.get_playbook_by_technique_type(db, technique_id, playbook_type)
|
||||
if not pb:
|
||||
from app.domain.errors import EntityNotFoundError
|
||||
raise EntityNotFoundError("Playbook", f"{technique_id}/{playbook_type}")
|
||||
return pb
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════════════════════════════
|
||||
# Lessons Learned
|
||||
# ══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
@router.get("/lessons", response_model=List[LessonLearnedOut])
|
||||
def list_lessons(
|
||||
entity_type: Optional[str] = None,
|
||||
entity_id: Optional[UUID] = None,
|
||||
severity: Optional[str] = None,
|
||||
tag: Optional[str] = None,
|
||||
technique_id: Optional[str] = None,
|
||||
include_inactive: bool = False,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
return ll_svc.list_lessons_learned(
|
||||
db,
|
||||
entity_type=entity_type,
|
||||
entity_id=entity_id,
|
||||
severity=severity,
|
||||
tag=tag,
|
||||
technique_id=technique_id,
|
||||
include_inactive=include_inactive,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/lessons", response_model=LessonLearnedOut, status_code=201)
|
||||
def create_lesson(
|
||||
body: LessonLearnedCreate,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(require_any_role("admin", "red_lead", "blue_lead")),
|
||||
):
|
||||
return ll_svc.create_lesson_learned(db, body.model_dump(), user.id)
|
||||
|
||||
|
||||
@router.get("/lessons/{lesson_id}", response_model=LessonLearnedOut)
|
||||
def get_lesson(
|
||||
lesson_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
return ll_svc.get_lesson_learned(db, lesson_id)
|
||||
|
||||
|
||||
@router.patch("/lessons/{lesson_id}", response_model=LessonLearnedOut)
|
||||
def update_lesson(
|
||||
lesson_id: UUID,
|
||||
body: LessonLearnedUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(require_any_role("admin", "red_lead", "blue_lead")),
|
||||
):
|
||||
return ll_svc.update_lesson_learned(
|
||||
db, lesson_id, body.model_dump(exclude_unset=True), user.id
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/lessons/{lesson_id}", status_code=204)
|
||||
def delete_lesson(
|
||||
lesson_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(require_any_role("admin", "red_lead", "blue_lead")),
|
||||
):
|
||||
"""Soft-delete a lesson (admin / lead only)."""
|
||||
ll_svc.delete_lesson_learned(db, lesson_id, user.id)
|
||||
|
||||
|
||||
# ── Stats ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/stats")
|
||||
def knowledge_stats(
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
"""Summary counts: total playbooks, lessons by severity, playbooks by type."""
|
||||
return ll_svc.get_knowledge_stats(db)
|
||||
@@ -0,0 +1,191 @@
|
||||
"""Phase 13: Operational Alerts router."""
|
||||
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import get_current_user, require_any_role
|
||||
from app.models.user import User
|
||||
from app.schemas.operational_alert_schema import (
|
||||
AlertRuleCreate, AlertRuleOut, AlertRuleUpdate,
|
||||
AlertInstanceOut, EvaluationResult, AlertSummary,
|
||||
)
|
||||
import app.services.operational_alert_service as svc
|
||||
|
||||
router = APIRouter(prefix="/alerts", tags=["Operational Alerts"])
|
||||
|
||||
|
||||
# ── Evaluation ────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.post("/evaluate", response_model=EvaluationResult, status_code=202)
|
||||
def evaluate_rules(
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(require_any_role("admin", "red_lead", "blue_lead")),
|
||||
):
|
||||
"""
|
||||
Run the alert evaluation engine against all enabled rules.
|
||||
|
||||
Fires AlertInstances for rules whose conditions are met and are not in cooldown.
|
||||
Admin / leads only.
|
||||
"""
|
||||
result = svc.evaluate_all_rules(db)
|
||||
return EvaluationResult(
|
||||
rules_evaluated = result["rules_evaluated"],
|
||||
alerts_fired = result["alerts_fired"],
|
||||
alerts = [AlertInstanceOut.model_validate(a) for a in result["alerts"]],
|
||||
duration_seconds = result["duration_seconds"],
|
||||
)
|
||||
|
||||
|
||||
# ── Alert instances ───────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("", response_model=List[AlertInstanceOut])
|
||||
def list_alerts(
|
||||
status: Optional[str] = Query(None),
|
||||
severity: Optional[str] = Query(None),
|
||||
rule_type: Optional[str] = Query(None),
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
offset: int = Query(0, ge=0),
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
"""List alert instances with optional filters."""
|
||||
return svc.list_instances(db, status=status, severity=severity,
|
||||
rule_type=rule_type, limit=limit, offset=offset)
|
||||
|
||||
|
||||
@router.get("/summary", response_model=AlertSummary)
|
||||
def alert_summary(
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
"""Aggregate counts by status, severity, and rule type."""
|
||||
data = svc.get_summary(db)
|
||||
return AlertSummary(
|
||||
total_open = data["total_open"],
|
||||
total_acknowledged = data["total_acknowledged"],
|
||||
total_resolved = data["total_resolved"],
|
||||
by_severity = data["by_severity"],
|
||||
by_rule_type = data["by_rule_type"],
|
||||
recent_alerts = [AlertInstanceOut.model_validate(a) for a in data["recent_alerts"]],
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{alert_id}", response_model=AlertInstanceOut)
|
||||
def get_alert(
|
||||
alert_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
"""Get a single alert instance."""
|
||||
return svc.get_instance(db, alert_id)
|
||||
|
||||
|
||||
@router.post("/{alert_id}/acknowledge", response_model=AlertInstanceOut)
|
||||
def acknowledge_alert(
|
||||
alert_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_any_role("admin", "red_lead", "blue_lead")),
|
||||
):
|
||||
"""Acknowledge an open alert (admin / lead roles only)."""
|
||||
return svc.acknowledge(db, alert_id, current_user.id)
|
||||
|
||||
|
||||
@router.post("/{alert_id}/resolve", response_model=AlertInstanceOut)
|
||||
def resolve_alert(
|
||||
alert_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_any_role("admin", "red_lead", "blue_lead")),
|
||||
):
|
||||
"""Mark an alert as resolved (admin / lead roles only)."""
|
||||
return svc.resolve(db, alert_id, current_user.id)
|
||||
|
||||
|
||||
@router.post("/{alert_id}/dismiss", response_model=AlertInstanceOut)
|
||||
def dismiss_alert(
|
||||
alert_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_any_role("admin", "red_lead", "blue_lead")),
|
||||
):
|
||||
"""Dismiss an alert (admin / lead roles only — won't re-fire until cooldown resets)."""
|
||||
return svc.dismiss(db, alert_id, current_user.id)
|
||||
|
||||
|
||||
# ── Alert rules ───────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/rules/list", response_model=List[AlertRuleOut])
|
||||
def list_rules(
|
||||
rule_type: Optional[str] = Query(None),
|
||||
include_disabled: bool = Query(False),
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
"""List alert rules (all users can read; admins/leads manage them)."""
|
||||
return svc.list_rules(db, rule_type=rule_type, include_disabled=include_disabled)
|
||||
|
||||
|
||||
@router.post("/rules", response_model=AlertRuleOut, status_code=201)
|
||||
def create_rule(
|
||||
body: AlertRuleCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_any_role("admin", "red_lead", "blue_lead")),
|
||||
):
|
||||
"""Create a custom alert rule."""
|
||||
return svc.create_rule(
|
||||
db,
|
||||
created_by = current_user.id,
|
||||
name = body.name,
|
||||
description = body.description,
|
||||
rule_type = body.rule_type,
|
||||
severity = body.severity,
|
||||
config = body.config,
|
||||
notify_in_app = body.notify_in_app,
|
||||
notify_webhook = body.notify_webhook,
|
||||
webhook_id = body.webhook_id,
|
||||
cooldown_hours = body.cooldown_hours,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/rules/{rule_id}", response_model=AlertRuleOut)
|
||||
def get_rule(
|
||||
rule_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
"""Get a single alert rule."""
|
||||
return svc.get_rule(db, rule_id)
|
||||
|
||||
|
||||
@router.patch("/rules/{rule_id}", response_model=AlertRuleOut)
|
||||
def update_rule(
|
||||
rule_id: UUID,
|
||||
body: AlertRuleUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(require_any_role("admin", "red_lead", "blue_lead")),
|
||||
):
|
||||
"""Update an alert rule (enable/disable, thresholds, cooldown)."""
|
||||
return svc.update_rule(
|
||||
db, rule_id,
|
||||
name = body.name,
|
||||
description = body.description,
|
||||
severity = body.severity,
|
||||
is_enabled = body.is_enabled,
|
||||
config = body.config,
|
||||
notify_in_app = body.notify_in_app,
|
||||
notify_webhook = body.notify_webhook,
|
||||
webhook_id = body.webhook_id,
|
||||
cooldown_hours = body.cooldown_hours,
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/rules/{rule_id}", status_code=204)
|
||||
def delete_rule(
|
||||
rule_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(require_any_role("admin")),
|
||||
):
|
||||
"""Delete a custom alert rule (system rules cannot be deleted)."""
|
||||
svc.delete_rule(db, rule_id)
|
||||
@@ -0,0 +1,216 @@
|
||||
"""Phase 9: Ownership & Daily Operations router."""
|
||||
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import get_current_user, require_any_role
|
||||
from app.domain.exceptions import EntityNotFoundError
|
||||
from app.schemas.ownership_queue_schema import (
|
||||
TechniqueOwnershipSet, TechniqueOwnershipOut,
|
||||
DetectionAssetOwnershipPatch,
|
||||
BulkAssignRequest, BulkAssignResult,
|
||||
QueueItemCreate, QueueItemPatch, QueueItemOut,
|
||||
AnalystDashboard,
|
||||
)
|
||||
from app.services import ownership_service, revalidation_queue_service
|
||||
from app.models.ownership_queue import RevalidationQueueItem
|
||||
|
||||
router = APIRouter(prefix="/ownership", tags=["ownership"])
|
||||
|
||||
|
||||
# ── Technique Ownership ───────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/techniques/{technique_id}", response_model=TechniqueOwnershipOut)
|
||||
def get_technique_ownership(
|
||||
technique_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
ownership = ownership_service.get_technique_ownership(db, technique_id)
|
||||
if not ownership:
|
||||
raise EntityNotFoundError("TechniqueOwnership", str(technique_id))
|
||||
return ownership
|
||||
|
||||
|
||||
@router.put("/techniques/{technique_id}", response_model=TechniqueOwnershipOut)
|
||||
def set_technique_ownership(
|
||||
technique_id: UUID,
|
||||
body: TechniqueOwnershipSet,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(require_any_role("admin", "blue_lead", "red_lead")),
|
||||
):
|
||||
return ownership_service.set_technique_ownership(
|
||||
db, technique_id,
|
||||
owner_id=body.owner_id,
|
||||
backup_owner_id=body.backup_owner_id,
|
||||
team=body.team,
|
||||
notes=body.notes,
|
||||
assigned_by=user.id,
|
||||
)
|
||||
|
||||
|
||||
# ── Detection Asset Ownership ─────────────────────────────────────────────────
|
||||
|
||||
@router.patch("/assets/{asset_id}", response_model=dict)
|
||||
def set_asset_ownership(
|
||||
asset_id: UUID,
|
||||
body: DetectionAssetOwnershipPatch,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(require_any_role("admin", "blue_lead")),
|
||||
):
|
||||
ownership_service.set_asset_ownership(
|
||||
db, asset_id,
|
||||
owner_id=body.owner_id,
|
||||
backup_owner_id=body.backup_owner_id,
|
||||
team=body.team,
|
||||
user_id=user.id,
|
||||
)
|
||||
return {"message": "Asset ownership updated"}
|
||||
|
||||
|
||||
# ── Orphan Reports ────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/orphans/techniques", response_model=list[dict])
|
||||
def orphan_techniques(
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
"""Return techniques with no assigned owner."""
|
||||
return ownership_service.get_orphan_techniques(db)
|
||||
|
||||
|
||||
@router.get("/orphans/assets", response_model=list[dict])
|
||||
def orphan_assets(
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
"""Return active detection assets with no assigned owner."""
|
||||
return ownership_service.get_orphan_assets(db)
|
||||
|
||||
|
||||
# ── Bulk Assignment ───────────────────────────────────────────────────────────
|
||||
|
||||
@router.post("/bulk-assign", response_model=BulkAssignResult)
|
||||
def bulk_assign(
|
||||
body: BulkAssignRequest,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(require_any_role("admin", "blue_lead", "red_lead")),
|
||||
):
|
||||
"""
|
||||
Bulk-assign ownership.
|
||||
- If `tactic` is set → assigns technique ownership for all techniques of that tactic.
|
||||
- If `platform` is set → assigns asset ownership for all assets on that platform.
|
||||
At least one of tactic/platform must be provided.
|
||||
"""
|
||||
if not body.tactic and not body.platform:
|
||||
from fastapi import HTTPException
|
||||
raise HTTPException(status_code=422, detail="Provide at least one of: tactic, platform")
|
||||
|
||||
if body.tactic:
|
||||
result = ownership_service.bulk_assign_techniques_by_tactic(
|
||||
db, body.tactic,
|
||||
owner_id=body.owner_id,
|
||||
backup_owner_id=body.backup_owner_id,
|
||||
team=body.team,
|
||||
overwrite=body.overwrite,
|
||||
user_id=user.id,
|
||||
)
|
||||
else:
|
||||
result = ownership_service.bulk_assign_assets_by_platform(
|
||||
db, body.platform,
|
||||
owner_id=body.owner_id,
|
||||
backup_owner_id=body.backup_owner_id,
|
||||
team=body.team,
|
||||
overwrite=body.overwrite,
|
||||
user_id=user.id,
|
||||
)
|
||||
|
||||
return BulkAssignResult(**result)
|
||||
|
||||
|
||||
# ── Revalidation Queue ────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/queue", response_model=list[QueueItemOut])
|
||||
def list_queue(
|
||||
status: Optional[str] = Query(None),
|
||||
priority: Optional[str] = Query(None),
|
||||
reason: Optional[str] = Query(None),
|
||||
assigned_to: Optional[UUID] = Query(None),
|
||||
technique_id: Optional[UUID] = Query(None),
|
||||
detection_asset_id: Optional[UUID] = Query(None),
|
||||
limit: int = Query(100, ge=1, le=500),
|
||||
offset: int = Query(0, ge=0),
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
return revalidation_queue_service.list_queue(
|
||||
db, status=status, priority=priority, reason=reason,
|
||||
assigned_to=assigned_to, technique_id=technique_id,
|
||||
detection_asset_id=detection_asset_id, limit=limit, offset=offset,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/queue", response_model=QueueItemOut, status_code=201)
|
||||
def create_queue_item(
|
||||
body: QueueItemCreate,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
return revalidation_queue_service.create_queue_item(db, body.model_dump(), user.id)
|
||||
|
||||
|
||||
@router.patch("/queue/{item_id}", response_model=QueueItemOut)
|
||||
def update_queue_item(
|
||||
item_id: UUID,
|
||||
body: QueueItemPatch,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
return revalidation_queue_service.update_queue_item(db, item_id, body.model_dump(exclude_unset=True), user.id)
|
||||
|
||||
|
||||
@router.post("/queue/generate", response_model=dict)
|
||||
def generate_queue(
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(require_any_role("admin", "blue_lead")),
|
||||
):
|
||||
"""Scan the system and create new revalidation queue items."""
|
||||
return revalidation_queue_service.generate_queue_items(db)
|
||||
|
||||
|
||||
# ── Analyst Dashboard ─────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/analyst-dashboard")
|
||||
def analyst_dashboard(
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
"""Personalised daily workday view: my queue, expiring validations, infra changes, low-confidence techniques."""
|
||||
dashboard = revalidation_queue_service.get_analyst_dashboard(db, user.id)
|
||||
|
||||
# Serialize queue items to dicts (ORM objects → plain dicts)
|
||||
def _item_to_dict(item: RevalidationQueueItem) -> dict:
|
||||
return {
|
||||
"id": str(item.id),
|
||||
"technique_id": str(item.technique_id) if item.technique_id else None,
|
||||
"detection_asset_id": str(item.detection_asset_id) if item.detection_asset_id else None,
|
||||
"priority": item.priority.value if hasattr(item.priority, "value") else item.priority,
|
||||
"reason": item.reason.value if hasattr(item.reason, "value") else item.reason,
|
||||
"reason_detail": item.reason_detail,
|
||||
"status": item.status.value if hasattr(item.status, "value") else item.status,
|
||||
"assigned_to": str(item.assigned_to) if item.assigned_to else None,
|
||||
"due_date": item.due_date.isoformat() if item.due_date else None,
|
||||
"created_at": item.created_at.isoformat() if item.created_at else None,
|
||||
}
|
||||
|
||||
return {
|
||||
"my_pending_items": [_item_to_dict(i) for i in dashboard["my_pending_items"]],
|
||||
"expiring_validations_7d": dashboard["expiring_validations_7d"],
|
||||
"recent_infra_changes": dashboard["recent_infra_changes"],
|
||||
"my_low_confidence_techniques": dashboard["my_low_confidence_techniques"],
|
||||
"summary": dashboard["summary"],
|
||||
}
|
||||
@@ -0,0 +1,114 @@
|
||||
"""Phase 12: Risk Intelligence router."""
|
||||
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import get_current_user, require_any_role
|
||||
from app.schemas.risk_schema import (
|
||||
TechniqueRiskProfileOut,
|
||||
RiskSummary,
|
||||
ComputeResult,
|
||||
)
|
||||
from app.services import risk_intelligence_service as svc
|
||||
|
||||
router = APIRouter(prefix="/risk", tags=["risk-intelligence"])
|
||||
|
||||
|
||||
# ── Compute ──────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.post("/compute", response_model=ComputeResult, status_code=202)
|
||||
def compute_all(
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(require_any_role("admin", "red_lead", "blue_lead")),
|
||||
):
|
||||
"""Recompute risk scores for ALL techniques (admin / leads only)."""
|
||||
result = svc.compute_all_risk_scores(db)
|
||||
return result
|
||||
|
||||
|
||||
@router.post("/profiles/{technique_id}/compute", response_model=TechniqueRiskProfileOut)
|
||||
def compute_one(
|
||||
technique_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
"""Compute (or refresh) the risk profile for a single technique."""
|
||||
return svc.compute_technique_risk(db, technique_id)
|
||||
|
||||
|
||||
# ── Read ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/profiles", response_model=List[TechniqueRiskProfileOut])
|
||||
def list_profiles(
|
||||
risk_level: Optional[str] = None,
|
||||
min_score: Optional[float] = None,
|
||||
max_score: Optional[float] = None,
|
||||
stale_only: bool = False,
|
||||
limit: int = Query(100, ge=1, le=500),
|
||||
offset: int = Query(0, ge=0),
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
"""List risk profiles with optional filters."""
|
||||
return svc.list_risk_profiles(
|
||||
db,
|
||||
risk_level=risk_level,
|
||||
min_score=min_score,
|
||||
max_score=max_score,
|
||||
stale_only=stale_only,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/profiles/{technique_id}", response_model=TechniqueRiskProfileOut)
|
||||
def get_profile(
|
||||
technique_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
"""Get the current risk profile for a technique."""
|
||||
return svc.get_risk_profile(db, technique_id)
|
||||
|
||||
|
||||
@router.get("/matrix")
|
||||
def risk_matrix(
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
"""All profiled techniques with likelihood/impact coordinates for matrix view."""
|
||||
return svc.get_risk_matrix(db)
|
||||
|
||||
|
||||
@router.get("/summary")
|
||||
def risk_summary(
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
"""Aggregate risk statistics: counts by level, average score, top risks."""
|
||||
return svc.get_risk_summary(db)
|
||||
|
||||
|
||||
@router.get("/recommendations")
|
||||
def recommendations(
|
||||
limit: int = Query(20, ge=1, le=100),
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
"""Prioritised list of techniques with actionable recommendations."""
|
||||
return svc.get_recommendations(db, limit=limit)
|
||||
|
||||
|
||||
@router.get("/top")
|
||||
def top_risks(
|
||||
limit: int = Query(10, ge=1, le=50),
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
"""Top N highest-risk techniques (sorted by risk score desc)."""
|
||||
profiles = svc.list_risk_profiles(db, limit=limit)
|
||||
return profiles
|
||||
@@ -0,0 +1,131 @@
|
||||
"""Phase 14: SSO / SAML 2.0 router."""
|
||||
|
||||
import os
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
||||
from fastapi.responses import RedirectResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import get_current_user, require_any_role
|
||||
from app import auth as auth_lib
|
||||
from app.schemas.sso_schema import (
|
||||
SsoConfigCreate, SsoConfigOut, SsoLoginInitResponse, SsoStatusResponse,
|
||||
)
|
||||
import app.services.sso_service as svc
|
||||
|
||||
router = APIRouter(prefix="/sso", tags=["SSO"])
|
||||
|
||||
_COOKIE_NAME = "aegis_token"
|
||||
|
||||
# Mirror the same SECURE_COOKIES logic used in the auth router so that
|
||||
# SAML-authenticated sessions respect the deployment's HTTPS configuration.
|
||||
_aegis_env = os.environ.get("AEGIS_ENV", "development").lower()
|
||||
_secure_cookie_env = os.environ.get("SECURE_COOKIES", "auto").lower()
|
||||
if _secure_cookie_env == "false":
|
||||
_IS_HTTPS = False
|
||||
elif _secure_cookie_env == "true":
|
||||
_IS_HTTPS = True
|
||||
else: # "auto" — active only when AEGIS_ENV=production
|
||||
_IS_HTTPS = _aegis_env == "production"
|
||||
|
||||
_COOKIE_OPTS = {"httponly": True, "samesite": "lax", "secure": _IS_HTTPS}
|
||||
|
||||
|
||||
# ── Public ────────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/status", response_model=SsoStatusResponse)
|
||||
def sso_status(db: Session = Depends(get_db)):
|
||||
"""Return whether SSO is enabled and configured (public — for login page)."""
|
||||
return svc.get_status(db)
|
||||
|
||||
|
||||
@router.get("/metadata", response_model=None)
|
||||
def sp_metadata(db: Session = Depends(get_db)):
|
||||
"""
|
||||
Return the Service Provider SAML metadata XML.
|
||||
|
||||
Upload this XML to your IdP (Okta, Azure AD, etc.) to register Aegis.
|
||||
"""
|
||||
try:
|
||||
xml = svc.get_sp_metadata(db)
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=503, detail=str(exc))
|
||||
return Response(content=xml, media_type="application/xml")
|
||||
|
||||
|
||||
@router.get("/login")
|
||||
def sso_login(request: Request, db: Session = Depends(get_db)):
|
||||
"""
|
||||
Initiate SAML login — redirects the browser to the IdP.
|
||||
|
||||
The IdP will POST the SAML Response to ``/sso/callback`` after authentication.
|
||||
"""
|
||||
request_data = {
|
||||
"https": request.url.scheme == "https",
|
||||
"http_host": request.url.hostname,
|
||||
"path": request.url.path,
|
||||
"port": str(request.url.port or (443 if request.url.scheme == "https" else 80)),
|
||||
"get_data": dict(request.query_params),
|
||||
"post_data": {},
|
||||
"query_string": str(request.url.query),
|
||||
}
|
||||
try:
|
||||
result = svc.initiate_login(db, request_data)
|
||||
except RuntimeError as exc:
|
||||
raise HTTPException(status_code=503, detail=str(exc))
|
||||
return RedirectResponse(url=result["redirect_url"])
|
||||
|
||||
|
||||
@router.post("/callback")
|
||||
async def sso_callback(request: Request, db: Session = Depends(get_db)):
|
||||
"""
|
||||
SAML Assertion Consumer Service (ACS) endpoint.
|
||||
|
||||
The IdP POSTs the SAML Response here. On success, sets the aegis_token
|
||||
cookie and redirects to the frontend.
|
||||
"""
|
||||
form = await request.form()
|
||||
request_data = {
|
||||
"https": request.url.scheme == "https",
|
||||
"http_host": request.url.hostname,
|
||||
"path": request.url.path,
|
||||
"port": str(request.url.port or (443 if request.url.scheme == "https" else 80)),
|
||||
"get_data": dict(request.query_params),
|
||||
"post_data": dict(form),
|
||||
"query_string": str(request.url.query),
|
||||
}
|
||||
try:
|
||||
user = svc.process_callback(db, request_data)
|
||||
except (ValueError, RuntimeError) as exc:
|
||||
raise HTTPException(status_code=401, detail=str(exc))
|
||||
|
||||
access_token = auth_lib.create_access_token({"sub": user.username})
|
||||
response = RedirectResponse(url="/", status_code=302)
|
||||
response.set_cookie(_COOKIE_NAME, access_token, **_COOKIE_OPTS)
|
||||
return response
|
||||
|
||||
|
||||
# ── Admin configuration ────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/config", response_model=SsoConfigOut)
|
||||
def get_sso_config(
|
||||
db: Session = Depends(get_db),
|
||||
_user=Depends(require_any_role("admin")),
|
||||
):
|
||||
"""Return the current SSO configuration (admin only)."""
|
||||
cfg = svc.get_config(db)
|
||||
if not cfg:
|
||||
raise HTTPException(status_code=404, detail="SSO not configured yet")
|
||||
return SsoConfigOut.model_validate(cfg)
|
||||
|
||||
|
||||
@router.put("/config", response_model=SsoConfigOut)
|
||||
def upsert_sso_config(
|
||||
body: SsoConfigCreate,
|
||||
db: Session = Depends(get_db),
|
||||
_user=Depends(require_any_role("admin")),
|
||||
):
|
||||
"""Create or replace the SSO configuration (admin only)."""
|
||||
cfg = svc.upsert_config(db, **body.model_dump(exclude_unset=False))
|
||||
return SsoConfigOut.model_validate(cfg)
|
||||
+695
-11
@@ -3,15 +3,20 @@
|
||||
Provides manual triggers for background operations such as the MITRE
|
||||
ATT&CK synchronisation, intel scanning, Atomic Red Team import, and
|
||||
scheduler health introspection.
|
||||
|
||||
Also exposes email configuration CRUD (admin only) that writes to the
|
||||
system_configs table so settings survive container restarts.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request, status
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import require_role
|
||||
from app.database import SessionLocal, get_db
|
||||
from app.dependencies.auth import get_current_user, require_role
|
||||
from app.models.user import User
|
||||
from app.services.mitre_sync_service import sync_mitre
|
||||
from app.services.intel_service import scan_intel
|
||||
@@ -24,25 +29,101 @@ logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/system", tags=["system"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pydantic schemas for email config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class EmailConfigOut(BaseModel):
|
||||
enabled: bool
|
||||
host: str
|
||||
port: int
|
||||
username: str
|
||||
from_email: str
|
||||
use_tls: bool
|
||||
# password is never returned
|
||||
|
||||
|
||||
class EmailConfigUpdate(BaseModel):
|
||||
enabled: Optional[bool] = None
|
||||
host: Optional[str] = None
|
||||
port: Optional[int] = None
|
||||
username: Optional[str] = None
|
||||
password: Optional[str] = None
|
||||
from_email: Optional[str] = None
|
||||
use_tls: Optional[bool] = None
|
||||
|
||||
|
||||
class EmailTestRequest(BaseModel):
|
||||
to: str
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers for system_configs CRUD
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_SMTP_KEYS = {
|
||||
"enabled": "smtp.enabled",
|
||||
"host": "smtp.host",
|
||||
"port": "smtp.port",
|
||||
"username": "smtp.username",
|
||||
"password": "smtp.password",
|
||||
"from_email": "smtp.from_email",
|
||||
"use_tls": "smtp.use_tls",
|
||||
}
|
||||
|
||||
|
||||
def _upsert_config(db: Session, key: str, value: str) -> None:
|
||||
from app.models.system_config import SystemConfig # lazy import avoids circular
|
||||
|
||||
row = db.query(SystemConfig).filter(SystemConfig.key == key).first()
|
||||
if row:
|
||||
row.value = value
|
||||
else:
|
||||
row = SystemConfig(key=key, value=value)
|
||||
db.add(row)
|
||||
|
||||
|
||||
def _read_email_config_from_db(db: Session) -> dict:
|
||||
"""Return a dict with resolved email settings (DB overrides env)."""
|
||||
from app.services.email_service import _get_smtp_config
|
||||
|
||||
return _get_smtp_config(db)
|
||||
|
||||
|
||||
def _bg_mitre_sync() -> None:
|
||||
"""Run MITRE sync in a background task with its own DB session."""
|
||||
logger.info("Background MITRE sync task starting...")
|
||||
db = SessionLocal()
|
||||
try:
|
||||
summary = sync_mitre(db)
|
||||
logger.info("Background MITRE sync task finished — %s", summary)
|
||||
except Exception:
|
||||
logger.exception("Background MITRE sync task failed")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@router.post("/sync-mitre")
|
||||
@limiter.limit("2/hour")
|
||||
def trigger_mitre_sync(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
background_tasks: BackgroundTasks,
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Manually trigger a MITRE ATT&CK synchronisation.
|
||||
"""Manually trigger a MITRE ATT&CK synchronisation in the background.
|
||||
|
||||
**Requires** the ``admin`` role.
|
||||
|
||||
Returns a JSON object with the sync summary including the count of
|
||||
new and updated techniques.
|
||||
Returns immediately — the sync runs asynchronously. Poll
|
||||
``/system/scheduler-status`` for progress, or check server logs.
|
||||
"""
|
||||
summary = sync_mitre(db)
|
||||
background_tasks.add_task(_bg_mitre_sync)
|
||||
return {
|
||||
"message": "MITRE sync completed",
|
||||
"new": summary["created"],
|
||||
"updated": summary["updated"],
|
||||
"message": "MITRE sync started in background",
|
||||
"status": "started",
|
||||
"new": 0,
|
||||
"updated": 0,
|
||||
}
|
||||
|
||||
|
||||
@@ -118,3 +199,606 @@ def scheduler_status(
|
||||
for job in jobs
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Jira config endpoints (admin only)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class JiraConfigOut(BaseModel):
|
||||
enabled: bool
|
||||
url: str
|
||||
project_key: str
|
||||
parent_ticket: str
|
||||
parent_ticket_standalone: str # parent for tests not in a campaign
|
||||
# Credentials are never returned
|
||||
|
||||
|
||||
class JiraConfigUpdate(BaseModel):
|
||||
enabled: Optional[bool] = None
|
||||
url: Optional[str] = None
|
||||
project_key: Optional[str] = None
|
||||
parent_ticket: Optional[str] = None
|
||||
parent_ticket_standalone: Optional[str] = None
|
||||
|
||||
|
||||
_JIRA_KEYS = {
|
||||
"enabled": "jira.enabled",
|
||||
"url": "jira.url",
|
||||
"project_key": "jira.project_key",
|
||||
"parent_ticket": "jira.parent_ticket",
|
||||
"parent_ticket_standalone": "jira.parent_ticket_standalone",
|
||||
}
|
||||
|
||||
|
||||
@router.get("/jira-config", response_model=JiraConfigOut)
|
||||
def get_jira_config(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Return current Jira configuration (merged DB + env).
|
||||
|
||||
**Requires** the ``admin`` role. Credentials are never returned.
|
||||
"""
|
||||
from app.services.jira_service import (
|
||||
get_jira_url, get_jira_project_key, is_jira_enabled,
|
||||
get_jira_parent_ticket, get_jira_parent_ticket_standalone,
|
||||
)
|
||||
|
||||
return JiraConfigOut(
|
||||
enabled=is_jira_enabled(db),
|
||||
url=get_jira_url(db) or "",
|
||||
project_key=get_jira_project_key(db) or "",
|
||||
parent_ticket=get_jira_parent_ticket(db) or "",
|
||||
parent_ticket_standalone=get_jira_parent_ticket_standalone(db) or "",
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/jira-config", response_model=JiraConfigOut)
|
||||
def update_jira_config(
|
||||
payload: JiraConfigUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Update Jira configuration and persist to DB.
|
||||
|
||||
**Requires** the ``admin`` role. Only provided fields are updated.
|
||||
"""
|
||||
from app.services.jira_service import (
|
||||
upsert_jira_config, get_jira_url, get_jira_project_key, is_jira_enabled,
|
||||
get_jira_parent_ticket, get_jira_parent_ticket_standalone,
|
||||
)
|
||||
|
||||
update_data = payload.model_dump(exclude_unset=True)
|
||||
for field, val in update_data.items():
|
||||
db_key = _JIRA_KEYS.get(field)
|
||||
if db_key:
|
||||
upsert_jira_config(db, db_key, str(val))
|
||||
db.commit()
|
||||
|
||||
return JiraConfigOut(
|
||||
enabled=is_jira_enabled(db),
|
||||
url=get_jira_url(db) or "",
|
||||
project_key=get_jira_project_key(db) or "",
|
||||
parent_ticket=get_jira_parent_ticket(db) or "",
|
||||
parent_ticket_standalone=get_jira_parent_ticket_standalone(db) or "",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/jira-test")
|
||||
def test_jira_connection(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Test the Jira connection using the current user's credentials.
|
||||
|
||||
Requires the admin to have a personal Jira API token configured in their
|
||||
profile settings.
|
||||
|
||||
Always returns HTTP 200 with a ``status`` field so Cloudflare never
|
||||
replaces the response with its own error page.
|
||||
"""
|
||||
from app.services.jira_service import get_user_jira_client, get_jira_url, _effective_jira_email
|
||||
|
||||
jira_url = get_jira_url(db)
|
||||
if not jira_url:
|
||||
return {"status": "error", "message": "Jira URL is not configured. Set it in System Settings → Jira Configuration.", "jira_url": ""}
|
||||
|
||||
auth_email = _effective_jira_email(current_user)
|
||||
|
||||
try:
|
||||
jira = get_user_jira_client(current_user, db)
|
||||
# 10-second timeout so we never block Cloudflare into a 524
|
||||
try:
|
||||
jira._session.timeout = 10 # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
pass
|
||||
myself = jira.myself()
|
||||
logger.info("Jira myself() response keys: %s", list(myself.keys()) if isinstance(myself, dict) else type(myself))
|
||||
# Use displayName → emailAddress → name → the auth email as fallback
|
||||
connected_as = (
|
||||
(myself.get("displayName") if isinstance(myself, dict) else None)
|
||||
or (myself.get("emailAddress") if isinstance(myself, dict) else None)
|
||||
or (myself.get("name") if isinstance(myself, dict) else None)
|
||||
or auth_email
|
||||
or "authenticated"
|
||||
)
|
||||
return {
|
||||
"status": "ok",
|
||||
"connected_as": connected_as,
|
||||
"jira_url": jira_url,
|
||||
}
|
||||
except Exception as exc:
|
||||
err = str(exc)
|
||||
# Always return HTTP 200 with status="error" so Cloudflare never
|
||||
# intercepts the response and the frontend always sees our message.
|
||||
if "Expecting value" in err or "line 1 column 1" in err:
|
||||
msg = (
|
||||
"Jira returned a non-JSON response. "
|
||||
"Verify the URL (e.g. https://company.atlassian.net), "
|
||||
"email and API token."
|
||||
)
|
||||
elif "401" in err or "Unauthorized" in err:
|
||||
msg = (
|
||||
"Authentication failed (401). "
|
||||
f"Check that the Atlassian email ({auth_email or 'not set'}) "
|
||||
"and API token are correct. The token must be an Atlassian API token "
|
||||
"(not your account password)."
|
||||
)
|
||||
elif "403" in err or "Forbidden" in err:
|
||||
msg = "Access denied (403). The token may not have permission for this Jira project."
|
||||
elif "timed out" in err.lower() or "timeout" in err.lower():
|
||||
msg = "Connection timed out. Check that the Jira URL is reachable from the server."
|
||||
elif "not configured" in err.lower():
|
||||
msg = err
|
||||
else:
|
||||
msg = f"Jira connection failed: {err}"
|
||||
logger.warning("Jira test connection failed: %s", err)
|
||||
return {"status": "error", "message": msg, "jira_url": jira_url}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /system/tempo-test
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.post("/tempo-test")
|
||||
def test_tempo_connection(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Test the current user's personal Tempo connection.
|
||||
|
||||
Uses the Tempo API token stored in the user's profile (not a global token).
|
||||
Always returns HTTP 200 with a ``status`` field so Cloudflare never
|
||||
intercepts the response.
|
||||
"""
|
||||
from app.services.tempo_service import has_tempo_configured
|
||||
|
||||
tempo_token = getattr(current_user, "tempo_api_token", None)
|
||||
if not tempo_token:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": (
|
||||
"No Tempo API token configured. "
|
||||
"Add it in Settings → Profile → Tempo Integration."
|
||||
),
|
||||
}
|
||||
|
||||
jira_account_id = getattr(current_user, "jira_account_id", None)
|
||||
if not jira_account_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": (
|
||||
"No Jira Account ID configured. "
|
||||
"Set it in Settings → Profile → Jira Integration → Account ID."
|
||||
),
|
||||
}
|
||||
|
||||
try:
|
||||
from tempoapiclient import client_v4 as tempo_client
|
||||
tempo = tempo_client.Tempo(auth_token=tempo_token)
|
||||
# search_worklogs by authorId is the correct v4 method; use a tight
|
||||
# date range so we fetch almost nothing but still verify connectivity.
|
||||
worklogs = tempo.search_worklogs(
|
||||
dateFrom="2024-01-01",
|
||||
dateTo="2024-01-02",
|
||||
authorIds=[jira_account_id],
|
||||
)
|
||||
count = len(worklogs) if isinstance(worklogs, list) else "n/a"
|
||||
return {
|
||||
"status": "ok",
|
||||
"message": f"Tempo connected successfully. Account ID: {jira_account_id}",
|
||||
"worklogs_found": count,
|
||||
}
|
||||
except Exception as exc:
|
||||
err = str(exc)
|
||||
if "401" in err or "Unauthorized" in err:
|
||||
msg = (
|
||||
f"Authentication failed (401). "
|
||||
f"Check your Tempo API token — obtain it at "
|
||||
f"Jira → Apps → Tempo → Settings → API Integration."
|
||||
)
|
||||
elif "403" in err or "Forbidden" in err:
|
||||
msg = "Access denied (403). The Tempo token lacks the required permissions."
|
||||
elif "404" in err or "not found" in err.lower():
|
||||
msg = (
|
||||
f"Account ID not found (404). "
|
||||
f"The value '{jira_account_id}' may be wrong — see the instructions "
|
||||
f"below to find your correct Atlassian Account ID."
|
||||
)
|
||||
else:
|
||||
msg = f"Tempo connection failed: {err}"
|
||||
logger.warning(
|
||||
"Tempo test connection failed for user %s (account_id=%s): %s",
|
||||
current_user.username, jira_account_id, err,
|
||||
)
|
||||
return {"status": "error", "message": msg}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /system/email-config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get("/email-config", response_model=EmailConfigOut)
|
||||
def get_email_config(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Return current SMTP email configuration (merged DB + env).
|
||||
|
||||
**Requires** the ``admin`` role. Password is never returned.
|
||||
"""
|
||||
cfg = _read_email_config_from_db(db)
|
||||
return EmailConfigOut(
|
||||
enabled=cfg["enabled"],
|
||||
host=cfg["host"],
|
||||
port=cfg["port"],
|
||||
username=cfg["username"],
|
||||
from_email=cfg["from_email"],
|
||||
use_tls=cfg["use_tls"],
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PATCH /system/email-config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.patch("/email-config", response_model=EmailConfigOut)
|
||||
def update_email_config(
|
||||
payload: EmailConfigUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Update SMTP email configuration and persist to DB.
|
||||
|
||||
**Requires** the ``admin`` role.
|
||||
Only provided fields are updated (partial update).
|
||||
"""
|
||||
update_data = payload.model_dump(exclude_unset=True)
|
||||
for field, val in update_data.items():
|
||||
db_key = _SMTP_KEYS.get(field)
|
||||
if db_key:
|
||||
_upsert_config(db, db_key, str(val))
|
||||
db.commit()
|
||||
|
||||
cfg = _read_email_config_from_db(db)
|
||||
return EmailConfigOut(
|
||||
enabled=cfg["enabled"],
|
||||
host=cfg["host"],
|
||||
port=cfg["port"],
|
||||
username=cfg["username"],
|
||||
from_email=cfg["from_email"],
|
||||
use_tls=cfg["use_tls"],
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /system/email-test
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ATT&CK Evaluations endpoints (admin only)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get("/attck-evaluations/rounds")
|
||||
def list_evaluation_rounds(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Return all public CrowdStrike ENTERPRISE evaluation rounds with import status.
|
||||
|
||||
Each entry includes whether it has already been imported into this platform.
|
||||
"""
|
||||
from app.services.attck_evaluations_service import fetch_rounds_with_status
|
||||
from app.models.evaluation_import import EvaluationImport
|
||||
|
||||
status_info = fetch_rounds_with_status()
|
||||
rounds = status_info["rounds"]
|
||||
|
||||
imported = {
|
||||
row.adversary_name.lower(): row
|
||||
for row in db.query(EvaluationImport).filter(EvaluationImport.status == "completed").all()
|
||||
}
|
||||
|
||||
round_list = [
|
||||
{
|
||||
"name": r["name"],
|
||||
"display_name": r.get("display_name", r["name"]),
|
||||
"eval_round": r["eval_round"],
|
||||
"imported": r["name"].lower() in imported,
|
||||
"imported_at": imported[r["name"].lower()].imported_at.isoformat()
|
||||
if r["name"].lower() in imported else None,
|
||||
"tests_created": imported[r["name"].lower()].tests_created
|
||||
if r["name"].lower() in imported else None,
|
||||
"techniques_covered": imported[r["name"].lower()].techniques_covered
|
||||
if r["name"].lower() in imported else None,
|
||||
}
|
||||
for r in rounds
|
||||
]
|
||||
|
||||
return {
|
||||
"rounds": round_list,
|
||||
"api_reachable": status_info["api_reachable"],
|
||||
"api_error": status_info.get("api_error"),
|
||||
}
|
||||
|
||||
|
||||
@router.post("/attck-evaluations/import")
|
||||
def import_evaluation_round(
|
||||
payload: dict,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Import a specific ATT&CK Evaluation round for CrowdStrike.
|
||||
|
||||
Body: { "adversary_name": "apt29", "adversary_display": "APT29", "eval_round": 2 }
|
||||
|
||||
Creates tests in ``in_review`` state — Blue Leads must validate each
|
||||
result before it counts as real coverage.
|
||||
"""
|
||||
from app.services.attck_evaluations_service import import_evaluation_round as _import
|
||||
|
||||
adversary_name = payload.get("adversary_name", "")
|
||||
adversary_display = payload.get("adversary_display", adversary_name)
|
||||
eval_round = payload.get("eval_round", 0)
|
||||
|
||||
if not adversary_name or not eval_round:
|
||||
raise HTTPException(status_code=400, detail="adversary_name and eval_round are required")
|
||||
|
||||
try:
|
||||
summary = _import(db, adversary_name, adversary_display, eval_round, current_user)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=409, detail=str(exc))
|
||||
except Exception as exc:
|
||||
logger.error("ATT&CK Evaluation import failed: %s", exc, exc_info=True)
|
||||
raise HTTPException(status_code=502, detail=f"Import failed: {exc}")
|
||||
|
||||
return {
|
||||
"message": f"Import complete — {summary['created']} tests created",
|
||||
**summary,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/attck-evaluations/import-latest")
|
||||
def import_latest_evaluation(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Import the latest available CrowdStrike evaluation round automatically.
|
||||
|
||||
Returns 409 if the latest round was already imported.
|
||||
"""
|
||||
from app.services.attck_evaluations_service import get_latest_round, import_evaluation_round as _import
|
||||
|
||||
try:
|
||||
latest = get_latest_round()
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=502, detail=f"Could not reach MITRE Evaluations API: {exc}")
|
||||
|
||||
try:
|
||||
summary = _import(
|
||||
db,
|
||||
latest["name"],
|
||||
latest.get("display_name", latest["name"]),
|
||||
latest["eval_round"],
|
||||
current_user,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=409, detail=str(exc))
|
||||
except Exception as exc:
|
||||
logger.error("ATT&CK Evaluation import failed: %s", exc, exc_info=True)
|
||||
raise HTTPException(status_code=502, detail=f"Import failed: {exc}")
|
||||
|
||||
return {
|
||||
"message": f"Import complete — {summary['created']} tests created",
|
||||
**summary,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/attck-evaluations/check-new")
|
||||
def check_new_evaluation_round(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Check if a new ATT&CK Evaluation round is available that hasn't been imported yet."""
|
||||
from app.services.attck_evaluations_service import check_for_new_round
|
||||
|
||||
return check_for_new_round(db)
|
||||
|
||||
|
||||
@router.post("/attck-evaluations/bulk-approve")
|
||||
def bulk_approve_evaluation_tests(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Bulk-approve all Blue Team validation for ATT&CK Evaluation imported tests.
|
||||
|
||||
Finds every test in ``in_review`` state whose name starts with ``[EVAL R``
|
||||
and approves the Blue Team side. Because all evaluation imports pre-approve
|
||||
the Red Team side, this moves every matched test to ``validated`` state.
|
||||
|
||||
**Important caveats** (enforced by UI warnings before this is called):
|
||||
- Results come from a controlled MITRE lab, NOT the organisation's env.
|
||||
- Validated tests will immediately affect coverage metrics and dashboards.
|
||||
- Blue Leads should still spot-check high-priority techniques individually.
|
||||
"""
|
||||
from datetime import datetime
|
||||
from app.models.test import Test
|
||||
from app.models.enums import TestState
|
||||
from app.models.technique import Technique
|
||||
from app.services.status_service import recalculate_technique_status
|
||||
from app.services.audit_service import log_action
|
||||
|
||||
# Find all pending evaluation tests
|
||||
pending = (
|
||||
db.query(Test)
|
||||
.filter(
|
||||
Test.state == TestState.in_review,
|
||||
Test.name.like("[EVAL R%"),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
if not pending:
|
||||
return {
|
||||
"approved": 0,
|
||||
"techniques_recalculated": 0,
|
||||
"message": "No pending evaluation tests found — nothing to approve.",
|
||||
}
|
||||
|
||||
now = datetime.utcnow()
|
||||
affected_technique_ids: set = set()
|
||||
|
||||
for test in pending:
|
||||
# Approve blue side
|
||||
test.blue_validation_status = "approved"
|
||||
test.blue_validated_by = current_user.id
|
||||
test.blue_validated_at = now
|
||||
test.blue_validation_notes = (
|
||||
"Bulk-approved via ATT&CK Evaluations admin panel. "
|
||||
"Source: MITRE lab environment — not organisational detection."
|
||||
)
|
||||
|
||||
# Red side was pre-approved during import → move to validated
|
||||
if test.red_validation_status == "approved":
|
||||
test.state = TestState.validated
|
||||
# else stays in_review (shouldn't happen for eval imports, but be safe)
|
||||
|
||||
if test.technique_id:
|
||||
affected_technique_ids.add(test.technique_id)
|
||||
|
||||
log_action(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
action="bulk_eval_approve",
|
||||
entity_type="test",
|
||||
entity_id=test.id,
|
||||
details={"source": "attck_evaluation_bulk_approve"},
|
||||
)
|
||||
|
||||
db.flush()
|
||||
|
||||
# Recalculate coverage for every touched technique
|
||||
for tech_id in affected_technique_ids:
|
||||
tech = db.query(Technique).filter(Technique.id == tech_id).first()
|
||||
if tech:
|
||||
recalculate_technique_status(db, tech)
|
||||
|
||||
db.commit()
|
||||
|
||||
logger.info(
|
||||
"Bulk eval approval: %d tests validated, %d techniques recalculated (by %s)",
|
||||
len(pending), len(affected_technique_ids), current_user.email,
|
||||
)
|
||||
|
||||
return {
|
||||
"approved": len(pending),
|
||||
"techniques_recalculated": len(affected_technique_ids),
|
||||
"message": (
|
||||
f"{len(pending)} evaluation tests approved and moved to Validated. "
|
||||
f"{len(affected_technique_ids)} technique statuses recalculated."
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@router.get("/attck-evaluations/pending-count")
|
||||
def get_pending_evaluation_count(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Return the number of imported evaluation tests still awaiting Blue approval."""
|
||||
from app.models.test import Test
|
||||
from app.models.enums import TestState
|
||||
|
||||
count = (
|
||||
db.query(Test)
|
||||
.filter(
|
||||
Test.state == TestState.in_review,
|
||||
Test.name.like("[EVAL R%"),
|
||||
)
|
||||
.count()
|
||||
)
|
||||
return {"pending": count}
|
||||
|
||||
|
||||
@router.post("/attck-evaluations/re-enrich")
|
||||
def re_enrich_evaluation_round(
|
||||
payload: dict,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Re-enrich already-imported evaluation tests with rich data from the MITRE API.
|
||||
|
||||
Updates procedure_text (attack path + criteria), description (data sources +
|
||||
substep references) and red_summary — without changing detection results,
|
||||
state or validation status.
|
||||
|
||||
Body: { "adversary_name": "turla", "adversary_display": "Turla", "eval_round": 5 }
|
||||
|
||||
Useful to upgrade tests that were imported before the enrichment feature
|
||||
was added.
|
||||
"""
|
||||
from app.services.attck_evaluations_service import re_enrich_evaluation_round as _re_enrich
|
||||
|
||||
adversary_name = payload.get("adversary_name", "")
|
||||
adversary_display = payload.get("adversary_display", adversary_name)
|
||||
eval_round = payload.get("eval_round", 0)
|
||||
|
||||
if not adversary_name or not eval_round:
|
||||
raise HTTPException(status_code=400, detail="adversary_name and eval_round are required")
|
||||
|
||||
try:
|
||||
summary = _re_enrich(db, adversary_name, adversary_display, eval_round, current_user)
|
||||
except Exception as exc:
|
||||
logger.error("ATT&CK Evaluation re-enrich failed: %s", exc, exc_info=True)
|
||||
raise HTTPException(status_code=502, detail=f"Re-enrich failed: {exc}")
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
@router.post("/email-test")
|
||||
def send_test_email(
|
||||
payload: EmailTestRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Send a test email to verify SMTP configuration.
|
||||
|
||||
**Requires** the ``admin`` role.
|
||||
Returns 200 on success, 502 if sending fails.
|
||||
"""
|
||||
from app.services.email_service import send_test_email as _send_test
|
||||
|
||||
ok = _send_test(payload.to, db=db)
|
||||
if not ok:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail="Failed to send test email. Check SMTP configuration and server logs.",
|
||||
)
|
||||
return {"detail": f"Test email sent to {payload.to}"}
|
||||
|
||||
@@ -31,6 +31,7 @@ from sqlalchemy.orm import Session
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import get_current_user, require_any_role
|
||||
from app.domain.unit_of_work import UnitOfWork
|
||||
from app.models.technique import Technique
|
||||
from app.models.user import User
|
||||
from app.schemas.test_template import (
|
||||
TestTemplateCreate,
|
||||
@@ -178,6 +179,15 @@ def create_template(
|
||||
"""Create a custom test template."""
|
||||
template = create_template_svc(db, **payload.model_dump())
|
||||
with UnitOfWork(db) as uow:
|
||||
# Flag the associated technique for review — new template available
|
||||
if template.mitre_technique_id:
|
||||
technique = (
|
||||
db.query(Technique)
|
||||
.filter(Technique.mitre_id == template.mitre_technique_id)
|
||||
.first()
|
||||
)
|
||||
if technique:
|
||||
technique.review_required = True
|
||||
log_action(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
|
||||
@@ -11,6 +11,7 @@ PATCH /tests/{id}/red — Red Team updates (draft, red_executing)
|
||||
PATCH /tests/{id}/blue — Blue Team updates (blue_evaluating)
|
||||
POST /tests/{id}/start-execution — draft → red_executing
|
||||
POST /tests/{id}/submit-red — red_executing → blue_evaluating
|
||||
POST /tests/{id}/start-blue-work — blue tech picks up (sets Tempo timer)
|
||||
POST /tests/{id}/submit-blue — blue_evaluating → in_review
|
||||
POST /tests/{id}/validate-red — Red Lead validates
|
||||
POST /tests/{id}/validate-blue — Blue Lead validates
|
||||
@@ -18,17 +19,26 @@ POST /tests/{id}/reopen — rejected → draft
|
||||
GET /tests/{id}/timeline — audit-log history for this test
|
||||
"""
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import os
|
||||
import uuid
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import get_current_user, require_any_role, require_role
|
||||
from app.domain.enums import DataClassification
|
||||
from app.limiter import limiter
|
||||
from app.models.enums import TestState
|
||||
from app.models.enums import TestState, TestResult, TeamSide
|
||||
from app.models.evidence import Evidence
|
||||
from app.storage import upload_file
|
||||
from app.models.technique import Technique
|
||||
from app.models.test import Test
|
||||
from app.models.user import User
|
||||
from app.schemas.test import (
|
||||
TestCreate,
|
||||
@@ -45,6 +55,7 @@ from app.schemas.test_template import TestTemplateInstantiate
|
||||
from app.domain.unit_of_work import UnitOfWork
|
||||
from app.services.audit_service import log_action
|
||||
from app.services.status_service import recalculate_technique_status
|
||||
from app.services.webhook_service import dispatch_webhook
|
||||
from app.services.test_crud_service import (
|
||||
create_test as crud_create_test,
|
||||
create_test_from_template as crud_create_from_template,
|
||||
@@ -61,6 +72,7 @@ from app.services.test_workflow_service import (
|
||||
start_execution as wf_start_execution,
|
||||
submit_red_evidence as wf_submit_red,
|
||||
submit_blue_evidence as wf_submit_blue,
|
||||
start_blue_work as wf_start_blue_work,
|
||||
validate_as_red_lead as wf_validate_red,
|
||||
validate_as_blue_lead as wf_validate_blue,
|
||||
reopen_test as wf_reopen,
|
||||
@@ -87,6 +99,9 @@ def list_tests(
|
||||
pending_validation_side: Optional[str] = Query(
|
||||
None, description="Filter in_review tests pending validation on 'red' or 'blue' side"
|
||||
),
|
||||
not_in_any_campaign: bool = Query(
|
||||
False, description="Only return tests not linked to any campaign"
|
||||
),
|
||||
offset: int = Query(0, ge=0),
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
db: Session = Depends(get_db),
|
||||
@@ -100,6 +115,7 @@ def list_tests(
|
||||
platform=platform,
|
||||
created_by=created_by,
|
||||
pending_validation_side=pending_validation_side,
|
||||
not_in_any_campaign=not_in_any_campaign,
|
||||
offset=offset,
|
||||
limit=limit,
|
||||
)
|
||||
@@ -144,6 +160,14 @@ def create_test(
|
||||
uow.commit()
|
||||
db.refresh(test)
|
||||
|
||||
# Auto-create Jira ticket (non-fatal — any failure is logged, not raised)
|
||||
try:
|
||||
from app.services.jira_service import auto_create_test_issue
|
||||
auto_create_test_issue(db, test, current_user)
|
||||
db.commit()
|
||||
except Exception:
|
||||
pass # jira_service already logs warnings internally
|
||||
|
||||
return test
|
||||
|
||||
|
||||
@@ -174,6 +198,11 @@ def create_test_from_template(
|
||||
template_id=payload.template_id,
|
||||
technique_id_or_mitre=payload.technique_id,
|
||||
creator_id=current_user.id,
|
||||
name_override=payload.name,
|
||||
description_override=payload.description,
|
||||
platform_override=payload.platform,
|
||||
procedure_text_override=payload.procedure_text,
|
||||
tool_used_override=payload.tool_used,
|
||||
)
|
||||
log_action(
|
||||
db,
|
||||
@@ -190,6 +219,14 @@ def create_test_from_template(
|
||||
uow.commit()
|
||||
db.refresh(test)
|
||||
|
||||
# Auto-create Jira ticket (non-fatal)
|
||||
try:
|
||||
from app.services.jira_service import auto_create_test_issue
|
||||
auto_create_test_issue(db, test, current_user)
|
||||
db.commit()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return test
|
||||
|
||||
|
||||
@@ -398,6 +435,26 @@ def submit_blue(
|
||||
return test
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /tests/{id}/start-blue-work — blue tech picks up test for evaluation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.post("/{test_id}/start-blue-work", response_model=TestOut)
|
||||
def start_blue_work(
|
||||
test_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_any_role("blue_tech", "blue_lead")),
|
||||
):
|
||||
"""Blue tech picks up the test to start evaluating. Sets the Tempo timer start."""
|
||||
test = crud_get_test_or_raise(db, test_id)
|
||||
with UnitOfWork(db) as uow:
|
||||
test = wf_start_blue_work(db, test, current_user)
|
||||
uow.commit()
|
||||
db.refresh(test)
|
||||
return test
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /tests/{id}/pause-timer — pause the active phase timer
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -460,8 +517,15 @@ def validate_red(
|
||||
)
|
||||
if test.state in (TestState.validated, TestState.rejected):
|
||||
recalculate_technique_status(db, test.technique)
|
||||
# Flag technique for review — coverage changed
|
||||
if test.technique:
|
||||
test.technique.review_required = True
|
||||
uow.commit()
|
||||
db.refresh(test)
|
||||
if test.state == TestState.validated:
|
||||
dispatch_webhook("test.validated", {"test_id": str(test.id), "technique_id": str(test.technique_id), "result": test.result.value if test.result else None})
|
||||
elif test.state == TestState.rejected:
|
||||
dispatch_webhook("test.rejected", {"test_id": str(test.id), "technique_id": str(test.technique_id)})
|
||||
return test
|
||||
|
||||
|
||||
@@ -487,8 +551,15 @@ def validate_blue(
|
||||
)
|
||||
if test.state in (TestState.validated, TestState.rejected):
|
||||
recalculate_technique_status(db, test.technique)
|
||||
# Flag technique for review — coverage changed
|
||||
if test.technique:
|
||||
test.technique.review_required = True
|
||||
uow.commit()
|
||||
db.refresh(test)
|
||||
if test.state == TestState.validated:
|
||||
dispatch_webhook("test.validated", {"test_id": str(test.id), "technique_id": str(test.technique_id), "result": test.result.value if test.result else None})
|
||||
elif test.state == TestState.rejected:
|
||||
dispatch_webhook("test.rejected", {"test_id": str(test.id), "technique_id": str(test.technique_id)})
|
||||
return test
|
||||
|
||||
|
||||
@@ -602,3 +673,366 @@ def get_retest_chain(
|
||||
}
|
||||
for t in chain
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /tests/{id}/sync-tempo — manual Tempo sync for red execution worklog
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.post("/{test_id}/sync-tempo")
|
||||
def sync_tempo(
|
||||
test_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Manually sync this test's red team execution worklog(s) to Tempo.
|
||||
|
||||
Useful when the automatic sync failed at phase completion (e.g. Tempo
|
||||
was not yet configured). Only red_team_execution worklogs are eligible.
|
||||
Already-synced worklogs are skipped. Returns a summary of what happened.
|
||||
"""
|
||||
from datetime import datetime as _dt
|
||||
from app.models.worklog import Worklog
|
||||
from app.services.tempo_service import auto_log_test_worklog
|
||||
from app.services.test_crud_service import get_test_or_raise as _get
|
||||
|
||||
test = _get(db, test_id)
|
||||
|
||||
worklogs = (
|
||||
db.query(Worklog)
|
||||
.filter(
|
||||
Worklog.entity_type == "test",
|
||||
Worklog.entity_id == test_id,
|
||||
Worklog.activity_type == "red_team_execution",
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
if not worklogs:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="No red team execution worklog found for this test.",
|
||||
)
|
||||
|
||||
results = []
|
||||
for wl in worklogs:
|
||||
if wl.tempo_synced:
|
||||
results.append({"worklog_id": str(wl.id), "status": "already_synced"})
|
||||
continue
|
||||
|
||||
try:
|
||||
result = auto_log_test_worklog(
|
||||
db=db,
|
||||
test=test,
|
||||
user=current_user,
|
||||
activity_type=wl.activity_type,
|
||||
duration_seconds=wl.duration_seconds,
|
||||
)
|
||||
if result and isinstance(result, dict):
|
||||
wl.tempo_synced = _dt.utcnow()
|
||||
wl.tempo_worklog_id = str(result.get("tempoWorklogId", ""))
|
||||
db.commit()
|
||||
results.append({"worklog_id": str(wl.id), "status": "synced"})
|
||||
else:
|
||||
results.append({
|
||||
"worklog_id": str(wl.id),
|
||||
"status": "skipped",
|
||||
"detail": "Tempo not configured or conditions not met.",
|
||||
})
|
||||
except Exception as exc:
|
||||
results.append({
|
||||
"worklog_id": str(wl.id),
|
||||
"status": "error",
|
||||
"detail": str(exc),
|
||||
})
|
||||
|
||||
return {"results": results}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /tests/{id}/request-discussion — disputed: confirm vote + notify other lead
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.post("/{test_id}/request-discussion")
|
||||
def request_discussion(
|
||||
test_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead", "admin")),
|
||||
):
|
||||
"""Called when the approving lead confirms their vote in a disputed test.
|
||||
|
||||
Sends a notification to the other lead (who rejected) asking them to
|
||||
discuss and resolve the conflict. The test remains in 'disputed' state.
|
||||
"""
|
||||
from app.models.enums import TestState as ModelTestState
|
||||
from app.models.user import User as UserModel
|
||||
from app.services.notification_service import create_notification
|
||||
|
||||
test = crud_get_test_or_raise(db, test_id)
|
||||
|
||||
if test.state.value != "disputed":
|
||||
from app.domain.errors import BusinessRuleViolation
|
||||
raise BusinessRuleViolation("Test is not in disputed state")
|
||||
|
||||
role = current_user.role
|
||||
|
||||
# Identify who the "other lead" is (the one who rejected)
|
||||
if (role in ("red_lead", "admin")) and test.red_validation_status == "approved":
|
||||
# Red approved, Blue rejected → notify Blue Lead who rejected
|
||||
rejector_id = test.blue_validated_by
|
||||
rejector_label = "Blue Lead"
|
||||
requester_label = "Red Lead"
|
||||
elif (role in ("blue_lead", "admin")) and test.blue_validation_status == "approved":
|
||||
# Blue approved, Red rejected → notify Red Lead who rejected
|
||||
rejector_id = test.red_validated_by
|
||||
rejector_label = "Red Lead"
|
||||
requester_label = "Blue Lead"
|
||||
else:
|
||||
from app.domain.errors import BusinessRuleViolation
|
||||
raise BusinessRuleViolation(
|
||||
"The conflict state is inconsistent — no approving lead found"
|
||||
)
|
||||
|
||||
# Look up the rejecting lead's full info for the response
|
||||
rejector = (
|
||||
db.query(UserModel).filter(UserModel.id == rejector_id).first()
|
||||
if rejector_id else None
|
||||
)
|
||||
rejector_name = rejector.username if rejector else rejector_label
|
||||
rejector_email = getattr(rejector, "email", None) if rejector else None
|
||||
|
||||
# Notify the rejecting lead
|
||||
if rejector_id:
|
||||
try:
|
||||
create_notification(
|
||||
db,
|
||||
user_id=rejector_id,
|
||||
type="validation_conflict",
|
||||
title="Discussion requested on disputed test",
|
||||
message=(
|
||||
f"{requester_label} ({current_user.username}) is confirming their approval "
|
||||
f"of test '{test.name}' and wants to discuss your rejection with you. "
|
||||
f"Please reach out to resolve the disagreement."
|
||||
),
|
||||
entity_type="test",
|
||||
entity_id=str(test.id),
|
||||
)
|
||||
except Exception as e:
|
||||
import logging
|
||||
logging.getLogger(__name__).warning(
|
||||
"Failed to send discussion notification: %s", e
|
||||
)
|
||||
|
||||
log_action(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
action="request_dispute_discussion",
|
||||
entity_type="test",
|
||||
entity_id=test.id,
|
||||
details={"test_name": test.name, "rejector": rejector_name},
|
||||
)
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
"status": "notification_sent",
|
||||
"message": f"Discussion request sent to {rejector_name}",
|
||||
"rejector_username": rejector_name,
|
||||
"rejector_email": rejector_email,
|
||||
"rejector_role": rejector_label,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /tests/import-rt — bulk import from a real Red Team engagement
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
_ALLOWED_IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".gif", ".webp", ".bmp"}
|
||||
_MAX_EVIDENCE_BYTES = 10 * 1024 * 1024 # 10 MB decoded per image
|
||||
|
||||
|
||||
class RTEvidenceEntry(BaseModel):
|
||||
filename: str # e.g. "screenshot_edr.png"
|
||||
data: str # base64-encoded image content
|
||||
caption: Optional[str] = None # optional description shown as evidence notes
|
||||
|
||||
|
||||
class RTTechniqueEntry(BaseModel):
|
||||
mitre_id: str
|
||||
result: str # "detected" | "not_detected" | "partially_detected"
|
||||
attack_success: bool = True
|
||||
platform: Optional[str] = None
|
||||
notes: Optional[str] = None
|
||||
evidence: list[RTEvidenceEntry] # REQUIRED — at least one image per technique
|
||||
|
||||
|
||||
class RTImportPayload(BaseModel):
|
||||
name: str # engagement name, e.g. "Red Team Q1 2024"
|
||||
date: Optional[str] = None # ISO date string
|
||||
description: Optional[str] = None
|
||||
operator: Optional[str] = None # team / company that ran the RT
|
||||
techniques: list[RTTechniqueEntry]
|
||||
|
||||
|
||||
@router.post("/import-rt", status_code=status.HTTP_201_CREATED)
|
||||
def import_rt(
|
||||
payload: RTImportPayload,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_any_role("red_lead")),
|
||||
):
|
||||
"""Import results from a real Red Team engagement.
|
||||
|
||||
Creates one Test record per technique in ``validated`` state (bypassing
|
||||
the normal Red/Blue workflow) and immediately recalculates coverage metrics.
|
||||
Requires ``red_lead`` or ``admin`` role.
|
||||
"""
|
||||
# Pre-validate: every technique must include at least one evidence image
|
||||
for entry in payload.techniques:
|
||||
if not entry.evidence:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
detail=(
|
||||
f"Technique {entry.mitre_id} is missing evidence. "
|
||||
"At least one screenshot or image is required per technique."
|
||||
),
|
||||
)
|
||||
|
||||
# Execution date from payload or now
|
||||
exec_date_str = payload.date or datetime.utcnow().date().isoformat()
|
||||
|
||||
# Result string → TestResult enum
|
||||
_result_map = {
|
||||
"detected": TestResult.detected,
|
||||
"not_detected": TestResult.not_detected,
|
||||
"partially_detected": TestResult.partially_detected,
|
||||
}
|
||||
|
||||
created: list[dict[str, Any]] = []
|
||||
skipped: list[dict[str, str]] = []
|
||||
affected_technique_ids: set = set()
|
||||
|
||||
with UnitOfWork(db) as uow:
|
||||
for entry in payload.techniques:
|
||||
# Find technique
|
||||
technique = (
|
||||
db.query(Technique)
|
||||
.filter(Technique.mitre_id == entry.mitre_id.upper())
|
||||
.first()
|
||||
)
|
||||
if technique is None:
|
||||
skipped.append({"mitre_id": entry.mitre_id, "reason": "Technique not found"})
|
||||
continue
|
||||
|
||||
detection_result = _result_map.get(entry.result)
|
||||
if detection_result is None:
|
||||
skipped.append({"mitre_id": entry.mitre_id, "reason": f"Unknown result value '{entry.result}'"})
|
||||
continue
|
||||
|
||||
test_name = f"[RT] {payload.name} — {technique.name}"
|
||||
|
||||
# Build red_summary from notes + engagement metadata
|
||||
parts = []
|
||||
if payload.operator:
|
||||
parts.append(f"Operator: {payload.operator}")
|
||||
parts.append(f"Engagement date: {exec_date_str}")
|
||||
if entry.notes:
|
||||
parts.append(f"\n{entry.notes}")
|
||||
red_summary_text = "\n".join(parts)
|
||||
|
||||
# RT pre-validates the Red side (they ran it), but Blue Lead
|
||||
# must still validate the detection result before it counts.
|
||||
# State = in_review so it appears in the Blue Lead's validation queue.
|
||||
test = Test(
|
||||
technique_id=technique.id,
|
||||
name=test_name,
|
||||
description=payload.description,
|
||||
platform=entry.platform,
|
||||
procedure_text=entry.notes,
|
||||
created_by=current_user.id,
|
||||
state=TestState.in_review,
|
||||
# Red team — approved by the RT operator
|
||||
attack_success=entry.attack_success,
|
||||
red_summary=red_summary_text,
|
||||
red_validation_status="approved",
|
||||
red_validated_by=current_user.id,
|
||||
red_validated_at=datetime.utcnow(),
|
||||
# Blue team — pre-fill the detection result but leave
|
||||
# validation_status pending so Blue Lead must confirm
|
||||
detection_result=detection_result,
|
||||
blue_validation_status=None,
|
||||
# Timing
|
||||
execution_date=exec_date_str,
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
db.add(test)
|
||||
db.flush()
|
||||
|
||||
# ── Store evidence images ──────────────────────────────
|
||||
evidence_count = 0
|
||||
for ev in entry.evidence:
|
||||
safe_name = os.path.basename(ev.filename) or "evidence.png"
|
||||
ext = os.path.splitext(safe_name)[1].lower()
|
||||
if ext not in _ALLOWED_IMAGE_EXTS:
|
||||
# Skip non-image files silently (log warning)
|
||||
continue
|
||||
try:
|
||||
img_bytes = base64.b64decode(ev.data)
|
||||
except Exception:
|
||||
continue # malformed base64 — skip
|
||||
if len(img_bytes) > _MAX_EVIDENCE_BYTES:
|
||||
continue # over size limit — skip
|
||||
sha256 = hashlib.sha256(img_bytes).hexdigest()
|
||||
key = f"{test.id}/{uuid.uuid4()}_{safe_name}"
|
||||
try:
|
||||
upload_file(img_bytes, key)
|
||||
except Exception:
|
||||
continue # storage error — skip but don't abort
|
||||
evidence_obj = Evidence(
|
||||
test_id=test.id,
|
||||
file_name=safe_name,
|
||||
file_path=key,
|
||||
sha256_hash=sha256,
|
||||
uploaded_by=current_user.id,
|
||||
uploaded_at=datetime.utcnow(),
|
||||
team=TeamSide.red,
|
||||
notes=ev.caption,
|
||||
)
|
||||
db.add(evidence_obj)
|
||||
evidence_count += 1
|
||||
|
||||
affected_technique_ids.add(technique.id)
|
||||
created.append({
|
||||
"mitre_id": entry.mitre_id,
|
||||
"test_name": test_name,
|
||||
"result": entry.result,
|
||||
"attack_success": entry.attack_success,
|
||||
"evidence_attached": evidence_count,
|
||||
})
|
||||
|
||||
log_action(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
action="rt_import_test",
|
||||
entity_type="test",
|
||||
entity_id=test.id,
|
||||
details={"engagement": payload.name, "mitre_id": entry.mitre_id},
|
||||
)
|
||||
|
||||
# Recalculate coverage for all affected techniques
|
||||
for tech_id in affected_technique_ids:
|
||||
tech = db.query(Technique).filter(Technique.id == tech_id).first()
|
||||
if tech:
|
||||
recalculate_technique_status(db, tech)
|
||||
|
||||
uow.commit()
|
||||
|
||||
return {
|
||||
"created": len(created),
|
||||
"skipped": len(skipped),
|
||||
"items": created,
|
||||
"warnings": skipped,
|
||||
"engagement": payload.name,
|
||||
}
|
||||
|
||||
|
||||
@@ -9,7 +9,8 @@ from app.database import get_db
|
||||
from app.dependencies.auth import require_role
|
||||
from app.domain.unit_of_work import UnitOfWork
|
||||
from app.models.user import User
|
||||
from app.schemas.user import UserCreate, UserUpdate, UserOut
|
||||
from app.dependencies.auth import get_current_user
|
||||
from app.schemas.user import UserCreate, UserUpdate, UserOut, UserPreferencesUpdate
|
||||
from app.services.audit_service import log_action
|
||||
from app.services.user_service import (
|
||||
create_user,
|
||||
@@ -21,6 +22,47 @@ from app.services.user_service import (
|
||||
router = APIRouter(prefix="/users", tags=["users"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PATCH /users/me/preferences — update current user preferences
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.patch("/me/preferences", response_model=UserOut)
|
||||
def update_my_preferences(
|
||||
payload: UserPreferencesUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Update the current user's notification preferences, Jira account ID and Jira API token.
|
||||
|
||||
Send ``jira_api_token: ""`` to clear a previously stored token.
|
||||
The token is never returned in any response.
|
||||
"""
|
||||
update_data = payload.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
if field in ("jira_api_token", "jira_email", "tempo_api_token"):
|
||||
# Empty string means "clear the value"
|
||||
setattr(current_user, field, value if value else None)
|
||||
else:
|
||||
setattr(current_user, field, value)
|
||||
db.commit()
|
||||
db.refresh(current_user)
|
||||
return current_user
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /users/me — get current user's own profile
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserOut)
|
||||
def get_me(
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Return the currently authenticated user's profile."""
|
||||
return current_user
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /users — list all users
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -0,0 +1,149 @@
|
||||
"""Webhook configuration CRUD router — admin only.
|
||||
|
||||
Endpoints
|
||||
---------
|
||||
GET /webhooks — list all webhook configs
|
||||
POST /webhooks — create a new webhook config
|
||||
GET /webhooks/{id} — get a single webhook config
|
||||
PATCH /webhooks/{id} — update a webhook config
|
||||
DELETE /webhooks/{id} — hard-delete a webhook config
|
||||
POST /webhooks/{id}/test — send a test ping
|
||||
"""
|
||||
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import require_any_role
|
||||
from app.domain.unit_of_work import UnitOfWork
|
||||
from app.models.user import User
|
||||
from app.schemas.webhook import WebhookConfigCreate, WebhookConfigOut, WebhookConfigUpdate
|
||||
from app.services.webhook_service import (
|
||||
create_webhook,
|
||||
delete_webhook,
|
||||
dispatch_webhook,
|
||||
get_webhook_or_raise,
|
||||
list_webhooks,
|
||||
update_webhook,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/webhooks", tags=["webhooks"])
|
||||
|
||||
|
||||
def _mask_secret(wh) -> WebhookConfigOut:
|
||||
"""Return a WebhookConfigOut with the secret masked."""
|
||||
out = WebhookConfigOut.model_validate(wh)
|
||||
if wh.secret:
|
||||
out.secret = "***"
|
||||
else:
|
||||
out.secret = None
|
||||
return out
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /webhooks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get("", response_model=list[WebhookConfigOut])
|
||||
def list_webhooks_route(
|
||||
offset: int = 0,
|
||||
limit: int = 50,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_any_role("admin")),
|
||||
):
|
||||
"""Return all webhook configurations. **Requires admin role.**"""
|
||||
webhooks = list_webhooks(db, offset=offset, limit=limit)
|
||||
return [_mask_secret(wh) for wh in webhooks]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /webhooks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.post("", response_model=WebhookConfigOut, status_code=status.HTTP_201_CREATED)
|
||||
def create_webhook_route(
|
||||
payload: WebhookConfigCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_any_role("admin")),
|
||||
):
|
||||
"""Create a new webhook configuration. **Requires admin role.**"""
|
||||
with UnitOfWork(db) as uow:
|
||||
wh = create_webhook(db, created_by=current_user.id, payload=payload)
|
||||
uow.commit()
|
||||
db.refresh(wh)
|
||||
return _mask_secret(wh)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /webhooks/{id}
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get("/{webhook_id}", response_model=WebhookConfigOut)
|
||||
def get_webhook_route(
|
||||
webhook_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_any_role("admin")),
|
||||
):
|
||||
"""Return a single webhook configuration. **Requires admin role.**"""
|
||||
wh = get_webhook_or_raise(db, webhook_id)
|
||||
return _mask_secret(wh)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PATCH /webhooks/{id}
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.patch("/{webhook_id}", response_model=WebhookConfigOut)
|
||||
def update_webhook_route(
|
||||
webhook_id: uuid.UUID,
|
||||
payload: WebhookConfigUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_any_role("admin")),
|
||||
):
|
||||
"""Update one or more fields of a webhook configuration. **Requires admin role.**"""
|
||||
with UnitOfWork(db) as uow:
|
||||
wh = update_webhook(db, webhook_id, payload)
|
||||
uow.commit()
|
||||
db.refresh(wh)
|
||||
return _mask_secret(wh)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DELETE /webhooks/{id}
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.delete("/{webhook_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
def delete_webhook_route(
|
||||
webhook_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_any_role("admin")),
|
||||
):
|
||||
"""Hard-delete a webhook configuration. **Requires admin role.**"""
|
||||
with UnitOfWork(db) as uow:
|
||||
delete_webhook(db, webhook_id)
|
||||
uow.commit()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /webhooks/{id}/test
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.post("/{webhook_id}/test", status_code=status.HTTP_202_ACCEPTED)
|
||||
def test_webhook_route(
|
||||
webhook_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_any_role("admin")),
|
||||
):
|
||||
"""Send a test ping to the webhook endpoint. **Requires admin role.**"""
|
||||
# Verify the webhook exists before dispatching
|
||||
get_webhook_or_raise(db, webhook_id)
|
||||
dispatch_webhook("webhook.test", {"webhook_id": str(webhook_id), "message": "Test ping from Aegis"})
|
||||
return {"detail": "Test ping dispatched"}
|
||||
@@ -0,0 +1,68 @@
|
||||
"""Phase 14: API Key Pydantic schemas."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from app.models.api_key import VALID_SCOPES
|
||||
|
||||
|
||||
class ApiKeyCreate(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=200)
|
||||
description: Optional[str] = None
|
||||
scopes: List[str] = Field(default=["read"])
|
||||
expires_at: Optional[datetime] = None
|
||||
|
||||
@field_validator("scopes")
|
||||
@classmethod
|
||||
def validate_scopes(cls, v: list) -> list:
|
||||
invalid = set(v) - VALID_SCOPES
|
||||
if invalid:
|
||||
raise ValueError(f"Invalid scopes: {invalid}. Valid: {VALID_SCOPES}")
|
||||
if not v:
|
||||
raise ValueError("At least one scope is required")
|
||||
return v
|
||||
|
||||
|
||||
class ApiKeyOut(BaseModel):
|
||||
"""Safe representation — never exposes key_hash."""
|
||||
id: UUID
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
key_prefix: str
|
||||
user_id: UUID
|
||||
scopes: List[str]
|
||||
last_used_at: Optional[datetime] = None
|
||||
expires_at: Optional[datetime] = None
|
||||
is_active: bool
|
||||
created_at: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class ApiKeyCreated(ApiKeyOut):
|
||||
"""Returned only once at creation — includes the raw key."""
|
||||
raw_key: str = Field(..., description="The full API key — shown only this once.")
|
||||
|
||||
|
||||
class ApiKeyUpdate(BaseModel):
|
||||
name: Optional[str] = Field(None, min_length=1, max_length=200)
|
||||
description: Optional[str] = None
|
||||
scopes: Optional[List[str]] = None
|
||||
expires_at: Optional[datetime] = None
|
||||
is_active: Optional[bool] = None
|
||||
|
||||
@field_validator("scopes")
|
||||
@classmethod
|
||||
def validate_scopes(cls, v: Optional[list]) -> Optional[list]:
|
||||
if v is None:
|
||||
return v
|
||||
invalid = set(v) - VALID_SCOPES
|
||||
if invalid:
|
||||
raise ValueError(f"Invalid scopes: {invalid}")
|
||||
return v
|
||||
@@ -0,0 +1,230 @@
|
||||
"""Pydantic schemas for Phase 10: Attack Paths & Advanced Purple Team."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, field_validator
|
||||
|
||||
VALID_KILL_CHAIN_PHASES = [
|
||||
"reconnaissance", "resource_development", "initial_access", "execution",
|
||||
"persistence", "privilege_escalation", "defense_evasion", "credential_access",
|
||||
"discovery", "lateral_movement", "collection", "command_and_control",
|
||||
"exfiltration", "impact",
|
||||
]
|
||||
|
||||
|
||||
# ── Attack Path ───────────────────────────────────────────────────────────────
|
||||
|
||||
class AttackPathCreate(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
objective: Optional[str] = None
|
||||
is_template: bool = False
|
||||
threat_actor_id: Optional[UUID] = None
|
||||
tags: Optional[list[str]] = None
|
||||
|
||||
|
||||
class AttackPathUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
objective: Optional[str] = None
|
||||
is_template: Optional[bool] = None
|
||||
threat_actor_id: Optional[UUID] = None
|
||||
tags: Optional[list[str]] = None
|
||||
is_active: Optional[bool] = None
|
||||
|
||||
|
||||
class AttackPathOut(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: UUID
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
objective: Optional[str] = None
|
||||
is_template: bool
|
||||
threat_actor_id: Optional[UUID] = None
|
||||
created_by: Optional[UUID] = None
|
||||
tags: Optional[list] = None
|
||||
is_active: bool
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
step_count: Optional[int] = None # injected by service
|
||||
|
||||
|
||||
# ── Attack Path Step ──────────────────────────────────────────────────────────
|
||||
|
||||
class AttackPathStepCreate(BaseModel):
|
||||
order_index: int = 0
|
||||
kill_chain_phase: Optional[str] = None
|
||||
technique_id: Optional[UUID] = None
|
||||
test_id: Optional[UUID] = None
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
expected_detection: bool = True
|
||||
notes: Optional[str] = None
|
||||
|
||||
@field_validator("kill_chain_phase")
|
||||
@classmethod
|
||||
def validate_phase(cls, v):
|
||||
if v is not None and v not in VALID_KILL_CHAIN_PHASES:
|
||||
raise ValueError(f"Invalid kill_chain_phase '{v}'. Valid: {VALID_KILL_CHAIN_PHASES}")
|
||||
return v
|
||||
|
||||
|
||||
class AttackPathStepUpdate(BaseModel):
|
||||
order_index: Optional[int] = None
|
||||
kill_chain_phase: Optional[str] = None
|
||||
technique_id: Optional[UUID] = None
|
||||
test_id: Optional[UUID] = None
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
expected_detection: Optional[bool] = None
|
||||
notes: Optional[str] = None
|
||||
|
||||
@field_validator("kill_chain_phase")
|
||||
@classmethod
|
||||
def validate_phase(cls, v):
|
||||
if v is not None and v not in VALID_KILL_CHAIN_PHASES:
|
||||
raise ValueError(f"Invalid kill_chain_phase '{v}'.")
|
||||
return v
|
||||
|
||||
|
||||
class AttackPathStepOut(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: UUID
|
||||
attack_path_id: UUID
|
||||
order_index: int
|
||||
kill_chain_phase: Optional[str] = None
|
||||
technique_id: Optional[UUID] = None
|
||||
test_id: Optional[UUID] = None
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
expected_detection: bool
|
||||
notes: Optional[str] = None
|
||||
|
||||
|
||||
# ── Execution ─────────────────────────────────────────────────────────────────
|
||||
|
||||
class ExecutionCreate(BaseModel):
|
||||
environment: Optional[str] = None
|
||||
red_team_lead: Optional[UUID] = None
|
||||
blue_team_lead: Optional[UUID] = None
|
||||
notes: Optional[str] = None
|
||||
|
||||
|
||||
class ExecutionOut(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: UUID
|
||||
attack_path_id: UUID
|
||||
status: str
|
||||
environment: Optional[str] = None
|
||||
red_team_lead: Optional[UUID] = None
|
||||
blue_team_lead: Optional[UUID] = None
|
||||
started_by: Optional[UUID] = None
|
||||
started_at: Optional[datetime] = None
|
||||
completed_at: Optional[datetime] = None
|
||||
notes: Optional[str] = None
|
||||
created_at: Optional[datetime] = None
|
||||
# metrics
|
||||
total_steps: Optional[int] = None
|
||||
detected_steps: Optional[int] = None
|
||||
not_detected_steps: Optional[int] = None
|
||||
skipped_steps: Optional[int] = None
|
||||
detection_rate: Optional[float] = None
|
||||
mttd_seconds: Optional[float] = None
|
||||
furthest_undetected_step: Optional[int] = None
|
||||
|
||||
|
||||
# ── Step Result ───────────────────────────────────────────────────────────────
|
||||
|
||||
class StepExecuteRequest(BaseModel):
|
||||
status: str # detected / not_detected / skipped
|
||||
executed_at: Optional[datetime] = None
|
||||
detected_at: Optional[datetime] = None
|
||||
detection_asset_id: Optional[UUID] = None
|
||||
notes: Optional[str] = None
|
||||
evidence_ids: Optional[list[UUID]] = None
|
||||
|
||||
@field_validator("status")
|
||||
@classmethod
|
||||
def validate_status(cls, v):
|
||||
valid = ("detected", "not_detected", "skipped", "executing")
|
||||
if v not in valid:
|
||||
raise ValueError(f"status must be one of {valid}")
|
||||
return v
|
||||
|
||||
|
||||
class StepResultOut(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: UUID
|
||||
execution_id: UUID
|
||||
step_id: UUID
|
||||
step_order: int
|
||||
status: str
|
||||
executed_by: Optional[UUID] = None
|
||||
executed_at: Optional[datetime] = None
|
||||
detected_at: Optional[datetime] = None
|
||||
time_to_detect_seconds: Optional[float] = None
|
||||
detection_asset_id: Optional[UUID] = None
|
||||
notes: Optional[str] = None
|
||||
evidence_ids: Optional[list] = None
|
||||
|
||||
|
||||
# ── Timeline ──────────────────────────────────────────────────────────────────
|
||||
|
||||
class TimelineEntryCreate(BaseModel):
|
||||
actor_side: str
|
||||
entry_type: str
|
||||
content: str
|
||||
step_id: Optional[UUID] = None
|
||||
timestamp: Optional[datetime] = None
|
||||
extra: Optional[dict] = None
|
||||
|
||||
@field_validator("actor_side")
|
||||
@classmethod
|
||||
def validate_side(cls, v):
|
||||
if v not in ("red", "blue", "system"):
|
||||
raise ValueError("actor_side must be red, blue or system")
|
||||
return v
|
||||
|
||||
@field_validator("entry_type")
|
||||
@classmethod
|
||||
def validate_type(cls, v):
|
||||
valid = ("action", "detection", "note", "phase_transition", "flag")
|
||||
if v not in valid:
|
||||
raise ValueError(f"entry_type must be one of {valid}")
|
||||
return v
|
||||
|
||||
|
||||
class TimelineEntryOut(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: UUID
|
||||
execution_id: UUID
|
||||
step_id: Optional[UUID] = None
|
||||
timestamp: datetime
|
||||
actor_side: str
|
||||
actor_id: Optional[UUID] = None
|
||||
entry_type: str
|
||||
content: str
|
||||
extra: Optional[dict] = None
|
||||
|
||||
|
||||
# ── Metrics ───────────────────────────────────────────────────────────────────
|
||||
|
||||
class KillChainMetrics(BaseModel):
|
||||
execution_id: UUID
|
||||
total_steps: int
|
||||
detected_steps: int
|
||||
not_detected_steps: int
|
||||
skipped_steps: int
|
||||
detection_rate: float # 0.0–1.0
|
||||
mttd_seconds: Optional[float] # mean time to detect
|
||||
furthest_undetected_step: Optional[int]
|
||||
furthest_undetected_phase: Optional[str]
|
||||
step_breakdown: list[dict] # per-step detail
|
||||
phase_summary: dict # detection rate per kill-chain phase
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
@@ -16,7 +16,7 @@ class AuditLogOut(BaseModel):
|
||||
action: str
|
||||
entity_type: str | None = None
|
||||
entity_id: str | None = None
|
||||
timestamp: datetime
|
||||
timestamp: Optional[datetime] = None
|
||||
details: dict[str, Any] | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@@ -0,0 +1,140 @@
|
||||
"""Pydantic schemas for Detection Lifecycle endpoints."""
|
||||
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
from datetime import datetime
|
||||
from app.models.detection_lifecycle import (
|
||||
DetectionConfidence, DetectionHealthStatus, InvalidationReason
|
||||
)
|
||||
|
||||
|
||||
class DetectionAssetCreate(BaseModel):
|
||||
name: str = Field(..., min_length=3, max_length=500)
|
||||
description: Optional[str] = None
|
||||
asset_type: str = Field(..., pattern=r'^(siem_rule|edr_rule|sigma_rule|yara_rule|spl_query|kql_query|custom_script)$')
|
||||
platform: Optional[str] = None
|
||||
rule_content: Optional[str] = None
|
||||
rule_language: Optional[str] = None
|
||||
rule_repository_url: Optional[str] = None
|
||||
rule_file_path: Optional[str] = None
|
||||
rule_version: Optional[str] = None
|
||||
log_source_name: Optional[str] = None
|
||||
log_source_version: Optional[str] = None
|
||||
log_source_config: Optional[dict] = Field(default_factory=dict)
|
||||
infrastructure_details: Optional[dict] = Field(default_factory=dict)
|
||||
expected_alert_frequency: Optional[str] = None
|
||||
tags: Optional[list[str]] = Field(default_factory=list)
|
||||
technique_ids: Optional[list[UUID]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class DetectionAssetUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
rule_content: Optional[str] = None
|
||||
rule_version: Optional[str] = None
|
||||
log_source_version: Optional[str] = None
|
||||
infrastructure_details: Optional[dict] = None
|
||||
expected_alert_frequency: Optional[str] = None
|
||||
health_status: Optional[DetectionHealthStatus] = None
|
||||
last_alert_at: Optional[datetime] = None
|
||||
alert_count_30d: Optional[int] = None
|
||||
false_positive_rate: Optional[float] = None
|
||||
owner_id: Optional[UUID] = None
|
||||
backup_owner_id: Optional[UUID] = None
|
||||
team: Optional[str] = None
|
||||
tags: Optional[list[str]] = None
|
||||
is_active: Optional[bool] = None
|
||||
|
||||
|
||||
class DetectionAssetOut(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: UUID
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
asset_type: str
|
||||
platform: Optional[str] = None
|
||||
rule_language: Optional[str] = None
|
||||
rule_version: Optional[str] = None
|
||||
rule_hash: Optional[str] = None
|
||||
health_status: DetectionHealthStatus
|
||||
last_alert_at: Optional[datetime] = None
|
||||
alert_count_30d: int
|
||||
false_positive_rate: Optional[float] = None
|
||||
expected_alert_frequency: Optional[str] = None
|
||||
owner_id: Optional[UUID] = None
|
||||
team: Optional[str] = None
|
||||
is_active: bool
|
||||
tags: list = Field(default_factory=list)
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
|
||||
class DetectionValidationCreate(BaseModel):
|
||||
detection_asset_id: UUID
|
||||
technique_id: Optional[UUID] = None
|
||||
test_id: Optional[UUID] = None
|
||||
validation_result: str = Field(..., pattern=r'^(detected|not_detected|partial|error)$')
|
||||
validation_method: str
|
||||
notes: Optional[str] = None
|
||||
evidence_ids: Optional[list[UUID]] = Field(default_factory=list)
|
||||
validity_days: int = Field(default=180, ge=30, le=730)
|
||||
|
||||
|
||||
class DetectionValidationOut(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: UUID
|
||||
detection_asset_id: UUID
|
||||
technique_id: Optional[UUID] = None
|
||||
validated_at: Optional[datetime] = None
|
||||
expires_at: datetime
|
||||
is_valid: bool
|
||||
validation_result: Optional[str] = None
|
||||
validation_method: Optional[str] = None
|
||||
invalidated_at: Optional[datetime] = None
|
||||
invalidation_reason: Optional[InvalidationReason] = None
|
||||
validated_by: Optional[UUID] = None
|
||||
notes: Optional[str] = None
|
||||
|
||||
|
||||
class TechniqueConfidenceOut(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
technique_id: UUID
|
||||
confidence_level: DetectionConfidence
|
||||
confidence_score: float
|
||||
detection_count: int
|
||||
valid_detection_count: int
|
||||
last_validated_at: Optional[datetime] = None
|
||||
next_validation_due: Optional[datetime] = None
|
||||
recency_factor: float
|
||||
coverage_factor: float
|
||||
health_factor: float
|
||||
diversity_factor: float
|
||||
risk_factors: list = Field(default_factory=list)
|
||||
|
||||
|
||||
class InfrastructureChangeCreate(BaseModel):
|
||||
change_type: str
|
||||
description: str = Field(..., min_length=10)
|
||||
affected_platforms: list[str] = Field(default_factory=list)
|
||||
affected_log_sources: list[str] = Field(default_factory=list)
|
||||
change_date: Optional[datetime] = None
|
||||
auto_invalidate: bool = True
|
||||
|
||||
|
||||
class InfrastructureChangeOut(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: UUID
|
||||
change_type: str
|
||||
description: str
|
||||
affected_platforms: list = Field(default_factory=list)
|
||||
affected_log_sources: list = Field(default_factory=list)
|
||||
change_date: Optional[datetime] = None
|
||||
auto_invalidate: bool
|
||||
invalidated_count: int
|
||||
reported_by: Optional[UUID] = None
|
||||
created_at: Optional[datetime] = None
|
||||
@@ -0,0 +1,113 @@
|
||||
"""Phase 13: Executive Dashboard — Pydantic schemas."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date, datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class PostureSnapshotOut(BaseModel):
|
||||
id: UUID
|
||||
snapshot_date: date
|
||||
|
||||
# Coverage
|
||||
total_techniques: int
|
||||
validated_count: int
|
||||
partial_count: int
|
||||
not_covered_count: int
|
||||
coverage_pct: float
|
||||
|
||||
# Risk
|
||||
avg_risk_score: float
|
||||
critical_count: int
|
||||
high_count: int
|
||||
medium_count: int
|
||||
low_count: int
|
||||
|
||||
# Operations
|
||||
open_queue_items: int
|
||||
orphan_techniques: int
|
||||
|
||||
# Knowledge
|
||||
playbook_count: int
|
||||
lesson_count: int
|
||||
|
||||
# MTTD
|
||||
mttd_avg_seconds: Optional[float] = None
|
||||
executions_30d: int
|
||||
detection_rate_30d: Optional[float] = None
|
||||
|
||||
# Meta
|
||||
created_by: Optional[UUID] = None
|
||||
created_at: Optional[datetime] = None
|
||||
extra: Optional[Dict[str, Any]] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class ExecutiveSummary(BaseModel):
|
||||
"""Full executive view — current posture + trends."""
|
||||
snapshot: PostureSnapshotOut
|
||||
coverage_trend: List[Dict[str, Any]] = Field(
|
||||
default_factory=list,
|
||||
description="Last 30-day coverage_pct series [{date, value}]",
|
||||
)
|
||||
risk_trend: List[Dict[str, Any]] = Field(
|
||||
default_factory=list,
|
||||
description="Last 30-day avg_risk_score series [{date, value}]",
|
||||
)
|
||||
top_risks: List[Dict[str, Any]] = Field(
|
||||
default_factory=list,
|
||||
description="Top 5 highest-risk techniques",
|
||||
)
|
||||
coverage_by_tactic: List[Dict[str, Any]] = Field(
|
||||
default_factory=list,
|
||||
description="Per-tactic validated/partial/not_covered counts",
|
||||
)
|
||||
recent_activity: List[Dict[str, Any]] = Field(
|
||||
default_factory=list,
|
||||
description="Most-recent events (tests, paths, queue changes)",
|
||||
)
|
||||
|
||||
|
||||
class KpiBlock(BaseModel):
|
||||
"""Compact KPI block for a dashboard header."""
|
||||
coverage_pct: float
|
||||
avg_risk_score: float
|
||||
critical_count: int
|
||||
open_queue_items: int
|
||||
orphan_techniques: int
|
||||
mttd_avg_seconds: Optional[float] = None
|
||||
detection_rate_30d: Optional[float] = None
|
||||
playbook_count: int
|
||||
lesson_count: int
|
||||
snapshot_date: date
|
||||
snapshot_id: Optional[UUID] = None
|
||||
|
||||
|
||||
class CoverageByTactic(BaseModel):
|
||||
tactic: str
|
||||
total: int
|
||||
validated: int
|
||||
partial: int
|
||||
not_covered: int
|
||||
coverage_pct: float
|
||||
|
||||
|
||||
class PostureHistoryEntry(BaseModel):
|
||||
snapshot_date: date
|
||||
coverage_pct: float
|
||||
avg_risk_score: float
|
||||
critical_count: int
|
||||
open_queue_items: int
|
||||
|
||||
|
||||
class ActivityEntry(BaseModel):
|
||||
ts: datetime
|
||||
category: str # "test" | "attack_path" | "queue" | "osint"
|
||||
title: str
|
||||
detail: Optional[str] = None
|
||||
@@ -0,0 +1,149 @@
|
||||
"""Phase 11: Knowledge Management schemas — Playbooks + Lessons Learned."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, field_validator
|
||||
|
||||
# ── Constants ─────────────────────────────────────────────────────────────────
|
||||
|
||||
VALID_PLAYBOOK_TYPES = ["attack", "detect", "investigate", "respond", "hunt"]
|
||||
VALID_SEVERITIES = ["critical", "high", "medium", "low", "info"]
|
||||
VALID_ENTITY_TYPES = ["test", "campaign", "attack_path", "manual"]
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════════════════════════════
|
||||
# Playbook schemas
|
||||
# ══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
class PlaybookCreate(BaseModel):
|
||||
technique_id: UUID
|
||||
playbook_type: str
|
||||
title: str
|
||||
content: str = ""
|
||||
tools: List[str] = []
|
||||
prerequisites: List[str] = []
|
||||
change_note: Optional[str] = None
|
||||
|
||||
@field_validator("playbook_type")
|
||||
@classmethod
|
||||
def validate_playbook_type(cls, v: str) -> str:
|
||||
if v not in VALID_PLAYBOOK_TYPES:
|
||||
raise ValueError(
|
||||
f"Invalid playbook_type '{v}'. Must be one of: {VALID_PLAYBOOK_TYPES}"
|
||||
)
|
||||
return v
|
||||
|
||||
|
||||
class PlaybookUpdate(BaseModel):
|
||||
title: Optional[str] = None
|
||||
content: Optional[str] = None
|
||||
tools: Optional[List[str]] = None
|
||||
prerequisites: Optional[List[str]] = None
|
||||
change_note: Optional[str] = None
|
||||
|
||||
|
||||
class PlaybookVersionOut(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: UUID
|
||||
playbook_id: UUID
|
||||
version: int
|
||||
title: str
|
||||
content: str
|
||||
tools: List[str] = []
|
||||
prerequisites: List[str] = []
|
||||
changed_by: Optional[UUID]
|
||||
change_note: Optional[str]
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class PlaybookOut(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: UUID
|
||||
technique_id: UUID
|
||||
playbook_type: str
|
||||
title: str
|
||||
content: str
|
||||
version: int
|
||||
tools: List[str] = []
|
||||
prerequisites: List[str] = []
|
||||
created_by: Optional[UUID]
|
||||
updated_by: Optional[UUID]
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
is_active: bool
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════════════════════════════
|
||||
# Lesson Learned schemas
|
||||
# ══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
class LessonLearnedCreate(BaseModel):
|
||||
title: str
|
||||
what_happened: str
|
||||
root_cause: str
|
||||
fix_applied: Optional[str] = None
|
||||
severity: str = "medium"
|
||||
entity_type: str = "manual"
|
||||
entity_id: Optional[UUID] = None
|
||||
technique_ids: List[str] = []
|
||||
tags: List[str] = []
|
||||
|
||||
@field_validator("severity")
|
||||
@classmethod
|
||||
def validate_severity(cls, v: str) -> str:
|
||||
if v not in VALID_SEVERITIES:
|
||||
raise ValueError(
|
||||
f"Invalid severity '{v}'. Must be one of: {VALID_SEVERITIES}"
|
||||
)
|
||||
return v
|
||||
|
||||
@field_validator("entity_type")
|
||||
@classmethod
|
||||
def validate_entity_type(cls, v: str) -> str:
|
||||
if v not in VALID_ENTITY_TYPES:
|
||||
raise ValueError(
|
||||
f"Invalid entity_type '{v}'. Must be one of: {VALID_ENTITY_TYPES}"
|
||||
)
|
||||
return v
|
||||
|
||||
|
||||
class LessonLearnedUpdate(BaseModel):
|
||||
title: Optional[str] = None
|
||||
what_happened: Optional[str] = None
|
||||
root_cause: Optional[str] = None
|
||||
fix_applied: Optional[str] = None
|
||||
severity: Optional[str] = None
|
||||
technique_ids: Optional[List[str]] = None
|
||||
tags: Optional[List[str]] = None
|
||||
|
||||
@field_validator("severity")
|
||||
@classmethod
|
||||
def validate_severity(cls, v: Optional[str]) -> Optional[str]:
|
||||
if v is not None and v not in VALID_SEVERITIES:
|
||||
raise ValueError(
|
||||
f"Invalid severity '{v}'. Must be one of: {VALID_SEVERITIES}"
|
||||
)
|
||||
return v
|
||||
|
||||
|
||||
class LessonLearnedOut(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: UUID
|
||||
title: str
|
||||
what_happened: str
|
||||
root_cause: str
|
||||
fix_applied: Optional[str]
|
||||
severity: str
|
||||
entity_type: str
|
||||
entity_id: Optional[UUID]
|
||||
technique_ids: List[str] = []
|
||||
tags: List[str] = []
|
||||
created_by: Optional[UUID]
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
is_active: bool
|
||||
@@ -0,0 +1,124 @@
|
||||
"""Phase 13: Operational Alerts — Pydantic schemas."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from app.models.operational_alert import AlertRuleType, AlertSeverity, AlertStatus
|
||||
|
||||
VALID_SEVERITIES = {s.value for s in AlertSeverity}
|
||||
VALID_STATUSES = {s.value for s in AlertStatus}
|
||||
VALID_RULE_TYPES = {r.value for r in AlertRuleType}
|
||||
|
||||
|
||||
# ── AlertRule schemas ─────────────────────────────────────────────────────────
|
||||
|
||||
class AlertRuleCreate(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=300)
|
||||
description: Optional[str] = None
|
||||
rule_type: str
|
||||
severity: str = "medium"
|
||||
config: Dict[str, Any] = Field(default_factory=dict)
|
||||
notify_in_app: bool = True
|
||||
notify_webhook: bool = False
|
||||
webhook_id: Optional[UUID] = None
|
||||
cooldown_hours: int = Field(24, ge=0, le=8760)
|
||||
|
||||
@field_validator("rule_type")
|
||||
@classmethod
|
||||
def validate_rule_type(cls, v: str) -> str:
|
||||
if v not in VALID_RULE_TYPES:
|
||||
raise ValueError(f"Invalid rule_type. Valid: {VALID_RULE_TYPES}")
|
||||
return v
|
||||
|
||||
@field_validator("severity")
|
||||
@classmethod
|
||||
def validate_severity(cls, v: str) -> str:
|
||||
if v not in VALID_SEVERITIES:
|
||||
raise ValueError(f"Invalid severity. Valid: {VALID_SEVERITIES}")
|
||||
return v
|
||||
|
||||
|
||||
class AlertRuleUpdate(BaseModel):
|
||||
name: Optional[str] = Field(None, min_length=1, max_length=300)
|
||||
description: Optional[str] = None
|
||||
severity: Optional[str] = None
|
||||
is_enabled: Optional[bool] = None
|
||||
config: Optional[Dict[str, Any]] = None
|
||||
notify_in_app: Optional[bool] = None
|
||||
notify_webhook: Optional[bool] = None
|
||||
webhook_id: Optional[UUID] = None
|
||||
cooldown_hours: Optional[int] = Field(None, ge=0, le=8760)
|
||||
|
||||
@field_validator("severity")
|
||||
@classmethod
|
||||
def validate_severity(cls, v: Optional[str]) -> Optional[str]:
|
||||
if v is not None and v not in VALID_SEVERITIES:
|
||||
raise ValueError(f"Invalid severity. Valid: {VALID_SEVERITIES}")
|
||||
return v
|
||||
|
||||
|
||||
class AlertRuleOut(BaseModel):
|
||||
id: UUID
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
rule_type: str
|
||||
severity: str
|
||||
is_enabled: bool
|
||||
is_system: bool
|
||||
config: Dict[str, Any]
|
||||
notify_in_app: bool
|
||||
notify_webhook: bool
|
||||
webhook_id: Optional[UUID] = None
|
||||
cooldown_hours: int
|
||||
created_by: Optional[UUID] = None
|
||||
created_at: Optional[datetime] = None
|
||||
last_fired_at: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
# ── AlertInstance schemas ─────────────────────────────────────────────────────
|
||||
|
||||
class AlertInstanceOut(BaseModel):
|
||||
id: UUID
|
||||
rule_id: Optional[UUID] = None
|
||||
rule_name: str
|
||||
rule_type: str
|
||||
severity: str
|
||||
title: str
|
||||
message: str
|
||||
details: Optional[Dict[str, Any]] = None
|
||||
status: str
|
||||
acknowledged_by: Optional[UUID] = None
|
||||
acknowledged_at: Optional[datetime] = None
|
||||
resolved_at: Optional[datetime] = None
|
||||
created_at: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
# ── Evaluation result ─────────────────────────────────────────────────────────
|
||||
|
||||
class EvaluationResult(BaseModel):
|
||||
rules_evaluated: int
|
||||
alerts_fired: int
|
||||
alerts: List[AlertInstanceOut] = Field(default_factory=list)
|
||||
duration_seconds: float
|
||||
|
||||
|
||||
# ── Summary ───────────────────────────────────────────────────────────────────
|
||||
|
||||
class AlertSummary(BaseModel):
|
||||
total_open: int
|
||||
total_acknowledged: int
|
||||
total_resolved: int
|
||||
by_severity: Dict[str, int]
|
||||
by_rule_type: Dict[str, int]
|
||||
recent_alerts: List[AlertInstanceOut] = Field(default_factory=list)
|
||||
@@ -0,0 +1,153 @@
|
||||
"""Pydantic schemas for Phase 9: Ownership & Revalidation Queue."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, field_validator
|
||||
|
||||
|
||||
# ── Technique Ownership ───────────────────────────────────────────────────────
|
||||
|
||||
class TechniqueOwnershipSet(BaseModel):
|
||||
"""Set (create or replace) ownership for a technique."""
|
||||
owner_id: Optional[UUID] = None
|
||||
backup_owner_id: Optional[UUID] = None
|
||||
team: Optional[str] = None
|
||||
notes: Optional[str] = None
|
||||
|
||||
|
||||
class TechniqueOwnershipOut(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: UUID
|
||||
technique_id: UUID
|
||||
owner_id: Optional[UUID] = None
|
||||
backup_owner_id: Optional[UUID] = None
|
||||
team: Optional[str] = None
|
||||
notes: Optional[str] = None
|
||||
assigned_at: Optional[datetime] = None
|
||||
assigned_by: Optional[UUID] = None
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
|
||||
class DetectionAssetOwnershipPatch(BaseModel):
|
||||
"""Update ownership fields on a detection asset."""
|
||||
owner_id: Optional[UUID] = None
|
||||
backup_owner_id: Optional[UUID] = None
|
||||
team: Optional[str] = None
|
||||
|
||||
|
||||
# ── Bulk Assignment ───────────────────────────────────────────────────────────
|
||||
|
||||
class BulkAssignRequest(BaseModel):
|
||||
"""Bulk-assign ownership by tactic, platform, or team filter."""
|
||||
owner_id: Optional[UUID] = None
|
||||
backup_owner_id: Optional[UUID] = None
|
||||
team: Optional[str] = None
|
||||
# Filters — at least one must be set
|
||||
tactic: Optional[str] = None # assign all techniques with this tactic
|
||||
platform: Optional[str] = None # assign all detection assets with this platform
|
||||
overwrite: bool = False # overwrite existing assignments
|
||||
|
||||
|
||||
class BulkAssignResult(BaseModel):
|
||||
assigned_count: int
|
||||
skipped_count: int
|
||||
target_type: str # "technique" or "detection_asset"
|
||||
|
||||
|
||||
# ── Revalidation Queue ────────────────────────────────────────────────────────
|
||||
|
||||
class QueueItemPatch(BaseModel):
|
||||
"""Update a revalidation queue item."""
|
||||
status: Optional[str] = None
|
||||
assigned_to: Optional[UUID] = None
|
||||
priority: Optional[str] = None
|
||||
due_date: Optional[datetime] = None
|
||||
|
||||
@field_validator("status")
|
||||
@classmethod
|
||||
def validate_status(cls, v):
|
||||
from app.models.ownership_queue import QueueStatus
|
||||
if v is not None:
|
||||
try:
|
||||
QueueStatus(v)
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid status: {v}")
|
||||
return v
|
||||
|
||||
@field_validator("priority")
|
||||
@classmethod
|
||||
def validate_priority(cls, v):
|
||||
from app.models.ownership_queue import QueuePriority
|
||||
if v is not None:
|
||||
try:
|
||||
QueuePriority(v)
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid priority: {v}")
|
||||
return v
|
||||
|
||||
|
||||
class QueueItemCreate(BaseModel):
|
||||
"""Manually create a queue item."""
|
||||
technique_id: Optional[UUID] = None
|
||||
detection_asset_id: Optional[UUID] = None
|
||||
priority: str = "medium"
|
||||
reason: str = "manual"
|
||||
reason_detail: Optional[str] = None
|
||||
assigned_to: Optional[UUID] = None
|
||||
due_date: Optional[datetime] = None
|
||||
|
||||
@field_validator("reason")
|
||||
@classmethod
|
||||
def validate_reason(cls, v):
|
||||
from app.models.ownership_queue import QueueReason
|
||||
try:
|
||||
QueueReason(v)
|
||||
except ValueError:
|
||||
valid = [e.value for e in QueueReason]
|
||||
raise ValueError(f"Invalid reason '{v}'. Must be one of: {valid}")
|
||||
return v
|
||||
|
||||
@field_validator("priority")
|
||||
@classmethod
|
||||
def validate_priority(cls, v):
|
||||
from app.models.ownership_queue import QueuePriority
|
||||
try:
|
||||
QueuePriority(v)
|
||||
except ValueError:
|
||||
valid = [e.value for e in QueuePriority]
|
||||
raise ValueError(f"Invalid priority '{v}'. Must be one of: {valid}")
|
||||
return v
|
||||
|
||||
|
||||
class QueueItemOut(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: UUID
|
||||
technique_id: Optional[UUID] = None
|
||||
detection_asset_id: Optional[UUID] = None
|
||||
priority: str
|
||||
reason: str
|
||||
reason_detail: Optional[str] = None
|
||||
status: str
|
||||
assigned_to: Optional[UUID] = None
|
||||
due_date: Optional[datetime] = None
|
||||
created_at: Optional[datetime] = None
|
||||
completed_at: Optional[datetime] = None
|
||||
dismissed_at: Optional[datetime] = None
|
||||
completed_by: Optional[UUID] = None
|
||||
extra: Optional[dict] = None
|
||||
|
||||
|
||||
# ── Analyst Dashboard ─────────────────────────────────────────────────────────
|
||||
|
||||
class AnalystDashboard(BaseModel):
|
||||
"""Personalised daily workday view for an analyst."""
|
||||
my_pending_items: list[QueueItemOut]
|
||||
expiring_validations_7d: list[dict]
|
||||
recent_infra_changes: list[dict]
|
||||
my_low_confidence_techniques: list[dict]
|
||||
summary: dict
|
||||
@@ -0,0 +1,71 @@
|
||||
"""Phase 12: Risk Intelligence schemas."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
VALID_RISK_LEVELS = ["critical", "high", "medium", "low", "info"]
|
||||
|
||||
|
||||
class TechniqueRiskProfileOut(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: UUID
|
||||
technique_id: UUID
|
||||
risk_score: float
|
||||
likelihood: float
|
||||
impact: float
|
||||
risk_level: str
|
||||
detection_gap: float
|
||||
threat_actor_count: int
|
||||
osint_signal_count: int
|
||||
test_fail_count: int
|
||||
test_total_count: int
|
||||
test_failure_rate: float
|
||||
confidence_level: float
|
||||
scoring_breakdown: Optional[Dict[str, Any]]
|
||||
recommendations: Optional[List[str]]
|
||||
computed_at: datetime
|
||||
is_stale: bool
|
||||
|
||||
|
||||
class RiskMatrixEntry(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
technique_id: UUID
|
||||
technique_name: Optional[str] = None
|
||||
technique_tid: Optional[str] = None # e.g. "T1059"
|
||||
risk_score: float
|
||||
likelihood: float
|
||||
impact: float
|
||||
risk_level: str
|
||||
detection_gap: float
|
||||
computed_at: datetime
|
||||
|
||||
|
||||
class RiskSummary(BaseModel):
|
||||
total_techniques: int
|
||||
scored_techniques: int
|
||||
stale_count: int
|
||||
by_level: Dict[str, int] # {"critical": 3, "high": 12, ...}
|
||||
avg_risk_score: float
|
||||
top_risks: List[RiskMatrixEntry]
|
||||
|
||||
|
||||
class RecommendationItem(BaseModel):
|
||||
technique_id: UUID
|
||||
technique_name: Optional[str] = None
|
||||
technique_tid: Optional[str] = None
|
||||
risk_level: str
|
||||
risk_score: float
|
||||
recommendations: List[str]
|
||||
priority: int # 1 = highest
|
||||
|
||||
|
||||
class ComputeResult(BaseModel):
|
||||
computed: int
|
||||
skipped: int
|
||||
errors: int
|
||||
duration_seconds: float
|
||||
@@ -0,0 +1,80 @@
|
||||
"""Phase 14: SSO / SAML 2.0 Pydantic schemas."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SsoConfigCreate(BaseModel):
|
||||
is_enabled: bool = False
|
||||
provider_name: Optional[str] = None
|
||||
|
||||
# SP settings (auto-derived if not provided)
|
||||
sp_entity_id: Optional[str] = None
|
||||
sp_acs_url: Optional[str] = None
|
||||
sp_slo_url: Optional[str] = None
|
||||
sp_certificate: Optional[str] = None
|
||||
sp_private_key: Optional[str] = None
|
||||
|
||||
# IdP settings
|
||||
idp_entity_id: Optional[str] = None
|
||||
idp_sso_url: Optional[str] = None
|
||||
idp_slo_url: Optional[str] = None
|
||||
idp_certificate: Optional[str] = None
|
||||
|
||||
# Attribute mapping
|
||||
attr_email: Optional[str] = "email"
|
||||
attr_username: Optional[str] = "username"
|
||||
attr_role: Optional[str] = "role"
|
||||
default_role: Optional[str] = "viewer"
|
||||
auto_provision: bool = True
|
||||
|
||||
|
||||
class SsoConfigUpdate(SsoConfigCreate):
|
||||
"""All fields optional for partial updates."""
|
||||
pass
|
||||
|
||||
|
||||
class SsoConfigOut(BaseModel):
|
||||
id: UUID
|
||||
is_enabled: bool
|
||||
provider_name: Optional[str] = None
|
||||
|
||||
sp_entity_id: Optional[str] = None
|
||||
sp_acs_url: Optional[str] = None
|
||||
sp_slo_url: Optional[str] = None
|
||||
sp_certificate: Optional[str] = None
|
||||
# sp_private_key is intentionally OMITTED from responses
|
||||
|
||||
idp_entity_id: Optional[str] = None
|
||||
idp_sso_url: Optional[str] = None
|
||||
idp_slo_url: Optional[str] = None
|
||||
idp_certificate: Optional[str] = None
|
||||
|
||||
attr_email: Optional[str] = None
|
||||
attr_username: Optional[str] = None
|
||||
attr_role: Optional[str] = None
|
||||
default_role: Optional[str] = None
|
||||
auto_provision: bool = True
|
||||
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class SsoLoginInitResponse(BaseModel):
|
||||
redirect_url: str = Field(..., description="URL to redirect the browser to for IdP login")
|
||||
request_id: str = Field(..., description="SAML AuthnRequest ID for validation")
|
||||
|
||||
|
||||
class SsoStatusResponse(BaseModel):
|
||||
enabled: bool
|
||||
provider_name: Optional[str] = None
|
||||
configured: bool = Field(..., description="True if IdP settings are present")
|
||||
login_url: Optional[str] = None # /sso/login URL
|
||||
@@ -3,10 +3,11 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic import BaseModel, ConfigDict, model_validator
|
||||
|
||||
from app.domain.enums import DataClassification
|
||||
from app.models.enums import TestResult, TestState
|
||||
from app.schemas.evidence import EvidenceOut
|
||||
|
||||
|
||||
# ── Create ──────────────────────────────────────────────────────────
|
||||
@@ -147,6 +148,7 @@ class TestOut(BaseModel):
|
||||
# Phase timing fields (for Tempo worklogs)
|
||||
red_started_at: datetime | None = None
|
||||
blue_started_at: datetime | None = None
|
||||
blue_work_started_at: datetime | None = None
|
||||
paused_at: datetime | None = None
|
||||
red_paused_seconds: int = 0
|
||||
blue_paused_seconds: int = 0
|
||||
@@ -165,12 +167,64 @@ class TestOut(BaseModel):
|
||||
technique_mitre_id: str | None = None
|
||||
technique_name: str | None = None
|
||||
|
||||
# Evidences split by team (populated from the ORM relationship)
|
||||
red_evidences: list[EvidenceOut] = []
|
||||
blue_evidences: list[EvidenceOut] = []
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def model_validate(cls, obj, **kwargs):
|
||||
"""Override to populate technique fields from the relationship."""
|
||||
if hasattr(obj, "technique") and obj.technique is not None:
|
||||
obj.__dict__["technique_mitre_id"] = obj.technique.mitre_id
|
||||
obj.__dict__["technique_name"] = obj.technique.name
|
||||
return super().model_validate(obj, **kwargs)
|
||||
def _populate_derived_fields(cls, obj):
|
||||
"""Populate technique and evidence fields from ORM relationships.
|
||||
|
||||
Uses ``@model_validator(mode='before')`` so it is called by Pydantic's
|
||||
internal Rust validation pipeline, including FastAPI's TypeAdapter path.
|
||||
A plain ``model_validate`` classmethod override is **not** invoked by
|
||||
FastAPI's response serialisation in Pydantic v2 — only registered
|
||||
validators are guaranteed to run.
|
||||
|
||||
Evidences are only processed when the relationship was **explicitly loaded**
|
||||
(via joinedload or prior access). Accessing ``obj.evidences`` blindly on a
|
||||
session-expired ORM object triggers a lazy query that fails on mutation
|
||||
endpoints that do not joinload the relationship. We inspect ``obj.__dict__``
|
||||
directly — SQLAlchemy stores loaded relationships there; if the key is absent
|
||||
the relationship is unloaded and we leave the lists empty (the frontend
|
||||
invalidates and refetches the detail endpoint, which *does* joinload).
|
||||
"""
|
||||
if not hasattr(obj, "__dict__"):
|
||||
return obj
|
||||
|
||||
# Technique info (lazy-load is fine here: session is still open on GET)
|
||||
try:
|
||||
if hasattr(obj, "technique") and obj.technique is not None:
|
||||
obj.__dict__["technique_mitre_id"] = obj.technique.mitre_id
|
||||
obj.__dict__["technique_name"] = obj.technique.name
|
||||
except Exception:
|
||||
pass # DetachedInstanceError or similar — leave technique fields None
|
||||
|
||||
# Only split evidences when they are already in memory (loaded via joinedload)
|
||||
raw_evs = obj.__dict__.get("evidences")
|
||||
if raw_evs is not None:
|
||||
red_evs: list[EvidenceOut] = []
|
||||
blue_evs: list[EvidenceOut] = []
|
||||
for ev in raw_evs:
|
||||
ev_out = EvidenceOut(
|
||||
id=ev.id,
|
||||
test_id=ev.test_id,
|
||||
file_name=ev.file_name,
|
||||
sha256_hash=ev.sha256_hash,
|
||||
uploaded_by=ev.uploaded_by,
|
||||
uploaded_at=ev.uploaded_at,
|
||||
team=ev.team,
|
||||
notes=ev.notes,
|
||||
download_url=f"/api/v1/evidence/{ev.id}/file",
|
||||
)
|
||||
if ev.team and ev.team.value == "blue":
|
||||
blue_evs.append(ev_out)
|
||||
else:
|
||||
red_evs.append(ev_out)
|
||||
obj.__dict__["red_evidences"] = red_evs
|
||||
obj.__dict__["blue_evidences"] = blue_evs
|
||||
|
||||
return obj
|
||||
|
||||
@@ -72,7 +72,17 @@ class TestTemplateSummary(BaseModel):
|
||||
|
||||
|
||||
class TestTemplateInstantiate(BaseModel):
|
||||
"""Payload to create a real test from an existing template."""
|
||||
"""Payload to create a real test from an existing template.
|
||||
|
||||
Optional override fields take precedence over the template values when provided.
|
||||
"""
|
||||
|
||||
template_id: uuid.UUID
|
||||
technique_id: str # accepts both UUID and MITRE ID (e.g. "T1059.001")
|
||||
|
||||
# User-editable overrides (if omitted the template value is used)
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
platform: str | None = None
|
||||
procedure_text: str | None = None
|
||||
tool_used: str | None = None
|
||||
|
||||
@@ -4,7 +4,7 @@ import re
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, EmailStr, field_validator
|
||||
from pydantic import BaseModel, ConfigDict, EmailStr, Field, field_validator, model_validator
|
||||
|
||||
|
||||
# ── Username policy ─────────────────────────────────────────────────
|
||||
@@ -121,6 +121,22 @@ class PasswordChange(BaseModel):
|
||||
return _validate_password_strength(v)
|
||||
|
||||
|
||||
class UserPreferencesUpdate(BaseModel):
|
||||
"""Payload for updating current user's notification preferences and Jira/Tempo settings."""
|
||||
|
||||
notification_preferences: dict | None = None
|
||||
jira_account_id: str | None = None
|
||||
# Personal Jira API token (Atlassian token) — write-only.
|
||||
# Set to empty string "" to clear the token.
|
||||
jira_api_token: str | None = None
|
||||
# Atlassian email for Jira auth — overrides account email.
|
||||
# Set to empty string "" to clear (falls back to account email).
|
||||
jira_email: str | None = None
|
||||
# Personal Tempo API token — write-only.
|
||||
# Set to empty string "" to clear the token.
|
||||
tempo_api_token: str | None = None
|
||||
|
||||
|
||||
class UserOut(BaseModel):
|
||||
"""Complete representation returned by the API."""
|
||||
|
||||
@@ -132,5 +148,26 @@ class UserOut(BaseModel):
|
||||
must_change_password: bool = True
|
||||
created_at: datetime | None = None
|
||||
last_login: datetime | None = None
|
||||
notification_preferences: dict | None = None
|
||||
jira_account_id: str | None = None
|
||||
jira_email: str | None = None
|
||||
# Read from ORM but NEVER exposed in responses — used only to derive *_token_set flags.
|
||||
jira_api_token: str | None = Field(default=None, exclude=True)
|
||||
tempo_api_token: str | None = Field(default=None, exclude=True)
|
||||
# True when the user has the respective token stored.
|
||||
jira_token_set: bool = False
|
||||
tempo_token_set: bool = False
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _derive_token_set_flags(self) -> "UserOut":
|
||||
"""Derive *_token_set booleans from the (excluded) raw token fields.
|
||||
|
||||
Uses @model_validator(mode='after') so Pydantic's Rust core calls it
|
||||
during FastAPI response serialisation — model_validate() overrides are
|
||||
bypassed by FastAPI's __pydantic_validator__.validate_python() path.
|
||||
"""
|
||||
self.jira_token_set = bool(self.jira_api_token)
|
||||
self.tempo_token_set = bool(self.tempo_api_token)
|
||||
return self
|
||||
|
||||
@@ -0,0 +1,92 @@
|
||||
"""Pydantic schemas for Webhook endpoints."""
|
||||
import ipaddress
|
||||
import socket
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, field_validator
|
||||
|
||||
# RFC-5735 / RFC-1918 / RFC-3927 — ranges that must never be webhook targets
|
||||
_BLOCKED_NETWORKS = [
|
||||
ipaddress.ip_network("10.0.0.0/8"),
|
||||
ipaddress.ip_network("172.16.0.0/12"),
|
||||
ipaddress.ip_network("192.168.0.0/16"),
|
||||
ipaddress.ip_network("169.254.0.0/16"), # link-local / AWS IMDS
|
||||
ipaddress.ip_network("127.0.0.0/8"), # loopback
|
||||
ipaddress.ip_network("::1/128"), # IPv6 loopback
|
||||
ipaddress.ip_network("fc00::/7"), # IPv6 ULA
|
||||
]
|
||||
|
||||
|
||||
def _validate_webhook_url(url: str) -> str:
|
||||
"""Reject URLs that point to private/reserved addresses (SSRF prevention)."""
|
||||
parsed = urlparse(url)
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
raise ValueError("Webhook URL must use http or https")
|
||||
hostname = parsed.hostname
|
||||
if not hostname:
|
||||
raise ValueError("Webhook URL must include a hostname")
|
||||
|
||||
# Resolve hostname to IP(s) and reject any private/reserved address
|
||||
try:
|
||||
infos = socket.getaddrinfo(hostname, None)
|
||||
for info in infos:
|
||||
raw_ip = info[4][0]
|
||||
try:
|
||||
ip_obj = ipaddress.ip_address(raw_ip)
|
||||
except ValueError:
|
||||
continue
|
||||
for network in _BLOCKED_NETWORKS:
|
||||
if ip_obj in network:
|
||||
raise ValueError(
|
||||
f"Webhook URL resolves to a private/reserved address ({raw_ip}) "
|
||||
"and cannot be used"
|
||||
)
|
||||
except OSError:
|
||||
# DNS resolution failure — allow (will fail at dispatch time)
|
||||
pass
|
||||
|
||||
return url
|
||||
|
||||
|
||||
class WebhookConfigCreate(BaseModel):
|
||||
name: str
|
||||
url: str
|
||||
secret: str | None = None
|
||||
events: list[str] = []
|
||||
is_active: bool = True
|
||||
|
||||
@field_validator("url")
|
||||
@classmethod
|
||||
def url_must_be_external(cls, v: str) -> str:
|
||||
return _validate_webhook_url(v)
|
||||
|
||||
|
||||
class WebhookConfigUpdate(BaseModel):
|
||||
name: str | None = None
|
||||
url: str | None = None
|
||||
secret: str | None = None
|
||||
events: list[str] | None = None
|
||||
is_active: bool | None = None
|
||||
|
||||
@field_validator("url")
|
||||
@classmethod
|
||||
def url_must_be_external(cls, v: str | None) -> str | None:
|
||||
if v is None:
|
||||
return v
|
||||
return _validate_webhook_url(v)
|
||||
|
||||
class WebhookConfigOut(BaseModel):
|
||||
id: uuid.UUID
|
||||
name: str
|
||||
url: str
|
||||
secret: str | None = None # masked on read
|
||||
events: list[str]
|
||||
is_active: bool
|
||||
created_by: uuid.UUID | None = None
|
||||
last_triggered_at: datetime | None = None
|
||||
failure_count: int
|
||||
created_at: datetime | None = None
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
@@ -0,0 +1,39 @@
|
||||
"""Seed default decay policies."""
|
||||
from datetime import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
from app.models.decay_policy import DecayPolicy
|
||||
|
||||
|
||||
def seed_decay_policies(db: Session) -> None:
|
||||
existing = db.query(DecayPolicy).filter(DecayPolicy.is_default == True).first()
|
||||
if existing:
|
||||
return
|
||||
|
||||
now = datetime.utcnow()
|
||||
default_policy = DecayPolicy(
|
||||
name="Default Decay Policy",
|
||||
description="Standard: Fresh 90d, Aging 91-180d, Stale 181-365d.",
|
||||
fresh_days=90, aging_days=180, stale_days=365,
|
||||
default_validity_days=180, silent_threshold_days=30,
|
||||
noisy_threshold_daily=100,
|
||||
recency_weight=0.30, coverage_weight=0.30,
|
||||
health_weight=0.25, diversity_weight=0.15,
|
||||
is_default=True, is_active=True,
|
||||
created_at=now, updated_at=now,
|
||||
)
|
||||
db.add(default_policy)
|
||||
|
||||
critical_policy = DecayPolicy(
|
||||
name="Critical Techniques Policy",
|
||||
description="Stricter: Fresh 60d, Aging 90d, Stale 180d.",
|
||||
applies_to_tactic="initial-access",
|
||||
fresh_days=60, aging_days=90, stale_days=180,
|
||||
default_validity_days=90, silent_threshold_days=14,
|
||||
noisy_threshold_daily=50,
|
||||
recency_weight=0.35, coverage_weight=0.30,
|
||||
health_weight=0.25, diversity_weight=0.10,
|
||||
is_default=False, is_active=True,
|
||||
created_at=now, updated_at=now,
|
||||
)
|
||||
db.add(critical_policy)
|
||||
db.commit()
|
||||
@@ -0,0 +1,155 @@
|
||||
"""Phase 14: API Key service — create, list, revoke, authenticate."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.domain.errors import EntityNotFoundError, DuplicateEntityError
|
||||
from app.models.api_key import ApiKey, generate_raw_key, hash_key, key_prefix_display
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
# ── Create ────────────────────────────────────────────────────────────────────
|
||||
|
||||
def create_api_key(
|
||||
db: Session,
|
||||
user_id: UUID,
|
||||
name: str,
|
||||
scopes: List[str],
|
||||
description: Optional[str] = None,
|
||||
expires_at: Optional[datetime] = None,
|
||||
) -> tuple[ApiKey, str]:
|
||||
"""
|
||||
Create a new API key.
|
||||
|
||||
Returns ``(ApiKey, raw_key)`` — the raw_key must be shown to the user
|
||||
immediately and is never retrievable again.
|
||||
"""
|
||||
raw_key = generate_raw_key()
|
||||
key_hash = hash_key(raw_key)
|
||||
prefix = key_prefix_display(raw_key)
|
||||
|
||||
# Detect accidental collision (astronomically unlikely)
|
||||
if db.query(ApiKey).filter(ApiKey.key_hash == key_hash).first():
|
||||
raise DuplicateEntityError("ApiKey", "key_hash", "<collision>")
|
||||
|
||||
key = ApiKey(
|
||||
name = name,
|
||||
description = description,
|
||||
key_prefix = prefix,
|
||||
key_hash = key_hash,
|
||||
user_id = user_id,
|
||||
scopes = scopes,
|
||||
expires_at = expires_at,
|
||||
)
|
||||
db.add(key)
|
||||
db.commit()
|
||||
db.refresh(key)
|
||||
return key, raw_key
|
||||
|
||||
|
||||
# ── Read ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
def list_api_keys(
|
||||
db: Session,
|
||||
user_id: Optional[UUID] = None,
|
||||
include_inactive: bool = False,
|
||||
) -> List[ApiKey]:
|
||||
q = db.query(ApiKey)
|
||||
if user_id is not None:
|
||||
q = q.filter(ApiKey.user_id == user_id)
|
||||
if not include_inactive:
|
||||
q = q.filter(ApiKey.is_active == True)
|
||||
return q.order_by(ApiKey.created_at.desc()).all()
|
||||
|
||||
|
||||
def get_api_key(db: Session, key_id: UUID, user_id: Optional[UUID] = None) -> ApiKey:
|
||||
q = db.query(ApiKey).filter(ApiKey.id == key_id)
|
||||
if user_id is not None:
|
||||
q = q.filter(ApiKey.user_id == user_id)
|
||||
key = q.first()
|
||||
if not key:
|
||||
raise EntityNotFoundError("ApiKey", str(key_id))
|
||||
return key
|
||||
|
||||
|
||||
# ── Update / Revoke ───────────────────────────────────────────────────────────
|
||||
|
||||
def update_api_key(
|
||||
db: Session,
|
||||
key_id: UUID,
|
||||
user_id: Optional[UUID] = None,
|
||||
*,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
scopes: Optional[List[str]] = None,
|
||||
expires_at: Optional[datetime] = None,
|
||||
is_active: Optional[bool] = None,
|
||||
) -> ApiKey:
|
||||
key = get_api_key(db, key_id, user_id)
|
||||
if name is not None:
|
||||
key.name = name
|
||||
if description is not None:
|
||||
key.description = description
|
||||
if scopes is not None:
|
||||
key.scopes = scopes
|
||||
if expires_at is not None:
|
||||
key.expires_at = expires_at
|
||||
if is_active is not None:
|
||||
key.is_active = is_active
|
||||
db.commit()
|
||||
db.refresh(key)
|
||||
return key
|
||||
|
||||
|
||||
def revoke_api_key(
|
||||
db: Session,
|
||||
key_id: UUID,
|
||||
user_id: Optional[UUID] = None,
|
||||
) -> ApiKey:
|
||||
"""Soft-revoke: set is_active = False."""
|
||||
return update_api_key(db, key_id, user_id, is_active=False)
|
||||
|
||||
|
||||
def delete_api_key(db: Session, key_id: UUID, user_id: Optional[UUID] = None) -> None:
|
||||
"""Hard delete — use revoke instead for audit trail."""
|
||||
key = get_api_key(db, key_id, user_id)
|
||||
db.delete(key)
|
||||
db.commit()
|
||||
|
||||
|
||||
# ── Authentication ────────────────────────────────────────────────────────────
|
||||
|
||||
def authenticate_raw_key(db: Session, raw_key: str) -> Optional[User]:
|
||||
"""
|
||||
Verify a raw API key.
|
||||
|
||||
Returns the owning User if the key is valid, active, and not expired.
|
||||
Updates ``last_used_at`` (throttled to once per request — always updates).
|
||||
Returns None on any failure.
|
||||
"""
|
||||
h = hash_key(raw_key)
|
||||
key: Optional[ApiKey] = db.query(ApiKey).filter(ApiKey.key_hash == h).first()
|
||||
|
||||
if key is None or not key.is_active:
|
||||
return None
|
||||
if key.expires_at and key.expires_at < datetime.utcnow():
|
||||
return None
|
||||
|
||||
# Update last_used_at
|
||||
key.last_used_at = datetime.utcnow()
|
||||
db.commit()
|
||||
|
||||
user: Optional[User] = db.query(User).filter(User.id == key.user_id).first()
|
||||
if user is None or not user.is_active:
|
||||
return None
|
||||
|
||||
# Attach the key's scopes to the user instance so scope-enforcement
|
||||
# dependencies can verify them without an additional DB query.
|
||||
# _api_key_scopes=None means "full user access" (JWT path).
|
||||
user._api_key_scopes = key.scopes or []
|
||||
return user
|
||||
@@ -35,6 +35,7 @@ import yaml
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.test_template import TestTemplate
|
||||
from app.models.technique import Technique
|
||||
from app.services.audit_service import log_action
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -218,6 +219,7 @@ def import_atomic_red_team(db: Session) -> dict:
|
||||
|
||||
created = 0
|
||||
skipped = 0
|
||||
new_technique_ids: set[str] = set()
|
||||
|
||||
for item in parsed_tests:
|
||||
if item["atomic_test_id"] in existing_ids:
|
||||
@@ -238,8 +240,14 @@ def import_atomic_red_team(db: Session) -> dict:
|
||||
)
|
||||
db.add(template)
|
||||
existing_ids.add(item["atomic_test_id"])
|
||||
new_technique_ids.add(item["technique_id"])
|
||||
created += 1
|
||||
|
||||
if new_technique_ids:
|
||||
db.query(Technique).filter(
|
||||
Technique.mitre_id.in_(new_technique_ids)
|
||||
).update({"review_required": True}, synchronize_session=False)
|
||||
|
||||
db.commit()
|
||||
|
||||
# Count distinct YAML files by technique_id
|
||||
|
||||
@@ -0,0 +1,553 @@
|
||||
"""Phase 10: Attack Path CRUD service."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from app.models.attack_path import (
|
||||
AttackPath, AttackPathStep, AttackPathExecution,
|
||||
AttackPathStepResult, TimelineEntry,
|
||||
ExecutionStatus, StepResultStatus, TimelineActorSide, TimelineEntryType,
|
||||
)
|
||||
from app.domain.exceptions import EntityNotFoundError
|
||||
from app.services import audit_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _now() -> datetime:
|
||||
return datetime.utcnow()
|
||||
|
||||
|
||||
# ── Attack Path CRUD ──────────────────────────────────────────────────────────
|
||||
|
||||
def create_attack_path(db: Session, data: dict, user_id: UUID) -> AttackPath:
|
||||
path = AttackPath(
|
||||
name=data["name"],
|
||||
description=data.get("description"),
|
||||
objective=data.get("objective"),
|
||||
is_template=data.get("is_template", False),
|
||||
threat_actor_id=data.get("threat_actor_id"),
|
||||
tags=data.get("tags") or [],
|
||||
created_by=user_id,
|
||||
)
|
||||
db.add(path)
|
||||
db.commit()
|
||||
db.refresh(path)
|
||||
audit_service.log_action(
|
||||
db, user_id, "ATTACK_PATH_CREATED", "attack_path", str(path.id),
|
||||
details={"name": path.name, "is_template": path.is_template},
|
||||
)
|
||||
return path
|
||||
|
||||
|
||||
def get_attack_path(db: Session, path_id: UUID) -> AttackPath:
|
||||
path = (
|
||||
db.query(AttackPath)
|
||||
.options(joinedload(AttackPath.steps))
|
||||
.filter(AttackPath.id == path_id)
|
||||
.first()
|
||||
)
|
||||
if not path:
|
||||
raise EntityNotFoundError("AttackPath", str(path_id))
|
||||
return path
|
||||
|
||||
|
||||
def list_attack_paths(
|
||||
db: Session,
|
||||
is_template: Optional[bool] = None,
|
||||
technique_id: Optional[UUID] = None,
|
||||
is_active: Optional[bool] = True,
|
||||
) -> list[AttackPath]:
|
||||
q = db.query(AttackPath)
|
||||
if is_active is not None:
|
||||
q = q.filter(AttackPath.is_active == is_active)
|
||||
if is_template is not None:
|
||||
q = q.filter(AttackPath.is_template == is_template)
|
||||
if technique_id:
|
||||
q = q.join(AttackPathStep).filter(AttackPathStep.technique_id == technique_id)
|
||||
return q.order_by(AttackPath.created_at.desc()).all()
|
||||
|
||||
|
||||
def update_attack_path(db: Session, path_id: UUID, data: dict, user_id: UUID) -> AttackPath:
|
||||
path = db.query(AttackPath).filter(AttackPath.id == path_id).first()
|
||||
if not path:
|
||||
raise EntityNotFoundError("AttackPath", str(path_id))
|
||||
for k, v in data.items():
|
||||
if v is not None and hasattr(path, k):
|
||||
setattr(path, k, v)
|
||||
path.updated_at = _now()
|
||||
db.commit()
|
||||
db.refresh(path)
|
||||
return path
|
||||
|
||||
|
||||
def delete_attack_path(db: Session, path_id: UUID, user_id: UUID) -> None:
|
||||
path = db.query(AttackPath).filter(AttackPath.id == path_id).first()
|
||||
if not path:
|
||||
raise EntityNotFoundError("AttackPath", str(path_id))
|
||||
path.is_active = False
|
||||
path.updated_at = _now()
|
||||
db.commit()
|
||||
audit_service.log_action(db, user_id, "ATTACK_PATH_ARCHIVED", "attack_path", str(path_id))
|
||||
|
||||
|
||||
# ── Steps CRUD ────────────────────────────────────────────────────────────────
|
||||
|
||||
def add_step(db: Session, path_id: UUID, data: dict, user_id: UUID) -> AttackPathStep:
|
||||
path = db.query(AttackPath).filter(AttackPath.id == path_id).first()
|
||||
if not path:
|
||||
raise EntityNotFoundError("AttackPath", str(path_id))
|
||||
|
||||
# Auto-assign order_index if not provided
|
||||
if data.get("order_index") is None:
|
||||
max_idx = db.query(AttackPathStep).filter(
|
||||
AttackPathStep.attack_path_id == path_id
|
||||
).count()
|
||||
data["order_index"] = max_idx
|
||||
|
||||
step = AttackPathStep(
|
||||
attack_path_id=path_id,
|
||||
order_index=data.get("order_index", 0),
|
||||
kill_chain_phase=data.get("kill_chain_phase"),
|
||||
technique_id=data.get("technique_id"),
|
||||
test_id=data.get("test_id"),
|
||||
name=data.get("name"),
|
||||
description=data.get("description"),
|
||||
expected_detection=data.get("expected_detection", True),
|
||||
notes=data.get("notes"),
|
||||
)
|
||||
db.add(step)
|
||||
path.updated_at = _now()
|
||||
db.commit()
|
||||
db.refresh(step)
|
||||
return step
|
||||
|
||||
|
||||
def update_step(db: Session, step_id: UUID, data: dict, user_id: UUID) -> AttackPathStep:
|
||||
step = db.query(AttackPathStep).filter(AttackPathStep.id == step_id).first()
|
||||
if not step:
|
||||
raise EntityNotFoundError("AttackPathStep", str(step_id))
|
||||
for k, v in data.items():
|
||||
if v is not None and hasattr(step, k):
|
||||
setattr(step, k, v)
|
||||
db.commit()
|
||||
db.refresh(step)
|
||||
return step
|
||||
|
||||
|
||||
def delete_step(db: Session, step_id: UUID, user_id: UUID) -> None:
|
||||
step = db.query(AttackPathStep).filter(AttackPathStep.id == step_id).first()
|
||||
if not step:
|
||||
raise EntityNotFoundError("AttackPathStep", str(step_id))
|
||||
db.delete(step)
|
||||
db.commit()
|
||||
|
||||
|
||||
def reorder_steps(db: Session, path_id: UUID, step_ids: list[UUID], user_id: UUID) -> list[AttackPathStep]:
|
||||
"""Reorder steps by providing ordered list of step IDs."""
|
||||
path = db.query(AttackPath).filter(AttackPath.id == path_id).first()
|
||||
if not path:
|
||||
raise EntityNotFoundError("AttackPath", str(path_id))
|
||||
|
||||
for idx, step_id in enumerate(step_ids):
|
||||
db.query(AttackPathStep).filter(
|
||||
AttackPathStep.id == step_id,
|
||||
AttackPathStep.attack_path_id == path_id,
|
||||
).update({"order_index": idx})
|
||||
|
||||
path.updated_at = _now()
|
||||
db.commit()
|
||||
|
||||
return (
|
||||
db.query(AttackPathStep)
|
||||
.filter(AttackPathStep.attack_path_id == path_id)
|
||||
.order_by(AttackPathStep.order_index)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
# ── Executions ────────────────────────────────────────────────────────────────
|
||||
|
||||
def create_execution(
|
||||
db: Session, path_id: UUID, data: dict, user_id: UUID
|
||||
) -> AttackPathExecution:
|
||||
path = db.query(AttackPath).filter(AttackPath.id == path_id).first()
|
||||
if not path:
|
||||
raise EntityNotFoundError("AttackPath", str(path_id))
|
||||
|
||||
execution = AttackPathExecution(
|
||||
attack_path_id=path_id,
|
||||
status=ExecutionStatus.planned,
|
||||
environment=data.get("environment"),
|
||||
red_team_lead=data.get("red_team_lead"),
|
||||
blue_team_lead=data.get("blue_team_lead"),
|
||||
notes=data.get("notes"),
|
||||
started_by=user_id,
|
||||
)
|
||||
db.add(execution)
|
||||
db.flush()
|
||||
|
||||
# Pre-create pending step results for every step in the path
|
||||
steps = (
|
||||
db.query(AttackPathStep)
|
||||
.filter(AttackPathStep.attack_path_id == path_id)
|
||||
.order_by(AttackPathStep.order_index)
|
||||
.all()
|
||||
)
|
||||
for step in steps:
|
||||
result = AttackPathStepResult(
|
||||
execution_id=execution.id,
|
||||
step_id=step.id,
|
||||
step_order=step.order_index,
|
||||
status=StepResultStatus.pending,
|
||||
)
|
||||
db.add(result)
|
||||
|
||||
db.commit()
|
||||
db.refresh(execution)
|
||||
|
||||
# Auto-add system timeline entry
|
||||
_add_system_entry(
|
||||
db, execution.id,
|
||||
entry_type=TimelineEntryType.phase_transition,
|
||||
content=f"Execution created for '{path.name}' with {len(steps)} steps.",
|
||||
)
|
||||
|
||||
audit_service.log_action(
|
||||
db, user_id, "ATTACK_PATH_EXECUTION_STARTED", "attack_path_execution",
|
||||
str(execution.id),
|
||||
details={"path_id": str(path_id), "path_name": path.name, "steps": len(steps)},
|
||||
)
|
||||
return execution
|
||||
|
||||
|
||||
def get_execution(db: Session, execution_id: UUID) -> AttackPathExecution:
|
||||
ex = (
|
||||
db.query(AttackPathExecution)
|
||||
.options(
|
||||
joinedload(AttackPathExecution.step_results),
|
||||
joinedload(AttackPathExecution.timeline),
|
||||
)
|
||||
.filter(AttackPathExecution.id == execution_id)
|
||||
.first()
|
||||
)
|
||||
if not ex:
|
||||
raise EntityNotFoundError("AttackPathExecution", str(execution_id))
|
||||
return ex
|
||||
|
||||
|
||||
def list_executions(db: Session, path_id: UUID) -> list[AttackPathExecution]:
|
||||
path = db.query(AttackPath).filter(AttackPath.id == path_id).first()
|
||||
if not path:
|
||||
raise EntityNotFoundError("AttackPath", str(path_id))
|
||||
return (
|
||||
db.query(AttackPathExecution)
|
||||
.filter(AttackPathExecution.attack_path_id == path_id)
|
||||
.order_by(AttackPathExecution.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def start_execution(db: Session, execution_id: UUID, user_id: UUID) -> AttackPathExecution:
|
||||
ex = db.query(AttackPathExecution).filter(AttackPathExecution.id == execution_id).first()
|
||||
if not ex:
|
||||
raise EntityNotFoundError("AttackPathExecution", str(execution_id))
|
||||
if ex.status not in (ExecutionStatus.planned,):
|
||||
from fastapi import HTTPException
|
||||
raise HTTPException(400, "Execution is not in 'planned' state")
|
||||
ex.status = ExecutionStatus.in_progress
|
||||
ex.started_at = _now()
|
||||
db.commit()
|
||||
db.refresh(ex)
|
||||
_add_system_entry(db, execution_id, TimelineEntryType.phase_transition,
|
||||
"Execution started.", actor_id=user_id, actor_side=TimelineActorSide.system)
|
||||
return ex
|
||||
|
||||
|
||||
# ── Step Execution ────────────────────────────────────────────────────────────
|
||||
|
||||
def execute_step(
|
||||
db: Session,
|
||||
execution_id: UUID,
|
||||
step_id: UUID,
|
||||
data: dict,
|
||||
user_id: UUID,
|
||||
) -> AttackPathStepResult:
|
||||
"""Record the result of executing one step."""
|
||||
ex = db.query(AttackPathExecution).filter(AttackPathExecution.id == execution_id).first()
|
||||
if not ex:
|
||||
raise EntityNotFoundError("AttackPathExecution", str(execution_id))
|
||||
if ex.status not in (ExecutionStatus.in_progress, ExecutionStatus.planned):
|
||||
from fastapi import HTTPException
|
||||
raise HTTPException(400, "Execution must be in_progress to record step results")
|
||||
|
||||
# Auto-start if still planned
|
||||
if ex.status == ExecutionStatus.planned:
|
||||
ex.status = ExecutionStatus.in_progress
|
||||
ex.started_at = _now()
|
||||
|
||||
result = (
|
||||
db.query(AttackPathStepResult)
|
||||
.filter(
|
||||
AttackPathStepResult.execution_id == execution_id,
|
||||
AttackPathStepResult.step_id == step_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not result:
|
||||
# Create on-the-fly if step was added after execution started
|
||||
step = db.query(AttackPathStep).filter(AttackPathStep.id == step_id).first()
|
||||
if not step:
|
||||
raise EntityNotFoundError("AttackPathStep", str(step_id))
|
||||
result = AttackPathStepResult(
|
||||
execution_id=execution_id,
|
||||
step_id=step_id,
|
||||
step_order=step.order_index,
|
||||
)
|
||||
db.add(result)
|
||||
|
||||
now = _now()
|
||||
new_status = StepResultStatus(data["status"])
|
||||
result.status = new_status
|
||||
result.executed_by = user_id
|
||||
result.executed_at = data.get("executed_at") or now
|
||||
result.notes = data.get("notes")
|
||||
result.evidence_ids = [str(e) for e in (data.get("evidence_ids") or [])]
|
||||
result.detection_asset_id = data.get("detection_asset_id")
|
||||
|
||||
if new_status == StepResultStatus.detected:
|
||||
result.detected_at = data.get("detected_at") or now
|
||||
if result.executed_at:
|
||||
delta = (result.detected_at - result.executed_at).total_seconds()
|
||||
result.time_to_detect_seconds = max(0.0, delta)
|
||||
|
||||
db.commit()
|
||||
db.refresh(result)
|
||||
|
||||
# Add timeline entry
|
||||
step_obj = db.query(AttackPathStep).filter(AttackPathStep.id == step_id).first()
|
||||
step_name = step_obj.name or (step_obj.kill_chain_phase or "Unknown step")
|
||||
actor_side = TimelineActorSide.red if new_status != StepResultStatus.detected else TimelineActorSide.blue
|
||||
entry_type = (
|
||||
TimelineEntryType.detection if new_status == StepResultStatus.detected
|
||||
else TimelineEntryType.action
|
||||
)
|
||||
content = (
|
||||
f"Step '{step_name}' marked as {new_status.value}."
|
||||
+ (f" Detected in {result.time_to_detect_seconds:.0f}s." if result.time_to_detect_seconds else "")
|
||||
)
|
||||
_add_system_entry(
|
||||
db, execution_id, entry_type, content,
|
||||
actor_id=user_id, actor_side=actor_side, step_id=step_id,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ── Completion & Metrics ──────────────────────────────────────────────────────
|
||||
|
||||
def complete_execution(db: Session, execution_id: UUID, user_id: UUID) -> AttackPathExecution:
|
||||
"""Mark execution complete and compute all kill-chain metrics."""
|
||||
ex = db.query(AttackPathExecution).filter(AttackPathExecution.id == execution_id).first()
|
||||
if not ex:
|
||||
raise EntityNotFoundError("AttackPathExecution", str(execution_id))
|
||||
|
||||
results = (
|
||||
db.query(AttackPathStepResult)
|
||||
.filter(AttackPathStepResult.execution_id == execution_id)
|
||||
.order_by(AttackPathStepResult.step_order)
|
||||
.all()
|
||||
)
|
||||
|
||||
total = len(results)
|
||||
detected = sum(1 for r in results if r.status == StepResultStatus.detected)
|
||||
not_detected = sum(1 for r in results if r.status == StepResultStatus.not_detected)
|
||||
skipped = sum(1 for r in results if r.status == StepResultStatus.skipped)
|
||||
|
||||
detection_rate = (detected / total) if total > 0 else 0.0
|
||||
|
||||
ttds = [r.time_to_detect_seconds for r in results
|
||||
if r.time_to_detect_seconds is not None]
|
||||
mttd = (sum(ttds) / len(ttds)) if ttds else None
|
||||
|
||||
# Furthest undetected step (highest order_index with not_detected status)
|
||||
undetected = [r for r in results if r.status == StepResultStatus.not_detected]
|
||||
furthest = max((r.step_order for r in undetected), default=None)
|
||||
|
||||
ex.status = ExecutionStatus.completed
|
||||
ex.completed_at = _now()
|
||||
ex.total_steps = total
|
||||
ex.detected_steps = detected
|
||||
ex.not_detected_steps = not_detected
|
||||
ex.skipped_steps = skipped
|
||||
ex.detection_rate = round(detection_rate, 4)
|
||||
ex.mttd_seconds = round(mttd, 1) if mttd is not None else None
|
||||
ex.furthest_undetected_step = furthest
|
||||
|
||||
db.commit()
|
||||
db.refresh(ex)
|
||||
|
||||
_add_system_entry(
|
||||
db, execution_id, TimelineEntryType.phase_transition,
|
||||
f"Execution completed. Detection rate: {detection_rate:.0%}. "
|
||||
f"Detected {detected}/{total} steps. "
|
||||
+ (f"MTTD: {mttd:.0f}s." if mttd else ""),
|
||||
actor_id=user_id, actor_side=TimelineActorSide.system,
|
||||
)
|
||||
|
||||
audit_service.log_action(
|
||||
db, user_id, "ATTACK_PATH_EXECUTION_COMPLETED", "attack_path_execution",
|
||||
str(execution_id),
|
||||
details={"detection_rate": detection_rate, "mttd_seconds": mttd,
|
||||
"detected": detected, "total": total},
|
||||
)
|
||||
return ex
|
||||
|
||||
|
||||
def abort_execution(db: Session, execution_id: UUID, user_id: UUID) -> AttackPathExecution:
|
||||
ex = db.query(AttackPathExecution).filter(AttackPathExecution.id == execution_id).first()
|
||||
if not ex:
|
||||
raise EntityNotFoundError("AttackPathExecution", str(execution_id))
|
||||
ex.status = ExecutionStatus.aborted
|
||||
ex.completed_at = _now()
|
||||
db.commit()
|
||||
db.refresh(ex)
|
||||
_add_system_entry(db, execution_id, TimelineEntryType.flag, "Execution aborted.",
|
||||
actor_id=user_id, actor_side=TimelineActorSide.system)
|
||||
return ex
|
||||
|
||||
|
||||
# ── Timeline ──────────────────────────────────────────────────────────────────
|
||||
|
||||
def add_timeline_entry(
|
||||
db: Session, execution_id: UUID, data: dict, user_id: UUID
|
||||
) -> TimelineEntry:
|
||||
ex = db.query(AttackPathExecution).filter(AttackPathExecution.id == execution_id).first()
|
||||
if not ex:
|
||||
raise EntityNotFoundError("AttackPathExecution", str(execution_id))
|
||||
|
||||
entry = TimelineEntry(
|
||||
execution_id=execution_id,
|
||||
step_id=data.get("step_id"),
|
||||
timestamp=data.get("timestamp") or _now(),
|
||||
actor_side=TimelineActorSide(data["actor_side"]),
|
||||
actor_id=user_id,
|
||||
entry_type=TimelineEntryType(data["entry_type"]),
|
||||
content=data["content"],
|
||||
extra=data.get("extra"),
|
||||
)
|
||||
db.add(entry)
|
||||
db.commit()
|
||||
db.refresh(entry)
|
||||
return entry
|
||||
|
||||
|
||||
def get_timeline(db: Session, execution_id: UUID) -> list[TimelineEntry]:
|
||||
ex = db.query(AttackPathExecution).filter(AttackPathExecution.id == execution_id).first()
|
||||
if not ex:
|
||||
raise EntityNotFoundError("AttackPathExecution", str(execution_id))
|
||||
return (
|
||||
db.query(TimelineEntry)
|
||||
.filter(TimelineEntry.execution_id == execution_id)
|
||||
.order_by(TimelineEntry.timestamp.asc())
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
# ── Kill-Chain Metrics ────────────────────────────────────────────────────────
|
||||
|
||||
def get_kill_chain_metrics(db: Session, execution_id: UUID) -> dict:
|
||||
ex = db.query(AttackPathExecution).filter(AttackPathExecution.id == execution_id).first()
|
||||
if not ex:
|
||||
raise EntityNotFoundError("AttackPathExecution", str(execution_id))
|
||||
|
||||
results = (
|
||||
db.query(AttackPathStepResult)
|
||||
.filter(AttackPathStepResult.execution_id == execution_id)
|
||||
.order_by(AttackPathStepResult.step_order)
|
||||
.all()
|
||||
)
|
||||
|
||||
step_breakdown = []
|
||||
phase_detected: dict[str, list] = {}
|
||||
|
||||
for r in results:
|
||||
step = db.query(AttackPathStep).filter(AttackPathStep.id == r.step_id).first()
|
||||
phase = step.kill_chain_phase if step else None
|
||||
entry = {
|
||||
"step_id": str(r.step_id),
|
||||
"step_order": r.step_order,
|
||||
"step_name": step.name if step else None,
|
||||
"kill_chain_phase": phase,
|
||||
"status": r.status.value if hasattr(r.status, "value") else r.status,
|
||||
"executed_at": r.executed_at.isoformat() if r.executed_at else None,
|
||||
"detected_at": r.detected_at.isoformat() if r.detected_at else None,
|
||||
"time_to_detect_seconds": r.time_to_detect_seconds,
|
||||
"detection_asset_id": str(r.detection_asset_id) if r.detection_asset_id else None,
|
||||
}
|
||||
step_breakdown.append(entry)
|
||||
if phase:
|
||||
phase_detected.setdefault(phase, []).append(
|
||||
r.status == StepResultStatus.detected
|
||||
)
|
||||
|
||||
phase_summary = {
|
||||
phase: {
|
||||
"total": len(v),
|
||||
"detected": sum(v),
|
||||
"detection_rate": round(sum(v) / len(v), 3) if v else 0.0,
|
||||
}
|
||||
for phase, v in phase_detected.items()
|
||||
}
|
||||
|
||||
# Furthest undetected phase
|
||||
furthest_undetected_phase = None
|
||||
if ex.furthest_undetected_step is not None:
|
||||
for r in reversed(results):
|
||||
if r.step_order == ex.furthest_undetected_step:
|
||||
step = db.query(AttackPathStep).filter(AttackPathStep.id == r.step_id).first()
|
||||
if step:
|
||||
furthest_undetected_phase = step.kill_chain_phase
|
||||
break
|
||||
|
||||
return {
|
||||
"execution_id": str(execution_id),
|
||||
"total_steps": ex.total_steps or len(results),
|
||||
"detected_steps": ex.detected_steps or 0,
|
||||
"not_detected_steps": ex.not_detected_steps or 0,
|
||||
"skipped_steps": ex.skipped_steps or 0,
|
||||
"detection_rate": ex.detection_rate or 0.0,
|
||||
"mttd_seconds": ex.mttd_seconds,
|
||||
"furthest_undetected_step": ex.furthest_undetected_step,
|
||||
"furthest_undetected_phase": furthest_undetected_phase,
|
||||
"step_breakdown": step_breakdown,
|
||||
"phase_summary": phase_summary,
|
||||
}
|
||||
|
||||
|
||||
# ── Helper ────────────────────────────────────────────────────────────────────
|
||||
|
||||
def _add_system_entry(
|
||||
db: Session,
|
||||
execution_id: UUID,
|
||||
entry_type: TimelineEntryType,
|
||||
content: str,
|
||||
actor_id: Optional[UUID] = None,
|
||||
actor_side: TimelineActorSide = TimelineActorSide.system,
|
||||
step_id: Optional[UUID] = None,
|
||||
) -> None:
|
||||
entry = TimelineEntry(
|
||||
execution_id=execution_id,
|
||||
step_id=step_id,
|
||||
timestamp=_now(),
|
||||
actor_side=actor_side,
|
||||
actor_id=actor_id,
|
||||
entry_type=entry_type,
|
||||
content=content,
|
||||
)
|
||||
db.add(entry)
|
||||
db.commit()
|
||||
@@ -0,0 +1,798 @@
|
||||
"""ATT&CK Evaluations importer — fetches real CrowdStrike detection results
|
||||
from MITRE Engenuity's public API and seeds the platform with validated tests.
|
||||
|
||||
Data source
|
||||
-----------
|
||||
https://evals.mitre.org/api/
|
||||
- /participants/ → list of vendors + rounds they completed
|
||||
- /results/?participant=crowdstrike&domain=ENTERPRISE
|
||||
→ per-substep detection results per adversary
|
||||
|
||||
Detection level mapping (MITRE → Aegis)
|
||||
---------------------------------------
|
||||
Technique / Specific Behavior → detected (correctly identified ATT&CK technique)
|
||||
Tactic → partially_detected (behavior noted but not categorized)
|
||||
General / IOC / MSSP → partially_detected (anomaly detected, not ATT&CK-mapped)
|
||||
Telemetry → partially_detected (raw data only — marginal detection)
|
||||
None / N/A → not_detected
|
||||
|
||||
All imported tests are created in ``in_review`` state so Blue Leads must
|
||||
confirm each result before it counts as real coverage for the organisation.
|
||||
|
||||
Important caveats stored in every test's description
|
||||
------------------------------------------------------
|
||||
"Source: MITRE ATT&CK Evaluation (Round N — Adversary). Results reflect
|
||||
CrowdStrike Falcon in a controlled lab environment, NOT this organisation's
|
||||
deployment. Validate detection in your own environment before approving."
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.enums import TestState, TestResult
|
||||
from app.models.evaluation_import import EvaluationImport
|
||||
from app.models.technique import Technique
|
||||
from app.models.test import Test
|
||||
from app.models.user import User
|
||||
from app.services.audit_service import log_action
|
||||
from app.services.status_service import recalculate_technique_status
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_BASE = "https://evals.mitre.org"
|
||||
_TIMEOUT = 30 # seconds per HTTP call
|
||||
_VENDOR = "crowdstrike"
|
||||
_DOMAIN = "ENTERPRISE"
|
||||
|
||||
# Browser-like headers to bypass Cloudflare bot protection on evals.mitre.org
|
||||
_HEADERS = {
|
||||
"User-Agent": (
|
||||
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
|
||||
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
||||
"Chrome/124.0.0.0 Safari/537.36"
|
||||
),
|
||||
"Accept": "application/json, text/plain, */*",
|
||||
"Accept-Language": "en-US,en;q=0.9",
|
||||
"Referer": "https://evals.mitre.org/",
|
||||
"Origin": "https://evals.mitre.org",
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fallback: hardcoded public CrowdStrike ENTERPRISE rounds
|
||||
# Used when evals.mitre.org API is unreachable (Cloudflare 502, outage, etc.)
|
||||
#
|
||||
# Names use the EXACT slugs the live API returns (hyphens, not underscores).
|
||||
# Verified from live API response on 2025-06-05.
|
||||
# CrowdStrike did NOT participate in Round 6 (OilRig) — not included.
|
||||
# ---------------------------------------------------------------------------
|
||||
_FALLBACK_ROUNDS: list[dict[str, Any]] = [
|
||||
{
|
||||
"name": "apt3",
|
||||
"display_name": "APT3",
|
||||
"eval_round": 1,
|
||||
"domain": "ENTERPRISE",
|
||||
"status": "PUBLIC",
|
||||
},
|
||||
{
|
||||
"name": "apt29",
|
||||
"display_name": "APT29",
|
||||
"eval_round": 2,
|
||||
"domain": "ENTERPRISE",
|
||||
"status": "PUBLIC",
|
||||
},
|
||||
{
|
||||
"name": "carbanak-fin7",
|
||||
"display_name": "Carbanak+FIN7",
|
||||
"eval_round": 3,
|
||||
"domain": "ENTERPRISE",
|
||||
"status": "PUBLIC",
|
||||
},
|
||||
{
|
||||
"name": "wizard-spider-sandworm",
|
||||
"display_name": "Wizard Spider + Sandworm",
|
||||
"eval_round": 4,
|
||||
"domain": "ENTERPRISE",
|
||||
"status": "PUBLIC",
|
||||
},
|
||||
{
|
||||
"name": "turla",
|
||||
"display_name": "Turla",
|
||||
"eval_round": 5,
|
||||
"domain": "ENTERPRISE",
|
||||
"status": "PUBLIC",
|
||||
},
|
||||
{
|
||||
"name": "er7",
|
||||
"display_name": "Enterprise 2025",
|
||||
"eval_round": 7,
|
||||
"domain": "ENTERPRISE",
|
||||
"status": "PUBLIC",
|
||||
},
|
||||
]
|
||||
|
||||
# Detection type → quality score (higher = better)
|
||||
_DETECTION_SCORE: dict[str, int] = {
|
||||
"none": 0,
|
||||
"n/a": 0,
|
||||
"telemetry": 1,
|
||||
"mssp": 2,
|
||||
"general": 2,
|
||||
"ioc": 2,
|
||||
"tactic": 3,
|
||||
"technique": 4,
|
||||
"specific behavior": 4,
|
||||
}
|
||||
|
||||
|
||||
def _score(detection_type: str) -> int:
|
||||
key = (detection_type or "").lower().strip()
|
||||
for pattern, score in _DETECTION_SCORE.items():
|
||||
if pattern in key:
|
||||
return score
|
||||
return 0
|
||||
|
||||
|
||||
def _score_to_result(score: int) -> TestResult:
|
||||
if score >= 4:
|
||||
return TestResult.detected
|
||||
if score >= 1:
|
||||
return TestResult.partially_detected
|
||||
return TestResult.not_detected
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def fetch_rounds_with_status() -> dict[str, Any]:
|
||||
"""Fetch CrowdStrike ENTERPRISE rounds and report whether the live API was reachable.
|
||||
|
||||
Returns::
|
||||
|
||||
{
|
||||
"rounds": [{"name": ..., "display_name": ..., "eval_round": ...}, ...],
|
||||
"api_reachable": True | False,
|
||||
"api_error": None | "<error message>",
|
||||
}
|
||||
"""
|
||||
try:
|
||||
session = requests.Session()
|
||||
session.headers.update(_HEADERS)
|
||||
resp = session.get(f"{_BASE}/api/participants/", timeout=_TIMEOUT)
|
||||
resp.raise_for_status()
|
||||
participants = resp.json()
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"evals.mitre.org API unreachable (%s) — using hardcoded fallback round list.",
|
||||
exc,
|
||||
)
|
||||
return {
|
||||
"rounds": list(_FALLBACK_ROUNDS),
|
||||
"api_reachable": False,
|
||||
"api_error": str(exc),
|
||||
}
|
||||
|
||||
crowdstrike = next(
|
||||
(p for p in participants if p.get("name", "").lower() == _VENDOR),
|
||||
None,
|
||||
)
|
||||
if not crowdstrike:
|
||||
logger.warning("Vendor '%s' not found in live data — using fallback.", _VENDOR)
|
||||
return {
|
||||
"rounds": list(_FALLBACK_ROUNDS),
|
||||
"api_reachable": True, # API was reachable, vendor just wasn't listed
|
||||
"api_error": f"Vendor '{_VENDOR}' not found in participants list",
|
||||
}
|
||||
|
||||
rounds = [
|
||||
adv
|
||||
for adv in crowdstrike.get("adversaries_completed", [])
|
||||
if adv.get("domain", "").upper() == _DOMAIN
|
||||
and adv.get("status", "").upper() == "PUBLIC"
|
||||
]
|
||||
rounds.sort(key=lambda x: x.get("eval_round", 0))
|
||||
return {
|
||||
"rounds": rounds if rounds else list(_FALLBACK_ROUNDS),
|
||||
"api_reachable": True,
|
||||
"api_error": None,
|
||||
}
|
||||
|
||||
|
||||
def fetch_available_rounds() -> list[dict[str, Any]]:
|
||||
"""Return all evaluation rounds CrowdStrike has completed (ENTERPRISE only).
|
||||
|
||||
Each dict has: name, display_name, eval_round.
|
||||
Sorted by eval_round ascending.
|
||||
|
||||
Falls back to ``_FALLBACK_ROUNDS`` if the live API is unreachable.
|
||||
"""
|
||||
return fetch_rounds_with_status()["rounds"]
|
||||
|
||||
|
||||
def get_latest_round() -> dict[str, Any]:
|
||||
"""Return the most recent PUBLIC ENTERPRISE round CrowdStrike participated in."""
|
||||
rounds = fetch_available_rounds()
|
||||
if not rounds:
|
||||
raise ValueError("No public Enterprise evaluation rounds found for CrowdStrike")
|
||||
return rounds[-1]
|
||||
|
||||
|
||||
def fetch_results_for_adversary(adversary_name: str) -> list[dict[str, Any]]:
|
||||
"""Fetch all per-substep detection results for a specific adversary round.
|
||||
|
||||
Returns a flat list of substep dicts, each containing:
|
||||
technique_id, technique_name, tactic_id, best_score, detection_type, note.
|
||||
"""
|
||||
url = f"{_BASE}/api/results/?participant={_VENDOR}&domain={_DOMAIN}"
|
||||
try:
|
||||
session = requests.Session()
|
||||
session.headers.update(_HEADERS)
|
||||
resp = session.get(url, timeout=_TIMEOUT)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
except Exception as exc:
|
||||
logger.error("Failed to fetch ATT&CK Evaluations results: %s", exc)
|
||||
raise
|
||||
|
||||
# The results endpoint returns a LIST of vendor objects:
|
||||
# [{"name": "crowdstrike", "adversaries": [{"Adversary_Name": "apt3", ...}, ...]}, ...]
|
||||
# (not a dict — hence the explicit vendor lookup below)
|
||||
if isinstance(data, list):
|
||||
vendor_entry = next(
|
||||
(v for v in data if isinstance(v, dict) and v.get("name", "").lower() == _VENDOR),
|
||||
None,
|
||||
)
|
||||
if not vendor_entry:
|
||||
raise ValueError(
|
||||
f"Vendor '{_VENDOR}' not found in results response. "
|
||||
f"Available vendors: {[v.get('name') for v in data if isinstance(v, dict)]}"
|
||||
)
|
||||
adversaries = vendor_entry.get("adversaries", [])
|
||||
else:
|
||||
# Fallback for legacy dict-shaped response (just in case API changes again)
|
||||
adversaries = data.get("adversaries", [])
|
||||
|
||||
target = next(
|
||||
(a for a in adversaries if a.get("Adversary_Name", "").lower() == adversary_name.lower()),
|
||||
None,
|
||||
)
|
||||
if not target:
|
||||
raise ValueError(
|
||||
f"Adversary '{adversary_name}' not found in results. "
|
||||
f"Available: {[a.get('Adversary_Name') for a in adversaries]}"
|
||||
)
|
||||
|
||||
substeps: list[dict[str, Any]] = []
|
||||
|
||||
scenarios = target.get("Detections_By_Step", {})
|
||||
for scenario_name, scenario_data in scenarios.items():
|
||||
for step in scenario_data.get("Steps", []):
|
||||
step_num = step.get("Step_Num", "")
|
||||
step_name = step.get("Step_Name", "")
|
||||
# Strip HTML tags from the Step.Description narrative
|
||||
step_desc_raw = step.get("Description") or ""
|
||||
step_description = re.sub(r"<[^>]+>", " ", step_desc_raw)
|
||||
step_description = re.sub(r"\s+", " ", step_description).strip()
|
||||
|
||||
for substep in step.get("Substeps", []):
|
||||
# Prefer sub-technique over technique
|
||||
sub = substep.get("Subtechnique") or {}
|
||||
tech = substep.get("Technique") or {}
|
||||
tactic = substep.get("Tactic") or {}
|
||||
|
||||
technique_id = (
|
||||
sub.get("Subtechnique_Id")
|
||||
or tech.get("Technique_Id")
|
||||
or ""
|
||||
).strip()
|
||||
technique_name = (
|
||||
sub.get("Subtechnique_Name")
|
||||
or tech.get("Technique_Name")
|
||||
or "Unknown"
|
||||
).strip()
|
||||
|
||||
if not technique_id:
|
||||
continue
|
||||
|
||||
detections = substep.get("Detections", [])
|
||||
best_score = 0
|
||||
best_type = "None"
|
||||
best_note = ""
|
||||
for det in detections:
|
||||
dtype = det.get("Detection_Type", "None")
|
||||
s = _score(dtype)
|
||||
if s > best_score:
|
||||
best_score = s
|
||||
best_type = dtype
|
||||
best_note = det.get("Detection_Note", "")
|
||||
|
||||
# Collect all unique data sources from screenshots across all detections
|
||||
data_sources: list[str] = sorted({
|
||||
src
|
||||
for det in detections
|
||||
for sc in det.get("Screenshots", [])
|
||||
for src in sc.get("Data_Sources", [])
|
||||
})
|
||||
|
||||
substeps.append(
|
||||
{
|
||||
"technique_id": technique_id,
|
||||
"technique_name": technique_name,
|
||||
"tactic_id": tactic.get("Tactic_Id", ""),
|
||||
"tactic_name": tactic.get("Tactic_Name", ""),
|
||||
"best_score": best_score,
|
||||
"detection_type": best_type,
|
||||
"note": best_note,
|
||||
# Enrichment fields from the API
|
||||
"scenario_name": scenario_name,
|
||||
"step_num": step_num,
|
||||
"step_name": step_name,
|
||||
"step_description": step_description,
|
||||
"substep_ref": substep.get("Substep", ""),
|
||||
"criteria": (substep.get("Criteria") or "").strip(),
|
||||
"data_sources": data_sources,
|
||||
}
|
||||
)
|
||||
|
||||
return substeps
|
||||
|
||||
|
||||
def _aggregate_by_technique(substeps: list[dict]) -> dict[str, dict]:
|
||||
"""Aggregate substep results per technique.
|
||||
|
||||
- Deduplicates substeps by (substep_ref, criteria) — prevents duplicates
|
||||
that arise when adversaries with multiple scenarios (e.g. Wizard Spider +
|
||||
Sandworm) repeat the same substep across a "combined" replay scenario.
|
||||
- Groups unique occurrences by scenario_name so the narrative can show
|
||||
"Wizard Spider scenario" vs "Sandworm scenario" separately.
|
||||
- Tracks best detection score across all unique substeps.
|
||||
"""
|
||||
by_technique: dict[str, dict] = {}
|
||||
|
||||
for sub in substeps:
|
||||
tid = sub["technique_id"]
|
||||
if tid not in by_technique:
|
||||
by_technique[tid] = {
|
||||
**sub,
|
||||
"occurrences": [], # flat list of unique occurrences
|
||||
"_seen_keys": set(), # (substep_ref, criteria) dedup set
|
||||
}
|
||||
|
||||
agg = by_technique[tid]
|
||||
|
||||
# Deduplication key: same substep_ref + same criteria text = duplicate
|
||||
dedup_key = (sub["substep_ref"], sub["criteria"])
|
||||
if dedup_key in agg["_seen_keys"]:
|
||||
continue
|
||||
agg["_seen_keys"].add(dedup_key)
|
||||
|
||||
agg["occurrences"].append({
|
||||
"scenario_name": sub["scenario_name"],
|
||||
"substep_ref": sub["substep_ref"],
|
||||
"step_num": sub["step_num"],
|
||||
"step_name": sub["step_name"],
|
||||
"step_description": sub["step_description"],
|
||||
"criteria": sub["criteria"],
|
||||
"data_sources": sub["data_sources"],
|
||||
"detection_type": sub["detection_type"],
|
||||
"best_score": sub["best_score"],
|
||||
"note": sub["note"],
|
||||
})
|
||||
|
||||
# Promote best detection score
|
||||
if sub["best_score"] > agg["best_score"]:
|
||||
agg["best_score"] = sub["best_score"]
|
||||
agg["detection_type"] = sub["detection_type"]
|
||||
agg["note"] = sub["note"]
|
||||
agg["tactic_id"] = sub["tactic_id"]
|
||||
agg["tactic_name"] = sub["tactic_name"]
|
||||
|
||||
# Clean up internal dedup sets before returning
|
||||
for agg in by_technique.values():
|
||||
agg.pop("_seen_keys", None)
|
||||
|
||||
return by_technique
|
||||
|
||||
|
||||
def _group_occurrences_by_scenario(occurrences: list[dict]) -> dict[str, list[dict]]:
|
||||
"""Group a technique's occurrences by scenario, preserving insertion order."""
|
||||
grouped: dict[str, list[dict]] = {}
|
||||
for occ in occurrences:
|
||||
sc = occ.get("scenario_name", "Scenario_1")
|
||||
grouped.setdefault(sc, []).append(occ)
|
||||
return grouped
|
||||
|
||||
|
||||
def _build_procedure_text(agg: dict, adversary_display: str, eval_round: int) -> str:
|
||||
"""Build a rich attack-path narrative for the Test.procedure_text field.
|
||||
|
||||
Groups substeps by scenario so adversaries with multiple threat groups
|
||||
(e.g. Wizard Spider + Sandworm with 3 scenarios) are clearly separated.
|
||||
Includes Step.Description narrative for context.
|
||||
"""
|
||||
occurrences = agg.get("occurrences", [])
|
||||
if not occurrences:
|
||||
return (
|
||||
f"MITRE ATT&CK Evaluation simulation using {adversary_display} TTPs. "
|
||||
f"See evaluation report at https://evals.mitre.org for full details."
|
||||
)
|
||||
|
||||
lines: list[str] = [f"ATT&CK Evaluation R{eval_round} — {adversary_display}", ""]
|
||||
|
||||
grouped = _group_occurrences_by_scenario(occurrences)
|
||||
scenario_count = len(grouped)
|
||||
|
||||
for sc_name, sc_occs in grouped.items():
|
||||
# Scenario header — only shown when there are multiple scenarios
|
||||
if scenario_count > 1:
|
||||
idx = sc_name.replace("Scenario_", "Scenario ")
|
||||
lines.append(f"=== {idx} ===")
|
||||
|
||||
# Within each scenario, group by step to emit description once per step
|
||||
seen_step_descs: set = set()
|
||||
for occ in sc_occs:
|
||||
step_num = occ.get("step_num", "")
|
||||
step_name = occ.get("step_name", "")
|
||||
step_desc = occ.get("step_description", "")
|
||||
# Use (step_num or step_name) as dedup key for descriptions
|
||||
step_key = str(step_num) if step_num else step_name
|
||||
|
||||
if step_key and step_key not in seen_step_descs:
|
||||
seen_step_descs.add(step_key)
|
||||
header = f"Step {step_num} — {step_name}:" if step_num else f"— {step_name}:"
|
||||
lines.append(header)
|
||||
if step_desc:
|
||||
truncated = step_desc[:450] + ("…" if len(step_desc) > 450 else "")
|
||||
lines.append(truncated)
|
||||
|
||||
ref = occ.get("substep_ref", "")
|
||||
criteria = occ.get("criteria", "")
|
||||
det = occ.get("detection_type", "")
|
||||
if criteria:
|
||||
tag = f" [{ref}]" if ref else " •"
|
||||
det_tag = f" [{det}]" if det and det.lower() not in ("none", "") else ""
|
||||
lines.append(f"{tag}{det_tag} {criteria}")
|
||||
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines).rstrip()
|
||||
|
||||
|
||||
def _build_description(agg: dict, adversary_display: str, eval_round: int) -> str:
|
||||
"""Build Test.description with source metadata, detection guidance and warning.
|
||||
|
||||
The 'criteria' field from the MITRE API describes what each substep does AND
|
||||
what should be detected, so it doubles as blue-team detection guidance.
|
||||
"""
|
||||
occurrences = agg.get("occurrences", [])
|
||||
|
||||
# Collect all unique data sources across every unique occurrence
|
||||
all_data_sources: list[str] = sorted({
|
||||
src
|
||||
for occ in occurrences
|
||||
for src in occ.get("data_sources", [])
|
||||
})
|
||||
|
||||
header = (
|
||||
f"Source: MITRE ATT&CK Evaluation — Round {eval_round} ({adversary_display}).\n"
|
||||
f"Vendor: CrowdStrike Falcon.\n"
|
||||
f"Detection type achieved: {agg['detection_type']}."
|
||||
)
|
||||
|
||||
ds_section = ""
|
||||
if all_data_sources:
|
||||
ds_section = "\n\nData sources observed:\n" + "\n".join(
|
||||
f" • {ds}" for ds in all_data_sources
|
||||
)
|
||||
|
||||
# Detection guidance — what criteria were observed (blue team can use these as IOCs)
|
||||
det_lines: list[str] = []
|
||||
grouped = _group_occurrences_by_scenario(occurrences)
|
||||
for sc_name, sc_occs in grouped.items():
|
||||
scenario_label = f"[{sc_name}] " if len(grouped) > 1 else ""
|
||||
for occ in sc_occs:
|
||||
ref = occ.get("substep_ref", "")
|
||||
step_name = occ.get("step_name", "")
|
||||
criteria = occ.get("criteria", "")
|
||||
det_type = occ.get("detection_type", "")
|
||||
if criteria:
|
||||
label = f"[{ref}]" if ref else "•"
|
||||
step_label = f" ({step_name})" if step_name else ""
|
||||
det_label = f" — {det_type}" if det_type and det_type.lower() not in ("none", "") else ""
|
||||
det_lines.append(f" {scenario_label}{label}{step_label}{det_label}: {criteria}")
|
||||
|
||||
det_section = ""
|
||||
if det_lines:
|
||||
det_section = "\n\nDetection criteria (what to look for):\n" + "\n".join(det_lines)
|
||||
|
||||
warning = (
|
||||
f"\n\n⚠️ IMPORTANT: These results reflect CrowdStrike Falcon performance in a "
|
||||
f"controlled MITRE lab environment against a simulated {adversary_display} "
|
||||
f"adversary. They do NOT represent your organisation's actual detection "
|
||||
f"capability. Validate in your own environment before approving."
|
||||
)
|
||||
|
||||
note_section = f"\n\nMITRE note: {agg['note']}" if agg.get("note") else ""
|
||||
|
||||
return header + ds_section + det_section + warning + note_section
|
||||
|
||||
|
||||
def _build_red_summary(agg: dict, adversary_display: str, eval_round: int) -> str:
|
||||
"""Build the Red Team summary for the Test.red_summary field."""
|
||||
occurrences = agg.get("occurrences", [])
|
||||
|
||||
lines = [
|
||||
f"MITRE ATT&CK Evaluation — Round {eval_round} ({adversary_display})",
|
||||
f"Vendor: CrowdStrike Falcon",
|
||||
f"Best detection level: {agg['detection_type']}",
|
||||
f"Tactic: {agg['tactic_name']} ({agg['tactic_id']})",
|
||||
f"Unique substeps: {len(occurrences)}",
|
||||
]
|
||||
|
||||
if occurrences:
|
||||
lines.append("")
|
||||
grouped = _group_occurrences_by_scenario(occurrences)
|
||||
for sc_name, sc_occs in grouped.items():
|
||||
if len(grouped) > 1:
|
||||
lines.append(f"{sc_name}:")
|
||||
for occ in sc_occs:
|
||||
ref = occ.get("substep_ref", "")
|
||||
criteria = occ.get("criteria", "")
|
||||
step_name = occ.get("step_name", "")
|
||||
det = occ.get("detection_type", "")
|
||||
if criteria:
|
||||
tag = f" [{ref}]" if ref else " •"
|
||||
step_tag = f" {step_name} —" if step_name else ""
|
||||
det_tag = f" [{det}]" if det and det.lower() not in ("none", "") else ""
|
||||
lines.append(f"{tag}{step_tag}{det_tag} {criteria}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main import function
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def import_evaluation_round(
|
||||
db: Session,
|
||||
adversary_name: str,
|
||||
adversary_display: str,
|
||||
eval_round: int,
|
||||
current_user: User,
|
||||
) -> dict[str, Any]:
|
||||
"""Import a single ATT&CK Evaluation round for CrowdStrike into the platform.
|
||||
|
||||
Creates one Test per unique technique with the best detection result
|
||||
observed across all substeps for that technique. All tests land in
|
||||
``in_review`` state — Blue Leads must confirm before they count as coverage.
|
||||
|
||||
Returns a summary dict: created, skipped, techniques_covered.
|
||||
Raises if the round was already imported (idempotency guard).
|
||||
"""
|
||||
# Idempotency — refuse duplicate imports
|
||||
existing = (
|
||||
db.query(EvaluationImport)
|
||||
.filter(
|
||||
EvaluationImport.adversary_name == adversary_name.lower(),
|
||||
EvaluationImport.status == "completed",
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
raise ValueError(
|
||||
f"Round '{adversary_display}' (round {eval_round}) was already imported "
|
||||
f"on {existing.imported_at.date()}. Re-import is not allowed."
|
||||
)
|
||||
|
||||
# Fetch and aggregate substep results
|
||||
substeps = fetch_results_for_adversary(adversary_name)
|
||||
by_technique = _aggregate_by_technique(substeps)
|
||||
|
||||
created = 0
|
||||
skipped = 0
|
||||
affected_technique_ids: set = set()
|
||||
|
||||
for mitre_id, agg in by_technique.items():
|
||||
# Look up the technique in our DB
|
||||
technique = (
|
||||
db.query(Technique)
|
||||
.filter(Technique.mitre_id == mitre_id.upper())
|
||||
.first()
|
||||
)
|
||||
if technique is None:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
detection_result = _score_to_result(agg["best_score"])
|
||||
|
||||
description = _build_description(agg, adversary_display, eval_round)
|
||||
red_summary = _build_red_summary(agg, adversary_display, eval_round)
|
||||
procedure_text = _build_procedure_text(agg, adversary_display, eval_round)
|
||||
|
||||
test = Test(
|
||||
technique_id=technique.id,
|
||||
name=f"[EVAL R{eval_round}] {adversary_display} — {technique.name}",
|
||||
description=description,
|
||||
platform=None,
|
||||
procedure_text=procedure_text,
|
||||
created_by=current_user.id,
|
||||
state=TestState.in_review,
|
||||
attack_success=True,
|
||||
red_summary=red_summary,
|
||||
red_validation_status="approved",
|
||||
red_validated_by=current_user.id,
|
||||
red_validated_at=datetime.utcnow(),
|
||||
detection_result=detection_result,
|
||||
blue_validation_status=None,
|
||||
execution_date=datetime.utcnow(),
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
db.add(test)
|
||||
db.flush()
|
||||
|
||||
log_action(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
action="eval_import_test",
|
||||
entity_type="test",
|
||||
entity_id=test.id,
|
||||
details={
|
||||
"adversary": adversary_name,
|
||||
"eval_round": eval_round,
|
||||
"mitre_id": mitre_id,
|
||||
"detection_type": agg["detection_type"],
|
||||
},
|
||||
)
|
||||
|
||||
affected_technique_ids.add(technique.id)
|
||||
created += 1
|
||||
|
||||
# Recalculate coverage for all touched techniques
|
||||
for tech_id in affected_technique_ids:
|
||||
tech = db.query(Technique).filter(Technique.id == tech_id).first()
|
||||
if tech:
|
||||
recalculate_technique_status(db, tech)
|
||||
|
||||
# Record the import
|
||||
record = EvaluationImport(
|
||||
id=uuid.uuid4(),
|
||||
adversary_name=adversary_name.lower(),
|
||||
adversary_display=adversary_display,
|
||||
eval_round=eval_round,
|
||||
imported_at=datetime.utcnow(),
|
||||
imported_by=current_user.id,
|
||||
tests_created=created,
|
||||
techniques_covered=len(affected_technique_ids),
|
||||
status="completed",
|
||||
notes=f"Skipped {skipped} techniques not found in local DB.",
|
||||
)
|
||||
db.add(record)
|
||||
db.commit()
|
||||
|
||||
logger.info(
|
||||
"ATT&CK Evaluation import complete — round %d (%s): %d tests created, %d skipped",
|
||||
eval_round, adversary_display, created, skipped,
|
||||
)
|
||||
return {
|
||||
"created": created,
|
||||
"skipped": skipped,
|
||||
"techniques_covered": len(affected_technique_ids),
|
||||
"adversary": adversary_display,
|
||||
"eval_round": eval_round,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# New-round check (used by the weekly scheduler)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def check_for_new_round(db: Session) -> dict[str, Any]:
|
||||
"""Check if a new evaluation round is available that hasn't been imported yet.
|
||||
|
||||
Returns:
|
||||
{"new_round_available": bool, "latest_round": dict | None, "already_imported": bool}
|
||||
"""
|
||||
try:
|
||||
latest = get_latest_round()
|
||||
except Exception as exc:
|
||||
logger.warning("Could not check for new ATT&CK Evaluation round: %s", exc)
|
||||
return {"new_round_available": False, "latest_round": None, "error": str(exc)}
|
||||
|
||||
already = (
|
||||
db.query(EvaluationImport)
|
||||
.filter(
|
||||
EvaluationImport.adversary_name == latest["name"].lower(),
|
||||
EvaluationImport.status == "completed",
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
return {
|
||||
"new_round_available": already is None,
|
||||
"already_imported": already is not None,
|
||||
"latest_round": {
|
||||
"name": latest["name"],
|
||||
"display_name": latest.get("display_name", latest["name"]),
|
||||
"eval_round": latest["eval_round"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Re-enrich existing tests with richer API data
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def re_enrich_evaluation_round(
|
||||
db: Session,
|
||||
adversary_name: str,
|
||||
adversary_display: str,
|
||||
eval_round: int,
|
||||
current_user: User,
|
||||
) -> dict[str, Any]:
|
||||
"""Update procedure_text / description / red_summary on already-imported tests
|
||||
for a given round using the enriched API data (attack path, criteria, data sources).
|
||||
|
||||
This is non-destructive — it only updates the three narrative fields and does
|
||||
not change detection results, state, or validation status.
|
||||
"""
|
||||
# Fetch & aggregate (same flow as import)
|
||||
substeps = fetch_results_for_adversary(adversary_name)
|
||||
by_technique = _aggregate_by_technique(substeps)
|
||||
|
||||
updated = 0
|
||||
skipped = 0
|
||||
|
||||
for mitre_id, agg in by_technique.items():
|
||||
technique = (
|
||||
db.query(Technique)
|
||||
.filter(Technique.mitre_id == mitre_id.upper())
|
||||
.first()
|
||||
)
|
||||
if technique is None:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
# Find the existing test for this round + technique
|
||||
existing_test = (
|
||||
db.query(Test)
|
||||
.filter(
|
||||
Test.technique_id == technique.id,
|
||||
Test.name.like(f"[EVAL R{eval_round}]%"),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not existing_test:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
existing_test.description = _build_description(agg, adversary_display, eval_round)
|
||||
existing_test.red_summary = _build_red_summary(agg, adversary_display, eval_round)
|
||||
existing_test.procedure_text = _build_procedure_text(agg, adversary_display, eval_round)
|
||||
updated += 1
|
||||
|
||||
db.commit()
|
||||
|
||||
logger.info(
|
||||
"Re-enrichment complete — round %d (%s): %d tests updated, %d skipped",
|
||||
eval_round, adversary_display, updated, skipped,
|
||||
)
|
||||
return {
|
||||
"updated": updated,
|
||||
"skipped": skipped,
|
||||
"adversary": adversary_display,
|
||||
"eval_round": eval_round,
|
||||
"message": (
|
||||
f"Re-enriched {updated} tests for {adversary_display} (Round {eval_round}) "
|
||||
f"with attack path, criteria and data sources from MITRE API."
|
||||
),
|
||||
}
|
||||
@@ -58,6 +58,7 @@ def log_action(
|
||||
ip_address=ip or None,
|
||||
user_agent=ua or None,
|
||||
session_id=session_id,
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
db.add(entry)
|
||||
db.flush()
|
||||
|
||||
@@ -41,5 +41,9 @@ def change_password(
|
||||
"""
|
||||
if not verify_password(current_password, user.hashed_password):
|
||||
raise BusinessRuleViolation("Current password is incorrect")
|
||||
if verify_password(new_password, user.hashed_password):
|
||||
raise BusinessRuleViolation(
|
||||
"New password must be different from the current password"
|
||||
)
|
||||
user.hashed_password = hash_password(new_password)
|
||||
user.must_change_password = False
|
||||
|
||||
@@ -35,6 +35,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.test_template import TestTemplate
|
||||
from app.models.data_source import DataSource
|
||||
from app.models.technique import Technique
|
||||
from app.services.audit_service import log_action
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -237,6 +238,7 @@ def sync(db: Session) -> dict:
|
||||
|
||||
created = 0
|
||||
skipped = 0
|
||||
new_technique_ids: set[str] = set()
|
||||
|
||||
for item in parsed:
|
||||
if item["atomic_test_id"] in existing_ids:
|
||||
@@ -257,8 +259,14 @@ def sync(db: Session) -> dict:
|
||||
)
|
||||
db.add(template)
|
||||
existing_ids.add(item["atomic_test_id"])
|
||||
new_technique_ids.add(item["mitre_technique_id"])
|
||||
created += 1
|
||||
|
||||
if new_technique_ids:
|
||||
db.query(Technique).filter(
|
||||
Technique.mitre_id.in_(new_technique_ids)
|
||||
).update({"review_required": True}, synchronize_session=False)
|
||||
|
||||
db.commit()
|
||||
|
||||
summary = {
|
||||
|
||||
@@ -25,6 +25,7 @@ from app.services.campaign_service import (
|
||||
TACTIC_TO_PHASE,
|
||||
)
|
||||
from app.services.campaign_scheduler_service import calculate_next_run
|
||||
from app.services.status_service import recalculate_technique_status
|
||||
|
||||
|
||||
# ── Serialization helpers ────────────────────────────────────────────────
|
||||
@@ -71,6 +72,7 @@ def serialize_campaign(db: Session, campaign: Campaign) -> dict:
|
||||
"threat_actor_id": str(campaign.threat_actor_id) if campaign.threat_actor_id else None,
|
||||
"threat_actor_name": actor.name if actor else None,
|
||||
"created_by": str(campaign.created_by) if campaign.created_by else None,
|
||||
"start_date": campaign.start_date.isoformat() if campaign.start_date else None,
|
||||
"scheduled_at": campaign.scheduled_at.isoformat() if campaign.scheduled_at else None,
|
||||
"completed_at": campaign.completed_at.isoformat() if campaign.completed_at else None,
|
||||
"target_platform": campaign.target_platform,
|
||||
@@ -99,6 +101,7 @@ def serialize_campaign_summary(db: Session, campaign: Campaign) -> dict:
|
||||
"status": campaign.status,
|
||||
"threat_actor_id": str(campaign.threat_actor_id) if campaign.threat_actor_id else None,
|
||||
"threat_actor_name": actor.name if actor else None,
|
||||
"start_date": campaign.start_date.isoformat() if campaign.start_date else None,
|
||||
"target_platform": campaign.target_platform,
|
||||
"tags": campaign.tags or [],
|
||||
"created_at": campaign.created_at.isoformat() if campaign.created_at else None,
|
||||
@@ -159,6 +162,7 @@ def create_campaign(
|
||||
target_platform: Optional[str] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
scheduled_at: Optional[str] = None,
|
||||
start_date: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Create a new campaign. Does not commit; caller commits."""
|
||||
campaign = Campaign(
|
||||
@@ -170,6 +174,7 @@ def create_campaign(
|
||||
tags=tags or [],
|
||||
created_by=creator_id,
|
||||
scheduled_at=datetime.fromisoformat(scheduled_at) if scheduled_at else None,
|
||||
start_date=datetime.fromisoformat(start_date) if start_date else None,
|
||||
)
|
||||
db.add(campaign)
|
||||
db.flush()
|
||||
@@ -212,6 +217,8 @@ def update_campaign(
|
||||
|
||||
if "scheduled_at" in fields and fields["scheduled_at"]:
|
||||
fields["scheduled_at"] = datetime.fromisoformat(fields["scheduled_at"])
|
||||
if "start_date" in fields and fields["start_date"]:
|
||||
fields["start_date"] = datetime.fromisoformat(fields["start_date"])
|
||||
|
||||
for field, value in fields.items():
|
||||
setattr(campaign, field, value)
|
||||
@@ -319,9 +326,28 @@ def remove_test_from_campaign(db: Session, campaign_id: str, campaign_test_id: s
|
||||
for dep in dependents:
|
||||
dep.depends_on = None
|
||||
|
||||
# Keep a reference to the underlying test before deleting the join record
|
||||
test_id = ct.test_id
|
||||
technique_id = None
|
||||
test_obj = db.query(Test).filter(Test.id == test_id).first()
|
||||
if test_obj:
|
||||
technique_id = test_obj.technique_id
|
||||
|
||||
db.delete(ct)
|
||||
db.flush()
|
||||
|
||||
# Also delete the actual test record (it was created for this campaign)
|
||||
if test_obj:
|
||||
db.delete(test_obj)
|
||||
db.flush()
|
||||
|
||||
# Recalculate technique status_global so coverage metrics stay consistent
|
||||
if technique_id:
|
||||
technique = db.query(Technique).filter(Technique.id == technique_id).first()
|
||||
if technique:
|
||||
recalculate_technique_status(db, technique)
|
||||
db.flush()
|
||||
|
||||
|
||||
def activate_campaign(db: Session, campaign_id: str) -> Campaign:
|
||||
"""Activate a campaign, moving it from draft to active.
|
||||
@@ -425,6 +451,72 @@ def schedule_campaign(
|
||||
return campaign
|
||||
|
||||
|
||||
def delete_campaign(
|
||||
db: Session,
|
||||
campaign_id: str,
|
||||
*,
|
||||
deleter_id: uuid.UUID,
|
||||
deleter_role: str,
|
||||
delete_tests: bool = False,
|
||||
) -> None:
|
||||
"""Delete a campaign.
|
||||
|
||||
Only draft campaigns can be deleted unless the caller is admin.
|
||||
If delete_tests=True, the associated Test objects are also deleted.
|
||||
Does not commit; caller commits.
|
||||
"""
|
||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||
if not campaign:
|
||||
raise EntityNotFoundError("Campaign", campaign_id)
|
||||
|
||||
if campaign.status != "draft" and deleter_role != "admin":
|
||||
raise BusinessRuleViolation("Only draft campaigns can be deleted")
|
||||
|
||||
if str(campaign.created_by) != str(deleter_id) and deleter_role != "admin":
|
||||
raise PermissionViolation("Only the creator or admin can delete this campaign")
|
||||
|
||||
# Collect test IDs before removing associations
|
||||
campaign_tests = (
|
||||
db.query(CampaignTest).filter(CampaignTest.campaign_id == campaign_id).all()
|
||||
)
|
||||
test_ids = [ct.test_id for ct in campaign_tests]
|
||||
|
||||
# Remove CampaignTest join rows (clear depends_on refs first to avoid FK cycles)
|
||||
for ct in campaign_tests:
|
||||
ct.depends_on = None
|
||||
db.flush()
|
||||
for ct in campaign_tests:
|
||||
db.delete(ct)
|
||||
db.flush()
|
||||
|
||||
# Optionally delete the associated tests
|
||||
if delete_tests:
|
||||
affected_technique_ids: set = set()
|
||||
for test_id in test_ids:
|
||||
test = db.query(Test).filter(Test.id == test_id).first()
|
||||
if test:
|
||||
if test.technique_id:
|
||||
affected_technique_ids.add(test.technique_id)
|
||||
db.delete(test)
|
||||
db.flush()
|
||||
# Recalculate status_global for every affected technique so the
|
||||
# coverage metrics stay consistent after test deletion.
|
||||
for tech_id in affected_technique_ids:
|
||||
technique = db.query(Technique).filter(Technique.id == tech_id).first()
|
||||
if technique:
|
||||
recalculate_technique_status(db, technique)
|
||||
db.flush()
|
||||
|
||||
# Null-out parent_campaign_id on child campaigns to avoid FK violation
|
||||
db.query(Campaign).filter(Campaign.parent_campaign_id == campaign.id).update(
|
||||
{"parent_campaign_id": None}
|
||||
)
|
||||
db.flush()
|
||||
|
||||
db.delete(campaign)
|
||||
db.flush()
|
||||
|
||||
|
||||
def get_campaign_history(db: Session, campaign_id: str) -> dict:
|
||||
"""List all child campaigns (execution history) of a recurring campaign.
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ threat actors, and progress calculation.
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -106,6 +107,8 @@ def generate_campaign_from_threat_actor(
|
||||
db: Session,
|
||||
actor_id: uuid.UUID,
|
||||
user: User,
|
||||
*,
|
||||
start_date: Optional[datetime] = None,
|
||||
) -> Campaign:
|
||||
"""Auto-generate a campaign from a threat actor's uncovered techniques.
|
||||
|
||||
@@ -146,6 +149,7 @@ def generate_campaign_from_threat_actor(
|
||||
status="draft",
|
||||
created_by=user.id,
|
||||
tags=[actor.name, "auto-generated"],
|
||||
start_date=start_date,
|
||||
)
|
||||
db.add(campaign)
|
||||
db.flush() # Get campaign.id
|
||||
@@ -181,6 +185,7 @@ def generate_campaign_from_threat_actor(
|
||||
tool_used=template.tool_suggested,
|
||||
created_by=user.id,
|
||||
state=TestState.draft,
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
db.add(test)
|
||||
db.flush() # Get test.id
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -62,6 +62,7 @@ def _get_control_status(control: ComplianceControl, db: Session) -> dict[str, An
|
||||
return {
|
||||
"control_id": control.control_id,
|
||||
"title": control.title,
|
||||
"description": control.description,
|
||||
"category": control.category,
|
||||
"status": "not_evaluated",
|
||||
"score": 0,
|
||||
@@ -104,6 +105,7 @@ def _get_control_status(control: ComplianceControl, db: Session) -> dict[str, An
|
||||
return {
|
||||
"control_id": control.control_id,
|
||||
"title": control.title,
|
||||
"description": control.description,
|
||||
"category": control.category,
|
||||
"status": status,
|
||||
"score": avg_score,
|
||||
|
||||
@@ -0,0 +1,259 @@
|
||||
"""Decay Engine — calculates confidence scores and expires validations."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.detection_lifecycle import (
|
||||
DetectionAsset, DetectionValidation,
|
||||
DetectionTechniqueMapping, TechniqueConfidenceScore,
|
||||
DetectionConfidence, DetectionHealthStatus,
|
||||
InfrastructureChangeLog,
|
||||
)
|
||||
from app.models.decay_policy import DecayPolicy
|
||||
from app.models.technique import Technique
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _now() -> datetime:
|
||||
return datetime.utcnow()
|
||||
|
||||
|
||||
def get_applicable_policy(db: Session, platform: Optional[str] = None, asset_type: Optional[str] = None, tactic: Optional[str] = None) -> DecayPolicy:
|
||||
query = db.query(DecayPolicy).filter(DecayPolicy.is_active == True)
|
||||
if platform:
|
||||
specific = query.filter(DecayPolicy.applies_to_platform == platform).first()
|
||||
if specific:
|
||||
return specific
|
||||
if asset_type:
|
||||
specific = query.filter(DecayPolicy.applies_to_asset_type == asset_type).first()
|
||||
if specific:
|
||||
return specific
|
||||
if tactic:
|
||||
specific = query.filter(DecayPolicy.applies_to_tactic == tactic).first()
|
||||
if specific:
|
||||
return specific
|
||||
default_policy = query.filter(DecayPolicy.is_default == True).first()
|
||||
if default_policy:
|
||||
return default_policy
|
||||
# Return an in-memory default if no DB policy exists
|
||||
p = DecayPolicy()
|
||||
p.fresh_days = 90
|
||||
p.aging_days = 180
|
||||
p.stale_days = 365
|
||||
p.recency_weight = 0.30
|
||||
p.coverage_weight = 0.30
|
||||
p.health_weight = 0.25
|
||||
p.diversity_weight = 0.15
|
||||
return p
|
||||
|
||||
|
||||
def calculate_confidence_for_technique(db: Session, technique_id: UUID) -> Optional[TechniqueConfidenceScore]:
|
||||
technique = db.query(Technique).filter(Technique.id == technique_id).first()
|
||||
if not technique:
|
||||
return None
|
||||
|
||||
policy = get_applicable_policy(db, tactic=technique.tactic)
|
||||
mappings = db.query(DetectionTechniqueMapping).filter(DetectionTechniqueMapping.technique_id == technique_id).all()
|
||||
asset_ids = [m.detection_asset_id for m in mappings]
|
||||
|
||||
if not asset_ids:
|
||||
return _create_or_update_score(db, technique_id,
|
||||
confidence_level=DetectionConfidence.unknown, confidence_score=0.0,
|
||||
factors={"recency": 0.0, "coverage": 0.0, "health": 0.0, "diversity": 0.0},
|
||||
risk_factors=["no_detection_assets"], detection_count=0, valid_count=0,
|
||||
)
|
||||
|
||||
assets = db.query(DetectionAsset).filter(DetectionAsset.id.in_(asset_ids), DetectionAsset.is_active == True).all()
|
||||
now = _now()
|
||||
|
||||
valid_validations = db.query(DetectionValidation).filter(
|
||||
DetectionValidation.detection_asset_id.in_(asset_ids),
|
||||
DetectionValidation.is_valid == True,
|
||||
DetectionValidation.expires_at > now,
|
||||
).all()
|
||||
|
||||
recency_factor = 0.0
|
||||
last_validated = None
|
||||
if valid_validations:
|
||||
most_recent = max(v.validated_at for v in valid_validations)
|
||||
# Strip tzinfo if present so arithmetic stays consistent with naive UTC
|
||||
if most_recent.tzinfo is not None:
|
||||
most_recent = most_recent.replace(tzinfo=None)
|
||||
last_validated = most_recent
|
||||
days_since = (now - most_recent).days
|
||||
if days_since <= policy.fresh_days:
|
||||
recency_factor = 1.0
|
||||
elif days_since <= policy.aging_days:
|
||||
range_days = policy.aging_days - policy.fresh_days
|
||||
elapsed = days_since - policy.fresh_days
|
||||
recency_factor = 1.0 - (elapsed / range_days) * 0.4
|
||||
elif days_since <= policy.stale_days:
|
||||
range_days = policy.stale_days - policy.aging_days
|
||||
elapsed = days_since - policy.aging_days
|
||||
recency_factor = 0.6 - (elapsed / range_days) * 0.4
|
||||
else:
|
||||
recency_factor = max(0.1, 0.2 - ((days_since - policy.stale_days) / 365) * 0.1)
|
||||
|
||||
active_count = len(assets)
|
||||
valid_count = len(set(v.detection_asset_id for v in valid_validations))
|
||||
|
||||
if active_count == 0:
|
||||
coverage_factor = 0.0
|
||||
elif valid_count >= 3:
|
||||
coverage_factor = 1.0
|
||||
elif valid_count >= 2:
|
||||
coverage_factor = 0.8
|
||||
elif valid_count >= 1:
|
||||
coverage_factor = 0.5
|
||||
else:
|
||||
coverage_factor = 0.1
|
||||
|
||||
health_scores = {
|
||||
DetectionHealthStatus.healthy: 1.0,
|
||||
DetectionHealthStatus.silent: 0.4,
|
||||
DetectionHealthStatus.noisy: 0.6,
|
||||
DetectionHealthStatus.orphan: 0.3,
|
||||
DetectionHealthStatus.deprecated: 0.0,
|
||||
DetectionHealthStatus.untested: 0.2,
|
||||
}
|
||||
health_factor = sum(health_scores.get(a.health_status, 0.2) for a in assets) / max(len(assets), 1)
|
||||
|
||||
platforms = set(a.platform for a in assets if a.platform)
|
||||
asset_types = set(a.asset_type for a in assets)
|
||||
diversity_factor = min(1.0, len(platforms) * 0.3 + len(asset_types) * 0.2)
|
||||
|
||||
confidence_score = (
|
||||
recency_factor * policy.recency_weight +
|
||||
coverage_factor * policy.coverage_weight +
|
||||
health_factor * policy.health_weight +
|
||||
diversity_factor * policy.diversity_weight
|
||||
) * 100
|
||||
|
||||
if confidence_score >= 75:
|
||||
confidence_level = DetectionConfidence.fresh
|
||||
elif confidence_score >= 50:
|
||||
confidence_level = DetectionConfidence.aging
|
||||
elif confidence_score >= 25:
|
||||
confidence_level = DetectionConfidence.stale
|
||||
elif confidence_score > 0:
|
||||
confidence_level = DetectionConfidence.broken
|
||||
else:
|
||||
confidence_level = DetectionConfidence.unknown
|
||||
|
||||
risk_factors = []
|
||||
if len(platforms) <= 1:
|
||||
risk_factors.append("single_platform")
|
||||
if valid_count == 0:
|
||||
risk_factors.append("no_valid_detections")
|
||||
if any(a.health_status == DetectionHealthStatus.silent for a in assets):
|
||||
risk_factors.append("silent_rules_present")
|
||||
if any(a.health_status == DetectionHealthStatus.orphan for a in assets):
|
||||
risk_factors.append("orphan_rules_present")
|
||||
if recency_factor < 0.5:
|
||||
risk_factors.append("stale_validation")
|
||||
if len(assets) < 2:
|
||||
risk_factors.append("low_detection_diversity")
|
||||
|
||||
next_due = None
|
||||
if valid_validations:
|
||||
earliest_expiry = min(v.expires_at for v in valid_validations)
|
||||
next_due = earliest_expiry
|
||||
|
||||
return _create_or_update_score(
|
||||
db, technique_id,
|
||||
confidence_level=confidence_level,
|
||||
confidence_score=round(confidence_score, 1),
|
||||
factors={"recency": round(recency_factor, 3), "coverage": round(coverage_factor, 3), "health": round(health_factor, 3), "diversity": round(diversity_factor, 3)},
|
||||
risk_factors=risk_factors,
|
||||
detection_count=active_count,
|
||||
valid_count=valid_count,
|
||||
last_validated=last_validated,
|
||||
next_due=next_due,
|
||||
)
|
||||
|
||||
|
||||
def _create_or_update_score(db: Session, technique_id: UUID, **kwargs) -> TechniqueConfidenceScore:
|
||||
score = db.query(TechniqueConfidenceScore).filter(TechniqueConfidenceScore.technique_id == technique_id).first()
|
||||
if not score:
|
||||
score = TechniqueConfidenceScore(technique_id=technique_id)
|
||||
db.add(score)
|
||||
|
||||
score.confidence_level = kwargs["confidence_level"]
|
||||
score.confidence_score = kwargs["confidence_score"]
|
||||
score.detection_count = kwargs["detection_count"]
|
||||
score.valid_detection_count = kwargs["valid_count"]
|
||||
score.recency_factor = kwargs["factors"]["recency"]
|
||||
score.coverage_factor = kwargs["factors"]["coverage"]
|
||||
score.health_factor = kwargs["factors"]["health"]
|
||||
score.diversity_factor = kwargs["factors"]["diversity"]
|
||||
score.risk_factors = kwargs["risk_factors"]
|
||||
score.score_breakdown = kwargs["factors"]
|
||||
score.last_validated_at = kwargs.get("last_validated")
|
||||
score.next_validation_due = kwargs.get("next_due")
|
||||
score.last_recalculated_at = _now()
|
||||
score.updated_at = _now()
|
||||
|
||||
db.commit()
|
||||
db.refresh(score)
|
||||
return score
|
||||
|
||||
|
||||
def run_decay_engine(db: Session) -> dict:
|
||||
techniques = db.query(Technique).all()
|
||||
results = {"total_techniques": len(techniques), "fresh": 0, "aging": 0, "stale": 0, "broken": 0, "unknown": 0, "validations_expired": 0}
|
||||
now = _now()
|
||||
|
||||
# Expire stale validations
|
||||
expired = db.query(DetectionValidation).filter(
|
||||
DetectionValidation.is_valid == True,
|
||||
DetectionValidation.expires_at <= now,
|
||||
).all()
|
||||
from app.models.detection_lifecycle import InvalidationReason
|
||||
for v in expired:
|
||||
v.is_valid = False
|
||||
v.invalidated_at = now
|
||||
v.invalidation_reason = InvalidationReason.time_decay
|
||||
results["validations_expired"] = len(expired)
|
||||
if expired:
|
||||
db.commit()
|
||||
|
||||
for technique in techniques:
|
||||
score = calculate_confidence_for_technique(db, technique.id)
|
||||
if score:
|
||||
level = score.confidence_level.value
|
||||
results[level] = results.get(level, 0) + 1
|
||||
|
||||
logger.info("Decay engine completed: %s", results)
|
||||
return results
|
||||
|
||||
|
||||
def process_infrastructure_change(db: Session, change_id: UUID) -> int:
|
||||
change = db.query(InfrastructureChangeLog).filter(InfrastructureChangeLog.id == change_id).first()
|
||||
if not change or not change.auto_invalidate:
|
||||
return 0
|
||||
|
||||
query = db.query(DetectionAsset).filter(DetectionAsset.is_active == True)
|
||||
if change.affected_platforms:
|
||||
query = query.filter(DetectionAsset.platform.in_(change.affected_platforms))
|
||||
|
||||
affected_assets = query.all()
|
||||
total_invalidated = 0
|
||||
|
||||
from app.services.detection_asset_service import invalidate_validations_for_asset
|
||||
for asset in affected_assets:
|
||||
if change.affected_log_sources:
|
||||
asset_log_source = asset.log_source_name or ""
|
||||
if not any(ls in asset_log_source for ls in change.affected_log_sources):
|
||||
continue
|
||||
count = invalidate_validations_for_asset(db, asset.id, change.reported_by, "infrastructure_change")
|
||||
total_invalidated += count
|
||||
|
||||
change.invalidated_count = total_invalidated
|
||||
db.commit()
|
||||
logger.info("Infrastructure change %s: invalidated %d validations", change_id, total_invalidated)
|
||||
return total_invalidated
|
||||
@@ -0,0 +1,211 @@
|
||||
"""Detection Asset CRUD service with auto-hash and change detection."""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from app.models.detection_lifecycle import (
|
||||
DetectionAsset, DetectionTechniqueMapping,
|
||||
DetectionValidation, DetectionHealthStatus, InvalidationReason
|
||||
)
|
||||
from app.models.technique import Technique
|
||||
from app.domain.exceptions import EntityNotFoundError
|
||||
from app.services import audit_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _compute_rule_hash(content: str) -> str:
|
||||
normalized = content.strip().replace('\r\n', '\n')
|
||||
return hashlib.sha256(normalized.encode()).hexdigest()
|
||||
|
||||
|
||||
def _now() -> datetime:
|
||||
return datetime.utcnow()
|
||||
|
||||
|
||||
def create_detection_asset(db: Session, data: dict, user_id: UUID) -> DetectionAsset:
|
||||
technique_ids = data.pop("technique_ids", []) or []
|
||||
# Remove None values so defaults apply
|
||||
data = {k: v for k, v in data.items() if v is not None or k in ("log_source_config", "infrastructure_details", "tags")}
|
||||
|
||||
asset = DetectionAsset(**data, created_by=user_id)
|
||||
|
||||
if asset.rule_content:
|
||||
asset.rule_hash = _compute_rule_hash(asset.rule_content)
|
||||
asset.last_rule_change_at = _now()
|
||||
|
||||
if asset.infrastructure_details:
|
||||
infra_str = str(sorted(asset.infrastructure_details.items()))
|
||||
asset.infrastructure_hash = hashlib.sha256(infra_str.encode()).hexdigest()
|
||||
|
||||
db.add(asset)
|
||||
db.flush()
|
||||
|
||||
for tech_id in technique_ids:
|
||||
technique = db.query(Technique).filter(Technique.id == tech_id).first()
|
||||
if technique:
|
||||
mapping = DetectionTechniqueMapping(
|
||||
detection_asset_id=asset.id,
|
||||
technique_id=tech_id,
|
||||
)
|
||||
db.add(mapping)
|
||||
|
||||
db.commit()
|
||||
db.refresh(asset)
|
||||
|
||||
audit_service.log_action(
|
||||
db, user_id, "DETECTION_ASSET_CREATED", "detection_asset", str(asset.id),
|
||||
details={"name": asset.name, "type": asset.asset_type, "platform": asset.platform, "technique_count": len(technique_ids)},
|
||||
)
|
||||
return asset
|
||||
|
||||
|
||||
def update_detection_asset(db: Session, asset_id: UUID, data: dict, user_id: UUID) -> DetectionAsset:
|
||||
asset = db.query(DetectionAsset).filter(DetectionAsset.id == asset_id).first()
|
||||
if not asset:
|
||||
raise EntityNotFoundError("DetectionAsset", str(asset_id))
|
||||
|
||||
changes = {}
|
||||
rule_changed = False
|
||||
|
||||
for key, value in data.items():
|
||||
if value is not None and hasattr(asset, key):
|
||||
old_value = getattr(asset, key)
|
||||
if old_value != value:
|
||||
changes[key] = {"old": str(old_value), "new": str(value)}
|
||||
setattr(asset, key, value)
|
||||
|
||||
if "rule_content" in data and data["rule_content"]:
|
||||
new_hash = _compute_rule_hash(data["rule_content"])
|
||||
if new_hash != asset.rule_hash:
|
||||
rule_changed = True
|
||||
asset.rule_hash = new_hash
|
||||
asset.last_rule_change_at = _now()
|
||||
|
||||
if "infrastructure_details" in data and data["infrastructure_details"]:
|
||||
infra_str = str(sorted(data["infrastructure_details"].items()))
|
||||
new_hash = hashlib.sha256(infra_str.encode()).hexdigest()
|
||||
if new_hash != asset.infrastructure_hash:
|
||||
asset.infrastructure_hash = new_hash
|
||||
changes["infrastructure_hash_changed"] = True
|
||||
|
||||
asset.updated_at = _now()
|
||||
db.commit()
|
||||
db.refresh(asset)
|
||||
|
||||
if changes:
|
||||
audit_service.log_action(
|
||||
db, user_id, "DETECTION_ASSET_UPDATED", "detection_asset", str(asset.id),
|
||||
details={"changes": changes, "rule_changed": rule_changed},
|
||||
)
|
||||
|
||||
if rule_changed:
|
||||
invalidate_validations_for_asset(db, asset.id, user_id, "rule_modified")
|
||||
|
||||
return asset
|
||||
|
||||
|
||||
def invalidate_validations_for_asset(db: Session, asset_id: UUID, user_id: UUID, reason: str) -> int:
|
||||
try:
|
||||
reason_enum = InvalidationReason(reason)
|
||||
except ValueError:
|
||||
reason_enum = InvalidationReason.manual
|
||||
|
||||
validations = db.query(DetectionValidation).filter(
|
||||
DetectionValidation.detection_asset_id == asset_id,
|
||||
DetectionValidation.is_valid == True,
|
||||
).all()
|
||||
|
||||
count = 0
|
||||
for v in validations:
|
||||
v.is_valid = False
|
||||
v.invalidated_at = _now()
|
||||
v.invalidation_reason = reason_enum
|
||||
v.invalidated_by = user_id
|
||||
count += 1
|
||||
|
||||
if count > 0:
|
||||
db.commit()
|
||||
logger.info("Invalidated %d validations for asset %s due to %s", count, asset_id, reason)
|
||||
|
||||
return count
|
||||
|
||||
|
||||
def get_asset_with_details(db: Session, asset_id: UUID) -> DetectionAsset:
|
||||
asset = (
|
||||
db.query(DetectionAsset)
|
||||
.options(joinedload(DetectionAsset.technique_mappings), joinedload(DetectionAsset.validations))
|
||||
.filter(DetectionAsset.id == asset_id)
|
||||
.first()
|
||||
)
|
||||
if not asset:
|
||||
raise EntityNotFoundError("DetectionAsset", str(asset_id))
|
||||
return asset
|
||||
|
||||
|
||||
def list_assets(
|
||||
db: Session,
|
||||
platform: Optional[str] = None,
|
||||
asset_type: Optional[str] = None,
|
||||
health_status: Optional[str] = None,
|
||||
technique_id: Optional[UUID] = None,
|
||||
is_active: Optional[bool] = True,
|
||||
) -> list:
|
||||
query = db.query(DetectionAsset)
|
||||
if platform:
|
||||
query = query.filter(DetectionAsset.platform == platform)
|
||||
if asset_type:
|
||||
query = query.filter(DetectionAsset.asset_type == asset_type)
|
||||
if health_status:
|
||||
query = query.filter(DetectionAsset.health_status == health_status)
|
||||
if is_active is not None:
|
||||
query = query.filter(DetectionAsset.is_active == is_active)
|
||||
if technique_id:
|
||||
query = query.join(DetectionTechniqueMapping).filter(
|
||||
DetectionTechniqueMapping.technique_id == technique_id
|
||||
)
|
||||
return query.order_by(DetectionAsset.name).all()
|
||||
|
||||
|
||||
def get_technique_detection_summary(db: Session, technique_id: UUID) -> dict:
|
||||
mappings = (
|
||||
db.query(DetectionTechniqueMapping)
|
||||
.options(joinedload(DetectionTechniqueMapping.detection_asset))
|
||||
.filter(DetectionTechniqueMapping.technique_id == technique_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
assets = [m.detection_asset for m in mappings if m.detection_asset]
|
||||
active_assets = [a for a in assets if a.is_active]
|
||||
now = _now()
|
||||
|
||||
valid_count = 0
|
||||
for asset in active_assets:
|
||||
has_valid = db.query(DetectionValidation).filter(
|
||||
DetectionValidation.detection_asset_id == asset.id,
|
||||
DetectionValidation.is_valid == True,
|
||||
DetectionValidation.expires_at > now,
|
||||
).first()
|
||||
if has_valid:
|
||||
valid_count += 1
|
||||
|
||||
health_distribution = {}
|
||||
for asset in active_assets:
|
||||
status = asset.health_status.value if asset.health_status else "unknown"
|
||||
health_distribution[status] = health_distribution.get(status, 0) + 1
|
||||
|
||||
platforms = list(set(a.platform for a in active_assets if a.platform))
|
||||
|
||||
return {
|
||||
"technique_id": str(technique_id),
|
||||
"total_assets": len(active_assets),
|
||||
"validated_assets": valid_count,
|
||||
"health_distribution": health_distribution,
|
||||
"platforms": platforms,
|
||||
"coverage_types": list(set(m.coverage_type for m in mappings if m.coverage_type)),
|
||||
}
|
||||
@@ -34,6 +34,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.detection_rule import DetectionRule
|
||||
from app.models.data_source import DataSource
|
||||
from app.models.technique import Technique
|
||||
from app.services.audit_service import log_action
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -316,6 +317,7 @@ def sync(db: Session) -> dict:
|
||||
|
||||
created = 0
|
||||
skipped = 0
|
||||
new_technique_ids: set[str] = set()
|
||||
|
||||
for item in parsed_rules:
|
||||
if item["source_id"] in existing_ids:
|
||||
@@ -337,8 +339,15 @@ def sync(db: Session) -> dict:
|
||||
)
|
||||
db.add(rule)
|
||||
existing_ids.add(item["source_id"])
|
||||
new_technique_ids.add(item["mitre_technique_id"])
|
||||
created += 1
|
||||
|
||||
# Flag techniques that received new rules for review
|
||||
if new_technique_ids:
|
||||
db.query(Technique).filter(
|
||||
Technique.mitre_id.in_(new_technique_ids)
|
||||
).update({"review_required": True}, synchronize_session=False)
|
||||
|
||||
db.commit()
|
||||
|
||||
summary = {
|
||||
|
||||
@@ -0,0 +1,168 @@
|
||||
"""Email notification service using SMTP.
|
||||
|
||||
Sending is silently skipped when SMTP_ENABLED=False (default) and no
|
||||
DB config overrides it. All errors are caught and logged — email
|
||||
failures never crash the caller.
|
||||
|
||||
Config priority:
|
||||
1. system_configs table (key ``smtp.*``) — managed via the Settings UI
|
||||
2. .env / environment variables (app.config.settings)
|
||||
"""
|
||||
import logging
|
||||
import smtplib
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
from typing import Optional
|
||||
|
||||
from app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers — read effective SMTP config (DB first, env fallback)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _get_smtp_config(db=None) -> dict:
|
||||
"""Return a dict with resolved SMTP settings.
|
||||
|
||||
When *db* is provided the function looks up ``system_configs`` rows
|
||||
whose key starts with ``smtp.`` and overrides the .env values.
|
||||
"""
|
||||
cfg = {
|
||||
"enabled": settings.SMTP_ENABLED,
|
||||
"host": settings.SMTP_HOST,
|
||||
"port": settings.SMTP_PORT,
|
||||
"username": settings.SMTP_USERNAME,
|
||||
"password": settings.SMTP_PASSWORD,
|
||||
"from_email": settings.SMTP_FROM_EMAIL,
|
||||
"use_tls": settings.SMTP_USE_TLS,
|
||||
}
|
||||
|
||||
if db is not None:
|
||||
try:
|
||||
from app.models.system_config import SystemConfig # avoid circular
|
||||
|
||||
rows = db.query(SystemConfig).filter(
|
||||
SystemConfig.key.like("smtp.%")
|
||||
).all()
|
||||
for row in rows:
|
||||
k = row.key # e.g. "smtp.host"
|
||||
v = row.value
|
||||
if v is None:
|
||||
continue
|
||||
short = k[len("smtp."):] # "host"
|
||||
if short == "enabled":
|
||||
cfg["enabled"] = v.lower() in ("true", "1", "yes")
|
||||
elif short == "host":
|
||||
cfg["host"] = v
|
||||
elif short == "port":
|
||||
try:
|
||||
cfg["port"] = int(v)
|
||||
except ValueError:
|
||||
pass
|
||||
elif short == "username":
|
||||
cfg["username"] = v
|
||||
elif short == "password":
|
||||
cfg["password"] = v
|
||||
elif short == "from_email":
|
||||
cfg["from_email"] = v
|
||||
elif short == "use_tls":
|
||||
cfg["use_tls"] = v.lower() in ("true", "1", "yes")
|
||||
except Exception:
|
||||
logger.exception("Failed to read SMTP config from DB — falling back to env")
|
||||
|
||||
return cfg
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Core send
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def send_email(to: str, subject: str, html_body: str, db=None) -> bool:
|
||||
"""Send an HTML email. Returns True on success, False on skip/error.
|
||||
|
||||
Pass *db* to allow runtime config override from system_configs table.
|
||||
"""
|
||||
cfg = _get_smtp_config(db)
|
||||
|
||||
if not cfg["enabled"]:
|
||||
logger.debug("SMTP disabled — skipping email to %s: %s", to, subject)
|
||||
return False
|
||||
if not to:
|
||||
return False
|
||||
try:
|
||||
msg = MIMEMultipart("alternative")
|
||||
msg["Subject"] = f"[Aegis] {subject}"
|
||||
msg["From"] = cfg["from_email"]
|
||||
msg["To"] = to
|
||||
msg.attach(MIMEText(html_body, "html"))
|
||||
with smtplib.SMTP(cfg["host"], cfg["port"], timeout=10) as server:
|
||||
if cfg["use_tls"]:
|
||||
server.starttls()
|
||||
if cfg["username"]:
|
||||
server.login(cfg["username"], cfg["password"])
|
||||
server.send_message(msg)
|
||||
logger.info("Email sent to %s: %s", to, subject)
|
||||
return True
|
||||
except Exception:
|
||||
logger.exception("Failed to send email to %s: %s", to, subject)
|
||||
return False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Typed senders
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def send_test_validated_email(to: str, test_name: str, technique_id: str, test_id: str, db=None) -> bool:
|
||||
"""Notify that a test was validated."""
|
||||
url = f"{settings.PLATFORM_URL}/tests/{test_id}"
|
||||
html = f"""
|
||||
<html><body style="font-family:sans-serif;color:#1a1a2e">
|
||||
<h2 style="color:#22d3ee">✅ Test Validated</h2>
|
||||
<p>Test <strong>{test_name}</strong> for technique <code>{technique_id}</code> has been validated.</p>
|
||||
<p><a href="{url}" style="background:#22d3ee;color:#000;padding:8px 16px;border-radius:4px;text-decoration:none">View Test</a></p>
|
||||
<p style="color:#666;font-size:12px">Aegis ATT&CK Coverage Platform</p>
|
||||
</body></html>"""
|
||||
return send_email(to, f"Test Validated: {test_name}", html, db=db)
|
||||
|
||||
|
||||
def send_campaign_completed_email(to: str, campaign_name: str, campaign_id: str, db=None) -> bool:
|
||||
"""Notify that a campaign was completed."""
|
||||
url = f"{settings.PLATFORM_URL}/campaigns/{campaign_id}"
|
||||
html = f"""
|
||||
<html><body style="font-family:sans-serif;color:#1a1a2e">
|
||||
<h2 style="color:#22d3ee">🎯 Campaign Completed</h2>
|
||||
<p>Campaign <strong>{campaign_name}</strong> has been completed.</p>
|
||||
<p><a href="{url}" style="background:#22d3ee;color:#000;padding:8px 16px;border-radius:4px;text-decoration:none">View Campaign</a></p>
|
||||
<p style="color:#666;font-size:12px">Aegis ATT&CK Coverage Platform</p>
|
||||
</body></html>"""
|
||||
return send_email(to, f"Campaign Completed: {campaign_name}", html, db=db)
|
||||
|
||||
|
||||
def send_new_mitre_techniques_email(to: str, created: int, updated: int, db=None) -> bool:
|
||||
"""Notify of new MITRE techniques after sync."""
|
||||
if created == 0:
|
||||
return False
|
||||
html = f"""
|
||||
<html><body style="font-family:sans-serif;color:#1a1a2e">
|
||||
<h2 style="color:#22d3ee">🔄 MITRE ATT&CK Updated</h2>
|
||||
<p><strong>{created}</strong> new techniques added, <strong>{updated}</strong> updated.</p>
|
||||
<p><a href="{settings.PLATFORM_URL}/techniques" style="background:#22d3ee;color:#000;padding:8px 16px;border-radius:4px;text-decoration:none">View Techniques</a></p>
|
||||
<p style="color:#666;font-size:12px">Aegis ATT&CK Coverage Platform</p>
|
||||
</body></html>"""
|
||||
return send_email(to, f"MITRE ATT&CK Updated: {created} new techniques", html, db=db)
|
||||
|
||||
|
||||
def send_test_email(to: str, db=None) -> bool:
|
||||
"""Send a test/ping email to verify SMTP config."""
|
||||
html = """
|
||||
<html><body style="font-family:sans-serif;color:#1a1a2e">
|
||||
<h2 style="color:#22d3ee">✅ Email Configuration Test</h2>
|
||||
<p>This is a test email from Aegis. If you received this, your SMTP configuration is working correctly.</p>
|
||||
<p style="color:#666;font-size:12px">Aegis ATT&CK Coverage Platform</p>
|
||||
</body></html>"""
|
||||
return send_email(to, "Email Configuration Test", html, db=db)
|
||||
@@ -0,0 +1,361 @@
|
||||
"""Phase 13: Executive Dashboard service — aggregate posture data across all phases."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from datetime import date, datetime, timedelta
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.executive_dashboard import PostureSnapshot
|
||||
from app.models.technique import Technique
|
||||
from app.models.risk_intelligence import TechniqueRiskProfile
|
||||
from app.models.ownership_queue import (
|
||||
TechniqueOwnership, RevalidationQueueItem, QueueStatus,
|
||||
)
|
||||
from app.models.knowledge import Playbook, LessonLearned
|
||||
from app.models.attack_path import AttackPathExecution, ExecutionStatus
|
||||
from app.models.test import Test
|
||||
from app.models.osint_item import OsintItem
|
||||
from app.models.enums import TechniqueStatus
|
||||
|
||||
|
||||
# ── Internal aggregation helpers ──────────────────────────────────────────────
|
||||
|
||||
def _aggregate_coverage(db: Session) -> dict:
|
||||
"""Aggregate technique coverage counts from live data."""
|
||||
techniques = db.query(Technique).all()
|
||||
total = len(techniques)
|
||||
|
||||
counts = {
|
||||
TechniqueStatus.validated: 0,
|
||||
TechniqueStatus.partial: 0,
|
||||
TechniqueStatus.not_covered: 0,
|
||||
}
|
||||
for t in techniques:
|
||||
s = t.status_global
|
||||
if s in counts:
|
||||
counts[s] += 1
|
||||
|
||||
validated = counts[TechniqueStatus.validated]
|
||||
partial = counts[TechniqueStatus.partial]
|
||||
not_covered = total - validated - partial
|
||||
coverage_pct = round((validated + partial * 0.5) / total * 100.0, 2) if total > 0 else 0.0
|
||||
|
||||
return {
|
||||
"total_techniques": total,
|
||||
"validated_count": validated,
|
||||
"partial_count": partial,
|
||||
"not_covered_count": not_covered,
|
||||
"coverage_pct": coverage_pct,
|
||||
}
|
||||
|
||||
|
||||
def _aggregate_risk(db: Session) -> dict:
|
||||
"""Aggregate risk metrics from TechniqueRiskProfile."""
|
||||
profiles = db.query(TechniqueRiskProfile).all()
|
||||
if not profiles:
|
||||
return {
|
||||
"avg_risk_score": 0.0,
|
||||
"critical_count": 0,
|
||||
"high_count": 0,
|
||||
"medium_count": 0,
|
||||
"low_count": 0,
|
||||
}
|
||||
|
||||
by_level = {"critical": 0, "high": 0, "medium": 0, "low": 0, "info": 0}
|
||||
score_sum = 0.0
|
||||
for p in profiles:
|
||||
score_sum += p.risk_score
|
||||
lvl = p.risk_level or "info"
|
||||
by_level[lvl] = by_level.get(lvl, 0) + 1
|
||||
|
||||
return {
|
||||
"avg_risk_score": round(score_sum / len(profiles), 2),
|
||||
"critical_count": by_level["critical"],
|
||||
"high_count": by_level["high"],
|
||||
"medium_count": by_level["medium"],
|
||||
"low_count": by_level["low"],
|
||||
}
|
||||
|
||||
|
||||
def _aggregate_operations(db: Session) -> dict:
|
||||
"""Aggregate operational queue and orphan counts."""
|
||||
open_queue = db.query(RevalidationQueueItem).filter(
|
||||
RevalidationQueueItem.status.in_([QueueStatus.pending, QueueStatus.in_progress]),
|
||||
).count()
|
||||
|
||||
# Orphan = technique with no ownership record OR owner_id IS NULL
|
||||
owned_technique_ids = (
|
||||
db.query(TechniqueOwnership.technique_id)
|
||||
.filter(TechniqueOwnership.owner_id.isnot(None))
|
||||
.subquery()
|
||||
)
|
||||
total_tech = db.query(Technique).count()
|
||||
owned_count = db.query(TechniqueOwnership).filter(
|
||||
TechniqueOwnership.owner_id.isnot(None)
|
||||
).count()
|
||||
orphans = total_tech - owned_count
|
||||
|
||||
return {
|
||||
"open_queue_items": open_queue,
|
||||
"orphan_techniques": max(orphans, 0),
|
||||
}
|
||||
|
||||
|
||||
def _aggregate_knowledge(db: Session) -> dict:
|
||||
"""Count active playbooks and lessons learned."""
|
||||
playbook_count = db.query(Playbook).filter(Playbook.is_active == True).count()
|
||||
lesson_count = db.query(LessonLearned).filter(LessonLearned.is_active == True).count()
|
||||
return {
|
||||
"playbook_count": playbook_count,
|
||||
"lesson_count": lesson_count,
|
||||
}
|
||||
|
||||
|
||||
def _aggregate_mttd(db: Session) -> dict:
|
||||
"""Aggregate MTTD from completed attack-path executions in the last 30 days."""
|
||||
cutoff = datetime.utcnow() - timedelta(days=30)
|
||||
execs = db.query(AttackPathExecution).filter(
|
||||
AttackPathExecution.status == ExecutionStatus.completed,
|
||||
AttackPathExecution.completed_at >= cutoff,
|
||||
).all()
|
||||
|
||||
count = len(execs)
|
||||
mttd_values = [e.mttd_seconds for e in execs if e.mttd_seconds is not None]
|
||||
dr_values = [e.detection_rate for e in execs if e.detection_rate is not None]
|
||||
|
||||
return {
|
||||
"executions_30d": count,
|
||||
"mttd_avg_seconds": round(sum(mttd_values) / len(mttd_values), 2) if mttd_values else None,
|
||||
"detection_rate_30d": round(sum(dr_values) / len(dr_values), 4) if dr_values else None,
|
||||
}
|
||||
|
||||
|
||||
def _build_extra_breakdown(db: Session) -> dict:
|
||||
"""Build the by-tactic breakdown stored in the `extra` JSONB field."""
|
||||
techniques = db.query(Technique).all()
|
||||
tactic_map: dict = {}
|
||||
for t in techniques:
|
||||
tac = t.tactic or "Unknown"
|
||||
if tac not in tactic_map:
|
||||
tactic_map[tac] = {"total": 0, "validated": 0, "partial": 0, "not_covered": 0}
|
||||
tactic_map[tac]["total"] += 1
|
||||
s = t.status_global
|
||||
if s == TechniqueStatus.validated:
|
||||
tactic_map[tac]["validated"] += 1
|
||||
elif s == TechniqueStatus.partial:
|
||||
tactic_map[tac]["partial"] += 1
|
||||
else:
|
||||
tactic_map[tac]["not_covered"] += 1
|
||||
|
||||
coverage_by_tactic = [
|
||||
{
|
||||
"tactic": tac,
|
||||
"total": v["total"],
|
||||
"validated": v["validated"],
|
||||
"partial": v["partial"],
|
||||
"not_covered": v["not_covered"],
|
||||
"coverage_pct": round(
|
||||
(v["validated"] + v["partial"] * 0.5) / v["total"] * 100.0, 2
|
||||
) if v["total"] > 0 else 0.0,
|
||||
}
|
||||
for tac, v in sorted(tactic_map.items())
|
||||
]
|
||||
return {"coverage_by_tactic": coverage_by_tactic}
|
||||
|
||||
|
||||
# ── Snapshot persistence ───────────────────────────────────────────────────────
|
||||
|
||||
def take_posture_snapshot(
|
||||
db: Session,
|
||||
created_by: Optional[UUID] = None,
|
||||
) -> PostureSnapshot:
|
||||
"""
|
||||
Aggregate all phases and write (or update) today's PostureSnapshot.
|
||||
Upserts on snapshot_date — only one row per calendar day.
|
||||
"""
|
||||
today = date.today()
|
||||
|
||||
coverage = _aggregate_coverage(db)
|
||||
risk = _aggregate_risk(db)
|
||||
operations = _aggregate_operations(db)
|
||||
knowledge = _aggregate_knowledge(db)
|
||||
mttd = _aggregate_mttd(db)
|
||||
extra = _build_extra_breakdown(db)
|
||||
|
||||
existing = db.query(PostureSnapshot).filter(
|
||||
PostureSnapshot.snapshot_date == today,
|
||||
).first()
|
||||
|
||||
values = {
|
||||
**coverage,
|
||||
**risk,
|
||||
**operations,
|
||||
**knowledge,
|
||||
**mttd,
|
||||
"extra": extra,
|
||||
}
|
||||
|
||||
if existing:
|
||||
for k, v in values.items():
|
||||
setattr(existing, k, v)
|
||||
existing.created_by = created_by
|
||||
db.commit()
|
||||
db.refresh(existing)
|
||||
return existing
|
||||
|
||||
snap = PostureSnapshot(snapshot_date=today, created_by=created_by, **values)
|
||||
db.add(snap)
|
||||
db.commit()
|
||||
db.refresh(snap)
|
||||
return snap
|
||||
|
||||
|
||||
# ── Live / read-only aggregations (no DB write) ───────────────────────────────
|
||||
|
||||
def get_live_kpis(db: Session) -> dict:
|
||||
"""Return current KPIs without persisting a snapshot."""
|
||||
coverage = _aggregate_coverage(db)
|
||||
risk = _aggregate_risk(db)
|
||||
operations = _aggregate_operations(db)
|
||||
knowledge = _aggregate_knowledge(db)
|
||||
mttd = _aggregate_mttd(db)
|
||||
return {**coverage, **risk, **operations, **knowledge, **mttd, "snapshot_date": date.today()}
|
||||
|
||||
|
||||
def get_coverage_by_tactic(db: Session) -> list:
|
||||
"""Per-tactic validated/partial/not_covered breakdown."""
|
||||
extra = _build_extra_breakdown(db)
|
||||
return extra["coverage_by_tactic"]
|
||||
|
||||
|
||||
def get_posture_history(
|
||||
db: Session,
|
||||
days: int = 30,
|
||||
) -> List[PostureSnapshot]:
|
||||
"""Return the last `days` PostureSnapshot rows ordered ascending."""
|
||||
cutoff = date.today() - timedelta(days=days)
|
||||
return (
|
||||
db.query(PostureSnapshot)
|
||||
.filter(PostureSnapshot.snapshot_date >= cutoff)
|
||||
.order_by(PostureSnapshot.snapshot_date.asc())
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def get_top_risks(db: Session, limit: int = 5) -> list:
|
||||
"""Return top-N risk profiles with technique details."""
|
||||
from app.models.risk_intelligence import TechniqueRiskProfile
|
||||
|
||||
rows = (
|
||||
db.query(TechniqueRiskProfile, Technique)
|
||||
.join(Technique, TechniqueRiskProfile.technique_id == Technique.id)
|
||||
.order_by(TechniqueRiskProfile.risk_score.desc())
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
return [
|
||||
{
|
||||
"technique_id": str(p.technique_id),
|
||||
"technique_name": t.name,
|
||||
"technique_tid": t.mitre_id,
|
||||
"tactic": t.tactic,
|
||||
"risk_score": p.risk_score,
|
||||
"risk_level": p.risk_level,
|
||||
"likelihood": p.likelihood,
|
||||
"impact": p.impact,
|
||||
"detection_gap": p.detection_gap,
|
||||
}
|
||||
for p, t in rows
|
||||
]
|
||||
|
||||
|
||||
def get_recent_activity(db: Session, limit: int = 20) -> list:
|
||||
"""Combine recent events from tests, attack-path executions, queue, and OSINT."""
|
||||
events: list = []
|
||||
|
||||
# Recent test executions (use execution_date, fall back to created_at)
|
||||
recent_tests = (
|
||||
db.query(Test)
|
||||
.filter(Test.result.isnot(None))
|
||||
.order_by(Test.created_at.desc())
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
for t in recent_tests:
|
||||
ts = t.execution_date or t.created_at
|
||||
events.append({
|
||||
"ts": ts,
|
||||
"category": "test",
|
||||
"title": f"Test executed — result: {t.result.value if t.result else 'pending'}",
|
||||
"detail": str(t.id),
|
||||
})
|
||||
|
||||
# Recent completed attack-path executions
|
||||
recent_execs = (
|
||||
db.query(AttackPathExecution)
|
||||
.filter(
|
||||
AttackPathExecution.status == ExecutionStatus.completed,
|
||||
AttackPathExecution.completed_at.isnot(None),
|
||||
)
|
||||
.order_by(AttackPathExecution.completed_at.desc())
|
||||
.limit(limit // 2)
|
||||
.all()
|
||||
)
|
||||
for e in recent_execs:
|
||||
dr = f"{e.detection_rate * 100:.0f}%" if e.detection_rate is not None else "n/a"
|
||||
events.append({
|
||||
"ts": e.completed_at,
|
||||
"category": "attack_path",
|
||||
"title": f"Attack path completed — detection: {dr}",
|
||||
"detail": str(e.id),
|
||||
})
|
||||
|
||||
# Recent OSINT items
|
||||
recent_osint = (
|
||||
db.query(OsintItem)
|
||||
.order_by(OsintItem.discovered_at.desc())
|
||||
.limit(limit // 4)
|
||||
.all()
|
||||
)
|
||||
for o in recent_osint:
|
||||
events.append({
|
||||
"ts": o.discovered_at,
|
||||
"category": "osint",
|
||||
"title": f"OSINT signal: {o.title or 'unknown'}",
|
||||
"detail": str(o.id),
|
||||
})
|
||||
|
||||
# Sort all events descending by timestamp, return top `limit`
|
||||
events.sort(key=lambda x: x["ts"] or datetime.min, reverse=True)
|
||||
return events[:limit]
|
||||
|
||||
|
||||
def get_executive_summary(db: Session) -> dict:
|
||||
"""Full executive view — live KPIs + snapshot + trends + top risks + activity."""
|
||||
# Take (or update) today's snapshot
|
||||
snap = take_posture_snapshot(db)
|
||||
|
||||
# 30-day trend
|
||||
history = get_posture_history(db, days=30)
|
||||
coverage_trend = [
|
||||
{"date": str(s.snapshot_date), "value": s.coverage_pct}
|
||||
for s in history
|
||||
]
|
||||
risk_trend = [
|
||||
{"date": str(s.snapshot_date), "value": s.avg_risk_score}
|
||||
for s in history
|
||||
]
|
||||
|
||||
return {
|
||||
"snapshot": snap,
|
||||
"coverage_trend": coverage_trend,
|
||||
"risk_trend": risk_trend,
|
||||
"top_risks": get_top_risks(db),
|
||||
"coverage_by_tactic": get_coverage_by_tactic(db),
|
||||
"recent_activity": get_recent_activity(db),
|
||||
}
|
||||
@@ -259,36 +259,30 @@ def build_threat_actor_layer(
|
||||
if is_actor_technique and score < min_score:
|
||||
continue
|
||||
|
||||
if is_actor_technique:
|
||||
tc = test_counts.get(tech.id, 0)
|
||||
rc = rule_counts.get(tech.mitre_id, 0)
|
||||
metadata = [
|
||||
{"name": "tests_count", "value": str(tc)},
|
||||
{"name": "detection_rules", "value": str(rc)},
|
||||
]
|
||||
if tech.last_review_date:
|
||||
metadata.append(
|
||||
{"name": "last_validated", "value": tech.last_review_date.strftime("%Y-%m-%d")}
|
||||
)
|
||||
layer["techniques"].append({
|
||||
"techniqueID": tech.mitre_id,
|
||||
"tactic": _format_tactic(tech.tactic),
|
||||
"color": _score_to_color(score),
|
||||
"score": score,
|
||||
"comment": f"Used by {actor.name} - Coverage: {tech.status_global.value}",
|
||||
"enabled": True,
|
||||
"metadata": metadata,
|
||||
})
|
||||
else:
|
||||
layer["techniques"].append({
|
||||
"techniqueID": tech.mitre_id,
|
||||
"tactic": _format_tactic(tech.tactic),
|
||||
"color": "",
|
||||
"score": 0,
|
||||
"comment": "",
|
||||
"enabled": False,
|
||||
"metadata": [],
|
||||
})
|
||||
# Only include techniques actually used by this actor — skip the rest
|
||||
# so that tactics with no actor techniques are hidden in the matrix.
|
||||
if not is_actor_technique:
|
||||
continue
|
||||
|
||||
tc = test_counts.get(tech.id, 0)
|
||||
rc = rule_counts.get(tech.mitre_id, 0)
|
||||
metadata = [
|
||||
{"name": "tests_count", "value": str(tc)},
|
||||
{"name": "detection_rules", "value": str(rc)},
|
||||
]
|
||||
if tech.last_review_date:
|
||||
metadata.append(
|
||||
{"name": "last_validated", "value": tech.last_review_date.strftime("%Y-%m-%d")}
|
||||
)
|
||||
layer["techniques"].append({
|
||||
"techniqueID": tech.mitre_id,
|
||||
"tactic": _format_tactic(tech.tactic),
|
||||
"color": _score_to_color(score),
|
||||
"score": score,
|
||||
"comment": f"Used by {actor.name} - Coverage: {tech.status_global.value}",
|
||||
"enabled": True,
|
||||
"metadata": metadata,
|
||||
})
|
||||
|
||||
return layer
|
||||
|
||||
@@ -300,10 +294,19 @@ def build_detection_rules_layer(
|
||||
tactics: str | None = None,
|
||||
min_score: int = 0,
|
||||
) -> dict:
|
||||
"""Detection rules layer -- score based on rule availability and evaluation ratio."""
|
||||
"""Detection rules layer -- score based on absolute rule count per technique.
|
||||
|
||||
Scoring uses fixed thresholds so the colour reflects real coverage regardless
|
||||
of what other techniques have:
|
||||
0 rules → gray (score 0)
|
||||
1 rule → red (score 25)
|
||||
2 rules → orange (score 50)
|
||||
3 rules → yellow (score 75)
|
||||
4+ rules → green (score 100)
|
||||
"""
|
||||
layer = _build_layer_skeleton(
|
||||
"Detection Rules Coverage",
|
||||
"Coverage of detection rules per technique",
|
||||
"Number of active detection rules per technique",
|
||||
)
|
||||
|
||||
query = _apply_filters(
|
||||
@@ -318,7 +321,6 @@ def build_detection_rules_layer(
|
||||
.group_by(DetectionRule.mitre_technique_id)
|
||||
.all()
|
||||
)
|
||||
max_rules = max(rule_counts.values()) if rule_counts else 1
|
||||
|
||||
evaluated_counts = dict(
|
||||
db.query(DetectionRule.mitre_technique_id, func.count(TestDetectionResult.id))
|
||||
@@ -328,26 +330,28 @@ def build_detection_rules_layer(
|
||||
.all()
|
||||
)
|
||||
|
||||
# 4 rules = full coverage (100). Each rule adds 25 points.
|
||||
RULES_FOR_FULL_COVERAGE = 4
|
||||
|
||||
for tech in techniques:
|
||||
total_rules = rule_counts.get(tech.mitre_id, 0)
|
||||
evaluated_rules = evaluated_counts.get(tech.mitre_id, 0)
|
||||
|
||||
if total_rules > 0:
|
||||
availability_score = min((total_rules / max_rules) * 50, 50)
|
||||
evaluation_score = (evaluated_rules / total_rules) * 50
|
||||
score = int(min(availability_score + evaluation_score, 100))
|
||||
else:
|
||||
score = 0
|
||||
score = min(int((total_rules / RULES_FOR_FULL_COVERAGE) * 100), 100)
|
||||
|
||||
if score < min_score:
|
||||
continue
|
||||
|
||||
rule_word = "rule" if total_rules == 1 else "rules"
|
||||
eval_note = f", {evaluated_rules} evaluated" if evaluated_rules > 0 else ""
|
||||
comment = f"{total_rules} active {rule_word}{eval_note}"
|
||||
|
||||
layer["techniques"].append({
|
||||
"techniqueID": tech.mitre_id,
|
||||
"tactic": _format_tactic(tech.tactic),
|
||||
"color": _score_to_color(score),
|
||||
"score": score,
|
||||
"comment": f"{total_rules} rules available, {evaluated_rules} evaluated",
|
||||
"comment": comment,
|
||||
"enabled": True,
|
||||
"metadata": [
|
||||
{"name": "total_rules", "value": str(total_rules)},
|
||||
|
||||
@@ -33,8 +33,8 @@ RSS_FEEDS: list[dict[str, str]] = [
|
||||
"url": "https://www.cisa.gov/cybersecurity-advisories/all.xml",
|
||||
},
|
||||
{
|
||||
"name": "NIST NVD CVE",
|
||||
"url": "https://nvd.nist.gov/feeds/xml/cve/misc/nvd-rss.xml",
|
||||
"name": "SecurityWeek",
|
||||
"url": "https://feeds.feedburner.com/Securityweek",
|
||||
},
|
||||
{
|
||||
"name": "SANS ISC",
|
||||
@@ -57,8 +57,9 @@ RSS_FEEDS: list[dict[str, str]] = [
|
||||
# Timeout for each feed request (seconds)
|
||||
_FEED_TIMEOUT = 15
|
||||
|
||||
# Maximum number of techniques to scan (to keep MVP fast)
|
||||
_MAX_TECHNIQUES = 50
|
||||
# Minimum technique name length for name-based matching
|
||||
# Short names ("Kill", "BITS") produce too many false positives
|
||||
_MIN_NAME_LEN = 8
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -118,25 +119,36 @@ def _fetch_feed(url: str) -> list[dict[str, str]]:
|
||||
return entries
|
||||
|
||||
|
||||
def _build_patterns(technique: Technique) -> list[re.Pattern]:
|
||||
"""Build regex patterns to search feed content for a given technique."""
|
||||
patterns: list[re.Pattern] = []
|
||||
def _build_patterns(technique: Technique) -> tuple[list[re.Pattern], list[re.Pattern]]:
|
||||
"""Build regex patterns for a technique.
|
||||
|
||||
mitre_id = re.escape(technique.mitre_id)
|
||||
patterns.append(re.compile(mitre_id, re.IGNORECASE))
|
||||
Returns two lists:
|
||||
- ``id_patterns``: MITRE ID patterns (high confidence, word-boundary matched)
|
||||
- ``name_patterns``: technique name patterns (lower confidence, long names only)
|
||||
"""
|
||||
id_patterns: list[re.Pattern] = []
|
||||
name_patterns: list[re.Pattern] = []
|
||||
|
||||
# Technique name — match if the full name appears
|
||||
if technique.name and len(technique.name) > 4:
|
||||
# MITRE ID with word boundaries so T1059 doesn't partially match T1059.001
|
||||
mitre_id_escaped = re.escape(technique.mitre_id)
|
||||
id_patterns.append(re.compile(rf"\b{mitre_id_escaped}\b", re.IGNORECASE))
|
||||
|
||||
# Technique name — only for distinctly long names to reduce false positives
|
||||
if technique.name and len(technique.name) >= _MIN_NAME_LEN:
|
||||
name_escaped = re.escape(technique.name)
|
||||
patterns.append(re.compile(name_escaped, re.IGNORECASE))
|
||||
name_patterns.append(re.compile(rf"\b{name_escaped}\b", re.IGNORECASE))
|
||||
|
||||
return patterns
|
||||
return id_patterns, name_patterns
|
||||
|
||||
|
||||
def _entry_matches(entry: dict[str, str], patterns: list[re.Pattern]) -> bool:
|
||||
def _entry_matches(
|
||||
entry: dict[str, str],
|
||||
id_patterns: list[re.Pattern],
|
||||
name_patterns: list[re.Pattern],
|
||||
) -> bool:
|
||||
"""Return True if any pattern matches the entry's title or description."""
|
||||
text = f"{entry.get('title', '')} {entry.get('description', '')}"
|
||||
return any(p.search(text) for p in patterns)
|
||||
return any(p.search(text) for p in id_patterns + name_patterns)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -160,11 +172,10 @@ def scan_intel(db: Session) -> dict:
|
||||
"""
|
||||
logger.info("Intel scan starting...")
|
||||
|
||||
# 1. Load techniques (limit for MVP speed)
|
||||
# 1. Load all active techniques
|
||||
techniques = (
|
||||
db.query(Technique)
|
||||
.order_by(Technique.mitre_id)
|
||||
.limit(_MAX_TECHNIQUES)
|
||||
.all()
|
||||
)
|
||||
logger.info("Scanning %d techniques against %d feeds", len(techniques), len(RSS_FEEDS))
|
||||
@@ -192,10 +203,14 @@ def scan_intel(db: Session) -> dict:
|
||||
techniques_flagged: set[str] = set()
|
||||
|
||||
for technique in techniques:
|
||||
patterns = _build_patterns(technique)
|
||||
id_patterns, name_patterns = _build_patterns(technique)
|
||||
|
||||
for feed_name, entry in all_entries:
|
||||
if not _entry_matches(entry, patterns):
|
||||
if not _entry_matches(entry, id_patterns, name_patterns):
|
||||
continue
|
||||
|
||||
# Skip entries with no title (low-quality)
|
||||
if not entry.get("title", "").strip():
|
||||
continue
|
||||
|
||||
url = entry.get("link", "").strip()
|
||||
|
||||
@@ -1,4 +1,32 @@
|
||||
"""Jira integration service — wraps atlassian-python-api for Jira REST calls."""
|
||||
"""Jira integration service.
|
||||
|
||||
Authentication model
|
||||
--------------------
|
||||
Each Aegis user authenticates to Jira with their own Atlassian email and
|
||||
personal API token. The email used is ``user.jira_email`` when set, falling
|
||||
back to ``user.email`` (the Aegis account email). This lets users specify a
|
||||
separate corporate Atlassian email without changing their Aegis login.
|
||||
The token is stored in ``user.jira_api_token``.
|
||||
|
||||
Admin configuration
|
||||
-------------------
|
||||
The Jira URL and default project key are stored in the ``system_configs``
|
||||
table (keys ``jira.url`` and ``jira.project_key``) so the admin can update
|
||||
them at runtime without redeploying. These values override the legacy
|
||||
``settings.JIRA_URL`` / ``settings.JIRA_DEFAULT_PROJECT`` env-vars which are
|
||||
kept for backwards-compatibility only.
|
||||
|
||||
Lifecycle hooks
|
||||
---------------
|
||||
``push_test_event()`` is the single entry-point called from the test-workflow
|
||||
service on every state transition. It posts a rich comment to the linked
|
||||
Jira issue (if one exists) using the acting user's credentials.
|
||||
|
||||
``auto_create_test_issue()`` is called once after a test is created; it
|
||||
creates the Jira ticket and stores the link.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
@@ -14,31 +42,605 @@ from app.models.campaign import Campaign
|
||||
from app.models.jira_link import JiraLink, JiraLinkEntityType, JiraSyncDirection
|
||||
from app.models.technique import Technique
|
||||
from app.models.test import Test
|
||||
from app.models.user import User
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_jira_client = None
|
||||
# ---------------------------------------------------------------------------
|
||||
# System-config helpers (admin-configurable Jira settings)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_JIRA_KEYS = {
|
||||
"url": "jira.url",
|
||||
"project_key": "jira.project_key",
|
||||
"enabled": "jira.enabled",
|
||||
}
|
||||
|
||||
|
||||
def _read_system_config(db: Session, key: str) -> Optional[str]:
|
||||
"""Return a value from system_configs, or None if not set."""
|
||||
from app.models.system_config import SystemConfig # avoid circular at import time
|
||||
|
||||
row = db.query(SystemConfig).filter(SystemConfig.key == key).first()
|
||||
return row.value if row else None
|
||||
|
||||
|
||||
def get_jira_url(db: Session) -> Optional[str]:
|
||||
"""Return the admin-configured Jira URL, falling back to the env-var."""
|
||||
return _read_system_config(db, _JIRA_KEYS["url"]) or settings.JIRA_URL or None
|
||||
|
||||
|
||||
def get_jira_project_key(db: Session) -> Optional[str]:
|
||||
"""Return the admin-configured default project key, falling back to env-var."""
|
||||
return (
|
||||
_read_system_config(db, _JIRA_KEYS["project_key"])
|
||||
or settings.JIRA_DEFAULT_PROJECT
|
||||
or None
|
||||
)
|
||||
|
||||
|
||||
def is_jira_enabled(db: Session) -> bool:
|
||||
"""Return True if Jira integration is enabled (DB setting or env-var)."""
|
||||
db_val = _read_system_config(db, _JIRA_KEYS["enabled"])
|
||||
if db_val is not None:
|
||||
return db_val.lower() in ("true", "1", "yes")
|
||||
return settings.JIRA_ENABLED
|
||||
|
||||
|
||||
def get_jira_parent_ticket(db: Session) -> Optional[str]:
|
||||
"""Return the configured parent ticket key for campaigns, or None if not set."""
|
||||
return _read_system_config(db, "jira.parent_ticket") or None
|
||||
|
||||
|
||||
def get_jira_parent_ticket_standalone(db: Session) -> Optional[str]:
|
||||
"""Return the parent ticket for standalone tests (not in a campaign).
|
||||
|
||||
Falls back to get_jira_parent_ticket() if not explicitly configured.
|
||||
"""
|
||||
return (
|
||||
_read_system_config(db, "jira.parent_ticket_standalone")
|
||||
or get_jira_parent_ticket(db)
|
||||
)
|
||||
|
||||
|
||||
def upsert_jira_config(db: Session, key: str, value: str) -> None:
|
||||
"""Persist a Jira config key-value pair."""
|
||||
from app.models.system_config import SystemConfig
|
||||
|
||||
row = db.query(SystemConfig).filter(SystemConfig.key == key).first()
|
||||
if row:
|
||||
row.value = value
|
||||
else:
|
||||
db.add(SystemConfig(key=key, value=value))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-user Jira client
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _effective_jira_email(user: User) -> Optional[str]:
|
||||
"""Return the email to use for Jira auth: jira_email if set, otherwise email."""
|
||||
return getattr(user, "jira_email", None) or user.email
|
||||
|
||||
|
||||
def get_user_jira_client(user: User, db: Session):
|
||||
"""Build an Atlassian Jira client authenticated as *user*.
|
||||
|
||||
Uses ``user.jira_email`` when set, otherwise falls back to ``user.email``.
|
||||
Raises ``InvalidOperationError`` when configuration is incomplete so
|
||||
callers can surface meaningful error messages.
|
||||
"""
|
||||
jira_url = get_jira_url(db)
|
||||
if not jira_url:
|
||||
raise InvalidOperationError(
|
||||
"Jira URL is not configured. Ask your administrator to set it in "
|
||||
"System Settings → Jira Configuration."
|
||||
)
|
||||
|
||||
auth_email = _effective_jira_email(user)
|
||||
if not auth_email:
|
||||
raise InvalidOperationError(
|
||||
"No email configured for Jira authentication. "
|
||||
"Set a Jira email in Settings → Profile → Jira Integration."
|
||||
)
|
||||
|
||||
if not user.jira_api_token:
|
||||
raise InvalidOperationError(
|
||||
"You have not configured a Jira API token. "
|
||||
"Go to Settings → Profile → Jira Integration and add your personal Atlassian token."
|
||||
)
|
||||
|
||||
from atlassian import Jira
|
||||
|
||||
# Strip trailing slash — the Atlassian library appends paths like
|
||||
# /rest/api/2/myself and a trailing slash causes double-slash URLs.
|
||||
clean_url = jira_url.rstrip("/")
|
||||
|
||||
return Jira(
|
||||
url=clean_url,
|
||||
username=auth_email,
|
||||
password=user.jira_api_token,
|
||||
cloud=True,
|
||||
)
|
||||
|
||||
|
||||
def has_jira_configured(user: User, db: Session) -> bool:
|
||||
"""Return True if *user* has everything needed to call Jira."""
|
||||
return bool(get_jira_url(db) and _effective_jira_email(user) and user.jira_api_token)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Ticket content builders (inspired by the pentest-to-Jira script)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_SEVERITY_TO_PRIORITY: dict[str, str] = {
|
||||
"critical": "Highest",
|
||||
"high": "High",
|
||||
"medium": "Medium",
|
||||
"low": "Low",
|
||||
"informational": "Lowest",
|
||||
}
|
||||
|
||||
_STATE_EMOJI: dict[str, str] = {
|
||||
"draft": "📝 Draft",
|
||||
"red_executing": "🔴 Red Team Executing",
|
||||
"blue_evaluating": "🔵 Blue Team Evaluating",
|
||||
"in_review": "📋 In Review",
|
||||
"validated": "✅ Validated",
|
||||
"rejected": "❌ Rejected",
|
||||
}
|
||||
|
||||
|
||||
def _technique_severity(technique: Optional[Technique]) -> str:
|
||||
"""Return a lowercase severity string from the technique, defaulting to medium."""
|
||||
if technique and hasattr(technique, "severity") and technique.severity:
|
||||
return technique.severity.lower()
|
||||
return "medium"
|
||||
|
||||
|
||||
def _build_test_description(test: Test, technique: Optional[Technique]) -> str:
|
||||
"""Build the initial Jira ticket description for a newly created test."""
|
||||
mitre_id = technique.mitre_id if technique else "N/A"
|
||||
tech_name = technique.name if technique else "N/A"
|
||||
tactic = technique.tactic if technique else "N/A"
|
||||
severity = _technique_severity(technique).capitalize()
|
||||
|
||||
lines = [
|
||||
"h2. Aegis Security Test",
|
||||
"",
|
||||
f"*Test Name:* {test.name}",
|
||||
f"*MITRE Technique:* [{mitre_id}|https://attack.mitre.org/techniques/{mitre_id.replace('.', '/')}] — {tech_name}",
|
||||
f"*Tactic:* {tactic}",
|
||||
f"*Platform:* {test.platform or 'N/A'}",
|
||||
f"*Severity:* {severity}",
|
||||
f"*Data Classification:* {test.data_classification or 'N/A'}",
|
||||
"",
|
||||
"h3. Description",
|
||||
test.description or "_No description provided._",
|
||||
"",
|
||||
f"*Tool:* {test.tool_used or 'N/A'}",
|
||||
"",
|
||||
"----",
|
||||
f"_Created via Aegis at {datetime.utcnow().strftime('%Y-%m-%d %H:%M')} UTC_",
|
||||
]
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _build_state_comment(
|
||||
test: Test,
|
||||
new_state: str,
|
||||
actor: User,
|
||||
extra: dict | None = None,
|
||||
) -> str:
|
||||
"""Build a Jira comment body for a test state transition."""
|
||||
label = _STATE_EMOJI.get(new_state, new_state)
|
||||
lines = [
|
||||
f"h3. {label}",
|
||||
"",
|
||||
f"*Changed by:* {actor.username} ({actor.email or 'no email'})",
|
||||
f"*At:* {datetime.utcnow().strftime('%Y-%m-%d %H:%M')} UTC",
|
||||
"",
|
||||
]
|
||||
|
||||
if new_state == "red_executing":
|
||||
lines += [
|
||||
"Red Team has started the attack execution.",
|
||||
]
|
||||
|
||||
elif new_state == "blue_evaluating":
|
||||
lines += [
|
||||
"Red Team has finished execution and submitted evidence for Blue Team evaluation.",
|
||||
"",
|
||||
f"*Attack Success:* {test.attack_success if test.attack_success is not None else 'N/A'}",
|
||||
]
|
||||
if test.red_summary:
|
||||
lines += ["", "h4. Red Team Summary", test.red_summary]
|
||||
|
||||
elif new_state == "in_review":
|
||||
lines += [
|
||||
"Blue Team has completed evaluation. Test is awaiting lead validation.",
|
||||
"",
|
||||
f"*Detection Result:* {test.detection_result or 'N/A'}",
|
||||
]
|
||||
if test.blue_summary:
|
||||
lines += ["", "h4. Blue Team Summary", test.blue_summary]
|
||||
if test.remediation_steps:
|
||||
lines += ["", "h4. Remediation Steps", test.remediation_steps]
|
||||
|
||||
elif new_state == "validated":
|
||||
lines += [
|
||||
"Test has been *validated* by both leads.",
|
||||
"",
|
||||
f"*Red Lead Status:* {test.red_validation_status or 'N/A'}",
|
||||
f"*Blue Lead Status:* {test.blue_validation_status or 'N/A'}",
|
||||
]
|
||||
if test.red_validation_notes:
|
||||
lines += ["", f"*Red Lead Notes:* {test.red_validation_notes}"]
|
||||
if test.blue_validation_notes:
|
||||
lines += ["", f"*Blue Lead Notes:* {test.blue_validation_notes}"]
|
||||
|
||||
elif new_state == "rejected":
|
||||
lines += [
|
||||
"Test has been *rejected* and must be reworked.",
|
||||
"",
|
||||
f"*Red Lead Status:* {test.red_validation_status or 'N/A'}",
|
||||
f"*Blue Lead Status:* {test.blue_validation_status or 'N/A'}",
|
||||
]
|
||||
if test.red_validation_notes:
|
||||
lines += ["", f"*Red Lead Notes:* {test.red_validation_notes}"]
|
||||
if test.blue_validation_notes:
|
||||
lines += ["", f"*Blue Lead Notes:* {test.blue_validation_notes}"]
|
||||
|
||||
elif new_state == "draft":
|
||||
lines += ["Test has been reopened for re-execution."]
|
||||
|
||||
# Any caller-supplied extra data
|
||||
if extra:
|
||||
lines.append("")
|
||||
for k, v in extra.items():
|
||||
lines.append(f"*{k}:* {v}")
|
||||
|
||||
lines.append("")
|
||||
lines.append("_Synced from [Aegis|https://aegis.undiamagico.es]_")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public lifecycle hooks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _build_campaign_description(campaign) -> str:
|
||||
"""Build the Jira ticket description for a campaign."""
|
||||
lines = [
|
||||
"h2. Aegis Security Campaign",
|
||||
"",
|
||||
f"*Campaign Name:* {campaign.name}",
|
||||
f"*Type:* {campaign.type}",
|
||||
f"*Status:* {campaign.status}",
|
||||
]
|
||||
if campaign.description:
|
||||
lines += ["", "h3. Description", campaign.description]
|
||||
if campaign.tags:
|
||||
lines += ["", f"*Tags:* {', '.join(campaign.tags)}"]
|
||||
lines += [
|
||||
"",
|
||||
"----",
|
||||
f"_Created via Aegis at {datetime.utcnow().strftime('%Y-%m-%d %H:%M')} UTC_",
|
||||
]
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def get_campaign_jira_key(db: Session, campaign_id) -> Optional[str]:
|
||||
"""Return the Jira issue key for a campaign, or None if not linked."""
|
||||
import uuid as _uuid
|
||||
try:
|
||||
cid = _uuid.UUID(str(campaign_id))
|
||||
except (ValueError, AttributeError):
|
||||
return None
|
||||
link = (
|
||||
db.query(JiraLink)
|
||||
.filter(
|
||||
JiraLink.entity_type == JiraLinkEntityType.campaign,
|
||||
JiraLink.entity_id == cid,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return link.jira_issue_key if link else None
|
||||
|
||||
|
||||
def get_test_jira_key(db: Session, test_id) -> Optional[str]:
|
||||
"""Return the Jira issue key for a test, or None if not linked."""
|
||||
import uuid as _uuid
|
||||
try:
|
||||
tid = _uuid.UUID(str(test_id))
|
||||
except (ValueError, AttributeError):
|
||||
return None
|
||||
link = (
|
||||
db.query(JiraLink)
|
||||
.filter(
|
||||
JiraLink.entity_type == JiraLinkEntityType.test,
|
||||
JiraLink.entity_id == tid,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return link.jira_issue_key if link else None
|
||||
|
||||
|
||||
def auto_create_campaign_issue(
|
||||
db: Session,
|
||||
campaign,
|
||||
actor: User,
|
||||
) -> Optional[str]:
|
||||
"""Create a Jira issue for *campaign* under the configured parent ticket.
|
||||
|
||||
Returns the Jira issue key on success, or ``None`` if Jira is not
|
||||
configured for *actor* or if the operation fails (non-fatal).
|
||||
|
||||
Called once right after a campaign is committed to the database.
|
||||
The created ticket is stored as a JiraLink with entity_type=campaign.
|
||||
"""
|
||||
if not has_jira_configured(actor, db):
|
||||
return None
|
||||
|
||||
project_key = get_jira_project_key(db)
|
||||
if not project_key:
|
||||
logger.warning(
|
||||
"Jira project key not configured; skipping auto-create for campaign %s",
|
||||
campaign.id,
|
||||
)
|
||||
return None
|
||||
|
||||
parent_ticket = get_jira_parent_ticket(db)
|
||||
|
||||
try:
|
||||
jira = get_user_jira_client(actor, db)
|
||||
|
||||
fields: dict = {
|
||||
"project": {"key": project_key},
|
||||
"summary": f"[Aegis Campaign] {campaign.name}",
|
||||
"description": _build_campaign_description(campaign),
|
||||
"issuetype": {"name": settings.JIRA_ISSUE_TYPE_CAMPAIGN},
|
||||
"labels": ["aegis", "campaign"],
|
||||
# customfield_10011 = Epic Name (required for Epic type in classic Jira)
|
||||
"customfield_10011": campaign.name,
|
||||
}
|
||||
|
||||
# Set start date: use campaign.start_date if set, otherwise today
|
||||
effective_start = campaign.start_date or campaign.created_at
|
||||
if effective_start:
|
||||
fields[settings.JIRA_START_DATE_FIELD] = effective_start.strftime("%Y-%m-%d")
|
||||
|
||||
# Nest under the configured parent ticket (Initiative, e.g. OFS-20795)
|
||||
if parent_ticket:
|
||||
fields["parent"] = {"key": parent_ticket}
|
||||
|
||||
result = jira.issue_create(fields=fields)
|
||||
issue_key = result["key"]
|
||||
issue_id = result.get("id", "")
|
||||
|
||||
link = JiraLink(
|
||||
entity_type=JiraLinkEntityType.campaign,
|
||||
entity_id=campaign.id,
|
||||
jira_issue_key=issue_key,
|
||||
jira_issue_id=issue_id,
|
||||
jira_project_key=project_key,
|
||||
sync_direction=JiraSyncDirection.aegis_to_jira,
|
||||
created_by=actor.id,
|
||||
)
|
||||
db.add(link)
|
||||
db.flush()
|
||||
|
||||
logger.info("Auto-created Jira issue %s for campaign %s", issue_key, campaign.id)
|
||||
return issue_key
|
||||
|
||||
except Exception as exc:
|
||||
# Non-fatal: Jira failures must never break the campaign creation flow
|
||||
logger.warning(
|
||||
"Failed to auto-create Jira issue for campaign %s: %s",
|
||||
campaign.id, exc, exc_info=True,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def auto_create_test_issue(
|
||||
db: Session,
|
||||
test: Test,
|
||||
actor: User,
|
||||
*,
|
||||
technique: Optional[Technique] = None,
|
||||
parent_ticket_override: Optional[str] = None,
|
||||
campaign_start_date=None, # datetime | None — inherited from campaign when available
|
||||
) -> Optional[str]:
|
||||
"""Create a Jira issue for *test* and store the link.
|
||||
|
||||
Returns the Jira issue key on success, or ``None`` if Jira is not
|
||||
configured for *actor* or if the operation fails (non-fatal).
|
||||
|
||||
Called once right after a test is committed to the database.
|
||||
|
||||
Args:
|
||||
parent_ticket_override: When set, use this as the Jira parent ticket
|
||||
instead of the system-configured parent (e.g. OFS-9107).
|
||||
Use this to nest test tickets under a campaign ticket.
|
||||
"""
|
||||
if not has_jira_configured(actor, db):
|
||||
return None
|
||||
|
||||
project_key = get_jira_project_key(db)
|
||||
if not project_key:
|
||||
logger.warning("Jira project key not configured; skipping auto-create for test %s", test.id)
|
||||
return None
|
||||
|
||||
# Resolve technique if not supplied
|
||||
if technique is None:
|
||||
technique = db.query(Technique).filter(Technique.id == test.technique_id).first()
|
||||
|
||||
severity = _technique_severity(technique)
|
||||
mitre_id = technique.mitre_id if technique else "N/A"
|
||||
|
||||
try:
|
||||
jira = get_user_jira_client(actor, db)
|
||||
|
||||
# All tests — whether inside a campaign or standalone — are created
|
||||
# as Task. Campaign tests use the campaign Jira key as parent
|
||||
# (passed via parent_ticket_override); standalone tests use the
|
||||
# configured standalone parent ticket (e.g. OFS-20798, which is an
|
||||
# Epic so it can parent Tasks).
|
||||
parent = parent_ticket_override or get_jira_parent_ticket_standalone(db)
|
||||
issue_type = settings.JIRA_ISSUE_TYPE_TEST # always Task
|
||||
|
||||
poc = test.procedure_text or "N/A"
|
||||
fields: dict = {
|
||||
"project": {"key": project_key},
|
||||
"summary": f"[Aegis] {mitre_id} — {test.name}",
|
||||
"description": _build_test_description(test, technique),
|
||||
"issuetype": {"name": issue_type},
|
||||
"labels": ["aegis", "security-test", mitre_id.replace(".", "-")],
|
||||
# customfield_10309 = Proof of Concept field (required by team's Jira config)
|
||||
"customfield_10309": f"{{code}}{poc}{{code}}",
|
||||
}
|
||||
|
||||
# Inherit campaign start date if available, otherwise use today
|
||||
from datetime import date as _date
|
||||
effective_start = campaign_start_date or _date.today()
|
||||
if hasattr(effective_start, "strftime"):
|
||||
fields[settings.JIRA_START_DATE_FIELD] = effective_start.strftime("%Y-%m-%d")
|
||||
|
||||
if parent:
|
||||
fields["parent"] = {"key": parent}
|
||||
|
||||
result = jira.issue_create(fields=fields)
|
||||
issue_key = result["key"]
|
||||
issue_id = result.get("id", "")
|
||||
|
||||
link = JiraLink(
|
||||
entity_type=JiraLinkEntityType.test,
|
||||
entity_id=test.id,
|
||||
jira_issue_key=issue_key,
|
||||
jira_issue_id=issue_id,
|
||||
jira_project_key=project_key,
|
||||
sync_direction=JiraSyncDirection.aegis_to_jira,
|
||||
created_by=actor.id,
|
||||
)
|
||||
db.add(link)
|
||||
db.flush()
|
||||
|
||||
logger.info("Auto-created Jira issue %s for test %s", issue_key, test.id)
|
||||
return issue_key
|
||||
|
||||
except Exception as exc:
|
||||
# Non-fatal: Jira failures must never break the test creation flow
|
||||
logger.warning(
|
||||
"Failed to auto-create Jira issue for test %s: %s",
|
||||
test.id, exc, exc_info=True,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def push_test_event(
|
||||
db: Session,
|
||||
test: Test,
|
||||
actor: User,
|
||||
new_state: str,
|
||||
*,
|
||||
extra: dict | None = None,
|
||||
) -> None:
|
||||
"""Post a lifecycle comment to the Jira issue linked to *test*.
|
||||
|
||||
Called from ``test_workflow_service`` after every state transition.
|
||||
Completely non-fatal — any Jira error is logged and swallowed so it
|
||||
never blocks the test workflow.
|
||||
"""
|
||||
if not has_jira_configured(actor, db):
|
||||
return
|
||||
|
||||
link = (
|
||||
db.query(JiraLink)
|
||||
.filter(
|
||||
JiraLink.entity_type == JiraLinkEntityType.test,
|
||||
JiraLink.entity_id == test.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not link:
|
||||
return
|
||||
|
||||
try:
|
||||
jira = get_user_jira_client(actor, db)
|
||||
comment = _build_state_comment(test, new_state, actor, extra)
|
||||
jira.issue_add_comment(link.jira_issue_key, comment)
|
||||
|
||||
# When the operator starts execution: transition to "In Progress"
|
||||
# and assign the ticket to that operator.
|
||||
if new_state == "red_executing":
|
||||
try:
|
||||
jira.set_issue_status(link.jira_issue_key, "In Progress")
|
||||
logger.info(
|
||||
"Transitioned Jira ticket %s to In Progress", link.jira_issue_key
|
||||
)
|
||||
except Exception as exc_t:
|
||||
logger.warning(
|
||||
"Could not transition %s to In Progress: %s",
|
||||
link.jira_issue_key, exc_t,
|
||||
)
|
||||
jira_account_id = getattr(actor, "jira_account_id", None)
|
||||
if jira_account_id:
|
||||
try:
|
||||
jira.assign_issue(link.jira_issue_key, account_id=jira_account_id)
|
||||
logger.info(
|
||||
"Assigned Jira ticket %s to account %s",
|
||||
link.jira_issue_key, jira_account_id,
|
||||
)
|
||||
except Exception as exc_a:
|
||||
logger.warning(
|
||||
"Could not assign %s to %s: %s",
|
||||
link.jira_issue_key, jira_account_id, exc_a,
|
||||
)
|
||||
|
||||
link.last_synced_at = datetime.utcnow()
|
||||
db.flush()
|
||||
logger.info(
|
||||
"Posted Jira comment to %s for test %s state=%s",
|
||||
link.jira_issue_key, test.id, new_state,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to push Jira event for test %s (state=%s): %s",
|
||||
test.id, new_state, exc, exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Legacy / generic helpers (kept for existing routes)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def get_jira_client():
|
||||
"""Return a lazily-initialised Jira client, or raise if disabled."""
|
||||
global _jira_client
|
||||
"""Return a shared Jira client using global credentials (legacy path).
|
||||
|
||||
Raises ``InvalidOperationError`` when Jira is disabled or unconfigured.
|
||||
Prefer ``get_user_jira_client()`` for new code.
|
||||
"""
|
||||
if not settings.JIRA_ENABLED:
|
||||
raise InvalidOperationError("Jira integration is not enabled")
|
||||
if _jira_client is None:
|
||||
from atlassian import Jira
|
||||
|
||||
_jira_client = Jira(
|
||||
url=settings.JIRA_URL,
|
||||
username=settings.JIRA_USERNAME,
|
||||
password=settings.JIRA_API_TOKEN,
|
||||
cloud=settings.JIRA_IS_CLOUD,
|
||||
if not settings.JIRA_URL or not settings.JIRA_USERNAME or not settings.JIRA_API_TOKEN:
|
||||
raise InvalidOperationError(
|
||||
"Jira is enabled but JIRA_URL / JIRA_USERNAME / JIRA_API_TOKEN are not set"
|
||||
)
|
||||
return _jira_client
|
||||
from atlassian import Jira
|
||||
|
||||
return Jira(
|
||||
url=settings.JIRA_URL,
|
||||
username=settings.JIRA_USERNAME,
|
||||
password=settings.JIRA_API_TOKEN,
|
||||
cloud=settings.JIRA_IS_CLOUD,
|
||||
)
|
||||
|
||||
|
||||
def search_jira_issues(query: str, max_results: int = 10) -> list[dict]:
|
||||
"""Search Jira issues by JQL or free text."""
|
||||
"""Search Jira issues by JQL or free text (uses global credentials)."""
|
||||
jira = get_jira_client()
|
||||
jql = query if "=" in query or "~" in query else f'summary ~ "{query}"'
|
||||
results = jira.jql(jql, limit=max_results)
|
||||
@@ -62,7 +664,7 @@ def create_jira_issue(
|
||||
labels: Optional[list[str]] = None,
|
||||
custom_fields: Optional[dict] = None,
|
||||
) -> dict:
|
||||
"""Create a Jira issue and return its key + id."""
|
||||
"""Create a Jira issue and return its key + id (uses global credentials)."""
|
||||
jira = get_jira_client()
|
||||
fields: dict = {
|
||||
"project": {"key": project_key},
|
||||
@@ -80,7 +682,7 @@ def create_jira_issue(
|
||||
|
||||
|
||||
def sync_jira_to_aegis(db: Session, link: JiraLink) -> None:
|
||||
"""Pull current status from Jira into the local link record."""
|
||||
"""Pull current status from Jira into the local link record (global creds)."""
|
||||
jira = get_jira_client()
|
||||
issue = jira.issue(link.jira_issue_key)
|
||||
fields = issue.get("fields", {})
|
||||
@@ -93,7 +695,7 @@ def sync_jira_to_aegis(db: Session, link: JiraLink) -> None:
|
||||
|
||||
|
||||
def sync_aegis_to_jira(db: Session, link: JiraLink, entity_data: dict) -> None:
|
||||
"""Push an Aegis status update as a Jira comment."""
|
||||
"""Push an Aegis status update as a Jira comment (global creds)."""
|
||||
jira = get_jira_client()
|
||||
comment_body = _build_sync_comment(entity_data)
|
||||
jira.issue_add_comment(link.jira_issue_key, comment_body)
|
||||
@@ -102,7 +704,6 @@ def sync_aegis_to_jira(db: Session, link: JiraLink, entity_data: dict) -> None:
|
||||
|
||||
|
||||
def _build_sync_comment(data: dict) -> str:
|
||||
"""Build a formatted Jira comment from entity data."""
|
||||
lines = ["h3. Aegis Sync Update", ""]
|
||||
for key, value in data.items():
|
||||
lines.append(f"*{key}:* {value}")
|
||||
@@ -110,7 +711,7 @@ def _build_sync_comment(data: dict) -> str:
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
# ── Link CRUD ────────────────────────────────────────────────────────
|
||||
# ── Link CRUD ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def create_link(
|
||||
@@ -122,7 +723,6 @@ def create_link(
|
||||
sync_direction: JiraSyncDirection,
|
||||
created_by: UUID,
|
||||
) -> JiraLink:
|
||||
"""Create a Jira link and optionally pull initial data from Jira."""
|
||||
link = JiraLink(
|
||||
entity_type=entity_type,
|
||||
entity_id=entity_id,
|
||||
@@ -147,18 +747,19 @@ def list_links(
|
||||
*,
|
||||
entity_type: Optional[JiraLinkEntityType] = None,
|
||||
entity_id: Optional[UUID] = None,
|
||||
entity_ids: Optional[list[UUID]] = None,
|
||||
) -> list[JiraLink]:
|
||||
"""List Jira links with optional filters."""
|
||||
query = db.query(JiraLink)
|
||||
if entity_type:
|
||||
query = query.filter(JiraLink.entity_type == entity_type)
|
||||
if entity_id:
|
||||
query = query.filter(JiraLink.entity_id == entity_id)
|
||||
elif entity_ids:
|
||||
query = query.filter(JiraLink.entity_id.in_(entity_ids))
|
||||
return query.order_by(JiraLink.created_at.desc()).all()
|
||||
|
||||
|
||||
def get_link_or_raise(db: Session, link_id: UUID) -> JiraLink:
|
||||
"""Get a Jira link by ID or raise EntityNotFoundError."""
|
||||
link = db.query(JiraLink).filter(JiraLink.id == link_id).first()
|
||||
if not link:
|
||||
raise EntityNotFoundError("JiraLink", str(link_id))
|
||||
@@ -166,23 +767,23 @@ def get_link_or_raise(db: Session, link_id: UUID) -> JiraLink:
|
||||
|
||||
|
||||
def delete_link(db: Session, link_id: UUID) -> JiraLink:
|
||||
"""Delete a Jira link. Returns the deleted link (for audit)."""
|
||||
link = get_link_or_raise(db, link_id)
|
||||
db.delete(link)
|
||||
return link
|
||||
|
||||
|
||||
def build_issue_data(db: Session, entity_type: JiraLinkEntityType, entity_id: UUID) -> tuple[str, str]:
|
||||
def build_issue_data(
|
||||
db: Session, entity_type: JiraLinkEntityType, entity_id: UUID
|
||||
) -> tuple[str, str]:
|
||||
"""Build Jira issue summary and description from an Aegis entity."""
|
||||
if entity_type == JiraLinkEntityType.test:
|
||||
entity = db.query(Test).filter(Test.id == entity_id).first()
|
||||
if not entity:
|
||||
raise EntityNotFoundError("Test", str(entity_id))
|
||||
technique = db.query(Technique).filter(Technique.id == entity.technique_id).first()
|
||||
return (
|
||||
f"[Aegis Test] {entity.name}",
|
||||
f"Test: {entity.name}\n"
|
||||
f"State: {entity.state.value if entity.state else 'draft'}\n"
|
||||
f"Description: {entity.description or 'N/A'}",
|
||||
f"[Aegis] {technique.mitre_id if technique else 'N/A'} — {entity.name}",
|
||||
_build_test_description(entity, technique),
|
||||
)
|
||||
elif entity_type == JiraLinkEntityType.campaign:
|
||||
entity = db.query(Campaign).filter(Campaign.id == entity_id).first()
|
||||
@@ -190,8 +791,7 @@ def build_issue_data(db: Session, entity_type: JiraLinkEntityType, entity_id: UU
|
||||
raise EntityNotFoundError("Campaign", str(entity_id))
|
||||
return (
|
||||
f"[Aegis Campaign] {entity.name}",
|
||||
f"Campaign: {entity.name}\n"
|
||||
f"Type: {entity.type}\nStatus: {entity.status}\n"
|
||||
f"Campaign: {entity.name}\nType: {entity.type}\nStatus: {entity.status}\n"
|
||||
f"Description: {entity.description or 'N/A'}",
|
||||
)
|
||||
elif entity_type == JiraLinkEntityType.technique:
|
||||
@@ -215,10 +815,11 @@ def create_issue_and_link(
|
||||
entity_id: UUID,
|
||||
created_by: UUID,
|
||||
) -> dict:
|
||||
"""Create a Jira issue from an Aegis entity and link them."""
|
||||
"""Create a Jira issue from an Aegis entity and link them (global creds)."""
|
||||
summary, description = build_issue_data(db, entity_type, entity_id)
|
||||
project_key = settings.JIRA_DEFAULT_PROJECT
|
||||
result = create_jira_issue(
|
||||
project_key=settings.JIRA_DEFAULT_PROJECT,
|
||||
project_key=project_key,
|
||||
summary=summary,
|
||||
description=description,
|
||||
labels=["aegis", entity_type.value],
|
||||
@@ -228,7 +829,7 @@ def create_issue_and_link(
|
||||
entity_id=entity_id,
|
||||
jira_issue_key=result["issue_key"],
|
||||
jira_issue_id=result["issue_id"],
|
||||
jira_project_key=settings.JIRA_DEFAULT_PROJECT,
|
||||
jira_project_key=project_key,
|
||||
created_by=created_by,
|
||||
)
|
||||
db.add(link)
|
||||
|
||||
@@ -0,0 +1,133 @@
|
||||
"""Phase 11: Lesson Learned service."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.domain.errors import EntityNotFoundError
|
||||
from app.models.knowledge import LessonLearned
|
||||
|
||||
|
||||
def _get_or_404(db: Session, ll_id: UUID) -> LessonLearned:
|
||||
ll = db.query(LessonLearned).filter(
|
||||
LessonLearned.id == ll_id,
|
||||
LessonLearned.is_active == True,
|
||||
).first()
|
||||
if not ll:
|
||||
raise EntityNotFoundError("LessonLearned", str(ll_id))
|
||||
return ll
|
||||
|
||||
|
||||
# ── Read ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
def get_lesson_learned(db: Session, ll_id: UUID) -> LessonLearned:
|
||||
return _get_or_404(db, ll_id)
|
||||
|
||||
|
||||
def list_lessons_learned(
|
||||
db: Session,
|
||||
entity_type: Optional[str] = None,
|
||||
entity_id: Optional[UUID] = None,
|
||||
severity: Optional[str] = None,
|
||||
tag: Optional[str] = None,
|
||||
technique_id: Optional[str] = None,
|
||||
include_inactive: bool = False,
|
||||
) -> List[LessonLearned]:
|
||||
q = db.query(LessonLearned)
|
||||
if not include_inactive:
|
||||
q = q.filter(LessonLearned.is_active == True)
|
||||
if entity_type:
|
||||
q = q.filter(LessonLearned.entity_type == entity_type)
|
||||
if entity_id:
|
||||
q = q.filter(LessonLearned.entity_id == entity_id)
|
||||
if severity:
|
||||
q = q.filter(LessonLearned.severity == severity)
|
||||
if tag:
|
||||
# JSONB contains operator — filter lessons that have this tag
|
||||
q = q.filter(LessonLearned.tags.contains([tag]))
|
||||
if technique_id:
|
||||
q = q.filter(LessonLearned.technique_ids.contains([technique_id]))
|
||||
return q.order_by(LessonLearned.created_at.desc()).all()
|
||||
|
||||
|
||||
# ── Create ────────────────────────────────────────────────────────────────────
|
||||
|
||||
def create_lesson_learned(db: Session, data: dict, user_id: UUID) -> LessonLearned:
|
||||
ll = LessonLearned(
|
||||
title = data["title"],
|
||||
what_happened = data.get("what_happened", ""),
|
||||
root_cause = data.get("root_cause", ""),
|
||||
fix_applied = data.get("fix_applied"),
|
||||
severity = data.get("severity", "medium"),
|
||||
entity_type = data.get("entity_type", "manual"),
|
||||
entity_id = data.get("entity_id"),
|
||||
technique_ids = data.get("technique_ids") or [],
|
||||
tags = data.get("tags") or [],
|
||||
created_by = user_id,
|
||||
)
|
||||
db.add(ll)
|
||||
db.commit()
|
||||
db.refresh(ll)
|
||||
return ll
|
||||
|
||||
|
||||
# ── Update ────────────────────────────────────────────────────────────────────
|
||||
|
||||
def update_lesson_learned(db: Session, ll_id: UUID, data: dict, user_id: UUID) -> LessonLearned:
|
||||
ll = _get_or_404(db, ll_id)
|
||||
|
||||
for field in ("title", "what_happened", "root_cause", "fix_applied",
|
||||
"severity", "technique_ids", "tags"):
|
||||
if field in data and data[field] is not None:
|
||||
setattr(ll, field, data[field])
|
||||
|
||||
ll.updated_at = datetime.utcnow()
|
||||
db.commit()
|
||||
db.refresh(ll)
|
||||
return ll
|
||||
|
||||
|
||||
# ── Delete (soft) ─────────────────────────────────────────────────────────────
|
||||
|
||||
def delete_lesson_learned(db: Session, ll_id: UUID, user_id: UUID) -> None:
|
||||
"""Soft-delete — admin-only enforcement is done at the router level."""
|
||||
ll = _get_or_404(db, ll_id)
|
||||
ll.is_active = False
|
||||
ll.updated_at = datetime.utcnow()
|
||||
db.commit()
|
||||
|
||||
|
||||
# ── Stats ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
def get_knowledge_stats(db: Session) -> dict:
|
||||
"""Summary counts for the knowledge base dashboard."""
|
||||
from app.models.knowledge import Playbook
|
||||
|
||||
total_playbooks = db.query(Playbook).filter(Playbook.is_active == True).count()
|
||||
total_lessons = db.query(LessonLearned).filter(LessonLearned.is_active == True).count()
|
||||
|
||||
severity_counts: dict = {}
|
||||
for sev in ("critical", "high", "medium", "low", "info"):
|
||||
severity_counts[sev] = (
|
||||
db.query(LessonLearned)
|
||||
.filter(LessonLearned.is_active == True, LessonLearned.severity == sev)
|
||||
.count()
|
||||
)
|
||||
|
||||
playbook_type_counts: dict = {}
|
||||
from app.schemas.knowledge_schema import VALID_PLAYBOOK_TYPES
|
||||
for ptype in VALID_PLAYBOOK_TYPES:
|
||||
playbook_type_counts[ptype] = (
|
||||
db.query(Playbook)
|
||||
.filter(Playbook.is_active == True, Playbook.playbook_type == ptype)
|
||||
.count()
|
||||
)
|
||||
|
||||
return {
|
||||
"total_playbooks": total_playbooks,
|
||||
"total_lessons": total_lessons,
|
||||
"lessons_by_severity": severity_counts,
|
||||
"playbooks_by_type": playbook_type_counts,
|
||||
}
|
||||
@@ -39,6 +39,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.test_template import TestTemplate
|
||||
from app.models.data_source import DataSource
|
||||
from app.models.technique import Technique
|
||||
from app.services.audit_service import log_action
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -295,6 +296,7 @@ def _upsert_templates(db: Session, items: list[dict], source_name: str) -> dict:
|
||||
|
||||
created = 0
|
||||
skipped = 0
|
||||
new_technique_ids: set[str] = set()
|
||||
|
||||
for item in items:
|
||||
if item["atomic_test_id"] in existing_ids:
|
||||
@@ -315,8 +317,14 @@ def _upsert_templates(db: Session, items: list[dict], source_name: str) -> dict:
|
||||
)
|
||||
db.add(template)
|
||||
existing_ids.add(item["atomic_test_id"])
|
||||
new_technique_ids.add(item["mitre_technique_id"])
|
||||
created += 1
|
||||
|
||||
if new_technique_ids:
|
||||
db.query(Technique).filter(
|
||||
Technique.mitre_id.in_(new_technique_ids)
|
||||
).update({"review_required": True}, synchronize_session=False)
|
||||
|
||||
db.commit()
|
||||
return {"created": created, "skipped_existing": skipped, "total_parsed": len(items)}
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user