Compare commits
209 Commits
1f19bd8432
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 986682aad1 | |||
| f8824291a2 | |||
| 443a04befb | |||
| 88c2af472e | |||
| 8ba9790625 | |||
| af5b6e1cff | |||
| dcd4bebc92 | |||
| f54dc0d342 | |||
| acc9092baa | |||
| 6d3617938e | |||
| 709a810775 | |||
| cf33c69f95 | |||
| 392ce162dc | |||
| 5e8b5ee33c | |||
| ebf47c6142 | |||
| 0e2e9d0bb0 | |||
| 9472fe91fa | |||
| 675870b469 | |||
| 92f4bdcdce | |||
| 3ec51524d6 | |||
| 7ded48bdb7 | |||
| 6ca37f743f | |||
| cea518b33c | |||
| 22293804ab | |||
| 64cc438bcc | |||
| 8fea0c1ada | |||
| 98fddccd32 | |||
| f6d33638fd | |||
| 0001b33594 | |||
| a7725ba519 | |||
| 0c9f3051b4 | |||
| e2861a08bc | |||
| 467afc334d | |||
| b630cd3210 | |||
| c0cecab797 | |||
| 51d86e5436 | |||
| 8515b8de17 | |||
| b037500b7c | |||
| 6f835c8501 | |||
| 46ade20d14 | |||
| 5f54396cb6 | |||
| f4289249b8 | |||
| 6ab61c8ace | |||
| 725cf3406e | |||
| 564eb406aa | |||
| bf3add9b09 | |||
| 840e1ac0bb | |||
| b5f924abe0 | |||
| 6b1f5d690a | |||
| 27c67a5f76 | |||
| f605b52d89 | |||
| af864ed735 | |||
| 92e8ff7aff | |||
| 9fb84fa65c | |||
| cafd7db94b | |||
| 80991b2f59 | |||
| 200ef88d67 | |||
| fd39658f5d | |||
| 3b552dbe4e | |||
| 2ecb950770 | |||
| c141e5bb67 | |||
| 6c343bd7a1 | |||
| 643e65fbe5 | |||
| de0db3cec8 | |||
| cd2fe5aad6 | |||
| a1415a379f | |||
| 611689f3ce | |||
| 7d7d351ca8 | |||
| 33e6a1a3f4 | |||
| 399628e20a | |||
| 0531e7e73e | |||
| a1c67419e7 | |||
| e5e1779208 | |||
| f3c07fdaf1 | |||
| 6ceb4125a0 | |||
| 96b6d683a4 | |||
| 14aea87675 | |||
| 2624585e05 | |||
| aae032445c | |||
| 0eeca61de2 | |||
| 31c644d23f | |||
| 70d5274448 | |||
| 17f9d1078f | |||
| a33a13eca8 | |||
| b438dd0af0 | |||
| 5c5398683a | |||
| 62f5542ef2 | |||
| e82af44a6c | |||
| 546b5692f0 | |||
| ebe8eecb94 | |||
| aa3e08f9b6 | |||
| 6e1f51e0ff | |||
| abcc948513 | |||
| 1fd5e37bd0 | |||
| 865a7b6e0f | |||
| f0fe8be005 | |||
| 0e51af9cf7 | |||
| 6021f0801c | |||
| 98e0f27172 | |||
| 6af37030f4 | |||
| 6f1f09d74d | |||
| 857c793f31 | |||
| b60e5562c0 | |||
| a238b05ca8 | |||
| 5e748dbf80 | |||
| 104ea5c65b | |||
| 5a5f8a01e7 | |||
| 791407d02f | |||
| 940e575a65 | |||
| d761b46590 | |||
| 662d38423e | |||
| fa994801a5 | |||
| 6a4a153d59 | |||
| 069728a010 | |||
| 14e9b8b43a | |||
| d125b0c8e4 | |||
| 9a30c11413 | |||
| ab68542120 | |||
| 4b1ea7b9d2 | |||
| 4c3773de34 | |||
| ea453feea0 | |||
| 64cda5e608 | |||
| 8d71ee1da2 | |||
| 1c27e31101 | |||
| 366fc2170c | |||
| 7594a09b20 | |||
| 60e2a31046 | |||
| f0bd4b7e7d | |||
| 8d64905739 | |||
| 965ff96433 | |||
| 785b5b44a3 | |||
| 424eef70c5 | |||
| 322b6fcb62 | |||
| cf19a18810 | |||
| a48bd3c475 | |||
| 117600acea | |||
| 8b48716766 | |||
| 1a974265de | |||
| cd718512ad | |||
| 18df271d07 | |||
| c0a0e1aa00 | |||
| e6c188c782 | |||
| eac6d10639 | |||
| 7e9a5a35f6 | |||
| 76a76607b0 | |||
| dd9d817d5d | |||
| cd9bdc7399 | |||
| b5a81b69ed | |||
| aaff54f432 | |||
| e2b8e7e207 | |||
| 3b20911c93 | |||
| 851100d8ec | |||
| 06e45b098c | |||
| 8968382731 | |||
| e9a3985a1f | |||
| 3d8f445d1b | |||
| 8577588a21 | |||
| 001cefb882 | |||
| 1ce427db88 | |||
| 95b4c4ea65 | |||
| 174b7e8d24 | |||
| d307039a41 | |||
| aaa344ab79 | |||
| 1f08eb014b | |||
| e08d8a9beb | |||
| e4342e1c3f | |||
| d3831b8ed9 | |||
| 87af1735ce | |||
| f3109644cb | |||
| ee1c773073 | |||
| a294300052 | |||
| f72287984c | |||
| 8a51f98631 | |||
| 8e7ee1494e | |||
| 6f24d340d2 | |||
| da89a9ae51 | |||
| 6665efd276 | |||
| 37f2d6daa6 | |||
| f2787bf860 | |||
| 21ed939569 | |||
| 3c077f971e | |||
| d9292fb3ff | |||
| 6fad769c13 | |||
| d1443d1ffa | |||
| 9d0cb6d67d | |||
| 3f174e7d89 | |||
| 6c3f00f6e6 | |||
| ed579fb8f7 | |||
| 612dec7a93 | |||
| a138c7a8ed | |||
| 6c4517c7f3 | |||
| dd1f0e472f | |||
| 8c73377571 | |||
| ab50bcd90e | |||
| b78593ca10 | |||
| 0b81580b44 | |||
| 95b46a95a8 | |||
| b493f92f75 | |||
| 61c26ddd0f | |||
| 634abc289b | |||
| 519ddfb7a0 | |||
| 7009fcabbf | |||
| b714b466c8 | |||
| ca17675253 | |||
| 9552ba2f14 | |||
| c172a8af00 | |||
| 83b74c5262 | |||
| d6fce0bc4e | |||
| cdb5055193 |
@@ -1,189 +0,0 @@
|
||||
---
|
||||
description: Aegis backend Clean Architecture rules. Apply when working on any backend Python file under backend/app/ or backend/tests/.
|
||||
globs: backend/**/*.py
|
||||
---
|
||||
|
||||
# Aegis — Clean Modular Monolith Architecture
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
Aegis follows a **Clean Architecture** pattern inside a modular monolith. The backend has four layers with strict dependency rules:
|
||||
|
||||
```
|
||||
Presentation → Application → Domain ← Infrastructure
|
||||
```
|
||||
|
||||
**The golden rule:** dependencies only point towards the Domain layer. Infrastructure implements the ports (interfaces) defined in Domain.
|
||||
|
||||
## Layer Structure and Rules
|
||||
|
||||
### Domain Layer (`backend/app/domain/`)
|
||||
|
||||
The innermost layer. **ZERO** imports from FastAPI, SQLAlchemy, Pydantic, or any framework.
|
||||
|
||||
| Directory | Purpose |
|
||||
|-----------|---------|
|
||||
| `domain/enums.py` | Canonical domain enums (TechniqueStatus, TestState, TeamSide, TestResult) |
|
||||
| `domain/errors.py` | Exception hierarchy (DomainError → EntityNotFoundError, InvalidStateTransition, etc.) |
|
||||
| `domain/exceptions.py` | Backward-compatible re-exports from errors.py |
|
||||
| `domain/test_entity.py` | TestEntity — pure state machine with domain events |
|
||||
| `domain/entities/` | Rich domain entities (TechniqueEntity, etc.) with business behavior |
|
||||
| `domain/value_objects/` | Immutable value types (MitreId, ScoringWeights) |
|
||||
| `domain/ports/repositories/` | Protocol interfaces defining data access contracts |
|
||||
| `domain/ports/services/` | Protocol interfaces for external capabilities (storage, events) |
|
||||
| `domain/unit_of_work.py` | UnitOfWork wrapping SQLAlchemy session |
|
||||
|
||||
**NEVER** import from `app.models`, `app.routers`, `app.infrastructure`, `fastapi`, or `sqlalchemy` inside `domain/`.
|
||||
|
||||
### Application Layer (`backend/app/application/` — future)
|
||||
|
||||
Use case orchestrators. Depends only on Domain.
|
||||
|
||||
| Directory | Purpose |
|
||||
|-----------|---------|
|
||||
| `application/use_cases/` | One class per business operation |
|
||||
| `application/dto/` | Plain data containers for use case input/output |
|
||||
| `application/interfaces/` | Application-level contracts (UnitOfWork protocol) |
|
||||
|
||||
### Infrastructure Layer (`backend/app/infrastructure/`)
|
||||
|
||||
Implements ports defined in Domain. Depends on Domain and Application.
|
||||
|
||||
| Directory | Purpose |
|
||||
|-----------|---------|
|
||||
| `infrastructure/redis_client.py` | Redis connection singleton |
|
||||
| `infrastructure/persistence/repositories/` | SQLAlchemy implementations of repository ports |
|
||||
| `infrastructure/persistence/mappers/` | ORM model ↔ domain entity converters |
|
||||
|
||||
### Presentation Layer (routers, schemas, dependencies)
|
||||
|
||||
HTTP boundary. Depends on Application and Domain (for exceptions).
|
||||
|
||||
| Directory | Purpose |
|
||||
|-----------|---------|
|
||||
| `routers/` | FastAPI routers — HTTP mapping only |
|
||||
| `schemas/` | Pydantic request/response models |
|
||||
| `dependencies/` | FastAPI `Depends()` wiring (auth, repositories) |
|
||||
| `middleware/` | Error handler mapping domain exceptions → HTTP responses |
|
||||
|
||||
## Import Rules (Strict)
|
||||
|
||||
| From \ To | domain/ | application/ | infrastructure/ | presentation/ |
|
||||
|-----------|---------|-------------|----------------|--------------|
|
||||
| **domain/** | Self only | FORBIDDEN | FORBIDDEN | FORBIDDEN |
|
||||
| **application/** | ALLOWED | Self only | FORBIDDEN | FORBIDDEN |
|
||||
| **infrastructure/** | ALLOWED (ports) | ALLOWED (UoW) | Self only | FORBIDDEN |
|
||||
| **presentation/** | ALLOWED (exceptions) | ALLOWED (use cases) | ALLOWED (wiring in dependencies/) | Self only |
|
||||
|
||||
## How to Add a New Feature
|
||||
|
||||
### 1. Start from the Domain
|
||||
|
||||
- Define or reuse domain entities in `domain/entities/`
|
||||
- Add value objects if needed in `domain/value_objects/`
|
||||
- Define repository port if a new aggregate root in `domain/ports/repositories/`
|
||||
- Domain exceptions go in `domain/errors.py`
|
||||
- Business rules live IN the entity, not in services or routers
|
||||
|
||||
### 2. Implement Infrastructure
|
||||
|
||||
- Create SQLAlchemy repository implementation in `infrastructure/persistence/repositories/`
|
||||
- Create mapper if converting between ORM model and domain entity
|
||||
- Repository does NOT call `commit()` — only `flush()`
|
||||
- Transaction control belongs to the Unit of Work
|
||||
|
||||
### 3. Wire in Presentation
|
||||
|
||||
- Add FastAPI `Depends()` provider in `dependencies/repositories.py`
|
||||
- Keep routers thin: parse request → call service/use case → return response
|
||||
- Map domain exceptions to HTTP via the error handler middleware (automatic)
|
||||
|
||||
### 4. Tests (Mandatory)
|
||||
|
||||
Every change MUST include tests:
|
||||
- **Domain entities/value objects**: pure unit tests, no DB, no mocking frameworks
|
||||
- **Repositories**: integration tests using the `db` fixture from conftest
|
||||
- **Routers**: API tests using the `client` fixture
|
||||
- At least one success test + one failure/edge-case test per behavior
|
||||
|
||||
Before committing, run: `scripts/agent_validate_backend.sh`
|
||||
|
||||
## Existing Patterns to Follow
|
||||
|
||||
### Domain Entity Pattern (see `domain/test_entity.py`)
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class SomeEntity:
|
||||
id: uuid.UUID
|
||||
# fields...
|
||||
_events: list[DomainEvent] = field(default_factory=list, repr=False)
|
||||
|
||||
@classmethod
|
||||
def from_orm(cls, model: Any) -> "SomeEntity":
|
||||
"""Build from SQLAlchemy model."""
|
||||
...
|
||||
|
||||
def apply_to(self, model: Any) -> None:
|
||||
"""Copy mutable fields back onto the ORM model."""
|
||||
...
|
||||
|
||||
def some_business_method(self) -> None:
|
||||
"""Business logic lives HERE, not in services."""
|
||||
...
|
||||
self._events.append(DomainEvent("something_happened"))
|
||||
```
|
||||
|
||||
### Repository Port Pattern (Protocol)
|
||||
|
||||
```python
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
@runtime_checkable
|
||||
class SomeRepository(Protocol):
|
||||
def find_by_id(self, id: uuid.UUID) -> SomeEntity | None: ...
|
||||
def save(self, entity: SomeEntity) -> SomeEntity: ...
|
||||
```
|
||||
|
||||
### Repository Implementation Pattern
|
||||
|
||||
```python
|
||||
class SASomeRepository:
|
||||
def __init__(self, session: Session) -> None:
|
||||
self._session = session
|
||||
|
||||
def find_by_id(self, id: uuid.UUID) -> SomeEntity | None:
|
||||
model = self._session.query(SomeModel).filter(SomeModel.id == id).first()
|
||||
return SomeMapper.to_entity(model) if model else None
|
||||
|
||||
def save(self, entity: SomeEntity) -> SomeEntity:
|
||||
model = SomeMapper.to_model(entity)
|
||||
merged = self._session.merge(model)
|
||||
self._session.flush() # NO commit — UoW does that
|
||||
return SomeMapper.to_entity(merged)
|
||||
```
|
||||
|
||||
### Error Handling (automatic via middleware)
|
||||
|
||||
Services raise domain exceptions → middleware maps to HTTP:
|
||||
- `EntityNotFoundError` → 404
|
||||
- `DuplicateEntityError` → 409
|
||||
- `InvalidStateTransition` → 400
|
||||
- `BusinessRuleViolation` → 400
|
||||
- `PermissionViolation` → 403
|
||||
|
||||
### Coexistence Strategy
|
||||
|
||||
Old code (direct `db.query()` in routers) and new code (repositories) coexist. Migration is incremental:
|
||||
1. New endpoints use repositories
|
||||
2. Existing endpoints are migrated one at a time
|
||||
3. Both access the same DB, same session, same tables
|
||||
|
||||
## Key Conventions
|
||||
|
||||
- **Enums**: canonical source is `domain/enums.py`, `models/enums.py` re-exports
|
||||
- **Exceptions**: raise from `domain/errors.py`, never raise `HTTPException` from services
|
||||
- **Commits**: only via `UnitOfWork.commit()` or at the router level, never inside services/repos
|
||||
- **IDs**: UUID everywhere (primary keys, foreign keys)
|
||||
- **Tests**: SQLite in-memory for unit/integration, PostgreSQL in CI
|
||||
- **Validation**: Pydantic in schemas (presentation), domain rules in entities (domain)
|
||||
@@ -0,0 +1,71 @@
|
||||
name: Snyk Security Scan
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main, develop]
|
||||
pull_request:
|
||||
branches: [main]
|
||||
schedule:
|
||||
- cron: '0 6 * * 1' # Weekly on Monday 06:00 UTC
|
||||
|
||||
jobs:
|
||||
snyk-backend:
|
||||
name: Python vulnerabilities (backend)
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
|
||||
- name: Install backend dependencies
|
||||
run: pip install -r backend/requirements-lock.txt
|
||||
|
||||
- name: Snyk — scan Python packages
|
||||
uses: snyk/actions/python@master
|
||||
env:
|
||||
SNYK_TOKEN: ${{ secrets.SNYK_TOKEN }}
|
||||
with:
|
||||
args: --file=backend/requirements-lock.txt --severity-threshold=high
|
||||
continue-on-error: true # report without blocking CI during initial cleanup
|
||||
|
||||
snyk-frontend:
|
||||
name: npm vulnerabilities (frontend)
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '20'
|
||||
|
||||
- name: Install frontend dependencies
|
||||
run: npm ci
|
||||
working-directory: frontend
|
||||
|
||||
- name: Snyk — scan npm packages
|
||||
uses: snyk/actions/node@master
|
||||
env:
|
||||
SNYK_TOKEN: ${{ secrets.SNYK_TOKEN }}
|
||||
with:
|
||||
args: --file=frontend/package.json --severity-threshold=high
|
||||
continue-on-error: true
|
||||
|
||||
snyk-docker-backend:
|
||||
name: Docker image vulnerabilities (backend)
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Build backend image for scanning
|
||||
run: docker build -t aegis-backend:scan backend/
|
||||
|
||||
- name: Snyk — scan Docker image
|
||||
uses: snyk/actions/docker@master
|
||||
env:
|
||||
SNYK_TOKEN: ${{ secrets.SNYK_TOKEN }}
|
||||
with:
|
||||
image: aegis-backend:scan
|
||||
args: --severity-threshold=high
|
||||
continue-on-error: true
|
||||
@@ -60,3 +60,12 @@ Thumbs.db
|
||||
|
||||
# Local development
|
||||
*.local
|
||||
|
||||
# Documentation drafts — never commit, delivered directly in chat
|
||||
docs/confluence/
|
||||
docs/drafts/
|
||||
|
||||
# Editor / AI assistant working files — never commit
|
||||
.claude/
|
||||
.cursor/
|
||||
CLAUDE.md
|
||||
|
||||
@@ -0,0 +1,2 @@
|
||||
skips:
|
||||
- B311
|
||||
+5
-1
@@ -3,10 +3,14 @@ FROM python:3.11-slim
|
||||
WORKDIR /app
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
RUN apt-get update && apt-get upgrade -y && apt-get install -y \
|
||||
gcc \
|
||||
libpq-dev \
|
||||
curl \
|
||||
pkg-config \
|
||||
libxml2-dev \
|
||||
libxmlsec1-dev \
|
||||
libxmlsec1-openssl \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy requirements first for better caching
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
"""Phase 6.1: webhook_configs table.
|
||||
|
||||
Revision ID: b031phase6
|
||||
Revises: b030phase5
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from alembic import op
|
||||
|
||||
revision: str = "b031phase6"
|
||||
down_revision: Union[str, None] = "b030phase5"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"webhook_configs",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column("name", sa.String(200), nullable=False),
|
||||
sa.Column("url", sa.Text, nullable=False),
|
||||
sa.Column("secret", sa.String(256), nullable=True),
|
||||
sa.Column("events", postgresql.JSONB, nullable=False, server_default="[]"),
|
||||
sa.Column("is_active", sa.Boolean, nullable=False, server_default="true"),
|
||||
sa.Column("created_by", postgresql.UUID(as_uuid=True), sa.ForeignKey("users.id", ondelete="SET NULL"), nullable=True),
|
||||
sa.Column("last_triggered_at", sa.DateTime, nullable=True),
|
||||
sa.Column("failure_count", sa.Integer, nullable=False, server_default="0"),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
)
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("webhook_configs")
|
||||
@@ -0,0 +1,41 @@
|
||||
"""Phase 7.2: user notification_preferences and jira_account_id columns.
|
||||
|
||||
Revision ID: b032phase7
|
||||
Revises: b031phase6
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from alembic import op
|
||||
|
||||
revision: str = "b032phase7"
|
||||
down_revision: Union[str, None] = "b031phase6"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
_DEFAULT_PREFS = '{"email_on_test_validated": true, "email_on_campaign_completed": true, "email_on_new_mitre_techniques": false, "in_app_all": true}'
|
||||
|
||||
def _column_names(table: str) -> set[str]:
|
||||
bind = op.get_bind()
|
||||
insp = sa.inspect(bind)
|
||||
return {c["name"] for c in insp.get_columns(table)}
|
||||
|
||||
def upgrade() -> None:
|
||||
user_cols = _column_names("users")
|
||||
if "notification_preferences" not in user_cols:
|
||||
op.add_column(
|
||||
"users",
|
||||
sa.Column("notification_preferences", postgresql.JSONB, nullable=True, server_default=_DEFAULT_PREFS),
|
||||
)
|
||||
if "jira_account_id" not in user_cols:
|
||||
op.add_column(
|
||||
"users",
|
||||
sa.Column("jira_account_id", sa.String(100), nullable=True),
|
||||
)
|
||||
|
||||
def downgrade() -> None:
|
||||
user_cols = _column_names("users")
|
||||
if "jira_account_id" in user_cols:
|
||||
op.drop_column("users", "jira_account_id")
|
||||
if "notification_preferences" in user_cols:
|
||||
op.drop_column("users", "notification_preferences")
|
||||
@@ -0,0 +1,43 @@
|
||||
"""Phase 8: system_configs table for runtime configuration.
|
||||
|
||||
Revision ID: b033syscfg
|
||||
Revises: b032phase7
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from alembic import op
|
||||
|
||||
revision: str = "b033syscfg"
|
||||
down_revision: Union[str, None] = "b032phase7"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def _table_exists(name: str) -> bool:
|
||||
bind = op.get_bind()
|
||||
insp = sa.inspect(bind)
|
||||
return name in insp.get_table_names()
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
if not _table_exists("system_configs"):
|
||||
op.create_table(
|
||||
"system_configs",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column("key", sa.String(200), unique=True, nullable=False),
|
||||
sa.Column("value", sa.Text, nullable=True),
|
||||
sa.Column("description", sa.String(500), nullable=True),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
)
|
||||
op.create_index("ix_system_configs_key", "system_configs", ["key"])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
if _table_exists("system_configs"):
|
||||
op.drop_index("ix_system_configs_key", table_name="system_configs")
|
||||
op.drop_table("system_configs")
|
||||
@@ -0,0 +1,174 @@
|
||||
"""Phase 8: Detection Lifecycle Management tables.
|
||||
|
||||
Revision ID: b034dlm
|
||||
Revises: b033syscfg
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from alembic import op
|
||||
|
||||
revision: str = "b034dlm"
|
||||
down_revision: Union[str, None] = "b033syscfg"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def _table_exists(name: str) -> bool:
|
||||
bind = op.get_bind()
|
||||
insp = sa.inspect(bind)
|
||||
return name in insp.get_table_names()
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
if not _table_exists("detection_assets"):
|
||||
op.create_table(
|
||||
"detection_assets",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column("name", sa.String(500), nullable=False),
|
||||
sa.Column("description", sa.Text),
|
||||
sa.Column("asset_type", sa.String(50), nullable=False),
|
||||
sa.Column("platform", sa.String(100)),
|
||||
sa.Column("rule_content", sa.Text),
|
||||
sa.Column("rule_language", sa.String(50)),
|
||||
sa.Column("rule_repository_url", sa.Text),
|
||||
sa.Column("rule_file_path", sa.String(500)),
|
||||
sa.Column("rule_version", sa.String(50)),
|
||||
sa.Column("rule_hash", sa.String(64)),
|
||||
sa.Column("last_rule_change_at", sa.DateTime),
|
||||
sa.Column("log_source_name", sa.String(200)),
|
||||
sa.Column("log_source_version", sa.String(50)),
|
||||
sa.Column("log_source_config", postgresql.JSONB, server_default="{}"),
|
||||
sa.Column("infrastructure_hash", sa.String(64)),
|
||||
sa.Column("infrastructure_details", postgresql.JSONB, server_default="{}"),
|
||||
sa.Column("health_status", sa.String(20), server_default="untested", nullable=False),
|
||||
sa.Column("last_alert_at", sa.DateTime),
|
||||
sa.Column("alert_count_30d", sa.Integer, server_default="0"),
|
||||
sa.Column("false_positive_rate", sa.Float),
|
||||
sa.Column("expected_alert_frequency", sa.String(50)),
|
||||
sa.Column("owner_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("users.id", ondelete="SET NULL")),
|
||||
sa.Column("backup_owner_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("users.id", ondelete="SET NULL")),
|
||||
sa.Column("team", sa.String(100)),
|
||||
sa.Column("is_active", sa.Boolean, server_default="true", nullable=False),
|
||||
sa.Column("tags", postgresql.JSONB, server_default="[]"),
|
||||
sa.Column("asset_metadata", postgresql.JSONB, server_default="{}"),
|
||||
sa.Column("created_by", postgresql.UUID(as_uuid=True), sa.ForeignKey("users.id", ondelete="SET NULL")),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()")),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()")),
|
||||
)
|
||||
op.create_index("ix_detection_assets_platform", "detection_assets", ["platform"])
|
||||
op.create_index("ix_detection_assets_health_status", "detection_assets", ["health_status"])
|
||||
op.create_index("ix_detection_assets_owner_id", "detection_assets", ["owner_id"])
|
||||
|
||||
if not _table_exists("detection_technique_mappings"):
|
||||
op.create_table(
|
||||
"detection_technique_mappings",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column("detection_asset_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("detection_assets.id", ondelete="CASCADE"), nullable=False),
|
||||
sa.Column("technique_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("techniques.id", ondelete="CASCADE"), nullable=False),
|
||||
sa.Column("coverage_type", sa.String(50), server_default="detect"),
|
||||
sa.Column("confidence_level", sa.String(20), server_default="medium"),
|
||||
sa.Column("notes", sa.Text),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()")),
|
||||
)
|
||||
op.create_index("ix_detection_technique_mappings_technique_id", "detection_technique_mappings", ["technique_id"])
|
||||
op.create_index("ix_detection_technique_mappings_asset_id", "detection_technique_mappings", ["detection_asset_id"])
|
||||
|
||||
if not _table_exists("detection_validations"):
|
||||
op.create_table(
|
||||
"detection_validations",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column("detection_asset_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("detection_assets.id", ondelete="CASCADE"), nullable=False),
|
||||
sa.Column("technique_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("techniques.id", ondelete="SET NULL")),
|
||||
sa.Column("test_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("tests.id", ondelete="SET NULL")),
|
||||
sa.Column("validated_at", sa.DateTime),
|
||||
sa.Column("expires_at", sa.DateTime, nullable=False),
|
||||
sa.Column("is_valid", sa.Boolean, server_default="true", nullable=False),
|
||||
sa.Column("validation_result", sa.String(50)),
|
||||
sa.Column("validation_method", sa.String(100)),
|
||||
sa.Column("rule_hash_at_validation", sa.String(64)),
|
||||
sa.Column("log_source_version_at_validation", sa.String(50)),
|
||||
sa.Column("infrastructure_hash_at_validation", sa.String(64)),
|
||||
sa.Column("environment_snapshot", postgresql.JSONB, server_default="{}"),
|
||||
sa.Column("invalidated_at", sa.DateTime),
|
||||
sa.Column("invalidation_reason", sa.String(50)),
|
||||
sa.Column("invalidation_details", sa.Text),
|
||||
sa.Column("invalidated_by", postgresql.UUID(as_uuid=True), sa.ForeignKey("users.id", ondelete="SET NULL")),
|
||||
sa.Column("validated_by", postgresql.UUID(as_uuid=True), sa.ForeignKey("users.id", ondelete="SET NULL"), nullable=False),
|
||||
sa.Column("integrity_hash", sa.String(64)),
|
||||
sa.Column("notes", sa.Text),
|
||||
sa.Column("evidence_ids", postgresql.JSONB, server_default="[]"),
|
||||
)
|
||||
op.create_index("ix_detection_validations_asset_id_valid", "detection_validations", ["detection_asset_id", "is_valid"])
|
||||
op.create_index("ix_detection_validations_expires_at", "detection_validations", ["expires_at"])
|
||||
|
||||
if not _table_exists("technique_confidence_scores"):
|
||||
op.create_table(
|
||||
"technique_confidence_scores",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column("technique_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("techniques.id", ondelete="CASCADE"), nullable=False, unique=True),
|
||||
sa.Column("confidence_level", sa.String(20), server_default="unknown"),
|
||||
sa.Column("confidence_score", sa.Float, server_default="0.0"),
|
||||
sa.Column("detection_count", sa.Integer, server_default="0"),
|
||||
sa.Column("valid_detection_count", sa.Integer, server_default="0"),
|
||||
sa.Column("last_validated_at", sa.DateTime),
|
||||
sa.Column("next_validation_due", sa.DateTime),
|
||||
sa.Column("last_recalculated_at", sa.DateTime),
|
||||
sa.Column("recency_factor", sa.Float, server_default="0.0"),
|
||||
sa.Column("coverage_factor", sa.Float, server_default="0.0"),
|
||||
sa.Column("health_factor", sa.Float, server_default="0.0"),
|
||||
sa.Column("diversity_factor", sa.Float, server_default="0.0"),
|
||||
sa.Column("score_breakdown", postgresql.JSONB, server_default="{}"),
|
||||
sa.Column("risk_factors", postgresql.JSONB, server_default="[]"),
|
||||
sa.Column("updated_at", sa.DateTime),
|
||||
)
|
||||
op.create_index("ix_technique_confidence_scores_technique_id", "technique_confidence_scores", ["technique_id"])
|
||||
op.create_index("ix_technique_confidence_scores_confidence_level", "technique_confidence_scores", ["confidence_level"])
|
||||
|
||||
if not _table_exists("infrastructure_change_logs"):
|
||||
op.create_table(
|
||||
"infrastructure_change_logs",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column("change_type", sa.String(100), nullable=False),
|
||||
sa.Column("description", sa.Text, nullable=False),
|
||||
sa.Column("affected_platforms", postgresql.JSONB, server_default="[]"),
|
||||
sa.Column("affected_log_sources", postgresql.JSONB, server_default="[]"),
|
||||
sa.Column("change_date", sa.DateTime),
|
||||
sa.Column("reported_by", postgresql.UUID(as_uuid=True), sa.ForeignKey("users.id", ondelete="SET NULL")),
|
||||
sa.Column("auto_invalidate", sa.Boolean, server_default="true"),
|
||||
sa.Column("invalidated_count", sa.Integer, server_default="0"),
|
||||
sa.Column("change_metadata", postgresql.JSONB, server_default="{}"),
|
||||
sa.Column("created_at", sa.DateTime),
|
||||
)
|
||||
op.create_index("ix_infrastructure_change_logs_change_date", "infrastructure_change_logs", ["change_date"])
|
||||
|
||||
if not _table_exists("decay_policies"):
|
||||
op.create_table(
|
||||
"decay_policies",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column("name", sa.String(200), nullable=False),
|
||||
sa.Column("description", sa.Text),
|
||||
sa.Column("applies_to_platform", sa.String(100)),
|
||||
sa.Column("applies_to_asset_type", sa.String(50)),
|
||||
sa.Column("applies_to_tactic", sa.String(100)),
|
||||
sa.Column("fresh_days", sa.Integer, server_default="90"),
|
||||
sa.Column("aging_days", sa.Integer, server_default="180"),
|
||||
sa.Column("stale_days", sa.Integer, server_default="365"),
|
||||
sa.Column("default_validity_days", sa.Integer, server_default="180"),
|
||||
sa.Column("silent_threshold_days", sa.Integer, server_default="30"),
|
||||
sa.Column("noisy_threshold_daily", sa.Integer, server_default="100"),
|
||||
sa.Column("recency_weight", sa.Float, server_default="0.3"),
|
||||
sa.Column("coverage_weight", sa.Float, server_default="0.3"),
|
||||
sa.Column("health_weight", sa.Float, server_default="0.25"),
|
||||
sa.Column("diversity_weight", sa.Float, server_default="0.15"),
|
||||
sa.Column("is_default", sa.Boolean, server_default="false"),
|
||||
sa.Column("is_active", sa.Boolean, server_default="true"),
|
||||
sa.Column("created_at", sa.DateTime),
|
||||
sa.Column("updated_at", sa.DateTime),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
for table in ["decay_policies", "infrastructure_change_logs", "technique_confidence_scores", "detection_validations", "detection_technique_mappings", "detection_assets"]:
|
||||
if _table_exists(table):
|
||||
op.drop_table(table)
|
||||
@@ -0,0 +1,118 @@
|
||||
"""Phase 9: Ownership & Revalidation Queue
|
||||
|
||||
Revision ID: b035ownerq
|
||||
Revises: b034dlm
|
||||
Create Date: 2026-05-19
|
||||
|
||||
Uses raw SQL for all DDL to avoid SQLAlchemy before_create hook issues
|
||||
with existing enum types.
|
||||
"""
|
||||
|
||||
from typing import Union
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision: str = "b035ownerq"
|
||||
down_revision: Union[str, None] = "b034dlm"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# ── Enums (idempotent) ────────────────────────────────────────────────────
|
||||
conn.execute(sa.text("""
|
||||
DO $$ BEGIN
|
||||
CREATE TYPE queue_priority AS ENUM ('critical', 'high', 'medium', 'low');
|
||||
EXCEPTION WHEN duplicate_object THEN NULL;
|
||||
END $$
|
||||
"""))
|
||||
conn.execute(sa.text("""
|
||||
DO $$ BEGIN
|
||||
CREATE TYPE queue_status AS ENUM ('pending', 'in_progress', 'completed', 'dismissed');
|
||||
EXCEPTION WHEN duplicate_object THEN NULL;
|
||||
END $$
|
||||
"""))
|
||||
conn.execute(sa.text("""
|
||||
DO $$ BEGIN
|
||||
CREATE TYPE queue_reason AS ENUM (
|
||||
'validation_expired', 'infra_change', 'osint_alert',
|
||||
'mitre_update', 'rule_modified', 'low_confidence', 'manual');
|
||||
EXCEPTION WHEN duplicate_object THEN NULL;
|
||||
END $$
|
||||
"""))
|
||||
|
||||
# ── technique_ownerships ──────────────────────────────────────────────────
|
||||
conn.execute(sa.text("""
|
||||
CREATE TABLE IF NOT EXISTS technique_ownerships (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
technique_id UUID NOT NULL UNIQUE
|
||||
REFERENCES techniques(id) ON DELETE CASCADE,
|
||||
owner_id UUID
|
||||
REFERENCES users(id) ON DELETE SET NULL,
|
||||
backup_owner_id UUID
|
||||
REFERENCES users(id) ON DELETE SET NULL,
|
||||
team VARCHAR(200),
|
||||
notes TEXT,
|
||||
assigned_at TIMESTAMP,
|
||||
assigned_by UUID
|
||||
REFERENCES users(id) ON DELETE SET NULL,
|
||||
created_at TIMESTAMP DEFAULT now(),
|
||||
updated_at TIMESTAMP DEFAULT now()
|
||||
)
|
||||
"""))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_techown_owner_id ON technique_ownerships (owner_id)"
|
||||
))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_techown_technique_id ON technique_ownerships (technique_id)"
|
||||
))
|
||||
|
||||
# ── revalidation_queue_items ──────────────────────────────────────────────
|
||||
conn.execute(sa.text("""
|
||||
CREATE TABLE IF NOT EXISTS revalidation_queue_items (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
technique_id UUID
|
||||
REFERENCES techniques(id) ON DELETE CASCADE,
|
||||
detection_asset_id UUID
|
||||
REFERENCES detection_assets(id) ON DELETE CASCADE,
|
||||
priority queue_priority NOT NULL DEFAULT 'medium',
|
||||
reason queue_reason NOT NULL,
|
||||
reason_detail TEXT,
|
||||
status queue_status NOT NULL DEFAULT 'pending',
|
||||
assigned_to UUID
|
||||
REFERENCES users(id) ON DELETE SET NULL,
|
||||
due_date TIMESTAMP,
|
||||
created_at TIMESTAMP DEFAULT now(),
|
||||
completed_at TIMESTAMP,
|
||||
dismissed_at TIMESTAMP,
|
||||
completed_by UUID
|
||||
REFERENCES users(id) ON DELETE SET NULL,
|
||||
extra JSONB
|
||||
)
|
||||
"""))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_rqueue_status ON revalidation_queue_items (status)"
|
||||
))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_rqueue_priority ON revalidation_queue_items (priority)"
|
||||
))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_rqueue_assigned_to ON revalidation_queue_items (assigned_to)"
|
||||
))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_rqueue_technique_id ON revalidation_queue_items (technique_id)"
|
||||
))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_rqueue_asset_id ON revalidation_queue_items (detection_asset_id)"
|
||||
))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
conn.execute(sa.text("DROP TABLE IF EXISTS revalidation_queue_items"))
|
||||
conn.execute(sa.text("DROP TABLE IF EXISTS technique_ownerships"))
|
||||
conn.execute(sa.text("DROP TYPE IF EXISTS queue_reason"))
|
||||
conn.execute(sa.text("DROP TYPE IF EXISTS queue_status"))
|
||||
conn.execute(sa.text("DROP TYPE IF EXISTS queue_priority"))
|
||||
@@ -0,0 +1,184 @@
|
||||
"""Phase 10: Attack Paths & Advanced Purple Team
|
||||
|
||||
Revision ID: b036atk
|
||||
Revises: b035ownerq
|
||||
Create Date: 2026-05-19
|
||||
|
||||
Uses raw SQL to avoid SQLAlchemy DDL hook issues with enum types.
|
||||
"""
|
||||
|
||||
from typing import Union
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision: str = "b036atk"
|
||||
down_revision: Union[str, None] = "b035ownerq"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# ── Enums ─────────────────────────────────────────────────────────────────
|
||||
conn.execute(sa.text("""
|
||||
DO $$ BEGIN
|
||||
CREATE TYPE execution_status AS ENUM
|
||||
('planned','in_progress','completed','aborted');
|
||||
EXCEPTION WHEN duplicate_object THEN NULL; END $$
|
||||
"""))
|
||||
conn.execute(sa.text("""
|
||||
DO $$ BEGIN
|
||||
CREATE TYPE step_result_status AS ENUM
|
||||
('pending','executing','detected','not_detected','skipped');
|
||||
EXCEPTION WHEN duplicate_object THEN NULL; END $$
|
||||
"""))
|
||||
conn.execute(sa.text("""
|
||||
DO $$ BEGIN
|
||||
CREATE TYPE timeline_actor_side AS ENUM ('red','blue','system');
|
||||
EXCEPTION WHEN duplicate_object THEN NULL; END $$
|
||||
"""))
|
||||
conn.execute(sa.text("""
|
||||
DO $$ BEGIN
|
||||
CREATE TYPE timeline_entry_type AS ENUM
|
||||
('action','detection','note','phase_transition','flag');
|
||||
EXCEPTION WHEN duplicate_object THEN NULL; END $$
|
||||
"""))
|
||||
|
||||
# ── attack_paths ──────────────────────────────────────────────────────────
|
||||
conn.execute(sa.text("""
|
||||
CREATE TABLE IF NOT EXISTS attack_paths (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
name VARCHAR(300) NOT NULL,
|
||||
description TEXT,
|
||||
objective TEXT,
|
||||
is_template BOOLEAN DEFAULT FALSE,
|
||||
threat_actor_id UUID REFERENCES threat_actors(id) ON DELETE SET NULL,
|
||||
created_by UUID REFERENCES users(id) ON DELETE SET NULL,
|
||||
tags JSONB,
|
||||
is_active BOOLEAN DEFAULT TRUE,
|
||||
created_at TIMESTAMP DEFAULT now(),
|
||||
updated_at TIMESTAMP DEFAULT now()
|
||||
)
|
||||
"""))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_attack_paths_created_by ON attack_paths (created_by)"
|
||||
))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_attack_paths_is_template ON attack_paths (is_template)"
|
||||
))
|
||||
|
||||
# ── attack_path_steps ─────────────────────────────────────────────────────
|
||||
conn.execute(sa.text("""
|
||||
CREATE TABLE IF NOT EXISTS attack_path_steps (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
attack_path_id UUID NOT NULL REFERENCES attack_paths(id) ON DELETE CASCADE,
|
||||
order_index INTEGER NOT NULL DEFAULT 0,
|
||||
kill_chain_phase VARCHAR(60),
|
||||
technique_id UUID REFERENCES techniques(id) ON DELETE SET NULL,
|
||||
test_id UUID REFERENCES tests(id) ON DELETE SET NULL,
|
||||
name VARCHAR(300),
|
||||
description TEXT,
|
||||
expected_detection BOOLEAN DEFAULT TRUE,
|
||||
notes TEXT
|
||||
)
|
||||
"""))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_ap_steps_path_id ON attack_path_steps (attack_path_id)"
|
||||
))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_ap_steps_technique_id ON attack_path_steps (technique_id)"
|
||||
))
|
||||
|
||||
# ── attack_path_executions ────────────────────────────────────────────────
|
||||
conn.execute(sa.text("""
|
||||
CREATE TABLE IF NOT EXISTS attack_path_executions (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
attack_path_id UUID NOT NULL REFERENCES attack_paths(id) ON DELETE CASCADE,
|
||||
status execution_status NOT NULL DEFAULT 'planned',
|
||||
environment VARCHAR(100),
|
||||
red_team_lead UUID REFERENCES users(id) ON DELETE SET NULL,
|
||||
blue_team_lead UUID REFERENCES users(id) ON DELETE SET NULL,
|
||||
started_by UUID REFERENCES users(id) ON DELETE SET NULL,
|
||||
started_at TIMESTAMP,
|
||||
completed_at TIMESTAMP,
|
||||
notes TEXT,
|
||||
created_at TIMESTAMP DEFAULT now(),
|
||||
-- kill-chain metrics (populated on completion)
|
||||
total_steps INTEGER,
|
||||
detected_steps INTEGER,
|
||||
not_detected_steps INTEGER,
|
||||
skipped_steps INTEGER,
|
||||
detection_rate FLOAT,
|
||||
mttd_seconds FLOAT,
|
||||
furthest_undetected_step INTEGER
|
||||
)
|
||||
"""))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_ap_exec_path_id ON attack_path_executions (attack_path_id)"
|
||||
))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_ap_exec_status ON attack_path_executions (status)"
|
||||
))
|
||||
|
||||
# ── attack_path_step_results ──────────────────────────────────────────────
|
||||
conn.execute(sa.text("""
|
||||
CREATE TABLE IF NOT EXISTS attack_path_step_results (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
execution_id UUID NOT NULL
|
||||
REFERENCES attack_path_executions(id) ON DELETE CASCADE,
|
||||
step_id UUID NOT NULL
|
||||
REFERENCES attack_path_steps(id) ON DELETE CASCADE,
|
||||
step_order INTEGER NOT NULL DEFAULT 0,
|
||||
status step_result_status NOT NULL DEFAULT 'pending',
|
||||
executed_by UUID REFERENCES users(id) ON DELETE SET NULL,
|
||||
executed_at TIMESTAMP,
|
||||
detected_at TIMESTAMP,
|
||||
time_to_detect_seconds FLOAT,
|
||||
detection_asset_id UUID
|
||||
REFERENCES detection_assets(id) ON DELETE SET NULL,
|
||||
notes TEXT,
|
||||
evidence_ids JSONB
|
||||
)
|
||||
"""))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_ap_stepres_exec ON attack_path_step_results (execution_id)"
|
||||
))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_ap_stepres_step ON attack_path_step_results (step_id)"
|
||||
))
|
||||
|
||||
# ── attack_path_timeline ──────────────────────────────────────────────────
|
||||
conn.execute(sa.text("""
|
||||
CREATE TABLE IF NOT EXISTS attack_path_timeline (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
execution_id UUID NOT NULL
|
||||
REFERENCES attack_path_executions(id) ON DELETE CASCADE,
|
||||
step_id UUID REFERENCES attack_path_steps(id) ON DELETE SET NULL,
|
||||
timestamp TIMESTAMP NOT NULL DEFAULT now(),
|
||||
actor_side timeline_actor_side NOT NULL,
|
||||
actor_id UUID REFERENCES users(id) ON DELETE SET NULL,
|
||||
entry_type timeline_entry_type NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
extra JSONB
|
||||
)
|
||||
"""))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_timeline_execution_id ON attack_path_timeline (execution_id)"
|
||||
))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_timeline_timestamp ON attack_path_timeline (timestamp)"
|
||||
))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
conn.execute(sa.text("DROP TABLE IF EXISTS attack_path_timeline"))
|
||||
conn.execute(sa.text("DROP TABLE IF EXISTS attack_path_step_results"))
|
||||
conn.execute(sa.text("DROP TABLE IF EXISTS attack_path_executions"))
|
||||
conn.execute(sa.text("DROP TABLE IF EXISTS attack_path_steps"))
|
||||
conn.execute(sa.text("DROP TABLE IF EXISTS attack_paths"))
|
||||
conn.execute(sa.text("DROP TYPE IF EXISTS timeline_entry_type"))
|
||||
conn.execute(sa.text("DROP TYPE IF EXISTS timeline_actor_side"))
|
||||
conn.execute(sa.text("DROP TYPE IF EXISTS step_result_status"))
|
||||
conn.execute(sa.text("DROP TYPE IF EXISTS execution_status"))
|
||||
@@ -0,0 +1,106 @@
|
||||
"""Phase 11: Knowledge Management — Playbooks + Lessons Learned
|
||||
|
||||
Revision ID: b037know
|
||||
Revises: b036atk
|
||||
Create Date: 2026-05-20
|
||||
|
||||
Uses raw SQL to bypass SQLAlchemy DDL hooks (no enum types — string columns
|
||||
with Pydantic-layer validation instead, so no PostgreSQL enums needed).
|
||||
"""
|
||||
|
||||
from typing import Union
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision: str = "b037know"
|
||||
down_revision: Union[str, None] = "b036atk"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# ── playbooks ──────────────────────────────────────────────────────────────
|
||||
conn.execute(sa.text("""
|
||||
CREATE TABLE IF NOT EXISTS playbooks (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
technique_id UUID NOT NULL REFERENCES techniques(id) ON DELETE CASCADE,
|
||||
playbook_type VARCHAR(32) NOT NULL,
|
||||
title VARCHAR(255) NOT NULL,
|
||||
content TEXT NOT NULL DEFAULT '',
|
||||
version INTEGER NOT NULL DEFAULT 1,
|
||||
tools JSONB,
|
||||
prerequisites JSONB,
|
||||
created_by UUID REFERENCES users(id) ON DELETE SET NULL,
|
||||
updated_by UUID REFERENCES users(id) ON DELETE SET NULL,
|
||||
created_at TIMESTAMP DEFAULT now(),
|
||||
updated_at TIMESTAMP DEFAULT now(),
|
||||
is_active BOOLEAN DEFAULT TRUE,
|
||||
CONSTRAINT uq_playbook_technique_type UNIQUE (technique_id, playbook_type)
|
||||
)
|
||||
"""))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_playbooks_technique_id ON playbooks (technique_id)"
|
||||
))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_playbooks_type ON playbooks (playbook_type)"
|
||||
))
|
||||
|
||||
# ── playbook_versions ──────────────────────────────────────────────────────
|
||||
conn.execute(sa.text("""
|
||||
CREATE TABLE IF NOT EXISTS playbook_versions (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
playbook_id UUID NOT NULL REFERENCES playbooks(id) ON DELETE CASCADE,
|
||||
version INTEGER NOT NULL,
|
||||
title VARCHAR(255) NOT NULL,
|
||||
content TEXT NOT NULL DEFAULT '',
|
||||
tools JSONB,
|
||||
prerequisites JSONB,
|
||||
changed_by UUID REFERENCES users(id) ON DELETE SET NULL,
|
||||
change_note VARCHAR(500),
|
||||
created_at TIMESTAMP DEFAULT now()
|
||||
)
|
||||
"""))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_pb_versions_playbook_id ON playbook_versions (playbook_id)"
|
||||
))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_pb_versions_version ON playbook_versions (playbook_id, version)"
|
||||
))
|
||||
|
||||
# ── lessons_learned ────────────────────────────────────────────────────────
|
||||
conn.execute(sa.text("""
|
||||
CREATE TABLE IF NOT EXISTS lessons_learned (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
title VARCHAR(255) NOT NULL,
|
||||
what_happened TEXT NOT NULL DEFAULT '',
|
||||
root_cause TEXT NOT NULL DEFAULT '',
|
||||
fix_applied TEXT,
|
||||
severity VARCHAR(16) NOT NULL DEFAULT 'medium',
|
||||
entity_type VARCHAR(32) NOT NULL DEFAULT 'manual',
|
||||
entity_id UUID,
|
||||
technique_ids JSONB,
|
||||
tags JSONB,
|
||||
created_by UUID REFERENCES users(id) ON DELETE SET NULL,
|
||||
created_at TIMESTAMP DEFAULT now(),
|
||||
updated_at TIMESTAMP DEFAULT now(),
|
||||
is_active BOOLEAN DEFAULT TRUE
|
||||
)
|
||||
"""))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_ll_entity ON lessons_learned (entity_type, entity_id)"
|
||||
))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_ll_severity ON lessons_learned (severity)"
|
||||
))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_ll_created_by ON lessons_learned (created_by)"
|
||||
))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
conn.execute(sa.text("DROP TABLE IF EXISTS lessons_learned"))
|
||||
conn.execute(sa.text("DROP TABLE IF EXISTS playbook_versions"))
|
||||
conn.execute(sa.text("DROP TABLE IF EXISTS playbooks"))
|
||||
@@ -0,0 +1,62 @@
|
||||
"""Phase 12: Risk Intelligence — technique_risk_profiles table
|
||||
|
||||
Revision ID: b038risk
|
||||
Revises: b037know
|
||||
Create Date: 2026-05-20
|
||||
|
||||
Uses raw SQL to bypass SQLAlchemy DDL hooks.
|
||||
"""
|
||||
|
||||
from typing import Union
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision: str = "b038risk"
|
||||
down_revision: Union[str, None] = "b037know"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
conn.execute(sa.text("""
|
||||
CREATE TABLE IF NOT EXISTS technique_risk_profiles (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
technique_id UUID NOT NULL REFERENCES techniques(id) ON DELETE CASCADE,
|
||||
risk_score FLOAT NOT NULL DEFAULT 0.0,
|
||||
likelihood FLOAT NOT NULL DEFAULT 0.0,
|
||||
impact FLOAT NOT NULL DEFAULT 0.0,
|
||||
risk_level VARCHAR(16) NOT NULL DEFAULT 'info',
|
||||
detection_gap FLOAT NOT NULL DEFAULT 1.0,
|
||||
threat_actor_count INTEGER NOT NULL DEFAULT 0,
|
||||
osint_signal_count INTEGER NOT NULL DEFAULT 0,
|
||||
test_fail_count INTEGER NOT NULL DEFAULT 0,
|
||||
test_total_count INTEGER NOT NULL DEFAULT 0,
|
||||
test_failure_rate FLOAT NOT NULL DEFAULT 0.0,
|
||||
confidence_level FLOAT NOT NULL DEFAULT 0.0,
|
||||
scoring_breakdown JSONB,
|
||||
recommendations JSONB,
|
||||
computed_at TIMESTAMP DEFAULT now(),
|
||||
is_stale BOOLEAN DEFAULT TRUE,
|
||||
CONSTRAINT uq_risk_profile_technique UNIQUE (technique_id)
|
||||
)
|
||||
"""))
|
||||
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_risk_profiles_risk_score "
|
||||
"ON technique_risk_profiles (risk_score)"
|
||||
))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_risk_profiles_risk_level "
|
||||
"ON technique_risk_profiles (risk_level)"
|
||||
))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_risk_profiles_stale "
|
||||
"ON technique_risk_profiles (is_stale)"
|
||||
))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
conn.execute(sa.text("DROP TABLE IF EXISTS technique_risk_profiles"))
|
||||
@@ -0,0 +1,77 @@
|
||||
"""Phase 13: Executive Dashboard — posture_snapshots table.
|
||||
|
||||
Revision ID: b039exec
|
||||
Revises: b038risk
|
||||
Create Date: 2026-05-20
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "b039exec"
|
||||
down_revision = "b038risk"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
conn.execute(sa.text("""
|
||||
CREATE TABLE IF NOT EXISTS posture_snapshots (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
snapshot_date DATE NOT NULL,
|
||||
|
||||
-- Coverage
|
||||
total_techniques INTEGER NOT NULL DEFAULT 0,
|
||||
validated_count INTEGER NOT NULL DEFAULT 0,
|
||||
partial_count INTEGER NOT NULL DEFAULT 0,
|
||||
not_covered_count INTEGER NOT NULL DEFAULT 0,
|
||||
coverage_pct FLOAT NOT NULL DEFAULT 0.0,
|
||||
|
||||
-- Risk
|
||||
avg_risk_score FLOAT NOT NULL DEFAULT 0.0,
|
||||
critical_count INTEGER NOT NULL DEFAULT 0,
|
||||
high_count INTEGER NOT NULL DEFAULT 0,
|
||||
medium_count INTEGER NOT NULL DEFAULT 0,
|
||||
low_count INTEGER NOT NULL DEFAULT 0,
|
||||
|
||||
-- Operations
|
||||
open_queue_items INTEGER NOT NULL DEFAULT 0,
|
||||
orphan_techniques INTEGER NOT NULL DEFAULT 0,
|
||||
|
||||
-- Knowledge
|
||||
playbook_count INTEGER NOT NULL DEFAULT 0,
|
||||
lesson_count INTEGER NOT NULL DEFAULT 0,
|
||||
|
||||
-- MTTD
|
||||
mttd_avg_seconds FLOAT,
|
||||
executions_30d INTEGER NOT NULL DEFAULT 0,
|
||||
detection_rate_30d FLOAT,
|
||||
|
||||
-- Meta
|
||||
created_by UUID REFERENCES users(id) ON DELETE SET NULL,
|
||||
created_at TIMESTAMP WITHOUT TIME ZONE DEFAULT now(),
|
||||
extra JSONB
|
||||
)
|
||||
"""))
|
||||
|
||||
# Unique constraint: one snapshot per calendar day
|
||||
conn.execute(sa.text("""
|
||||
DO $$ BEGIN
|
||||
ALTER TABLE posture_snapshots
|
||||
ADD CONSTRAINT uq_posture_snapshot_date UNIQUE (snapshot_date);
|
||||
EXCEPTION WHEN duplicate_table THEN NULL;
|
||||
WHEN duplicate_object THEN NULL;
|
||||
END $$
|
||||
"""))
|
||||
|
||||
# Index for date-range trend queries
|
||||
conn.execute(sa.text("""
|
||||
CREATE INDEX IF NOT EXISTS ix_posture_snapshots_date
|
||||
ON posture_snapshots (snapshot_date)
|
||||
"""))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
conn.execute(sa.text("DROP TABLE IF EXISTS posture_snapshots CASCADE"))
|
||||
@@ -0,0 +1,75 @@
|
||||
"""Phase 14: Enterprise Readiness — api_keys and sso_configs tables.
|
||||
|
||||
Revision ID: b040ent
|
||||
Revises: b039exec
|
||||
Create Date: 2026-05-20
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "b040ent"
|
||||
down_revision = "b039exec"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# ── api_keys ──────────────────────────────────────────────────────────────
|
||||
conn.execute(sa.text("""
|
||||
CREATE TABLE IF NOT EXISTS api_keys (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
name VARCHAR(200) NOT NULL,
|
||||
description TEXT,
|
||||
key_prefix VARCHAR(13) NOT NULL,
|
||||
key_hash VARCHAR(64) NOT NULL UNIQUE,
|
||||
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
scopes JSONB NOT NULL DEFAULT '["read"]',
|
||||
last_used_at TIMESTAMP WITHOUT TIME ZONE,
|
||||
expires_at TIMESTAMP WITHOUT TIME ZONE,
|
||||
is_active BOOLEAN NOT NULL DEFAULT TRUE,
|
||||
created_at TIMESTAMP WITHOUT TIME ZONE DEFAULT now()
|
||||
)
|
||||
"""))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_api_keys_user_id ON api_keys (user_id)"
|
||||
))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_api_keys_key_hash ON api_keys (key_hash)"
|
||||
))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_api_keys_active ON api_keys (is_active)"
|
||||
))
|
||||
|
||||
# ── sso_configs ───────────────────────────────────────────────────────────
|
||||
conn.execute(sa.text("""
|
||||
CREATE TABLE IF NOT EXISTS sso_configs (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
is_enabled BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
provider_name VARCHAR(200),
|
||||
sp_entity_id VARCHAR(500),
|
||||
sp_acs_url VARCHAR(500),
|
||||
sp_slo_url VARCHAR(500),
|
||||
sp_certificate TEXT,
|
||||
sp_private_key TEXT,
|
||||
idp_entity_id VARCHAR(500),
|
||||
idp_sso_url VARCHAR(500),
|
||||
idp_slo_url VARCHAR(500),
|
||||
idp_certificate TEXT,
|
||||
attr_email VARCHAR(200) DEFAULT 'email',
|
||||
attr_username VARCHAR(200) DEFAULT 'username',
|
||||
attr_role VARCHAR(200) DEFAULT 'role',
|
||||
default_role VARCHAR(50) DEFAULT 'viewer',
|
||||
auto_provision BOOLEAN NOT NULL DEFAULT TRUE,
|
||||
created_at TIMESTAMP WITHOUT TIME ZONE DEFAULT now(),
|
||||
updated_at TIMESTAMP WITHOUT TIME ZONE DEFAULT now()
|
||||
)
|
||||
"""))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
conn.execute(sa.text("DROP TABLE IF EXISTS api_keys CASCADE"))
|
||||
conn.execute(sa.text("DROP TABLE IF EXISTS sso_configs CASCADE"))
|
||||
@@ -0,0 +1,82 @@
|
||||
"""Phase 13: Operational Alerts — alert_rules and alert_instances tables.
|
||||
|
||||
Revision ID: b041alerts
|
||||
Revises: b040ent
|
||||
Create Date: 2026-05-21
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "b041alerts"
|
||||
down_revision = "b040ent"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# ── alert_rules ───────────────────────────────────────────────────────────
|
||||
conn.execute(sa.text("""
|
||||
CREATE TABLE IF NOT EXISTS alert_rules (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
name VARCHAR(300) NOT NULL,
|
||||
description TEXT,
|
||||
rule_type VARCHAR(50) NOT NULL,
|
||||
severity VARCHAR(20) NOT NULL DEFAULT 'medium',
|
||||
is_enabled BOOLEAN NOT NULL DEFAULT TRUE,
|
||||
is_system BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
config JSONB NOT NULL DEFAULT '{}',
|
||||
notify_in_app BOOLEAN NOT NULL DEFAULT TRUE,
|
||||
notify_webhook BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
webhook_id UUID REFERENCES webhook_configs(id) ON DELETE SET NULL,
|
||||
cooldown_hours INTEGER NOT NULL DEFAULT 24,
|
||||
created_by UUID REFERENCES users(id) ON DELETE SET NULL,
|
||||
created_at TIMESTAMP WITHOUT TIME ZONE DEFAULT now(),
|
||||
last_fired_at TIMESTAMP WITHOUT TIME ZONE
|
||||
)
|
||||
"""))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_alert_rules_type ON alert_rules (rule_type)"
|
||||
))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_alert_rules_enabled ON alert_rules (is_enabled)"
|
||||
))
|
||||
|
||||
# ── alert_instances ───────────────────────────────────────────────────────
|
||||
conn.execute(sa.text("""
|
||||
CREATE TABLE IF NOT EXISTS alert_instances (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
rule_id UUID REFERENCES alert_rules(id) ON DELETE SET NULL,
|
||||
rule_name VARCHAR(300) NOT NULL,
|
||||
rule_type VARCHAR(50) NOT NULL,
|
||||
severity VARCHAR(20) NOT NULL,
|
||||
title VARCHAR(500) NOT NULL,
|
||||
message TEXT NOT NULL,
|
||||
details JSONB,
|
||||
status VARCHAR(20) NOT NULL DEFAULT 'open',
|
||||
acknowledged_by UUID REFERENCES users(id) ON DELETE SET NULL,
|
||||
acknowledged_at TIMESTAMP WITHOUT TIME ZONE,
|
||||
resolved_at TIMESTAMP WITHOUT TIME ZONE,
|
||||
created_at TIMESTAMP WITHOUT TIME ZONE DEFAULT now()
|
||||
)
|
||||
"""))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_alert_instances_rule_id ON alert_instances (rule_id)"
|
||||
))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_alert_instances_status ON alert_instances (status)"
|
||||
))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_alert_instances_severity ON alert_instances (severity)"
|
||||
))
|
||||
conn.execute(sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_alert_instances_created ON alert_instances (created_at)"
|
||||
))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
conn.execute(sa.text("DROP TABLE IF EXISTS alert_instances CASCADE"))
|
||||
conn.execute(sa.text("DROP TABLE IF EXISTS alert_rules CASCADE"))
|
||||
@@ -0,0 +1,25 @@
|
||||
"""Add jira_api_token to users table.
|
||||
|
||||
Revision ID: b042
|
||||
Revises: b041_operational_alerts
|
||||
Create Date: 2026-05-26
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "b042"
|
||||
down_revision = "b041alerts"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"users",
|
||||
sa.Column("jira_api_token", sa.String(500), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("users", "jira_api_token")
|
||||
@@ -0,0 +1,28 @@
|
||||
"""Add jira_email to users table.
|
||||
|
||||
Allows each user to specify a separate email for Jira authentication,
|
||||
independent of their Aegis account email.
|
||||
|
||||
Revision ID: b043
|
||||
Revises: b042
|
||||
Create Date: 2026-05-26
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "b043"
|
||||
down_revision = "b042"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"users",
|
||||
sa.Column("jira_email", sa.String(255), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("users", "jira_email")
|
||||
@@ -0,0 +1,25 @@
|
||||
"""add tempo_api_token to users
|
||||
|
||||
Revision ID: b044
|
||||
Revises: b043
|
||||
Create Date: 2026-05-27
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "b044"
|
||||
down_revision = "b043"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
op.add_column(
|
||||
"users",
|
||||
sa.Column("tempo_api_token", sa.String(500), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_column("users", "tempo_api_token")
|
||||
@@ -0,0 +1,16 @@
|
||||
"""Add blue_work_started_at to tests table."""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "b045"
|
||||
down_revision = "b044"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
op.add_column("tests", sa.Column("blue_work_started_at", sa.DateTime(), nullable=True))
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_column("tests", "blue_work_started_at")
|
||||
@@ -0,0 +1,22 @@
|
||||
"""Add 'disputed' value to teststate enum.
|
||||
|
||||
Revision ID: b046
|
||||
Revises: b045
|
||||
Create Date: 2026-06-03
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision = "b046"
|
||||
down_revision = "b045"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute("ALTER TYPE teststate ADD VALUE IF NOT EXISTS 'disputed'")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# PostgreSQL does not support removing enum values; downgrade is a no-op.
|
||||
pass
|
||||
@@ -0,0 +1,27 @@
|
||||
"""Add start_date to campaigns.
|
||||
|
||||
Revision ID: b047
|
||||
Revises: b046
|
||||
Create Date: 2026-06-03
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "b047"
|
||||
down_revision = "b046"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"campaigns",
|
||||
sa.Column("start_date", sa.DateTime(), nullable=True),
|
||||
)
|
||||
op.create_index("ix_campaigns_start_date", "campaigns", ["start_date"])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_campaigns_start_date", table_name="campaigns")
|
||||
op.drop_column("campaigns", "start_date")
|
||||
@@ -0,0 +1,39 @@
|
||||
"""Add evaluation_imports table.
|
||||
|
||||
Revision ID: b048
|
||||
Revises: b047
|
||||
Create Date: 2026-06-05
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
revision = "b048"
|
||||
down_revision = "b047"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"evaluation_imports",
|
||||
sa.Column("id", UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column("adversary_name", sa.String, nullable=False),
|
||||
sa.Column("adversary_display", sa.String, nullable=False),
|
||||
sa.Column("eval_round", sa.Integer, nullable=False),
|
||||
sa.Column("imported_at", sa.DateTime, nullable=False),
|
||||
sa.Column("imported_by", UUID(as_uuid=True), sa.ForeignKey("users.id"), nullable=True),
|
||||
sa.Column("tests_created", sa.Integer, default=0),
|
||||
sa.Column("techniques_covered", sa.Integer, default=0),
|
||||
sa.Column("status", sa.String, default="completed"),
|
||||
sa.Column("notes", sa.Text, nullable=True),
|
||||
)
|
||||
op.create_index("ix_evaluation_imports_adversary", "evaluation_imports", ["adversary_name"])
|
||||
op.create_index("ix_evaluation_imports_round", "evaluation_imports", ["eval_round"])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_evaluation_imports_round", table_name="evaluation_imports")
|
||||
op.drop_index("ix_evaluation_imports_adversary", table_name="evaluation_imports")
|
||||
op.drop_table("evaluation_imports")
|
||||
+24
-8
@@ -39,8 +39,7 @@ class Settings(BaseSettings):
|
||||
SECRET_KEY: str = ""
|
||||
# Assign ALGORITHM = "HS256"
|
||||
ALGORITHM: str = "HS256"
|
||||
# Assign ACCESS_TOKEN_EXPIRE_MINUTES = 15 # short-lived for security; configurable via env
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 15 # short-lived for security; configurable via env
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 480 # 8 hours — /auth/refresh extends active sessions
|
||||
|
||||
# ── Redis ─────────────────────────────────────────────────────────
|
||||
REDIS_URL: str = "redis://redis:6379/0"
|
||||
@@ -57,7 +56,10 @@ class Settings(BaseSettings):
|
||||
|
||||
# ── MinIO / S3 ───────────────────────────────────────────────────
|
||||
MINIO_ENDPOINT: str = "minio:9000"
|
||||
# Assign MINIO_ACCESS_KEY = "minioadmin"
|
||||
# Public hostname used in presigned URLs returned to browsers.
|
||||
# In production set this to <server-ip>:9000 (or a public FQDN) so
|
||||
# the browser can reach MinIO directly. Defaults to MINIO_ENDPOINT.
|
||||
MINIO_PUBLIC_ENDPOINT: str = ""
|
||||
MINIO_ACCESS_KEY: str = "minioadmin"
|
||||
# Assign MINIO_SECRET_KEY = "minioadmin"
|
||||
MINIO_SECRET_KEY: str = "minioadmin"
|
||||
@@ -81,10 +83,11 @@ class Settings(BaseSettings):
|
||||
JIRA_IS_CLOUD: bool = True
|
||||
# Assign JIRA_DEFAULT_PROJECT = ""
|
||||
JIRA_DEFAULT_PROJECT: str = ""
|
||||
# Assign JIRA_ISSUE_TYPE_TEST = "Task"
|
||||
JIRA_ISSUE_TYPE_TEST: str = "Task"
|
||||
# Assign JIRA_ISSUE_TYPE_CAMPAIGN = "Epic"
|
||||
JIRA_ISSUE_TYPE_CAMPAIGN: str = "Epic"
|
||||
JIRA_ISSUE_TYPE_TEST: str = "Task" # tests (campaign or standalone)
|
||||
JIRA_ISSUE_TYPE_CAMPAIGN: str = "Epic" # campaigns (under Initiative)
|
||||
# Jira custom field ID for "Start date" — Jira Cloud team-managed: customfield_10015
|
||||
# Override with the correct field ID for your Jira instance if different.
|
||||
JIRA_START_DATE_FIELD: str = "customfield_10015"
|
||||
|
||||
# ── Tempo Integration ─────────────────────────────────────────────
|
||||
TEMPO_ENABLED: bool = False
|
||||
@@ -94,6 +97,9 @@ class Settings(BaseSettings):
|
||||
TEMPO_API_VERSION: int = 4
|
||||
# Assign TEMPO_DEFAULT_WORK_TYPE = "Red Team"
|
||||
TEMPO_DEFAULT_WORK_TYPE: str = "Red Team"
|
||||
# Tempo API base URL — use https://api.eu.tempo.io/4 for EU workspaces.
|
||||
# Can also be set via system_configs key "tempo.base_url" at runtime.
|
||||
TEMPO_BASE_URL: str = "" # empty → falls back to https://api.tempo.io/4
|
||||
|
||||
# ── OSINT / Intelligence ────────────────────────────────────────
|
||||
NVD_API_KEY: str = "" # optional; increases NVD rate limit from 5/30s to 50/30s
|
||||
@@ -103,12 +109,22 @@ class Settings(BaseSettings):
|
||||
# ── Reporting ─────────────────────────────────────────────────────
|
||||
REPORT_TEMPLATES_DIR: str = "app/templates/reports"
|
||||
# Assign REPORT_OUTPUT_DIR = "/tmp/aegis_reports"
|
||||
REPORT_OUTPUT_DIR: str = "/tmp/aegis_reports"
|
||||
REPORT_OUTPUT_DIR: str = "/app/reports"
|
||||
# Assign COMPANY_NAME = "Organization"
|
||||
COMPANY_NAME: str = "Organization"
|
||||
# Assign COMPANY_LOGO_PATH = "app/templates/reports/assets/logo.png"
|
||||
COMPANY_LOGO_PATH: str = "app/templates/reports/assets/logo.png"
|
||||
|
||||
# ── Email / SMTP ──────────────────────────────────────────────────
|
||||
SMTP_ENABLED: bool = False
|
||||
SMTP_HOST: str = ""
|
||||
SMTP_PORT: int = 587
|
||||
SMTP_USERNAME: str = ""
|
||||
SMTP_PASSWORD: str = ""
|
||||
SMTP_FROM_EMAIL: str = "aegis@company.com"
|
||||
SMTP_USE_TLS: bool = True
|
||||
PLATFORM_URL: str = "http://localhost:5173" # base URL for links in emails
|
||||
|
||||
# ── Scoring weights (must sum to 100) ────────────────────────────
|
||||
SCORING_WEIGHT_TESTS: int = 40
|
||||
# Assign SCORING_WEIGHT_DETECTION_RULES = 25
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
Provides:
|
||||
- ``get_current_user``: decodes JWT from HttpOnly cookie (preferred) or
|
||||
Authorization header (fallback), fetches user from DB, raises 401 on failure.
|
||||
Also accepts Aegis API keys (``aegis_…`` prefix) as Bearer tokens.
|
||||
- ``require_role``: factory that returns a dependency enforcing a specific role
|
||||
(admins always pass).
|
||||
"""
|
||||
@@ -36,6 +37,7 @@ from app.database import get_db
|
||||
|
||||
# Import User from app.models.user
|
||||
from app.models.user import User
|
||||
from app.models.api_key import KEY_PREFIX
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OAuth2 scheme (reads Authorization header — used as fallback / Swagger UI)
|
||||
@@ -98,7 +100,15 @@ async def get_current_user(
|
||||
# Raise credentials_exception
|
||||
raise credentials_exception
|
||||
|
||||
# Attempt the following; catch errors below
|
||||
# ── API Key path (Bearer token starts with "aegis_") ──────────────────
|
||||
if token.startswith(KEY_PREFIX):
|
||||
from app.services.api_key_service import authenticate_raw_key
|
||||
user = authenticate_raw_key(db, token)
|
||||
if user is None:
|
||||
raise credentials_exception
|
||||
return user
|
||||
|
||||
# ── JWT path ──────────────────────────────────────────────────────────
|
||||
try:
|
||||
# Assign payload = jwt.decode(
|
||||
payload = jwt.decode(
|
||||
@@ -162,12 +172,27 @@ async def require_password_changed(
|
||||
return current_user
|
||||
|
||||
|
||||
# Define function require_role
|
||||
def require_role(required_role: str) -> Callable[..., object]:
|
||||
def _check_api_key_scope(user: User, required_scope: str) -> None:
|
||||
"""Raise 403 if the request was authenticated via an API key that lacks *required_scope*.
|
||||
|
||||
When authenticated via JWT (browser session), ``_api_key_scopes`` is not set
|
||||
and the check is skipped — full access is granted based on role alone.
|
||||
"""
|
||||
key_scopes = getattr(user, "_api_key_scopes", None)
|
||||
if key_scopes is not None and required_scope not in key_scopes:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"API key scope '{required_scope}' required for this operation",
|
||||
)
|
||||
|
||||
|
||||
def require_role(required_role: str):
|
||||
"""Return a FastAPI dependency that enforces *required_role*.
|
||||
|
||||
The dependency allows the request to proceed when
|
||||
``user.role == required_role`` **or** ``user.role == "admin"``.
|
||||
Also enforces API key scopes: admin-role endpoints require the ``admin``
|
||||
scope; all other role-restricted endpoints require ``write``.
|
||||
Otherwise it raises :class:`~fastapi.HTTPException` **403**.
|
||||
"""
|
||||
|
||||
@@ -185,7 +210,8 @@ def require_role(required_role: str) -> Callable[..., object]:
|
||||
# Keyword argument: detail
|
||||
detail="Not enough permissions",
|
||||
)
|
||||
# Return current_user
|
||||
scope = "admin" if required_role == "admin" else "write"
|
||||
_check_api_key_scope(current_user, scope)
|
||||
return current_user
|
||||
|
||||
# Return role_checker
|
||||
@@ -196,7 +222,11 @@ def require_role(required_role: str) -> Callable[..., object]:
|
||||
def require_any_role(*roles: str) -> Callable[..., object]:
|
||||
"""Return a FastAPI dependency that enforces **any** of the given *roles*.
|
||||
|
||||
Admins always pass. Usage example::
|
||||
Admins always pass. Also enforces API key scopes: if the only accepted
|
||||
role is ``admin``, the key must carry the ``admin`` scope; otherwise the
|
||||
``write`` scope is required.
|
||||
|
||||
Usage example::
|
||||
|
||||
@router.patch("/resource", dependencies=[Depends(require_any_role("red_lead", "blue_lead"))])
|
||||
"""
|
||||
@@ -215,8 +245,28 @@ def require_any_role(*roles: str) -> Callable[..., object]:
|
||||
# Keyword argument: detail
|
||||
detail="Not enough permissions",
|
||||
)
|
||||
# Return current_user
|
||||
scope = "admin" if set(roles) == {"admin"} else "write"
|
||||
_check_api_key_scope(current_user, scope)
|
||||
return current_user
|
||||
|
||||
# Return role_checker
|
||||
return role_checker
|
||||
|
||||
|
||||
def require_scope(scope: str):
|
||||
"""Return a dependency that enforces the API key carries *scope*.
|
||||
|
||||
JWT-authenticated requests (browser sessions) bypass this check entirely.
|
||||
Use on mutation endpoints that don't already use ``require_role`` /
|
||||
``require_any_role``::
|
||||
|
||||
@router.post("/resource", dependencies=[Depends(require_scope("write"))])
|
||||
"""
|
||||
|
||||
async def scope_checker(
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> User:
|
||||
_check_api_key_scope(current_user, scope)
|
||||
return current_user
|
||||
|
||||
return scope_checker
|
||||
|
||||
@@ -221,14 +221,17 @@ class TechniqueEntity:
|
||||
) -> TechniqueStatus:
|
||||
"""Recompute ``status_global`` from a list of (state, detection_result) pairs.
|
||||
|
||||
Rules (v2):
|
||||
Rules (v3):
|
||||
1. No tests -> not_evaluated
|
||||
2. All validated -> inspect detection results:
|
||||
- All detected -> validated
|
||||
- Any partially_detected -> partial
|
||||
- Otherwise -> not_covered
|
||||
3. Some validated, others in progress -> partial
|
||||
4. All in intermediate states -> in_progress
|
||||
2. All tests validated -> inspect detection results:
|
||||
a. All detected AND ≥ 1 validated test -> validated
|
||||
b. Any partially_detected -> partial
|
||||
d. Otherwise (no detected results) -> not_covered
|
||||
3. Some validated, others in intermediate states -> partial
|
||||
4. All tests in intermediate states (draft/executing/evaluating/review/rejected)
|
||||
-> in_progress
|
||||
|
||||
Minimum validated count for "validated": 1 test.
|
||||
|
||||
Args:
|
||||
test_snapshots (list[tuple[str, str | None]]): Each element is a
|
||||
@@ -240,7 +243,8 @@ class TechniqueEntity:
|
||||
TechniqueStatus: The newly computed status, which is also stored on
|
||||
the entity's ``status_global`` field.
|
||||
"""
|
||||
# Assign tests = [
|
||||
min_validated_for_full = 1 # require ≥ N validated tests for "validated"
|
||||
|
||||
tests = [
|
||||
_TestSnapshot(
|
||||
# Keyword argument: state
|
||||
@@ -257,13 +261,15 @@ class TechniqueEntity:
|
||||
self.status_global = TechniqueStatus.not_evaluated
|
||||
# Alternative: all(t.state == TestState.validated for t in tests)
|
||||
elif all(t.state == TestState.validated for t in tests):
|
||||
# Assign results = [t.detection_result for t in tests if t.detection_result]
|
||||
validated_count = len(tests)
|
||||
results = [t.detection_result for t in tests if t.detection_result]
|
||||
# Check: results and all(r == TestResult.detected or r == "detected" for r i...
|
||||
if results and all(r == TestResult.detected or r == "detected" for r in results):
|
||||
# Assign self.status_global = TechniqueStatus.validated
|
||||
self.status_global = TechniqueStatus.validated
|
||||
# elif any(
|
||||
# Need at least min_validated_for_full tests for "validated"
|
||||
if validated_count >= min_validated_for_full:
|
||||
self.status_global = TechniqueStatus.validated
|
||||
else:
|
||||
self.status_global = TechniqueStatus.partial
|
||||
elif any(
|
||||
# Keyword argument: r
|
||||
r == TestResult.partially_detected or r == "partially_detected"
|
||||
|
||||
@@ -43,6 +43,7 @@ class TestState(str, enum.Enum):
|
||||
validated = "validated"
|
||||
# Assign rejected = "rejected"
|
||||
rejected = "rejected"
|
||||
disputed = "disputed" # one lead approved, the other rejected
|
||||
|
||||
|
||||
# Define class TeamSide
|
||||
|
||||
@@ -68,6 +68,7 @@ class TestState(str, enum.Enum):
|
||||
validated = "validated"
|
||||
# Assign rejected = "rejected"
|
||||
rejected = "rejected"
|
||||
disputed = "disputed" # one lead approved, the other rejected
|
||||
|
||||
|
||||
# Assign VALID_TRANSITIONS = {
|
||||
@@ -75,7 +76,8 @@ VALID_TRANSITIONS: dict[TestState, list[TestState]] = {
|
||||
TestState.draft: [TestState.red_executing],
|
||||
TestState.red_executing: [TestState.blue_evaluating],
|
||||
TestState.blue_evaluating: [TestState.in_review],
|
||||
TestState.in_review: [TestState.validated, TestState.rejected],
|
||||
TestState.in_review: [TestState.validated, TestState.rejected, TestState.disputed],
|
||||
TestState.disputed: [TestState.validated, TestState.rejected],
|
||||
TestState.rejected: [TestState.draft],
|
||||
TestState.validated: [],
|
||||
}
|
||||
@@ -591,37 +593,23 @@ class TestEntity:
|
||||
def check_dual_validation(self) -> None:
|
||||
"""Evaluate both leads' votes and advance state if appropriate.
|
||||
|
||||
- Both **approved** -> ``validated``
|
||||
- Either **rejected** -> ``rejected``
|
||||
- Otherwise no change (waiting for the other lead).
|
||||
Rules (v2 — consensus required):
|
||||
- Both **approved** -> ``validated``
|
||||
- Both **rejected** -> ``rejected``
|
||||
- One approved + one rejected -> ``disputed`` (conflict, needs discussion)
|
||||
- Otherwise (one or both still pending) -> no change
|
||||
|
||||
Called automatically by :meth:`validate_red` and :meth:`validate_blue`.
|
||||
Also available as a standalone entry point for backward compatibility
|
||||
when validation fields are set externally.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Call self._check_dual_validation()
|
||||
self._check_dual_validation()
|
||||
|
||||
# Define function _assert_in_review
|
||||
def _assert_in_review(self, side: str) -> None:
|
||||
"""Raise InvalidOperationError unless the test is in ``in_review`` state.
|
||||
|
||||
Args:
|
||||
side (str): The team side being validated (``"red"`` or ``"blue"``),
|
||||
used in the error message.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Check: self.state != TestState.in_review
|
||||
if self.state != TestState.in_review:
|
||||
# Raise InvalidOperationError
|
||||
if self.state not in (TestState.in_review, TestState.disputed):
|
||||
raise InvalidOperationError(
|
||||
f"Cannot validate {side} side while test is in "
|
||||
f"'{self.state.value}' state (must be in_review)"
|
||||
f"'{self.state.value}' state (must be in_review or disputed)"
|
||||
)
|
||||
|
||||
# Apply the @staticmethod decorator
|
||||
@@ -646,22 +634,15 @@ class TestEntity:
|
||||
|
||||
# Define function _check_dual_validation
|
||||
def _check_dual_validation(self) -> None:
|
||||
"""Advance to ``validated`` or ``rejected`` once both leads have voted.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# r, b = self.red_validation_status, self.blue_validation_status
|
||||
"""Advance the test state once both leads have voted."""
|
||||
r, b = self.red_validation_status, self.blue_validation_status
|
||||
# Check: r == "rejected" or b == "rejected"
|
||||
if r == "rejected" or b == "rejected":
|
||||
# Assign self.state = TestState.rejected
|
||||
self.state = TestState.rejected
|
||||
# Call self._events.append()
|
||||
self._events.append(DomainEvent("dual_validation_rejected"))
|
||||
# Alternative: r == "approved" and b == "approved"
|
||||
elif r == "approved" and b == "approved":
|
||||
# Assign self.state = TestState.validated
|
||||
|
||||
if r == "approved" and b == "approved":
|
||||
self.state = TestState.validated
|
||||
# Call self._events.append()
|
||||
self._events.append(DomainEvent("dual_validation_approved"))
|
||||
|
||||
elif r == "rejected" or b == "rejected":
|
||||
# Any rejection is a veto — one lead can reject without waiting for the other
|
||||
self.state = TestState.rejected
|
||||
self._events.append(DomainEvent("dual_validation_rejected"))
|
||||
|
||||
@@ -12,6 +12,7 @@ sessions.
|
||||
|
||||
# Import logging
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
# Import BackgroundScheduler from apscheduler.schedulers.background
|
||||
from apscheduler.schedulers.background import BackgroundScheduler
|
||||
@@ -63,7 +64,7 @@ scheduler = BackgroundScheduler()
|
||||
|
||||
def _run_mitre_sync() -> None:
|
||||
"""Execute a MITRE sync inside its own DB session."""
|
||||
# Log info: "Scheduled MITRE sync job starting..."
|
||||
from app.services.webhook_service import dispatch_webhook
|
||||
logger.info("Scheduled MITRE sync job starting...")
|
||||
# Assign db = SessionLocal()
|
||||
db = SessionLocal()
|
||||
@@ -73,7 +74,7 @@ def _run_mitre_sync() -> None:
|
||||
summary = sync_mitre(db)
|
||||
# Log info: "Scheduled MITRE sync job finished — %s", summary
|
||||
logger.info("Scheduled MITRE sync job finished — %s", summary)
|
||||
# Handle Exception
|
||||
dispatch_webhook("mitre.synced", {"created": summary.get("created", 0), "updated": summary.get("updated", 0)})
|
||||
except Exception:
|
||||
# Log exception: "Scheduled MITRE sync job failed"
|
||||
logger.exception("Scheduled MITRE sync job failed")
|
||||
@@ -163,7 +164,96 @@ def _run_recurring_campaigns() -> None:
|
||||
db.close()
|
||||
|
||||
|
||||
# Define function _run_intel_scan
|
||||
def _run_scheduled_campaign_activation() -> None:
|
||||
"""Auto-activate campaigns whose start_date has arrived.
|
||||
|
||||
Finds all campaigns in 'draft' state with a start_date <= now,
|
||||
activates them, creates Jira tickets, and notifies the red_tech team.
|
||||
Runs every hour so campaigns activate within ~1 hour of their scheduled time.
|
||||
"""
|
||||
logger.info("Scheduled campaign auto-activation check starting...")
|
||||
db = SessionLocal()
|
||||
try:
|
||||
from datetime import datetime as _dt
|
||||
from app.models.campaign import Campaign
|
||||
from app.models.user import User
|
||||
from app.services.campaign_crud_service import activate_campaign as _activate
|
||||
from app.services.notification_service import notify_role
|
||||
from app.services.audit_service import log_action
|
||||
|
||||
now = _dt.utcnow()
|
||||
due_campaigns = (
|
||||
db.query(Campaign)
|
||||
.filter(
|
||||
Campaign.status == "draft",
|
||||
Campaign.start_date != None, # noqa: E711
|
||||
Campaign.start_date <= now,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
activated = 0
|
||||
for campaign in due_campaigns:
|
||||
try:
|
||||
_activate(db, str(campaign.id))
|
||||
notify_role(
|
||||
db,
|
||||
role="red_tech",
|
||||
type="campaign_activated",
|
||||
title="Campaign auto-activated",
|
||||
message=f'Campaign "{campaign.name}" has been automatically activated on its scheduled start date.',
|
||||
entity_type="campaign",
|
||||
entity_id=campaign.id,
|
||||
)
|
||||
log_action(
|
||||
db,
|
||||
user_id=None,
|
||||
action="auto_activate_campaign",
|
||||
entity_type="campaign",
|
||||
entity_id=campaign.id,
|
||||
details={"name": campaign.name, "start_date": str(campaign.start_date)},
|
||||
)
|
||||
|
||||
# Create Jira tickets non-fatally
|
||||
try:
|
||||
from app.services.jira_service import (
|
||||
auto_create_campaign_issue,
|
||||
auto_create_test_issue,
|
||||
get_campaign_jira_key,
|
||||
get_test_jira_key,
|
||||
)
|
||||
# Use first admin user as actor for Jira auth
|
||||
admin_user = db.query(User).filter(User.role == "admin").first()
|
||||
if admin_user:
|
||||
db.refresh(campaign)
|
||||
campaign_jira_key = get_campaign_jira_key(db, str(campaign.id))
|
||||
if not campaign_jira_key:
|
||||
campaign_jira_key = auto_create_campaign_issue(db, campaign, admin_user)
|
||||
if campaign_jira_key:
|
||||
for ct in campaign.campaign_tests:
|
||||
if ct.test and not get_test_jira_key(db, ct.test.id):
|
||||
auto_create_test_issue(
|
||||
db, ct.test, admin_user,
|
||||
parent_ticket_override=campaign_jira_key,
|
||||
campaign_start_date=campaign.start_date,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Jira auto-create failed for auto-activated campaign %s", campaign.id)
|
||||
|
||||
db.commit()
|
||||
activated += 1
|
||||
logger.info("Auto-activated campaign %s (%s)", campaign.id, campaign.name)
|
||||
except Exception:
|
||||
logger.exception("Failed to auto-activate campaign %s", campaign.id)
|
||||
db.rollback()
|
||||
|
||||
logger.info("Campaign auto-activation check finished — activated %d campaigns", activated)
|
||||
except Exception:
|
||||
logger.exception("Campaign auto-activation job failed")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def _run_intel_scan() -> None:
|
||||
"""Execute an intel scan inside its own DB session."""
|
||||
# Log info: "Scheduled intel scan job starting..."
|
||||
@@ -186,7 +276,83 @@ def _run_intel_scan() -> None:
|
||||
db.close()
|
||||
|
||||
|
||||
# Define function _run_osint_enrichment
|
||||
def _run_evaluation_round_check() -> None:
|
||||
"""Weekly job: check if a new ATT&CK Evaluation round is available.
|
||||
|
||||
If a new round is found it is imported automatically and an admin
|
||||
notification is created so the team knows new baseline data is available.
|
||||
"""
|
||||
logger.info("ATT&CK Evaluations new-round check starting...")
|
||||
db = SessionLocal()
|
||||
try:
|
||||
from app.services.attck_evaluations_service import check_for_new_round, import_evaluation_round
|
||||
from app.models.user import User as UserModel
|
||||
|
||||
result = check_for_new_round(db)
|
||||
if result.get("error"):
|
||||
logger.warning("ATT&CK Evaluations check failed: %s", result["error"])
|
||||
return
|
||||
|
||||
if not result.get("new_round_available"):
|
||||
logger.info(
|
||||
"ATT&CK Evaluations check — latest round '%s' already imported",
|
||||
result.get("latest_round", {}).get("display_name", "?"),
|
||||
)
|
||||
return
|
||||
|
||||
latest = result["latest_round"]
|
||||
logger.info(
|
||||
"New ATT&CK Evaluation round detected: %s (round %d) — starting auto-import",
|
||||
latest["display_name"], latest["eval_round"],
|
||||
)
|
||||
|
||||
# Use the first admin user as the importer (system action)
|
||||
admin = db.query(UserModel).filter(UserModel.role == "admin").first()
|
||||
if not admin:
|
||||
logger.warning("ATT&CK Evaluations auto-import: no admin user found — skipping")
|
||||
return
|
||||
|
||||
summary = import_evaluation_round(
|
||||
db,
|
||||
latest["name"],
|
||||
latest["display_name"],
|
||||
latest["eval_round"],
|
||||
admin,
|
||||
)
|
||||
logger.info(
|
||||
"ATT&CK Evaluations auto-import complete — round %d (%s): %d tests created",
|
||||
latest["eval_round"], latest["display_name"], summary["created"],
|
||||
)
|
||||
|
||||
# Notify all admins
|
||||
try:
|
||||
from app.services.notification_service import create_notification
|
||||
admins = db.query(UserModel).filter(UserModel.role == "admin").all()
|
||||
for adm in admins:
|
||||
create_notification(
|
||||
db,
|
||||
user_id=adm.id,
|
||||
title="New ATT&CK Evaluation round imported",
|
||||
message=(
|
||||
f"Round {latest['eval_round']} — {latest['display_name']} — "
|
||||
f"has been automatically imported. "
|
||||
f"{summary['created']} tests created in In Review state. "
|
||||
f"Blue Leads must validate each result before it counts as coverage."
|
||||
),
|
||||
notification_type="eval_import",
|
||||
entity_type="evaluation",
|
||||
entity_id=None,
|
||||
)
|
||||
db.commit()
|
||||
except Exception:
|
||||
logger.warning("Failed to send eval import notifications", exc_info=True)
|
||||
|
||||
except Exception:
|
||||
logger.exception("ATT&CK Evaluations round check job failed")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def _run_osint_enrichment() -> None:
|
||||
"""Execute weekly OSINT enrichment inside its own DB session."""
|
||||
# Log info: "Scheduled OSINT enrichment job starting..."
|
||||
@@ -209,7 +375,61 @@ def _run_osint_enrichment() -> None:
|
||||
db.close()
|
||||
|
||||
|
||||
# Define function _run_stale_detection
|
||||
_FREQUENCY_INTERVALS: dict[str, timedelta] = {
|
||||
"daily": timedelta(days=1),
|
||||
"weekly": timedelta(weeks=1),
|
||||
"monthly": timedelta(days=30),
|
||||
}
|
||||
|
||||
|
||||
def _run_data_sources_sync() -> None:
|
||||
"""Check all enabled data sources and sync those that are overdue."""
|
||||
logger.info("Scheduled data sources sync check starting...")
|
||||
db = SessionLocal()
|
||||
try:
|
||||
from app.models.data_source import DataSource
|
||||
from app.services.data_source_service import sync_source
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
sources = (
|
||||
db.query(DataSource)
|
||||
.filter(DataSource.is_enabled == True) # noqa: E712
|
||||
.all()
|
||||
)
|
||||
synced = 0
|
||||
for ds in sources:
|
||||
freq = ds.sync_frequency
|
||||
if not freq or freq == "manual":
|
||||
continue
|
||||
interval = _FREQUENCY_INTERVALS.get(freq)
|
||||
if interval is None:
|
||||
continue
|
||||
last = ds.last_sync_at
|
||||
if last is None:
|
||||
# Never synced — run it now
|
||||
overdue = True
|
||||
else:
|
||||
# Make last timezone-aware if needed
|
||||
if last.tzinfo is None:
|
||||
last = last.replace(tzinfo=timezone.utc)
|
||||
overdue = now - last >= interval
|
||||
if overdue:
|
||||
logger.info(
|
||||
"Data source '%s' is overdue (freq=%s, last=%s) — syncing",
|
||||
ds.name, freq, last,
|
||||
)
|
||||
try:
|
||||
sync_source(db, str(ds.id))
|
||||
synced += 1
|
||||
except Exception:
|
||||
logger.exception("Failed to sync data source '%s'", ds.name)
|
||||
logger.info("Data sources sync check finished — %d source(s) synced", synced)
|
||||
except Exception:
|
||||
logger.exception("Data sources sync check failed")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def _run_stale_detection() -> None:
|
||||
"""Execute daily stale coverage detection inside its own DB session."""
|
||||
# Log info: "Scheduled stale coverage detection starting..."
|
||||
@@ -232,6 +452,53 @@ def _run_stale_detection() -> None:
|
||||
db.close()
|
||||
|
||||
|
||||
def _run_decay_engine() -> None:
|
||||
"""Execute the decay engine inside its own DB session."""
|
||||
logger.info("Scheduled decay engine job starting...")
|
||||
db = SessionLocal()
|
||||
try:
|
||||
from app.services.decay_engine_service import run_decay_engine
|
||||
results = run_decay_engine(db)
|
||||
logger.info("Decay engine job finished — %s", results)
|
||||
except Exception:
|
||||
logger.exception("Decay engine job failed")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def _run_queue_generation() -> None:
|
||||
"""Generate revalidation queue items for analysts — runs after decay engine."""
|
||||
logger.info("Scheduled revalidation queue generation starting...")
|
||||
db = SessionLocal()
|
||||
try:
|
||||
from app.services.revalidation_queue_service import generate_queue_items
|
||||
results = generate_queue_items(db)
|
||||
logger.info("Queue generation finished — %s", results)
|
||||
except Exception:
|
||||
logger.exception("Queue generation job failed")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def _run_alert_evaluation() -> None:
|
||||
"""Evaluate all enabled operational alert rules (hourly)."""
|
||||
logger.info("Scheduled alert evaluation job starting...")
|
||||
db = SessionLocal()
|
||||
try:
|
||||
from app.services.operational_alert_service import evaluate_all_rules
|
||||
result = evaluate_all_rules(db)
|
||||
logger.info(
|
||||
"Alert evaluation finished — %d rules, %d alerts fired in %.3fs",
|
||||
result["rules_evaluated"],
|
||||
result["alerts_fired"],
|
||||
result["duration_seconds"],
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Alert evaluation job failed")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scheduler bootstrap
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -308,6 +575,14 @@ def start_scheduler() -> None:
|
||||
replace_existing=True,
|
||||
)
|
||||
# Call scheduler.add_job()
|
||||
scheduler.add_job(
|
||||
_run_scheduled_campaign_activation,
|
||||
trigger="interval",
|
||||
hours=1,
|
||||
id="scheduled_campaign_activation",
|
||||
name="Auto-activate campaigns on start_date (hourly)",
|
||||
replace_existing=True,
|
||||
)
|
||||
scheduler.add_job(
|
||||
_run_recurring_campaigns,
|
||||
# Keyword argument: trigger
|
||||
@@ -377,7 +652,50 @@ def start_scheduler() -> None:
|
||||
# Keyword argument: replace_existing
|
||||
replace_existing=True,
|
||||
)
|
||||
# Call scheduler.start()
|
||||
scheduler.add_job(
|
||||
_run_data_sources_sync,
|
||||
trigger="interval",
|
||||
hours=6,
|
||||
id="data_sources_sync",
|
||||
name="Data sources auto-sync (every 6h)",
|
||||
replace_existing=True,
|
||||
)
|
||||
scheduler.add_job(
|
||||
_run_decay_engine,
|
||||
trigger="cron",
|
||||
hour=2,
|
||||
minute=0,
|
||||
id="decay_engine",
|
||||
name="Detection decay engine (daily 02:00)",
|
||||
replace_existing=True,
|
||||
)
|
||||
scheduler.add_job(
|
||||
_run_queue_generation,
|
||||
trigger="cron",
|
||||
hour=2,
|
||||
minute=30,
|
||||
id="queue_generation",
|
||||
name="Revalidation queue generation (daily 02:30)",
|
||||
replace_existing=True,
|
||||
)
|
||||
scheduler.add_job(
|
||||
_run_alert_evaluation,
|
||||
trigger="interval",
|
||||
hours=1,
|
||||
id="alert_evaluation",
|
||||
name="Operational alert evaluation (hourly)",
|
||||
replace_existing=True,
|
||||
)
|
||||
scheduler.add_job(
|
||||
_run_evaluation_round_check,
|
||||
trigger="cron",
|
||||
day_of_week="mon",
|
||||
hour=6,
|
||||
minute=0,
|
||||
id="attck_evaluation_check",
|
||||
name="ATT&CK Evaluations new-round check (Mondays 06:00)",
|
||||
replace_existing=True,
|
||||
)
|
||||
scheduler.start()
|
||||
# Log info:
|
||||
logger.info(
|
||||
@@ -389,6 +707,6 @@ def start_scheduler() -> None:
|
||||
"recurring_campaigns (daily), jira_sync (1h), "
|
||||
# Literal argument value
|
||||
"osint_enrichment (weekly), stale_detection (daily), "
|
||||
# Literal argument value
|
||||
"retention_policies (daily)"
|
||||
"retention_policies (daily), data_sources_sync (6h), "
|
||||
"alert_evaluation (1h), attck_evaluation_check (Mondays 06:00)"
|
||||
)
|
||||
|
||||
+89
-89
@@ -38,10 +38,45 @@ from slowapi.errors import RateLimitExceeded
|
||||
# Import SQLAlchemyError from sqlalchemy.exc
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
# Import settings as _settings from app.config
|
||||
from app.config import settings as _settings
|
||||
|
||||
# Import DomainError from app.domain.errors
|
||||
from app.routers import auth as auth_router
|
||||
from app.routers import techniques as techniques_router
|
||||
from app.routers import tests as tests_router
|
||||
from app.routers import evidence as evidence_router
|
||||
from app.routers import test_templates as test_templates_router
|
||||
from app.routers import system as system_router
|
||||
from app.routers import metrics as metrics_router
|
||||
from app.routers import users as users_router
|
||||
from app.routers import audit as audit_router
|
||||
from app.routers import notifications as notifications_router
|
||||
from app.routers import reports as reports_router
|
||||
from app.routers import data_sources as data_sources_router
|
||||
from app.routers import threat_actors as threat_actors_router
|
||||
from app.routers import d3fend as d3fend_router
|
||||
from app.routers import detection_rules as detection_rules_router
|
||||
from app.routers import campaigns as campaigns_router
|
||||
from app.routers import heatmap as heatmap_router
|
||||
from app.routers import scores as scores_router
|
||||
from app.routers import operational_metrics as operational_metrics_router
|
||||
from app.routers import compliance as compliance_router
|
||||
from app.routers import snapshots as snapshots_router
|
||||
from app.routers import jira as jira_router
|
||||
from app.routers import worklogs as worklogs_router
|
||||
from app.routers import professional_reports as professional_reports_router
|
||||
from app.routers import analytics as analytics_router
|
||||
from app.routers import advanced_metrics as advanced_metrics_router
|
||||
from app.routers import osint as osint_router
|
||||
from app.routers import webhooks as webhooks_router
|
||||
from app.routers import detection_lifecycle as detection_lifecycle_router
|
||||
from app.routers import intel as intel_router
|
||||
from app.routers import admin_config as admin_config_router
|
||||
from app.routers import ownership as ownership_router
|
||||
from app.routers import attack_paths as attack_paths_router
|
||||
from app.routers import knowledge as knowledge_router
|
||||
from app.routers import risk_intelligence as risk_router
|
||||
from app.routers import executive_dashboard as dashboard_router
|
||||
from app.routers import api_keys as api_keys_router
|
||||
from app.routers import sso as sso_router
|
||||
from app.routers import operational_alerts as alerts_router
|
||||
from app.domain.errors import DomainError
|
||||
|
||||
# Import scheduler, start_scheduler from app.jobs.mitre_sync_job
|
||||
@@ -58,94 +93,15 @@ from app.middleware.error_handler import domain_exception_handler
|
||||
|
||||
# Import RequestContextMiddleware from app.middleware.request_context
|
||||
from app.middleware.request_context import RequestContextMiddleware
|
||||
|
||||
# Import advanced_metrics as advanced_metrics_router from app.routers
|
||||
from app.routers import advanced_metrics as advanced_metrics_router
|
||||
|
||||
# Import analytics as analytics_router from app.routers
|
||||
from app.routers import analytics as analytics_router
|
||||
|
||||
# Import audit as audit_router from app.routers
|
||||
from app.routers import audit as audit_router
|
||||
|
||||
# Import auth as auth_router from app.routers
|
||||
from app.routers import auth as auth_router
|
||||
|
||||
# Import campaigns as campaigns_router from app.routers
|
||||
from app.routers import campaigns as campaigns_router
|
||||
|
||||
# Import compliance as compliance_router from app.routers
|
||||
from app.routers import compliance as compliance_router
|
||||
|
||||
# Import d3fend as d3fend_router from app.routers
|
||||
from app.routers import d3fend as d3fend_router
|
||||
|
||||
# Import data_sources as data_sources_router from app.routers
|
||||
from app.routers import data_sources as data_sources_router
|
||||
|
||||
# Import detection_rules as detection_rules_router from app.routers
|
||||
from app.routers import detection_rules as detection_rules_router
|
||||
|
||||
# Import evidence as evidence_router from app.routers
|
||||
from app.routers import evidence as evidence_router
|
||||
|
||||
# Import heatmap as heatmap_router from app.routers
|
||||
from app.routers import heatmap as heatmap_router
|
||||
|
||||
# Import jira as jira_router from app.routers
|
||||
from app.routers import jira as jira_router
|
||||
|
||||
# Import metrics as metrics_router from app.routers
|
||||
from app.routers import metrics as metrics_router
|
||||
|
||||
# Import notifications as notifications_router from app.routers
|
||||
from app.routers import notifications as notifications_router
|
||||
|
||||
# Import operational_metrics as operational_metrics_router from app.routers
|
||||
from app.routers import operational_metrics as operational_metrics_router
|
||||
|
||||
# Import osint as osint_router from app.routers
|
||||
from app.routers import osint as osint_router
|
||||
|
||||
# Import professional_reports as professional_reports_ro... from app.routers
|
||||
from app.routers import professional_reports as professional_reports_router
|
||||
|
||||
# Import reports as reports_router from app.routers
|
||||
from app.routers import reports as reports_router
|
||||
|
||||
# Import scores as scores_router from app.routers
|
||||
from app.routers import scores as scores_router
|
||||
|
||||
# Import snapshots as snapshots_router from app.routers
|
||||
from app.routers import snapshots as snapshots_router
|
||||
|
||||
# Import system as system_router from app.routers
|
||||
from app.routers import system as system_router
|
||||
|
||||
# Import techniques as techniques_router from app.routers
|
||||
from app.routers import techniques as techniques_router
|
||||
|
||||
# Import test_templates as test_templates_router from app.routers
|
||||
from app.routers import test_templates as test_templates_router
|
||||
|
||||
# Import tests as tests_router from app.routers
|
||||
from app.routers import tests as tests_router
|
||||
|
||||
# Import threat_actors as threat_actors_router from app.routers
|
||||
from app.routers import threat_actors as threat_actors_router
|
||||
|
||||
# Import users as users_router from app.routers
|
||||
from app.routers import users as users_router
|
||||
|
||||
# Import worklogs as worklogs_router from app.routers
|
||||
from app.routers import worklogs as worklogs_router
|
||||
|
||||
# Import ensure_bucket_exists from app.storage
|
||||
from app.storage import ensure_bucket_exists
|
||||
from app.config import settings as _settings
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
# Configure structured logging before any module initialises its own logger
|
||||
setup_logging()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── Environment detection ─────────────────────────────────────────────────
|
||||
_IS_PRODUCTION = os.environ.get("AEGIS_ENV", "").lower() == "production"
|
||||
|
||||
@@ -165,7 +121,25 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
ensure_bucket_exists()
|
||||
# Call start_scheduler()
|
||||
start_scheduler()
|
||||
# Yield value
|
||||
# Seed decay policies
|
||||
from app.database import SessionLocal
|
||||
from app.seed_decay_policies import seed_decay_policies
|
||||
db = SessionLocal()
|
||||
try:
|
||||
seed_decay_policies(db)
|
||||
except Exception as e:
|
||||
logger.warning("seed_decay_policies failed at startup: %s", e)
|
||||
finally:
|
||||
db.close()
|
||||
# Seed operational alert system rules
|
||||
db2 = SessionLocal()
|
||||
try:
|
||||
from app.services.operational_alert_service import seed_system_rules
|
||||
seed_system_rules(db2)
|
||||
except Exception as e:
|
||||
logger.warning("seed_system_rules failed at startup: %s", e)
|
||||
finally:
|
||||
db2.close()
|
||||
yield
|
||||
# Graceful shutdown of the background scheduler
|
||||
scheduler.shutdown(wait=False)
|
||||
@@ -193,6 +167,21 @@ app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
||||
# Call app.add_middleware()
|
||||
app.add_middleware(RequestContextMiddleware)
|
||||
|
||||
|
||||
# ── No-cache middleware for all /api/ responses ───────────────────────────
|
||||
# Prevents Cloudflare and browser caches from storing API responses,
|
||||
# which would cause stale/empty data to be served after backend restarts.
|
||||
class NoCacheAPIMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
response = await call_next(request)
|
||||
if request.url.path.startswith("/api/"):
|
||||
response.headers["Cache-Control"] = "no-store, no-cache, must-revalidate"
|
||||
response.headers["Pragma"] = "no-cache"
|
||||
return response
|
||||
|
||||
app.add_middleware(NoCacheAPIMiddleware)
|
||||
|
||||
|
||||
# ── Domain exception → HTTP mapping ──────────────────────────────────────
|
||||
app.add_exception_handler(DomainError, domain_exception_handler)
|
||||
|
||||
@@ -254,7 +243,8 @@ app.include_router(scores_router.router, prefix="/api/v1")
|
||||
app.include_router(operational_metrics_router.router, prefix="/api/v1")
|
||||
# Call app.include_router()
|
||||
app.include_router(compliance_router.router, prefix="/api/v1")
|
||||
# Call app.include_router()
|
||||
app.include_router(intel_router.router, prefix="/api/v1")
|
||||
app.include_router(admin_config_router.router, prefix="/api/v1")
|
||||
app.include_router(snapshots_router.router, prefix="/api/v1")
|
||||
# Call app.include_router()
|
||||
app.include_router(jira_router.router, prefix="/api/v1")
|
||||
@@ -268,6 +258,16 @@ app.include_router(analytics_router.router, prefix="/api/v1")
|
||||
app.include_router(advanced_metrics_router.router, prefix="/api/v1")
|
||||
# Call app.include_router()
|
||||
app.include_router(osint_router.router, prefix="/api/v1")
|
||||
app.include_router(webhooks_router.router, prefix="/api/v1")
|
||||
app.include_router(detection_lifecycle_router.router, prefix="/api/v1")
|
||||
app.include_router(ownership_router.router, prefix="/api/v1")
|
||||
app.include_router(attack_paths_router.router, prefix="/api/v1")
|
||||
app.include_router(knowledge_router.router, prefix="/api/v1")
|
||||
app.include_router(risk_router.router, prefix="/api/v1")
|
||||
app.include_router(dashboard_router.router, prefix="/api/v1")
|
||||
app.include_router(api_keys_router.router, prefix="/api/v1")
|
||||
app.include_router(sso_router.router, prefix="/api/v1")
|
||||
app.include_router(alerts_router.router, prefix="/api/v1")
|
||||
|
||||
|
||||
# Apply the @app.get decorator
|
||||
|
||||
@@ -1,73 +1,50 @@
|
||||
"""SQLAlchemy ORM model definitions for all database tables."""
|
||||
# Import all models here so Alembic can detect them
|
||||
from app.models.audit import AuditLog
|
||||
|
||||
# Import Campaign, CampaignTest from app.models.campaign
|
||||
from app.models.campaign import Campaign, CampaignTest
|
||||
|
||||
# Import from app.models.compliance
|
||||
from app.models.compliance import (
|
||||
ComplianceControl,
|
||||
ComplianceControlMapping,
|
||||
ComplianceFramework,
|
||||
)
|
||||
|
||||
# Import CoverageSnapshot, SnapshotTechniqueState from app.models.coverage_snapshot
|
||||
from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState
|
||||
|
||||
# Import DataSource from app.models.data_source
|
||||
from app.models.data_source import DataSource
|
||||
|
||||
# Import DefensiveTechnique, DefensiveTechniqueMapping from app.models.defensive_technique
|
||||
from app.models.defensive_technique import DefensiveTechnique, DefensiveTechniqueMapping
|
||||
|
||||
# Import DetectionRule from app.models.detection_rule
|
||||
from app.models.detection_rule import DetectionRule
|
||||
|
||||
# Import TeamSide, TechniqueStatus, TestResult, TestState from app.models.enums
|
||||
from app.models.enums import TeamSide, TechniqueStatus, TestResult, TestState
|
||||
|
||||
# Import Evidence from app.models.evidence
|
||||
from app.models.evidence import Evidence
|
||||
|
||||
# Import IntelItem from app.models.intel
|
||||
from app.models.intel import IntelItem
|
||||
|
||||
# Import JiraLink, JiraLinkEntityType, JiraSyncDirection from app.models.jira_link
|
||||
from app.models.jira_link import JiraLink, JiraLinkEntityType, JiraSyncDirection
|
||||
|
||||
# Import Notification from app.models.notification
|
||||
from app.models.notification import Notification
|
||||
|
||||
# Import OsintItem from app.models.osint_item
|
||||
from app.models.osint_item import OsintItem
|
||||
|
||||
# Import ScoringConfig from app.models.scoring_config
|
||||
from app.models.scoring_config import ScoringConfig
|
||||
|
||||
# Import Technique from app.models.technique
|
||||
from app.models.technique import Technique
|
||||
|
||||
# Import Test from app.models.test
|
||||
from app.models.test import Test
|
||||
|
||||
# Import TestDetectionResult from app.models.test_detection_result
|
||||
from app.models.test_detection_result import TestDetectionResult
|
||||
|
||||
# Import TestTemplate from app.models.test_template
|
||||
from app.models.test_template import TestTemplate
|
||||
|
||||
# Import TestTemplateDetectionRule from app.models.test_template_detection_rule
|
||||
from app.models.test_template_detection_rule import TestTemplateDetectionRule
|
||||
|
||||
# Import ThreatActor, ThreatActorTechnique from app.models.threat_actor
|
||||
from app.models.data_source import DataSource
|
||||
from app.models.detection_rule import DetectionRule
|
||||
from app.models.threat_actor import ThreatActor, ThreatActorTechnique
|
||||
|
||||
# Import User from app.models.user
|
||||
from app.models.user import User
|
||||
|
||||
# Import Worklog from app.models.worklog
|
||||
from app.models.defensive_technique import DefensiveTechnique, DefensiveTechniqueMapping
|
||||
from app.models.test_template_detection_rule import TestTemplateDetectionRule
|
||||
from app.models.test_detection_result import TestDetectionResult
|
||||
from app.models.campaign import Campaign, CampaignTest
|
||||
from app.models.compliance import ComplianceFramework, ComplianceControl, ComplianceControlMapping
|
||||
from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState
|
||||
from app.models.jira_link import JiraLink, JiraLinkEntityType, JiraSyncDirection
|
||||
from app.models.worklog import Worklog
|
||||
from app.models.osint_item import OsintItem
|
||||
from app.models.scoring_config import ScoringConfig
|
||||
from app.models.enums import TechniqueStatus, TestState, TestResult, TeamSide
|
||||
from app.models.webhook_config import WebhookConfig
|
||||
from app.models.system_config import SystemConfig
|
||||
from app.models.detection_lifecycle import (
|
||||
DetectionAsset, DetectionTechniqueMapping, DetectionValidation,
|
||||
TechniqueConfidenceScore, InfrastructureChangeLog,
|
||||
DetectionConfidence, DetectionHealthStatus, InvalidationReason,
|
||||
)
|
||||
from app.models.decay_policy import DecayPolicy
|
||||
from app.models.ownership_queue import (
|
||||
TechniqueOwnership, RevalidationQueueItem,
|
||||
QueuePriority, QueueStatus, QueueReason,
|
||||
)
|
||||
from app.models.attack_path import (
|
||||
AttackPath, AttackPathStep, AttackPathExecution,
|
||||
AttackPathStepResult, TimelineEntry,
|
||||
ExecutionStatus, StepResultStatus, TimelineActorSide, TimelineEntryType,
|
||||
)
|
||||
from app.models.knowledge import Playbook, PlaybookVersion, LessonLearned
|
||||
from app.models.risk_intelligence import TechniqueRiskProfile
|
||||
from app.models.executive_dashboard import PostureSnapshot
|
||||
from app.models.api_key import ApiKey
|
||||
from app.models.sso_config import SsoConfig
|
||||
from app.models.operational_alert import AlertRule, AlertInstance
|
||||
from app.models.evidence import Evidence
|
||||
from app.models.intel import IntelItem
|
||||
from app.models.technique import Technique
|
||||
from app.models.test import Test
|
||||
from app.models.test_template import TestTemplate
|
||||
from app.models.user import User
|
||||
|
||||
# Assign __all__ = [
|
||||
__all__ = [
|
||||
@@ -93,4 +70,20 @@ __all__ = [
|
||||
"Worklog", "OsintItem", "ScoringConfig",
|
||||
# Literal argument value
|
||||
"TechniqueStatus", "TestState", "TestResult", "TeamSide",
|
||||
"WebhookConfig", "SystemConfig",
|
||||
"DetectionAsset", "DetectionTechniqueMapping", "DetectionValidation",
|
||||
"TechniqueConfidenceScore", "InfrastructureChangeLog",
|
||||
"DetectionConfidence", "DetectionHealthStatus", "InvalidationReason", "DecayPolicy",
|
||||
"TechniqueOwnership", "RevalidationQueueItem",
|
||||
"QueuePriority", "QueueStatus", "QueueReason",
|
||||
"AttackPath", "AttackPathStep", "AttackPathExecution",
|
||||
"AttackPathStepResult", "TimelineEntry",
|
||||
"ExecutionStatus", "StepResultStatus", "TimelineActorSide", "TimelineEntryType",
|
||||
"Playbook", "PlaybookVersion", "LessonLearned",
|
||||
"TechniqueRiskProfile",
|
||||
"PostureSnapshot",
|
||||
"ApiKey",
|
||||
"SsoConfig",
|
||||
"AlertRule",
|
||||
"AlertInstance",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,81 @@
|
||||
"""Phase 14: API Key model for programmatic access."""
|
||||
|
||||
import hashlib
|
||||
import secrets
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Index, String, Text
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.database import Base
|
||||
|
||||
# ── Key generation constants ──────────────────────────────────────────────────
|
||||
KEY_PREFIX = "aegis_"
|
||||
KEY_BYTES = 32 # 32 random bytes → 64 hex chars → 70-char key total
|
||||
DISPLAY_LEN = 12 # chars stored as prefix for UI display
|
||||
|
||||
|
||||
def generate_raw_key() -> str:
|
||||
"""Generate a fresh raw API key (must be shown to user only once)."""
|
||||
return KEY_PREFIX + secrets.token_hex(KEY_BYTES)
|
||||
|
||||
|
||||
def hash_key(raw_key: str) -> str:
|
||||
"""SHA-256 hash of a raw API key for secure storage."""
|
||||
return hashlib.sha256(raw_key.encode()).hexdigest()
|
||||
|
||||
|
||||
def key_prefix_display(raw_key: str) -> str:
|
||||
"""First DISPLAY_LEN characters of the raw key (safe for UI)."""
|
||||
return raw_key[:DISPLAY_LEN]
|
||||
|
||||
|
||||
# ── Valid scopes ──────────────────────────────────────────────────────────────
|
||||
VALID_SCOPES = {"read", "write", "admin"}
|
||||
|
||||
|
||||
class ApiKey(Base):
|
||||
"""
|
||||
Scoped API key for programmatic / BI / SOAR access.
|
||||
|
||||
The full raw key is **never stored** — only a SHA-256 hash.
|
||||
The first 12 characters (``key_prefix``) are retained for display.
|
||||
"""
|
||||
|
||||
__tablename__ = "api_keys"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
name = Column(String(200), nullable=False)
|
||||
description = Column(Text, nullable=True)
|
||||
|
||||
# Display only — never use for auth
|
||||
key_prefix = Column(String(DISPLAY_LEN + 1), nullable=False)
|
||||
|
||||
# Auth token — SHA-256 of the full raw key
|
||||
key_hash = Column(String(64), nullable=False, unique=True)
|
||||
|
||||
# Owner
|
||||
user_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Permissions
|
||||
scopes = Column(JSONB, nullable=False, default=["read"]) # ["read","write","admin"]
|
||||
|
||||
# Lifecycle
|
||||
last_used_at = Column(DateTime, nullable=True)
|
||||
expires_at = Column(DateTime, nullable=True)
|
||||
is_active = Column(Boolean, nullable=False, default=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
user = relationship("User", foreign_keys=[user_id])
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_api_keys_user_id", "user_id"),
|
||||
Index("ix_api_keys_key_hash", "key_hash"),
|
||||
Index("ix_api_keys_active", "is_active"),
|
||||
)
|
||||
@@ -0,0 +1,253 @@
|
||||
"""Phase 10: Attack Paths & Advanced Purple Team models."""
|
||||
|
||||
import enum
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import (
|
||||
Boolean, Column, DateTime, Enum, Float, ForeignKey,
|
||||
Index, Integer, String, Text,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class ExecutionStatus(str, enum.Enum):
|
||||
planned = "planned"
|
||||
in_progress = "in_progress"
|
||||
completed = "completed"
|
||||
aborted = "aborted"
|
||||
|
||||
|
||||
class StepResultStatus(str, enum.Enum):
|
||||
pending = "pending"
|
||||
executing = "executing"
|
||||
detected = "detected"
|
||||
not_detected = "not_detected"
|
||||
skipped = "skipped"
|
||||
|
||||
|
||||
class TimelineActorSide(str, enum.Enum):
|
||||
red = "red"
|
||||
blue = "blue"
|
||||
system = "system"
|
||||
|
||||
|
||||
class TimelineEntryType(str, enum.Enum):
|
||||
action = "action"
|
||||
detection = "detection"
|
||||
note = "note"
|
||||
phase_transition = "phase_transition"
|
||||
flag = "flag"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class AttackPath(Base):
|
||||
"""
|
||||
A reusable attack scenario composed of ordered kill-chain steps.
|
||||
Can be a template (shared) or a one-off scenario.
|
||||
"""
|
||||
__tablename__ = "attack_paths"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
name = Column(String(300), nullable=False)
|
||||
description = Column(Text, nullable=True)
|
||||
objective = Column(Text, nullable=True) # what the attacker aims to achieve
|
||||
is_template = Column(Boolean, default=False) # reusable template flag
|
||||
threat_actor_id = Column(
|
||||
UUID(as_uuid=True), ForeignKey("threat_actors.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
created_by = Column(
|
||||
UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
tags = Column(JSONB, nullable=True, default=list)
|
||||
is_active = Column(Boolean, default=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
steps = relationship(
|
||||
"AttackPathStep", back_populates="attack_path",
|
||||
cascade="all, delete-orphan",
|
||||
order_by="AttackPathStep.order_index",
|
||||
)
|
||||
executions = relationship("AttackPathExecution", back_populates="attack_path")
|
||||
creator = relationship("User", foreign_keys=[created_by])
|
||||
threat_actor = relationship("ThreatActor", foreign_keys=[threat_actor_id])
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_attack_paths_created_by", "created_by"),
|
||||
Index("ix_attack_paths_is_template", "is_template"),
|
||||
)
|
||||
|
||||
|
||||
class AttackPathStep(Base):
|
||||
"""One step in an attack path — maps to a kill-chain phase + technique."""
|
||||
|
||||
__tablename__ = "attack_path_steps"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
attack_path_id = Column(
|
||||
UUID(as_uuid=True), ForeignKey("attack_paths.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
order_index = Column(Integer, nullable=False, default=0)
|
||||
kill_chain_phase = Column(String(60), nullable=True) # initial_access, execution, …
|
||||
technique_id = Column(
|
||||
UUID(as_uuid=True), ForeignKey("techniques.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
test_id = Column(
|
||||
UUID(as_uuid=True), ForeignKey("tests.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
name = Column(String(300), nullable=True) # human label for the step
|
||||
description = Column(Text, nullable=True)
|
||||
expected_detection = Column(Boolean, default=True) # do we expect blue to detect this?
|
||||
notes = Column(Text, nullable=True)
|
||||
|
||||
attack_path = relationship("AttackPath", back_populates="steps")
|
||||
technique = relationship("Technique", foreign_keys=[technique_id])
|
||||
test = relationship("Test", foreign_keys=[test_id])
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_ap_steps_path_id", "attack_path_id"),
|
||||
Index("ix_ap_steps_technique_id", "technique_id"),
|
||||
)
|
||||
|
||||
|
||||
class AttackPathExecution(Base):
|
||||
"""
|
||||
A single run of an attack path.
|
||||
Tracks Red/Blue participants, timing, and aggregated kill-chain metrics.
|
||||
"""
|
||||
__tablename__ = "attack_path_executions"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
attack_path_id = Column(
|
||||
UUID(as_uuid=True), ForeignKey("attack_paths.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
status = Column(
|
||||
Enum(ExecutionStatus, name="execution_status"), nullable=False,
|
||||
default=ExecutionStatus.planned,
|
||||
)
|
||||
environment = Column(String(100), nullable=True) # prod, staging, lab
|
||||
red_team_lead = Column(
|
||||
UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
blue_team_lead = Column(
|
||||
UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
started_by = Column(
|
||||
UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
started_at = Column(DateTime, nullable=True)
|
||||
completed_at = Column(DateTime, nullable=True)
|
||||
notes = Column(Text, nullable=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
# ── Computed kill-chain metrics (written on complete) ─────────────────
|
||||
total_steps = Column(Integer, nullable=True)
|
||||
detected_steps = Column(Integer, nullable=True)
|
||||
not_detected_steps = Column(Integer, nullable=True)
|
||||
skipped_steps = Column(Integer, nullable=True)
|
||||
detection_rate = Column(Float, nullable=True) # 0.0–1.0
|
||||
mttd_seconds = Column(Float, nullable=True) # mean time to detect (avg across detected)
|
||||
furthest_undetected_step = Column(Integer, nullable=True) # order_index of deepest undetected step
|
||||
|
||||
attack_path = relationship("AttackPath", back_populates="executions")
|
||||
step_results = relationship(
|
||||
"AttackPathStepResult", back_populates="execution",
|
||||
cascade="all, delete-orphan",
|
||||
order_by="AttackPathStepResult.step_order",
|
||||
)
|
||||
timeline = relationship(
|
||||
"TimelineEntry", back_populates="execution",
|
||||
cascade="all, delete-orphan",
|
||||
order_by="TimelineEntry.timestamp",
|
||||
)
|
||||
red_lead_user = relationship("User", foreign_keys=[red_team_lead])
|
||||
blue_lead_user = relationship("User", foreign_keys=[blue_team_lead])
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_ap_exec_path_id", "attack_path_id"),
|
||||
Index("ix_ap_exec_status", "status"),
|
||||
)
|
||||
|
||||
|
||||
class AttackPathStepResult(Base):
|
||||
"""Result of executing one step in an attack path execution."""
|
||||
|
||||
__tablename__ = "attack_path_step_results"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
execution_id = Column(
|
||||
UUID(as_uuid=True), ForeignKey("attack_path_executions.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
step_id = Column(
|
||||
UUID(as_uuid=True), ForeignKey("attack_path_steps.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
step_order = Column(Integer, nullable=False, default=0) # denormalized for sorting
|
||||
status = Column(
|
||||
Enum(StepResultStatus, name="step_result_status"), nullable=False,
|
||||
default=StepResultStatus.pending,
|
||||
)
|
||||
executed_by = Column(
|
||||
UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
executed_at = Column(DateTime, nullable=True)
|
||||
detected_at = Column(DateTime, nullable=True)
|
||||
time_to_detect_seconds = Column(Float, nullable=True)
|
||||
detection_asset_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("detection_assets.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
notes = Column(Text, nullable=True)
|
||||
evidence_ids = Column(JSONB, nullable=True, default=list)
|
||||
|
||||
execution = relationship("AttackPathExecution", back_populates="step_results")
|
||||
step = relationship("AttackPathStep")
|
||||
detection_asset = relationship("DetectionAsset", foreign_keys=[detection_asset_id])
|
||||
executor = relationship("User", foreign_keys=[executed_by])
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_ap_stepres_execution_id", "execution_id"),
|
||||
Index("ix_ap_stepres_step_id", "step_id"),
|
||||
)
|
||||
|
||||
|
||||
class TimelineEntry(Base):
|
||||
"""Timestamped Red/Blue action during an execution — used for MTTD/MTTR."""
|
||||
|
||||
__tablename__ = "attack_path_timeline"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
execution_id = Column(
|
||||
UUID(as_uuid=True), ForeignKey("attack_path_executions.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
step_id = Column(
|
||||
UUID(as_uuid=True), ForeignKey("attack_path_steps.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
timestamp = Column(DateTime, nullable=False, default=datetime.utcnow)
|
||||
actor_side = Column(
|
||||
Enum(TimelineActorSide, name="timeline_actor_side"), nullable=False,
|
||||
)
|
||||
actor_id = Column(
|
||||
UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
entry_type = Column(
|
||||
Enum(TimelineEntryType, name="timeline_entry_type"), nullable=False,
|
||||
)
|
||||
content = Column(Text, nullable=False)
|
||||
extra = Column(JSONB, nullable=True)
|
||||
|
||||
execution = relationship("AttackPathExecution", back_populates="timeline")
|
||||
actor = relationship("User", foreign_keys=[actor_id])
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_timeline_execution_id", "execution_id"),
|
||||
Index("ix_timeline_timestamp", "timestamp"),
|
||||
)
|
||||
@@ -73,7 +73,7 @@ class Campaign(Base):
|
||||
# Keyword argument: nullable
|
||||
nullable=True,
|
||||
)
|
||||
# Assign scheduled_at = Column(DateTime, nullable=True)
|
||||
start_date = Column(DateTime, nullable=True) # campaign won't activate before this date
|
||||
scheduled_at = Column(DateTime, nullable=True)
|
||||
# Assign completed_at = Column(DateTime, nullable=True)
|
||||
completed_at = Column(DateTime, nullable=True)
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
"""Decay Policy model — configurable detection validity rules."""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from sqlalchemy import Column, String, Integer, Float, Boolean, DateTime, Text
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class DecayPolicy(Base):
|
||||
__tablename__ = "decay_policies"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
name = Column(String(200), nullable=False)
|
||||
description = Column(Text)
|
||||
applies_to_platform = Column(String(100))
|
||||
applies_to_asset_type = Column(String(50))
|
||||
applies_to_tactic = Column(String(100))
|
||||
fresh_days = Column(Integer, default=90, server_default='90')
|
||||
aging_days = Column(Integer, default=180, server_default='180')
|
||||
stale_days = Column(Integer, default=365, server_default='365')
|
||||
default_validity_days = Column(Integer, default=180, server_default='180')
|
||||
silent_threshold_days = Column(Integer, default=30, server_default='30')
|
||||
noisy_threshold_daily = Column(Integer, default=100, server_default='100')
|
||||
recency_weight = Column(Float, default=0.3, server_default='0.3')
|
||||
coverage_weight = Column(Float, default=0.3, server_default='0.3')
|
||||
health_weight = Column(Float, default=0.25, server_default='0.25')
|
||||
diversity_weight = Column(Float, default=0.15, server_default='0.15')
|
||||
is_default = Column(Boolean, default=False, server_default='false')
|
||||
is_active = Column(Boolean, default=True, server_default='true')
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow)
|
||||
@@ -0,0 +1,168 @@
|
||||
"""Detection Lifecycle Management models."""
|
||||
|
||||
import uuid
|
||||
import enum
|
||||
from datetime import datetime
|
||||
from sqlalchemy import (
|
||||
Column, String, Integer, Float, Boolean, DateTime,
|
||||
ForeignKey, Text, Enum as SQLEnum
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
from sqlalchemy.orm import relationship
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class DetectionConfidence(str, enum.Enum):
|
||||
fresh = "fresh"
|
||||
aging = "aging"
|
||||
stale = "stale"
|
||||
broken = "broken"
|
||||
unknown = "unknown"
|
||||
|
||||
|
||||
class DetectionHealthStatus(str, enum.Enum):
|
||||
healthy = "healthy"
|
||||
silent = "silent"
|
||||
noisy = "noisy"
|
||||
orphan = "orphan"
|
||||
deprecated = "deprecated"
|
||||
untested = "untested"
|
||||
|
||||
|
||||
class InvalidationReason(str, enum.Enum):
|
||||
time_decay = "time_decay"
|
||||
mitre_update = "mitre_update"
|
||||
log_source_change = "log_source_change"
|
||||
siem_update = "siem_update"
|
||||
edr_update = "edr_update"
|
||||
infrastructure_change = "infrastructure_change"
|
||||
parser_change = "parser_change"
|
||||
manual = "manual"
|
||||
rule_modified = "rule_modified"
|
||||
|
||||
|
||||
class DetectionAsset(Base):
|
||||
__tablename__ = "detection_assets"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
name = Column(String(500), nullable=False)
|
||||
description = Column(Text)
|
||||
asset_type = Column(String(50), nullable=False)
|
||||
platform = Column(String(100))
|
||||
rule_content = Column(Text)
|
||||
rule_language = Column(String(50))
|
||||
rule_repository_url = Column(Text)
|
||||
rule_file_path = Column(String(500))
|
||||
rule_version = Column(String(50))
|
||||
rule_hash = Column(String(64))
|
||||
last_rule_change_at = Column(DateTime)
|
||||
log_source_name = Column(String(200))
|
||||
log_source_version = Column(String(50))
|
||||
log_source_config = Column(JSONB, server_default='{}')
|
||||
infrastructure_hash = Column(String(64))
|
||||
infrastructure_details = Column(JSONB, server_default='{}')
|
||||
health_status = Column(
|
||||
SQLEnum(DetectionHealthStatus, name="detectionhealthstatus"),
|
||||
default=DetectionHealthStatus.untested,
|
||||
nullable=False,
|
||||
server_default="untested",
|
||||
)
|
||||
last_alert_at = Column(DateTime)
|
||||
alert_count_30d = Column(Integer, default=0, server_default='0')
|
||||
false_positive_rate = Column(Float)
|
||||
expected_alert_frequency = Column(String(50))
|
||||
owner_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True)
|
||||
backup_owner_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True)
|
||||
team = Column(String(100))
|
||||
is_active = Column(Boolean, default=True, nullable=False, server_default='true')
|
||||
tags = Column(JSONB, server_default='[]')
|
||||
asset_metadata = Column(JSONB, server_default='{}')
|
||||
created_by = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True)
|
||||
created_at = Column(DateTime(timezone=True), server_default='now()')
|
||||
updated_at = Column(DateTime(timezone=True), server_default='now()')
|
||||
|
||||
technique_mappings = relationship("DetectionTechniqueMapping", back_populates="detection_asset", cascade="all, delete-orphan")
|
||||
validations = relationship("DetectionValidation", back_populates="detection_asset", cascade="all, delete-orphan")
|
||||
|
||||
|
||||
class DetectionTechniqueMapping(Base):
|
||||
__tablename__ = "detection_technique_mappings"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
detection_asset_id = Column(UUID(as_uuid=True), ForeignKey("detection_assets.id", ondelete="CASCADE"), nullable=False)
|
||||
technique_id = Column(UUID(as_uuid=True), ForeignKey("techniques.id", ondelete="CASCADE"), nullable=False)
|
||||
coverage_type = Column(String(50), default="detect", server_default="detect")
|
||||
confidence_level = Column(String(20), default="medium", server_default="medium")
|
||||
notes = Column(Text)
|
||||
created_at = Column(DateTime(timezone=True), server_default='now()')
|
||||
|
||||
detection_asset = relationship("DetectionAsset", back_populates="technique_mappings")
|
||||
|
||||
|
||||
class DetectionValidation(Base):
|
||||
__tablename__ = "detection_validations"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
detection_asset_id = Column(UUID(as_uuid=True), ForeignKey("detection_assets.id", ondelete="CASCADE"), nullable=False)
|
||||
technique_id = Column(UUID(as_uuid=True), ForeignKey("techniques.id", ondelete="SET NULL"), nullable=True)
|
||||
test_id = Column(UUID(as_uuid=True), ForeignKey("tests.id", ondelete="SET NULL"), nullable=True)
|
||||
validated_at = Column(DateTime, default=datetime.utcnow)
|
||||
expires_at = Column(DateTime, nullable=False)
|
||||
is_valid = Column(Boolean, default=True, nullable=False, server_default='true')
|
||||
validation_result = Column(String(50))
|
||||
validation_method = Column(String(100))
|
||||
rule_hash_at_validation = Column(String(64))
|
||||
log_source_version_at_validation = Column(String(50))
|
||||
infrastructure_hash_at_validation = Column(String(64))
|
||||
environment_snapshot = Column(JSONB, server_default='{}')
|
||||
invalidated_at = Column(DateTime)
|
||||
invalidation_reason = Column(SQLEnum(InvalidationReason, name="invalidationreason"), nullable=True)
|
||||
invalidation_details = Column(Text)
|
||||
invalidated_by = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True)
|
||||
validated_by = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=False)
|
||||
integrity_hash = Column(String(64))
|
||||
notes = Column(Text)
|
||||
evidence_ids = Column(JSONB, server_default='[]')
|
||||
|
||||
detection_asset = relationship("DetectionAsset", back_populates="validations")
|
||||
|
||||
|
||||
class TechniqueConfidenceScore(Base):
|
||||
__tablename__ = "technique_confidence_scores"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
technique_id = Column(UUID(as_uuid=True), ForeignKey("techniques.id", ondelete="CASCADE"), nullable=False, unique=True)
|
||||
confidence_level = Column(
|
||||
SQLEnum(DetectionConfidence, name="detectionconfidence"),
|
||||
default=DetectionConfidence.unknown,
|
||||
server_default="unknown",
|
||||
)
|
||||
confidence_score = Column(Float, default=0.0, server_default='0.0')
|
||||
detection_count = Column(Integer, default=0, server_default='0')
|
||||
valid_detection_count = Column(Integer, default=0, server_default='0')
|
||||
last_validated_at = Column(DateTime)
|
||||
next_validation_due = Column(DateTime)
|
||||
last_recalculated_at = Column(DateTime, default=datetime.utcnow)
|
||||
recency_factor = Column(Float, default=0.0, server_default='0.0')
|
||||
coverage_factor = Column(Float, default=0.0, server_default='0.0')
|
||||
health_factor = Column(Float, default=0.0, server_default='0.0')
|
||||
diversity_factor = Column(Float, default=0.0, server_default='0.0')
|
||||
score_breakdown = Column(JSONB, server_default='{}')
|
||||
risk_factors = Column(JSONB, server_default='[]')
|
||||
updated_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
|
||||
class InfrastructureChangeLog(Base):
|
||||
__tablename__ = "infrastructure_change_logs"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
change_type = Column(String(100), nullable=False)
|
||||
description = Column(Text, nullable=False)
|
||||
affected_platforms = Column(JSONB, server_default='[]')
|
||||
affected_log_sources = Column(JSONB, server_default='[]')
|
||||
change_date = Column(DateTime, default=datetime.utcnow)
|
||||
reported_by = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True)
|
||||
auto_invalidate = Column(Boolean, default=True, server_default='true')
|
||||
invalidated_count = Column(Integer, default=0, server_default='0')
|
||||
change_metadata = Column(JSONB, server_default='{}')
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
@@ -0,0 +1,34 @@
|
||||
"""SQLAlchemy model for tracking imported ATT&CK Evaluation rounds."""
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Column, String, Integer, DateTime, Text, ForeignKey, Index
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class EvaluationImport(Base):
|
||||
"""Tracks which ATT&CK Evaluation rounds have been imported into the platform.
|
||||
|
||||
Each row represents one vendor+adversary combination that has been processed
|
||||
and turned into Test records. Used to avoid duplicate imports and to show
|
||||
the admin panel which rounds are available vs imported.
|
||||
"""
|
||||
__tablename__ = "evaluation_imports"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
adversary_name = Column(String, nullable=False) # "apt29", "turla"
|
||||
adversary_display = Column(String, nullable=False) # "APT29", "Turla"
|
||||
eval_round = Column(Integer, nullable=False) # 1, 2, 3 …
|
||||
imported_at = Column(DateTime, nullable=False, default=datetime.utcnow)
|
||||
imported_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
|
||||
tests_created = Column(Integer, default=0)
|
||||
techniques_covered = Column(Integer, default=0)
|
||||
status = Column(String, default="completed") # "completed" | "failed"
|
||||
notes = Column(Text, nullable=True)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_evaluation_imports_adversary", "adversary_name"),
|
||||
Index("ix_evaluation_imports_round", "eval_round"),
|
||||
)
|
||||
@@ -0,0 +1,68 @@
|
||||
"""Phase 13: Executive Dashboard — PostureSnapshot model."""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import (
|
||||
Column, Date, DateTime, Float, ForeignKey,
|
||||
Index, Integer, UniqueConstraint,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class PostureSnapshot(Base):
|
||||
"""
|
||||
Daily point-in-time capture of the organisation's security posture.
|
||||
|
||||
Aggregates data from all phases (coverage, risk, ownership, knowledge,
|
||||
attack-paths) into a single row that can be trended over time.
|
||||
"""
|
||||
|
||||
__tablename__ = "posture_snapshots"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
snapshot_date = Column(Date, nullable=False) # one per calendar day
|
||||
|
||||
# ── Coverage ──────────────────────────────────────────────────────────────
|
||||
total_techniques = Column(Integer, nullable=False, default=0)
|
||||
validated_count = Column(Integer, nullable=False, default=0)
|
||||
partial_count = Column(Integer, nullable=False, default=0)
|
||||
not_covered_count = Column(Integer, nullable=False, default=0)
|
||||
coverage_pct = Column(Float, nullable=False, default=0.0) # 0–100
|
||||
|
||||
# ── Risk ─────────────────────────────────────────────────────────────────
|
||||
avg_risk_score = Column(Float, nullable=False, default=0.0)
|
||||
critical_count = Column(Integer, nullable=False, default=0)
|
||||
high_count = Column(Integer, nullable=False, default=0)
|
||||
medium_count = Column(Integer, nullable=False, default=0)
|
||||
low_count = Column(Integer, nullable=False, default=0)
|
||||
|
||||
# ── Operations ────────────────────────────────────────────────────────────
|
||||
open_queue_items = Column(Integer, nullable=False, default=0)
|
||||
orphan_techniques = Column(Integer, nullable=False, default=0)
|
||||
|
||||
# ── Knowledge ─────────────────────────────────────────────────────────────
|
||||
playbook_count = Column(Integer, nullable=False, default=0)
|
||||
lesson_count = Column(Integer, nullable=False, default=0)
|
||||
|
||||
# ── MTTD (from attack-path executions completed in last 30 d) ────────────
|
||||
mttd_avg_seconds = Column(Float, nullable=True) # None if no data
|
||||
executions_30d = Column(Integer, nullable=False, default=0)
|
||||
detection_rate_30d = Column(Float, nullable=True) # avg across executions
|
||||
|
||||
# ── Meta ─────────────────────────────────────────────────────────────────
|
||||
created_by = Column(
|
||||
UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
extra = Column(JSONB, nullable=True) # full breakdown / by-tactic
|
||||
|
||||
creator = relationship("User", foreign_keys=[created_by])
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("snapshot_date", name="uq_posture_snapshot_date"),
|
||||
Index("ix_posture_snapshots_date", "snapshot_date"),
|
||||
)
|
||||
@@ -0,0 +1,129 @@
|
||||
"""Phase 11: Knowledge Management models — Playbooks + Lessons Learned."""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import (
|
||||
Boolean, Column, DateTime, ForeignKey,
|
||||
Index, Integer, String, Text, UniqueConstraint,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
# ── Playbooks ──────────────────────────────────────────────────────────────────
|
||||
|
||||
class Playbook(Base):
|
||||
"""
|
||||
Structured runbook for a specific technique and playbook type.
|
||||
|
||||
playbook_type: attack | detect | investigate | respond | hunt
|
||||
One playbook per (technique, type). Edits increment ``version``
|
||||
and save a snapshot to ``PlaybookVersion``.
|
||||
"""
|
||||
__tablename__ = "playbooks"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
technique_id = Column(
|
||||
UUID(as_uuid=True), ForeignKey("techniques.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
playbook_type = Column(String(32), nullable=False) # attack/detect/investigate/respond/hunt
|
||||
title = Column(String(255), nullable=False)
|
||||
content = Column(Text, nullable=False, default="")
|
||||
version = Column(Integer, default=1, nullable=False)
|
||||
tools = Column(JSONB, default=list) # list of tool name strings
|
||||
prerequisites = Column(JSONB, default=list) # list of prerequisite strings
|
||||
created_by = Column(
|
||||
UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
updated_by = Column(
|
||||
UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
is_active = Column(Boolean, default=True)
|
||||
|
||||
# Relationships
|
||||
technique = relationship("Technique", foreign_keys=[technique_id])
|
||||
creator = relationship("User", foreign_keys=[created_by])
|
||||
updater = relationship("User", foreign_keys=[updated_by])
|
||||
versions = relationship(
|
||||
"PlaybookVersion", back_populates="playbook",
|
||||
cascade="all, delete-orphan",
|
||||
order_by="PlaybookVersion.version.desc()",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("technique_id", "playbook_type", name="uq_playbook_technique_type"),
|
||||
Index("ix_playbooks_technique_id", "technique_id"),
|
||||
Index("ix_playbooks_type", "playbook_type"),
|
||||
)
|
||||
|
||||
|
||||
class PlaybookVersion(Base):
|
||||
"""Immutable snapshot of a playbook at a given version number."""
|
||||
|
||||
__tablename__ = "playbook_versions"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
playbook_id = Column(
|
||||
UUID(as_uuid=True), ForeignKey("playbooks.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
version = Column(Integer, nullable=False)
|
||||
title = Column(String(255), nullable=False)
|
||||
content = Column(Text, nullable=False, default="")
|
||||
tools = Column(JSONB, default=list)
|
||||
prerequisites = Column(JSONB, default=list)
|
||||
changed_by = Column(
|
||||
UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
change_note = Column(String(500), nullable=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
playbook = relationship("Playbook", back_populates="versions")
|
||||
changer = relationship("User", foreign_keys=[changed_by])
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_pb_versions_playbook_id", "playbook_id"),
|
||||
Index("ix_pb_versions_version", "playbook_id", "version"),
|
||||
)
|
||||
|
||||
|
||||
# ── Lessons Learned ────────────────────────────────────────────────────────────
|
||||
|
||||
class LessonLearned(Base):
|
||||
"""
|
||||
Immutable post-mortem record linked to a test, campaign, attack-path or
|
||||
created manually.
|
||||
|
||||
severity: critical | high | medium | low | info
|
||||
entity_type: test | campaign | attack_path | manual
|
||||
"""
|
||||
__tablename__ = "lessons_learned"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
title = Column(String(255), nullable=False)
|
||||
what_happened = Column(Text, nullable=False, default="")
|
||||
root_cause = Column(Text, nullable=False, default="")
|
||||
fix_applied = Column(Text, nullable=True)
|
||||
severity = Column(String(16), nullable=False, default="medium")
|
||||
entity_type = Column(String(32), nullable=False, default="manual")
|
||||
entity_id = Column(UUID(as_uuid=True), nullable=True)
|
||||
technique_ids = Column(JSONB, default=list) # list of UUID strings
|
||||
tags = Column(JSONB, default=list)
|
||||
created_by = Column(
|
||||
UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
is_active = Column(Boolean, default=True) # soft-delete (admin only)
|
||||
|
||||
creator = relationship("User", foreign_keys=[created_by])
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_ll_entity", "entity_type", "entity_id"),
|
||||
Index("ix_ll_severity", "severity"),
|
||||
Index("ix_ll_created_by", "created_by"),
|
||||
)
|
||||
@@ -0,0 +1,144 @@
|
||||
"""Phase 13: Operational Alerts — AlertRule and AlertInstance models."""
|
||||
|
||||
import enum
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import (
|
||||
Boolean, Column, DateTime, ForeignKey,
|
||||
Index, Integer, String, Text,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
# ── Enumerations ──────────────────────────────────────────────────────────────
|
||||
|
||||
class AlertSeverity(str, enum.Enum):
|
||||
critical = "critical"
|
||||
high = "high"
|
||||
medium = "medium"
|
||||
low = "low"
|
||||
info = "info"
|
||||
|
||||
|
||||
class AlertStatus(str, enum.Enum):
|
||||
open = "open"
|
||||
acknowledged = "acknowledged"
|
||||
resolved = "resolved"
|
||||
dismissed = "dismissed"
|
||||
|
||||
|
||||
class AlertRuleType(str, enum.Enum):
|
||||
high_risk = "high_risk" # risk_score >= threshold
|
||||
stale_technique = "stale_technique" # not validated in N days
|
||||
coverage_regression = "coverage_regression" # coverage_pct dropped
|
||||
low_coverage = "low_coverage" # coverage below min
|
||||
expiry_wave = "expiry_wave" # many pending queue items
|
||||
new_technique = "new_technique" # new MITRE techniques added
|
||||
orphan_spike = "orphan_spike" # many unowned techniques
|
||||
custom = "custom" # future extension placeholder
|
||||
|
||||
|
||||
# ── AlertRule ─────────────────────────────────────────────────────────────────
|
||||
|
||||
class AlertRule(Base):
|
||||
"""
|
||||
Defines a condition that, when satisfied, fires an AlertInstance.
|
||||
|
||||
System rules (is_system=True) are seeded at startup and cannot be deleted.
|
||||
Custom rules (is_system=False) can be created by admins.
|
||||
"""
|
||||
|
||||
__tablename__ = "alert_rules"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
name = Column(String(300), nullable=False)
|
||||
description = Column(Text, nullable=True)
|
||||
rule_type = Column(String(50), nullable=False)
|
||||
severity = Column(String(20), nullable=False, default=AlertSeverity.medium.value)
|
||||
is_enabled = Column(Boolean, nullable=False, default=True)
|
||||
is_system = Column(Boolean, nullable=False, default=False) # seeded, not deletable
|
||||
|
||||
# Rule-specific thresholds/config (varies by rule_type)
|
||||
config = Column(JSONB, nullable=False, default={})
|
||||
|
||||
# Delivery
|
||||
notify_in_app = Column(Boolean, nullable=False, default=True)
|
||||
notify_webhook = Column(Boolean, nullable=False, default=False)
|
||||
webhook_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("webhook_configs.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
# Cooldown — don't re-fire within N hours of last firing
|
||||
cooldown_hours = Column(Integer, nullable=False, default=24)
|
||||
|
||||
# Meta
|
||||
created_by = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("users.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
last_fired_at = Column(DateTime, nullable=True)
|
||||
|
||||
creator = relationship("User", foreign_keys=[created_by])
|
||||
instances = relationship("AlertInstance", back_populates="rule",
|
||||
cascade="all, delete-orphan")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_alert_rules_type", "rule_type"),
|
||||
Index("ix_alert_rules_enabled", "is_enabled"),
|
||||
)
|
||||
|
||||
|
||||
# ── AlertInstance ─────────────────────────────────────────────────────────────
|
||||
|
||||
class AlertInstance(Base):
|
||||
"""
|
||||
A single firing of an AlertRule.
|
||||
|
||||
Transitions: open → acknowledged → resolved
|
||||
open → dismissed
|
||||
"""
|
||||
|
||||
__tablename__ = "alert_instances"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
rule_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("alert_rules.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
# Denormalised fields kept for history even after rule deletion
|
||||
rule_name = Column(String(300), nullable=False)
|
||||
rule_type = Column(String(50), nullable=False)
|
||||
severity = Column(String(20), nullable=False)
|
||||
|
||||
title = Column(String(500), nullable=False)
|
||||
message = Column(Text, nullable=False)
|
||||
details = Column(JSONB, nullable=True) # structured context
|
||||
|
||||
status = Column(String(20), nullable=False, default=AlertStatus.open.value)
|
||||
acknowledged_by = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("users.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
acknowledged_at = Column(DateTime, nullable=True)
|
||||
resolved_at = Column(DateTime, nullable=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
rule = relationship("AlertRule", back_populates="instances")
|
||||
acknowledger = relationship("User", foreign_keys=[acknowledged_by])
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_alert_instances_rule_id", "rule_id"),
|
||||
Index("ix_alert_instances_status", "status"),
|
||||
Index("ix_alert_instances_severity", "severity"),
|
||||
Index("ix_alert_instances_created", "created_at"),
|
||||
)
|
||||
@@ -0,0 +1,136 @@
|
||||
"""Phase 9: Ownership & Revalidation Queue models."""
|
||||
|
||||
import enum
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Column, DateTime, Enum, ForeignKey, Index, String, Text
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class QueuePriority(str, enum.Enum):
|
||||
critical = "critical"
|
||||
high = "high"
|
||||
medium = "medium"
|
||||
low = "low"
|
||||
|
||||
|
||||
class QueueStatus(str, enum.Enum):
|
||||
pending = "pending"
|
||||
in_progress = "in_progress"
|
||||
completed = "completed"
|
||||
dismissed = "dismissed"
|
||||
|
||||
|
||||
class QueueReason(str, enum.Enum):
|
||||
validation_expired = "validation_expired"
|
||||
infra_change = "infra_change"
|
||||
osint_alert = "osint_alert"
|
||||
mitre_update = "mitre_update"
|
||||
rule_modified = "rule_modified"
|
||||
low_confidence = "low_confidence"
|
||||
manual = "manual"
|
||||
|
||||
|
||||
class TechniqueOwnership(Base):
|
||||
"""Ownership assignment for a MITRE technique."""
|
||||
|
||||
__tablename__ = "technique_ownerships"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
technique_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("techniques.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
unique=True,
|
||||
)
|
||||
owner_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("users.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
backup_owner_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("users.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
team = Column(String(200), nullable=True)
|
||||
notes = Column(Text, nullable=True)
|
||||
assigned_at = Column(DateTime, nullable=True)
|
||||
assigned_by = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("users.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
technique = relationship("Technique", foreign_keys=[technique_id])
|
||||
owner = relationship("User", foreign_keys=[owner_id])
|
||||
backup_owner = relationship("User", foreign_keys=[backup_owner_id])
|
||||
|
||||
|
||||
class RevalidationQueueItem(Base):
|
||||
"""A prioritised work item for the analyst's daily queue."""
|
||||
|
||||
__tablename__ = "revalidation_queue_items"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
technique_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("techniques.id", ondelete="CASCADE"),
|
||||
nullable=True,
|
||||
)
|
||||
detection_asset_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("detection_assets.id", ondelete="CASCADE"),
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
priority = Column(
|
||||
Enum(QueuePriority, name="queue_priority"),
|
||||
nullable=False,
|
||||
default=QueuePriority.medium,
|
||||
)
|
||||
reason = Column(
|
||||
Enum(QueueReason, name="queue_reason"),
|
||||
nullable=False,
|
||||
)
|
||||
reason_detail = Column(Text, nullable=True)
|
||||
status = Column(
|
||||
Enum(QueueStatus, name="queue_status"),
|
||||
nullable=False,
|
||||
default=QueueStatus.pending,
|
||||
)
|
||||
|
||||
assigned_to = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("users.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
due_date = Column(DateTime, nullable=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
completed_at = Column(DateTime, nullable=True)
|
||||
dismissed_at = Column(DateTime, nullable=True)
|
||||
completed_by = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("users.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
extra = Column(JSONB, nullable=True) # arbitrary metadata
|
||||
|
||||
technique = relationship("Technique", foreign_keys=[technique_id])
|
||||
detection_asset = relationship("DetectionAsset", foreign_keys=[detection_asset_id])
|
||||
assignee = relationship("User", foreign_keys=[assigned_to])
|
||||
|
||||
|
||||
# Indexes
|
||||
Index("ix_rqueue_status", RevalidationQueueItem.status)
|
||||
Index("ix_rqueue_priority", RevalidationQueueItem.priority)
|
||||
Index("ix_rqueue_assigned_to", RevalidationQueueItem.assigned_to)
|
||||
Index("ix_rqueue_technique_id", RevalidationQueueItem.technique_id)
|
||||
Index("ix_rqueue_asset_id", RevalidationQueueItem.detection_asset_id)
|
||||
Index("ix_techown_owner_id", TechniqueOwnership.owner_id)
|
||||
@@ -0,0 +1,69 @@
|
||||
"""Phase 12: Risk Intelligence model — per-technique risk scoring."""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import (
|
||||
Boolean, Column, DateTime, Float, ForeignKey,
|
||||
Index, Integer, String, UniqueConstraint,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class TechniqueRiskProfile(Base):
|
||||
"""
|
||||
Aggregated risk profile for one technique.
|
||||
|
||||
Combines four weighted factors:
|
||||
• detection_gap (35 %) — 0=fully covered → 1=no coverage
|
||||
• threat_actor_rel (30 %) — normalised actor count
|
||||
• osint_signals (20 %) — normalised recent OSINT items (30 d)
|
||||
• test_failure_rate (15 %) — proportion of tests where blue didn't detect
|
||||
|
||||
risk_score = weighted sum × 100 → 0–100
|
||||
risk_level: critical ≥75 | high ≥50 | medium ≥25 | low ≥10 | info
|
||||
"""
|
||||
|
||||
__tablename__ = "technique_risk_profiles"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
technique_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("techniques.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# ── Computed scores ───────────────────────────────────────────────────────
|
||||
risk_score = Column(Float, nullable=False, default=0.0) # 0–100
|
||||
likelihood = Column(Float, nullable=False, default=0.0) # 0–100
|
||||
impact = Column(Float, nullable=False, default=0.0) # 0–100
|
||||
risk_level = Column(String(16), nullable=False, default="info")
|
||||
|
||||
# ── Raw factor values ─────────────────────────────────────────────────────
|
||||
detection_gap = Column(Float, nullable=False, default=1.0) # 0–1
|
||||
threat_actor_count = Column(Integer, nullable=False, default=0)
|
||||
osint_signal_count = Column(Integer, nullable=False, default=0) # last 30 d
|
||||
test_fail_count = Column(Integer, nullable=False, default=0)
|
||||
test_total_count = Column(Integer, nullable=False, default=0)
|
||||
test_failure_rate = Column(Float, nullable=False, default=0.0) # 0–1
|
||||
confidence_level = Column(Float, nullable=False, default=0.0) # DLC 0–1
|
||||
|
||||
# ── Rich detail ──────────────────────────────────────────────────────────
|
||||
scoring_breakdown = Column(JSONB, nullable=True) # per-factor contributions
|
||||
recommendations = Column(JSONB, nullable=True) # list[str]
|
||||
|
||||
# ── Meta ─────────────────────────────────────────────────────────────────
|
||||
computed_at = Column(DateTime, default=datetime.utcnow)
|
||||
is_stale = Column(Boolean, default=True)
|
||||
|
||||
technique = relationship("Technique", foreign_keys=[technique_id])
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("technique_id", name="uq_risk_profile_technique"),
|
||||
Index("ix_risk_profiles_risk_score", "risk_score"),
|
||||
Index("ix_risk_profiles_risk_level", "risk_level"),
|
||||
Index("ix_risk_profiles_stale", "is_stale"),
|
||||
)
|
||||
@@ -0,0 +1,49 @@
|
||||
"""Phase 14: SSO / SAML 2.0 configuration model."""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Boolean, Column, DateTime, String, Text
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class SsoConfig(Base):
|
||||
"""
|
||||
SAML 2.0 Identity Provider configuration.
|
||||
|
||||
Exactly one row is expected (use upsert). The SP metadata endpoint
|
||||
reads from this row to generate XML for IdP registration.
|
||||
"""
|
||||
|
||||
__tablename__ = "sso_configs"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
is_enabled = Column(Boolean, nullable=False, default=False)
|
||||
provider_name = Column(String(200), nullable=True) # e.g., "Okta", "Azure AD"
|
||||
|
||||
# ── Service Provider (Aegis) settings ────────────────────────────────────
|
||||
sp_entity_id = Column(String(500), nullable=True) # e.g., https://aegis.co/api/v1/sso/metadata
|
||||
sp_acs_url = Column(String(500), nullable=True) # Assertion Consumer Service URL
|
||||
sp_slo_url = Column(String(500), nullable=True) # Single Logout URL (optional)
|
||||
sp_certificate = Column(Text, nullable=True) # SP public cert for signed requests
|
||||
sp_private_key = Column(Text, nullable=True) # SP private key (stored encrypted in future)
|
||||
|
||||
# ── Identity Provider settings ────────────────────────────────────────────
|
||||
idp_entity_id = Column(String(500), nullable=True)
|
||||
idp_sso_url = Column(String(500), nullable=True) # IdP redirect/POST binding URL
|
||||
idp_slo_url = Column(String(500), nullable=True) # IdP SLO URL
|
||||
idp_certificate = Column(Text, nullable=True) # IdP X.509 cert for response validation
|
||||
|
||||
# ── Attribute mapping ─────────────────────────────────────────────────────
|
||||
# SAML attribute name → Aegis field
|
||||
attr_email = Column(String(200), nullable=True, default="email")
|
||||
attr_username = Column(String(200), nullable=True, default="username")
|
||||
attr_role = Column(String(200), nullable=True, default="role")
|
||||
default_role = Column(String(50), nullable=True, default="viewer")
|
||||
auto_provision = Column(Boolean, nullable=False, default=True) # create user on first login
|
||||
|
||||
# ── Meta ─────────────────────────────────────────────────────────────────
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
@@ -0,0 +1,26 @@
|
||||
"""SystemConfig model — runtime key-value configuration store."""
|
||||
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import Column, String, Text, DateTime, func
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class SystemConfig(Base):
|
||||
"""Generic key-value store for runtime system configuration.
|
||||
|
||||
Currently used for:
|
||||
- SMTP email settings (overrides .env values when present)
|
||||
|
||||
Keys are namespaced by convention: ``smtp.host``, ``smtp.port``, etc.
|
||||
"""
|
||||
|
||||
__tablename__ = "system_configs"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
key = Column(String(200), unique=True, nullable=False, index=True)
|
||||
value = Column(Text, nullable=True)
|
||||
description = Column(String(500), nullable=True)
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
|
||||
@@ -96,7 +96,7 @@ class Test(Base):
|
||||
red_started_at = Column(DateTime, nullable=True)
|
||||
# Assign blue_started_at = Column(DateTime, nullable=True)
|
||||
blue_started_at = Column(DateTime, nullable=True)
|
||||
# Assign paused_at = Column(DateTime, nullable=True)
|
||||
blue_work_started_at = Column(DateTime, nullable=True) # when blue tech picks up (Tempo start)
|
||||
paused_at = Column(DateTime, nullable=True)
|
||||
# Assign red_paused_seconds = Column(Integer, default=0)
|
||||
red_paused_seconds = Column(Integer, default=0)
|
||||
|
||||
@@ -2,12 +2,8 @@
|
||||
|
||||
# Import uuid
|
||||
import uuid
|
||||
|
||||
# Import Boolean, Column, DateTime, String, func from sqlalchemy
|
||||
from sqlalchemy import Boolean, Column, DateTime, String, func
|
||||
|
||||
# Import UUID from sqlalchemy.dialects.postgresql
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy import Column, String, Boolean, DateTime, func
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
|
||||
# Import Base from app.database
|
||||
from app.database import Base
|
||||
@@ -46,3 +42,8 @@ class User(Base):
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
# Assign last_login = Column(DateTime, nullable=True)
|
||||
last_login = Column(DateTime, nullable=True)
|
||||
notification_preferences = Column(JSONB, nullable=True, server_default='{"email_on_test_validated": true, "email_on_campaign_completed": true, "email_on_new_mitre_techniques": false, "in_app_all": true}')
|
||||
jira_account_id = Column(String(100), nullable=True)
|
||||
jira_api_token = Column(String(500), nullable=True) # personal Atlassian token
|
||||
jira_email = Column(String(255), nullable=True) # Atlassian email (overrides account email)
|
||||
tempo_api_token = Column(String(500), nullable=True) # personal Tempo API token
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
"""WebhookConfig model — outbound HTTP notification endpoints."""
|
||||
import uuid
|
||||
from sqlalchemy import Column, String, Boolean, DateTime, Integer, Text, ForeignKey, func
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
from app.database import Base
|
||||
|
||||
class WebhookConfig(Base):
|
||||
__tablename__ = "webhook_configs"
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
name = Column(String(200), nullable=False)
|
||||
url = Column(Text, nullable=False)
|
||||
secret = Column(String(256), nullable=True) # HMAC signature key
|
||||
events = Column(JSONB, nullable=False, server_default="[]") # list of event types
|
||||
is_active = Column(Boolean, default=True, nullable=False)
|
||||
created_by = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True)
|
||||
last_triggered_at = Column(DateTime, nullable=True)
|
||||
failure_count = Column(Integer, default=0, nullable=False)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
@@ -0,0 +1,339 @@
|
||||
"""Admin configuration export/import — single-file migration bundle.
|
||||
|
||||
GET /admin/export-config — download JSON bundle (admin only)
|
||||
POST /admin/import-config — upload JSON bundle and restore (admin only)
|
||||
|
||||
What is exported (and what is NOT):
|
||||
✓ system_configs — email / jira settings (passwords REDACTED)
|
||||
✓ webhook_configs — notification webhooks (secrets REDACTED)
|
||||
✓ sso_configs — SAML/SSO config (private keys REDACTED)
|
||||
✓ scoring_config — technique scoring weights
|
||||
✓ test_templates — CUSTOM templates only (source='custom')
|
||||
✓ users — username / email / role (no passwords / tokens)
|
||||
✗ atomic/sigma/elastic templates, techniques, tests, campaigns, reports
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.auth import hash_password
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import require_role
|
||||
from app.models.scoring_config import ScoringConfig
|
||||
from app.models.sso_config import SsoConfig
|
||||
from app.models.system_config import SystemConfig
|
||||
from app.models.test_template import TestTemplate
|
||||
from app.models.user import User
|
||||
from app.models.webhook_config import WebhookConfig
|
||||
|
||||
router = APIRouter(prefix="/admin", tags=["admin"])
|
||||
|
||||
# Keys whose values contain secrets and must be redacted in the export
|
||||
_REDACTED_KEYS = {
|
||||
"smtp.password",
|
||||
"jira.api_token",
|
||||
"jira.password",
|
||||
"tempo.api_token",
|
||||
}
|
||||
|
||||
_EXPORT_VERSION = "1.0"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _redact(key: str, value: Any) -> Any:
|
||||
if key in _REDACTED_KEYS:
|
||||
return "[REDACTED]"
|
||||
return value
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /admin/export-config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get("/export-config")
|
||||
def export_config(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Export all platform configuration as a downloadable JSON bundle."""
|
||||
|
||||
# ── 1. system_configs ────────────────────────────────────────────
|
||||
system_configs = [
|
||||
{
|
||||
"key": r.key,
|
||||
"value": _redact(r.key, r.value),
|
||||
"description": r.description,
|
||||
}
|
||||
for r in db.query(SystemConfig).order_by(SystemConfig.key).all()
|
||||
]
|
||||
|
||||
# ── 2. webhook_configs ───────────────────────────────────────────
|
||||
webhooks = [
|
||||
{
|
||||
"name": w.name,
|
||||
"url": w.url,
|
||||
"secret": "[REDACTED]" if w.secret else None,
|
||||
"events": w.events or [],
|
||||
"is_active": w.is_active,
|
||||
}
|
||||
for w in db.query(WebhookConfig).order_by(WebhookConfig.name).all()
|
||||
]
|
||||
|
||||
# ── 3. SSO config (single row) ───────────────────────────────────
|
||||
sso_row = db.query(SsoConfig).first()
|
||||
sso = None
|
||||
if sso_row:
|
||||
sso = {
|
||||
"is_enabled": sso_row.is_enabled,
|
||||
"provider_name": sso_row.provider_name,
|
||||
"sp_entity_id": sso_row.sp_entity_id,
|
||||
"sp_acs_url": sso_row.sp_acs_url,
|
||||
"sp_slo_url": sso_row.sp_slo_url,
|
||||
"sp_certificate": sso_row.sp_certificate,
|
||||
"sp_private_key": "[REDACTED]", # never export private keys
|
||||
"idp_entity_id": sso_row.idp_entity_id,
|
||||
"idp_sso_url": getattr(sso_row, "idp_sso_url", None),
|
||||
"idp_slo_url": getattr(sso_row, "idp_slo_url", None),
|
||||
"idp_certificate": getattr(sso_row, "idp_certificate", None),
|
||||
"attr_email": getattr(sso_row, "attr_email", None),
|
||||
"attr_username": getattr(sso_row, "attr_username", None),
|
||||
"attr_role": getattr(sso_row, "attr_role", None),
|
||||
"default_role": getattr(sso_row, "default_role", None),
|
||||
"auto_provision": getattr(sso_row, "auto_provision", False),
|
||||
}
|
||||
|
||||
# ── 4. Scoring config (single row) ──────────────────────────────
|
||||
sc = db.query(ScoringConfig).first()
|
||||
scoring = None
|
||||
if sc:
|
||||
scoring = {
|
||||
"weight_tests": sc.weight_tests,
|
||||
"weight_detection_rules": sc.weight_detection_rules,
|
||||
"weight_d3fend": sc.weight_d3fend,
|
||||
"weight_recency": sc.weight_recency,
|
||||
"weight_severity": sc.weight_severity,
|
||||
}
|
||||
|
||||
# ── 5. Custom test templates only ───────────────────────────────
|
||||
templates = [
|
||||
{
|
||||
"mitre_technique_id": t.mitre_technique_id,
|
||||
"name": t.name,
|
||||
"description": t.description,
|
||||
"source": t.source,
|
||||
"source_url": t.source_url,
|
||||
"attack_procedure": t.attack_procedure,
|
||||
"expected_detection": t.expected_detection,
|
||||
"platform": t.platform,
|
||||
"tool_suggested": t.tool_suggested,
|
||||
"severity": t.severity,
|
||||
"suggested_remediation": t.suggested_remediation,
|
||||
"is_active": t.is_active,
|
||||
}
|
||||
for t in db.query(TestTemplate).filter(TestTemplate.source == "custom").all()
|
||||
]
|
||||
|
||||
# ── 6. Users (sanitized — no passwords/tokens) ───────────────────
|
||||
users = [
|
||||
{
|
||||
"username": u.username,
|
||||
"email": u.email if hasattr(u, "email") else None,
|
||||
"role": u.role,
|
||||
"is_active": u.is_active,
|
||||
"must_change_password": True, # force password reset on new instance # nosec B105
|
||||
}
|
||||
for u in db.query(User).order_by(User.username).all()
|
||||
]
|
||||
|
||||
bundle = {
|
||||
"_meta": {
|
||||
"version": _EXPORT_VERSION,
|
||||
"exported_at": datetime.utcnow().isoformat() + "Z",
|
||||
"exported_by": current_user.username,
|
||||
"note": (
|
||||
"Sensitive values (passwords, API tokens, private keys) are REDACTED. "
|
||||
"Re-enter them manually after import. "
|
||||
"User passwords are NOT exported — users must reset passwords on first login."
|
||||
),
|
||||
},
|
||||
"system_configs": system_configs,
|
||||
"webhooks": webhooks,
|
||||
"sso": sso,
|
||||
"scoring": scoring,
|
||||
"custom_templates": templates,
|
||||
"users": users,
|
||||
}
|
||||
|
||||
filename = f"aegis-config-{datetime.utcnow().strftime('%Y%m%d-%H%M%S')}.json"
|
||||
return JSONResponse(
|
||||
content=bundle,
|
||||
headers={
|
||||
"Content-Disposition": f'attachment; filename="{filename}"',
|
||||
"X-Export-Version": _EXPORT_VERSION,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /admin/import-config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.post("/import-config")
|
||||
async def import_config(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Restore platform configuration from a previously exported JSON bundle.
|
||||
|
||||
Idempotent: safe to run multiple times. Existing records are updated,
|
||||
missing ones are created. REDACTED values are skipped (left as-is).
|
||||
User passwords are set to a random temp value with must_change_password=True.
|
||||
"""
|
||||
try:
|
||||
bundle = await request.json()
|
||||
except Exception:
|
||||
raise HTTPException(status_code=400, detail="Invalid JSON body")
|
||||
|
||||
meta = bundle.get("_meta", {})
|
||||
version = meta.get("version", "unknown")
|
||||
summary: dict[str, int] = {
|
||||
"system_configs": 0,
|
||||
"webhooks": 0,
|
||||
"custom_templates": 0,
|
||||
"users_created": 0,
|
||||
"users_updated": 0,
|
||||
}
|
||||
|
||||
# ── 1. system_configs ────────────────────────────────────────────
|
||||
for item in bundle.get("system_configs", []):
|
||||
key = item.get("key")
|
||||
value = item.get("value")
|
||||
if not key or value == "[REDACTED]":
|
||||
continue
|
||||
row = db.query(SystemConfig).filter(SystemConfig.key == key).first()
|
||||
if row:
|
||||
row.value = value
|
||||
row.description = item.get("description") or row.description
|
||||
else:
|
||||
db.add(SystemConfig(key=key, value=value, description=item.get("description")))
|
||||
summary["system_configs"] += 1
|
||||
|
||||
# ── 2. webhooks ──────────────────────────────────────────────────
|
||||
for item in bundle.get("webhooks", []):
|
||||
name = item.get("name")
|
||||
url = item.get("url")
|
||||
if not name or not url:
|
||||
continue
|
||||
existing = db.query(WebhookConfig).filter(WebhookConfig.name == name).first()
|
||||
if existing:
|
||||
existing.url = url
|
||||
existing.events = item.get("events", [])
|
||||
existing.is_active = item.get("is_active", True)
|
||||
existing.failure_count = 0
|
||||
else:
|
||||
db.add(WebhookConfig(
|
||||
name=name,
|
||||
url=url,
|
||||
secret=None, # never restore secrets
|
||||
events=item.get("events", []),
|
||||
is_active=item.get("is_active", True),
|
||||
created_by=current_user.id,
|
||||
failure_count=0,
|
||||
))
|
||||
summary["webhooks"] += 1
|
||||
|
||||
# ── 3. SSO config ────────────────────────────────────────────────
|
||||
sso_data = bundle.get("sso")
|
||||
if sso_data:
|
||||
sso_row = db.query(SsoConfig).first()
|
||||
if sso_row:
|
||||
for field, val in sso_data.items():
|
||||
if val == "[REDACTED]":
|
||||
continue
|
||||
if hasattr(sso_row, field):
|
||||
setattr(sso_row, field, val)
|
||||
else:
|
||||
clean = {k: v for k, v in sso_data.items() if v != "[REDACTED]"}
|
||||
clean.pop("sp_private_key", None)
|
||||
db.add(SsoConfig(**clean))
|
||||
|
||||
# ── 4. Scoring config ────────────────────────────────────────────
|
||||
scoring_data = bundle.get("scoring")
|
||||
if scoring_data:
|
||||
sc = db.query(ScoringConfig).first()
|
||||
if sc:
|
||||
for field, val in scoring_data.items():
|
||||
if hasattr(sc, field) and val is not None:
|
||||
setattr(sc, field, val)
|
||||
else:
|
||||
db.add(ScoringConfig(**scoring_data))
|
||||
|
||||
# ── 5. Custom templates ──────────────────────────────────────────
|
||||
for item in bundle.get("custom_templates", []):
|
||||
name = item.get("name")
|
||||
mitre_id = item.get("mitre_technique_id")
|
||||
if not name or not mitre_id:
|
||||
continue
|
||||
existing = (
|
||||
db.query(TestTemplate)
|
||||
.filter(TestTemplate.name == name, TestTemplate.source == "custom")
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
for field, val in item.items():
|
||||
if hasattr(existing, field):
|
||||
setattr(existing, field, val)
|
||||
else:
|
||||
db.add(TestTemplate(**{k: v for k, v in item.items()
|
||||
if k not in ("id", "created_at")}))
|
||||
summary["custom_templates"] += 1
|
||||
|
||||
# ── 6. Users ─────────────────────────────────────────────────────
|
||||
import secrets as _secrets
|
||||
for item in bundle.get("users", []):
|
||||
username = item.get("username")
|
||||
if not username:
|
||||
continue
|
||||
existing = db.query(User).filter(User.username == username).first()
|
||||
if existing:
|
||||
existing.role = item.get("role", existing.role)
|
||||
existing.is_active = item.get("is_active", existing.is_active)
|
||||
summary["users_updated"] += 1
|
||||
else:
|
||||
# Create with random temp password — user must reset on login
|
||||
temp_pw = _secrets.token_urlsafe(16) + "Aa1!"
|
||||
new_user = User(
|
||||
username=username,
|
||||
hashed_password=hash_password(temp_pw),
|
||||
role=item.get("role", "viewer"),
|
||||
is_active=item.get("is_active", True),
|
||||
must_change_password=True,
|
||||
)
|
||||
if item.get("email") and hasattr(User, "email"):
|
||||
new_user.email = item["email"]
|
||||
db.add(new_user)
|
||||
summary["users_created"] += 1
|
||||
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
"status": "ok",
|
||||
"imported_from_version": version,
|
||||
"summary": summary,
|
||||
"warnings": [
|
||||
"REDACTED values were skipped — re-enter passwords/tokens manually.",
|
||||
"All imported users have must_change_password=True.",
|
||||
"SSO private key was not restored — re-upload it manually.",
|
||||
],
|
||||
}
|
||||
@@ -0,0 +1,104 @@
|
||||
"""Phase 14: API Key management router."""
|
||||
|
||||
from typing import List
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import get_current_user, require_any_role
|
||||
from app.models.user import User
|
||||
from app.schemas.api_key_schema import (
|
||||
ApiKeyCreate, ApiKeyCreated, ApiKeyOut, ApiKeyUpdate,
|
||||
)
|
||||
import app.services.api_key_service as svc
|
||||
|
||||
router = APIRouter(prefix="/api-keys", tags=["API Keys"])
|
||||
|
||||
|
||||
@router.post("", response_model=ApiKeyCreated, status_code=201)
|
||||
def create_key(
|
||||
body: ApiKeyCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Create a scoped API key.
|
||||
|
||||
The ``raw_key`` field in the response is shown **exactly once** and
|
||||
cannot be retrieved later. Store it securely.
|
||||
"""
|
||||
key, raw_key = svc.create_api_key(
|
||||
db,
|
||||
user_id = current_user.id,
|
||||
name = body.name,
|
||||
scopes = body.scopes,
|
||||
description = body.description,
|
||||
expires_at = body.expires_at,
|
||||
)
|
||||
out = ApiKeyOut.model_validate(key)
|
||||
return ApiKeyCreated(**out.model_dump(), raw_key=raw_key)
|
||||
|
||||
|
||||
@router.get("", response_model=List[ApiKeyOut])
|
||||
def list_keys(
|
||||
include_inactive: bool = Query(False),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""List API keys owned by the current user."""
|
||||
# Admins can see all keys; others only see their own
|
||||
user_id = None if current_user.role == "admin" else current_user.id
|
||||
return svc.list_api_keys(db, user_id=user_id, include_inactive=include_inactive)
|
||||
|
||||
|
||||
@router.get("/{key_id}", response_model=ApiKeyOut)
|
||||
def get_key(
|
||||
key_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Get a single API key (owner or admin)."""
|
||||
user_id = None if current_user.role == "admin" else current_user.id
|
||||
return svc.get_api_key(db, key_id, user_id=user_id)
|
||||
|
||||
|
||||
@router.patch("/{key_id}", response_model=ApiKeyOut)
|
||||
def update_key(
|
||||
key_id: UUID,
|
||||
body: ApiKeyUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Update name, description, scopes, expiry, or active status."""
|
||||
user_id = None if current_user.role == "admin" else current_user.id
|
||||
return svc.update_api_key(
|
||||
db, key_id, user_id,
|
||||
name = body.name,
|
||||
description = body.description,
|
||||
scopes = body.scopes,
|
||||
expires_at = body.expires_at,
|
||||
is_active = body.is_active,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{key_id}/revoke", response_model=ApiKeyOut)
|
||||
def revoke_key(
|
||||
key_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Revoke an API key (soft-delete — sets is_active=False)."""
|
||||
user_id = None if current_user.role == "admin" else current_user.id
|
||||
return svc.revoke_api_key(db, key_id, user_id=user_id)
|
||||
|
||||
|
||||
@router.delete("/{key_id}", status_code=204)
|
||||
def delete_key(
|
||||
key_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_any_role("admin")),
|
||||
):
|
||||
"""Permanently delete an API key (admin only)."""
|
||||
svc.delete_api_key(db, key_id)
|
||||
@@ -0,0 +1,249 @@
|
||||
"""Phase 10: Attack Paths & Advanced Purple Team router."""
|
||||
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import get_current_user, require_any_role
|
||||
from app.schemas.attack_path_schema import (
|
||||
AttackPathCreate, AttackPathUpdate, AttackPathOut,
|
||||
AttackPathStepCreate, AttackPathStepUpdate, AttackPathStepOut,
|
||||
ExecutionCreate, ExecutionOut,
|
||||
StepExecuteRequest, StepResultOut,
|
||||
TimelineEntryCreate, TimelineEntryOut,
|
||||
)
|
||||
from app.services import attack_path_service as svc
|
||||
|
||||
router = APIRouter(prefix="/attack-paths", tags=["attack-paths"])
|
||||
|
||||
|
||||
# ── Attack Paths CRUD ─────────────────────────────────────────────────────────
|
||||
|
||||
@router.post("", response_model=AttackPathOut, status_code=201)
|
||||
def create_attack_path(
|
||||
body: AttackPathCreate,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
return svc.create_attack_path(db, body.model_dump(), user.id)
|
||||
|
||||
|
||||
@router.get("", response_model=list[AttackPathOut])
|
||||
def list_attack_paths(
|
||||
is_template: Optional[bool] = None,
|
||||
technique_id: Optional[UUID] = None,
|
||||
is_active: Optional[bool] = True,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
paths = svc.list_attack_paths(db, is_template=is_template,
|
||||
technique_id=technique_id, is_active=is_active)
|
||||
# Inject step_count
|
||||
result = []
|
||||
for p in paths:
|
||||
d = AttackPathOut.model_validate(p)
|
||||
d.step_count = len(p.steps)
|
||||
result.append(d)
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/{path_id}", response_model=AttackPathOut)
|
||||
def get_attack_path(
|
||||
path_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
p = svc.get_attack_path(db, path_id)
|
||||
d = AttackPathOut.model_validate(p)
|
||||
d.step_count = len(p.steps)
|
||||
return d
|
||||
|
||||
|
||||
@router.patch("/{path_id}", response_model=AttackPathOut)
|
||||
def update_attack_path(
|
||||
path_id: UUID,
|
||||
body: AttackPathUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
return svc.update_attack_path(db, path_id, body.model_dump(exclude_unset=True), user.id)
|
||||
|
||||
|
||||
@router.delete("/{path_id}", status_code=204)
|
||||
def delete_attack_path(
|
||||
path_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(require_any_role("admin", "red_lead", "blue_lead")),
|
||||
):
|
||||
svc.delete_attack_path(db, path_id, user.id)
|
||||
|
||||
|
||||
# ── Steps ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/{path_id}/steps", response_model=list[AttackPathStepOut])
|
||||
def list_steps(
|
||||
path_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
path = svc.get_attack_path(db, path_id)
|
||||
return path.steps
|
||||
|
||||
|
||||
@router.post("/{path_id}/steps", response_model=AttackPathStepOut, status_code=201)
|
||||
def add_step(
|
||||
path_id: UUID,
|
||||
body: AttackPathStepCreate,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
return svc.add_step(db, path_id, body.model_dump(), user.id)
|
||||
|
||||
|
||||
@router.patch("/{path_id}/steps/{step_id}", response_model=AttackPathStepOut)
|
||||
def update_step(
|
||||
path_id: UUID,
|
||||
step_id: UUID,
|
||||
body: AttackPathStepUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
return svc.update_step(db, step_id, body.model_dump(exclude_unset=True), user.id)
|
||||
|
||||
|
||||
@router.delete("/{path_id}/steps/{step_id}", status_code=204)
|
||||
def delete_step(
|
||||
path_id: UUID,
|
||||
step_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
svc.delete_step(db, step_id, user.id)
|
||||
|
||||
|
||||
@router.post("/{path_id}/steps/reorder", response_model=list[AttackPathStepOut])
|
||||
def reorder_steps(
|
||||
path_id: UUID,
|
||||
step_ids: list[UUID],
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
"""Pass an ordered list of step UUIDs to reorder the steps."""
|
||||
return svc.reorder_steps(db, path_id, step_ids, user.id)
|
||||
|
||||
|
||||
# ── Executions ────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.post("/{path_id}/executions", response_model=ExecutionOut, status_code=201)
|
||||
def create_execution(
|
||||
path_id: UUID,
|
||||
body: ExecutionCreate,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
return svc.create_execution(db, path_id, body.model_dump(), user.id)
|
||||
|
||||
|
||||
@router.get("/{path_id}/executions", response_model=list[ExecutionOut])
|
||||
def list_executions(
|
||||
path_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
return svc.list_executions(db, path_id)
|
||||
|
||||
|
||||
@router.get("/executions/{execution_id}", response_model=ExecutionOut)
|
||||
def get_execution(
|
||||
execution_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
return svc.get_execution(db, execution_id)
|
||||
|
||||
|
||||
@router.post("/executions/{execution_id}/start", response_model=ExecutionOut)
|
||||
def start_execution(
|
||||
execution_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
return svc.start_execution(db, execution_id, user.id)
|
||||
|
||||
|
||||
@router.post("/executions/{execution_id}/steps/{step_id}", response_model=StepResultOut)
|
||||
def execute_step(
|
||||
execution_id: UUID,
|
||||
step_id: UUID,
|
||||
body: StepExecuteRequest,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
"""Record the result of one step (detected / not_detected / skipped)."""
|
||||
return svc.execute_step(db, execution_id, step_id, body.model_dump(), user.id)
|
||||
|
||||
|
||||
@router.get("/executions/{execution_id}/steps", response_model=list[StepResultOut])
|
||||
def list_step_results(
|
||||
execution_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
ex = svc.get_execution(db, execution_id)
|
||||
return ex.step_results
|
||||
|
||||
|
||||
@router.post("/executions/{execution_id}/complete", response_model=ExecutionOut)
|
||||
def complete_execution(
|
||||
execution_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
"""Mark execution as complete and compute kill-chain metrics."""
|
||||
return svc.complete_execution(db, execution_id, user.id)
|
||||
|
||||
|
||||
@router.post("/executions/{execution_id}/abort", response_model=ExecutionOut)
|
||||
def abort_execution(
|
||||
execution_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(require_any_role("admin", "red_lead", "blue_lead")),
|
||||
):
|
||||
return svc.abort_execution(db, execution_id, user.id)
|
||||
|
||||
|
||||
# ── Timeline ──────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.post("/executions/{execution_id}/timeline",
|
||||
response_model=TimelineEntryOut, status_code=201)
|
||||
def add_timeline_entry(
|
||||
execution_id: UUID,
|
||||
body: TimelineEntryCreate,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
return svc.add_timeline_entry(db, execution_id, body.model_dump(), user.id)
|
||||
|
||||
|
||||
@router.get("/executions/{execution_id}/timeline", response_model=list[TimelineEntryOut])
|
||||
def get_timeline(
|
||||
execution_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
return svc.get_timeline(db, execution_id)
|
||||
|
||||
|
||||
# ── Kill-Chain Metrics ────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/executions/{execution_id}/metrics")
|
||||
def get_metrics(
|
||||
execution_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
"""Return full kill-chain metrics for a completed (or partial) execution."""
|
||||
return svc.get_kill_chain_metrics(db, execution_id)
|
||||
@@ -18,6 +18,7 @@ from fastapi.security import OAuth2PasswordRequestForm
|
||||
|
||||
# Import jwt (PyJWT)
|
||||
import jwt
|
||||
from jwt.exceptions import PyJWTError as JWTError
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -71,9 +72,16 @@ from app.services.auth_service import (
|
||||
# Assign router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
# Assign _IS_HTTPS = os.environ.get("AEGIS_ENV", "").lower() == "production"
|
||||
_IS_HTTPS = os.environ.get("AEGIS_ENV", "").lower() == "production"
|
||||
# Assign _COOKIE_NAME = "aegis_token"
|
||||
# SECURE_COOKIES desacopla la seguridad de la cookie del entorno de ejecucion.
|
||||
# Por defecto activo en produccion; ponlo en "false" para servidores HTTP.
|
||||
_aegis_env = os.environ.get("AEGIS_ENV", "development").lower()
|
||||
_secure_cookie_env = os.environ.get("SECURE_COOKIES", "auto").lower()
|
||||
if _secure_cookie_env == "false":
|
||||
_IS_HTTPS = False
|
||||
elif _secure_cookie_env == "true":
|
||||
_IS_HTTPS = True
|
||||
else: # "auto" — activo solo si AEGIS_ENV=production
|
||||
_IS_HTTPS = _aegis_env == "production"
|
||||
_COOKIE_NAME = "aegis_token"
|
||||
|
||||
|
||||
@@ -256,7 +264,57 @@ def logout(
|
||||
return {"detail": "Logged out"}
|
||||
|
||||
|
||||
# Apply the @router.get decorator
|
||||
@router.post("/refresh", response_model=TokenResponse)
|
||||
def refresh_token(
|
||||
response: Response,
|
||||
aegis_token: str | None = Cookie(None),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Issue a new access token if the current one is valid.
|
||||
|
||||
Called automatically by the frontend when it detects an expired
|
||||
session while the user is actively using the app. If the current
|
||||
cookie token is still valid (not blacklisted, not expired), a fresh
|
||||
token is issued and the cookie is renewed — keeping the session alive
|
||||
without requiring re-authentication.
|
||||
"""
|
||||
if not aegis_token:
|
||||
raise PermissionViolation("No active session")
|
||||
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
aegis_token,
|
||||
settings.SECRET_KEY,
|
||||
algorithms=[settings.ALGORITHM],
|
||||
)
|
||||
except JWTError:
|
||||
raise PermissionViolation("Session expired — please log in again")
|
||||
|
||||
username: str | None = payload.get("sub")
|
||||
if not username:
|
||||
raise PermissionViolation("Invalid session")
|
||||
|
||||
user = db.query(User).filter(User.username == username).first()
|
||||
if user is None or not user.is_active:
|
||||
raise PermissionViolation("Account not found or disabled")
|
||||
|
||||
if getattr(user, "must_change_password", False):
|
||||
raise PermissionViolation("Password change required before refreshing session")
|
||||
|
||||
# Issue a fresh token with a new expiry
|
||||
new_token = create_access_token(data={"sub": user.username})
|
||||
response.set_cookie(
|
||||
key=_COOKIE_NAME,
|
||||
value=new_token,
|
||||
httponly=True,
|
||||
secure=_IS_HTTPS,
|
||||
samesite="strict",
|
||||
max_age=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
||||
path="/",
|
||||
)
|
||||
return TokenResponse(access_token=new_token)
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserOut)
|
||||
# Define function read_current_user
|
||||
def read_current_user(current_user: User = Depends(get_current_user)) -> UserOut:
|
||||
|
||||
@@ -9,8 +9,7 @@ import logging
|
||||
|
||||
# Import uuid
|
||||
import uuid
|
||||
|
||||
# Import Optional from typing
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
# Import APIRouter, Depends, Query from fastapi
|
||||
@@ -33,16 +32,9 @@ from app.domain.unit_of_work import UnitOfWork
|
||||
|
||||
# Import User from app.models.user
|
||||
from app.models.user import User
|
||||
|
||||
# Import log_action from app.services.audit_service
|
||||
from app.services.audit_service import log_action
|
||||
|
||||
# Import from app.services.campaign_crud_service
|
||||
from app.services.campaign_crud_service import (
|
||||
activate_campaign as crud_activate,
|
||||
)
|
||||
|
||||
# Import from app.services.campaign_crud_service
|
||||
from app.models.campaign import Campaign, CampaignTest
|
||||
from app.models.test import Test
|
||||
from app.services.campaign_service import generate_campaign_from_threat_actor
|
||||
from app.services.campaign_crud_service import (
|
||||
add_test_to_campaign as crud_add_test,
|
||||
)
|
||||
@@ -55,10 +47,7 @@ from app.services.campaign_crud_service import (
|
||||
# Import from app.services.campaign_crud_service
|
||||
from app.services.campaign_crud_service import (
|
||||
create_campaign as crud_create,
|
||||
)
|
||||
|
||||
# Import from app.services.campaign_crud_service
|
||||
from app.services.campaign_crud_service import (
|
||||
delete_campaign as crud_delete,
|
||||
get_campaign_detail as crud_get_detail,
|
||||
)
|
||||
|
||||
@@ -97,11 +86,17 @@ from app.services.campaign_crud_service import (
|
||||
update_campaign as crud_update,
|
||||
)
|
||||
|
||||
# Import generate_campaign_from_threat_actor from app.services.campaign_service
|
||||
from app.services.campaign_service import generate_campaign_from_threat_actor
|
||||
# Import activate_campaign from app.services.campaign_crud_service
|
||||
from app.services.campaign_crud_service import (
|
||||
activate_campaign as crud_activate,
|
||||
)
|
||||
|
||||
# Import log_action from app.services.audit_service
|
||||
from app.services.audit_service import log_action
|
||||
|
||||
# Import notify_role from app.services.notification_service
|
||||
from app.services.notification_service import notify_role
|
||||
from app.services.webhook_service import dispatch_webhook
|
||||
|
||||
# Assign logger = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -129,6 +124,7 @@ class CampaignCreate(BaseModel):
|
||||
tags: Optional[list[str]] = Field(default_factory=list)
|
||||
# Assign scheduled_at = None
|
||||
scheduled_at: Optional[str] = None
|
||||
start_date: Optional[str] = None # ISO date — campaign won't activate before this
|
||||
|
||||
|
||||
# Define class CampaignUpdate
|
||||
@@ -147,6 +143,7 @@ class CampaignUpdate(BaseModel):
|
||||
tags: Optional[list[str]] = None
|
||||
# Assign scheduled_at = None
|
||||
scheduled_at: Optional[str] = None
|
||||
start_date: Optional[str] = None # ISO date — can be updated while still in draft
|
||||
|
||||
|
||||
# Define class AddTestPayload
|
||||
@@ -198,7 +195,7 @@ def list_campaigns(
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> list:
|
||||
) -> dict:
|
||||
"""List campaigns with optional filters and pagination.
|
||||
|
||||
Args:
|
||||
@@ -277,8 +274,9 @@ def create_campaign(
|
||||
tags=payload.tags,
|
||||
# Keyword argument: scheduled_at
|
||||
scheduled_at=payload.scheduled_at,
|
||||
start_date=payload.start_date,
|
||||
)
|
||||
# Call log_action()
|
||||
campaign_id = result["id"]
|
||||
log_action(
|
||||
db,
|
||||
# Keyword argument: user_id
|
||||
@@ -287,9 +285,7 @@ def create_campaign(
|
||||
action="create_campaign",
|
||||
# Keyword argument: entity_type
|
||||
entity_type="campaign",
|
||||
# Keyword argument: entity_id
|
||||
entity_id=result["id"],
|
||||
# Keyword argument: details
|
||||
entity_id=campaign_id,
|
||||
details={"name": payload.name, "type": payload.type},
|
||||
)
|
||||
# Call uow.commit()
|
||||
@@ -389,6 +385,37 @@ def update_campaign(
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DELETE /campaigns/{id} — Delete campaign
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.delete("/{campaign_id}", status_code=204)
|
||||
def delete_campaign(
|
||||
campaign_id: str,
|
||||
delete_tests: bool = Query(False, description="Also delete associated tests"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Delete a campaign. Only draft campaigns can be deleted (admins can delete any)."""
|
||||
with UnitOfWork(db) as uow:
|
||||
crud_delete(
|
||||
db,
|
||||
campaign_id,
|
||||
deleter_id=current_user.id,
|
||||
deleter_role=current_user.role,
|
||||
delete_tests=delete_tests,
|
||||
)
|
||||
log_action(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
action="delete_campaign",
|
||||
entity_type="campaign",
|
||||
entity_id=campaign_id,
|
||||
details={"delete_tests": delete_tests},
|
||||
)
|
||||
uow.commit()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /campaigns/{id}/tests — Add test to campaign
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -433,7 +460,7 @@ def add_test_to_campaign(
|
||||
)
|
||||
# Call uow.commit()
|
||||
uow.commit()
|
||||
# Return result
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -483,22 +510,36 @@ def remove_test_from_campaign(
|
||||
def activate_campaign(
|
||||
# Entry: campaign_id
|
||||
campaign_id: str,
|
||||
# Entry: db
|
||||
force: bool = Query(False, description="Activate even if start_date is in the future"),
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||
) -> dict:
|
||||
):
|
||||
"""Activate a campaign, moving it from draft to active.
|
||||
|
||||
Args:
|
||||
campaign_id (str): UUID string of the campaign to activate.
|
||||
db (Session): SQLAlchemy database session.
|
||||
current_user (User): Authenticated red_lead or blue_lead activating the campaign.
|
||||
|
||||
Returns:
|
||||
dict: Serialised representation of the activated campaign.
|
||||
If the campaign has a start_date in the future and force=False, returns a 409
|
||||
with a warning so the frontend can show a confirmation modal. If force=True,
|
||||
activates immediately regardless of start_date.
|
||||
"""
|
||||
# Open context manager
|
||||
from fastapi import HTTPException
|
||||
campaign_obj = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||
if campaign_obj and campaign_obj.start_date and not force:
|
||||
now = datetime.utcnow()
|
||||
if campaign_obj.start_date > now:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail={
|
||||
"code": "start_date_in_future",
|
||||
"start_date": campaign_obj.start_date.strftime("%Y-%m-%d"),
|
||||
"message": (
|
||||
f"This campaign is scheduled to start on "
|
||||
f"{campaign_obj.start_date.strftime('%d %b %Y')}. "
|
||||
f"It will activate automatically on that date. "
|
||||
f"Do you want to activate it now anyway?"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
with UnitOfWork(db) as uow:
|
||||
# Assign campaign = crud_activate(db, campaign_id)
|
||||
campaign = crud_activate(db, campaign_id)
|
||||
@@ -537,7 +578,33 @@ def activate_campaign(
|
||||
# Reload ORM object attributes from the database
|
||||
db.refresh(campaign)
|
||||
|
||||
# Return serialize_campaign(db, campaign)
|
||||
# Create Jira tickets for campaign and tests at activation time (non-fatal).
|
||||
# Campaign ticket is created here if it doesn't already exist (deferred from creation).
|
||||
try:
|
||||
from app.services.jira_service import (
|
||||
auto_create_campaign_issue,
|
||||
auto_create_test_issue,
|
||||
get_campaign_jira_key,
|
||||
get_test_jira_key,
|
||||
)
|
||||
campaign_jira_key = get_campaign_jira_key(db, campaign_id)
|
||||
if not campaign_jira_key:
|
||||
campaign_jira_key = auto_create_campaign_issue(db, campaign, current_user)
|
||||
if campaign_jira_key:
|
||||
for ct in campaign.campaign_tests:
|
||||
if ct.test and not get_test_jira_key(db, ct.test.id):
|
||||
auto_create_test_issue(
|
||||
db, ct.test, current_user,
|
||||
parent_ticket_override=campaign_jira_key,
|
||||
campaign_start_date=campaign.start_date,
|
||||
)
|
||||
db.commit()
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Jira ticket creation failed during activation of campaign %s",
|
||||
campaign_id,
|
||||
)
|
||||
|
||||
return serialize_campaign(db, campaign)
|
||||
|
||||
|
||||
@@ -587,6 +654,7 @@ def complete_campaign(
|
||||
uow.commit()
|
||||
# Reload ORM object attributes from the database
|
||||
db.refresh(campaign)
|
||||
dispatch_webhook("campaign.completed", {"campaign_id": str(campaign.id), "name": campaign.name})
|
||||
|
||||
# Return serialize_campaign(db, campaign)
|
||||
return serialize_campaign(db, campaign)
|
||||
@@ -624,12 +692,16 @@ def get_campaign_progress_endpoint(
|
||||
# POST /campaigns/from-threat-actor/{actor_id} — Auto-generate campaign
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class GenerateFromActorPayload(BaseModel):
|
||||
start_date: Optional[str] = None # ISO date YYYY-MM-DD
|
||||
|
||||
|
||||
@router.post("/from-threat-actor/{actor_id}", status_code=201)
|
||||
# Define function generate_campaign_from_actor
|
||||
def generate_campaign_from_actor(
|
||||
# Entry: actor_id
|
||||
actor_id: str,
|
||||
# Entry: db
|
||||
payload: GenerateFromActorPayload = GenerateFromActorPayload(),
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||
@@ -647,11 +719,14 @@ def generate_campaign_from_actor(
|
||||
Returns:
|
||||
dict: Serialised representation of the newly generated campaign.
|
||||
"""
|
||||
# Assign campaign = generate_campaign_from_threat_actor(
|
||||
start_date_parsed = (
|
||||
datetime.fromisoformat(payload.start_date) if payload.start_date else None
|
||||
)
|
||||
campaign = generate_campaign_from_threat_actor(
|
||||
db,
|
||||
uuid.UUID(actor_id),
|
||||
current_user,
|
||||
start_date=start_date_parsed,
|
||||
)
|
||||
|
||||
# Open context manager
|
||||
@@ -779,3 +854,97 @@ def get_campaign_history(
|
||||
"""
|
||||
# Return crud_get_history(db, campaign_id)
|
||||
return crud_get_history(db, campaign_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /campaigns/{id}/timing-summary — Aggregated timing across campaign tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _seconds_between(start: datetime | None, end: datetime | None) -> int:
|
||||
"""Return elapsed seconds between two datetimes; 0 if either is None."""
|
||||
if not start or not end:
|
||||
return 0
|
||||
diff = (end - start).total_seconds()
|
||||
return max(0, int(diff))
|
||||
|
||||
|
||||
@router.get("/{campaign_id}/timing-summary")
|
||||
def get_campaign_timing_summary(
|
||||
campaign_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Return aggregated Red/Blue timing metrics for all tests in a campaign.
|
||||
|
||||
For each test we calculate:
|
||||
- red_execution_secs : red_started_at → blue_started_at (minus red_paused_seconds)
|
||||
- blue_queue_secs : blue_started_at → blue_work_started_at (waiting for Blue pick-up)
|
||||
- blue_evaluation_secs: blue_work_started_at → first validation timestamp (minus blue_paused_seconds)
|
||||
- total_secs : sum of the three phases
|
||||
|
||||
Returns totals + per-test breakdown.
|
||||
"""
|
||||
# Load campaign
|
||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||
if not campaign:
|
||||
from fastapi import HTTPException
|
||||
raise HTTPException(status_code=404, detail="Campaign not found")
|
||||
|
||||
# Load all tests for this campaign
|
||||
test_ids = [
|
||||
ct.test_id
|
||||
for ct in db.query(CampaignTest).filter(CampaignTest.campaign_id == campaign.id).all()
|
||||
]
|
||||
tests = db.query(Test).filter(Test.id.in_(test_ids)).all() if test_ids else []
|
||||
|
||||
breakdown = []
|
||||
total_red = 0
|
||||
total_queue = 0
|
||||
total_blue = 0
|
||||
|
||||
for t in tests:
|
||||
# Red execution: from start-execution to submit-to-blue, minus paused time
|
||||
red_secs = max(
|
||||
0,
|
||||
_seconds_between(t.red_started_at, t.blue_started_at) - (t.red_paused_seconds or 0),
|
||||
)
|
||||
|
||||
# Blue queue: from receiving the test to actually starting evaluation
|
||||
queue_secs = _seconds_between(t.blue_started_at, t.blue_work_started_at)
|
||||
|
||||
# Blue evaluation: from starting evaluation to first validation, minus paused time
|
||||
eval_end = t.red_validated_at or t.blue_validated_at
|
||||
blue_secs = max(
|
||||
0,
|
||||
_seconds_between(t.blue_work_started_at, eval_end) - (t.blue_paused_seconds or 0),
|
||||
)
|
||||
|
||||
total_red += red_secs
|
||||
total_queue += queue_secs
|
||||
total_blue += blue_secs
|
||||
|
||||
breakdown.append({
|
||||
"test_id": str(t.id),
|
||||
"test_name": t.name,
|
||||
"state": t.state.value if t.state else None,
|
||||
"red_execution_secs": red_secs,
|
||||
"blue_queue_secs": queue_secs,
|
||||
"blue_evaluation_secs": blue_secs,
|
||||
"total_secs": red_secs + queue_secs + blue_secs,
|
||||
"has_timing": bool(t.red_started_at),
|
||||
})
|
||||
|
||||
total_secs = total_red + total_queue + total_blue
|
||||
|
||||
return {
|
||||
"campaign_id": campaign_id,
|
||||
"campaign_name": campaign.name,
|
||||
"tests_total": len(tests),
|
||||
"tests_with_timing": sum(1 for b in breakdown if b["has_timing"]),
|
||||
"red_execution_secs": total_red,
|
||||
"blue_queue_secs": total_queue,
|
||||
"blue_evaluation_secs": total_blue,
|
||||
"total_secs": total_secs,
|
||||
"breakdown": sorted(breakdown, key=lambda x: -(x["total_secs"])),
|
||||
}
|
||||
|
||||
@@ -26,6 +26,9 @@ from app.models.user import User
|
||||
# Import from app.services.compliance_import_service
|
||||
from app.services.compliance_import_service import (
|
||||
import_cis_controls_v8_mappings,
|
||||
import_dora_mappings,
|
||||
import_iso_27001_mappings,
|
||||
import_iso_42001_mappings,
|
||||
import_nist_800_53_mappings,
|
||||
)
|
||||
|
||||
@@ -232,3 +235,33 @@ def import_cis(
|
||||
result = import_cis_controls_v8_mappings(db)
|
||||
# Return result
|
||||
return result
|
||||
|
||||
|
||||
@router.post("/import/dora")
|
||||
def import_dora(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Import DORA (EU 2022/2554) compliance mappings (admin only)."""
|
||||
result = import_dora_mappings(db)
|
||||
return result
|
||||
|
||||
|
||||
@router.post("/import/iso-27001")
|
||||
def import_iso27001(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Import ISO/IEC 27001:2022 Annex A compliance mappings (admin only)."""
|
||||
result = import_iso_27001_mappings(db)
|
||||
return result
|
||||
|
||||
|
||||
@router.post("/import/iso-42001")
|
||||
def import_iso42001(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Import ISO/IEC 42001:2023 AI Management System compliance mappings (admin only)."""
|
||||
result = import_iso_42001_mappings(db)
|
||||
return result
|
||||
|
||||
@@ -64,7 +64,7 @@ def list_defensive_techniques(
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> list:
|
||||
) -> dict:
|
||||
"""List all D3FEND defensive techniques with optional filters."""
|
||||
# Return list_defensive_techniques_svc(
|
||||
return list_defensive_techniques_svc(
|
||||
@@ -102,7 +102,7 @@ def get_defenses_for_attack_technique_endpoint(
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> list:
|
||||
) -> dict:
|
||||
"""Get all D3FEND defensive techniques mapped to a given ATT&CK technique."""
|
||||
# Return get_defenses_for_attack_technique(db, mitre_id)
|
||||
return get_defenses_for_attack_technique(db, mitre_id)
|
||||
|
||||
@@ -0,0 +1,319 @@
|
||||
"""Detection Lifecycle Management router."""
|
||||
|
||||
import hashlib
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import get_current_user, require_any_role
|
||||
from app.domain.exceptions import EntityNotFoundError
|
||||
from app.models.detection_lifecycle import (
|
||||
DetectionAsset, DetectionTechniqueMapping, DetectionValidation,
|
||||
TechniqueConfidenceScore, InfrastructureChangeLog,
|
||||
)
|
||||
from app.schemas.detection_lifecycle_schema import (
|
||||
DetectionAssetCreate, DetectionAssetUpdate, DetectionAssetOut,
|
||||
DetectionValidationCreate, DetectionValidationOut,
|
||||
TechniqueConfidenceOut,
|
||||
InfrastructureChangeCreate, InfrastructureChangeOut,
|
||||
)
|
||||
from app.services import detection_asset_service, decay_engine_service, audit_service
|
||||
|
||||
router = APIRouter(prefix="/detection-lifecycle", tags=["detection-lifecycle"])
|
||||
|
||||
|
||||
def _now() -> datetime:
|
||||
return datetime.utcnow()
|
||||
|
||||
|
||||
# ── Detection Assets ─────────────────────────────────────────────────────────
|
||||
|
||||
@router.post("/assets", response_model=DetectionAssetOut, status_code=201)
|
||||
def create_asset(body: DetectionAssetCreate, db: Session = Depends(get_db), user=Depends(get_current_user)):
|
||||
asset = detection_asset_service.create_detection_asset(db, body.model_dump(), user.id)
|
||||
return asset
|
||||
|
||||
|
||||
@router.get("/assets", response_model=list[DetectionAssetOut])
|
||||
def list_assets(
|
||||
platform: Optional[str] = None,
|
||||
asset_type: Optional[str] = None,
|
||||
health_status: Optional[str] = None,
|
||||
technique_id: Optional[UUID] = None,
|
||||
is_active: Optional[bool] = True,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
return detection_asset_service.list_assets(db, platform=platform, asset_type=asset_type, health_status=health_status, technique_id=technique_id, is_active=is_active)
|
||||
|
||||
|
||||
@router.get("/assets/{asset_id}", response_model=DetectionAssetOut)
|
||||
def get_asset(asset_id: UUID, db: Session = Depends(get_db), user=Depends(get_current_user)):
|
||||
return detection_asset_service.get_asset_with_details(db, asset_id)
|
||||
|
||||
|
||||
@router.patch("/assets/{asset_id}", response_model=DetectionAssetOut)
|
||||
def update_asset(asset_id: UUID, body: DetectionAssetUpdate, db: Session = Depends(get_db), user=Depends(get_current_user)):
|
||||
return detection_asset_service.update_detection_asset(db, asset_id, body.model_dump(exclude_unset=True), user.id)
|
||||
|
||||
|
||||
@router.delete("/assets/{asset_id}", status_code=204)
|
||||
def delete_asset(asset_id: UUID, db: Session = Depends(get_db), user=Depends(require_any_role("red_lead", "blue_lead"))):
|
||||
asset = db.query(DetectionAsset).filter(DetectionAsset.id == asset_id).first()
|
||||
if not asset:
|
||||
raise EntityNotFoundError("DetectionAsset", str(asset_id))
|
||||
asset.is_active = False
|
||||
db.commit()
|
||||
|
||||
|
||||
# ── Technique Mappings ───────────────────────────────────────────────────────
|
||||
|
||||
@router.post("/assets/{asset_id}/techniques/{technique_id}")
|
||||
def map_technique(
|
||||
asset_id: UUID, technique_id: UUID,
|
||||
coverage_type: str = Query("detect"),
|
||||
confidence_level: str = Query("medium"),
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
# Validate asset exists
|
||||
asset = db.query(DetectionAsset).filter(DetectionAsset.id == asset_id).first()
|
||||
if not asset:
|
||||
raise EntityNotFoundError("DetectionAsset", str(asset_id))
|
||||
|
||||
# Prevent duplicate mappings
|
||||
existing = db.query(DetectionTechniqueMapping).filter(
|
||||
DetectionTechniqueMapping.detection_asset_id == asset_id,
|
||||
DetectionTechniqueMapping.technique_id == technique_id,
|
||||
).first()
|
||||
if existing:
|
||||
# Update coverage/confidence on existing mapping instead of duplicating
|
||||
existing.coverage_type = coverage_type
|
||||
existing.confidence_level = confidence_level
|
||||
db.commit()
|
||||
return {"message": "Technique mapping updated", "mapping_id": str(existing.id)}
|
||||
|
||||
mapping = DetectionTechniqueMapping(
|
||||
detection_asset_id=asset_id, technique_id=technique_id,
|
||||
coverage_type=coverage_type, confidence_level=confidence_level,
|
||||
)
|
||||
db.add(mapping)
|
||||
db.commit()
|
||||
return {"message": "Technique mapped", "mapping_id": str(mapping.id)}
|
||||
|
||||
|
||||
@router.get("/techniques/{technique_id}/detections")
|
||||
def get_technique_detections(technique_id: UUID, db: Session = Depends(get_db), user=Depends(get_current_user)):
|
||||
return detection_asset_service.get_technique_detection_summary(db, technique_id)
|
||||
|
||||
|
||||
# ── Validations ──────────────────────────────────────────────────────────────
|
||||
|
||||
@router.post("/validations", response_model=DetectionValidationOut, status_code=201)
|
||||
def create_validation(body: DetectionValidationCreate, db: Session = Depends(get_db), user=Depends(get_current_user)):
|
||||
asset = db.query(DetectionAsset).filter(DetectionAsset.id == body.detection_asset_id).first()
|
||||
if not asset:
|
||||
raise EntityNotFoundError("DetectionAsset", str(body.detection_asset_id))
|
||||
|
||||
now = _now()
|
||||
validation = DetectionValidation(
|
||||
detection_asset_id=body.detection_asset_id,
|
||||
technique_id=body.technique_id,
|
||||
test_id=body.test_id,
|
||||
validation_result=body.validation_result,
|
||||
validation_method=body.validation_method,
|
||||
notes=body.notes,
|
||||
evidence_ids=[str(e) for e in (body.evidence_ids or [])],
|
||||
validated_by=user.id,
|
||||
validated_at=now,
|
||||
expires_at=now + timedelta(days=body.validity_days),
|
||||
rule_hash_at_validation=asset.rule_hash,
|
||||
log_source_version_at_validation=asset.log_source_version,
|
||||
infrastructure_hash_at_validation=asset.infrastructure_hash,
|
||||
)
|
||||
data = f"{validation.detection_asset_id}:{validation.validated_by}:{validation.validation_result}:{validation.validated_at}"
|
||||
validation.integrity_hash = hashlib.sha256(data.encode()).hexdigest()
|
||||
|
||||
db.add(validation)
|
||||
db.commit()
|
||||
db.refresh(validation)
|
||||
|
||||
if body.technique_id:
|
||||
decay_engine_service.calculate_confidence_for_technique(db, body.technique_id)
|
||||
|
||||
audit_service.log_action(db, user.id, "DETECTION_VALIDATED", "detection_validation", str(validation.id),
|
||||
details={"asset_id": str(body.detection_asset_id), "result": body.validation_result, "validity_days": body.validity_days})
|
||||
|
||||
return validation
|
||||
|
||||
|
||||
@router.get("/validations", response_model=list[DetectionValidationOut])
|
||||
def list_validations(
|
||||
asset_id: Optional[UUID] = None,
|
||||
technique_id: Optional[UUID] = None,
|
||||
is_valid: Optional[bool] = None,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
query = db.query(DetectionValidation)
|
||||
if asset_id:
|
||||
query = query.filter(DetectionValidation.detection_asset_id == asset_id)
|
||||
if technique_id:
|
||||
query = query.filter(DetectionValidation.technique_id == technique_id)
|
||||
if is_valid is not None:
|
||||
query = query.filter(DetectionValidation.is_valid == is_valid)
|
||||
return query.order_by(DetectionValidation.validated_at.desc()).all()
|
||||
|
||||
|
||||
@router.post("/validations/{validation_id}/invalidate")
|
||||
def invalidate_validation(
|
||||
validation_id: UUID,
|
||||
reason: str = Query(...),
|
||||
details: Optional[str] = None,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(require_any_role("admin", "blue_lead")),
|
||||
):
|
||||
validation = db.query(DetectionValidation).filter(DetectionValidation.id == validation_id).first()
|
||||
if not validation:
|
||||
raise EntityNotFoundError("DetectionValidation", str(validation_id))
|
||||
|
||||
from app.models.detection_lifecycle import InvalidationReason
|
||||
try:
|
||||
reason_enum = InvalidationReason(reason)
|
||||
except ValueError:
|
||||
reason_enum = InvalidationReason.manual
|
||||
|
||||
validation.is_valid = False
|
||||
validation.invalidated_at = _now()
|
||||
validation.invalidation_reason = reason_enum
|
||||
validation.invalidation_details = details
|
||||
validation.invalidated_by = user.id
|
||||
db.commit()
|
||||
return {"message": "Validation invalidated"}
|
||||
|
||||
|
||||
# ── Confidence Scores ────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/confidence", response_model=list[TechniqueConfidenceOut])
|
||||
def list_confidence_scores(
|
||||
confidence_level: Optional[str] = None,
|
||||
min_score: Optional[float] = None,
|
||||
max_score: Optional[float] = None,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
query = db.query(TechniqueConfidenceScore)
|
||||
if confidence_level:
|
||||
query = query.filter(TechniqueConfidenceScore.confidence_level == confidence_level)
|
||||
if min_score is not None:
|
||||
query = query.filter(TechniqueConfidenceScore.confidence_score >= min_score)
|
||||
if max_score is not None:
|
||||
query = query.filter(TechniqueConfidenceScore.confidence_score <= max_score)
|
||||
return query.order_by(TechniqueConfidenceScore.confidence_score.asc()).all()
|
||||
|
||||
|
||||
@router.get("/confidence/{technique_id}", response_model=TechniqueConfidenceOut)
|
||||
def get_technique_confidence(
|
||||
technique_id: UUID,
|
||||
recalculate: bool = Query(False),
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
if recalculate:
|
||||
return decay_engine_service.calculate_confidence_for_technique(db, technique_id)
|
||||
score = db.query(TechniqueConfidenceScore).filter(TechniqueConfidenceScore.technique_id == technique_id).first()
|
||||
if not score:
|
||||
return decay_engine_service.calculate_confidence_for_technique(db, technique_id)
|
||||
return score
|
||||
|
||||
|
||||
# ── Infrastructure Changes ───────────────────────────────────────────────────
|
||||
|
||||
@router.post("/infrastructure-changes", response_model=InfrastructureChangeOut, status_code=201)
|
||||
def report_infrastructure_change(
|
||||
body: InfrastructureChangeCreate,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(require_any_role("admin", "blue_lead")),
|
||||
):
|
||||
change = InfrastructureChangeLog(
|
||||
change_type=body.change_type,
|
||||
description=body.description,
|
||||
affected_platforms=body.affected_platforms,
|
||||
affected_log_sources=body.affected_log_sources,
|
||||
change_date=body.change_date or _now(),
|
||||
auto_invalidate=body.auto_invalidate,
|
||||
reported_by=user.id,
|
||||
)
|
||||
db.add(change)
|
||||
db.commit()
|
||||
db.refresh(change)
|
||||
|
||||
if change.auto_invalidate:
|
||||
decay_engine_service.process_infrastructure_change(db, change.id)
|
||||
db.refresh(change)
|
||||
|
||||
audit_service.log_action(db, user.id, "INFRASTRUCTURE_CHANGE_REPORTED", "infrastructure_change", str(change.id),
|
||||
details={"type": body.change_type, "invalidated_count": change.invalidated_count})
|
||||
|
||||
return change
|
||||
|
||||
|
||||
@router.get("/infrastructure-changes", response_model=list[InfrastructureChangeOut])
|
||||
def list_infrastructure_changes(
|
||||
days: int = Query(90, ge=1, le=730),
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
cutoff = _now() - timedelta(days=days)
|
||||
return db.query(InfrastructureChangeLog).filter(InfrastructureChangeLog.change_date >= cutoff).order_by(InfrastructureChangeLog.change_date.desc()).all()
|
||||
|
||||
|
||||
# ── Decay Engine Control ─────────────────────────────────────────────────────
|
||||
|
||||
@router.post("/decay-engine/run")
|
||||
def trigger_decay_engine(db: Session = Depends(get_db), user=Depends(require_any_role("admin"))):
|
||||
results = decay_engine_service.run_decay_engine(db)
|
||||
return {"message": "Decay engine completed", "results": results}
|
||||
|
||||
|
||||
# ── Dashboard ────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/dashboard")
|
||||
def lifecycle_dashboard(db: Session = Depends(get_db), user=Depends(get_current_user)):
|
||||
now = _now()
|
||||
|
||||
health_dist = dict(
|
||||
db.query(DetectionAsset.health_status, func.count(DetectionAsset.id))
|
||||
.filter(DetectionAsset.is_active == True)
|
||||
.group_by(DetectionAsset.health_status)
|
||||
.all()
|
||||
)
|
||||
confidence_dist = dict(
|
||||
db.query(TechniqueConfidenceScore.confidence_level, func.count(TechniqueConfidenceScore.id))
|
||||
.group_by(TechniqueConfidenceScore.confidence_level)
|
||||
.all()
|
||||
)
|
||||
expiring_soon = db.query(func.count(DetectionValidation.id)).filter(
|
||||
DetectionValidation.is_valid == True,
|
||||
DetectionValidation.expires_at <= (now + timedelta(days=7)),
|
||||
).scalar() or 0
|
||||
|
||||
total_assets = db.query(func.count(DetectionAsset.id)).filter(DetectionAsset.is_active == True).scalar() or 0
|
||||
total_valid = db.query(func.count(DetectionValidation.id)).filter(DetectionValidation.is_valid == True).scalar() or 0
|
||||
recent_changes = db.query(func.count(InfrastructureChangeLog.id)).filter(
|
||||
InfrastructureChangeLog.change_date >= (now - timedelta(days=30))
|
||||
).scalar() or 0
|
||||
|
||||
return {
|
||||
"total_detection_assets": total_assets,
|
||||
"total_valid_validations": total_valid,
|
||||
"health_distribution": {k.value if hasattr(k, "value") else str(k): v for k, v in health_dist.items()},
|
||||
"confidence_distribution": {k.value if hasattr(k, "value") else str(k): v for k, v in confidence_dist.items()},
|
||||
"validations_expiring_7d": expiring_soon,
|
||||
"infrastructure_changes_30d": recent_changes,
|
||||
}
|
||||
@@ -80,7 +80,7 @@ def list_detection_rules(
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> list:
|
||||
) -> dict:
|
||||
"""List detection rules with optional filters and pagination."""
|
||||
# Return list_rules(
|
||||
return list_rules(
|
||||
@@ -112,7 +112,7 @@ def get_detection_rules_for_template(
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> list:
|
||||
) -> dict:
|
||||
"""Get detection rules associated with a test template."""
|
||||
# Return get_rules_for_template(db, template_id)
|
||||
return get_rules_for_template(db, template_id)
|
||||
@@ -151,7 +151,7 @@ def get_detection_rules_for_test(
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> list:
|
||||
) -> dict:
|
||||
"""Get detection rules relevant to a test, along with their evaluation results.
|
||||
|
||||
Finds rules by matching the test's technique_id to detection rules,
|
||||
|
||||
@@ -4,7 +4,8 @@ Endpoints
|
||||
---------
|
||||
POST /tests/{test_id}/evidence — upload evidence (with team=red/blue)
|
||||
GET /tests/{test_id}/evidence — list evidences (filterable by team)
|
||||
GET /evidence/{id} — presigned download URL
|
||||
GET /evidence/{id} — metadata + download_url
|
||||
GET /evidence/{id}/file — proxy download (streams file through backend)
|
||||
DELETE /evidence/{id} — delete evidence (only in editable states)
|
||||
|
||||
Access Control
|
||||
@@ -21,20 +22,17 @@ Access Control
|
||||
|
||||
# Import hashlib
|
||||
import hashlib
|
||||
|
||||
# Import os
|
||||
import logging
|
||||
import os
|
||||
|
||||
# Import uuid
|
||||
import uuid as _uuid
|
||||
|
||||
# Import Optional from typing
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
# Import APIRouter, Depends, File, Form, Query, Request,... from fastapi
|
||||
from fastapi import APIRouter, Depends, File, Form, Query, Request, UploadFile, status
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import get_db from app.database
|
||||
@@ -74,9 +72,9 @@ from app.services.evidence_service import (
|
||||
validate_file,
|
||||
validate_upload_permission,
|
||||
)
|
||||
from app.storage import download_file, upload_file
|
||||
|
||||
# Import get_presigned_url, upload_file from app.storage
|
||||
from app.storage import get_presigned_url, upload_file
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Assign router = APIRouter(tags=["evidence"])
|
||||
router = APIRouter(tags=["evidence"])
|
||||
@@ -87,8 +85,11 @@ router = APIRouter(tags=["evidence"])
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _evidence_to_out(evidence: Evidence) -> EvidenceOut:
|
||||
"""Convert an ORM ``Evidence`` to the API schema, injecting a presigned URL."""
|
||||
# Return EvidenceOut(
|
||||
"""Convert an ORM ``Evidence`` to the API schema.
|
||||
|
||||
``download_url`` points to the backend proxy endpoint so the browser
|
||||
never needs direct access to MinIO.
|
||||
"""
|
||||
return EvidenceOut(
|
||||
# Keyword argument: id
|
||||
id=evidence.id,
|
||||
@@ -106,8 +107,7 @@ def _evidence_to_out(evidence: Evidence) -> EvidenceOut:
|
||||
team=evidence.team,
|
||||
# Keyword argument: notes
|
||||
notes=evidence.notes,
|
||||
# Keyword argument: download_url
|
||||
download_url=get_presigned_url(evidence.file_path),
|
||||
download_url=f"/api/v1/evidence/{evidence.id}/file",
|
||||
)
|
||||
|
||||
|
||||
@@ -185,7 +185,7 @@ async def upload_evidence(
|
||||
sha256_hash=sha256,
|
||||
# Keyword argument: uploaded_by
|
||||
uploaded_by=current_user.id,
|
||||
# Keyword argument: team
|
||||
uploaded_at=datetime.utcnow(), # set explicitly — DB column has no server default
|
||||
team=team,
|
||||
# Keyword argument: notes
|
||||
notes=notes,
|
||||
@@ -222,10 +222,43 @@ async def upload_evidence(
|
||||
# Reload ORM object attributes from the database
|
||||
db.refresh(evidence)
|
||||
|
||||
# Return _evidence_to_out(evidence)
|
||||
# 7. Attach to Jira ticket if one exists (non-fatal)
|
||||
_attach_evidence_to_jira(db, test_id, content, safe_name, current_user)
|
||||
|
||||
return _evidence_to_out(evidence)
|
||||
|
||||
|
||||
def _attach_evidence_to_jira(
|
||||
db,
|
||||
test_id: _uuid.UUID,
|
||||
content: bytes,
|
||||
file_name: str,
|
||||
actor,
|
||||
) -> None:
|
||||
"""Attach uploaded evidence to the linked Jira ticket (non-fatal)."""
|
||||
try:
|
||||
from app.services.jira_service import get_test_jira_key, get_user_jira_client, has_jira_configured
|
||||
if not has_jira_configured(actor, db):
|
||||
return
|
||||
issue_key = get_test_jira_key(db, test_id)
|
||||
if not issue_key:
|
||||
return
|
||||
import io
|
||||
jira = get_user_jira_client(actor, db)
|
||||
buf = io.BytesIO(content)
|
||||
buf.name = file_name # requests uses .name as the multipart filename
|
||||
jira.add_attachment_object(issue_key, buf)
|
||||
import logging
|
||||
logging.getLogger(__name__).info(
|
||||
"Attached evidence '%s' to Jira ticket %s", file_name, issue_key
|
||||
)
|
||||
except Exception as exc:
|
||||
import logging
|
||||
logging.getLogger(__name__).warning(
|
||||
"Failed to attach evidence '%s' to Jira: %s", file_name, exc, exc_info=True
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /tests/{test_id}/evidence — list (with optional team filter)
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -253,7 +286,7 @@ def list_evidence(
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /evidence/{id} — presigned download URL
|
||||
# GET /evidence/{id} — metadata + proxy download URL
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@@ -266,14 +299,50 @@ def get_evidence(
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> EvidenceOut:
|
||||
"""Return evidence metadata together with a presigned download URL."""
|
||||
# Assign evidence = get_evidence_or_raise(db, evidence_id)
|
||||
):
|
||||
"""Return evidence metadata. ``download_url`` is a backend proxy URL."""
|
||||
evidence = get_evidence_or_raise(db, evidence_id)
|
||||
# Return _evidence_to_out(evidence)
|
||||
return _evidence_to_out(evidence)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /evidence/{id}/file — proxy download (streams file via backend)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get("/evidence/{evidence_id}/file")
|
||||
def download_evidence_file(
|
||||
evidence_id: _uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Stream the evidence file through the backend.
|
||||
|
||||
The browser calls this endpoint (authenticated via JWT cookie/header).
|
||||
The backend fetches the file from MinIO internally and streams it back,
|
||||
so MinIO never needs to be publicly accessible.
|
||||
"""
|
||||
import mimetypes
|
||||
|
||||
evidence = get_evidence_or_raise(db, evidence_id)
|
||||
content = download_file(evidence.file_path)
|
||||
|
||||
mime_type, _ = mimetypes.guess_type(evidence.file_name)
|
||||
if not mime_type:
|
||||
mime_type = "application/octet-stream"
|
||||
|
||||
safe_name = evidence.file_name.replace('"', '\\"')
|
||||
return StreamingResponse(
|
||||
iter([content]),
|
||||
media_type=mime_type,
|
||||
headers={
|
||||
"Content-Disposition": f'inline; filename="{safe_name}"',
|
||||
"Content-Length": str(len(content)),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DELETE /evidence/{id} — delete evidence (editable states only)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -0,0 +1,123 @@
|
||||
"""Phase 13: Executive Dashboard router."""
|
||||
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import get_current_user, require_any_role
|
||||
from app.schemas.executive_dashboard_schema import (
|
||||
PostureSnapshotOut,
|
||||
ExecutiveSummary,
|
||||
KpiBlock,
|
||||
CoverageByTactic,
|
||||
PostureHistoryEntry,
|
||||
ActivityEntry,
|
||||
)
|
||||
import app.services.executive_dashboard_service as svc
|
||||
|
||||
router = APIRouter(prefix="/dashboard", tags=["Executive Dashboard"])
|
||||
|
||||
|
||||
@router.get("/executive", response_model=ExecutiveSummary)
|
||||
def executive_view(
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Full executive view — snapshot, 30-day trends, top risks,
|
||||
coverage by tactic, and recent activity feed.
|
||||
"""
|
||||
data = svc.get_executive_summary(db)
|
||||
snap = data["snapshot"]
|
||||
return ExecutiveSummary(
|
||||
snapshot=PostureSnapshotOut.model_validate(snap),
|
||||
coverage_trend=data["coverage_trend"],
|
||||
risk_trend=data["risk_trend"],
|
||||
top_risks=data["top_risks"],
|
||||
coverage_by_tactic=data["coverage_by_tactic"],
|
||||
recent_activity=data["recent_activity"],
|
||||
)
|
||||
|
||||
|
||||
@router.get("/kpis", response_model=KpiBlock)
|
||||
def kpis(
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
"""Compact KPI block — live aggregation without persisting a snapshot."""
|
||||
live = svc.get_live_kpis(db)
|
||||
|
||||
# Try to find today's snapshot id; fall back to None
|
||||
from datetime import date
|
||||
from app.models.executive_dashboard import PostureSnapshot
|
||||
today_snap = db.query(PostureSnapshot).filter(
|
||||
PostureSnapshot.snapshot_date == date.today()
|
||||
).first()
|
||||
|
||||
return KpiBlock(
|
||||
coverage_pct=live["coverage_pct"],
|
||||
avg_risk_score=live["avg_risk_score"],
|
||||
critical_count=live["critical_count"],
|
||||
open_queue_items=live["open_queue_items"],
|
||||
orphan_techniques=live["orphan_techniques"],
|
||||
mttd_avg_seconds=live.get("mttd_avg_seconds"),
|
||||
detection_rate_30d=live.get("detection_rate_30d"),
|
||||
playbook_count=live["playbook_count"],
|
||||
lesson_count=live["lesson_count"],
|
||||
snapshot_date=live["snapshot_date"],
|
||||
snapshot_id=today_snap.id if today_snap else None,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/coverage-by-tactic", response_model=List[CoverageByTactic])
|
||||
def coverage_by_tactic(
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
"""Per-tactic validated / partial / not_covered breakdown."""
|
||||
return svc.get_coverage_by_tactic(db)
|
||||
|
||||
|
||||
@router.get("/posture-history", response_model=List[PostureHistoryEntry])
|
||||
def posture_history(
|
||||
days: int = Query(30, ge=1, le=365),
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
"""Historical posture snapshots for trend charts (default last 30 days)."""
|
||||
snaps = svc.get_posture_history(db, days=days)
|
||||
return [
|
||||
PostureHistoryEntry(
|
||||
snapshot_date=s.snapshot_date,
|
||||
coverage_pct=s.coverage_pct,
|
||||
avg_risk_score=s.avg_risk_score,
|
||||
critical_count=s.critical_count,
|
||||
open_queue_items=s.open_queue_items,
|
||||
)
|
||||
for s in snaps
|
||||
]
|
||||
|
||||
|
||||
@router.post("/posture-snapshot", response_model=PostureSnapshotOut, status_code=201)
|
||||
def create_posture_snapshot(
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(require_any_role("admin", "red_lead", "blue_lead")),
|
||||
):
|
||||
"""
|
||||
Take (or refresh) today's posture snapshot — admin / leads only.
|
||||
Aggregates live data from all phases into a single PostureSnapshot row.
|
||||
"""
|
||||
snap = svc.take_posture_snapshot(db, created_by=user.id)
|
||||
return PostureSnapshotOut.model_validate(snap)
|
||||
|
||||
|
||||
@router.get("/activity", response_model=List[ActivityEntry])
|
||||
def recent_activity(
|
||||
limit: int = Query(20, ge=1, le=100),
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
"""Recent activity feed — tests, attack-path executions, OSINT signals."""
|
||||
return svc.get_recent_activity(db, limit=limit)
|
||||
@@ -0,0 +1,54 @@
|
||||
"""Intel items endpoints — list and manage threat intelligence items."""
|
||||
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import get_current_user
|
||||
from app.models.intel import IntelItem
|
||||
from app.models.user import User
|
||||
|
||||
router = APIRouter(prefix="/intel", tags=["intel"])
|
||||
|
||||
|
||||
class IntelItemOut(BaseModel):
|
||||
id: uuid.UUID
|
||||
technique_id: Optional[uuid.UUID] = None
|
||||
url: str
|
||||
title: Optional[str] = None
|
||||
source: Optional[str] = None
|
||||
detected_at: Optional[str] = None
|
||||
reviewed: bool
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
@router.get("/items", response_model=list[IntelItemOut])
|
||||
def list_intel_items(
|
||||
technique_id: Optional[uuid.UUID] = Query(None, description="Filter by technique"),
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""List threat intelligence items, optionally filtered by technique."""
|
||||
query = db.query(IntelItem).order_by(IntelItem.detected_at.desc())
|
||||
if technique_id:
|
||||
query = query.filter(IntelItem.technique_id == technique_id)
|
||||
items = query.limit(limit).all()
|
||||
return [
|
||||
IntelItemOut(
|
||||
id=item.id,
|
||||
technique_id=item.technique_id,
|
||||
url=item.url,
|
||||
title=item.title,
|
||||
source=item.source,
|
||||
detected_at=item.detected_at.isoformat() if item.detected_at else None,
|
||||
reviewed=item.reviewed,
|
||||
)
|
||||
for item in items
|
||||
]
|
||||
@@ -129,19 +129,19 @@ def list_links(
|
||||
entity_type: Optional[JiraLinkEntityType] = None,
|
||||
# Entry: entity_id
|
||||
entity_id: Optional[UUID] = None,
|
||||
# Entry: db
|
||||
entity_ids: Optional[list[UUID]] = Query(default=None, description="Filter by multiple entity IDs"),
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: user
|
||||
user: User = Depends(get_current_user),
|
||||
) -> list[JiraLinkOut]:
|
||||
"""List Jira links, optionally filtered by entity."""
|
||||
# Return jira_service.list_links(
|
||||
):
|
||||
"""List Jira links, optionally filtered by entity or a list of entity IDs."""
|
||||
return jira_service.list_links(
|
||||
db,
|
||||
# Keyword argument: entity_type
|
||||
entity_type=entity_type,
|
||||
# Keyword argument: entity_id
|
||||
entity_id=entity_id,
|
||||
entity_ids=entity_ids,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,206 @@
|
||||
"""Phase 11: Knowledge Management router — Playbooks + Lessons Learned."""
|
||||
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import get_current_user, require_any_role
|
||||
from app.schemas.knowledge_schema import (
|
||||
PlaybookCreate, PlaybookUpdate, PlaybookOut, PlaybookVersionOut,
|
||||
LessonLearnedCreate, LessonLearnedUpdate, LessonLearnedOut,
|
||||
)
|
||||
from app.services import playbook_service as pb_svc
|
||||
from app.services import lesson_learned_service as ll_svc
|
||||
|
||||
router = APIRouter(prefix="/knowledge", tags=["knowledge"])
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════════════════════════════
|
||||
# Playbooks
|
||||
# ══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
@router.get("/playbooks", response_model=List[PlaybookOut])
|
||||
def list_playbooks(
|
||||
technique_id: Optional[UUID] = None,
|
||||
playbook_type: Optional[str] = None,
|
||||
include_inactive: bool = False,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
return pb_svc.list_playbooks(
|
||||
db,
|
||||
technique_id=technique_id,
|
||||
playbook_type=playbook_type,
|
||||
include_inactive=include_inactive,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/playbooks", response_model=PlaybookOut, status_code=201)
|
||||
def create_playbook(
|
||||
body: PlaybookCreate,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(require_any_role("admin", "red_lead", "blue_lead")),
|
||||
):
|
||||
return pb_svc.create_playbook(db, body.model_dump(), user.id)
|
||||
|
||||
|
||||
@router.get("/playbooks/{playbook_id}", response_model=PlaybookOut)
|
||||
def get_playbook(
|
||||
playbook_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
return pb_svc.get_playbook(db, playbook_id)
|
||||
|
||||
|
||||
@router.patch("/playbooks/{playbook_id}", response_model=PlaybookOut)
|
||||
def update_playbook(
|
||||
playbook_id: UUID,
|
||||
body: PlaybookUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(require_any_role("admin", "red_lead", "blue_lead")),
|
||||
):
|
||||
return pb_svc.update_playbook(db, playbook_id, body.model_dump(exclude_unset=True), user.id)
|
||||
|
||||
|
||||
@router.delete("/playbooks/{playbook_id}", status_code=204)
|
||||
def delete_playbook(
|
||||
playbook_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(require_any_role("admin", "red_lead", "blue_lead")),
|
||||
):
|
||||
pb_svc.delete_playbook(db, playbook_id, user.id)
|
||||
|
||||
|
||||
# ── Versions ──────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/playbooks/{playbook_id}/versions", response_model=List[PlaybookVersionOut])
|
||||
def list_versions(
|
||||
playbook_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
return pb_svc.get_playbook_versions(db, playbook_id)
|
||||
|
||||
|
||||
@router.post("/playbooks/{playbook_id}/restore/{version}", response_model=PlaybookOut)
|
||||
def restore_version(
|
||||
playbook_id: UUID,
|
||||
version: int,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(require_any_role("admin", "red_lead", "blue_lead")),
|
||||
):
|
||||
"""Roll the playbook back to a specific historical version."""
|
||||
return pb_svc.restore_version(db, playbook_id, version, user.id)
|
||||
|
||||
|
||||
# ── By technique (convenience) ────────────────────────────────────────────────
|
||||
|
||||
@router.get(
|
||||
"/techniques/{technique_id}/playbooks",
|
||||
response_model=List[PlaybookOut],
|
||||
)
|
||||
def playbooks_for_technique(
|
||||
technique_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
"""List all active playbooks for a specific technique."""
|
||||
return pb_svc.list_playbooks(db, technique_id=technique_id)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/techniques/{technique_id}/playbooks/{playbook_type}",
|
||||
response_model=PlaybookOut,
|
||||
)
|
||||
def get_playbook_by_technique_type(
|
||||
technique_id: UUID,
|
||||
playbook_type: str,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
pb = pb_svc.get_playbook_by_technique_type(db, technique_id, playbook_type)
|
||||
if not pb:
|
||||
from app.domain.errors import EntityNotFoundError
|
||||
raise EntityNotFoundError("Playbook", f"{technique_id}/{playbook_type}")
|
||||
return pb
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════════════════════════════
|
||||
# Lessons Learned
|
||||
# ══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
@router.get("/lessons", response_model=List[LessonLearnedOut])
|
||||
def list_lessons(
|
||||
entity_type: Optional[str] = None,
|
||||
entity_id: Optional[UUID] = None,
|
||||
severity: Optional[str] = None,
|
||||
tag: Optional[str] = None,
|
||||
technique_id: Optional[str] = None,
|
||||
include_inactive: bool = False,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
return ll_svc.list_lessons_learned(
|
||||
db,
|
||||
entity_type=entity_type,
|
||||
entity_id=entity_id,
|
||||
severity=severity,
|
||||
tag=tag,
|
||||
technique_id=technique_id,
|
||||
include_inactive=include_inactive,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/lessons", response_model=LessonLearnedOut, status_code=201)
|
||||
def create_lesson(
|
||||
body: LessonLearnedCreate,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(require_any_role("admin", "red_lead", "blue_lead")),
|
||||
):
|
||||
return ll_svc.create_lesson_learned(db, body.model_dump(), user.id)
|
||||
|
||||
|
||||
@router.get("/lessons/{lesson_id}", response_model=LessonLearnedOut)
|
||||
def get_lesson(
|
||||
lesson_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
return ll_svc.get_lesson_learned(db, lesson_id)
|
||||
|
||||
|
||||
@router.patch("/lessons/{lesson_id}", response_model=LessonLearnedOut)
|
||||
def update_lesson(
|
||||
lesson_id: UUID,
|
||||
body: LessonLearnedUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(require_any_role("admin", "red_lead", "blue_lead")),
|
||||
):
|
||||
return ll_svc.update_lesson_learned(
|
||||
db, lesson_id, body.model_dump(exclude_unset=True), user.id
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/lessons/{lesson_id}", status_code=204)
|
||||
def delete_lesson(
|
||||
lesson_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(require_any_role("admin", "red_lead", "blue_lead")),
|
||||
):
|
||||
"""Soft-delete a lesson (admin / lead only)."""
|
||||
ll_svc.delete_lesson_learned(db, lesson_id, user.id)
|
||||
|
||||
|
||||
# ── Stats ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/stats")
|
||||
def knowledge_stats(
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
"""Summary counts: total playbooks, lessons by severity, playbooks by type."""
|
||||
return ll_svc.get_knowledge_stats(db)
|
||||
@@ -0,0 +1,191 @@
|
||||
"""Phase 13: Operational Alerts router."""
|
||||
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import get_current_user, require_any_role
|
||||
from app.models.user import User
|
||||
from app.schemas.operational_alert_schema import (
|
||||
AlertRuleCreate, AlertRuleOut, AlertRuleUpdate,
|
||||
AlertInstanceOut, EvaluationResult, AlertSummary,
|
||||
)
|
||||
import app.services.operational_alert_service as svc
|
||||
|
||||
router = APIRouter(prefix="/alerts", tags=["Operational Alerts"])
|
||||
|
||||
|
||||
# ── Evaluation ────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.post("/evaluate", response_model=EvaluationResult, status_code=202)
|
||||
def evaluate_rules(
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(require_any_role("admin", "red_lead", "blue_lead")),
|
||||
):
|
||||
"""
|
||||
Run the alert evaluation engine against all enabled rules.
|
||||
|
||||
Fires AlertInstances for rules whose conditions are met and are not in cooldown.
|
||||
Admin / leads only.
|
||||
"""
|
||||
result = svc.evaluate_all_rules(db)
|
||||
return EvaluationResult(
|
||||
rules_evaluated = result["rules_evaluated"],
|
||||
alerts_fired = result["alerts_fired"],
|
||||
alerts = [AlertInstanceOut.model_validate(a) for a in result["alerts"]],
|
||||
duration_seconds = result["duration_seconds"],
|
||||
)
|
||||
|
||||
|
||||
# ── Alert instances ───────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("", response_model=List[AlertInstanceOut])
|
||||
def list_alerts(
|
||||
status: Optional[str] = Query(None),
|
||||
severity: Optional[str] = Query(None),
|
||||
rule_type: Optional[str] = Query(None),
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
offset: int = Query(0, ge=0),
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
"""List alert instances with optional filters."""
|
||||
return svc.list_instances(db, status=status, severity=severity,
|
||||
rule_type=rule_type, limit=limit, offset=offset)
|
||||
|
||||
|
||||
@router.get("/summary", response_model=AlertSummary)
|
||||
def alert_summary(
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
"""Aggregate counts by status, severity, and rule type."""
|
||||
data = svc.get_summary(db)
|
||||
return AlertSummary(
|
||||
total_open = data["total_open"],
|
||||
total_acknowledged = data["total_acknowledged"],
|
||||
total_resolved = data["total_resolved"],
|
||||
by_severity = data["by_severity"],
|
||||
by_rule_type = data["by_rule_type"],
|
||||
recent_alerts = [AlertInstanceOut.model_validate(a) for a in data["recent_alerts"]],
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{alert_id}", response_model=AlertInstanceOut)
|
||||
def get_alert(
|
||||
alert_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
"""Get a single alert instance."""
|
||||
return svc.get_instance(db, alert_id)
|
||||
|
||||
|
||||
@router.post("/{alert_id}/acknowledge", response_model=AlertInstanceOut)
|
||||
def acknowledge_alert(
|
||||
alert_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_any_role("admin", "red_lead", "blue_lead")),
|
||||
):
|
||||
"""Acknowledge an open alert (admin / lead roles only)."""
|
||||
return svc.acknowledge(db, alert_id, current_user.id)
|
||||
|
||||
|
||||
@router.post("/{alert_id}/resolve", response_model=AlertInstanceOut)
|
||||
def resolve_alert(
|
||||
alert_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_any_role("admin", "red_lead", "blue_lead")),
|
||||
):
|
||||
"""Mark an alert as resolved (admin / lead roles only)."""
|
||||
return svc.resolve(db, alert_id, current_user.id)
|
||||
|
||||
|
||||
@router.post("/{alert_id}/dismiss", response_model=AlertInstanceOut)
|
||||
def dismiss_alert(
|
||||
alert_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_any_role("admin", "red_lead", "blue_lead")),
|
||||
):
|
||||
"""Dismiss an alert (admin / lead roles only — won't re-fire until cooldown resets)."""
|
||||
return svc.dismiss(db, alert_id, current_user.id)
|
||||
|
||||
|
||||
# ── Alert rules ───────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/rules/list", response_model=List[AlertRuleOut])
|
||||
def list_rules(
|
||||
rule_type: Optional[str] = Query(None),
|
||||
include_disabled: bool = Query(False),
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
"""List alert rules (all users can read; admins/leads manage them)."""
|
||||
return svc.list_rules(db, rule_type=rule_type, include_disabled=include_disabled)
|
||||
|
||||
|
||||
@router.post("/rules", response_model=AlertRuleOut, status_code=201)
|
||||
def create_rule(
|
||||
body: AlertRuleCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_any_role("admin", "red_lead", "blue_lead")),
|
||||
):
|
||||
"""Create a custom alert rule."""
|
||||
return svc.create_rule(
|
||||
db,
|
||||
created_by = current_user.id,
|
||||
name = body.name,
|
||||
description = body.description,
|
||||
rule_type = body.rule_type,
|
||||
severity = body.severity,
|
||||
config = body.config,
|
||||
notify_in_app = body.notify_in_app,
|
||||
notify_webhook = body.notify_webhook,
|
||||
webhook_id = body.webhook_id,
|
||||
cooldown_hours = body.cooldown_hours,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/rules/{rule_id}", response_model=AlertRuleOut)
|
||||
def get_rule(
|
||||
rule_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
"""Get a single alert rule."""
|
||||
return svc.get_rule(db, rule_id)
|
||||
|
||||
|
||||
@router.patch("/rules/{rule_id}", response_model=AlertRuleOut)
|
||||
def update_rule(
|
||||
rule_id: UUID,
|
||||
body: AlertRuleUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(require_any_role("admin", "red_lead", "blue_lead")),
|
||||
):
|
||||
"""Update an alert rule (enable/disable, thresholds, cooldown)."""
|
||||
return svc.update_rule(
|
||||
db, rule_id,
|
||||
name = body.name,
|
||||
description = body.description,
|
||||
severity = body.severity,
|
||||
is_enabled = body.is_enabled,
|
||||
config = body.config,
|
||||
notify_in_app = body.notify_in_app,
|
||||
notify_webhook = body.notify_webhook,
|
||||
webhook_id = body.webhook_id,
|
||||
cooldown_hours = body.cooldown_hours,
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/rules/{rule_id}", status_code=204)
|
||||
def delete_rule(
|
||||
rule_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(require_any_role("admin")),
|
||||
):
|
||||
"""Delete a custom alert rule (system rules cannot be deleted)."""
|
||||
svc.delete_rule(db, rule_id)
|
||||
@@ -97,7 +97,7 @@ def list_osint_items(
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: user
|
||||
user: User = Depends(get_current_user),
|
||||
) -> list:
|
||||
) -> dict:
|
||||
"""List OSINT items with optional filters.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -0,0 +1,215 @@
|
||||
"""Phase 9: Ownership & Daily Operations router."""
|
||||
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import get_current_user, require_any_role
|
||||
from app.domain.exceptions import EntityNotFoundError
|
||||
from app.schemas.ownership_queue_schema import (
|
||||
TechniqueOwnershipSet, TechniqueOwnershipOut,
|
||||
DetectionAssetOwnershipPatch,
|
||||
BulkAssignRequest, BulkAssignResult,
|
||||
QueueItemCreate, QueueItemPatch, QueueItemOut,
|
||||
)
|
||||
from app.services import ownership_service, revalidation_queue_service
|
||||
from app.models.ownership_queue import RevalidationQueueItem
|
||||
|
||||
router = APIRouter(prefix="/ownership", tags=["ownership"])
|
||||
|
||||
|
||||
# ── Technique Ownership ───────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/techniques/{technique_id}", response_model=TechniqueOwnershipOut)
|
||||
def get_technique_ownership(
|
||||
technique_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
ownership = ownership_service.get_technique_ownership(db, technique_id)
|
||||
if not ownership:
|
||||
raise EntityNotFoundError("TechniqueOwnership", str(technique_id))
|
||||
return ownership
|
||||
|
||||
|
||||
@router.put("/techniques/{technique_id}", response_model=TechniqueOwnershipOut)
|
||||
def set_technique_ownership(
|
||||
technique_id: UUID,
|
||||
body: TechniqueOwnershipSet,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(require_any_role("admin", "blue_lead", "red_lead")),
|
||||
):
|
||||
return ownership_service.set_technique_ownership(
|
||||
db, technique_id,
|
||||
owner_id=body.owner_id,
|
||||
backup_owner_id=body.backup_owner_id,
|
||||
team=body.team,
|
||||
notes=body.notes,
|
||||
assigned_by=user.id,
|
||||
)
|
||||
|
||||
|
||||
# ── Detection Asset Ownership ─────────────────────────────────────────────────
|
||||
|
||||
@router.patch("/assets/{asset_id}", response_model=dict)
|
||||
def set_asset_ownership(
|
||||
asset_id: UUID,
|
||||
body: DetectionAssetOwnershipPatch,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(require_any_role("admin", "blue_lead")),
|
||||
):
|
||||
ownership_service.set_asset_ownership(
|
||||
db, asset_id,
|
||||
owner_id=body.owner_id,
|
||||
backup_owner_id=body.backup_owner_id,
|
||||
team=body.team,
|
||||
user_id=user.id,
|
||||
)
|
||||
return {"message": "Asset ownership updated"}
|
||||
|
||||
|
||||
# ── Orphan Reports ────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/orphans/techniques", response_model=list[dict])
|
||||
def orphan_techniques(
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
"""Return techniques with no assigned owner."""
|
||||
return ownership_service.get_orphan_techniques(db)
|
||||
|
||||
|
||||
@router.get("/orphans/assets", response_model=list[dict])
|
||||
def orphan_assets(
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
"""Return active detection assets with no assigned owner."""
|
||||
return ownership_service.get_orphan_assets(db)
|
||||
|
||||
|
||||
# ── Bulk Assignment ───────────────────────────────────────────────────────────
|
||||
|
||||
@router.post("/bulk-assign", response_model=BulkAssignResult)
|
||||
def bulk_assign(
|
||||
body: BulkAssignRequest,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(require_any_role("admin", "blue_lead", "red_lead")),
|
||||
):
|
||||
"""
|
||||
Bulk-assign ownership.
|
||||
- If `tactic` is set → assigns technique ownership for all techniques of that tactic.
|
||||
- If `platform` is set → assigns asset ownership for all assets on that platform.
|
||||
At least one of tactic/platform must be provided.
|
||||
"""
|
||||
if not body.tactic and not body.platform:
|
||||
from fastapi import HTTPException
|
||||
raise HTTPException(status_code=422, detail="Provide at least one of: tactic, platform")
|
||||
|
||||
if body.tactic:
|
||||
result = ownership_service.bulk_assign_techniques_by_tactic(
|
||||
db, body.tactic,
|
||||
owner_id=body.owner_id,
|
||||
backup_owner_id=body.backup_owner_id,
|
||||
team=body.team,
|
||||
overwrite=body.overwrite,
|
||||
user_id=user.id,
|
||||
)
|
||||
else:
|
||||
result = ownership_service.bulk_assign_assets_by_platform(
|
||||
db, body.platform,
|
||||
owner_id=body.owner_id,
|
||||
backup_owner_id=body.backup_owner_id,
|
||||
team=body.team,
|
||||
overwrite=body.overwrite,
|
||||
user_id=user.id,
|
||||
)
|
||||
|
||||
return BulkAssignResult(**result)
|
||||
|
||||
|
||||
# ── Revalidation Queue ────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/queue", response_model=list[QueueItemOut])
|
||||
def list_queue(
|
||||
status: Optional[str] = Query(None),
|
||||
priority: Optional[str] = Query(None),
|
||||
reason: Optional[str] = Query(None),
|
||||
assigned_to: Optional[UUID] = Query(None),
|
||||
technique_id: Optional[UUID] = Query(None),
|
||||
detection_asset_id: Optional[UUID] = Query(None),
|
||||
limit: int = Query(100, ge=1, le=500),
|
||||
offset: int = Query(0, ge=0),
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
return revalidation_queue_service.list_queue(
|
||||
db, status=status, priority=priority, reason=reason,
|
||||
assigned_to=assigned_to, technique_id=technique_id,
|
||||
detection_asset_id=detection_asset_id, limit=limit, offset=offset,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/queue", response_model=QueueItemOut, status_code=201)
|
||||
def create_queue_item(
|
||||
body: QueueItemCreate,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
return revalidation_queue_service.create_queue_item(db, body.model_dump(), user.id)
|
||||
|
||||
|
||||
@router.patch("/queue/{item_id}", response_model=QueueItemOut)
|
||||
def update_queue_item(
|
||||
item_id: UUID,
|
||||
body: QueueItemPatch,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
return revalidation_queue_service.update_queue_item(db, item_id, body.model_dump(exclude_unset=True), user.id)
|
||||
|
||||
|
||||
@router.post("/queue/generate", response_model=dict)
|
||||
def generate_queue(
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(require_any_role("admin", "blue_lead")),
|
||||
):
|
||||
"""Scan the system and create new revalidation queue items."""
|
||||
return revalidation_queue_service.generate_queue_items(db)
|
||||
|
||||
|
||||
# ── Analyst Dashboard ─────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/analyst-dashboard")
|
||||
def analyst_dashboard(
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
"""Personalised daily workday view: my queue, expiring validations, infra changes, low-confidence techniques."""
|
||||
dashboard = revalidation_queue_service.get_analyst_dashboard(db, user.id)
|
||||
|
||||
# Serialize queue items to dicts (ORM objects → plain dicts)
|
||||
def _item_to_dict(item: RevalidationQueueItem) -> dict:
|
||||
return {
|
||||
"id": str(item.id),
|
||||
"technique_id": str(item.technique_id) if item.technique_id else None,
|
||||
"detection_asset_id": str(item.detection_asset_id) if item.detection_asset_id else None,
|
||||
"priority": item.priority.value if hasattr(item.priority, "value") else item.priority,
|
||||
"reason": item.reason.value if hasattr(item.reason, "value") else item.reason,
|
||||
"reason_detail": item.reason_detail,
|
||||
"status": item.status.value if hasattr(item.status, "value") else item.status,
|
||||
"assigned_to": str(item.assigned_to) if item.assigned_to else None,
|
||||
"due_date": item.due_date.isoformat() if item.due_date else None,
|
||||
"created_at": item.created_at.isoformat() if item.created_at else None,
|
||||
}
|
||||
|
||||
return {
|
||||
"my_pending_items": [_item_to_dict(i) for i in dashboard["my_pending_items"]],
|
||||
"expiring_validations_7d": dashboard["expiring_validations_7d"],
|
||||
"recent_infra_changes": dashboard["recent_infra_changes"],
|
||||
"my_low_confidence_techniques": dashboard["my_low_confidence_techniques"],
|
||||
"summary": dashboard["summary"],
|
||||
}
|
||||
@@ -2,9 +2,10 @@
|
||||
|
||||
# Import UUID from uuid
|
||||
from uuid import UUID
|
||||
from pathlib import Path
|
||||
|
||||
# Import APIRouter, Depends, Query, Request from fastapi
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
# Import APIRouter, Depends, HTTPException, Query, Request from fastapi
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
|
||||
# Import FileResponse from fastapi.responses
|
||||
from fastapi.responses import FileResponse
|
||||
@@ -21,12 +22,24 @@ from app.dependencies.auth import get_current_user, require_any_role
|
||||
# Import limiter from app.limiter
|
||||
from app.limiter import limiter
|
||||
|
||||
# Import settings from app.config
|
||||
from app.config import settings
|
||||
|
||||
# Import User from app.models.user
|
||||
from app.models.user import User
|
||||
|
||||
# Import report_generation_service from app.services
|
||||
from app.services import report_generation_service
|
||||
|
||||
|
||||
def _assert_safe_report_path(filepath: str) -> str:
|
||||
"""Raise 500 if the generated filepath escapes the configured report directory."""
|
||||
output_dir = Path(settings.REPORT_OUTPUT_DIR).resolve()
|
||||
resolved = Path(filepath).resolve()
|
||||
if not resolved.is_relative_to(output_dir):
|
||||
raise HTTPException(status_code=500, detail="Report generation path error")
|
||||
return filepath
|
||||
|
||||
# Assign router = APIRouter(prefix="/reports/generate", tags=["professional-reports"])
|
||||
router = APIRouter(prefix="/reports/generate", tags=["professional-reports"])
|
||||
|
||||
@@ -65,7 +78,7 @@ def generate_purple_report(
|
||||
)
|
||||
# Return FileResponse(
|
||||
return FileResponse(
|
||||
filepath,
|
||||
_assert_safe_report_path(filepath),
|
||||
# Keyword argument: media_type
|
||||
media_type=_MEDIA_TYPES[format],
|
||||
# Keyword argument: filename
|
||||
@@ -95,7 +108,7 @@ def generate_coverage_report(
|
||||
)
|
||||
# Return FileResponse(
|
||||
return FileResponse(
|
||||
filepath,
|
||||
_assert_safe_report_path(filepath),
|
||||
# Keyword argument: media_type
|
||||
media_type=_MEDIA_TYPES[format],
|
||||
# Keyword argument: filename
|
||||
@@ -125,7 +138,7 @@ def generate_executive_report(
|
||||
)
|
||||
# Return FileResponse(
|
||||
return FileResponse(
|
||||
filepath,
|
||||
_assert_safe_report_path(filepath),
|
||||
# Keyword argument: media_type
|
||||
media_type=_MEDIA_TYPES[format],
|
||||
# Keyword argument: filename
|
||||
@@ -155,7 +168,7 @@ def generate_quarterly_report(
|
||||
)
|
||||
# Return FileResponse(
|
||||
return FileResponse(
|
||||
filepath,
|
||||
_assert_safe_report_path(filepath),
|
||||
# Keyword argument: media_type
|
||||
media_type=_MEDIA_TYPES[format],
|
||||
# Keyword argument: filename
|
||||
@@ -187,7 +200,7 @@ def generate_technique_report(
|
||||
)
|
||||
# Return FileResponse(
|
||||
return FileResponse(
|
||||
filepath,
|
||||
_assert_safe_report_path(filepath),
|
||||
# Keyword argument: media_type
|
||||
media_type=_MEDIA_TYPES[format],
|
||||
# Keyword argument: filename
|
||||
|
||||
@@ -0,0 +1,113 @@
|
||||
"""Phase 12: Risk Intelligence router."""
|
||||
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import get_current_user, require_any_role
|
||||
from app.schemas.risk_schema import (
|
||||
TechniqueRiskProfileOut,
|
||||
ComputeResult,
|
||||
)
|
||||
from app.services import risk_intelligence_service as svc
|
||||
|
||||
router = APIRouter(prefix="/risk", tags=["risk-intelligence"])
|
||||
|
||||
|
||||
# ── Compute ──────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.post("/compute", response_model=ComputeResult, status_code=202)
|
||||
def compute_all(
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(require_any_role("admin", "red_lead", "blue_lead")),
|
||||
):
|
||||
"""Recompute risk scores for ALL techniques (admin / leads only)."""
|
||||
result = svc.compute_all_risk_scores(db)
|
||||
return result
|
||||
|
||||
|
||||
@router.post("/profiles/{technique_id}/compute", response_model=TechniqueRiskProfileOut)
|
||||
def compute_one(
|
||||
technique_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
"""Compute (or refresh) the risk profile for a single technique."""
|
||||
return svc.compute_technique_risk(db, technique_id)
|
||||
|
||||
|
||||
# ── Read ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/profiles", response_model=List[TechniqueRiskProfileOut])
|
||||
def list_profiles(
|
||||
risk_level: Optional[str] = None,
|
||||
min_score: Optional[float] = None,
|
||||
max_score: Optional[float] = None,
|
||||
stale_only: bool = False,
|
||||
limit: int = Query(100, ge=1, le=500),
|
||||
offset: int = Query(0, ge=0),
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
"""List risk profiles with optional filters."""
|
||||
return svc.list_risk_profiles(
|
||||
db,
|
||||
risk_level=risk_level,
|
||||
min_score=min_score,
|
||||
max_score=max_score,
|
||||
stale_only=stale_only,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/profiles/{technique_id}", response_model=TechniqueRiskProfileOut)
|
||||
def get_profile(
|
||||
technique_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
"""Get the current risk profile for a technique."""
|
||||
return svc.get_risk_profile(db, technique_id)
|
||||
|
||||
|
||||
@router.get("/matrix")
|
||||
def risk_matrix(
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
"""All profiled techniques with likelihood/impact coordinates for matrix view."""
|
||||
return svc.get_risk_matrix(db)
|
||||
|
||||
|
||||
@router.get("/summary")
|
||||
def risk_summary(
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
"""Aggregate risk statistics: counts by level, average score, top risks."""
|
||||
return svc.get_risk_summary(db)
|
||||
|
||||
|
||||
@router.get("/recommendations")
|
||||
def recommendations(
|
||||
limit: int = Query(20, ge=1, le=100),
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
"""Prioritised list of techniques with actionable recommendations."""
|
||||
return svc.get_recommendations(db, limit=limit)
|
||||
|
||||
|
||||
@router.get("/top")
|
||||
def top_risks(
|
||||
limit: int = Query(10, ge=1, le=50),
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
"""Top N highest-risk techniques (sorted by risk score desc)."""
|
||||
profiles = svc.list_risk_profiles(db, limit=limit)
|
||||
return profiles
|
||||
@@ -87,7 +87,7 @@ def list_snapshots(
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> list:
|
||||
) -> dict:
|
||||
"""List coverage snapshots ordered by creation date (newest first)."""
|
||||
# Return list_snapshots_svc(db, offset=offset, limit=limit)
|
||||
return list_snapshots_svc(db, offset=offset, limit=limit)
|
||||
|
||||
@@ -0,0 +1,135 @@
|
||||
"""Phase 14: SSO / SAML 2.0 router."""
|
||||
|
||||
import os
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
||||
from fastapi.responses import RedirectResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import require_any_role
|
||||
from app import auth as auth_lib
|
||||
from app.schemas.sso_schema import (
|
||||
SsoConfigCreate, SsoConfigOut, SsoStatusResponse,
|
||||
)
|
||||
import app.services.sso_service as svc
|
||||
|
||||
router = APIRouter(prefix="/sso", tags=["SSO"])
|
||||
|
||||
_COOKIE_NAME = "aegis_token"
|
||||
|
||||
# Mirror the same SECURE_COOKIES logic used in the auth router so that
|
||||
# SAML-authenticated sessions respect the deployment's HTTPS configuration.
|
||||
_aegis_env = os.environ.get("AEGIS_ENV", "development").lower()
|
||||
_secure_cookie_env = os.environ.get("SECURE_COOKIES", "auto").lower()
|
||||
if _secure_cookie_env == "false":
|
||||
_IS_HTTPS = False
|
||||
elif _secure_cookie_env == "true":
|
||||
_IS_HTTPS = True
|
||||
else: # "auto" — active only when AEGIS_ENV=production
|
||||
_IS_HTTPS = _aegis_env == "production"
|
||||
|
||||
_COOKIE_OPTS = {"httponly": True, "samesite": "lax", "secure": _IS_HTTPS}
|
||||
|
||||
|
||||
# ── Public ────────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/status", response_model=SsoStatusResponse)
|
||||
def sso_status(db: Session = Depends(get_db)):
|
||||
"""Return whether SSO is enabled and configured (public — for login page)."""
|
||||
return svc.get_status(db)
|
||||
|
||||
|
||||
@router.get("/metadata", response_model=None)
|
||||
def sp_metadata(db: Session = Depends(get_db)):
|
||||
"""
|
||||
Return the Service Provider SAML metadata XML.
|
||||
|
||||
Upload this XML to your IdP (Okta, Azure AD, etc.) to register Aegis.
|
||||
"""
|
||||
try:
|
||||
xml = svc.get_sp_metadata(db)
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=503, detail=str(exc))
|
||||
return Response(content=xml, media_type="application/xml")
|
||||
|
||||
|
||||
@router.get("/login")
|
||||
def sso_login(request: Request, db: Session = Depends(get_db)):
|
||||
"""
|
||||
Initiate SAML login — redirects the browser to the IdP.
|
||||
|
||||
The IdP will POST the SAML Response to ``/sso/callback`` after authentication.
|
||||
"""
|
||||
request_data = {
|
||||
"https": request.url.scheme == "https",
|
||||
"http_host": request.url.hostname,
|
||||
"path": request.url.path,
|
||||
"port": str(request.url.port or (443 if request.url.scheme == "https" else 80)),
|
||||
"get_data": dict(request.query_params),
|
||||
"post_data": {},
|
||||
"query_string": str(request.url.query),
|
||||
}
|
||||
try:
|
||||
result = svc.initiate_login(db, request_data)
|
||||
except RuntimeError as exc:
|
||||
raise HTTPException(status_code=503, detail=str(exc))
|
||||
redirect_url = result["redirect_url"]
|
||||
if urlparse(redirect_url).scheme not in ("http", "https"):
|
||||
raise HTTPException(status_code=400, detail="Invalid IdP redirect URL")
|
||||
return RedirectResponse(url=redirect_url)
|
||||
|
||||
|
||||
@router.post("/callback")
|
||||
async def sso_callback(request: Request, db: Session = Depends(get_db)):
|
||||
"""
|
||||
SAML Assertion Consumer Service (ACS) endpoint.
|
||||
|
||||
The IdP POSTs the SAML Response here. On success, sets the aegis_token
|
||||
cookie and redirects to the frontend.
|
||||
"""
|
||||
form = await request.form()
|
||||
request_data = {
|
||||
"https": request.url.scheme == "https",
|
||||
"http_host": request.url.hostname,
|
||||
"path": request.url.path,
|
||||
"port": str(request.url.port or (443 if request.url.scheme == "https" else 80)),
|
||||
"get_data": dict(request.query_params),
|
||||
"post_data": dict(form),
|
||||
"query_string": str(request.url.query),
|
||||
}
|
||||
try:
|
||||
user = svc.process_callback(db, request_data)
|
||||
except (ValueError, RuntimeError) as exc:
|
||||
raise HTTPException(status_code=401, detail=str(exc))
|
||||
|
||||
access_token = auth_lib.create_access_token({"sub": user.username})
|
||||
response = RedirectResponse(url="/", status_code=302)
|
||||
response.set_cookie(_COOKIE_NAME, access_token, **_COOKIE_OPTS)
|
||||
return response
|
||||
|
||||
|
||||
# ── Admin configuration ────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/config", response_model=SsoConfigOut)
|
||||
def get_sso_config(
|
||||
db: Session = Depends(get_db),
|
||||
_user=Depends(require_any_role("admin")),
|
||||
):
|
||||
"""Return the current SSO configuration (admin only)."""
|
||||
cfg = svc.get_config(db)
|
||||
if not cfg:
|
||||
raise HTTPException(status_code=404, detail="SSO not configured yet")
|
||||
return SsoConfigOut.model_validate(cfg)
|
||||
|
||||
|
||||
@router.put("/config", response_model=SsoConfigOut)
|
||||
def upsert_sso_config(
|
||||
body: SsoConfigCreate,
|
||||
db: Session = Depends(get_db),
|
||||
_user=Depends(require_any_role("admin")),
|
||||
):
|
||||
"""Create or replace the SSO configuration (admin only)."""
|
||||
cfg = svc.upsert_config(db, **body.model_dump(exclude_unset=False))
|
||||
return SsoConfigOut.model_validate(cfg)
|
||||
+699
-40
@@ -3,41 +3,30 @@
|
||||
Provides manual triggers for background operations such as the MITRE
|
||||
ATT&CK synchronisation, intel scanning, Atomic Red Team import, and
|
||||
scheduler health introspection.
|
||||
|
||||
Also exposes email configuration CRUD (admin only) that writes to the
|
||||
system_configs table so settings survive container restarts.
|
||||
"""
|
||||
|
||||
# Import logging
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
# Import APIRouter, Depends, Request from fastapi
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request, status
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import get_db from app.database
|
||||
from app.database import get_db
|
||||
|
||||
# Import require_role from app.dependencies.auth
|
||||
from app.dependencies.auth import require_role
|
||||
|
||||
# Import scheduler from app.jobs.mitre_sync_job
|
||||
from app.database import SessionLocal, get_db
|
||||
from app.dependencies.auth import get_current_user, require_role
|
||||
from app.models.user import User
|
||||
from app.services.mitre_sync_service import sync_mitre
|
||||
from app.services.intel_service import scan_intel
|
||||
from app.services.atomic_import_service import import_atomic_red_team
|
||||
from app.jobs.mitre_sync_job import scheduler
|
||||
|
||||
# Import limiter from app.limiter
|
||||
from app.limiter import limiter
|
||||
|
||||
# Import User from app.models.user
|
||||
from app.models.user import User
|
||||
|
||||
# Import import_atomic_red_team from app.services.atomic_import_service
|
||||
from app.services.atomic_import_service import import_atomic_red_team
|
||||
|
||||
# Import scan_intel from app.services.intel_service
|
||||
from app.services.intel_service import scan_intel
|
||||
|
||||
# Import sync_mitre from app.services.mitre_sync_service
|
||||
from app.services.mitre_sync_service import sync_mitre
|
||||
|
||||
# Assign logger = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -45,7 +34,81 @@ logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/system", tags=["system"])
|
||||
|
||||
|
||||
# Apply the @router.post decorator
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pydantic schemas for email config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class EmailConfigOut(BaseModel):
|
||||
enabled: bool
|
||||
host: str
|
||||
port: int
|
||||
username: str
|
||||
from_email: str
|
||||
use_tls: bool
|
||||
# password is never returned
|
||||
|
||||
|
||||
class EmailConfigUpdate(BaseModel):
|
||||
enabled: Optional[bool] = None
|
||||
host: Optional[str] = None
|
||||
port: Optional[int] = None
|
||||
username: Optional[str] = None
|
||||
password: Optional[str] = None
|
||||
from_email: Optional[str] = None
|
||||
use_tls: Optional[bool] = None
|
||||
|
||||
|
||||
class EmailTestRequest(BaseModel):
|
||||
to: str
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers for system_configs CRUD
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_SMTP_KEYS = {
|
||||
"enabled": "smtp.enabled",
|
||||
"host": "smtp.host",
|
||||
"port": "smtp.port",
|
||||
"username": "smtp.username",
|
||||
"password": "smtp.password", # nosec B105
|
||||
"from_email": "smtp.from_email",
|
||||
"use_tls": "smtp.use_tls",
|
||||
}
|
||||
|
||||
|
||||
def _upsert_config(db: Session, key: str, value: str) -> None:
|
||||
from app.models.system_config import SystemConfig # lazy import avoids circular
|
||||
|
||||
row = db.query(SystemConfig).filter(SystemConfig.key == key).first()
|
||||
if row:
|
||||
row.value = value
|
||||
else:
|
||||
row = SystemConfig(key=key, value=value)
|
||||
db.add(row)
|
||||
|
||||
|
||||
def _read_email_config_from_db(db: Session) -> dict:
|
||||
"""Return a dict with resolved email settings (DB overrides env)."""
|
||||
from app.services.email_service import _get_smtp_config
|
||||
|
||||
return _get_smtp_config(db)
|
||||
|
||||
|
||||
def _bg_mitre_sync() -> None:
|
||||
"""Run MITRE sync in a background task with its own DB session."""
|
||||
logger.info("Background MITRE sync task starting...")
|
||||
db = SessionLocal()
|
||||
try:
|
||||
summary = sync_mitre(db)
|
||||
logger.info("Background MITRE sync task finished — %s", summary)
|
||||
except Exception:
|
||||
logger.exception("Background MITRE sync task failed")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@router.post("/sync-mitre")
|
||||
# Apply the @limiter.limit decorator
|
||||
@limiter.limit("2/hour")
|
||||
@@ -53,28 +116,22 @@ router = APIRouter(prefix="/system", tags=["system"])
|
||||
def trigger_mitre_sync(
|
||||
# Entry: request
|
||||
request: Request,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
background_tasks: BackgroundTasks,
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
) -> dict:
|
||||
"""Manually trigger a MITRE ATT&CK synchronisation.
|
||||
):
|
||||
"""Manually trigger a MITRE ATT&CK synchronisation in the background.
|
||||
|
||||
**Requires** the ``admin`` role.
|
||||
|
||||
Returns a JSON object with the sync summary including the count of
|
||||
new and updated techniques.
|
||||
Returns immediately — the sync runs asynchronously. Poll
|
||||
``/system/scheduler-status`` for progress, or check server logs.
|
||||
"""
|
||||
# Assign summary = sync_mitre(db)
|
||||
summary = sync_mitre(db)
|
||||
# Return {
|
||||
background_tasks.add_task(_bg_mitre_sync)
|
||||
return {
|
||||
# Literal argument value
|
||||
"message": "MITRE sync completed",
|
||||
# Literal argument value
|
||||
"new": summary["created"],
|
||||
# Literal argument value
|
||||
"updated": summary["updated"],
|
||||
"message": "MITRE sync started in background",
|
||||
"status": "started",
|
||||
"new": 0,
|
||||
"updated": 0,
|
||||
}
|
||||
|
||||
|
||||
@@ -185,3 +242,605 @@ def scheduler_status(
|
||||
for job in jobs
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Jira config endpoints (admin only)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class JiraConfigOut(BaseModel):
|
||||
enabled: bool
|
||||
url: str
|
||||
project_key: str
|
||||
parent_ticket: str
|
||||
parent_ticket_standalone: str # parent for tests not in a campaign
|
||||
# Credentials are never returned
|
||||
|
||||
|
||||
class JiraConfigUpdate(BaseModel):
|
||||
enabled: Optional[bool] = None
|
||||
url: Optional[str] = None
|
||||
project_key: Optional[str] = None
|
||||
parent_ticket: Optional[str] = None
|
||||
parent_ticket_standalone: Optional[str] = None
|
||||
|
||||
|
||||
_JIRA_KEYS = {
|
||||
"enabled": "jira.enabled",
|
||||
"url": "jira.url",
|
||||
"project_key": "jira.project_key",
|
||||
"parent_ticket": "jira.parent_ticket",
|
||||
"parent_ticket_standalone": "jira.parent_ticket_standalone",
|
||||
}
|
||||
|
||||
|
||||
@router.get("/jira-config", response_model=JiraConfigOut)
|
||||
def get_jira_config(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Return current Jira configuration (merged DB + env).
|
||||
|
||||
**Requires** the ``admin`` role. Credentials are never returned.
|
||||
"""
|
||||
from app.services.jira_service import (
|
||||
get_jira_url, get_jira_project_key, is_jira_enabled,
|
||||
get_jira_parent_ticket, get_jira_parent_ticket_standalone,
|
||||
)
|
||||
|
||||
return JiraConfigOut(
|
||||
enabled=is_jira_enabled(db),
|
||||
url=get_jira_url(db) or "",
|
||||
project_key=get_jira_project_key(db) or "",
|
||||
parent_ticket=get_jira_parent_ticket(db) or "",
|
||||
parent_ticket_standalone=get_jira_parent_ticket_standalone(db) or "",
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/jira-config", response_model=JiraConfigOut)
|
||||
def update_jira_config(
|
||||
payload: JiraConfigUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Update Jira configuration and persist to DB.
|
||||
|
||||
**Requires** the ``admin`` role. Only provided fields are updated.
|
||||
"""
|
||||
from app.services.jira_service import (
|
||||
upsert_jira_config, get_jira_url, get_jira_project_key, is_jira_enabled,
|
||||
get_jira_parent_ticket, get_jira_parent_ticket_standalone,
|
||||
)
|
||||
|
||||
update_data = payload.model_dump(exclude_unset=True)
|
||||
for field, val in update_data.items():
|
||||
db_key = _JIRA_KEYS.get(field)
|
||||
if db_key:
|
||||
upsert_jira_config(db, db_key, str(val))
|
||||
db.commit()
|
||||
|
||||
return JiraConfigOut(
|
||||
enabled=is_jira_enabled(db),
|
||||
url=get_jira_url(db) or "",
|
||||
project_key=get_jira_project_key(db) or "",
|
||||
parent_ticket=get_jira_parent_ticket(db) or "",
|
||||
parent_ticket_standalone=get_jira_parent_ticket_standalone(db) or "",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/jira-test")
|
||||
def test_jira_connection(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Test the Jira connection using the current user's credentials.
|
||||
|
||||
Requires the admin to have a personal Jira API token configured in their
|
||||
profile settings.
|
||||
|
||||
Always returns HTTP 200 with a ``status`` field so Cloudflare never
|
||||
replaces the response with its own error page.
|
||||
"""
|
||||
from app.services.jira_service import get_user_jira_client, get_jira_url, _effective_jira_email
|
||||
|
||||
jira_url = get_jira_url(db)
|
||||
if not jira_url:
|
||||
return {"status": "error", "message": "Jira URL is not configured. Set it in System Settings → Jira Configuration.", "jira_url": ""}
|
||||
|
||||
auth_email = _effective_jira_email(current_user)
|
||||
|
||||
try:
|
||||
jira = get_user_jira_client(current_user, db)
|
||||
# 10-second timeout so we never block Cloudflare into a 524
|
||||
try:
|
||||
jira._session.timeout = 10 # type: ignore[attr-defined]
|
||||
except Exception: # nosec B110
|
||||
pass
|
||||
myself = jira.myself()
|
||||
logger.info("Jira myself() response keys: %s", list(myself.keys()) if isinstance(myself, dict) else type(myself))
|
||||
# Use displayName → emailAddress → name → the auth email as fallback
|
||||
connected_as = (
|
||||
(myself.get("displayName") if isinstance(myself, dict) else None)
|
||||
or (myself.get("emailAddress") if isinstance(myself, dict) else None)
|
||||
or (myself.get("name") if isinstance(myself, dict) else None)
|
||||
or auth_email
|
||||
or "authenticated"
|
||||
)
|
||||
return {
|
||||
"status": "ok",
|
||||
"connected_as": connected_as,
|
||||
"jira_url": jira_url,
|
||||
}
|
||||
except Exception as exc:
|
||||
err = str(exc)
|
||||
# Always return HTTP 200 with status="error" so Cloudflare never
|
||||
# intercepts the response and the frontend always sees our message.
|
||||
if "Expecting value" in err or "line 1 column 1" in err:
|
||||
msg = (
|
||||
"Jira returned a non-JSON response. "
|
||||
"Verify the URL (e.g. https://company.atlassian.net), "
|
||||
"email and API token."
|
||||
)
|
||||
elif "401" in err or "Unauthorized" in err:
|
||||
msg = (
|
||||
"Authentication failed (401). "
|
||||
f"Check that the Atlassian email ({auth_email or 'not set'}) "
|
||||
"and API token are correct. The token must be an Atlassian API token "
|
||||
"(not your account password)."
|
||||
)
|
||||
elif "403" in err or "Forbidden" in err:
|
||||
msg = "Access denied (403). The token may not have permission for this Jira project."
|
||||
elif "timed out" in err.lower() or "timeout" in err.lower():
|
||||
msg = "Connection timed out. Check that the Jira URL is reachable from the server."
|
||||
elif "not configured" in err.lower():
|
||||
msg = err
|
||||
else:
|
||||
msg = f"Jira connection failed: {err}"
|
||||
logger.warning("Jira test connection failed: %s", err)
|
||||
return {"status": "error", "message": msg, "jira_url": jira_url}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /system/tempo-test
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.post("/tempo-test")
|
||||
def test_tempo_connection(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Test the current user's personal Tempo connection.
|
||||
|
||||
Uses the Tempo API token stored in the user's profile (not a global token).
|
||||
Always returns HTTP 200 with a ``status`` field so Cloudflare never
|
||||
intercepts the response.
|
||||
"""
|
||||
|
||||
tempo_token = getattr(current_user, "tempo_api_token", None)
|
||||
if not tempo_token:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": (
|
||||
"No Tempo API token configured. "
|
||||
"Add it in Settings → Profile → Tempo Integration."
|
||||
),
|
||||
}
|
||||
|
||||
jira_account_id = getattr(current_user, "jira_account_id", None)
|
||||
if not jira_account_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": (
|
||||
"No Jira Account ID configured. "
|
||||
"Set it in Settings → Profile → Jira Integration → Account ID."
|
||||
),
|
||||
}
|
||||
|
||||
try:
|
||||
from tempoapiclient import client_v4 as tempo_client
|
||||
tempo = tempo_client.Tempo(auth_token=tempo_token)
|
||||
# search_worklogs by authorId is the correct v4 method; use a tight
|
||||
# date range so we fetch almost nothing but still verify connectivity.
|
||||
worklogs = tempo.search_worklogs(
|
||||
dateFrom="2024-01-01",
|
||||
dateTo="2024-01-02",
|
||||
authorIds=[jira_account_id],
|
||||
)
|
||||
count = len(worklogs) if isinstance(worklogs, list) else "n/a"
|
||||
return {
|
||||
"status": "ok",
|
||||
"message": f"Tempo connected successfully. Account ID: {jira_account_id}",
|
||||
"worklogs_found": count,
|
||||
}
|
||||
except Exception as exc:
|
||||
err = str(exc)
|
||||
if "401" in err or "Unauthorized" in err:
|
||||
msg = (
|
||||
"Authentication failed (401). "
|
||||
"Check your Tempo API token — obtain it at "
|
||||
"Jira → Apps → Tempo → Settings → API Integration."
|
||||
)
|
||||
elif "403" in err or "Forbidden" in err:
|
||||
msg = "Access denied (403). The Tempo token lacks the required permissions."
|
||||
elif "404" in err or "not found" in err.lower():
|
||||
msg = (
|
||||
"Account ID not found (404). "
|
||||
f"The value '{jira_account_id}' may be wrong — see the instructions "
|
||||
"below to find your correct Atlassian Account ID."
|
||||
)
|
||||
else:
|
||||
msg = f"Tempo connection failed: {err}"
|
||||
logger.warning(
|
||||
"Tempo test connection failed for user %s (account_id=%s): %s",
|
||||
current_user.username, jira_account_id, err,
|
||||
)
|
||||
return {"status": "error", "message": msg}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /system/email-config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get("/email-config", response_model=EmailConfigOut)
|
||||
def get_email_config(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Return current SMTP email configuration (merged DB + env).
|
||||
|
||||
**Requires** the ``admin`` role. Password is never returned.
|
||||
"""
|
||||
cfg = _read_email_config_from_db(db)
|
||||
return EmailConfigOut(
|
||||
enabled=cfg["enabled"],
|
||||
host=cfg["host"],
|
||||
port=cfg["port"],
|
||||
username=cfg["username"],
|
||||
from_email=cfg["from_email"],
|
||||
use_tls=cfg["use_tls"],
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PATCH /system/email-config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.patch("/email-config", response_model=EmailConfigOut)
|
||||
def update_email_config(
|
||||
payload: EmailConfigUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Update SMTP email configuration and persist to DB.
|
||||
|
||||
**Requires** the ``admin`` role.
|
||||
Only provided fields are updated (partial update).
|
||||
"""
|
||||
update_data = payload.model_dump(exclude_unset=True)
|
||||
for field, val in update_data.items():
|
||||
db_key = _SMTP_KEYS.get(field)
|
||||
if db_key:
|
||||
_upsert_config(db, db_key, str(val))
|
||||
db.commit()
|
||||
|
||||
cfg = _read_email_config_from_db(db)
|
||||
return EmailConfigOut(
|
||||
enabled=cfg["enabled"],
|
||||
host=cfg["host"],
|
||||
port=cfg["port"],
|
||||
username=cfg["username"],
|
||||
from_email=cfg["from_email"],
|
||||
use_tls=cfg["use_tls"],
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /system/email-test
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ATT&CK Evaluations endpoints (admin only)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get("/attck-evaluations/rounds")
|
||||
def list_evaluation_rounds(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Return all public CrowdStrike ENTERPRISE evaluation rounds with import status.
|
||||
|
||||
Each entry includes whether it has already been imported into this platform.
|
||||
"""
|
||||
from app.services.attck_evaluations_service import fetch_rounds_with_status
|
||||
from app.models.evaluation_import import EvaluationImport
|
||||
|
||||
status_info = fetch_rounds_with_status()
|
||||
rounds = status_info["rounds"]
|
||||
|
||||
imported = {
|
||||
row.adversary_name.lower(): row
|
||||
for row in db.query(EvaluationImport).filter(EvaluationImport.status == "completed").all()
|
||||
}
|
||||
|
||||
round_list = [
|
||||
{
|
||||
"name": r["name"],
|
||||
"display_name": r.get("display_name", r["name"]),
|
||||
"eval_round": r["eval_round"],
|
||||
"imported": r["name"].lower() in imported,
|
||||
"imported_at": imported[r["name"].lower()].imported_at.isoformat()
|
||||
if r["name"].lower() in imported else None,
|
||||
"tests_created": imported[r["name"].lower()].tests_created
|
||||
if r["name"].lower() in imported else None,
|
||||
"techniques_covered": imported[r["name"].lower()].techniques_covered
|
||||
if r["name"].lower() in imported else None,
|
||||
}
|
||||
for r in rounds
|
||||
]
|
||||
|
||||
return {
|
||||
"rounds": round_list,
|
||||
"api_reachable": status_info["api_reachable"],
|
||||
"api_error": status_info.get("api_error"),
|
||||
}
|
||||
|
||||
|
||||
@router.post("/attck-evaluations/import")
|
||||
def import_evaluation_round(
|
||||
payload: dict,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Import a specific ATT&CK Evaluation round for CrowdStrike.
|
||||
|
||||
Body: { "adversary_name": "apt29", "adversary_display": "APT29", "eval_round": 2 }
|
||||
|
||||
Creates tests in ``in_review`` state — Blue Leads must validate each
|
||||
result before it counts as real coverage.
|
||||
"""
|
||||
from app.services.attck_evaluations_service import import_evaluation_round as _import
|
||||
|
||||
adversary_name = payload.get("adversary_name", "")
|
||||
adversary_display = payload.get("adversary_display", adversary_name)
|
||||
eval_round = payload.get("eval_round", 0)
|
||||
|
||||
if not adversary_name or not eval_round:
|
||||
raise HTTPException(status_code=400, detail="adversary_name and eval_round are required")
|
||||
|
||||
try:
|
||||
summary = _import(db, adversary_name, adversary_display, eval_round, current_user)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=409, detail=str(exc))
|
||||
except Exception as exc:
|
||||
logger.error("ATT&CK Evaluation import failed: %s", exc, exc_info=True)
|
||||
raise HTTPException(status_code=502, detail=f"Import failed: {exc}")
|
||||
|
||||
return {
|
||||
"message": f"Import complete — {summary['created']} tests created",
|
||||
**summary,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/attck-evaluations/import-latest")
|
||||
def import_latest_evaluation(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Import the latest available CrowdStrike evaluation round automatically.
|
||||
|
||||
Returns 409 if the latest round was already imported.
|
||||
"""
|
||||
from app.services.attck_evaluations_service import get_latest_round, import_evaluation_round as _import
|
||||
|
||||
try:
|
||||
latest = get_latest_round()
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=502, detail=f"Could not reach MITRE Evaluations API: {exc}")
|
||||
|
||||
try:
|
||||
summary = _import(
|
||||
db,
|
||||
latest["name"],
|
||||
latest.get("display_name", latest["name"]),
|
||||
latest["eval_round"],
|
||||
current_user,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=409, detail=str(exc))
|
||||
except Exception as exc:
|
||||
logger.error("ATT&CK Evaluation import failed: %s", exc, exc_info=True)
|
||||
raise HTTPException(status_code=502, detail=f"Import failed: {exc}")
|
||||
|
||||
return {
|
||||
"message": f"Import complete — {summary['created']} tests created",
|
||||
**summary,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/attck-evaluations/check-new")
|
||||
def check_new_evaluation_round(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Check if a new ATT&CK Evaluation round is available that hasn't been imported yet."""
|
||||
from app.services.attck_evaluations_service import check_for_new_round
|
||||
|
||||
return check_for_new_round(db)
|
||||
|
||||
|
||||
@router.post("/attck-evaluations/bulk-approve")
|
||||
def bulk_approve_evaluation_tests(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Bulk-approve all Blue Team validation for ATT&CK Evaluation imported tests.
|
||||
|
||||
Finds every test in ``in_review`` state whose name starts with ``[EVAL R``
|
||||
and approves the Blue Team side. Because all evaluation imports pre-approve
|
||||
the Red Team side, this moves every matched test to ``validated`` state.
|
||||
|
||||
**Important caveats** (enforced by UI warnings before this is called):
|
||||
- Results come from a controlled MITRE lab, NOT the organisation's env.
|
||||
- Validated tests will immediately affect coverage metrics and dashboards.
|
||||
- Blue Leads should still spot-check high-priority techniques individually.
|
||||
"""
|
||||
from datetime import datetime
|
||||
from app.models.test import Test
|
||||
from app.models.enums import TestState
|
||||
from app.models.technique import Technique
|
||||
from app.services.status_service import recalculate_technique_status
|
||||
from app.services.audit_service import log_action
|
||||
|
||||
# Find all pending evaluation tests
|
||||
pending = (
|
||||
db.query(Test)
|
||||
.filter(
|
||||
Test.state == TestState.in_review,
|
||||
Test.name.like("[EVAL R%"),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
if not pending:
|
||||
return {
|
||||
"approved": 0,
|
||||
"techniques_recalculated": 0,
|
||||
"message": "No pending evaluation tests found — nothing to approve.",
|
||||
}
|
||||
|
||||
now = datetime.utcnow()
|
||||
affected_technique_ids: set = set()
|
||||
|
||||
for test in pending:
|
||||
# Approve blue side
|
||||
test.blue_validation_status = "approved"
|
||||
test.blue_validated_by = current_user.id
|
||||
test.blue_validated_at = now
|
||||
test.blue_validation_notes = (
|
||||
"Bulk-approved via ATT&CK Evaluations admin panel. "
|
||||
"Source: MITRE lab environment — not organisational detection."
|
||||
)
|
||||
|
||||
# Red side was pre-approved during import → move to validated
|
||||
if test.red_validation_status == "approved":
|
||||
test.state = TestState.validated
|
||||
# else stays in_review (shouldn't happen for eval imports, but be safe)
|
||||
|
||||
if test.technique_id:
|
||||
affected_technique_ids.add(test.technique_id)
|
||||
|
||||
log_action(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
action="bulk_eval_approve",
|
||||
entity_type="test",
|
||||
entity_id=test.id,
|
||||
details={"source": "attck_evaluation_bulk_approve"},
|
||||
)
|
||||
|
||||
db.flush()
|
||||
|
||||
# Recalculate coverage for every touched technique
|
||||
for tech_id in affected_technique_ids:
|
||||
tech = db.query(Technique).filter(Technique.id == tech_id).first()
|
||||
if tech:
|
||||
recalculate_technique_status(db, tech)
|
||||
|
||||
db.commit()
|
||||
|
||||
logger.info(
|
||||
"Bulk eval approval: %d tests validated, %d techniques recalculated (by %s)",
|
||||
len(pending), len(affected_technique_ids), current_user.email,
|
||||
)
|
||||
|
||||
return {
|
||||
"approved": len(pending),
|
||||
"techniques_recalculated": len(affected_technique_ids),
|
||||
"message": (
|
||||
f"{len(pending)} evaluation tests approved and moved to Validated. "
|
||||
f"{len(affected_technique_ids)} technique statuses recalculated."
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@router.get("/attck-evaluations/pending-count")
|
||||
def get_pending_evaluation_count(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Return the number of imported evaluation tests still awaiting Blue approval."""
|
||||
from app.models.test import Test
|
||||
from app.models.enums import TestState
|
||||
|
||||
count = (
|
||||
db.query(Test)
|
||||
.filter(
|
||||
Test.state == TestState.in_review,
|
||||
Test.name.like("[EVAL R%"),
|
||||
)
|
||||
.count()
|
||||
)
|
||||
return {"pending": count}
|
||||
|
||||
|
||||
@router.post("/attck-evaluations/re-enrich")
|
||||
def re_enrich_evaluation_round(
|
||||
payload: dict,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Re-enrich already-imported evaluation tests with rich data from the MITRE API.
|
||||
|
||||
Updates procedure_text (attack path + criteria), description (data sources +
|
||||
substep references) and red_summary — without changing detection results,
|
||||
state or validation status.
|
||||
|
||||
Body: { "adversary_name": "turla", "adversary_display": "Turla", "eval_round": 5 }
|
||||
|
||||
Useful to upgrade tests that were imported before the enrichment feature
|
||||
was added.
|
||||
"""
|
||||
from app.services.attck_evaluations_service import re_enrich_evaluation_round as _re_enrich
|
||||
|
||||
adversary_name = payload.get("adversary_name", "")
|
||||
adversary_display = payload.get("adversary_display", adversary_name)
|
||||
eval_round = payload.get("eval_round", 0)
|
||||
|
||||
if not adversary_name or not eval_round:
|
||||
raise HTTPException(status_code=400, detail="adversary_name and eval_round are required")
|
||||
|
||||
try:
|
||||
summary = _re_enrich(db, adversary_name, adversary_display, eval_round, current_user)
|
||||
except Exception as exc:
|
||||
logger.error("ATT&CK Evaluation re-enrich failed: %s", exc, exc_info=True)
|
||||
raise HTTPException(status_code=502, detail=f"Re-enrich failed: {exc}")
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
@router.post("/email-test")
|
||||
def send_test_email(
|
||||
payload: EmailTestRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Send a test email to verify SMTP configuration.
|
||||
|
||||
**Requires** the ``admin`` role.
|
||||
Returns 200 on success, 502 if sending fails.
|
||||
"""
|
||||
from app.services.email_service import send_test_email as _send_test
|
||||
|
||||
ok = _send_test(payload.to, db=db)
|
||||
if not ok:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail="Failed to send test email. Check SMTP configuration and server logs.",
|
||||
)
|
||||
return {"detail": f"Test email sent to {payload.to}"}
|
||||
|
||||
@@ -42,8 +42,7 @@ from app.dependencies.auth import get_current_user, require_any_role
|
||||
|
||||
# Import UnitOfWork from app.domain.unit_of_work
|
||||
from app.domain.unit_of_work import UnitOfWork
|
||||
|
||||
# Import User from app.models.user
|
||||
from app.models.technique import Technique
|
||||
from app.models.user import User
|
||||
|
||||
# Import from app.schemas.test_template
|
||||
@@ -334,7 +333,15 @@ def create_template(
|
||||
template = create_template_svc(db, **payload.model_dump())
|
||||
# Open context manager
|
||||
with UnitOfWork(db) as uow:
|
||||
# Call log_action()
|
||||
# Flag the associated technique for review — new template available
|
||||
if template.mitre_technique_id:
|
||||
technique = (
|
||||
db.query(Technique)
|
||||
.filter(Technique.mitre_id == template.mitre_technique_id)
|
||||
.first()
|
||||
)
|
||||
if technique:
|
||||
technique.review_required = True
|
||||
log_action(
|
||||
db,
|
||||
# Keyword argument: user_id
|
||||
|
||||
+442
-64
@@ -11,6 +11,7 @@ PATCH /tests/{id}/red — Red Team updates (draft, red_executing)
|
||||
PATCH /tests/{id}/blue — Blue Team updates (blue_evaluating)
|
||||
POST /tests/{id}/start-execution — draft → red_executing
|
||||
POST /tests/{id}/submit-red — red_executing → blue_evaluating
|
||||
POST /tests/{id}/start-blue-work — blue tech picks up (sets Tempo timer)
|
||||
POST /tests/{id}/submit-blue — blue_evaluating → in_review
|
||||
POST /tests/{id}/validate-red — Red Lead validates
|
||||
POST /tests/{id}/validate-blue — Blue Lead validates
|
||||
@@ -18,16 +19,16 @@ POST /tests/{id}/reopen — rejected → draft
|
||||
GET /tests/{id}/timeline — audit-log history for this test
|
||||
"""
|
||||
|
||||
# Import uuid
|
||||
import base64
|
||||
import hashlib
|
||||
import os
|
||||
import uuid
|
||||
|
||||
# Import Optional from typing
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
# Import APIRouter, Depends, HTTPException, Query, Reque... from fastapi
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import get_db from app.database
|
||||
@@ -41,11 +42,11 @@ from app.domain.unit_of_work import UnitOfWork
|
||||
|
||||
# Import limiter from app.limiter
|
||||
from app.limiter import limiter
|
||||
|
||||
# Import TestState from app.models.enums
|
||||
from app.models.enums import TestState
|
||||
|
||||
# Import User from app.models.user
|
||||
from app.models.enums import TestState, TestResult, TeamSide
|
||||
from app.models.evidence import Evidence
|
||||
from app.storage import upload_file
|
||||
from app.models.technique import Technique
|
||||
from app.models.test import Test
|
||||
from app.models.user import User
|
||||
|
||||
# Import from app.schemas.test
|
||||
@@ -69,8 +70,7 @@ from app.services.audit_service import log_action
|
||||
|
||||
# Import recalculate_technique_status from app.services.status_service
|
||||
from app.services.status_service import recalculate_technique_status
|
||||
|
||||
# Import from app.services.test_crud_service
|
||||
from app.services.webhook_service import dispatch_webhook
|
||||
from app.services.test_crud_service import (
|
||||
create_test as crud_create_test,
|
||||
)
|
||||
@@ -120,54 +120,19 @@ from app.services.test_crud_service import (
|
||||
update_test_red as crud_update_test_red,
|
||||
)
|
||||
|
||||
# Import from app.services.test_workflow_service
|
||||
from app.services.test_workflow_service import (
|
||||
get_retest_chain as wf_get_retest_chain,
|
||||
)
|
||||
|
||||
# Import from app.services.test_workflow_service
|
||||
from app.services.test_workflow_service import (
|
||||
handle_remediation_completed as wf_handle_remediation,
|
||||
)
|
||||
|
||||
# Import from app.services.test_workflow_service
|
||||
from app.services.test_workflow_service import (
|
||||
pause_timer as wf_pause_timer,
|
||||
)
|
||||
|
||||
# Import from app.services.test_workflow_service
|
||||
from app.services.test_workflow_service import (
|
||||
reopen_test as wf_reopen,
|
||||
)
|
||||
|
||||
# Import from app.services.test_workflow_service
|
||||
from app.services.test_workflow_service import (
|
||||
resume_timer as wf_resume_timer,
|
||||
)
|
||||
|
||||
# Import from app.services.test_workflow_service
|
||||
from app.services.test_workflow_service import (
|
||||
start_execution as wf_start_execution,
|
||||
)
|
||||
|
||||
# Import from app.services.test_workflow_service
|
||||
from app.services.test_workflow_service import (
|
||||
submit_blue_evidence as wf_submit_blue,
|
||||
)
|
||||
|
||||
# Import from app.services.test_workflow_service
|
||||
from app.services.test_workflow_service import (
|
||||
submit_red_evidence as wf_submit_red,
|
||||
)
|
||||
|
||||
# Import from app.services.test_workflow_service
|
||||
from app.services.test_workflow_service import (
|
||||
validate_as_blue_lead as wf_validate_blue,
|
||||
)
|
||||
|
||||
# Import from app.services.test_workflow_service
|
||||
from app.services.test_workflow_service import (
|
||||
submit_blue_evidence as wf_submit_blue,
|
||||
start_blue_work as wf_start_blue_work,
|
||||
validate_as_red_lead as wf_validate_red,
|
||||
validate_as_blue_lead as wf_validate_blue,
|
||||
reopen_test as wf_reopen,
|
||||
handle_remediation_completed as wf_handle_remediation,
|
||||
get_retest_chain as wf_get_retest_chain,
|
||||
pause_timer as wf_pause_timer,
|
||||
resume_timer as wf_resume_timer,
|
||||
)
|
||||
|
||||
# Assign router = APIRouter(prefix="/tests", tags=["tests"])
|
||||
@@ -194,7 +159,9 @@ def list_tests(
|
||||
pending_validation_side: Optional[str] = Query(
|
||||
None, description="Filter in_review tests pending validation on 'red' or 'blue' side"
|
||||
),
|
||||
# Entry: offset
|
||||
not_in_any_campaign: bool = Query(
|
||||
False, description="Only return tests not linked to any campaign"
|
||||
),
|
||||
offset: int = Query(0, ge=0),
|
||||
# Entry: limit
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
@@ -233,7 +200,7 @@ def list_tests(
|
||||
created_by=created_by,
|
||||
# Keyword argument: pending_validation_side
|
||||
pending_validation_side=pending_validation_side,
|
||||
# Keyword argument: offset
|
||||
not_in_any_campaign=not_in_any_campaign,
|
||||
offset=offset,
|
||||
# Keyword argument: limit
|
||||
limit=limit,
|
||||
@@ -309,7 +276,14 @@ def create_test(
|
||||
# Reload ORM object attributes from the database
|
||||
db.refresh(test)
|
||||
|
||||
# Return test
|
||||
# Auto-create Jira ticket (non-fatal — any failure is logged, not raised)
|
||||
try:
|
||||
from app.services.jira_service import auto_create_test_issue
|
||||
auto_create_test_issue(db, test, current_user)
|
||||
db.commit()
|
||||
except Exception: # nosec B110
|
||||
pass # jira_service already logs warnings internally
|
||||
|
||||
return test
|
||||
|
||||
|
||||
@@ -363,6 +337,11 @@ def create_test_from_template(
|
||||
technique_id_or_mitre=payload.technique_id,
|
||||
# Keyword argument: creator_id
|
||||
creator_id=current_user.id,
|
||||
name_override=payload.name,
|
||||
description_override=payload.description,
|
||||
platform_override=payload.platform,
|
||||
procedure_text_override=payload.procedure_text,
|
||||
tool_used_override=payload.tool_used,
|
||||
)
|
||||
# Call log_action()
|
||||
log_action(
|
||||
@@ -390,7 +369,14 @@ def create_test_from_template(
|
||||
# Reload ORM object attributes from the database
|
||||
db.refresh(test)
|
||||
|
||||
# Return test
|
||||
# Auto-create Jira ticket (non-fatal)
|
||||
try:
|
||||
from app.services.jira_service import auto_create_test_issue
|
||||
auto_create_test_issue(db, test, current_user)
|
||||
db.commit()
|
||||
except Exception: # nosec B110
|
||||
pass # jira_service already logs warnings internally
|
||||
|
||||
return test
|
||||
|
||||
|
||||
@@ -780,6 +766,26 @@ def submit_blue(
|
||||
return test
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /tests/{id}/start-blue-work — blue tech picks up test for evaluation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.post("/{test_id}/start-blue-work", response_model=TestOut)
|
||||
def start_blue_work(
|
||||
test_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_any_role("blue_tech", "blue_lead")),
|
||||
):
|
||||
"""Blue tech picks up the test to start evaluating. Sets the Tempo timer start."""
|
||||
test = crud_get_test_or_raise(db, test_id)
|
||||
with UnitOfWork(db) as uow:
|
||||
test = wf_start_blue_work(db, test, current_user)
|
||||
uow.commit()
|
||||
db.refresh(test)
|
||||
return test
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /tests/{id}/pause-timer — pause the active phase timer
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -902,11 +908,16 @@ def validate_red(
|
||||
if test.state in (TestState.validated, TestState.rejected):
|
||||
# Call recalculate_technique_status()
|
||||
recalculate_technique_status(db, test.technique)
|
||||
# Call uow.commit()
|
||||
# Flag technique for review — coverage changed
|
||||
if test.technique:
|
||||
test.technique.review_required = True
|
||||
uow.commit()
|
||||
# Reload ORM object attributes from the database
|
||||
db.refresh(test)
|
||||
# Return test
|
||||
if test.state == TestState.validated:
|
||||
dispatch_webhook("test.validated", {"test_id": str(test.id), "technique_id": str(test.technique_id), "result": test.result.value if test.result else None})
|
||||
elif test.state == TestState.rejected:
|
||||
dispatch_webhook("test.rejected", {"test_id": str(test.id), "technique_id": str(test.technique_id)})
|
||||
return test
|
||||
|
||||
|
||||
@@ -954,11 +965,16 @@ def validate_blue(
|
||||
if test.state in (TestState.validated, TestState.rejected):
|
||||
# Call recalculate_technique_status()
|
||||
recalculate_technique_status(db, test.technique)
|
||||
# Call uow.commit()
|
||||
# Flag technique for review — coverage changed
|
||||
if test.technique:
|
||||
test.technique.review_required = True
|
||||
uow.commit()
|
||||
# Reload ORM object attributes from the database
|
||||
db.refresh(test)
|
||||
# Return test
|
||||
if test.state == TestState.validated:
|
||||
dispatch_webhook("test.validated", {"test_id": str(test.id), "technique_id": str(test.technique_id), "result": test.result.value if test.result else None})
|
||||
elif test.state == TestState.rejected:
|
||||
dispatch_webhook("test.rejected", {"test_id": str(test.id), "technique_id": str(test.technique_id)})
|
||||
return test
|
||||
|
||||
|
||||
@@ -1164,3 +1180,365 @@ def get_retest_chain(
|
||||
}
|
||||
for t in chain
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /tests/{id}/sync-tempo — manual Tempo sync for red execution worklog
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.post("/{test_id}/sync-tempo")
|
||||
def sync_tempo(
|
||||
test_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Manually sync this test's red team execution worklog(s) to Tempo.
|
||||
|
||||
Useful when the automatic sync failed at phase completion (e.g. Tempo
|
||||
was not yet configured). Only red_team_execution worklogs are eligible.
|
||||
Already-synced worklogs are skipped. Returns a summary of what happened.
|
||||
"""
|
||||
from datetime import datetime as _dt
|
||||
from app.models.worklog import Worklog
|
||||
from app.services.tempo_service import auto_log_test_worklog
|
||||
from app.services.test_crud_service import get_test_or_raise as _get
|
||||
|
||||
test = _get(db, test_id)
|
||||
|
||||
worklogs = (
|
||||
db.query(Worklog)
|
||||
.filter(
|
||||
Worklog.entity_type == "test",
|
||||
Worklog.entity_id == test_id,
|
||||
Worklog.activity_type == "red_team_execution",
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
if not worklogs:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="No red team execution worklog found for this test.",
|
||||
)
|
||||
|
||||
results = []
|
||||
for wl in worklogs:
|
||||
if wl.tempo_synced:
|
||||
results.append({"worklog_id": str(wl.id), "status": "already_synced"})
|
||||
continue
|
||||
|
||||
try:
|
||||
result = auto_log_test_worklog(
|
||||
db=db,
|
||||
test=test,
|
||||
user=current_user,
|
||||
activity_type=wl.activity_type,
|
||||
duration_seconds=wl.duration_seconds,
|
||||
)
|
||||
if result and isinstance(result, dict):
|
||||
wl.tempo_synced = _dt.utcnow()
|
||||
wl.tempo_worklog_id = str(result.get("tempoWorklogId", ""))
|
||||
db.commit()
|
||||
results.append({"worklog_id": str(wl.id), "status": "synced"})
|
||||
else:
|
||||
results.append({
|
||||
"worklog_id": str(wl.id),
|
||||
"status": "skipped",
|
||||
"detail": "Tempo not configured or conditions not met.",
|
||||
})
|
||||
except Exception as exc:
|
||||
results.append({
|
||||
"worklog_id": str(wl.id),
|
||||
"status": "error",
|
||||
"detail": str(exc),
|
||||
})
|
||||
|
||||
return {"results": results}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /tests/{id}/request-discussion — disputed: confirm vote + notify other lead
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.post("/{test_id}/request-discussion")
|
||||
def request_discussion(
|
||||
test_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead", "admin")),
|
||||
):
|
||||
"""Called when the approving lead confirms their vote in a disputed test.
|
||||
|
||||
Sends a notification to the other lead (who rejected) asking them to
|
||||
discuss and resolve the conflict. The test remains in 'disputed' state.
|
||||
"""
|
||||
from app.models.user import User as UserModel
|
||||
from app.services.notification_service import create_notification
|
||||
|
||||
test = crud_get_test_or_raise(db, test_id)
|
||||
|
||||
if test.state.value != "disputed":
|
||||
from app.domain.errors import BusinessRuleViolation
|
||||
raise BusinessRuleViolation("Test is not in disputed state")
|
||||
|
||||
role = current_user.role
|
||||
|
||||
# Identify who the "other lead" is (the one who rejected)
|
||||
if (role in ("red_lead", "admin")) and test.red_validation_status == "approved":
|
||||
# Red approved, Blue rejected → notify Blue Lead who rejected
|
||||
rejector_id = test.blue_validated_by
|
||||
rejector_label = "Blue Lead"
|
||||
requester_label = "Red Lead"
|
||||
elif (role in ("blue_lead", "admin")) and test.blue_validation_status == "approved":
|
||||
# Blue approved, Red rejected → notify Red Lead who rejected
|
||||
rejector_id = test.red_validated_by
|
||||
rejector_label = "Red Lead"
|
||||
requester_label = "Blue Lead"
|
||||
else:
|
||||
from app.domain.errors import BusinessRuleViolation
|
||||
raise BusinessRuleViolation(
|
||||
"The conflict state is inconsistent — no approving lead found"
|
||||
)
|
||||
|
||||
# Look up the rejecting lead's full info for the response
|
||||
rejector = (
|
||||
db.query(UserModel).filter(UserModel.id == rejector_id).first()
|
||||
if rejector_id else None
|
||||
)
|
||||
rejector_name = rejector.username if rejector else rejector_label
|
||||
rejector_email = getattr(rejector, "email", None) if rejector else None
|
||||
|
||||
# Notify the rejecting lead
|
||||
if rejector_id:
|
||||
try:
|
||||
create_notification(
|
||||
db,
|
||||
user_id=rejector_id,
|
||||
type="validation_conflict",
|
||||
title="Discussion requested on disputed test",
|
||||
message=(
|
||||
f"{requester_label} ({current_user.username}) is confirming their approval "
|
||||
f"of test '{test.name}' and wants to discuss your rejection with you. "
|
||||
f"Please reach out to resolve the disagreement."
|
||||
),
|
||||
entity_type="test",
|
||||
entity_id=str(test.id),
|
||||
)
|
||||
except Exception as e:
|
||||
import logging
|
||||
logging.getLogger(__name__).warning(
|
||||
"Failed to send discussion notification: %s", e
|
||||
)
|
||||
|
||||
log_action(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
action="request_dispute_discussion",
|
||||
entity_type="test",
|
||||
entity_id=test.id,
|
||||
details={"test_name": test.name, "rejector": rejector_name},
|
||||
)
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
"status": "notification_sent",
|
||||
"message": f"Discussion request sent to {rejector_name}",
|
||||
"rejector_username": rejector_name,
|
||||
"rejector_email": rejector_email,
|
||||
"rejector_role": rejector_label,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /tests/import-rt — bulk import from a real Red Team engagement
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
_ALLOWED_IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".gif", ".webp", ".bmp"}
|
||||
_MAX_EVIDENCE_BYTES = 10 * 1024 * 1024 # 10 MB decoded per image
|
||||
|
||||
|
||||
class RTEvidenceEntry(BaseModel):
|
||||
filename: str # e.g. "screenshot_edr.png"
|
||||
data: str # base64-encoded image content
|
||||
caption: Optional[str] = None # optional description shown as evidence notes
|
||||
|
||||
|
||||
class RTTechniqueEntry(BaseModel):
|
||||
mitre_id: str
|
||||
result: str # "detected" | "not_detected" | "partially_detected"
|
||||
attack_success: bool = True
|
||||
platform: Optional[str] = None
|
||||
notes: Optional[str] = None
|
||||
evidence: list[RTEvidenceEntry] # REQUIRED — at least one image per technique
|
||||
|
||||
|
||||
class RTImportPayload(BaseModel):
|
||||
name: str # engagement name, e.g. "Red Team Q1 2024"
|
||||
date: Optional[str] = None # ISO date string
|
||||
description: Optional[str] = None
|
||||
operator: Optional[str] = None # team / company that ran the RT
|
||||
techniques: list[RTTechniqueEntry]
|
||||
|
||||
|
||||
@router.post("/import-rt", status_code=status.HTTP_201_CREATED)
|
||||
def import_rt(
|
||||
payload: RTImportPayload,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_any_role("red_lead")),
|
||||
):
|
||||
"""Import results from a real Red Team engagement.
|
||||
|
||||
Creates one Test record per technique in ``validated`` state (bypassing
|
||||
the normal Red/Blue workflow) and immediately recalculates coverage metrics.
|
||||
Requires ``red_lead`` or ``admin`` role.
|
||||
"""
|
||||
# Pre-validate: every technique must include at least one evidence image
|
||||
for entry in payload.techniques:
|
||||
if not entry.evidence:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
detail=(
|
||||
f"Technique {entry.mitre_id} is missing evidence. "
|
||||
"At least one screenshot or image is required per technique."
|
||||
),
|
||||
)
|
||||
|
||||
# Execution date from payload or now
|
||||
exec_date_str = payload.date or datetime.utcnow().date().isoformat()
|
||||
|
||||
# Result string → TestResult enum
|
||||
_result_map = {
|
||||
"detected": TestResult.detected,
|
||||
"not_detected": TestResult.not_detected,
|
||||
"partially_detected": TestResult.partially_detected,
|
||||
}
|
||||
|
||||
created: list[dict[str, Any]] = []
|
||||
skipped: list[dict[str, str]] = []
|
||||
affected_technique_ids: set = set()
|
||||
|
||||
with UnitOfWork(db) as uow:
|
||||
for entry in payload.techniques:
|
||||
# Find technique
|
||||
technique = (
|
||||
db.query(Technique)
|
||||
.filter(Technique.mitre_id == entry.mitre_id.upper())
|
||||
.first()
|
||||
)
|
||||
if technique is None:
|
||||
skipped.append({"mitre_id": entry.mitre_id, "reason": "Technique not found"})
|
||||
continue
|
||||
|
||||
detection_result = _result_map.get(entry.result)
|
||||
if detection_result is None:
|
||||
skipped.append({"mitre_id": entry.mitre_id, "reason": f"Unknown result value '{entry.result}'"})
|
||||
continue
|
||||
|
||||
test_name = f"[RT] {payload.name} — {technique.name}"
|
||||
|
||||
# Build red_summary from notes + engagement metadata
|
||||
parts = []
|
||||
if payload.operator:
|
||||
parts.append(f"Operator: {payload.operator}")
|
||||
parts.append(f"Engagement date: {exec_date_str}")
|
||||
if entry.notes:
|
||||
parts.append(f"\n{entry.notes}")
|
||||
red_summary_text = "\n".join(parts)
|
||||
|
||||
# RT pre-validates the Red side (they ran it), but Blue Lead
|
||||
# must still validate the detection result before it counts.
|
||||
# State = in_review so it appears in the Blue Lead's validation queue.
|
||||
test = Test(
|
||||
technique_id=technique.id,
|
||||
name=test_name,
|
||||
description=payload.description,
|
||||
platform=entry.platform,
|
||||
procedure_text=entry.notes,
|
||||
created_by=current_user.id,
|
||||
state=TestState.in_review,
|
||||
# Red team — approved by the RT operator
|
||||
attack_success=entry.attack_success,
|
||||
red_summary=red_summary_text,
|
||||
red_validation_status="approved",
|
||||
red_validated_by=current_user.id,
|
||||
red_validated_at=datetime.utcnow(),
|
||||
# Blue team — pre-fill the detection result but leave
|
||||
# validation_status pending so Blue Lead must confirm
|
||||
detection_result=detection_result,
|
||||
blue_validation_status=None,
|
||||
# Timing
|
||||
execution_date=exec_date_str,
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
db.add(test)
|
||||
db.flush()
|
||||
|
||||
# ── Store evidence images ──────────────────────────────
|
||||
evidence_count = 0
|
||||
for ev in entry.evidence:
|
||||
safe_name = os.path.basename(ev.filename) or "evidence.png"
|
||||
ext = os.path.splitext(safe_name)[1].lower()
|
||||
if ext not in _ALLOWED_IMAGE_EXTS:
|
||||
# Skip non-image files silently (log warning)
|
||||
continue
|
||||
try:
|
||||
img_bytes = base64.b64decode(ev.data)
|
||||
except Exception: # nosec B112
|
||||
continue # malformed base64 — skip
|
||||
if len(img_bytes) > _MAX_EVIDENCE_BYTES:
|
||||
continue # over size limit — skip
|
||||
sha256 = hashlib.sha256(img_bytes).hexdigest()
|
||||
key = f"{test.id}/{uuid.uuid4()}_{safe_name}"
|
||||
try:
|
||||
upload_file(img_bytes, key)
|
||||
except Exception: # nosec B112
|
||||
continue # storage error — skip but don't abort
|
||||
evidence_obj = Evidence(
|
||||
test_id=test.id,
|
||||
file_name=safe_name,
|
||||
file_path=key,
|
||||
sha256_hash=sha256,
|
||||
uploaded_by=current_user.id,
|
||||
uploaded_at=datetime.utcnow(),
|
||||
team=TeamSide.red,
|
||||
notes=ev.caption,
|
||||
)
|
||||
db.add(evidence_obj)
|
||||
evidence_count += 1
|
||||
|
||||
affected_technique_ids.add(technique.id)
|
||||
created.append({
|
||||
"mitre_id": entry.mitre_id,
|
||||
"test_name": test_name,
|
||||
"result": entry.result,
|
||||
"attack_success": entry.attack_success,
|
||||
"evidence_attached": evidence_count,
|
||||
})
|
||||
|
||||
log_action(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
action="rt_import_test",
|
||||
entity_type="test",
|
||||
entity_id=test.id,
|
||||
details={"engagement": payload.name, "mitre_id": entry.mitre_id},
|
||||
)
|
||||
|
||||
# Recalculate coverage for all affected techniques
|
||||
for tech_id in affected_technique_ids:
|
||||
tech = db.query(Technique).filter(Technique.id == tech_id).first()
|
||||
if tech:
|
||||
recalculate_technique_status(db, tech)
|
||||
|
||||
uow.commit()
|
||||
|
||||
return {
|
||||
"created": len(created),
|
||||
"skipped": len(skipped),
|
||||
"items": created,
|
||||
"warnings": skipped,
|
||||
"engagement": payload.name,
|
||||
}
|
||||
|
||||
|
||||
@@ -62,7 +62,7 @@ def list_threat_actors(
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> list:
|
||||
) -> dict:
|
||||
"""List threat actors with optional filters and pagination.
|
||||
|
||||
**Requires** authentication (any role).
|
||||
@@ -138,7 +138,7 @@ def get_threat_actor_gaps(
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> list:
|
||||
) -> dict:
|
||||
"""Identify techniques of this actor that are NOT fully validated.
|
||||
|
||||
**Requires** authentication (any role).
|
||||
|
||||
@@ -20,11 +20,8 @@ from app.domain.unit_of_work import UnitOfWork
|
||||
|
||||
# Import User from app.models.user
|
||||
from app.models.user import User
|
||||
|
||||
# Import UserCreate, UserOut, UserUpdate from app.schemas.user
|
||||
from app.schemas.user import UserCreate, UserOut, UserUpdate
|
||||
|
||||
# Import log_action from app.services.audit_service
|
||||
from app.dependencies.auth import get_current_user
|
||||
from app.schemas.user import UserCreate, UserUpdate, UserOut, UserPreferencesUpdate
|
||||
from app.services.audit_service import log_action
|
||||
|
||||
# Import from app.services.user_service
|
||||
@@ -39,6 +36,47 @@ from app.services.user_service import (
|
||||
router = APIRouter(prefix="/users", tags=["users"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PATCH /users/me/preferences — update current user preferences
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.patch("/me/preferences", response_model=UserOut)
|
||||
def update_my_preferences(
|
||||
payload: UserPreferencesUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Update the current user's notification preferences, Jira account ID and Jira API token.
|
||||
|
||||
Send ``jira_api_token: ""`` to clear a previously stored token.
|
||||
The token is never returned in any response.
|
||||
"""
|
||||
update_data = payload.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
if field in ("jira_api_token", "jira_email", "tempo_api_token"):
|
||||
# Empty string means "clear the value"
|
||||
setattr(current_user, field, value if value else None)
|
||||
else:
|
||||
setattr(current_user, field, value)
|
||||
db.commit()
|
||||
db.refresh(current_user)
|
||||
return current_user
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /users/me — get current user's own profile
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserOut)
|
||||
def get_me(
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Return the currently authenticated user's profile."""
|
||||
return current_user
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /users — list all users
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -0,0 +1,149 @@
|
||||
"""Webhook configuration CRUD router — admin only.
|
||||
|
||||
Endpoints
|
||||
---------
|
||||
GET /webhooks — list all webhook configs
|
||||
POST /webhooks — create a new webhook config
|
||||
GET /webhooks/{id} — get a single webhook config
|
||||
PATCH /webhooks/{id} — update a webhook config
|
||||
DELETE /webhooks/{id} — hard-delete a webhook config
|
||||
POST /webhooks/{id}/test — send a test ping
|
||||
"""
|
||||
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import require_any_role
|
||||
from app.domain.unit_of_work import UnitOfWork
|
||||
from app.models.user import User
|
||||
from app.schemas.webhook import WebhookConfigCreate, WebhookConfigOut, WebhookConfigUpdate
|
||||
from app.services.webhook_service import (
|
||||
create_webhook,
|
||||
delete_webhook,
|
||||
dispatch_webhook,
|
||||
get_webhook_or_raise,
|
||||
list_webhooks,
|
||||
update_webhook,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/webhooks", tags=["webhooks"])
|
||||
|
||||
|
||||
def _mask_secret(wh) -> WebhookConfigOut:
|
||||
"""Return a WebhookConfigOut with the secret masked."""
|
||||
out = WebhookConfigOut.model_validate(wh)
|
||||
if wh.secret:
|
||||
out.secret = "***" # nosec B105
|
||||
else:
|
||||
out.secret = None
|
||||
return out
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /webhooks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get("", response_model=list[WebhookConfigOut])
|
||||
def list_webhooks_route(
|
||||
offset: int = 0,
|
||||
limit: int = 50,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_any_role("admin")),
|
||||
):
|
||||
"""Return all webhook configurations. **Requires admin role.**"""
|
||||
webhooks = list_webhooks(db, offset=offset, limit=limit)
|
||||
return [_mask_secret(wh) for wh in webhooks]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /webhooks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.post("", response_model=WebhookConfigOut, status_code=status.HTTP_201_CREATED)
|
||||
def create_webhook_route(
|
||||
payload: WebhookConfigCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_any_role("admin")),
|
||||
):
|
||||
"""Create a new webhook configuration. **Requires admin role.**"""
|
||||
with UnitOfWork(db) as uow:
|
||||
wh = create_webhook(db, created_by=current_user.id, payload=payload)
|
||||
uow.commit()
|
||||
db.refresh(wh)
|
||||
return _mask_secret(wh)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /webhooks/{id}
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get("/{webhook_id}", response_model=WebhookConfigOut)
|
||||
def get_webhook_route(
|
||||
webhook_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_any_role("admin")),
|
||||
):
|
||||
"""Return a single webhook configuration. **Requires admin role.**"""
|
||||
wh = get_webhook_or_raise(db, webhook_id)
|
||||
return _mask_secret(wh)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PATCH /webhooks/{id}
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.patch("/{webhook_id}", response_model=WebhookConfigOut)
|
||||
def update_webhook_route(
|
||||
webhook_id: uuid.UUID,
|
||||
payload: WebhookConfigUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_any_role("admin")),
|
||||
):
|
||||
"""Update one or more fields of a webhook configuration. **Requires admin role.**"""
|
||||
with UnitOfWork(db) as uow:
|
||||
wh = update_webhook(db, webhook_id, payload)
|
||||
uow.commit()
|
||||
db.refresh(wh)
|
||||
return _mask_secret(wh)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DELETE /webhooks/{id}
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.delete("/{webhook_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
def delete_webhook_route(
|
||||
webhook_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_any_role("admin")),
|
||||
):
|
||||
"""Hard-delete a webhook configuration. **Requires admin role.**"""
|
||||
with UnitOfWork(db) as uow:
|
||||
delete_webhook(db, webhook_id)
|
||||
uow.commit()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /webhooks/{id}/test
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.post("/{webhook_id}/test", status_code=status.HTTP_202_ACCEPTED)
|
||||
def test_webhook_route(
|
||||
webhook_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_any_role("admin")),
|
||||
):
|
||||
"""Send a test ping to the webhook endpoint. **Requires admin role.**"""
|
||||
# Verify the webhook exists before dispatching
|
||||
get_webhook_or_raise(db, webhook_id)
|
||||
dispatch_webhook("webhook.test", {"webhook_id": str(webhook_id), "message": "Test ping from Aegis"})
|
||||
return {"detail": "Test ping dispatched"}
|
||||
@@ -0,0 +1,68 @@
|
||||
"""Phase 14: API Key Pydantic schemas."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from app.models.api_key import VALID_SCOPES
|
||||
|
||||
|
||||
class ApiKeyCreate(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=200)
|
||||
description: Optional[str] = None
|
||||
scopes: List[str] = Field(default=["read"])
|
||||
expires_at: Optional[datetime] = None
|
||||
|
||||
@field_validator("scopes")
|
||||
@classmethod
|
||||
def validate_scopes(cls, v: list) -> list:
|
||||
invalid = set(v) - VALID_SCOPES
|
||||
if invalid:
|
||||
raise ValueError(f"Invalid scopes: {invalid}. Valid: {VALID_SCOPES}")
|
||||
if not v:
|
||||
raise ValueError("At least one scope is required")
|
||||
return v
|
||||
|
||||
|
||||
class ApiKeyOut(BaseModel):
|
||||
"""Safe representation — never exposes key_hash."""
|
||||
id: UUID
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
key_prefix: str
|
||||
user_id: UUID
|
||||
scopes: List[str]
|
||||
last_used_at: Optional[datetime] = None
|
||||
expires_at: Optional[datetime] = None
|
||||
is_active: bool
|
||||
created_at: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class ApiKeyCreated(ApiKeyOut):
|
||||
"""Returned only once at creation — includes the raw key."""
|
||||
raw_key: str = Field(..., description="The full API key — shown only this once.")
|
||||
|
||||
|
||||
class ApiKeyUpdate(BaseModel):
|
||||
name: Optional[str] = Field(None, min_length=1, max_length=200)
|
||||
description: Optional[str] = None
|
||||
scopes: Optional[List[str]] = None
|
||||
expires_at: Optional[datetime] = None
|
||||
is_active: Optional[bool] = None
|
||||
|
||||
@field_validator("scopes")
|
||||
@classmethod
|
||||
def validate_scopes(cls, v: Optional[list]) -> Optional[list]:
|
||||
if v is None:
|
||||
return v
|
||||
invalid = set(v) - VALID_SCOPES
|
||||
if invalid:
|
||||
raise ValueError(f"Invalid scopes: {invalid}")
|
||||
return v
|
||||
@@ -0,0 +1,230 @@
|
||||
"""Pydantic schemas for Phase 10: Attack Paths & Advanced Purple Team."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, field_validator
|
||||
|
||||
VALID_KILL_CHAIN_PHASES = [
|
||||
"reconnaissance", "resource_development", "initial_access", "execution",
|
||||
"persistence", "privilege_escalation", "defense_evasion", "credential_access",
|
||||
"discovery", "lateral_movement", "collection", "command_and_control",
|
||||
"exfiltration", "impact",
|
||||
]
|
||||
|
||||
|
||||
# ── Attack Path ───────────────────────────────────────────────────────────────
|
||||
|
||||
class AttackPathCreate(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
objective: Optional[str] = None
|
||||
is_template: bool = False
|
||||
threat_actor_id: Optional[UUID] = None
|
||||
tags: Optional[list[str]] = None
|
||||
|
||||
|
||||
class AttackPathUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
objective: Optional[str] = None
|
||||
is_template: Optional[bool] = None
|
||||
threat_actor_id: Optional[UUID] = None
|
||||
tags: Optional[list[str]] = None
|
||||
is_active: Optional[bool] = None
|
||||
|
||||
|
||||
class AttackPathOut(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: UUID
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
objective: Optional[str] = None
|
||||
is_template: bool
|
||||
threat_actor_id: Optional[UUID] = None
|
||||
created_by: Optional[UUID] = None
|
||||
tags: Optional[list] = None
|
||||
is_active: bool
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
step_count: Optional[int] = None # injected by service
|
||||
|
||||
|
||||
# ── Attack Path Step ──────────────────────────────────────────────────────────
|
||||
|
||||
class AttackPathStepCreate(BaseModel):
|
||||
order_index: int = 0
|
||||
kill_chain_phase: Optional[str] = None
|
||||
technique_id: Optional[UUID] = None
|
||||
test_id: Optional[UUID] = None
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
expected_detection: bool = True
|
||||
notes: Optional[str] = None
|
||||
|
||||
@field_validator("kill_chain_phase")
|
||||
@classmethod
|
||||
def validate_phase(cls, v):
|
||||
if v is not None and v not in VALID_KILL_CHAIN_PHASES:
|
||||
raise ValueError(f"Invalid kill_chain_phase '{v}'. Valid: {VALID_KILL_CHAIN_PHASES}")
|
||||
return v
|
||||
|
||||
|
||||
class AttackPathStepUpdate(BaseModel):
|
||||
order_index: Optional[int] = None
|
||||
kill_chain_phase: Optional[str] = None
|
||||
technique_id: Optional[UUID] = None
|
||||
test_id: Optional[UUID] = None
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
expected_detection: Optional[bool] = None
|
||||
notes: Optional[str] = None
|
||||
|
||||
@field_validator("kill_chain_phase")
|
||||
@classmethod
|
||||
def validate_phase(cls, v):
|
||||
if v is not None and v not in VALID_KILL_CHAIN_PHASES:
|
||||
raise ValueError(f"Invalid kill_chain_phase '{v}'.")
|
||||
return v
|
||||
|
||||
|
||||
class AttackPathStepOut(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: UUID
|
||||
attack_path_id: UUID
|
||||
order_index: int
|
||||
kill_chain_phase: Optional[str] = None
|
||||
technique_id: Optional[UUID] = None
|
||||
test_id: Optional[UUID] = None
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
expected_detection: bool
|
||||
notes: Optional[str] = None
|
||||
|
||||
|
||||
# ── Execution ─────────────────────────────────────────────────────────────────
|
||||
|
||||
class ExecutionCreate(BaseModel):
|
||||
environment: Optional[str] = None
|
||||
red_team_lead: Optional[UUID] = None
|
||||
blue_team_lead: Optional[UUID] = None
|
||||
notes: Optional[str] = None
|
||||
|
||||
|
||||
class ExecutionOut(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: UUID
|
||||
attack_path_id: UUID
|
||||
status: str
|
||||
environment: Optional[str] = None
|
||||
red_team_lead: Optional[UUID] = None
|
||||
blue_team_lead: Optional[UUID] = None
|
||||
started_by: Optional[UUID] = None
|
||||
started_at: Optional[datetime] = None
|
||||
completed_at: Optional[datetime] = None
|
||||
notes: Optional[str] = None
|
||||
created_at: Optional[datetime] = None
|
||||
# metrics
|
||||
total_steps: Optional[int] = None
|
||||
detected_steps: Optional[int] = None
|
||||
not_detected_steps: Optional[int] = None
|
||||
skipped_steps: Optional[int] = None
|
||||
detection_rate: Optional[float] = None
|
||||
mttd_seconds: Optional[float] = None
|
||||
furthest_undetected_step: Optional[int] = None
|
||||
|
||||
|
||||
# ── Step Result ───────────────────────────────────────────────────────────────
|
||||
|
||||
class StepExecuteRequest(BaseModel):
|
||||
status: str # detected / not_detected / skipped
|
||||
executed_at: Optional[datetime] = None
|
||||
detected_at: Optional[datetime] = None
|
||||
detection_asset_id: Optional[UUID] = None
|
||||
notes: Optional[str] = None
|
||||
evidence_ids: Optional[list[UUID]] = None
|
||||
|
||||
@field_validator("status")
|
||||
@classmethod
|
||||
def validate_status(cls, v):
|
||||
valid = ("detected", "not_detected", "skipped", "executing")
|
||||
if v not in valid:
|
||||
raise ValueError(f"status must be one of {valid}")
|
||||
return v
|
||||
|
||||
|
||||
class StepResultOut(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: UUID
|
||||
execution_id: UUID
|
||||
step_id: UUID
|
||||
step_order: int
|
||||
status: str
|
||||
executed_by: Optional[UUID] = None
|
||||
executed_at: Optional[datetime] = None
|
||||
detected_at: Optional[datetime] = None
|
||||
time_to_detect_seconds: Optional[float] = None
|
||||
detection_asset_id: Optional[UUID] = None
|
||||
notes: Optional[str] = None
|
||||
evidence_ids: Optional[list] = None
|
||||
|
||||
|
||||
# ── Timeline ──────────────────────────────────────────────────────────────────
|
||||
|
||||
class TimelineEntryCreate(BaseModel):
|
||||
actor_side: str
|
||||
entry_type: str
|
||||
content: str
|
||||
step_id: Optional[UUID] = None
|
||||
timestamp: Optional[datetime] = None
|
||||
extra: Optional[dict] = None
|
||||
|
||||
@field_validator("actor_side")
|
||||
@classmethod
|
||||
def validate_side(cls, v):
|
||||
if v not in ("red", "blue", "system"):
|
||||
raise ValueError("actor_side must be red, blue or system")
|
||||
return v
|
||||
|
||||
@field_validator("entry_type")
|
||||
@classmethod
|
||||
def validate_type(cls, v):
|
||||
valid = ("action", "detection", "note", "phase_transition", "flag")
|
||||
if v not in valid:
|
||||
raise ValueError(f"entry_type must be one of {valid}")
|
||||
return v
|
||||
|
||||
|
||||
class TimelineEntryOut(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: UUID
|
||||
execution_id: UUID
|
||||
step_id: Optional[UUID] = None
|
||||
timestamp: datetime
|
||||
actor_side: str
|
||||
actor_id: Optional[UUID] = None
|
||||
entry_type: str
|
||||
content: str
|
||||
extra: Optional[dict] = None
|
||||
|
||||
|
||||
# ── Metrics ───────────────────────────────────────────────────────────────────
|
||||
|
||||
class KillChainMetrics(BaseModel):
|
||||
execution_id: UUID
|
||||
total_steps: int
|
||||
detected_steps: int
|
||||
not_detected_steps: int
|
||||
skipped_steps: int
|
||||
detection_rate: float # 0.0–1.0
|
||||
mttd_seconds: Optional[float] # mean time to detect
|
||||
furthest_undetected_step: Optional[int]
|
||||
furthest_undetected_phase: Optional[str]
|
||||
step_breakdown: list[dict] # per-step detail
|
||||
phase_summary: dict # detection rate per kill-chain phase
|
||||
@@ -5,9 +5,7 @@ import uuid
|
||||
|
||||
# Import datetime from datetime
|
||||
from datetime import datetime
|
||||
|
||||
# Import Any from typing
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
# Import BaseModel, ConfigDict from pydantic
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
@@ -29,9 +27,7 @@ class AuditLogOut(BaseModel):
|
||||
entity_type: str | None = None
|
||||
# Assign entity_id = None
|
||||
entity_id: str | None = None
|
||||
# timestamp: datetime
|
||||
timestamp: datetime
|
||||
# Assign details = None
|
||||
timestamp: Optional[datetime] = None
|
||||
details: dict[str, Any] | None = None
|
||||
|
||||
# Assign model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@@ -0,0 +1,140 @@
|
||||
"""Pydantic schemas for Detection Lifecycle endpoints."""
|
||||
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
from datetime import datetime
|
||||
from app.models.detection_lifecycle import (
|
||||
DetectionConfidence, DetectionHealthStatus, InvalidationReason
|
||||
)
|
||||
|
||||
|
||||
class DetectionAssetCreate(BaseModel):
|
||||
name: str = Field(..., min_length=3, max_length=500)
|
||||
description: Optional[str] = None
|
||||
asset_type: str = Field(..., pattern=r'^(siem_rule|edr_rule|sigma_rule|yara_rule|spl_query|kql_query|custom_script)$')
|
||||
platform: Optional[str] = None
|
||||
rule_content: Optional[str] = None
|
||||
rule_language: Optional[str] = None
|
||||
rule_repository_url: Optional[str] = None
|
||||
rule_file_path: Optional[str] = None
|
||||
rule_version: Optional[str] = None
|
||||
log_source_name: Optional[str] = None
|
||||
log_source_version: Optional[str] = None
|
||||
log_source_config: Optional[dict] = Field(default_factory=dict)
|
||||
infrastructure_details: Optional[dict] = Field(default_factory=dict)
|
||||
expected_alert_frequency: Optional[str] = None
|
||||
tags: Optional[list[str]] = Field(default_factory=list)
|
||||
technique_ids: Optional[list[UUID]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class DetectionAssetUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
rule_content: Optional[str] = None
|
||||
rule_version: Optional[str] = None
|
||||
log_source_version: Optional[str] = None
|
||||
infrastructure_details: Optional[dict] = None
|
||||
expected_alert_frequency: Optional[str] = None
|
||||
health_status: Optional[DetectionHealthStatus] = None
|
||||
last_alert_at: Optional[datetime] = None
|
||||
alert_count_30d: Optional[int] = None
|
||||
false_positive_rate: Optional[float] = None
|
||||
owner_id: Optional[UUID] = None
|
||||
backup_owner_id: Optional[UUID] = None
|
||||
team: Optional[str] = None
|
||||
tags: Optional[list[str]] = None
|
||||
is_active: Optional[bool] = None
|
||||
|
||||
|
||||
class DetectionAssetOut(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: UUID
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
asset_type: str
|
||||
platform: Optional[str] = None
|
||||
rule_language: Optional[str] = None
|
||||
rule_version: Optional[str] = None
|
||||
rule_hash: Optional[str] = None
|
||||
health_status: DetectionHealthStatus
|
||||
last_alert_at: Optional[datetime] = None
|
||||
alert_count_30d: int
|
||||
false_positive_rate: Optional[float] = None
|
||||
expected_alert_frequency: Optional[str] = None
|
||||
owner_id: Optional[UUID] = None
|
||||
team: Optional[str] = None
|
||||
is_active: bool
|
||||
tags: list = Field(default_factory=list)
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
|
||||
class DetectionValidationCreate(BaseModel):
|
||||
detection_asset_id: UUID
|
||||
technique_id: Optional[UUID] = None
|
||||
test_id: Optional[UUID] = None
|
||||
validation_result: str = Field(..., pattern=r'^(detected|not_detected|partial|error)$')
|
||||
validation_method: str
|
||||
notes: Optional[str] = None
|
||||
evidence_ids: Optional[list[UUID]] = Field(default_factory=list)
|
||||
validity_days: int = Field(default=180, ge=30, le=730)
|
||||
|
||||
|
||||
class DetectionValidationOut(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: UUID
|
||||
detection_asset_id: UUID
|
||||
technique_id: Optional[UUID] = None
|
||||
validated_at: Optional[datetime] = None
|
||||
expires_at: datetime
|
||||
is_valid: bool
|
||||
validation_result: Optional[str] = None
|
||||
validation_method: Optional[str] = None
|
||||
invalidated_at: Optional[datetime] = None
|
||||
invalidation_reason: Optional[InvalidationReason] = None
|
||||
validated_by: Optional[UUID] = None
|
||||
notes: Optional[str] = None
|
||||
|
||||
|
||||
class TechniqueConfidenceOut(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
technique_id: UUID
|
||||
confidence_level: DetectionConfidence
|
||||
confidence_score: float
|
||||
detection_count: int
|
||||
valid_detection_count: int
|
||||
last_validated_at: Optional[datetime] = None
|
||||
next_validation_due: Optional[datetime] = None
|
||||
recency_factor: float
|
||||
coverage_factor: float
|
||||
health_factor: float
|
||||
diversity_factor: float
|
||||
risk_factors: list = Field(default_factory=list)
|
||||
|
||||
|
||||
class InfrastructureChangeCreate(BaseModel):
|
||||
change_type: str
|
||||
description: str = Field(..., min_length=10)
|
||||
affected_platforms: list[str] = Field(default_factory=list)
|
||||
affected_log_sources: list[str] = Field(default_factory=list)
|
||||
change_date: Optional[datetime] = None
|
||||
auto_invalidate: bool = True
|
||||
|
||||
|
||||
class InfrastructureChangeOut(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: UUID
|
||||
change_type: str
|
||||
description: str
|
||||
affected_platforms: list = Field(default_factory=list)
|
||||
affected_log_sources: list = Field(default_factory=list)
|
||||
change_date: Optional[datetime] = None
|
||||
auto_invalidate: bool
|
||||
invalidated_count: int
|
||||
reported_by: Optional[UUID] = None
|
||||
created_at: Optional[datetime] = None
|
||||
@@ -0,0 +1,113 @@
|
||||
"""Phase 13: Executive Dashboard — Pydantic schemas."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date, datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class PostureSnapshotOut(BaseModel):
|
||||
id: UUID
|
||||
snapshot_date: date
|
||||
|
||||
# Coverage
|
||||
total_techniques: int
|
||||
validated_count: int
|
||||
partial_count: int
|
||||
not_covered_count: int
|
||||
coverage_pct: float
|
||||
|
||||
# Risk
|
||||
avg_risk_score: float
|
||||
critical_count: int
|
||||
high_count: int
|
||||
medium_count: int
|
||||
low_count: int
|
||||
|
||||
# Operations
|
||||
open_queue_items: int
|
||||
orphan_techniques: int
|
||||
|
||||
# Knowledge
|
||||
playbook_count: int
|
||||
lesson_count: int
|
||||
|
||||
# MTTD
|
||||
mttd_avg_seconds: Optional[float] = None
|
||||
executions_30d: int
|
||||
detection_rate_30d: Optional[float] = None
|
||||
|
||||
# Meta
|
||||
created_by: Optional[UUID] = None
|
||||
created_at: Optional[datetime] = None
|
||||
extra: Optional[Dict[str, Any]] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class ExecutiveSummary(BaseModel):
|
||||
"""Full executive view — current posture + trends."""
|
||||
snapshot: PostureSnapshotOut
|
||||
coverage_trend: List[Dict[str, Any]] = Field(
|
||||
default_factory=list,
|
||||
description="Last 30-day coverage_pct series [{date, value}]",
|
||||
)
|
||||
risk_trend: List[Dict[str, Any]] = Field(
|
||||
default_factory=list,
|
||||
description="Last 30-day avg_risk_score series [{date, value}]",
|
||||
)
|
||||
top_risks: List[Dict[str, Any]] = Field(
|
||||
default_factory=list,
|
||||
description="Top 5 highest-risk techniques",
|
||||
)
|
||||
coverage_by_tactic: List[Dict[str, Any]] = Field(
|
||||
default_factory=list,
|
||||
description="Per-tactic validated/partial/not_covered counts",
|
||||
)
|
||||
recent_activity: List[Dict[str, Any]] = Field(
|
||||
default_factory=list,
|
||||
description="Most-recent events (tests, paths, queue changes)",
|
||||
)
|
||||
|
||||
|
||||
class KpiBlock(BaseModel):
|
||||
"""Compact KPI block for a dashboard header."""
|
||||
coverage_pct: float
|
||||
avg_risk_score: float
|
||||
critical_count: int
|
||||
open_queue_items: int
|
||||
orphan_techniques: int
|
||||
mttd_avg_seconds: Optional[float] = None
|
||||
detection_rate_30d: Optional[float] = None
|
||||
playbook_count: int
|
||||
lesson_count: int
|
||||
snapshot_date: date
|
||||
snapshot_id: Optional[UUID] = None
|
||||
|
||||
|
||||
class CoverageByTactic(BaseModel):
|
||||
tactic: str
|
||||
total: int
|
||||
validated: int
|
||||
partial: int
|
||||
not_covered: int
|
||||
coverage_pct: float
|
||||
|
||||
|
||||
class PostureHistoryEntry(BaseModel):
|
||||
snapshot_date: date
|
||||
coverage_pct: float
|
||||
avg_risk_score: float
|
||||
critical_count: int
|
||||
open_queue_items: int
|
||||
|
||||
|
||||
class ActivityEntry(BaseModel):
|
||||
ts: datetime
|
||||
category: str # "test" | "attack_path" | "queue" | "osint"
|
||||
title: str
|
||||
detail: Optional[str] = None
|
||||
@@ -0,0 +1,149 @@
|
||||
"""Phase 11: Knowledge Management schemas — Playbooks + Lessons Learned."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, field_validator
|
||||
|
||||
# ── Constants ─────────────────────────────────────────────────────────────────
|
||||
|
||||
VALID_PLAYBOOK_TYPES = ["attack", "detect", "investigate", "respond", "hunt"]
|
||||
VALID_SEVERITIES = ["critical", "high", "medium", "low", "info"]
|
||||
VALID_ENTITY_TYPES = ["test", "campaign", "attack_path", "manual"]
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════════════════════════════
|
||||
# Playbook schemas
|
||||
# ══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
class PlaybookCreate(BaseModel):
|
||||
technique_id: UUID
|
||||
playbook_type: str
|
||||
title: str
|
||||
content: str = ""
|
||||
tools: List[str] = []
|
||||
prerequisites: List[str] = []
|
||||
change_note: Optional[str] = None
|
||||
|
||||
@field_validator("playbook_type")
|
||||
@classmethod
|
||||
def validate_playbook_type(cls, v: str) -> str:
|
||||
if v not in VALID_PLAYBOOK_TYPES:
|
||||
raise ValueError(
|
||||
f"Invalid playbook_type '{v}'. Must be one of: {VALID_PLAYBOOK_TYPES}"
|
||||
)
|
||||
return v
|
||||
|
||||
|
||||
class PlaybookUpdate(BaseModel):
|
||||
title: Optional[str] = None
|
||||
content: Optional[str] = None
|
||||
tools: Optional[List[str]] = None
|
||||
prerequisites: Optional[List[str]] = None
|
||||
change_note: Optional[str] = None
|
||||
|
||||
|
||||
class PlaybookVersionOut(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: UUID
|
||||
playbook_id: UUID
|
||||
version: int
|
||||
title: str
|
||||
content: str
|
||||
tools: List[str] = []
|
||||
prerequisites: List[str] = []
|
||||
changed_by: Optional[UUID]
|
||||
change_note: Optional[str]
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class PlaybookOut(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: UUID
|
||||
technique_id: UUID
|
||||
playbook_type: str
|
||||
title: str
|
||||
content: str
|
||||
version: int
|
||||
tools: List[str] = []
|
||||
prerequisites: List[str] = []
|
||||
created_by: Optional[UUID]
|
||||
updated_by: Optional[UUID]
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
is_active: bool
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════════════════════════════
|
||||
# Lesson Learned schemas
|
||||
# ══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
class LessonLearnedCreate(BaseModel):
|
||||
title: str
|
||||
what_happened: str
|
||||
root_cause: str
|
||||
fix_applied: Optional[str] = None
|
||||
severity: str = "medium"
|
||||
entity_type: str = "manual"
|
||||
entity_id: Optional[UUID] = None
|
||||
technique_ids: List[str] = []
|
||||
tags: List[str] = []
|
||||
|
||||
@field_validator("severity")
|
||||
@classmethod
|
||||
def validate_severity(cls, v: str) -> str:
|
||||
if v not in VALID_SEVERITIES:
|
||||
raise ValueError(
|
||||
f"Invalid severity '{v}'. Must be one of: {VALID_SEVERITIES}"
|
||||
)
|
||||
return v
|
||||
|
||||
@field_validator("entity_type")
|
||||
@classmethod
|
||||
def validate_entity_type(cls, v: str) -> str:
|
||||
if v not in VALID_ENTITY_TYPES:
|
||||
raise ValueError(
|
||||
f"Invalid entity_type '{v}'. Must be one of: {VALID_ENTITY_TYPES}"
|
||||
)
|
||||
return v
|
||||
|
||||
|
||||
class LessonLearnedUpdate(BaseModel):
|
||||
title: Optional[str] = None
|
||||
what_happened: Optional[str] = None
|
||||
root_cause: Optional[str] = None
|
||||
fix_applied: Optional[str] = None
|
||||
severity: Optional[str] = None
|
||||
technique_ids: Optional[List[str]] = None
|
||||
tags: Optional[List[str]] = None
|
||||
|
||||
@field_validator("severity")
|
||||
@classmethod
|
||||
def validate_severity(cls, v: Optional[str]) -> Optional[str]:
|
||||
if v is not None and v not in VALID_SEVERITIES:
|
||||
raise ValueError(
|
||||
f"Invalid severity '{v}'. Must be one of: {VALID_SEVERITIES}"
|
||||
)
|
||||
return v
|
||||
|
||||
|
||||
class LessonLearnedOut(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: UUID
|
||||
title: str
|
||||
what_happened: str
|
||||
root_cause: str
|
||||
fix_applied: Optional[str]
|
||||
severity: str
|
||||
entity_type: str
|
||||
entity_id: Optional[UUID]
|
||||
technique_ids: List[str] = []
|
||||
tags: List[str] = []
|
||||
created_by: Optional[UUID]
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
is_active: bool
|
||||
@@ -0,0 +1,124 @@
|
||||
"""Phase 13: Operational Alerts — Pydantic schemas."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from app.models.operational_alert import AlertRuleType, AlertSeverity, AlertStatus
|
||||
|
||||
VALID_SEVERITIES = {s.value for s in AlertSeverity}
|
||||
VALID_STATUSES = {s.value for s in AlertStatus}
|
||||
VALID_RULE_TYPES = {r.value for r in AlertRuleType}
|
||||
|
||||
|
||||
# ── AlertRule schemas ─────────────────────────────────────────────────────────
|
||||
|
||||
class AlertRuleCreate(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=300)
|
||||
description: Optional[str] = None
|
||||
rule_type: str
|
||||
severity: str = "medium"
|
||||
config: Dict[str, Any] = Field(default_factory=dict)
|
||||
notify_in_app: bool = True
|
||||
notify_webhook: bool = False
|
||||
webhook_id: Optional[UUID] = None
|
||||
cooldown_hours: int = Field(24, ge=0, le=8760)
|
||||
|
||||
@field_validator("rule_type")
|
||||
@classmethod
|
||||
def validate_rule_type(cls, v: str) -> str:
|
||||
if v not in VALID_RULE_TYPES:
|
||||
raise ValueError(f"Invalid rule_type. Valid: {VALID_RULE_TYPES}")
|
||||
return v
|
||||
|
||||
@field_validator("severity")
|
||||
@classmethod
|
||||
def validate_severity(cls, v: str) -> str:
|
||||
if v not in VALID_SEVERITIES:
|
||||
raise ValueError(f"Invalid severity. Valid: {VALID_SEVERITIES}")
|
||||
return v
|
||||
|
||||
|
||||
class AlertRuleUpdate(BaseModel):
|
||||
name: Optional[str] = Field(None, min_length=1, max_length=300)
|
||||
description: Optional[str] = None
|
||||
severity: Optional[str] = None
|
||||
is_enabled: Optional[bool] = None
|
||||
config: Optional[Dict[str, Any]] = None
|
||||
notify_in_app: Optional[bool] = None
|
||||
notify_webhook: Optional[bool] = None
|
||||
webhook_id: Optional[UUID] = None
|
||||
cooldown_hours: Optional[int] = Field(None, ge=0, le=8760)
|
||||
|
||||
@field_validator("severity")
|
||||
@classmethod
|
||||
def validate_severity(cls, v: Optional[str]) -> Optional[str]:
|
||||
if v is not None and v not in VALID_SEVERITIES:
|
||||
raise ValueError(f"Invalid severity. Valid: {VALID_SEVERITIES}")
|
||||
return v
|
||||
|
||||
|
||||
class AlertRuleOut(BaseModel):
|
||||
id: UUID
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
rule_type: str
|
||||
severity: str
|
||||
is_enabled: bool
|
||||
is_system: bool
|
||||
config: Dict[str, Any]
|
||||
notify_in_app: bool
|
||||
notify_webhook: bool
|
||||
webhook_id: Optional[UUID] = None
|
||||
cooldown_hours: int
|
||||
created_by: Optional[UUID] = None
|
||||
created_at: Optional[datetime] = None
|
||||
last_fired_at: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
# ── AlertInstance schemas ─────────────────────────────────────────────────────
|
||||
|
||||
class AlertInstanceOut(BaseModel):
|
||||
id: UUID
|
||||
rule_id: Optional[UUID] = None
|
||||
rule_name: str
|
||||
rule_type: str
|
||||
severity: str
|
||||
title: str
|
||||
message: str
|
||||
details: Optional[Dict[str, Any]] = None
|
||||
status: str
|
||||
acknowledged_by: Optional[UUID] = None
|
||||
acknowledged_at: Optional[datetime] = None
|
||||
resolved_at: Optional[datetime] = None
|
||||
created_at: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
# ── Evaluation result ─────────────────────────────────────────────────────────
|
||||
|
||||
class EvaluationResult(BaseModel):
|
||||
rules_evaluated: int
|
||||
alerts_fired: int
|
||||
alerts: List[AlertInstanceOut] = Field(default_factory=list)
|
||||
duration_seconds: float
|
||||
|
||||
|
||||
# ── Summary ───────────────────────────────────────────────────────────────────
|
||||
|
||||
class AlertSummary(BaseModel):
|
||||
total_open: int
|
||||
total_acknowledged: int
|
||||
total_resolved: int
|
||||
by_severity: Dict[str, int]
|
||||
by_rule_type: Dict[str, int]
|
||||
recent_alerts: List[AlertInstanceOut] = Field(default_factory=list)
|
||||
@@ -0,0 +1,153 @@
|
||||
"""Pydantic schemas for Phase 9: Ownership & Revalidation Queue."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, field_validator
|
||||
|
||||
|
||||
# ── Technique Ownership ───────────────────────────────────────────────────────
|
||||
|
||||
class TechniqueOwnershipSet(BaseModel):
|
||||
"""Set (create or replace) ownership for a technique."""
|
||||
owner_id: Optional[UUID] = None
|
||||
backup_owner_id: Optional[UUID] = None
|
||||
team: Optional[str] = None
|
||||
notes: Optional[str] = None
|
||||
|
||||
|
||||
class TechniqueOwnershipOut(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: UUID
|
||||
technique_id: UUID
|
||||
owner_id: Optional[UUID] = None
|
||||
backup_owner_id: Optional[UUID] = None
|
||||
team: Optional[str] = None
|
||||
notes: Optional[str] = None
|
||||
assigned_at: Optional[datetime] = None
|
||||
assigned_by: Optional[UUID] = None
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
|
||||
class DetectionAssetOwnershipPatch(BaseModel):
|
||||
"""Update ownership fields on a detection asset."""
|
||||
owner_id: Optional[UUID] = None
|
||||
backup_owner_id: Optional[UUID] = None
|
||||
team: Optional[str] = None
|
||||
|
||||
|
||||
# ── Bulk Assignment ───────────────────────────────────────────────────────────
|
||||
|
||||
class BulkAssignRequest(BaseModel):
|
||||
"""Bulk-assign ownership by tactic, platform, or team filter."""
|
||||
owner_id: Optional[UUID] = None
|
||||
backup_owner_id: Optional[UUID] = None
|
||||
team: Optional[str] = None
|
||||
# Filters — at least one must be set
|
||||
tactic: Optional[str] = None # assign all techniques with this tactic
|
||||
platform: Optional[str] = None # assign all detection assets with this platform
|
||||
overwrite: bool = False # overwrite existing assignments
|
||||
|
||||
|
||||
class BulkAssignResult(BaseModel):
|
||||
assigned_count: int
|
||||
skipped_count: int
|
||||
target_type: str # "technique" or "detection_asset"
|
||||
|
||||
|
||||
# ── Revalidation Queue ────────────────────────────────────────────────────────
|
||||
|
||||
class QueueItemPatch(BaseModel):
|
||||
"""Update a revalidation queue item."""
|
||||
status: Optional[str] = None
|
||||
assigned_to: Optional[UUID] = None
|
||||
priority: Optional[str] = None
|
||||
due_date: Optional[datetime] = None
|
||||
|
||||
@field_validator("status")
|
||||
@classmethod
|
||||
def validate_status(cls, v):
|
||||
from app.models.ownership_queue import QueueStatus
|
||||
if v is not None:
|
||||
try:
|
||||
QueueStatus(v)
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid status: {v}")
|
||||
return v
|
||||
|
||||
@field_validator("priority")
|
||||
@classmethod
|
||||
def validate_priority(cls, v):
|
||||
from app.models.ownership_queue import QueuePriority
|
||||
if v is not None:
|
||||
try:
|
||||
QueuePriority(v)
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid priority: {v}")
|
||||
return v
|
||||
|
||||
|
||||
class QueueItemCreate(BaseModel):
|
||||
"""Manually create a queue item."""
|
||||
technique_id: Optional[UUID] = None
|
||||
detection_asset_id: Optional[UUID] = None
|
||||
priority: str = "medium"
|
||||
reason: str = "manual"
|
||||
reason_detail: Optional[str] = None
|
||||
assigned_to: Optional[UUID] = None
|
||||
due_date: Optional[datetime] = None
|
||||
|
||||
@field_validator("reason")
|
||||
@classmethod
|
||||
def validate_reason(cls, v):
|
||||
from app.models.ownership_queue import QueueReason
|
||||
try:
|
||||
QueueReason(v)
|
||||
except ValueError:
|
||||
valid = [e.value for e in QueueReason]
|
||||
raise ValueError(f"Invalid reason '{v}'. Must be one of: {valid}")
|
||||
return v
|
||||
|
||||
@field_validator("priority")
|
||||
@classmethod
|
||||
def validate_priority(cls, v):
|
||||
from app.models.ownership_queue import QueuePriority
|
||||
try:
|
||||
QueuePriority(v)
|
||||
except ValueError:
|
||||
valid = [e.value for e in QueuePriority]
|
||||
raise ValueError(f"Invalid priority '{v}'. Must be one of: {valid}")
|
||||
return v
|
||||
|
||||
|
||||
class QueueItemOut(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: UUID
|
||||
technique_id: Optional[UUID] = None
|
||||
detection_asset_id: Optional[UUID] = None
|
||||
priority: str
|
||||
reason: str
|
||||
reason_detail: Optional[str] = None
|
||||
status: str
|
||||
assigned_to: Optional[UUID] = None
|
||||
due_date: Optional[datetime] = None
|
||||
created_at: Optional[datetime] = None
|
||||
completed_at: Optional[datetime] = None
|
||||
dismissed_at: Optional[datetime] = None
|
||||
completed_by: Optional[UUID] = None
|
||||
extra: Optional[dict] = None
|
||||
|
||||
|
||||
# ── Analyst Dashboard ─────────────────────────────────────────────────────────
|
||||
|
||||
class AnalystDashboard(BaseModel):
|
||||
"""Personalised daily workday view for an analyst."""
|
||||
my_pending_items: list[QueueItemOut]
|
||||
expiring_validations_7d: list[dict]
|
||||
recent_infra_changes: list[dict]
|
||||
my_low_confidence_techniques: list[dict]
|
||||
summary: dict
|
||||
@@ -0,0 +1,71 @@
|
||||
"""Phase 12: Risk Intelligence schemas."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
VALID_RISK_LEVELS = ["critical", "high", "medium", "low", "info"]
|
||||
|
||||
|
||||
class TechniqueRiskProfileOut(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: UUID
|
||||
technique_id: UUID
|
||||
risk_score: float
|
||||
likelihood: float
|
||||
impact: float
|
||||
risk_level: str
|
||||
detection_gap: float
|
||||
threat_actor_count: int
|
||||
osint_signal_count: int
|
||||
test_fail_count: int
|
||||
test_total_count: int
|
||||
test_failure_rate: float
|
||||
confidence_level: float
|
||||
scoring_breakdown: Optional[Dict[str, Any]]
|
||||
recommendations: Optional[List[str]]
|
||||
computed_at: datetime
|
||||
is_stale: bool
|
||||
|
||||
|
||||
class RiskMatrixEntry(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
technique_id: UUID
|
||||
technique_name: Optional[str] = None
|
||||
technique_tid: Optional[str] = None # e.g. "T1059"
|
||||
risk_score: float
|
||||
likelihood: float
|
||||
impact: float
|
||||
risk_level: str
|
||||
detection_gap: float
|
||||
computed_at: datetime
|
||||
|
||||
|
||||
class RiskSummary(BaseModel):
|
||||
total_techniques: int
|
||||
scored_techniques: int
|
||||
stale_count: int
|
||||
by_level: Dict[str, int] # {"critical": 3, "high": 12, ...}
|
||||
avg_risk_score: float
|
||||
top_risks: List[RiskMatrixEntry]
|
||||
|
||||
|
||||
class RecommendationItem(BaseModel):
|
||||
technique_id: UUID
|
||||
technique_name: Optional[str] = None
|
||||
technique_tid: Optional[str] = None
|
||||
risk_level: str
|
||||
risk_score: float
|
||||
recommendations: List[str]
|
||||
priority: int # 1 = highest
|
||||
|
||||
|
||||
class ComputeResult(BaseModel):
|
||||
computed: int
|
||||
skipped: int
|
||||
errors: int
|
||||
duration_seconds: float
|
||||
@@ -0,0 +1,80 @@
|
||||
"""Phase 14: SSO / SAML 2.0 Pydantic schemas."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SsoConfigCreate(BaseModel):
|
||||
is_enabled: bool = False
|
||||
provider_name: Optional[str] = None
|
||||
|
||||
# SP settings (auto-derived if not provided)
|
||||
sp_entity_id: Optional[str] = None
|
||||
sp_acs_url: Optional[str] = None
|
||||
sp_slo_url: Optional[str] = None
|
||||
sp_certificate: Optional[str] = None
|
||||
sp_private_key: Optional[str] = None
|
||||
|
||||
# IdP settings
|
||||
idp_entity_id: Optional[str] = None
|
||||
idp_sso_url: Optional[str] = None
|
||||
idp_slo_url: Optional[str] = None
|
||||
idp_certificate: Optional[str] = None
|
||||
|
||||
# Attribute mapping
|
||||
attr_email: Optional[str] = "email"
|
||||
attr_username: Optional[str] = "username"
|
||||
attr_role: Optional[str] = "role"
|
||||
default_role: Optional[str] = "viewer"
|
||||
auto_provision: bool = True
|
||||
|
||||
|
||||
class SsoConfigUpdate(SsoConfigCreate):
|
||||
"""All fields optional for partial updates."""
|
||||
pass
|
||||
|
||||
|
||||
class SsoConfigOut(BaseModel):
|
||||
id: UUID
|
||||
is_enabled: bool
|
||||
provider_name: Optional[str] = None
|
||||
|
||||
sp_entity_id: Optional[str] = None
|
||||
sp_acs_url: Optional[str] = None
|
||||
sp_slo_url: Optional[str] = None
|
||||
sp_certificate: Optional[str] = None
|
||||
# sp_private_key is intentionally OMITTED from responses
|
||||
|
||||
idp_entity_id: Optional[str] = None
|
||||
idp_sso_url: Optional[str] = None
|
||||
idp_slo_url: Optional[str] = None
|
||||
idp_certificate: Optional[str] = None
|
||||
|
||||
attr_email: Optional[str] = None
|
||||
attr_username: Optional[str] = None
|
||||
attr_role: Optional[str] = None
|
||||
default_role: Optional[str] = None
|
||||
auto_provision: bool = True
|
||||
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class SsoLoginInitResponse(BaseModel):
|
||||
redirect_url: str = Field(..., description="URL to redirect the browser to for IdP login")
|
||||
request_id: str = Field(..., description="SAML AuthnRequest ID for validation")
|
||||
|
||||
|
||||
class SsoStatusResponse(BaseModel):
|
||||
enabled: bool
|
||||
provider_name: Optional[str] = None
|
||||
configured: bool = Field(..., description="True if IdP settings are present")
|
||||
login_url: Optional[str] = None # /sso/login URL
|
||||
+58
-23
@@ -6,14 +6,12 @@ import uuid
|
||||
# Import datetime from datetime
|
||||
from datetime import datetime
|
||||
|
||||
# Import BaseModel, ConfigDict from pydantic
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic import BaseModel, ConfigDict, model_validator
|
||||
|
||||
# Import DataClassification from app.domain.enums
|
||||
from app.domain.enums import DataClassification
|
||||
|
||||
# Import TestResult, TestState from app.models.enums
|
||||
from app.models.enums import TestResult, TestState
|
||||
from app.schemas.evidence import EvidenceOut
|
||||
|
||||
# ── Create ──────────────────────────────────────────────────────────
|
||||
|
||||
@@ -209,7 +207,7 @@ class TestOut(BaseModel):
|
||||
red_started_at: datetime | None = None
|
||||
# Assign blue_started_at = None
|
||||
blue_started_at: datetime | None = None
|
||||
# Assign paused_at = None
|
||||
blue_work_started_at: datetime | None = None
|
||||
paused_at: datetime | None = None
|
||||
# Assign red_paused_seconds = 0
|
||||
red_paused_seconds: int = 0
|
||||
@@ -235,27 +233,64 @@ class TestOut(BaseModel):
|
||||
# Assign technique_name = None
|
||||
technique_name: str | None = None
|
||||
|
||||
# Assign model_config = ConfigDict(from_attributes=True)
|
||||
# Evidences split by team (populated from the ORM relationship)
|
||||
red_evidences: list[EvidenceOut] = []
|
||||
blue_evidences: list[EvidenceOut] = []
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
# Apply the @classmethod decorator
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
# Define function model_validate
|
||||
def model_validate(cls, obj: object, **kwargs: object) -> "TestOut":
|
||||
"""Populate technique fields from the ORM relationship before validation.
|
||||
def _populate_derived_fields(cls, obj):
|
||||
"""Populate technique and evidence fields from ORM relationships.
|
||||
|
||||
Args:
|
||||
obj (object): The ORM model instance (or any compatible object) to validate.
|
||||
**kwargs (object): Additional keyword arguments forwarded to the parent.
|
||||
Uses ``@model_validator(mode='before')`` so it is called by Pydantic's
|
||||
internal Rust validation pipeline, including FastAPI's TypeAdapter path.
|
||||
A plain ``model_validate`` classmethod override is **not** invoked by
|
||||
FastAPI's response serialisation in Pydantic v2 — only registered
|
||||
validators are guaranteed to run.
|
||||
|
||||
Returns:
|
||||
TestOut: The validated schema instance with technique fields populated.
|
||||
Evidences are only processed when the relationship was **explicitly loaded**
|
||||
(via joinedload or prior access). Accessing ``obj.evidences`` blindly on a
|
||||
session-expired ORM object triggers a lazy query that fails on mutation
|
||||
endpoints that do not joinload the relationship. We inspect ``obj.__dict__``
|
||||
directly — SQLAlchemy stores loaded relationships there; if the key is absent
|
||||
the relationship is unloaded and we leave the lists empty (the frontend
|
||||
invalidates and refetches the detail endpoint, which *does* joinload).
|
||||
"""
|
||||
# Check: hasattr(obj, "technique") and obj.technique is not None
|
||||
if hasattr(obj, "technique") and obj.technique is not None:
|
||||
# Assign obj.__dict__["technique_mitre_id"] = obj.technique.mitre_id
|
||||
obj.__dict__["technique_mitre_id"] = obj.technique.mitre_id
|
||||
# Assign obj.__dict__["technique_name"] = obj.technique.name
|
||||
obj.__dict__["technique_name"] = obj.technique.name
|
||||
# Return super().model_validate(obj, **kwargs)
|
||||
return super().model_validate(obj, **kwargs)
|
||||
if not hasattr(obj, "__dict__"):
|
||||
return obj
|
||||
|
||||
# Technique info (lazy-load is fine here: session is still open on GET)
|
||||
try:
|
||||
if hasattr(obj, "technique") and obj.technique is not None:
|
||||
obj.__dict__["technique_mitre_id"] = obj.technique.mitre_id
|
||||
obj.__dict__["technique_name"] = obj.technique.name
|
||||
except Exception: # nosec B110
|
||||
pass # DetachedInstanceError or similar — leave technique fields None
|
||||
|
||||
# Only split evidences when they are already in memory (loaded via joinedload)
|
||||
raw_evs = obj.__dict__.get("evidences")
|
||||
if raw_evs is not None:
|
||||
red_evs: list[EvidenceOut] = []
|
||||
blue_evs: list[EvidenceOut] = []
|
||||
for ev in raw_evs:
|
||||
ev_out = EvidenceOut(
|
||||
id=ev.id,
|
||||
test_id=ev.test_id,
|
||||
file_name=ev.file_name,
|
||||
sha256_hash=ev.sha256_hash,
|
||||
uploaded_by=ev.uploaded_by,
|
||||
uploaded_at=ev.uploaded_at,
|
||||
team=ev.team,
|
||||
notes=ev.notes,
|
||||
download_url=f"/api/v1/evidence/{ev.id}/file",
|
||||
)
|
||||
if ev.team and ev.team.value == "blue":
|
||||
blue_evs.append(ev_out)
|
||||
else:
|
||||
red_evs.append(ev_out)
|
||||
obj.__dict__["red_evidences"] = red_evs
|
||||
obj.__dict__["blue_evidences"] = blue_evs
|
||||
|
||||
return obj
|
||||
|
||||
@@ -111,9 +111,19 @@ class TestTemplateSummary(BaseModel):
|
||||
|
||||
|
||||
class TestTemplateInstantiate(BaseModel):
|
||||
"""Payload to create a real test from an existing template."""
|
||||
"""Payload to create a real test from an existing template.
|
||||
|
||||
Optional override fields take precedence over the template values when provided.
|
||||
"""
|
||||
|
||||
# template_id: uuid.UUID
|
||||
template_id: uuid.UUID
|
||||
# technique_id: str # accepts both UUID and MITRE ID (e.g. "T1059.001")
|
||||
technique_id: str # accepts both UUID and MITRE ID (e.g. "T1059.001")
|
||||
|
||||
# User-editable overrides (if omitted the template value is used)
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
platform: str | None = None
|
||||
procedure_text: str | None = None
|
||||
tool_used: str | None = None
|
||||
|
||||
@@ -9,8 +9,8 @@ import uuid
|
||||
# Import datetime from datetime
|
||||
from datetime import datetime
|
||||
|
||||
# Import BaseModel, ConfigDict, field_validator from pydantic
|
||||
from pydantic import BaseModel, ConfigDict, field_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||
|
||||
|
||||
# ── Username policy ─────────────────────────────────────────────────
|
||||
|
||||
@@ -225,7 +225,22 @@ class PasswordChange(BaseModel):
|
||||
return _validate_password_strength(v)
|
||||
|
||||
|
||||
# Define class UserOut
|
||||
class UserPreferencesUpdate(BaseModel):
|
||||
"""Payload for updating current user's notification preferences and Jira/Tempo settings."""
|
||||
|
||||
notification_preferences: dict | None = None
|
||||
jira_account_id: str | None = None
|
||||
# Personal Jira API token (Atlassian token) — write-only.
|
||||
# Set to empty string "" to clear the token.
|
||||
jira_api_token: str | None = None
|
||||
# Atlassian email for Jira auth — overrides account email.
|
||||
# Set to empty string "" to clear (falls back to account email).
|
||||
jira_email: str | None = None
|
||||
# Personal Tempo API token — write-only.
|
||||
# Set to empty string "" to clear the token.
|
||||
tempo_api_token: str | None = None
|
||||
|
||||
|
||||
class UserOut(BaseModel):
|
||||
"""Complete representation returned by the API."""
|
||||
|
||||
@@ -245,6 +260,27 @@ class UserOut(BaseModel):
|
||||
created_at: datetime | None = None
|
||||
# Assign last_login = None
|
||||
last_login: datetime | None = None
|
||||
notification_preferences: dict | None = None
|
||||
jira_account_id: str | None = None
|
||||
jira_email: str | None = None
|
||||
# Read from ORM but NEVER exposed in responses — used only to derive *_token_set flags.
|
||||
jira_api_token: str | None = Field(default=None, exclude=True)
|
||||
tempo_api_token: str | None = Field(default=None, exclude=True)
|
||||
# True when the user has the respective token stored.
|
||||
jira_token_set: bool = False
|
||||
tempo_token_set: bool = False
|
||||
|
||||
# Assign model_config = ConfigDict(from_attributes=True)
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _derive_token_set_flags(self) -> "UserOut":
|
||||
"""Derive *_token_set booleans from the (excluded) raw token fields.
|
||||
|
||||
Uses @model_validator(mode='after') so Pydantic's Rust core calls it
|
||||
during FastAPI response serialisation — model_validate() overrides are
|
||||
bypassed by FastAPI's __pydantic_validator__.validate_python() path.
|
||||
"""
|
||||
self.jira_token_set = bool(self.jira_api_token)
|
||||
self.tempo_token_set = bool(self.tempo_api_token)
|
||||
return self
|
||||
|
||||
@@ -0,0 +1,91 @@
|
||||
"""Pydantic schemas for Webhook endpoints."""
|
||||
import ipaddress
|
||||
import socket
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, field_validator
|
||||
|
||||
# RFC-5735 / RFC-1918 / RFC-3927 — ranges that must never be webhook targets
|
||||
_BLOCKED_NETWORKS = [
|
||||
ipaddress.ip_network("10.0.0.0/8"),
|
||||
ipaddress.ip_network("172.16.0.0/12"),
|
||||
ipaddress.ip_network("192.168.0.0/16"),
|
||||
ipaddress.ip_network("169.254.0.0/16"), # link-local / AWS IMDS
|
||||
ipaddress.ip_network("127.0.0.0/8"), # loopback
|
||||
ipaddress.ip_network("::1/128"), # IPv6 loopback
|
||||
ipaddress.ip_network("fc00::/7"), # IPv6 ULA
|
||||
]
|
||||
|
||||
|
||||
def _validate_webhook_url(url: str) -> str:
|
||||
"""Reject URLs that point to private/reserved addresses (SSRF prevention)."""
|
||||
parsed = urlparse(url)
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
raise ValueError("Webhook URL must use http or https")
|
||||
hostname = parsed.hostname
|
||||
if not hostname:
|
||||
raise ValueError("Webhook URL must include a hostname")
|
||||
|
||||
# Resolve hostname to IP(s) and reject any private/reserved address
|
||||
try:
|
||||
infos = socket.getaddrinfo(hostname, None)
|
||||
for info in infos:
|
||||
raw_ip = info[4][0]
|
||||
try:
|
||||
ip_obj = ipaddress.ip_address(raw_ip)
|
||||
except ValueError:
|
||||
continue
|
||||
for network in _BLOCKED_NETWORKS:
|
||||
if ip_obj in network:
|
||||
raise ValueError(
|
||||
f"Webhook URL resolves to a private/reserved address ({raw_ip}) "
|
||||
"and cannot be used"
|
||||
)
|
||||
except OSError:
|
||||
# DNS resolution failure — allow (will fail at dispatch time)
|
||||
pass
|
||||
|
||||
return url
|
||||
|
||||
|
||||
class WebhookConfigCreate(BaseModel):
|
||||
name: str
|
||||
url: str
|
||||
secret: str | None = None
|
||||
events: list[str] = []
|
||||
is_active: bool = True
|
||||
|
||||
@field_validator("url")
|
||||
@classmethod
|
||||
def url_must_be_external(cls, v: str) -> str:
|
||||
return _validate_webhook_url(v)
|
||||
|
||||
|
||||
class WebhookConfigUpdate(BaseModel):
|
||||
name: str | None = None
|
||||
url: str | None = None
|
||||
secret: str | None = None
|
||||
events: list[str] | None = None
|
||||
is_active: bool | None = None
|
||||
|
||||
@field_validator("url")
|
||||
@classmethod
|
||||
def url_must_be_external(cls, v: str | None) -> str | None:
|
||||
if v is None:
|
||||
return v
|
||||
return _validate_webhook_url(v)
|
||||
|
||||
class WebhookConfigOut(BaseModel):
|
||||
id: uuid.UUID
|
||||
name: str
|
||||
url: str
|
||||
secret: str | None = None # masked on read
|
||||
events: list[str]
|
||||
is_active: bool
|
||||
created_by: uuid.UUID | None = None
|
||||
last_triggered_at: datetime | None = None
|
||||
failure_count: int
|
||||
created_at: datetime | None = None
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
@@ -0,0 +1,39 @@
|
||||
"""Seed default decay policies."""
|
||||
from datetime import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
from app.models.decay_policy import DecayPolicy
|
||||
|
||||
|
||||
def seed_decay_policies(db: Session) -> None:
|
||||
existing = db.query(DecayPolicy).filter(DecayPolicy.is_default == True).first()
|
||||
if existing:
|
||||
return
|
||||
|
||||
now = datetime.utcnow()
|
||||
default_policy = DecayPolicy(
|
||||
name="Default Decay Policy",
|
||||
description="Standard: Fresh 90d, Aging 91-180d, Stale 181-365d.",
|
||||
fresh_days=90, aging_days=180, stale_days=365,
|
||||
default_validity_days=180, silent_threshold_days=30,
|
||||
noisy_threshold_daily=100,
|
||||
recency_weight=0.30, coverage_weight=0.30,
|
||||
health_weight=0.25, diversity_weight=0.15,
|
||||
is_default=True, is_active=True,
|
||||
created_at=now, updated_at=now,
|
||||
)
|
||||
db.add(default_policy)
|
||||
|
||||
critical_policy = DecayPolicy(
|
||||
name="Critical Techniques Policy",
|
||||
description="Stricter: Fresh 60d, Aging 90d, Stale 180d.",
|
||||
applies_to_tactic="initial-access",
|
||||
fresh_days=60, aging_days=90, stale_days=180,
|
||||
default_validity_days=90, silent_threshold_days=14,
|
||||
noisy_threshold_daily=50,
|
||||
recency_weight=0.35, coverage_weight=0.30,
|
||||
health_weight=0.25, diversity_weight=0.10,
|
||||
is_default=False, is_active=True,
|
||||
created_at=now, updated_at=now,
|
||||
)
|
||||
db.add(critical_policy)
|
||||
db.commit()
|
||||
@@ -0,0 +1,155 @@
|
||||
"""Phase 14: API Key service — create, list, revoke, authenticate."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.domain.errors import EntityNotFoundError, DuplicateEntityError
|
||||
from app.models.api_key import ApiKey, generate_raw_key, hash_key, key_prefix_display
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
# ── Create ────────────────────────────────────────────────────────────────────
|
||||
|
||||
def create_api_key(
|
||||
db: Session,
|
||||
user_id: UUID,
|
||||
name: str,
|
||||
scopes: List[str],
|
||||
description: Optional[str] = None,
|
||||
expires_at: Optional[datetime] = None,
|
||||
) -> tuple[ApiKey, str]:
|
||||
"""
|
||||
Create a new API key.
|
||||
|
||||
Returns ``(ApiKey, raw_key)`` — the raw_key must be shown to the user
|
||||
immediately and is never retrievable again.
|
||||
"""
|
||||
raw_key = generate_raw_key()
|
||||
key_hash = hash_key(raw_key)
|
||||
prefix = key_prefix_display(raw_key)
|
||||
|
||||
# Detect accidental collision (astronomically unlikely)
|
||||
if db.query(ApiKey).filter(ApiKey.key_hash == key_hash).first():
|
||||
raise DuplicateEntityError("ApiKey", "key_hash", "<collision>")
|
||||
|
||||
key = ApiKey(
|
||||
name = name,
|
||||
description = description,
|
||||
key_prefix = prefix,
|
||||
key_hash = key_hash,
|
||||
user_id = user_id,
|
||||
scopes = scopes,
|
||||
expires_at = expires_at,
|
||||
)
|
||||
db.add(key)
|
||||
db.commit()
|
||||
db.refresh(key)
|
||||
return key, raw_key
|
||||
|
||||
|
||||
# ── Read ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
def list_api_keys(
|
||||
db: Session,
|
||||
user_id: Optional[UUID] = None,
|
||||
include_inactive: bool = False,
|
||||
) -> List[ApiKey]:
|
||||
q = db.query(ApiKey)
|
||||
if user_id is not None:
|
||||
q = q.filter(ApiKey.user_id == user_id)
|
||||
if not include_inactive:
|
||||
q = q.filter(ApiKey.is_active == True)
|
||||
return q.order_by(ApiKey.created_at.desc()).all()
|
||||
|
||||
|
||||
def get_api_key(db: Session, key_id: UUID, user_id: Optional[UUID] = None) -> ApiKey:
|
||||
q = db.query(ApiKey).filter(ApiKey.id == key_id)
|
||||
if user_id is not None:
|
||||
q = q.filter(ApiKey.user_id == user_id)
|
||||
key = q.first()
|
||||
if not key:
|
||||
raise EntityNotFoundError("ApiKey", str(key_id))
|
||||
return key
|
||||
|
||||
|
||||
# ── Update / Revoke ───────────────────────────────────────────────────────────
|
||||
|
||||
def update_api_key(
|
||||
db: Session,
|
||||
key_id: UUID,
|
||||
user_id: Optional[UUID] = None,
|
||||
*,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
scopes: Optional[List[str]] = None,
|
||||
expires_at: Optional[datetime] = None,
|
||||
is_active: Optional[bool] = None,
|
||||
) -> ApiKey:
|
||||
key = get_api_key(db, key_id, user_id)
|
||||
if name is not None:
|
||||
key.name = name
|
||||
if description is not None:
|
||||
key.description = description
|
||||
if scopes is not None:
|
||||
key.scopes = scopes
|
||||
if expires_at is not None:
|
||||
key.expires_at = expires_at
|
||||
if is_active is not None:
|
||||
key.is_active = is_active
|
||||
db.commit()
|
||||
db.refresh(key)
|
||||
return key
|
||||
|
||||
|
||||
def revoke_api_key(
|
||||
db: Session,
|
||||
key_id: UUID,
|
||||
user_id: Optional[UUID] = None,
|
||||
) -> ApiKey:
|
||||
"""Soft-revoke: set is_active = False."""
|
||||
return update_api_key(db, key_id, user_id, is_active=False)
|
||||
|
||||
|
||||
def delete_api_key(db: Session, key_id: UUID, user_id: Optional[UUID] = None) -> None:
|
||||
"""Hard delete — use revoke instead for audit trail."""
|
||||
key = get_api_key(db, key_id, user_id)
|
||||
db.delete(key)
|
||||
db.commit()
|
||||
|
||||
|
||||
# ── Authentication ────────────────────────────────────────────────────────────
|
||||
|
||||
def authenticate_raw_key(db: Session, raw_key: str) -> Optional[User]:
|
||||
"""
|
||||
Verify a raw API key.
|
||||
|
||||
Returns the owning User if the key is valid, active, and not expired.
|
||||
Updates ``last_used_at`` (throttled to once per request — always updates).
|
||||
Returns None on any failure.
|
||||
"""
|
||||
h = hash_key(raw_key)
|
||||
key: Optional[ApiKey] = db.query(ApiKey).filter(ApiKey.key_hash == h).first()
|
||||
|
||||
if key is None or not key.is_active:
|
||||
return None
|
||||
if key.expires_at and key.expires_at < datetime.utcnow():
|
||||
return None
|
||||
|
||||
# Update last_used_at
|
||||
key.last_used_at = datetime.utcnow()
|
||||
db.commit()
|
||||
|
||||
user: Optional[User] = db.query(User).filter(User.id == key.user_id).first()
|
||||
if user is None or not user.is_active:
|
||||
return None
|
||||
|
||||
# Attach the key's scopes to the user instance so scope-enforcement
|
||||
# dependencies can verify them without an additional DB query.
|
||||
# _api_key_scopes=None means "full user access" (JWT path).
|
||||
user._api_key_scopes = key.scopes or []
|
||||
return user
|
||||
@@ -51,8 +51,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
# Import TestTemplate from app.models.test_template
|
||||
from app.models.test_template import TestTemplate
|
||||
|
||||
# Import log_action from app.services.audit_service
|
||||
from app.models.technique import Technique
|
||||
from app.services.audit_service import log_action
|
||||
|
||||
# Assign logger = logging.getLogger(__name__)
|
||||
@@ -136,7 +135,7 @@ def _safe_extract_zip(zip_bytes: bytes, dest: str) -> None:
|
||||
f"exceeds limit of {_MAX_UNCOMPRESSED_SIZE / (1024 * 1024):.0f} MB"
|
||||
)
|
||||
|
||||
# Iterate over entries
|
||||
# Iterate over entries — validate and extract each member individually
|
||||
for member in entries:
|
||||
# Assign target = (dest_path / member.filename).resolve()
|
||||
target = (dest_path / member.filename).resolve()
|
||||
@@ -147,9 +146,7 @@ def _safe_extract_zip(zip_bytes: bytes, dest: str) -> None:
|
||||
f"Zip Slip detected — member '{member.filename}' "
|
||||
f"resolves outside target directory"
|
||||
)
|
||||
|
||||
# Call zf.extractall()
|
||||
zf.extractall(dest)
|
||||
zf.extract(member, dest)
|
||||
|
||||
|
||||
# Define function _extract_zip
|
||||
@@ -310,6 +307,7 @@ def import_atomic_red_team(db: Session) -> dict:
|
||||
created = 0
|
||||
# Assign skipped = 0
|
||||
skipped = 0
|
||||
new_technique_ids: set[str] = set()
|
||||
|
||||
# Iterate over parsed_tests
|
||||
for item in parsed_tests:
|
||||
@@ -347,10 +345,14 @@ def import_atomic_red_team(db: Session) -> dict:
|
||||
db.add(template)
|
||||
# Call existing_ids.add()
|
||||
existing_ids.add(item["atomic_test_id"])
|
||||
# Assign created = 1
|
||||
new_technique_ids.add(item["technique_id"])
|
||||
created += 1
|
||||
|
||||
# Commit all pending changes to the database
|
||||
if new_technique_ids:
|
||||
db.query(Technique).filter(
|
||||
Technique.mitre_id.in_(new_technique_ids)
|
||||
).update({"review_required": True}, synchronize_session=False)
|
||||
|
||||
db.commit()
|
||||
|
||||
# Count distinct YAML files by technique_id
|
||||
|
||||
@@ -0,0 +1,553 @@
|
||||
"""Phase 10: Attack Path CRUD service."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from app.models.attack_path import (
|
||||
AttackPath, AttackPathStep, AttackPathExecution,
|
||||
AttackPathStepResult, TimelineEntry,
|
||||
ExecutionStatus, StepResultStatus, TimelineActorSide, TimelineEntryType,
|
||||
)
|
||||
from app.domain.exceptions import EntityNotFoundError
|
||||
from app.services import audit_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _now() -> datetime:
|
||||
return datetime.utcnow()
|
||||
|
||||
|
||||
# ── Attack Path CRUD ──────────────────────────────────────────────────────────
|
||||
|
||||
def create_attack_path(db: Session, data: dict, user_id: UUID) -> AttackPath:
|
||||
path = AttackPath(
|
||||
name=data["name"],
|
||||
description=data.get("description"),
|
||||
objective=data.get("objective"),
|
||||
is_template=data.get("is_template", False),
|
||||
threat_actor_id=data.get("threat_actor_id"),
|
||||
tags=data.get("tags") or [],
|
||||
created_by=user_id,
|
||||
)
|
||||
db.add(path)
|
||||
db.commit()
|
||||
db.refresh(path)
|
||||
audit_service.log_action(
|
||||
db, user_id, "ATTACK_PATH_CREATED", "attack_path", str(path.id),
|
||||
details={"name": path.name, "is_template": path.is_template},
|
||||
)
|
||||
return path
|
||||
|
||||
|
||||
def get_attack_path(db: Session, path_id: UUID) -> AttackPath:
|
||||
path = (
|
||||
db.query(AttackPath)
|
||||
.options(joinedload(AttackPath.steps))
|
||||
.filter(AttackPath.id == path_id)
|
||||
.first()
|
||||
)
|
||||
if not path:
|
||||
raise EntityNotFoundError("AttackPath", str(path_id))
|
||||
return path
|
||||
|
||||
|
||||
def list_attack_paths(
|
||||
db: Session,
|
||||
is_template: Optional[bool] = None,
|
||||
technique_id: Optional[UUID] = None,
|
||||
is_active: Optional[bool] = True,
|
||||
) -> list[AttackPath]:
|
||||
q = db.query(AttackPath)
|
||||
if is_active is not None:
|
||||
q = q.filter(AttackPath.is_active == is_active)
|
||||
if is_template is not None:
|
||||
q = q.filter(AttackPath.is_template == is_template)
|
||||
if technique_id:
|
||||
q = q.join(AttackPathStep).filter(AttackPathStep.technique_id == technique_id)
|
||||
return q.order_by(AttackPath.created_at.desc()).all()
|
||||
|
||||
|
||||
def update_attack_path(db: Session, path_id: UUID, data: dict, user_id: UUID) -> AttackPath:
|
||||
path = db.query(AttackPath).filter(AttackPath.id == path_id).first()
|
||||
if not path:
|
||||
raise EntityNotFoundError("AttackPath", str(path_id))
|
||||
for k, v in data.items():
|
||||
if v is not None and hasattr(path, k):
|
||||
setattr(path, k, v)
|
||||
path.updated_at = _now()
|
||||
db.commit()
|
||||
db.refresh(path)
|
||||
return path
|
||||
|
||||
|
||||
def delete_attack_path(db: Session, path_id: UUID, user_id: UUID) -> None:
|
||||
path = db.query(AttackPath).filter(AttackPath.id == path_id).first()
|
||||
if not path:
|
||||
raise EntityNotFoundError("AttackPath", str(path_id))
|
||||
path.is_active = False
|
||||
path.updated_at = _now()
|
||||
db.commit()
|
||||
audit_service.log_action(db, user_id, "ATTACK_PATH_ARCHIVED", "attack_path", str(path_id))
|
||||
|
||||
|
||||
# ── Steps CRUD ────────────────────────────────────────────────────────────────
|
||||
|
||||
def add_step(db: Session, path_id: UUID, data: dict, user_id: UUID) -> AttackPathStep:
|
||||
path = db.query(AttackPath).filter(AttackPath.id == path_id).first()
|
||||
if not path:
|
||||
raise EntityNotFoundError("AttackPath", str(path_id))
|
||||
|
||||
# Auto-assign order_index if not provided
|
||||
if data.get("order_index") is None:
|
||||
max_idx = db.query(AttackPathStep).filter(
|
||||
AttackPathStep.attack_path_id == path_id
|
||||
).count()
|
||||
data["order_index"] = max_idx
|
||||
|
||||
step = AttackPathStep(
|
||||
attack_path_id=path_id,
|
||||
order_index=data.get("order_index", 0),
|
||||
kill_chain_phase=data.get("kill_chain_phase"),
|
||||
technique_id=data.get("technique_id"),
|
||||
test_id=data.get("test_id"),
|
||||
name=data.get("name"),
|
||||
description=data.get("description"),
|
||||
expected_detection=data.get("expected_detection", True),
|
||||
notes=data.get("notes"),
|
||||
)
|
||||
db.add(step)
|
||||
path.updated_at = _now()
|
||||
db.commit()
|
||||
db.refresh(step)
|
||||
return step
|
||||
|
||||
|
||||
def update_step(db: Session, step_id: UUID, data: dict, user_id: UUID) -> AttackPathStep:
|
||||
step = db.query(AttackPathStep).filter(AttackPathStep.id == step_id).first()
|
||||
if not step:
|
||||
raise EntityNotFoundError("AttackPathStep", str(step_id))
|
||||
for k, v in data.items():
|
||||
if v is not None and hasattr(step, k):
|
||||
setattr(step, k, v)
|
||||
db.commit()
|
||||
db.refresh(step)
|
||||
return step
|
||||
|
||||
|
||||
def delete_step(db: Session, step_id: UUID, user_id: UUID) -> None:
|
||||
step = db.query(AttackPathStep).filter(AttackPathStep.id == step_id).first()
|
||||
if not step:
|
||||
raise EntityNotFoundError("AttackPathStep", str(step_id))
|
||||
db.delete(step)
|
||||
db.commit()
|
||||
|
||||
|
||||
def reorder_steps(db: Session, path_id: UUID, step_ids: list[UUID], user_id: UUID) -> list[AttackPathStep]:
|
||||
"""Reorder steps by providing ordered list of step IDs."""
|
||||
path = db.query(AttackPath).filter(AttackPath.id == path_id).first()
|
||||
if not path:
|
||||
raise EntityNotFoundError("AttackPath", str(path_id))
|
||||
|
||||
for idx, step_id in enumerate(step_ids):
|
||||
db.query(AttackPathStep).filter(
|
||||
AttackPathStep.id == step_id,
|
||||
AttackPathStep.attack_path_id == path_id,
|
||||
).update({"order_index": idx})
|
||||
|
||||
path.updated_at = _now()
|
||||
db.commit()
|
||||
|
||||
return (
|
||||
db.query(AttackPathStep)
|
||||
.filter(AttackPathStep.attack_path_id == path_id)
|
||||
.order_by(AttackPathStep.order_index)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
# ── Executions ────────────────────────────────────────────────────────────────
|
||||
|
||||
def create_execution(
|
||||
db: Session, path_id: UUID, data: dict, user_id: UUID
|
||||
) -> AttackPathExecution:
|
||||
path = db.query(AttackPath).filter(AttackPath.id == path_id).first()
|
||||
if not path:
|
||||
raise EntityNotFoundError("AttackPath", str(path_id))
|
||||
|
||||
execution = AttackPathExecution(
|
||||
attack_path_id=path_id,
|
||||
status=ExecutionStatus.planned,
|
||||
environment=data.get("environment"),
|
||||
red_team_lead=data.get("red_team_lead"),
|
||||
blue_team_lead=data.get("blue_team_lead"),
|
||||
notes=data.get("notes"),
|
||||
started_by=user_id,
|
||||
)
|
||||
db.add(execution)
|
||||
db.flush()
|
||||
|
||||
# Pre-create pending step results for every step in the path
|
||||
steps = (
|
||||
db.query(AttackPathStep)
|
||||
.filter(AttackPathStep.attack_path_id == path_id)
|
||||
.order_by(AttackPathStep.order_index)
|
||||
.all()
|
||||
)
|
||||
for step in steps:
|
||||
result = AttackPathStepResult(
|
||||
execution_id=execution.id,
|
||||
step_id=step.id,
|
||||
step_order=step.order_index,
|
||||
status=StepResultStatus.pending,
|
||||
)
|
||||
db.add(result)
|
||||
|
||||
db.commit()
|
||||
db.refresh(execution)
|
||||
|
||||
# Auto-add system timeline entry
|
||||
_add_system_entry(
|
||||
db, execution.id,
|
||||
entry_type=TimelineEntryType.phase_transition,
|
||||
content=f"Execution created for '{path.name}' with {len(steps)} steps.",
|
||||
)
|
||||
|
||||
audit_service.log_action(
|
||||
db, user_id, "ATTACK_PATH_EXECUTION_STARTED", "attack_path_execution",
|
||||
str(execution.id),
|
||||
details={"path_id": str(path_id), "path_name": path.name, "steps": len(steps)},
|
||||
)
|
||||
return execution
|
||||
|
||||
|
||||
def get_execution(db: Session, execution_id: UUID) -> AttackPathExecution:
|
||||
ex = (
|
||||
db.query(AttackPathExecution)
|
||||
.options(
|
||||
joinedload(AttackPathExecution.step_results),
|
||||
joinedload(AttackPathExecution.timeline),
|
||||
)
|
||||
.filter(AttackPathExecution.id == execution_id)
|
||||
.first()
|
||||
)
|
||||
if not ex:
|
||||
raise EntityNotFoundError("AttackPathExecution", str(execution_id))
|
||||
return ex
|
||||
|
||||
|
||||
def list_executions(db: Session, path_id: UUID) -> list[AttackPathExecution]:
|
||||
path = db.query(AttackPath).filter(AttackPath.id == path_id).first()
|
||||
if not path:
|
||||
raise EntityNotFoundError("AttackPath", str(path_id))
|
||||
return (
|
||||
db.query(AttackPathExecution)
|
||||
.filter(AttackPathExecution.attack_path_id == path_id)
|
||||
.order_by(AttackPathExecution.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def start_execution(db: Session, execution_id: UUID, user_id: UUID) -> AttackPathExecution:
|
||||
ex = db.query(AttackPathExecution).filter(AttackPathExecution.id == execution_id).first()
|
||||
if not ex:
|
||||
raise EntityNotFoundError("AttackPathExecution", str(execution_id))
|
||||
if ex.status not in (ExecutionStatus.planned,):
|
||||
from fastapi import HTTPException
|
||||
raise HTTPException(400, "Execution is not in 'planned' state")
|
||||
ex.status = ExecutionStatus.in_progress
|
||||
ex.started_at = _now()
|
||||
db.commit()
|
||||
db.refresh(ex)
|
||||
_add_system_entry(db, execution_id, TimelineEntryType.phase_transition,
|
||||
"Execution started.", actor_id=user_id, actor_side=TimelineActorSide.system)
|
||||
return ex
|
||||
|
||||
|
||||
# ── Step Execution ────────────────────────────────────────────────────────────
|
||||
|
||||
def execute_step(
|
||||
db: Session,
|
||||
execution_id: UUID,
|
||||
step_id: UUID,
|
||||
data: dict,
|
||||
user_id: UUID,
|
||||
) -> AttackPathStepResult:
|
||||
"""Record the result of executing one step."""
|
||||
ex = db.query(AttackPathExecution).filter(AttackPathExecution.id == execution_id).first()
|
||||
if not ex:
|
||||
raise EntityNotFoundError("AttackPathExecution", str(execution_id))
|
||||
if ex.status not in (ExecutionStatus.in_progress, ExecutionStatus.planned):
|
||||
from fastapi import HTTPException
|
||||
raise HTTPException(400, "Execution must be in_progress to record step results")
|
||||
|
||||
# Auto-start if still planned
|
||||
if ex.status == ExecutionStatus.planned:
|
||||
ex.status = ExecutionStatus.in_progress
|
||||
ex.started_at = _now()
|
||||
|
||||
result = (
|
||||
db.query(AttackPathStepResult)
|
||||
.filter(
|
||||
AttackPathStepResult.execution_id == execution_id,
|
||||
AttackPathStepResult.step_id == step_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not result:
|
||||
# Create on-the-fly if step was added after execution started
|
||||
step = db.query(AttackPathStep).filter(AttackPathStep.id == step_id).first()
|
||||
if not step:
|
||||
raise EntityNotFoundError("AttackPathStep", str(step_id))
|
||||
result = AttackPathStepResult(
|
||||
execution_id=execution_id,
|
||||
step_id=step_id,
|
||||
step_order=step.order_index,
|
||||
)
|
||||
db.add(result)
|
||||
|
||||
now = _now()
|
||||
new_status = StepResultStatus(data["status"])
|
||||
result.status = new_status
|
||||
result.executed_by = user_id
|
||||
result.executed_at = data.get("executed_at") or now
|
||||
result.notes = data.get("notes")
|
||||
result.evidence_ids = [str(e) for e in (data.get("evidence_ids") or [])]
|
||||
result.detection_asset_id = data.get("detection_asset_id")
|
||||
|
||||
if new_status == StepResultStatus.detected:
|
||||
result.detected_at = data.get("detected_at") or now
|
||||
if result.executed_at:
|
||||
delta = (result.detected_at - result.executed_at).total_seconds()
|
||||
result.time_to_detect_seconds = max(0.0, delta)
|
||||
|
||||
db.commit()
|
||||
db.refresh(result)
|
||||
|
||||
# Add timeline entry
|
||||
step_obj = db.query(AttackPathStep).filter(AttackPathStep.id == step_id).first()
|
||||
step_name = step_obj.name or (step_obj.kill_chain_phase or "Unknown step")
|
||||
actor_side = TimelineActorSide.red if new_status != StepResultStatus.detected else TimelineActorSide.blue
|
||||
entry_type = (
|
||||
TimelineEntryType.detection if new_status == StepResultStatus.detected
|
||||
else TimelineEntryType.action
|
||||
)
|
||||
content = (
|
||||
f"Step '{step_name}' marked as {new_status.value}."
|
||||
+ (f" Detected in {result.time_to_detect_seconds:.0f}s." if result.time_to_detect_seconds else "")
|
||||
)
|
||||
_add_system_entry(
|
||||
db, execution_id, entry_type, content,
|
||||
actor_id=user_id, actor_side=actor_side, step_id=step_id,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ── Completion & Metrics ──────────────────────────────────────────────────────
|
||||
|
||||
def complete_execution(db: Session, execution_id: UUID, user_id: UUID) -> AttackPathExecution:
|
||||
"""Mark execution complete and compute all kill-chain metrics."""
|
||||
ex = db.query(AttackPathExecution).filter(AttackPathExecution.id == execution_id).first()
|
||||
if not ex:
|
||||
raise EntityNotFoundError("AttackPathExecution", str(execution_id))
|
||||
|
||||
results = (
|
||||
db.query(AttackPathStepResult)
|
||||
.filter(AttackPathStepResult.execution_id == execution_id)
|
||||
.order_by(AttackPathStepResult.step_order)
|
||||
.all()
|
||||
)
|
||||
|
||||
total = len(results)
|
||||
detected = sum(1 for r in results if r.status == StepResultStatus.detected)
|
||||
not_detected = sum(1 for r in results if r.status == StepResultStatus.not_detected)
|
||||
skipped = sum(1 for r in results if r.status == StepResultStatus.skipped)
|
||||
|
||||
detection_rate = (detected / total) if total > 0 else 0.0
|
||||
|
||||
ttds = [r.time_to_detect_seconds for r in results
|
||||
if r.time_to_detect_seconds is not None]
|
||||
mttd = (sum(ttds) / len(ttds)) if ttds else None
|
||||
|
||||
# Furthest undetected step (highest order_index with not_detected status)
|
||||
undetected = [r for r in results if r.status == StepResultStatus.not_detected]
|
||||
furthest = max((r.step_order for r in undetected), default=None)
|
||||
|
||||
ex.status = ExecutionStatus.completed
|
||||
ex.completed_at = _now()
|
||||
ex.total_steps = total
|
||||
ex.detected_steps = detected
|
||||
ex.not_detected_steps = not_detected
|
||||
ex.skipped_steps = skipped
|
||||
ex.detection_rate = round(detection_rate, 4)
|
||||
ex.mttd_seconds = round(mttd, 1) if mttd is not None else None
|
||||
ex.furthest_undetected_step = furthest
|
||||
|
||||
db.commit()
|
||||
db.refresh(ex)
|
||||
|
||||
_add_system_entry(
|
||||
db, execution_id, TimelineEntryType.phase_transition,
|
||||
f"Execution completed. Detection rate: {detection_rate:.0%}. "
|
||||
f"Detected {detected}/{total} steps. "
|
||||
+ (f"MTTD: {mttd:.0f}s." if mttd else ""),
|
||||
actor_id=user_id, actor_side=TimelineActorSide.system,
|
||||
)
|
||||
|
||||
audit_service.log_action(
|
||||
db, user_id, "ATTACK_PATH_EXECUTION_COMPLETED", "attack_path_execution",
|
||||
str(execution_id),
|
||||
details={"detection_rate": detection_rate, "mttd_seconds": mttd,
|
||||
"detected": detected, "total": total},
|
||||
)
|
||||
return ex
|
||||
|
||||
|
||||
def abort_execution(db: Session, execution_id: UUID, user_id: UUID) -> AttackPathExecution:
|
||||
ex = db.query(AttackPathExecution).filter(AttackPathExecution.id == execution_id).first()
|
||||
if not ex:
|
||||
raise EntityNotFoundError("AttackPathExecution", str(execution_id))
|
||||
ex.status = ExecutionStatus.aborted
|
||||
ex.completed_at = _now()
|
||||
db.commit()
|
||||
db.refresh(ex)
|
||||
_add_system_entry(db, execution_id, TimelineEntryType.flag, "Execution aborted.",
|
||||
actor_id=user_id, actor_side=TimelineActorSide.system)
|
||||
return ex
|
||||
|
||||
|
||||
# ── Timeline ──────────────────────────────────────────────────────────────────
|
||||
|
||||
def add_timeline_entry(
|
||||
db: Session, execution_id: UUID, data: dict, user_id: UUID
|
||||
) -> TimelineEntry:
|
||||
ex = db.query(AttackPathExecution).filter(AttackPathExecution.id == execution_id).first()
|
||||
if not ex:
|
||||
raise EntityNotFoundError("AttackPathExecution", str(execution_id))
|
||||
|
||||
entry = TimelineEntry(
|
||||
execution_id=execution_id,
|
||||
step_id=data.get("step_id"),
|
||||
timestamp=data.get("timestamp") or _now(),
|
||||
actor_side=TimelineActorSide(data["actor_side"]),
|
||||
actor_id=user_id,
|
||||
entry_type=TimelineEntryType(data["entry_type"]),
|
||||
content=data["content"],
|
||||
extra=data.get("extra"),
|
||||
)
|
||||
db.add(entry)
|
||||
db.commit()
|
||||
db.refresh(entry)
|
||||
return entry
|
||||
|
||||
|
||||
def get_timeline(db: Session, execution_id: UUID) -> list[TimelineEntry]:
|
||||
ex = db.query(AttackPathExecution).filter(AttackPathExecution.id == execution_id).first()
|
||||
if not ex:
|
||||
raise EntityNotFoundError("AttackPathExecution", str(execution_id))
|
||||
return (
|
||||
db.query(TimelineEntry)
|
||||
.filter(TimelineEntry.execution_id == execution_id)
|
||||
.order_by(TimelineEntry.timestamp.asc())
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
# ── Kill-Chain Metrics ────────────────────────────────────────────────────────
|
||||
|
||||
def get_kill_chain_metrics(db: Session, execution_id: UUID) -> dict:
|
||||
ex = db.query(AttackPathExecution).filter(AttackPathExecution.id == execution_id).first()
|
||||
if not ex:
|
||||
raise EntityNotFoundError("AttackPathExecution", str(execution_id))
|
||||
|
||||
results = (
|
||||
db.query(AttackPathStepResult)
|
||||
.filter(AttackPathStepResult.execution_id == execution_id)
|
||||
.order_by(AttackPathStepResult.step_order)
|
||||
.all()
|
||||
)
|
||||
|
||||
step_breakdown = []
|
||||
phase_detected: dict[str, list] = {}
|
||||
|
||||
for r in results:
|
||||
step = db.query(AttackPathStep).filter(AttackPathStep.id == r.step_id).first()
|
||||
phase = step.kill_chain_phase if step else None
|
||||
entry = {
|
||||
"step_id": str(r.step_id),
|
||||
"step_order": r.step_order,
|
||||
"step_name": step.name if step else None,
|
||||
"kill_chain_phase": phase,
|
||||
"status": r.status.value if hasattr(r.status, "value") else r.status,
|
||||
"executed_at": r.executed_at.isoformat() if r.executed_at else None,
|
||||
"detected_at": r.detected_at.isoformat() if r.detected_at else None,
|
||||
"time_to_detect_seconds": r.time_to_detect_seconds,
|
||||
"detection_asset_id": str(r.detection_asset_id) if r.detection_asset_id else None,
|
||||
}
|
||||
step_breakdown.append(entry)
|
||||
if phase:
|
||||
phase_detected.setdefault(phase, []).append(
|
||||
r.status == StepResultStatus.detected
|
||||
)
|
||||
|
||||
phase_summary = {
|
||||
phase: {
|
||||
"total": len(v),
|
||||
"detected": sum(v),
|
||||
"detection_rate": round(sum(v) / len(v), 3) if v else 0.0,
|
||||
}
|
||||
for phase, v in phase_detected.items()
|
||||
}
|
||||
|
||||
# Furthest undetected phase
|
||||
furthest_undetected_phase = None
|
||||
if ex.furthest_undetected_step is not None:
|
||||
for r in reversed(results):
|
||||
if r.step_order == ex.furthest_undetected_step:
|
||||
step = db.query(AttackPathStep).filter(AttackPathStep.id == r.step_id).first()
|
||||
if step:
|
||||
furthest_undetected_phase = step.kill_chain_phase
|
||||
break
|
||||
|
||||
return {
|
||||
"execution_id": str(execution_id),
|
||||
"total_steps": ex.total_steps or len(results),
|
||||
"detected_steps": ex.detected_steps or 0,
|
||||
"not_detected_steps": ex.not_detected_steps or 0,
|
||||
"skipped_steps": ex.skipped_steps or 0,
|
||||
"detection_rate": ex.detection_rate or 0.0,
|
||||
"mttd_seconds": ex.mttd_seconds,
|
||||
"furthest_undetected_step": ex.furthest_undetected_step,
|
||||
"furthest_undetected_phase": furthest_undetected_phase,
|
||||
"step_breakdown": step_breakdown,
|
||||
"phase_summary": phase_summary,
|
||||
}
|
||||
|
||||
|
||||
# ── Helper ────────────────────────────────────────────────────────────────────
|
||||
|
||||
def _add_system_entry(
|
||||
db: Session,
|
||||
execution_id: UUID,
|
||||
entry_type: TimelineEntryType,
|
||||
content: str,
|
||||
actor_id: Optional[UUID] = None,
|
||||
actor_side: TimelineActorSide = TimelineActorSide.system,
|
||||
step_id: Optional[UUID] = None,
|
||||
) -> None:
|
||||
entry = TimelineEntry(
|
||||
execution_id=execution_id,
|
||||
step_id=step_id,
|
||||
timestamp=_now(),
|
||||
actor_side=actor_side,
|
||||
actor_id=actor_id,
|
||||
entry_type=entry_type,
|
||||
content=content,
|
||||
)
|
||||
db.add(entry)
|
||||
db.commit()
|
||||
@@ -0,0 +1,798 @@
|
||||
"""ATT&CK Evaluations importer — fetches real CrowdStrike detection results
|
||||
from MITRE Engenuity's public API and seeds the platform with validated tests.
|
||||
|
||||
Data source
|
||||
-----------
|
||||
https://evals.mitre.org/api/
|
||||
- /participants/ → list of vendors + rounds they completed
|
||||
- /results/?participant=crowdstrike&domain=ENTERPRISE
|
||||
→ per-substep detection results per adversary
|
||||
|
||||
Detection level mapping (MITRE → Aegis)
|
||||
---------------------------------------
|
||||
Technique / Specific Behavior → detected (correctly identified ATT&CK technique)
|
||||
Tactic → partially_detected (behavior noted but not categorized)
|
||||
General / IOC / MSSP → partially_detected (anomaly detected, not ATT&CK-mapped)
|
||||
Telemetry → partially_detected (raw data only — marginal detection)
|
||||
None / N/A → not_detected
|
||||
|
||||
All imported tests are created in ``in_review`` state so Blue Leads must
|
||||
confirm each result before it counts as real coverage for the organisation.
|
||||
|
||||
Important caveats stored in every test's description
|
||||
------------------------------------------------------
|
||||
"Source: MITRE ATT&CK Evaluation (Round N — Adversary). Results reflect
|
||||
CrowdStrike Falcon in a controlled lab environment, NOT this organisation's
|
||||
deployment. Validate detection in your own environment before approving."
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.enums import TestState, TestResult
|
||||
from app.models.evaluation_import import EvaluationImport
|
||||
from app.models.technique import Technique
|
||||
from app.models.test import Test
|
||||
from app.models.user import User
|
||||
from app.services.audit_service import log_action
|
||||
from app.services.status_service import recalculate_technique_status
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_BASE = "https://evals.mitre.org"
|
||||
_TIMEOUT = 30 # seconds per HTTP call
|
||||
_VENDOR = "crowdstrike"
|
||||
_DOMAIN = "ENTERPRISE"
|
||||
|
||||
# Browser-like headers to bypass Cloudflare bot protection on evals.mitre.org
|
||||
_HEADERS = {
|
||||
"User-Agent": (
|
||||
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
|
||||
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
||||
"Chrome/124.0.0.0 Safari/537.36"
|
||||
),
|
||||
"Accept": "application/json, text/plain, */*",
|
||||
"Accept-Language": "en-US,en;q=0.9",
|
||||
"Referer": "https://evals.mitre.org/",
|
||||
"Origin": "https://evals.mitre.org",
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fallback: hardcoded public CrowdStrike ENTERPRISE rounds
|
||||
# Used when evals.mitre.org API is unreachable (Cloudflare 502, outage, etc.)
|
||||
#
|
||||
# Names use the EXACT slugs the live API returns (hyphens, not underscores).
|
||||
# Verified from live API response on 2025-06-05.
|
||||
# CrowdStrike did NOT participate in Round 6 (OilRig) — not included.
|
||||
# ---------------------------------------------------------------------------
|
||||
_FALLBACK_ROUNDS: list[dict[str, Any]] = [
|
||||
{
|
||||
"name": "apt3",
|
||||
"display_name": "APT3",
|
||||
"eval_round": 1,
|
||||
"domain": "ENTERPRISE",
|
||||
"status": "PUBLIC",
|
||||
},
|
||||
{
|
||||
"name": "apt29",
|
||||
"display_name": "APT29",
|
||||
"eval_round": 2,
|
||||
"domain": "ENTERPRISE",
|
||||
"status": "PUBLIC",
|
||||
},
|
||||
{
|
||||
"name": "carbanak-fin7",
|
||||
"display_name": "Carbanak+FIN7",
|
||||
"eval_round": 3,
|
||||
"domain": "ENTERPRISE",
|
||||
"status": "PUBLIC",
|
||||
},
|
||||
{
|
||||
"name": "wizard-spider-sandworm",
|
||||
"display_name": "Wizard Spider + Sandworm",
|
||||
"eval_round": 4,
|
||||
"domain": "ENTERPRISE",
|
||||
"status": "PUBLIC",
|
||||
},
|
||||
{
|
||||
"name": "turla",
|
||||
"display_name": "Turla",
|
||||
"eval_round": 5,
|
||||
"domain": "ENTERPRISE",
|
||||
"status": "PUBLIC",
|
||||
},
|
||||
{
|
||||
"name": "er7",
|
||||
"display_name": "Enterprise 2025",
|
||||
"eval_round": 7,
|
||||
"domain": "ENTERPRISE",
|
||||
"status": "PUBLIC",
|
||||
},
|
||||
]
|
||||
|
||||
# Detection type → quality score (higher = better)
|
||||
_DETECTION_SCORE: dict[str, int] = {
|
||||
"none": 0,
|
||||
"n/a": 0,
|
||||
"telemetry": 1,
|
||||
"mssp": 2,
|
||||
"general": 2,
|
||||
"ioc": 2,
|
||||
"tactic": 3,
|
||||
"technique": 4,
|
||||
"specific behavior": 4,
|
||||
}
|
||||
|
||||
|
||||
def _score(detection_type: str) -> int:
|
||||
key = (detection_type or "").lower().strip()
|
||||
for pattern, score in _DETECTION_SCORE.items():
|
||||
if pattern in key:
|
||||
return score
|
||||
return 0
|
||||
|
||||
|
||||
def _score_to_result(score: int) -> TestResult:
|
||||
if score >= 4:
|
||||
return TestResult.detected
|
||||
if score >= 1:
|
||||
return TestResult.partially_detected
|
||||
return TestResult.not_detected
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def fetch_rounds_with_status() -> dict[str, Any]:
|
||||
"""Fetch CrowdStrike ENTERPRISE rounds and report whether the live API was reachable.
|
||||
|
||||
Returns::
|
||||
|
||||
{
|
||||
"rounds": [{"name": ..., "display_name": ..., "eval_round": ...}, ...],
|
||||
"api_reachable": True | False,
|
||||
"api_error": None | "<error message>",
|
||||
}
|
||||
"""
|
||||
try:
|
||||
session = requests.Session()
|
||||
session.headers.update(_HEADERS)
|
||||
resp = session.get(f"{_BASE}/api/participants/", timeout=_TIMEOUT)
|
||||
resp.raise_for_status()
|
||||
participants = resp.json()
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"evals.mitre.org API unreachable (%s) — using hardcoded fallback round list.",
|
||||
exc,
|
||||
)
|
||||
return {
|
||||
"rounds": list(_FALLBACK_ROUNDS),
|
||||
"api_reachable": False,
|
||||
"api_error": str(exc),
|
||||
}
|
||||
|
||||
crowdstrike = next(
|
||||
(p for p in participants if p.get("name", "").lower() == _VENDOR),
|
||||
None,
|
||||
)
|
||||
if not crowdstrike:
|
||||
logger.warning("Vendor '%s' not found in live data — using fallback.", _VENDOR)
|
||||
return {
|
||||
"rounds": list(_FALLBACK_ROUNDS),
|
||||
"api_reachable": True, # API was reachable, vendor just wasn't listed
|
||||
"api_error": f"Vendor '{_VENDOR}' not found in participants list",
|
||||
}
|
||||
|
||||
rounds = [
|
||||
adv
|
||||
for adv in crowdstrike.get("adversaries_completed", [])
|
||||
if adv.get("domain", "").upper() == _DOMAIN
|
||||
and adv.get("status", "").upper() == "PUBLIC"
|
||||
]
|
||||
rounds.sort(key=lambda x: x.get("eval_round", 0))
|
||||
return {
|
||||
"rounds": rounds if rounds else list(_FALLBACK_ROUNDS),
|
||||
"api_reachable": True,
|
||||
"api_error": None,
|
||||
}
|
||||
|
||||
|
||||
def fetch_available_rounds() -> list[dict[str, Any]]:
|
||||
"""Return all evaluation rounds CrowdStrike has completed (ENTERPRISE only).
|
||||
|
||||
Each dict has: name, display_name, eval_round.
|
||||
Sorted by eval_round ascending.
|
||||
|
||||
Falls back to ``_FALLBACK_ROUNDS`` if the live API is unreachable.
|
||||
"""
|
||||
return fetch_rounds_with_status()["rounds"]
|
||||
|
||||
|
||||
def get_latest_round() -> dict[str, Any]:
|
||||
"""Return the most recent PUBLIC ENTERPRISE round CrowdStrike participated in."""
|
||||
rounds = fetch_available_rounds()
|
||||
if not rounds:
|
||||
raise ValueError("No public Enterprise evaluation rounds found for CrowdStrike")
|
||||
return rounds[-1]
|
||||
|
||||
|
||||
def fetch_results_for_adversary(adversary_name: str) -> list[dict[str, Any]]:
|
||||
"""Fetch all per-substep detection results for a specific adversary round.
|
||||
|
||||
Returns a flat list of substep dicts, each containing:
|
||||
technique_id, technique_name, tactic_id, best_score, detection_type, note.
|
||||
"""
|
||||
url = f"{_BASE}/api/results/?participant={_VENDOR}&domain={_DOMAIN}"
|
||||
try:
|
||||
session = requests.Session()
|
||||
session.headers.update(_HEADERS)
|
||||
resp = session.get(url, timeout=_TIMEOUT)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
except Exception as exc:
|
||||
logger.error("Failed to fetch ATT&CK Evaluations results: %s", exc)
|
||||
raise
|
||||
|
||||
# The results endpoint returns a LIST of vendor objects:
|
||||
# [{"name": "crowdstrike", "adversaries": [{"Adversary_Name": "apt3", ...}, ...]}, ...]
|
||||
# (not a dict — hence the explicit vendor lookup below)
|
||||
if isinstance(data, list):
|
||||
vendor_entry = next(
|
||||
(v for v in data if isinstance(v, dict) and v.get("name", "").lower() == _VENDOR),
|
||||
None,
|
||||
)
|
||||
if not vendor_entry:
|
||||
raise ValueError(
|
||||
f"Vendor '{_VENDOR}' not found in results response. "
|
||||
f"Available vendors: {[v.get('name') for v in data if isinstance(v, dict)]}"
|
||||
)
|
||||
adversaries = vendor_entry.get("adversaries", [])
|
||||
else:
|
||||
# Fallback for legacy dict-shaped response (just in case API changes again)
|
||||
adversaries = data.get("adversaries", [])
|
||||
|
||||
target = next(
|
||||
(a for a in adversaries if a.get("Adversary_Name", "").lower() == adversary_name.lower()),
|
||||
None,
|
||||
)
|
||||
if not target:
|
||||
raise ValueError(
|
||||
f"Adversary '{adversary_name}' not found in results. "
|
||||
f"Available: {[a.get('Adversary_Name') for a in adversaries]}"
|
||||
)
|
||||
|
||||
substeps: list[dict[str, Any]] = []
|
||||
|
||||
scenarios = target.get("Detections_By_Step", {})
|
||||
for scenario_name, scenario_data in scenarios.items():
|
||||
for step in scenario_data.get("Steps", []):
|
||||
step_num = step.get("Step_Num", "")
|
||||
step_name = step.get("Step_Name", "")
|
||||
# Strip HTML tags from the Step.Description narrative
|
||||
step_desc_raw = step.get("Description") or ""
|
||||
step_description = re.sub(r"<[^>]+>", " ", step_desc_raw)
|
||||
step_description = re.sub(r"\s+", " ", step_description).strip()
|
||||
|
||||
for substep in step.get("Substeps", []):
|
||||
# Prefer sub-technique over technique
|
||||
sub = substep.get("Subtechnique") or {}
|
||||
tech = substep.get("Technique") or {}
|
||||
tactic = substep.get("Tactic") or {}
|
||||
|
||||
technique_id = (
|
||||
sub.get("Subtechnique_Id")
|
||||
or tech.get("Technique_Id")
|
||||
or ""
|
||||
).strip()
|
||||
technique_name = (
|
||||
sub.get("Subtechnique_Name")
|
||||
or tech.get("Technique_Name")
|
||||
or "Unknown"
|
||||
).strip()
|
||||
|
||||
if not technique_id:
|
||||
continue
|
||||
|
||||
detections = substep.get("Detections", [])
|
||||
best_score = 0
|
||||
best_type = "None"
|
||||
best_note = ""
|
||||
for det in detections:
|
||||
dtype = det.get("Detection_Type", "None")
|
||||
s = _score(dtype)
|
||||
if s > best_score:
|
||||
best_score = s
|
||||
best_type = dtype
|
||||
best_note = det.get("Detection_Note", "")
|
||||
|
||||
# Collect all unique data sources from screenshots across all detections
|
||||
data_sources: list[str] = sorted({
|
||||
src
|
||||
for det in detections
|
||||
for sc in det.get("Screenshots", [])
|
||||
for src in sc.get("Data_Sources", [])
|
||||
})
|
||||
|
||||
substeps.append(
|
||||
{
|
||||
"technique_id": technique_id,
|
||||
"technique_name": technique_name,
|
||||
"tactic_id": tactic.get("Tactic_Id", ""),
|
||||
"tactic_name": tactic.get("Tactic_Name", ""),
|
||||
"best_score": best_score,
|
||||
"detection_type": best_type,
|
||||
"note": best_note,
|
||||
# Enrichment fields from the API
|
||||
"scenario_name": scenario_name,
|
||||
"step_num": step_num,
|
||||
"step_name": step_name,
|
||||
"step_description": step_description,
|
||||
"substep_ref": substep.get("Substep", ""),
|
||||
"criteria": (substep.get("Criteria") or "").strip(),
|
||||
"data_sources": data_sources,
|
||||
}
|
||||
)
|
||||
|
||||
return substeps
|
||||
|
||||
|
||||
def _aggregate_by_technique(substeps: list[dict]) -> dict[str, dict]:
|
||||
"""Aggregate substep results per technique.
|
||||
|
||||
- Deduplicates substeps by (substep_ref, criteria) — prevents duplicates
|
||||
that arise when adversaries with multiple scenarios (e.g. Wizard Spider +
|
||||
Sandworm) repeat the same substep across a "combined" replay scenario.
|
||||
- Groups unique occurrences by scenario_name so the narrative can show
|
||||
"Wizard Spider scenario" vs "Sandworm scenario" separately.
|
||||
- Tracks best detection score across all unique substeps.
|
||||
"""
|
||||
by_technique: dict[str, dict] = {}
|
||||
|
||||
for sub in substeps:
|
||||
tid = sub["technique_id"]
|
||||
if tid not in by_technique:
|
||||
by_technique[tid] = {
|
||||
**sub,
|
||||
"occurrences": [], # flat list of unique occurrences
|
||||
"_seen_keys": set(), # (substep_ref, criteria) dedup set
|
||||
}
|
||||
|
||||
agg = by_technique[tid]
|
||||
|
||||
# Deduplication key: same substep_ref + same criteria text = duplicate
|
||||
dedup_key = (sub["substep_ref"], sub["criteria"])
|
||||
if dedup_key in agg["_seen_keys"]:
|
||||
continue
|
||||
agg["_seen_keys"].add(dedup_key)
|
||||
|
||||
agg["occurrences"].append({
|
||||
"scenario_name": sub["scenario_name"],
|
||||
"substep_ref": sub["substep_ref"],
|
||||
"step_num": sub["step_num"],
|
||||
"step_name": sub["step_name"],
|
||||
"step_description": sub["step_description"],
|
||||
"criteria": sub["criteria"],
|
||||
"data_sources": sub["data_sources"],
|
||||
"detection_type": sub["detection_type"],
|
||||
"best_score": sub["best_score"],
|
||||
"note": sub["note"],
|
||||
})
|
||||
|
||||
# Promote best detection score
|
||||
if sub["best_score"] > agg["best_score"]:
|
||||
agg["best_score"] = sub["best_score"]
|
||||
agg["detection_type"] = sub["detection_type"]
|
||||
agg["note"] = sub["note"]
|
||||
agg["tactic_id"] = sub["tactic_id"]
|
||||
agg["tactic_name"] = sub["tactic_name"]
|
||||
|
||||
# Clean up internal dedup sets before returning
|
||||
for agg in by_technique.values():
|
||||
agg.pop("_seen_keys", None)
|
||||
|
||||
return by_technique
|
||||
|
||||
|
||||
def _group_occurrences_by_scenario(occurrences: list[dict]) -> dict[str, list[dict]]:
|
||||
"""Group a technique's occurrences by scenario, preserving insertion order."""
|
||||
grouped: dict[str, list[dict]] = {}
|
||||
for occ in occurrences:
|
||||
sc = occ.get("scenario_name", "Scenario_1")
|
||||
grouped.setdefault(sc, []).append(occ)
|
||||
return grouped
|
||||
|
||||
|
||||
def _build_procedure_text(agg: dict, adversary_display: str, eval_round: int) -> str:
|
||||
"""Build a rich attack-path narrative for the Test.procedure_text field.
|
||||
|
||||
Groups substeps by scenario so adversaries with multiple threat groups
|
||||
(e.g. Wizard Spider + Sandworm with 3 scenarios) are clearly separated.
|
||||
Includes Step.Description narrative for context.
|
||||
"""
|
||||
occurrences = agg.get("occurrences", [])
|
||||
if not occurrences:
|
||||
return (
|
||||
f"MITRE ATT&CK Evaluation simulation using {adversary_display} TTPs. "
|
||||
f"See evaluation report at https://evals.mitre.org for full details."
|
||||
)
|
||||
|
||||
lines: list[str] = [f"ATT&CK Evaluation R{eval_round} — {adversary_display}", ""]
|
||||
|
||||
grouped = _group_occurrences_by_scenario(occurrences)
|
||||
scenario_count = len(grouped)
|
||||
|
||||
for sc_name, sc_occs in grouped.items():
|
||||
# Scenario header — only shown when there are multiple scenarios
|
||||
if scenario_count > 1:
|
||||
idx = sc_name.replace("Scenario_", "Scenario ")
|
||||
lines.append(f"=== {idx} ===")
|
||||
|
||||
# Within each scenario, group by step to emit description once per step
|
||||
seen_step_descs: set = set()
|
||||
for occ in sc_occs:
|
||||
step_num = occ.get("step_num", "")
|
||||
step_name = occ.get("step_name", "")
|
||||
step_desc = occ.get("step_description", "")
|
||||
# Use (step_num or step_name) as dedup key for descriptions
|
||||
step_key = str(step_num) if step_num else step_name
|
||||
|
||||
if step_key and step_key not in seen_step_descs:
|
||||
seen_step_descs.add(step_key)
|
||||
header = f"Step {step_num} — {step_name}:" if step_num else f"— {step_name}:"
|
||||
lines.append(header)
|
||||
if step_desc:
|
||||
truncated = step_desc[:450] + ("…" if len(step_desc) > 450 else "")
|
||||
lines.append(truncated)
|
||||
|
||||
ref = occ.get("substep_ref", "")
|
||||
criteria = occ.get("criteria", "")
|
||||
det = occ.get("detection_type", "")
|
||||
if criteria:
|
||||
tag = f" [{ref}]" if ref else " •"
|
||||
det_tag = f" [{det}]" if det and det.lower() not in ("none", "") else ""
|
||||
lines.append(f"{tag}{det_tag} {criteria}")
|
||||
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines).rstrip()
|
||||
|
||||
|
||||
def _build_description(agg: dict, adversary_display: str, eval_round: int) -> str:
|
||||
"""Build Test.description with source metadata, detection guidance and warning.
|
||||
|
||||
The 'criteria' field from the MITRE API describes what each substep does AND
|
||||
what should be detected, so it doubles as blue-team detection guidance.
|
||||
"""
|
||||
occurrences = agg.get("occurrences", [])
|
||||
|
||||
# Collect all unique data sources across every unique occurrence
|
||||
all_data_sources: list[str] = sorted({
|
||||
src
|
||||
for occ in occurrences
|
||||
for src in occ.get("data_sources", [])
|
||||
})
|
||||
|
||||
header = (
|
||||
f"Source: MITRE ATT&CK Evaluation — Round {eval_round} ({adversary_display}).\n"
|
||||
f"Vendor: CrowdStrike Falcon.\n"
|
||||
f"Detection type achieved: {agg['detection_type']}."
|
||||
)
|
||||
|
||||
ds_section = ""
|
||||
if all_data_sources:
|
||||
ds_section = "\n\nData sources observed:\n" + "\n".join(
|
||||
f" • {ds}" for ds in all_data_sources
|
||||
)
|
||||
|
||||
# Detection guidance — what criteria were observed (blue team can use these as IOCs)
|
||||
det_lines: list[str] = []
|
||||
grouped = _group_occurrences_by_scenario(occurrences)
|
||||
for sc_name, sc_occs in grouped.items():
|
||||
scenario_label = f"[{sc_name}] " if len(grouped) > 1 else ""
|
||||
for occ in sc_occs:
|
||||
ref = occ.get("substep_ref", "")
|
||||
step_name = occ.get("step_name", "")
|
||||
criteria = occ.get("criteria", "")
|
||||
det_type = occ.get("detection_type", "")
|
||||
if criteria:
|
||||
label = f"[{ref}]" if ref else "•"
|
||||
step_label = f" ({step_name})" if step_name else ""
|
||||
det_label = f" — {det_type}" if det_type and det_type.lower() not in ("none", "") else ""
|
||||
det_lines.append(f" {scenario_label}{label}{step_label}{det_label}: {criteria}")
|
||||
|
||||
det_section = ""
|
||||
if det_lines:
|
||||
det_section = "\n\nDetection criteria (what to look for):\n" + "\n".join(det_lines)
|
||||
|
||||
warning = (
|
||||
f"\n\n⚠️ IMPORTANT: These results reflect CrowdStrike Falcon performance in a "
|
||||
f"controlled MITRE lab environment against a simulated {adversary_display} "
|
||||
f"adversary. They do NOT represent your organisation's actual detection "
|
||||
f"capability. Validate in your own environment before approving."
|
||||
)
|
||||
|
||||
note_section = f"\n\nMITRE note: {agg['note']}" if agg.get("note") else ""
|
||||
|
||||
return header + ds_section + det_section + warning + note_section
|
||||
|
||||
|
||||
def _build_red_summary(agg: dict, adversary_display: str, eval_round: int) -> str:
|
||||
"""Build the Red Team summary for the Test.red_summary field."""
|
||||
occurrences = agg.get("occurrences", [])
|
||||
|
||||
lines = [
|
||||
f"MITRE ATT&CK Evaluation — Round {eval_round} ({adversary_display})",
|
||||
"Vendor: CrowdStrike Falcon",
|
||||
f"Best detection level: {agg['detection_type']}",
|
||||
f"Tactic: {agg['tactic_name']} ({agg['tactic_id']})",
|
||||
f"Unique substeps: {len(occurrences)}",
|
||||
]
|
||||
|
||||
if occurrences:
|
||||
lines.append("")
|
||||
grouped = _group_occurrences_by_scenario(occurrences)
|
||||
for sc_name, sc_occs in grouped.items():
|
||||
if len(grouped) > 1:
|
||||
lines.append(f"{sc_name}:")
|
||||
for occ in sc_occs:
|
||||
ref = occ.get("substep_ref", "")
|
||||
criteria = occ.get("criteria", "")
|
||||
step_name = occ.get("step_name", "")
|
||||
det = occ.get("detection_type", "")
|
||||
if criteria:
|
||||
tag = f" [{ref}]" if ref else " •"
|
||||
step_tag = f" {step_name} —" if step_name else ""
|
||||
det_tag = f" [{det}]" if det and det.lower() not in ("none", "") else ""
|
||||
lines.append(f"{tag}{step_tag}{det_tag} {criteria}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main import function
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def import_evaluation_round(
|
||||
db: Session,
|
||||
adversary_name: str,
|
||||
adversary_display: str,
|
||||
eval_round: int,
|
||||
current_user: User,
|
||||
) -> dict[str, Any]:
|
||||
"""Import a single ATT&CK Evaluation round for CrowdStrike into the platform.
|
||||
|
||||
Creates one Test per unique technique with the best detection result
|
||||
observed across all substeps for that technique. All tests land in
|
||||
``in_review`` state — Blue Leads must confirm before they count as coverage.
|
||||
|
||||
Returns a summary dict: created, skipped, techniques_covered.
|
||||
Raises if the round was already imported (idempotency guard).
|
||||
"""
|
||||
# Idempotency — refuse duplicate imports
|
||||
existing = (
|
||||
db.query(EvaluationImport)
|
||||
.filter(
|
||||
EvaluationImport.adversary_name == adversary_name.lower(),
|
||||
EvaluationImport.status == "completed",
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
raise ValueError(
|
||||
f"Round '{adversary_display}' (round {eval_round}) was already imported "
|
||||
f"on {existing.imported_at.date()}. Re-import is not allowed."
|
||||
)
|
||||
|
||||
# Fetch and aggregate substep results
|
||||
substeps = fetch_results_for_adversary(adversary_name)
|
||||
by_technique = _aggregate_by_technique(substeps)
|
||||
|
||||
created = 0
|
||||
skipped = 0
|
||||
affected_technique_ids: set = set()
|
||||
|
||||
for mitre_id, agg in by_technique.items():
|
||||
# Look up the technique in our DB
|
||||
technique = (
|
||||
db.query(Technique)
|
||||
.filter(Technique.mitre_id == mitre_id.upper())
|
||||
.first()
|
||||
)
|
||||
if technique is None:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
detection_result = _score_to_result(agg["best_score"])
|
||||
|
||||
description = _build_description(agg, adversary_display, eval_round)
|
||||
red_summary = _build_red_summary(agg, adversary_display, eval_round)
|
||||
procedure_text = _build_procedure_text(agg, adversary_display, eval_round)
|
||||
|
||||
test = Test(
|
||||
technique_id=technique.id,
|
||||
name=f"[EVAL R{eval_round}] {adversary_display} — {technique.name}",
|
||||
description=description,
|
||||
platform=None,
|
||||
procedure_text=procedure_text,
|
||||
created_by=current_user.id,
|
||||
state=TestState.in_review,
|
||||
attack_success=True,
|
||||
red_summary=red_summary,
|
||||
red_validation_status="approved",
|
||||
red_validated_by=current_user.id,
|
||||
red_validated_at=datetime.utcnow(),
|
||||
detection_result=detection_result,
|
||||
blue_validation_status=None,
|
||||
execution_date=datetime.utcnow(),
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
db.add(test)
|
||||
db.flush()
|
||||
|
||||
log_action(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
action="eval_import_test",
|
||||
entity_type="test",
|
||||
entity_id=test.id,
|
||||
details={
|
||||
"adversary": adversary_name,
|
||||
"eval_round": eval_round,
|
||||
"mitre_id": mitre_id,
|
||||
"detection_type": agg["detection_type"],
|
||||
},
|
||||
)
|
||||
|
||||
affected_technique_ids.add(technique.id)
|
||||
created += 1
|
||||
|
||||
# Recalculate coverage for all touched techniques
|
||||
for tech_id in affected_technique_ids:
|
||||
tech = db.query(Technique).filter(Technique.id == tech_id).first()
|
||||
if tech:
|
||||
recalculate_technique_status(db, tech)
|
||||
|
||||
# Record the import
|
||||
record = EvaluationImport(
|
||||
id=uuid.uuid4(),
|
||||
adversary_name=adversary_name.lower(),
|
||||
adversary_display=adversary_display,
|
||||
eval_round=eval_round,
|
||||
imported_at=datetime.utcnow(),
|
||||
imported_by=current_user.id,
|
||||
tests_created=created,
|
||||
techniques_covered=len(affected_technique_ids),
|
||||
status="completed",
|
||||
notes=f"Skipped {skipped} techniques not found in local DB.",
|
||||
)
|
||||
db.add(record)
|
||||
db.commit()
|
||||
|
||||
logger.info(
|
||||
"ATT&CK Evaluation import complete — round %d (%s): %d tests created, %d skipped",
|
||||
eval_round, adversary_display, created, skipped,
|
||||
)
|
||||
return {
|
||||
"created": created,
|
||||
"skipped": skipped,
|
||||
"techniques_covered": len(affected_technique_ids),
|
||||
"adversary": adversary_display,
|
||||
"eval_round": eval_round,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# New-round check (used by the weekly scheduler)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def check_for_new_round(db: Session) -> dict[str, Any]:
|
||||
"""Check if a new evaluation round is available that hasn't been imported yet.
|
||||
|
||||
Returns:
|
||||
{"new_round_available": bool, "latest_round": dict | None, "already_imported": bool}
|
||||
"""
|
||||
try:
|
||||
latest = get_latest_round()
|
||||
except Exception as exc:
|
||||
logger.warning("Could not check for new ATT&CK Evaluation round: %s", exc)
|
||||
return {"new_round_available": False, "latest_round": None, "error": str(exc)}
|
||||
|
||||
already = (
|
||||
db.query(EvaluationImport)
|
||||
.filter(
|
||||
EvaluationImport.adversary_name == latest["name"].lower(),
|
||||
EvaluationImport.status == "completed",
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
return {
|
||||
"new_round_available": already is None,
|
||||
"already_imported": already is not None,
|
||||
"latest_round": {
|
||||
"name": latest["name"],
|
||||
"display_name": latest.get("display_name", latest["name"]),
|
||||
"eval_round": latest["eval_round"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Re-enrich existing tests with richer API data
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def re_enrich_evaluation_round(
|
||||
db: Session,
|
||||
adversary_name: str,
|
||||
adversary_display: str,
|
||||
eval_round: int,
|
||||
current_user: User,
|
||||
) -> dict[str, Any]:
|
||||
"""Update procedure_text / description / red_summary on already-imported tests
|
||||
for a given round using the enriched API data (attack path, criteria, data sources).
|
||||
|
||||
This is non-destructive — it only updates the three narrative fields and does
|
||||
not change detection results, state, or validation status.
|
||||
"""
|
||||
# Fetch & aggregate (same flow as import)
|
||||
substeps = fetch_results_for_adversary(adversary_name)
|
||||
by_technique = _aggregate_by_technique(substeps)
|
||||
|
||||
updated = 0
|
||||
skipped = 0
|
||||
|
||||
for mitre_id, agg in by_technique.items():
|
||||
technique = (
|
||||
db.query(Technique)
|
||||
.filter(Technique.mitre_id == mitre_id.upper())
|
||||
.first()
|
||||
)
|
||||
if technique is None:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
# Find the existing test for this round + technique
|
||||
existing_test = (
|
||||
db.query(Test)
|
||||
.filter(
|
||||
Test.technique_id == technique.id,
|
||||
Test.name.like(f"[EVAL R{eval_round}]%"),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not existing_test:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
existing_test.description = _build_description(agg, adversary_display, eval_round)
|
||||
existing_test.red_summary = _build_red_summary(agg, adversary_display, eval_round)
|
||||
existing_test.procedure_text = _build_procedure_text(agg, adversary_display, eval_round)
|
||||
updated += 1
|
||||
|
||||
db.commit()
|
||||
|
||||
logger.info(
|
||||
"Re-enrichment complete — round %d (%s): %d tests updated, %d skipped",
|
||||
eval_round, adversary_display, updated, skipped,
|
||||
)
|
||||
return {
|
||||
"updated": updated,
|
||||
"skipped": skipped,
|
||||
"adversary": adversary_display,
|
||||
"eval_round": eval_round,
|
||||
"message": (
|
||||
f"Re-enriched {updated} tests for {adversary_display} (Round {eval_round}) "
|
||||
f"with attack path, criteria and data sources from MITRE API."
|
||||
),
|
||||
}
|
||||
@@ -104,11 +104,16 @@ def log_action(
|
||||
user_agent=ua or None,
|
||||
# Keyword argument: session_id
|
||||
session_id=session_id,
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
# Stage new record(s) for database insertion
|
||||
db.add(entry)
|
||||
# Flush changes to DB without committing the transaction
|
||||
db.flush()
|
||||
# Reload from DB so the timestamp is in DB-stable format before hashing.
|
||||
# Without this, a round-trip through the DB (e.g. refresh after commit) can
|
||||
# return a timestamp with different precision/timezone, causing hash mismatch.
|
||||
db.refresh(entry)
|
||||
# Assign entry.integrity_hash = compute_integrity_hash(entry)
|
||||
entry.integrity_hash = compute_integrity_hash(entry)
|
||||
# Return entry
|
||||
|
||||
@@ -65,7 +65,10 @@ def change_password(
|
||||
if not verify_password(current_password, user.hashed_password):
|
||||
# Raise BusinessRuleViolation
|
||||
raise BusinessRuleViolation("Current password is incorrect")
|
||||
# Assign user.hashed_password = hash_password(new_password)
|
||||
if verify_password(new_password, user.hashed_password):
|
||||
raise BusinessRuleViolation(
|
||||
"New password must be different from the current password"
|
||||
)
|
||||
user.hashed_password = hash_password(new_password)
|
||||
# Assign user.must_change_password = False
|
||||
user.must_change_password = False
|
||||
|
||||
@@ -53,11 +53,9 @@ from sqlalchemy.orm import Session
|
||||
|
||||
# Import DataSource from app.models.data_source
|
||||
from app.models.data_source import DataSource
|
||||
|
||||
from app.models.technique import Technique
|
||||
# Import TestTemplate from app.models.test_template
|
||||
from app.models.test_template import TestTemplate
|
||||
|
||||
# Import log_action from app.services.audit_service
|
||||
from app.services.audit_service import log_action
|
||||
|
||||
# Assign logger = logging.getLogger(__name__)
|
||||
@@ -104,10 +102,15 @@ def _download_zip(url: str = CALDERA_ZIP_URL) -> bytes:
|
||||
# Define function _extract_zip
|
||||
def _extract_zip(zip_bytes: bytes, dest: str) -> Path:
|
||||
"""Extract *zip_bytes* into *dest* and return abilities dir."""
|
||||
# Open context manager
|
||||
dest_path = Path(dest).resolve()
|
||||
with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
|
||||
# Call zf.extractall()
|
||||
zf.extractall(dest)
|
||||
for member in zf.infolist():
|
||||
target = (dest_path / member.filename).resolve()
|
||||
if not target.is_relative_to(dest_path):
|
||||
raise ValueError(
|
||||
f"Zip Slip detected — '{member.filename}' resolves outside target directory"
|
||||
)
|
||||
zf.extract(member, dest)
|
||||
# Assign abilities_dir = Path(dest) / _ZIP_ROOT_PREFIX / "data" / "abilities"
|
||||
abilities_dir = Path(dest) / _ZIP_ROOT_PREFIX / "data" / "abilities"
|
||||
# Check: not abilities_dir.is_dir()
|
||||
@@ -368,6 +371,7 @@ def sync(db: Session) -> dict:
|
||||
created = 0
|
||||
# Assign skipped = 0
|
||||
skipped = 0
|
||||
new_technique_ids: set[str] = set()
|
||||
|
||||
# Iterate over parsed
|
||||
for item in parsed:
|
||||
@@ -405,10 +409,14 @@ def sync(db: Session) -> dict:
|
||||
db.add(template)
|
||||
# Call existing_ids.add()
|
||||
existing_ids.add(item["atomic_test_id"])
|
||||
# Assign created = 1
|
||||
new_technique_ids.add(item["mitre_technique_id"])
|
||||
created += 1
|
||||
|
||||
# Commit all pending changes to the database
|
||||
if new_technique_ids:
|
||||
db.query(Technique).filter(
|
||||
Technique.mitre_id.in_(new_technique_ids)
|
||||
).update({"review_required": True}, synchronize_session=False)
|
||||
|
||||
db.commit()
|
||||
|
||||
# Assign summary = {
|
||||
|
||||
@@ -34,6 +34,7 @@ from app.models.test import Test
|
||||
|
||||
# Import calculate_next_run from app.services.campaign_scheduler_service
|
||||
from app.services.campaign_scheduler_service import calculate_next_run
|
||||
from app.services.status_service import recalculate_technique_status
|
||||
|
||||
# Import from app.services.campaign_service
|
||||
from app.services.campaign_service import (
|
||||
@@ -120,7 +121,7 @@ def serialize_campaign(db: Session, campaign: Campaign) -> dict:
|
||||
"threat_actor_name": actor.name if actor else None,
|
||||
# Literal argument value
|
||||
"created_by": str(campaign.created_by) if campaign.created_by else None,
|
||||
# Literal argument value
|
||||
"start_date": campaign.start_date.isoformat() if campaign.start_date else None,
|
||||
"scheduled_at": campaign.scheduled_at.isoformat() if campaign.scheduled_at else None,
|
||||
# Literal argument value
|
||||
"completed_at": campaign.completed_at.isoformat() if campaign.completed_at else None,
|
||||
@@ -171,7 +172,7 @@ def serialize_campaign_summary(db: Session, campaign: Campaign) -> dict:
|
||||
"threat_actor_id": str(campaign.threat_actor_id) if campaign.threat_actor_id else None,
|
||||
# Literal argument value
|
||||
"threat_actor_name": actor.name if actor else None,
|
||||
# Literal argument value
|
||||
"start_date": campaign.start_date.isoformat() if campaign.start_date else None,
|
||||
"target_platform": campaign.target_platform,
|
||||
# Literal argument value
|
||||
"tags": campaign.tags or [],
|
||||
@@ -274,6 +275,7 @@ def create_campaign(
|
||||
tags: Optional[list[str]] = None,
|
||||
# Entry: scheduled_at
|
||||
scheduled_at: Optional[str] = None,
|
||||
start_date: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Create a new campaign. Does not commit; caller commits."""
|
||||
# Assign campaign = Campaign(
|
||||
@@ -294,6 +296,7 @@ def create_campaign(
|
||||
created_by=creator_id,
|
||||
# Keyword argument: scheduled_at
|
||||
scheduled_at=datetime.fromisoformat(scheduled_at) if scheduled_at else None,
|
||||
start_date=datetime.fromisoformat(start_date) if start_date else None,
|
||||
)
|
||||
# Stage new record(s) for database insertion
|
||||
db.add(campaign)
|
||||
@@ -358,6 +361,8 @@ def update_campaign(
|
||||
if "scheduled_at" in fields and fields["scheduled_at"]:
|
||||
# Assign fields["scheduled_at"] = datetime.fromisoformat(fields["scheduled_at"])
|
||||
fields["scheduled_at"] = datetime.fromisoformat(fields["scheduled_at"])
|
||||
if "start_date" in fields and fields["start_date"]:
|
||||
fields["start_date"] = datetime.fromisoformat(fields["start_date"])
|
||||
|
||||
# Iterate over fields.items()
|
||||
for field, value in fields.items():
|
||||
@@ -531,11 +536,29 @@ def remove_test_from_campaign(db: Session, campaign_id: str, campaign_test_id: s
|
||||
# Assign dep.depends_on = None
|
||||
dep.depends_on = None
|
||||
|
||||
# Mark record for deletion on next commit
|
||||
# Keep a reference to the underlying test before deleting the join record
|
||||
test_id = ct.test_id
|
||||
technique_id = None
|
||||
test_obj = db.query(Test).filter(Test.id == test_id).first()
|
||||
if test_obj:
|
||||
technique_id = test_obj.technique_id
|
||||
|
||||
db.delete(ct)
|
||||
# Flush changes to DB without committing the transaction
|
||||
db.flush()
|
||||
|
||||
# Also delete the actual test record (it was created for this campaign)
|
||||
if test_obj:
|
||||
db.delete(test_obj)
|
||||
db.flush()
|
||||
|
||||
# Recalculate technique status_global so coverage metrics stay consistent
|
||||
if technique_id:
|
||||
technique = db.query(Technique).filter(Technique.id == technique_id).first()
|
||||
if technique:
|
||||
recalculate_technique_status(db, technique)
|
||||
db.flush()
|
||||
|
||||
|
||||
# Define function activate_campaign
|
||||
def activate_campaign(db: Session, campaign_id: str) -> Campaign:
|
||||
@@ -697,7 +720,72 @@ def schedule_campaign(
|
||||
return campaign
|
||||
|
||||
|
||||
# Define function get_campaign_history
|
||||
def delete_campaign(
|
||||
db: Session,
|
||||
campaign_id: str,
|
||||
*,
|
||||
deleter_id: uuid.UUID,
|
||||
deleter_role: str,
|
||||
delete_tests: bool = False,
|
||||
) -> None:
|
||||
"""Delete a campaign.
|
||||
|
||||
Only draft campaigns can be deleted unless the caller is admin.
|
||||
If delete_tests=True, the associated Test objects are also deleted.
|
||||
Does not commit; caller commits.
|
||||
"""
|
||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||
if not campaign:
|
||||
raise EntityNotFoundError("Campaign", campaign_id)
|
||||
|
||||
if campaign.status != "draft" and deleter_role != "admin":
|
||||
raise BusinessRuleViolation("Only draft campaigns can be deleted")
|
||||
|
||||
if str(campaign.created_by) != str(deleter_id) and deleter_role != "admin":
|
||||
raise PermissionViolation("Only the creator or admin can delete this campaign")
|
||||
|
||||
# Collect test IDs before removing associations
|
||||
campaign_tests = (
|
||||
db.query(CampaignTest).filter(CampaignTest.campaign_id == campaign_id).all()
|
||||
)
|
||||
test_ids = [ct.test_id for ct in campaign_tests]
|
||||
|
||||
# Remove CampaignTest join rows (clear depends_on refs first to avoid FK cycles)
|
||||
for ct in campaign_tests:
|
||||
ct.depends_on = None
|
||||
db.flush()
|
||||
for ct in campaign_tests:
|
||||
db.delete(ct)
|
||||
db.flush()
|
||||
|
||||
# Optionally delete the associated tests
|
||||
if delete_tests:
|
||||
affected_technique_ids: set = set()
|
||||
for test_id in test_ids:
|
||||
test = db.query(Test).filter(Test.id == test_id).first()
|
||||
if test:
|
||||
if test.technique_id:
|
||||
affected_technique_ids.add(test.technique_id)
|
||||
db.delete(test)
|
||||
db.flush()
|
||||
# Recalculate status_global for every affected technique so the
|
||||
# coverage metrics stay consistent after test deletion.
|
||||
for tech_id in affected_technique_ids:
|
||||
technique = db.query(Technique).filter(Technique.id == tech_id).first()
|
||||
if technique:
|
||||
recalculate_technique_status(db, technique)
|
||||
db.flush()
|
||||
|
||||
# Null-out parent_campaign_id on child campaigns to avoid FK violation
|
||||
db.query(Campaign).filter(Campaign.parent_campaign_id == campaign.id).update(
|
||||
{"parent_campaign_id": None}
|
||||
)
|
||||
db.flush()
|
||||
|
||||
db.delete(campaign)
|
||||
db.flush()
|
||||
|
||||
|
||||
def get_campaign_history(db: Session, campaign_id: str) -> dict:
|
||||
"""List all child campaigns (execution history) of a recurring campaign.
|
||||
|
||||
|
||||
@@ -6,9 +6,9 @@ threat actors, and progress calculation.
|
||||
|
||||
# Import logging
|
||||
import logging
|
||||
|
||||
# Import uuid
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -179,6 +179,8 @@ def generate_campaign_from_threat_actor(
|
||||
actor_id: uuid.UUID,
|
||||
# Entry: user
|
||||
user: User,
|
||||
*,
|
||||
start_date: Optional[datetime] = None,
|
||||
) -> Campaign:
|
||||
"""Auto-generate a campaign from a threat actor's uncovered techniques.
|
||||
|
||||
@@ -236,6 +238,7 @@ def generate_campaign_from_threat_actor(
|
||||
created_by=user.id,
|
||||
# Keyword argument: tags
|
||||
tags=[actor.name, "auto-generated"],
|
||||
start_date=start_date,
|
||||
)
|
||||
# Stage new record(s) for database insertion
|
||||
db.add(campaign)
|
||||
@@ -288,6 +291,7 @@ def generate_campaign_from_threat_actor(
|
||||
created_by=user.id,
|
||||
# Keyword argument: state
|
||||
state=TestState.draft,
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
# Stage new record(s) for database insertion
|
||||
db.add(test)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -102,7 +102,7 @@ def _get_control_status(control: ComplianceControl, db: Session) -> dict[str, An
|
||||
"control_id": control.control_id,
|
||||
# Literal argument value
|
||||
"title": control.title,
|
||||
# Literal argument value
|
||||
"description": control.description,
|
||||
"category": control.category,
|
||||
# Literal argument value
|
||||
"status": "not_evaluated",
|
||||
@@ -173,7 +173,7 @@ def _get_control_status(control: ComplianceControl, db: Session) -> dict[str, An
|
||||
"control_id": control.control_id,
|
||||
# Literal argument value
|
||||
"title": control.title,
|
||||
# Literal argument value
|
||||
"description": control.description,
|
||||
"category": control.category,
|
||||
# Literal argument value
|
||||
"status": status,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user