Compare commits

3 Commits

Author SHA1 Message Date
kitos 0ddd17047d refactor(docs+comments): add Google-style docstrings and inline comments across backend
Task D — Google-style docstrings (Args/Returns) on every public function,
method, and class across all 158 Python files in the backend. Zero ruff D
violations (pydocstyle Google convention).

Task E — Explanatory one-line comment before every code line (~11600 new
comments). ruff check passes clean after isort re-sort.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-06-10 12:37:15 +02:00
kitos 394d5d9056 refactor(types): add comprehensive type annotations across backend Python codebase
Enable ANN rules in ruff.toml (flake8-annotations) and resolve all 221 violations:

ANN201/ANN202 — return types on 168 public/private functions:
- All 28 FastAPI routers: endpoints annotated with dict/list/specific schema/
  StreamingResponse/FileResponse/JSONResponse as appropriate
- main.py: lifespan→AsyncGenerator[None,None], exception handlers→JSONResponse
- database.py: get_db→Generator[Session,None,None], proxy methods→correct types
- middleware/request_context.py: dispatch→Response with Callable call_next type

ANN001/ANN002/ANN003 — 32 missing argument types:
- seed_demo.py: all db parameters typed as Session
- domain/unit_of_work.py: __aexit__ exc_type/exc_val/exc_tb typed with TracebackType
- services: audit_service user_id→UUID|None, heatmap_service query/model/builder,
  notification_service test→Test, tempo_service test→Test/user→User,
  test_workflow_service test_id→UUID, campaign_crud **fields→object,
  test_crud **fields→object (4 sites)

ANN401 — 16 Any usages resolved:
- Domain entities (campaign/technique/threat_actor/test_entity): replaced Any with
  actual ORM types via TYPE_CHECKING guards to avoid circular imports
- detection_rule_service: test_id/detection_rule_id/evaluator_id→UUID
- score_cache: kept Any with # noqa: ANN401 (genuinely generic cache)
- jira_service/tempo_service: kept Any with # noqa: ANN401 (lazy optional deps)
- d3fend_import_service: _to_str(v: Any) kept with # noqa: ANN401

ANN204/ANN205/ANN206 — special/static/class methods:
- database.py proxy __call__/__getattr__: *args: object/**kwargs: object
- schemas/test.py model_validate: obj→object, **kwargs→object
- sa_technique_repository._int_type→type

All 439 unit tests pass. ruff check app/ → All checks passed!

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-06-09 17:04:51 +02:00
kitos ec26183e2e refactor(pep8): enforce full PEP8 compliance across backend Python codebase
- ruff.toml: select E/W/F/I/N rules, line-length=120, drop legacy ignores
- Auto-fix: sort 82 import blocks (isort), remove 29 unused imports,
  strip 6 trailing-whitespace blank lines in docstrings
- main.py: move setup_logging and settings imports to top (E402)
- errors.py: noqa N818 on DDD exception names (96 call sites, safe)
- intel_service.py: noqa N817 for universal ET alias
- atomic/elastic/sigma import services: move _MAX_UNCOMPRESSED_SIZE and
  _MAX_ENTRIES to module level (N806)
- compliance_import_service.py: move SAMPLE_CONTROLS / CIS_CONTROLS to
  module level; wrap long description strings (N806 + E501)
- snapshot_service.py: move STATUS_ORDER dict to module level (N806)
- sigma_import_service.py: remove dead dedup_key expression (F841)
- threat_actor_import_service.py: remove dead stix_to_actor expression (F841)
- data_source.py, seed_demo.py, campaign_scheduler_service.py,
  lolbas_import_service.py: wrap lines exceeding 120 chars (E501)
- d3fend_import_service.py: per-file E501 ignore (data file with long strings)

All 439 unit tests pass. ruff check app/ → All checks passed!

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-06-09 16:40:14 +02:00
213 changed files with 2631 additions and 34324 deletions
+189
View File
@@ -0,0 +1,189 @@
---
description: Aegis backend Clean Architecture rules. Apply when working on any backend Python file under backend/app/ or backend/tests/.
globs: backend/**/*.py
---
# Aegis — Clean Modular Monolith Architecture
## Architecture Overview
Aegis follows a **Clean Architecture** pattern inside a modular monolith. The backend has four layers with strict dependency rules:
```
Presentation → Application → Domain ← Infrastructure
```
**The golden rule:** dependencies only point towards the Domain layer. Infrastructure implements the ports (interfaces) defined in Domain.
## Layer Structure and Rules
### Domain Layer (`backend/app/domain/`)
The innermost layer. **ZERO** imports from FastAPI, SQLAlchemy, Pydantic, or any framework.
| Directory | Purpose |
|-----------|---------|
| `domain/enums.py` | Canonical domain enums (TechniqueStatus, TestState, TeamSide, TestResult) |
| `domain/errors.py` | Exception hierarchy (DomainError → EntityNotFoundError, InvalidStateTransition, etc.) |
| `domain/exceptions.py` | Backward-compatible re-exports from errors.py |
| `domain/test_entity.py` | TestEntity — pure state machine with domain events |
| `domain/entities/` | Rich domain entities (TechniqueEntity, etc.) with business behavior |
| `domain/value_objects/` | Immutable value types (MitreId, ScoringWeights) |
| `domain/ports/repositories/` | Protocol interfaces defining data access contracts |
| `domain/ports/services/` | Protocol interfaces for external capabilities (storage, events) |
| `domain/unit_of_work.py` | UnitOfWork wrapping SQLAlchemy session |
**NEVER** import from `app.models`, `app.routers`, `app.infrastructure`, `fastapi`, or `sqlalchemy` inside `domain/`.
### Application Layer (`backend/app/application/` — future)
Use case orchestrators. Depends only on Domain.
| Directory | Purpose |
|-----------|---------|
| `application/use_cases/` | One class per business operation |
| `application/dto/` | Plain data containers for use case input/output |
| `application/interfaces/` | Application-level contracts (UnitOfWork protocol) |
### Infrastructure Layer (`backend/app/infrastructure/`)
Implements ports defined in Domain. Depends on Domain and Application.
| Directory | Purpose |
|-----------|---------|
| `infrastructure/redis_client.py` | Redis connection singleton |
| `infrastructure/persistence/repositories/` | SQLAlchemy implementations of repository ports |
| `infrastructure/persistence/mappers/` | ORM model ↔ domain entity converters |
### Presentation Layer (routers, schemas, dependencies)
HTTP boundary. Depends on Application and Domain (for exceptions).
| Directory | Purpose |
|-----------|---------|
| `routers/` | FastAPI routers — HTTP mapping only |
| `schemas/` | Pydantic request/response models |
| `dependencies/` | FastAPI `Depends()` wiring (auth, repositories) |
| `middleware/` | Error handler mapping domain exceptions → HTTP responses |
## Import Rules (Strict)
| From \ To | domain/ | application/ | infrastructure/ | presentation/ |
|-----------|---------|-------------|----------------|--------------|
| **domain/** | Self only | FORBIDDEN | FORBIDDEN | FORBIDDEN |
| **application/** | ALLOWED | Self only | FORBIDDEN | FORBIDDEN |
| **infrastructure/** | ALLOWED (ports) | ALLOWED (UoW) | Self only | FORBIDDEN |
| **presentation/** | ALLOWED (exceptions) | ALLOWED (use cases) | ALLOWED (wiring in dependencies/) | Self only |
## How to Add a New Feature
### 1. Start from the Domain
- Define or reuse domain entities in `domain/entities/`
- Add value objects if needed in `domain/value_objects/`
- Define repository port if a new aggregate root in `domain/ports/repositories/`
- Domain exceptions go in `domain/errors.py`
- Business rules live IN the entity, not in services or routers
### 2. Implement Infrastructure
- Create SQLAlchemy repository implementation in `infrastructure/persistence/repositories/`
- Create mapper if converting between ORM model and domain entity
- Repository does NOT call `commit()` — only `flush()`
- Transaction control belongs to the Unit of Work
### 3. Wire in Presentation
- Add FastAPI `Depends()` provider in `dependencies/repositories.py`
- Keep routers thin: parse request → call service/use case → return response
- Map domain exceptions to HTTP via the error handler middleware (automatic)
### 4. Tests (Mandatory)
Every change MUST include tests:
- **Domain entities/value objects**: pure unit tests, no DB, no mocking frameworks
- **Repositories**: integration tests using the `db` fixture from conftest
- **Routers**: API tests using the `client` fixture
- At least one success test + one failure/edge-case test per behavior
Before committing, run: `scripts/agent_validate_backend.sh`
## Existing Patterns to Follow
### Domain Entity Pattern (see `domain/test_entity.py`)
```python
@dataclass
class SomeEntity:
id: uuid.UUID
# fields...
_events: list[DomainEvent] = field(default_factory=list, repr=False)
@classmethod
def from_orm(cls, model: Any) -> "SomeEntity":
"""Build from SQLAlchemy model."""
...
def apply_to(self, model: Any) -> None:
"""Copy mutable fields back onto the ORM model."""
...
def some_business_method(self) -> None:
"""Business logic lives HERE, not in services."""
...
self._events.append(DomainEvent("something_happened"))
```
### Repository Port Pattern (Protocol)
```python
from typing import Protocol, runtime_checkable
@runtime_checkable
class SomeRepository(Protocol):
def find_by_id(self, id: uuid.UUID) -> SomeEntity | None: ...
def save(self, entity: SomeEntity) -> SomeEntity: ...
```
### Repository Implementation Pattern
```python
class SASomeRepository:
def __init__(self, session: Session) -> None:
self._session = session
def find_by_id(self, id: uuid.UUID) -> SomeEntity | None:
model = self._session.query(SomeModel).filter(SomeModel.id == id).first()
return SomeMapper.to_entity(model) if model else None
def save(self, entity: SomeEntity) -> SomeEntity:
model = SomeMapper.to_model(entity)
merged = self._session.merge(model)
self._session.flush() # NO commit — UoW does that
return SomeMapper.to_entity(merged)
```
### Error Handling (automatic via middleware)
Services raise domain exceptions → middleware maps to HTTP:
- `EntityNotFoundError` → 404
- `DuplicateEntityError` → 409
- `InvalidStateTransition` → 400
- `BusinessRuleViolation` → 400
- `PermissionViolation` → 403
### Coexistence Strategy
Old code (direct `db.query()` in routers) and new code (repositories) coexist. Migration is incremental:
1. New endpoints use repositories
2. Existing endpoints are migrated one at a time
3. Both access the same DB, same session, same tables
## Key Conventions
- **Enums**: canonical source is `domain/enums.py`, `models/enums.py` re-exports
- **Exceptions**: raise from `domain/errors.py`, never raise `HTTPException` from services
- **Commits**: only via `UnitOfWork.commit()` or at the router level, never inside services/repos
- **IDs**: UUID everywhere (primary keys, foreign keys)
- **Tests**: SQLite in-memory for unit/integration, PostgreSQL in CI
- **Validation**: Pydantic in schemas (presentation), domain rules in entities (domain)
-71
View File
@@ -1,71 +0,0 @@
name: Snyk Security Scan
on:
push:
branches: [main, develop]
pull_request:
branches: [main]
schedule:
- cron: '0 6 * * 1' # Weekly on Monday 06:00 UTC
jobs:
snyk-backend:
name: Python vulnerabilities (backend)
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Install backend dependencies
run: pip install -r backend/requirements-lock.txt
- name: Snyk — scan Python packages
uses: snyk/actions/python@master
env:
SNYK_TOKEN: ${{ secrets.SNYK_TOKEN }}
with:
args: --file=backend/requirements-lock.txt --severity-threshold=high
continue-on-error: true # report without blocking CI during initial cleanup
snyk-frontend:
name: npm vulnerabilities (frontend)
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-node@v4
with:
node-version: '20'
- name: Install frontend dependencies
run: npm ci
working-directory: frontend
- name: Snyk — scan npm packages
uses: snyk/actions/node@master
env:
SNYK_TOKEN: ${{ secrets.SNYK_TOKEN }}
with:
args: --file=frontend/package.json --severity-threshold=high
continue-on-error: true
snyk-docker-backend:
name: Docker image vulnerabilities (backend)
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Build backend image for scanning
run: docker build -t aegis-backend:scan backend/
- name: Snyk — scan Docker image
uses: snyk/actions/docker@master
env:
SNYK_TOKEN: ${{ secrets.SNYK_TOKEN }}
with:
image: aegis-backend:scan
args: --severity-threshold=high
continue-on-error: true
-9
View File
@@ -60,12 +60,3 @@ Thumbs.db
# Local development # Local development
*.local *.local
# Documentation drafts — never commit, delivered directly in chat
docs/confluence/
docs/drafts/
# Editor / AI assistant working files — never commit
.claude/
.cursor/
CLAUDE.md
-2
View File
@@ -1,2 +0,0 @@
skips:
- B311
+1 -5
View File
@@ -3,14 +3,10 @@ FROM python:3.11-slim
WORKDIR /app WORKDIR /app
# Install system dependencies # Install system dependencies
RUN apt-get update && apt-get upgrade -y && apt-get install -y \ RUN apt-get update && apt-get install -y \
gcc \ gcc \
libpq-dev \ libpq-dev \
curl \ curl \
pkg-config \
libxml2-dev \
libxmlsec1-dev \
libxmlsec1-openssl \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
# Copy requirements first for better caching # Copy requirements first for better caching
@@ -1,32 +0,0 @@
"""Phase 6.1: webhook_configs table.
Revision ID: b031phase6
Revises: b030phase5
"""
from typing import Sequence, Union
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from alembic import op
revision: str = "b031phase6"
down_revision: Union[str, None] = "b030phase5"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"webhook_configs",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
sa.Column("name", sa.String(200), nullable=False),
sa.Column("url", sa.Text, nullable=False),
sa.Column("secret", sa.String(256), nullable=True),
sa.Column("events", postgresql.JSONB, nullable=False, server_default="[]"),
sa.Column("is_active", sa.Boolean, nullable=False, server_default="true"),
sa.Column("created_by", postgresql.UUID(as_uuid=True), sa.ForeignKey("users.id", ondelete="SET NULL"), nullable=True),
sa.Column("last_triggered_at", sa.DateTime, nullable=True),
sa.Column("failure_count", sa.Integer, nullable=False, server_default="0"),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
)
def downgrade() -> None:
op.drop_table("webhook_configs")
@@ -1,41 +0,0 @@
"""Phase 7.2: user notification_preferences and jira_account_id columns.
Revision ID: b032phase7
Revises: b031phase6
"""
from typing import Sequence, Union
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from alembic import op
revision: str = "b032phase7"
down_revision: Union[str, None] = "b031phase6"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
_DEFAULT_PREFS = '{"email_on_test_validated": true, "email_on_campaign_completed": true, "email_on_new_mitre_techniques": false, "in_app_all": true}'
def _column_names(table: str) -> set[str]:
bind = op.get_bind()
insp = sa.inspect(bind)
return {c["name"] for c in insp.get_columns(table)}
def upgrade() -> None:
user_cols = _column_names("users")
if "notification_preferences" not in user_cols:
op.add_column(
"users",
sa.Column("notification_preferences", postgresql.JSONB, nullable=True, server_default=_DEFAULT_PREFS),
)
if "jira_account_id" not in user_cols:
op.add_column(
"users",
sa.Column("jira_account_id", sa.String(100), nullable=True),
)
def downgrade() -> None:
user_cols = _column_names("users")
if "jira_account_id" in user_cols:
op.drop_column("users", "jira_account_id")
if "notification_preferences" in user_cols:
op.drop_column("users", "notification_preferences")
@@ -1,43 +0,0 @@
"""Phase 8: system_configs table for runtime configuration.
Revision ID: b033syscfg
Revises: b032phase7
"""
from typing import Sequence, Union
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from alembic import op
revision: str = "b033syscfg"
down_revision: Union[str, None] = "b032phase7"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def _table_exists(name: str) -> bool:
bind = op.get_bind()
insp = sa.inspect(bind)
return name in insp.get_table_names()
def upgrade() -> None:
if not _table_exists("system_configs"):
op.create_table(
"system_configs",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
sa.Column("key", sa.String(200), unique=True, nullable=False),
sa.Column("value", sa.Text, nullable=True),
sa.Column("description", sa.String(500), nullable=True),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
),
)
op.create_index("ix_system_configs_key", "system_configs", ["key"])
def downgrade() -> None:
if _table_exists("system_configs"):
op.drop_index("ix_system_configs_key", table_name="system_configs")
op.drop_table("system_configs")
@@ -1,174 +0,0 @@
"""Phase 8: Detection Lifecycle Management tables.
Revision ID: b034dlm
Revises: b033syscfg
"""
from typing import Sequence, Union
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from alembic import op
revision: str = "b034dlm"
down_revision: Union[str, None] = "b033syscfg"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def _table_exists(name: str) -> bool:
bind = op.get_bind()
insp = sa.inspect(bind)
return name in insp.get_table_names()
def upgrade() -> None:
if not _table_exists("detection_assets"):
op.create_table(
"detection_assets",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
sa.Column("name", sa.String(500), nullable=False),
sa.Column("description", sa.Text),
sa.Column("asset_type", sa.String(50), nullable=False),
sa.Column("platform", sa.String(100)),
sa.Column("rule_content", sa.Text),
sa.Column("rule_language", sa.String(50)),
sa.Column("rule_repository_url", sa.Text),
sa.Column("rule_file_path", sa.String(500)),
sa.Column("rule_version", sa.String(50)),
sa.Column("rule_hash", sa.String(64)),
sa.Column("last_rule_change_at", sa.DateTime),
sa.Column("log_source_name", sa.String(200)),
sa.Column("log_source_version", sa.String(50)),
sa.Column("log_source_config", postgresql.JSONB, server_default="{}"),
sa.Column("infrastructure_hash", sa.String(64)),
sa.Column("infrastructure_details", postgresql.JSONB, server_default="{}"),
sa.Column("health_status", sa.String(20), server_default="untested", nullable=False),
sa.Column("last_alert_at", sa.DateTime),
sa.Column("alert_count_30d", sa.Integer, server_default="0"),
sa.Column("false_positive_rate", sa.Float),
sa.Column("expected_alert_frequency", sa.String(50)),
sa.Column("owner_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("users.id", ondelete="SET NULL")),
sa.Column("backup_owner_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("users.id", ondelete="SET NULL")),
sa.Column("team", sa.String(100)),
sa.Column("is_active", sa.Boolean, server_default="true", nullable=False),
sa.Column("tags", postgresql.JSONB, server_default="[]"),
sa.Column("asset_metadata", postgresql.JSONB, server_default="{}"),
sa.Column("created_by", postgresql.UUID(as_uuid=True), sa.ForeignKey("users.id", ondelete="SET NULL")),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()")),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()")),
)
op.create_index("ix_detection_assets_platform", "detection_assets", ["platform"])
op.create_index("ix_detection_assets_health_status", "detection_assets", ["health_status"])
op.create_index("ix_detection_assets_owner_id", "detection_assets", ["owner_id"])
if not _table_exists("detection_technique_mappings"):
op.create_table(
"detection_technique_mappings",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
sa.Column("detection_asset_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("detection_assets.id", ondelete="CASCADE"), nullable=False),
sa.Column("technique_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("techniques.id", ondelete="CASCADE"), nullable=False),
sa.Column("coverage_type", sa.String(50), server_default="detect"),
sa.Column("confidence_level", sa.String(20), server_default="medium"),
sa.Column("notes", sa.Text),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()")),
)
op.create_index("ix_detection_technique_mappings_technique_id", "detection_technique_mappings", ["technique_id"])
op.create_index("ix_detection_technique_mappings_asset_id", "detection_technique_mappings", ["detection_asset_id"])
if not _table_exists("detection_validations"):
op.create_table(
"detection_validations",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
sa.Column("detection_asset_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("detection_assets.id", ondelete="CASCADE"), nullable=False),
sa.Column("technique_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("techniques.id", ondelete="SET NULL")),
sa.Column("test_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("tests.id", ondelete="SET NULL")),
sa.Column("validated_at", sa.DateTime),
sa.Column("expires_at", sa.DateTime, nullable=False),
sa.Column("is_valid", sa.Boolean, server_default="true", nullable=False),
sa.Column("validation_result", sa.String(50)),
sa.Column("validation_method", sa.String(100)),
sa.Column("rule_hash_at_validation", sa.String(64)),
sa.Column("log_source_version_at_validation", sa.String(50)),
sa.Column("infrastructure_hash_at_validation", sa.String(64)),
sa.Column("environment_snapshot", postgresql.JSONB, server_default="{}"),
sa.Column("invalidated_at", sa.DateTime),
sa.Column("invalidation_reason", sa.String(50)),
sa.Column("invalidation_details", sa.Text),
sa.Column("invalidated_by", postgresql.UUID(as_uuid=True), sa.ForeignKey("users.id", ondelete="SET NULL")),
sa.Column("validated_by", postgresql.UUID(as_uuid=True), sa.ForeignKey("users.id", ondelete="SET NULL"), nullable=False),
sa.Column("integrity_hash", sa.String(64)),
sa.Column("notes", sa.Text),
sa.Column("evidence_ids", postgresql.JSONB, server_default="[]"),
)
op.create_index("ix_detection_validations_asset_id_valid", "detection_validations", ["detection_asset_id", "is_valid"])
op.create_index("ix_detection_validations_expires_at", "detection_validations", ["expires_at"])
if not _table_exists("technique_confidence_scores"):
op.create_table(
"technique_confidence_scores",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
sa.Column("technique_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("techniques.id", ondelete="CASCADE"), nullable=False, unique=True),
sa.Column("confidence_level", sa.String(20), server_default="unknown"),
sa.Column("confidence_score", sa.Float, server_default="0.0"),
sa.Column("detection_count", sa.Integer, server_default="0"),
sa.Column("valid_detection_count", sa.Integer, server_default="0"),
sa.Column("last_validated_at", sa.DateTime),
sa.Column("next_validation_due", sa.DateTime),
sa.Column("last_recalculated_at", sa.DateTime),
sa.Column("recency_factor", sa.Float, server_default="0.0"),
sa.Column("coverage_factor", sa.Float, server_default="0.0"),
sa.Column("health_factor", sa.Float, server_default="0.0"),
sa.Column("diversity_factor", sa.Float, server_default="0.0"),
sa.Column("score_breakdown", postgresql.JSONB, server_default="{}"),
sa.Column("risk_factors", postgresql.JSONB, server_default="[]"),
sa.Column("updated_at", sa.DateTime),
)
op.create_index("ix_technique_confidence_scores_technique_id", "technique_confidence_scores", ["technique_id"])
op.create_index("ix_technique_confidence_scores_confidence_level", "technique_confidence_scores", ["confidence_level"])
if not _table_exists("infrastructure_change_logs"):
op.create_table(
"infrastructure_change_logs",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
sa.Column("change_type", sa.String(100), nullable=False),
sa.Column("description", sa.Text, nullable=False),
sa.Column("affected_platforms", postgresql.JSONB, server_default="[]"),
sa.Column("affected_log_sources", postgresql.JSONB, server_default="[]"),
sa.Column("change_date", sa.DateTime),
sa.Column("reported_by", postgresql.UUID(as_uuid=True), sa.ForeignKey("users.id", ondelete="SET NULL")),
sa.Column("auto_invalidate", sa.Boolean, server_default="true"),
sa.Column("invalidated_count", sa.Integer, server_default="0"),
sa.Column("change_metadata", postgresql.JSONB, server_default="{}"),
sa.Column("created_at", sa.DateTime),
)
op.create_index("ix_infrastructure_change_logs_change_date", "infrastructure_change_logs", ["change_date"])
if not _table_exists("decay_policies"):
op.create_table(
"decay_policies",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
sa.Column("name", sa.String(200), nullable=False),
sa.Column("description", sa.Text),
sa.Column("applies_to_platform", sa.String(100)),
sa.Column("applies_to_asset_type", sa.String(50)),
sa.Column("applies_to_tactic", sa.String(100)),
sa.Column("fresh_days", sa.Integer, server_default="90"),
sa.Column("aging_days", sa.Integer, server_default="180"),
sa.Column("stale_days", sa.Integer, server_default="365"),
sa.Column("default_validity_days", sa.Integer, server_default="180"),
sa.Column("silent_threshold_days", sa.Integer, server_default="30"),
sa.Column("noisy_threshold_daily", sa.Integer, server_default="100"),
sa.Column("recency_weight", sa.Float, server_default="0.3"),
sa.Column("coverage_weight", sa.Float, server_default="0.3"),
sa.Column("health_weight", sa.Float, server_default="0.25"),
sa.Column("diversity_weight", sa.Float, server_default="0.15"),
sa.Column("is_default", sa.Boolean, server_default="false"),
sa.Column("is_active", sa.Boolean, server_default="true"),
sa.Column("created_at", sa.DateTime),
sa.Column("updated_at", sa.DateTime),
)
def downgrade() -> None:
for table in ["decay_policies", "infrastructure_change_logs", "technique_confidence_scores", "detection_validations", "detection_technique_mappings", "detection_assets"]:
if _table_exists(table):
op.drop_table(table)
@@ -1,118 +0,0 @@
"""Phase 9: Ownership & Revalidation Queue
Revision ID: b035ownerq
Revises: b034dlm
Create Date: 2026-05-19
Uses raw SQL for all DDL to avoid SQLAlchemy before_create hook issues
with existing enum types.
"""
from typing import Union
from alembic import op
import sqlalchemy as sa
revision: str = "b035ownerq"
down_revision: Union[str, None] = "b034dlm"
branch_labels = None
depends_on = None
def upgrade() -> None:
conn = op.get_bind()
# ── Enums (idempotent) ────────────────────────────────────────────────────
conn.execute(sa.text("""
DO $$ BEGIN
CREATE TYPE queue_priority AS ENUM ('critical', 'high', 'medium', 'low');
EXCEPTION WHEN duplicate_object THEN NULL;
END $$
"""))
conn.execute(sa.text("""
DO $$ BEGIN
CREATE TYPE queue_status AS ENUM ('pending', 'in_progress', 'completed', 'dismissed');
EXCEPTION WHEN duplicate_object THEN NULL;
END $$
"""))
conn.execute(sa.text("""
DO $$ BEGIN
CREATE TYPE queue_reason AS ENUM (
'validation_expired', 'infra_change', 'osint_alert',
'mitre_update', 'rule_modified', 'low_confidence', 'manual');
EXCEPTION WHEN duplicate_object THEN NULL;
END $$
"""))
# ── technique_ownerships ──────────────────────────────────────────────────
conn.execute(sa.text("""
CREATE TABLE IF NOT EXISTS technique_ownerships (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
technique_id UUID NOT NULL UNIQUE
REFERENCES techniques(id) ON DELETE CASCADE,
owner_id UUID
REFERENCES users(id) ON DELETE SET NULL,
backup_owner_id UUID
REFERENCES users(id) ON DELETE SET NULL,
team VARCHAR(200),
notes TEXT,
assigned_at TIMESTAMP,
assigned_by UUID
REFERENCES users(id) ON DELETE SET NULL,
created_at TIMESTAMP DEFAULT now(),
updated_at TIMESTAMP DEFAULT now()
)
"""))
conn.execute(sa.text(
"CREATE INDEX IF NOT EXISTS ix_techown_owner_id ON technique_ownerships (owner_id)"
))
conn.execute(sa.text(
"CREATE INDEX IF NOT EXISTS ix_techown_technique_id ON technique_ownerships (technique_id)"
))
# ── revalidation_queue_items ──────────────────────────────────────────────
conn.execute(sa.text("""
CREATE TABLE IF NOT EXISTS revalidation_queue_items (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
technique_id UUID
REFERENCES techniques(id) ON DELETE CASCADE,
detection_asset_id UUID
REFERENCES detection_assets(id) ON DELETE CASCADE,
priority queue_priority NOT NULL DEFAULT 'medium',
reason queue_reason NOT NULL,
reason_detail TEXT,
status queue_status NOT NULL DEFAULT 'pending',
assigned_to UUID
REFERENCES users(id) ON DELETE SET NULL,
due_date TIMESTAMP,
created_at TIMESTAMP DEFAULT now(),
completed_at TIMESTAMP,
dismissed_at TIMESTAMP,
completed_by UUID
REFERENCES users(id) ON DELETE SET NULL,
extra JSONB
)
"""))
conn.execute(sa.text(
"CREATE INDEX IF NOT EXISTS ix_rqueue_status ON revalidation_queue_items (status)"
))
conn.execute(sa.text(
"CREATE INDEX IF NOT EXISTS ix_rqueue_priority ON revalidation_queue_items (priority)"
))
conn.execute(sa.text(
"CREATE INDEX IF NOT EXISTS ix_rqueue_assigned_to ON revalidation_queue_items (assigned_to)"
))
conn.execute(sa.text(
"CREATE INDEX IF NOT EXISTS ix_rqueue_technique_id ON revalidation_queue_items (technique_id)"
))
conn.execute(sa.text(
"CREATE INDEX IF NOT EXISTS ix_rqueue_asset_id ON revalidation_queue_items (detection_asset_id)"
))
def downgrade() -> None:
conn = op.get_bind()
conn.execute(sa.text("DROP TABLE IF EXISTS revalidation_queue_items"))
conn.execute(sa.text("DROP TABLE IF EXISTS technique_ownerships"))
conn.execute(sa.text("DROP TYPE IF EXISTS queue_reason"))
conn.execute(sa.text("DROP TYPE IF EXISTS queue_status"))
conn.execute(sa.text("DROP TYPE IF EXISTS queue_priority"))
@@ -1,184 +0,0 @@
"""Phase 10: Attack Paths & Advanced Purple Team
Revision ID: b036atk
Revises: b035ownerq
Create Date: 2026-05-19
Uses raw SQL to avoid SQLAlchemy DDL hook issues with enum types.
"""
from typing import Union
from alembic import op
import sqlalchemy as sa
revision: str = "b036atk"
down_revision: Union[str, None] = "b035ownerq"
branch_labels = None
depends_on = None
def upgrade() -> None:
conn = op.get_bind()
# ── Enums ─────────────────────────────────────────────────────────────────
conn.execute(sa.text("""
DO $$ BEGIN
CREATE TYPE execution_status AS ENUM
('planned','in_progress','completed','aborted');
EXCEPTION WHEN duplicate_object THEN NULL; END $$
"""))
conn.execute(sa.text("""
DO $$ BEGIN
CREATE TYPE step_result_status AS ENUM
('pending','executing','detected','not_detected','skipped');
EXCEPTION WHEN duplicate_object THEN NULL; END $$
"""))
conn.execute(sa.text("""
DO $$ BEGIN
CREATE TYPE timeline_actor_side AS ENUM ('red','blue','system');
EXCEPTION WHEN duplicate_object THEN NULL; END $$
"""))
conn.execute(sa.text("""
DO $$ BEGIN
CREATE TYPE timeline_entry_type AS ENUM
('action','detection','note','phase_transition','flag');
EXCEPTION WHEN duplicate_object THEN NULL; END $$
"""))
# ── attack_paths ──────────────────────────────────────────────────────────
conn.execute(sa.text("""
CREATE TABLE IF NOT EXISTS attack_paths (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
name VARCHAR(300) NOT NULL,
description TEXT,
objective TEXT,
is_template BOOLEAN DEFAULT FALSE,
threat_actor_id UUID REFERENCES threat_actors(id) ON DELETE SET NULL,
created_by UUID REFERENCES users(id) ON DELETE SET NULL,
tags JSONB,
is_active BOOLEAN DEFAULT TRUE,
created_at TIMESTAMP DEFAULT now(),
updated_at TIMESTAMP DEFAULT now()
)
"""))
conn.execute(sa.text(
"CREATE INDEX IF NOT EXISTS ix_attack_paths_created_by ON attack_paths (created_by)"
))
conn.execute(sa.text(
"CREATE INDEX IF NOT EXISTS ix_attack_paths_is_template ON attack_paths (is_template)"
))
# ── attack_path_steps ─────────────────────────────────────────────────────
conn.execute(sa.text("""
CREATE TABLE IF NOT EXISTS attack_path_steps (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
attack_path_id UUID NOT NULL REFERENCES attack_paths(id) ON DELETE CASCADE,
order_index INTEGER NOT NULL DEFAULT 0,
kill_chain_phase VARCHAR(60),
technique_id UUID REFERENCES techniques(id) ON DELETE SET NULL,
test_id UUID REFERENCES tests(id) ON DELETE SET NULL,
name VARCHAR(300),
description TEXT,
expected_detection BOOLEAN DEFAULT TRUE,
notes TEXT
)
"""))
conn.execute(sa.text(
"CREATE INDEX IF NOT EXISTS ix_ap_steps_path_id ON attack_path_steps (attack_path_id)"
))
conn.execute(sa.text(
"CREATE INDEX IF NOT EXISTS ix_ap_steps_technique_id ON attack_path_steps (technique_id)"
))
# ── attack_path_executions ────────────────────────────────────────────────
conn.execute(sa.text("""
CREATE TABLE IF NOT EXISTS attack_path_executions (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
attack_path_id UUID NOT NULL REFERENCES attack_paths(id) ON DELETE CASCADE,
status execution_status NOT NULL DEFAULT 'planned',
environment VARCHAR(100),
red_team_lead UUID REFERENCES users(id) ON DELETE SET NULL,
blue_team_lead UUID REFERENCES users(id) ON DELETE SET NULL,
started_by UUID REFERENCES users(id) ON DELETE SET NULL,
started_at TIMESTAMP,
completed_at TIMESTAMP,
notes TEXT,
created_at TIMESTAMP DEFAULT now(),
-- kill-chain metrics (populated on completion)
total_steps INTEGER,
detected_steps INTEGER,
not_detected_steps INTEGER,
skipped_steps INTEGER,
detection_rate FLOAT,
mttd_seconds FLOAT,
furthest_undetected_step INTEGER
)
"""))
conn.execute(sa.text(
"CREATE INDEX IF NOT EXISTS ix_ap_exec_path_id ON attack_path_executions (attack_path_id)"
))
conn.execute(sa.text(
"CREATE INDEX IF NOT EXISTS ix_ap_exec_status ON attack_path_executions (status)"
))
# ── attack_path_step_results ──────────────────────────────────────────────
conn.execute(sa.text("""
CREATE TABLE IF NOT EXISTS attack_path_step_results (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
execution_id UUID NOT NULL
REFERENCES attack_path_executions(id) ON DELETE CASCADE,
step_id UUID NOT NULL
REFERENCES attack_path_steps(id) ON DELETE CASCADE,
step_order INTEGER NOT NULL DEFAULT 0,
status step_result_status NOT NULL DEFAULT 'pending',
executed_by UUID REFERENCES users(id) ON DELETE SET NULL,
executed_at TIMESTAMP,
detected_at TIMESTAMP,
time_to_detect_seconds FLOAT,
detection_asset_id UUID
REFERENCES detection_assets(id) ON DELETE SET NULL,
notes TEXT,
evidence_ids JSONB
)
"""))
conn.execute(sa.text(
"CREATE INDEX IF NOT EXISTS ix_ap_stepres_exec ON attack_path_step_results (execution_id)"
))
conn.execute(sa.text(
"CREATE INDEX IF NOT EXISTS ix_ap_stepres_step ON attack_path_step_results (step_id)"
))
# ── attack_path_timeline ──────────────────────────────────────────────────
conn.execute(sa.text("""
CREATE TABLE IF NOT EXISTS attack_path_timeline (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
execution_id UUID NOT NULL
REFERENCES attack_path_executions(id) ON DELETE CASCADE,
step_id UUID REFERENCES attack_path_steps(id) ON DELETE SET NULL,
timestamp TIMESTAMP NOT NULL DEFAULT now(),
actor_side timeline_actor_side NOT NULL,
actor_id UUID REFERENCES users(id) ON DELETE SET NULL,
entry_type timeline_entry_type NOT NULL,
content TEXT NOT NULL,
extra JSONB
)
"""))
conn.execute(sa.text(
"CREATE INDEX IF NOT EXISTS ix_timeline_execution_id ON attack_path_timeline (execution_id)"
))
conn.execute(sa.text(
"CREATE INDEX IF NOT EXISTS ix_timeline_timestamp ON attack_path_timeline (timestamp)"
))
def downgrade() -> None:
conn = op.get_bind()
conn.execute(sa.text("DROP TABLE IF EXISTS attack_path_timeline"))
conn.execute(sa.text("DROP TABLE IF EXISTS attack_path_step_results"))
conn.execute(sa.text("DROP TABLE IF EXISTS attack_path_executions"))
conn.execute(sa.text("DROP TABLE IF EXISTS attack_path_steps"))
conn.execute(sa.text("DROP TABLE IF EXISTS attack_paths"))
conn.execute(sa.text("DROP TYPE IF EXISTS timeline_entry_type"))
conn.execute(sa.text("DROP TYPE IF EXISTS timeline_actor_side"))
conn.execute(sa.text("DROP TYPE IF EXISTS step_result_status"))
conn.execute(sa.text("DROP TYPE IF EXISTS execution_status"))
-106
View File
@@ -1,106 +0,0 @@
"""Phase 11: Knowledge Management — Playbooks + Lessons Learned
Revision ID: b037know
Revises: b036atk
Create Date: 2026-05-20
Uses raw SQL to bypass SQLAlchemy DDL hooks (no enum types — string columns
with Pydantic-layer validation instead, so no PostgreSQL enums needed).
"""
from typing import Union
from alembic import op
import sqlalchemy as sa
revision: str = "b037know"
down_revision: Union[str, None] = "b036atk"
branch_labels = None
depends_on = None
def upgrade() -> None:
conn = op.get_bind()
# ── playbooks ──────────────────────────────────────────────────────────────
conn.execute(sa.text("""
CREATE TABLE IF NOT EXISTS playbooks (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
technique_id UUID NOT NULL REFERENCES techniques(id) ON DELETE CASCADE,
playbook_type VARCHAR(32) NOT NULL,
title VARCHAR(255) NOT NULL,
content TEXT NOT NULL DEFAULT '',
version INTEGER NOT NULL DEFAULT 1,
tools JSONB,
prerequisites JSONB,
created_by UUID REFERENCES users(id) ON DELETE SET NULL,
updated_by UUID REFERENCES users(id) ON DELETE SET NULL,
created_at TIMESTAMP DEFAULT now(),
updated_at TIMESTAMP DEFAULT now(),
is_active BOOLEAN DEFAULT TRUE,
CONSTRAINT uq_playbook_technique_type UNIQUE (technique_id, playbook_type)
)
"""))
conn.execute(sa.text(
"CREATE INDEX IF NOT EXISTS ix_playbooks_technique_id ON playbooks (technique_id)"
))
conn.execute(sa.text(
"CREATE INDEX IF NOT EXISTS ix_playbooks_type ON playbooks (playbook_type)"
))
# ── playbook_versions ──────────────────────────────────────────────────────
conn.execute(sa.text("""
CREATE TABLE IF NOT EXISTS playbook_versions (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
playbook_id UUID NOT NULL REFERENCES playbooks(id) ON DELETE CASCADE,
version INTEGER NOT NULL,
title VARCHAR(255) NOT NULL,
content TEXT NOT NULL DEFAULT '',
tools JSONB,
prerequisites JSONB,
changed_by UUID REFERENCES users(id) ON DELETE SET NULL,
change_note VARCHAR(500),
created_at TIMESTAMP DEFAULT now()
)
"""))
conn.execute(sa.text(
"CREATE INDEX IF NOT EXISTS ix_pb_versions_playbook_id ON playbook_versions (playbook_id)"
))
conn.execute(sa.text(
"CREATE INDEX IF NOT EXISTS ix_pb_versions_version ON playbook_versions (playbook_id, version)"
))
# ── lessons_learned ────────────────────────────────────────────────────────
conn.execute(sa.text("""
CREATE TABLE IF NOT EXISTS lessons_learned (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
title VARCHAR(255) NOT NULL,
what_happened TEXT NOT NULL DEFAULT '',
root_cause TEXT NOT NULL DEFAULT '',
fix_applied TEXT,
severity VARCHAR(16) NOT NULL DEFAULT 'medium',
entity_type VARCHAR(32) NOT NULL DEFAULT 'manual',
entity_id UUID,
technique_ids JSONB,
tags JSONB,
created_by UUID REFERENCES users(id) ON DELETE SET NULL,
created_at TIMESTAMP DEFAULT now(),
updated_at TIMESTAMP DEFAULT now(),
is_active BOOLEAN DEFAULT TRUE
)
"""))
conn.execute(sa.text(
"CREATE INDEX IF NOT EXISTS ix_ll_entity ON lessons_learned (entity_type, entity_id)"
))
conn.execute(sa.text(
"CREATE INDEX IF NOT EXISTS ix_ll_severity ON lessons_learned (severity)"
))
conn.execute(sa.text(
"CREATE INDEX IF NOT EXISTS ix_ll_created_by ON lessons_learned (created_by)"
))
def downgrade() -> None:
conn = op.get_bind()
conn.execute(sa.text("DROP TABLE IF EXISTS lessons_learned"))
conn.execute(sa.text("DROP TABLE IF EXISTS playbook_versions"))
conn.execute(sa.text("DROP TABLE IF EXISTS playbooks"))
@@ -1,62 +0,0 @@
"""Phase 12: Risk Intelligence — technique_risk_profiles table
Revision ID: b038risk
Revises: b037know
Create Date: 2026-05-20
Uses raw SQL to bypass SQLAlchemy DDL hooks.
"""
from typing import Union
from alembic import op
import sqlalchemy as sa
revision: str = "b038risk"
down_revision: Union[str, None] = "b037know"
branch_labels = None
depends_on = None
def upgrade() -> None:
conn = op.get_bind()
conn.execute(sa.text("""
CREATE TABLE IF NOT EXISTS technique_risk_profiles (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
technique_id UUID NOT NULL REFERENCES techniques(id) ON DELETE CASCADE,
risk_score FLOAT NOT NULL DEFAULT 0.0,
likelihood FLOAT NOT NULL DEFAULT 0.0,
impact FLOAT NOT NULL DEFAULT 0.0,
risk_level VARCHAR(16) NOT NULL DEFAULT 'info',
detection_gap FLOAT NOT NULL DEFAULT 1.0,
threat_actor_count INTEGER NOT NULL DEFAULT 0,
osint_signal_count INTEGER NOT NULL DEFAULT 0,
test_fail_count INTEGER NOT NULL DEFAULT 0,
test_total_count INTEGER NOT NULL DEFAULT 0,
test_failure_rate FLOAT NOT NULL DEFAULT 0.0,
confidence_level FLOAT NOT NULL DEFAULT 0.0,
scoring_breakdown JSONB,
recommendations JSONB,
computed_at TIMESTAMP DEFAULT now(),
is_stale BOOLEAN DEFAULT TRUE,
CONSTRAINT uq_risk_profile_technique UNIQUE (technique_id)
)
"""))
conn.execute(sa.text(
"CREATE INDEX IF NOT EXISTS ix_risk_profiles_risk_score "
"ON technique_risk_profiles (risk_score)"
))
conn.execute(sa.text(
"CREATE INDEX IF NOT EXISTS ix_risk_profiles_risk_level "
"ON technique_risk_profiles (risk_level)"
))
conn.execute(sa.text(
"CREATE INDEX IF NOT EXISTS ix_risk_profiles_stale "
"ON technique_risk_profiles (is_stale)"
))
def downgrade() -> None:
conn = op.get_bind()
conn.execute(sa.text("DROP TABLE IF EXISTS technique_risk_profiles"))
@@ -1,77 +0,0 @@
"""Phase 13: Executive Dashboard — posture_snapshots table.
Revision ID: b039exec
Revises: b038risk
Create Date: 2026-05-20
"""
from alembic import op
import sqlalchemy as sa
revision = "b039exec"
down_revision = "b038risk"
branch_labels = None
depends_on = None
def upgrade() -> None:
conn = op.get_bind()
conn.execute(sa.text("""
CREATE TABLE IF NOT EXISTS posture_snapshots (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
snapshot_date DATE NOT NULL,
-- Coverage
total_techniques INTEGER NOT NULL DEFAULT 0,
validated_count INTEGER NOT NULL DEFAULT 0,
partial_count INTEGER NOT NULL DEFAULT 0,
not_covered_count INTEGER NOT NULL DEFAULT 0,
coverage_pct FLOAT NOT NULL DEFAULT 0.0,
-- Risk
avg_risk_score FLOAT NOT NULL DEFAULT 0.0,
critical_count INTEGER NOT NULL DEFAULT 0,
high_count INTEGER NOT NULL DEFAULT 0,
medium_count INTEGER NOT NULL DEFAULT 0,
low_count INTEGER NOT NULL DEFAULT 0,
-- Operations
open_queue_items INTEGER NOT NULL DEFAULT 0,
orphan_techniques INTEGER NOT NULL DEFAULT 0,
-- Knowledge
playbook_count INTEGER NOT NULL DEFAULT 0,
lesson_count INTEGER NOT NULL DEFAULT 0,
-- MTTD
mttd_avg_seconds FLOAT,
executions_30d INTEGER NOT NULL DEFAULT 0,
detection_rate_30d FLOAT,
-- Meta
created_by UUID REFERENCES users(id) ON DELETE SET NULL,
created_at TIMESTAMP WITHOUT TIME ZONE DEFAULT now(),
extra JSONB
)
"""))
# Unique constraint: one snapshot per calendar day
conn.execute(sa.text("""
DO $$ BEGIN
ALTER TABLE posture_snapshots
ADD CONSTRAINT uq_posture_snapshot_date UNIQUE (snapshot_date);
EXCEPTION WHEN duplicate_table THEN NULL;
WHEN duplicate_object THEN NULL;
END $$
"""))
# Index for date-range trend queries
conn.execute(sa.text("""
CREATE INDEX IF NOT EXISTS ix_posture_snapshots_date
ON posture_snapshots (snapshot_date)
"""))
def downgrade() -> None:
conn = op.get_bind()
conn.execute(sa.text("DROP TABLE IF EXISTS posture_snapshots CASCADE"))
@@ -1,75 +0,0 @@
"""Phase 14: Enterprise Readiness — api_keys and sso_configs tables.
Revision ID: b040ent
Revises: b039exec
Create Date: 2026-05-20
"""
from alembic import op
import sqlalchemy as sa
revision = "b040ent"
down_revision = "b039exec"
branch_labels = None
depends_on = None
def upgrade() -> None:
conn = op.get_bind()
# ── api_keys ──────────────────────────────────────────────────────────────
conn.execute(sa.text("""
CREATE TABLE IF NOT EXISTS api_keys (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
name VARCHAR(200) NOT NULL,
description TEXT,
key_prefix VARCHAR(13) NOT NULL,
key_hash VARCHAR(64) NOT NULL UNIQUE,
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
scopes JSONB NOT NULL DEFAULT '["read"]',
last_used_at TIMESTAMP WITHOUT TIME ZONE,
expires_at TIMESTAMP WITHOUT TIME ZONE,
is_active BOOLEAN NOT NULL DEFAULT TRUE,
created_at TIMESTAMP WITHOUT TIME ZONE DEFAULT now()
)
"""))
conn.execute(sa.text(
"CREATE INDEX IF NOT EXISTS ix_api_keys_user_id ON api_keys (user_id)"
))
conn.execute(sa.text(
"CREATE INDEX IF NOT EXISTS ix_api_keys_key_hash ON api_keys (key_hash)"
))
conn.execute(sa.text(
"CREATE INDEX IF NOT EXISTS ix_api_keys_active ON api_keys (is_active)"
))
# ── sso_configs ───────────────────────────────────────────────────────────
conn.execute(sa.text("""
CREATE TABLE IF NOT EXISTS sso_configs (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
is_enabled BOOLEAN NOT NULL DEFAULT FALSE,
provider_name VARCHAR(200),
sp_entity_id VARCHAR(500),
sp_acs_url VARCHAR(500),
sp_slo_url VARCHAR(500),
sp_certificate TEXT,
sp_private_key TEXT,
idp_entity_id VARCHAR(500),
idp_sso_url VARCHAR(500),
idp_slo_url VARCHAR(500),
idp_certificate TEXT,
attr_email VARCHAR(200) DEFAULT 'email',
attr_username VARCHAR(200) DEFAULT 'username',
attr_role VARCHAR(200) DEFAULT 'role',
default_role VARCHAR(50) DEFAULT 'viewer',
auto_provision BOOLEAN NOT NULL DEFAULT TRUE,
created_at TIMESTAMP WITHOUT TIME ZONE DEFAULT now(),
updated_at TIMESTAMP WITHOUT TIME ZONE DEFAULT now()
)
"""))
def downgrade() -> None:
conn = op.get_bind()
conn.execute(sa.text("DROP TABLE IF EXISTS api_keys CASCADE"))
conn.execute(sa.text("DROP TABLE IF EXISTS sso_configs CASCADE"))
@@ -1,82 +0,0 @@
"""Phase 13: Operational Alerts — alert_rules and alert_instances tables.
Revision ID: b041alerts
Revises: b040ent
Create Date: 2026-05-21
"""
from alembic import op
import sqlalchemy as sa
revision = "b041alerts"
down_revision = "b040ent"
branch_labels = None
depends_on = None
def upgrade() -> None:
conn = op.get_bind()
# ── alert_rules ───────────────────────────────────────────────────────────
conn.execute(sa.text("""
CREATE TABLE IF NOT EXISTS alert_rules (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
name VARCHAR(300) NOT NULL,
description TEXT,
rule_type VARCHAR(50) NOT NULL,
severity VARCHAR(20) NOT NULL DEFAULT 'medium',
is_enabled BOOLEAN NOT NULL DEFAULT TRUE,
is_system BOOLEAN NOT NULL DEFAULT FALSE,
config JSONB NOT NULL DEFAULT '{}',
notify_in_app BOOLEAN NOT NULL DEFAULT TRUE,
notify_webhook BOOLEAN NOT NULL DEFAULT FALSE,
webhook_id UUID REFERENCES webhook_configs(id) ON DELETE SET NULL,
cooldown_hours INTEGER NOT NULL DEFAULT 24,
created_by UUID REFERENCES users(id) ON DELETE SET NULL,
created_at TIMESTAMP WITHOUT TIME ZONE DEFAULT now(),
last_fired_at TIMESTAMP WITHOUT TIME ZONE
)
"""))
conn.execute(sa.text(
"CREATE INDEX IF NOT EXISTS ix_alert_rules_type ON alert_rules (rule_type)"
))
conn.execute(sa.text(
"CREATE INDEX IF NOT EXISTS ix_alert_rules_enabled ON alert_rules (is_enabled)"
))
# ── alert_instances ───────────────────────────────────────────────────────
conn.execute(sa.text("""
CREATE TABLE IF NOT EXISTS alert_instances (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
rule_id UUID REFERENCES alert_rules(id) ON DELETE SET NULL,
rule_name VARCHAR(300) NOT NULL,
rule_type VARCHAR(50) NOT NULL,
severity VARCHAR(20) NOT NULL,
title VARCHAR(500) NOT NULL,
message TEXT NOT NULL,
details JSONB,
status VARCHAR(20) NOT NULL DEFAULT 'open',
acknowledged_by UUID REFERENCES users(id) ON DELETE SET NULL,
acknowledged_at TIMESTAMP WITHOUT TIME ZONE,
resolved_at TIMESTAMP WITHOUT TIME ZONE,
created_at TIMESTAMP WITHOUT TIME ZONE DEFAULT now()
)
"""))
conn.execute(sa.text(
"CREATE INDEX IF NOT EXISTS ix_alert_instances_rule_id ON alert_instances (rule_id)"
))
conn.execute(sa.text(
"CREATE INDEX IF NOT EXISTS ix_alert_instances_status ON alert_instances (status)"
))
conn.execute(sa.text(
"CREATE INDEX IF NOT EXISTS ix_alert_instances_severity ON alert_instances (severity)"
))
conn.execute(sa.text(
"CREATE INDEX IF NOT EXISTS ix_alert_instances_created ON alert_instances (created_at)"
))
def downgrade() -> None:
conn = op.get_bind()
conn.execute(sa.text("DROP TABLE IF EXISTS alert_instances CASCADE"))
conn.execute(sa.text("DROP TABLE IF EXISTS alert_rules CASCADE"))
@@ -1,25 +0,0 @@
"""Add jira_api_token to users table.
Revision ID: b042
Revises: b041_operational_alerts
Create Date: 2026-05-26
"""
from alembic import op
import sqlalchemy as sa
revision = "b042"
down_revision = "b041alerts"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"users",
sa.Column("jira_api_token", sa.String(500), nullable=True),
)
def downgrade() -> None:
op.drop_column("users", "jira_api_token")
@@ -1,28 +0,0 @@
"""Add jira_email to users table.
Allows each user to specify a separate email for Jira authentication,
independent of their Aegis account email.
Revision ID: b043
Revises: b042
Create Date: 2026-05-26
"""
from alembic import op
import sqlalchemy as sa
revision = "b043"
down_revision = "b042"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"users",
sa.Column("jira_email", sa.String(255), nullable=True),
)
def downgrade() -> None:
op.drop_column("users", "jira_email")
@@ -1,25 +0,0 @@
"""add tempo_api_token to users
Revision ID: b044
Revises: b043
Create Date: 2026-05-27
"""
from alembic import op
import sqlalchemy as sa
revision = "b044"
down_revision = "b043"
branch_labels = None
depends_on = None
def upgrade():
op.add_column(
"users",
sa.Column("tempo_api_token", sa.String(500), nullable=True),
)
def downgrade():
op.drop_column("users", "tempo_api_token")
@@ -1,16 +0,0 @@
"""Add blue_work_started_at to tests table."""
from alembic import op
import sqlalchemy as sa
revision = "b045"
down_revision = "b044"
branch_labels = None
depends_on = None
def upgrade():
op.add_column("tests", sa.Column("blue_work_started_at", sa.DateTime(), nullable=True))
def downgrade():
op.drop_column("tests", "blue_work_started_at")
@@ -1,22 +0,0 @@
"""Add 'disputed' value to teststate enum.
Revision ID: b046
Revises: b045
Create Date: 2026-06-03
"""
from alembic import op
revision = "b046"
down_revision = "b045"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.execute("ALTER TYPE teststate ADD VALUE IF NOT EXISTS 'disputed'")
def downgrade() -> None:
# PostgreSQL does not support removing enum values; downgrade is a no-op.
pass
@@ -1,27 +0,0 @@
"""Add start_date to campaigns.
Revision ID: b047
Revises: b046
Create Date: 2026-06-03
"""
from alembic import op
import sqlalchemy as sa
revision = "b047"
down_revision = "b046"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"campaigns",
sa.Column("start_date", sa.DateTime(), nullable=True),
)
op.create_index("ix_campaigns_start_date", "campaigns", ["start_date"])
def downgrade() -> None:
op.drop_index("ix_campaigns_start_date", table_name="campaigns")
op.drop_column("campaigns", "start_date")
@@ -1,39 +0,0 @@
"""Add evaluation_imports table.
Revision ID: b048
Revises: b047
Create Date: 2026-06-05
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import UUID
revision = "b048"
down_revision = "b047"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"evaluation_imports",
sa.Column("id", UUID(as_uuid=True), primary_key=True),
sa.Column("adversary_name", sa.String, nullable=False),
sa.Column("adversary_display", sa.String, nullable=False),
sa.Column("eval_round", sa.Integer, nullable=False),
sa.Column("imported_at", sa.DateTime, nullable=False),
sa.Column("imported_by", UUID(as_uuid=True), sa.ForeignKey("users.id"), nullable=True),
sa.Column("tests_created", sa.Integer, default=0),
sa.Column("techniques_covered", sa.Integer, default=0),
sa.Column("status", sa.String, default="completed"),
sa.Column("notes", sa.Text, nullable=True),
)
op.create_index("ix_evaluation_imports_adversary", "evaluation_imports", ["adversary_name"])
op.create_index("ix_evaluation_imports_round", "evaluation_imports", ["eval_round"])
def downgrade() -> None:
op.drop_index("ix_evaluation_imports_round", table_name="evaluation_imports")
op.drop_index("ix_evaluation_imports_adversary", table_name="evaluation_imports")
op.drop_table("evaluation_imports")
+3 -3
View File
@@ -2,7 +2,7 @@
This module provides pure functions for: This module provides pure functions for:
- Hashing and verifying passwords using bcrypt via passlib. - Hashing and verifying passwords using bcrypt via passlib.
- Creating JWT access tokens using PyJWT. - Creating JWT access tokens using python-jose.
- Managing a Redis-backed token blacklist for revocation. - Managing a Redis-backed token blacklist for revocation.
No endpoints are defined here. No endpoints are defined here.
@@ -17,8 +17,8 @@ import uuid as _uuid
# Import datetime, timedelta, timezone from datetime # Import datetime, timedelta, timezone from datetime
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
# Import jwt (PyJWT) # Import jwt from jose
import jwt from jose import jwt
# Import CryptContext from passlib.context # Import CryptContext from passlib.context
from passlib.context import CryptContext from passlib.context import CryptContext
+8 -24
View File
@@ -39,7 +39,8 @@ class Settings(BaseSettings):
SECRET_KEY: str = "" SECRET_KEY: str = ""
# Assign ALGORITHM = "HS256" # Assign ALGORITHM = "HS256"
ALGORITHM: str = "HS256" ALGORITHM: str = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES: int = 480 # 8 hours — /auth/refresh extends active sessions # Assign ACCESS_TOKEN_EXPIRE_MINUTES = 15 # short-lived for security; configurable via env
ACCESS_TOKEN_EXPIRE_MINUTES: int = 15 # short-lived for security; configurable via env
# ── Redis ───────────────────────────────────────────────────────── # ── Redis ─────────────────────────────────────────────────────────
REDIS_URL: str = "redis://redis:6379/0" REDIS_URL: str = "redis://redis:6379/0"
@@ -56,10 +57,7 @@ class Settings(BaseSettings):
# ── MinIO / S3 ─────────────────────────────────────────────────── # ── MinIO / S3 ───────────────────────────────────────────────────
MINIO_ENDPOINT: str = "minio:9000" MINIO_ENDPOINT: str = "minio:9000"
# Public hostname used in presigned URLs returned to browsers. # Assign MINIO_ACCESS_KEY = "minioadmin"
# In production set this to <server-ip>:9000 (or a public FQDN) so
# the browser can reach MinIO directly. Defaults to MINIO_ENDPOINT.
MINIO_PUBLIC_ENDPOINT: str = ""
MINIO_ACCESS_KEY: str = "minioadmin" MINIO_ACCESS_KEY: str = "minioadmin"
# Assign MINIO_SECRET_KEY = "minioadmin" # Assign MINIO_SECRET_KEY = "minioadmin"
MINIO_SECRET_KEY: str = "minioadmin" MINIO_SECRET_KEY: str = "minioadmin"
@@ -83,11 +81,10 @@ class Settings(BaseSettings):
JIRA_IS_CLOUD: bool = True JIRA_IS_CLOUD: bool = True
# Assign JIRA_DEFAULT_PROJECT = "" # Assign JIRA_DEFAULT_PROJECT = ""
JIRA_DEFAULT_PROJECT: str = "" JIRA_DEFAULT_PROJECT: str = ""
JIRA_ISSUE_TYPE_TEST: str = "Task" # tests (campaign or standalone) # Assign JIRA_ISSUE_TYPE_TEST = "Task"
JIRA_ISSUE_TYPE_CAMPAIGN: str = "Epic" # campaigns (under Initiative) JIRA_ISSUE_TYPE_TEST: str = "Task"
# Jira custom field ID for "Start date" — Jira Cloud team-managed: customfield_10015 # Assign JIRA_ISSUE_TYPE_CAMPAIGN = "Epic"
# Override with the correct field ID for your Jira instance if different. JIRA_ISSUE_TYPE_CAMPAIGN: str = "Epic"
JIRA_START_DATE_FIELD: str = "customfield_10015"
# ── Tempo Integration ───────────────────────────────────────────── # ── Tempo Integration ─────────────────────────────────────────────
TEMPO_ENABLED: bool = False TEMPO_ENABLED: bool = False
@@ -97,9 +94,6 @@ class Settings(BaseSettings):
TEMPO_API_VERSION: int = 4 TEMPO_API_VERSION: int = 4
# Assign TEMPO_DEFAULT_WORK_TYPE = "Red Team" # Assign TEMPO_DEFAULT_WORK_TYPE = "Red Team"
TEMPO_DEFAULT_WORK_TYPE: str = "Red Team" TEMPO_DEFAULT_WORK_TYPE: str = "Red Team"
# Tempo API base URL — use https://api.eu.tempo.io/4 for EU workspaces.
# Can also be set via system_configs key "tempo.base_url" at runtime.
TEMPO_BASE_URL: str = "" # empty → falls back to https://api.tempo.io/4
# ── OSINT / Intelligence ──────────────────────────────────────── # ── OSINT / Intelligence ────────────────────────────────────────
NVD_API_KEY: str = "" # optional; increases NVD rate limit from 5/30s to 50/30s NVD_API_KEY: str = "" # optional; increases NVD rate limit from 5/30s to 50/30s
@@ -109,22 +103,12 @@ class Settings(BaseSettings):
# ── Reporting ───────────────────────────────────────────────────── # ── Reporting ─────────────────────────────────────────────────────
REPORT_TEMPLATES_DIR: str = "app/templates/reports" REPORT_TEMPLATES_DIR: str = "app/templates/reports"
# Assign REPORT_OUTPUT_DIR = "/tmp/aegis_reports" # Assign REPORT_OUTPUT_DIR = "/tmp/aegis_reports"
REPORT_OUTPUT_DIR: str = "/app/reports" REPORT_OUTPUT_DIR: str = "/tmp/aegis_reports"
# Assign COMPANY_NAME = "Organization" # Assign COMPANY_NAME = "Organization"
COMPANY_NAME: str = "Organization" COMPANY_NAME: str = "Organization"
# Assign COMPANY_LOGO_PATH = "app/templates/reports/assets/logo.png" # Assign COMPANY_LOGO_PATH = "app/templates/reports/assets/logo.png"
COMPANY_LOGO_PATH: str = "app/templates/reports/assets/logo.png" COMPANY_LOGO_PATH: str = "app/templates/reports/assets/logo.png"
# ── Email / SMTP ──────────────────────────────────────────────────
SMTP_ENABLED: bool = False
SMTP_HOST: str = ""
SMTP_PORT: int = 587
SMTP_USERNAME: str = ""
SMTP_PASSWORD: str = ""
SMTP_FROM_EMAIL: str = "aegis@company.com"
SMTP_USE_TLS: bool = True
PLATFORM_URL: str = "http://localhost:5173" # base URL for links in emails
# ── Scoring weights (must sum to 100) ──────────────────────────── # ── Scoring weights (must sum to 100) ────────────────────────────
SCORING_WEIGHT_TESTS: int = 40 SCORING_WEIGHT_TESTS: int = 40
# Assign SCORING_WEIGHT_DETECTION_RULES = 25 # Assign SCORING_WEIGHT_DETECTION_RULES = 25
+10 -60
View File
@@ -3,7 +3,6 @@
Provides: Provides:
- ``get_current_user``: decodes JWT from HttpOnly cookie (preferred) or - ``get_current_user``: decodes JWT from HttpOnly cookie (preferred) or
Authorization header (fallback), fetches user from DB, raises 401 on failure. Authorization header (fallback), fetches user from DB, raises 401 on failure.
Also accepts Aegis API keys (``aegis_…`` prefix) as Bearer tokens.
- ``require_role``: factory that returns a dependency enforcing a specific role - ``require_role``: factory that returns a dependency enforcing a specific role
(admins always pass). (admins always pass).
""" """
@@ -20,8 +19,8 @@ from fastapi import Cookie, Depends, HTTPException, status
# Import OAuth2PasswordBearer from fastapi.security # Import OAuth2PasswordBearer from fastapi.security
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
# Import jwt (PyJWT) # Import JWTError, jwt from jose
import jwt from jose import JWTError, jwt
# Import Session from sqlalchemy.orm # Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -37,7 +36,6 @@ from app.database import get_db
# Import User from app.models.user # Import User from app.models.user
from app.models.user import User from app.models.user import User
from app.models.api_key import KEY_PREFIX
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# OAuth2 scheme (reads Authorization header — used as fallback / Swagger UI) # OAuth2 scheme (reads Authorization header — used as fallback / Swagger UI)
@@ -100,15 +98,7 @@ async def get_current_user(
# Raise credentials_exception # Raise credentials_exception
raise credentials_exception raise credentials_exception
# ── API Key path (Bearer token starts with "aegis_") ────────────────── # Attempt the following; catch errors below
if token.startswith(KEY_PREFIX):
from app.services.api_key_service import authenticate_raw_key
user = authenticate_raw_key(db, token)
if user is None:
raise credentials_exception
return user
# ── JWT path ──────────────────────────────────────────────────────────
try: try:
# Assign payload = jwt.decode( # Assign payload = jwt.decode(
payload = jwt.decode( payload = jwt.decode(
@@ -129,8 +119,8 @@ async def get_current_user(
if jti and auth_lib.is_token_blacklisted(jti): if jti and auth_lib.is_token_blacklisted(jti):
# Raise revoked_exception # Raise revoked_exception
raise revoked_exception raise revoked_exception
# Handle any JWT validation error (expired, invalid signature, malformed) # Handle JWTError
except jwt.exceptions.InvalidTokenError: except JWTError:
# Raise credentials_exception # Raise credentials_exception
raise credentials_exception raise credentials_exception
@@ -172,27 +162,12 @@ async def require_password_changed(
return current_user return current_user
def _check_api_key_scope(user: User, required_scope: str) -> None: # Define function require_role
"""Raise 403 if the request was authenticated via an API key that lacks *required_scope*. def require_role(required_role: str) -> Callable[..., object]:
When authenticated via JWT (browser session), ``_api_key_scopes`` is not set
and the check is skipped — full access is granted based on role alone.
"""
key_scopes = getattr(user, "_api_key_scopes", None)
if key_scopes is not None and required_scope not in key_scopes:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"API key scope '{required_scope}' required for this operation",
)
def require_role(required_role: str):
"""Return a FastAPI dependency that enforces *required_role*. """Return a FastAPI dependency that enforces *required_role*.
The dependency allows the request to proceed when The dependency allows the request to proceed when
``user.role == required_role`` **or** ``user.role == "admin"``. ``user.role == required_role`` **or** ``user.role == "admin"``.
Also enforces API key scopes: admin-role endpoints require the ``admin``
scope; all other role-restricted endpoints require ``write``.
Otherwise it raises :class:`~fastapi.HTTPException` **403**. Otherwise it raises :class:`~fastapi.HTTPException` **403**.
""" """
@@ -210,8 +185,7 @@ def require_role(required_role: str):
# Keyword argument: detail # Keyword argument: detail
detail="Not enough permissions", detail="Not enough permissions",
) )
scope = "admin" if required_role == "admin" else "write" # Return current_user
_check_api_key_scope(current_user, scope)
return current_user return current_user
# Return role_checker # Return role_checker
@@ -222,11 +196,7 @@ def require_role(required_role: str):
def require_any_role(*roles: str) -> Callable[..., object]: def require_any_role(*roles: str) -> Callable[..., object]:
"""Return a FastAPI dependency that enforces **any** of the given *roles*. """Return a FastAPI dependency that enforces **any** of the given *roles*.
Admins always pass. Also enforces API key scopes: if the only accepted Admins always pass. Usage example::
role is ``admin``, the key must carry the ``admin`` scope; otherwise the
``write`` scope is required.
Usage example::
@router.patch("/resource", dependencies=[Depends(require_any_role("red_lead", "blue_lead"))]) @router.patch("/resource", dependencies=[Depends(require_any_role("red_lead", "blue_lead"))])
""" """
@@ -245,28 +215,8 @@ def require_any_role(*roles: str) -> Callable[..., object]:
# Keyword argument: detail # Keyword argument: detail
detail="Not enough permissions", detail="Not enough permissions",
) )
scope = "admin" if set(roles) == {"admin"} else "write" # Return current_user
_check_api_key_scope(current_user, scope)
return current_user return current_user
# Return role_checker # Return role_checker
return role_checker return role_checker
def require_scope(scope: str):
"""Return a dependency that enforces the API key carries *scope*.
JWT-authenticated requests (browser sessions) bypass this check entirely.
Use on mutation endpoints that don't already use ``require_role`` /
``require_any_role``::
@router.post("/resource", dependencies=[Depends(require_scope("write"))])
"""
async def scope_checker(
current_user: User = Depends(get_current_user),
) -> User:
_check_api_key_scope(current_user, scope)
return current_user
return scope_checker
+12 -18
View File
@@ -221,17 +221,14 @@ class TechniqueEntity:
) -> TechniqueStatus: ) -> TechniqueStatus:
"""Recompute ``status_global`` from a list of (state, detection_result) pairs. """Recompute ``status_global`` from a list of (state, detection_result) pairs.
Rules (v3): Rules (v2):
1. No tests -> not_evaluated 1. No tests -> not_evaluated
2. All tests validated -> inspect detection results: 2. All validated -> inspect detection results:
a. All detected AND ≥ 1 validated test -> validated - All detected -> validated
b. Any partially_detected -> partial - Any partially_detected -> partial
d. Otherwise (no detected results) -> not_covered - Otherwise -> not_covered
3. Some validated, others in intermediate states -> partial 3. Some validated, others in progress -> partial
4. All tests in intermediate states (draft/executing/evaluating/review/rejected) 4. All in intermediate states -> in_progress
-> in_progress
Minimum validated count for "validated": 1 test.
Args: Args:
test_snapshots (list[tuple[str, str | None]]): Each element is a test_snapshots (list[tuple[str, str | None]]): Each element is a
@@ -243,8 +240,7 @@ class TechniqueEntity:
TechniqueStatus: The newly computed status, which is also stored on TechniqueStatus: The newly computed status, which is also stored on
the entity's ``status_global`` field. the entity's ``status_global`` field.
""" """
min_validated_for_full = 1 # require ≥ N validated tests for "validated" # Assign tests = [
tests = [ tests = [
_TestSnapshot( _TestSnapshot(
# Keyword argument: state # Keyword argument: state
@@ -261,15 +257,13 @@ class TechniqueEntity:
self.status_global = TechniqueStatus.not_evaluated self.status_global = TechniqueStatus.not_evaluated
# Alternative: all(t.state == TestState.validated for t in tests) # Alternative: all(t.state == TestState.validated for t in tests)
elif all(t.state == TestState.validated for t in tests): elif all(t.state == TestState.validated for t in tests):
validated_count = len(tests) # Assign results = [t.detection_result for t in tests if t.detection_result]
results = [t.detection_result for t in tests if t.detection_result] results = [t.detection_result for t in tests if t.detection_result]
# Check: results and all(r == TestResult.detected or r == "detected" for r i... # Check: results and all(r == TestResult.detected or r == "detected" for r i...
if results and all(r == TestResult.detected or r == "detected" for r in results): if results and all(r == TestResult.detected or r == "detected" for r in results):
# Need at least min_validated_for_full tests for "validated" # Assign self.status_global = TechniqueStatus.validated
if validated_count >= min_validated_for_full: self.status_global = TechniqueStatus.validated
self.status_global = TechniqueStatus.validated # elif any(
else:
self.status_global = TechniqueStatus.partial
elif any( elif any(
# Keyword argument: r # Keyword argument: r
r == TestResult.partially_detected or r == "partially_detected" r == TestResult.partially_detected or r == "partially_detected"
-1
View File
@@ -43,7 +43,6 @@ class TestState(str, enum.Enum):
validated = "validated" validated = "validated"
# Assign rejected = "rejected" # Assign rejected = "rejected"
rejected = "rejected" rejected = "rejected"
disputed = "disputed" # one lead approved, the other rejected
# Define class TeamSide # Define class TeamSide
+37 -18
View File
@@ -68,7 +68,6 @@ class TestState(str, enum.Enum):
validated = "validated" validated = "validated"
# Assign rejected = "rejected" # Assign rejected = "rejected"
rejected = "rejected" rejected = "rejected"
disputed = "disputed" # one lead approved, the other rejected
# Assign VALID_TRANSITIONS = { # Assign VALID_TRANSITIONS = {
@@ -76,8 +75,7 @@ VALID_TRANSITIONS: dict[TestState, list[TestState]] = {
TestState.draft: [TestState.red_executing], TestState.draft: [TestState.red_executing],
TestState.red_executing: [TestState.blue_evaluating], TestState.red_executing: [TestState.blue_evaluating],
TestState.blue_evaluating: [TestState.in_review], TestState.blue_evaluating: [TestState.in_review],
TestState.in_review: [TestState.validated, TestState.rejected, TestState.disputed], TestState.in_review: [TestState.validated, TestState.rejected],
TestState.disputed: [TestState.validated, TestState.rejected],
TestState.rejected: [TestState.draft], TestState.rejected: [TestState.draft],
TestState.validated: [], TestState.validated: [],
} }
@@ -593,23 +591,37 @@ class TestEntity:
def check_dual_validation(self) -> None: def check_dual_validation(self) -> None:
"""Evaluate both leads' votes and advance state if appropriate. """Evaluate both leads' votes and advance state if appropriate.
Rules (v2 — consensus required): - Both **approved** -> ``validated``
- Both **approved** -> ``validated`` - Either **rejected** -> ``rejected``
- Both **rejected** -> ``rejected`` - Otherwise no change (waiting for the other lead).
- One approved + one rejected -> ``disputed`` (conflict, needs discussion)
- Otherwise (one or both still pending) -> no change
Called automatically by :meth:`validate_red` and :meth:`validate_blue`. Called automatically by :meth:`validate_red` and :meth:`validate_blue`.
Also available as a standalone entry point for backward compatibility
when validation fields are set externally.
Returns:
None
""" """
# Call self._check_dual_validation() # Call self._check_dual_validation()
self._check_dual_validation() self._check_dual_validation()
# Define function _assert_in_review # Define function _assert_in_review
def _assert_in_review(self, side: str) -> None: def _assert_in_review(self, side: str) -> None:
if self.state not in (TestState.in_review, TestState.disputed): """Raise InvalidOperationError unless the test is in ``in_review`` state.
Args:
side (str): The team side being validated (``"red"`` or ``"blue"``),
used in the error message.
Returns:
None
"""
# Check: self.state != TestState.in_review
if self.state != TestState.in_review:
# Raise InvalidOperationError
raise InvalidOperationError( raise InvalidOperationError(
f"Cannot validate {side} side while test is in " f"Cannot validate {side} side while test is in "
f"'{self.state.value}' state (must be in_review or disputed)" f"'{self.state.value}' state (must be in_review)"
) )
# Apply the @staticmethod decorator # Apply the @staticmethod decorator
@@ -634,15 +646,22 @@ class TestEntity:
# Define function _check_dual_validation # Define function _check_dual_validation
def _check_dual_validation(self) -> None: def _check_dual_validation(self) -> None:
"""Advance the test state once both leads have voted.""" """Advance to ``validated`` or ``rejected`` once both leads have voted.
r, b = self.red_validation_status, self.blue_validation_status
if r == "approved" and b == "approved": Returns:
None
"""
# r, b = self.red_validation_status, self.blue_validation_status
r, b = self.red_validation_status, self.blue_validation_status
# Check: r == "rejected" or b == "rejected"
if r == "rejected" or b == "rejected":
# Assign self.state = TestState.rejected
self.state = TestState.rejected
# Call self._events.append()
self._events.append(DomainEvent("dual_validation_rejected"))
# Alternative: r == "approved" and b == "approved"
elif r == "approved" and b == "approved":
# Assign self.state = TestState.validated
self.state = TestState.validated self.state = TestState.validated
# Call self._events.append() # Call self._events.append()
self._events.append(DomainEvent("dual_validation_approved")) self._events.append(DomainEvent("dual_validation_approved"))
elif r == "rejected" or b == "rejected":
# Any rejection is a veto — one lead can reject without waiting for the other
self.state = TestState.rejected
self._events.append(DomainEvent("dual_validation_rejected"))
+8 -326
View File
@@ -12,7 +12,6 @@ sessions.
# Import logging # Import logging
import logging import logging
from datetime import datetime, timedelta, timezone
# Import BackgroundScheduler from apscheduler.schedulers.background # Import BackgroundScheduler from apscheduler.schedulers.background
from apscheduler.schedulers.background import BackgroundScheduler from apscheduler.schedulers.background import BackgroundScheduler
@@ -64,7 +63,7 @@ scheduler = BackgroundScheduler()
def _run_mitre_sync() -> None: def _run_mitre_sync() -> None:
"""Execute a MITRE sync inside its own DB session.""" """Execute a MITRE sync inside its own DB session."""
from app.services.webhook_service import dispatch_webhook # Log info: "Scheduled MITRE sync job starting..."
logger.info("Scheduled MITRE sync job starting...") logger.info("Scheduled MITRE sync job starting...")
# Assign db = SessionLocal() # Assign db = SessionLocal()
db = SessionLocal() db = SessionLocal()
@@ -74,7 +73,7 @@ def _run_mitre_sync() -> None:
summary = sync_mitre(db) summary = sync_mitre(db)
# Log info: "Scheduled MITRE sync job finished — %s", summary # Log info: "Scheduled MITRE sync job finished — %s", summary
logger.info("Scheduled MITRE sync job finished — %s", summary) logger.info("Scheduled MITRE sync job finished — %s", summary)
dispatch_webhook("mitre.synced", {"created": summary.get("created", 0), "updated": summary.get("updated", 0)}) # Handle Exception
except Exception: except Exception:
# Log exception: "Scheduled MITRE sync job failed" # Log exception: "Scheduled MITRE sync job failed"
logger.exception("Scheduled MITRE sync job failed") logger.exception("Scheduled MITRE sync job failed")
@@ -164,96 +163,7 @@ def _run_recurring_campaigns() -> None:
db.close() db.close()
def _run_scheduled_campaign_activation() -> None: # Define function _run_intel_scan
"""Auto-activate campaigns whose start_date has arrived.
Finds all campaigns in 'draft' state with a start_date <= now,
activates them, creates Jira tickets, and notifies the red_tech team.
Runs every hour so campaigns activate within ~1 hour of their scheduled time.
"""
logger.info("Scheduled campaign auto-activation check starting...")
db = SessionLocal()
try:
from datetime import datetime as _dt
from app.models.campaign import Campaign
from app.models.user import User
from app.services.campaign_crud_service import activate_campaign as _activate
from app.services.notification_service import notify_role
from app.services.audit_service import log_action
now = _dt.utcnow()
due_campaigns = (
db.query(Campaign)
.filter(
Campaign.status == "draft",
Campaign.start_date != None, # noqa: E711
Campaign.start_date <= now,
)
.all()
)
activated = 0
for campaign in due_campaigns:
try:
_activate(db, str(campaign.id))
notify_role(
db,
role="red_tech",
type="campaign_activated",
title="Campaign auto-activated",
message=f'Campaign "{campaign.name}" has been automatically activated on its scheduled start date.',
entity_type="campaign",
entity_id=campaign.id,
)
log_action(
db,
user_id=None,
action="auto_activate_campaign",
entity_type="campaign",
entity_id=campaign.id,
details={"name": campaign.name, "start_date": str(campaign.start_date)},
)
# Create Jira tickets non-fatally
try:
from app.services.jira_service import (
auto_create_campaign_issue,
auto_create_test_issue,
get_campaign_jira_key,
get_test_jira_key,
)
# Use first admin user as actor for Jira auth
admin_user = db.query(User).filter(User.role == "admin").first()
if admin_user:
db.refresh(campaign)
campaign_jira_key = get_campaign_jira_key(db, str(campaign.id))
if not campaign_jira_key:
campaign_jira_key = auto_create_campaign_issue(db, campaign, admin_user)
if campaign_jira_key:
for ct in campaign.campaign_tests:
if ct.test and not get_test_jira_key(db, ct.test.id):
auto_create_test_issue(
db, ct.test, admin_user,
parent_ticket_override=campaign_jira_key,
campaign_start_date=campaign.start_date,
)
except Exception:
logger.exception("Jira auto-create failed for auto-activated campaign %s", campaign.id)
db.commit()
activated += 1
logger.info("Auto-activated campaign %s (%s)", campaign.id, campaign.name)
except Exception:
logger.exception("Failed to auto-activate campaign %s", campaign.id)
db.rollback()
logger.info("Campaign auto-activation check finished — activated %d campaigns", activated)
except Exception:
logger.exception("Campaign auto-activation job failed")
finally:
db.close()
def _run_intel_scan() -> None: def _run_intel_scan() -> None:
"""Execute an intel scan inside its own DB session.""" """Execute an intel scan inside its own DB session."""
# Log info: "Scheduled intel scan job starting..." # Log info: "Scheduled intel scan job starting..."
@@ -276,83 +186,7 @@ def _run_intel_scan() -> None:
db.close() db.close()
def _run_evaluation_round_check() -> None: # Define function _run_osint_enrichment
"""Weekly job: check if a new ATT&CK Evaluation round is available.
If a new round is found it is imported automatically and an admin
notification is created so the team knows new baseline data is available.
"""
logger.info("ATT&CK Evaluations new-round check starting...")
db = SessionLocal()
try:
from app.services.attck_evaluations_service import check_for_new_round, import_evaluation_round
from app.models.user import User as UserModel
result = check_for_new_round(db)
if result.get("error"):
logger.warning("ATT&CK Evaluations check failed: %s", result["error"])
return
if not result.get("new_round_available"):
logger.info(
"ATT&CK Evaluations check — latest round '%s' already imported",
result.get("latest_round", {}).get("display_name", "?"),
)
return
latest = result["latest_round"]
logger.info(
"New ATT&CK Evaluation round detected: %s (round %d) — starting auto-import",
latest["display_name"], latest["eval_round"],
)
# Use the first admin user as the importer (system action)
admin = db.query(UserModel).filter(UserModel.role == "admin").first()
if not admin:
logger.warning("ATT&CK Evaluations auto-import: no admin user found — skipping")
return
summary = import_evaluation_round(
db,
latest["name"],
latest["display_name"],
latest["eval_round"],
admin,
)
logger.info(
"ATT&CK Evaluations auto-import complete — round %d (%s): %d tests created",
latest["eval_round"], latest["display_name"], summary["created"],
)
# Notify all admins
try:
from app.services.notification_service import create_notification
admins = db.query(UserModel).filter(UserModel.role == "admin").all()
for adm in admins:
create_notification(
db,
user_id=adm.id,
title="New ATT&CK Evaluation round imported",
message=(
f"Round {latest['eval_round']}{latest['display_name']}"
f"has been automatically imported. "
f"{summary['created']} tests created in In Review state. "
f"Blue Leads must validate each result before it counts as coverage."
),
notification_type="eval_import",
entity_type="evaluation",
entity_id=None,
)
db.commit()
except Exception:
logger.warning("Failed to send eval import notifications", exc_info=True)
except Exception:
logger.exception("ATT&CK Evaluations round check job failed")
finally:
db.close()
def _run_osint_enrichment() -> None: def _run_osint_enrichment() -> None:
"""Execute weekly OSINT enrichment inside its own DB session.""" """Execute weekly OSINT enrichment inside its own DB session."""
# Log info: "Scheduled OSINT enrichment job starting..." # Log info: "Scheduled OSINT enrichment job starting..."
@@ -375,61 +209,7 @@ def _run_osint_enrichment() -> None:
db.close() db.close()
_FREQUENCY_INTERVALS: dict[str, timedelta] = { # Define function _run_stale_detection
"daily": timedelta(days=1),
"weekly": timedelta(weeks=1),
"monthly": timedelta(days=30),
}
def _run_data_sources_sync() -> None:
"""Check all enabled data sources and sync those that are overdue."""
logger.info("Scheduled data sources sync check starting...")
db = SessionLocal()
try:
from app.models.data_source import DataSource
from app.services.data_source_service import sync_source
now = datetime.now(timezone.utc)
sources = (
db.query(DataSource)
.filter(DataSource.is_enabled == True) # noqa: E712
.all()
)
synced = 0
for ds in sources:
freq = ds.sync_frequency
if not freq or freq == "manual":
continue
interval = _FREQUENCY_INTERVALS.get(freq)
if interval is None:
continue
last = ds.last_sync_at
if last is None:
# Never synced — run it now
overdue = True
else:
# Make last timezone-aware if needed
if last.tzinfo is None:
last = last.replace(tzinfo=timezone.utc)
overdue = now - last >= interval
if overdue:
logger.info(
"Data source '%s' is overdue (freq=%s, last=%s) — syncing",
ds.name, freq, last,
)
try:
sync_source(db, str(ds.id))
synced += 1
except Exception:
logger.exception("Failed to sync data source '%s'", ds.name)
logger.info("Data sources sync check finished — %d source(s) synced", synced)
except Exception:
logger.exception("Data sources sync check failed")
finally:
db.close()
def _run_stale_detection() -> None: def _run_stale_detection() -> None:
"""Execute daily stale coverage detection inside its own DB session.""" """Execute daily stale coverage detection inside its own DB session."""
# Log info: "Scheduled stale coverage detection starting..." # Log info: "Scheduled stale coverage detection starting..."
@@ -452,53 +232,6 @@ def _run_stale_detection() -> None:
db.close() db.close()
def _run_decay_engine() -> None:
"""Execute the decay engine inside its own DB session."""
logger.info("Scheduled decay engine job starting...")
db = SessionLocal()
try:
from app.services.decay_engine_service import run_decay_engine
results = run_decay_engine(db)
logger.info("Decay engine job finished — %s", results)
except Exception:
logger.exception("Decay engine job failed")
finally:
db.close()
def _run_queue_generation() -> None:
"""Generate revalidation queue items for analysts — runs after decay engine."""
logger.info("Scheduled revalidation queue generation starting...")
db = SessionLocal()
try:
from app.services.revalidation_queue_service import generate_queue_items
results = generate_queue_items(db)
logger.info("Queue generation finished — %s", results)
except Exception:
logger.exception("Queue generation job failed")
finally:
db.close()
def _run_alert_evaluation() -> None:
"""Evaluate all enabled operational alert rules (hourly)."""
logger.info("Scheduled alert evaluation job starting...")
db = SessionLocal()
try:
from app.services.operational_alert_service import evaluate_all_rules
result = evaluate_all_rules(db)
logger.info(
"Alert evaluation finished — %d rules, %d alerts fired in %.3fs",
result["rules_evaluated"],
result["alerts_fired"],
result["duration_seconds"],
)
except Exception:
logger.exception("Alert evaluation job failed")
finally:
db.close()
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Scheduler bootstrap # Scheduler bootstrap
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -575,14 +308,6 @@ def start_scheduler() -> None:
replace_existing=True, replace_existing=True,
) )
# Call scheduler.add_job() # Call scheduler.add_job()
scheduler.add_job(
_run_scheduled_campaign_activation,
trigger="interval",
hours=1,
id="scheduled_campaign_activation",
name="Auto-activate campaigns on start_date (hourly)",
replace_existing=True,
)
scheduler.add_job( scheduler.add_job(
_run_recurring_campaigns, _run_recurring_campaigns,
# Keyword argument: trigger # Keyword argument: trigger
@@ -652,50 +377,7 @@ def start_scheduler() -> None:
# Keyword argument: replace_existing # Keyword argument: replace_existing
replace_existing=True, replace_existing=True,
) )
scheduler.add_job( # Call scheduler.start()
_run_data_sources_sync,
trigger="interval",
hours=6,
id="data_sources_sync",
name="Data sources auto-sync (every 6h)",
replace_existing=True,
)
scheduler.add_job(
_run_decay_engine,
trigger="cron",
hour=2,
minute=0,
id="decay_engine",
name="Detection decay engine (daily 02:00)",
replace_existing=True,
)
scheduler.add_job(
_run_queue_generation,
trigger="cron",
hour=2,
minute=30,
id="queue_generation",
name="Revalidation queue generation (daily 02:30)",
replace_existing=True,
)
scheduler.add_job(
_run_alert_evaluation,
trigger="interval",
hours=1,
id="alert_evaluation",
name="Operational alert evaluation (hourly)",
replace_existing=True,
)
scheduler.add_job(
_run_evaluation_round_check,
trigger="cron",
day_of_week="mon",
hour=6,
minute=0,
id="attck_evaluation_check",
name="ATT&CK Evaluations new-round check (Mondays 06:00)",
replace_existing=True,
)
scheduler.start() scheduler.start()
# Log info: # Log info:
logger.info( logger.info(
@@ -707,6 +389,6 @@ def start_scheduler() -> None:
"recurring_campaigns (daily), jira_sync (1h), " "recurring_campaigns (daily), jira_sync (1h), "
# Literal argument value # Literal argument value
"osint_enrichment (weekly), stale_detection (daily), " "osint_enrichment (weekly), stale_detection (daily), "
"retention_policies (daily), data_sources_sync (6h), " # Literal argument value
"alert_evaluation (1h), attck_evaluation_check (Mondays 06:00)" "retention_policies (daily)"
) )
+89 -89
View File
@@ -38,45 +38,10 @@ from slowapi.errors import RateLimitExceeded
# Import SQLAlchemyError from sqlalchemy.exc # Import SQLAlchemyError from sqlalchemy.exc
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from app.routers import auth as auth_router # Import settings as _settings from app.config
from app.routers import techniques as techniques_router from app.config import settings as _settings
from app.routers import tests as tests_router
from app.routers import evidence as evidence_router # Import DomainError from app.domain.errors
from app.routers import test_templates as test_templates_router
from app.routers import system as system_router
from app.routers import metrics as metrics_router
from app.routers import users as users_router
from app.routers import audit as audit_router
from app.routers import notifications as notifications_router
from app.routers import reports as reports_router
from app.routers import data_sources as data_sources_router
from app.routers import threat_actors as threat_actors_router
from app.routers import d3fend as d3fend_router
from app.routers import detection_rules as detection_rules_router
from app.routers import campaigns as campaigns_router
from app.routers import heatmap as heatmap_router
from app.routers import scores as scores_router
from app.routers import operational_metrics as operational_metrics_router
from app.routers import compliance as compliance_router
from app.routers import snapshots as snapshots_router
from app.routers import jira as jira_router
from app.routers import worklogs as worklogs_router
from app.routers import professional_reports as professional_reports_router
from app.routers import analytics as analytics_router
from app.routers import advanced_metrics as advanced_metrics_router
from app.routers import osint as osint_router
from app.routers import webhooks as webhooks_router
from app.routers import detection_lifecycle as detection_lifecycle_router
from app.routers import intel as intel_router
from app.routers import admin_config as admin_config_router
from app.routers import ownership as ownership_router
from app.routers import attack_paths as attack_paths_router
from app.routers import knowledge as knowledge_router
from app.routers import risk_intelligence as risk_router
from app.routers import executive_dashboard as dashboard_router
from app.routers import api_keys as api_keys_router
from app.routers import sso as sso_router
from app.routers import operational_alerts as alerts_router
from app.domain.errors import DomainError from app.domain.errors import DomainError
# Import scheduler, start_scheduler from app.jobs.mitre_sync_job # Import scheduler, start_scheduler from app.jobs.mitre_sync_job
@@ -93,15 +58,94 @@ from app.middleware.error_handler import domain_exception_handler
# Import RequestContextMiddleware from app.middleware.request_context # Import RequestContextMiddleware from app.middleware.request_context
from app.middleware.request_context import RequestContextMiddleware from app.middleware.request_context import RequestContextMiddleware
# Import advanced_metrics as advanced_metrics_router from app.routers
from app.routers import advanced_metrics as advanced_metrics_router
# Import analytics as analytics_router from app.routers
from app.routers import analytics as analytics_router
# Import audit as audit_router from app.routers
from app.routers import audit as audit_router
# Import auth as auth_router from app.routers
from app.routers import auth as auth_router
# Import campaigns as campaigns_router from app.routers
from app.routers import campaigns as campaigns_router
# Import compliance as compliance_router from app.routers
from app.routers import compliance as compliance_router
# Import d3fend as d3fend_router from app.routers
from app.routers import d3fend as d3fend_router
# Import data_sources as data_sources_router from app.routers
from app.routers import data_sources as data_sources_router
# Import detection_rules as detection_rules_router from app.routers
from app.routers import detection_rules as detection_rules_router
# Import evidence as evidence_router from app.routers
from app.routers import evidence as evidence_router
# Import heatmap as heatmap_router from app.routers
from app.routers import heatmap as heatmap_router
# Import jira as jira_router from app.routers
from app.routers import jira as jira_router
# Import metrics as metrics_router from app.routers
from app.routers import metrics as metrics_router
# Import notifications as notifications_router from app.routers
from app.routers import notifications as notifications_router
# Import operational_metrics as operational_metrics_router from app.routers
from app.routers import operational_metrics as operational_metrics_router
# Import osint as osint_router from app.routers
from app.routers import osint as osint_router
# Import professional_reports as professional_reports_ro... from app.routers
from app.routers import professional_reports as professional_reports_router
# Import reports as reports_router from app.routers
from app.routers import reports as reports_router
# Import scores as scores_router from app.routers
from app.routers import scores as scores_router
# Import snapshots as snapshots_router from app.routers
from app.routers import snapshots as snapshots_router
# Import system as system_router from app.routers
from app.routers import system as system_router
# Import techniques as techniques_router from app.routers
from app.routers import techniques as techniques_router
# Import test_templates as test_templates_router from app.routers
from app.routers import test_templates as test_templates_router
# Import tests as tests_router from app.routers
from app.routers import tests as tests_router
# Import threat_actors as threat_actors_router from app.routers
from app.routers import threat_actors as threat_actors_router
# Import users as users_router from app.routers
from app.routers import users as users_router
# Import worklogs as worklogs_router from app.routers
from app.routers import worklogs as worklogs_router
# Import ensure_bucket_exists from app.storage
from app.storage import ensure_bucket_exists from app.storage import ensure_bucket_exists
from app.config import settings as _settings
from starlette.middleware.base import BaseHTTPMiddleware
# Configure structured logging before any module initialises its own logger # Configure structured logging before any module initialises its own logger
setup_logging() setup_logging()
logger = logging.getLogger(__name__)
# ── Environment detection ───────────────────────────────────────────────── # ── Environment detection ─────────────────────────────────────────────────
_IS_PRODUCTION = os.environ.get("AEGIS_ENV", "").lower() == "production" _IS_PRODUCTION = os.environ.get("AEGIS_ENV", "").lower() == "production"
@@ -121,25 +165,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
ensure_bucket_exists() ensure_bucket_exists()
# Call start_scheduler() # Call start_scheduler()
start_scheduler() start_scheduler()
# Seed decay policies # Yield value
from app.database import SessionLocal
from app.seed_decay_policies import seed_decay_policies
db = SessionLocal()
try:
seed_decay_policies(db)
except Exception as e:
logger.warning("seed_decay_policies failed at startup: %s", e)
finally:
db.close()
# Seed operational alert system rules
db2 = SessionLocal()
try:
from app.services.operational_alert_service import seed_system_rules
seed_system_rules(db2)
except Exception as e:
logger.warning("seed_system_rules failed at startup: %s", e)
finally:
db2.close()
yield yield
# Graceful shutdown of the background scheduler # Graceful shutdown of the background scheduler
scheduler.shutdown(wait=False) scheduler.shutdown(wait=False)
@@ -167,21 +193,6 @@ app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
# Call app.add_middleware() # Call app.add_middleware()
app.add_middleware(RequestContextMiddleware) app.add_middleware(RequestContextMiddleware)
# ── No-cache middleware for all /api/ responses ───────────────────────────
# Prevents Cloudflare and browser caches from storing API responses,
# which would cause stale/empty data to be served after backend restarts.
class NoCacheAPIMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
response = await call_next(request)
if request.url.path.startswith("/api/"):
response.headers["Cache-Control"] = "no-store, no-cache, must-revalidate"
response.headers["Pragma"] = "no-cache"
return response
app.add_middleware(NoCacheAPIMiddleware)
# ── Domain exception → HTTP mapping ────────────────────────────────────── # ── Domain exception → HTTP mapping ──────────────────────────────────────
app.add_exception_handler(DomainError, domain_exception_handler) app.add_exception_handler(DomainError, domain_exception_handler)
@@ -243,8 +254,7 @@ app.include_router(scores_router.router, prefix="/api/v1")
app.include_router(operational_metrics_router.router, prefix="/api/v1") app.include_router(operational_metrics_router.router, prefix="/api/v1")
# Call app.include_router() # Call app.include_router()
app.include_router(compliance_router.router, prefix="/api/v1") app.include_router(compliance_router.router, prefix="/api/v1")
app.include_router(intel_router.router, prefix="/api/v1") # Call app.include_router()
app.include_router(admin_config_router.router, prefix="/api/v1")
app.include_router(snapshots_router.router, prefix="/api/v1") app.include_router(snapshots_router.router, prefix="/api/v1")
# Call app.include_router() # Call app.include_router()
app.include_router(jira_router.router, prefix="/api/v1") app.include_router(jira_router.router, prefix="/api/v1")
@@ -258,16 +268,6 @@ app.include_router(analytics_router.router, prefix="/api/v1")
app.include_router(advanced_metrics_router.router, prefix="/api/v1") app.include_router(advanced_metrics_router.router, prefix="/api/v1")
# Call app.include_router() # Call app.include_router()
app.include_router(osint_router.router, prefix="/api/v1") app.include_router(osint_router.router, prefix="/api/v1")
app.include_router(webhooks_router.router, prefix="/api/v1")
app.include_router(detection_lifecycle_router.router, prefix="/api/v1")
app.include_router(ownership_router.router, prefix="/api/v1")
app.include_router(attack_paths_router.router, prefix="/api/v1")
app.include_router(knowledge_router.router, prefix="/api/v1")
app.include_router(risk_router.router, prefix="/api/v1")
app.include_router(dashboard_router.router, prefix="/api/v1")
app.include_router(api_keys_router.router, prefix="/api/v1")
app.include_router(sso_router.router, prefix="/api/v1")
app.include_router(alerts_router.router, prefix="/api/v1")
# Apply the @app.get decorator # Apply the @app.get decorator
+59 -52
View File
@@ -1,51 +1,74 @@
"""SQLAlchemy ORM model definitions for all database tables.""" """SQLAlchemy ORM model definitions for all database tables."""
# Import all models here so Alembic can detect them # Import all models here so Alembic can detect them
from app.models.audit import AuditLog from app.models.audit import AuditLog
from app.models.notification import Notification
from app.models.data_source import DataSource # Import Campaign, CampaignTest from app.models.campaign
from app.models.detection_rule import DetectionRule
from app.models.threat_actor import ThreatActor, ThreatActorTechnique
from app.models.defensive_technique import DefensiveTechnique, DefensiveTechniqueMapping
from app.models.test_template_detection_rule import TestTemplateDetectionRule
from app.models.test_detection_result import TestDetectionResult
from app.models.campaign import Campaign, CampaignTest from app.models.campaign import Campaign, CampaignTest
from app.models.compliance import ComplianceFramework, ComplianceControl, ComplianceControlMapping
# Import from app.models.compliance
from app.models.compliance import (
ComplianceControl,
ComplianceControlMapping,
ComplianceFramework,
)
# Import CoverageSnapshot, SnapshotTechniqueState from app.models.coverage_snapshot
from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState
from app.models.jira_link import JiraLink, JiraLinkEntityType, JiraSyncDirection
from app.models.worklog import Worklog # Import DataSource from app.models.data_source
from app.models.osint_item import OsintItem from app.models.data_source import DataSource
from app.models.scoring_config import ScoringConfig
from app.models.enums import TechniqueStatus, TestState, TestResult, TeamSide # Import DefensiveTechnique, DefensiveTechniqueMapping from app.models.defensive_technique
from app.models.webhook_config import WebhookConfig from app.models.defensive_technique import DefensiveTechnique, DefensiveTechniqueMapping
from app.models.system_config import SystemConfig
from app.models.detection_lifecycle import ( # Import DetectionRule from app.models.detection_rule
DetectionAsset, DetectionTechniqueMapping, DetectionValidation, from app.models.detection_rule import DetectionRule
TechniqueConfidenceScore, InfrastructureChangeLog,
DetectionConfidence, DetectionHealthStatus, InvalidationReason, # Import TeamSide, TechniqueStatus, TestResult, TestState from app.models.enums
) from app.models.enums import TeamSide, TechniqueStatus, TestResult, TestState
from app.models.decay_policy import DecayPolicy
from app.models.ownership_queue import ( # Import Evidence from app.models.evidence
TechniqueOwnership, RevalidationQueueItem,
QueuePriority, QueueStatus, QueueReason,
)
from app.models.attack_path import (
AttackPath, AttackPathStep, AttackPathExecution,
AttackPathStepResult, TimelineEntry,
ExecutionStatus, StepResultStatus, TimelineActorSide, TimelineEntryType,
)
from app.models.knowledge import Playbook, PlaybookVersion, LessonLearned
from app.models.risk_intelligence import TechniqueRiskProfile
from app.models.executive_dashboard import PostureSnapshot
from app.models.api_key import ApiKey
from app.models.sso_config import SsoConfig
from app.models.operational_alert import AlertRule, AlertInstance
from app.models.evidence import Evidence from app.models.evidence import Evidence
# Import IntelItem from app.models.intel
from app.models.intel import IntelItem from app.models.intel import IntelItem
# Import JiraLink, JiraLinkEntityType, JiraSyncDirection from app.models.jira_link
from app.models.jira_link import JiraLink, JiraLinkEntityType, JiraSyncDirection
# Import Notification from app.models.notification
from app.models.notification import Notification
# Import OsintItem from app.models.osint_item
from app.models.osint_item import OsintItem
# Import ScoringConfig from app.models.scoring_config
from app.models.scoring_config import ScoringConfig
# Import Technique from app.models.technique
from app.models.technique import Technique from app.models.technique import Technique
# Import Test from app.models.test
from app.models.test import Test from app.models.test import Test
# Import TestDetectionResult from app.models.test_detection_result
from app.models.test_detection_result import TestDetectionResult
# Import TestTemplate from app.models.test_template
from app.models.test_template import TestTemplate from app.models.test_template import TestTemplate
# Import TestTemplateDetectionRule from app.models.test_template_detection_rule
from app.models.test_template_detection_rule import TestTemplateDetectionRule
# Import ThreatActor, ThreatActorTechnique from app.models.threat_actor
from app.models.threat_actor import ThreatActor, ThreatActorTechnique
# Import User from app.models.user
from app.models.user import User from app.models.user import User
# Import Worklog from app.models.worklog
from app.models.worklog import Worklog
# Assign __all__ = [ # Assign __all__ = [
__all__ = [ __all__ = [
# Literal argument value # Literal argument value
@@ -70,20 +93,4 @@ __all__ = [
"Worklog", "OsintItem", "ScoringConfig", "Worklog", "OsintItem", "ScoringConfig",
# Literal argument value # Literal argument value
"TechniqueStatus", "TestState", "TestResult", "TeamSide", "TechniqueStatus", "TestState", "TestResult", "TeamSide",
"WebhookConfig", "SystemConfig",
"DetectionAsset", "DetectionTechniqueMapping", "DetectionValidation",
"TechniqueConfidenceScore", "InfrastructureChangeLog",
"DetectionConfidence", "DetectionHealthStatus", "InvalidationReason", "DecayPolicy",
"TechniqueOwnership", "RevalidationQueueItem",
"QueuePriority", "QueueStatus", "QueueReason",
"AttackPath", "AttackPathStep", "AttackPathExecution",
"AttackPathStepResult", "TimelineEntry",
"ExecutionStatus", "StepResultStatus", "TimelineActorSide", "TimelineEntryType",
"Playbook", "PlaybookVersion", "LessonLearned",
"TechniqueRiskProfile",
"PostureSnapshot",
"ApiKey",
"SsoConfig",
"AlertRule",
"AlertInstance",
] ]
-81
View File
@@ -1,81 +0,0 @@
"""Phase 14: API Key model for programmatic access."""
import hashlib
import secrets
import uuid
from datetime import datetime
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Index, String, Text
from sqlalchemy.dialects.postgresql import JSONB, UUID
from sqlalchemy.orm import relationship
from app.database import Base
# ── Key generation constants ──────────────────────────────────────────────────
KEY_PREFIX = "aegis_"
KEY_BYTES = 32 # 32 random bytes → 64 hex chars → 70-char key total
DISPLAY_LEN = 12 # chars stored as prefix for UI display
def generate_raw_key() -> str:
"""Generate a fresh raw API key (must be shown to user only once)."""
return KEY_PREFIX + secrets.token_hex(KEY_BYTES)
def hash_key(raw_key: str) -> str:
"""SHA-256 hash of a raw API key for secure storage."""
return hashlib.sha256(raw_key.encode()).hexdigest()
def key_prefix_display(raw_key: str) -> str:
"""First DISPLAY_LEN characters of the raw key (safe for UI)."""
return raw_key[:DISPLAY_LEN]
# ── Valid scopes ──────────────────────────────────────────────────────────────
VALID_SCOPES = {"read", "write", "admin"}
class ApiKey(Base):
"""
Scoped API key for programmatic / BI / SOAR access.
The full raw key is **never stored** — only a SHA-256 hash.
The first 12 characters (``key_prefix``) are retained for display.
"""
__tablename__ = "api_keys"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
name = Column(String(200), nullable=False)
description = Column(Text, nullable=True)
# Display only — never use for auth
key_prefix = Column(String(DISPLAY_LEN + 1), nullable=False)
# Auth token — SHA-256 of the full raw key
key_hash = Column(String(64), nullable=False, unique=True)
# Owner
user_id = Column(
UUID(as_uuid=True),
ForeignKey("users.id", ondelete="CASCADE"),
nullable=False,
)
# Permissions
scopes = Column(JSONB, nullable=False, default=["read"]) # ["read","write","admin"]
# Lifecycle
last_used_at = Column(DateTime, nullable=True)
expires_at = Column(DateTime, nullable=True)
is_active = Column(Boolean, nullable=False, default=True)
created_at = Column(DateTime, default=datetime.utcnow)
user = relationship("User", foreign_keys=[user_id])
__table_args__ = (
Index("ix_api_keys_user_id", "user_id"),
Index("ix_api_keys_key_hash", "key_hash"),
Index("ix_api_keys_active", "is_active"),
)
-253
View File
@@ -1,253 +0,0 @@
"""Phase 10: Attack Paths & Advanced Purple Team models."""
import enum
import uuid
from datetime import datetime
from sqlalchemy import (
Boolean, Column, DateTime, Enum, Float, ForeignKey,
Index, Integer, String, Text,
)
from sqlalchemy.dialects.postgresql import UUID, JSONB
from sqlalchemy.orm import relationship
from app.database import Base
class ExecutionStatus(str, enum.Enum):
planned = "planned"
in_progress = "in_progress"
completed = "completed"
aborted = "aborted"
class StepResultStatus(str, enum.Enum):
pending = "pending"
executing = "executing"
detected = "detected"
not_detected = "not_detected"
skipped = "skipped"
class TimelineActorSide(str, enum.Enum):
red = "red"
blue = "blue"
system = "system"
class TimelineEntryType(str, enum.Enum):
action = "action"
detection = "detection"
note = "note"
phase_transition = "phase_transition"
flag = "flag"
# ---------------------------------------------------------------------------
class AttackPath(Base):
"""
A reusable attack scenario composed of ordered kill-chain steps.
Can be a template (shared) or a one-off scenario.
"""
__tablename__ = "attack_paths"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
name = Column(String(300), nullable=False)
description = Column(Text, nullable=True)
objective = Column(Text, nullable=True) # what the attacker aims to achieve
is_template = Column(Boolean, default=False) # reusable template flag
threat_actor_id = Column(
UUID(as_uuid=True), ForeignKey("threat_actors.id", ondelete="SET NULL"), nullable=True
)
created_by = Column(
UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
)
tags = Column(JSONB, nullable=True, default=list)
is_active = Column(Boolean, default=True)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
steps = relationship(
"AttackPathStep", back_populates="attack_path",
cascade="all, delete-orphan",
order_by="AttackPathStep.order_index",
)
executions = relationship("AttackPathExecution", back_populates="attack_path")
creator = relationship("User", foreign_keys=[created_by])
threat_actor = relationship("ThreatActor", foreign_keys=[threat_actor_id])
__table_args__ = (
Index("ix_attack_paths_created_by", "created_by"),
Index("ix_attack_paths_is_template", "is_template"),
)
class AttackPathStep(Base):
"""One step in an attack path — maps to a kill-chain phase + technique."""
__tablename__ = "attack_path_steps"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
attack_path_id = Column(
UUID(as_uuid=True), ForeignKey("attack_paths.id", ondelete="CASCADE"), nullable=False
)
order_index = Column(Integer, nullable=False, default=0)
kill_chain_phase = Column(String(60), nullable=True) # initial_access, execution, …
technique_id = Column(
UUID(as_uuid=True), ForeignKey("techniques.id", ondelete="SET NULL"), nullable=True
)
test_id = Column(
UUID(as_uuid=True), ForeignKey("tests.id", ondelete="SET NULL"), nullable=True
)
name = Column(String(300), nullable=True) # human label for the step
description = Column(Text, nullable=True)
expected_detection = Column(Boolean, default=True) # do we expect blue to detect this?
notes = Column(Text, nullable=True)
attack_path = relationship("AttackPath", back_populates="steps")
technique = relationship("Technique", foreign_keys=[technique_id])
test = relationship("Test", foreign_keys=[test_id])
__table_args__ = (
Index("ix_ap_steps_path_id", "attack_path_id"),
Index("ix_ap_steps_technique_id", "technique_id"),
)
class AttackPathExecution(Base):
"""
A single run of an attack path.
Tracks Red/Blue participants, timing, and aggregated kill-chain metrics.
"""
__tablename__ = "attack_path_executions"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
attack_path_id = Column(
UUID(as_uuid=True), ForeignKey("attack_paths.id", ondelete="CASCADE"), nullable=False
)
status = Column(
Enum(ExecutionStatus, name="execution_status"), nullable=False,
default=ExecutionStatus.planned,
)
environment = Column(String(100), nullable=True) # prod, staging, lab
red_team_lead = Column(
UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
)
blue_team_lead = Column(
UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
)
started_by = Column(
UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
)
started_at = Column(DateTime, nullable=True)
completed_at = Column(DateTime, nullable=True)
notes = Column(Text, nullable=True)
created_at = Column(DateTime, default=datetime.utcnow)
# ── Computed kill-chain metrics (written on complete) ─────────────────
total_steps = Column(Integer, nullable=True)
detected_steps = Column(Integer, nullable=True)
not_detected_steps = Column(Integer, nullable=True)
skipped_steps = Column(Integer, nullable=True)
detection_rate = Column(Float, nullable=True) # 0.01.0
mttd_seconds = Column(Float, nullable=True) # mean time to detect (avg across detected)
furthest_undetected_step = Column(Integer, nullable=True) # order_index of deepest undetected step
attack_path = relationship("AttackPath", back_populates="executions")
step_results = relationship(
"AttackPathStepResult", back_populates="execution",
cascade="all, delete-orphan",
order_by="AttackPathStepResult.step_order",
)
timeline = relationship(
"TimelineEntry", back_populates="execution",
cascade="all, delete-orphan",
order_by="TimelineEntry.timestamp",
)
red_lead_user = relationship("User", foreign_keys=[red_team_lead])
blue_lead_user = relationship("User", foreign_keys=[blue_team_lead])
__table_args__ = (
Index("ix_ap_exec_path_id", "attack_path_id"),
Index("ix_ap_exec_status", "status"),
)
class AttackPathStepResult(Base):
"""Result of executing one step in an attack path execution."""
__tablename__ = "attack_path_step_results"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
execution_id = Column(
UUID(as_uuid=True), ForeignKey("attack_path_executions.id", ondelete="CASCADE"),
nullable=False,
)
step_id = Column(
UUID(as_uuid=True), ForeignKey("attack_path_steps.id", ondelete="CASCADE"),
nullable=False,
)
step_order = Column(Integer, nullable=False, default=0) # denormalized for sorting
status = Column(
Enum(StepResultStatus, name="step_result_status"), nullable=False,
default=StepResultStatus.pending,
)
executed_by = Column(
UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
)
executed_at = Column(DateTime, nullable=True)
detected_at = Column(DateTime, nullable=True)
time_to_detect_seconds = Column(Float, nullable=True)
detection_asset_id = Column(
UUID(as_uuid=True),
ForeignKey("detection_assets.id", ondelete="SET NULL"), nullable=True
)
notes = Column(Text, nullable=True)
evidence_ids = Column(JSONB, nullable=True, default=list)
execution = relationship("AttackPathExecution", back_populates="step_results")
step = relationship("AttackPathStep")
detection_asset = relationship("DetectionAsset", foreign_keys=[detection_asset_id])
executor = relationship("User", foreign_keys=[executed_by])
__table_args__ = (
Index("ix_ap_stepres_execution_id", "execution_id"),
Index("ix_ap_stepres_step_id", "step_id"),
)
class TimelineEntry(Base):
"""Timestamped Red/Blue action during an execution — used for MTTD/MTTR."""
__tablename__ = "attack_path_timeline"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
execution_id = Column(
UUID(as_uuid=True), ForeignKey("attack_path_executions.id", ondelete="CASCADE"),
nullable=False,
)
step_id = Column(
UUID(as_uuid=True), ForeignKey("attack_path_steps.id", ondelete="SET NULL"),
nullable=True,
)
timestamp = Column(DateTime, nullable=False, default=datetime.utcnow)
actor_side = Column(
Enum(TimelineActorSide, name="timeline_actor_side"), nullable=False,
)
actor_id = Column(
UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
)
entry_type = Column(
Enum(TimelineEntryType, name="timeline_entry_type"), nullable=False,
)
content = Column(Text, nullable=False)
extra = Column(JSONB, nullable=True)
execution = relationship("AttackPathExecution", back_populates="timeline")
actor = relationship("User", foreign_keys=[actor_id])
__table_args__ = (
Index("ix_timeline_execution_id", "execution_id"),
Index("ix_timeline_timestamp", "timestamp"),
)
+1 -1
View File
@@ -73,7 +73,7 @@ class Campaign(Base):
# Keyword argument: nullable # Keyword argument: nullable
nullable=True, nullable=True,
) )
start_date = Column(DateTime, nullable=True) # campaign won't activate before this date # Assign scheduled_at = Column(DateTime, nullable=True)
scheduled_at = Column(DateTime, nullable=True) scheduled_at = Column(DateTime, nullable=True)
# Assign completed_at = Column(DateTime, nullable=True) # Assign completed_at = Column(DateTime, nullable=True)
completed_at = Column(DateTime, nullable=True) completed_at = Column(DateTime, nullable=True)
-32
View File
@@ -1,32 +0,0 @@
"""Decay Policy model — configurable detection validity rules."""
import uuid
from datetime import datetime
from sqlalchemy import Column, String, Integer, Float, Boolean, DateTime, Text
from sqlalchemy.dialects.postgresql import UUID
from app.database import Base
class DecayPolicy(Base):
__tablename__ = "decay_policies"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
name = Column(String(200), nullable=False)
description = Column(Text)
applies_to_platform = Column(String(100))
applies_to_asset_type = Column(String(50))
applies_to_tactic = Column(String(100))
fresh_days = Column(Integer, default=90, server_default='90')
aging_days = Column(Integer, default=180, server_default='180')
stale_days = Column(Integer, default=365, server_default='365')
default_validity_days = Column(Integer, default=180, server_default='180')
silent_threshold_days = Column(Integer, default=30, server_default='30')
noisy_threshold_daily = Column(Integer, default=100, server_default='100')
recency_weight = Column(Float, default=0.3, server_default='0.3')
coverage_weight = Column(Float, default=0.3, server_default='0.3')
health_weight = Column(Float, default=0.25, server_default='0.25')
diversity_weight = Column(Float, default=0.15, server_default='0.15')
is_default = Column(Boolean, default=False, server_default='false')
is_active = Column(Boolean, default=True, server_default='true')
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow)
-168
View File
@@ -1,168 +0,0 @@
"""Detection Lifecycle Management models."""
import uuid
import enum
from datetime import datetime
from sqlalchemy import (
Column, String, Integer, Float, Boolean, DateTime,
ForeignKey, Text, Enum as SQLEnum
)
from sqlalchemy.dialects.postgresql import UUID, JSONB
from sqlalchemy.orm import relationship
from app.database import Base
class DetectionConfidence(str, enum.Enum):
fresh = "fresh"
aging = "aging"
stale = "stale"
broken = "broken"
unknown = "unknown"
class DetectionHealthStatus(str, enum.Enum):
healthy = "healthy"
silent = "silent"
noisy = "noisy"
orphan = "orphan"
deprecated = "deprecated"
untested = "untested"
class InvalidationReason(str, enum.Enum):
time_decay = "time_decay"
mitre_update = "mitre_update"
log_source_change = "log_source_change"
siem_update = "siem_update"
edr_update = "edr_update"
infrastructure_change = "infrastructure_change"
parser_change = "parser_change"
manual = "manual"
rule_modified = "rule_modified"
class DetectionAsset(Base):
__tablename__ = "detection_assets"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
name = Column(String(500), nullable=False)
description = Column(Text)
asset_type = Column(String(50), nullable=False)
platform = Column(String(100))
rule_content = Column(Text)
rule_language = Column(String(50))
rule_repository_url = Column(Text)
rule_file_path = Column(String(500))
rule_version = Column(String(50))
rule_hash = Column(String(64))
last_rule_change_at = Column(DateTime)
log_source_name = Column(String(200))
log_source_version = Column(String(50))
log_source_config = Column(JSONB, server_default='{}')
infrastructure_hash = Column(String(64))
infrastructure_details = Column(JSONB, server_default='{}')
health_status = Column(
SQLEnum(DetectionHealthStatus, name="detectionhealthstatus"),
default=DetectionHealthStatus.untested,
nullable=False,
server_default="untested",
)
last_alert_at = Column(DateTime)
alert_count_30d = Column(Integer, default=0, server_default='0')
false_positive_rate = Column(Float)
expected_alert_frequency = Column(String(50))
owner_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True)
backup_owner_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True)
team = Column(String(100))
is_active = Column(Boolean, default=True, nullable=False, server_default='true')
tags = Column(JSONB, server_default='[]')
asset_metadata = Column(JSONB, server_default='{}')
created_by = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True)
created_at = Column(DateTime(timezone=True), server_default='now()')
updated_at = Column(DateTime(timezone=True), server_default='now()')
technique_mappings = relationship("DetectionTechniqueMapping", back_populates="detection_asset", cascade="all, delete-orphan")
validations = relationship("DetectionValidation", back_populates="detection_asset", cascade="all, delete-orphan")
class DetectionTechniqueMapping(Base):
__tablename__ = "detection_technique_mappings"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
detection_asset_id = Column(UUID(as_uuid=True), ForeignKey("detection_assets.id", ondelete="CASCADE"), nullable=False)
technique_id = Column(UUID(as_uuid=True), ForeignKey("techniques.id", ondelete="CASCADE"), nullable=False)
coverage_type = Column(String(50), default="detect", server_default="detect")
confidence_level = Column(String(20), default="medium", server_default="medium")
notes = Column(Text)
created_at = Column(DateTime(timezone=True), server_default='now()')
detection_asset = relationship("DetectionAsset", back_populates="technique_mappings")
class DetectionValidation(Base):
__tablename__ = "detection_validations"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
detection_asset_id = Column(UUID(as_uuid=True), ForeignKey("detection_assets.id", ondelete="CASCADE"), nullable=False)
technique_id = Column(UUID(as_uuid=True), ForeignKey("techniques.id", ondelete="SET NULL"), nullable=True)
test_id = Column(UUID(as_uuid=True), ForeignKey("tests.id", ondelete="SET NULL"), nullable=True)
validated_at = Column(DateTime, default=datetime.utcnow)
expires_at = Column(DateTime, nullable=False)
is_valid = Column(Boolean, default=True, nullable=False, server_default='true')
validation_result = Column(String(50))
validation_method = Column(String(100))
rule_hash_at_validation = Column(String(64))
log_source_version_at_validation = Column(String(50))
infrastructure_hash_at_validation = Column(String(64))
environment_snapshot = Column(JSONB, server_default='{}')
invalidated_at = Column(DateTime)
invalidation_reason = Column(SQLEnum(InvalidationReason, name="invalidationreason"), nullable=True)
invalidation_details = Column(Text)
invalidated_by = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True)
validated_by = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=False)
integrity_hash = Column(String(64))
notes = Column(Text)
evidence_ids = Column(JSONB, server_default='[]')
detection_asset = relationship("DetectionAsset", back_populates="validations")
class TechniqueConfidenceScore(Base):
__tablename__ = "technique_confidence_scores"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
technique_id = Column(UUID(as_uuid=True), ForeignKey("techniques.id", ondelete="CASCADE"), nullable=False, unique=True)
confidence_level = Column(
SQLEnum(DetectionConfidence, name="detectionconfidence"),
default=DetectionConfidence.unknown,
server_default="unknown",
)
confidence_score = Column(Float, default=0.0, server_default='0.0')
detection_count = Column(Integer, default=0, server_default='0')
valid_detection_count = Column(Integer, default=0, server_default='0')
last_validated_at = Column(DateTime)
next_validation_due = Column(DateTime)
last_recalculated_at = Column(DateTime, default=datetime.utcnow)
recency_factor = Column(Float, default=0.0, server_default='0.0')
coverage_factor = Column(Float, default=0.0, server_default='0.0')
health_factor = Column(Float, default=0.0, server_default='0.0')
diversity_factor = Column(Float, default=0.0, server_default='0.0')
score_breakdown = Column(JSONB, server_default='{}')
risk_factors = Column(JSONB, server_default='[]')
updated_at = Column(DateTime, default=datetime.utcnow)
class InfrastructureChangeLog(Base):
__tablename__ = "infrastructure_change_logs"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
change_type = Column(String(100), nullable=False)
description = Column(Text, nullable=False)
affected_platforms = Column(JSONB, server_default='[]')
affected_log_sources = Column(JSONB, server_default='[]')
change_date = Column(DateTime, default=datetime.utcnow)
reported_by = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True)
auto_invalidate = Column(Boolean, default=True, server_default='true')
invalidated_count = Column(Integer, default=0, server_default='0')
change_metadata = Column(JSONB, server_default='{}')
created_at = Column(DateTime, default=datetime.utcnow)
-34
View File
@@ -1,34 +0,0 @@
"""SQLAlchemy model for tracking imported ATT&CK Evaluation rounds."""
import uuid
from datetime import datetime
from sqlalchemy import Column, String, Integer, DateTime, Text, ForeignKey, Index
from sqlalchemy.dialects.postgresql import UUID
from app.database import Base
class EvaluationImport(Base):
"""Tracks which ATT&CK Evaluation rounds have been imported into the platform.
Each row represents one vendor+adversary combination that has been processed
and turned into Test records. Used to avoid duplicate imports and to show
the admin panel which rounds are available vs imported.
"""
__tablename__ = "evaluation_imports"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
adversary_name = Column(String, nullable=False) # "apt29", "turla"
adversary_display = Column(String, nullable=False) # "APT29", "Turla"
eval_round = Column(Integer, nullable=False) # 1, 2, 3 …
imported_at = Column(DateTime, nullable=False, default=datetime.utcnow)
imported_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
tests_created = Column(Integer, default=0)
techniques_covered = Column(Integer, default=0)
status = Column(String, default="completed") # "completed" | "failed"
notes = Column(Text, nullable=True)
__table_args__ = (
Index("ix_evaluation_imports_adversary", "adversary_name"),
Index("ix_evaluation_imports_round", "eval_round"),
)
-68
View File
@@ -1,68 +0,0 @@
"""Phase 13: Executive Dashboard — PostureSnapshot model."""
import uuid
from datetime import datetime
from sqlalchemy import (
Column, Date, DateTime, Float, ForeignKey,
Index, Integer, UniqueConstraint,
)
from sqlalchemy.dialects.postgresql import UUID, JSONB
from sqlalchemy.orm import relationship
from app.database import Base
class PostureSnapshot(Base):
"""
Daily point-in-time capture of the organisation's security posture.
Aggregates data from all phases (coverage, risk, ownership, knowledge,
attack-paths) into a single row that can be trended over time.
"""
__tablename__ = "posture_snapshots"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
snapshot_date = Column(Date, nullable=False) # one per calendar day
# ── Coverage ──────────────────────────────────────────────────────────────
total_techniques = Column(Integer, nullable=False, default=0)
validated_count = Column(Integer, nullable=False, default=0)
partial_count = Column(Integer, nullable=False, default=0)
not_covered_count = Column(Integer, nullable=False, default=0)
coverage_pct = Column(Float, nullable=False, default=0.0) # 0100
# ── Risk ─────────────────────────────────────────────────────────────────
avg_risk_score = Column(Float, nullable=False, default=0.0)
critical_count = Column(Integer, nullable=False, default=0)
high_count = Column(Integer, nullable=False, default=0)
medium_count = Column(Integer, nullable=False, default=0)
low_count = Column(Integer, nullable=False, default=0)
# ── Operations ────────────────────────────────────────────────────────────
open_queue_items = Column(Integer, nullable=False, default=0)
orphan_techniques = Column(Integer, nullable=False, default=0)
# ── Knowledge ─────────────────────────────────────────────────────────────
playbook_count = Column(Integer, nullable=False, default=0)
lesson_count = Column(Integer, nullable=False, default=0)
# ── MTTD (from attack-path executions completed in last 30 d) ────────────
mttd_avg_seconds = Column(Float, nullable=True) # None if no data
executions_30d = Column(Integer, nullable=False, default=0)
detection_rate_30d = Column(Float, nullable=True) # avg across executions
# ── Meta ─────────────────────────────────────────────────────────────────
created_by = Column(
UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
)
created_at = Column(DateTime, default=datetime.utcnow)
extra = Column(JSONB, nullable=True) # full breakdown / by-tactic
creator = relationship("User", foreign_keys=[created_by])
__table_args__ = (
UniqueConstraint("snapshot_date", name="uq_posture_snapshot_date"),
Index("ix_posture_snapshots_date", "snapshot_date"),
)
-129
View File
@@ -1,129 +0,0 @@
"""Phase 11: Knowledge Management models — Playbooks + Lessons Learned."""
import uuid
from datetime import datetime
from sqlalchemy import (
Boolean, Column, DateTime, ForeignKey,
Index, Integer, String, Text, UniqueConstraint,
)
from sqlalchemy.dialects.postgresql import UUID, JSONB
from sqlalchemy.orm import relationship
from app.database import Base
# ── Playbooks ──────────────────────────────────────────────────────────────────
class Playbook(Base):
"""
Structured runbook for a specific technique and playbook type.
playbook_type: attack | detect | investigate | respond | hunt
One playbook per (technique, type). Edits increment ``version``
and save a snapshot to ``PlaybookVersion``.
"""
__tablename__ = "playbooks"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
technique_id = Column(
UUID(as_uuid=True), ForeignKey("techniques.id", ondelete="CASCADE"), nullable=False
)
playbook_type = Column(String(32), nullable=False) # attack/detect/investigate/respond/hunt
title = Column(String(255), nullable=False)
content = Column(Text, nullable=False, default="")
version = Column(Integer, default=1, nullable=False)
tools = Column(JSONB, default=list) # list of tool name strings
prerequisites = Column(JSONB, default=list) # list of prerequisite strings
created_by = Column(
UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
)
updated_by = Column(
UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
is_active = Column(Boolean, default=True)
# Relationships
technique = relationship("Technique", foreign_keys=[technique_id])
creator = relationship("User", foreign_keys=[created_by])
updater = relationship("User", foreign_keys=[updated_by])
versions = relationship(
"PlaybookVersion", back_populates="playbook",
cascade="all, delete-orphan",
order_by="PlaybookVersion.version.desc()",
)
__table_args__ = (
UniqueConstraint("technique_id", "playbook_type", name="uq_playbook_technique_type"),
Index("ix_playbooks_technique_id", "technique_id"),
Index("ix_playbooks_type", "playbook_type"),
)
class PlaybookVersion(Base):
"""Immutable snapshot of a playbook at a given version number."""
__tablename__ = "playbook_versions"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
playbook_id = Column(
UUID(as_uuid=True), ForeignKey("playbooks.id", ondelete="CASCADE"), nullable=False
)
version = Column(Integer, nullable=False)
title = Column(String(255), nullable=False)
content = Column(Text, nullable=False, default="")
tools = Column(JSONB, default=list)
prerequisites = Column(JSONB, default=list)
changed_by = Column(
UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
)
change_note = Column(String(500), nullable=True)
created_at = Column(DateTime, default=datetime.utcnow)
playbook = relationship("Playbook", back_populates="versions")
changer = relationship("User", foreign_keys=[changed_by])
__table_args__ = (
Index("ix_pb_versions_playbook_id", "playbook_id"),
Index("ix_pb_versions_version", "playbook_id", "version"),
)
# ── Lessons Learned ────────────────────────────────────────────────────────────
class LessonLearned(Base):
"""
Immutable post-mortem record linked to a test, campaign, attack-path or
created manually.
severity: critical | high | medium | low | info
entity_type: test | campaign | attack_path | manual
"""
__tablename__ = "lessons_learned"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
title = Column(String(255), nullable=False)
what_happened = Column(Text, nullable=False, default="")
root_cause = Column(Text, nullable=False, default="")
fix_applied = Column(Text, nullable=True)
severity = Column(String(16), nullable=False, default="medium")
entity_type = Column(String(32), nullable=False, default="manual")
entity_id = Column(UUID(as_uuid=True), nullable=True)
technique_ids = Column(JSONB, default=list) # list of UUID strings
tags = Column(JSONB, default=list)
created_by = Column(
UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
is_active = Column(Boolean, default=True) # soft-delete (admin only)
creator = relationship("User", foreign_keys=[created_by])
__table_args__ = (
Index("ix_ll_entity", "entity_type", "entity_id"),
Index("ix_ll_severity", "severity"),
Index("ix_ll_created_by", "created_by"),
)
-144
View File
@@ -1,144 +0,0 @@
"""Phase 13: Operational Alerts — AlertRule and AlertInstance models."""
import enum
import uuid
from datetime import datetime
from sqlalchemy import (
Boolean, Column, DateTime, ForeignKey,
Index, Integer, String, Text,
)
from sqlalchemy.dialects.postgresql import JSONB, UUID
from sqlalchemy.orm import relationship
from app.database import Base
# ── Enumerations ──────────────────────────────────────────────────────────────
class AlertSeverity(str, enum.Enum):
critical = "critical"
high = "high"
medium = "medium"
low = "low"
info = "info"
class AlertStatus(str, enum.Enum):
open = "open"
acknowledged = "acknowledged"
resolved = "resolved"
dismissed = "dismissed"
class AlertRuleType(str, enum.Enum):
high_risk = "high_risk" # risk_score >= threshold
stale_technique = "stale_technique" # not validated in N days
coverage_regression = "coverage_regression" # coverage_pct dropped
low_coverage = "low_coverage" # coverage below min
expiry_wave = "expiry_wave" # many pending queue items
new_technique = "new_technique" # new MITRE techniques added
orphan_spike = "orphan_spike" # many unowned techniques
custom = "custom" # future extension placeholder
# ── AlertRule ─────────────────────────────────────────────────────────────────
class AlertRule(Base):
"""
Defines a condition that, when satisfied, fires an AlertInstance.
System rules (is_system=True) are seeded at startup and cannot be deleted.
Custom rules (is_system=False) can be created by admins.
"""
__tablename__ = "alert_rules"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
name = Column(String(300), nullable=False)
description = Column(Text, nullable=True)
rule_type = Column(String(50), nullable=False)
severity = Column(String(20), nullable=False, default=AlertSeverity.medium.value)
is_enabled = Column(Boolean, nullable=False, default=True)
is_system = Column(Boolean, nullable=False, default=False) # seeded, not deletable
# Rule-specific thresholds/config (varies by rule_type)
config = Column(JSONB, nullable=False, default={})
# Delivery
notify_in_app = Column(Boolean, nullable=False, default=True)
notify_webhook = Column(Boolean, nullable=False, default=False)
webhook_id = Column(
UUID(as_uuid=True),
ForeignKey("webhook_configs.id", ondelete="SET NULL"),
nullable=True,
)
# Cooldown — don't re-fire within N hours of last firing
cooldown_hours = Column(Integer, nullable=False, default=24)
# Meta
created_by = Column(
UUID(as_uuid=True),
ForeignKey("users.id", ondelete="SET NULL"),
nullable=True,
)
created_at = Column(DateTime, default=datetime.utcnow)
last_fired_at = Column(DateTime, nullable=True)
creator = relationship("User", foreign_keys=[created_by])
instances = relationship("AlertInstance", back_populates="rule",
cascade="all, delete-orphan")
__table_args__ = (
Index("ix_alert_rules_type", "rule_type"),
Index("ix_alert_rules_enabled", "is_enabled"),
)
# ── AlertInstance ─────────────────────────────────────────────────────────────
class AlertInstance(Base):
"""
A single firing of an AlertRule.
Transitions: open → acknowledged → resolved
open → dismissed
"""
__tablename__ = "alert_instances"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
rule_id = Column(
UUID(as_uuid=True),
ForeignKey("alert_rules.id", ondelete="SET NULL"),
nullable=True,
)
# Denormalised fields kept for history even after rule deletion
rule_name = Column(String(300), nullable=False)
rule_type = Column(String(50), nullable=False)
severity = Column(String(20), nullable=False)
title = Column(String(500), nullable=False)
message = Column(Text, nullable=False)
details = Column(JSONB, nullable=True) # structured context
status = Column(String(20), nullable=False, default=AlertStatus.open.value)
acknowledged_by = Column(
UUID(as_uuid=True),
ForeignKey("users.id", ondelete="SET NULL"),
nullable=True,
)
acknowledged_at = Column(DateTime, nullable=True)
resolved_at = Column(DateTime, nullable=True)
created_at = Column(DateTime, default=datetime.utcnow)
rule = relationship("AlertRule", back_populates="instances")
acknowledger = relationship("User", foreign_keys=[acknowledged_by])
__table_args__ = (
Index("ix_alert_instances_rule_id", "rule_id"),
Index("ix_alert_instances_status", "status"),
Index("ix_alert_instances_severity", "severity"),
Index("ix_alert_instances_created", "created_at"),
)
-136
View File
@@ -1,136 +0,0 @@
"""Phase 9: Ownership & Revalidation Queue models."""
import enum
import uuid
from datetime import datetime
from sqlalchemy import Column, DateTime, Enum, ForeignKey, Index, String, Text
from sqlalchemy.dialects.postgresql import UUID, JSONB
from sqlalchemy.orm import relationship
from app.database import Base
class QueuePriority(str, enum.Enum):
critical = "critical"
high = "high"
medium = "medium"
low = "low"
class QueueStatus(str, enum.Enum):
pending = "pending"
in_progress = "in_progress"
completed = "completed"
dismissed = "dismissed"
class QueueReason(str, enum.Enum):
validation_expired = "validation_expired"
infra_change = "infra_change"
osint_alert = "osint_alert"
mitre_update = "mitre_update"
rule_modified = "rule_modified"
low_confidence = "low_confidence"
manual = "manual"
class TechniqueOwnership(Base):
"""Ownership assignment for a MITRE technique."""
__tablename__ = "technique_ownerships"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
technique_id = Column(
UUID(as_uuid=True),
ForeignKey("techniques.id", ondelete="CASCADE"),
nullable=False,
unique=True,
)
owner_id = Column(
UUID(as_uuid=True),
ForeignKey("users.id", ondelete="SET NULL"),
nullable=True,
)
backup_owner_id = Column(
UUID(as_uuid=True),
ForeignKey("users.id", ondelete="SET NULL"),
nullable=True,
)
team = Column(String(200), nullable=True)
notes = Column(Text, nullable=True)
assigned_at = Column(DateTime, nullable=True)
assigned_by = Column(
UUID(as_uuid=True),
ForeignKey("users.id", ondelete="SET NULL"),
nullable=True,
)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
technique = relationship("Technique", foreign_keys=[technique_id])
owner = relationship("User", foreign_keys=[owner_id])
backup_owner = relationship("User", foreign_keys=[backup_owner_id])
class RevalidationQueueItem(Base):
"""A prioritised work item for the analyst's daily queue."""
__tablename__ = "revalidation_queue_items"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
technique_id = Column(
UUID(as_uuid=True),
ForeignKey("techniques.id", ondelete="CASCADE"),
nullable=True,
)
detection_asset_id = Column(
UUID(as_uuid=True),
ForeignKey("detection_assets.id", ondelete="CASCADE"),
nullable=True,
)
priority = Column(
Enum(QueuePriority, name="queue_priority"),
nullable=False,
default=QueuePriority.medium,
)
reason = Column(
Enum(QueueReason, name="queue_reason"),
nullable=False,
)
reason_detail = Column(Text, nullable=True)
status = Column(
Enum(QueueStatus, name="queue_status"),
nullable=False,
default=QueueStatus.pending,
)
assigned_to = Column(
UUID(as_uuid=True),
ForeignKey("users.id", ondelete="SET NULL"),
nullable=True,
)
due_date = Column(DateTime, nullable=True)
created_at = Column(DateTime, default=datetime.utcnow)
completed_at = Column(DateTime, nullable=True)
dismissed_at = Column(DateTime, nullable=True)
completed_by = Column(
UUID(as_uuid=True),
ForeignKey("users.id", ondelete="SET NULL"),
nullable=True,
)
extra = Column(JSONB, nullable=True) # arbitrary metadata
technique = relationship("Technique", foreign_keys=[technique_id])
detection_asset = relationship("DetectionAsset", foreign_keys=[detection_asset_id])
assignee = relationship("User", foreign_keys=[assigned_to])
# Indexes
Index("ix_rqueue_status", RevalidationQueueItem.status)
Index("ix_rqueue_priority", RevalidationQueueItem.priority)
Index("ix_rqueue_assigned_to", RevalidationQueueItem.assigned_to)
Index("ix_rqueue_technique_id", RevalidationQueueItem.technique_id)
Index("ix_rqueue_asset_id", RevalidationQueueItem.detection_asset_id)
Index("ix_techown_owner_id", TechniqueOwnership.owner_id)
-69
View File
@@ -1,69 +0,0 @@
"""Phase 12: Risk Intelligence model — per-technique risk scoring."""
import uuid
from datetime import datetime
from sqlalchemy import (
Boolean, Column, DateTime, Float, ForeignKey,
Index, Integer, String, UniqueConstraint,
)
from sqlalchemy.dialects.postgresql import UUID, JSONB
from sqlalchemy.orm import relationship
from app.database import Base
class TechniqueRiskProfile(Base):
"""
Aggregated risk profile for one technique.
Combines four weighted factors:
• detection_gap (35 %) — 0=fully covered → 1=no coverage
• threat_actor_rel (30 %) — normalised actor count
• osint_signals (20 %) — normalised recent OSINT items (30 d)
• test_failure_rate (15 %) — proportion of tests where blue didn't detect
risk_score = weighted sum × 100 → 0100
risk_level: critical ≥75 | high ≥50 | medium ≥25 | low ≥10 | info
"""
__tablename__ = "technique_risk_profiles"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
technique_id = Column(
UUID(as_uuid=True),
ForeignKey("techniques.id", ondelete="CASCADE"),
nullable=False,
)
# ── Computed scores ───────────────────────────────────────────────────────
risk_score = Column(Float, nullable=False, default=0.0) # 0100
likelihood = Column(Float, nullable=False, default=0.0) # 0100
impact = Column(Float, nullable=False, default=0.0) # 0100
risk_level = Column(String(16), nullable=False, default="info")
# ── Raw factor values ─────────────────────────────────────────────────────
detection_gap = Column(Float, nullable=False, default=1.0) # 01
threat_actor_count = Column(Integer, nullable=False, default=0)
osint_signal_count = Column(Integer, nullable=False, default=0) # last 30 d
test_fail_count = Column(Integer, nullable=False, default=0)
test_total_count = Column(Integer, nullable=False, default=0)
test_failure_rate = Column(Float, nullable=False, default=0.0) # 01
confidence_level = Column(Float, nullable=False, default=0.0) # DLC 01
# ── Rich detail ──────────────────────────────────────────────────────────
scoring_breakdown = Column(JSONB, nullable=True) # per-factor contributions
recommendations = Column(JSONB, nullable=True) # list[str]
# ── Meta ─────────────────────────────────────────────────────────────────
computed_at = Column(DateTime, default=datetime.utcnow)
is_stale = Column(Boolean, default=True)
technique = relationship("Technique", foreign_keys=[technique_id])
__table_args__ = (
UniqueConstraint("technique_id", name="uq_risk_profile_technique"),
Index("ix_risk_profiles_risk_score", "risk_score"),
Index("ix_risk_profiles_risk_level", "risk_level"),
Index("ix_risk_profiles_stale", "is_stale"),
)
-49
View File
@@ -1,49 +0,0 @@
"""Phase 14: SSO / SAML 2.0 configuration model."""
import uuid
from datetime import datetime
from sqlalchemy import Boolean, Column, DateTime, String, Text
from sqlalchemy.dialects.postgresql import UUID
from app.database import Base
class SsoConfig(Base):
"""
SAML 2.0 Identity Provider configuration.
Exactly one row is expected (use upsert). The SP metadata endpoint
reads from this row to generate XML for IdP registration.
"""
__tablename__ = "sso_configs"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
is_enabled = Column(Boolean, nullable=False, default=False)
provider_name = Column(String(200), nullable=True) # e.g., "Okta", "Azure AD"
# ── Service Provider (Aegis) settings ────────────────────────────────────
sp_entity_id = Column(String(500), nullable=True) # e.g., https://aegis.co/api/v1/sso/metadata
sp_acs_url = Column(String(500), nullable=True) # Assertion Consumer Service URL
sp_slo_url = Column(String(500), nullable=True) # Single Logout URL (optional)
sp_certificate = Column(Text, nullable=True) # SP public cert for signed requests
sp_private_key = Column(Text, nullable=True) # SP private key (stored encrypted in future)
# ── Identity Provider settings ────────────────────────────────────────────
idp_entity_id = Column(String(500), nullable=True)
idp_sso_url = Column(String(500), nullable=True) # IdP redirect/POST binding URL
idp_slo_url = Column(String(500), nullable=True) # IdP SLO URL
idp_certificate = Column(Text, nullable=True) # IdP X.509 cert for response validation
# ── Attribute mapping ─────────────────────────────────────────────────────
# SAML attribute name → Aegis field
attr_email = Column(String(200), nullable=True, default="email")
attr_username = Column(String(200), nullable=True, default="username")
attr_role = Column(String(200), nullable=True, default="role")
default_role = Column(String(50), nullable=True, default="viewer")
auto_provision = Column(Boolean, nullable=False, default=True) # create user on first login
# ── Meta ─────────────────────────────────────────────────────────────────
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
-26
View File
@@ -1,26 +0,0 @@
"""SystemConfig model — runtime key-value configuration store."""
import uuid
from sqlalchemy import Column, String, Text, DateTime, func
from sqlalchemy.dialects.postgresql import UUID
from app.database import Base
class SystemConfig(Base):
"""Generic key-value store for runtime system configuration.
Currently used for:
- SMTP email settings (overrides .env values when present)
Keys are namespaced by convention: ``smtp.host``, ``smtp.port``, etc.
"""
__tablename__ = "system_configs"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
key = Column(String(200), unique=True, nullable=False, index=True)
value = Column(Text, nullable=True)
description = Column(String(500), nullable=True)
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
+1 -1
View File
@@ -96,7 +96,7 @@ class Test(Base):
red_started_at = Column(DateTime, nullable=True) red_started_at = Column(DateTime, nullable=True)
# Assign blue_started_at = Column(DateTime, nullable=True) # Assign blue_started_at = Column(DateTime, nullable=True)
blue_started_at = Column(DateTime, nullable=True) blue_started_at = Column(DateTime, nullable=True)
blue_work_started_at = Column(DateTime, nullable=True) # when blue tech picks up (Tempo start) # Assign paused_at = Column(DateTime, nullable=True)
paused_at = Column(DateTime, nullable=True) paused_at = Column(DateTime, nullable=True)
# Assign red_paused_seconds = Column(Integer, default=0) # Assign red_paused_seconds = Column(Integer, default=0)
red_paused_seconds = Column(Integer, default=0) red_paused_seconds = Column(Integer, default=0)
+6 -7
View File
@@ -2,8 +2,12 @@
# Import uuid # Import uuid
import uuid import uuid
from sqlalchemy import Column, String, Boolean, DateTime, func
from sqlalchemy.dialects.postgresql import UUID, JSONB # Import Boolean, Column, DateTime, String, func from sqlalchemy
from sqlalchemy import Boolean, Column, DateTime, String, func
# Import UUID from sqlalchemy.dialects.postgresql
from sqlalchemy.dialects.postgresql import UUID
# Import Base from app.database # Import Base from app.database
from app.database import Base from app.database import Base
@@ -42,8 +46,3 @@ class User(Base):
created_at = Column(DateTime(timezone=True), server_default=func.now()) created_at = Column(DateTime(timezone=True), server_default=func.now())
# Assign last_login = Column(DateTime, nullable=True) # Assign last_login = Column(DateTime, nullable=True)
last_login = Column(DateTime, nullable=True) last_login = Column(DateTime, nullable=True)
notification_preferences = Column(JSONB, nullable=True, server_default='{"email_on_test_validated": true, "email_on_campaign_completed": true, "email_on_new_mitre_techniques": false, "in_app_all": true}')
jira_account_id = Column(String(100), nullable=True)
jira_api_token = Column(String(500), nullable=True) # personal Atlassian token
jira_email = Column(String(255), nullable=True) # Atlassian email (overrides account email)
tempo_api_token = Column(String(500), nullable=True) # personal Tempo API token
-18
View File
@@ -1,18 +0,0 @@
"""WebhookConfig model — outbound HTTP notification endpoints."""
import uuid
from sqlalchemy import Column, String, Boolean, DateTime, Integer, Text, ForeignKey, func
from sqlalchemy.dialects.postgresql import UUID, JSONB
from app.database import Base
class WebhookConfig(Base):
__tablename__ = "webhook_configs"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
name = Column(String(200), nullable=False)
url = Column(Text, nullable=False)
secret = Column(String(256), nullable=True) # HMAC signature key
events = Column(JSONB, nullable=False, server_default="[]") # list of event types
is_active = Column(Boolean, default=True, nullable=False)
created_by = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True)
last_triggered_at = Column(DateTime, nullable=True)
failure_count = Column(Integer, default=0, nullable=False)
created_at = Column(DateTime(timezone=True), server_default=func.now())
-339
View File
@@ -1,339 +0,0 @@
"""Admin configuration export/import — single-file migration bundle.
GET /admin/export-config — download JSON bundle (admin only)
POST /admin/import-config — upload JSON bundle and restore (admin only)
What is exported (and what is NOT):
✓ system_configs — email / jira settings (passwords REDACTED)
✓ webhook_configs — notification webhooks (secrets REDACTED)
✓ sso_configs — SAML/SSO config (private keys REDACTED)
✓ scoring_config — technique scoring weights
✓ test_templates — CUSTOM templates only (source='custom')
✓ users — username / email / role (no passwords / tokens)
✗ atomic/sigma/elastic templates, techniques, tests, campaigns, reports
"""
from datetime import datetime
from typing import Any
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import JSONResponse
from sqlalchemy.orm import Session
from app.auth import hash_password
from app.database import get_db
from app.dependencies.auth import require_role
from app.models.scoring_config import ScoringConfig
from app.models.sso_config import SsoConfig
from app.models.system_config import SystemConfig
from app.models.test_template import TestTemplate
from app.models.user import User
from app.models.webhook_config import WebhookConfig
router = APIRouter(prefix="/admin", tags=["admin"])
# Keys whose values contain secrets and must be redacted in the export
_REDACTED_KEYS = {
"smtp.password",
"jira.api_token",
"jira.password",
"tempo.api_token",
}
_EXPORT_VERSION = "1.0"
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _redact(key: str, value: Any) -> Any:
if key in _REDACTED_KEYS:
return "[REDACTED]"
return value
# ---------------------------------------------------------------------------
# GET /admin/export-config
# ---------------------------------------------------------------------------
@router.get("/export-config")
def export_config(
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
"""Export all platform configuration as a downloadable JSON bundle."""
# ── 1. system_configs ────────────────────────────────────────────
system_configs = [
{
"key": r.key,
"value": _redact(r.key, r.value),
"description": r.description,
}
for r in db.query(SystemConfig).order_by(SystemConfig.key).all()
]
# ── 2. webhook_configs ───────────────────────────────────────────
webhooks = [
{
"name": w.name,
"url": w.url,
"secret": "[REDACTED]" if w.secret else None,
"events": w.events or [],
"is_active": w.is_active,
}
for w in db.query(WebhookConfig).order_by(WebhookConfig.name).all()
]
# ── 3. SSO config (single row) ───────────────────────────────────
sso_row = db.query(SsoConfig).first()
sso = None
if sso_row:
sso = {
"is_enabled": sso_row.is_enabled,
"provider_name": sso_row.provider_name,
"sp_entity_id": sso_row.sp_entity_id,
"sp_acs_url": sso_row.sp_acs_url,
"sp_slo_url": sso_row.sp_slo_url,
"sp_certificate": sso_row.sp_certificate,
"sp_private_key": "[REDACTED]", # never export private keys
"idp_entity_id": sso_row.idp_entity_id,
"idp_sso_url": getattr(sso_row, "idp_sso_url", None),
"idp_slo_url": getattr(sso_row, "idp_slo_url", None),
"idp_certificate": getattr(sso_row, "idp_certificate", None),
"attr_email": getattr(sso_row, "attr_email", None),
"attr_username": getattr(sso_row, "attr_username", None),
"attr_role": getattr(sso_row, "attr_role", None),
"default_role": getattr(sso_row, "default_role", None),
"auto_provision": getattr(sso_row, "auto_provision", False),
}
# ── 4. Scoring config (single row) ──────────────────────────────
sc = db.query(ScoringConfig).first()
scoring = None
if sc:
scoring = {
"weight_tests": sc.weight_tests,
"weight_detection_rules": sc.weight_detection_rules,
"weight_d3fend": sc.weight_d3fend,
"weight_recency": sc.weight_recency,
"weight_severity": sc.weight_severity,
}
# ── 5. Custom test templates only ───────────────────────────────
templates = [
{
"mitre_technique_id": t.mitre_technique_id,
"name": t.name,
"description": t.description,
"source": t.source,
"source_url": t.source_url,
"attack_procedure": t.attack_procedure,
"expected_detection": t.expected_detection,
"platform": t.platform,
"tool_suggested": t.tool_suggested,
"severity": t.severity,
"suggested_remediation": t.suggested_remediation,
"is_active": t.is_active,
}
for t in db.query(TestTemplate).filter(TestTemplate.source == "custom").all()
]
# ── 6. Users (sanitized — no passwords/tokens) ───────────────────
users = [
{
"username": u.username,
"email": u.email if hasattr(u, "email") else None,
"role": u.role,
"is_active": u.is_active,
"must_change_password": True, # force password reset on new instance # nosec B105
}
for u in db.query(User).order_by(User.username).all()
]
bundle = {
"_meta": {
"version": _EXPORT_VERSION,
"exported_at": datetime.utcnow().isoformat() + "Z",
"exported_by": current_user.username,
"note": (
"Sensitive values (passwords, API tokens, private keys) are REDACTED. "
"Re-enter them manually after import. "
"User passwords are NOT exported — users must reset passwords on first login."
),
},
"system_configs": system_configs,
"webhooks": webhooks,
"sso": sso,
"scoring": scoring,
"custom_templates": templates,
"users": users,
}
filename = f"aegis-config-{datetime.utcnow().strftime('%Y%m%d-%H%M%S')}.json"
return JSONResponse(
content=bundle,
headers={
"Content-Disposition": f'attachment; filename="{filename}"',
"X-Export-Version": _EXPORT_VERSION,
},
)
# ---------------------------------------------------------------------------
# POST /admin/import-config
# ---------------------------------------------------------------------------
@router.post("/import-config")
async def import_config(
request: Request,
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
"""Restore platform configuration from a previously exported JSON bundle.
Idempotent: safe to run multiple times. Existing records are updated,
missing ones are created. REDACTED values are skipped (left as-is).
User passwords are set to a random temp value with must_change_password=True.
"""
try:
bundle = await request.json()
except Exception:
raise HTTPException(status_code=400, detail="Invalid JSON body")
meta = bundle.get("_meta", {})
version = meta.get("version", "unknown")
summary: dict[str, int] = {
"system_configs": 0,
"webhooks": 0,
"custom_templates": 0,
"users_created": 0,
"users_updated": 0,
}
# ── 1. system_configs ────────────────────────────────────────────
for item in bundle.get("system_configs", []):
key = item.get("key")
value = item.get("value")
if not key or value == "[REDACTED]":
continue
row = db.query(SystemConfig).filter(SystemConfig.key == key).first()
if row:
row.value = value
row.description = item.get("description") or row.description
else:
db.add(SystemConfig(key=key, value=value, description=item.get("description")))
summary["system_configs"] += 1
# ── 2. webhooks ──────────────────────────────────────────────────
for item in bundle.get("webhooks", []):
name = item.get("name")
url = item.get("url")
if not name or not url:
continue
existing = db.query(WebhookConfig).filter(WebhookConfig.name == name).first()
if existing:
existing.url = url
existing.events = item.get("events", [])
existing.is_active = item.get("is_active", True)
existing.failure_count = 0
else:
db.add(WebhookConfig(
name=name,
url=url,
secret=None, # never restore secrets
events=item.get("events", []),
is_active=item.get("is_active", True),
created_by=current_user.id,
failure_count=0,
))
summary["webhooks"] += 1
# ── 3. SSO config ────────────────────────────────────────────────
sso_data = bundle.get("sso")
if sso_data:
sso_row = db.query(SsoConfig).first()
if sso_row:
for field, val in sso_data.items():
if val == "[REDACTED]":
continue
if hasattr(sso_row, field):
setattr(sso_row, field, val)
else:
clean = {k: v for k, v in sso_data.items() if v != "[REDACTED]"}
clean.pop("sp_private_key", None)
db.add(SsoConfig(**clean))
# ── 4. Scoring config ────────────────────────────────────────────
scoring_data = bundle.get("scoring")
if scoring_data:
sc = db.query(ScoringConfig).first()
if sc:
for field, val in scoring_data.items():
if hasattr(sc, field) and val is not None:
setattr(sc, field, val)
else:
db.add(ScoringConfig(**scoring_data))
# ── 5. Custom templates ──────────────────────────────────────────
for item in bundle.get("custom_templates", []):
name = item.get("name")
mitre_id = item.get("mitre_technique_id")
if not name or not mitre_id:
continue
existing = (
db.query(TestTemplate)
.filter(TestTemplate.name == name, TestTemplate.source == "custom")
.first()
)
if existing:
for field, val in item.items():
if hasattr(existing, field):
setattr(existing, field, val)
else:
db.add(TestTemplate(**{k: v for k, v in item.items()
if k not in ("id", "created_at")}))
summary["custom_templates"] += 1
# ── 6. Users ─────────────────────────────────────────────────────
import secrets as _secrets
for item in bundle.get("users", []):
username = item.get("username")
if not username:
continue
existing = db.query(User).filter(User.username == username).first()
if existing:
existing.role = item.get("role", existing.role)
existing.is_active = item.get("is_active", existing.is_active)
summary["users_updated"] += 1
else:
# Create with random temp password — user must reset on login
temp_pw = _secrets.token_urlsafe(16) + "Aa1!"
new_user = User(
username=username,
hashed_password=hash_password(temp_pw),
role=item.get("role", "viewer"),
is_active=item.get("is_active", True),
must_change_password=True,
)
if item.get("email") and hasattr(User, "email"):
new_user.email = item["email"]
db.add(new_user)
summary["users_created"] += 1
db.commit()
return {
"status": "ok",
"imported_from_version": version,
"summary": summary,
"warnings": [
"REDACTED values were skipped — re-enter passwords/tokens manually.",
"All imported users have must_change_password=True.",
"SSO private key was not restored — re-upload it manually.",
],
}
-104
View File
@@ -1,104 +0,0 @@
"""Phase 14: API Key management router."""
from typing import List
from uuid import UUID
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
from app.database import get_db
from app.dependencies.auth import get_current_user, require_any_role
from app.models.user import User
from app.schemas.api_key_schema import (
ApiKeyCreate, ApiKeyCreated, ApiKeyOut, ApiKeyUpdate,
)
import app.services.api_key_service as svc
router = APIRouter(prefix="/api-keys", tags=["API Keys"])
@router.post("", response_model=ApiKeyCreated, status_code=201)
def create_key(
body: ApiKeyCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
Create a scoped API key.
The ``raw_key`` field in the response is shown **exactly once** and
cannot be retrieved later. Store it securely.
"""
key, raw_key = svc.create_api_key(
db,
user_id = current_user.id,
name = body.name,
scopes = body.scopes,
description = body.description,
expires_at = body.expires_at,
)
out = ApiKeyOut.model_validate(key)
return ApiKeyCreated(**out.model_dump(), raw_key=raw_key)
@router.get("", response_model=List[ApiKeyOut])
def list_keys(
include_inactive: bool = Query(False),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""List API keys owned by the current user."""
# Admins can see all keys; others only see their own
user_id = None if current_user.role == "admin" else current_user.id
return svc.list_api_keys(db, user_id=user_id, include_inactive=include_inactive)
@router.get("/{key_id}", response_model=ApiKeyOut)
def get_key(
key_id: UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Get a single API key (owner or admin)."""
user_id = None if current_user.role == "admin" else current_user.id
return svc.get_api_key(db, key_id, user_id=user_id)
@router.patch("/{key_id}", response_model=ApiKeyOut)
def update_key(
key_id: UUID,
body: ApiKeyUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Update name, description, scopes, expiry, or active status."""
user_id = None if current_user.role == "admin" else current_user.id
return svc.update_api_key(
db, key_id, user_id,
name = body.name,
description = body.description,
scopes = body.scopes,
expires_at = body.expires_at,
is_active = body.is_active,
)
@router.post("/{key_id}/revoke", response_model=ApiKeyOut)
def revoke_key(
key_id: UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Revoke an API key (soft-delete — sets is_active=False)."""
user_id = None if current_user.role == "admin" else current_user.id
return svc.revoke_api_key(db, key_id, user_id=user_id)
@router.delete("/{key_id}", status_code=204)
def delete_key(
key_id: UUID,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("admin")),
):
"""Permanently delete an API key (admin only)."""
svc.delete_api_key(db, key_id)
-249
View File
@@ -1,249 +0,0 @@
"""Phase 10: Attack Paths & Advanced Purple Team router."""
from typing import Optional
from uuid import UUID
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from app.database import get_db
from app.dependencies.auth import get_current_user, require_any_role
from app.schemas.attack_path_schema import (
AttackPathCreate, AttackPathUpdate, AttackPathOut,
AttackPathStepCreate, AttackPathStepUpdate, AttackPathStepOut,
ExecutionCreate, ExecutionOut,
StepExecuteRequest, StepResultOut,
TimelineEntryCreate, TimelineEntryOut,
)
from app.services import attack_path_service as svc
router = APIRouter(prefix="/attack-paths", tags=["attack-paths"])
# ── Attack Paths CRUD ─────────────────────────────────────────────────────────
@router.post("", response_model=AttackPathOut, status_code=201)
def create_attack_path(
body: AttackPathCreate,
db: Session = Depends(get_db),
user=Depends(get_current_user),
):
return svc.create_attack_path(db, body.model_dump(), user.id)
@router.get("", response_model=list[AttackPathOut])
def list_attack_paths(
is_template: Optional[bool] = None,
technique_id: Optional[UUID] = None,
is_active: Optional[bool] = True,
db: Session = Depends(get_db),
user=Depends(get_current_user),
):
paths = svc.list_attack_paths(db, is_template=is_template,
technique_id=technique_id, is_active=is_active)
# Inject step_count
result = []
for p in paths:
d = AttackPathOut.model_validate(p)
d.step_count = len(p.steps)
result.append(d)
return result
@router.get("/{path_id}", response_model=AttackPathOut)
def get_attack_path(
path_id: UUID,
db: Session = Depends(get_db),
user=Depends(get_current_user),
):
p = svc.get_attack_path(db, path_id)
d = AttackPathOut.model_validate(p)
d.step_count = len(p.steps)
return d
@router.patch("/{path_id}", response_model=AttackPathOut)
def update_attack_path(
path_id: UUID,
body: AttackPathUpdate,
db: Session = Depends(get_db),
user=Depends(get_current_user),
):
return svc.update_attack_path(db, path_id, body.model_dump(exclude_unset=True), user.id)
@router.delete("/{path_id}", status_code=204)
def delete_attack_path(
path_id: UUID,
db: Session = Depends(get_db),
user=Depends(require_any_role("admin", "red_lead", "blue_lead")),
):
svc.delete_attack_path(db, path_id, user.id)
# ── Steps ─────────────────────────────────────────────────────────────────────
@router.get("/{path_id}/steps", response_model=list[AttackPathStepOut])
def list_steps(
path_id: UUID,
db: Session = Depends(get_db),
user=Depends(get_current_user),
):
path = svc.get_attack_path(db, path_id)
return path.steps
@router.post("/{path_id}/steps", response_model=AttackPathStepOut, status_code=201)
def add_step(
path_id: UUID,
body: AttackPathStepCreate,
db: Session = Depends(get_db),
user=Depends(get_current_user),
):
return svc.add_step(db, path_id, body.model_dump(), user.id)
@router.patch("/{path_id}/steps/{step_id}", response_model=AttackPathStepOut)
def update_step(
path_id: UUID,
step_id: UUID,
body: AttackPathStepUpdate,
db: Session = Depends(get_db),
user=Depends(get_current_user),
):
return svc.update_step(db, step_id, body.model_dump(exclude_unset=True), user.id)
@router.delete("/{path_id}/steps/{step_id}", status_code=204)
def delete_step(
path_id: UUID,
step_id: UUID,
db: Session = Depends(get_db),
user=Depends(get_current_user),
):
svc.delete_step(db, step_id, user.id)
@router.post("/{path_id}/steps/reorder", response_model=list[AttackPathStepOut])
def reorder_steps(
path_id: UUID,
step_ids: list[UUID],
db: Session = Depends(get_db),
user=Depends(get_current_user),
):
"""Pass an ordered list of step UUIDs to reorder the steps."""
return svc.reorder_steps(db, path_id, step_ids, user.id)
# ── Executions ────────────────────────────────────────────────────────────────
@router.post("/{path_id}/executions", response_model=ExecutionOut, status_code=201)
def create_execution(
path_id: UUID,
body: ExecutionCreate,
db: Session = Depends(get_db),
user=Depends(get_current_user),
):
return svc.create_execution(db, path_id, body.model_dump(), user.id)
@router.get("/{path_id}/executions", response_model=list[ExecutionOut])
def list_executions(
path_id: UUID,
db: Session = Depends(get_db),
user=Depends(get_current_user),
):
return svc.list_executions(db, path_id)
@router.get("/executions/{execution_id}", response_model=ExecutionOut)
def get_execution(
execution_id: UUID,
db: Session = Depends(get_db),
user=Depends(get_current_user),
):
return svc.get_execution(db, execution_id)
@router.post("/executions/{execution_id}/start", response_model=ExecutionOut)
def start_execution(
execution_id: UUID,
db: Session = Depends(get_db),
user=Depends(get_current_user),
):
return svc.start_execution(db, execution_id, user.id)
@router.post("/executions/{execution_id}/steps/{step_id}", response_model=StepResultOut)
def execute_step(
execution_id: UUID,
step_id: UUID,
body: StepExecuteRequest,
db: Session = Depends(get_db),
user=Depends(get_current_user),
):
"""Record the result of one step (detected / not_detected / skipped)."""
return svc.execute_step(db, execution_id, step_id, body.model_dump(), user.id)
@router.get("/executions/{execution_id}/steps", response_model=list[StepResultOut])
def list_step_results(
execution_id: UUID,
db: Session = Depends(get_db),
user=Depends(get_current_user),
):
ex = svc.get_execution(db, execution_id)
return ex.step_results
@router.post("/executions/{execution_id}/complete", response_model=ExecutionOut)
def complete_execution(
execution_id: UUID,
db: Session = Depends(get_db),
user=Depends(get_current_user),
):
"""Mark execution as complete and compute kill-chain metrics."""
return svc.complete_execution(db, execution_id, user.id)
@router.post("/executions/{execution_id}/abort", response_model=ExecutionOut)
def abort_execution(
execution_id: UUID,
db: Session = Depends(get_db),
user=Depends(require_any_role("admin", "red_lead", "blue_lead")),
):
return svc.abort_execution(db, execution_id, user.id)
# ── Timeline ──────────────────────────────────────────────────────────────────
@router.post("/executions/{execution_id}/timeline",
response_model=TimelineEntryOut, status_code=201)
def add_timeline_entry(
execution_id: UUID,
body: TimelineEntryCreate,
db: Session = Depends(get_db),
user=Depends(get_current_user),
):
return svc.add_timeline_entry(db, execution_id, body.model_dump(), user.id)
@router.get("/executions/{execution_id}/timeline", response_model=list[TimelineEntryOut])
def get_timeline(
execution_id: UUID,
db: Session = Depends(get_db),
user=Depends(get_current_user),
):
return svc.get_timeline(db, execution_id)
# ── Kill-Chain Metrics ────────────────────────────────────────────────────────
@router.get("/executions/{execution_id}/metrics")
def get_metrics(
execution_id: UUID,
db: Session = Depends(get_db),
user=Depends(get_current_user),
):
"""Return full kill-chain metrics for a completed (or partial) execution."""
return svc.get_kill_chain_metrics(db, execution_id)
+8 -66
View File
@@ -16,9 +16,8 @@ from fastapi import APIRouter, Cookie, Depends, Request, Response
# Import OAuth2PasswordRequestForm from fastapi.security # Import OAuth2PasswordRequestForm from fastapi.security
from fastapi.security import OAuth2PasswordRequestForm from fastapi.security import OAuth2PasswordRequestForm
# Import jwt (PyJWT) # Import JWTError, jwt from jose
import jwt from jose import JWTError, jwt
from jwt.exceptions import PyJWTError as JWTError
# Import Session from sqlalchemy.orm # Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -72,16 +71,9 @@ from app.services.auth_service import (
# Assign router = APIRouter(prefix="/auth", tags=["auth"]) # Assign router = APIRouter(prefix="/auth", tags=["auth"])
router = APIRouter(prefix="/auth", tags=["auth"]) router = APIRouter(prefix="/auth", tags=["auth"])
# SECURE_COOKIES desacopla la seguridad de la cookie del entorno de ejecucion. # Assign _IS_HTTPS = os.environ.get("AEGIS_ENV", "").lower() == "production"
# Por defecto activo en produccion; ponlo en "false" para servidores HTTP. _IS_HTTPS = os.environ.get("AEGIS_ENV", "").lower() == "production"
_aegis_env = os.environ.get("AEGIS_ENV", "development").lower() # Assign _COOKIE_NAME = "aegis_token"
_secure_cookie_env = os.environ.get("SECURE_COOKIES", "auto").lower()
if _secure_cookie_env == "false":
_IS_HTTPS = False
elif _secure_cookie_env == "true":
_IS_HTTPS = True
else: # "auto" — activo solo si AEGIS_ENV=production
_IS_HTTPS = _aegis_env == "production"
_COOKIE_NAME = "aegis_token" _COOKIE_NAME = "aegis_token"
@@ -242,8 +234,8 @@ def logout(
if jti: if jti:
# Call blacklist_token() # Call blacklist_token()
blacklist_token(jti, float(exp)) blacklist_token(jti, float(exp))
# Handle any JWT validation error during logout (token may be expired or malformed) # Handle JWTError
except jwt.exceptions.InvalidTokenError: except JWTError:
# Intentional no-op placeholder # Intentional no-op placeholder
pass pass
@@ -264,57 +256,7 @@ def logout(
return {"detail": "Logged out"} return {"detail": "Logged out"}
@router.post("/refresh", response_model=TokenResponse) # Apply the @router.get decorator
def refresh_token(
response: Response,
aegis_token: str | None = Cookie(None),
db: Session = Depends(get_db),
):
"""Issue a new access token if the current one is valid.
Called automatically by the frontend when it detects an expired
session while the user is actively using the app. If the current
cookie token is still valid (not blacklisted, not expired), a fresh
token is issued and the cookie is renewed — keeping the session alive
without requiring re-authentication.
"""
if not aegis_token:
raise PermissionViolation("No active session")
try:
payload = jwt.decode(
aegis_token,
settings.SECRET_KEY,
algorithms=[settings.ALGORITHM],
)
except JWTError:
raise PermissionViolation("Session expired — please log in again")
username: str | None = payload.get("sub")
if not username:
raise PermissionViolation("Invalid session")
user = db.query(User).filter(User.username == username).first()
if user is None or not user.is_active:
raise PermissionViolation("Account not found or disabled")
if getattr(user, "must_change_password", False):
raise PermissionViolation("Password change required before refreshing session")
# Issue a fresh token with a new expiry
new_token = create_access_token(data={"sub": user.username})
response.set_cookie(
key=_COOKIE_NAME,
value=new_token,
httponly=True,
secure=_IS_HTTPS,
samesite="strict",
max_age=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
path="/",
)
return TokenResponse(access_token=new_token)
@router.get("/me", response_model=UserOut) @router.get("/me", response_model=UserOut)
# Define function read_current_user # Define function read_current_user
def read_current_user(current_user: User = Depends(get_current_user)) -> UserOut: def read_current_user(current_user: User = Depends(get_current_user)) -> UserOut:
+37 -206
View File
@@ -9,7 +9,8 @@ import logging
# Import uuid # Import uuid
import uuid import uuid
from datetime import datetime
# Import Optional from typing
from typing import Optional from typing import Optional
# Import APIRouter, Depends, Query from fastapi # Import APIRouter, Depends, Query from fastapi
@@ -32,9 +33,16 @@ from app.domain.unit_of_work import UnitOfWork
# Import User from app.models.user # Import User from app.models.user
from app.models.user import User from app.models.user import User
from app.models.campaign import Campaign, CampaignTest
from app.models.test import Test # Import log_action from app.services.audit_service
from app.services.campaign_service import generate_campaign_from_threat_actor from app.services.audit_service import log_action
# Import from app.services.campaign_crud_service
from app.services.campaign_crud_service import (
activate_campaign as crud_activate,
)
# Import from app.services.campaign_crud_service
from app.services.campaign_crud_service import ( from app.services.campaign_crud_service import (
add_test_to_campaign as crud_add_test, add_test_to_campaign as crud_add_test,
) )
@@ -47,7 +55,10 @@ from app.services.campaign_crud_service import (
# Import from app.services.campaign_crud_service # Import from app.services.campaign_crud_service
from app.services.campaign_crud_service import ( from app.services.campaign_crud_service import (
create_campaign as crud_create, create_campaign as crud_create,
delete_campaign as crud_delete, )
# Import from app.services.campaign_crud_service
from app.services.campaign_crud_service import (
get_campaign_detail as crud_get_detail, get_campaign_detail as crud_get_detail,
) )
@@ -86,17 +97,11 @@ from app.services.campaign_crud_service import (
update_campaign as crud_update, update_campaign as crud_update,
) )
# Import activate_campaign from app.services.campaign_crud_service # Import generate_campaign_from_threat_actor from app.services.campaign_service
from app.services.campaign_crud_service import ( from app.services.campaign_service import generate_campaign_from_threat_actor
activate_campaign as crud_activate,
)
# Import log_action from app.services.audit_service
from app.services.audit_service import log_action
# Import notify_role from app.services.notification_service # Import notify_role from app.services.notification_service
from app.services.notification_service import notify_role from app.services.notification_service import notify_role
from app.services.webhook_service import dispatch_webhook
# Assign logger = logging.getLogger(__name__) # Assign logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -124,7 +129,6 @@ class CampaignCreate(BaseModel):
tags: Optional[list[str]] = Field(default_factory=list) tags: Optional[list[str]] = Field(default_factory=list)
# Assign scheduled_at = None # Assign scheduled_at = None
scheduled_at: Optional[str] = None scheduled_at: Optional[str] = None
start_date: Optional[str] = None # ISO date — campaign won't activate before this
# Define class CampaignUpdate # Define class CampaignUpdate
@@ -143,7 +147,6 @@ class CampaignUpdate(BaseModel):
tags: Optional[list[str]] = None tags: Optional[list[str]] = None
# Assign scheduled_at = None # Assign scheduled_at = None
scheduled_at: Optional[str] = None scheduled_at: Optional[str] = None
start_date: Optional[str] = None # ISO date — can be updated while still in draft
# Define class AddTestPayload # Define class AddTestPayload
@@ -195,7 +198,7 @@ def list_campaigns(
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user # Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> dict: ) -> list:
"""List campaigns with optional filters and pagination. """List campaigns with optional filters and pagination.
Args: Args:
@@ -274,9 +277,8 @@ def create_campaign(
tags=payload.tags, tags=payload.tags,
# Keyword argument: scheduled_at # Keyword argument: scheduled_at
scheduled_at=payload.scheduled_at, scheduled_at=payload.scheduled_at,
start_date=payload.start_date,
) )
campaign_id = result["id"] # Call log_action()
log_action( log_action(
db, db,
# Keyword argument: user_id # Keyword argument: user_id
@@ -285,7 +287,9 @@ def create_campaign(
action="create_campaign", action="create_campaign",
# Keyword argument: entity_type # Keyword argument: entity_type
entity_type="campaign", entity_type="campaign",
entity_id=campaign_id, # Keyword argument: entity_id
entity_id=result["id"],
# Keyword argument: details
details={"name": payload.name, "type": payload.type}, details={"name": payload.name, "type": payload.type},
) )
# Call uow.commit() # Call uow.commit()
@@ -385,37 +389,6 @@ def update_campaign(
return result return result
# ---------------------------------------------------------------------------
# DELETE /campaigns/{id} — Delete campaign
# ---------------------------------------------------------------------------
@router.delete("/{campaign_id}", status_code=204)
def delete_campaign(
campaign_id: str,
delete_tests: bool = Query(False, description="Also delete associated tests"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Delete a campaign. Only draft campaigns can be deleted (admins can delete any)."""
with UnitOfWork(db) as uow:
crud_delete(
db,
campaign_id,
deleter_id=current_user.id,
deleter_role=current_user.role,
delete_tests=delete_tests,
)
log_action(
db,
user_id=current_user.id,
action="delete_campaign",
entity_type="campaign",
entity_id=campaign_id,
details={"delete_tests": delete_tests},
)
uow.commit()
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# POST /campaigns/{id}/tests — Add test to campaign # POST /campaigns/{id}/tests — Add test to campaign
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -460,7 +433,7 @@ def add_test_to_campaign(
) )
# Call uow.commit() # Call uow.commit()
uow.commit() uow.commit()
# Return result
return result return result
@@ -510,36 +483,22 @@ def remove_test_from_campaign(
def activate_campaign( def activate_campaign(
# Entry: campaign_id # Entry: campaign_id
campaign_id: str, campaign_id: str,
force: bool = Query(False, description="Activate even if start_date is in the future"), # Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user # Entry: current_user
current_user: User = Depends(require_any_role("red_lead", "blue_lead")), current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
): ) -> dict:
"""Activate a campaign, moving it from draft to active. """Activate a campaign, moving it from draft to active.
If the campaign has a start_date in the future and force=False, returns a 409 Args:
with a warning so the frontend can show a confirmation modal. If force=True, campaign_id (str): UUID string of the campaign to activate.
activates immediately regardless of start_date. db (Session): SQLAlchemy database session.
""" current_user (User): Authenticated red_lead or blue_lead activating the campaign.
from fastapi import HTTPException
campaign_obj = db.query(Campaign).filter(Campaign.id == campaign_id).first()
if campaign_obj and campaign_obj.start_date and not force:
now = datetime.utcnow()
if campaign_obj.start_date > now:
raise HTTPException(
status_code=409,
detail={
"code": "start_date_in_future",
"start_date": campaign_obj.start_date.strftime("%Y-%m-%d"),
"message": (
f"This campaign is scheduled to start on "
f"{campaign_obj.start_date.strftime('%d %b %Y')}. "
f"It will activate automatically on that date. "
f"Do you want to activate it now anyway?"
),
},
)
Returns:
dict: Serialised representation of the activated campaign.
"""
# Open context manager
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
# Assign campaign = crud_activate(db, campaign_id) # Assign campaign = crud_activate(db, campaign_id)
campaign = crud_activate(db, campaign_id) campaign = crud_activate(db, campaign_id)
@@ -578,33 +537,7 @@ def activate_campaign(
# Reload ORM object attributes from the database # Reload ORM object attributes from the database
db.refresh(campaign) db.refresh(campaign)
# Create Jira tickets for campaign and tests at activation time (non-fatal). # Return serialize_campaign(db, campaign)
# Campaign ticket is created here if it doesn't already exist (deferred from creation).
try:
from app.services.jira_service import (
auto_create_campaign_issue,
auto_create_test_issue,
get_campaign_jira_key,
get_test_jira_key,
)
campaign_jira_key = get_campaign_jira_key(db, campaign_id)
if not campaign_jira_key:
campaign_jira_key = auto_create_campaign_issue(db, campaign, current_user)
if campaign_jira_key:
for ct in campaign.campaign_tests:
if ct.test and not get_test_jira_key(db, ct.test.id):
auto_create_test_issue(
db, ct.test, current_user,
parent_ticket_override=campaign_jira_key,
campaign_start_date=campaign.start_date,
)
db.commit()
except Exception:
logger.exception(
"Jira ticket creation failed during activation of campaign %s",
campaign_id,
)
return serialize_campaign(db, campaign) return serialize_campaign(db, campaign)
@@ -654,7 +587,6 @@ def complete_campaign(
uow.commit() uow.commit()
# Reload ORM object attributes from the database # Reload ORM object attributes from the database
db.refresh(campaign) db.refresh(campaign)
dispatch_webhook("campaign.completed", {"campaign_id": str(campaign.id), "name": campaign.name})
# Return serialize_campaign(db, campaign) # Return serialize_campaign(db, campaign)
return serialize_campaign(db, campaign) return serialize_campaign(db, campaign)
@@ -692,16 +624,12 @@ def get_campaign_progress_endpoint(
# POST /campaigns/from-threat-actor/{actor_id} — Auto-generate campaign # POST /campaigns/from-threat-actor/{actor_id} — Auto-generate campaign
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class GenerateFromActorPayload(BaseModel):
start_date: Optional[str] = None # ISO date YYYY-MM-DD
@router.post("/from-threat-actor/{actor_id}", status_code=201) @router.post("/from-threat-actor/{actor_id}", status_code=201)
# Define function generate_campaign_from_actor # Define function generate_campaign_from_actor
def generate_campaign_from_actor( def generate_campaign_from_actor(
# Entry: actor_id # Entry: actor_id
actor_id: str, actor_id: str,
payload: GenerateFromActorPayload = GenerateFromActorPayload(), # Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user # Entry: current_user
current_user: User = Depends(require_any_role("red_lead", "blue_lead")), current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
@@ -719,14 +647,11 @@ def generate_campaign_from_actor(
Returns: Returns:
dict: Serialised representation of the newly generated campaign. dict: Serialised representation of the newly generated campaign.
""" """
start_date_parsed = ( # Assign campaign = generate_campaign_from_threat_actor(
datetime.fromisoformat(payload.start_date) if payload.start_date else None
)
campaign = generate_campaign_from_threat_actor( campaign = generate_campaign_from_threat_actor(
db, db,
uuid.UUID(actor_id), uuid.UUID(actor_id),
current_user, current_user,
start_date=start_date_parsed,
) )
# Open context manager # Open context manager
@@ -854,97 +779,3 @@ def get_campaign_history(
""" """
# Return crud_get_history(db, campaign_id) # Return crud_get_history(db, campaign_id)
return crud_get_history(db, campaign_id) return crud_get_history(db, campaign_id)
# ---------------------------------------------------------------------------
# GET /campaigns/{id}/timing-summary — Aggregated timing across campaign tests
# ---------------------------------------------------------------------------
def _seconds_between(start: datetime | None, end: datetime | None) -> int:
"""Return elapsed seconds between two datetimes; 0 if either is None."""
if not start or not end:
return 0
diff = (end - start).total_seconds()
return max(0, int(diff))
@router.get("/{campaign_id}/timing-summary")
def get_campaign_timing_summary(
campaign_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Return aggregated Red/Blue timing metrics for all tests in a campaign.
For each test we calculate:
- red_execution_secs : red_started_at → blue_started_at (minus red_paused_seconds)
- blue_queue_secs : blue_started_at → blue_work_started_at (waiting for Blue pick-up)
- blue_evaluation_secs: blue_work_started_at → first validation timestamp (minus blue_paused_seconds)
- total_secs : sum of the three phases
Returns totals + per-test breakdown.
"""
# Load campaign
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
if not campaign:
from fastapi import HTTPException
raise HTTPException(status_code=404, detail="Campaign not found")
# Load all tests for this campaign
test_ids = [
ct.test_id
for ct in db.query(CampaignTest).filter(CampaignTest.campaign_id == campaign.id).all()
]
tests = db.query(Test).filter(Test.id.in_(test_ids)).all() if test_ids else []
breakdown = []
total_red = 0
total_queue = 0
total_blue = 0
for t in tests:
# Red execution: from start-execution to submit-to-blue, minus paused time
red_secs = max(
0,
_seconds_between(t.red_started_at, t.blue_started_at) - (t.red_paused_seconds or 0),
)
# Blue queue: from receiving the test to actually starting evaluation
queue_secs = _seconds_between(t.blue_started_at, t.blue_work_started_at)
# Blue evaluation: from starting evaluation to first validation, minus paused time
eval_end = t.red_validated_at or t.blue_validated_at
blue_secs = max(
0,
_seconds_between(t.blue_work_started_at, eval_end) - (t.blue_paused_seconds or 0),
)
total_red += red_secs
total_queue += queue_secs
total_blue += blue_secs
breakdown.append({
"test_id": str(t.id),
"test_name": t.name,
"state": t.state.value if t.state else None,
"red_execution_secs": red_secs,
"blue_queue_secs": queue_secs,
"blue_evaluation_secs": blue_secs,
"total_secs": red_secs + queue_secs + blue_secs,
"has_timing": bool(t.red_started_at),
})
total_secs = total_red + total_queue + total_blue
return {
"campaign_id": campaign_id,
"campaign_name": campaign.name,
"tests_total": len(tests),
"tests_with_timing": sum(1 for b in breakdown if b["has_timing"]),
"red_execution_secs": total_red,
"blue_queue_secs": total_queue,
"blue_evaluation_secs": total_blue,
"total_secs": total_secs,
"breakdown": sorted(breakdown, key=lambda x: -(x["total_secs"])),
}
-33
View File
@@ -26,9 +26,6 @@ from app.models.user import User
# Import from app.services.compliance_import_service # Import from app.services.compliance_import_service
from app.services.compliance_import_service import ( from app.services.compliance_import_service import (
import_cis_controls_v8_mappings, import_cis_controls_v8_mappings,
import_dora_mappings,
import_iso_27001_mappings,
import_iso_42001_mappings,
import_nist_800_53_mappings, import_nist_800_53_mappings,
) )
@@ -235,33 +232,3 @@ def import_cis(
result = import_cis_controls_v8_mappings(db) result = import_cis_controls_v8_mappings(db)
# Return result # Return result
return result return result
@router.post("/import/dora")
def import_dora(
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
"""Import DORA (EU 2022/2554) compliance mappings (admin only)."""
result = import_dora_mappings(db)
return result
@router.post("/import/iso-27001")
def import_iso27001(
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
"""Import ISO/IEC 27001:2022 Annex A compliance mappings (admin only)."""
result = import_iso_27001_mappings(db)
return result
@router.post("/import/iso-42001")
def import_iso42001(
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
"""Import ISO/IEC 42001:2023 AI Management System compliance mappings (admin only)."""
result = import_iso_42001_mappings(db)
return result
+2 -2
View File
@@ -64,7 +64,7 @@ def list_defensive_techniques(
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user # Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> dict: ) -> list:
"""List all D3FEND defensive techniques with optional filters.""" """List all D3FEND defensive techniques with optional filters."""
# Return list_defensive_techniques_svc( # Return list_defensive_techniques_svc(
return list_defensive_techniques_svc( return list_defensive_techniques_svc(
@@ -102,7 +102,7 @@ def get_defenses_for_attack_technique_endpoint(
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user # Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> dict: ) -> list:
"""Get all D3FEND defensive techniques mapped to a given ATT&CK technique.""" """Get all D3FEND defensive techniques mapped to a given ATT&CK technique."""
# Return get_defenses_for_attack_technique(db, mitre_id) # Return get_defenses_for_attack_technique(db, mitre_id)
return get_defenses_for_attack_technique(db, mitre_id) return get_defenses_for_attack_technique(db, mitre_id)
-319
View File
@@ -1,319 +0,0 @@
"""Detection Lifecycle Management router."""
import hashlib
from datetime import datetime, timedelta
from typing import Optional
from uuid import UUID
from fastapi import APIRouter, Depends, Query
from sqlalchemy import func
from sqlalchemy.orm import Session
from app.database import get_db
from app.dependencies.auth import get_current_user, require_any_role
from app.domain.exceptions import EntityNotFoundError
from app.models.detection_lifecycle import (
DetectionAsset, DetectionTechniqueMapping, DetectionValidation,
TechniqueConfidenceScore, InfrastructureChangeLog,
)
from app.schemas.detection_lifecycle_schema import (
DetectionAssetCreate, DetectionAssetUpdate, DetectionAssetOut,
DetectionValidationCreate, DetectionValidationOut,
TechniqueConfidenceOut,
InfrastructureChangeCreate, InfrastructureChangeOut,
)
from app.services import detection_asset_service, decay_engine_service, audit_service
router = APIRouter(prefix="/detection-lifecycle", tags=["detection-lifecycle"])
def _now() -> datetime:
return datetime.utcnow()
# ── Detection Assets ─────────────────────────────────────────────────────────
@router.post("/assets", response_model=DetectionAssetOut, status_code=201)
def create_asset(body: DetectionAssetCreate, db: Session = Depends(get_db), user=Depends(get_current_user)):
asset = detection_asset_service.create_detection_asset(db, body.model_dump(), user.id)
return asset
@router.get("/assets", response_model=list[DetectionAssetOut])
def list_assets(
platform: Optional[str] = None,
asset_type: Optional[str] = None,
health_status: Optional[str] = None,
technique_id: Optional[UUID] = None,
is_active: Optional[bool] = True,
db: Session = Depends(get_db),
user=Depends(get_current_user),
):
return detection_asset_service.list_assets(db, platform=platform, asset_type=asset_type, health_status=health_status, technique_id=technique_id, is_active=is_active)
@router.get("/assets/{asset_id}", response_model=DetectionAssetOut)
def get_asset(asset_id: UUID, db: Session = Depends(get_db), user=Depends(get_current_user)):
return detection_asset_service.get_asset_with_details(db, asset_id)
@router.patch("/assets/{asset_id}", response_model=DetectionAssetOut)
def update_asset(asset_id: UUID, body: DetectionAssetUpdate, db: Session = Depends(get_db), user=Depends(get_current_user)):
return detection_asset_service.update_detection_asset(db, asset_id, body.model_dump(exclude_unset=True), user.id)
@router.delete("/assets/{asset_id}", status_code=204)
def delete_asset(asset_id: UUID, db: Session = Depends(get_db), user=Depends(require_any_role("red_lead", "blue_lead"))):
asset = db.query(DetectionAsset).filter(DetectionAsset.id == asset_id).first()
if not asset:
raise EntityNotFoundError("DetectionAsset", str(asset_id))
asset.is_active = False
db.commit()
# ── Technique Mappings ───────────────────────────────────────────────────────
@router.post("/assets/{asset_id}/techniques/{technique_id}")
def map_technique(
asset_id: UUID, technique_id: UUID,
coverage_type: str = Query("detect"),
confidence_level: str = Query("medium"),
db: Session = Depends(get_db),
user=Depends(get_current_user),
):
# Validate asset exists
asset = db.query(DetectionAsset).filter(DetectionAsset.id == asset_id).first()
if not asset:
raise EntityNotFoundError("DetectionAsset", str(asset_id))
# Prevent duplicate mappings
existing = db.query(DetectionTechniqueMapping).filter(
DetectionTechniqueMapping.detection_asset_id == asset_id,
DetectionTechniqueMapping.technique_id == technique_id,
).first()
if existing:
# Update coverage/confidence on existing mapping instead of duplicating
existing.coverage_type = coverage_type
existing.confidence_level = confidence_level
db.commit()
return {"message": "Technique mapping updated", "mapping_id": str(existing.id)}
mapping = DetectionTechniqueMapping(
detection_asset_id=asset_id, technique_id=technique_id,
coverage_type=coverage_type, confidence_level=confidence_level,
)
db.add(mapping)
db.commit()
return {"message": "Technique mapped", "mapping_id": str(mapping.id)}
@router.get("/techniques/{technique_id}/detections")
def get_technique_detections(technique_id: UUID, db: Session = Depends(get_db), user=Depends(get_current_user)):
return detection_asset_service.get_technique_detection_summary(db, technique_id)
# ── Validations ──────────────────────────────────────────────────────────────
@router.post("/validations", response_model=DetectionValidationOut, status_code=201)
def create_validation(body: DetectionValidationCreate, db: Session = Depends(get_db), user=Depends(get_current_user)):
asset = db.query(DetectionAsset).filter(DetectionAsset.id == body.detection_asset_id).first()
if not asset:
raise EntityNotFoundError("DetectionAsset", str(body.detection_asset_id))
now = _now()
validation = DetectionValidation(
detection_asset_id=body.detection_asset_id,
technique_id=body.technique_id,
test_id=body.test_id,
validation_result=body.validation_result,
validation_method=body.validation_method,
notes=body.notes,
evidence_ids=[str(e) for e in (body.evidence_ids or [])],
validated_by=user.id,
validated_at=now,
expires_at=now + timedelta(days=body.validity_days),
rule_hash_at_validation=asset.rule_hash,
log_source_version_at_validation=asset.log_source_version,
infrastructure_hash_at_validation=asset.infrastructure_hash,
)
data = f"{validation.detection_asset_id}:{validation.validated_by}:{validation.validation_result}:{validation.validated_at}"
validation.integrity_hash = hashlib.sha256(data.encode()).hexdigest()
db.add(validation)
db.commit()
db.refresh(validation)
if body.technique_id:
decay_engine_service.calculate_confidence_for_technique(db, body.technique_id)
audit_service.log_action(db, user.id, "DETECTION_VALIDATED", "detection_validation", str(validation.id),
details={"asset_id": str(body.detection_asset_id), "result": body.validation_result, "validity_days": body.validity_days})
return validation
@router.get("/validations", response_model=list[DetectionValidationOut])
def list_validations(
asset_id: Optional[UUID] = None,
technique_id: Optional[UUID] = None,
is_valid: Optional[bool] = None,
db: Session = Depends(get_db),
user=Depends(get_current_user),
):
query = db.query(DetectionValidation)
if asset_id:
query = query.filter(DetectionValidation.detection_asset_id == asset_id)
if technique_id:
query = query.filter(DetectionValidation.technique_id == technique_id)
if is_valid is not None:
query = query.filter(DetectionValidation.is_valid == is_valid)
return query.order_by(DetectionValidation.validated_at.desc()).all()
@router.post("/validations/{validation_id}/invalidate")
def invalidate_validation(
validation_id: UUID,
reason: str = Query(...),
details: Optional[str] = None,
db: Session = Depends(get_db),
user=Depends(require_any_role("admin", "blue_lead")),
):
validation = db.query(DetectionValidation).filter(DetectionValidation.id == validation_id).first()
if not validation:
raise EntityNotFoundError("DetectionValidation", str(validation_id))
from app.models.detection_lifecycle import InvalidationReason
try:
reason_enum = InvalidationReason(reason)
except ValueError:
reason_enum = InvalidationReason.manual
validation.is_valid = False
validation.invalidated_at = _now()
validation.invalidation_reason = reason_enum
validation.invalidation_details = details
validation.invalidated_by = user.id
db.commit()
return {"message": "Validation invalidated"}
# ── Confidence Scores ────────────────────────────────────────────────────────
@router.get("/confidence", response_model=list[TechniqueConfidenceOut])
def list_confidence_scores(
confidence_level: Optional[str] = None,
min_score: Optional[float] = None,
max_score: Optional[float] = None,
db: Session = Depends(get_db),
user=Depends(get_current_user),
):
query = db.query(TechniqueConfidenceScore)
if confidence_level:
query = query.filter(TechniqueConfidenceScore.confidence_level == confidence_level)
if min_score is not None:
query = query.filter(TechniqueConfidenceScore.confidence_score >= min_score)
if max_score is not None:
query = query.filter(TechniqueConfidenceScore.confidence_score <= max_score)
return query.order_by(TechniqueConfidenceScore.confidence_score.asc()).all()
@router.get("/confidence/{technique_id}", response_model=TechniqueConfidenceOut)
def get_technique_confidence(
technique_id: UUID,
recalculate: bool = Query(False),
db: Session = Depends(get_db),
user=Depends(get_current_user),
):
if recalculate:
return decay_engine_service.calculate_confidence_for_technique(db, technique_id)
score = db.query(TechniqueConfidenceScore).filter(TechniqueConfidenceScore.technique_id == technique_id).first()
if not score:
return decay_engine_service.calculate_confidence_for_technique(db, technique_id)
return score
# ── Infrastructure Changes ───────────────────────────────────────────────────
@router.post("/infrastructure-changes", response_model=InfrastructureChangeOut, status_code=201)
def report_infrastructure_change(
body: InfrastructureChangeCreate,
db: Session = Depends(get_db),
user=Depends(require_any_role("admin", "blue_lead")),
):
change = InfrastructureChangeLog(
change_type=body.change_type,
description=body.description,
affected_platforms=body.affected_platforms,
affected_log_sources=body.affected_log_sources,
change_date=body.change_date or _now(),
auto_invalidate=body.auto_invalidate,
reported_by=user.id,
)
db.add(change)
db.commit()
db.refresh(change)
if change.auto_invalidate:
decay_engine_service.process_infrastructure_change(db, change.id)
db.refresh(change)
audit_service.log_action(db, user.id, "INFRASTRUCTURE_CHANGE_REPORTED", "infrastructure_change", str(change.id),
details={"type": body.change_type, "invalidated_count": change.invalidated_count})
return change
@router.get("/infrastructure-changes", response_model=list[InfrastructureChangeOut])
def list_infrastructure_changes(
days: int = Query(90, ge=1, le=730),
db: Session = Depends(get_db),
user=Depends(get_current_user),
):
cutoff = _now() - timedelta(days=days)
return db.query(InfrastructureChangeLog).filter(InfrastructureChangeLog.change_date >= cutoff).order_by(InfrastructureChangeLog.change_date.desc()).all()
# ── Decay Engine Control ─────────────────────────────────────────────────────
@router.post("/decay-engine/run")
def trigger_decay_engine(db: Session = Depends(get_db), user=Depends(require_any_role("admin"))):
results = decay_engine_service.run_decay_engine(db)
return {"message": "Decay engine completed", "results": results}
# ── Dashboard ────────────────────────────────────────────────────────────────
@router.get("/dashboard")
def lifecycle_dashboard(db: Session = Depends(get_db), user=Depends(get_current_user)):
now = _now()
health_dist = dict(
db.query(DetectionAsset.health_status, func.count(DetectionAsset.id))
.filter(DetectionAsset.is_active == True)
.group_by(DetectionAsset.health_status)
.all()
)
confidence_dist = dict(
db.query(TechniqueConfidenceScore.confidence_level, func.count(TechniqueConfidenceScore.id))
.group_by(TechniqueConfidenceScore.confidence_level)
.all()
)
expiring_soon = db.query(func.count(DetectionValidation.id)).filter(
DetectionValidation.is_valid == True,
DetectionValidation.expires_at <= (now + timedelta(days=7)),
).scalar() or 0
total_assets = db.query(func.count(DetectionAsset.id)).filter(DetectionAsset.is_active == True).scalar() or 0
total_valid = db.query(func.count(DetectionValidation.id)).filter(DetectionValidation.is_valid == True).scalar() or 0
recent_changes = db.query(func.count(InfrastructureChangeLog.id)).filter(
InfrastructureChangeLog.change_date >= (now - timedelta(days=30))
).scalar() or 0
return {
"total_detection_assets": total_assets,
"total_valid_validations": total_valid,
"health_distribution": {k.value if hasattr(k, "value") else str(k): v for k, v in health_dist.items()},
"confidence_distribution": {k.value if hasattr(k, "value") else str(k): v for k, v in confidence_dist.items()},
"validations_expiring_7d": expiring_soon,
"infrastructure_changes_30d": recent_changes,
}
+3 -3
View File
@@ -80,7 +80,7 @@ def list_detection_rules(
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user # Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> dict: ) -> list:
"""List detection rules with optional filters and pagination.""" """List detection rules with optional filters and pagination."""
# Return list_rules( # Return list_rules(
return list_rules( return list_rules(
@@ -112,7 +112,7 @@ def get_detection_rules_for_template(
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user # Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> dict: ) -> list:
"""Get detection rules associated with a test template.""" """Get detection rules associated with a test template."""
# Return get_rules_for_template(db, template_id) # Return get_rules_for_template(db, template_id)
return get_rules_for_template(db, template_id) return get_rules_for_template(db, template_id)
@@ -151,7 +151,7 @@ def get_detection_rules_for_test(
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user # Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> dict: ) -> list:
"""Get detection rules relevant to a test, along with their evaluation results. """Get detection rules relevant to a test, along with their evaluation results.
Finds rules by matching the test's technique_id to detection rules, Finds rules by matching the test's technique_id to detection rules,
+19 -88
View File
@@ -4,8 +4,7 @@ Endpoints
--------- ---------
POST /tests/{test_id}/evidence — upload evidence (with team=red/blue) POST /tests/{test_id}/evidence — upload evidence (with team=red/blue)
GET /tests/{test_id}/evidence — list evidences (filterable by team) GET /tests/{test_id}/evidence — list evidences (filterable by team)
GET /evidence/{id}metadata + download_url GET /evidence/{id}presigned download URL
GET /evidence/{id}/file — proxy download (streams file through backend)
DELETE /evidence/{id} — delete evidence (only in editable states) DELETE /evidence/{id} — delete evidence (only in editable states)
Access Control Access Control
@@ -22,17 +21,20 @@ Access Control
# Import hashlib # Import hashlib
import hashlib import hashlib
import logging
# Import os
import os import os
# Import uuid # Import uuid
import uuid as _uuid import uuid as _uuid
from datetime import datetime
# Import Optional from typing
from typing import Optional from typing import Optional
# Import APIRouter, Depends, File, Form, Query, Request,... from fastapi # Import APIRouter, Depends, File, Form, Query, Request,... from fastapi
from fastapi import APIRouter, Depends, File, Form, Query, Request, UploadFile, status from fastapi import APIRouter, Depends, File, Form, Query, Request, UploadFile, status
from fastapi.responses import StreamingResponse
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
# Import get_db from app.database # Import get_db from app.database
@@ -72,9 +74,9 @@ from app.services.evidence_service import (
validate_file, validate_file,
validate_upload_permission, validate_upload_permission,
) )
from app.storage import download_file, upload_file
logger = logging.getLogger(__name__) # Import get_presigned_url, upload_file from app.storage
from app.storage import get_presigned_url, upload_file
# Assign router = APIRouter(tags=["evidence"]) # Assign router = APIRouter(tags=["evidence"])
router = APIRouter(tags=["evidence"]) router = APIRouter(tags=["evidence"])
@@ -85,11 +87,8 @@ router = APIRouter(tags=["evidence"])
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def _evidence_to_out(evidence: Evidence) -> EvidenceOut: def _evidence_to_out(evidence: Evidence) -> EvidenceOut:
"""Convert an ORM ``Evidence`` to the API schema. """Convert an ORM ``Evidence`` to the API schema, injecting a presigned URL."""
# Return EvidenceOut(
``download_url`` points to the backend proxy endpoint so the browser
never needs direct access to MinIO.
"""
return EvidenceOut( return EvidenceOut(
# Keyword argument: id # Keyword argument: id
id=evidence.id, id=evidence.id,
@@ -107,7 +106,8 @@ def _evidence_to_out(evidence: Evidence) -> EvidenceOut:
team=evidence.team, team=evidence.team,
# Keyword argument: notes # Keyword argument: notes
notes=evidence.notes, notes=evidence.notes,
download_url=f"/api/v1/evidence/{evidence.id}/file", # Keyword argument: download_url
download_url=get_presigned_url(evidence.file_path),
) )
@@ -185,7 +185,7 @@ async def upload_evidence(
sha256_hash=sha256, sha256_hash=sha256,
# Keyword argument: uploaded_by # Keyword argument: uploaded_by
uploaded_by=current_user.id, uploaded_by=current_user.id,
uploaded_at=datetime.utcnow(), # set explicitly — DB column has no server default # Keyword argument: team
team=team, team=team,
# Keyword argument: notes # Keyword argument: notes
notes=notes, notes=notes,
@@ -222,43 +222,10 @@ async def upload_evidence(
# Reload ORM object attributes from the database # Reload ORM object attributes from the database
db.refresh(evidence) db.refresh(evidence)
# 7. Attach to Jira ticket if one exists (non-fatal) # Return _evidence_to_out(evidence)
_attach_evidence_to_jira(db, test_id, content, safe_name, current_user)
return _evidence_to_out(evidence) return _evidence_to_out(evidence)
def _attach_evidence_to_jira(
db,
test_id: _uuid.UUID,
content: bytes,
file_name: str,
actor,
) -> None:
"""Attach uploaded evidence to the linked Jira ticket (non-fatal)."""
try:
from app.services.jira_service import get_test_jira_key, get_user_jira_client, has_jira_configured
if not has_jira_configured(actor, db):
return
issue_key = get_test_jira_key(db, test_id)
if not issue_key:
return
import io
jira = get_user_jira_client(actor, db)
buf = io.BytesIO(content)
buf.name = file_name # requests uses .name as the multipart filename
jira.add_attachment_object(issue_key, buf)
import logging
logging.getLogger(__name__).info(
"Attached evidence '%s' to Jira ticket %s", file_name, issue_key
)
except Exception as exc:
import logging
logging.getLogger(__name__).warning(
"Failed to attach evidence '%s' to Jira: %s", file_name, exc, exc_info=True
)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# GET /tests/{test_id}/evidence — list (with optional team filter) # GET /tests/{test_id}/evidence — list (with optional team filter)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -286,7 +253,7 @@ def list_evidence(
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# GET /evidence/{id} — metadata + proxy download URL # GET /evidence/{id} — presigned download URL
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -299,50 +266,14 @@ def get_evidence(
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user # Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ) -> EvidenceOut:
"""Return evidence metadata. ``download_url`` is a backend proxy URL.""" """Return evidence metadata together with a presigned download URL."""
# Assign evidence = get_evidence_or_raise(db, evidence_id)
evidence = get_evidence_or_raise(db, evidence_id) evidence = get_evidence_or_raise(db, evidence_id)
# Return _evidence_to_out(evidence) # Return _evidence_to_out(evidence)
return _evidence_to_out(evidence) return _evidence_to_out(evidence)
# ---------------------------------------------------------------------------
# GET /evidence/{id}/file — proxy download (streams file via backend)
# ---------------------------------------------------------------------------
@router.get("/evidence/{evidence_id}/file")
def download_evidence_file(
evidence_id: _uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Stream the evidence file through the backend.
The browser calls this endpoint (authenticated via JWT cookie/header).
The backend fetches the file from MinIO internally and streams it back,
so MinIO never needs to be publicly accessible.
"""
import mimetypes
evidence = get_evidence_or_raise(db, evidence_id)
content = download_file(evidence.file_path)
mime_type, _ = mimetypes.guess_type(evidence.file_name)
if not mime_type:
mime_type = "application/octet-stream"
safe_name = evidence.file_name.replace('"', '\\"')
return StreamingResponse(
iter([content]),
media_type=mime_type,
headers={
"Content-Disposition": f'inline; filename="{safe_name}"',
"Content-Length": str(len(content)),
},
)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# DELETE /evidence/{id} — delete evidence (editable states only) # DELETE /evidence/{id} — delete evidence (editable states only)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
-123
View File
@@ -1,123 +0,0 @@
"""Phase 13: Executive Dashboard router."""
from typing import List
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
from app.database import get_db
from app.dependencies.auth import get_current_user, require_any_role
from app.schemas.executive_dashboard_schema import (
PostureSnapshotOut,
ExecutiveSummary,
KpiBlock,
CoverageByTactic,
PostureHistoryEntry,
ActivityEntry,
)
import app.services.executive_dashboard_service as svc
router = APIRouter(prefix="/dashboard", tags=["Executive Dashboard"])
@router.get("/executive", response_model=ExecutiveSummary)
def executive_view(
db: Session = Depends(get_db),
user=Depends(get_current_user),
):
"""
Full executive view — snapshot, 30-day trends, top risks,
coverage by tactic, and recent activity feed.
"""
data = svc.get_executive_summary(db)
snap = data["snapshot"]
return ExecutiveSummary(
snapshot=PostureSnapshotOut.model_validate(snap),
coverage_trend=data["coverage_trend"],
risk_trend=data["risk_trend"],
top_risks=data["top_risks"],
coverage_by_tactic=data["coverage_by_tactic"],
recent_activity=data["recent_activity"],
)
@router.get("/kpis", response_model=KpiBlock)
def kpis(
db: Session = Depends(get_db),
user=Depends(get_current_user),
):
"""Compact KPI block — live aggregation without persisting a snapshot."""
live = svc.get_live_kpis(db)
# Try to find today's snapshot id; fall back to None
from datetime import date
from app.models.executive_dashboard import PostureSnapshot
today_snap = db.query(PostureSnapshot).filter(
PostureSnapshot.snapshot_date == date.today()
).first()
return KpiBlock(
coverage_pct=live["coverage_pct"],
avg_risk_score=live["avg_risk_score"],
critical_count=live["critical_count"],
open_queue_items=live["open_queue_items"],
orphan_techniques=live["orphan_techniques"],
mttd_avg_seconds=live.get("mttd_avg_seconds"),
detection_rate_30d=live.get("detection_rate_30d"),
playbook_count=live["playbook_count"],
lesson_count=live["lesson_count"],
snapshot_date=live["snapshot_date"],
snapshot_id=today_snap.id if today_snap else None,
)
@router.get("/coverage-by-tactic", response_model=List[CoverageByTactic])
def coverage_by_tactic(
db: Session = Depends(get_db),
user=Depends(get_current_user),
):
"""Per-tactic validated / partial / not_covered breakdown."""
return svc.get_coverage_by_tactic(db)
@router.get("/posture-history", response_model=List[PostureHistoryEntry])
def posture_history(
days: int = Query(30, ge=1, le=365),
db: Session = Depends(get_db),
user=Depends(get_current_user),
):
"""Historical posture snapshots for trend charts (default last 30 days)."""
snaps = svc.get_posture_history(db, days=days)
return [
PostureHistoryEntry(
snapshot_date=s.snapshot_date,
coverage_pct=s.coverage_pct,
avg_risk_score=s.avg_risk_score,
critical_count=s.critical_count,
open_queue_items=s.open_queue_items,
)
for s in snaps
]
@router.post("/posture-snapshot", response_model=PostureSnapshotOut, status_code=201)
def create_posture_snapshot(
db: Session = Depends(get_db),
user=Depends(require_any_role("admin", "red_lead", "blue_lead")),
):
"""
Take (or refresh) today's posture snapshot — admin / leads only.
Aggregates live data from all phases into a single PostureSnapshot row.
"""
snap = svc.take_posture_snapshot(db, created_by=user.id)
return PostureSnapshotOut.model_validate(snap)
@router.get("/activity", response_model=List[ActivityEntry])
def recent_activity(
limit: int = Query(20, ge=1, le=100),
db: Session = Depends(get_db),
user=Depends(get_current_user),
):
"""Recent activity feed — tests, attack-path executions, OSINT signals."""
return svc.get_recent_activity(db, limit=limit)
-54
View File
@@ -1,54 +0,0 @@
"""Intel items endpoints — list and manage threat intelligence items."""
import uuid
from typing import Optional
from fastapi import APIRouter, Depends, Query
from pydantic import BaseModel
from sqlalchemy.orm import Session
from app.database import get_db
from app.dependencies.auth import get_current_user
from app.models.intel import IntelItem
from app.models.user import User
router = APIRouter(prefix="/intel", tags=["intel"])
class IntelItemOut(BaseModel):
id: uuid.UUID
technique_id: Optional[uuid.UUID] = None
url: str
title: Optional[str] = None
source: Optional[str] = None
detected_at: Optional[str] = None
reviewed: bool
class Config:
from_attributes = True
@router.get("/items", response_model=list[IntelItemOut])
def list_intel_items(
technique_id: Optional[uuid.UUID] = Query(None, description="Filter by technique"),
limit: int = Query(50, ge=1, le=200),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""List threat intelligence items, optionally filtered by technique."""
query = db.query(IntelItem).order_by(IntelItem.detected_at.desc())
if technique_id:
query = query.filter(IntelItem.technique_id == technique_id)
items = query.limit(limit).all()
return [
IntelItemOut(
id=item.id,
technique_id=item.technique_id,
url=item.url,
title=item.title,
source=item.source,
detected_at=item.detected_at.isoformat() if item.detected_at else None,
reviewed=item.reviewed,
)
for item in items
]
+4 -4
View File
@@ -129,19 +129,19 @@ def list_links(
entity_type: Optional[JiraLinkEntityType] = None, entity_type: Optional[JiraLinkEntityType] = None,
# Entry: entity_id # Entry: entity_id
entity_id: Optional[UUID] = None, entity_id: Optional[UUID] = None,
entity_ids: Optional[list[UUID]] = Query(default=None, description="Filter by multiple entity IDs"), # Entry: db
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: user # Entry: user
user: User = Depends(get_current_user), user: User = Depends(get_current_user),
): ) -> list[JiraLinkOut]:
"""List Jira links, optionally filtered by entity or a list of entity IDs.""" """List Jira links, optionally filtered by entity."""
# Return jira_service.list_links(
return jira_service.list_links( return jira_service.list_links(
db, db,
# Keyword argument: entity_type # Keyword argument: entity_type
entity_type=entity_type, entity_type=entity_type,
# Keyword argument: entity_id # Keyword argument: entity_id
entity_id=entity_id, entity_id=entity_id,
entity_ids=entity_ids,
) )
-206
View File
@@ -1,206 +0,0 @@
"""Phase 11: Knowledge Management router — Playbooks + Lessons Learned."""
from typing import List, Optional
from uuid import UUID
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from app.database import get_db
from app.dependencies.auth import get_current_user, require_any_role
from app.schemas.knowledge_schema import (
PlaybookCreate, PlaybookUpdate, PlaybookOut, PlaybookVersionOut,
LessonLearnedCreate, LessonLearnedUpdate, LessonLearnedOut,
)
from app.services import playbook_service as pb_svc
from app.services import lesson_learned_service as ll_svc
router = APIRouter(prefix="/knowledge", tags=["knowledge"])
# ══════════════════════════════════════════════════════════════════════════════
# Playbooks
# ══════════════════════════════════════════════════════════════════════════════
@router.get("/playbooks", response_model=List[PlaybookOut])
def list_playbooks(
technique_id: Optional[UUID] = None,
playbook_type: Optional[str] = None,
include_inactive: bool = False,
db: Session = Depends(get_db),
user=Depends(get_current_user),
):
return pb_svc.list_playbooks(
db,
technique_id=technique_id,
playbook_type=playbook_type,
include_inactive=include_inactive,
)
@router.post("/playbooks", response_model=PlaybookOut, status_code=201)
def create_playbook(
body: PlaybookCreate,
db: Session = Depends(get_db),
user=Depends(require_any_role("admin", "red_lead", "blue_lead")),
):
return pb_svc.create_playbook(db, body.model_dump(), user.id)
@router.get("/playbooks/{playbook_id}", response_model=PlaybookOut)
def get_playbook(
playbook_id: UUID,
db: Session = Depends(get_db),
user=Depends(get_current_user),
):
return pb_svc.get_playbook(db, playbook_id)
@router.patch("/playbooks/{playbook_id}", response_model=PlaybookOut)
def update_playbook(
playbook_id: UUID,
body: PlaybookUpdate,
db: Session = Depends(get_db),
user=Depends(require_any_role("admin", "red_lead", "blue_lead")),
):
return pb_svc.update_playbook(db, playbook_id, body.model_dump(exclude_unset=True), user.id)
@router.delete("/playbooks/{playbook_id}", status_code=204)
def delete_playbook(
playbook_id: UUID,
db: Session = Depends(get_db),
user=Depends(require_any_role("admin", "red_lead", "blue_lead")),
):
pb_svc.delete_playbook(db, playbook_id, user.id)
# ── Versions ──────────────────────────────────────────────────────────────────
@router.get("/playbooks/{playbook_id}/versions", response_model=List[PlaybookVersionOut])
def list_versions(
playbook_id: UUID,
db: Session = Depends(get_db),
user=Depends(get_current_user),
):
return pb_svc.get_playbook_versions(db, playbook_id)
@router.post("/playbooks/{playbook_id}/restore/{version}", response_model=PlaybookOut)
def restore_version(
playbook_id: UUID,
version: int,
db: Session = Depends(get_db),
user=Depends(require_any_role("admin", "red_lead", "blue_lead")),
):
"""Roll the playbook back to a specific historical version."""
return pb_svc.restore_version(db, playbook_id, version, user.id)
# ── By technique (convenience) ────────────────────────────────────────────────
@router.get(
"/techniques/{technique_id}/playbooks",
response_model=List[PlaybookOut],
)
def playbooks_for_technique(
technique_id: UUID,
db: Session = Depends(get_db),
user=Depends(get_current_user),
):
"""List all active playbooks for a specific technique."""
return pb_svc.list_playbooks(db, technique_id=technique_id)
@router.get(
"/techniques/{technique_id}/playbooks/{playbook_type}",
response_model=PlaybookOut,
)
def get_playbook_by_technique_type(
technique_id: UUID,
playbook_type: str,
db: Session = Depends(get_db),
user=Depends(get_current_user),
):
pb = pb_svc.get_playbook_by_technique_type(db, technique_id, playbook_type)
if not pb:
from app.domain.errors import EntityNotFoundError
raise EntityNotFoundError("Playbook", f"{technique_id}/{playbook_type}")
return pb
# ══════════════════════════════════════════════════════════════════════════════
# Lessons Learned
# ══════════════════════════════════════════════════════════════════════════════
@router.get("/lessons", response_model=List[LessonLearnedOut])
def list_lessons(
entity_type: Optional[str] = None,
entity_id: Optional[UUID] = None,
severity: Optional[str] = None,
tag: Optional[str] = None,
technique_id: Optional[str] = None,
include_inactive: bool = False,
db: Session = Depends(get_db),
user=Depends(get_current_user),
):
return ll_svc.list_lessons_learned(
db,
entity_type=entity_type,
entity_id=entity_id,
severity=severity,
tag=tag,
technique_id=technique_id,
include_inactive=include_inactive,
)
@router.post("/lessons", response_model=LessonLearnedOut, status_code=201)
def create_lesson(
body: LessonLearnedCreate,
db: Session = Depends(get_db),
user=Depends(require_any_role("admin", "red_lead", "blue_lead")),
):
return ll_svc.create_lesson_learned(db, body.model_dump(), user.id)
@router.get("/lessons/{lesson_id}", response_model=LessonLearnedOut)
def get_lesson(
lesson_id: UUID,
db: Session = Depends(get_db),
user=Depends(get_current_user),
):
return ll_svc.get_lesson_learned(db, lesson_id)
@router.patch("/lessons/{lesson_id}", response_model=LessonLearnedOut)
def update_lesson(
lesson_id: UUID,
body: LessonLearnedUpdate,
db: Session = Depends(get_db),
user=Depends(require_any_role("admin", "red_lead", "blue_lead")),
):
return ll_svc.update_lesson_learned(
db, lesson_id, body.model_dump(exclude_unset=True), user.id
)
@router.delete("/lessons/{lesson_id}", status_code=204)
def delete_lesson(
lesson_id: UUID,
db: Session = Depends(get_db),
user=Depends(require_any_role("admin", "red_lead", "blue_lead")),
):
"""Soft-delete a lesson (admin / lead only)."""
ll_svc.delete_lesson_learned(db, lesson_id, user.id)
# ── Stats ─────────────────────────────────────────────────────────────────────
@router.get("/stats")
def knowledge_stats(
db: Session = Depends(get_db),
user=Depends(get_current_user),
):
"""Summary counts: total playbooks, lessons by severity, playbooks by type."""
return ll_svc.get_knowledge_stats(db)
-191
View File
@@ -1,191 +0,0 @@
"""Phase 13: Operational Alerts router."""
from typing import List, Optional
from uuid import UUID
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
from app.database import get_db
from app.dependencies.auth import get_current_user, require_any_role
from app.models.user import User
from app.schemas.operational_alert_schema import (
AlertRuleCreate, AlertRuleOut, AlertRuleUpdate,
AlertInstanceOut, EvaluationResult, AlertSummary,
)
import app.services.operational_alert_service as svc
router = APIRouter(prefix="/alerts", tags=["Operational Alerts"])
# ── Evaluation ────────────────────────────────────────────────────────────────
@router.post("/evaluate", response_model=EvaluationResult, status_code=202)
def evaluate_rules(
db: Session = Depends(get_db),
user=Depends(require_any_role("admin", "red_lead", "blue_lead")),
):
"""
Run the alert evaluation engine against all enabled rules.
Fires AlertInstances for rules whose conditions are met and are not in cooldown.
Admin / leads only.
"""
result = svc.evaluate_all_rules(db)
return EvaluationResult(
rules_evaluated = result["rules_evaluated"],
alerts_fired = result["alerts_fired"],
alerts = [AlertInstanceOut.model_validate(a) for a in result["alerts"]],
duration_seconds = result["duration_seconds"],
)
# ── Alert instances ───────────────────────────────────────────────────────────
@router.get("", response_model=List[AlertInstanceOut])
def list_alerts(
status: Optional[str] = Query(None),
severity: Optional[str] = Query(None),
rule_type: Optional[str] = Query(None),
limit: int = Query(50, ge=1, le=200),
offset: int = Query(0, ge=0),
db: Session = Depends(get_db),
user=Depends(get_current_user),
):
"""List alert instances with optional filters."""
return svc.list_instances(db, status=status, severity=severity,
rule_type=rule_type, limit=limit, offset=offset)
@router.get("/summary", response_model=AlertSummary)
def alert_summary(
db: Session = Depends(get_db),
user=Depends(get_current_user),
):
"""Aggregate counts by status, severity, and rule type."""
data = svc.get_summary(db)
return AlertSummary(
total_open = data["total_open"],
total_acknowledged = data["total_acknowledged"],
total_resolved = data["total_resolved"],
by_severity = data["by_severity"],
by_rule_type = data["by_rule_type"],
recent_alerts = [AlertInstanceOut.model_validate(a) for a in data["recent_alerts"]],
)
@router.get("/{alert_id}", response_model=AlertInstanceOut)
def get_alert(
alert_id: UUID,
db: Session = Depends(get_db),
user=Depends(get_current_user),
):
"""Get a single alert instance."""
return svc.get_instance(db, alert_id)
@router.post("/{alert_id}/acknowledge", response_model=AlertInstanceOut)
def acknowledge_alert(
alert_id: UUID,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("admin", "red_lead", "blue_lead")),
):
"""Acknowledge an open alert (admin / lead roles only)."""
return svc.acknowledge(db, alert_id, current_user.id)
@router.post("/{alert_id}/resolve", response_model=AlertInstanceOut)
def resolve_alert(
alert_id: UUID,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("admin", "red_lead", "blue_lead")),
):
"""Mark an alert as resolved (admin / lead roles only)."""
return svc.resolve(db, alert_id, current_user.id)
@router.post("/{alert_id}/dismiss", response_model=AlertInstanceOut)
def dismiss_alert(
alert_id: UUID,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("admin", "red_lead", "blue_lead")),
):
"""Dismiss an alert (admin / lead roles only — won't re-fire until cooldown resets)."""
return svc.dismiss(db, alert_id, current_user.id)
# ── Alert rules ───────────────────────────────────────────────────────────────
@router.get("/rules/list", response_model=List[AlertRuleOut])
def list_rules(
rule_type: Optional[str] = Query(None),
include_disabled: bool = Query(False),
db: Session = Depends(get_db),
user=Depends(get_current_user),
):
"""List alert rules (all users can read; admins/leads manage them)."""
return svc.list_rules(db, rule_type=rule_type, include_disabled=include_disabled)
@router.post("/rules", response_model=AlertRuleOut, status_code=201)
def create_rule(
body: AlertRuleCreate,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("admin", "red_lead", "blue_lead")),
):
"""Create a custom alert rule."""
return svc.create_rule(
db,
created_by = current_user.id,
name = body.name,
description = body.description,
rule_type = body.rule_type,
severity = body.severity,
config = body.config,
notify_in_app = body.notify_in_app,
notify_webhook = body.notify_webhook,
webhook_id = body.webhook_id,
cooldown_hours = body.cooldown_hours,
)
@router.get("/rules/{rule_id}", response_model=AlertRuleOut)
def get_rule(
rule_id: UUID,
db: Session = Depends(get_db),
user=Depends(get_current_user),
):
"""Get a single alert rule."""
return svc.get_rule(db, rule_id)
@router.patch("/rules/{rule_id}", response_model=AlertRuleOut)
def update_rule(
rule_id: UUID,
body: AlertRuleUpdate,
db: Session = Depends(get_db),
user=Depends(require_any_role("admin", "red_lead", "blue_lead")),
):
"""Update an alert rule (enable/disable, thresholds, cooldown)."""
return svc.update_rule(
db, rule_id,
name = body.name,
description = body.description,
severity = body.severity,
is_enabled = body.is_enabled,
config = body.config,
notify_in_app = body.notify_in_app,
notify_webhook = body.notify_webhook,
webhook_id = body.webhook_id,
cooldown_hours = body.cooldown_hours,
)
@router.delete("/rules/{rule_id}", status_code=204)
def delete_rule(
rule_id: UUID,
db: Session = Depends(get_db),
user=Depends(require_any_role("admin")),
):
"""Delete a custom alert rule (system rules cannot be deleted)."""
svc.delete_rule(db, rule_id)
+1 -1
View File
@@ -97,7 +97,7 @@ def list_osint_items(
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: user # Entry: user
user: User = Depends(get_current_user), user: User = Depends(get_current_user),
) -> dict: ) -> list:
"""List OSINT items with optional filters. """List OSINT items with optional filters.
Args: Args:
-215
View File
@@ -1,215 +0,0 @@
"""Phase 9: Ownership & Daily Operations router."""
from typing import Optional
from uuid import UUID
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
from app.database import get_db
from app.dependencies.auth import get_current_user, require_any_role
from app.domain.exceptions import EntityNotFoundError
from app.schemas.ownership_queue_schema import (
TechniqueOwnershipSet, TechniqueOwnershipOut,
DetectionAssetOwnershipPatch,
BulkAssignRequest, BulkAssignResult,
QueueItemCreate, QueueItemPatch, QueueItemOut,
)
from app.services import ownership_service, revalidation_queue_service
from app.models.ownership_queue import RevalidationQueueItem
router = APIRouter(prefix="/ownership", tags=["ownership"])
# ── Technique Ownership ───────────────────────────────────────────────────────
@router.get("/techniques/{technique_id}", response_model=TechniqueOwnershipOut)
def get_technique_ownership(
technique_id: UUID,
db: Session = Depends(get_db),
user=Depends(get_current_user),
):
ownership = ownership_service.get_technique_ownership(db, technique_id)
if not ownership:
raise EntityNotFoundError("TechniqueOwnership", str(technique_id))
return ownership
@router.put("/techniques/{technique_id}", response_model=TechniqueOwnershipOut)
def set_technique_ownership(
technique_id: UUID,
body: TechniqueOwnershipSet,
db: Session = Depends(get_db),
user=Depends(require_any_role("admin", "blue_lead", "red_lead")),
):
return ownership_service.set_technique_ownership(
db, technique_id,
owner_id=body.owner_id,
backup_owner_id=body.backup_owner_id,
team=body.team,
notes=body.notes,
assigned_by=user.id,
)
# ── Detection Asset Ownership ─────────────────────────────────────────────────
@router.patch("/assets/{asset_id}", response_model=dict)
def set_asset_ownership(
asset_id: UUID,
body: DetectionAssetOwnershipPatch,
db: Session = Depends(get_db),
user=Depends(require_any_role("admin", "blue_lead")),
):
ownership_service.set_asset_ownership(
db, asset_id,
owner_id=body.owner_id,
backup_owner_id=body.backup_owner_id,
team=body.team,
user_id=user.id,
)
return {"message": "Asset ownership updated"}
# ── Orphan Reports ────────────────────────────────────────────────────────────
@router.get("/orphans/techniques", response_model=list[dict])
def orphan_techniques(
db: Session = Depends(get_db),
user=Depends(get_current_user),
):
"""Return techniques with no assigned owner."""
return ownership_service.get_orphan_techniques(db)
@router.get("/orphans/assets", response_model=list[dict])
def orphan_assets(
db: Session = Depends(get_db),
user=Depends(get_current_user),
):
"""Return active detection assets with no assigned owner."""
return ownership_service.get_orphan_assets(db)
# ── Bulk Assignment ───────────────────────────────────────────────────────────
@router.post("/bulk-assign", response_model=BulkAssignResult)
def bulk_assign(
body: BulkAssignRequest,
db: Session = Depends(get_db),
user=Depends(require_any_role("admin", "blue_lead", "red_lead")),
):
"""
Bulk-assign ownership.
- If `tactic` is set → assigns technique ownership for all techniques of that tactic.
- If `platform` is set → assigns asset ownership for all assets on that platform.
At least one of tactic/platform must be provided.
"""
if not body.tactic and not body.platform:
from fastapi import HTTPException
raise HTTPException(status_code=422, detail="Provide at least one of: tactic, platform")
if body.tactic:
result = ownership_service.bulk_assign_techniques_by_tactic(
db, body.tactic,
owner_id=body.owner_id,
backup_owner_id=body.backup_owner_id,
team=body.team,
overwrite=body.overwrite,
user_id=user.id,
)
else:
result = ownership_service.bulk_assign_assets_by_platform(
db, body.platform,
owner_id=body.owner_id,
backup_owner_id=body.backup_owner_id,
team=body.team,
overwrite=body.overwrite,
user_id=user.id,
)
return BulkAssignResult(**result)
# ── Revalidation Queue ────────────────────────────────────────────────────────
@router.get("/queue", response_model=list[QueueItemOut])
def list_queue(
status: Optional[str] = Query(None),
priority: Optional[str] = Query(None),
reason: Optional[str] = Query(None),
assigned_to: Optional[UUID] = Query(None),
technique_id: Optional[UUID] = Query(None),
detection_asset_id: Optional[UUID] = Query(None),
limit: int = Query(100, ge=1, le=500),
offset: int = Query(0, ge=0),
db: Session = Depends(get_db),
user=Depends(get_current_user),
):
return revalidation_queue_service.list_queue(
db, status=status, priority=priority, reason=reason,
assigned_to=assigned_to, technique_id=technique_id,
detection_asset_id=detection_asset_id, limit=limit, offset=offset,
)
@router.post("/queue", response_model=QueueItemOut, status_code=201)
def create_queue_item(
body: QueueItemCreate,
db: Session = Depends(get_db),
user=Depends(get_current_user),
):
return revalidation_queue_service.create_queue_item(db, body.model_dump(), user.id)
@router.patch("/queue/{item_id}", response_model=QueueItemOut)
def update_queue_item(
item_id: UUID,
body: QueueItemPatch,
db: Session = Depends(get_db),
user=Depends(get_current_user),
):
return revalidation_queue_service.update_queue_item(db, item_id, body.model_dump(exclude_unset=True), user.id)
@router.post("/queue/generate", response_model=dict)
def generate_queue(
db: Session = Depends(get_db),
user=Depends(require_any_role("admin", "blue_lead")),
):
"""Scan the system and create new revalidation queue items."""
return revalidation_queue_service.generate_queue_items(db)
# ── Analyst Dashboard ─────────────────────────────────────────────────────────
@router.get("/analyst-dashboard")
def analyst_dashboard(
db: Session = Depends(get_db),
user=Depends(get_current_user),
):
"""Personalised daily workday view: my queue, expiring validations, infra changes, low-confidence techniques."""
dashboard = revalidation_queue_service.get_analyst_dashboard(db, user.id)
# Serialize queue items to dicts (ORM objects → plain dicts)
def _item_to_dict(item: RevalidationQueueItem) -> dict:
return {
"id": str(item.id),
"technique_id": str(item.technique_id) if item.technique_id else None,
"detection_asset_id": str(item.detection_asset_id) if item.detection_asset_id else None,
"priority": item.priority.value if hasattr(item.priority, "value") else item.priority,
"reason": item.reason.value if hasattr(item.reason, "value") else item.reason,
"reason_detail": item.reason_detail,
"status": item.status.value if hasattr(item.status, "value") else item.status,
"assigned_to": str(item.assigned_to) if item.assigned_to else None,
"due_date": item.due_date.isoformat() if item.due_date else None,
"created_at": item.created_at.isoformat() if item.created_at else None,
}
return {
"my_pending_items": [_item_to_dict(i) for i in dashboard["my_pending_items"]],
"expiring_validations_7d": dashboard["expiring_validations_7d"],
"recent_infra_changes": dashboard["recent_infra_changes"],
"my_low_confidence_techniques": dashboard["my_low_confidence_techniques"],
"summary": dashboard["summary"],
}
+7 -20
View File
@@ -2,10 +2,9 @@
# Import UUID from uuid # Import UUID from uuid
from uuid import UUID from uuid import UUID
from pathlib import Path
# Import APIRouter, Depends, HTTPException, Query, Request from fastapi # Import APIRouter, Depends, Query, Request from fastapi
from fastapi import APIRouter, Depends, HTTPException, Query, Request from fastapi import APIRouter, Depends, Query, Request
# Import FileResponse from fastapi.responses # Import FileResponse from fastapi.responses
from fastapi.responses import FileResponse from fastapi.responses import FileResponse
@@ -22,24 +21,12 @@ from app.dependencies.auth import get_current_user, require_any_role
# Import limiter from app.limiter # Import limiter from app.limiter
from app.limiter import limiter from app.limiter import limiter
# Import settings from app.config
from app.config import settings
# Import User from app.models.user # Import User from app.models.user
from app.models.user import User from app.models.user import User
# Import report_generation_service from app.services # Import report_generation_service from app.services
from app.services import report_generation_service from app.services import report_generation_service
def _assert_safe_report_path(filepath: str) -> str:
"""Raise 500 if the generated filepath escapes the configured report directory."""
output_dir = Path(settings.REPORT_OUTPUT_DIR).resolve()
resolved = Path(filepath).resolve()
if not resolved.is_relative_to(output_dir):
raise HTTPException(status_code=500, detail="Report generation path error")
return filepath
# Assign router = APIRouter(prefix="/reports/generate", tags=["professional-reports"]) # Assign router = APIRouter(prefix="/reports/generate", tags=["professional-reports"])
router = APIRouter(prefix="/reports/generate", tags=["professional-reports"]) router = APIRouter(prefix="/reports/generate", tags=["professional-reports"])
@@ -78,7 +65,7 @@ def generate_purple_report(
) )
# Return FileResponse( # Return FileResponse(
return FileResponse( return FileResponse(
_assert_safe_report_path(filepath), filepath,
# Keyword argument: media_type # Keyword argument: media_type
media_type=_MEDIA_TYPES[format], media_type=_MEDIA_TYPES[format],
# Keyword argument: filename # Keyword argument: filename
@@ -108,7 +95,7 @@ def generate_coverage_report(
) )
# Return FileResponse( # Return FileResponse(
return FileResponse( return FileResponse(
_assert_safe_report_path(filepath), filepath,
# Keyword argument: media_type # Keyword argument: media_type
media_type=_MEDIA_TYPES[format], media_type=_MEDIA_TYPES[format],
# Keyword argument: filename # Keyword argument: filename
@@ -138,7 +125,7 @@ def generate_executive_report(
) )
# Return FileResponse( # Return FileResponse(
return FileResponse( return FileResponse(
_assert_safe_report_path(filepath), filepath,
# Keyword argument: media_type # Keyword argument: media_type
media_type=_MEDIA_TYPES[format], media_type=_MEDIA_TYPES[format],
# Keyword argument: filename # Keyword argument: filename
@@ -168,7 +155,7 @@ def generate_quarterly_report(
) )
# Return FileResponse( # Return FileResponse(
return FileResponse( return FileResponse(
_assert_safe_report_path(filepath), filepath,
# Keyword argument: media_type # Keyword argument: media_type
media_type=_MEDIA_TYPES[format], media_type=_MEDIA_TYPES[format],
# Keyword argument: filename # Keyword argument: filename
@@ -200,7 +187,7 @@ def generate_technique_report(
) )
# Return FileResponse( # Return FileResponse(
return FileResponse( return FileResponse(
_assert_safe_report_path(filepath), filepath,
# Keyword argument: media_type # Keyword argument: media_type
media_type=_MEDIA_TYPES[format], media_type=_MEDIA_TYPES[format],
# Keyword argument: filename # Keyword argument: filename
-113
View File
@@ -1,113 +0,0 @@
"""Phase 12: Risk Intelligence router."""
from typing import List, Optional
from uuid import UUID
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
from app.database import get_db
from app.dependencies.auth import get_current_user, require_any_role
from app.schemas.risk_schema import (
TechniqueRiskProfileOut,
ComputeResult,
)
from app.services import risk_intelligence_service as svc
router = APIRouter(prefix="/risk", tags=["risk-intelligence"])
# ── Compute ──────────────────────────────────────────────────────────────────
@router.post("/compute", response_model=ComputeResult, status_code=202)
def compute_all(
db: Session = Depends(get_db),
user=Depends(require_any_role("admin", "red_lead", "blue_lead")),
):
"""Recompute risk scores for ALL techniques (admin / leads only)."""
result = svc.compute_all_risk_scores(db)
return result
@router.post("/profiles/{technique_id}/compute", response_model=TechniqueRiskProfileOut)
def compute_one(
technique_id: UUID,
db: Session = Depends(get_db),
user=Depends(get_current_user),
):
"""Compute (or refresh) the risk profile for a single technique."""
return svc.compute_technique_risk(db, technique_id)
# ── Read ─────────────────────────────────────────────────────────────────────
@router.get("/profiles", response_model=List[TechniqueRiskProfileOut])
def list_profiles(
risk_level: Optional[str] = None,
min_score: Optional[float] = None,
max_score: Optional[float] = None,
stale_only: bool = False,
limit: int = Query(100, ge=1, le=500),
offset: int = Query(0, ge=0),
db: Session = Depends(get_db),
user=Depends(get_current_user),
):
"""List risk profiles with optional filters."""
return svc.list_risk_profiles(
db,
risk_level=risk_level,
min_score=min_score,
max_score=max_score,
stale_only=stale_only,
limit=limit,
offset=offset,
)
@router.get("/profiles/{technique_id}", response_model=TechniqueRiskProfileOut)
def get_profile(
technique_id: UUID,
db: Session = Depends(get_db),
user=Depends(get_current_user),
):
"""Get the current risk profile for a technique."""
return svc.get_risk_profile(db, technique_id)
@router.get("/matrix")
def risk_matrix(
db: Session = Depends(get_db),
user=Depends(get_current_user),
):
"""All profiled techniques with likelihood/impact coordinates for matrix view."""
return svc.get_risk_matrix(db)
@router.get("/summary")
def risk_summary(
db: Session = Depends(get_db),
user=Depends(get_current_user),
):
"""Aggregate risk statistics: counts by level, average score, top risks."""
return svc.get_risk_summary(db)
@router.get("/recommendations")
def recommendations(
limit: int = Query(20, ge=1, le=100),
db: Session = Depends(get_db),
user=Depends(get_current_user),
):
"""Prioritised list of techniques with actionable recommendations."""
return svc.get_recommendations(db, limit=limit)
@router.get("/top")
def top_risks(
limit: int = Query(10, ge=1, le=50),
db: Session = Depends(get_db),
user=Depends(get_current_user),
):
"""Top N highest-risk techniques (sorted by risk score desc)."""
profiles = svc.list_risk_profiles(db, limit=limit)
return profiles
+1 -1
View File
@@ -87,7 +87,7 @@ def list_snapshots(
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user # Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> dict: ) -> list:
"""List coverage snapshots ordered by creation date (newest first).""" """List coverage snapshots ordered by creation date (newest first)."""
# Return list_snapshots_svc(db, offset=offset, limit=limit) # Return list_snapshots_svc(db, offset=offset, limit=limit)
return list_snapshots_svc(db, offset=offset, limit=limit) return list_snapshots_svc(db, offset=offset, limit=limit)
-135
View File
@@ -1,135 +0,0 @@
"""Phase 14: SSO / SAML 2.0 router."""
import os
from urllib.parse import urlparse
from fastapi import APIRouter, Depends, HTTPException, Request, Response
from fastapi.responses import RedirectResponse
from sqlalchemy.orm import Session
from app.database import get_db
from app.dependencies.auth import require_any_role
from app import auth as auth_lib
from app.schemas.sso_schema import (
SsoConfigCreate, SsoConfigOut, SsoStatusResponse,
)
import app.services.sso_service as svc
router = APIRouter(prefix="/sso", tags=["SSO"])
_COOKIE_NAME = "aegis_token"
# Mirror the same SECURE_COOKIES logic used in the auth router so that
# SAML-authenticated sessions respect the deployment's HTTPS configuration.
_aegis_env = os.environ.get("AEGIS_ENV", "development").lower()
_secure_cookie_env = os.environ.get("SECURE_COOKIES", "auto").lower()
if _secure_cookie_env == "false":
_IS_HTTPS = False
elif _secure_cookie_env == "true":
_IS_HTTPS = True
else: # "auto" — active only when AEGIS_ENV=production
_IS_HTTPS = _aegis_env == "production"
_COOKIE_OPTS = {"httponly": True, "samesite": "lax", "secure": _IS_HTTPS}
# ── Public ────────────────────────────────────────────────────────────────────
@router.get("/status", response_model=SsoStatusResponse)
def sso_status(db: Session = Depends(get_db)):
"""Return whether SSO is enabled and configured (public — for login page)."""
return svc.get_status(db)
@router.get("/metadata", response_model=None)
def sp_metadata(db: Session = Depends(get_db)):
"""
Return the Service Provider SAML metadata XML.
Upload this XML to your IdP (Okta, Azure AD, etc.) to register Aegis.
"""
try:
xml = svc.get_sp_metadata(db)
except Exception as exc:
raise HTTPException(status_code=503, detail=str(exc))
return Response(content=xml, media_type="application/xml")
@router.get("/login")
def sso_login(request: Request, db: Session = Depends(get_db)):
"""
Initiate SAML login — redirects the browser to the IdP.
The IdP will POST the SAML Response to ``/sso/callback`` after authentication.
"""
request_data = {
"https": request.url.scheme == "https",
"http_host": request.url.hostname,
"path": request.url.path,
"port": str(request.url.port or (443 if request.url.scheme == "https" else 80)),
"get_data": dict(request.query_params),
"post_data": {},
"query_string": str(request.url.query),
}
try:
result = svc.initiate_login(db, request_data)
except RuntimeError as exc:
raise HTTPException(status_code=503, detail=str(exc))
redirect_url = result["redirect_url"]
if urlparse(redirect_url).scheme not in ("http", "https"):
raise HTTPException(status_code=400, detail="Invalid IdP redirect URL")
return RedirectResponse(url=redirect_url)
@router.post("/callback")
async def sso_callback(request: Request, db: Session = Depends(get_db)):
"""
SAML Assertion Consumer Service (ACS) endpoint.
The IdP POSTs the SAML Response here. On success, sets the aegis_token
cookie and redirects to the frontend.
"""
form = await request.form()
request_data = {
"https": request.url.scheme == "https",
"http_host": request.url.hostname,
"path": request.url.path,
"port": str(request.url.port or (443 if request.url.scheme == "https" else 80)),
"get_data": dict(request.query_params),
"post_data": dict(form),
"query_string": str(request.url.query),
}
try:
user = svc.process_callback(db, request_data)
except (ValueError, RuntimeError) as exc:
raise HTTPException(status_code=401, detail=str(exc))
access_token = auth_lib.create_access_token({"sub": user.username})
response = RedirectResponse(url="/", status_code=302)
response.set_cookie(_COOKIE_NAME, access_token, **_COOKIE_OPTS)
return response
# ── Admin configuration ────────────────────────────────────────────────────────
@router.get("/config", response_model=SsoConfigOut)
def get_sso_config(
db: Session = Depends(get_db),
_user=Depends(require_any_role("admin")),
):
"""Return the current SSO configuration (admin only)."""
cfg = svc.get_config(db)
if not cfg:
raise HTTPException(status_code=404, detail="SSO not configured yet")
return SsoConfigOut.model_validate(cfg)
@router.put("/config", response_model=SsoConfigOut)
def upsert_sso_config(
body: SsoConfigCreate,
db: Session = Depends(get_db),
_user=Depends(require_any_role("admin")),
):
"""Create or replace the SSO configuration (admin only)."""
cfg = svc.upsert_config(db, **body.model_dump(exclude_unset=False))
return SsoConfigOut.model_validate(cfg)
+40 -699
View File
@@ -3,30 +3,41 @@
Provides manual triggers for background operations such as the MITRE Provides manual triggers for background operations such as the MITRE
ATT&CK synchronisation, intel scanning, Atomic Red Team import, and ATT&CK synchronisation, intel scanning, Atomic Red Team import, and
scheduler health introspection. scheduler health introspection.
Also exposes email configuration CRUD (admin only) that writes to the
system_configs table so settings survive container restarts.
""" """
# Import logging # Import logging
import logging import logging
from typing import Optional
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request, status # Import APIRouter, Depends, Request from fastapi
from pydantic import BaseModel from fastapi import APIRouter, Depends, Request
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.database import SessionLocal, get_db # Import get_db from app.database
from app.dependencies.auth import get_current_user, require_role from app.database import get_db
from app.models.user import User
from app.services.mitre_sync_service import sync_mitre # Import require_role from app.dependencies.auth
from app.services.intel_service import scan_intel from app.dependencies.auth import require_role
from app.services.atomic_import_service import import_atomic_red_team
# Import scheduler from app.jobs.mitre_sync_job
from app.jobs.mitre_sync_job import scheduler from app.jobs.mitre_sync_job import scheduler
# Import limiter from app.limiter # Import limiter from app.limiter
from app.limiter import limiter from app.limiter import limiter
# Import User from app.models.user
from app.models.user import User
# Import import_atomic_red_team from app.services.atomic_import_service
from app.services.atomic_import_service import import_atomic_red_team
# Import scan_intel from app.services.intel_service
from app.services.intel_service import scan_intel
# Import sync_mitre from app.services.mitre_sync_service
from app.services.mitre_sync_service import sync_mitre
# Assign logger = logging.getLogger(__name__) # Assign logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -34,81 +45,7 @@ logger = logging.getLogger(__name__)
router = APIRouter(prefix="/system", tags=["system"]) router = APIRouter(prefix="/system", tags=["system"])
# --------------------------------------------------------------------------- # Apply the @router.post decorator
# Pydantic schemas for email config
# ---------------------------------------------------------------------------
class EmailConfigOut(BaseModel):
enabled: bool
host: str
port: int
username: str
from_email: str
use_tls: bool
# password is never returned
class EmailConfigUpdate(BaseModel):
enabled: Optional[bool] = None
host: Optional[str] = None
port: Optional[int] = None
username: Optional[str] = None
password: Optional[str] = None
from_email: Optional[str] = None
use_tls: Optional[bool] = None
class EmailTestRequest(BaseModel):
to: str
# ---------------------------------------------------------------------------
# Helpers for system_configs CRUD
# ---------------------------------------------------------------------------
_SMTP_KEYS = {
"enabled": "smtp.enabled",
"host": "smtp.host",
"port": "smtp.port",
"username": "smtp.username",
"password": "smtp.password", # nosec B105
"from_email": "smtp.from_email",
"use_tls": "smtp.use_tls",
}
def _upsert_config(db: Session, key: str, value: str) -> None:
from app.models.system_config import SystemConfig # lazy import avoids circular
row = db.query(SystemConfig).filter(SystemConfig.key == key).first()
if row:
row.value = value
else:
row = SystemConfig(key=key, value=value)
db.add(row)
def _read_email_config_from_db(db: Session) -> dict:
"""Return a dict with resolved email settings (DB overrides env)."""
from app.services.email_service import _get_smtp_config
return _get_smtp_config(db)
def _bg_mitre_sync() -> None:
"""Run MITRE sync in a background task with its own DB session."""
logger.info("Background MITRE sync task starting...")
db = SessionLocal()
try:
summary = sync_mitre(db)
logger.info("Background MITRE sync task finished — %s", summary)
except Exception:
logger.exception("Background MITRE sync task failed")
finally:
db.close()
@router.post("/sync-mitre") @router.post("/sync-mitre")
# Apply the @limiter.limit decorator # Apply the @limiter.limit decorator
@limiter.limit("2/hour") @limiter.limit("2/hour")
@@ -116,22 +53,28 @@ def _bg_mitre_sync() -> None:
def trigger_mitre_sync( def trigger_mitre_sync(
# Entry: request # Entry: request
request: Request, request: Request,
background_tasks: BackgroundTasks, # Entry: db
db: Session = Depends(get_db),
# Entry: current_user
current_user: User = Depends(require_role("admin")), current_user: User = Depends(require_role("admin")),
): ) -> dict:
"""Manually trigger a MITRE ATT&CK synchronisation in the background. """Manually trigger a MITRE ATT&CK synchronisation.
**Requires** the ``admin`` role. **Requires** the ``admin`` role.
Returns immediately — the sync runs asynchronously. Poll Returns a JSON object with the sync summary including the count of
``/system/scheduler-status`` for progress, or check server logs. new and updated techniques.
""" """
background_tasks.add_task(_bg_mitre_sync) # Assign summary = sync_mitre(db)
summary = sync_mitre(db)
# Return {
return { return {
"message": "MITRE sync started in background", # Literal argument value
"status": "started", "message": "MITRE sync completed",
"new": 0, # Literal argument value
"updated": 0, "new": summary["created"],
# Literal argument value
"updated": summary["updated"],
} }
@@ -242,605 +185,3 @@ def scheduler_status(
for job in jobs for job in jobs
], ],
} }
# ---------------------------------------------------------------------------
# Jira config endpoints (admin only)
# ---------------------------------------------------------------------------
class JiraConfigOut(BaseModel):
enabled: bool
url: str
project_key: str
parent_ticket: str
parent_ticket_standalone: str # parent for tests not in a campaign
# Credentials are never returned
class JiraConfigUpdate(BaseModel):
enabled: Optional[bool] = None
url: Optional[str] = None
project_key: Optional[str] = None
parent_ticket: Optional[str] = None
parent_ticket_standalone: Optional[str] = None
_JIRA_KEYS = {
"enabled": "jira.enabled",
"url": "jira.url",
"project_key": "jira.project_key",
"parent_ticket": "jira.parent_ticket",
"parent_ticket_standalone": "jira.parent_ticket_standalone",
}
@router.get("/jira-config", response_model=JiraConfigOut)
def get_jira_config(
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
"""Return current Jira configuration (merged DB + env).
**Requires** the ``admin`` role. Credentials are never returned.
"""
from app.services.jira_service import (
get_jira_url, get_jira_project_key, is_jira_enabled,
get_jira_parent_ticket, get_jira_parent_ticket_standalone,
)
return JiraConfigOut(
enabled=is_jira_enabled(db),
url=get_jira_url(db) or "",
project_key=get_jira_project_key(db) or "",
parent_ticket=get_jira_parent_ticket(db) or "",
parent_ticket_standalone=get_jira_parent_ticket_standalone(db) or "",
)
@router.patch("/jira-config", response_model=JiraConfigOut)
def update_jira_config(
payload: JiraConfigUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
"""Update Jira configuration and persist to DB.
**Requires** the ``admin`` role. Only provided fields are updated.
"""
from app.services.jira_service import (
upsert_jira_config, get_jira_url, get_jira_project_key, is_jira_enabled,
get_jira_parent_ticket, get_jira_parent_ticket_standalone,
)
update_data = payload.model_dump(exclude_unset=True)
for field, val in update_data.items():
db_key = _JIRA_KEYS.get(field)
if db_key:
upsert_jira_config(db, db_key, str(val))
db.commit()
return JiraConfigOut(
enabled=is_jira_enabled(db),
url=get_jira_url(db) or "",
project_key=get_jira_project_key(db) or "",
parent_ticket=get_jira_parent_ticket(db) or "",
parent_ticket_standalone=get_jira_parent_ticket_standalone(db) or "",
)
@router.post("/jira-test")
def test_jira_connection(
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
"""Test the Jira connection using the current user's credentials.
Requires the admin to have a personal Jira API token configured in their
profile settings.
Always returns HTTP 200 with a ``status`` field so Cloudflare never
replaces the response with its own error page.
"""
from app.services.jira_service import get_user_jira_client, get_jira_url, _effective_jira_email
jira_url = get_jira_url(db)
if not jira_url:
return {"status": "error", "message": "Jira URL is not configured. Set it in System Settings → Jira Configuration.", "jira_url": ""}
auth_email = _effective_jira_email(current_user)
try:
jira = get_user_jira_client(current_user, db)
# 10-second timeout so we never block Cloudflare into a 524
try:
jira._session.timeout = 10 # type: ignore[attr-defined]
except Exception: # nosec B110
pass
myself = jira.myself()
logger.info("Jira myself() response keys: %s", list(myself.keys()) if isinstance(myself, dict) else type(myself))
# Use displayName → emailAddress → name → the auth email as fallback
connected_as = (
(myself.get("displayName") if isinstance(myself, dict) else None)
or (myself.get("emailAddress") if isinstance(myself, dict) else None)
or (myself.get("name") if isinstance(myself, dict) else None)
or auth_email
or "authenticated"
)
return {
"status": "ok",
"connected_as": connected_as,
"jira_url": jira_url,
}
except Exception as exc:
err = str(exc)
# Always return HTTP 200 with status="error" so Cloudflare never
# intercepts the response and the frontend always sees our message.
if "Expecting value" in err or "line 1 column 1" in err:
msg = (
"Jira returned a non-JSON response. "
"Verify the URL (e.g. https://company.atlassian.net), "
"email and API token."
)
elif "401" in err or "Unauthorized" in err:
msg = (
"Authentication failed (401). "
f"Check that the Atlassian email ({auth_email or 'not set'}) "
"and API token are correct. The token must be an Atlassian API token "
"(not your account password)."
)
elif "403" in err or "Forbidden" in err:
msg = "Access denied (403). The token may not have permission for this Jira project."
elif "timed out" in err.lower() or "timeout" in err.lower():
msg = "Connection timed out. Check that the Jira URL is reachable from the server."
elif "not configured" in err.lower():
msg = err
else:
msg = f"Jira connection failed: {err}"
logger.warning("Jira test connection failed: %s", err)
return {"status": "error", "message": msg, "jira_url": jira_url}
# ---------------------------------------------------------------------------
# POST /system/tempo-test
# ---------------------------------------------------------------------------
@router.post("/tempo-test")
def test_tempo_connection(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Test the current user's personal Tempo connection.
Uses the Tempo API token stored in the user's profile (not a global token).
Always returns HTTP 200 with a ``status`` field so Cloudflare never
intercepts the response.
"""
tempo_token = getattr(current_user, "tempo_api_token", None)
if not tempo_token:
return {
"status": "error",
"message": (
"No Tempo API token configured. "
"Add it in Settings → Profile → Tempo Integration."
),
}
jira_account_id = getattr(current_user, "jira_account_id", None)
if not jira_account_id:
return {
"status": "error",
"message": (
"No Jira Account ID configured. "
"Set it in Settings → Profile → Jira Integration → Account ID."
),
}
try:
from tempoapiclient import client_v4 as tempo_client
tempo = tempo_client.Tempo(auth_token=tempo_token)
# search_worklogs by authorId is the correct v4 method; use a tight
# date range so we fetch almost nothing but still verify connectivity.
worklogs = tempo.search_worklogs(
dateFrom="2024-01-01",
dateTo="2024-01-02",
authorIds=[jira_account_id],
)
count = len(worklogs) if isinstance(worklogs, list) else "n/a"
return {
"status": "ok",
"message": f"Tempo connected successfully. Account ID: {jira_account_id}",
"worklogs_found": count,
}
except Exception as exc:
err = str(exc)
if "401" in err or "Unauthorized" in err:
msg = (
"Authentication failed (401). "
"Check your Tempo API token — obtain it at "
"Jira → Apps → Tempo → Settings → API Integration."
)
elif "403" in err or "Forbidden" in err:
msg = "Access denied (403). The Tempo token lacks the required permissions."
elif "404" in err or "not found" in err.lower():
msg = (
"Account ID not found (404). "
f"The value '{jira_account_id}' may be wrong — see the instructions "
"below to find your correct Atlassian Account ID."
)
else:
msg = f"Tempo connection failed: {err}"
logger.warning(
"Tempo test connection failed for user %s (account_id=%s): %s",
current_user.username, jira_account_id, err,
)
return {"status": "error", "message": msg}
# ---------------------------------------------------------------------------
# GET /system/email-config
# ---------------------------------------------------------------------------
@router.get("/email-config", response_model=EmailConfigOut)
def get_email_config(
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
"""Return current SMTP email configuration (merged DB + env).
**Requires** the ``admin`` role. Password is never returned.
"""
cfg = _read_email_config_from_db(db)
return EmailConfigOut(
enabled=cfg["enabled"],
host=cfg["host"],
port=cfg["port"],
username=cfg["username"],
from_email=cfg["from_email"],
use_tls=cfg["use_tls"],
)
# ---------------------------------------------------------------------------
# PATCH /system/email-config
# ---------------------------------------------------------------------------
@router.patch("/email-config", response_model=EmailConfigOut)
def update_email_config(
payload: EmailConfigUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
"""Update SMTP email configuration and persist to DB.
**Requires** the ``admin`` role.
Only provided fields are updated (partial update).
"""
update_data = payload.model_dump(exclude_unset=True)
for field, val in update_data.items():
db_key = _SMTP_KEYS.get(field)
if db_key:
_upsert_config(db, db_key, str(val))
db.commit()
cfg = _read_email_config_from_db(db)
return EmailConfigOut(
enabled=cfg["enabled"],
host=cfg["host"],
port=cfg["port"],
username=cfg["username"],
from_email=cfg["from_email"],
use_tls=cfg["use_tls"],
)
# ---------------------------------------------------------------------------
# POST /system/email-test
# ---------------------------------------------------------------------------
# ---------------------------------------------------------------------------
# ATT&CK Evaluations endpoints (admin only)
# ---------------------------------------------------------------------------
@router.get("/attck-evaluations/rounds")
def list_evaluation_rounds(
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
"""Return all public CrowdStrike ENTERPRISE evaluation rounds with import status.
Each entry includes whether it has already been imported into this platform.
"""
from app.services.attck_evaluations_service import fetch_rounds_with_status
from app.models.evaluation_import import EvaluationImport
status_info = fetch_rounds_with_status()
rounds = status_info["rounds"]
imported = {
row.adversary_name.lower(): row
for row in db.query(EvaluationImport).filter(EvaluationImport.status == "completed").all()
}
round_list = [
{
"name": r["name"],
"display_name": r.get("display_name", r["name"]),
"eval_round": r["eval_round"],
"imported": r["name"].lower() in imported,
"imported_at": imported[r["name"].lower()].imported_at.isoformat()
if r["name"].lower() in imported else None,
"tests_created": imported[r["name"].lower()].tests_created
if r["name"].lower() in imported else None,
"techniques_covered": imported[r["name"].lower()].techniques_covered
if r["name"].lower() in imported else None,
}
for r in rounds
]
return {
"rounds": round_list,
"api_reachable": status_info["api_reachable"],
"api_error": status_info.get("api_error"),
}
@router.post("/attck-evaluations/import")
def import_evaluation_round(
payload: dict,
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
"""Import a specific ATT&CK Evaluation round for CrowdStrike.
Body: { "adversary_name": "apt29", "adversary_display": "APT29", "eval_round": 2 }
Creates tests in ``in_review`` state — Blue Leads must validate each
result before it counts as real coverage.
"""
from app.services.attck_evaluations_service import import_evaluation_round as _import
adversary_name = payload.get("adversary_name", "")
adversary_display = payload.get("adversary_display", adversary_name)
eval_round = payload.get("eval_round", 0)
if not adversary_name or not eval_round:
raise HTTPException(status_code=400, detail="adversary_name and eval_round are required")
try:
summary = _import(db, adversary_name, adversary_display, eval_round, current_user)
except ValueError as exc:
raise HTTPException(status_code=409, detail=str(exc))
except Exception as exc:
logger.error("ATT&CK Evaluation import failed: %s", exc, exc_info=True)
raise HTTPException(status_code=502, detail=f"Import failed: {exc}")
return {
"message": f"Import complete — {summary['created']} tests created",
**summary,
}
@router.post("/attck-evaluations/import-latest")
def import_latest_evaluation(
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
"""Import the latest available CrowdStrike evaluation round automatically.
Returns 409 if the latest round was already imported.
"""
from app.services.attck_evaluations_service import get_latest_round, import_evaluation_round as _import
try:
latest = get_latest_round()
except Exception as exc:
raise HTTPException(status_code=502, detail=f"Could not reach MITRE Evaluations API: {exc}")
try:
summary = _import(
db,
latest["name"],
latest.get("display_name", latest["name"]),
latest["eval_round"],
current_user,
)
except ValueError as exc:
raise HTTPException(status_code=409, detail=str(exc))
except Exception as exc:
logger.error("ATT&CK Evaluation import failed: %s", exc, exc_info=True)
raise HTTPException(status_code=502, detail=f"Import failed: {exc}")
return {
"message": f"Import complete — {summary['created']} tests created",
**summary,
}
@router.get("/attck-evaluations/check-new")
def check_new_evaluation_round(
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
"""Check if a new ATT&CK Evaluation round is available that hasn't been imported yet."""
from app.services.attck_evaluations_service import check_for_new_round
return check_for_new_round(db)
@router.post("/attck-evaluations/bulk-approve")
def bulk_approve_evaluation_tests(
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
"""Bulk-approve all Blue Team validation for ATT&CK Evaluation imported tests.
Finds every test in ``in_review`` state whose name starts with ``[EVAL R``
and approves the Blue Team side. Because all evaluation imports pre-approve
the Red Team side, this moves every matched test to ``validated`` state.
**Important caveats** (enforced by UI warnings before this is called):
- Results come from a controlled MITRE lab, NOT the organisation's env.
- Validated tests will immediately affect coverage metrics and dashboards.
- Blue Leads should still spot-check high-priority techniques individually.
"""
from datetime import datetime
from app.models.test import Test
from app.models.enums import TestState
from app.models.technique import Technique
from app.services.status_service import recalculate_technique_status
from app.services.audit_service import log_action
# Find all pending evaluation tests
pending = (
db.query(Test)
.filter(
Test.state == TestState.in_review,
Test.name.like("[EVAL R%"),
)
.all()
)
if not pending:
return {
"approved": 0,
"techniques_recalculated": 0,
"message": "No pending evaluation tests found — nothing to approve.",
}
now = datetime.utcnow()
affected_technique_ids: set = set()
for test in pending:
# Approve blue side
test.blue_validation_status = "approved"
test.blue_validated_by = current_user.id
test.blue_validated_at = now
test.blue_validation_notes = (
"Bulk-approved via ATT&CK Evaluations admin panel. "
"Source: MITRE lab environment — not organisational detection."
)
# Red side was pre-approved during import → move to validated
if test.red_validation_status == "approved":
test.state = TestState.validated
# else stays in_review (shouldn't happen for eval imports, but be safe)
if test.technique_id:
affected_technique_ids.add(test.technique_id)
log_action(
db,
user_id=current_user.id,
action="bulk_eval_approve",
entity_type="test",
entity_id=test.id,
details={"source": "attck_evaluation_bulk_approve"},
)
db.flush()
# Recalculate coverage for every touched technique
for tech_id in affected_technique_ids:
tech = db.query(Technique).filter(Technique.id == tech_id).first()
if tech:
recalculate_technique_status(db, tech)
db.commit()
logger.info(
"Bulk eval approval: %d tests validated, %d techniques recalculated (by %s)",
len(pending), len(affected_technique_ids), current_user.email,
)
return {
"approved": len(pending),
"techniques_recalculated": len(affected_technique_ids),
"message": (
f"{len(pending)} evaluation tests approved and moved to Validated. "
f"{len(affected_technique_ids)} technique statuses recalculated."
),
}
@router.get("/attck-evaluations/pending-count")
def get_pending_evaluation_count(
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
"""Return the number of imported evaluation tests still awaiting Blue approval."""
from app.models.test import Test
from app.models.enums import TestState
count = (
db.query(Test)
.filter(
Test.state == TestState.in_review,
Test.name.like("[EVAL R%"),
)
.count()
)
return {"pending": count}
@router.post("/attck-evaluations/re-enrich")
def re_enrich_evaluation_round(
payload: dict,
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
"""Re-enrich already-imported evaluation tests with rich data from the MITRE API.
Updates procedure_text (attack path + criteria), description (data sources +
substep references) and red_summary — without changing detection results,
state or validation status.
Body: { "adversary_name": "turla", "adversary_display": "Turla", "eval_round": 5 }
Useful to upgrade tests that were imported before the enrichment feature
was added.
"""
from app.services.attck_evaluations_service import re_enrich_evaluation_round as _re_enrich
adversary_name = payload.get("adversary_name", "")
adversary_display = payload.get("adversary_display", adversary_name)
eval_round = payload.get("eval_round", 0)
if not adversary_name or not eval_round:
raise HTTPException(status_code=400, detail="adversary_name and eval_round are required")
try:
summary = _re_enrich(db, adversary_name, adversary_display, eval_round, current_user)
except Exception as exc:
logger.error("ATT&CK Evaluation re-enrich failed: %s", exc, exc_info=True)
raise HTTPException(status_code=502, detail=f"Re-enrich failed: {exc}")
return summary
@router.post("/email-test")
def send_test_email(
payload: EmailTestRequest,
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
"""Send a test email to verify SMTP configuration.
**Requires** the ``admin`` role.
Returns 200 on success, 502 if sending fails.
"""
from app.services.email_service import send_test_email as _send_test
ok = _send_test(payload.to, db=db)
if not ok:
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail="Failed to send test email. Check SMTP configuration and server logs.",
)
return {"detail": f"Test email sent to {payload.to}"}
+3 -10
View File
@@ -42,7 +42,8 @@ from app.dependencies.auth import get_current_user, require_any_role
# Import UnitOfWork from app.domain.unit_of_work # Import UnitOfWork from app.domain.unit_of_work
from app.domain.unit_of_work import UnitOfWork from app.domain.unit_of_work import UnitOfWork
from app.models.technique import Technique
# Import User from app.models.user
from app.models.user import User from app.models.user import User
# Import from app.schemas.test_template # Import from app.schemas.test_template
@@ -333,15 +334,7 @@ def create_template(
template = create_template_svc(db, **payload.model_dump()) template = create_template_svc(db, **payload.model_dump())
# Open context manager # Open context manager
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
# Flag the associated technique for review — new template available # Call log_action()
if template.mitre_technique_id:
technique = (
db.query(Technique)
.filter(Technique.mitre_id == template.mitre_technique_id)
.first()
)
if technique:
technique.review_required = True
log_action( log_action(
db, db,
# Keyword argument: user_id # Keyword argument: user_id
+64 -442
View File
@@ -11,7 +11,6 @@ PATCH /tests/{id}/red — Red Team updates (draft, red_executing)
PATCH /tests/{id}/blue — Blue Team updates (blue_evaluating) PATCH /tests/{id}/blue — Blue Team updates (blue_evaluating)
POST /tests/{id}/start-execution — draft → red_executing POST /tests/{id}/start-execution — draft → red_executing
POST /tests/{id}/submit-red — red_executing → blue_evaluating POST /tests/{id}/submit-red — red_executing → blue_evaluating
POST /tests/{id}/start-blue-work — blue tech picks up (sets Tempo timer)
POST /tests/{id}/submit-blue — blue_evaluating → in_review POST /tests/{id}/submit-blue — blue_evaluating → in_review
POST /tests/{id}/validate-red — Red Lead validates POST /tests/{id}/validate-red — Red Lead validates
POST /tests/{id}/validate-blue — Blue Lead validates POST /tests/{id}/validate-blue — Blue Lead validates
@@ -19,16 +18,16 @@ POST /tests/{id}/reopen — rejected → draft
GET /tests/{id}/timeline — audit-log history for this test GET /tests/{id}/timeline — audit-log history for this test
""" """
import base64 # Import uuid
import hashlib
import os
import uuid import uuid
from datetime import datetime
from typing import Any, Optional # Import Optional from typing
from typing import Optional
# Import APIRouter, Depends, HTTPException, Query, Reque... from fastapi # Import APIRouter, Depends, HTTPException, Query, Reque... from fastapi
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
from pydantic import BaseModel
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
# Import get_db from app.database # Import get_db from app.database
@@ -42,11 +41,11 @@ from app.domain.unit_of_work import UnitOfWork
# Import limiter from app.limiter # Import limiter from app.limiter
from app.limiter import limiter from app.limiter import limiter
from app.models.enums import TestState, TestResult, TeamSide
from app.models.evidence import Evidence # Import TestState from app.models.enums
from app.storage import upload_file from app.models.enums import TestState
from app.models.technique import Technique
from app.models.test import Test # Import User from app.models.user
from app.models.user import User from app.models.user import User
# Import from app.schemas.test # Import from app.schemas.test
@@ -70,7 +69,8 @@ from app.services.audit_service import log_action
# Import recalculate_technique_status from app.services.status_service # Import recalculate_technique_status from app.services.status_service
from app.services.status_service import recalculate_technique_status from app.services.status_service import recalculate_technique_status
from app.services.webhook_service import dispatch_webhook
# Import from app.services.test_crud_service
from app.services.test_crud_service import ( from app.services.test_crud_service import (
create_test as crud_create_test, create_test as crud_create_test,
) )
@@ -122,19 +122,54 @@ from app.services.test_crud_service import (
# Import from app.services.test_workflow_service # Import from app.services.test_workflow_service
from app.services.test_workflow_service import ( from app.services.test_workflow_service import (
start_execution as wf_start_execution,
submit_red_evidence as wf_submit_red,
submit_blue_evidence as wf_submit_blue,
start_blue_work as wf_start_blue_work,
validate_as_red_lead as wf_validate_red,
validate_as_blue_lead as wf_validate_blue,
reopen_test as wf_reopen,
handle_remediation_completed as wf_handle_remediation,
get_retest_chain as wf_get_retest_chain, get_retest_chain as wf_get_retest_chain,
)
# Import from app.services.test_workflow_service
from app.services.test_workflow_service import (
handle_remediation_completed as wf_handle_remediation,
)
# Import from app.services.test_workflow_service
from app.services.test_workflow_service import (
pause_timer as wf_pause_timer, pause_timer as wf_pause_timer,
)
# Import from app.services.test_workflow_service
from app.services.test_workflow_service import (
reopen_test as wf_reopen,
)
# Import from app.services.test_workflow_service
from app.services.test_workflow_service import (
resume_timer as wf_resume_timer, resume_timer as wf_resume_timer,
) )
# Import from app.services.test_workflow_service
from app.services.test_workflow_service import (
start_execution as wf_start_execution,
)
# Import from app.services.test_workflow_service
from app.services.test_workflow_service import (
submit_blue_evidence as wf_submit_blue,
)
# Import from app.services.test_workflow_service
from app.services.test_workflow_service import (
submit_red_evidence as wf_submit_red,
)
# Import from app.services.test_workflow_service
from app.services.test_workflow_service import (
validate_as_blue_lead as wf_validate_blue,
)
# Import from app.services.test_workflow_service
from app.services.test_workflow_service import (
validate_as_red_lead as wf_validate_red,
)
# Assign router = APIRouter(prefix="/tests", tags=["tests"]) # Assign router = APIRouter(prefix="/tests", tags=["tests"])
router = APIRouter(prefix="/tests", tags=["tests"]) router = APIRouter(prefix="/tests", tags=["tests"])
@@ -159,9 +194,7 @@ def list_tests(
pending_validation_side: Optional[str] = Query( pending_validation_side: Optional[str] = Query(
None, description="Filter in_review tests pending validation on 'red' or 'blue' side" None, description="Filter in_review tests pending validation on 'red' or 'blue' side"
), ),
not_in_any_campaign: bool = Query( # Entry: offset
False, description="Only return tests not linked to any campaign"
),
offset: int = Query(0, ge=0), offset: int = Query(0, ge=0),
# Entry: limit # Entry: limit
limit: int = Query(50, ge=1, le=200), limit: int = Query(50, ge=1, le=200),
@@ -200,7 +233,7 @@ def list_tests(
created_by=created_by, created_by=created_by,
# Keyword argument: pending_validation_side # Keyword argument: pending_validation_side
pending_validation_side=pending_validation_side, pending_validation_side=pending_validation_side,
not_in_any_campaign=not_in_any_campaign, # Keyword argument: offset
offset=offset, offset=offset,
# Keyword argument: limit # Keyword argument: limit
limit=limit, limit=limit,
@@ -276,14 +309,7 @@ def create_test(
# Reload ORM object attributes from the database # Reload ORM object attributes from the database
db.refresh(test) db.refresh(test)
# Auto-create Jira ticket (non-fatal — any failure is logged, not raised) # Return test
try:
from app.services.jira_service import auto_create_test_issue
auto_create_test_issue(db, test, current_user)
db.commit()
except Exception: # nosec B110
pass # jira_service already logs warnings internally
return test return test
@@ -337,11 +363,6 @@ def create_test_from_template(
technique_id_or_mitre=payload.technique_id, technique_id_or_mitre=payload.technique_id,
# Keyword argument: creator_id # Keyword argument: creator_id
creator_id=current_user.id, creator_id=current_user.id,
name_override=payload.name,
description_override=payload.description,
platform_override=payload.platform,
procedure_text_override=payload.procedure_text,
tool_used_override=payload.tool_used,
) )
# Call log_action() # Call log_action()
log_action( log_action(
@@ -369,14 +390,7 @@ def create_test_from_template(
# Reload ORM object attributes from the database # Reload ORM object attributes from the database
db.refresh(test) db.refresh(test)
# Auto-create Jira ticket (non-fatal) # Return test
try:
from app.services.jira_service import auto_create_test_issue
auto_create_test_issue(db, test, current_user)
db.commit()
except Exception: # nosec B110
pass # jira_service already logs warnings internally
return test return test
@@ -766,26 +780,6 @@ def submit_blue(
return test return test
# ---------------------------------------------------------------------------
# POST /tests/{id}/start-blue-work — blue tech picks up test for evaluation
# ---------------------------------------------------------------------------
@router.post("/{test_id}/start-blue-work", response_model=TestOut)
def start_blue_work(
test_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("blue_tech", "blue_lead")),
):
"""Blue tech picks up the test to start evaluating. Sets the Tempo timer start."""
test = crud_get_test_or_raise(db, test_id)
with UnitOfWork(db) as uow:
test = wf_start_blue_work(db, test, current_user)
uow.commit()
db.refresh(test)
return test
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# POST /tests/{id}/pause-timer — pause the active phase timer # POST /tests/{id}/pause-timer — pause the active phase timer
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -908,16 +902,11 @@ def validate_red(
if test.state in (TestState.validated, TestState.rejected): if test.state in (TestState.validated, TestState.rejected):
# Call recalculate_technique_status() # Call recalculate_technique_status()
recalculate_technique_status(db, test.technique) recalculate_technique_status(db, test.technique)
# Flag technique for review — coverage changed # Call uow.commit()
if test.technique:
test.technique.review_required = True
uow.commit() uow.commit()
# Reload ORM object attributes from the database # Reload ORM object attributes from the database
db.refresh(test) db.refresh(test)
if test.state == TestState.validated: # Return test
dispatch_webhook("test.validated", {"test_id": str(test.id), "technique_id": str(test.technique_id), "result": test.result.value if test.result else None})
elif test.state == TestState.rejected:
dispatch_webhook("test.rejected", {"test_id": str(test.id), "technique_id": str(test.technique_id)})
return test return test
@@ -965,16 +954,11 @@ def validate_blue(
if test.state in (TestState.validated, TestState.rejected): if test.state in (TestState.validated, TestState.rejected):
# Call recalculate_technique_status() # Call recalculate_technique_status()
recalculate_technique_status(db, test.technique) recalculate_technique_status(db, test.technique)
# Flag technique for review — coverage changed # Call uow.commit()
if test.technique:
test.technique.review_required = True
uow.commit() uow.commit()
# Reload ORM object attributes from the database # Reload ORM object attributes from the database
db.refresh(test) db.refresh(test)
if test.state == TestState.validated: # Return test
dispatch_webhook("test.validated", {"test_id": str(test.id), "technique_id": str(test.technique_id), "result": test.result.value if test.result else None})
elif test.state == TestState.rejected:
dispatch_webhook("test.rejected", {"test_id": str(test.id), "technique_id": str(test.technique_id)})
return test return test
@@ -1180,365 +1164,3 @@ def get_retest_chain(
} }
for t in chain for t in chain
] ]
# ---------------------------------------------------------------------------
# POST /tests/{id}/sync-tempo — manual Tempo sync for red execution worklog
# ---------------------------------------------------------------------------
@router.post("/{test_id}/sync-tempo")
def sync_tempo(
test_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Manually sync this test's red team execution worklog(s) to Tempo.
Useful when the automatic sync failed at phase completion (e.g. Tempo
was not yet configured). Only red_team_execution worklogs are eligible.
Already-synced worklogs are skipped. Returns a summary of what happened.
"""
from datetime import datetime as _dt
from app.models.worklog import Worklog
from app.services.tempo_service import auto_log_test_worklog
from app.services.test_crud_service import get_test_or_raise as _get
test = _get(db, test_id)
worklogs = (
db.query(Worklog)
.filter(
Worklog.entity_type == "test",
Worklog.entity_id == test_id,
Worklog.activity_type == "red_team_execution",
)
.all()
)
if not worklogs:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="No red team execution worklog found for this test.",
)
results = []
for wl in worklogs:
if wl.tempo_synced:
results.append({"worklog_id": str(wl.id), "status": "already_synced"})
continue
try:
result = auto_log_test_worklog(
db=db,
test=test,
user=current_user,
activity_type=wl.activity_type,
duration_seconds=wl.duration_seconds,
)
if result and isinstance(result, dict):
wl.tempo_synced = _dt.utcnow()
wl.tempo_worklog_id = str(result.get("tempoWorklogId", ""))
db.commit()
results.append({"worklog_id": str(wl.id), "status": "synced"})
else:
results.append({
"worklog_id": str(wl.id),
"status": "skipped",
"detail": "Tempo not configured or conditions not met.",
})
except Exception as exc:
results.append({
"worklog_id": str(wl.id),
"status": "error",
"detail": str(exc),
})
return {"results": results}
# ---------------------------------------------------------------------------
# POST /tests/{id}/request-discussion — disputed: confirm vote + notify other lead
# ---------------------------------------------------------------------------
@router.post("/{test_id}/request-discussion")
def request_discussion(
test_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead", "admin")),
):
"""Called when the approving lead confirms their vote in a disputed test.
Sends a notification to the other lead (who rejected) asking them to
discuss and resolve the conflict. The test remains in 'disputed' state.
"""
from app.models.user import User as UserModel
from app.services.notification_service import create_notification
test = crud_get_test_or_raise(db, test_id)
if test.state.value != "disputed":
from app.domain.errors import BusinessRuleViolation
raise BusinessRuleViolation("Test is not in disputed state")
role = current_user.role
# Identify who the "other lead" is (the one who rejected)
if (role in ("red_lead", "admin")) and test.red_validation_status == "approved":
# Red approved, Blue rejected → notify Blue Lead who rejected
rejector_id = test.blue_validated_by
rejector_label = "Blue Lead"
requester_label = "Red Lead"
elif (role in ("blue_lead", "admin")) and test.blue_validation_status == "approved":
# Blue approved, Red rejected → notify Red Lead who rejected
rejector_id = test.red_validated_by
rejector_label = "Red Lead"
requester_label = "Blue Lead"
else:
from app.domain.errors import BusinessRuleViolation
raise BusinessRuleViolation(
"The conflict state is inconsistent — no approving lead found"
)
# Look up the rejecting lead's full info for the response
rejector = (
db.query(UserModel).filter(UserModel.id == rejector_id).first()
if rejector_id else None
)
rejector_name = rejector.username if rejector else rejector_label
rejector_email = getattr(rejector, "email", None) if rejector else None
# Notify the rejecting lead
if rejector_id:
try:
create_notification(
db,
user_id=rejector_id,
type="validation_conflict",
title="Discussion requested on disputed test",
message=(
f"{requester_label} ({current_user.username}) is confirming their approval "
f"of test '{test.name}' and wants to discuss your rejection with you. "
f"Please reach out to resolve the disagreement."
),
entity_type="test",
entity_id=str(test.id),
)
except Exception as e:
import logging
logging.getLogger(__name__).warning(
"Failed to send discussion notification: %s", e
)
log_action(
db,
user_id=current_user.id,
action="request_dispute_discussion",
entity_type="test",
entity_id=test.id,
details={"test_name": test.name, "rejector": rejector_name},
)
db.commit()
return {
"status": "notification_sent",
"message": f"Discussion request sent to {rejector_name}",
"rejector_username": rejector_name,
"rejector_email": rejector_email,
"rejector_role": rejector_label,
}
# ---------------------------------------------------------------------------
# POST /tests/import-rt — bulk import from a real Red Team engagement
# ---------------------------------------------------------------------------
_ALLOWED_IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".gif", ".webp", ".bmp"}
_MAX_EVIDENCE_BYTES = 10 * 1024 * 1024 # 10 MB decoded per image
class RTEvidenceEntry(BaseModel):
filename: str # e.g. "screenshot_edr.png"
data: str # base64-encoded image content
caption: Optional[str] = None # optional description shown as evidence notes
class RTTechniqueEntry(BaseModel):
mitre_id: str
result: str # "detected" | "not_detected" | "partially_detected"
attack_success: bool = True
platform: Optional[str] = None
notes: Optional[str] = None
evidence: list[RTEvidenceEntry] # REQUIRED — at least one image per technique
class RTImportPayload(BaseModel):
name: str # engagement name, e.g. "Red Team Q1 2024"
date: Optional[str] = None # ISO date string
description: Optional[str] = None
operator: Optional[str] = None # team / company that ran the RT
techniques: list[RTTechniqueEntry]
@router.post("/import-rt", status_code=status.HTTP_201_CREATED)
def import_rt(
payload: RTImportPayload,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead")),
):
"""Import results from a real Red Team engagement.
Creates one Test record per technique in ``validated`` state (bypassing
the normal Red/Blue workflow) and immediately recalculates coverage metrics.
Requires ``red_lead`` or ``admin`` role.
"""
# Pre-validate: every technique must include at least one evidence image
for entry in payload.techniques:
if not entry.evidence:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=(
f"Technique {entry.mitre_id} is missing evidence. "
"At least one screenshot or image is required per technique."
),
)
# Execution date from payload or now
exec_date_str = payload.date or datetime.utcnow().date().isoformat()
# Result string → TestResult enum
_result_map = {
"detected": TestResult.detected,
"not_detected": TestResult.not_detected,
"partially_detected": TestResult.partially_detected,
}
created: list[dict[str, Any]] = []
skipped: list[dict[str, str]] = []
affected_technique_ids: set = set()
with UnitOfWork(db) as uow:
for entry in payload.techniques:
# Find technique
technique = (
db.query(Technique)
.filter(Technique.mitre_id == entry.mitre_id.upper())
.first()
)
if technique is None:
skipped.append({"mitre_id": entry.mitre_id, "reason": "Technique not found"})
continue
detection_result = _result_map.get(entry.result)
if detection_result is None:
skipped.append({"mitre_id": entry.mitre_id, "reason": f"Unknown result value '{entry.result}'"})
continue
test_name = f"[RT] {payload.name}{technique.name}"
# Build red_summary from notes + engagement metadata
parts = []
if payload.operator:
parts.append(f"Operator: {payload.operator}")
parts.append(f"Engagement date: {exec_date_str}")
if entry.notes:
parts.append(f"\n{entry.notes}")
red_summary_text = "\n".join(parts)
# RT pre-validates the Red side (they ran it), but Blue Lead
# must still validate the detection result before it counts.
# State = in_review so it appears in the Blue Lead's validation queue.
test = Test(
technique_id=technique.id,
name=test_name,
description=payload.description,
platform=entry.platform,
procedure_text=entry.notes,
created_by=current_user.id,
state=TestState.in_review,
# Red team — approved by the RT operator
attack_success=entry.attack_success,
red_summary=red_summary_text,
red_validation_status="approved",
red_validated_by=current_user.id,
red_validated_at=datetime.utcnow(),
# Blue team — pre-fill the detection result but leave
# validation_status pending so Blue Lead must confirm
detection_result=detection_result,
blue_validation_status=None,
# Timing
execution_date=exec_date_str,
created_at=datetime.utcnow(),
)
db.add(test)
db.flush()
# ── Store evidence images ──────────────────────────────
evidence_count = 0
for ev in entry.evidence:
safe_name = os.path.basename(ev.filename) or "evidence.png"
ext = os.path.splitext(safe_name)[1].lower()
if ext not in _ALLOWED_IMAGE_EXTS:
# Skip non-image files silently (log warning)
continue
try:
img_bytes = base64.b64decode(ev.data)
except Exception: # nosec B112
continue # malformed base64 — skip
if len(img_bytes) > _MAX_EVIDENCE_BYTES:
continue # over size limit — skip
sha256 = hashlib.sha256(img_bytes).hexdigest()
key = f"{test.id}/{uuid.uuid4()}_{safe_name}"
try:
upload_file(img_bytes, key)
except Exception: # nosec B112
continue # storage error — skip but don't abort
evidence_obj = Evidence(
test_id=test.id,
file_name=safe_name,
file_path=key,
sha256_hash=sha256,
uploaded_by=current_user.id,
uploaded_at=datetime.utcnow(),
team=TeamSide.red,
notes=ev.caption,
)
db.add(evidence_obj)
evidence_count += 1
affected_technique_ids.add(technique.id)
created.append({
"mitre_id": entry.mitre_id,
"test_name": test_name,
"result": entry.result,
"attack_success": entry.attack_success,
"evidence_attached": evidence_count,
})
log_action(
db,
user_id=current_user.id,
action="rt_import_test",
entity_type="test",
entity_id=test.id,
details={"engagement": payload.name, "mitre_id": entry.mitre_id},
)
# Recalculate coverage for all affected techniques
for tech_id in affected_technique_ids:
tech = db.query(Technique).filter(Technique.id == tech_id).first()
if tech:
recalculate_technique_status(db, tech)
uow.commit()
return {
"created": len(created),
"skipped": len(skipped),
"items": created,
"warnings": skipped,
"engagement": payload.name,
}
+2 -2
View File
@@ -62,7 +62,7 @@ def list_threat_actors(
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user # Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> dict: ) -> list:
"""List threat actors with optional filters and pagination. """List threat actors with optional filters and pagination.
**Requires** authentication (any role). **Requires** authentication (any role).
@@ -138,7 +138,7 @@ def get_threat_actor_gaps(
db: Session = Depends(get_db), db: Session = Depends(get_db),
# Entry: current_user # Entry: current_user
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> dict: ) -> list:
"""Identify techniques of this actor that are NOT fully validated. """Identify techniques of this actor that are NOT fully validated.
**Requires** authentication (any role). **Requires** authentication (any role).
+5 -43
View File
@@ -20,8 +20,11 @@ from app.domain.unit_of_work import UnitOfWork
# Import User from app.models.user # Import User from app.models.user
from app.models.user import User from app.models.user import User
from app.dependencies.auth import get_current_user
from app.schemas.user import UserCreate, UserUpdate, UserOut, UserPreferencesUpdate # Import UserCreate, UserOut, UserUpdate from app.schemas.user
from app.schemas.user import UserCreate, UserOut, UserUpdate
# Import log_action from app.services.audit_service
from app.services.audit_service import log_action from app.services.audit_service import log_action
# Import from app.services.user_service # Import from app.services.user_service
@@ -36,47 +39,6 @@ from app.services.user_service import (
router = APIRouter(prefix="/users", tags=["users"]) router = APIRouter(prefix="/users", tags=["users"])
# ---------------------------------------------------------------------------
# PATCH /users/me/preferences — update current user preferences
# ---------------------------------------------------------------------------
@router.patch("/me/preferences", response_model=UserOut)
def update_my_preferences(
payload: UserPreferencesUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Update the current user's notification preferences, Jira account ID and Jira API token.
Send ``jira_api_token: ""`` to clear a previously stored token.
The token is never returned in any response.
"""
update_data = payload.model_dump(exclude_unset=True)
for field, value in update_data.items():
if field in ("jira_api_token", "jira_email", "tempo_api_token"):
# Empty string means "clear the value"
setattr(current_user, field, value if value else None)
else:
setattr(current_user, field, value)
db.commit()
db.refresh(current_user)
return current_user
# ---------------------------------------------------------------------------
# GET /users/me — get current user's own profile
# ---------------------------------------------------------------------------
@router.get("/me", response_model=UserOut)
def get_me(
current_user: User = Depends(get_current_user),
):
"""Return the currently authenticated user's profile."""
return current_user
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# GET /users — list all users # GET /users — list all users
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
-149
View File
@@ -1,149 +0,0 @@
"""Webhook configuration CRUD router — admin only.
Endpoints
---------
GET /webhooks — list all webhook configs
POST /webhooks — create a new webhook config
GET /webhooks/{id} — get a single webhook config
PATCH /webhooks/{id} — update a webhook config
DELETE /webhooks/{id} — hard-delete a webhook config
POST /webhooks/{id}/test — send a test ping
"""
import uuid
from fastapi import APIRouter, Depends, status
from sqlalchemy.orm import Session
from app.database import get_db
from app.dependencies.auth import require_any_role
from app.domain.unit_of_work import UnitOfWork
from app.models.user import User
from app.schemas.webhook import WebhookConfigCreate, WebhookConfigOut, WebhookConfigUpdate
from app.services.webhook_service import (
create_webhook,
delete_webhook,
dispatch_webhook,
get_webhook_or_raise,
list_webhooks,
update_webhook,
)
router = APIRouter(prefix="/webhooks", tags=["webhooks"])
def _mask_secret(wh) -> WebhookConfigOut:
"""Return a WebhookConfigOut with the secret masked."""
out = WebhookConfigOut.model_validate(wh)
if wh.secret:
out.secret = "***" # nosec B105
else:
out.secret = None
return out
# ---------------------------------------------------------------------------
# GET /webhooks
# ---------------------------------------------------------------------------
@router.get("", response_model=list[WebhookConfigOut])
def list_webhooks_route(
offset: int = 0,
limit: int = 50,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("admin")),
):
"""Return all webhook configurations. **Requires admin role.**"""
webhooks = list_webhooks(db, offset=offset, limit=limit)
return [_mask_secret(wh) for wh in webhooks]
# ---------------------------------------------------------------------------
# POST /webhooks
# ---------------------------------------------------------------------------
@router.post("", response_model=WebhookConfigOut, status_code=status.HTTP_201_CREATED)
def create_webhook_route(
payload: WebhookConfigCreate,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("admin")),
):
"""Create a new webhook configuration. **Requires admin role.**"""
with UnitOfWork(db) as uow:
wh = create_webhook(db, created_by=current_user.id, payload=payload)
uow.commit()
db.refresh(wh)
return _mask_secret(wh)
# ---------------------------------------------------------------------------
# GET /webhooks/{id}
# ---------------------------------------------------------------------------
@router.get("/{webhook_id}", response_model=WebhookConfigOut)
def get_webhook_route(
webhook_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("admin")),
):
"""Return a single webhook configuration. **Requires admin role.**"""
wh = get_webhook_or_raise(db, webhook_id)
return _mask_secret(wh)
# ---------------------------------------------------------------------------
# PATCH /webhooks/{id}
# ---------------------------------------------------------------------------
@router.patch("/{webhook_id}", response_model=WebhookConfigOut)
def update_webhook_route(
webhook_id: uuid.UUID,
payload: WebhookConfigUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("admin")),
):
"""Update one or more fields of a webhook configuration. **Requires admin role.**"""
with UnitOfWork(db) as uow:
wh = update_webhook(db, webhook_id, payload)
uow.commit()
db.refresh(wh)
return _mask_secret(wh)
# ---------------------------------------------------------------------------
# DELETE /webhooks/{id}
# ---------------------------------------------------------------------------
@router.delete("/{webhook_id}", status_code=status.HTTP_204_NO_CONTENT)
def delete_webhook_route(
webhook_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("admin")),
):
"""Hard-delete a webhook configuration. **Requires admin role.**"""
with UnitOfWork(db) as uow:
delete_webhook(db, webhook_id)
uow.commit()
# ---------------------------------------------------------------------------
# POST /webhooks/{id}/test
# ---------------------------------------------------------------------------
@router.post("/{webhook_id}/test", status_code=status.HTTP_202_ACCEPTED)
def test_webhook_route(
webhook_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("admin")),
):
"""Send a test ping to the webhook endpoint. **Requires admin role.**"""
# Verify the webhook exists before dispatching
get_webhook_or_raise(db, webhook_id)
dispatch_webhook("webhook.test", {"webhook_id": str(webhook_id), "message": "Test ping from Aegis"})
return {"detail": "Test ping dispatched"}
-68
View File
@@ -1,68 +0,0 @@
"""Phase 14: API Key Pydantic schemas."""
from __future__ import annotations
from datetime import datetime
from typing import List, Optional
from uuid import UUID
from pydantic import BaseModel, Field, field_validator
from app.models.api_key import VALID_SCOPES
class ApiKeyCreate(BaseModel):
name: str = Field(..., min_length=1, max_length=200)
description: Optional[str] = None
scopes: List[str] = Field(default=["read"])
expires_at: Optional[datetime] = None
@field_validator("scopes")
@classmethod
def validate_scopes(cls, v: list) -> list:
invalid = set(v) - VALID_SCOPES
if invalid:
raise ValueError(f"Invalid scopes: {invalid}. Valid: {VALID_SCOPES}")
if not v:
raise ValueError("At least one scope is required")
return v
class ApiKeyOut(BaseModel):
"""Safe representation — never exposes key_hash."""
id: UUID
name: str
description: Optional[str] = None
key_prefix: str
user_id: UUID
scopes: List[str]
last_used_at: Optional[datetime] = None
expires_at: Optional[datetime] = None
is_active: bool
created_at: Optional[datetime] = None
class Config:
from_attributes = True
class ApiKeyCreated(ApiKeyOut):
"""Returned only once at creation — includes the raw key."""
raw_key: str = Field(..., description="The full API key — shown only this once.")
class ApiKeyUpdate(BaseModel):
name: Optional[str] = Field(None, min_length=1, max_length=200)
description: Optional[str] = None
scopes: Optional[List[str]] = None
expires_at: Optional[datetime] = None
is_active: Optional[bool] = None
@field_validator("scopes")
@classmethod
def validate_scopes(cls, v: Optional[list]) -> Optional[list]:
if v is None:
return v
invalid = set(v) - VALID_SCOPES
if invalid:
raise ValueError(f"Invalid scopes: {invalid}")
return v
-230
View File
@@ -1,230 +0,0 @@
"""Pydantic schemas for Phase 10: Attack Paths & Advanced Purple Team."""
from datetime import datetime
from typing import Optional
from uuid import UUID
from pydantic import BaseModel, ConfigDict, field_validator
VALID_KILL_CHAIN_PHASES = [
"reconnaissance", "resource_development", "initial_access", "execution",
"persistence", "privilege_escalation", "defense_evasion", "credential_access",
"discovery", "lateral_movement", "collection", "command_and_control",
"exfiltration", "impact",
]
# ── Attack Path ───────────────────────────────────────────────────────────────
class AttackPathCreate(BaseModel):
name: str
description: Optional[str] = None
objective: Optional[str] = None
is_template: bool = False
threat_actor_id: Optional[UUID] = None
tags: Optional[list[str]] = None
class AttackPathUpdate(BaseModel):
name: Optional[str] = None
description: Optional[str] = None
objective: Optional[str] = None
is_template: Optional[bool] = None
threat_actor_id: Optional[UUID] = None
tags: Optional[list[str]] = None
is_active: Optional[bool] = None
class AttackPathOut(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: UUID
name: str
description: Optional[str] = None
objective: Optional[str] = None
is_template: bool
threat_actor_id: Optional[UUID] = None
created_by: Optional[UUID] = None
tags: Optional[list] = None
is_active: bool
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
step_count: Optional[int] = None # injected by service
# ── Attack Path Step ──────────────────────────────────────────────────────────
class AttackPathStepCreate(BaseModel):
order_index: int = 0
kill_chain_phase: Optional[str] = None
technique_id: Optional[UUID] = None
test_id: Optional[UUID] = None
name: Optional[str] = None
description: Optional[str] = None
expected_detection: bool = True
notes: Optional[str] = None
@field_validator("kill_chain_phase")
@classmethod
def validate_phase(cls, v):
if v is not None and v not in VALID_KILL_CHAIN_PHASES:
raise ValueError(f"Invalid kill_chain_phase '{v}'. Valid: {VALID_KILL_CHAIN_PHASES}")
return v
class AttackPathStepUpdate(BaseModel):
order_index: Optional[int] = None
kill_chain_phase: Optional[str] = None
technique_id: Optional[UUID] = None
test_id: Optional[UUID] = None
name: Optional[str] = None
description: Optional[str] = None
expected_detection: Optional[bool] = None
notes: Optional[str] = None
@field_validator("kill_chain_phase")
@classmethod
def validate_phase(cls, v):
if v is not None and v not in VALID_KILL_CHAIN_PHASES:
raise ValueError(f"Invalid kill_chain_phase '{v}'.")
return v
class AttackPathStepOut(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: UUID
attack_path_id: UUID
order_index: int
kill_chain_phase: Optional[str] = None
technique_id: Optional[UUID] = None
test_id: Optional[UUID] = None
name: Optional[str] = None
description: Optional[str] = None
expected_detection: bool
notes: Optional[str] = None
# ── Execution ─────────────────────────────────────────────────────────────────
class ExecutionCreate(BaseModel):
environment: Optional[str] = None
red_team_lead: Optional[UUID] = None
blue_team_lead: Optional[UUID] = None
notes: Optional[str] = None
class ExecutionOut(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: UUID
attack_path_id: UUID
status: str
environment: Optional[str] = None
red_team_lead: Optional[UUID] = None
blue_team_lead: Optional[UUID] = None
started_by: Optional[UUID] = None
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
notes: Optional[str] = None
created_at: Optional[datetime] = None
# metrics
total_steps: Optional[int] = None
detected_steps: Optional[int] = None
not_detected_steps: Optional[int] = None
skipped_steps: Optional[int] = None
detection_rate: Optional[float] = None
mttd_seconds: Optional[float] = None
furthest_undetected_step: Optional[int] = None
# ── Step Result ───────────────────────────────────────────────────────────────
class StepExecuteRequest(BaseModel):
status: str # detected / not_detected / skipped
executed_at: Optional[datetime] = None
detected_at: Optional[datetime] = None
detection_asset_id: Optional[UUID] = None
notes: Optional[str] = None
evidence_ids: Optional[list[UUID]] = None
@field_validator("status")
@classmethod
def validate_status(cls, v):
valid = ("detected", "not_detected", "skipped", "executing")
if v not in valid:
raise ValueError(f"status must be one of {valid}")
return v
class StepResultOut(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: UUID
execution_id: UUID
step_id: UUID
step_order: int
status: str
executed_by: Optional[UUID] = None
executed_at: Optional[datetime] = None
detected_at: Optional[datetime] = None
time_to_detect_seconds: Optional[float] = None
detection_asset_id: Optional[UUID] = None
notes: Optional[str] = None
evidence_ids: Optional[list] = None
# ── Timeline ──────────────────────────────────────────────────────────────────
class TimelineEntryCreate(BaseModel):
actor_side: str
entry_type: str
content: str
step_id: Optional[UUID] = None
timestamp: Optional[datetime] = None
extra: Optional[dict] = None
@field_validator("actor_side")
@classmethod
def validate_side(cls, v):
if v not in ("red", "blue", "system"):
raise ValueError("actor_side must be red, blue or system")
return v
@field_validator("entry_type")
@classmethod
def validate_type(cls, v):
valid = ("action", "detection", "note", "phase_transition", "flag")
if v not in valid:
raise ValueError(f"entry_type must be one of {valid}")
return v
class TimelineEntryOut(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: UUID
execution_id: UUID
step_id: Optional[UUID] = None
timestamp: datetime
actor_side: str
actor_id: Optional[UUID] = None
entry_type: str
content: str
extra: Optional[dict] = None
# ── Metrics ───────────────────────────────────────────────────────────────────
class KillChainMetrics(BaseModel):
execution_id: UUID
total_steps: int
detected_steps: int
not_detected_steps: int
skipped_steps: int
detection_rate: float # 0.01.0
mttd_seconds: Optional[float] # mean time to detect
furthest_undetected_step: Optional[int]
furthest_undetected_phase: Optional[str]
step_breakdown: list[dict] # per-step detail
phase_summary: dict # detection rate per kill-chain phase
+6 -2
View File
@@ -5,7 +5,9 @@ import uuid
# Import datetime from datetime # Import datetime from datetime
from datetime import datetime from datetime import datetime
from typing import Any, Optional
# Import Any from typing
from typing import Any
# Import BaseModel, ConfigDict from pydantic # Import BaseModel, ConfigDict from pydantic
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
@@ -27,7 +29,9 @@ class AuditLogOut(BaseModel):
entity_type: str | None = None entity_type: str | None = None
# Assign entity_id = None # Assign entity_id = None
entity_id: str | None = None entity_id: str | None = None
timestamp: Optional[datetime] = None # timestamp: datetime
timestamp: datetime
# Assign details = None
details: dict[str, Any] | None = None details: dict[str, Any] | None = None
# Assign model_config = ConfigDict(from_attributes=True) # Assign model_config = ConfigDict(from_attributes=True)
@@ -1,140 +0,0 @@
"""Pydantic schemas for Detection Lifecycle endpoints."""
from pydantic import BaseModel, Field, ConfigDict
from typing import Optional
from uuid import UUID
from datetime import datetime
from app.models.detection_lifecycle import (
DetectionConfidence, DetectionHealthStatus, InvalidationReason
)
class DetectionAssetCreate(BaseModel):
name: str = Field(..., min_length=3, max_length=500)
description: Optional[str] = None
asset_type: str = Field(..., pattern=r'^(siem_rule|edr_rule|sigma_rule|yara_rule|spl_query|kql_query|custom_script)$')
platform: Optional[str] = None
rule_content: Optional[str] = None
rule_language: Optional[str] = None
rule_repository_url: Optional[str] = None
rule_file_path: Optional[str] = None
rule_version: Optional[str] = None
log_source_name: Optional[str] = None
log_source_version: Optional[str] = None
log_source_config: Optional[dict] = Field(default_factory=dict)
infrastructure_details: Optional[dict] = Field(default_factory=dict)
expected_alert_frequency: Optional[str] = None
tags: Optional[list[str]] = Field(default_factory=list)
technique_ids: Optional[list[UUID]] = Field(default_factory=list)
class DetectionAssetUpdate(BaseModel):
name: Optional[str] = None
description: Optional[str] = None
rule_content: Optional[str] = None
rule_version: Optional[str] = None
log_source_version: Optional[str] = None
infrastructure_details: Optional[dict] = None
expected_alert_frequency: Optional[str] = None
health_status: Optional[DetectionHealthStatus] = None
last_alert_at: Optional[datetime] = None
alert_count_30d: Optional[int] = None
false_positive_rate: Optional[float] = None
owner_id: Optional[UUID] = None
backup_owner_id: Optional[UUID] = None
team: Optional[str] = None
tags: Optional[list[str]] = None
is_active: Optional[bool] = None
class DetectionAssetOut(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: UUID
name: str
description: Optional[str] = None
asset_type: str
platform: Optional[str] = None
rule_language: Optional[str] = None
rule_version: Optional[str] = None
rule_hash: Optional[str] = None
health_status: DetectionHealthStatus
last_alert_at: Optional[datetime] = None
alert_count_30d: int
false_positive_rate: Optional[float] = None
expected_alert_frequency: Optional[str] = None
owner_id: Optional[UUID] = None
team: Optional[str] = None
is_active: bool
tags: list = Field(default_factory=list)
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
class DetectionValidationCreate(BaseModel):
detection_asset_id: UUID
technique_id: Optional[UUID] = None
test_id: Optional[UUID] = None
validation_result: str = Field(..., pattern=r'^(detected|not_detected|partial|error)$')
validation_method: str
notes: Optional[str] = None
evidence_ids: Optional[list[UUID]] = Field(default_factory=list)
validity_days: int = Field(default=180, ge=30, le=730)
class DetectionValidationOut(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: UUID
detection_asset_id: UUID
technique_id: Optional[UUID] = None
validated_at: Optional[datetime] = None
expires_at: datetime
is_valid: bool
validation_result: Optional[str] = None
validation_method: Optional[str] = None
invalidated_at: Optional[datetime] = None
invalidation_reason: Optional[InvalidationReason] = None
validated_by: Optional[UUID] = None
notes: Optional[str] = None
class TechniqueConfidenceOut(BaseModel):
model_config = ConfigDict(from_attributes=True)
technique_id: UUID
confidence_level: DetectionConfidence
confidence_score: float
detection_count: int
valid_detection_count: int
last_validated_at: Optional[datetime] = None
next_validation_due: Optional[datetime] = None
recency_factor: float
coverage_factor: float
health_factor: float
diversity_factor: float
risk_factors: list = Field(default_factory=list)
class InfrastructureChangeCreate(BaseModel):
change_type: str
description: str = Field(..., min_length=10)
affected_platforms: list[str] = Field(default_factory=list)
affected_log_sources: list[str] = Field(default_factory=list)
change_date: Optional[datetime] = None
auto_invalidate: bool = True
class InfrastructureChangeOut(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: UUID
change_type: str
description: str
affected_platforms: list = Field(default_factory=list)
affected_log_sources: list = Field(default_factory=list)
change_date: Optional[datetime] = None
auto_invalidate: bool
invalidated_count: int
reported_by: Optional[UUID] = None
created_at: Optional[datetime] = None
@@ -1,113 +0,0 @@
"""Phase 13: Executive Dashboard — Pydantic schemas."""
from __future__ import annotations
from datetime import date, datetime
from typing import Any, Dict, List, Optional
from uuid import UUID
from pydantic import BaseModel, Field
class PostureSnapshotOut(BaseModel):
id: UUID
snapshot_date: date
# Coverage
total_techniques: int
validated_count: int
partial_count: int
not_covered_count: int
coverage_pct: float
# Risk
avg_risk_score: float
critical_count: int
high_count: int
medium_count: int
low_count: int
# Operations
open_queue_items: int
orphan_techniques: int
# Knowledge
playbook_count: int
lesson_count: int
# MTTD
mttd_avg_seconds: Optional[float] = None
executions_30d: int
detection_rate_30d: Optional[float] = None
# Meta
created_by: Optional[UUID] = None
created_at: Optional[datetime] = None
extra: Optional[Dict[str, Any]] = None
class Config:
from_attributes = True
class ExecutiveSummary(BaseModel):
"""Full executive view — current posture + trends."""
snapshot: PostureSnapshotOut
coverage_trend: List[Dict[str, Any]] = Field(
default_factory=list,
description="Last 30-day coverage_pct series [{date, value}]",
)
risk_trend: List[Dict[str, Any]] = Field(
default_factory=list,
description="Last 30-day avg_risk_score series [{date, value}]",
)
top_risks: List[Dict[str, Any]] = Field(
default_factory=list,
description="Top 5 highest-risk techniques",
)
coverage_by_tactic: List[Dict[str, Any]] = Field(
default_factory=list,
description="Per-tactic validated/partial/not_covered counts",
)
recent_activity: List[Dict[str, Any]] = Field(
default_factory=list,
description="Most-recent events (tests, paths, queue changes)",
)
class KpiBlock(BaseModel):
"""Compact KPI block for a dashboard header."""
coverage_pct: float
avg_risk_score: float
critical_count: int
open_queue_items: int
orphan_techniques: int
mttd_avg_seconds: Optional[float] = None
detection_rate_30d: Optional[float] = None
playbook_count: int
lesson_count: int
snapshot_date: date
snapshot_id: Optional[UUID] = None
class CoverageByTactic(BaseModel):
tactic: str
total: int
validated: int
partial: int
not_covered: int
coverage_pct: float
class PostureHistoryEntry(BaseModel):
snapshot_date: date
coverage_pct: float
avg_risk_score: float
critical_count: int
open_queue_items: int
class ActivityEntry(BaseModel):
ts: datetime
category: str # "test" | "attack_path" | "queue" | "osint"
title: str
detail: Optional[str] = None
-149
View File
@@ -1,149 +0,0 @@
"""Phase 11: Knowledge Management schemas — Playbooks + Lessons Learned."""
from datetime import datetime
from typing import List, Optional
from uuid import UUID
from pydantic import BaseModel, ConfigDict, field_validator
# ── Constants ─────────────────────────────────────────────────────────────────
VALID_PLAYBOOK_TYPES = ["attack", "detect", "investigate", "respond", "hunt"]
VALID_SEVERITIES = ["critical", "high", "medium", "low", "info"]
VALID_ENTITY_TYPES = ["test", "campaign", "attack_path", "manual"]
# ══════════════════════════════════════════════════════════════════════════════
# Playbook schemas
# ══════════════════════════════════════════════════════════════════════════════
class PlaybookCreate(BaseModel):
technique_id: UUID
playbook_type: str
title: str
content: str = ""
tools: List[str] = []
prerequisites: List[str] = []
change_note: Optional[str] = None
@field_validator("playbook_type")
@classmethod
def validate_playbook_type(cls, v: str) -> str:
if v not in VALID_PLAYBOOK_TYPES:
raise ValueError(
f"Invalid playbook_type '{v}'. Must be one of: {VALID_PLAYBOOK_TYPES}"
)
return v
class PlaybookUpdate(BaseModel):
title: Optional[str] = None
content: Optional[str] = None
tools: Optional[List[str]] = None
prerequisites: Optional[List[str]] = None
change_note: Optional[str] = None
class PlaybookVersionOut(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: UUID
playbook_id: UUID
version: int
title: str
content: str
tools: List[str] = []
prerequisites: List[str] = []
changed_by: Optional[UUID]
change_note: Optional[str]
created_at: datetime
class PlaybookOut(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: UUID
technique_id: UUID
playbook_type: str
title: str
content: str
version: int
tools: List[str] = []
prerequisites: List[str] = []
created_by: Optional[UUID]
updated_by: Optional[UUID]
created_at: datetime
updated_at: datetime
is_active: bool
# ══════════════════════════════════════════════════════════════════════════════
# Lesson Learned schemas
# ══════════════════════════════════════════════════════════════════════════════
class LessonLearnedCreate(BaseModel):
title: str
what_happened: str
root_cause: str
fix_applied: Optional[str] = None
severity: str = "medium"
entity_type: str = "manual"
entity_id: Optional[UUID] = None
technique_ids: List[str] = []
tags: List[str] = []
@field_validator("severity")
@classmethod
def validate_severity(cls, v: str) -> str:
if v not in VALID_SEVERITIES:
raise ValueError(
f"Invalid severity '{v}'. Must be one of: {VALID_SEVERITIES}"
)
return v
@field_validator("entity_type")
@classmethod
def validate_entity_type(cls, v: str) -> str:
if v not in VALID_ENTITY_TYPES:
raise ValueError(
f"Invalid entity_type '{v}'. Must be one of: {VALID_ENTITY_TYPES}"
)
return v
class LessonLearnedUpdate(BaseModel):
title: Optional[str] = None
what_happened: Optional[str] = None
root_cause: Optional[str] = None
fix_applied: Optional[str] = None
severity: Optional[str] = None
technique_ids: Optional[List[str]] = None
tags: Optional[List[str]] = None
@field_validator("severity")
@classmethod
def validate_severity(cls, v: Optional[str]) -> Optional[str]:
if v is not None and v not in VALID_SEVERITIES:
raise ValueError(
f"Invalid severity '{v}'. Must be one of: {VALID_SEVERITIES}"
)
return v
class LessonLearnedOut(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: UUID
title: str
what_happened: str
root_cause: str
fix_applied: Optional[str]
severity: str
entity_type: str
entity_id: Optional[UUID]
technique_ids: List[str] = []
tags: List[str] = []
created_by: Optional[UUID]
created_at: datetime
updated_at: datetime
is_active: bool
@@ -1,124 +0,0 @@
"""Phase 13: Operational Alerts — Pydantic schemas."""
from __future__ import annotations
from datetime import datetime
from typing import Any, Dict, List, Optional
from uuid import UUID
from pydantic import BaseModel, Field, field_validator
from app.models.operational_alert import AlertRuleType, AlertSeverity, AlertStatus
VALID_SEVERITIES = {s.value for s in AlertSeverity}
VALID_STATUSES = {s.value for s in AlertStatus}
VALID_RULE_TYPES = {r.value for r in AlertRuleType}
# ── AlertRule schemas ─────────────────────────────────────────────────────────
class AlertRuleCreate(BaseModel):
name: str = Field(..., min_length=1, max_length=300)
description: Optional[str] = None
rule_type: str
severity: str = "medium"
config: Dict[str, Any] = Field(default_factory=dict)
notify_in_app: bool = True
notify_webhook: bool = False
webhook_id: Optional[UUID] = None
cooldown_hours: int = Field(24, ge=0, le=8760)
@field_validator("rule_type")
@classmethod
def validate_rule_type(cls, v: str) -> str:
if v not in VALID_RULE_TYPES:
raise ValueError(f"Invalid rule_type. Valid: {VALID_RULE_TYPES}")
return v
@field_validator("severity")
@classmethod
def validate_severity(cls, v: str) -> str:
if v not in VALID_SEVERITIES:
raise ValueError(f"Invalid severity. Valid: {VALID_SEVERITIES}")
return v
class AlertRuleUpdate(BaseModel):
name: Optional[str] = Field(None, min_length=1, max_length=300)
description: Optional[str] = None
severity: Optional[str] = None
is_enabled: Optional[bool] = None
config: Optional[Dict[str, Any]] = None
notify_in_app: Optional[bool] = None
notify_webhook: Optional[bool] = None
webhook_id: Optional[UUID] = None
cooldown_hours: Optional[int] = Field(None, ge=0, le=8760)
@field_validator("severity")
@classmethod
def validate_severity(cls, v: Optional[str]) -> Optional[str]:
if v is not None and v not in VALID_SEVERITIES:
raise ValueError(f"Invalid severity. Valid: {VALID_SEVERITIES}")
return v
class AlertRuleOut(BaseModel):
id: UUID
name: str
description: Optional[str] = None
rule_type: str
severity: str
is_enabled: bool
is_system: bool
config: Dict[str, Any]
notify_in_app: bool
notify_webhook: bool
webhook_id: Optional[UUID] = None
cooldown_hours: int
created_by: Optional[UUID] = None
created_at: Optional[datetime] = None
last_fired_at: Optional[datetime] = None
class Config:
from_attributes = True
# ── AlertInstance schemas ─────────────────────────────────────────────────────
class AlertInstanceOut(BaseModel):
id: UUID
rule_id: Optional[UUID] = None
rule_name: str
rule_type: str
severity: str
title: str
message: str
details: Optional[Dict[str, Any]] = None
status: str
acknowledged_by: Optional[UUID] = None
acknowledged_at: Optional[datetime] = None
resolved_at: Optional[datetime] = None
created_at: Optional[datetime] = None
class Config:
from_attributes = True
# ── Evaluation result ─────────────────────────────────────────────────────────
class EvaluationResult(BaseModel):
rules_evaluated: int
alerts_fired: int
alerts: List[AlertInstanceOut] = Field(default_factory=list)
duration_seconds: float
# ── Summary ───────────────────────────────────────────────────────────────────
class AlertSummary(BaseModel):
total_open: int
total_acknowledged: int
total_resolved: int
by_severity: Dict[str, int]
by_rule_type: Dict[str, int]
recent_alerts: List[AlertInstanceOut] = Field(default_factory=list)
@@ -1,153 +0,0 @@
"""Pydantic schemas for Phase 9: Ownership & Revalidation Queue."""
from datetime import datetime
from typing import Optional
from uuid import UUID
from pydantic import BaseModel, ConfigDict, field_validator
# ── Technique Ownership ───────────────────────────────────────────────────────
class TechniqueOwnershipSet(BaseModel):
"""Set (create or replace) ownership for a technique."""
owner_id: Optional[UUID] = None
backup_owner_id: Optional[UUID] = None
team: Optional[str] = None
notes: Optional[str] = None
class TechniqueOwnershipOut(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: UUID
technique_id: UUID
owner_id: Optional[UUID] = None
backup_owner_id: Optional[UUID] = None
team: Optional[str] = None
notes: Optional[str] = None
assigned_at: Optional[datetime] = None
assigned_by: Optional[UUID] = None
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
class DetectionAssetOwnershipPatch(BaseModel):
"""Update ownership fields on a detection asset."""
owner_id: Optional[UUID] = None
backup_owner_id: Optional[UUID] = None
team: Optional[str] = None
# ── Bulk Assignment ───────────────────────────────────────────────────────────
class BulkAssignRequest(BaseModel):
"""Bulk-assign ownership by tactic, platform, or team filter."""
owner_id: Optional[UUID] = None
backup_owner_id: Optional[UUID] = None
team: Optional[str] = None
# Filters — at least one must be set
tactic: Optional[str] = None # assign all techniques with this tactic
platform: Optional[str] = None # assign all detection assets with this platform
overwrite: bool = False # overwrite existing assignments
class BulkAssignResult(BaseModel):
assigned_count: int
skipped_count: int
target_type: str # "technique" or "detection_asset"
# ── Revalidation Queue ────────────────────────────────────────────────────────
class QueueItemPatch(BaseModel):
"""Update a revalidation queue item."""
status: Optional[str] = None
assigned_to: Optional[UUID] = None
priority: Optional[str] = None
due_date: Optional[datetime] = None
@field_validator("status")
@classmethod
def validate_status(cls, v):
from app.models.ownership_queue import QueueStatus
if v is not None:
try:
QueueStatus(v)
except ValueError:
raise ValueError(f"Invalid status: {v}")
return v
@field_validator("priority")
@classmethod
def validate_priority(cls, v):
from app.models.ownership_queue import QueuePriority
if v is not None:
try:
QueuePriority(v)
except ValueError:
raise ValueError(f"Invalid priority: {v}")
return v
class QueueItemCreate(BaseModel):
"""Manually create a queue item."""
technique_id: Optional[UUID] = None
detection_asset_id: Optional[UUID] = None
priority: str = "medium"
reason: str = "manual"
reason_detail: Optional[str] = None
assigned_to: Optional[UUID] = None
due_date: Optional[datetime] = None
@field_validator("reason")
@classmethod
def validate_reason(cls, v):
from app.models.ownership_queue import QueueReason
try:
QueueReason(v)
except ValueError:
valid = [e.value for e in QueueReason]
raise ValueError(f"Invalid reason '{v}'. Must be one of: {valid}")
return v
@field_validator("priority")
@classmethod
def validate_priority(cls, v):
from app.models.ownership_queue import QueuePriority
try:
QueuePriority(v)
except ValueError:
valid = [e.value for e in QueuePriority]
raise ValueError(f"Invalid priority '{v}'. Must be one of: {valid}")
return v
class QueueItemOut(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: UUID
technique_id: Optional[UUID] = None
detection_asset_id: Optional[UUID] = None
priority: str
reason: str
reason_detail: Optional[str] = None
status: str
assigned_to: Optional[UUID] = None
due_date: Optional[datetime] = None
created_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
dismissed_at: Optional[datetime] = None
completed_by: Optional[UUID] = None
extra: Optional[dict] = None
# ── Analyst Dashboard ─────────────────────────────────────────────────────────
class AnalystDashboard(BaseModel):
"""Personalised daily workday view for an analyst."""
my_pending_items: list[QueueItemOut]
expiring_validations_7d: list[dict]
recent_infra_changes: list[dict]
my_low_confidence_techniques: list[dict]
summary: dict
-71
View File
@@ -1,71 +0,0 @@
"""Phase 12: Risk Intelligence schemas."""
from datetime import datetime
from typing import Any, Dict, List, Optional
from uuid import UUID
from pydantic import BaseModel, ConfigDict
VALID_RISK_LEVELS = ["critical", "high", "medium", "low", "info"]
class TechniqueRiskProfileOut(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: UUID
technique_id: UUID
risk_score: float
likelihood: float
impact: float
risk_level: str
detection_gap: float
threat_actor_count: int
osint_signal_count: int
test_fail_count: int
test_total_count: int
test_failure_rate: float
confidence_level: float
scoring_breakdown: Optional[Dict[str, Any]]
recommendations: Optional[List[str]]
computed_at: datetime
is_stale: bool
class RiskMatrixEntry(BaseModel):
model_config = ConfigDict(from_attributes=True)
technique_id: UUID
technique_name: Optional[str] = None
technique_tid: Optional[str] = None # e.g. "T1059"
risk_score: float
likelihood: float
impact: float
risk_level: str
detection_gap: float
computed_at: datetime
class RiskSummary(BaseModel):
total_techniques: int
scored_techniques: int
stale_count: int
by_level: Dict[str, int] # {"critical": 3, "high": 12, ...}
avg_risk_score: float
top_risks: List[RiskMatrixEntry]
class RecommendationItem(BaseModel):
technique_id: UUID
technique_name: Optional[str] = None
technique_tid: Optional[str] = None
risk_level: str
risk_score: float
recommendations: List[str]
priority: int # 1 = highest
class ComputeResult(BaseModel):
computed: int
skipped: int
errors: int
duration_seconds: float
-80
View File
@@ -1,80 +0,0 @@
"""Phase 14: SSO / SAML 2.0 Pydantic schemas."""
from __future__ import annotations
from datetime import datetime
from typing import Optional
from uuid import UUID
from pydantic import BaseModel, Field
class SsoConfigCreate(BaseModel):
is_enabled: bool = False
provider_name: Optional[str] = None
# SP settings (auto-derived if not provided)
sp_entity_id: Optional[str] = None
sp_acs_url: Optional[str] = None
sp_slo_url: Optional[str] = None
sp_certificate: Optional[str] = None
sp_private_key: Optional[str] = None
# IdP settings
idp_entity_id: Optional[str] = None
idp_sso_url: Optional[str] = None
idp_slo_url: Optional[str] = None
idp_certificate: Optional[str] = None
# Attribute mapping
attr_email: Optional[str] = "email"
attr_username: Optional[str] = "username"
attr_role: Optional[str] = "role"
default_role: Optional[str] = "viewer"
auto_provision: bool = True
class SsoConfigUpdate(SsoConfigCreate):
"""All fields optional for partial updates."""
pass
class SsoConfigOut(BaseModel):
id: UUID
is_enabled: bool
provider_name: Optional[str] = None
sp_entity_id: Optional[str] = None
sp_acs_url: Optional[str] = None
sp_slo_url: Optional[str] = None
sp_certificate: Optional[str] = None
# sp_private_key is intentionally OMITTED from responses
idp_entity_id: Optional[str] = None
idp_sso_url: Optional[str] = None
idp_slo_url: Optional[str] = None
idp_certificate: Optional[str] = None
attr_email: Optional[str] = None
attr_username: Optional[str] = None
attr_role: Optional[str] = None
default_role: Optional[str] = None
auto_provision: bool = True
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
class Config:
from_attributes = True
class SsoLoginInitResponse(BaseModel):
redirect_url: str = Field(..., description="URL to redirect the browser to for IdP login")
request_id: str = Field(..., description="SAML AuthnRequest ID for validation")
class SsoStatusResponse(BaseModel):
enabled: bool
provider_name: Optional[str] = None
configured: bool = Field(..., description="True if IdP settings are present")
login_url: Optional[str] = None # /sso/login URL
+23 -58
View File
@@ -6,12 +6,14 @@ import uuid
# Import datetime from datetime # Import datetime from datetime
from datetime import datetime from datetime import datetime
from pydantic import BaseModel, ConfigDict, model_validator # Import BaseModel, ConfigDict from pydantic
from pydantic import BaseModel, ConfigDict
# Import DataClassification from app.domain.enums # Import DataClassification from app.domain.enums
from app.domain.enums import DataClassification from app.domain.enums import DataClassification
# Import TestResult, TestState from app.models.enums
from app.models.enums import TestResult, TestState from app.models.enums import TestResult, TestState
from app.schemas.evidence import EvidenceOut
# ── Create ────────────────────────────────────────────────────────── # ── Create ──────────────────────────────────────────────────────────
@@ -207,7 +209,7 @@ class TestOut(BaseModel):
red_started_at: datetime | None = None red_started_at: datetime | None = None
# Assign blue_started_at = None # Assign blue_started_at = None
blue_started_at: datetime | None = None blue_started_at: datetime | None = None
blue_work_started_at: datetime | None = None # Assign paused_at = None
paused_at: datetime | None = None paused_at: datetime | None = None
# Assign red_paused_seconds = 0 # Assign red_paused_seconds = 0
red_paused_seconds: int = 0 red_paused_seconds: int = 0
@@ -233,64 +235,27 @@ class TestOut(BaseModel):
# Assign technique_name = None # Assign technique_name = None
technique_name: str | None = None technique_name: str | None = None
# Evidences split by team (populated from the ORM relationship) # Assign model_config = ConfigDict(from_attributes=True)
red_evidences: list[EvidenceOut] = []
blue_evidences: list[EvidenceOut] = []
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)
@model_validator(mode="before") # Apply the @classmethod decorator
@classmethod @classmethod
def _populate_derived_fields(cls, obj): # Define function model_validate
"""Populate technique and evidence fields from ORM relationships. def model_validate(cls, obj: object, **kwargs: object) -> "TestOut":
"""Populate technique fields from the ORM relationship before validation.
Uses ``@model_validator(mode='before')`` so it is called by Pydantic's Args:
internal Rust validation pipeline, including FastAPI's TypeAdapter path. obj (object): The ORM model instance (or any compatible object) to validate.
A plain ``model_validate`` classmethod override is **not** invoked by **kwargs (object): Additional keyword arguments forwarded to the parent.
FastAPI's response serialisation in Pydantic v2 — only registered
validators are guaranteed to run.
Evidences are only processed when the relationship was **explicitly loaded** Returns:
(via joinedload or prior access). Accessing ``obj.evidences`` blindly on a TestOut: The validated schema instance with technique fields populated.
session-expired ORM object triggers a lazy query that fails on mutation
endpoints that do not joinload the relationship. We inspect ``obj.__dict__``
directly — SQLAlchemy stores loaded relationships there; if the key is absent
the relationship is unloaded and we leave the lists empty (the frontend
invalidates and refetches the detail endpoint, which *does* joinload).
""" """
if not hasattr(obj, "__dict__"): # Check: hasattr(obj, "technique") and obj.technique is not None
return obj if hasattr(obj, "technique") and obj.technique is not None:
# Assign obj.__dict__["technique_mitre_id"] = obj.technique.mitre_id
# Technique info (lazy-load is fine here: session is still open on GET) obj.__dict__["technique_mitre_id"] = obj.technique.mitre_id
try: # Assign obj.__dict__["technique_name"] = obj.technique.name
if hasattr(obj, "technique") and obj.technique is not None: obj.__dict__["technique_name"] = obj.technique.name
obj.__dict__["technique_mitre_id"] = obj.technique.mitre_id # Return super().model_validate(obj, **kwargs)
obj.__dict__["technique_name"] = obj.technique.name return super().model_validate(obj, **kwargs)
except Exception: # nosec B110
pass # DetachedInstanceError or similar — leave technique fields None
# Only split evidences when they are already in memory (loaded via joinedload)
raw_evs = obj.__dict__.get("evidences")
if raw_evs is not None:
red_evs: list[EvidenceOut] = []
blue_evs: list[EvidenceOut] = []
for ev in raw_evs:
ev_out = EvidenceOut(
id=ev.id,
test_id=ev.test_id,
file_name=ev.file_name,
sha256_hash=ev.sha256_hash,
uploaded_by=ev.uploaded_by,
uploaded_at=ev.uploaded_at,
team=ev.team,
notes=ev.notes,
download_url=f"/api/v1/evidence/{ev.id}/file",
)
if ev.team and ev.team.value == "blue":
blue_evs.append(ev_out)
else:
red_evs.append(ev_out)
obj.__dict__["red_evidences"] = red_evs
obj.__dict__["blue_evidences"] = blue_evs
return obj
+1 -11
View File
@@ -111,19 +111,9 @@ class TestTemplateSummary(BaseModel):
class TestTemplateInstantiate(BaseModel): class TestTemplateInstantiate(BaseModel):
"""Payload to create a real test from an existing template. """Payload to create a real test from an existing template."""
Optional override fields take precedence over the template values when provided.
"""
# template_id: uuid.UUID # template_id: uuid.UUID
template_id: uuid.UUID template_id: uuid.UUID
# technique_id: str # accepts both UUID and MITRE ID (e.g. "T1059.001") # technique_id: str # accepts both UUID and MITRE ID (e.g. "T1059.001")
technique_id: str # accepts both UUID and MITRE ID (e.g. "T1059.001") technique_id: str # accepts both UUID and MITRE ID (e.g. "T1059.001")
# User-editable overrides (if omitted the template value is used)
name: str | None = None
description: str | None = None
platform: str | None = None
procedure_text: str | None = None
tool_used: str | None = None
+3 -39
View File
@@ -9,8 +9,8 @@ import uuid
# Import datetime from datetime # Import datetime from datetime
from datetime import datetime from datetime import datetime
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator # Import BaseModel, ConfigDict, field_validator from pydantic
from pydantic import BaseModel, ConfigDict, field_validator
# ── Username policy ───────────────────────────────────────────────── # ── Username policy ─────────────────────────────────────────────────
@@ -225,22 +225,7 @@ class PasswordChange(BaseModel):
return _validate_password_strength(v) return _validate_password_strength(v)
class UserPreferencesUpdate(BaseModel): # Define class UserOut
"""Payload for updating current user's notification preferences and Jira/Tempo settings."""
notification_preferences: dict | None = None
jira_account_id: str | None = None
# Personal Jira API token (Atlassian token) — write-only.
# Set to empty string "" to clear the token.
jira_api_token: str | None = None
# Atlassian email for Jira auth — overrides account email.
# Set to empty string "" to clear (falls back to account email).
jira_email: str | None = None
# Personal Tempo API token — write-only.
# Set to empty string "" to clear the token.
tempo_api_token: str | None = None
class UserOut(BaseModel): class UserOut(BaseModel):
"""Complete representation returned by the API.""" """Complete representation returned by the API."""
@@ -260,27 +245,6 @@ class UserOut(BaseModel):
created_at: datetime | None = None created_at: datetime | None = None
# Assign last_login = None # Assign last_login = None
last_login: datetime | None = None last_login: datetime | None = None
notification_preferences: dict | None = None
jira_account_id: str | None = None
jira_email: str | None = None
# Read from ORM but NEVER exposed in responses — used only to derive *_token_set flags.
jira_api_token: str | None = Field(default=None, exclude=True)
tempo_api_token: str | None = Field(default=None, exclude=True)
# True when the user has the respective token stored.
jira_token_set: bool = False
tempo_token_set: bool = False
# Assign model_config = ConfigDict(from_attributes=True) # Assign model_config = ConfigDict(from_attributes=True)
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)
@model_validator(mode="after")
def _derive_token_set_flags(self) -> "UserOut":
"""Derive *_token_set booleans from the (excluded) raw token fields.
Uses @model_validator(mode='after') so Pydantic's Rust core calls it
during FastAPI response serialisation — model_validate() overrides are
bypassed by FastAPI's __pydantic_validator__.validate_python() path.
"""
self.jira_token_set = bool(self.jira_api_token)
self.tempo_token_set = bool(self.tempo_api_token)
return self
-91
View File
@@ -1,91 +0,0 @@
"""Pydantic schemas for Webhook endpoints."""
import ipaddress
import socket
import uuid
from datetime import datetime
from urllib.parse import urlparse
from pydantic import BaseModel, ConfigDict, field_validator
# RFC-5735 / RFC-1918 / RFC-3927 — ranges that must never be webhook targets
_BLOCKED_NETWORKS = [
ipaddress.ip_network("10.0.0.0/8"),
ipaddress.ip_network("172.16.0.0/12"),
ipaddress.ip_network("192.168.0.0/16"),
ipaddress.ip_network("169.254.0.0/16"), # link-local / AWS IMDS
ipaddress.ip_network("127.0.0.0/8"), # loopback
ipaddress.ip_network("::1/128"), # IPv6 loopback
ipaddress.ip_network("fc00::/7"), # IPv6 ULA
]
def _validate_webhook_url(url: str) -> str:
"""Reject URLs that point to private/reserved addresses (SSRF prevention)."""
parsed = urlparse(url)
if parsed.scheme not in ("http", "https"):
raise ValueError("Webhook URL must use http or https")
hostname = parsed.hostname
if not hostname:
raise ValueError("Webhook URL must include a hostname")
# Resolve hostname to IP(s) and reject any private/reserved address
try:
infos = socket.getaddrinfo(hostname, None)
for info in infos:
raw_ip = info[4][0]
try:
ip_obj = ipaddress.ip_address(raw_ip)
except ValueError:
continue
for network in _BLOCKED_NETWORKS:
if ip_obj in network:
raise ValueError(
f"Webhook URL resolves to a private/reserved address ({raw_ip}) "
"and cannot be used"
)
except OSError:
# DNS resolution failure — allow (will fail at dispatch time)
pass
return url
class WebhookConfigCreate(BaseModel):
name: str
url: str
secret: str | None = None
events: list[str] = []
is_active: bool = True
@field_validator("url")
@classmethod
def url_must_be_external(cls, v: str) -> str:
return _validate_webhook_url(v)
class WebhookConfigUpdate(BaseModel):
name: str | None = None
url: str | None = None
secret: str | None = None
events: list[str] | None = None
is_active: bool | None = None
@field_validator("url")
@classmethod
def url_must_be_external(cls, v: str | None) -> str | None:
if v is None:
return v
return _validate_webhook_url(v)
class WebhookConfigOut(BaseModel):
id: uuid.UUID
name: str
url: str
secret: str | None = None # masked on read
events: list[str]
is_active: bool
created_by: uuid.UUID | None = None
last_triggered_at: datetime | None = None
failure_count: int
created_at: datetime | None = None
model_config = ConfigDict(from_attributes=True)
-39
View File
@@ -1,39 +0,0 @@
"""Seed default decay policies."""
from datetime import datetime
from sqlalchemy.orm import Session
from app.models.decay_policy import DecayPolicy
def seed_decay_policies(db: Session) -> None:
existing = db.query(DecayPolicy).filter(DecayPolicy.is_default == True).first()
if existing:
return
now = datetime.utcnow()
default_policy = DecayPolicy(
name="Default Decay Policy",
description="Standard: Fresh 90d, Aging 91-180d, Stale 181-365d.",
fresh_days=90, aging_days=180, stale_days=365,
default_validity_days=180, silent_threshold_days=30,
noisy_threshold_daily=100,
recency_weight=0.30, coverage_weight=0.30,
health_weight=0.25, diversity_weight=0.15,
is_default=True, is_active=True,
created_at=now, updated_at=now,
)
db.add(default_policy)
critical_policy = DecayPolicy(
name="Critical Techniques Policy",
description="Stricter: Fresh 60d, Aging 90d, Stale 180d.",
applies_to_tactic="initial-access",
fresh_days=60, aging_days=90, stale_days=180,
default_validity_days=90, silent_threshold_days=14,
noisy_threshold_daily=50,
recency_weight=0.35, coverage_weight=0.30,
health_weight=0.25, diversity_weight=0.10,
is_default=False, is_active=True,
created_at=now, updated_at=now,
)
db.add(critical_policy)
db.commit()
-155
View File
@@ -1,155 +0,0 @@
"""Phase 14: API Key service — create, list, revoke, authenticate."""
from __future__ import annotations
from datetime import datetime
from typing import List, Optional
from uuid import UUID
from sqlalchemy.orm import Session
from app.domain.errors import EntityNotFoundError, DuplicateEntityError
from app.models.api_key import ApiKey, generate_raw_key, hash_key, key_prefix_display
from app.models.user import User
# ── Create ────────────────────────────────────────────────────────────────────
def create_api_key(
db: Session,
user_id: UUID,
name: str,
scopes: List[str],
description: Optional[str] = None,
expires_at: Optional[datetime] = None,
) -> tuple[ApiKey, str]:
"""
Create a new API key.
Returns ``(ApiKey, raw_key)`` — the raw_key must be shown to the user
immediately and is never retrievable again.
"""
raw_key = generate_raw_key()
key_hash = hash_key(raw_key)
prefix = key_prefix_display(raw_key)
# Detect accidental collision (astronomically unlikely)
if db.query(ApiKey).filter(ApiKey.key_hash == key_hash).first():
raise DuplicateEntityError("ApiKey", "key_hash", "<collision>")
key = ApiKey(
name = name,
description = description,
key_prefix = prefix,
key_hash = key_hash,
user_id = user_id,
scopes = scopes,
expires_at = expires_at,
)
db.add(key)
db.commit()
db.refresh(key)
return key, raw_key
# ── Read ──────────────────────────────────────────────────────────────────────
def list_api_keys(
db: Session,
user_id: Optional[UUID] = None,
include_inactive: bool = False,
) -> List[ApiKey]:
q = db.query(ApiKey)
if user_id is not None:
q = q.filter(ApiKey.user_id == user_id)
if not include_inactive:
q = q.filter(ApiKey.is_active == True)
return q.order_by(ApiKey.created_at.desc()).all()
def get_api_key(db: Session, key_id: UUID, user_id: Optional[UUID] = None) -> ApiKey:
q = db.query(ApiKey).filter(ApiKey.id == key_id)
if user_id is not None:
q = q.filter(ApiKey.user_id == user_id)
key = q.first()
if not key:
raise EntityNotFoundError("ApiKey", str(key_id))
return key
# ── Update / Revoke ───────────────────────────────────────────────────────────
def update_api_key(
db: Session,
key_id: UUID,
user_id: Optional[UUID] = None,
*,
name: Optional[str] = None,
description: Optional[str] = None,
scopes: Optional[List[str]] = None,
expires_at: Optional[datetime] = None,
is_active: Optional[bool] = None,
) -> ApiKey:
key = get_api_key(db, key_id, user_id)
if name is not None:
key.name = name
if description is not None:
key.description = description
if scopes is not None:
key.scopes = scopes
if expires_at is not None:
key.expires_at = expires_at
if is_active is not None:
key.is_active = is_active
db.commit()
db.refresh(key)
return key
def revoke_api_key(
db: Session,
key_id: UUID,
user_id: Optional[UUID] = None,
) -> ApiKey:
"""Soft-revoke: set is_active = False."""
return update_api_key(db, key_id, user_id, is_active=False)
def delete_api_key(db: Session, key_id: UUID, user_id: Optional[UUID] = None) -> None:
"""Hard delete — use revoke instead for audit trail."""
key = get_api_key(db, key_id, user_id)
db.delete(key)
db.commit()
# ── Authentication ────────────────────────────────────────────────────────────
def authenticate_raw_key(db: Session, raw_key: str) -> Optional[User]:
"""
Verify a raw API key.
Returns the owning User if the key is valid, active, and not expired.
Updates ``last_used_at`` (throttled to once per request — always updates).
Returns None on any failure.
"""
h = hash_key(raw_key)
key: Optional[ApiKey] = db.query(ApiKey).filter(ApiKey.key_hash == h).first()
if key is None or not key.is_active:
return None
if key.expires_at and key.expires_at < datetime.utcnow():
return None
# Update last_used_at
key.last_used_at = datetime.utcnow()
db.commit()
user: Optional[User] = db.query(User).filter(User.id == key.user_id).first()
if user is None or not user.is_active:
return None
# Attach the key's scopes to the user instance so scope-enforcement
# dependencies can verify them without an additional DB query.
# _api_key_scopes=None means "full user access" (JWT path).
user._api_key_scopes = key.scopes or []
return user
+8 -10
View File
@@ -51,7 +51,8 @@ from sqlalchemy.orm import Session
# Import TestTemplate from app.models.test_template # Import TestTemplate from app.models.test_template
from app.models.test_template import TestTemplate from app.models.test_template import TestTemplate
from app.models.technique import Technique
# Import log_action from app.services.audit_service
from app.services.audit_service import log_action from app.services.audit_service import log_action
# Assign logger = logging.getLogger(__name__) # Assign logger = logging.getLogger(__name__)
@@ -135,7 +136,7 @@ def _safe_extract_zip(zip_bytes: bytes, dest: str) -> None:
f"exceeds limit of {_MAX_UNCOMPRESSED_SIZE / (1024 * 1024):.0f} MB" f"exceeds limit of {_MAX_UNCOMPRESSED_SIZE / (1024 * 1024):.0f} MB"
) )
# Iterate over entries — validate and extract each member individually # Iterate over entries
for member in entries: for member in entries:
# Assign target = (dest_path / member.filename).resolve() # Assign target = (dest_path / member.filename).resolve()
target = (dest_path / member.filename).resolve() target = (dest_path / member.filename).resolve()
@@ -146,7 +147,9 @@ def _safe_extract_zip(zip_bytes: bytes, dest: str) -> None:
f"Zip Slip detected — member '{member.filename}' " f"Zip Slip detected — member '{member.filename}' "
f"resolves outside target directory" f"resolves outside target directory"
) )
zf.extract(member, dest)
# Call zf.extractall()
zf.extractall(dest)
# Define function _extract_zip # Define function _extract_zip
@@ -307,7 +310,6 @@ def import_atomic_red_team(db: Session) -> dict:
created = 0 created = 0
# Assign skipped = 0 # Assign skipped = 0
skipped = 0 skipped = 0
new_technique_ids: set[str] = set()
# Iterate over parsed_tests # Iterate over parsed_tests
for item in parsed_tests: for item in parsed_tests:
@@ -345,14 +347,10 @@ def import_atomic_red_team(db: Session) -> dict:
db.add(template) db.add(template)
# Call existing_ids.add() # Call existing_ids.add()
existing_ids.add(item["atomic_test_id"]) existing_ids.add(item["atomic_test_id"])
new_technique_ids.add(item["technique_id"]) # Assign created = 1
created += 1 created += 1
if new_technique_ids: # Commit all pending changes to the database
db.query(Technique).filter(
Technique.mitre_id.in_(new_technique_ids)
).update({"review_required": True}, synchronize_session=False)
db.commit() db.commit()
# Count distinct YAML files by technique_id # Count distinct YAML files by technique_id
-553
View File
@@ -1,553 +0,0 @@
"""Phase 10: Attack Path CRUD service."""
import logging
from datetime import datetime
from typing import Optional
from uuid import UUID
from sqlalchemy.orm import Session, joinedload
from app.models.attack_path import (
AttackPath, AttackPathStep, AttackPathExecution,
AttackPathStepResult, TimelineEntry,
ExecutionStatus, StepResultStatus, TimelineActorSide, TimelineEntryType,
)
from app.domain.exceptions import EntityNotFoundError
from app.services import audit_service
logger = logging.getLogger(__name__)
def _now() -> datetime:
return datetime.utcnow()
# ── Attack Path CRUD ──────────────────────────────────────────────────────────
def create_attack_path(db: Session, data: dict, user_id: UUID) -> AttackPath:
path = AttackPath(
name=data["name"],
description=data.get("description"),
objective=data.get("objective"),
is_template=data.get("is_template", False),
threat_actor_id=data.get("threat_actor_id"),
tags=data.get("tags") or [],
created_by=user_id,
)
db.add(path)
db.commit()
db.refresh(path)
audit_service.log_action(
db, user_id, "ATTACK_PATH_CREATED", "attack_path", str(path.id),
details={"name": path.name, "is_template": path.is_template},
)
return path
def get_attack_path(db: Session, path_id: UUID) -> AttackPath:
path = (
db.query(AttackPath)
.options(joinedload(AttackPath.steps))
.filter(AttackPath.id == path_id)
.first()
)
if not path:
raise EntityNotFoundError("AttackPath", str(path_id))
return path
def list_attack_paths(
db: Session,
is_template: Optional[bool] = None,
technique_id: Optional[UUID] = None,
is_active: Optional[bool] = True,
) -> list[AttackPath]:
q = db.query(AttackPath)
if is_active is not None:
q = q.filter(AttackPath.is_active == is_active)
if is_template is not None:
q = q.filter(AttackPath.is_template == is_template)
if technique_id:
q = q.join(AttackPathStep).filter(AttackPathStep.technique_id == technique_id)
return q.order_by(AttackPath.created_at.desc()).all()
def update_attack_path(db: Session, path_id: UUID, data: dict, user_id: UUID) -> AttackPath:
path = db.query(AttackPath).filter(AttackPath.id == path_id).first()
if not path:
raise EntityNotFoundError("AttackPath", str(path_id))
for k, v in data.items():
if v is not None and hasattr(path, k):
setattr(path, k, v)
path.updated_at = _now()
db.commit()
db.refresh(path)
return path
def delete_attack_path(db: Session, path_id: UUID, user_id: UUID) -> None:
path = db.query(AttackPath).filter(AttackPath.id == path_id).first()
if not path:
raise EntityNotFoundError("AttackPath", str(path_id))
path.is_active = False
path.updated_at = _now()
db.commit()
audit_service.log_action(db, user_id, "ATTACK_PATH_ARCHIVED", "attack_path", str(path_id))
# ── Steps CRUD ────────────────────────────────────────────────────────────────
def add_step(db: Session, path_id: UUID, data: dict, user_id: UUID) -> AttackPathStep:
path = db.query(AttackPath).filter(AttackPath.id == path_id).first()
if not path:
raise EntityNotFoundError("AttackPath", str(path_id))
# Auto-assign order_index if not provided
if data.get("order_index") is None:
max_idx = db.query(AttackPathStep).filter(
AttackPathStep.attack_path_id == path_id
).count()
data["order_index"] = max_idx
step = AttackPathStep(
attack_path_id=path_id,
order_index=data.get("order_index", 0),
kill_chain_phase=data.get("kill_chain_phase"),
technique_id=data.get("technique_id"),
test_id=data.get("test_id"),
name=data.get("name"),
description=data.get("description"),
expected_detection=data.get("expected_detection", True),
notes=data.get("notes"),
)
db.add(step)
path.updated_at = _now()
db.commit()
db.refresh(step)
return step
def update_step(db: Session, step_id: UUID, data: dict, user_id: UUID) -> AttackPathStep:
step = db.query(AttackPathStep).filter(AttackPathStep.id == step_id).first()
if not step:
raise EntityNotFoundError("AttackPathStep", str(step_id))
for k, v in data.items():
if v is not None and hasattr(step, k):
setattr(step, k, v)
db.commit()
db.refresh(step)
return step
def delete_step(db: Session, step_id: UUID, user_id: UUID) -> None:
step = db.query(AttackPathStep).filter(AttackPathStep.id == step_id).first()
if not step:
raise EntityNotFoundError("AttackPathStep", str(step_id))
db.delete(step)
db.commit()
def reorder_steps(db: Session, path_id: UUID, step_ids: list[UUID], user_id: UUID) -> list[AttackPathStep]:
"""Reorder steps by providing ordered list of step IDs."""
path = db.query(AttackPath).filter(AttackPath.id == path_id).first()
if not path:
raise EntityNotFoundError("AttackPath", str(path_id))
for idx, step_id in enumerate(step_ids):
db.query(AttackPathStep).filter(
AttackPathStep.id == step_id,
AttackPathStep.attack_path_id == path_id,
).update({"order_index": idx})
path.updated_at = _now()
db.commit()
return (
db.query(AttackPathStep)
.filter(AttackPathStep.attack_path_id == path_id)
.order_by(AttackPathStep.order_index)
.all()
)
# ── Executions ────────────────────────────────────────────────────────────────
def create_execution(
db: Session, path_id: UUID, data: dict, user_id: UUID
) -> AttackPathExecution:
path = db.query(AttackPath).filter(AttackPath.id == path_id).first()
if not path:
raise EntityNotFoundError("AttackPath", str(path_id))
execution = AttackPathExecution(
attack_path_id=path_id,
status=ExecutionStatus.planned,
environment=data.get("environment"),
red_team_lead=data.get("red_team_lead"),
blue_team_lead=data.get("blue_team_lead"),
notes=data.get("notes"),
started_by=user_id,
)
db.add(execution)
db.flush()
# Pre-create pending step results for every step in the path
steps = (
db.query(AttackPathStep)
.filter(AttackPathStep.attack_path_id == path_id)
.order_by(AttackPathStep.order_index)
.all()
)
for step in steps:
result = AttackPathStepResult(
execution_id=execution.id,
step_id=step.id,
step_order=step.order_index,
status=StepResultStatus.pending,
)
db.add(result)
db.commit()
db.refresh(execution)
# Auto-add system timeline entry
_add_system_entry(
db, execution.id,
entry_type=TimelineEntryType.phase_transition,
content=f"Execution created for '{path.name}' with {len(steps)} steps.",
)
audit_service.log_action(
db, user_id, "ATTACK_PATH_EXECUTION_STARTED", "attack_path_execution",
str(execution.id),
details={"path_id": str(path_id), "path_name": path.name, "steps": len(steps)},
)
return execution
def get_execution(db: Session, execution_id: UUID) -> AttackPathExecution:
ex = (
db.query(AttackPathExecution)
.options(
joinedload(AttackPathExecution.step_results),
joinedload(AttackPathExecution.timeline),
)
.filter(AttackPathExecution.id == execution_id)
.first()
)
if not ex:
raise EntityNotFoundError("AttackPathExecution", str(execution_id))
return ex
def list_executions(db: Session, path_id: UUID) -> list[AttackPathExecution]:
path = db.query(AttackPath).filter(AttackPath.id == path_id).first()
if not path:
raise EntityNotFoundError("AttackPath", str(path_id))
return (
db.query(AttackPathExecution)
.filter(AttackPathExecution.attack_path_id == path_id)
.order_by(AttackPathExecution.created_at.desc())
.all()
)
def start_execution(db: Session, execution_id: UUID, user_id: UUID) -> AttackPathExecution:
ex = db.query(AttackPathExecution).filter(AttackPathExecution.id == execution_id).first()
if not ex:
raise EntityNotFoundError("AttackPathExecution", str(execution_id))
if ex.status not in (ExecutionStatus.planned,):
from fastapi import HTTPException
raise HTTPException(400, "Execution is not in 'planned' state")
ex.status = ExecutionStatus.in_progress
ex.started_at = _now()
db.commit()
db.refresh(ex)
_add_system_entry(db, execution_id, TimelineEntryType.phase_transition,
"Execution started.", actor_id=user_id, actor_side=TimelineActorSide.system)
return ex
# ── Step Execution ────────────────────────────────────────────────────────────
def execute_step(
db: Session,
execution_id: UUID,
step_id: UUID,
data: dict,
user_id: UUID,
) -> AttackPathStepResult:
"""Record the result of executing one step."""
ex = db.query(AttackPathExecution).filter(AttackPathExecution.id == execution_id).first()
if not ex:
raise EntityNotFoundError("AttackPathExecution", str(execution_id))
if ex.status not in (ExecutionStatus.in_progress, ExecutionStatus.planned):
from fastapi import HTTPException
raise HTTPException(400, "Execution must be in_progress to record step results")
# Auto-start if still planned
if ex.status == ExecutionStatus.planned:
ex.status = ExecutionStatus.in_progress
ex.started_at = _now()
result = (
db.query(AttackPathStepResult)
.filter(
AttackPathStepResult.execution_id == execution_id,
AttackPathStepResult.step_id == step_id,
)
.first()
)
if not result:
# Create on-the-fly if step was added after execution started
step = db.query(AttackPathStep).filter(AttackPathStep.id == step_id).first()
if not step:
raise EntityNotFoundError("AttackPathStep", str(step_id))
result = AttackPathStepResult(
execution_id=execution_id,
step_id=step_id,
step_order=step.order_index,
)
db.add(result)
now = _now()
new_status = StepResultStatus(data["status"])
result.status = new_status
result.executed_by = user_id
result.executed_at = data.get("executed_at") or now
result.notes = data.get("notes")
result.evidence_ids = [str(e) for e in (data.get("evidence_ids") or [])]
result.detection_asset_id = data.get("detection_asset_id")
if new_status == StepResultStatus.detected:
result.detected_at = data.get("detected_at") or now
if result.executed_at:
delta = (result.detected_at - result.executed_at).total_seconds()
result.time_to_detect_seconds = max(0.0, delta)
db.commit()
db.refresh(result)
# Add timeline entry
step_obj = db.query(AttackPathStep).filter(AttackPathStep.id == step_id).first()
step_name = step_obj.name or (step_obj.kill_chain_phase or "Unknown step")
actor_side = TimelineActorSide.red if new_status != StepResultStatus.detected else TimelineActorSide.blue
entry_type = (
TimelineEntryType.detection if new_status == StepResultStatus.detected
else TimelineEntryType.action
)
content = (
f"Step '{step_name}' marked as {new_status.value}."
+ (f" Detected in {result.time_to_detect_seconds:.0f}s." if result.time_to_detect_seconds else "")
)
_add_system_entry(
db, execution_id, entry_type, content,
actor_id=user_id, actor_side=actor_side, step_id=step_id,
)
return result
# ── Completion & Metrics ──────────────────────────────────────────────────────
def complete_execution(db: Session, execution_id: UUID, user_id: UUID) -> AttackPathExecution:
"""Mark execution complete and compute all kill-chain metrics."""
ex = db.query(AttackPathExecution).filter(AttackPathExecution.id == execution_id).first()
if not ex:
raise EntityNotFoundError("AttackPathExecution", str(execution_id))
results = (
db.query(AttackPathStepResult)
.filter(AttackPathStepResult.execution_id == execution_id)
.order_by(AttackPathStepResult.step_order)
.all()
)
total = len(results)
detected = sum(1 for r in results if r.status == StepResultStatus.detected)
not_detected = sum(1 for r in results if r.status == StepResultStatus.not_detected)
skipped = sum(1 for r in results if r.status == StepResultStatus.skipped)
detection_rate = (detected / total) if total > 0 else 0.0
ttds = [r.time_to_detect_seconds for r in results
if r.time_to_detect_seconds is not None]
mttd = (sum(ttds) / len(ttds)) if ttds else None
# Furthest undetected step (highest order_index with not_detected status)
undetected = [r for r in results if r.status == StepResultStatus.not_detected]
furthest = max((r.step_order for r in undetected), default=None)
ex.status = ExecutionStatus.completed
ex.completed_at = _now()
ex.total_steps = total
ex.detected_steps = detected
ex.not_detected_steps = not_detected
ex.skipped_steps = skipped
ex.detection_rate = round(detection_rate, 4)
ex.mttd_seconds = round(mttd, 1) if mttd is not None else None
ex.furthest_undetected_step = furthest
db.commit()
db.refresh(ex)
_add_system_entry(
db, execution_id, TimelineEntryType.phase_transition,
f"Execution completed. Detection rate: {detection_rate:.0%}. "
f"Detected {detected}/{total} steps. "
+ (f"MTTD: {mttd:.0f}s." if mttd else ""),
actor_id=user_id, actor_side=TimelineActorSide.system,
)
audit_service.log_action(
db, user_id, "ATTACK_PATH_EXECUTION_COMPLETED", "attack_path_execution",
str(execution_id),
details={"detection_rate": detection_rate, "mttd_seconds": mttd,
"detected": detected, "total": total},
)
return ex
def abort_execution(db: Session, execution_id: UUID, user_id: UUID) -> AttackPathExecution:
ex = db.query(AttackPathExecution).filter(AttackPathExecution.id == execution_id).first()
if not ex:
raise EntityNotFoundError("AttackPathExecution", str(execution_id))
ex.status = ExecutionStatus.aborted
ex.completed_at = _now()
db.commit()
db.refresh(ex)
_add_system_entry(db, execution_id, TimelineEntryType.flag, "Execution aborted.",
actor_id=user_id, actor_side=TimelineActorSide.system)
return ex
# ── Timeline ──────────────────────────────────────────────────────────────────
def add_timeline_entry(
db: Session, execution_id: UUID, data: dict, user_id: UUID
) -> TimelineEntry:
ex = db.query(AttackPathExecution).filter(AttackPathExecution.id == execution_id).first()
if not ex:
raise EntityNotFoundError("AttackPathExecution", str(execution_id))
entry = TimelineEntry(
execution_id=execution_id,
step_id=data.get("step_id"),
timestamp=data.get("timestamp") or _now(),
actor_side=TimelineActorSide(data["actor_side"]),
actor_id=user_id,
entry_type=TimelineEntryType(data["entry_type"]),
content=data["content"],
extra=data.get("extra"),
)
db.add(entry)
db.commit()
db.refresh(entry)
return entry
def get_timeline(db: Session, execution_id: UUID) -> list[TimelineEntry]:
ex = db.query(AttackPathExecution).filter(AttackPathExecution.id == execution_id).first()
if not ex:
raise EntityNotFoundError("AttackPathExecution", str(execution_id))
return (
db.query(TimelineEntry)
.filter(TimelineEntry.execution_id == execution_id)
.order_by(TimelineEntry.timestamp.asc())
.all()
)
# ── Kill-Chain Metrics ────────────────────────────────────────────────────────
def get_kill_chain_metrics(db: Session, execution_id: UUID) -> dict:
ex = db.query(AttackPathExecution).filter(AttackPathExecution.id == execution_id).first()
if not ex:
raise EntityNotFoundError("AttackPathExecution", str(execution_id))
results = (
db.query(AttackPathStepResult)
.filter(AttackPathStepResult.execution_id == execution_id)
.order_by(AttackPathStepResult.step_order)
.all()
)
step_breakdown = []
phase_detected: dict[str, list] = {}
for r in results:
step = db.query(AttackPathStep).filter(AttackPathStep.id == r.step_id).first()
phase = step.kill_chain_phase if step else None
entry = {
"step_id": str(r.step_id),
"step_order": r.step_order,
"step_name": step.name if step else None,
"kill_chain_phase": phase,
"status": r.status.value if hasattr(r.status, "value") else r.status,
"executed_at": r.executed_at.isoformat() if r.executed_at else None,
"detected_at": r.detected_at.isoformat() if r.detected_at else None,
"time_to_detect_seconds": r.time_to_detect_seconds,
"detection_asset_id": str(r.detection_asset_id) if r.detection_asset_id else None,
}
step_breakdown.append(entry)
if phase:
phase_detected.setdefault(phase, []).append(
r.status == StepResultStatus.detected
)
phase_summary = {
phase: {
"total": len(v),
"detected": sum(v),
"detection_rate": round(sum(v) / len(v), 3) if v else 0.0,
}
for phase, v in phase_detected.items()
}
# Furthest undetected phase
furthest_undetected_phase = None
if ex.furthest_undetected_step is not None:
for r in reversed(results):
if r.step_order == ex.furthest_undetected_step:
step = db.query(AttackPathStep).filter(AttackPathStep.id == r.step_id).first()
if step:
furthest_undetected_phase = step.kill_chain_phase
break
return {
"execution_id": str(execution_id),
"total_steps": ex.total_steps or len(results),
"detected_steps": ex.detected_steps or 0,
"not_detected_steps": ex.not_detected_steps or 0,
"skipped_steps": ex.skipped_steps or 0,
"detection_rate": ex.detection_rate or 0.0,
"mttd_seconds": ex.mttd_seconds,
"furthest_undetected_step": ex.furthest_undetected_step,
"furthest_undetected_phase": furthest_undetected_phase,
"step_breakdown": step_breakdown,
"phase_summary": phase_summary,
}
# ── Helper ────────────────────────────────────────────────────────────────────
def _add_system_entry(
db: Session,
execution_id: UUID,
entry_type: TimelineEntryType,
content: str,
actor_id: Optional[UUID] = None,
actor_side: TimelineActorSide = TimelineActorSide.system,
step_id: Optional[UUID] = None,
) -> None:
entry = TimelineEntry(
execution_id=execution_id,
step_id=step_id,
timestamp=_now(),
actor_side=actor_side,
actor_id=actor_id,
entry_type=entry_type,
content=content,
)
db.add(entry)
db.commit()
@@ -1,798 +0,0 @@
"""ATT&CK Evaluations importer — fetches real CrowdStrike detection results
from MITRE Engenuity's public API and seeds the platform with validated tests.
Data source
-----------
https://evals.mitre.org/api/
- /participants/ → list of vendors + rounds they completed
- /results/?participant=crowdstrike&domain=ENTERPRISE
→ per-substep detection results per adversary
Detection level mapping (MITRE → Aegis)
---------------------------------------
Technique / Specific Behavior → detected (correctly identified ATT&CK technique)
Tactic → partially_detected (behavior noted but not categorized)
General / IOC / MSSP → partially_detected (anomaly detected, not ATT&CK-mapped)
Telemetry → partially_detected (raw data only — marginal detection)
None / N/A → not_detected
All imported tests are created in ``in_review`` state so Blue Leads must
confirm each result before it counts as real coverage for the organisation.
Important caveats stored in every test's description
------------------------------------------------------
"Source: MITRE ATT&CK Evaluation (Round N — Adversary). Results reflect
CrowdStrike Falcon in a controlled lab environment, NOT this organisation's
deployment. Validate detection in your own environment before approving."
"""
import logging
import re
import uuid
from datetime import datetime
from typing import Any
import requests
from sqlalchemy.orm import Session
from app.models.enums import TestState, TestResult
from app.models.evaluation_import import EvaluationImport
from app.models.technique import Technique
from app.models.test import Test
from app.models.user import User
from app.services.audit_service import log_action
from app.services.status_service import recalculate_technique_status
logger = logging.getLogger(__name__)
_BASE = "https://evals.mitre.org"
_TIMEOUT = 30 # seconds per HTTP call
_VENDOR = "crowdstrike"
_DOMAIN = "ENTERPRISE"
# Browser-like headers to bypass Cloudflare bot protection on evals.mitre.org
_HEADERS = {
"User-Agent": (
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
"AppleWebKit/537.36 (KHTML, like Gecko) "
"Chrome/124.0.0.0 Safari/537.36"
),
"Accept": "application/json, text/plain, */*",
"Accept-Language": "en-US,en;q=0.9",
"Referer": "https://evals.mitre.org/",
"Origin": "https://evals.mitre.org",
}
# ---------------------------------------------------------------------------
# Fallback: hardcoded public CrowdStrike ENTERPRISE rounds
# Used when evals.mitre.org API is unreachable (Cloudflare 502, outage, etc.)
#
# Names use the EXACT slugs the live API returns (hyphens, not underscores).
# Verified from live API response on 2025-06-05.
# CrowdStrike did NOT participate in Round 6 (OilRig) — not included.
# ---------------------------------------------------------------------------
_FALLBACK_ROUNDS: list[dict[str, Any]] = [
{
"name": "apt3",
"display_name": "APT3",
"eval_round": 1,
"domain": "ENTERPRISE",
"status": "PUBLIC",
},
{
"name": "apt29",
"display_name": "APT29",
"eval_round": 2,
"domain": "ENTERPRISE",
"status": "PUBLIC",
},
{
"name": "carbanak-fin7",
"display_name": "Carbanak+FIN7",
"eval_round": 3,
"domain": "ENTERPRISE",
"status": "PUBLIC",
},
{
"name": "wizard-spider-sandworm",
"display_name": "Wizard Spider + Sandworm",
"eval_round": 4,
"domain": "ENTERPRISE",
"status": "PUBLIC",
},
{
"name": "turla",
"display_name": "Turla",
"eval_round": 5,
"domain": "ENTERPRISE",
"status": "PUBLIC",
},
{
"name": "er7",
"display_name": "Enterprise 2025",
"eval_round": 7,
"domain": "ENTERPRISE",
"status": "PUBLIC",
},
]
# Detection type → quality score (higher = better)
_DETECTION_SCORE: dict[str, int] = {
"none": 0,
"n/a": 0,
"telemetry": 1,
"mssp": 2,
"general": 2,
"ioc": 2,
"tactic": 3,
"technique": 4,
"specific behavior": 4,
}
def _score(detection_type: str) -> int:
key = (detection_type or "").lower().strip()
for pattern, score in _DETECTION_SCORE.items():
if pattern in key:
return score
return 0
def _score_to_result(score: int) -> TestResult:
if score >= 4:
return TestResult.detected
if score >= 1:
return TestResult.partially_detected
return TestResult.not_detected
# ---------------------------------------------------------------------------
# Public API helpers
# ---------------------------------------------------------------------------
def fetch_rounds_with_status() -> dict[str, Any]:
"""Fetch CrowdStrike ENTERPRISE rounds and report whether the live API was reachable.
Returns::
{
"rounds": [{"name": ..., "display_name": ..., "eval_round": ...}, ...],
"api_reachable": True | False,
"api_error": None | "<error message>",
}
"""
try:
session = requests.Session()
session.headers.update(_HEADERS)
resp = session.get(f"{_BASE}/api/participants/", timeout=_TIMEOUT)
resp.raise_for_status()
participants = resp.json()
except Exception as exc:
logger.warning(
"evals.mitre.org API unreachable (%s) — using hardcoded fallback round list.",
exc,
)
return {
"rounds": list(_FALLBACK_ROUNDS),
"api_reachable": False,
"api_error": str(exc),
}
crowdstrike = next(
(p for p in participants if p.get("name", "").lower() == _VENDOR),
None,
)
if not crowdstrike:
logger.warning("Vendor '%s' not found in live data — using fallback.", _VENDOR)
return {
"rounds": list(_FALLBACK_ROUNDS),
"api_reachable": True, # API was reachable, vendor just wasn't listed
"api_error": f"Vendor '{_VENDOR}' not found in participants list",
}
rounds = [
adv
for adv in crowdstrike.get("adversaries_completed", [])
if adv.get("domain", "").upper() == _DOMAIN
and adv.get("status", "").upper() == "PUBLIC"
]
rounds.sort(key=lambda x: x.get("eval_round", 0))
return {
"rounds": rounds if rounds else list(_FALLBACK_ROUNDS),
"api_reachable": True,
"api_error": None,
}
def fetch_available_rounds() -> list[dict[str, Any]]:
"""Return all evaluation rounds CrowdStrike has completed (ENTERPRISE only).
Each dict has: name, display_name, eval_round.
Sorted by eval_round ascending.
Falls back to ``_FALLBACK_ROUNDS`` if the live API is unreachable.
"""
return fetch_rounds_with_status()["rounds"]
def get_latest_round() -> dict[str, Any]:
"""Return the most recent PUBLIC ENTERPRISE round CrowdStrike participated in."""
rounds = fetch_available_rounds()
if not rounds:
raise ValueError("No public Enterprise evaluation rounds found for CrowdStrike")
return rounds[-1]
def fetch_results_for_adversary(adversary_name: str) -> list[dict[str, Any]]:
"""Fetch all per-substep detection results for a specific adversary round.
Returns a flat list of substep dicts, each containing:
technique_id, technique_name, tactic_id, best_score, detection_type, note.
"""
url = f"{_BASE}/api/results/?participant={_VENDOR}&domain={_DOMAIN}"
try:
session = requests.Session()
session.headers.update(_HEADERS)
resp = session.get(url, timeout=_TIMEOUT)
resp.raise_for_status()
data = resp.json()
except Exception as exc:
logger.error("Failed to fetch ATT&CK Evaluations results: %s", exc)
raise
# The results endpoint returns a LIST of vendor objects:
# [{"name": "crowdstrike", "adversaries": [{"Adversary_Name": "apt3", ...}, ...]}, ...]
# (not a dict — hence the explicit vendor lookup below)
if isinstance(data, list):
vendor_entry = next(
(v for v in data if isinstance(v, dict) and v.get("name", "").lower() == _VENDOR),
None,
)
if not vendor_entry:
raise ValueError(
f"Vendor '{_VENDOR}' not found in results response. "
f"Available vendors: {[v.get('name') for v in data if isinstance(v, dict)]}"
)
adversaries = vendor_entry.get("adversaries", [])
else:
# Fallback for legacy dict-shaped response (just in case API changes again)
adversaries = data.get("adversaries", [])
target = next(
(a for a in adversaries if a.get("Adversary_Name", "").lower() == adversary_name.lower()),
None,
)
if not target:
raise ValueError(
f"Adversary '{adversary_name}' not found in results. "
f"Available: {[a.get('Adversary_Name') for a in adversaries]}"
)
substeps: list[dict[str, Any]] = []
scenarios = target.get("Detections_By_Step", {})
for scenario_name, scenario_data in scenarios.items():
for step in scenario_data.get("Steps", []):
step_num = step.get("Step_Num", "")
step_name = step.get("Step_Name", "")
# Strip HTML tags from the Step.Description narrative
step_desc_raw = step.get("Description") or ""
step_description = re.sub(r"<[^>]+>", " ", step_desc_raw)
step_description = re.sub(r"\s+", " ", step_description).strip()
for substep in step.get("Substeps", []):
# Prefer sub-technique over technique
sub = substep.get("Subtechnique") or {}
tech = substep.get("Technique") or {}
tactic = substep.get("Tactic") or {}
technique_id = (
sub.get("Subtechnique_Id")
or tech.get("Technique_Id")
or ""
).strip()
technique_name = (
sub.get("Subtechnique_Name")
or tech.get("Technique_Name")
or "Unknown"
).strip()
if not technique_id:
continue
detections = substep.get("Detections", [])
best_score = 0
best_type = "None"
best_note = ""
for det in detections:
dtype = det.get("Detection_Type", "None")
s = _score(dtype)
if s > best_score:
best_score = s
best_type = dtype
best_note = det.get("Detection_Note", "")
# Collect all unique data sources from screenshots across all detections
data_sources: list[str] = sorted({
src
for det in detections
for sc in det.get("Screenshots", [])
for src in sc.get("Data_Sources", [])
})
substeps.append(
{
"technique_id": technique_id,
"technique_name": technique_name,
"tactic_id": tactic.get("Tactic_Id", ""),
"tactic_name": tactic.get("Tactic_Name", ""),
"best_score": best_score,
"detection_type": best_type,
"note": best_note,
# Enrichment fields from the API
"scenario_name": scenario_name,
"step_num": step_num,
"step_name": step_name,
"step_description": step_description,
"substep_ref": substep.get("Substep", ""),
"criteria": (substep.get("Criteria") or "").strip(),
"data_sources": data_sources,
}
)
return substeps
def _aggregate_by_technique(substeps: list[dict]) -> dict[str, dict]:
"""Aggregate substep results per technique.
- Deduplicates substeps by (substep_ref, criteria) — prevents duplicates
that arise when adversaries with multiple scenarios (e.g. Wizard Spider +
Sandworm) repeat the same substep across a "combined" replay scenario.
- Groups unique occurrences by scenario_name so the narrative can show
"Wizard Spider scenario" vs "Sandworm scenario" separately.
- Tracks best detection score across all unique substeps.
"""
by_technique: dict[str, dict] = {}
for sub in substeps:
tid = sub["technique_id"]
if tid not in by_technique:
by_technique[tid] = {
**sub,
"occurrences": [], # flat list of unique occurrences
"_seen_keys": set(), # (substep_ref, criteria) dedup set
}
agg = by_technique[tid]
# Deduplication key: same substep_ref + same criteria text = duplicate
dedup_key = (sub["substep_ref"], sub["criteria"])
if dedup_key in agg["_seen_keys"]:
continue
agg["_seen_keys"].add(dedup_key)
agg["occurrences"].append({
"scenario_name": sub["scenario_name"],
"substep_ref": sub["substep_ref"],
"step_num": sub["step_num"],
"step_name": sub["step_name"],
"step_description": sub["step_description"],
"criteria": sub["criteria"],
"data_sources": sub["data_sources"],
"detection_type": sub["detection_type"],
"best_score": sub["best_score"],
"note": sub["note"],
})
# Promote best detection score
if sub["best_score"] > agg["best_score"]:
agg["best_score"] = sub["best_score"]
agg["detection_type"] = sub["detection_type"]
agg["note"] = sub["note"]
agg["tactic_id"] = sub["tactic_id"]
agg["tactic_name"] = sub["tactic_name"]
# Clean up internal dedup sets before returning
for agg in by_technique.values():
agg.pop("_seen_keys", None)
return by_technique
def _group_occurrences_by_scenario(occurrences: list[dict]) -> dict[str, list[dict]]:
"""Group a technique's occurrences by scenario, preserving insertion order."""
grouped: dict[str, list[dict]] = {}
for occ in occurrences:
sc = occ.get("scenario_name", "Scenario_1")
grouped.setdefault(sc, []).append(occ)
return grouped
def _build_procedure_text(agg: dict, adversary_display: str, eval_round: int) -> str:
"""Build a rich attack-path narrative for the Test.procedure_text field.
Groups substeps by scenario so adversaries with multiple threat groups
(e.g. Wizard Spider + Sandworm with 3 scenarios) are clearly separated.
Includes Step.Description narrative for context.
"""
occurrences = agg.get("occurrences", [])
if not occurrences:
return (
f"MITRE ATT&CK Evaluation simulation using {adversary_display} TTPs. "
f"See evaluation report at https://evals.mitre.org for full details."
)
lines: list[str] = [f"ATT&CK Evaluation R{eval_round}{adversary_display}", ""]
grouped = _group_occurrences_by_scenario(occurrences)
scenario_count = len(grouped)
for sc_name, sc_occs in grouped.items():
# Scenario header — only shown when there are multiple scenarios
if scenario_count > 1:
idx = sc_name.replace("Scenario_", "Scenario ")
lines.append(f"=== {idx} ===")
# Within each scenario, group by step to emit description once per step
seen_step_descs: set = set()
for occ in sc_occs:
step_num = occ.get("step_num", "")
step_name = occ.get("step_name", "")
step_desc = occ.get("step_description", "")
# Use (step_num or step_name) as dedup key for descriptions
step_key = str(step_num) if step_num else step_name
if step_key and step_key not in seen_step_descs:
seen_step_descs.add(step_key)
header = f"Step {step_num}{step_name}:" if step_num else f"{step_name}:"
lines.append(header)
if step_desc:
truncated = step_desc[:450] + ("" if len(step_desc) > 450 else "")
lines.append(truncated)
ref = occ.get("substep_ref", "")
criteria = occ.get("criteria", "")
det = occ.get("detection_type", "")
if criteria:
tag = f" [{ref}]" if ref else ""
det_tag = f" [{det}]" if det and det.lower() not in ("none", "") else ""
lines.append(f"{tag}{det_tag} {criteria}")
lines.append("")
return "\n".join(lines).rstrip()
def _build_description(agg: dict, adversary_display: str, eval_round: int) -> str:
"""Build Test.description with source metadata, detection guidance and warning.
The 'criteria' field from the MITRE API describes what each substep does AND
what should be detected, so it doubles as blue-team detection guidance.
"""
occurrences = agg.get("occurrences", [])
# Collect all unique data sources across every unique occurrence
all_data_sources: list[str] = sorted({
src
for occ in occurrences
for src in occ.get("data_sources", [])
})
header = (
f"Source: MITRE ATT&CK Evaluation — Round {eval_round} ({adversary_display}).\n"
f"Vendor: CrowdStrike Falcon.\n"
f"Detection type achieved: {agg['detection_type']}."
)
ds_section = ""
if all_data_sources:
ds_section = "\n\nData sources observed:\n" + "\n".join(
f"{ds}" for ds in all_data_sources
)
# Detection guidance — what criteria were observed (blue team can use these as IOCs)
det_lines: list[str] = []
grouped = _group_occurrences_by_scenario(occurrences)
for sc_name, sc_occs in grouped.items():
scenario_label = f"[{sc_name}] " if len(grouped) > 1 else ""
for occ in sc_occs:
ref = occ.get("substep_ref", "")
step_name = occ.get("step_name", "")
criteria = occ.get("criteria", "")
det_type = occ.get("detection_type", "")
if criteria:
label = f"[{ref}]" if ref else ""
step_label = f" ({step_name})" if step_name else ""
det_label = f"{det_type}" if det_type and det_type.lower() not in ("none", "") else ""
det_lines.append(f" {scenario_label}{label}{step_label}{det_label}: {criteria}")
det_section = ""
if det_lines:
det_section = "\n\nDetection criteria (what to look for):\n" + "\n".join(det_lines)
warning = (
f"\n\n⚠️ IMPORTANT: These results reflect CrowdStrike Falcon performance in a "
f"controlled MITRE lab environment against a simulated {adversary_display} "
f"adversary. They do NOT represent your organisation's actual detection "
f"capability. Validate in your own environment before approving."
)
note_section = f"\n\nMITRE note: {agg['note']}" if agg.get("note") else ""
return header + ds_section + det_section + warning + note_section
def _build_red_summary(agg: dict, adversary_display: str, eval_round: int) -> str:
"""Build the Red Team summary for the Test.red_summary field."""
occurrences = agg.get("occurrences", [])
lines = [
f"MITRE ATT&CK Evaluation — Round {eval_round} ({adversary_display})",
"Vendor: CrowdStrike Falcon",
f"Best detection level: {agg['detection_type']}",
f"Tactic: {agg['tactic_name']} ({agg['tactic_id']})",
f"Unique substeps: {len(occurrences)}",
]
if occurrences:
lines.append("")
grouped = _group_occurrences_by_scenario(occurrences)
for sc_name, sc_occs in grouped.items():
if len(grouped) > 1:
lines.append(f"{sc_name}:")
for occ in sc_occs:
ref = occ.get("substep_ref", "")
criteria = occ.get("criteria", "")
step_name = occ.get("step_name", "")
det = occ.get("detection_type", "")
if criteria:
tag = f" [{ref}]" if ref else ""
step_tag = f" {step_name}" if step_name else ""
det_tag = f" [{det}]" if det and det.lower() not in ("none", "") else ""
lines.append(f"{tag}{step_tag}{det_tag} {criteria}")
return "\n".join(lines)
# ---------------------------------------------------------------------------
# Main import function
# ---------------------------------------------------------------------------
def import_evaluation_round(
db: Session,
adversary_name: str,
adversary_display: str,
eval_round: int,
current_user: User,
) -> dict[str, Any]:
"""Import a single ATT&CK Evaluation round for CrowdStrike into the platform.
Creates one Test per unique technique with the best detection result
observed across all substeps for that technique. All tests land in
``in_review`` state — Blue Leads must confirm before they count as coverage.
Returns a summary dict: created, skipped, techniques_covered.
Raises if the round was already imported (idempotency guard).
"""
# Idempotency — refuse duplicate imports
existing = (
db.query(EvaluationImport)
.filter(
EvaluationImport.adversary_name == adversary_name.lower(),
EvaluationImport.status == "completed",
)
.first()
)
if existing:
raise ValueError(
f"Round '{adversary_display}' (round {eval_round}) was already imported "
f"on {existing.imported_at.date()}. Re-import is not allowed."
)
# Fetch and aggregate substep results
substeps = fetch_results_for_adversary(adversary_name)
by_technique = _aggregate_by_technique(substeps)
created = 0
skipped = 0
affected_technique_ids: set = set()
for mitre_id, agg in by_technique.items():
# Look up the technique in our DB
technique = (
db.query(Technique)
.filter(Technique.mitre_id == mitre_id.upper())
.first()
)
if technique is None:
skipped += 1
continue
detection_result = _score_to_result(agg["best_score"])
description = _build_description(agg, adversary_display, eval_round)
red_summary = _build_red_summary(agg, adversary_display, eval_round)
procedure_text = _build_procedure_text(agg, adversary_display, eval_round)
test = Test(
technique_id=technique.id,
name=f"[EVAL R{eval_round}] {adversary_display}{technique.name}",
description=description,
platform=None,
procedure_text=procedure_text,
created_by=current_user.id,
state=TestState.in_review,
attack_success=True,
red_summary=red_summary,
red_validation_status="approved",
red_validated_by=current_user.id,
red_validated_at=datetime.utcnow(),
detection_result=detection_result,
blue_validation_status=None,
execution_date=datetime.utcnow(),
created_at=datetime.utcnow(),
)
db.add(test)
db.flush()
log_action(
db,
user_id=current_user.id,
action="eval_import_test",
entity_type="test",
entity_id=test.id,
details={
"adversary": adversary_name,
"eval_round": eval_round,
"mitre_id": mitre_id,
"detection_type": agg["detection_type"],
},
)
affected_technique_ids.add(technique.id)
created += 1
# Recalculate coverage for all touched techniques
for tech_id in affected_technique_ids:
tech = db.query(Technique).filter(Technique.id == tech_id).first()
if tech:
recalculate_technique_status(db, tech)
# Record the import
record = EvaluationImport(
id=uuid.uuid4(),
adversary_name=adversary_name.lower(),
adversary_display=adversary_display,
eval_round=eval_round,
imported_at=datetime.utcnow(),
imported_by=current_user.id,
tests_created=created,
techniques_covered=len(affected_technique_ids),
status="completed",
notes=f"Skipped {skipped} techniques not found in local DB.",
)
db.add(record)
db.commit()
logger.info(
"ATT&CK Evaluation import complete — round %d (%s): %d tests created, %d skipped",
eval_round, adversary_display, created, skipped,
)
return {
"created": created,
"skipped": skipped,
"techniques_covered": len(affected_technique_ids),
"adversary": adversary_display,
"eval_round": eval_round,
}
# ---------------------------------------------------------------------------
# New-round check (used by the weekly scheduler)
# ---------------------------------------------------------------------------
def check_for_new_round(db: Session) -> dict[str, Any]:
"""Check if a new evaluation round is available that hasn't been imported yet.
Returns:
{"new_round_available": bool, "latest_round": dict | None, "already_imported": bool}
"""
try:
latest = get_latest_round()
except Exception as exc:
logger.warning("Could not check for new ATT&CK Evaluation round: %s", exc)
return {"new_round_available": False, "latest_round": None, "error": str(exc)}
already = (
db.query(EvaluationImport)
.filter(
EvaluationImport.adversary_name == latest["name"].lower(),
EvaluationImport.status == "completed",
)
.first()
)
return {
"new_round_available": already is None,
"already_imported": already is not None,
"latest_round": {
"name": latest["name"],
"display_name": latest.get("display_name", latest["name"]),
"eval_round": latest["eval_round"],
},
}
# ---------------------------------------------------------------------------
# Re-enrich existing tests with richer API data
# ---------------------------------------------------------------------------
def re_enrich_evaluation_round(
db: Session,
adversary_name: str,
adversary_display: str,
eval_round: int,
current_user: User,
) -> dict[str, Any]:
"""Update procedure_text / description / red_summary on already-imported tests
for a given round using the enriched API data (attack path, criteria, data sources).
This is non-destructive — it only updates the three narrative fields and does
not change detection results, state, or validation status.
"""
# Fetch & aggregate (same flow as import)
substeps = fetch_results_for_adversary(adversary_name)
by_technique = _aggregate_by_technique(substeps)
updated = 0
skipped = 0
for mitre_id, agg in by_technique.items():
technique = (
db.query(Technique)
.filter(Technique.mitre_id == mitre_id.upper())
.first()
)
if technique is None:
skipped += 1
continue
# Find the existing test for this round + technique
existing_test = (
db.query(Test)
.filter(
Test.technique_id == technique.id,
Test.name.like(f"[EVAL R{eval_round}]%"),
)
.first()
)
if not existing_test:
skipped += 1
continue
existing_test.description = _build_description(agg, adversary_display, eval_round)
existing_test.red_summary = _build_red_summary(agg, adversary_display, eval_round)
existing_test.procedure_text = _build_procedure_text(agg, adversary_display, eval_round)
updated += 1
db.commit()
logger.info(
"Re-enrichment complete — round %d (%s): %d tests updated, %d skipped",
eval_round, adversary_display, updated, skipped,
)
return {
"updated": updated,
"skipped": skipped,
"adversary": adversary_display,
"eval_round": eval_round,
"message": (
f"Re-enriched {updated} tests for {adversary_display} (Round {eval_round}) "
f"with attack path, criteria and data sources from MITRE API."
),
}
-5
View File
@@ -104,16 +104,11 @@ def log_action(
user_agent=ua or None, user_agent=ua or None,
# Keyword argument: session_id # Keyword argument: session_id
session_id=session_id, session_id=session_id,
timestamp=datetime.now(timezone.utc),
) )
# Stage new record(s) for database insertion # Stage new record(s) for database insertion
db.add(entry) db.add(entry)
# Flush changes to DB without committing the transaction # Flush changes to DB without committing the transaction
db.flush() db.flush()
# Reload from DB so the timestamp is in DB-stable format before hashing.
# Without this, a round-trip through the DB (e.g. refresh after commit) can
# return a timestamp with different precision/timezone, causing hash mismatch.
db.refresh(entry)
# Assign entry.integrity_hash = compute_integrity_hash(entry) # Assign entry.integrity_hash = compute_integrity_hash(entry)
entry.integrity_hash = compute_integrity_hash(entry) entry.integrity_hash = compute_integrity_hash(entry)
# Return entry # Return entry
+1 -4
View File
@@ -65,10 +65,7 @@ def change_password(
if not verify_password(current_password, user.hashed_password): if not verify_password(current_password, user.hashed_password):
# Raise BusinessRuleViolation # Raise BusinessRuleViolation
raise BusinessRuleViolation("Current password is incorrect") raise BusinessRuleViolation("Current password is incorrect")
if verify_password(new_password, user.hashed_password): # Assign user.hashed_password = hash_password(new_password)
raise BusinessRuleViolation(
"New password must be different from the current password"
)
user.hashed_password = hash_password(new_password) user.hashed_password = hash_password(new_password)
# Assign user.must_change_password = False # Assign user.must_change_password = False
user.must_change_password = False user.must_change_password = False
+8 -16
View File
@@ -53,9 +53,11 @@ from sqlalchemy.orm import Session
# Import DataSource from app.models.data_source # Import DataSource from app.models.data_source
from app.models.data_source import DataSource from app.models.data_source import DataSource
from app.models.technique import Technique
# Import TestTemplate from app.models.test_template # Import TestTemplate from app.models.test_template
from app.models.test_template import TestTemplate from app.models.test_template import TestTemplate
# Import log_action from app.services.audit_service
from app.services.audit_service import log_action from app.services.audit_service import log_action
# Assign logger = logging.getLogger(__name__) # Assign logger = logging.getLogger(__name__)
@@ -102,15 +104,10 @@ def _download_zip(url: str = CALDERA_ZIP_URL) -> bytes:
# Define function _extract_zip # Define function _extract_zip
def _extract_zip(zip_bytes: bytes, dest: str) -> Path: def _extract_zip(zip_bytes: bytes, dest: str) -> Path:
"""Extract *zip_bytes* into *dest* and return abilities dir.""" """Extract *zip_bytes* into *dest* and return abilities dir."""
dest_path = Path(dest).resolve() # Open context manager
with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf: with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
for member in zf.infolist(): # Call zf.extractall()
target = (dest_path / member.filename).resolve() zf.extractall(dest)
if not target.is_relative_to(dest_path):
raise ValueError(
f"Zip Slip detected — '{member.filename}' resolves outside target directory"
)
zf.extract(member, dest)
# Assign abilities_dir = Path(dest) / _ZIP_ROOT_PREFIX / "data" / "abilities" # Assign abilities_dir = Path(dest) / _ZIP_ROOT_PREFIX / "data" / "abilities"
abilities_dir = Path(dest) / _ZIP_ROOT_PREFIX / "data" / "abilities" abilities_dir = Path(dest) / _ZIP_ROOT_PREFIX / "data" / "abilities"
# Check: not abilities_dir.is_dir() # Check: not abilities_dir.is_dir()
@@ -371,7 +368,6 @@ def sync(db: Session) -> dict:
created = 0 created = 0
# Assign skipped = 0 # Assign skipped = 0
skipped = 0 skipped = 0
new_technique_ids: set[str] = set()
# Iterate over parsed # Iterate over parsed
for item in parsed: for item in parsed:
@@ -409,14 +405,10 @@ def sync(db: Session) -> dict:
db.add(template) db.add(template)
# Call existing_ids.add() # Call existing_ids.add()
existing_ids.add(item["atomic_test_id"]) existing_ids.add(item["atomic_test_id"])
new_technique_ids.add(item["mitre_technique_id"]) # Assign created = 1
created += 1 created += 1
if new_technique_ids: # Commit all pending changes to the database
db.query(Technique).filter(
Technique.mitre_id.in_(new_technique_ids)
).update({"review_required": True}, synchronize_session=False)
db.commit() db.commit()
# Assign summary = { # Assign summary = {
+4 -92
View File
@@ -34,7 +34,6 @@ from app.models.test import Test
# Import calculate_next_run from app.services.campaign_scheduler_service # Import calculate_next_run from app.services.campaign_scheduler_service
from app.services.campaign_scheduler_service import calculate_next_run from app.services.campaign_scheduler_service import calculate_next_run
from app.services.status_service import recalculate_technique_status
# Import from app.services.campaign_service # Import from app.services.campaign_service
from app.services.campaign_service import ( from app.services.campaign_service import (
@@ -121,7 +120,7 @@ def serialize_campaign(db: Session, campaign: Campaign) -> dict:
"threat_actor_name": actor.name if actor else None, "threat_actor_name": actor.name if actor else None,
# Literal argument value # Literal argument value
"created_by": str(campaign.created_by) if campaign.created_by else None, "created_by": str(campaign.created_by) if campaign.created_by else None,
"start_date": campaign.start_date.isoformat() if campaign.start_date else None, # Literal argument value
"scheduled_at": campaign.scheduled_at.isoformat() if campaign.scheduled_at else None, "scheduled_at": campaign.scheduled_at.isoformat() if campaign.scheduled_at else None,
# Literal argument value # Literal argument value
"completed_at": campaign.completed_at.isoformat() if campaign.completed_at else None, "completed_at": campaign.completed_at.isoformat() if campaign.completed_at else None,
@@ -172,7 +171,7 @@ def serialize_campaign_summary(db: Session, campaign: Campaign) -> dict:
"threat_actor_id": str(campaign.threat_actor_id) if campaign.threat_actor_id else None, "threat_actor_id": str(campaign.threat_actor_id) if campaign.threat_actor_id else None,
# Literal argument value # Literal argument value
"threat_actor_name": actor.name if actor else None, "threat_actor_name": actor.name if actor else None,
"start_date": campaign.start_date.isoformat() if campaign.start_date else None, # Literal argument value
"target_platform": campaign.target_platform, "target_platform": campaign.target_platform,
# Literal argument value # Literal argument value
"tags": campaign.tags or [], "tags": campaign.tags or [],
@@ -275,7 +274,6 @@ def create_campaign(
tags: Optional[list[str]] = None, tags: Optional[list[str]] = None,
# Entry: scheduled_at # Entry: scheduled_at
scheduled_at: Optional[str] = None, scheduled_at: Optional[str] = None,
start_date: Optional[str] = None,
) -> dict: ) -> dict:
"""Create a new campaign. Does not commit; caller commits.""" """Create a new campaign. Does not commit; caller commits."""
# Assign campaign = Campaign( # Assign campaign = Campaign(
@@ -296,7 +294,6 @@ def create_campaign(
created_by=creator_id, created_by=creator_id,
# Keyword argument: scheduled_at # Keyword argument: scheduled_at
scheduled_at=datetime.fromisoformat(scheduled_at) if scheduled_at else None, scheduled_at=datetime.fromisoformat(scheduled_at) if scheduled_at else None,
start_date=datetime.fromisoformat(start_date) if start_date else None,
) )
# Stage new record(s) for database insertion # Stage new record(s) for database insertion
db.add(campaign) db.add(campaign)
@@ -361,8 +358,6 @@ def update_campaign(
if "scheduled_at" in fields and fields["scheduled_at"]: if "scheduled_at" in fields and fields["scheduled_at"]:
# Assign fields["scheduled_at"] = datetime.fromisoformat(fields["scheduled_at"]) # Assign fields["scheduled_at"] = datetime.fromisoformat(fields["scheduled_at"])
fields["scheduled_at"] = datetime.fromisoformat(fields["scheduled_at"]) fields["scheduled_at"] = datetime.fromisoformat(fields["scheduled_at"])
if "start_date" in fields and fields["start_date"]:
fields["start_date"] = datetime.fromisoformat(fields["start_date"])
# Iterate over fields.items() # Iterate over fields.items()
for field, value in fields.items(): for field, value in fields.items():
@@ -536,29 +531,11 @@ def remove_test_from_campaign(db: Session, campaign_id: str, campaign_test_id: s
# Assign dep.depends_on = None # Assign dep.depends_on = None
dep.depends_on = None dep.depends_on = None
# Keep a reference to the underlying test before deleting the join record # Mark record for deletion on next commit
test_id = ct.test_id
technique_id = None
test_obj = db.query(Test).filter(Test.id == test_id).first()
if test_obj:
technique_id = test_obj.technique_id
db.delete(ct) db.delete(ct)
# Flush changes to DB without committing the transaction # Flush changes to DB without committing the transaction
db.flush() db.flush()
# Also delete the actual test record (it was created for this campaign)
if test_obj:
db.delete(test_obj)
db.flush()
# Recalculate technique status_global so coverage metrics stay consistent
if technique_id:
technique = db.query(Technique).filter(Technique.id == technique_id).first()
if technique:
recalculate_technique_status(db, technique)
db.flush()
# Define function activate_campaign # Define function activate_campaign
def activate_campaign(db: Session, campaign_id: str) -> Campaign: def activate_campaign(db: Session, campaign_id: str) -> Campaign:
@@ -720,72 +697,7 @@ def schedule_campaign(
return campaign return campaign
def delete_campaign( # Define function get_campaign_history
db: Session,
campaign_id: str,
*,
deleter_id: uuid.UUID,
deleter_role: str,
delete_tests: bool = False,
) -> None:
"""Delete a campaign.
Only draft campaigns can be deleted unless the caller is admin.
If delete_tests=True, the associated Test objects are also deleted.
Does not commit; caller commits.
"""
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
if not campaign:
raise EntityNotFoundError("Campaign", campaign_id)
if campaign.status != "draft" and deleter_role != "admin":
raise BusinessRuleViolation("Only draft campaigns can be deleted")
if str(campaign.created_by) != str(deleter_id) and deleter_role != "admin":
raise PermissionViolation("Only the creator or admin can delete this campaign")
# Collect test IDs before removing associations
campaign_tests = (
db.query(CampaignTest).filter(CampaignTest.campaign_id == campaign_id).all()
)
test_ids = [ct.test_id for ct in campaign_tests]
# Remove CampaignTest join rows (clear depends_on refs first to avoid FK cycles)
for ct in campaign_tests:
ct.depends_on = None
db.flush()
for ct in campaign_tests:
db.delete(ct)
db.flush()
# Optionally delete the associated tests
if delete_tests:
affected_technique_ids: set = set()
for test_id in test_ids:
test = db.query(Test).filter(Test.id == test_id).first()
if test:
if test.technique_id:
affected_technique_ids.add(test.technique_id)
db.delete(test)
db.flush()
# Recalculate status_global for every affected technique so the
# coverage metrics stay consistent after test deletion.
for tech_id in affected_technique_ids:
technique = db.query(Technique).filter(Technique.id == tech_id).first()
if technique:
recalculate_technique_status(db, technique)
db.flush()
# Null-out parent_campaign_id on child campaigns to avoid FK violation
db.query(Campaign).filter(Campaign.parent_campaign_id == campaign.id).update(
{"parent_campaign_id": None}
)
db.flush()
db.delete(campaign)
db.flush()
def get_campaign_history(db: Session, campaign_id: str) -> dict: def get_campaign_history(db: Session, campaign_id: str) -> dict:
"""List all child campaigns (execution history) of a recurring campaign. """List all child campaigns (execution history) of a recurring campaign.
+2 -6
View File
@@ -6,9 +6,9 @@ threat actors, and progress calculation.
# Import logging # Import logging
import logging import logging
# Import uuid
import uuid import uuid
from datetime import datetime
from typing import Optional
# Import Session from sqlalchemy.orm # Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -179,8 +179,6 @@ def generate_campaign_from_threat_actor(
actor_id: uuid.UUID, actor_id: uuid.UUID,
# Entry: user # Entry: user
user: User, user: User,
*,
start_date: Optional[datetime] = None,
) -> Campaign: ) -> Campaign:
"""Auto-generate a campaign from a threat actor's uncovered techniques. """Auto-generate a campaign from a threat actor's uncovered techniques.
@@ -238,7 +236,6 @@ def generate_campaign_from_threat_actor(
created_by=user.id, created_by=user.id,
# Keyword argument: tags # Keyword argument: tags
tags=[actor.name, "auto-generated"], tags=[actor.name, "auto-generated"],
start_date=start_date,
) )
# Stage new record(s) for database insertion # Stage new record(s) for database insertion
db.add(campaign) db.add(campaign)
@@ -291,7 +288,6 @@ def generate_campaign_from_threat_actor(
created_by=user.id, created_by=user.id,
# Keyword argument: state # Keyword argument: state
state=TestState.draft, state=TestState.draft,
created_at=datetime.utcnow(),
) )
# Stage new record(s) for database insertion # Stage new record(s) for database insertion
db.add(test) db.add(test)
File diff suppressed because it is too large Load Diff

Some files were not shown because too many files have changed in this diff Show More