Compare commits
89 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 1f19bd8432 | |||
| d2a46feba8 | |||
| 9ff0f04ba3 | |||
| 8f98bdd273 | |||
| 1249391ef0 | |||
| 05b221a22d | |||
| 2ee59d4e18 | |||
| bdeeed54e1 | |||
| 3e854b7b79 | |||
| 5b29c2fc56 | |||
| 6b076f52b2 | |||
| c0aff4cbeb | |||
| a8a24b5429 | |||
| b6f23f385d | |||
| 6ab950ec42 | |||
| ed2c34ef28 | |||
| 96fdd9fa85 | |||
| c28a47c43b | |||
| 0d4c404f08 | |||
| 03d7d1cc80 | |||
| b8c9c4ac6a | |||
| 73867d3990 | |||
| f45b7ea926 | |||
| 6b28934f05 | |||
| 6f35d85a97 | |||
| c5eb6f6dc1 | |||
| 9b70655b7e | |||
| 821c4ac5ec | |||
| abef2a45e0 | |||
| 309b3bc02d | |||
| 0148bf28dc | |||
| 79a4772ab5 | |||
| a9255e15ce | |||
| 0c526c48f9 | |||
| 0d211d5156 | |||
| 14d995b40c | |||
| 339d669498 | |||
| 9e22fde746 | |||
| bbc2dddd86 | |||
| d77075272e | |||
| c0c6cda11d | |||
| 44621364be | |||
| 0eff48c768 | |||
| 764a2f7579 | |||
| f4c74230ec | |||
| 50b70704ae | |||
| 20738d11b3 | |||
| 4e3787d091 | |||
| 93fde55389 | |||
| 560fc0c9f0 | |||
| d305db8794 | |||
| 25fddad17c | |||
| 8d5c5fa80e | |||
| 42a9f4dcd4 | |||
| 2b6d9090c9 | |||
| 0b65f51d1c | |||
| f41b8fd8c2 | |||
| 1521005b62 | |||
| 5c55e7c17f | |||
| e651ef8a8c | |||
| 1338d52cd0 | |||
| 576705d61d | |||
| 9e204b78ec | |||
| bc8025ffcf | |||
| 633c8e46ad | |||
| 611e10620e | |||
| 55dba1e00a | |||
| 6147abc87a | |||
| bfce1a8a0e | |||
| 98e8ca1eef | |||
| f0f59facdb | |||
| 898bb7e4e7 | |||
| 51c927394d | |||
| a4a2adccee | |||
| 8f764d8e39 | |||
| 222979574a | |||
| 31e116b4ba | |||
| febf460580 | |||
| 005a09b42f | |||
| 7e33746539 | |||
| 703dd891d3 | |||
| 9b98f60a9a | |||
| 6d18a5417d | |||
| 6a327f6b51 | |||
| 875d7b1a15 | |||
| 64d64080e0 | |||
| e7e63161e8 | |||
| 38285f885c | |||
| cc0bbdf797 |
@@ -0,0 +1,189 @@
|
||||
---
|
||||
description: Aegis backend Clean Architecture rules. Apply when working on any backend Python file under backend/app/ or backend/tests/.
|
||||
globs: backend/**/*.py
|
||||
---
|
||||
|
||||
# Aegis — Clean Modular Monolith Architecture
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
Aegis follows a **Clean Architecture** pattern inside a modular monolith. The backend has four layers with strict dependency rules:
|
||||
|
||||
```
|
||||
Presentation → Application → Domain ← Infrastructure
|
||||
```
|
||||
|
||||
**The golden rule:** dependencies only point towards the Domain layer. Infrastructure implements the ports (interfaces) defined in Domain.
|
||||
|
||||
## Layer Structure and Rules
|
||||
|
||||
### Domain Layer (`backend/app/domain/`)
|
||||
|
||||
The innermost layer. **ZERO** imports from FastAPI, SQLAlchemy, Pydantic, or any framework.
|
||||
|
||||
| Directory | Purpose |
|
||||
|-----------|---------|
|
||||
| `domain/enums.py` | Canonical domain enums (TechniqueStatus, TestState, TeamSide, TestResult) |
|
||||
| `domain/errors.py` | Exception hierarchy (DomainError → EntityNotFoundError, InvalidStateTransition, etc.) |
|
||||
| `domain/exceptions.py` | Backward-compatible re-exports from errors.py |
|
||||
| `domain/test_entity.py` | TestEntity — pure state machine with domain events |
|
||||
| `domain/entities/` | Rich domain entities (TechniqueEntity, etc.) with business behavior |
|
||||
| `domain/value_objects/` | Immutable value types (MitreId, ScoringWeights) |
|
||||
| `domain/ports/repositories/` | Protocol interfaces defining data access contracts |
|
||||
| `domain/ports/services/` | Protocol interfaces for external capabilities (storage, events) |
|
||||
| `domain/unit_of_work.py` | UnitOfWork wrapping SQLAlchemy session |
|
||||
|
||||
**NEVER** import from `app.models`, `app.routers`, `app.infrastructure`, `fastapi`, or `sqlalchemy` inside `domain/`.
|
||||
|
||||
### Application Layer (`backend/app/application/` — future)
|
||||
|
||||
Use case orchestrators. Depends only on Domain.
|
||||
|
||||
| Directory | Purpose |
|
||||
|-----------|---------|
|
||||
| `application/use_cases/` | One class per business operation |
|
||||
| `application/dto/` | Plain data containers for use case input/output |
|
||||
| `application/interfaces/` | Application-level contracts (UnitOfWork protocol) |
|
||||
|
||||
### Infrastructure Layer (`backend/app/infrastructure/`)
|
||||
|
||||
Implements ports defined in Domain. Depends on Domain and Application.
|
||||
|
||||
| Directory | Purpose |
|
||||
|-----------|---------|
|
||||
| `infrastructure/redis_client.py` | Redis connection singleton |
|
||||
| `infrastructure/persistence/repositories/` | SQLAlchemy implementations of repository ports |
|
||||
| `infrastructure/persistence/mappers/` | ORM model ↔ domain entity converters |
|
||||
|
||||
### Presentation Layer (routers, schemas, dependencies)
|
||||
|
||||
HTTP boundary. Depends on Application and Domain (for exceptions).
|
||||
|
||||
| Directory | Purpose |
|
||||
|-----------|---------|
|
||||
| `routers/` | FastAPI routers — HTTP mapping only |
|
||||
| `schemas/` | Pydantic request/response models |
|
||||
| `dependencies/` | FastAPI `Depends()` wiring (auth, repositories) |
|
||||
| `middleware/` | Error handler mapping domain exceptions → HTTP responses |
|
||||
|
||||
## Import Rules (Strict)
|
||||
|
||||
| From \ To | domain/ | application/ | infrastructure/ | presentation/ |
|
||||
|-----------|---------|-------------|----------------|--------------|
|
||||
| **domain/** | Self only | FORBIDDEN | FORBIDDEN | FORBIDDEN |
|
||||
| **application/** | ALLOWED | Self only | FORBIDDEN | FORBIDDEN |
|
||||
| **infrastructure/** | ALLOWED (ports) | ALLOWED (UoW) | Self only | FORBIDDEN |
|
||||
| **presentation/** | ALLOWED (exceptions) | ALLOWED (use cases) | ALLOWED (wiring in dependencies/) | Self only |
|
||||
|
||||
## How to Add a New Feature
|
||||
|
||||
### 1. Start from the Domain
|
||||
|
||||
- Define or reuse domain entities in `domain/entities/`
|
||||
- Add value objects if needed in `domain/value_objects/`
|
||||
- Define repository port if a new aggregate root in `domain/ports/repositories/`
|
||||
- Domain exceptions go in `domain/errors.py`
|
||||
- Business rules live IN the entity, not in services or routers
|
||||
|
||||
### 2. Implement Infrastructure
|
||||
|
||||
- Create SQLAlchemy repository implementation in `infrastructure/persistence/repositories/`
|
||||
- Create mapper if converting between ORM model and domain entity
|
||||
- Repository does NOT call `commit()` — only `flush()`
|
||||
- Transaction control belongs to the Unit of Work
|
||||
|
||||
### 3. Wire in Presentation
|
||||
|
||||
- Add FastAPI `Depends()` provider in `dependencies/repositories.py`
|
||||
- Keep routers thin: parse request → call service/use case → return response
|
||||
- Map domain exceptions to HTTP via the error handler middleware (automatic)
|
||||
|
||||
### 4. Tests (Mandatory)
|
||||
|
||||
Every change MUST include tests:
|
||||
- **Domain entities/value objects**: pure unit tests, no DB, no mocking frameworks
|
||||
- **Repositories**: integration tests using the `db` fixture from conftest
|
||||
- **Routers**: API tests using the `client` fixture
|
||||
- At least one success test + one failure/edge-case test per behavior
|
||||
|
||||
Before committing, run: `scripts/agent_validate_backend.sh`
|
||||
|
||||
## Existing Patterns to Follow
|
||||
|
||||
### Domain Entity Pattern (see `domain/test_entity.py`)
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class SomeEntity:
|
||||
id: uuid.UUID
|
||||
# fields...
|
||||
_events: list[DomainEvent] = field(default_factory=list, repr=False)
|
||||
|
||||
@classmethod
|
||||
def from_orm(cls, model: Any) -> "SomeEntity":
|
||||
"""Build from SQLAlchemy model."""
|
||||
...
|
||||
|
||||
def apply_to(self, model: Any) -> None:
|
||||
"""Copy mutable fields back onto the ORM model."""
|
||||
...
|
||||
|
||||
def some_business_method(self) -> None:
|
||||
"""Business logic lives HERE, not in services."""
|
||||
...
|
||||
self._events.append(DomainEvent("something_happened"))
|
||||
```
|
||||
|
||||
### Repository Port Pattern (Protocol)
|
||||
|
||||
```python
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
@runtime_checkable
|
||||
class SomeRepository(Protocol):
|
||||
def find_by_id(self, id: uuid.UUID) -> SomeEntity | None: ...
|
||||
def save(self, entity: SomeEntity) -> SomeEntity: ...
|
||||
```
|
||||
|
||||
### Repository Implementation Pattern
|
||||
|
||||
```python
|
||||
class SASomeRepository:
|
||||
def __init__(self, session: Session) -> None:
|
||||
self._session = session
|
||||
|
||||
def find_by_id(self, id: uuid.UUID) -> SomeEntity | None:
|
||||
model = self._session.query(SomeModel).filter(SomeModel.id == id).first()
|
||||
return SomeMapper.to_entity(model) if model else None
|
||||
|
||||
def save(self, entity: SomeEntity) -> SomeEntity:
|
||||
model = SomeMapper.to_model(entity)
|
||||
merged = self._session.merge(model)
|
||||
self._session.flush() # NO commit — UoW does that
|
||||
return SomeMapper.to_entity(merged)
|
||||
```
|
||||
|
||||
### Error Handling (automatic via middleware)
|
||||
|
||||
Services raise domain exceptions → middleware maps to HTTP:
|
||||
- `EntityNotFoundError` → 404
|
||||
- `DuplicateEntityError` → 409
|
||||
- `InvalidStateTransition` → 400
|
||||
- `BusinessRuleViolation` → 400
|
||||
- `PermissionViolation` → 403
|
||||
|
||||
### Coexistence Strategy
|
||||
|
||||
Old code (direct `db.query()` in routers) and new code (repositories) coexist. Migration is incremental:
|
||||
1. New endpoints use repositories
|
||||
2. Existing endpoints are migrated one at a time
|
||||
3. Both access the same DB, same session, same tables
|
||||
|
||||
## Key Conventions
|
||||
|
||||
- **Enums**: canonical source is `domain/enums.py`, `models/enums.py` re-exports
|
||||
- **Exceptions**: raise from `domain/errors.py`, never raise `HTTPException` from services
|
||||
- **Commits**: only via `UnitOfWork.commit()` or at the router level, never inside services/repos
|
||||
- **IDs**: UUID everywhere (primary keys, foreign keys)
|
||||
- **Tests**: SQLite in-memory for unit/integration, PostgreSQL in CI
|
||||
- **Validation**: Pydantic in schemas (presentation), domain rules in entities (domain)
|
||||
+26
-6
@@ -1,24 +1,44 @@
|
||||
# =============================================================================
|
||||
# Aegis Environment Variables
|
||||
# =============================================================================
|
||||
# Copy this file to .env and fill in the values
|
||||
# Copy this file to .env and fill in the values BEFORE deploying.
|
||||
#
|
||||
# Generate secure random values with:
|
||||
# openssl rand -hex 32 (for SECRET_KEY)
|
||||
# openssl rand -base64 18 (for passwords)
|
||||
# =============================================================================
|
||||
|
||||
# ── Database ─────────────────────────────────────────────────────────────────
|
||||
DB_USER=postgres
|
||||
DB_PASSWORD=change-me-in-production
|
||||
DB_PASSWORD= # REQUIRED — set a strong password
|
||||
DB_NAME=attackdb
|
||||
|
||||
# ── Security ─────────────────────────────────────────────────────────────────
|
||||
# IMPORTANT: Generate a strong random key for production
|
||||
# Example: openssl rand -hex 32
|
||||
SECRET_KEY=change-me-in-production-use-a-long-random-string
|
||||
# REQUIRED in production — the app will refuse to start without it.
|
||||
# Generate with: openssl rand -hex 32
|
||||
SECRET_KEY=
|
||||
|
||||
TOKEN_EXPIRE_MINUTES=60
|
||||
|
||||
# ── Initial Admin Account ────────────────────────────────────────────────────
|
||||
# If ADMIN_PASSWORD is empty, a random password is auto-generated and
|
||||
# printed to the backend container logs on first startup.
|
||||
ADMIN_USERNAME=admin
|
||||
ADMIN_PASSWORD=
|
||||
|
||||
# ── MinIO Object Storage ─────────────────────────────────────────────────────
|
||||
MINIO_ACCESS_KEY=minioadmin
|
||||
MINIO_SECRET_KEY=change-me-in-production
|
||||
MINIO_SECRET_KEY= # REQUIRED — set a strong password
|
||||
MINIO_BUCKET=evidence
|
||||
MINIO_SECURE=false # Set to true if MinIO is behind TLS
|
||||
|
||||
# ── CORS ──────────────────────────────────────────────────────────────────────
|
||||
# Comma-separated list of allowed frontend origins
|
||||
CORS_ORIGINS=https://your-domain.com
|
||||
|
||||
# ── Frontend ─────────────────────────────────────────────────────────────────
|
||||
FRONTEND_PORT=80
|
||||
|
||||
# ── Environment flag ─────────────────────────────────────────────────────────
|
||||
# Set to "production" for production deployments (enforces SECRET_KEY, etc.)
|
||||
AEGIS_ENV=production
|
||||
|
||||
@@ -0,0 +1,64 @@
|
||||
name: Aegis CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main, develop]
|
||||
pull_request:
|
||||
branches: [main]
|
||||
|
||||
jobs:
|
||||
lint-and-test:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
services:
|
||||
postgres:
|
||||
image: postgres:15-alpine
|
||||
env:
|
||||
POSTGRES_DB: testdb
|
||||
POSTGRES_USER: test
|
||||
POSTGRES_PASSWORD: test
|
||||
ports:
|
||||
- 5432:5432
|
||||
options: >-
|
||||
--health-cmd pg_isready
|
||||
--health-interval 10s
|
||||
--health-timeout 5s
|
||||
--health-retries 5
|
||||
redis:
|
||||
image: redis:7-alpine
|
||||
ports:
|
||||
- 6379:6379
|
||||
options: >-
|
||||
--health-cmd "redis-cli ping"
|
||||
--health-interval 10s
|
||||
--health-timeout 5s
|
||||
--health-retries 5
|
||||
|
||||
defaults:
|
||||
run:
|
||||
working-directory: backend
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
cache: pip
|
||||
cache-dependency-path: backend/requirements.txt
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install --upgrade pip
|
||||
pip install -r requirements.txt
|
||||
pip install ruff
|
||||
|
||||
- name: Lint
|
||||
run: ruff check app/ tests/
|
||||
|
||||
- name: Test
|
||||
env:
|
||||
DATABASE_URL: postgresql://test:test@localhost:5432/testdb
|
||||
REDIS_URL: redis://localhost:6379/0
|
||||
SECRET_KEY: ci-test-secret-key-not-for-production
|
||||
run: pytest tests/ -v --tb=short
|
||||
-1232
File diff suppressed because it is too large
Load Diff
-1431
File diff suppressed because it is too large
Load Diff
-1475
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,7 @@
|
||||
# Aegis — MITRE ATT&CK Coverage Platform
|
||||
|
||||
Continuous integration (lint + tests against PostgreSQL and Redis) is defined in [`.github/workflows/ci.yml`](.github/workflows/ci.yml).
|
||||
|
||||
Aegis is a comprehensive platform for tracking and managing security coverage against the MITRE ATT&CK framework. It enables security teams to document, validate, and visualize their defensive capabilities against known adversary techniques through a structured Red Team / Blue Team validation workflow.
|
||||
|
||||
## Features
|
||||
@@ -81,12 +83,14 @@ Both Red Lead and Blue Lead must independently vote:
|
||||
|
||||
## Tech Stack
|
||||
|
||||
- **Backend**: FastAPI (Python 3.11)
|
||||
- **Database**: PostgreSQL 15 with UUID primary keys and JSONB columns
|
||||
- **Backend**: FastAPI (Python 3.11) — Clean Modular Monolith with domain entities, services, and repository pattern
|
||||
- **Database**: PostgreSQL 16 with UUID primary keys and JSONB columns
|
||||
- **Object Storage**: MinIO (S3-compatible)
|
||||
- **ORM**: SQLAlchemy with Alembic migrations (18 migration files)
|
||||
- **ORM**: SQLAlchemy 2.x with Alembic migrations
|
||||
- **Frontend**: React 19 + TypeScript + Vite 7 + Tailwind CSS v4 + TanStack Query + TanStack Virtual
|
||||
- **Cache / Token Store**: Redis (token blacklist, score caching)
|
||||
- **Scheduler**: APScheduler (MITRE sync, Intel scan, Notification cleanup, Snapshots, Recurring campaigns)
|
||||
- **Testing**: Pytest (367+ tests), Ruff (linting), GitHub Actions CI
|
||||
- **Charts**: Recharts
|
||||
|
||||
## Quick Start
|
||||
@@ -108,14 +112,16 @@ chmod +x scripts/install.sh
|
||||
./scripts/install.sh
|
||||
```
|
||||
|
||||
The install script will automatically:
|
||||
- Generate a `.env` file with secure random secrets
|
||||
- Build and start all containers (PostgreSQL, MinIO, Backend, Frontend)
|
||||
- Run database migrations
|
||||
- Seed the admin user and data sources
|
||||
- Optionally run the initial MITRE ATT&CK sync
|
||||
The interactive install wizard will guide you through:
|
||||
1. **Domain configuration** — your domain or IP, protocol (HTTP/HTTPS), and port
|
||||
2. **Admin account** — custom username and password (or auto-generated secure password)
|
||||
3. **Database** — name, user, and password (or auto-generated)
|
||||
4. **Session duration** — JWT token expiry (default: 15 minutes)
|
||||
5. **MITRE ATT&CK sync** — optionally import ~700 techniques on first run
|
||||
|
||||
Access the application at **http://your-server:80**.
|
||||
The script automatically generates cryptographically secure random secrets for `SECRET_KEY`, database password, and MinIO credentials. A summary with all credentials is displayed at the end of the installation.
|
||||
|
||||
Access the application at the URL shown in the installation summary.
|
||||
|
||||
### Development Setup
|
||||
|
||||
@@ -132,14 +138,18 @@ Access at **http://localhost:5173** (frontend dev server) and **http://localhost
|
||||
|
||||
### Authentication
|
||||
|
||||
JWT-based authentication. Default admin credentials after seeding:
|
||||
JWT-based authentication with HttpOnly cookies. Admin credentials are configured during installation:
|
||||
|
||||
```
|
||||
Username: admin
|
||||
Password: admin123
|
||||
- If you set a custom password in the wizard, use that.
|
||||
- If you left it blank, a secure random password was auto-generated and displayed in the installation summary and backend logs.
|
||||
|
||||
To retrieve auto-generated credentials after installation:
|
||||
|
||||
```bash
|
||||
docker logs aegis-backend 2>&1 | grep -A 5 "Initial Admin User Created"
|
||||
```
|
||||
|
||||
> **Important:** Change the default `admin123` password immediately after first login.
|
||||
> **Note:** Passwords must meet complexity requirements: minimum 12 characters with at least one uppercase letter, one lowercase letter, one digit, and one special character.
|
||||
|
||||
### Importing Data Sources
|
||||
|
||||
@@ -161,10 +171,11 @@ curl -X POST http://your-server/api/v1/data-sources/sync-all -H "Authorization:
|
||||
|
||||
### Production Considerations
|
||||
|
||||
- **HTTPS/TLS:** For internet-facing deployments, place a reverse proxy with TLS in front (e.g., Traefik, Caddy, or Nginx with Let's Encrypt).
|
||||
- **HTTPS/TLS:** For internet-facing deployments, place a reverse proxy with TLS in front (e.g., Traefik, Caddy, or Nginx with Let's Encrypt). Uncomment the HSTS header in `frontend/nginx.conf` once HTTPS is configured.
|
||||
- **Backups:** Set up regular PostgreSQL backups: `docker exec aegis-postgres pg_dump -U postgres attackdb > backup.sql`
|
||||
- **Updates:** To update, pull the latest code and run: `docker compose -f docker-compose.prod.yml up -d --build`
|
||||
- **Firewall:** Only expose port 80/443. All other services (DB, MinIO, backend) are internal only.
|
||||
- **Reconfigure:** Run `./scripts/install.sh` again to reconfigure the environment (domain, credentials, etc.).
|
||||
|
||||
### Configuring Scoring Weights
|
||||
|
||||
@@ -221,11 +232,13 @@ Or at runtime via the admin API — see [docs/SCORING.md](docs/SCORING.md).
|
||||
|
||||
## API Documentation
|
||||
|
||||
Interactive API documentation available at:
|
||||
Interactive API documentation is available **in development only** (disabled in production for security):
|
||||
|
||||
- **Swagger UI**: http://localhost:8000/docs
|
||||
- **ReDoc**: http://localhost:8000/redoc
|
||||
|
||||
> In production (`AEGIS_ENV=production`), these endpoints are disabled. Use the development environment or refer to [docs/API.md](docs/API.md).
|
||||
|
||||
### API Endpoints
|
||||
|
||||
| Group | Prefix | Key Operations |
|
||||
@@ -256,22 +269,66 @@ See [docs/API.md](docs/API.md) for the full endpoint reference.
|
||||
|
||||
## Configuration
|
||||
|
||||
All variables are configured automatically by `scripts/install.sh`. For manual setup, copy `.env.example` to `.env` and fill in the values.
|
||||
|
||||
### Required (production)
|
||||
|
||||
| Variable | Description |
|
||||
|----------|-------------|
|
||||
| `SECRET_KEY` | JWT signing key — **required** in production (app refuses to start without it). Generate with `openssl rand -hex 32` |
|
||||
| `DB_PASSWORD` | PostgreSQL password |
|
||||
| `MINIO_SECRET_KEY` | MinIO secret key |
|
||||
|
||||
### Security & Auth
|
||||
|
||||
| Variable | Default | Description |
|
||||
|----------|---------|-------------|
|
||||
| `DATABASE_URL` | `postgresql://postgres:postgres@postgres:5432/attackdb` | PostgreSQL connection |
|
||||
| `SECRET_KEY` | `change-me-in-production` | JWT signing key |
|
||||
| `ALGORITHM` | `HS256` | JWT signing algorithm |
|
||||
| `ACCESS_TOKEN_EXPIRE_MINUTES` | `60` | Token lifetime |
|
||||
| `MINIO_ENDPOINT` | `minio:9000` | MinIO server |
|
||||
| `AEGIS_ENV` | — | Set to `production` to enforce security settings |
|
||||
| `ADMIN_USERNAME` | `admin` | Initial admin account username |
|
||||
| `ADMIN_PASSWORD` | *(auto-generated)* | Initial admin password. If empty, a secure random password is generated and shown in logs |
|
||||
| `ACCESS_TOKEN_EXPIRE_MINUTES` | `15` | JWT token lifetime in minutes |
|
||||
| `CORS_ORIGINS` | `http://localhost:5173` | Comma-separated list of allowed frontend origins |
|
||||
|
||||
### Infrastructure
|
||||
|
||||
| Variable | Default | Description |
|
||||
|----------|---------|-------------|
|
||||
| `DB_USER` | `postgres` | PostgreSQL username |
|
||||
| `DB_NAME` | `attackdb` | PostgreSQL database name |
|
||||
| `MINIO_ENDPOINT` | `minio:9000` | MinIO server address |
|
||||
| `MINIO_ACCESS_KEY` | `minioadmin` | MinIO access key |
|
||||
| `MINIO_SECRET_KEY` | `minioadmin` | MinIO secret key |
|
||||
| `MINIO_BUCKET` | `evidence` | Evidence bucket |
|
||||
| `MAX_RETEST_COUNT` | `3` | Max automatic retests per original test |
|
||||
| `MINIO_BUCKET` | `evidence` | MinIO bucket for evidence files |
|
||||
| `MINIO_SECURE` | `false` | Set to `true` if MinIO is behind TLS |
|
||||
| `FRONTEND_PORT` | `80` | Port exposed by the frontend container |
|
||||
|
||||
### Scoring Weights
|
||||
|
||||
| Variable | Default | Description |
|
||||
|----------|---------|-------------|
|
||||
| `SCORING_WEIGHT_TESTS` | `40` | Weight for test validation component |
|
||||
| `SCORING_WEIGHT_DETECTION_RULES` | `20` | Weight for detection rules component |
|
||||
| `SCORING_WEIGHT_D3FEND` | `15` | Weight for D3FEND coverage component |
|
||||
| `SCORING_WEIGHT_FRESHNESS` | `15` | Weight for freshness component |
|
||||
| `SCORING_WEIGHT_PLATFORM_DIVERSITY` | `10` | Weight for platform diversity component |
|
||||
| `MAX_RETEST_COUNT` | `3` | Max automatic retests per original test |
|
||||
|
||||
## Security
|
||||
|
||||
Aegis includes several security hardening measures:
|
||||
|
||||
- **Authentication:** JWT tokens stored in HttpOnly/Secure/SameSite cookies (immune to XSS theft). Token revocation via Redis-backed blacklist on logout.
|
||||
- **Rate limiting:** Login endpoint limited to 5 attempts per minute per IP (via slowapi).
|
||||
- **Password policy:** Minimum 12 characters with uppercase, lowercase, digit, and special character.
|
||||
- **CORS:** Configurable origins via `CORS_ORIGINS` environment variable. Restrictive method and header lists.
|
||||
- **Nginx security headers:** CSP, X-Frame-Options, X-Content-Type-Options, Referrer-Policy, Permissions-Policy.
|
||||
- **Non-root container:** Backend runs as `appuser` (UID 1001), not root.
|
||||
- **File uploads:** 50 MB size limit, extension whitelist, filename sanitization.
|
||||
- **ZIP imports:** Zip Slip (path traversal) and Zip Bomb (size/entry limit) protection.
|
||||
- **API surface:** Swagger UI, ReDoc, and OpenAPI schema disabled in production.
|
||||
- **Health endpoint:** Restricted to internal networks via Nginx ACL.
|
||||
- **Input sanitization:** LIKE wildcard escaping in all search queries; Pydantic schemas on all endpoints.
|
||||
- **XML parsing:** Uses `defusedxml` to prevent Billion Laughs / XXE attacks.
|
||||
- **Error handling:** Internal exception details are logged server-side only, never exposed to clients.
|
||||
|
||||
## Project Structure
|
||||
|
||||
@@ -279,54 +336,50 @@ See [docs/API.md](docs/API.md) for the full endpoint reference.
|
||||
Aegis/
|
||||
├── docker-compose.yml
|
||||
├── docker-compose.prod.yml
|
||||
├── .github/workflows/ci.yml # GitHub Actions: ruff + pytest on PostgreSQL + Redis
|
||||
├── docs/
|
||||
│ ├── API.md # Full API endpoint reference
|
||||
│ ├── ARCHITECTURE.md # System architecture and DB schema
|
||||
│ ├── ARCHITECTURE.md # System architecture, DB schema, service map
|
||||
│ ├── ADR.md # Architecture Decision Records
|
||||
│ ├── DATA_SOURCES.md # External data source documentation
|
||||
│ └── SCORING.md # Scoring system and metrics
|
||||
│ ├── SCORING.md # Scoring system and metrics
|
||||
│ ├── TECHNOLOGY_JUSTIFICATION.md
|
||||
│ ├── C4_CONTEXT_DIAGRAM.md # System context (C4 Level 1)
|
||||
│ └── C4_CONTAINER_DIAGRAM.md # Container architecture (C4 Level 2)
|
||||
├── backend/
|
||||
│ ├── Dockerfile
|
||||
│ ├── requirements.txt
|
||||
│ ├── alembic.ini
|
||||
│ ├── alembic/versions/ # b001–b018 migration files
|
||||
│ ├── alembic/versions/ # Database migration files
|
||||
│ ├── pytest.ini
|
||||
│ ├── tests/ # 367+ pytest tests (domain, service, API)
|
||||
│ └── app/
|
||||
│ ├── main.py # FastAPI app with all routers + lifespan
|
||||
│ ├── config.py # Settings from environment
|
||||
│ ├── config.py # Pydantic Settings from environment
|
||||
│ ├── database.py # SQLAlchemy engine + session (lazy init)
|
||||
│ ├── storage.py # MinIO/S3 helpers
|
||||
│ ├── auth.py # Password hashing + JWT tokens
|
||||
│ ├── models/ # 18 model files (SQLAlchemy ORM)
|
||||
│ ├── domain/ # Pure business logic (zero framework imports)
|
||||
│ │ ├── entities/ # Rich domain entities (Technique, Campaign, etc.)
|
||||
│ │ ├── ports/ # Protocol interfaces (repos, ImportService)
|
||||
│ │ ├── value_objects/ # Immutable types (MitreId, ScoringWeights)
|
||||
│ │ ├── errors.py # Domain exception hierarchy
|
||||
│ │ └── unit_of_work.py # Transaction management
|
||||
│ ├── infrastructure/ # SQLAlchemy repos, Redis, mappers
|
||||
│ ├── models/ # SQLAlchemy ORM models
|
||||
│ ├── schemas/ # Pydantic request/response schemas
|
||||
│ ├── routers/ # 21 API routers
|
||||
│ ├── services/ # 20 business logic services
|
||||
│ ├── dependencies/ # Auth dependencies (get_current_user, require_role)
|
||||
│ └── jobs/
|
||||
│ └── mitre_sync_job.py # APScheduler: 5 background jobs
|
||||
├── frontend/src/
|
||||
│ ├── App.tsx # Routes with lazy loading + role protection
|
||||
│ ├── api/ # 22 API client modules (Axios + TanStack Query)
|
||||
│ ├── components/
|
||||
│ │ ├── Layout.tsx # Sidebar + header + NotificationBell
|
||||
│ │ ├── Sidebar.tsx # Role-aware collapsible navigation
|
||||
│ │ ├── heatmap/ # ATT&CK heatmap (6 components)
|
||||
│ │ ├── compliance/ # Compliance UI (gauge, controls table)
|
||||
│ │ └── test-detail/ # Test detail sub-components
|
||||
│ ├── hooks/
|
||||
│ │ └── useDebounce.ts # Debounce hook for search inputs
|
||||
│ ├── context/
|
||||
│ │ └── AuthContext.tsx # Auth state management
|
||||
│ └── pages/ # 21 page components
|
||||
└── backend/tests/
|
||||
├── conftest.py # SQLite test DB with JSONB/UUID compatibility
|
||||
├── fixtures/ # YAML/TOML/JSON test fixtures
|
||||
├── test_data_sources.py # Data source parsing tests
|
||||
├── test_scoring_and_compliance.py # Scoring + metrics + compliance tests
|
||||
├── test_campaigns_and_snapshots.py # Campaign, snapshot, and retest tests
|
||||
├── test_workflow.py # Red/Blue workflow tests
|
||||
├── test_templates_crud.py # Template CRUD tests
|
||||
├── test_metrics_v2.py # V2 metrics tests
|
||||
└── test_integration_v2.py # Full integration E2E tests
|
||||
│ ├── routers/ # 27 thin HTTP adapter routers
|
||||
│ ├── services/ # 46 framework-agnostic business services
|
||||
│ ├── middleware/ # Error handler (domain exceptions → HTTP)
|
||||
│ ├── dependencies/ # FastAPI dependency injection (auth, repos)
|
||||
│ └── jobs/ # APScheduler background jobs
|
||||
└── frontend/src/
|
||||
├── App.tsx # Routes with lazy loading + role protection
|
||||
├── api/ # API client modules (Axios + TanStack Query)
|
||||
├── components/ # Reusable UI components
|
||||
├── hooks/ # Custom hooks (useDebounce, etc.)
|
||||
├── context/ # Auth state management
|
||||
└── pages/ # Page components
|
||||
```
|
||||
|
||||
## Development
|
||||
@@ -369,10 +422,13 @@ GET /api/v1/compliance/{framework_id}/gaps
|
||||
|
||||
## Further Documentation
|
||||
|
||||
- **[Architecture](docs/ARCHITECTURE.md)** — Database schema, service layer, state machine diagrams
|
||||
- **[Data Sources](docs/DATA_SOURCES.md)** — All external data sources with import instructions
|
||||
- **[Scoring](docs/SCORING.md)** — Scoring system explained with examples and configuration
|
||||
- **[Architecture](docs/ARCHITECTURE.md)** — Database schema, backend layers, domain entities, service map
|
||||
- **[API Reference](docs/API.md)** — Full endpoint documentation
|
||||
- **[Scoring](docs/SCORING.md)** — Scoring system explained with examples and configuration
|
||||
- **[Data Sources](docs/DATA_SOURCES.md)** — All external data sources with import instructions
|
||||
- **[ADRs](docs/ADR.md)** — Architecture Decision Records
|
||||
- **[Technology Justification](docs/TECHNOLOGY_JUSTIFICATION.md)** — Technology choices and rationale
|
||||
- **[C4 Diagrams](docs/C4_CONTEXT_DIAGRAM.md)** — System context and container architecture
|
||||
|
||||
## License
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -19,6 +19,12 @@ COPY . .
|
||||
# Make entrypoints executable
|
||||
RUN chmod +x /app/entrypoint.sh /app/entrypoint.prod.sh
|
||||
|
||||
# Create a non-root user and give it ownership of /app
|
||||
RUN adduser --disabled-password --gecos '' --uid 1001 appuser \
|
||||
&& chown -R appuser:appuser /app
|
||||
|
||||
USER appuser
|
||||
|
||||
# Expose port
|
||||
EXPOSE 8000
|
||||
|
||||
|
||||
@@ -0,0 +1,67 @@
|
||||
"""add_composite_indexes
|
||||
|
||||
Additional composite indexes for scoring, heatmap, metrics, reports,
|
||||
MTTD/MTTR calculations, and notification queries.
|
||||
|
||||
Revision ID: b019composite
|
||||
Revises: b018perfidx
|
||||
Create Date: 2026-02-17 14:00:00.000000
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "b019composite"
|
||||
down_revision: Union[str, None] = "b018perfidx"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ── Tests ────────────────────────────────────────────────────────
|
||||
# Used by scoring queries that filter by state + validation date
|
||||
op.create_index(
|
||||
"ix_tests_state_red_validated_at",
|
||||
"tests",
|
||||
["state", "red_validated_at"],
|
||||
)
|
||||
|
||||
# Used by remediation dashboard and metrics
|
||||
op.create_index(
|
||||
"ix_tests_remediation_status",
|
||||
"tests",
|
||||
["remediation_status"],
|
||||
)
|
||||
|
||||
# ── Audit logs ───────────────────────────────────────────────────
|
||||
# Three-column index for MTTD/MTTR queries that filter by entity + action
|
||||
op.create_index(
|
||||
"ix_audit_logs_entity_type_entity_id_action",
|
||||
"audit_logs",
|
||||
["entity_type", "entity_id", "action"],
|
||||
)
|
||||
|
||||
# Used for per-user audit trail queries
|
||||
op.create_index(
|
||||
"ix_audit_logs_user_id",
|
||||
"audit_logs",
|
||||
["user_id"],
|
||||
)
|
||||
|
||||
# ── Notifications ────────────────────────────────────────────────
|
||||
# Used by "unread notifications" badge and inbox queries
|
||||
op.create_index(
|
||||
"ix_notifications_user_id_read",
|
||||
"notifications",
|
||||
["user_id", "read"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_notifications_user_id_read", table_name="notifications")
|
||||
op.drop_index("ix_audit_logs_user_id", table_name="audit_logs")
|
||||
op.drop_index("ix_audit_logs_entity_type_entity_id_action", table_name="audit_logs")
|
||||
op.drop_index("ix_tests_remediation_status", table_name="tests")
|
||||
op.drop_index("ix_tests_state_red_validated_at", table_name="tests")
|
||||
@@ -0,0 +1,95 @@
|
||||
"""add_jira_links_and_worklogs
|
||||
|
||||
Revision ID: b020jiraworklogs
|
||||
Revises: b019composite
|
||||
Create Date: 2026-02-17 16:00:00.000000
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "b020jiraworklogs"
|
||||
down_revision: Union[str, None] = "b019composite"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ── jira_links: 100 % raw SQL to avoid all SQLAlchemy enum hooks ──
|
||||
op.execute("""
|
||||
DO $$ BEGIN
|
||||
IF NOT EXISTS (SELECT 1 FROM pg_type WHERE typname = 'jiralinkentitytype') THEN
|
||||
CREATE TYPE jiralinkentitytype AS ENUM ('test', 'technique', 'campaign', 'evidence');
|
||||
END IF;
|
||||
END $$;
|
||||
|
||||
DO $$ BEGIN
|
||||
IF NOT EXISTS (SELECT 1 FROM pg_type WHERE typname = 'jirasyncdirection') THEN
|
||||
CREATE TYPE jirasyncdirection AS ENUM ('aegis_to_jira', 'jira_to_aegis', 'bidirectional');
|
||||
END IF;
|
||||
END $$;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS jira_links (
|
||||
id UUID PRIMARY KEY,
|
||||
entity_type jiralinkentitytype NOT NULL,
|
||||
entity_id UUID NOT NULL,
|
||||
jira_issue_key VARCHAR(50) NOT NULL,
|
||||
jira_issue_id VARCHAR(50),
|
||||
jira_project_key VARCHAR(20),
|
||||
jira_status VARCHAR(100),
|
||||
jira_priority VARCHAR(50),
|
||||
jira_assignee VARCHAR(255),
|
||||
jira_story_points VARCHAR(10),
|
||||
sync_direction jirasyncdirection DEFAULT 'bidirectional',
|
||||
last_synced_at TIMESTAMP,
|
||||
sync_metadata JSONB DEFAULT '{}',
|
||||
created_by UUID REFERENCES users(id),
|
||||
created_at TIMESTAMP DEFAULT now(),
|
||||
updated_at TIMESTAMP DEFAULT now()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS ix_jira_links_entity_id
|
||||
ON jira_links (entity_id);
|
||||
CREATE INDEX IF NOT EXISTS ix_jira_links_issue_key
|
||||
ON jira_links (jira_issue_key);
|
||||
CREATE INDEX IF NOT EXISTS ix_jira_links_entity_type_entity_id
|
||||
ON jira_links (entity_type, entity_id);
|
||||
""")
|
||||
|
||||
# ── worklogs table (no enums, straightforward) ───────────────────
|
||||
op.execute("""
|
||||
CREATE TABLE IF NOT EXISTS worklogs (
|
||||
id UUID PRIMARY KEY,
|
||||
entity_type VARCHAR(50) NOT NULL,
|
||||
entity_id UUID NOT NULL,
|
||||
user_id UUID NOT NULL REFERENCES users(id),
|
||||
activity_type VARCHAR(100) NOT NULL,
|
||||
started_at TIMESTAMP NOT NULL,
|
||||
ended_at TIMESTAMP,
|
||||
duration_seconds INTEGER NOT NULL,
|
||||
description TEXT,
|
||||
tempo_synced TIMESTAMP,
|
||||
tempo_worklog_id VARCHAR(100),
|
||||
integrity_hash VARCHAR(64),
|
||||
created_at TIMESTAMP DEFAULT now(),
|
||||
metadata JSONB DEFAULT '{}'
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS ix_worklogs_entity_id
|
||||
ON worklogs (entity_id);
|
||||
CREATE INDEX IF NOT EXISTS ix_worklogs_user_id
|
||||
ON worklogs (user_id);
|
||||
CREATE INDEX IF NOT EXISTS ix_worklogs_entity_type_entity_id
|
||||
ON worklogs (entity_type, entity_id);
|
||||
""")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("DROP TABLE IF EXISTS worklogs")
|
||||
op.execute("DROP TABLE IF EXISTS jira_links")
|
||||
op.execute("DROP TYPE IF EXISTS jirasyncdirection")
|
||||
op.execute("DROP TYPE IF EXISTS jiralinkentitytype")
|
||||
@@ -0,0 +1,38 @@
|
||||
"""add_phase_timing_fields
|
||||
|
||||
Revision ID: b021phasetiming
|
||||
Revises: b020jiraworklogs
|
||||
Create Date: 2026-02-17 18:00:00.000000
|
||||
|
||||
Add red_started_at and blue_started_at columns to the tests table
|
||||
so that automatic worklogs can record real elapsed time per phase.
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision = "b021phasetiming"
|
||||
down_revision = "b020jiraworklogs"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute("""
|
||||
ALTER TABLE tests
|
||||
ADD COLUMN IF NOT EXISTS red_started_at TIMESTAMP,
|
||||
ADD COLUMN IF NOT EXISTS blue_started_at TIMESTAMP,
|
||||
ADD COLUMN IF NOT EXISTS paused_at TIMESTAMP,
|
||||
ADD COLUMN IF NOT EXISTS red_paused_seconds INTEGER DEFAULT 0,
|
||||
ADD COLUMN IF NOT EXISTS blue_paused_seconds INTEGER DEFAULT 0;
|
||||
""")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("""
|
||||
ALTER TABLE tests
|
||||
DROP COLUMN IF EXISTS red_started_at,
|
||||
DROP COLUMN IF EXISTS blue_started_at,
|
||||
DROP COLUMN IF EXISTS paused_at,
|
||||
DROP COLUMN IF EXISTS red_paused_seconds,
|
||||
DROP COLUMN IF EXISTS blue_paused_seconds;
|
||||
""")
|
||||
@@ -0,0 +1,47 @@
|
||||
"""add_osint_items
|
||||
|
||||
Revision ID: b022osintitems
|
||||
Revises: b021phasetiming
|
||||
Create Date: 2026-02-17 22:00:00.000000
|
||||
|
||||
Add osint_items table for OSINT enrichment data linked to techniques.
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision = "b022osintitems"
|
||||
down_revision = "b021phasetiming"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute("""
|
||||
CREATE TABLE IF NOT EXISTS osint_items (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
technique_id UUID NOT NULL REFERENCES techniques(id),
|
||||
source_type VARCHAR(50) NOT NULL,
|
||||
source_url TEXT NOT NULL,
|
||||
title VARCHAR(500) NOT NULL,
|
||||
description TEXT,
|
||||
severity VARCHAR(20),
|
||||
discovered_at TIMESTAMP NOT NULL DEFAULT now(),
|
||||
reviewed BOOLEAN NOT NULL DEFAULT false,
|
||||
metadata JSONB DEFAULT '{}'::jsonb
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS ix_osint_items_technique_id
|
||||
ON osint_items (technique_id);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS ix_osint_items_source_type
|
||||
ON osint_items (source_type);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS ix_osint_items_reviewed
|
||||
ON osint_items (reviewed) WHERE NOT reviewed;
|
||||
""")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("""
|
||||
DROP TABLE IF EXISTS osint_items CASCADE;
|
||||
""")
|
||||
@@ -0,0 +1,30 @@
|
||||
"""add_must_change_password
|
||||
|
||||
Revision ID: b023mustchgpwd
|
||||
Revises: b022osintitems
|
||||
Create Date: 2026-02-17 23:00:00.000000
|
||||
|
||||
Add must_change_password column to users table to force password
|
||||
change on first login.
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision = "b023mustchgpwd"
|
||||
down_revision = "b022osintitems"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute("""
|
||||
ALTER TABLE users
|
||||
ADD COLUMN IF NOT EXISTS must_change_password BOOLEAN DEFAULT true;
|
||||
""")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("""
|
||||
ALTER TABLE users
|
||||
DROP COLUMN IF EXISTS must_change_password;
|
||||
""")
|
||||
@@ -0,0 +1,37 @@
|
||||
"""add_critical_test_audit_indexes
|
||||
|
||||
Add missing critical indexes for tests and audit_logs tables to match
|
||||
model __table_args__ declarations. Existing indexes (from b005, b018,
|
||||
b019) are left untouched; only the two genuinely new indexes are created.
|
||||
|
||||
Revision ID: b024critidx
|
||||
Revises: b023mustchgpwd
|
||||
Create Date: 2026-02-18 12:00:00.000000
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "b024critidx"
|
||||
down_revision: Union[str, None] = "b023mustchgpwd"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_index(
|
||||
"ix_tests_created_at",
|
||||
"tests",
|
||||
["created_at"],
|
||||
)
|
||||
op.create_index(
|
||||
"ix_tests_state_created_at",
|
||||
"tests",
|
||||
["state", "created_at"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_tests_state_created_at", table_name="tests")
|
||||
op.drop_index("ix_tests_created_at", table_name="tests")
|
||||
@@ -0,0 +1,41 @@
|
||||
"""add_unique_test_detection_result
|
||||
|
||||
Enforce one evaluation per (test, detection_rule) pair. Before creating
|
||||
the constraint, duplicate rows (if any) are collapsed so the migration
|
||||
never fails on an existing database.
|
||||
|
||||
Revision ID: b025uqtdr
|
||||
Revises: b024critidx
|
||||
Create Date: 2026-02-18 14:00:00.000000
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "b025uqtdr"
|
||||
down_revision: Union[str, None] = "b024critidx"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Remove duplicates keeping the most recently evaluated row
|
||||
op.execute("""
|
||||
DELETE FROM test_detection_results
|
||||
WHERE id NOT IN (
|
||||
SELECT DISTINCT ON (test_id, detection_rule_id) id
|
||||
FROM test_detection_results
|
||||
ORDER BY test_id, detection_rule_id, evaluated_at DESC NULLS LAST
|
||||
)
|
||||
""")
|
||||
|
||||
op.create_unique_constraint(
|
||||
"uq_tdr_test_rule",
|
||||
"test_detection_results",
|
||||
["test_id", "detection_rule_id"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_constraint("uq_tdr_test_rule", "test_detection_results", type_="unique")
|
||||
@@ -0,0 +1,38 @@
|
||||
"""add_technique_query_indexes
|
||||
|
||||
Add indexes on techniques table for common query patterns
|
||||
(filter by tactic, filter by status_global) used in heatmap, scoring,
|
||||
and list-all-techniques operations.
|
||||
|
||||
These may already exist if the ORM model auto-created them; the
|
||||
``if_not_exists`` flag makes this migration safe to run regardless.
|
||||
|
||||
Revision ID: b026techidx
|
||||
Revises: b025uqtdr
|
||||
Create Date: 2026-02-18 18:00:00.000000
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "b026techidx"
|
||||
down_revision: Union[str, None] = "b025uqtdr"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS ix_techniques_tactic "
|
||||
"ON techniques (tactic)"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS ix_techniques_status_global "
|
||||
"ON techniques (status_global)"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_techniques_status_global", table_name="techniques")
|
||||
op.drop_index("ix_techniques_tactic", table_name="techniques")
|
||||
@@ -0,0 +1,37 @@
|
||||
"""add_scoring_config
|
||||
|
||||
Single-row table to persist scoring weights in the database,
|
||||
replacing the mutable in-process Settings approach.
|
||||
|
||||
Revision ID: b027scorecfg
|
||||
Revises: b026techidx
|
||||
Create Date: 2026-02-19 10:00:00.000000
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
revision: str = "b027scorecfg"
|
||||
down_revision: Union[str, None] = "b026techidx"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"scoring_config",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column("weight_tests", sa.Float(), nullable=False, server_default="40.0"),
|
||||
sa.Column("weight_detection_rules", sa.Float(), nullable=False, server_default="20.0"),
|
||||
sa.Column("weight_d3fend", sa.Float(), nullable=False, server_default="15.0"),
|
||||
sa.Column("weight_freshness", sa.Float(), nullable=False, server_default="15.0"),
|
||||
sa.Column("weight_platform_diversity", sa.Float(), nullable=False, server_default="10.0"),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("scoring_config")
|
||||
@@ -0,0 +1,35 @@
|
||||
"""phase0 SR-006 — campaign_tests composite index
|
||||
|
||||
Most SR-006 indexes already ship in b005, b009, b018, b019, and b026.
|
||||
``tests`` has no ``campaign_id`` column (membership is ``campaign_tests``),
|
||||
so this revision adds a composite index to speed “tests in campaign” joins.
|
||||
|
||||
Revision ID: b028phase0
|
||||
Revises: b027scorecfg
|
||||
Create Date: 2026-05-18 12:00:00.000000
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "b028phase0"
|
||||
down_revision: Union[str, None] = "b027scorecfg"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_index(
|
||||
"ix_campaign_tests_campaign_id_test_id",
|
||||
"campaign_tests",
|
||||
["campaign_id", "test_id"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index(
|
||||
"ix_campaign_tests_campaign_id_test_id",
|
||||
table_name="campaign_tests",
|
||||
)
|
||||
@@ -0,0 +1,58 @@
|
||||
"""Phase 3: audit trail columns and data classification fields.
|
||||
|
||||
Revision ID: b029phase3
|
||||
Revises: b028phase0
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
revision: str = "b029phase3"
|
||||
down_revision: Union[str, None] = "b028phase0"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def _column_names(table: str) -> set[str]:
|
||||
bind = op.get_bind()
|
||||
insp = sa.inspect(bind)
|
||||
return {c["name"] for c in insp.get_columns(table)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
audit_cols = _column_names("audit_logs")
|
||||
if "ip_address" not in audit_cols:
|
||||
op.add_column("audit_logs", sa.Column("ip_address", sa.String(45), nullable=True))
|
||||
if "user_agent" not in audit_cols:
|
||||
op.add_column("audit_logs", sa.Column("user_agent", sa.String(500), nullable=True))
|
||||
if "integrity_hash" not in audit_cols:
|
||||
op.add_column("audit_logs", sa.Column("integrity_hash", sa.String(64), nullable=True))
|
||||
if "session_id" not in audit_cols:
|
||||
op.add_column("audit_logs", sa.Column("session_id", sa.String(100), nullable=True))
|
||||
|
||||
for table in ("tests", "evidences", "campaigns"):
|
||||
cols = _column_names(table)
|
||||
if "data_classification" not in cols:
|
||||
op.add_column(
|
||||
table,
|
||||
sa.Column(
|
||||
"data_classification",
|
||||
sa.String(20),
|
||||
nullable=False,
|
||||
server_default="internal",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
for table in ("campaigns", "evidences", "tests"):
|
||||
cols = _column_names(table)
|
||||
if "data_classification" in cols:
|
||||
op.drop_column(table, "data_classification")
|
||||
|
||||
audit_cols = _column_names("audit_logs")
|
||||
for col in ("session_id", "integrity_hash", "user_agent", "ip_address"):
|
||||
if col in audit_cols:
|
||||
op.drop_column("audit_logs", col)
|
||||
@@ -0,0 +1,117 @@
|
||||
"""Phase 5: scoring recency/severity columns and snapshot breakdown fields.
|
||||
|
||||
Revision ID: b030phase5
|
||||
Revises: b029phase3
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
revision: str = "b030phase5"
|
||||
down_revision: Union[str, None] = "b029phase3"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def _column_names(table: str) -> set[str]:
|
||||
bind = op.get_bind()
|
||||
insp = sa.inspect(bind)
|
||||
return {c["name"] for c in insp.get_columns(table)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
snap_cols = _column_names("coverage_snapshots")
|
||||
if "by_tactic" not in snap_cols:
|
||||
op.add_column(
|
||||
"coverage_snapshots",
|
||||
sa.Column("by_tactic", postgresql.JSONB(), nullable=False, server_default="{}"),
|
||||
)
|
||||
if "by_status" not in snap_cols:
|
||||
op.add_column(
|
||||
"coverage_snapshots",
|
||||
sa.Column("by_status", postgresql.JSONB(), nullable=False, server_default="{}"),
|
||||
)
|
||||
if "stale_count" not in snap_cols:
|
||||
op.add_column(
|
||||
"coverage_snapshots",
|
||||
sa.Column("stale_count", sa.Integer(), nullable=False, server_default="0"),
|
||||
)
|
||||
if "never_tested_count" not in snap_cols:
|
||||
op.add_column(
|
||||
"coverage_snapshots",
|
||||
sa.Column("never_tested_count", sa.Integer(), nullable=False, server_default="0"),
|
||||
)
|
||||
if "coverage_percentage" not in snap_cols:
|
||||
op.add_column(
|
||||
"coverage_snapshots",
|
||||
sa.Column("coverage_percentage", sa.Float(), nullable=False, server_default="0"),
|
||||
)
|
||||
|
||||
cfg_cols = _column_names("scoring_config")
|
||||
if "weight_recency" not in cfg_cols and "weight_freshness" in cfg_cols:
|
||||
op.alter_column(
|
||||
"scoring_config",
|
||||
"weight_freshness",
|
||||
new_column_name="weight_recency",
|
||||
)
|
||||
cfg_cols.remove("weight_freshness")
|
||||
cfg_cols.add("weight_recency")
|
||||
elif "weight_recency" not in cfg_cols:
|
||||
op.add_column(
|
||||
"scoring_config",
|
||||
sa.Column("weight_recency", sa.Float(), nullable=False, server_default="10.0"),
|
||||
)
|
||||
|
||||
if "weight_severity" not in cfg_cols and "weight_platform_diversity" in cfg_cols:
|
||||
op.alter_column(
|
||||
"scoring_config",
|
||||
"weight_platform_diversity",
|
||||
new_column_name="weight_severity",
|
||||
)
|
||||
elif "weight_severity" not in cfg_cols:
|
||||
op.add_column(
|
||||
"scoring_config",
|
||||
sa.Column("weight_severity", sa.Float(), nullable=False, server_default="10.0"),
|
||||
)
|
||||
|
||||
if "updated_by" not in cfg_cols:
|
||||
op.add_column(
|
||||
"scoring_config",
|
||||
sa.Column(
|
||||
"updated_by",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
sa.ForeignKey("users.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
cfg_cols = _column_names("scoring_config")
|
||||
if "updated_by" in cfg_cols:
|
||||
op.drop_column("scoring_config", "updated_by")
|
||||
if "weight_severity" in cfg_cols:
|
||||
op.alter_column(
|
||||
"scoring_config",
|
||||
"weight_severity",
|
||||
new_column_name="weight_platform_diversity",
|
||||
)
|
||||
if "weight_recency" in cfg_cols:
|
||||
op.alter_column(
|
||||
"scoring_config",
|
||||
"weight_recency",
|
||||
new_column_name="weight_freshness",
|
||||
)
|
||||
|
||||
for col in (
|
||||
"coverage_percentage",
|
||||
"never_tested_count",
|
||||
"stale_count",
|
||||
"by_status",
|
||||
"by_tactic",
|
||||
):
|
||||
if col in _column_names("coverage_snapshots"):
|
||||
op.drop_column("coverage_snapshots", col)
|
||||
@@ -0,0 +1 @@
|
||||
"""Aegis — MITRE ATT&CK Coverage Platform application package."""
|
||||
|
||||
+95
-7
@@ -1,20 +1,34 @@
|
||||
"""
|
||||
Security utilities: password hashing and JWT token management.
|
||||
"""Security utilities: password hashing and JWT token management.
|
||||
|
||||
This module provides pure functions for:
|
||||
- Hashing and verifying passwords using bcrypt via passlib.
|
||||
- Creating JWT access tokens using python-jose.
|
||||
- Creating JWT access tokens using PyJWT.
|
||||
- Managing a Redis-backed token blacklist for revocation.
|
||||
|
||||
No endpoints are defined here.
|
||||
"""
|
||||
|
||||
# Import logging
|
||||
import logging
|
||||
|
||||
# Import uuid
|
||||
import uuid as _uuid
|
||||
|
||||
# Import datetime, timedelta, timezone from datetime
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from jose import jwt
|
||||
# Import jwt (PyJWT)
|
||||
import jwt
|
||||
|
||||
# Import CryptContext from passlib.context
|
||||
from passlib.context import CryptContext
|
||||
|
||||
# Import settings from app.config
|
||||
from app.config import settings
|
||||
|
||||
# Assign logger = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Password hashing
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -22,13 +36,17 @@ from app.config import settings
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
|
||||
# Define function hash_password
|
||||
def hash_password(password: str) -> str:
|
||||
"""Return a bcrypt hash of *password*."""
|
||||
# Return pwd_context.hash(password)
|
||||
return pwd_context.hash(password)
|
||||
|
||||
|
||||
# Define function verify_password
|
||||
def verify_password(plain: str, hashed: str) -> bool:
|
||||
"""Return ``True`` if *plain* matches the bcrypt *hashed* value."""
|
||||
# Return pwd_context.verify(plain, hashed)
|
||||
return pwd_context.verify(plain, hashed)
|
||||
|
||||
|
||||
@@ -38,13 +56,83 @@ def verify_password(plain: str, hashed: str) -> bool:
|
||||
|
||||
|
||||
def create_access_token(data: dict) -> str:
|
||||
"""Create a signed JWT containing *data* plus an ``exp`` claim.
|
||||
"""Create a signed JWT containing *data* plus ``exp`` and ``jti`` claims.
|
||||
|
||||
The token expires after ``ACCESS_TOKEN_EXPIRE_MINUTES`` (from settings).
|
||||
- ``jti`` (JWT ID): unique identifier that enables token revocation.
|
||||
- ``exp``: expiration timestamp based on ``ACCESS_TOKEN_EXPIRE_MINUTES``.
|
||||
"""
|
||||
# Assign to_encode = data.copy()
|
||||
to_encode = data.copy()
|
||||
# Assign expire = datetime.now(timezone.utc) + timedelta(
|
||||
expire = datetime.now(timezone.utc) + timedelta(
|
||||
# Keyword argument: minutes
|
||||
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES,
|
||||
)
|
||||
to_encode.update({"exp": expire})
|
||||
# Call to_encode.update()
|
||||
to_encode.update({
|
||||
# Literal argument value
|
||||
"exp": expire,
|
||||
# Literal argument value
|
||||
"jti": str(_uuid.uuid4()),
|
||||
})
|
||||
# Return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGOR...
|
||||
return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Token blacklist (Redis-backed)
|
||||
# ---------------------------------------------------------------------------
|
||||
# Each revoked token's ``jti`` is stored in Redis with a TTL equal to the
|
||||
# token's remaining lifetime. This means entries auto-expire exactly when
|
||||
# the token would have become invalid anyway — no manual cleanup needed.
|
||||
#
|
||||
# Redis survives backend restarts, so blacklisted tokens stay revoked
|
||||
# across deploys and multi-worker setups.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_BLACKLIST_PREFIX = "blacklist:"
|
||||
|
||||
|
||||
# Define function blacklist_token
|
||||
def blacklist_token(jti: str, exp: float) -> None:
|
||||
"""Add *jti* to the Redis blacklist with a TTL derived from *exp*.
|
||||
|
||||
*exp* is the token's ``exp`` claim (epoch timestamp). The TTL is set
|
||||
to ``exp - now`` so the key vanishes when the token would have expired
|
||||
naturally.
|
||||
"""
|
||||
# Import get_redis_blacklist from app.infrastructure.redis_client
|
||||
from app.infrastructure.redis_client import get_redis_blacklist
|
||||
|
||||
# Assign ttl = max(int(exp - datetime.now(timezone.utc).timestamp()), 1)
|
||||
ttl = max(int(exp - datetime.now(timezone.utc).timestamp()), 1)
|
||||
# Attempt the following; catch errors below
|
||||
try:
|
||||
# Assign r = get_redis_blacklist()
|
||||
r = get_redis_blacklist()
|
||||
# Call r.setex()
|
||||
r.setex(f"{_BLACKLIST_PREFIX}{jti}", ttl, "1")
|
||||
# Handle Exception
|
||||
except Exception:
|
||||
# Log warning: "Failed to blacklist token %s in Redis", jti, exc_
|
||||
logger.warning("Failed to blacklist token %s in Redis", jti, exc_info=True)
|
||||
|
||||
|
||||
# Define function is_token_blacklisted
|
||||
def is_token_blacklisted(jti: str) -> bool:
|
||||
"""Return ``True`` if *jti* has been revoked (exists in Redis)."""
|
||||
# Import get_redis_blacklist from app.infrastructure.redis_client
|
||||
from app.infrastructure.redis_client import get_redis_blacklist
|
||||
|
||||
# Attempt the following; catch errors below
|
||||
try:
|
||||
# Assign r = get_redis_blacklist()
|
||||
r = get_redis_blacklist()
|
||||
# Return r.exists(f"{_BLACKLIST_PREFIX}{jti}") > 0
|
||||
return r.exists(f"{_BLACKLIST_PREFIX}{jti}") > 0
|
||||
# Handle Exception
|
||||
except Exception:
|
||||
# Log warning: "Failed to check blacklist for %s in Redis", jti,
|
||||
logger.warning("Failed to check blacklist for %s in Redis", jti, exc_info=True)
|
||||
# Return False
|
||||
return False
|
||||
|
||||
+181
-12
@@ -1,28 +1,197 @@
|
||||
"""Application configuration for the Aegis MITRE ATT&CK Coverage Platform.
|
||||
|
||||
Loads settings from environment variables and ``.env`` files via
|
||||
``pydantic-settings``. Validates critical secrets at import time and raises
|
||||
``RuntimeError`` (production) or issues a ``UserWarning`` (development) when
|
||||
unsafe defaults are detected.
|
||||
"""
|
||||
|
||||
# Import os
|
||||
import os
|
||||
|
||||
# Import secrets
|
||||
import secrets
|
||||
|
||||
# Import warnings
|
||||
import warnings
|
||||
|
||||
# Import BaseSettings from pydantic_settings
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Detect environment: "production" when AEGIS_ENV or common indicators are set
|
||||
# ---------------------------------------------------------------------------
|
||||
_is_production = os.environ.get("AEGIS_ENV", "").lower() == "production"
|
||||
|
||||
|
||||
# Define class Settings
|
||||
class Settings(BaseSettings):
|
||||
DATABASE_URL: str = "postgresql://postgres:postgres@postgres:5432/attackdb"
|
||||
SECRET_KEY: str = "change-me-in-production"
|
||||
ALGORITHM: str = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 60
|
||||
MINIO_ENDPOINT: str = "minio:9000"
|
||||
MINIO_ACCESS_KEY: str = "minioadmin"
|
||||
MINIO_SECRET_KEY: str = "minioadmin"
|
||||
MINIO_BUCKET: str = "evidence"
|
||||
"""Application settings loaded from environment variables and .env file."""
|
||||
|
||||
# Re-testing
|
||||
# Assign DATABASE_URL = "postgresql://postgres:postgres@postgres:5432/attackdb"
|
||||
DATABASE_URL: str = "postgresql://postgres:postgres@postgres:5432/attackdb"
|
||||
|
||||
# ── Security ──────────────────────────────────────────────────────
|
||||
# SECRET_KEY has NO safe default. In development a random key is
|
||||
# generated at startup (tokens invalidate on restart — acceptable
|
||||
# for local dev). In production it MUST be supplied via env/.env
|
||||
# so tokens survive restarts.
|
||||
SECRET_KEY: str = ""
|
||||
# Assign ALGORITHM = "HS256"
|
||||
ALGORITHM: str = "HS256"
|
||||
# Assign ACCESS_TOKEN_EXPIRE_MINUTES = 15 # short-lived for security; configurable via env
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 15 # short-lived for security; configurable via env
|
||||
|
||||
# ── Redis ─────────────────────────────────────────────────────────
|
||||
REDIS_URL: str = "redis://redis:6379/0"
|
||||
# Logical DB indices on the same Redis instance (PATH in URL is overridden).
|
||||
REDIS_TOKEN_BLACKLIST_DB: int = 1
|
||||
# Assign REDIS_CACHE_DB = 2
|
||||
REDIS_CACHE_DB: int = 2
|
||||
|
||||
# ── CORS ─────────────────────────────────────────────────────────
|
||||
# Comma-separated list of allowed origins, or a JSON array.
|
||||
# In dev this defaults to common local ports; in production set it
|
||||
# to the actual frontend domain(s).
|
||||
CORS_ORIGINS: str = "http://localhost:3000,http://localhost:5173"
|
||||
|
||||
# ── MinIO / S3 ───────────────────────────────────────────────────
|
||||
MINIO_ENDPOINT: str = "minio:9000"
|
||||
# Assign MINIO_ACCESS_KEY = "minioadmin"
|
||||
MINIO_ACCESS_KEY: str = "minioadmin"
|
||||
# Assign MINIO_SECRET_KEY = "minioadmin"
|
||||
MINIO_SECRET_KEY: str = "minioadmin"
|
||||
# Assign MINIO_BUCKET = "evidence"
|
||||
MINIO_BUCKET: str = "evidence"
|
||||
# Assign MINIO_SECURE = False # True → use HTTPS to connect to MinIO
|
||||
MINIO_SECURE: bool = False # True → use HTTPS to connect to MinIO
|
||||
|
||||
# ── Re-testing ───────────────────────────────────────────────────
|
||||
MAX_RETEST_COUNT: int = 3 # maximum automatic retests per original test
|
||||
|
||||
# Scoring weights (must sum to 100)
|
||||
# ── Jira Integration ────────────────────────────────────────────
|
||||
JIRA_ENABLED: bool = False
|
||||
# Assign JIRA_URL = ""
|
||||
JIRA_URL: str = ""
|
||||
# Assign JIRA_USERNAME = ""
|
||||
JIRA_USERNAME: str = ""
|
||||
# Assign JIRA_API_TOKEN = ""
|
||||
JIRA_API_TOKEN: str = ""
|
||||
# Assign JIRA_IS_CLOUD = True
|
||||
JIRA_IS_CLOUD: bool = True
|
||||
# Assign JIRA_DEFAULT_PROJECT = ""
|
||||
JIRA_DEFAULT_PROJECT: str = ""
|
||||
# Assign JIRA_ISSUE_TYPE_TEST = "Task"
|
||||
JIRA_ISSUE_TYPE_TEST: str = "Task"
|
||||
# Assign JIRA_ISSUE_TYPE_CAMPAIGN = "Epic"
|
||||
JIRA_ISSUE_TYPE_CAMPAIGN: str = "Epic"
|
||||
|
||||
# ── Tempo Integration ─────────────────────────────────────────────
|
||||
TEMPO_ENABLED: bool = False
|
||||
# Assign TEMPO_API_TOKEN = ""
|
||||
TEMPO_API_TOKEN: str = ""
|
||||
# Assign TEMPO_API_VERSION = 4
|
||||
TEMPO_API_VERSION: int = 4
|
||||
# Assign TEMPO_DEFAULT_WORK_TYPE = "Red Team"
|
||||
TEMPO_DEFAULT_WORK_TYPE: str = "Red Team"
|
||||
|
||||
# ── OSINT / Intelligence ────────────────────────────────────────
|
||||
NVD_API_KEY: str = "" # optional; increases NVD rate limit from 5/30s to 50/30s
|
||||
# Assign STALE_THRESHOLD_DAYS = 365 # days before coverage is considered stale
|
||||
STALE_THRESHOLD_DAYS: int = 365 # days before coverage is considered stale
|
||||
|
||||
# ── Reporting ─────────────────────────────────────────────────────
|
||||
REPORT_TEMPLATES_DIR: str = "app/templates/reports"
|
||||
# Assign REPORT_OUTPUT_DIR = "/tmp/aegis_reports"
|
||||
REPORT_OUTPUT_DIR: str = "/tmp/aegis_reports"
|
||||
# Assign COMPANY_NAME = "Organization"
|
||||
COMPANY_NAME: str = "Organization"
|
||||
# Assign COMPANY_LOGO_PATH = "app/templates/reports/assets/logo.png"
|
||||
COMPANY_LOGO_PATH: str = "app/templates/reports/assets/logo.png"
|
||||
|
||||
# ── Scoring weights (must sum to 100) ────────────────────────────
|
||||
SCORING_WEIGHT_TESTS: int = 40
|
||||
SCORING_WEIGHT_DETECTION_RULES: int = 20
|
||||
# Assign SCORING_WEIGHT_DETECTION_RULES = 25
|
||||
SCORING_WEIGHT_DETECTION_RULES: int = 25
|
||||
# Assign SCORING_WEIGHT_D3FEND = 15
|
||||
SCORING_WEIGHT_D3FEND: int = 15
|
||||
SCORING_WEIGHT_FRESHNESS: int = 15
|
||||
# Assign SCORING_WEIGHT_RECENCY = 10
|
||||
SCORING_WEIGHT_RECENCY: int = 10
|
||||
# Assign SCORING_WEIGHT_SEVERITY = 10
|
||||
SCORING_WEIGHT_SEVERITY: int = 10
|
||||
# Legacy env names (mapped in scoring_config_service)
|
||||
SCORING_WEIGHT_FRESHNESS: int = 10
|
||||
# Assign SCORING_WEIGHT_PLATFORM_DIVERSITY = 10
|
||||
SCORING_WEIGHT_PLATFORM_DIVERSITY: int = 10
|
||||
|
||||
# Define class Config
|
||||
class Config:
|
||||
"""Pydantic BaseSettings configuration — load from .env file."""
|
||||
|
||||
# Assign env_file = ".env"
|
||||
env_file = ".env"
|
||||
|
||||
|
||||
# Assign settings = Settings()
|
||||
settings = Settings()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Post-init validation for SECRET_KEY
|
||||
# ---------------------------------------------------------------------------
|
||||
_UNSAFE_SECRETS = {
|
||||
# Literal argument value
|
||||
"",
|
||||
# Literal argument value
|
||||
"change-me-in-production",
|
||||
# Literal argument value
|
||||
"change-me-in-production-use-a-long-random-string",
|
||||
}
|
||||
|
||||
# Check: settings.SECRET_KEY in _UNSAFE_SECRETS
|
||||
if settings.SECRET_KEY in _UNSAFE_SECRETS:
|
||||
# Check: _is_production
|
||||
if _is_production:
|
||||
# Raise RuntimeError
|
||||
raise RuntimeError(
|
||||
# Literal argument value
|
||||
"CRITICAL: SECRET_KEY is not configured. "
|
||||
# Literal argument value
|
||||
"Set a strong random value (>= 32 chars) via the SECRET_KEY "
|
||||
# Literal argument value
|
||||
"environment variable or in your .env file before running in "
|
||||
# Literal argument value
|
||||
"production. Example: openssl rand -hex 32"
|
||||
)
|
||||
# Development: auto-generate an ephemeral key and warn
|
||||
settings.SECRET_KEY = secrets.token_hex(32)
|
||||
# Call warnings.warn()
|
||||
warnings.warn(
|
||||
# Literal argument value
|
||||
"SECRET_KEY was not set — using an auto-generated ephemeral key. "
|
||||
# Literal argument value
|
||||
"JWT tokens will be invalidated on every restart. "
|
||||
# Literal argument value
|
||||
"Set SECRET_KEY in your environment for persistent sessions.",
|
||||
# Keyword argument: stacklevel
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SEC-002: Reject default credentials in production
|
||||
# ---------------------------------------------------------------------------
|
||||
if _is_production:
|
||||
# Assign _DEFAULT_CREDS = {
|
||||
_DEFAULT_CREDS = {
|
||||
("MINIO_ACCESS_KEY", settings.MINIO_ACCESS_KEY, "minioadmin"),
|
||||
("MINIO_SECRET_KEY", settings.MINIO_SECRET_KEY, "minioadmin"),
|
||||
}
|
||||
# Iterate over _DEFAULT_CREDS
|
||||
for name, current, default in _DEFAULT_CREDS:
|
||||
# Check: current == default
|
||||
if current == default:
|
||||
# Raise RuntimeError
|
||||
raise RuntimeError(
|
||||
f"CRITICAL: {name} is using the default value '{default}'. "
|
||||
f"Set a strong value via the {name} environment variable "
|
||||
f"before running in production."
|
||||
)
|
||||
|
||||
+117
-11
@@ -1,58 +1,164 @@
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker, declarative_base
|
||||
"""Database engine and session management for the Aegis platform.
|
||||
|
||||
The engine and session factory are created lazily so that tests can override
|
||||
``DATABASE_URL`` via environment variables before any import triggers real
|
||||
PostgreSQL engine creation (which requires psycopg2).
|
||||
"""
|
||||
|
||||
# Import Generator from collections.abc
|
||||
from collections.abc import Generator
|
||||
|
||||
# Import create_engine from sqlalchemy
|
||||
from sqlalchemy import create_engine
|
||||
|
||||
# Import Engine from sqlalchemy.engine
|
||||
from sqlalchemy.engine import Engine
|
||||
|
||||
# Import Session, declarative_base, sessionmaker from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session, declarative_base, sessionmaker
|
||||
|
||||
# Assign Base = declarative_base()
|
||||
Base = declarative_base()
|
||||
|
||||
# Engine and session factory are created lazily so that tests can
|
||||
# override DATABASE_URL via environment *before* any import triggers
|
||||
# the real PostgreSQL engine creation (which requires psycopg2).
|
||||
_engine = None
|
||||
# Assign _SessionLocal = None
|
||||
_SessionLocal = None
|
||||
|
||||
|
||||
def _get_engine():
|
||||
# Define function _get_engine
|
||||
def _get_engine() -> Engine:
|
||||
"""Return the shared SQLAlchemy engine, creating it on first call.
|
||||
|
||||
Returns:
|
||||
Engine: Configured SQLAlchemy engine for the application database.
|
||||
"""
|
||||
# Declare global variable
|
||||
global _engine
|
||||
# Check: _engine is None
|
||||
if _engine is None:
|
||||
# Import settings from app.config
|
||||
from app.config import settings
|
||||
_engine = create_engine(settings.DATABASE_URL)
|
||||
|
||||
# Assign url = settings.DATABASE_URL
|
||||
url = settings.DATABASE_URL
|
||||
# Assign kwargs = {}
|
||||
kwargs: dict = {}
|
||||
# Check: url.startswith("postgresql")
|
||||
if url.startswith("postgresql"):
|
||||
# Call kwargs.update()
|
||||
kwargs.update(
|
||||
# Keyword argument: pool_size
|
||||
pool_size=20,
|
||||
# Keyword argument: max_overflow
|
||||
max_overflow=10,
|
||||
# Keyword argument: pool_recycle
|
||||
pool_recycle=3600,
|
||||
# Keyword argument: pool_pre_ping
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
# Assign _engine = create_engine(url, **kwargs)
|
||||
_engine = create_engine(url, **kwargs)
|
||||
# Return _engine
|
||||
return _engine
|
||||
|
||||
|
||||
def _get_session_factory():
|
||||
# Define function _get_session_factory
|
||||
def _get_session_factory() -> sessionmaker:
|
||||
"""Return the shared sessionmaker, creating it on first call.
|
||||
|
||||
Returns:
|
||||
sessionmaker: Configured sessionmaker bound to the application engine.
|
||||
"""
|
||||
# Declare global variable
|
||||
global _SessionLocal
|
||||
# Check: _SessionLocal is None
|
||||
if _SessionLocal is None:
|
||||
# Assign _SessionLocal = sessionmaker(
|
||||
_SessionLocal = sessionmaker(
|
||||
# Keyword argument: autocommit
|
||||
autocommit=False, autoflush=False, bind=_get_engine()
|
||||
)
|
||||
# Return _SessionLocal
|
||||
return _SessionLocal
|
||||
|
||||
|
||||
# Define class _LazySessionLocal
|
||||
class _LazySessionLocal:
|
||||
"""Proxy so ``SessionLocal()`` keeps working as before but the real
|
||||
sessionmaker is only created on first call."""
|
||||
"""Proxy so ``SessionLocal()`` keeps working as before but the real sessionmaker is only created on first call."""
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
# Define function __call__
|
||||
def __call__(self, *args: object, **kwargs: object) -> Session:
|
||||
"""Create and return a new database session.
|
||||
|
||||
Args:
|
||||
*args (object): Positional arguments forwarded to the sessionmaker.
|
||||
**kwargs (object): Keyword arguments forwarded to the sessionmaker.
|
||||
|
||||
Returns:
|
||||
Session: A new SQLAlchemy database session.
|
||||
"""
|
||||
# Return _get_session_factory()(*args, **kwargs)
|
||||
return _get_session_factory()(*args, **kwargs)
|
||||
|
||||
def __getattr__(self, name):
|
||||
# Define function __getattr__
|
||||
def __getattr__(self, name: str) -> object:
|
||||
"""Delegate attribute access to the underlying sessionmaker.
|
||||
|
||||
Args:
|
||||
name (str): Attribute name to look up on the sessionmaker.
|
||||
|
||||
Returns:
|
||||
object: The attribute value from the underlying sessionmaker.
|
||||
"""
|
||||
# Return getattr(_get_session_factory(), name)
|
||||
return getattr(_get_session_factory(), name)
|
||||
|
||||
|
||||
# Assign SessionLocal = _LazySessionLocal()
|
||||
SessionLocal = _LazySessionLocal()
|
||||
|
||||
|
||||
# Define class _EngineProxy
|
||||
class _EngineProxy:
|
||||
"""Thin proxy so ``from app.database import engine`` still works."""
|
||||
def __getattr__(self, name):
|
||||
|
||||
# Define function __getattr__
|
||||
def __getattr__(self, name: str) -> object:
|
||||
"""Delegate attribute access to the lazily-created engine.
|
||||
|
||||
Args:
|
||||
name (str): Attribute name to look up on the real engine.
|
||||
|
||||
Returns:
|
||||
object: The attribute value from the underlying SQLAlchemy engine.
|
||||
"""
|
||||
# Return getattr(_get_engine(), name)
|
||||
return getattr(_get_engine(), name)
|
||||
|
||||
|
||||
# Assign engine = _EngineProxy() # type: ignore[assignment]
|
||||
engine = _EngineProxy() # type: ignore[assignment]
|
||||
|
||||
|
||||
def get_db():
|
||||
# Define function get_db
|
||||
def get_db() -> Generator[Session, None, None]:
|
||||
"""Yield a database session and close it when the request is done.
|
||||
|
||||
Intended for use as a FastAPI dependency.
|
||||
|
||||
Yields:
|
||||
Session: An active SQLAlchemy session for the current request.
|
||||
"""
|
||||
# Assign db = SessionLocal()
|
||||
db = SessionLocal()
|
||||
# Attempt the following; catch errors below
|
||||
try:
|
||||
# Yield db
|
||||
yield db
|
||||
# Always execute this cleanup block
|
||||
finally:
|
||||
# Close the database session
|
||||
db.close()
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
"""FastAPI dependency injection helpers for auth, DB, and shared state."""
|
||||
|
||||
@@ -1,26 +1,50 @@
|
||||
"""
|
||||
Authentication and RBAC dependencies for FastAPI.
|
||||
"""Authentication and RBAC dependencies for FastAPI.
|
||||
|
||||
Provides:
|
||||
- ``get_current_user``: decodes JWT, fetches user from DB, raises 401 on failure.
|
||||
- ``get_current_user``: decodes JWT from HttpOnly cookie (preferred) or
|
||||
Authorization header (fallback), fetches user from DB, raises 401 on failure.
|
||||
- ``require_role``: factory that returns a dependency enforcing a specific role
|
||||
(admins always pass).
|
||||
"""
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
# Import Callable from collections.abc
|
||||
from collections.abc import Callable
|
||||
|
||||
# Import Optional from typing
|
||||
from typing import Optional
|
||||
|
||||
# Import Cookie, Depends, HTTPException, status from fastapi
|
||||
from fastapi import Cookie, Depends, HTTPException, status
|
||||
|
||||
# Import OAuth2PasswordBearer from fastapi.security
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from jose import JWTError, jwt
|
||||
|
||||
# Import jwt (PyJWT)
|
||||
import jwt
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import auth as auth_lib from app
|
||||
from app import auth as auth_lib
|
||||
|
||||
# Import settings from app.config
|
||||
from app.config import settings
|
||||
|
||||
# Import get_db from app.database
|
||||
from app.database import get_db
|
||||
|
||||
# Import User from app.models.user
|
||||
from app.models.user import User
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OAuth2 scheme
|
||||
# OAuth2 scheme (reads Authorization header — used as fallback / Swagger UI)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login", auto_error=False)
|
||||
|
||||
# Cookie name — must match the one set in the auth router
|
||||
_COOKIE_NAME = "aegis_token"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Current-user dependency
|
||||
@@ -28,38 +52,86 @@ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
token: str = Depends(oauth2_scheme),
|
||||
# Entry: aegis_token
|
||||
aegis_token: Optional[str] = Cookie(None),
|
||||
# Entry: bearer_token
|
||||
bearer_token: Optional[str] = Depends(oauth2_scheme),
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
) -> User:
|
||||
"""Decode the JWT *token*, look up the user in *db*, and return it.
|
||||
"""Decode the JWT, look up the user in *db*, and return it.
|
||||
|
||||
Token resolution order:
|
||||
1. ``aegis_token`` **HttpOnly cookie** (preferred — immune to XSS).
|
||||
2. ``Authorization: Bearer <token>`` header (fallback for API clients
|
||||
and Swagger UI).
|
||||
|
||||
Raises :class:`~fastapi.HTTPException` **401** when:
|
||||
- no token is found in either location,
|
||||
- the token cannot be decoded,
|
||||
- the ``sub`` claim is missing, or
|
||||
- no matching active user exists in the database.
|
||||
"""
|
||||
# Assign credentials_exception = HTTPException(
|
||||
credentials_exception = HTTPException(
|
||||
# Keyword argument: status_code
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
# Keyword argument: detail
|
||||
detail="Could not validate credentials",
|
||||
# Keyword argument: headers
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
# Assign revoked_exception = HTTPException(
|
||||
revoked_exception = HTTPException(
|
||||
# Keyword argument: status_code
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
# Keyword argument: detail
|
||||
detail="Token has been revoked",
|
||||
# Keyword argument: headers
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
# Prefer cookie, fall back to header
|
||||
token = aegis_token or bearer_token
|
||||
# Check: token is None
|
||||
if token is None:
|
||||
# Raise credentials_exception
|
||||
raise credentials_exception
|
||||
|
||||
# Attempt the following; catch errors below
|
||||
try:
|
||||
# Assign payload = jwt.decode(
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.SECRET_KEY,
|
||||
# Keyword argument: algorithms
|
||||
algorithms=[settings.ALGORITHM],
|
||||
)
|
||||
# Assign username = payload.get("sub")
|
||||
username: str | None = payload.get("sub")
|
||||
# Check: username is None
|
||||
if username is None:
|
||||
# Raise credentials_exception
|
||||
raise credentials_exception
|
||||
except JWTError:
|
||||
# Check token blacklist (revoked tokens)
|
||||
jti: str | None = payload.get("jti")
|
||||
# Check: jti and auth_lib.is_token_blacklisted(jti)
|
||||
if jti and auth_lib.is_token_blacklisted(jti):
|
||||
# Raise revoked_exception
|
||||
raise revoked_exception
|
||||
# Handle any JWT validation error (expired, invalid signature, malformed)
|
||||
except jwt.exceptions.InvalidTokenError:
|
||||
# Raise credentials_exception
|
||||
raise credentials_exception
|
||||
|
||||
# Assign user = db.query(User).filter(User.username == username).first()
|
||||
user = db.query(User).filter(User.username == username).first()
|
||||
if user is None:
|
||||
# Check: user is None or not user.is_active
|
||||
if user is None or not user.is_active:
|
||||
# Raise credentials_exception
|
||||
raise credentials_exception
|
||||
|
||||
# Return user
|
||||
return user
|
||||
|
||||
|
||||
@@ -68,7 +140,30 @@ async def get_current_user(
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def require_role(required_role: str):
|
||||
async def require_password_changed(
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> User:
|
||||
"""Block all requests when the user still needs to change their password.
|
||||
|
||||
Only ``/auth/change-password`` and ``/auth/me`` are exempt — those
|
||||
endpoints do **not** depend on this function.
|
||||
"""
|
||||
# Check: getattr(current_user, "must_change_password", False)
|
||||
if getattr(current_user, "must_change_password", False):
|
||||
# Raise HTTPException
|
||||
raise HTTPException(
|
||||
# Keyword argument: status_code
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
# Keyword argument: detail
|
||||
detail="PASSWORD_CHANGE_REQUIRED",
|
||||
)
|
||||
# Return current_user
|
||||
return current_user
|
||||
|
||||
|
||||
# Define function require_role
|
||||
def require_role(required_role: str) -> Callable[..., object]:
|
||||
"""Return a FastAPI dependency that enforces *required_role*.
|
||||
|
||||
The dependency allows the request to proceed when
|
||||
@@ -76,20 +171,29 @@ def require_role(required_role: str):
|
||||
Otherwise it raises :class:`~fastapi.HTTPException` **403**.
|
||||
"""
|
||||
|
||||
# Define async function role_checker
|
||||
async def role_checker(
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> User:
|
||||
# Check: current_user.role != required_role and current_user.role != "admin"
|
||||
if current_user.role != required_role and current_user.role != "admin":
|
||||
# Raise HTTPException
|
||||
raise HTTPException(
|
||||
# Keyword argument: status_code
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
# Keyword argument: detail
|
||||
detail="Not enough permissions",
|
||||
)
|
||||
# Return current_user
|
||||
return current_user
|
||||
|
||||
# Return role_checker
|
||||
return role_checker
|
||||
|
||||
|
||||
def require_any_role(*roles: str):
|
||||
# Define function require_any_role
|
||||
def require_any_role(*roles: str) -> Callable[..., object]:
|
||||
"""Return a FastAPI dependency that enforces **any** of the given *roles*.
|
||||
|
||||
Admins always pass. Usage example::
|
||||
@@ -97,14 +201,22 @@ def require_any_role(*roles: str):
|
||||
@router.patch("/resource", dependencies=[Depends(require_any_role("red_lead", "blue_lead"))])
|
||||
"""
|
||||
|
||||
# Define async function role_checker
|
||||
async def role_checker(
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> User:
|
||||
# Check: current_user.role != "admin" and current_user.role not in roles
|
||||
if current_user.role != "admin" and current_user.role not in roles:
|
||||
# Raise HTTPException
|
||||
raise HTTPException(
|
||||
# Keyword argument: status_code
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
# Keyword argument: detail
|
||||
detail="Not enough permissions",
|
||||
)
|
||||
# Return current_user
|
||||
return current_user
|
||||
|
||||
# Return role_checker
|
||||
return role_checker
|
||||
|
||||
@@ -0,0 +1,44 @@
|
||||
"""FastAPI dependency providers for repositories.
|
||||
|
||||
Wiring lives ONLY in the presentation layer — use cases and services
|
||||
never know which concrete repository implementation they receive.
|
||||
"""
|
||||
|
||||
# Import Depends from fastapi
|
||||
from fastapi import Depends
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import get_db from app.database
|
||||
from app.database import get_db
|
||||
|
||||
# Import from app.infrastructure.persistence.repositories.sa_technique_repository
|
||||
from app.infrastructure.persistence.repositories.sa_technique_repository import (
|
||||
SATechniqueRepository,
|
||||
)
|
||||
|
||||
# Import from app.infrastructure.persistence.repositories.sa_test_repository
|
||||
from app.infrastructure.persistence.repositories.sa_test_repository import (
|
||||
SATestRepository,
|
||||
)
|
||||
|
||||
|
||||
# Define function get_technique_repository
|
||||
def get_technique_repository(
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
) -> SATechniqueRepository:
|
||||
"""Provide a TechniqueRepository backed by the current DB session."""
|
||||
# Return SATechniqueRepository(db)
|
||||
return SATechniqueRepository(db)
|
||||
|
||||
|
||||
# Define function get_test_repository
|
||||
def get_test_repository(
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
) -> SATestRepository:
|
||||
"""Provide a TestRepository backed by the current DB session."""
|
||||
# Return SATestRepository(db)
|
||||
return SATestRepository(db)
|
||||
@@ -0,0 +1 @@
|
||||
"""Domain layer — entities, value objects, errors, and repository ports."""
|
||||
@@ -0,0 +1,34 @@
|
||||
"""Domain entity classes representing core business objects."""
|
||||
# Import CampaignEntity from app.domain.entities.campaign
|
||||
from app.domain.entities.campaign import CampaignEntity
|
||||
|
||||
# Import from app.domain.entities.compliance
|
||||
from app.domain.entities.compliance import (
|
||||
ComplianceControlEntity,
|
||||
ComplianceFrameworkEntity,
|
||||
ControlCoverageStatus,
|
||||
)
|
||||
|
||||
# Import TechniqueEntity from app.domain.entities.technique
|
||||
from app.domain.entities.technique import TechniqueEntity
|
||||
|
||||
# Import ThreatActorEntity, ThreatActorTechniqueRef from app.domain.entities.threat_actor
|
||||
from app.domain.entities.threat_actor import ThreatActorEntity, ThreatActorTechniqueRef
|
||||
|
||||
# Assign __all__ = [
|
||||
__all__ = [
|
||||
# Literal argument value
|
||||
"CampaignEntity",
|
||||
# Literal argument value
|
||||
"ComplianceControlEntity",
|
||||
# Literal argument value
|
||||
"ComplianceFrameworkEntity",
|
||||
# Literal argument value
|
||||
"ControlCoverageStatus",
|
||||
# Literal argument value
|
||||
"TechniqueEntity",
|
||||
# Literal argument value
|
||||
"ThreatActorEntity",
|
||||
# Literal argument value
|
||||
"ThreatActorTechniqueRef",
|
||||
]
|
||||
@@ -0,0 +1,219 @@
|
||||
"""Campaign domain entity with lifecycle validation.
|
||||
|
||||
Pure domain logic — no framework imports.
|
||||
"""
|
||||
|
||||
# Enable future language features for compatibility
|
||||
from __future__ import annotations
|
||||
|
||||
# Import enum
|
||||
import enum
|
||||
|
||||
# Import uuid
|
||||
import uuid
|
||||
|
||||
# Import dataclass, field from dataclasses
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
# Import TYPE_CHECKING from typing
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
# Import BusinessRuleViolation, InvalidStateTransition from app.domain.errors
|
||||
from app.domain.errors import BusinessRuleViolation, InvalidStateTransition
|
||||
|
||||
# Check: TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
# Import Campaign as CampaignORM from app.models.campaign
|
||||
from app.models.campaign import Campaign as CampaignORM
|
||||
|
||||
|
||||
# Define class CampaignStatus
|
||||
class CampaignStatus(str, enum.Enum):
|
||||
"""Lifecycle states for a campaign."""
|
||||
|
||||
# Assign draft = "draft"
|
||||
draft = "draft"
|
||||
# Assign active = "active"
|
||||
active = "active"
|
||||
# Assign completed = "completed"
|
||||
completed = "completed"
|
||||
# Assign archived = "archived"
|
||||
archived = "archived"
|
||||
|
||||
|
||||
# Define class CampaignType
|
||||
class CampaignType(str, enum.Enum):
|
||||
"""Classification of the campaign's testing methodology."""
|
||||
|
||||
# Assign custom = "custom"
|
||||
custom = "custom"
|
||||
# Assign apt_emulation = "apt_emulation"
|
||||
apt_emulation = "apt_emulation"
|
||||
# Assign kill_chain = "kill_chain"
|
||||
kill_chain = "kill_chain"
|
||||
# Assign compliance = "compliance"
|
||||
compliance = "compliance"
|
||||
|
||||
|
||||
# Assign VALID_TRANSITIONS = {
|
||||
VALID_TRANSITIONS: dict[CampaignStatus, list[CampaignStatus]] = {
|
||||
CampaignStatus.draft: [CampaignStatus.active],
|
||||
CampaignStatus.active: [CampaignStatus.completed],
|
||||
CampaignStatus.completed: [CampaignStatus.archived],
|
||||
CampaignStatus.archived: [],
|
||||
}
|
||||
|
||||
|
||||
# Apply the @dataclass decorator
|
||||
@dataclass
|
||||
# Define class CampaignEntity
|
||||
class CampaignEntity:
|
||||
"""Pure domain representation of a security testing campaign.
|
||||
|
||||
Owns all lifecycle state-machine logic for campaign activation,
|
||||
completion, and archival.
|
||||
"""
|
||||
|
||||
# name: str
|
||||
name: str
|
||||
# Assign type = CampaignType.custom
|
||||
type: CampaignType = CampaignType.custom
|
||||
# Assign status = CampaignStatus.draft
|
||||
status: CampaignStatus = CampaignStatus.draft
|
||||
# Assign id = None
|
||||
id: uuid.UUID | None = None
|
||||
# Assign description = None
|
||||
description: str | None = None
|
||||
# Assign threat_actor_id = None
|
||||
threat_actor_id: uuid.UUID | None = None
|
||||
# Assign created_by = None
|
||||
created_by: uuid.UUID | None = None
|
||||
# Assign target_platform = None
|
||||
target_platform: str | None = None
|
||||
# Assign tags = field(default_factory=list)
|
||||
tags: list[str] = field(default_factory=list)
|
||||
# Assign test_count = 0
|
||||
test_count: int = 0
|
||||
|
||||
# Define function can_transition_to
|
||||
def can_transition_to(self, target: CampaignStatus) -> bool:
|
||||
"""Check whether transitioning from the current status to *target* is valid.
|
||||
|
||||
Args:
|
||||
target (CampaignStatus): The desired next status.
|
||||
|
||||
Returns:
|
||||
bool: True if the transition is allowed, False otherwise.
|
||||
"""
|
||||
# Return target in VALID_TRANSITIONS.get(self.status, [])
|
||||
return target in VALID_TRANSITIONS.get(self.status, [])
|
||||
|
||||
# Define function activate
|
||||
def activate(self) -> None:
|
||||
"""Transition the campaign from ``draft`` to ``active``.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Check: not self.can_transition_to(CampaignStatus.active)
|
||||
if not self.can_transition_to(CampaignStatus.active):
|
||||
# Raise InvalidStateTransition
|
||||
raise InvalidStateTransition(
|
||||
self.status.value, CampaignStatus.active.value,
|
||||
[s.value for s in VALID_TRANSITIONS[self.status]],
|
||||
)
|
||||
# Check: self.test_count == 0
|
||||
if self.test_count == 0:
|
||||
# Raise BusinessRuleViolation
|
||||
raise BusinessRuleViolation(
|
||||
# Literal argument value
|
||||
"Campaign must have at least one test to activate"
|
||||
)
|
||||
# Assign self.status = CampaignStatus.active
|
||||
self.status = CampaignStatus.active
|
||||
|
||||
# Define function complete
|
||||
def complete(self) -> None:
|
||||
"""Transition the campaign from ``active`` to ``completed``.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Check: not self.can_transition_to(CampaignStatus.completed)
|
||||
if not self.can_transition_to(CampaignStatus.completed):
|
||||
# Raise InvalidStateTransition
|
||||
raise InvalidStateTransition(
|
||||
self.status.value, CampaignStatus.completed.value,
|
||||
[s.value for s in VALID_TRANSITIONS[self.status]],
|
||||
)
|
||||
# Assign self.status = CampaignStatus.completed
|
||||
self.status = CampaignStatus.completed
|
||||
|
||||
# Define function archive
|
||||
def archive(self) -> None:
|
||||
"""Transition the campaign from ``completed`` to ``archived``.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Check: not self.can_transition_to(CampaignStatus.archived)
|
||||
if not self.can_transition_to(CampaignStatus.archived):
|
||||
# Raise InvalidStateTransition
|
||||
raise InvalidStateTransition(
|
||||
self.status.value, CampaignStatus.archived.value,
|
||||
[s.value for s in VALID_TRANSITIONS[self.status]],
|
||||
)
|
||||
# Assign self.status = CampaignStatus.archived
|
||||
self.status = CampaignStatus.archived
|
||||
|
||||
# Define function ensure_modifiable
|
||||
def ensure_modifiable(self) -> None:
|
||||
"""Raise BusinessRuleViolation if the campaign is not in a modifiable state.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Check: self.status not in (CampaignStatus.draft, CampaignStatus.active)
|
||||
if self.status not in (CampaignStatus.draft, CampaignStatus.active):
|
||||
# Raise BusinessRuleViolation
|
||||
raise BusinessRuleViolation(
|
||||
f"Cannot modify campaign in '{self.status.value}' state"
|
||||
)
|
||||
|
||||
# Apply the @classmethod decorator
|
||||
@classmethod
|
||||
# Define function from_orm
|
||||
def from_orm(cls, orm: CampaignORM) -> CampaignEntity:
|
||||
"""Build a CampaignEntity from a SQLAlchemy Campaign model.
|
||||
|
||||
Args:
|
||||
orm (CampaignORM): The SQLAlchemy Campaign ORM model instance.
|
||||
|
||||
Returns:
|
||||
CampaignEntity: A fully populated domain entity reflecting the ORM state.
|
||||
"""
|
||||
# Assign test_count = len(getattr(orm, "campaign_tests", None) or [])
|
||||
test_count = len(getattr(orm, "campaign_tests", None) or [])
|
||||
# Return cls(
|
||||
return cls(
|
||||
# Keyword argument: id
|
||||
id=orm.id,
|
||||
# Keyword argument: name
|
||||
name=orm.name,
|
||||
# Keyword argument: type
|
||||
type=CampaignType(orm.type) if orm.type else CampaignType.custom,
|
||||
# Keyword argument: status
|
||||
status=CampaignStatus(orm.status) if orm.status else CampaignStatus.draft,
|
||||
# Keyword argument: description
|
||||
description=orm.description,
|
||||
# Keyword argument: threat_actor_id
|
||||
threat_actor_id=orm.threat_actor_id,
|
||||
# Keyword argument: created_by
|
||||
created_by=orm.created_by,
|
||||
# Keyword argument: target_platform
|
||||
target_platform=orm.target_platform,
|
||||
# Keyword argument: tags
|
||||
tags=orm.tags or [],
|
||||
# Keyword argument: test_count
|
||||
test_count=test_count,
|
||||
)
|
||||
@@ -0,0 +1,164 @@
|
||||
"""Compliance domain entities with coverage calculation logic.
|
||||
|
||||
Pure domain logic — no framework imports.
|
||||
"""
|
||||
|
||||
# Enable future language features for compatibility
|
||||
from __future__ import annotations
|
||||
|
||||
# Import enum
|
||||
import enum
|
||||
|
||||
# Import uuid
|
||||
import uuid
|
||||
|
||||
# Import dataclass, field from dataclasses
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
# Define class ControlCoverageStatus
|
||||
class ControlCoverageStatus(str, enum.Enum):
|
||||
"""Computed coverage level for a single compliance control."""
|
||||
|
||||
# Assign covered = "covered"
|
||||
covered = "covered"
|
||||
# Assign partially_covered = "partially_covered"
|
||||
partially_covered = "partially_covered"
|
||||
# Assign not_covered = "not_covered"
|
||||
not_covered = "not_covered"
|
||||
|
||||
|
||||
# Apply the @dataclass decorator
|
||||
@dataclass
|
||||
# Define class ComplianceControlEntity
|
||||
class ComplianceControlEntity:
|
||||
"""Pure domain representation of a single compliance framework control.
|
||||
|
||||
Derives its coverage status from the technique statuses associated
|
||||
with it via the ``technique_statuses`` list.
|
||||
"""
|
||||
|
||||
# control_id: str
|
||||
control_id: str
|
||||
# title: str
|
||||
title: str
|
||||
# Assign id = None
|
||||
id: uuid.UUID | None = None
|
||||
# Assign description = None
|
||||
description: str | None = None
|
||||
# Assign category = None
|
||||
category: str | None = None
|
||||
# Assign technique_statuses = field(default_factory=list)
|
||||
technique_statuses: list[str] = field(default_factory=list)
|
||||
|
||||
# Apply the @property decorator
|
||||
@property
|
||||
# Define function coverage_status
|
||||
def coverage_status(self) -> ControlCoverageStatus:
|
||||
"""Compute the coverage status for this control based on linked technique statuses.
|
||||
|
||||
Returns:
|
||||
ControlCoverageStatus: ``covered`` when all techniques are covered,
|
||||
``partially_covered`` when at least one is covered, and
|
||||
``not_covered`` when none are covered or the control has no techniques.
|
||||
"""
|
||||
# Check: not self.technique_statuses
|
||||
if not self.technique_statuses:
|
||||
# Return ControlCoverageStatus.not_covered
|
||||
return ControlCoverageStatus.not_covered
|
||||
# Assign covered_statuses = {"validated", "partial"}
|
||||
covered_statuses = {"validated", "partial"}
|
||||
# Assign covered = [s for s in self.technique_statuses if s in covered_statuses]
|
||||
covered = [s for s in self.technique_statuses if s in covered_statuses]
|
||||
# Check: len(covered) == len(self.technique_statuses)
|
||||
if len(covered) == len(self.technique_statuses):
|
||||
# Return ControlCoverageStatus.covered
|
||||
return ControlCoverageStatus.covered
|
||||
# Alternative: len(covered) > 0
|
||||
elif len(covered) > 0:
|
||||
# Return ControlCoverageStatus.partially_covered
|
||||
return ControlCoverageStatus.partially_covered
|
||||
# Return ControlCoverageStatus.not_covered
|
||||
return ControlCoverageStatus.not_covered
|
||||
|
||||
|
||||
# Apply the @dataclass decorator
|
||||
@dataclass
|
||||
# Define class ComplianceFrameworkEntity
|
||||
class ComplianceFrameworkEntity:
|
||||
"""Pure domain representation of a compliance framework (e.g. NIST 800-53, PCI-DSS).
|
||||
|
||||
Aggregates a collection of controls and provides aggregate coverage statistics.
|
||||
"""
|
||||
|
||||
# name: str
|
||||
name: str
|
||||
# Assign id = None
|
||||
id: uuid.UUID | None = None
|
||||
# Assign version = None
|
||||
version: str | None = None
|
||||
# Assign description = None
|
||||
description: str | None = None
|
||||
# Assign is_active = True
|
||||
is_active: bool = True
|
||||
# Assign controls = field(default_factory=list)
|
||||
controls: list[ComplianceControlEntity] = field(default_factory=list)
|
||||
|
||||
# Apply the @property decorator
|
||||
@property
|
||||
# Define function total_controls
|
||||
def total_controls(self) -> int:
|
||||
"""Return the total number of controls in this framework.
|
||||
|
||||
Returns:
|
||||
int: Count of all controls regardless of coverage status.
|
||||
"""
|
||||
# Return len(self.controls)
|
||||
return len(self.controls)
|
||||
|
||||
# Apply the @property decorator
|
||||
@property
|
||||
# Define function covered_controls
|
||||
def covered_controls(self) -> int:
|
||||
"""Return the number of fully covered controls in this framework.
|
||||
|
||||
Returns:
|
||||
int: Count of controls with ``ControlCoverageStatus.covered`` status.
|
||||
"""
|
||||
# Return sum(
|
||||
return sum(
|
||||
# Literal argument value
|
||||
1 for c in self.controls
|
||||
if c.coverage_status == ControlCoverageStatus.covered
|
||||
)
|
||||
|
||||
# Apply the @property decorator
|
||||
@property
|
||||
# Define function coverage_pct
|
||||
def coverage_pct(self) -> float:
|
||||
"""Return the percentage of controls that are fully covered.
|
||||
|
||||
Returns:
|
||||
float: A value from 0.0 to 100.0, rounded to one decimal place.
|
||||
Returns 0.0 when the framework has no controls.
|
||||
"""
|
||||
# Check: self.total_controls == 0
|
||||
if self.total_controls == 0:
|
||||
# Return 0.0
|
||||
return 0.0
|
||||
# Return round(self.covered_controls / self.total_controls * 100, 1)
|
||||
return round(self.covered_controls / self.total_controls * 100, 1)
|
||||
|
||||
# Define function get_gap_controls
|
||||
def get_gap_controls(self) -> list[ComplianceControlEntity]:
|
||||
"""Return controls that are not fully covered.
|
||||
|
||||
Returns:
|
||||
list[ComplianceControlEntity]: Controls with ``partially_covered`` or
|
||||
``not_covered`` status.
|
||||
"""
|
||||
# Return [
|
||||
return [
|
||||
c for c in self.controls
|
||||
if c.coverage_status != ControlCoverageStatus.covered
|
||||
]
|
||||
@@ -0,0 +1,310 @@
|
||||
"""TechniqueEntity — pure domain object for a MITRE ATT&CK technique.
|
||||
|
||||
Owns the status recalculation logic that was previously in
|
||||
``status_service.py``. Has **no** dependency on FastAPI, SQLAlchemy,
|
||||
or any infrastructure concern.
|
||||
|
||||
Usage::
|
||||
|
||||
entity = TechniqueEntity.from_orm(technique_orm_model)
|
||||
entity.recalculate_status(test_states_and_results)
|
||||
entity.mark_reviewed()
|
||||
entity.apply_to(technique_orm_model)
|
||||
"""
|
||||
|
||||
# Enable future language features for compatibility
|
||||
from __future__ import annotations
|
||||
|
||||
# Import uuid
|
||||
import uuid
|
||||
|
||||
# Import dataclass, field from dataclasses
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
# Import datetime from datetime
|
||||
from datetime import datetime
|
||||
|
||||
# Import TYPE_CHECKING from typing
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
# Import TechniqueStatus, TestResult, TestState from app.domain.enums
|
||||
from app.domain.enums import TechniqueStatus, TestResult, TestState
|
||||
|
||||
# Import MitreId from app.domain.value_objects.mitre_id
|
||||
from app.domain.value_objects.mitre_id import MitreId
|
||||
|
||||
# Check: TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
# Import Technique as TechniqueORM from app.models.technique
|
||||
from app.models.technique import Technique as TechniqueORM
|
||||
|
||||
|
||||
# Apply the @dataclass decorator
|
||||
@dataclass(frozen=True)
|
||||
# Define class _TestSnapshot
|
||||
class _TestSnapshot:
|
||||
"""Minimal read-only view of a test for status calculation."""
|
||||
|
||||
# state: TestState
|
||||
state: TestState
|
||||
# detection_result: str | None
|
||||
detection_result: str | None
|
||||
|
||||
|
||||
# Apply the @dataclass decorator
|
||||
@dataclass
|
||||
# Define class TechniqueEntity
|
||||
class TechniqueEntity:
|
||||
"""Pure domain representation of a MITRE ATT&CK technique."""
|
||||
|
||||
# id: uuid.UUID
|
||||
id: uuid.UUID
|
||||
# mitre_id: str
|
||||
mitre_id: str
|
||||
# name: str
|
||||
name: str
|
||||
# Assign tactic = None
|
||||
tactic: str | None = None
|
||||
# Assign description = None
|
||||
description: str | None = None
|
||||
# Assign platforms = field(default_factory=list)
|
||||
platforms: list[str] = field(default_factory=list)
|
||||
# Assign is_subtechnique = False
|
||||
is_subtechnique: bool = False
|
||||
# Assign parent_mitre_id = None
|
||||
parent_mitre_id: str | None = None
|
||||
# Assign status_global = TechniqueStatus.not_evaluated
|
||||
status_global: TechniqueStatus = TechniqueStatus.not_evaluated
|
||||
# Assign review_required = False
|
||||
review_required: bool = False
|
||||
# Assign last_review_date = None
|
||||
last_review_date: datetime | None = None
|
||||
# Assign mitre_version = None
|
||||
mitre_version: str | None = None
|
||||
# Assign mitre_last_modified = None
|
||||
mitre_last_modified: datetime | None = None
|
||||
|
||||
# -- Factory -----------------------------------------------------------
|
||||
|
||||
@classmethod
|
||||
# Define function create
|
||||
def create(
|
||||
cls,
|
||||
*,
|
||||
# Entry: mitre_id
|
||||
mitre_id: str,
|
||||
# Entry: name
|
||||
name: str,
|
||||
# Entry: tactic
|
||||
tactic: str | None = None,
|
||||
# Entry: description
|
||||
description: str | None = None,
|
||||
# Entry: platforms
|
||||
platforms: list[str] | None = None,
|
||||
) -> TechniqueEntity:
|
||||
"""Create a new technique, validating the MITRE ID format.
|
||||
|
||||
Args:
|
||||
mitre_id (str): MITRE ATT&CK identifier (e.g. ``"T1059"`` or ``"T1059.001"``).
|
||||
name (str): Human-readable name of the technique.
|
||||
tactic (str | None): MITRE tactic category the technique belongs to.
|
||||
description (str | None): Optional free-text description.
|
||||
platforms (list[str] | None): List of platform strings the technique applies to.
|
||||
|
||||
Returns:
|
||||
TechniqueEntity: A new entity with a freshly generated UUID and
|
||||
``status_global`` set to ``not_evaluated``.
|
||||
"""
|
||||
# Assign validated_id = MitreId(mitre_id)
|
||||
validated_id = MitreId(mitre_id)
|
||||
# Return cls(
|
||||
return cls(
|
||||
# Keyword argument: id
|
||||
id=uuid.uuid4(),
|
||||
# Keyword argument: mitre_id
|
||||
mitre_id=validated_id.value,
|
||||
# Keyword argument: name
|
||||
name=name,
|
||||
# Keyword argument: tactic
|
||||
tactic=tactic,
|
||||
# Keyword argument: description
|
||||
description=description,
|
||||
# Keyword argument: platforms
|
||||
platforms=platforms or [],
|
||||
# Keyword argument: is_subtechnique
|
||||
is_subtechnique=validated_id.is_subtechnique,
|
||||
# Keyword argument: parent_mitre_id
|
||||
parent_mitre_id=validated_id.parent_id,
|
||||
# Keyword argument: status_global
|
||||
status_global=TechniqueStatus.not_evaluated,
|
||||
)
|
||||
|
||||
# Apply the @classmethod decorator
|
||||
@classmethod
|
||||
# Define function from_orm
|
||||
def from_orm(cls, model: TechniqueORM) -> TechniqueEntity:
|
||||
"""Build a TechniqueEntity from a SQLAlchemy Technique model.
|
||||
|
||||
Args:
|
||||
model (TechniqueORM): The ORM model instance to convert.
|
||||
|
||||
Returns:
|
||||
TechniqueEntity: A fully populated domain entity reflecting the ORM state.
|
||||
"""
|
||||
# Assign raw_status = model.status_global
|
||||
raw_status = model.status_global
|
||||
# Check: raw_status is None
|
||||
if raw_status is None:
|
||||
# Assign status = TechniqueStatus.not_evaluated
|
||||
status = TechniqueStatus.not_evaluated
|
||||
# Alternative: isinstance(raw_status, TechniqueStatus)
|
||||
elif isinstance(raw_status, TechniqueStatus):
|
||||
# Assign status = raw_status
|
||||
status = raw_status
|
||||
# Fallback: handle remaining cases
|
||||
else:
|
||||
# Assign status = TechniqueStatus(raw_status)
|
||||
status = TechniqueStatus(raw_status)
|
||||
# Return cls(
|
||||
return cls(
|
||||
# Keyword argument: id
|
||||
id=model.id,
|
||||
# Keyword argument: mitre_id
|
||||
mitre_id=model.mitre_id,
|
||||
# Keyword argument: name
|
||||
name=model.name,
|
||||
# Keyword argument: tactic
|
||||
tactic=model.tactic,
|
||||
# Keyword argument: description
|
||||
description=model.description,
|
||||
# Keyword argument: platforms
|
||||
platforms=model.platforms or [],
|
||||
# Keyword argument: is_subtechnique
|
||||
is_subtechnique=model.is_subtechnique or False,
|
||||
# Keyword argument: parent_mitre_id
|
||||
parent_mitre_id=model.parent_mitre_id,
|
||||
# Keyword argument: status_global
|
||||
status_global=status,
|
||||
# Keyword argument: review_required
|
||||
review_required=model.review_required or False,
|
||||
# Keyword argument: last_review_date
|
||||
last_review_date=model.last_review_date,
|
||||
# Keyword argument: mitre_version
|
||||
mitre_version=getattr(model, "mitre_version", None),
|
||||
# Keyword argument: mitre_last_modified
|
||||
mitre_last_modified=getattr(model, "mitre_last_modified", None),
|
||||
)
|
||||
|
||||
# Define function apply_to
|
||||
def apply_to(self, model: TechniqueORM) -> None:
|
||||
"""Copy mutable fields back onto the ORM model.
|
||||
|
||||
Args:
|
||||
model (TechniqueORM): The ORM model to update in-place.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Assign model.status_global = self.status_global
|
||||
model.status_global = self.status_global
|
||||
# Assign model.review_required = self.review_required
|
||||
model.review_required = self.review_required
|
||||
# Assign model.last_review_date = self.last_review_date
|
||||
model.last_review_date = self.last_review_date
|
||||
|
||||
# -- Business logic ----------------------------------------------------
|
||||
|
||||
def recalculate_status(
|
||||
self,
|
||||
# Entry: test_snapshots
|
||||
test_snapshots: list[tuple[str, str | None]],
|
||||
) -> TechniqueStatus:
|
||||
"""Recompute ``status_global`` from a list of (state, detection_result) pairs.
|
||||
|
||||
Rules (v2):
|
||||
1. No tests -> not_evaluated
|
||||
2. All validated -> inspect detection results:
|
||||
- All detected -> validated
|
||||
- Any partially_detected -> partial
|
||||
- Otherwise -> not_covered
|
||||
3. Some validated, others in progress -> partial
|
||||
4. All in intermediate states -> in_progress
|
||||
|
||||
Args:
|
||||
test_snapshots (list[tuple[str, str | None]]): Each element is a
|
||||
``(state, detection_result)`` pair where *state* is a
|
||||
:class:`TestState` value string and *detection_result* is a
|
||||
:class:`TestResult` value string or ``None``.
|
||||
|
||||
Returns:
|
||||
TechniqueStatus: The newly computed status, which is also stored on
|
||||
the entity's ``status_global`` field.
|
||||
"""
|
||||
# Assign tests = [
|
||||
tests = [
|
||||
_TestSnapshot(
|
||||
# Keyword argument: state
|
||||
state=s if isinstance(s, TestState) else TestState(s),
|
||||
# Keyword argument: detection_result
|
||||
detection_result=dr,
|
||||
)
|
||||
for s, dr in test_snapshots
|
||||
]
|
||||
|
||||
# Check: not tests
|
||||
if not tests:
|
||||
# Assign self.status_global = TechniqueStatus.not_evaluated
|
||||
self.status_global = TechniqueStatus.not_evaluated
|
||||
# Alternative: all(t.state == TestState.validated for t in tests)
|
||||
elif all(t.state == TestState.validated for t in tests):
|
||||
# Assign results = [t.detection_result for t in tests if t.detection_result]
|
||||
results = [t.detection_result for t in tests if t.detection_result]
|
||||
# Check: results and all(r == TestResult.detected or r == "detected" for r i...
|
||||
if results and all(r == TestResult.detected or r == "detected" for r in results):
|
||||
# Assign self.status_global = TechniqueStatus.validated
|
||||
self.status_global = TechniqueStatus.validated
|
||||
# elif any(
|
||||
elif any(
|
||||
# Keyword argument: r
|
||||
r == TestResult.partially_detected or r == "partially_detected"
|
||||
for r in results
|
||||
):
|
||||
# Assign self.status_global = TechniqueStatus.partial
|
||||
self.status_global = TechniqueStatus.partial
|
||||
# Fallback: handle remaining cases
|
||||
else:
|
||||
# Assign self.status_global = TechniqueStatus.not_covered
|
||||
self.status_global = TechniqueStatus.not_covered
|
||||
# Alternative: any(t.state == TestState.validated for t in tests)
|
||||
elif any(t.state == TestState.validated for t in tests):
|
||||
# Assign self.status_global = TechniqueStatus.partial
|
||||
self.status_global = TechniqueStatus.partial
|
||||
# Fallback: handle remaining cases
|
||||
else:
|
||||
# Assign self.status_global = TechniqueStatus.in_progress
|
||||
self.status_global = TechniqueStatus.in_progress
|
||||
|
||||
# Return self.status_global
|
||||
return self.status_global
|
||||
|
||||
# Define function mark_reviewed
|
||||
def mark_reviewed(self) -> None:
|
||||
"""Mark the technique as reviewed, clearing the review flag.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Assign self.review_required = False
|
||||
self.review_required = False
|
||||
# Assign self.last_review_date = datetime.utcnow()
|
||||
self.last_review_date = datetime.utcnow()
|
||||
|
||||
# Define function flag_for_review
|
||||
def flag_for_review(self) -> None:
|
||||
"""Flag the technique as needing review.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Assign self.review_required = True
|
||||
self.review_required = True
|
||||
@@ -0,0 +1,206 @@
|
||||
"""Threat actor domain entity with coverage analysis logic.
|
||||
|
||||
Pure domain logic — no framework imports.
|
||||
"""
|
||||
|
||||
# Enable future language features for compatibility
|
||||
from __future__ import annotations
|
||||
|
||||
# Import uuid
|
||||
import uuid
|
||||
|
||||
# Import dataclass, field from dataclasses
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
# Import TYPE_CHECKING from typing
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
# Check: TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
# Import ThreatActor as ThreatActorORM from app.models.threat_actor
|
||||
from app.models.threat_actor import ThreatActor as ThreatActorORM
|
||||
|
||||
|
||||
# Apply the @dataclass decorator
|
||||
@dataclass
|
||||
# Define class ThreatActorTechniqueRef
|
||||
class ThreatActorTechniqueRef:
|
||||
"""Lightweight reference to a technique used by an actor."""
|
||||
|
||||
# technique_id: uuid.UUID
|
||||
technique_id: uuid.UUID
|
||||
# Assign mitre_id = None
|
||||
mitre_id: str | None = None
|
||||
# Assign name = None
|
||||
name: str | None = None
|
||||
# Assign status = None
|
||||
status: str | None = None
|
||||
# Assign usage_description = None
|
||||
usage_description: str | None = None
|
||||
|
||||
|
||||
# Apply the @dataclass decorator
|
||||
@dataclass
|
||||
# Define class ThreatActorEntity
|
||||
class ThreatActorEntity:
|
||||
"""Pure domain representation of a MITRE ATT&CK threat actor (group).
|
||||
|
||||
Aggregates references to the techniques the actor is known to use and
|
||||
provides coverage analysis properties.
|
||||
"""
|
||||
|
||||
# name: str
|
||||
name: str
|
||||
# Assign id = None
|
||||
id: uuid.UUID | None = None
|
||||
# Assign mitre_id = None
|
||||
mitre_id: str | None = None
|
||||
# Assign aliases = field(default_factory=list)
|
||||
aliases: list[str] = field(default_factory=list)
|
||||
# Assign description = None
|
||||
description: str | None = None
|
||||
# Assign country = None
|
||||
country: str | None = None
|
||||
# Assign target_sectors = field(default_factory=list)
|
||||
target_sectors: list[str] = field(default_factory=list)
|
||||
# Assign target_regions = field(default_factory=list)
|
||||
target_regions: list[str] = field(default_factory=list)
|
||||
# Assign motivation = None
|
||||
motivation: str | None = None
|
||||
# Assign sophistication = None
|
||||
sophistication: str | None = None
|
||||
# Assign first_seen = None
|
||||
first_seen: str | None = None
|
||||
# Assign last_seen = None
|
||||
last_seen: str | None = None
|
||||
# Assign is_active = True
|
||||
is_active: bool = True
|
||||
# Assign techniques = field(default_factory=list)
|
||||
techniques: list[ThreatActorTechniqueRef] = field(default_factory=list)
|
||||
|
||||
# Apply the @property decorator
|
||||
@property
|
||||
# Define function technique_count
|
||||
def technique_count(self) -> int:
|
||||
"""Return the total number of techniques associated with this actor.
|
||||
|
||||
Returns:
|
||||
int: Count of technique references.
|
||||
"""
|
||||
# Return len(self.techniques)
|
||||
return len(self.techniques)
|
||||
|
||||
# Apply the @property decorator
|
||||
@property
|
||||
# Define function covered_techniques
|
||||
def covered_techniques(self) -> list[ThreatActorTechniqueRef]:
|
||||
"""Return technique references whose coverage status is ``validated`` or ``partial``.
|
||||
|
||||
Returns:
|
||||
list[ThreatActorTechniqueRef]: Subset of techniques considered covered.
|
||||
"""
|
||||
# Return [
|
||||
return [
|
||||
t for t in self.techniques
|
||||
if t.status in ("validated", "partial")
|
||||
]
|
||||
|
||||
# Apply the @property decorator
|
||||
@property
|
||||
# Define function uncovered_techniques
|
||||
def uncovered_techniques(self) -> list[ThreatActorTechniqueRef]:
|
||||
"""Return technique references whose coverage status is neither ``validated`` nor ``partial``.
|
||||
|
||||
Returns:
|
||||
list[ThreatActorTechniqueRef]: Subset of techniques not yet covered.
|
||||
"""
|
||||
# Return [
|
||||
return [
|
||||
t for t in self.techniques
|
||||
if t.status not in ("validated", "partial")
|
||||
]
|
||||
|
||||
# Apply the @property decorator
|
||||
@property
|
||||
# Define function coverage_pct
|
||||
def coverage_pct(self) -> float:
|
||||
"""Return the percentage of the actor's techniques that are covered.
|
||||
|
||||
Returns:
|
||||
float: A value from 0.0 to 100.0, rounded to one decimal place.
|
||||
Returns 0.0 when the actor has no associated techniques.
|
||||
"""
|
||||
# Check: not self.techniques
|
||||
if not self.techniques:
|
||||
# Return 0.0
|
||||
return 0.0
|
||||
# Return round(len(self.covered_techniques) / len(self.techniques) * 100, 1)
|
||||
return round(len(self.covered_techniques) / len(self.techniques) * 100, 1)
|
||||
|
||||
# Apply the @classmethod decorator
|
||||
@classmethod
|
||||
# Define function from_orm
|
||||
def from_orm(cls, orm: ThreatActorORM) -> ThreatActorEntity:
|
||||
"""Build a ThreatActorEntity from a SQLAlchemy ThreatActor model.
|
||||
|
||||
Args:
|
||||
orm (ThreatActorORM): The ORM model instance to convert.
|
||||
|
||||
Returns:
|
||||
ThreatActorEntity: A fully populated domain entity including
|
||||
technique references resolved from the ORM relationship.
|
||||
"""
|
||||
# Assign techs = []
|
||||
techs: list[ThreatActorTechniqueRef] = []
|
||||
# Iterate over getattr(orm, "techniques", None) or []
|
||||
for tat in getattr(orm, "techniques", None) or []:
|
||||
# Assign technique = getattr(tat, "technique", None)
|
||||
technique = getattr(tat, "technique", None)
|
||||
# Call techs.append()
|
||||
techs.append(ThreatActorTechniqueRef(
|
||||
# Keyword argument: technique_id
|
||||
technique_id=tat.technique_id,
|
||||
# Keyword argument: mitre_id
|
||||
mitre_id=getattr(technique, "mitre_id", None) if technique else None,
|
||||
# Keyword argument: name
|
||||
name=getattr(technique, "name", None) if technique else None,
|
||||
# Keyword argument: status
|
||||
status=(
|
||||
technique.status_global.value
|
||||
if technique and hasattr(technique.status_global, "value")
|
||||
else getattr(technique, "status_global", None) if technique else None
|
||||
),
|
||||
# Keyword argument: usage_description
|
||||
usage_description=tat.usage_description,
|
||||
))
|
||||
# Return cls(
|
||||
return cls(
|
||||
# Keyword argument: id
|
||||
id=orm.id,
|
||||
# Keyword argument: name
|
||||
name=orm.name,
|
||||
# Keyword argument: mitre_id
|
||||
mitre_id=orm.mitre_id,
|
||||
# Keyword argument: aliases
|
||||
aliases=orm.aliases or [],
|
||||
# Keyword argument: description
|
||||
description=orm.description,
|
||||
# Keyword argument: country
|
||||
country=orm.country,
|
||||
# Keyword argument: target_sectors
|
||||
target_sectors=orm.target_sectors or [],
|
||||
# Keyword argument: target_regions
|
||||
target_regions=orm.target_regions or [],
|
||||
# Keyword argument: motivation
|
||||
motivation=orm.motivation,
|
||||
# Keyword argument: sophistication
|
||||
sophistication=orm.sophistication,
|
||||
# Keyword argument: first_seen
|
||||
first_seen=orm.first_seen,
|
||||
# Keyword argument: last_seen
|
||||
last_seen=orm.last_seen,
|
||||
# Keyword argument: is_active
|
||||
is_active=orm.is_active if orm.is_active is not None else True,
|
||||
# Keyword argument: techniques
|
||||
techniques=techs,
|
||||
)
|
||||
@@ -0,0 +1,81 @@
|
||||
"""Canonical domain enums for Aegis.
|
||||
|
||||
These enums represent core domain concepts and are the single source of
|
||||
truth. ``models/enums.py`` re-exports them so that existing ORM code
|
||||
continues to work without changes.
|
||||
"""
|
||||
|
||||
# Import enum
|
||||
import enum
|
||||
|
||||
|
||||
# Define class TechniqueStatus
|
||||
class TechniqueStatus(str, enum.Enum):
|
||||
"""Coverage and evaluation status for a MITRE ATT&CK technique."""
|
||||
|
||||
# Assign not_evaluated = "not_evaluated"
|
||||
not_evaluated = "not_evaluated"
|
||||
# Assign in_progress = "in_progress"
|
||||
in_progress = "in_progress"
|
||||
# Assign validated = "validated"
|
||||
validated = "validated"
|
||||
# Assign partial = "partial"
|
||||
partial = "partial"
|
||||
# Assign not_covered = "not_covered"
|
||||
not_covered = "not_covered"
|
||||
# Assign review_required = "review_required"
|
||||
review_required = "review_required"
|
||||
|
||||
|
||||
# Define class TestState
|
||||
class TestState(str, enum.Enum):
|
||||
"""Lifecycle states in the security test state machine."""
|
||||
|
||||
# Assign draft = "draft"
|
||||
draft = "draft"
|
||||
# Assign red_executing = "red_executing"
|
||||
red_executing = "red_executing"
|
||||
# Assign blue_evaluating = "blue_evaluating"
|
||||
blue_evaluating = "blue_evaluating"
|
||||
# Assign in_review = "in_review"
|
||||
in_review = "in_review"
|
||||
# Assign validated = "validated"
|
||||
validated = "validated"
|
||||
# Assign rejected = "rejected"
|
||||
rejected = "rejected"
|
||||
|
||||
|
||||
# Define class TeamSide
|
||||
class TeamSide(str, enum.Enum):
|
||||
"""Identifies which team (red or blue) an action belongs to."""
|
||||
|
||||
# Assign red = "red"
|
||||
red = "red"
|
||||
# Assign blue = "blue"
|
||||
blue = "blue"
|
||||
|
||||
|
||||
# Define class TestResult
|
||||
class TestResult(str, enum.Enum):
|
||||
"""Outcome of a red-team test from a detection perspective."""
|
||||
|
||||
# Assign detected = "detected"
|
||||
detected = "detected"
|
||||
# Assign not_detected = "not_detected"
|
||||
not_detected = "not_detected"
|
||||
# Assign partially_detected = "partially_detected"
|
||||
partially_detected = "partially_detected"
|
||||
|
||||
|
||||
# Define class DataClassification
|
||||
class DataClassification(str, enum.Enum):
|
||||
"""Data sensitivity classification levels for compliance and retention policies."""
|
||||
|
||||
# Assign public = "public"
|
||||
public = "public"
|
||||
# Assign internal = "internal"
|
||||
internal = "internal"
|
||||
# Assign sensitive = "sensitive"
|
||||
sensitive = "sensitive"
|
||||
# Assign restricted = "restricted"
|
||||
restricted = "restricted"
|
||||
@@ -0,0 +1,192 @@
|
||||
"""Canonical domain error hierarchy for Aegis.
|
||||
|
||||
Every service-layer error should be a subclass of :class:`DomainError`.
|
||||
The global exception handler in ``app.middleware.error_handler`` maps
|
||||
each concrete subclass to an appropriate HTTP status code so that
|
||||
services never depend on FastAPI.
|
||||
|
||||
Existing code that imports from ``app.domain.exceptions`` continues to
|
||||
work — that module re-exports everything defined here.
|
||||
"""
|
||||
|
||||
# Enable future language features for compatibility
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
# Define class DomainError
|
||||
class DomainError(Exception):
|
||||
"""Base for all domain errors."""
|
||||
|
||||
# Define function __init__
|
||||
def __init__(self, message: str, *, code: str = "DOMAIN_ERROR") -> None:
|
||||
"""Initialise the domain error with a human-readable message and error code.
|
||||
|
||||
Args:
|
||||
message (str): Human-readable description of the error.
|
||||
code (str): Machine-readable error code used by the HTTP error handler.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Assign self.message = message
|
||||
self.message = message
|
||||
# Assign self.code = code
|
||||
self.code = code
|
||||
# Call super()
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
# ── Entity lifecycle ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
class EntityNotFoundError(DomainError):
|
||||
"""A requested entity does not exist."""
|
||||
|
||||
# Define function __init__
|
||||
def __init__(self, entity: str, identifier: str) -> None:
|
||||
"""Initialise an entity-not-found error.
|
||||
|
||||
Args:
|
||||
entity (str): Name of the entity type that was not found (e.g. "Technique").
|
||||
identifier (str): The ID or key used in the failed lookup.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Call super()
|
||||
super().__init__(f"{entity} not found: {identifier}", code="NOT_FOUND")
|
||||
# Assign self.entity = entity
|
||||
self.entity = entity
|
||||
# Assign self.identifier = identifier
|
||||
self.identifier = identifier
|
||||
|
||||
|
||||
# Define class DuplicateEntityError
|
||||
class DuplicateEntityError(DomainError):
|
||||
"""Creating an entity that already exists."""
|
||||
|
||||
# Define function __init__
|
||||
def __init__(self, entity: str, field: str, value: str) -> None:
|
||||
"""Initialise a duplicate-entity error.
|
||||
|
||||
Args:
|
||||
entity (str): Name of the entity type that already exists (e.g. "Campaign").
|
||||
field (str): Name of the field whose value conflicts (e.g. "name").
|
||||
value (str): The conflicting value that is already in use.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Call super()
|
||||
super().__init__(
|
||||
f"{entity} with {field}='{value}' already exists",
|
||||
# Keyword argument: code
|
||||
code="DUPLICATE",
|
||||
)
|
||||
|
||||
|
||||
# ── State machine ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class InvalidStateTransition(DomainError): # noqa: N818 — DDD term, renaming would break 96 call sites
|
||||
"""A state-machine transition is not allowed."""
|
||||
|
||||
# Define function __init__
|
||||
def __init__(
|
||||
self,
|
||||
# Entry: current_state
|
||||
current_state: str,
|
||||
# Entry: target_state
|
||||
target_state: str,
|
||||
# Entry: valid_transitions
|
||||
valid_transitions: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Initialise an invalid state-transition error.
|
||||
|
||||
Args:
|
||||
current_state (str): The entity's present state (e.g. "draft").
|
||||
target_state (str): The state that was illegally requested.
|
||||
valid_transitions (list[str] | None): Allowed target states from the
|
||||
current state; included in the error message when provided.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Assign msg = f"Cannot transition from '{current_state}' to '{target_state}'"
|
||||
msg = f"Cannot transition from '{current_state}' to '{target_state}'"
|
||||
# Check: valid_transitions
|
||||
if valid_transitions:
|
||||
# Assign msg = f". Valid transitions: {valid_transitions}"
|
||||
msg += f". Valid transitions: {valid_transitions}"
|
||||
# Call super()
|
||||
super().__init__(msg, code="INVALID_TRANSITION")
|
||||
# Assign self.current_state = current_state
|
||||
self.current_state = current_state
|
||||
# Assign self.target_state = target_state
|
||||
self.target_state = target_state
|
||||
# Assign self.valid_transitions = valid_transitions or []
|
||||
self.valid_transitions = valid_transitions or []
|
||||
|
||||
|
||||
# ── Business rules ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class BusinessRuleViolation(DomainError): # noqa: N818 — DDD term, renaming would break 96 call sites
|
||||
"""An operation violates a business invariant."""
|
||||
|
||||
# Define function __init__
|
||||
def __init__(self, message: str) -> None:
|
||||
"""Initialise a business-rule violation error.
|
||||
|
||||
Args:
|
||||
message (str): Human-readable description of the violated rule.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Call super()
|
||||
super().__init__(message, code="BUSINESS_RULE_VIOLATION")
|
||||
|
||||
|
||||
# Define class InvalidOperationError
|
||||
class InvalidOperationError(BusinessRuleViolation):
|
||||
"""An operation is invalid in the current context.
|
||||
|
||||
Kept for backward compatibility; new code should prefer
|
||||
:class:`BusinessRuleViolation` directly.
|
||||
"""
|
||||
|
||||
# Define function __init__
|
||||
def __init__(self, message: str) -> None:
|
||||
"""Initialise an invalid-operation error.
|
||||
|
||||
Args:
|
||||
message (str): Human-readable description of why the operation is invalid.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Call super()
|
||||
super().__init__(message)
|
||||
# Assign self.code = "INVALID_OPERATION"
|
||||
self.code = "INVALID_OPERATION"
|
||||
|
||||
|
||||
# ── Authorization ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class PermissionViolation(DomainError): # noqa: N818 — DDD term, renaming would break 96 call sites
|
||||
"""The user lacks permissions for an action."""
|
||||
|
||||
# Define function __init__
|
||||
def __init__(self, message: str = "Insufficient permissions") -> None:
|
||||
"""Initialise a permission-violation error.
|
||||
|
||||
Args:
|
||||
message (str): Human-readable description of the access denial.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Call super()
|
||||
super().__init__(message, code="FORBIDDEN")
|
||||
@@ -0,0 +1,25 @@
|
||||
"""Backward-compatible re-exports from :mod:`app.domain.errors`.
|
||||
|
||||
All domain errors now live in ``errors.py``. This module preserves the
|
||||
old import paths so that existing code keeps working without changes::
|
||||
|
||||
from app.domain.exceptions import InvalidTransitionError # still works
|
||||
"""
|
||||
|
||||
# Import # noqa: F401 from app.domain.errors
|
||||
from app.domain.errors import ( # noqa: F401
|
||||
BusinessRuleViolation,
|
||||
DomainError,
|
||||
DuplicateEntityError,
|
||||
EntityNotFoundError,
|
||||
InvalidOperationError,
|
||||
InvalidStateTransition,
|
||||
PermissionViolation,
|
||||
)
|
||||
|
||||
# Legacy aliases — old name → new name
|
||||
DomainException = DomainError
|
||||
# Assign InvalidTransitionError = InvalidStateTransition
|
||||
InvalidTransitionError = InvalidStateTransition
|
||||
# Assign AuthorizationError = PermissionViolation
|
||||
AuthorizationError = PermissionViolation
|
||||
@@ -0,0 +1 @@
|
||||
"""Abstract port interfaces that infrastructure adapters must implement."""
|
||||
@@ -0,0 +1,165 @@
|
||||
"""Port defining the common interface for data import services.
|
||||
|
||||
All import services (Atomic Red Team, Sigma, CALDERA, etc.) follow the
|
||||
same contract: they receive a database session and return a summary dict
|
||||
with import statistics.
|
||||
|
||||
New import sources can be added by:
|
||||
1. Implementing the ``ImportService`` protocol in a new module
|
||||
2. Registering the handler in the ``IMPORT_REGISTRY``
|
||||
|
||||
This satisfies the Open/Closed Principle — the system is open for new
|
||||
import sources without modifying existing code.
|
||||
"""
|
||||
|
||||
# Enable future language features for compatibility
|
||||
from __future__ import annotations
|
||||
|
||||
# Import Any, Protocol, runtime_checkable from typing
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
# Apply the @runtime_checkable decorator
|
||||
@runtime_checkable
|
||||
# Define class ImportService
|
||||
class ImportService(Protocol):
|
||||
"""Contract for any data-import operation.
|
||||
|
||||
Each implementation is a callable ``(Session) -> dict`` that
|
||||
downloads, parses, and upserts records from an external source.
|
||||
"""
|
||||
|
||||
# Define function __call__
|
||||
def __call__(self, db: Session) -> dict[str, Any]:
|
||||
"""Execute the import operation against the given database session.
|
||||
|
||||
Args:
|
||||
db (Session): Active SQLAlchemy session to use for all DB operations.
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: Summary statistics for the import run (e.g. created,
|
||||
updated, skipped counts).
|
||||
"""
|
||||
# ...
|
||||
...
|
||||
|
||||
|
||||
# Define class ImportServiceEntry
|
||||
class ImportServiceEntry:
|
||||
"""Lazy-loading wrapper that resolves a module-level function on first call."""
|
||||
|
||||
# Assign __slots__ = ("_module_path", "_func_name", "_resolved")
|
||||
__slots__ = ("_module_path", "_func_name", "_resolved")
|
||||
|
||||
# Define function __init__
|
||||
def __init__(self, module_path: str, func_name: str) -> None:
|
||||
"""Initialise the lazy entry with the module path and function name to resolve later.
|
||||
|
||||
Args:
|
||||
module_path (str): Dotted Python module path, e.g.
|
||||
``"app.services.atomic_import_service"``.
|
||||
func_name (str): Name of the callable to import from *module_path*.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Assign self._module_path = module_path
|
||||
self._module_path = module_path
|
||||
# Assign self._func_name = func_name
|
||||
self._func_name = func_name
|
||||
# Assign self._resolved = None
|
||||
self._resolved: ImportService | None = None
|
||||
|
||||
# Define function __call__
|
||||
def __call__(self, db: Session) -> dict[str, Any]:
|
||||
"""Resolve the import function on first call and invoke it with *db*.
|
||||
|
||||
Args:
|
||||
db (Session): SQLAlchemy session passed through to the underlying
|
||||
import function.
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: Import statistics returned by the underlying function
|
||||
(e.g. counts of created/updated/skipped records).
|
||||
"""
|
||||
# Check: self._resolved is None
|
||||
if self._resolved is None:
|
||||
# Import importlib
|
||||
import importlib
|
||||
# Assign mod = importlib.import_module(self._module_path)
|
||||
mod = importlib.import_module(self._module_path)
|
||||
# Assign self._resolved = getattr(mod, self._func_name)
|
||||
self._resolved = getattr(mod, self._func_name)
|
||||
# Return self._resolved(db)
|
||||
return self._resolved(db)
|
||||
|
||||
# Apply the @property decorator
|
||||
@property
|
||||
# Define function source_info
|
||||
def source_info(self) -> str:
|
||||
"""Return a human-readable identifier for this import entry.
|
||||
|
||||
Returns:
|
||||
str: The fully qualified function reference as
|
||||
``"<module_path>.<func_name>"``.
|
||||
"""
|
||||
# Return f"{self._module_path}.{self._func_name}"
|
||||
return f"{self._module_path}.{self._func_name}"
|
||||
|
||||
|
||||
# Assign IMPORT_REGISTRY = {
|
||||
IMPORT_REGISTRY: dict[str, ImportServiceEntry] = {
|
||||
# Literal argument value
|
||||
"atomic_red_team": ImportServiceEntry(
|
||||
# Literal argument value
|
||||
"app.services.atomic_import_service", "import_atomic_red_team",
|
||||
),
|
||||
# Literal argument value
|
||||
"sigma": ImportServiceEntry(
|
||||
# Literal argument value
|
||||
"app.services.sigma_import_service", "sync",
|
||||
),
|
||||
# Literal argument value
|
||||
"lolbas": ImportServiceEntry(
|
||||
# Literal argument value
|
||||
"app.services.lolbas_import_service", "sync",
|
||||
),
|
||||
# Literal argument value
|
||||
"gtfobins": ImportServiceEntry(
|
||||
# Literal argument value
|
||||
"app.services.lolbas_import_service", "sync_gtfobins",
|
||||
),
|
||||
# Literal argument value
|
||||
"caldera": ImportServiceEntry(
|
||||
# Literal argument value
|
||||
"app.services.caldera_import_service", "sync",
|
||||
),
|
||||
# Literal argument value
|
||||
"elastic_rules": ImportServiceEntry(
|
||||
# Literal argument value
|
||||
"app.services.elastic_import_service", "sync",
|
||||
),
|
||||
# Literal argument value
|
||||
"mitre_cti": ImportServiceEntry(
|
||||
# Literal argument value
|
||||
"app.services.threat_actor_import_service", "sync",
|
||||
),
|
||||
# Literal argument value
|
||||
"d3fend": ImportServiceEntry(
|
||||
# Literal argument value
|
||||
"app.services.d3fend_import_service", "sync",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
# Define function get_import_handler
|
||||
def get_import_handler(source_name: str) -> ImportServiceEntry | None:
|
||||
"""Look up the import handler for *source_name*.
|
||||
|
||||
Returns ``None`` when no handler is registered.
|
||||
"""
|
||||
# Return IMPORT_REGISTRY.get(source_name)
|
||||
return IMPORT_REGISTRY.get(source_name)
|
||||
@@ -0,0 +1,9 @@
|
||||
"""Abstract repository port interfaces for domain entity persistence."""
|
||||
# Import TechniqueRepository from app.domain.ports.repositories.technique_repository
|
||||
from app.domain.ports.repositories.technique_repository import TechniqueRepository
|
||||
|
||||
# Import TestRepository from app.domain.ports.repositories.test_repository
|
||||
from app.domain.ports.repositories.test_repository import TestRepository
|
||||
|
||||
# Assign __all__ = ["TechniqueRepository", "TestRepository"]
|
||||
__all__ = ["TechniqueRepository", "TestRepository"]
|
||||
@@ -0,0 +1,160 @@
|
||||
"""Port defining how the application accesses technique data.
|
||||
|
||||
This is a domain contract — implementations live in infrastructure/.
|
||||
The domain layer NEVER imports the implementation.
|
||||
"""
|
||||
|
||||
# Enable future language features for compatibility
|
||||
from __future__ import annotations
|
||||
|
||||
# Import uuid
|
||||
import uuid
|
||||
|
||||
# Import NamedTuple, Protocol, runtime_checkable from typing
|
||||
from typing import NamedTuple, Protocol, runtime_checkable
|
||||
|
||||
# Import TechniqueEntity from app.domain.entities.technique
|
||||
from app.domain.entities.technique import TechniqueEntity
|
||||
|
||||
# Import TechniqueStatus from app.domain.enums
|
||||
from app.domain.enums import TechniqueStatus
|
||||
|
||||
|
||||
# Define class TechniqueWithCounts
|
||||
class TechniqueWithCounts(NamedTuple):
|
||||
"""Pre-aggregated technique data for heatmap/scoring."""
|
||||
|
||||
# entity: TechniqueEntity
|
||||
entity: TechniqueEntity
|
||||
# test_count: int
|
||||
test_count: int
|
||||
# validated_test_count: int
|
||||
validated_test_count: int
|
||||
# detection_rule_count: int
|
||||
detection_rule_count: int
|
||||
|
||||
|
||||
# Apply the @runtime_checkable decorator
|
||||
@runtime_checkable
|
||||
# Define class TechniqueRepository
|
||||
class TechniqueRepository(Protocol):
|
||||
"""Data access contract for techniques (one per aggregate root)."""
|
||||
|
||||
# -- Single-entity access ----------------------------------------------
|
||||
|
||||
def find_by_id(self, technique_id: uuid.UUID) -> TechniqueEntity | None:
|
||||
"""Return the technique with the given primary key, or None if absent.
|
||||
|
||||
Args:
|
||||
technique_id (uuid.UUID): Primary key of the technique to look up.
|
||||
|
||||
Returns:
|
||||
TechniqueEntity | None: The matching entity, or None if not found.
|
||||
"""
|
||||
# ...
|
||||
...
|
||||
|
||||
# Define function find_by_mitre_id
|
||||
def find_by_mitre_id(self, mitre_id: str) -> TechniqueEntity | None:
|
||||
"""Return the technique matching the given MITRE ATT&CK identifier, or None.
|
||||
|
||||
Args:
|
||||
mitre_id (str): MITRE ATT&CK ID (e.g. ``"T1059"`` or ``"T1059.001"``).
|
||||
|
||||
Returns:
|
||||
TechniqueEntity | None: The matching entity, or None if not found.
|
||||
"""
|
||||
# ...
|
||||
...
|
||||
|
||||
# -- List access -------------------------------------------------------
|
||||
|
||||
def list_all(
|
||||
self,
|
||||
*,
|
||||
# Entry: tactic
|
||||
tactic: str | None = None,
|
||||
# Entry: status
|
||||
status: TechniqueStatus | None = None,
|
||||
# Entry: review_required
|
||||
review_required: bool | None = None,
|
||||
) -> list[TechniqueEntity]:
|
||||
"""Return all techniques, optionally filtered by tactic, status, or review flag.
|
||||
|
||||
Args:
|
||||
tactic (str | None): When provided, restrict results to this tactic category.
|
||||
status (TechniqueStatus | None): When provided, restrict results to this status.
|
||||
review_required (bool | None): When provided, restrict results to techniques
|
||||
whose ``review_required`` flag matches this value.
|
||||
|
||||
Returns:
|
||||
list[TechniqueEntity]: Matching technique entities; may be empty.
|
||||
"""
|
||||
# ...
|
||||
...
|
||||
|
||||
# Define function list_by_ids
|
||||
def list_by_ids(self, ids: list[uuid.UUID]) -> list[TechniqueEntity]:
|
||||
"""Return all techniques whose primary keys are in *ids*.
|
||||
|
||||
Args:
|
||||
ids (list[uuid.UUID]): List of technique UUIDs to retrieve.
|
||||
|
||||
Returns:
|
||||
list[TechniqueEntity]: Entities found for the supplied IDs; order
|
||||
is not guaranteed and missing IDs are silently omitted.
|
||||
"""
|
||||
# ...
|
||||
...
|
||||
|
||||
# -- Batch queries (scoring/heatmap performance) -----------------------
|
||||
|
||||
def count_by_status(self) -> dict[TechniqueStatus, int]:
|
||||
"""Return a count of techniques grouped by their global status.
|
||||
|
||||
Returns:
|
||||
dict[TechniqueStatus, int]: Mapping from each status value to the
|
||||
number of techniques in that state.
|
||||
"""
|
||||
# ...
|
||||
...
|
||||
|
||||
# Define function find_all_with_test_counts
|
||||
def find_all_with_test_counts(self) -> list[TechniqueWithCounts]:
|
||||
"""Return all techniques together with pre-aggregated test and rule counts.
|
||||
|
||||
Returns:
|
||||
list[TechniqueWithCounts]: Each element bundles a TechniqueEntity
|
||||
with its total, validated, and detection-rule counts for use
|
||||
in heatmap and scoring calculations.
|
||||
"""
|
||||
# ...
|
||||
...
|
||||
|
||||
# -- Mutations ---------------------------------------------------------
|
||||
|
||||
def save(self, technique: TechniqueEntity) -> TechniqueEntity:
|
||||
"""Persist a technique entity and return the saved state.
|
||||
|
||||
Args:
|
||||
technique (TechniqueEntity): The entity to create or update.
|
||||
|
||||
Returns:
|
||||
TechniqueEntity: The persisted entity, potentially with updated
|
||||
fields (e.g. server-side timestamps).
|
||||
"""
|
||||
# ...
|
||||
...
|
||||
|
||||
# Define function exists_by_mitre_id
|
||||
def exists_by_mitre_id(self, mitre_id: str) -> bool:
|
||||
"""Return True if a technique with the given MITRE ID exists in the repository.
|
||||
|
||||
Args:
|
||||
mitre_id (str): MITRE ATT&CK ID to check (e.g. ``"T1059"``).
|
||||
|
||||
Returns:
|
||||
bool: True if a matching technique is found, False otherwise.
|
||||
"""
|
||||
# ...
|
||||
...
|
||||
@@ -0,0 +1,108 @@
|
||||
"""Port defining how the application accesses test data.
|
||||
|
||||
This is a domain contract — implementations live in infrastructure/.
|
||||
"""
|
||||
|
||||
# Enable future language features for compatibility
|
||||
from __future__ import annotations
|
||||
|
||||
# Import uuid
|
||||
import uuid
|
||||
|
||||
# Import Protocol from typing
|
||||
from typing import Protocol
|
||||
|
||||
# Import TestState from app.domain.enums
|
||||
from app.domain.enums import TestState
|
||||
|
||||
|
||||
# Define class TestRepository
|
||||
class TestRepository(Protocol):
|
||||
"""Data access contract for tests."""
|
||||
|
||||
# -- Single-entity access ----------------------------------------------
|
||||
|
||||
def find_by_id(self, test_id: uuid.UUID) -> object | None:
|
||||
"""Return a Test ORM model by primary key, or None.
|
||||
|
||||
Returns the ORM model directly (not a domain entity) because
|
||||
the TestEntity is constructed at the service layer via
|
||||
``TestEntity.from_orm()``.
|
||||
|
||||
Args:
|
||||
test_id (uuid.UUID): Primary key of the test to look up.
|
||||
|
||||
Returns:
|
||||
object | None: The ORM model instance, or None if not found.
|
||||
"""
|
||||
# ...
|
||||
...
|
||||
|
||||
# -- List access -------------------------------------------------------
|
||||
|
||||
def list_by_technique(self, technique_id: uuid.UUID) -> list[object]:
|
||||
"""Return all test ORM models associated with the given technique.
|
||||
|
||||
Args:
|
||||
technique_id (uuid.UUID): Primary key of the technique whose tests to retrieve.
|
||||
|
||||
Returns:
|
||||
list[object]: ORM model instances for all tests linked to this technique.
|
||||
"""
|
||||
# ...
|
||||
...
|
||||
|
||||
# Define function list_by_state
|
||||
def list_by_state(self, state: TestState) -> list[object]:
|
||||
"""Return all test ORM models in the given state.
|
||||
|
||||
Args:
|
||||
state (TestState): The state to filter tests by.
|
||||
|
||||
Returns:
|
||||
list[object]: ORM model instances for all tests currently in *state*.
|
||||
"""
|
||||
# ...
|
||||
...
|
||||
|
||||
# Define function count_by_technique_and_state
|
||||
def count_by_technique_and_state(
|
||||
self,
|
||||
# Entry: technique_id
|
||||
technique_id: uuid.UUID,
|
||||
) -> dict[TestState, int]:
|
||||
"""Return test counts grouped by state for a single technique.
|
||||
|
||||
Args:
|
||||
technique_id (uuid.UUID): Primary key of the technique whose test
|
||||
counts to aggregate.
|
||||
|
||||
Returns:
|
||||
dict[TestState, int]: Mapping from each test state to the number of
|
||||
tests in that state for the given technique.
|
||||
"""
|
||||
# ...
|
||||
...
|
||||
|
||||
# -- Batch queries -----------------------------------------------------
|
||||
|
||||
def get_states_and_results_for_technique(
|
||||
self,
|
||||
# Entry: technique_id
|
||||
technique_id: uuid.UUID,
|
||||
) -> list[tuple[str, str | None]]:
|
||||
"""Return (state, detection_result) pairs for all tests of a technique.
|
||||
|
||||
Used by TechniqueEntity.recalculate_status() without loading full
|
||||
test models.
|
||||
|
||||
Args:
|
||||
technique_id (uuid.UUID): Primary key of the technique whose test
|
||||
data to retrieve.
|
||||
|
||||
Returns:
|
||||
list[tuple[str, str | None]]: Each tuple contains the test state
|
||||
string and the detection result string (or None if not yet set).
|
||||
"""
|
||||
# ...
|
||||
...
|
||||
@@ -0,0 +1,667 @@
|
||||
"""TestEntity — pure domain object for the test lifecycle state machine.
|
||||
|
||||
This entity owns ALL state-transition logic and business rules for a
|
||||
security test. It has **no** dependency on FastAPI, SQLAlchemy, or any
|
||||
infrastructure concern.
|
||||
|
||||
Usage::
|
||||
|
||||
entity = TestEntity.from_orm(test_orm_model)
|
||||
entity.start_execution() # draft → red_executing
|
||||
entity.submit_red_evidence() # red_executing → blue_evaluating
|
||||
entity.pause_timer()
|
||||
entity.resume_timer()
|
||||
entity.submit_blue_evidence() # blue_evaluating → in_review
|
||||
entity.validate_red("approved")
|
||||
entity.validate_blue("approved") # triggers dual-validation → validated
|
||||
entity.reopen() # rejected → draft
|
||||
|
||||
After mutations, the service layer copies ``entity.changes`` back onto
|
||||
the ORM model and persists via Unit of Work.
|
||||
"""
|
||||
|
||||
# Enable future language features for compatibility
|
||||
from __future__ import annotations
|
||||
|
||||
# Import enum
|
||||
import enum
|
||||
|
||||
# Import uuid
|
||||
import uuid
|
||||
|
||||
# Import dataclass, field from dataclasses
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
# Import datetime from datetime
|
||||
from datetime import datetime
|
||||
|
||||
# Import TYPE_CHECKING, Any from typing
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
# Import from app.domain.errors
|
||||
from app.domain.errors import (
|
||||
BusinessRuleViolation,
|
||||
InvalidOperationError,
|
||||
InvalidStateTransition,
|
||||
)
|
||||
|
||||
# Check: TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
# Import Test as TestORM from app.models.test
|
||||
from app.models.test import Test as TestORM
|
||||
|
||||
# ── Value objects ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestState(str, enum.Enum):
|
||||
"""Ordered lifecycle states for a security test."""
|
||||
|
||||
# Assign draft = "draft"
|
||||
draft = "draft"
|
||||
# Assign red_executing = "red_executing"
|
||||
red_executing = "red_executing"
|
||||
# Assign blue_evaluating = "blue_evaluating"
|
||||
blue_evaluating = "blue_evaluating"
|
||||
# Assign in_review = "in_review"
|
||||
in_review = "in_review"
|
||||
# Assign validated = "validated"
|
||||
validated = "validated"
|
||||
# Assign rejected = "rejected"
|
||||
rejected = "rejected"
|
||||
|
||||
|
||||
# Assign VALID_TRANSITIONS = {
|
||||
VALID_TRANSITIONS: dict[TestState, list[TestState]] = {
|
||||
TestState.draft: [TestState.red_executing],
|
||||
TestState.red_executing: [TestState.blue_evaluating],
|
||||
TestState.blue_evaluating: [TestState.in_review],
|
||||
TestState.in_review: [TestState.validated, TestState.rejected],
|
||||
TestState.rejected: [TestState.draft],
|
||||
TestState.validated: [],
|
||||
}
|
||||
|
||||
# Assign _PAUSABLE_STATES = frozenset({TestState.red_executing, TestState.blue_evaluating})
|
||||
_PAUSABLE_STATES = frozenset({TestState.red_executing, TestState.blue_evaluating})
|
||||
|
||||
|
||||
# ── Domain events (lightweight records of what happened) ─────────────
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
# Define class DomainEvent
|
||||
class DomainEvent:
|
||||
"""Immutable record of a domain-level event emitted by the test entity."""
|
||||
|
||||
# name: str
|
||||
name: str
|
||||
# Assign payload = field(default_factory=dict)
|
||||
payload: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
# ── Entity ───────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@dataclass
|
||||
# Define class TestEntity
|
||||
class TestEntity:
|
||||
"""Pure domain representation of a security test."""
|
||||
|
||||
# id: uuid.UUID
|
||||
id: uuid.UUID
|
||||
# state: TestState
|
||||
state: TestState
|
||||
|
||||
# Red validation
|
||||
red_validation_status: str | None = None
|
||||
# Assign red_validated_by = None
|
||||
red_validated_by: uuid.UUID | None = None
|
||||
# Assign red_validated_at = None
|
||||
red_validated_at: datetime | None = None
|
||||
# Assign red_validation_notes = None
|
||||
red_validation_notes: str | None = None
|
||||
|
||||
# Blue validation
|
||||
blue_validation_status: str | None = None
|
||||
# Assign blue_validated_by = None
|
||||
blue_validated_by: uuid.UUID | None = None
|
||||
# Assign blue_validated_at = None
|
||||
blue_validated_at: datetime | None = None
|
||||
# Assign blue_validation_notes = None
|
||||
blue_validation_notes: str | None = None
|
||||
|
||||
# Phase timing
|
||||
execution_date: datetime | None = None
|
||||
# Assign red_started_at = None
|
||||
red_started_at: datetime | None = None
|
||||
# Assign blue_started_at = None
|
||||
blue_started_at: datetime | None = None
|
||||
# Assign paused_at = None
|
||||
paused_at: datetime | None = None
|
||||
# Assign red_paused_seconds = 0
|
||||
red_paused_seconds: int = 0
|
||||
# Assign blue_paused_seconds = 0
|
||||
blue_paused_seconds: int = 0
|
||||
|
||||
# Internal bookkeeping (not persisted as-is)
|
||||
_events: list[DomainEvent] = field(default_factory=list, repr=False)
|
||||
|
||||
# -- Factory --------------------------------------------------------
|
||||
|
||||
@classmethod
|
||||
# Define function from_orm
|
||||
def from_orm(cls, model: TestORM) -> TestEntity:
|
||||
"""Build a TestEntity from a SQLAlchemy ``Test`` model instance.
|
||||
|
||||
Args:
|
||||
model (TestORM): The ORM model whose fields will be copied into the entity.
|
||||
|
||||
Returns:
|
||||
TestEntity: A fully populated domain entity reflecting the ORM state.
|
||||
"""
|
||||
# Assign raw_state = model.state
|
||||
raw_state = model.state
|
||||
# Assign state = raw_state if isinstance(raw_state, TestState) else TestState(raw_st...
|
||||
state = raw_state if isinstance(raw_state, TestState) else TestState(raw_state)
|
||||
# Return cls(
|
||||
return cls(
|
||||
# Keyword argument: id
|
||||
id=model.id,
|
||||
# Keyword argument: state
|
||||
state=state,
|
||||
# Keyword argument: red_validation_status
|
||||
red_validation_status=model.red_validation_status,
|
||||
# Keyword argument: red_validated_by
|
||||
red_validated_by=model.red_validated_by,
|
||||
# Keyword argument: red_validated_at
|
||||
red_validated_at=model.red_validated_at,
|
||||
# Keyword argument: red_validation_notes
|
||||
red_validation_notes=model.red_validation_notes,
|
||||
# Keyword argument: blue_validation_status
|
||||
blue_validation_status=model.blue_validation_status,
|
||||
# Keyword argument: blue_validated_by
|
||||
blue_validated_by=model.blue_validated_by,
|
||||
# Keyword argument: blue_validated_at
|
||||
blue_validated_at=model.blue_validated_at,
|
||||
# Keyword argument: blue_validation_notes
|
||||
blue_validation_notes=model.blue_validation_notes,
|
||||
# Keyword argument: execution_date
|
||||
execution_date=model.execution_date,
|
||||
# Keyword argument: red_started_at
|
||||
red_started_at=model.red_started_at,
|
||||
# Keyword argument: blue_started_at
|
||||
blue_started_at=model.blue_started_at,
|
||||
# Keyword argument: paused_at
|
||||
paused_at=model.paused_at,
|
||||
# Keyword argument: red_paused_seconds
|
||||
red_paused_seconds=model.red_paused_seconds or 0,
|
||||
# Keyword argument: blue_paused_seconds
|
||||
blue_paused_seconds=model.blue_paused_seconds or 0,
|
||||
)
|
||||
|
||||
# Define function apply_to
|
||||
def apply_to(self, model: TestORM) -> None:
|
||||
"""Copy the entity's mutable fields back onto the ORM model.
|
||||
|
||||
Args:
|
||||
model (TestORM): The ORM model to update in-place.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Assign model.state = self.state
|
||||
model.state = self.state
|
||||
# Assign model.red_validation_status = self.red_validation_status
|
||||
model.red_validation_status = self.red_validation_status
|
||||
# Assign model.red_validated_by = self.red_validated_by
|
||||
model.red_validated_by = self.red_validated_by
|
||||
# Assign model.red_validated_at = self.red_validated_at
|
||||
model.red_validated_at = self.red_validated_at
|
||||
# Assign model.red_validation_notes = self.red_validation_notes
|
||||
model.red_validation_notes = self.red_validation_notes
|
||||
# Assign model.blue_validation_status = self.blue_validation_status
|
||||
model.blue_validation_status = self.blue_validation_status
|
||||
# Assign model.blue_validated_by = self.blue_validated_by
|
||||
model.blue_validated_by = self.blue_validated_by
|
||||
# Assign model.blue_validated_at = self.blue_validated_at
|
||||
model.blue_validated_at = self.blue_validated_at
|
||||
# Assign model.blue_validation_notes = self.blue_validation_notes
|
||||
model.blue_validation_notes = self.blue_validation_notes
|
||||
# Assign model.execution_date = self.execution_date
|
||||
model.execution_date = self.execution_date
|
||||
# Assign model.red_started_at = self.red_started_at
|
||||
model.red_started_at = self.red_started_at
|
||||
# Assign model.blue_started_at = self.blue_started_at
|
||||
model.blue_started_at = self.blue_started_at
|
||||
# Assign model.paused_at = self.paused_at
|
||||
model.paused_at = self.paused_at
|
||||
# Assign model.red_paused_seconds = self.red_paused_seconds
|
||||
model.red_paused_seconds = self.red_paused_seconds
|
||||
# Assign model.blue_paused_seconds = self.blue_paused_seconds
|
||||
model.blue_paused_seconds = self.blue_paused_seconds
|
||||
|
||||
# -- Query helpers --------------------------------------------------
|
||||
|
||||
@property
|
||||
# Define function events
|
||||
def events(self) -> list[DomainEvent]:
|
||||
"""Return a snapshot of all domain events raised on this entity.
|
||||
|
||||
Returns:
|
||||
list[DomainEvent]: Ordered list of events emitted since the entity
|
||||
was constructed or last cleared.
|
||||
"""
|
||||
# Return list(self._events)
|
||||
return list(self._events)
|
||||
|
||||
# Define function can_transition
|
||||
def can_transition(self, target: TestState) -> bool:
|
||||
"""Check whether a transition from the current state to *target* is valid.
|
||||
|
||||
Args:
|
||||
target (TestState): The desired next state.
|
||||
|
||||
Returns:
|
||||
bool: True if the transition is allowed, False otherwise.
|
||||
"""
|
||||
# Return target in VALID_TRANSITIONS.get(self.state, [])
|
||||
return target in VALID_TRANSITIONS.get(self.state, [])
|
||||
|
||||
# Apply the @property decorator
|
||||
@property
|
||||
# Define function is_terminal
|
||||
def is_terminal(self) -> bool:
|
||||
"""Return True if the test has reached its final (validated) state.
|
||||
|
||||
Returns:
|
||||
bool: True when state is ``validated``, False for all other states.
|
||||
"""
|
||||
# Return self.state == TestState.validated
|
||||
return self.state == TestState.validated
|
||||
|
||||
# -- Core transition ------------------------------------------------
|
||||
|
||||
def transition_to(self, target: TestState | str) -> str:
|
||||
"""Validate and apply a state transition.
|
||||
|
||||
Accepts either a :class:`TestState` member or its string value
|
||||
(so callers using ``models.enums.TestState`` work transparently).
|
||||
|
||||
Returns the *previous* state value as a plain string.
|
||||
|
||||
Raises :class:`InvalidStateTransition` when the move is illegal.
|
||||
|
||||
Args:
|
||||
target (TestState | str): The desired next state, as an enum member
|
||||
or its string equivalent.
|
||||
|
||||
Returns:
|
||||
str: The previous state value before the transition.
|
||||
"""
|
||||
# Assign value = target.value if hasattr(target, "value") else str(target)
|
||||
value = target.value if hasattr(target, "value") else str(target)
|
||||
# Assign resolved = target if isinstance(target, TestState) else TestState(value)
|
||||
resolved = target if isinstance(target, TestState) else TestState(value)
|
||||
# Return self._transition(resolved)
|
||||
return self._transition(resolved)
|
||||
|
||||
# Define function _transition
|
||||
def _transition(self, target: TestState) -> str:
|
||||
"""Validate and apply a state transition, returning the previous state value.
|
||||
|
||||
Args:
|
||||
target (TestState): The desired next state enum member.
|
||||
|
||||
Returns:
|
||||
str: The previous state value before the transition was applied.
|
||||
"""
|
||||
# Check: not self.can_transition(target)
|
||||
if not self.can_transition(target):
|
||||
# Assign valid = [s.value for s in VALID_TRANSITIONS.get(self.state, [])]
|
||||
valid = [s.value for s in VALID_TRANSITIONS.get(self.state, [])]
|
||||
# Raise InvalidStateTransition
|
||||
raise InvalidStateTransition(
|
||||
# Keyword argument: current_state
|
||||
current_state=self.state.value,
|
||||
# Keyword argument: target_state
|
||||
target_state=target.value,
|
||||
# Keyword argument: valid_transitions
|
||||
valid_transitions=valid,
|
||||
)
|
||||
# Assign previous = self.state.value
|
||||
previous = self.state.value
|
||||
# Assign self.state = target
|
||||
self.state = target
|
||||
# Call self._events.append()
|
||||
self._events.append(DomainEvent(
|
||||
# Literal argument value
|
||||
"state_changed",
|
||||
{"previous": previous, "new": target.value},
|
||||
))
|
||||
# Return previous
|
||||
return previous
|
||||
|
||||
# -- Lifecycle commands --------------------------------------------
|
||||
|
||||
def start_execution(self) -> None:
|
||||
"""Transition the test from ``draft`` to ``red_executing``.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Call self._transition()
|
||||
self._transition(TestState.red_executing)
|
||||
# Assign now = datetime.utcnow()
|
||||
now = datetime.utcnow()
|
||||
# Assign self.execution_date = now
|
||||
self.execution_date = now
|
||||
# Assign self.red_started_at = now
|
||||
self.red_started_at = now
|
||||
# Call self._events.append()
|
||||
self._events.append(DomainEvent("execution_started"))
|
||||
|
||||
# Define function submit_red_evidence
|
||||
def submit_red_evidence(self) -> int:
|
||||
"""Transition the test from ``red_executing`` to ``blue_evaluating``.
|
||||
|
||||
Auto-resumes if paused. Returns paused seconds accumulated
|
||||
during this phase (for worklog calculation).
|
||||
|
||||
Returns:
|
||||
int: Total seconds the red phase was paused.
|
||||
"""
|
||||
# Assign paused_extra = self._auto_resume()
|
||||
paused_extra = self._auto_resume()
|
||||
# Call self._transition()
|
||||
self._transition(TestState.blue_evaluating)
|
||||
# Assign total_paused = self.red_paused_seconds + paused_extra
|
||||
total_paused = self.red_paused_seconds + paused_extra
|
||||
# Assign self.blue_started_at = datetime.utcnow()
|
||||
self.blue_started_at = datetime.utcnow()
|
||||
# Assign self.blue_paused_seconds = 0
|
||||
self.blue_paused_seconds = 0
|
||||
# Call self._events.append()
|
||||
self._events.append(DomainEvent(
|
||||
# Literal argument value
|
||||
"red_evidence_submitted",
|
||||
{"red_paused_seconds": total_paused},
|
||||
))
|
||||
# Return total_paused
|
||||
return total_paused
|
||||
|
||||
# Define function submit_blue_evidence
|
||||
def submit_blue_evidence(self) -> int:
|
||||
"""Transition the test from ``blue_evaluating`` to ``in_review``.
|
||||
|
||||
Auto-resumes if paused. Returns paused seconds accumulated
|
||||
during this phase (for worklog calculation).
|
||||
|
||||
Returns:
|
||||
int: Total seconds the blue phase was paused.
|
||||
"""
|
||||
# Assign paused_extra = self._auto_resume()
|
||||
paused_extra = self._auto_resume()
|
||||
# Call self._transition()
|
||||
self._transition(TestState.in_review)
|
||||
# Assign total_paused = self.blue_paused_seconds + paused_extra
|
||||
total_paused = self.blue_paused_seconds + paused_extra
|
||||
# Call self._events.append()
|
||||
self._events.append(DomainEvent(
|
||||
# Literal argument value
|
||||
"blue_evidence_submitted",
|
||||
{"blue_paused_seconds": total_paused},
|
||||
))
|
||||
# Return total_paused
|
||||
return total_paused
|
||||
|
||||
# Define function pause_timer
|
||||
def pause_timer(self) -> None:
|
||||
"""Pause the active phase timer.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Check: self.state not in _PAUSABLE_STATES
|
||||
if self.state not in _PAUSABLE_STATES:
|
||||
# Raise BusinessRuleViolation
|
||||
raise BusinessRuleViolation(
|
||||
f"Cannot pause timer in '{self.state.value}' state"
|
||||
)
|
||||
# Check: self.paused_at is not None
|
||||
if self.paused_at is not None:
|
||||
# Raise BusinessRuleViolation
|
||||
raise BusinessRuleViolation("Timer is already paused")
|
||||
# Assign self.paused_at = datetime.utcnow()
|
||||
self.paused_at = datetime.utcnow()
|
||||
# Call self._events.append()
|
||||
self._events.append(DomainEvent("timer_paused"))
|
||||
|
||||
# Define function resume_timer
|
||||
def resume_timer(self) -> int:
|
||||
"""Resume a paused timer.
|
||||
|
||||
Returns:
|
||||
int: Number of seconds the timer was paused for.
|
||||
"""
|
||||
# Check: self.paused_at is None
|
||||
if self.paused_at is None:
|
||||
# Raise BusinessRuleViolation
|
||||
raise BusinessRuleViolation("Timer is not paused")
|
||||
# Assign now = datetime.utcnow()
|
||||
now = datetime.utcnow()
|
||||
# Assign paused_seconds = max(int((now - self.paused_at).total_seconds()), 0)
|
||||
paused_seconds = max(int((now - self.paused_at).total_seconds()), 0)
|
||||
# Check: self.state == TestState.red_executing
|
||||
if self.state == TestState.red_executing:
|
||||
# Assign self.red_paused_seconds = paused_seconds
|
||||
self.red_paused_seconds += paused_seconds
|
||||
# Alternative: self.state == TestState.blue_evaluating
|
||||
elif self.state == TestState.blue_evaluating:
|
||||
# Assign self.blue_paused_seconds = paused_seconds
|
||||
self.blue_paused_seconds += paused_seconds
|
||||
# Assign self.paused_at = None
|
||||
self.paused_at = None
|
||||
# Call self._events.append()
|
||||
self._events.append(DomainEvent("timer_resumed", {"paused_seconds": paused_seconds}))
|
||||
# Return paused_seconds
|
||||
return paused_seconds
|
||||
|
||||
# Define function validate_red
|
||||
def validate_red(self, status: str, *, by: uuid.UUID, notes: str | None = None) -> None:
|
||||
"""Record Red Lead's validation decision.
|
||||
|
||||
Args:
|
||||
status (str): Validation outcome; must be ``"approved"`` or ``"rejected"``.
|
||||
by (uuid.UUID): UUID of the Red Lead recording the decision.
|
||||
notes (str | None): Optional free-text notes about the decision.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Call self._assert_in_review()
|
||||
self._assert_in_review("red")
|
||||
# Call self._assert_valid_vote()
|
||||
self._assert_valid_vote(status)
|
||||
# Assign now = datetime.utcnow()
|
||||
now = datetime.utcnow()
|
||||
# Assign self.red_validation_status = status
|
||||
self.red_validation_status = status
|
||||
# Assign self.red_validated_by = by
|
||||
self.red_validated_by = by
|
||||
# Assign self.red_validated_at = now
|
||||
self.red_validated_at = now
|
||||
# Assign self.red_validation_notes = notes
|
||||
self.red_validation_notes = notes
|
||||
# Call self._events.append()
|
||||
self._events.append(DomainEvent("red_validated", {"status": status}))
|
||||
# Call self._check_dual_validation()
|
||||
self._check_dual_validation()
|
||||
|
||||
# Define function validate_blue
|
||||
def validate_blue(self, status: str, *, by: uuid.UUID, notes: str | None = None) -> None:
|
||||
"""Record Blue Lead's validation decision.
|
||||
|
||||
Args:
|
||||
status (str): Validation outcome; must be ``"approved"`` or ``"rejected"``.
|
||||
by (uuid.UUID): UUID of the Blue Lead recording the decision.
|
||||
notes (str | None): Optional free-text notes about the decision.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Call self._assert_in_review()
|
||||
self._assert_in_review("blue")
|
||||
# Call self._assert_valid_vote()
|
||||
self._assert_valid_vote(status)
|
||||
# Assign now = datetime.utcnow()
|
||||
now = datetime.utcnow()
|
||||
# Assign self.blue_validation_status = status
|
||||
self.blue_validation_status = status
|
||||
# Assign self.blue_validated_by = by
|
||||
self.blue_validated_by = by
|
||||
# Assign self.blue_validated_at = now
|
||||
self.blue_validated_at = now
|
||||
# Assign self.blue_validation_notes = notes
|
||||
self.blue_validation_notes = notes
|
||||
# Call self._events.append()
|
||||
self._events.append(DomainEvent("blue_validated", {"status": status}))
|
||||
# Call self._check_dual_validation()
|
||||
self._check_dual_validation()
|
||||
|
||||
# Define function reopen
|
||||
def reopen(self) -> None:
|
||||
"""Transition the test from ``rejected`` back to ``draft``, clearing all validation and timing fields.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Call self._transition()
|
||||
self._transition(TestState.draft)
|
||||
# Assign self.red_validation_status = None
|
||||
self.red_validation_status = None
|
||||
# Assign self.red_validated_by = None
|
||||
self.red_validated_by = None
|
||||
# Assign self.red_validated_at = None
|
||||
self.red_validated_at = None
|
||||
# Assign self.red_validation_notes = None
|
||||
self.red_validation_notes = None
|
||||
# Assign self.blue_validation_status = None
|
||||
self.blue_validation_status = None
|
||||
# Assign self.blue_validated_by = None
|
||||
self.blue_validated_by = None
|
||||
# Assign self.blue_validated_at = None
|
||||
self.blue_validated_at = None
|
||||
# Assign self.blue_validation_notes = None
|
||||
self.blue_validation_notes = None
|
||||
# Assign self.red_started_at = None
|
||||
self.red_started_at = None
|
||||
# Assign self.blue_started_at = None
|
||||
self.blue_started_at = None
|
||||
# Assign self.paused_at = None
|
||||
self.paused_at = None
|
||||
# Assign self.red_paused_seconds = 0
|
||||
self.red_paused_seconds = 0
|
||||
# Assign self.blue_paused_seconds = 0
|
||||
self.blue_paused_seconds = 0
|
||||
# Call self._events.append()
|
||||
self._events.append(DomainEvent("test_reopened"))
|
||||
|
||||
# -- Private -------------------------------------------------------
|
||||
|
||||
def _auto_resume(self) -> int:
|
||||
"""Accumulate pause time and clear the paused timestamp if currently paused.
|
||||
|
||||
Returns:
|
||||
int: Extra seconds that were accumulated from the current pause, or 0
|
||||
if the timer was not paused.
|
||||
"""
|
||||
# Check: self.paused_at is None
|
||||
if self.paused_at is None:
|
||||
# Return 0
|
||||
return 0
|
||||
# Assign now = datetime.utcnow()
|
||||
now = datetime.utcnow()
|
||||
# Assign extra = max(int((now - self.paused_at).total_seconds()), 0)
|
||||
extra = max(int((now - self.paused_at).total_seconds()), 0)
|
||||
# Assign self.paused_at = None
|
||||
self.paused_at = None
|
||||
# Return extra
|
||||
return extra
|
||||
|
||||
# Define function check_dual_validation
|
||||
def check_dual_validation(self) -> None:
|
||||
"""Evaluate both leads' votes and advance state if appropriate.
|
||||
|
||||
- Both **approved** -> ``validated``
|
||||
- Either **rejected** -> ``rejected``
|
||||
- Otherwise no change (waiting for the other lead).
|
||||
|
||||
Called automatically by :meth:`validate_red` and :meth:`validate_blue`.
|
||||
Also available as a standalone entry point for backward compatibility
|
||||
when validation fields are set externally.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Call self._check_dual_validation()
|
||||
self._check_dual_validation()
|
||||
|
||||
# Define function _assert_in_review
|
||||
def _assert_in_review(self, side: str) -> None:
|
||||
"""Raise InvalidOperationError unless the test is in ``in_review`` state.
|
||||
|
||||
Args:
|
||||
side (str): The team side being validated (``"red"`` or ``"blue"``),
|
||||
used in the error message.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Check: self.state != TestState.in_review
|
||||
if self.state != TestState.in_review:
|
||||
# Raise InvalidOperationError
|
||||
raise InvalidOperationError(
|
||||
f"Cannot validate {side} side while test is in "
|
||||
f"'{self.state.value}' state (must be in_review)"
|
||||
)
|
||||
|
||||
# Apply the @staticmethod decorator
|
||||
@staticmethod
|
||||
# Define function _assert_valid_vote
|
||||
def _assert_valid_vote(status: str) -> None:
|
||||
"""Raise InvalidOperationError if *status* is not a valid vote value.
|
||||
|
||||
Args:
|
||||
status (str): The vote value to validate; must be ``"approved"`` or ``"rejected"``.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Check: status not in ("approved", "rejected")
|
||||
if status not in ("approved", "rejected"):
|
||||
# Raise InvalidOperationError
|
||||
raise InvalidOperationError(
|
||||
# Literal argument value
|
||||
"validation_status must be 'approved' or 'rejected'"
|
||||
)
|
||||
|
||||
# Define function _check_dual_validation
|
||||
def _check_dual_validation(self) -> None:
|
||||
"""Advance to ``validated`` or ``rejected`` once both leads have voted.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# r, b = self.red_validation_status, self.blue_validation_status
|
||||
r, b = self.red_validation_status, self.blue_validation_status
|
||||
# Check: r == "rejected" or b == "rejected"
|
||||
if r == "rejected" or b == "rejected":
|
||||
# Assign self.state = TestState.rejected
|
||||
self.state = TestState.rejected
|
||||
# Call self._events.append()
|
||||
self._events.append(DomainEvent("dual_validation_rejected"))
|
||||
# Alternative: r == "approved" and b == "approved"
|
||||
elif r == "approved" and b == "approved":
|
||||
# Assign self.state = TestState.validated
|
||||
self.state = TestState.validated
|
||||
# Call self._events.append()
|
||||
self._events.append(DomainEvent("dual_validation_approved"))
|
||||
@@ -0,0 +1,103 @@
|
||||
"""Unit of Work — wraps a SQLAlchemy session for explicit transaction control.
|
||||
|
||||
Usage in routers::
|
||||
|
||||
with UnitOfWork(db) as uow:
|
||||
service_a(db, ...)
|
||||
service_b(db, ...)
|
||||
uow.commit() # single commit for the entire operation
|
||||
|
||||
If an exception propagates, ``__exit__`` issues a rollback automatically.
|
||||
Services should **never** call ``db.commit()``; they use ``db.add()`` /
|
||||
``db.flush()`` to stage work and let the caller decide when to commit.
|
||||
|
||||
**Documented exceptions** (services that may commit internally):
|
||||
- Import services (atomic_import, sigma_import, etc.) — self-contained sync ops.
|
||||
- Background jobs (campaign_scheduler, intel_service, stale_detection,
|
||||
mitre_sync) — self-contained operations.
|
||||
- Self-contained batch ops (e.g. detection_rule_service.auto_associate_rules,
|
||||
snapshot_service.create_snapshot, campaign_service.generate_campaign_from_*,
|
||||
osint_enrichment_service.enrich_technique_with_cves).
|
||||
"""
|
||||
|
||||
# Enable future language features for compatibility
|
||||
from __future__ import annotations
|
||||
|
||||
# Import TracebackType from types
|
||||
from types import TracebackType
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
# Define class UnitOfWork
|
||||
class UnitOfWork:
|
||||
"""Lightweight transaction wrapper around an existing SQLAlchemy session."""
|
||||
|
||||
# Define function __init__
|
||||
def __init__(self, session: Session) -> None:
|
||||
"""Wrap an existing SQLAlchemy session in a Unit of Work.
|
||||
|
||||
Args:
|
||||
session (Session): The active SQLAlchemy session to manage.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Assign self._session = session
|
||||
self._session = session
|
||||
|
||||
# -- context manager -----------------------------------------------------
|
||||
|
||||
def __enter__(self) -> "UnitOfWork":
|
||||
"""Enter the runtime context, returning this UnitOfWork instance.
|
||||
|
||||
Returns:
|
||||
UnitOfWork: The UnitOfWork itself, for use in ``with`` statements.
|
||||
"""
|
||||
# Return self
|
||||
return self
|
||||
|
||||
# Define function __exit__
|
||||
def __exit__(
|
||||
self,
|
||||
# Entry: exc_type
|
||||
exc_type: type[BaseException] | None,
|
||||
# Entry: exc_val
|
||||
exc_val: BaseException | None,
|
||||
# Entry: exc_tb
|
||||
exc_tb: TracebackType | None,
|
||||
) -> None:
|
||||
"""Exit the runtime context, rolling back if an exception propagated.
|
||||
|
||||
Args:
|
||||
exc_type (type[BaseException] | None): Exception class, if raised.
|
||||
exc_val (BaseException | None): Exception instance, if raised.
|
||||
exc_tb (TracebackType | None): Traceback object, if an exception was raised.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Check: exc_type is not None
|
||||
if exc_type is not None:
|
||||
# Call self.rollback()
|
||||
self.rollback()
|
||||
|
||||
# -- public API ----------------------------------------------------------
|
||||
|
||||
def commit(self) -> None:
|
||||
"""Flush pending changes and commit the transaction."""
|
||||
# Call self._session.commit()
|
||||
self._session.commit()
|
||||
|
||||
# Define function rollback
|
||||
def rollback(self) -> None:
|
||||
"""Roll back the current transaction."""
|
||||
# Call self._session.rollback()
|
||||
self._session.rollback()
|
||||
|
||||
# Define function flush
|
||||
def flush(self) -> None:
|
||||
"""Flush pending changes without committing (useful for getting IDs)."""
|
||||
# Call self._session.flush()
|
||||
self._session.flush()
|
||||
@@ -0,0 +1,9 @@
|
||||
"""Immutable domain value objects."""
|
||||
# Import MitreId from app.domain.value_objects.mitre_id
|
||||
from app.domain.value_objects.mitre_id import MitreId
|
||||
|
||||
# Import ScoringWeights from app.domain.value_objects.scoring_weights
|
||||
from app.domain.value_objects.scoring_weights import ScoringWeights
|
||||
|
||||
# Assign __all__ = ["MitreId", "ScoringWeights"]
|
||||
__all__ = ["MitreId", "ScoringWeights"]
|
||||
@@ -0,0 +1,115 @@
|
||||
"""MitreId — validated MITRE ATT&CK technique identifier.
|
||||
|
||||
Immutable value object that ensures the identifier follows the ATT&CK
|
||||
format: ``T`` followed by 4 digits, optionally a dot and 3 more digits
|
||||
for sub-techniques (e.g. ``T1059``, ``T1059.001``).
|
||||
"""
|
||||
|
||||
# Enable future language features for compatibility
|
||||
from __future__ import annotations
|
||||
|
||||
# Import re
|
||||
import re
|
||||
|
||||
# Import dataclass from dataclasses
|
||||
from dataclasses import dataclass
|
||||
|
||||
# Assign _MITRE_ID_RE = re.compile(r"^T\d{4}(\.\d{3})?$")
|
||||
_MITRE_ID_RE = re.compile(r"^T\d{4}(\.\d{3})?$")
|
||||
|
||||
|
||||
# Apply the @dataclass decorator
|
||||
@dataclass(frozen=True, slots=True)
|
||||
# Define class MitreId
|
||||
class MitreId:
|
||||
"""Validated MITRE ATT&CK technique identifier."""
|
||||
|
||||
# value: str
|
||||
value: str
|
||||
|
||||
# Define function __post_init__
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate that *value* matches the expected MITRE ATT&CK ID format.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Check: not _MITRE_ID_RE.match(self.value)
|
||||
if not _MITRE_ID_RE.match(self.value):
|
||||
# Raise ValueError
|
||||
raise ValueError(
|
||||
f"Invalid MITRE ATT&CK ID '{self.value}'. "
|
||||
# Literal argument value
|
||||
"Expected format: T1234 or T1234.001"
|
||||
)
|
||||
|
||||
# Apply the @property decorator
|
||||
@property
|
||||
# Define function is_subtechnique
|
||||
def is_subtechnique(self) -> bool:
|
||||
"""Return True if this identifier represents a sub-technique.
|
||||
|
||||
Returns:
|
||||
bool: True when the ID contains a dot (e.g. ``T1059.001``).
|
||||
"""
|
||||
# Return "." in self.value
|
||||
return "." in self.value
|
||||
|
||||
# Apply the @property decorator
|
||||
@property
|
||||
# Define function parent_id
|
||||
def parent_id(self) -> str | None:
|
||||
"""Return the parent technique ID (e.g. ``T1059`` for ``T1059.001``).
|
||||
|
||||
Returns:
|
||||
str | None: The parent ID string, or None if this is not a sub-technique.
|
||||
"""
|
||||
# Check: not self.is_subtechnique
|
||||
if not self.is_subtechnique:
|
||||
# Return None
|
||||
return None
|
||||
# Return self.value.split(".")[0]
|
||||
return self.value.split(".")[0]
|
||||
|
||||
# Define function __str__
|
||||
def __str__(self) -> str:
|
||||
"""Return the string representation of the MITRE ID.
|
||||
|
||||
Returns:
|
||||
str: The raw identifier string (e.g. ``"T1059.001"``).
|
||||
"""
|
||||
# Return self.value
|
||||
return self.value
|
||||
|
||||
# Define function __eq__
|
||||
def __eq__(self, other: object) -> bool:
|
||||
"""Compare this MitreId to another MitreId or a plain string.
|
||||
|
||||
Args:
|
||||
other (object): The value to compare against; may be a
|
||||
:class:`MitreId` instance or a plain ``str``.
|
||||
|
||||
Returns:
|
||||
bool: True if the identifiers are equal, NotImplemented for
|
||||
unsupported types.
|
||||
"""
|
||||
# Check: isinstance(other, MitreId)
|
||||
if isinstance(other, MitreId):
|
||||
# Return self.value == other.value
|
||||
return self.value == other.value
|
||||
# Check: isinstance(other, str)
|
||||
if isinstance(other, str):
|
||||
# Return self.value == other
|
||||
return self.value == other
|
||||
# Return NotImplemented
|
||||
return NotImplemented
|
||||
|
||||
# Define function __hash__
|
||||
def __hash__(self) -> int:
|
||||
"""Return the hash of the identifier string.
|
||||
|
||||
Returns:
|
||||
int: Hash value derived from the raw identifier string.
|
||||
"""
|
||||
# Return hash(self.value)
|
||||
return hash(self.value)
|
||||
@@ -0,0 +1,107 @@
|
||||
"""ScoringWeights — validated immutable weight set for the scoring engine.
|
||||
|
||||
Enforces that all five weights are non-negative and sum to exactly 100.
|
||||
"""
|
||||
|
||||
# Enable future language features for compatibility
|
||||
from __future__ import annotations
|
||||
|
||||
# Import dataclass from dataclasses
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
# Apply the @dataclass decorator
|
||||
@dataclass(frozen=True, slots=True)
|
||||
# Define class ScoringWeights
|
||||
class ScoringWeights:
|
||||
"""Five scoring dimension weights that must sum to 100."""
|
||||
|
||||
# tests: float
|
||||
tests: float
|
||||
# detection_rules: float
|
||||
detection_rules: float
|
||||
# d3fend: float
|
||||
d3fend: float
|
||||
# recency: float
|
||||
recency: float
|
||||
# severity: float
|
||||
severity: float
|
||||
|
||||
# Define function __post_init__
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate that all weights are non-negative and sum to exactly 100.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Assign fields = [
|
||||
fields = [
|
||||
self.tests,
|
||||
self.detection_rules,
|
||||
self.d3fend,
|
||||
self.recency,
|
||||
self.severity,
|
||||
]
|
||||
# Iterate over fields
|
||||
for f in fields:
|
||||
# Check: f < 0
|
||||
if f < 0:
|
||||
# Raise ValueError
|
||||
raise ValueError("Scoring weights must be non-negative")
|
||||
|
||||
# Assign total = sum(fields)
|
||||
total = sum(fields)
|
||||
# Check: abs(total - 100) > 0.01
|
||||
if abs(total - 100) > 0.01:
|
||||
# Raise ValueError
|
||||
raise ValueError(
|
||||
f"Scoring weights must sum to 100, got {total}"
|
||||
)
|
||||
|
||||
# Apply the @classmethod decorator
|
||||
@classmethod
|
||||
# Define function default
|
||||
def default(cls) -> ScoringWeights:
|
||||
"""Return the default weight distribution.
|
||||
|
||||
Returns:
|
||||
ScoringWeights: A weight set with tests=40, detection_rules=25,
|
||||
d3fend=15, recency=10, severity=10.
|
||||
"""
|
||||
# Return cls(
|
||||
return cls(
|
||||
# Keyword argument: tests
|
||||
tests=40.0,
|
||||
# Keyword argument: detection_rules
|
||||
detection_rules=25.0,
|
||||
# Keyword argument: d3fend
|
||||
d3fend=15.0,
|
||||
# Keyword argument: recency
|
||||
recency=10.0,
|
||||
# Keyword argument: severity
|
||||
severity=10.0,
|
||||
)
|
||||
|
||||
# Backward-compatible aliases for older API payloads
|
||||
@property
|
||||
# Define function freshness
|
||||
def freshness(self) -> float:
|
||||
"""Return the recency weight (backward-compatible alias).
|
||||
|
||||
Returns:
|
||||
float: The value of the ``recency`` weight.
|
||||
"""
|
||||
# Return self.recency
|
||||
return self.recency
|
||||
|
||||
# Apply the @property decorator
|
||||
@property
|
||||
# Define function platform_diversity
|
||||
def platform_diversity(self) -> float:
|
||||
"""Return the severity weight (backward-compatible alias).
|
||||
|
||||
Returns:
|
||||
float: The value of the ``severity`` weight.
|
||||
"""
|
||||
# Return self.severity
|
||||
return self.severity
|
||||
@@ -0,0 +1 @@
|
||||
"""Infrastructure adapters — persistence, caching, and external services."""
|
||||
@@ -0,0 +1 @@
|
||||
"""SQLAlchemy-based persistence adapters for the domain repository ports."""
|
||||
@@ -0,0 +1 @@
|
||||
"""ORM-to-domain entity mapper functions."""
|
||||
@@ -0,0 +1,28 @@
|
||||
"""Technique ORM model <-> domain entity mapper."""
|
||||
|
||||
# Enable future language features for compatibility
|
||||
from __future__ import annotations
|
||||
|
||||
# Import TechniqueEntity from app.domain.entities.technique
|
||||
from app.domain.entities.technique import TechniqueEntity
|
||||
|
||||
|
||||
# Define class TechniqueMapper
|
||||
class TechniqueMapper:
|
||||
"""Converts between SQLAlchemy Technique model and TechniqueEntity."""
|
||||
|
||||
# Apply the @staticmethod decorator
|
||||
@staticmethod
|
||||
# Define function to_entity
|
||||
def to_entity(model: object) -> TechniqueEntity:
|
||||
"""Convert an ORM Technique model to a domain TechniqueEntity."""
|
||||
# Return TechniqueEntity.from_orm(model)
|
||||
return TechniqueEntity.from_orm(model)
|
||||
|
||||
# Apply the @staticmethod decorator
|
||||
@staticmethod
|
||||
# Define function to_model_updates
|
||||
def to_model_updates(entity: TechniqueEntity, model: object) -> None:
|
||||
"""Apply entity changes back onto an existing ORM model."""
|
||||
# Call entity.apply_to()
|
||||
entity.apply_to(model)
|
||||
@@ -0,0 +1,13 @@
|
||||
"""Concrete SQLAlchemy repository implementations."""
|
||||
# Import from app.infrastructure.persistence.repositories.sa_technique_repository
|
||||
from app.infrastructure.persistence.repositories.sa_technique_repository import (
|
||||
SATechniqueRepository,
|
||||
)
|
||||
|
||||
# Import from app.infrastructure.persistence.repositories.sa_test_repository
|
||||
from app.infrastructure.persistence.repositories.sa_test_repository import (
|
||||
SATestRepository,
|
||||
)
|
||||
|
||||
# Assign __all__ = ["SATechniqueRepository", "SATestRepository"]
|
||||
__all__ = ["SATechniqueRepository", "SATestRepository"]
|
||||
@@ -0,0 +1,380 @@
|
||||
"""SQLAlchemy implementation of TechniqueRepository.
|
||||
|
||||
Receives a Session from the caller — does NOT create its own.
|
||||
Does NOT call commit() — the Unit of Work owns that.
|
||||
"""
|
||||
|
||||
# Enable future language features for compatibility
|
||||
from __future__ import annotations
|
||||
|
||||
# Import uuid
|
||||
import uuid
|
||||
|
||||
# Import func from sqlalchemy
|
||||
from sqlalchemy import func
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import TechniqueEntity from app.domain.entities.technique
|
||||
from app.domain.entities.technique import TechniqueEntity
|
||||
|
||||
# Import TechniqueStatus, TestState from app.domain.enums
|
||||
from app.domain.enums import TechniqueStatus, TestState
|
||||
|
||||
# Import TechniqueWithCounts from app.domain.ports.repositories.technique_repository
|
||||
from app.domain.ports.repositories.technique_repository import TechniqueWithCounts
|
||||
|
||||
# Import TechniqueMapper from app.infrastructure.persistence.mappers.technique_mapper
|
||||
from app.infrastructure.persistence.mappers.technique_mapper import TechniqueMapper
|
||||
|
||||
# Import DetectionRule from app.models.detection_rule
|
||||
from app.models.detection_rule import DetectionRule
|
||||
|
||||
# Import Technique from app.models.technique
|
||||
from app.models.technique import Technique
|
||||
|
||||
# Import Test from app.models.test
|
||||
from app.models.test import Test
|
||||
|
||||
|
||||
# Define class SATechniqueRepository
|
||||
class SATechniqueRepository:
|
||||
"""Concrete repository backed by SQLAlchemy."""
|
||||
|
||||
# Define function __init__
|
||||
def __init__(self, session: Session) -> None:
|
||||
"""Initialise the repository with a caller-provided session.
|
||||
|
||||
Args:
|
||||
session (Session): The SQLAlchemy session to use for all queries.
|
||||
"""
|
||||
# Assign self._session = session
|
||||
self._session = session
|
||||
|
||||
# -- Single-entity access ----------------------------------------------
|
||||
|
||||
def find_by_id(self, technique_id: uuid.UUID) -> TechniqueEntity | None:
|
||||
"""Return a single technique by its primary key.
|
||||
|
||||
Args:
|
||||
technique_id (uuid.UUID): The UUID primary key of the technique.
|
||||
|
||||
Returns:
|
||||
TechniqueEntity | None: The matching entity, or ``None`` if not found.
|
||||
"""
|
||||
# Assign model = (
|
||||
model = (
|
||||
self._session.query(Technique)
|
||||
# Chain .filter() call
|
||||
.filter(Technique.id == technique_id)
|
||||
# Chain .first() call
|
||||
.first()
|
||||
)
|
||||
# Return TechniqueMapper.to_entity(model) if model else None
|
||||
return TechniqueMapper.to_entity(model) if model else None
|
||||
|
||||
# Define function find_by_mitre_id
|
||||
def find_by_mitre_id(self, mitre_id: str) -> TechniqueEntity | None:
|
||||
"""Return a single technique by its MITRE ATT&CK ID (e.g. ``T1059.001``).
|
||||
|
||||
Args:
|
||||
mitre_id (str): The MITRE ATT&CK identifier string.
|
||||
|
||||
Returns:
|
||||
TechniqueEntity | None: The matching entity, or ``None`` if not found.
|
||||
"""
|
||||
# Assign model = (
|
||||
model = (
|
||||
self._session.query(Technique)
|
||||
# Chain .filter() call
|
||||
.filter(Technique.mitre_id == mitre_id)
|
||||
# Chain .first() call
|
||||
.first()
|
||||
)
|
||||
# Return TechniqueMapper.to_entity(model) if model else None
|
||||
return TechniqueMapper.to_entity(model) if model else None
|
||||
|
||||
# -- List access -------------------------------------------------------
|
||||
|
||||
def list_all(
|
||||
self,
|
||||
*,
|
||||
# Entry: tactic
|
||||
tactic: str | None = None,
|
||||
# Entry: status
|
||||
status: TechniqueStatus | None = None,
|
||||
# Entry: review_required
|
||||
review_required: bool | None = None,
|
||||
) -> list[TechniqueEntity]:
|
||||
"""Return all techniques, optionally filtered by tactic, status, or review flag.
|
||||
|
||||
Args:
|
||||
tactic (str | None): Filter to techniques belonging to this tactic name.
|
||||
status (TechniqueStatus | None): Filter to techniques with this coverage status.
|
||||
review_required (bool | None): Filter to techniques where ``review_required`` matches.
|
||||
|
||||
Returns:
|
||||
list[TechniqueEntity]: Ordered list of matching technique entities.
|
||||
"""
|
||||
# Assign query = self._session.query(Technique)
|
||||
query = self._session.query(Technique)
|
||||
# Check: tactic is not None
|
||||
if tactic is not None:
|
||||
# Assign query = query.filter(Technique.tactic == tactic)
|
||||
query = query.filter(Technique.tactic == tactic)
|
||||
# Check: status is not None
|
||||
if status is not None:
|
||||
# Assign query = query.filter(Technique.status_global == status)
|
||||
query = query.filter(Technique.status_global == status)
|
||||
# Check: review_required is not None
|
||||
if review_required is not None:
|
||||
# Assign query = query.filter(Technique.review_required == review_required)
|
||||
query = query.filter(Technique.review_required == review_required)
|
||||
# Assign models = query.order_by(Technique.mitre_id).all()
|
||||
models = query.order_by(Technique.mitre_id).all()
|
||||
# Return [TechniqueMapper.to_entity(m) for m in models]
|
||||
return [TechniqueMapper.to_entity(m) for m in models]
|
||||
|
||||
# Define function list_by_ids
|
||||
def list_by_ids(self, ids: list[uuid.UUID]) -> list[TechniqueEntity]:
|
||||
"""Return techniques matching the provided list of UUIDs.
|
||||
|
||||
Args:
|
||||
ids (list[uuid.UUID]): UUIDs of the techniques to retrieve.
|
||||
|
||||
Returns:
|
||||
list[TechniqueEntity]: Technique entities corresponding to the given IDs.
|
||||
"""
|
||||
# Check: not ids
|
||||
if not ids:
|
||||
# Return []
|
||||
return []
|
||||
# Assign models = (
|
||||
models = (
|
||||
self._session.query(Technique)
|
||||
# Chain .filter() call
|
||||
.filter(Technique.id.in_(ids))
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
# Return [TechniqueMapper.to_entity(m) for m in models]
|
||||
return [TechniqueMapper.to_entity(m) for m in models]
|
||||
|
||||
# -- Batch queries (for scoring/heatmap) -------------------------------
|
||||
|
||||
def count_by_status(self) -> dict[TechniqueStatus, int]:
|
||||
"""Return a count of techniques grouped by their coverage status.
|
||||
|
||||
Returns:
|
||||
dict[TechniqueStatus, int]: Mapping of each status value to its technique count.
|
||||
"""
|
||||
# Assign rows = (
|
||||
rows = (
|
||||
self._session.query(
|
||||
Technique.status_global,
|
||||
func.count(Technique.id),
|
||||
)
|
||||
# Chain .group_by() call
|
||||
.group_by(Technique.status_global)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
# Assign result = {s: 0 for s in TechniqueStatus}
|
||||
result = {s: 0 for s in TechniqueStatus}
|
||||
# Iterate over rows
|
||||
for status_val, count in rows:
|
||||
# Assign key = (
|
||||
key = (
|
||||
status_val
|
||||
if isinstance(status_val, TechniqueStatus)
|
||||
else TechniqueStatus(status_val)
|
||||
)
|
||||
# Assign result[key] = count
|
||||
result[key] = count
|
||||
# Return result
|
||||
return result
|
||||
|
||||
# Define function find_all_with_test_counts
|
||||
def find_all_with_test_counts(self) -> list[TechniqueWithCounts]:
|
||||
"""Return all techniques with pre-aggregated test and detection rule counts.
|
||||
|
||||
Uses a single query with subqueries to avoid the N+1 pattern.
|
||||
|
||||
Returns:
|
||||
list[TechniqueWithCounts]: All techniques with their associated counts.
|
||||
"""
|
||||
# Assign test_count_sq = (
|
||||
test_count_sq = (
|
||||
self._session.query(
|
||||
Test.technique_id,
|
||||
func.count(Test.id).label("test_count"),
|
||||
func.sum(
|
||||
func.cast(Test.state == TestState.validated, self._int_type())
|
||||
).label("validated_count"),
|
||||
)
|
||||
# Chain .group_by() call
|
||||
.group_by(Test.technique_id)
|
||||
# Chain .subquery() call
|
||||
.subquery()
|
||||
)
|
||||
# Assign rule_count_sq = (
|
||||
rule_count_sq = (
|
||||
self._session.query(
|
||||
DetectionRule.mitre_technique_id,
|
||||
func.count(DetectionRule.id).label("rule_count"),
|
||||
)
|
||||
# Chain .group_by() call
|
||||
.group_by(DetectionRule.mitre_technique_id)
|
||||
# Chain .subquery() call
|
||||
.subquery()
|
||||
)
|
||||
|
||||
# Assign rows = (
|
||||
rows = (
|
||||
self._session.query(
|
||||
Technique,
|
||||
func.coalesce(test_count_sq.c.test_count, 0),
|
||||
func.coalesce(test_count_sq.c.validated_count, 0),
|
||||
func.coalesce(rule_count_sq.c.rule_count, 0),
|
||||
)
|
||||
# Chain .outerjoin() call
|
||||
.outerjoin(test_count_sq, Technique.id == test_count_sq.c.technique_id)
|
||||
# Chain .outerjoin() call
|
||||
.outerjoin(
|
||||
rule_count_sq,
|
||||
Technique.mitre_id == rule_count_sq.c.mitre_technique_id,
|
||||
)
|
||||
# Chain .order_by() call
|
||||
.order_by(Technique.mitre_id)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
|
||||
# Return [
|
||||
return [
|
||||
TechniqueWithCounts(
|
||||
# Keyword argument: entity
|
||||
entity=TechniqueMapper.to_entity(tech),
|
||||
# Keyword argument: test_count
|
||||
test_count=int(tc),
|
||||
# Keyword argument: validated_test_count
|
||||
validated_test_count=int(vtc),
|
||||
# Keyword argument: detection_rule_count
|
||||
detection_rule_count=int(rc),
|
||||
)
|
||||
for tech, tc, vtc, rc in rows
|
||||
]
|
||||
|
||||
# -- Mutations ---------------------------------------------------------
|
||||
|
||||
def save(self, technique: TechniqueEntity) -> TechniqueEntity:
|
||||
"""Persist a technique entity, inserting or updating as needed.
|
||||
|
||||
Args:
|
||||
technique (TechniqueEntity): The domain entity to persist.
|
||||
|
||||
Returns:
|
||||
TechniqueEntity: The persisted entity reflecting the current DB state.
|
||||
"""
|
||||
# Assign existing = (
|
||||
existing = (
|
||||
self._session.query(Technique)
|
||||
# Chain .filter() call
|
||||
.filter(Technique.id == technique.id)
|
||||
# Chain .first() call
|
||||
.first()
|
||||
)
|
||||
# Check: existing
|
||||
if existing:
|
||||
# Call technique.apply_to()
|
||||
technique.apply_to(existing)
|
||||
# Assign existing.mitre_id = technique.mitre_id
|
||||
existing.mitre_id = technique.mitre_id
|
||||
# Assign existing.name = technique.name
|
||||
existing.name = technique.name
|
||||
# Assign existing.tactic = technique.tactic
|
||||
existing.tactic = technique.tactic
|
||||
# Assign existing.description = technique.description
|
||||
existing.description = technique.description
|
||||
# Assign existing.platforms = technique.platforms
|
||||
existing.platforms = technique.platforms
|
||||
# Assign existing.is_subtechnique = technique.is_subtechnique
|
||||
existing.is_subtechnique = technique.is_subtechnique
|
||||
# Assign existing.parent_mitre_id = technique.parent_mitre_id
|
||||
existing.parent_mitre_id = technique.parent_mitre_id
|
||||
# Assign existing.mitre_version = technique.mitre_version
|
||||
existing.mitre_version = technique.mitre_version
|
||||
# Assign existing.mitre_last_modified = technique.mitre_last_modified
|
||||
existing.mitre_last_modified = technique.mitre_last_modified
|
||||
# Call self._session.flush()
|
||||
self._session.flush()
|
||||
# Return TechniqueMapper.to_entity(existing)
|
||||
return TechniqueMapper.to_entity(existing)
|
||||
# Fallback: handle remaining cases
|
||||
else:
|
||||
# Assign model = Technique(
|
||||
model = Technique(
|
||||
# Keyword argument: id
|
||||
id=technique.id,
|
||||
# Keyword argument: mitre_id
|
||||
mitre_id=technique.mitre_id,
|
||||
# Keyword argument: name
|
||||
name=technique.name,
|
||||
# Keyword argument: tactic
|
||||
tactic=technique.tactic,
|
||||
# Keyword argument: description
|
||||
description=technique.description,
|
||||
# Keyword argument: platforms
|
||||
platforms=technique.platforms,
|
||||
# Keyword argument: is_subtechnique
|
||||
is_subtechnique=technique.is_subtechnique,
|
||||
# Keyword argument: parent_mitre_id
|
||||
parent_mitre_id=technique.parent_mitre_id,
|
||||
# Keyword argument: status_global
|
||||
status_global=technique.status_global,
|
||||
# Keyword argument: review_required
|
||||
review_required=technique.review_required,
|
||||
# Keyword argument: last_review_date
|
||||
last_review_date=technique.last_review_date,
|
||||
# Keyword argument: mitre_version
|
||||
mitre_version=technique.mitre_version,
|
||||
# Keyword argument: mitre_last_modified
|
||||
mitre_last_modified=technique.mitre_last_modified,
|
||||
)
|
||||
# Call self._session.add()
|
||||
self._session.add(model)
|
||||
# Call self._session.flush()
|
||||
self._session.flush()
|
||||
# Return TechniqueMapper.to_entity(model)
|
||||
return TechniqueMapper.to_entity(model)
|
||||
|
||||
# Define function exists_by_mitre_id
|
||||
def exists_by_mitre_id(self, mitre_id: str) -> bool:
|
||||
"""Check whether a technique with the given MITRE ID already exists.
|
||||
|
||||
Args:
|
||||
mitre_id (str): The MITRE ATT&CK identifier to look up.
|
||||
|
||||
Returns:
|
||||
bool: ``True`` if the technique exists, ``False`` otherwise.
|
||||
"""
|
||||
# Return (
|
||||
return (
|
||||
self._session.query(Technique.id)
|
||||
# Chain .filter() call
|
||||
.filter(Technique.mitre_id == mitre_id)
|
||||
# Chain .first() call
|
||||
.first()
|
||||
) is not None
|
||||
|
||||
# -- Internal ----------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
# Define function _int_type
|
||||
def _int_type() -> type:
|
||||
"""Return an Integer type for CAST expressions (SQLite-compatible)."""
|
||||
# Import Integer from sqlalchemy
|
||||
from sqlalchemy import Integer
|
||||
# Return Integer
|
||||
return Integer
|
||||
@@ -0,0 +1,171 @@
|
||||
"""SQLAlchemy implementation of TestRepository."""
|
||||
|
||||
# Enable future language features for compatibility
|
||||
from __future__ import annotations
|
||||
|
||||
# Import uuid
|
||||
import uuid
|
||||
|
||||
# Import func from sqlalchemy
|
||||
from sqlalchemy import func
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import TestState from app.domain.enums
|
||||
from app.domain.enums import TestState
|
||||
|
||||
# Import Test from app.models.test
|
||||
from app.models.test import Test
|
||||
|
||||
|
||||
# Define class SATestRepository
|
||||
class SATestRepository:
|
||||
"""Concrete test repository backed by SQLAlchemy."""
|
||||
|
||||
# Define function __init__
|
||||
def __init__(self, session: Session) -> None:
|
||||
"""Initialise the repository with a caller-provided session.
|
||||
|
||||
Args:
|
||||
session (Session): The SQLAlchemy session to use for all queries.
|
||||
"""
|
||||
# Assign self._session = session
|
||||
self._session = session
|
||||
|
||||
# Define function find_by_id
|
||||
def find_by_id(self, test_id: uuid.UUID) -> Test | None:
|
||||
"""Return a single test by its primary key.
|
||||
|
||||
Args:
|
||||
test_id (uuid.UUID): The UUID primary key of the test.
|
||||
|
||||
Returns:
|
||||
Test | None: The ORM model instance, or ``None`` if not found.
|
||||
"""
|
||||
# Return (
|
||||
return (
|
||||
self._session.query(Test)
|
||||
# Chain .filter() call
|
||||
.filter(Test.id == test_id)
|
||||
# Chain .first() call
|
||||
.first()
|
||||
)
|
||||
|
||||
# Define function list_by_technique
|
||||
def list_by_technique(self, technique_id: uuid.UUID) -> list[Test]:
|
||||
"""Return all tests for a given technique, ordered by creation date.
|
||||
|
||||
Args:
|
||||
technique_id (uuid.UUID): The UUID of the parent technique.
|
||||
|
||||
Returns:
|
||||
list[Test]: ORM model instances ordered by ``created_at`` ascending.
|
||||
"""
|
||||
# Return (
|
||||
return (
|
||||
self._session.query(Test)
|
||||
# Chain .filter() call
|
||||
.filter(Test.technique_id == technique_id)
|
||||
# Chain .order_by() call
|
||||
.order_by(Test.created_at)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
|
||||
# Define function list_by_state
|
||||
def list_by_state(self, state: TestState) -> list[Test]:
|
||||
"""Return all tests that are currently in the given workflow state.
|
||||
|
||||
Args:
|
||||
state (TestState): The workflow state to filter on.
|
||||
|
||||
Returns:
|
||||
list[Test]: All ORM model instances with the specified state.
|
||||
"""
|
||||
# Return (
|
||||
return (
|
||||
self._session.query(Test)
|
||||
# Chain .filter() call
|
||||
.filter(Test.state == state)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
|
||||
# Define function count_by_technique_and_state
|
||||
def count_by_technique_and_state(
|
||||
self,
|
||||
# Entry: technique_id
|
||||
technique_id: uuid.UUID,
|
||||
) -> dict[TestState, int]:
|
||||
"""Return per-state test counts for a specific technique.
|
||||
|
||||
Args:
|
||||
technique_id (uuid.UUID): The UUID of the technique to aggregate for.
|
||||
|
||||
Returns:
|
||||
dict[TestState, int]: Mapping of each state to the number of tests in that state.
|
||||
"""
|
||||
# Assign rows = (
|
||||
rows = (
|
||||
self._session.query(Test.state, func.count(Test.id))
|
||||
# Chain .filter() call
|
||||
.filter(Test.technique_id == technique_id)
|
||||
# Chain .group_by() call
|
||||
.group_by(Test.state)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
# Assign result = {}
|
||||
result: dict[TestState, int] = {}
|
||||
# Iterate over rows
|
||||
for state_val, count in rows:
|
||||
# Assign key = (
|
||||
key = (
|
||||
state_val
|
||||
if isinstance(state_val, TestState)
|
||||
else TestState(state_val)
|
||||
)
|
||||
# Assign result[key] = count
|
||||
result[key] = count
|
||||
# Return result
|
||||
return result
|
||||
|
||||
# Define function get_states_and_results_for_technique
|
||||
def get_states_and_results_for_technique(
|
||||
self,
|
||||
# Entry: technique_id
|
||||
technique_id: uuid.UUID,
|
||||
) -> list[tuple[str, str | None]]:
|
||||
"""Return lightweight ``(state, detection_result)`` pairs for a technique.
|
||||
|
||||
Used by ``TechniqueEntity.recalculate_status()`` to avoid loading full
|
||||
``Test`` models.
|
||||
|
||||
Args:
|
||||
technique_id (uuid.UUID): The UUID of the technique to query.
|
||||
|
||||
Returns:
|
||||
list[tuple[str, str | None]]: Each tuple contains the state string
|
||||
and the detection result string (or ``None``).
|
||||
"""
|
||||
# Assign rows = (
|
||||
rows = (
|
||||
self._session.query(Test.state, Test.detection_result)
|
||||
# Chain .filter() call
|
||||
.filter(Test.technique_id == technique_id)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
# Return [
|
||||
return [
|
||||
(
|
||||
r.state.value if hasattr(r.state, "value") else str(r.state),
|
||||
(
|
||||
r.detection_result.value
|
||||
if hasattr(r.detection_result, "value")
|
||||
else r.detection_result
|
||||
),
|
||||
)
|
||||
for r in rows
|
||||
]
|
||||
@@ -0,0 +1,91 @@
|
||||
"""Redis client factories.
|
||||
|
||||
``settings.REDIS_URL`` selects the default logical database (usually ``0``).
|
||||
Token blacklist and application cache use separate logical DBs on the same
|
||||
Redis instance (``REDIS_TOKEN_BLACKLIST_DB``, ``REDIS_CACHE_DB``) so keys never
|
||||
collide and TTL policies can differ per workload.
|
||||
|
||||
Usage::
|
||||
|
||||
from app.infrastructure.redis_client import get_redis, get_redis_blacklist
|
||||
|
||||
get_redis().set("key", "value", ex=300)
|
||||
get_redis_blacklist().setex("blacklist:…", ttl, "1")
|
||||
"""
|
||||
|
||||
# Enable future language features for compatibility
|
||||
from __future__ import annotations
|
||||
|
||||
# Import logging
|
||||
import logging
|
||||
|
||||
# Import urlparse, urlunparse from urllib.parse
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
|
||||
# Import redis
|
||||
import redis
|
||||
|
||||
# Import settings from app.config
|
||||
from app.config import settings
|
||||
|
||||
# Assign logger = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Assign _clients = {}
|
||||
_clients: dict[str, redis.Redis] = {}
|
||||
|
||||
|
||||
# Define function _redis_url_with_db
|
||||
def _redis_url_with_db(base_url: str, db_index: int) -> str:
|
||||
"""Return *base_url* with its path replaced by ``/{db_index}``."""
|
||||
# Assign parsed = urlparse(base_url)
|
||||
parsed = urlparse(base_url)
|
||||
# Assign path = f"/{db_index}"
|
||||
path = f"/{db_index}"
|
||||
# Return urlunparse(
|
||||
return urlunparse(
|
||||
(parsed.scheme, parsed.netloc, path, "", "", ""),
|
||||
)
|
||||
|
||||
|
||||
# Define function _get_client
|
||||
def _get_client(url: str) -> redis.Redis:
|
||||
# Check: url not in _clients
|
||||
if url not in _clients:
|
||||
# Assign _clients[url] = redis.from_url(url, decode_responses=True)
|
||||
_clients[url] = redis.from_url(url, decode_responses=True)
|
||||
# Log info: "Redis client connected to %s", url
|
||||
logger.info("Redis client connected to %s", url)
|
||||
# Return _clients[url]
|
||||
return _clients[url]
|
||||
|
||||
|
||||
# Define function get_redis
|
||||
def get_redis() -> redis.Redis:
|
||||
"""Default Redis connection (URL from ``settings.REDIS_URL``)."""
|
||||
# Return _get_client(settings.REDIS_URL)
|
||||
return _get_client(settings.REDIS_URL)
|
||||
|
||||
|
||||
# Define function get_redis_blacklist
|
||||
def get_redis_blacklist() -> redis.Redis:
|
||||
"""Redis DB used for JWT revocation (``jti`` keys with TTL)."""
|
||||
# Assign url = _redis_url_with_db(
|
||||
url = _redis_url_with_db(
|
||||
settings.REDIS_URL,
|
||||
settings.REDIS_TOKEN_BLACKLIST_DB,
|
||||
)
|
||||
# Return _get_client(url)
|
||||
return _get_client(url)
|
||||
|
||||
|
||||
# Define function get_redis_cache
|
||||
def get_redis_cache() -> redis.Redis:
|
||||
"""Redis DB reserved for shared cache (scores, queues, etc.)."""
|
||||
# Assign url = _redis_url_with_db(
|
||||
url = _redis_url_with_db(
|
||||
settings.REDIS_URL,
|
||||
settings.REDIS_CACHE_DB,
|
||||
)
|
||||
# Return _get_client(url)
|
||||
return _get_client(url)
|
||||
@@ -0,0 +1 @@
|
||||
"""Background scheduler jobs (MITRE sync, Jira sync, data retention)."""
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
"""Scheduled job — syncs all Jira links hourly."""
|
||||
|
||||
# Import logging
|
||||
import logging
|
||||
|
||||
# Import settings from app.config
|
||||
from app.config import settings
|
||||
|
||||
# Import SessionLocal from app.database
|
||||
from app.database import SessionLocal
|
||||
|
||||
# Import JiraLink from app.models.jira_link
|
||||
from app.models.jira_link import JiraLink
|
||||
|
||||
# Import jira_service from app.services
|
||||
from app.services import jira_service
|
||||
|
||||
# Assign logger = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Define function sync_all_jira_links
|
||||
def sync_all_jira_links() -> None:
|
||||
"""Pull latest status from Jira for every stored link.
|
||||
|
||||
Silently skips if ``JIRA_ENABLED`` is ``False``. Individual link
|
||||
failures are logged but do not abort the rest of the batch.
|
||||
"""
|
||||
# Check: not settings.JIRA_ENABLED
|
||||
if not settings.JIRA_ENABLED:
|
||||
# Return control to caller
|
||||
return
|
||||
|
||||
# Assign db = SessionLocal()
|
||||
db = SessionLocal()
|
||||
# Attempt the following; catch errors below
|
||||
try:
|
||||
# Assign links = db.query(JiraLink).all()
|
||||
links = db.query(JiraLink).all()
|
||||
# Assign synced = 0
|
||||
synced = 0
|
||||
# Iterate over links
|
||||
for link in links:
|
||||
# Attempt the following; catch errors below
|
||||
try:
|
||||
# Call jira_service.sync_jira_to_aegis()
|
||||
jira_service.sync_jira_to_aegis(db, link)
|
||||
# Assign synced = 1
|
||||
synced += 1
|
||||
# Handle Exception
|
||||
except Exception as e:
|
||||
# Log warning: "Jira sync failed for link %s: %s", link.id, e
|
||||
logger.warning("Jira sync failed for link %s: %s", link.id, e)
|
||||
# Commit all pending changes to the database
|
||||
db.commit()
|
||||
# Log info: "Jira sync completed: %d/%d links updated", synced
|
||||
logger.info("Jira sync completed: %d/%d links updated", synced, len(links))
|
||||
# Handle Exception
|
||||
except Exception:
|
||||
# Log exception: "Jira sync batch job failed"
|
||||
logger.exception("Jira sync batch job failed")
|
||||
# Always execute this cleanup block
|
||||
finally:
|
||||
# Close the database session
|
||||
db.close()
|
||||
@@ -10,17 +10,43 @@ Each job manages its own database session (created on entry, closed in
|
||||
sessions.
|
||||
"""
|
||||
|
||||
# Import logging
|
||||
import logging
|
||||
|
||||
# Import BackgroundScheduler from apscheduler.schedulers.background
|
||||
from apscheduler.schedulers.background import BackgroundScheduler
|
||||
|
||||
# Import SessionLocal from app.database
|
||||
from app.database import SessionLocal
|
||||
from app.services.mitre_sync_service import sync_mitre
|
||||
from app.services.intel_service import scan_intel
|
||||
from app.services.notification_service import cleanup_old_notifications
|
||||
from app.services.snapshot_service import create_snapshot, cleanup_old_snapshots
|
||||
|
||||
# Import sync_all_jira_links from app.jobs.jira_sync_job
|
||||
from app.jobs.jira_sync_job import sync_all_jira_links
|
||||
|
||||
# Import run_retention_job from app.jobs.retention_job
|
||||
from app.jobs.retention_job import run_retention_job
|
||||
|
||||
# Import check_and_run_recurring_campaigns from app.services.campaign_scheduler_service
|
||||
from app.services.campaign_scheduler_service import check_and_run_recurring_campaigns
|
||||
|
||||
# Import scan_intel from app.services.intel_service
|
||||
from app.services.intel_service import scan_intel
|
||||
|
||||
# Import sync_mitre from app.services.mitre_sync_service
|
||||
from app.services.mitre_sync_service import sync_mitre
|
||||
|
||||
# Import cleanup_old_notifications from app.services.notification_service
|
||||
from app.services.notification_service import cleanup_old_notifications
|
||||
|
||||
# Import enrich_all_techniques from app.services.osint_enrichment_service
|
||||
from app.services.osint_enrichment_service import enrich_all_techniques
|
||||
|
||||
# Import cleanup_old_snapshots, create_snapshot from app.services.snapshot_service
|
||||
from app.services.snapshot_service import cleanup_old_snapshots, create_snapshot
|
||||
|
||||
# Import detect_stale_coverage from app.services.stale_detection_service
|
||||
from app.services.stale_detection_service import detect_stale_coverage
|
||||
|
||||
# Assign logger = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -37,73 +63,172 @@ scheduler = BackgroundScheduler()
|
||||
|
||||
def _run_mitre_sync() -> None:
|
||||
"""Execute a MITRE sync inside its own DB session."""
|
||||
# Log info: "Scheduled MITRE sync job starting..."
|
||||
logger.info("Scheduled MITRE sync job starting...")
|
||||
# Assign db = SessionLocal()
|
||||
db = SessionLocal()
|
||||
# Attempt the following; catch errors below
|
||||
try:
|
||||
# Assign summary = sync_mitre(db)
|
||||
summary = sync_mitre(db)
|
||||
# Log info: "Scheduled MITRE sync job finished — %s", summary
|
||||
logger.info("Scheduled MITRE sync job finished — %s", summary)
|
||||
# Handle Exception
|
||||
except Exception:
|
||||
# Log exception: "Scheduled MITRE sync job failed"
|
||||
logger.exception("Scheduled MITRE sync job failed")
|
||||
# Always execute this cleanup block
|
||||
finally:
|
||||
# Close the database session
|
||||
db.close()
|
||||
|
||||
|
||||
# Define function _run_notification_cleanup
|
||||
def _run_notification_cleanup() -> None:
|
||||
"""Clean up old read notifications."""
|
||||
# Log info: "Scheduled notification cleanup job starting..."
|
||||
logger.info("Scheduled notification cleanup job starting...")
|
||||
# Assign db = SessionLocal()
|
||||
db = SessionLocal()
|
||||
# Attempt the following; catch errors below
|
||||
try:
|
||||
# Assign deleted = cleanup_old_notifications(db, days=90)
|
||||
deleted = cleanup_old_notifications(db, days=90)
|
||||
# Log info: "Notification cleanup finished — deleted %d old no
|
||||
logger.info("Notification cleanup finished — deleted %d old notifications", deleted)
|
||||
# Handle Exception
|
||||
except Exception:
|
||||
# Log exception: "Notification cleanup job failed"
|
||||
logger.exception("Notification cleanup job failed")
|
||||
# Always execute this cleanup block
|
||||
finally:
|
||||
# Close the database session
|
||||
db.close()
|
||||
|
||||
|
||||
# Define function _run_weekly_snapshot
|
||||
def _run_weekly_snapshot() -> None:
|
||||
"""Create a weekly coverage snapshot and clean up old ones."""
|
||||
# Log info: "Scheduled weekly snapshot job starting..."
|
||||
logger.info("Scheduled weekly snapshot job starting...")
|
||||
# Assign db = SessionLocal()
|
||||
db = SessionLocal()
|
||||
# Attempt the following; catch errors below
|
||||
try:
|
||||
# Assign snapshot = create_snapshot(db, name="Auto-weekly")
|
||||
snapshot = create_snapshot(db, name="Auto-weekly")
|
||||
# Log info:
|
||||
logger.info(
|
||||
# Literal argument value
|
||||
"Weekly snapshot created — score %.1f, %d techniques",
|
||||
snapshot.organization_score,
|
||||
snapshot.total_techniques,
|
||||
)
|
||||
# Assign deleted = cleanup_old_snapshots(db, keep_last=52)
|
||||
deleted = cleanup_old_snapshots(db, keep_last=52)
|
||||
# Check: deleted
|
||||
if deleted:
|
||||
# Log info: "Cleaned up %d old snapshots", deleted
|
||||
logger.info("Cleaned up %d old snapshots", deleted)
|
||||
# Handle Exception
|
||||
except Exception:
|
||||
# Log exception: "Weekly snapshot job failed"
|
||||
logger.exception("Weekly snapshot job failed")
|
||||
# Always execute this cleanup block
|
||||
finally:
|
||||
# Close the database session
|
||||
db.close()
|
||||
|
||||
|
||||
# Define function _run_recurring_campaigns
|
||||
def _run_recurring_campaigns() -> None:
|
||||
"""Check and run any due recurring campaigns."""
|
||||
# Log info: "Scheduled recurring campaigns check starting..."
|
||||
logger.info("Scheduled recurring campaigns check starting...")
|
||||
# Assign db = SessionLocal()
|
||||
db = SessionLocal()
|
||||
# Attempt the following; catch errors below
|
||||
try:
|
||||
# Assign spawned = check_and_run_recurring_campaigns(db)
|
||||
spawned = check_and_run_recurring_campaigns(db)
|
||||
# Log info: "Recurring campaigns check finished — spawned %d c
|
||||
logger.info("Recurring campaigns check finished — spawned %d campaigns", spawned)
|
||||
# Handle Exception
|
||||
except Exception:
|
||||
# Log exception: "Recurring campaigns check failed"
|
||||
logger.exception("Recurring campaigns check failed")
|
||||
# Always execute this cleanup block
|
||||
finally:
|
||||
# Close the database session
|
||||
db.close()
|
||||
|
||||
|
||||
# Define function _run_intel_scan
|
||||
def _run_intel_scan() -> None:
|
||||
"""Execute an intel scan inside its own DB session."""
|
||||
# Log info: "Scheduled intel scan job starting..."
|
||||
logger.info("Scheduled intel scan job starting...")
|
||||
# Assign db = SessionLocal()
|
||||
db = SessionLocal()
|
||||
# Attempt the following; catch errors below
|
||||
try:
|
||||
# Assign summary = scan_intel(db)
|
||||
summary = scan_intel(db)
|
||||
# Log info: "Scheduled intel scan job finished — %s", summary
|
||||
logger.info("Scheduled intel scan job finished — %s", summary)
|
||||
# Handle Exception
|
||||
except Exception:
|
||||
# Log exception: "Scheduled intel scan job failed"
|
||||
logger.exception("Scheduled intel scan job failed")
|
||||
# Always execute this cleanup block
|
||||
finally:
|
||||
# Close the database session
|
||||
db.close()
|
||||
|
||||
|
||||
# Define function _run_osint_enrichment
|
||||
def _run_osint_enrichment() -> None:
|
||||
"""Execute weekly OSINT enrichment inside its own DB session."""
|
||||
# Log info: "Scheduled OSINT enrichment job starting..."
|
||||
logger.info("Scheduled OSINT enrichment job starting...")
|
||||
# Assign db = SessionLocal()
|
||||
db = SessionLocal()
|
||||
# Attempt the following; catch errors below
|
||||
try:
|
||||
# Assign total = enrich_all_techniques(db)
|
||||
total = enrich_all_techniques(db)
|
||||
# Log info: "OSINT enrichment finished — %d new items", total
|
||||
logger.info("OSINT enrichment finished — %d new items", total)
|
||||
# Handle Exception
|
||||
except Exception:
|
||||
# Log exception: "OSINT enrichment job failed"
|
||||
logger.exception("OSINT enrichment job failed")
|
||||
# Always execute this cleanup block
|
||||
finally:
|
||||
# Close the database session
|
||||
db.close()
|
||||
|
||||
|
||||
# Define function _run_stale_detection
|
||||
def _run_stale_detection() -> None:
|
||||
"""Execute daily stale coverage detection inside its own DB session."""
|
||||
# Log info: "Scheduled stale coverage detection starting..."
|
||||
logger.info("Scheduled stale coverage detection starting...")
|
||||
# Assign db = SessionLocal()
|
||||
db = SessionLocal()
|
||||
# Attempt the following; catch errors below
|
||||
try:
|
||||
# Assign count = detect_stale_coverage(db)
|
||||
count = detect_stale_coverage(db)
|
||||
# Log info: "Stale detection finished — %d techniques flagged"
|
||||
logger.info("Stale detection finished — %d techniques flagged", count)
|
||||
# Handle Exception
|
||||
except Exception:
|
||||
# Log exception: "Stale coverage detection job failed"
|
||||
logger.exception("Stale coverage detection job failed")
|
||||
# Always execute this cleanup block
|
||||
finally:
|
||||
# Close the database session
|
||||
db.close()
|
||||
|
||||
|
||||
@@ -122,51 +247,148 @@ def start_scheduler() -> None:
|
||||
|
||||
Neither job fires immediately on startup.
|
||||
"""
|
||||
# Call scheduler.add_job()
|
||||
scheduler.add_job(
|
||||
_run_mitre_sync,
|
||||
# Keyword argument: trigger
|
||||
trigger="interval",
|
||||
# Keyword argument: hours
|
||||
hours=24,
|
||||
# Keyword argument: id
|
||||
id="mitre_sync",
|
||||
# Keyword argument: name
|
||||
name="MITRE ATT&CK sync (every 24h)",
|
||||
# Keyword argument: replace_existing
|
||||
replace_existing=True,
|
||||
)
|
||||
# Call scheduler.add_job()
|
||||
scheduler.add_job(
|
||||
_run_intel_scan,
|
||||
# Keyword argument: trigger
|
||||
trigger="interval",
|
||||
# Keyword argument: weeks
|
||||
weeks=1,
|
||||
# Keyword argument: id
|
||||
id="intel_scan",
|
||||
# Keyword argument: name
|
||||
name="Intel scan (every 7d)",
|
||||
# Keyword argument: replace_existing
|
||||
replace_existing=True,
|
||||
)
|
||||
# Call scheduler.add_job()
|
||||
scheduler.add_job(
|
||||
_run_notification_cleanup,
|
||||
# Keyword argument: trigger
|
||||
trigger="interval",
|
||||
# Keyword argument: hours
|
||||
hours=24,
|
||||
# Keyword argument: id
|
||||
id="notification_cleanup",
|
||||
# Keyword argument: name
|
||||
name="Notification cleanup (daily)",
|
||||
# Keyword argument: replace_existing
|
||||
replace_existing=True,
|
||||
)
|
||||
# Call scheduler.add_job()
|
||||
scheduler.add_job(
|
||||
_run_weekly_snapshot,
|
||||
# Keyword argument: trigger
|
||||
trigger="cron",
|
||||
# Keyword argument: day_of_week
|
||||
day_of_week="sun",
|
||||
# Keyword argument: hour
|
||||
hour=0,
|
||||
# Keyword argument: minute
|
||||
minute=0,
|
||||
# Keyword argument: id
|
||||
id="weekly_snapshot",
|
||||
# Keyword argument: name
|
||||
name="Weekly coverage snapshot (Sundays 00:00)",
|
||||
# Keyword argument: replace_existing
|
||||
replace_existing=True,
|
||||
)
|
||||
# Call scheduler.add_job()
|
||||
scheduler.add_job(
|
||||
_run_recurring_campaigns,
|
||||
# Keyword argument: trigger
|
||||
trigger="interval",
|
||||
# Keyword argument: hours
|
||||
hours=24,
|
||||
# Keyword argument: id
|
||||
id="recurring_campaigns",
|
||||
# Keyword argument: name
|
||||
name="Recurring campaigns check (daily)",
|
||||
# Keyword argument: replace_existing
|
||||
replace_existing=True,
|
||||
)
|
||||
scheduler.start()
|
||||
logger.info(
|
||||
"Background scheduler started — mitre_sync (24h), intel_scan (7d), "
|
||||
"notification_cleanup (24h), weekly_snapshot (Sundays 00:00), "
|
||||
"recurring_campaigns (daily)"
|
||||
# Call scheduler.add_job()
|
||||
scheduler.add_job(
|
||||
sync_all_jira_links,
|
||||
# Keyword argument: trigger
|
||||
trigger="interval",
|
||||
# Keyword argument: hours
|
||||
hours=1,
|
||||
# Keyword argument: id
|
||||
id="jira_sync",
|
||||
# Keyword argument: name
|
||||
name="Jira link sync (hourly)",
|
||||
# Keyword argument: replace_existing
|
||||
replace_existing=True,
|
||||
)
|
||||
# Call scheduler.add_job()
|
||||
scheduler.add_job(
|
||||
_run_osint_enrichment,
|
||||
# Keyword argument: trigger
|
||||
trigger="interval",
|
||||
# Keyword argument: weeks
|
||||
weeks=1,
|
||||
# Keyword argument: id
|
||||
id="osint_enrichment",
|
||||
# Keyword argument: name
|
||||
name="OSINT enrichment (weekly)",
|
||||
# Keyword argument: replace_existing
|
||||
replace_existing=True,
|
||||
)
|
||||
# Call scheduler.add_job()
|
||||
scheduler.add_job(
|
||||
_run_stale_detection,
|
||||
# Keyword argument: trigger
|
||||
trigger="interval",
|
||||
# Keyword argument: hours
|
||||
hours=24,
|
||||
# Keyword argument: id
|
||||
id="stale_detection",
|
||||
# Keyword argument: name
|
||||
name="Stale coverage detection (daily)",
|
||||
# Keyword argument: replace_existing
|
||||
replace_existing=True,
|
||||
)
|
||||
# Call scheduler.add_job()
|
||||
scheduler.add_job(
|
||||
run_retention_job,
|
||||
# Keyword argument: trigger
|
||||
trigger="interval",
|
||||
# Keyword argument: hours
|
||||
hours=24,
|
||||
# Keyword argument: id
|
||||
id="retention_policies",
|
||||
# Keyword argument: name
|
||||
name="Data retention policies (daily)",
|
||||
# Keyword argument: replace_existing
|
||||
replace_existing=True,
|
||||
)
|
||||
# Call scheduler.start()
|
||||
scheduler.start()
|
||||
# Log info:
|
||||
logger.info(
|
||||
# Literal argument value
|
||||
"Background scheduler started — mitre_sync (24h), intel_scan (7d), "
|
||||
# Literal argument value
|
||||
"notification_cleanup (24h), weekly_snapshot (Sundays 00:00), "
|
||||
# Literal argument value
|
||||
"recurring_campaigns (daily), jira_sync (1h), "
|
||||
# Literal argument value
|
||||
"osint_enrichment (weekly), stale_detection (daily), "
|
||||
# Literal argument value
|
||||
"retention_policies (daily)"
|
||||
)
|
||||
|
||||
@@ -0,0 +1,89 @@
|
||||
"""Data retention policies — scheduled cleanup of aged records."""
|
||||
|
||||
# Enable future language features for compatibility
|
||||
from __future__ import annotations
|
||||
|
||||
# Import logging
|
||||
import logging
|
||||
|
||||
# Import datetime, timedelta, timezone from datetime
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import SessionLocal from app.database
|
||||
from app.database import SessionLocal
|
||||
|
||||
# Import AuditLog from app.models.audit
|
||||
from app.models.audit import AuditLog
|
||||
|
||||
# Import cleanup_old_notifications from app.services.notification_service
|
||||
from app.services.notification_service import cleanup_old_notifications
|
||||
|
||||
# Assign logger = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Assign AUDIT_LOG_RETENTION_DAYS = 730
|
||||
AUDIT_LOG_RETENTION_DAYS = 730
|
||||
|
||||
|
||||
# Define function apply_retention_policies
|
||||
def apply_retention_policies(db: Session) -> dict[str, int]:
|
||||
"""Apply retention rules. Commits the session before returning."""
|
||||
# Assign cutoff = datetime.now(timezone.utc) - timedelta(days=AUDIT_LOG_RETENTION_DAYS)
|
||||
cutoff = datetime.now(timezone.utc) - timedelta(days=AUDIT_LOG_RETENTION_DAYS)
|
||||
# Assign deleted_audit = (
|
||||
deleted_audit = (
|
||||
db.query(AuditLog)
|
||||
# Chain .filter() call
|
||||
.filter(AuditLog.timestamp < cutoff)
|
||||
# Chain .delete() call
|
||||
.delete(synchronize_session=False)
|
||||
)
|
||||
# Check: deleted_audit
|
||||
if deleted_audit:
|
||||
# Log info:
|
||||
logger.info(
|
||||
# Literal argument value
|
||||
"Retention: deleted %d audit logs older than %d days",
|
||||
deleted_audit,
|
||||
AUDIT_LOG_RETENTION_DAYS,
|
||||
)
|
||||
|
||||
# Assign deleted_notifications = cleanup_old_notifications(db, days=90)
|
||||
deleted_notifications = cleanup_old_notifications(db, days=90)
|
||||
# Commit all pending changes to the database
|
||||
db.commit()
|
||||
# Return {
|
||||
return {
|
||||
# Literal argument value
|
||||
"audit_logs_deleted": deleted_audit,
|
||||
# Literal argument value
|
||||
"notifications_deleted": deleted_notifications,
|
||||
}
|
||||
|
||||
|
||||
# Define function run_retention_job
|
||||
def run_retention_job() -> None:
|
||||
"""Entry point for the daily retention scheduler job."""
|
||||
# Log info: "Scheduled retention job starting..."
|
||||
logger.info("Scheduled retention job starting...")
|
||||
# Assign db = SessionLocal()
|
||||
db = SessionLocal()
|
||||
# Attempt the following; catch errors below
|
||||
try:
|
||||
# Assign summary = apply_retention_policies(db)
|
||||
summary = apply_retention_policies(db)
|
||||
# Log info: "Retention job finished — %s", summary
|
||||
logger.info("Retention job finished — %s", summary)
|
||||
# Handle Exception
|
||||
except Exception:
|
||||
# Log exception: "Retention job failed"
|
||||
logger.exception("Retention job failed")
|
||||
# Roll back all uncommitted changes
|
||||
db.rollback()
|
||||
# Always execute this cleanup block
|
||||
finally:
|
||||
# Close the database session
|
||||
db.close()
|
||||
@@ -0,0 +1,10 @@
|
||||
"""Shared SlowAPI rate limiter for all routers."""
|
||||
|
||||
# Import Limiter from slowapi
|
||||
from slowapi import Limiter
|
||||
|
||||
# Import get_remote_address from slowapi.util
|
||||
from slowapi.util import get_remote_address
|
||||
|
||||
# Assign limiter = Limiter(key_func=get_remote_address)
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
@@ -0,0 +1,108 @@
|
||||
"""Structured JSON logging configuration.
|
||||
|
||||
In **production** (``AEGIS_ENV=production``), emits one JSON object per
|
||||
line so that log aggregators (ELK, CloudWatch, Datadog) can ingest them
|
||||
without custom parsing.
|
||||
|
||||
In **development** (default), uses a human-readable text format for
|
||||
comfortable local work.
|
||||
"""
|
||||
|
||||
# Enable future language features for compatibility
|
||||
from __future__ import annotations
|
||||
|
||||
# Import json
|
||||
import json
|
||||
|
||||
# Import logging
|
||||
import logging
|
||||
|
||||
# Import os
|
||||
import os
|
||||
|
||||
# Import sys
|
||||
import sys
|
||||
|
||||
# Import datetime, timezone from datetime
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
# Define class _JSONFormatter
|
||||
class _JSONFormatter(logging.Formatter):
|
||||
"""Emit each log record as a single-line JSON object."""
|
||||
|
||||
# Define function format
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
# Assign payload = {
|
||||
payload: dict = {
|
||||
# Literal argument value
|
||||
"timestamp": datetime.fromtimestamp(record.created, tz=timezone.utc).isoformat(),
|
||||
# Literal argument value
|
||||
"level": record.levelname,
|
||||
# Literal argument value
|
||||
"logger": record.name,
|
||||
# Literal argument value
|
||||
"message": record.getMessage(),
|
||||
}
|
||||
|
||||
# Check: record.exc_info and record.exc_info[1] is not None
|
||||
if record.exc_info and record.exc_info[1] is not None:
|
||||
# Assign payload["exception"] = self.formatException(record.exc_info)
|
||||
payload["exception"] = self.formatException(record.exc_info)
|
||||
|
||||
# Assign extra = getattr(record, "_extra", None)
|
||||
extra = getattr(record, "_extra", None)
|
||||
# Check: extra
|
||||
if extra:
|
||||
# Call payload.update()
|
||||
payload.update(extra)
|
||||
|
||||
# Return json.dumps(payload, default=str)
|
||||
return json.dumps(payload, default=str)
|
||||
|
||||
|
||||
# Assign _DEV_FORMAT = "%(asctime)s %(levelname)-8s %(name)s — %(message)s"
|
||||
_DEV_FORMAT = "%(asctime)s %(levelname)-8s %(name)s — %(message)s"
|
||||
|
||||
|
||||
# Define function setup_logging
|
||||
def setup_logging() -> None:
|
||||
"""Configure the root logger based on the environment."""
|
||||
# Assign is_production = os.environ.get("AEGIS_ENV", "").lower() == "production"
|
||||
is_production = os.environ.get("AEGIS_ENV", "").lower() == "production"
|
||||
# Assign level_name = os.environ.get("LOG_LEVEL", "INFO").upper()
|
||||
level_name = os.environ.get("LOG_LEVEL", "INFO").upper()
|
||||
# Assign level = getattr(logging, level_name, logging.INFO)
|
||||
level = getattr(logging, level_name, logging.INFO)
|
||||
|
||||
# Assign root = logging.getLogger()
|
||||
root = logging.getLogger()
|
||||
# Call root.setLevel()
|
||||
root.setLevel(level)
|
||||
|
||||
# Check: root.handlers
|
||||
if root.handlers:
|
||||
# Call root.handlers.clear()
|
||||
root.handlers.clear()
|
||||
|
||||
# Assign handler = logging.StreamHandler(sys.stdout)
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
# Call handler.setLevel()
|
||||
handler.setLevel(level)
|
||||
|
||||
# Check: is_production
|
||||
if is_production:
|
||||
# Call handler.setFormatter()
|
||||
handler.setFormatter(_JSONFormatter())
|
||||
# Fallback: handle remaining cases
|
||||
else:
|
||||
# Call handler.setFormatter()
|
||||
handler.setFormatter(logging.Formatter(_DEV_FORMAT))
|
||||
|
||||
# Call root.addHandler()
|
||||
root.addHandler(handler)
|
||||
|
||||
# Call logging.getLogger()
|
||||
logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
|
||||
# Call logging.getLogger()
|
||||
logging.getLogger("sqlalchemy.engine").setLevel(logging.WARNING)
|
||||
+318
-41
@@ -1,129 +1,406 @@
|
||||
"""FastAPI application factory and global middleware/exception configuration.
|
||||
|
||||
Builds the ``app`` instance, wires up CORS, rate limiting, domain-error
|
||||
mapping, all API routers, and async lifespan hooks (MinIO bucket creation,
|
||||
APScheduler startup/shutdown).
|
||||
"""
|
||||
|
||||
# Import logging
|
||||
import logging
|
||||
|
||||
# Import os
|
||||
import os
|
||||
|
||||
# Import AsyncGenerator from collections.abc
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
# Import asynccontextmanager from contextlib
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
# Import FastAPI, Request, status from fastapi
|
||||
from fastapi import FastAPI, Request, status
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
# Import RequestValidationError from fastapi.exceptions
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
|
||||
# Import CORSMiddleware from fastapi.middleware.cors
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
# Import JSONResponse from fastapi.responses
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
# Import _rate_limit_exceeded_handler from slowapi
|
||||
from slowapi import _rate_limit_exceeded_handler
|
||||
|
||||
# Import RateLimitExceeded from slowapi.errors
|
||||
from slowapi.errors import RateLimitExceeded
|
||||
|
||||
# Import SQLAlchemyError from sqlalchemy.exc
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from app.routers import auth as auth_router
|
||||
from app.routers import techniques as techniques_router
|
||||
from app.routers import tests as tests_router
|
||||
from app.routers import evidence as evidence_router
|
||||
from app.routers import test_templates as test_templates_router
|
||||
from app.routers import system as system_router
|
||||
from app.routers import metrics as metrics_router
|
||||
from app.routers import users as users_router
|
||||
# Import settings as _settings from app.config
|
||||
from app.config import settings as _settings
|
||||
|
||||
# Import DomainError from app.domain.errors
|
||||
from app.domain.errors import DomainError
|
||||
|
||||
# Import scheduler, start_scheduler from app.jobs.mitre_sync_job
|
||||
from app.jobs.mitre_sync_job import scheduler, start_scheduler
|
||||
|
||||
# Import limiter from app.limiter
|
||||
from app.limiter import limiter
|
||||
|
||||
# Import setup_logging from app.logging_config
|
||||
from app.logging_config import setup_logging
|
||||
|
||||
# Import domain_exception_handler from app.middleware.error_handler
|
||||
from app.middleware.error_handler import domain_exception_handler
|
||||
|
||||
# Import RequestContextMiddleware from app.middleware.request_context
|
||||
from app.middleware.request_context import RequestContextMiddleware
|
||||
|
||||
# Import advanced_metrics as advanced_metrics_router from app.routers
|
||||
from app.routers import advanced_metrics as advanced_metrics_router
|
||||
|
||||
# Import analytics as analytics_router from app.routers
|
||||
from app.routers import analytics as analytics_router
|
||||
|
||||
# Import audit as audit_router from app.routers
|
||||
from app.routers import audit as audit_router
|
||||
from app.routers import notifications as notifications_router
|
||||
from app.routers import reports as reports_router
|
||||
from app.routers import data_sources as data_sources_router
|
||||
from app.routers import threat_actors as threat_actors_router
|
||||
from app.routers import d3fend as d3fend_router
|
||||
from app.routers import detection_rules as detection_rules_router
|
||||
|
||||
# Import auth as auth_router from app.routers
|
||||
from app.routers import auth as auth_router
|
||||
|
||||
# Import campaigns as campaigns_router from app.routers
|
||||
from app.routers import campaigns as campaigns_router
|
||||
from app.routers import heatmap as heatmap_router
|
||||
from app.routers import scores as scores_router
|
||||
from app.routers import operational_metrics as operational_metrics_router
|
||||
|
||||
# Import compliance as compliance_router from app.routers
|
||||
from app.routers import compliance as compliance_router
|
||||
|
||||
# Import d3fend as d3fend_router from app.routers
|
||||
from app.routers import d3fend as d3fend_router
|
||||
|
||||
# Import data_sources as data_sources_router from app.routers
|
||||
from app.routers import data_sources as data_sources_router
|
||||
|
||||
# Import detection_rules as detection_rules_router from app.routers
|
||||
from app.routers import detection_rules as detection_rules_router
|
||||
|
||||
# Import evidence as evidence_router from app.routers
|
||||
from app.routers import evidence as evidence_router
|
||||
|
||||
# Import heatmap as heatmap_router from app.routers
|
||||
from app.routers import heatmap as heatmap_router
|
||||
|
||||
# Import jira as jira_router from app.routers
|
||||
from app.routers import jira as jira_router
|
||||
|
||||
# Import metrics as metrics_router from app.routers
|
||||
from app.routers import metrics as metrics_router
|
||||
|
||||
# Import notifications as notifications_router from app.routers
|
||||
from app.routers import notifications as notifications_router
|
||||
|
||||
# Import operational_metrics as operational_metrics_router from app.routers
|
||||
from app.routers import operational_metrics as operational_metrics_router
|
||||
|
||||
# Import osint as osint_router from app.routers
|
||||
from app.routers import osint as osint_router
|
||||
|
||||
# Import professional_reports as professional_reports_ro... from app.routers
|
||||
from app.routers import professional_reports as professional_reports_router
|
||||
|
||||
# Import reports as reports_router from app.routers
|
||||
from app.routers import reports as reports_router
|
||||
|
||||
# Import scores as scores_router from app.routers
|
||||
from app.routers import scores as scores_router
|
||||
|
||||
# Import snapshots as snapshots_router from app.routers
|
||||
from app.routers import snapshots as snapshots_router
|
||||
|
||||
# Import system as system_router from app.routers
|
||||
from app.routers import system as system_router
|
||||
|
||||
# Import techniques as techniques_router from app.routers
|
||||
from app.routers import techniques as techniques_router
|
||||
|
||||
# Import test_templates as test_templates_router from app.routers
|
||||
from app.routers import test_templates as test_templates_router
|
||||
|
||||
# Import tests as tests_router from app.routers
|
||||
from app.routers import tests as tests_router
|
||||
|
||||
# Import threat_actors as threat_actors_router from app.routers
|
||||
from app.routers import threat_actors as threat_actors_router
|
||||
|
||||
# Import users as users_router from app.routers
|
||||
from app.routers import users as users_router
|
||||
|
||||
# Import worklogs as worklogs_router from app.routers
|
||||
from app.routers import worklogs as worklogs_router
|
||||
|
||||
# Import ensure_bucket_exists from app.storage
|
||||
from app.storage import ensure_bucket_exists
|
||||
from app.jobs.mitre_sync_job import start_scheduler, scheduler
|
||||
|
||||
# ── Logging ───────────────────────────────────────────────────────────────
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)-8s %(name)s — %(message)s",
|
||||
)
|
||||
# Configure structured logging before any module initialises its own logger
|
||||
setup_logging()
|
||||
|
||||
# ── Environment detection ─────────────────────────────────────────────────
|
||||
_IS_PRODUCTION = os.environ.get("AEGIS_ENV", "").lower() == "production"
|
||||
|
||||
# Apply the @asynccontextmanager decorator
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Startup / shutdown logic."""
|
||||
# Define async function lifespan
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
"""Manage application startup and shutdown lifecycle.
|
||||
|
||||
Args:
|
||||
app (FastAPI): The FastAPI application instance.
|
||||
|
||||
Yields:
|
||||
None: Control is yielded to the running application.
|
||||
"""
|
||||
# Call ensure_bucket_exists()
|
||||
ensure_bucket_exists()
|
||||
# Call start_scheduler()
|
||||
start_scheduler()
|
||||
# Yield value
|
||||
yield
|
||||
# Graceful shutdown of the background scheduler
|
||||
scheduler.shutdown(wait=False)
|
||||
|
||||
|
||||
app = FastAPI(title="Attack Coverage Platform", lifespan=lifespan)
|
||||
# ── In production, disable Swagger UI and ReDoc to hide API surface ──────
|
||||
app = FastAPI(
|
||||
# Keyword argument: title
|
||||
title="Attack Coverage Platform",
|
||||
# Keyword argument: lifespan
|
||||
lifespan=lifespan,
|
||||
# Keyword argument: docs_url
|
||||
docs_url=None if _IS_PRODUCTION else "/docs",
|
||||
# Keyword argument: redoc_url
|
||||
redoc_url=None if _IS_PRODUCTION else "/redoc",
|
||||
# Keyword argument: openapi_url
|
||||
openapi_url=None if _IS_PRODUCTION else "/openapi.json",
|
||||
)
|
||||
|
||||
# ── Rate Limiter ──────────────────────────────────────────────────────────
|
||||
app.state.limiter = limiter
|
||||
# Call app.add_exception_handler()
|
||||
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
||||
|
||||
# Call app.add_middleware()
|
||||
app.add_middleware(RequestContextMiddleware)
|
||||
|
||||
# ── Domain exception → HTTP mapping ──────────────────────────────────────
|
||||
app.add_exception_handler(DomainError, domain_exception_handler)
|
||||
|
||||
# ── CORS ──────────────────────────────────────────────────────────────────
|
||||
_cors_origins: list[str] = [
|
||||
o.strip() for o in _settings.CORS_ORIGINS.split(",") if o.strip()
|
||||
]
|
||||
|
||||
# Call app.add_middleware()
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["http://localhost:3000", "http://localhost:5173"],
|
||||
# Keyword argument: allow_origins
|
||||
allow_origins=_cors_origins,
|
||||
# Keyword argument: allow_credentials
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
# Keyword argument: allow_methods
|
||||
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
|
||||
# Keyword argument: allow_headers
|
||||
allow_headers=["Authorization", "Content-Type"],
|
||||
)
|
||||
|
||||
# ── Routers ──────────────────────────────────────────────────────────────
|
||||
app.include_router(auth_router.router, prefix="/api/v1")
|
||||
# Call app.include_router()
|
||||
app.include_router(techniques_router.router, prefix="/api/v1")
|
||||
# Call app.include_router()
|
||||
app.include_router(tests_router.router, prefix="/api/v1")
|
||||
# Call app.include_router()
|
||||
app.include_router(evidence_router.router, prefix="/api/v1")
|
||||
# Call app.include_router()
|
||||
app.include_router(test_templates_router.router, prefix="/api/v1")
|
||||
# Call app.include_router()
|
||||
app.include_router(system_router.router, prefix="/api/v1")
|
||||
# Call app.include_router()
|
||||
app.include_router(metrics_router.router, prefix="/api/v1")
|
||||
# Call app.include_router()
|
||||
app.include_router(users_router.router, prefix="/api/v1")
|
||||
# Call app.include_router()
|
||||
app.include_router(audit_router.router, prefix="/api/v1")
|
||||
# Call app.include_router()
|
||||
app.include_router(notifications_router.router, prefix="/api/v1")
|
||||
# Call app.include_router()
|
||||
app.include_router(reports_router.router, prefix="/api/v1")
|
||||
# Call app.include_router()
|
||||
app.include_router(data_sources_router.router, prefix="/api/v1")
|
||||
# Call app.include_router()
|
||||
app.include_router(threat_actors_router.router, prefix="/api/v1")
|
||||
# Call app.include_router()
|
||||
app.include_router(d3fend_router.router, prefix="/api/v1")
|
||||
# Call app.include_router()
|
||||
app.include_router(detection_rules_router.router, prefix="/api/v1")
|
||||
# Call app.include_router()
|
||||
app.include_router(campaigns_router.router, prefix="/api/v1")
|
||||
# Call app.include_router()
|
||||
app.include_router(heatmap_router.router, prefix="/api/v1")
|
||||
# Call app.include_router()
|
||||
app.include_router(scores_router.router, prefix="/api/v1")
|
||||
# Call app.include_router()
|
||||
app.include_router(operational_metrics_router.router, prefix="/api/v1")
|
||||
# Call app.include_router()
|
||||
app.include_router(compliance_router.router, prefix="/api/v1")
|
||||
# Call app.include_router()
|
||||
app.include_router(snapshots_router.router, prefix="/api/v1")
|
||||
# Call app.include_router()
|
||||
app.include_router(jira_router.router, prefix="/api/v1")
|
||||
# Call app.include_router()
|
||||
app.include_router(worklogs_router.router, prefix="/api/v1")
|
||||
# Call app.include_router()
|
||||
app.include_router(professional_reports_router.router, prefix="/api/v1")
|
||||
# Call app.include_router()
|
||||
app.include_router(analytics_router.router, prefix="/api/v1")
|
||||
# Call app.include_router()
|
||||
app.include_router(advanced_metrics_router.router, prefix="/api/v1")
|
||||
# Call app.include_router()
|
||||
app.include_router(osint_router.router, prefix="/api/v1")
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health():
|
||||
# Apply the @app.get decorator
|
||||
@app.get("/health", include_in_schema=False)
|
||||
# Define function health
|
||||
def health() -> dict[str, str]:
|
||||
"""Return a minimal liveness probe response.
|
||||
|
||||
Access is restricted to internal networks at the Nginx level
|
||||
(see ``frontend/nginx.conf``).
|
||||
|
||||
Returns:
|
||||
dict[str, str]: A dict with ``{"status": "ok"}``.
|
||||
"""
|
||||
# Return {"status": "ok"}
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
# ── Exception Handlers ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _serialize_validation_errors(exc: RequestValidationError) -> list[dict]:
|
||||
"""Return validation errors safe for JSON serialization.
|
||||
|
||||
Converts non-serializable values inside ``ctx`` dictionaries to strings
|
||||
so the response body can be safely encoded.
|
||||
|
||||
Args:
|
||||
exc (RequestValidationError): The Pydantic validation exception.
|
||||
|
||||
Returns:
|
||||
list[dict]: A list of sanitised error detail dictionaries.
|
||||
"""
|
||||
# Assign serialized = []
|
||||
serialized: list[dict] = []
|
||||
# Iterate over exc.errors()
|
||||
for err in exc.errors():
|
||||
# Assign item = dict(err)
|
||||
item = dict(err)
|
||||
# Assign ctx = item.get("ctx")
|
||||
ctx = item.get("ctx")
|
||||
# Check: isinstance(ctx, dict)
|
||||
if isinstance(ctx, dict):
|
||||
# Assign item["ctx"] = {key: str(value) for key, value in ctx.items()}
|
||||
item["ctx"] = {key: str(value) for key, value in ctx.items()}
|
||||
# Call serialized.append()
|
||||
serialized.append(item)
|
||||
# Return serialized
|
||||
return serialized
|
||||
|
||||
|
||||
# Apply the @app.exception_handler decorator
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
||||
"""Handle validation errors with consistent format."""
|
||||
# Define async function validation_exception_handler
|
||||
async def validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse:
|
||||
"""Handle Pydantic validation errors and return a structured 422 response.
|
||||
|
||||
Args:
|
||||
request (Request): The incoming HTTP request.
|
||||
exc (RequestValidationError): The caught validation exception.
|
||||
|
||||
Returns:
|
||||
JSONResponse: A 422 response with a ``VALIDATION_ERROR`` code and error details.
|
||||
"""
|
||||
# Return JSONResponse(
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
# Keyword argument: status_code
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
# Keyword argument: content
|
||||
content={
|
||||
# Literal argument value
|
||||
"detail": "Validation error",
|
||||
# Literal argument value
|
||||
"code": "VALIDATION_ERROR",
|
||||
"errors": exc.errors(),
|
||||
# Literal argument value
|
||||
"errors": _serialize_validation_errors(exc),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# Apply the @app.exception_handler decorator
|
||||
@app.exception_handler(SQLAlchemyError)
|
||||
async def sqlalchemy_exception_handler(request: Request, exc: SQLAlchemyError):
|
||||
"""Handle database errors."""
|
||||
# Define async function sqlalchemy_exception_handler
|
||||
async def sqlalchemy_exception_handler(request: Request, exc: SQLAlchemyError) -> JSONResponse:
|
||||
"""Handle SQLAlchemy database errors and return a structured 500 response.
|
||||
|
||||
Args:
|
||||
request (Request): The incoming HTTP request.
|
||||
exc (SQLAlchemyError): The caught SQLAlchemy exception.
|
||||
|
||||
Returns:
|
||||
JSONResponse: A 500 response with a ``DATABASE_ERROR`` code.
|
||||
"""
|
||||
# Log error: f"Database error: {exc}"
|
||||
logging.error(f"Database error: {exc}")
|
||||
# Return JSONResponse(
|
||||
return JSONResponse(
|
||||
# Keyword argument: status_code
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
# Keyword argument: content
|
||||
content={
|
||||
# Literal argument value
|
||||
"detail": "Database error occurred",
|
||||
# Literal argument value
|
||||
"code": "DATABASE_ERROR",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# Apply the @app.exception_handler decorator
|
||||
@app.exception_handler(Exception)
|
||||
async def general_exception_handler(request: Request, exc: Exception):
|
||||
"""Handle all unhandled exceptions."""
|
||||
# Define async function general_exception_handler
|
||||
async def general_exception_handler(request: Request, exc: Exception) -> JSONResponse:
|
||||
"""Handle all otherwise-unhandled exceptions and return a structured 500 response.
|
||||
|
||||
Args:
|
||||
request (Request): The incoming HTTP request.
|
||||
exc (Exception): The unhandled exception.
|
||||
|
||||
Returns:
|
||||
JSONResponse: A 500 response with an ``INTERNAL_ERROR`` code.
|
||||
"""
|
||||
# Log error: f"Unhandled exception: {exc}"
|
||||
logging.error(f"Unhandled exception: {exc}")
|
||||
# Return JSONResponse(
|
||||
return JSONResponse(
|
||||
# Keyword argument: status_code
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
# Keyword argument: content
|
||||
content={
|
||||
# Literal argument value
|
||||
"detail": "An internal server error occurred",
|
||||
# Literal argument value
|
||||
"code": "INTERNAL_ERROR",
|
||||
},
|
||||
)
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
"""ASGI middleware components for request context, error handling, and rate limiting."""
|
||||
@@ -0,0 +1,66 @@
|
||||
"""Domain error → HTTP response mapping.
|
||||
|
||||
This module provides a single exception handler that converts
|
||||
domain-layer errors into structured JSON responses, keeping
|
||||
the service layer free from FastAPI's ``HTTPException``.
|
||||
"""
|
||||
|
||||
# Import Request from fastapi
|
||||
from fastapi import Request
|
||||
|
||||
# Import JSONResponse from fastapi.responses
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
# Import from app.domain.errors
|
||||
from app.domain.errors import (
|
||||
BusinessRuleViolation,
|
||||
DomainError,
|
||||
DuplicateEntityError,
|
||||
EntityNotFoundError,
|
||||
InvalidOperationError,
|
||||
InvalidStateTransition,
|
||||
PermissionViolation,
|
||||
)
|
||||
|
||||
# Assign EXCEPTION_STATUS_MAP = {
|
||||
EXCEPTION_STATUS_MAP: dict[type[DomainError], int] = {
|
||||
# Entry: EntityNotFoundError
|
||||
EntityNotFoundError: 404,
|
||||
# Entry: DuplicateEntityError
|
||||
DuplicateEntityError: 409,
|
||||
# Entry: InvalidStateTransition
|
||||
InvalidStateTransition: 400,
|
||||
# Entry: InvalidOperationError
|
||||
InvalidOperationError: 400,
|
||||
# Entry: BusinessRuleViolation
|
||||
BusinessRuleViolation: 400,
|
||||
# Entry: PermissionViolation
|
||||
PermissionViolation: 403,
|
||||
}
|
||||
|
||||
|
||||
# Define async function domain_exception_handler
|
||||
async def domain_exception_handler(
|
||||
# Entry: request
|
||||
request: Request,
|
||||
# Entry: exc
|
||||
exc: DomainError,
|
||||
) -> JSONResponse:
|
||||
"""Convert a :class:`DomainError` into a JSON error response."""
|
||||
# Assign status_code = EXCEPTION_STATUS_MAP.get(type(exc), 400)
|
||||
status_code = EXCEPTION_STATUS_MAP.get(type(exc), 400)
|
||||
|
||||
# Assign content = {"detail": exc.message, "code": exc.code}
|
||||
content: dict = {"detail": exc.message, "code": exc.code}
|
||||
|
||||
# Check: isinstance(exc, InvalidStateTransition)
|
||||
if isinstance(exc, InvalidStateTransition):
|
||||
# Assign content["current_state"] = exc.current_state
|
||||
content["current_state"] = exc.current_state
|
||||
# Assign content["target_state"] = exc.target_state
|
||||
content["target_state"] = exc.target_state
|
||||
# Assign content["valid_transitions"] = exc.valid_transitions
|
||||
content["valid_transitions"] = exc.valid_transitions
|
||||
|
||||
# Return JSONResponse(status_code=status_code, content=content)
|
||||
return JSONResponse(status_code=status_code, content=content)
|
||||
@@ -0,0 +1,74 @@
|
||||
"""Request context middleware — captures client IP and User-Agent per request."""
|
||||
|
||||
# Import Awaitable, Callable from collections.abc
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
# Import ContextVar from contextvars
|
||||
from contextvars import ContextVar
|
||||
|
||||
# Import Request from fastapi
|
||||
from fastapi import Request
|
||||
|
||||
# Import BaseHTTPMiddleware from starlette.middleware.base
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
# Import Response from starlette.responses
|
||||
from starlette.responses import Response
|
||||
|
||||
# Assign request_ip = ContextVar("request_ip", default="")
|
||||
request_ip: ContextVar[str] = ContextVar("request_ip", default="")
|
||||
# Assign request_user_agent = ContextVar("request_user_agent", default="")
|
||||
request_user_agent: ContextVar[str] = ContextVar("request_user_agent", default="")
|
||||
|
||||
|
||||
# Define function resolve_client_ip
|
||||
def resolve_client_ip(request: Request) -> str:
|
||||
"""Extract the real client IP, honouring ``X-Forwarded-For`` when present.
|
||||
|
||||
Args:
|
||||
request (Request): The incoming Starlette/FastAPI request.
|
||||
|
||||
Returns:
|
||||
str: The resolved client IP address, or ``"unknown"`` when unavailable.
|
||||
"""
|
||||
# Assign forwarded = request.headers.get("X-Forwarded-For")
|
||||
forwarded = request.headers.get("X-Forwarded-For")
|
||||
# Check: forwarded
|
||||
if forwarded:
|
||||
# Return forwarded.split(",")[0].strip()
|
||||
return forwarded.split(",")[0].strip()
|
||||
# Check: request.client
|
||||
if request.client:
|
||||
# Return request.client.host
|
||||
return request.client.host
|
||||
# Return "unknown"
|
||||
return "unknown"
|
||||
|
||||
|
||||
# Define class RequestContextMiddleware
|
||||
class RequestContextMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware that captures client IP and User-Agent into context variables."""
|
||||
|
||||
# Define async function dispatch
|
||||
async def dispatch(
|
||||
self,
|
||||
# Entry: request
|
||||
request: Request,
|
||||
# Entry: call_next
|
||||
call_next: Callable[[Request], Awaitable[Response]],
|
||||
) -> Response:
|
||||
"""Store client IP and User-Agent in context vars for the current request.
|
||||
|
||||
Args:
|
||||
request (Request): The incoming HTTP request.
|
||||
call_next (Callable[[Request], Awaitable[Response]]): The next middleware or route handler.
|
||||
|
||||
Returns:
|
||||
Response: The HTTP response produced by the downstream handler.
|
||||
"""
|
||||
# Call request_ip.set()
|
||||
request_ip.set(resolve_client_ip(request))
|
||||
# Call request_user_agent.set()
|
||||
request_user_agent.set(request.headers.get("User-Agent", ""))
|
||||
# Return await call_next(request)
|
||||
return await call_next(request)
|
||||
@@ -1,31 +1,96 @@
|
||||
"""SQLAlchemy ORM model definitions for all database tables."""
|
||||
# Import all models here so Alembic can detect them
|
||||
from app.models.user import User
|
||||
from app.models.technique import Technique
|
||||
from app.models.test import Test
|
||||
from app.models.test_template import TestTemplate
|
||||
from app.models.evidence import Evidence
|
||||
from app.models.intel import IntelItem
|
||||
from app.models.audit import AuditLog
|
||||
from app.models.notification import Notification
|
||||
from app.models.data_source import DataSource
|
||||
from app.models.detection_rule import DetectionRule
|
||||
from app.models.threat_actor import ThreatActor, ThreatActorTechnique
|
||||
from app.models.defensive_technique import DefensiveTechnique, DefensiveTechniqueMapping
|
||||
from app.models.test_template_detection_rule import TestTemplateDetectionRule
|
||||
from app.models.test_detection_result import TestDetectionResult
|
||||
from app.models.campaign import Campaign, CampaignTest
|
||||
from app.models.compliance import ComplianceFramework, ComplianceControl, ComplianceControlMapping
|
||||
from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState
|
||||
from app.models.enums import TechniqueStatus, TestState, TestResult, TeamSide
|
||||
|
||||
# Import Campaign, CampaignTest from app.models.campaign
|
||||
from app.models.campaign import Campaign, CampaignTest
|
||||
|
||||
# Import from app.models.compliance
|
||||
from app.models.compliance import (
|
||||
ComplianceControl,
|
||||
ComplianceControlMapping,
|
||||
ComplianceFramework,
|
||||
)
|
||||
|
||||
# Import CoverageSnapshot, SnapshotTechniqueState from app.models.coverage_snapshot
|
||||
from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState
|
||||
|
||||
# Import DataSource from app.models.data_source
|
||||
from app.models.data_source import DataSource
|
||||
|
||||
# Import DefensiveTechnique, DefensiveTechniqueMapping from app.models.defensive_technique
|
||||
from app.models.defensive_technique import DefensiveTechnique, DefensiveTechniqueMapping
|
||||
|
||||
# Import DetectionRule from app.models.detection_rule
|
||||
from app.models.detection_rule import DetectionRule
|
||||
|
||||
# Import TeamSide, TechniqueStatus, TestResult, TestState from app.models.enums
|
||||
from app.models.enums import TeamSide, TechniqueStatus, TestResult, TestState
|
||||
|
||||
# Import Evidence from app.models.evidence
|
||||
from app.models.evidence import Evidence
|
||||
|
||||
# Import IntelItem from app.models.intel
|
||||
from app.models.intel import IntelItem
|
||||
|
||||
# Import JiraLink, JiraLinkEntityType, JiraSyncDirection from app.models.jira_link
|
||||
from app.models.jira_link import JiraLink, JiraLinkEntityType, JiraSyncDirection
|
||||
|
||||
# Import Notification from app.models.notification
|
||||
from app.models.notification import Notification
|
||||
|
||||
# Import OsintItem from app.models.osint_item
|
||||
from app.models.osint_item import OsintItem
|
||||
|
||||
# Import ScoringConfig from app.models.scoring_config
|
||||
from app.models.scoring_config import ScoringConfig
|
||||
|
||||
# Import Technique from app.models.technique
|
||||
from app.models.technique import Technique
|
||||
|
||||
# Import Test from app.models.test
|
||||
from app.models.test import Test
|
||||
|
||||
# Import TestDetectionResult from app.models.test_detection_result
|
||||
from app.models.test_detection_result import TestDetectionResult
|
||||
|
||||
# Import TestTemplate from app.models.test_template
|
||||
from app.models.test_template import TestTemplate
|
||||
|
||||
# Import TestTemplateDetectionRule from app.models.test_template_detection_rule
|
||||
from app.models.test_template_detection_rule import TestTemplateDetectionRule
|
||||
|
||||
# Import ThreatActor, ThreatActorTechnique from app.models.threat_actor
|
||||
from app.models.threat_actor import ThreatActor, ThreatActorTechnique
|
||||
|
||||
# Import User from app.models.user
|
||||
from app.models.user import User
|
||||
|
||||
# Import Worklog from app.models.worklog
|
||||
from app.models.worklog import Worklog
|
||||
|
||||
# Assign __all__ = [
|
||||
__all__ = [
|
||||
# Literal argument value
|
||||
"User", "Technique", "Test", "TestTemplate", "Evidence",
|
||||
# Literal argument value
|
||||
"IntelItem", "AuditLog", "Notification", "DataSource",
|
||||
# Literal argument value
|
||||
"DetectionRule", "ThreatActor", "ThreatActorTechnique",
|
||||
# Literal argument value
|
||||
"DefensiveTechnique", "DefensiveTechniqueMapping",
|
||||
# Literal argument value
|
||||
"TestTemplateDetectionRule", "TestDetectionResult",
|
||||
# Literal argument value
|
||||
"Campaign", "CampaignTest",
|
||||
# Literal argument value
|
||||
"ComplianceFramework", "ComplianceControl", "ComplianceControlMapping",
|
||||
# Literal argument value
|
||||
"CoverageSnapshot", "SnapshotTechniqueState",
|
||||
# Literal argument value
|
||||
"JiraLink", "JiraLinkEntityType", "JiraSyncDirection",
|
||||
# Literal argument value
|
||||
"Worklog", "OsintItem", "ScoringConfig",
|
||||
# Literal argument value
|
||||
"TechniqueStatus", "TestState", "TestResult", "TeamSide",
|
||||
]
|
||||
|
||||
@@ -1,29 +1,60 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
"""SQLAlchemy model for the audit log table."""
|
||||
|
||||
from sqlalchemy import Column, String, DateTime, ForeignKey
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
# Import uuid
|
||||
import uuid
|
||||
|
||||
# Import Column, DateTime, ForeignKey, Index, String, func from sqlalchemy
|
||||
from sqlalchemy import Column, DateTime, ForeignKey, Index, String, func
|
||||
|
||||
# Import JSONB, UUID from sqlalchemy.dialects.postgresql
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
|
||||
# Import relationship from sqlalchemy.orm
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
# Import Base from app.database
|
||||
from app.database import Base
|
||||
|
||||
|
||||
# Define class AuditLog
|
||||
class AuditLog(Base):
|
||||
"""
|
||||
Audit log model for tracking all system actions.
|
||||
|
||||
"""Audit log model for tracking all system actions.
|
||||
|
||||
Records user actions, entity changes, and system events
|
||||
for security auditing and compliance purposes.
|
||||
"""
|
||||
# Assign __tablename__ = "audit_logs"
|
||||
__tablename__ = "audit_logs"
|
||||
|
||||
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
# Assign user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
|
||||
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
|
||||
# Assign action = Column(String, nullable=False)
|
||||
action = Column(String, nullable=False)
|
||||
# Assign entity_type = Column(String, nullable=True)
|
||||
entity_type = Column(String, nullable=True)
|
||||
# Assign entity_id = Column(String, nullable=True)
|
||||
entity_id = Column(String, nullable=True)
|
||||
timestamp = Column(DateTime, default=datetime.utcnow)
|
||||
# Assign timestamp = Column(DateTime(timezone=True), server_default=func.now())
|
||||
timestamp = Column(DateTime(timezone=True), server_default=func.now())
|
||||
# Assign details = Column(JSONB, nullable=True)
|
||||
details = Column(JSONB, nullable=True)
|
||||
# Assign ip_address = Column(String(45), nullable=True)
|
||||
ip_address = Column(String(45), nullable=True)
|
||||
# Assign user_agent = Column(String(500), nullable=True)
|
||||
user_agent = Column(String(500), nullable=True)
|
||||
# Assign integrity_hash = Column(String(64), nullable=True)
|
||||
integrity_hash = Column(String(64), nullable=True)
|
||||
# Assign session_id = Column(String(100), nullable=True)
|
||||
session_id = Column(String(100), nullable=True)
|
||||
|
||||
# Relationships
|
||||
user = relationship("User")
|
||||
|
||||
# Assign __table_args__ = (
|
||||
__table_args__ = (
|
||||
Index("ix_audit_logs_entity", "entity_type", "entity_id"),
|
||||
Index("ix_audit_logs_timestamp", "timestamp"),
|
||||
Index("ix_audit_logs_entity_type_entity_id_action", "entity_type", "entity_id", "action"),
|
||||
)
|
||||
|
||||
@@ -4,22 +4,35 @@ Campaigns group multiple tests into a kill chain sequence,
|
||||
enabling simulation of complete attack chains and APT emulations.
|
||||
"""
|
||||
|
||||
# Import uuid
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
# Import from sqlalchemy
|
||||
from sqlalchemy import (
|
||||
Column, String, Text, Integer, Boolean, DateTime,
|
||||
ForeignKey, Index,
|
||||
Boolean,
|
||||
Column,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Index,
|
||||
Integer,
|
||||
String,
|
||||
Text,
|
||||
func,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
|
||||
# Import JSONB, UUID from sqlalchemy.dialects.postgresql
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
|
||||
# Import relationship from sqlalchemy.orm
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
# Import Base from app.database
|
||||
from app.database import Base
|
||||
|
||||
|
||||
# Define class Campaign
|
||||
class Campaign(Base):
|
||||
"""
|
||||
A campaign groups multiple tests into a sequenced attack chain.
|
||||
"""A campaign groups multiple tests into a sequenced attack chain.
|
||||
|
||||
Types:
|
||||
- custom: manually created campaign
|
||||
@@ -33,60 +46,97 @@ class Campaign(Base):
|
||||
- completed: all tests done
|
||||
- archived: historical record
|
||||
"""
|
||||
# Assign __tablename__ = "campaigns"
|
||||
__tablename__ = "campaigns"
|
||||
|
||||
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
# Assign name = Column(String, nullable=False)
|
||||
name = Column(String, nullable=False)
|
||||
# Assign description = Column(Text, nullable=True)
|
||||
description = Column(Text, nullable=True)
|
||||
# Assign type = Column(String, nullable=False, default="custom") # custom, ap...
|
||||
type = Column(String, nullable=False, default="custom") # custom, apt_emulation, kill_chain, compliance
|
||||
# Assign threat_actor_id = Column(
|
||||
threat_actor_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("threat_actors.id", ondelete="SET NULL"),
|
||||
# Keyword argument: nullable
|
||||
nullable=True,
|
||||
)
|
||||
# Assign status = Column(String, nullable=False, default="draft") # draft, activ...
|
||||
status = Column(String, nullable=False, default="draft") # draft, active, completed, archived
|
||||
# Assign created_by = Column(
|
||||
created_by = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("users.id", ondelete="SET NULL"),
|
||||
# Keyword argument: nullable
|
||||
nullable=True,
|
||||
)
|
||||
# Assign scheduled_at = Column(DateTime, nullable=True)
|
||||
scheduled_at = Column(DateTime, nullable=True)
|
||||
# Assign completed_at = Column(DateTime, nullable=True)
|
||||
completed_at = Column(DateTime, nullable=True)
|
||||
# Assign target_platform = Column(String, nullable=True)
|
||||
target_platform = Column(String, nullable=True)
|
||||
# Assign tags = Column(JSONB, nullable=True, default=[])
|
||||
tags = Column(JSONB, nullable=True, default=[])
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
# Assign created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
# Assign data_classification = Column(String(20), nullable=False, server_default="internal")
|
||||
data_classification = Column(String(20), nullable=False, server_default="internal")
|
||||
|
||||
# Recurring scheduling fields
|
||||
is_recurring = Column(Boolean, default=False)
|
||||
# Assign recurrence_pattern = Column(String, nullable=True) # weekly, monthly, quarterly
|
||||
recurrence_pattern = Column(String, nullable=True) # weekly, monthly, quarterly
|
||||
# Assign next_run_at = Column(DateTime, nullable=True)
|
||||
next_run_at = Column(DateTime, nullable=True)
|
||||
# Assign last_run_at = Column(DateTime, nullable=True)
|
||||
last_run_at = Column(DateTime, nullable=True)
|
||||
# Assign parent_campaign_id = Column(
|
||||
parent_campaign_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("campaigns.id", ondelete="SET NULL"),
|
||||
# Keyword argument: nullable
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
# Relationships
|
||||
threat_actor = relationship("ThreatActor")
|
||||
# Assign creator = relationship("User", foreign_keys=[created_by])
|
||||
creator = relationship("User", foreign_keys=[created_by])
|
||||
# Assign campaign_tests = relationship(
|
||||
campaign_tests = relationship(
|
||||
# Literal argument value
|
||||
"CampaignTest",
|
||||
# Keyword argument: back_populates
|
||||
back_populates="campaign",
|
||||
# Keyword argument: cascade
|
||||
cascade="all, delete-orphan",
|
||||
# Keyword argument: order_by
|
||||
order_by="CampaignTest.order_index",
|
||||
)
|
||||
# Assign parent_campaign = relationship(
|
||||
parent_campaign = relationship(
|
||||
# Literal argument value
|
||||
"Campaign",
|
||||
# Keyword argument: remote_side
|
||||
remote_side="Campaign.id",
|
||||
# Keyword argument: foreign_keys
|
||||
foreign_keys=[parent_campaign_id],
|
||||
)
|
||||
# Assign child_campaigns = relationship(
|
||||
child_campaigns = relationship(
|
||||
# Literal argument value
|
||||
"Campaign",
|
||||
# Keyword argument: foreign_keys
|
||||
foreign_keys=[parent_campaign_id],
|
||||
# Keyword argument: back_populates
|
||||
back_populates="parent_campaign",
|
||||
)
|
||||
|
||||
# Assign __table_args__ = (
|
||||
__table_args__ = (
|
||||
Index('ix_campaigns_status', 'status'),
|
||||
Index('ix_campaigns_type', 'type'),
|
||||
@@ -98,56 +148,83 @@ class Campaign(Base):
|
||||
|
||||
# Kill chain phases in order (for sorting and validation)
|
||||
KILL_CHAIN_PHASES = [
|
||||
# Literal argument value
|
||||
"reconnaissance",
|
||||
# Literal argument value
|
||||
"resource_development",
|
||||
# Literal argument value
|
||||
"initial_access",
|
||||
# Literal argument value
|
||||
"execution",
|
||||
# Literal argument value
|
||||
"persistence",
|
||||
# Literal argument value
|
||||
"privilege_escalation",
|
||||
# Literal argument value
|
||||
"defense_evasion",
|
||||
# Literal argument value
|
||||
"credential_access",
|
||||
# Literal argument value
|
||||
"discovery",
|
||||
# Literal argument value
|
||||
"lateral_movement",
|
||||
# Literal argument value
|
||||
"collection",
|
||||
# Literal argument value
|
||||
"command_and_control",
|
||||
# Literal argument value
|
||||
"exfiltration",
|
||||
# Literal argument value
|
||||
"impact",
|
||||
]
|
||||
|
||||
|
||||
# Define class CampaignTest
|
||||
class CampaignTest(Base):
|
||||
"""
|
||||
A test within a campaign, with ordering and dependency information.
|
||||
"""A test within a campaign, with ordering and dependency information.
|
||||
|
||||
``depends_on`` creates a self-referential chain (A -> B -> C).
|
||||
Circular dependencies are validated at the service layer.
|
||||
"""
|
||||
# Assign __tablename__ = "campaign_tests"
|
||||
__tablename__ = "campaign_tests"
|
||||
|
||||
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
# Assign campaign_id = Column(
|
||||
campaign_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("campaigns.id", ondelete="CASCADE"),
|
||||
# Keyword argument: nullable
|
||||
nullable=False,
|
||||
)
|
||||
# Assign test_id = Column(
|
||||
test_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("tests.id", ondelete="CASCADE"),
|
||||
# Keyword argument: nullable
|
||||
nullable=False,
|
||||
)
|
||||
# Assign order_index = Column(Integer, nullable=False, default=0)
|
||||
order_index = Column(Integer, nullable=False, default=0)
|
||||
# Assign depends_on = Column(
|
||||
depends_on = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("campaign_tests.id", ondelete="SET NULL"),
|
||||
# Keyword argument: nullable
|
||||
nullable=True,
|
||||
)
|
||||
# Assign phase = Column(String, nullable=True) # kill chain phase
|
||||
phase = Column(String, nullable=True) # kill chain phase
|
||||
|
||||
# Relationships
|
||||
campaign = relationship("Campaign", back_populates="campaign_tests")
|
||||
# Assign test = relationship("Test")
|
||||
test = relationship("Test")
|
||||
# Assign dependency = relationship("CampaignTest", remote_side="CampaignTest.id")
|
||||
dependency = relationship("CampaignTest", remote_side="CampaignTest.id")
|
||||
|
||||
# Assign __table_args__ = (
|
||||
__table_args__ = (
|
||||
Index('ix_campaign_tests_campaign', 'campaign_id'),
|
||||
Index('ix_campaign_tests_test', 'test_id'),
|
||||
|
||||
@@ -4,94 +4,145 @@ Maps compliance frameworks (NIST 800-53, DORA, NIS2, ISO 27001) to
|
||||
MITRE ATT&CK techniques, enabling compliance gap analysis.
|
||||
"""
|
||||
|
||||
# Import uuid
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
# Import from sqlalchemy
|
||||
from sqlalchemy import (
|
||||
Column, String, Text, Boolean, DateTime,
|
||||
ForeignKey, Index, UniqueConstraint,
|
||||
Boolean,
|
||||
Column,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Index,
|
||||
String,
|
||||
Text,
|
||||
UniqueConstraint,
|
||||
func,
|
||||
)
|
||||
|
||||
# Import UUID from sqlalchemy.dialects.postgresql
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
# Import relationship from sqlalchemy.orm
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
# Import Base from app.database
|
||||
from app.database import Base
|
||||
|
||||
|
||||
# Define class ComplianceFramework
|
||||
class ComplianceFramework(Base):
|
||||
"""A compliance framework (e.g. NIST 800-53, ISO 27001)."""
|
||||
# Assign __tablename__ = "compliance_frameworks"
|
||||
__tablename__ = "compliance_frameworks"
|
||||
|
||||
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
# Assign name = Column(String, unique=True, nullable=False)
|
||||
name = Column(String, unique=True, nullable=False)
|
||||
# Assign version = Column(String, nullable=True)
|
||||
version = Column(String, nullable=True)
|
||||
# Assign description = Column(Text, nullable=True)
|
||||
description = Column(Text, nullable=True)
|
||||
# Assign url = Column(String, nullable=True)
|
||||
url = Column(String, nullable=True)
|
||||
# Assign is_active = Column(Boolean, default=True)
|
||||
is_active = Column(Boolean, default=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
# Assign created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
# Relationships
|
||||
controls = relationship(
|
||||
# Literal argument value
|
||||
"ComplianceControl",
|
||||
# Keyword argument: back_populates
|
||||
back_populates="framework",
|
||||
# Keyword argument: cascade
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
|
||||
# Define class ComplianceControl
|
||||
class ComplianceControl(Base):
|
||||
"""A control within a compliance framework (e.g. AC-2, PR.AC-1)."""
|
||||
# Assign __tablename__ = "compliance_controls"
|
||||
__tablename__ = "compliance_controls"
|
||||
|
||||
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
# Assign framework_id = Column(
|
||||
framework_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("compliance_frameworks.id", ondelete="CASCADE"),
|
||||
# Keyword argument: nullable
|
||||
nullable=False,
|
||||
)
|
||||
# Assign control_id = Column(String, nullable=False) # e.g. "AC-2"
|
||||
control_id = Column(String, nullable=False) # e.g. "AC-2"
|
||||
# Assign title = Column(String, nullable=False)
|
||||
title = Column(String, nullable=False)
|
||||
# Assign description = Column(Text, nullable=True)
|
||||
description = Column(Text, nullable=True)
|
||||
# Assign category = Column(String, nullable=True)
|
||||
category = Column(String, nullable=True)
|
||||
|
||||
# Relationships
|
||||
framework = relationship("ComplianceFramework", back_populates="controls")
|
||||
# Assign technique_mappings = relationship(
|
||||
technique_mappings = relationship(
|
||||
# Literal argument value
|
||||
"ComplianceControlMapping",
|
||||
# Keyword argument: back_populates
|
||||
back_populates="compliance_control",
|
||||
# Keyword argument: cascade
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
# Assign __table_args__ = (
|
||||
__table_args__ = (
|
||||
Index('ix_compliance_controls_framework', 'framework_id'),
|
||||
)
|
||||
|
||||
|
||||
# Define class ComplianceControlMapping
|
||||
class ComplianceControlMapping(Base):
|
||||
"""Maps a compliance control to a MITRE ATT&CK technique."""
|
||||
# Assign __tablename__ = "compliance_control_mappings"
|
||||
__tablename__ = "compliance_control_mappings"
|
||||
|
||||
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
# Assign compliance_control_id = Column(
|
||||
compliance_control_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("compliance_controls.id", ondelete="CASCADE"),
|
||||
# Keyword argument: nullable
|
||||
nullable=False,
|
||||
)
|
||||
# Assign technique_id = Column(
|
||||
technique_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("techniques.id", ondelete="CASCADE"),
|
||||
# Keyword argument: nullable
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Relationships
|
||||
compliance_control = relationship(
|
||||
# Literal argument value
|
||||
"ComplianceControl", back_populates="technique_mappings"
|
||||
)
|
||||
# Assign technique = relationship("Technique")
|
||||
technique = relationship("Technique")
|
||||
|
||||
# Assign __table_args__ = (
|
||||
__table_args__ = (
|
||||
Index('ix_compliance_mappings_control', 'compliance_control_id'),
|
||||
Index('ix_compliance_mappings_technique', 'technique_id'),
|
||||
UniqueConstraint(
|
||||
# Literal argument value
|
||||
'compliance_control_id', 'technique_id',
|
||||
# Keyword argument: name
|
||||
name='uq_control_technique',
|
||||
),
|
||||
)
|
||||
|
||||
@@ -5,73 +5,125 @@ SnapshotTechniqueState stores per-technique state (normalized, one row
|
||||
per technique per snapshot) to avoid bloated JSONB fields.
|
||||
"""
|
||||
|
||||
# Import uuid
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
# Import from sqlalchemy
|
||||
from sqlalchemy import (
|
||||
Column, String, Float, Integer, DateTime,
|
||||
ForeignKey, Index,
|
||||
Column,
|
||||
DateTime,
|
||||
Float,
|
||||
ForeignKey,
|
||||
Index,
|
||||
Integer,
|
||||
String,
|
||||
func,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
# Import JSONB, UUID from sqlalchemy.dialects.postgresql
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
|
||||
# Import relationship from sqlalchemy.orm
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
# Import Base from app.database
|
||||
from app.database import Base
|
||||
|
||||
|
||||
# Define class CoverageSnapshot
|
||||
class CoverageSnapshot(Base):
|
||||
"""A point-in-time snapshot of the organisation's overall coverage."""
|
||||
|
||||
# Assign __tablename__ = "coverage_snapshots"
|
||||
__tablename__ = "coverage_snapshots"
|
||||
|
||||
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
# Assign name = Column(String, nullable=True) # e.g. "Pre-remediación Q1"
|
||||
name = Column(String, nullable=True) # e.g. "Pre-remediación Q1"
|
||||
# Assign organization_score = Column(Float, nullable=False)
|
||||
organization_score = Column(Float, nullable=False)
|
||||
# Assign total_techniques = Column(Integer, nullable=False)
|
||||
total_techniques = Column(Integer, nullable=False)
|
||||
# Assign validated_count = Column(Integer, nullable=False)
|
||||
validated_count = Column(Integer, nullable=False)
|
||||
# Assign partial_count = Column(Integer, nullable=False)
|
||||
partial_count = Column(Integer, nullable=False)
|
||||
# Assign not_covered_count = Column(Integer, nullable=False)
|
||||
not_covered_count = Column(Integer, nullable=False)
|
||||
# Assign in_progress_count = Column(Integer, nullable=False)
|
||||
in_progress_count = Column(Integer, nullable=False)
|
||||
# Assign not_evaluated_count = Column(Integer, nullable=False)
|
||||
not_evaluated_count = Column(Integer, nullable=False)
|
||||
# Assign coverage_percentage = Column(Float, nullable=False, default=0.0)
|
||||
coverage_percentage = Column(Float, nullable=False, default=0.0)
|
||||
# Assign by_tactic = Column(JSONB, nullable=False, default=dict)
|
||||
by_tactic = Column(JSONB, nullable=False, default=dict)
|
||||
# Assign by_status = Column(JSONB, nullable=False, default=dict)
|
||||
by_status = Column(JSONB, nullable=False, default=dict)
|
||||
# Assign stale_count = Column(Integer, nullable=False, default=0)
|
||||
stale_count = Column(Integer, nullable=False, default=0)
|
||||
# Assign never_tested_count = Column(Integer, nullable=False, default=0)
|
||||
never_tested_count = Column(Integer, nullable=False, default=0)
|
||||
# Assign created_by = Column(
|
||||
created_by = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("users.id", ondelete="SET NULL"),
|
||||
# Keyword argument: nullable
|
||||
nullable=True,
|
||||
)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
# Assign created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
# Relationships
|
||||
creator = relationship("User", foreign_keys=[created_by])
|
||||
# Assign technique_states = relationship(
|
||||
technique_states = relationship(
|
||||
# Literal argument value
|
||||
"SnapshotTechniqueState",
|
||||
# Keyword argument: back_populates
|
||||
back_populates="snapshot",
|
||||
# Keyword argument: cascade
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
|
||||
# Define class SnapshotTechniqueState
|
||||
class SnapshotTechniqueState(Base):
|
||||
"""Per-technique state within a snapshot (normalised storage)."""
|
||||
|
||||
# Assign __tablename__ = "snapshot_technique_states"
|
||||
__tablename__ = "snapshot_technique_states"
|
||||
|
||||
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
# Assign snapshot_id = Column(
|
||||
snapshot_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("coverage_snapshots.id", ondelete="CASCADE"),
|
||||
# Keyword argument: nullable
|
||||
nullable=False,
|
||||
)
|
||||
# Assign technique_id = Column(
|
||||
technique_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("techniques.id", ondelete="CASCADE"),
|
||||
# Keyword argument: nullable
|
||||
nullable=False,
|
||||
)
|
||||
# Assign mitre_id = Column(String, nullable=False) # denormalised for fast queries
|
||||
mitre_id = Column(String, nullable=False) # denormalised for fast queries
|
||||
# Assign status = Column(String, nullable=False)
|
||||
status = Column(String, nullable=False)
|
||||
# Assign score = Column(Float, nullable=True)
|
||||
score = Column(Float, nullable=True)
|
||||
|
||||
# Relationships
|
||||
snapshot = relationship("CoverageSnapshot", back_populates="technique_states")
|
||||
# Assign technique = relationship("Technique")
|
||||
technique = relationship("Technique")
|
||||
|
||||
# Assign __table_args__ = (
|
||||
__table_args__ = (
|
||||
Index("ix_snapshot_technique_states_snapshot", "snapshot_id"),
|
||||
Index("ix_snapshot_technique_states_technique", "technique_id"),
|
||||
|
||||
@@ -1,38 +1,56 @@
|
||||
"""DataSource model — registry of external data sources for import."""
|
||||
|
||||
# Import uuid
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Column, String, Text, Boolean, DateTime, Index
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
# Import Boolean, Column, DateTime, Index, String, Text,... from sqlalchemy
|
||||
from sqlalchemy import Boolean, Column, DateTime, Index, String, Text, func
|
||||
|
||||
# Import JSONB, UUID from sqlalchemy.dialects.postgresql
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
|
||||
# Import Base from app.database
|
||||
from app.database import Base
|
||||
|
||||
|
||||
# Define class DataSource
|
||||
class DataSource(Base):
|
||||
"""
|
||||
Unified registry of all external data sources (attack procedures,
|
||||
detection rules, threat intel, defensive techniques).
|
||||
"""Unified registry of all external data sources.
|
||||
|
||||
Each source can be independently enabled/disabled and tracks its own
|
||||
synchronisation state.
|
||||
Covers attack procedures, detection rules, threat intel, and defensive techniques.
|
||||
Each source can be independently enabled/disabled and tracks its own synchronisation state.
|
||||
"""
|
||||
# Assign __tablename__ = "data_sources"
|
||||
__tablename__ = "data_sources"
|
||||
|
||||
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
# Assign name = Column(String, unique=True, nullable=False) # e.g. "atom...
|
||||
name = Column(String, unique=True, nullable=False) # e.g. "atomic_red_team"
|
||||
# Assign display_name = Column(String, nullable=False) # e.g. "Atomic Red ...
|
||||
display_name = Column(String, nullable=False) # e.g. "Atomic Red Team"
|
||||
type = Column(String, nullable=False) # attack_procedure / detection_rule / threat_intel / defensive_technique
|
||||
# Values: attack_procedure / detection_rule / threat_intel / defensive_technique
|
||||
type = Column(String, nullable=False)
|
||||
# Assign url = Column(String, nullable=True) # URL base...
|
||||
url = Column(String, nullable=True) # URL base of repo/API
|
||||
# Assign description = Column(Text, nullable=True)
|
||||
description = Column(Text, nullable=True)
|
||||
# Assign is_enabled = Column(Boolean, default=True)
|
||||
is_enabled = Column(Boolean, default=True)
|
||||
# Assign last_sync_at = Column(DateTime, nullable=True)
|
||||
last_sync_at = Column(DateTime, nullable=True)
|
||||
# Assign last_sync_status = Column(String, nullable=True) # success / error / in_...
|
||||
last_sync_status = Column(String, nullable=True) # success / error / in_progress
|
||||
# Assign last_sync_stats = Column(JSONB, nullable=True) # {"imported": X, "upd...
|
||||
last_sync_stats = Column(JSONB, nullable=True) # {"imported": X, "updated": Y, ...}
|
||||
# Assign sync_frequency = Column(String, nullable=True) # daily / weekly / mo...
|
||||
sync_frequency = Column(String, nullable=True) # daily / weekly / monthly / manual
|
||||
# Assign config = Column(JSONB, nullable=True) # source-spec...
|
||||
config = Column(JSONB, nullable=True) # source-specific configuration
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
# Assign created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
# Assign __table_args__ = (
|
||||
__table_args__ = (
|
||||
Index('ix_data_sources_type', 'type'),
|
||||
Index('ix_data_sources_is_enabled', 'is_enabled'),
|
||||
|
||||
@@ -4,76 +4,108 @@ Stores MITRE D3FEND defensive techniques and their mappings to
|
||||
ATT&CK techniques, enabling recommended countermeasure lookups.
|
||||
"""
|
||||
|
||||
# Import uuid
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
# Import from sqlalchemy
|
||||
from sqlalchemy import (
|
||||
Column, String, Text, DateTime,
|
||||
ForeignKey, Index, UniqueConstraint,
|
||||
Column,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Index,
|
||||
String,
|
||||
Text,
|
||||
UniqueConstraint,
|
||||
func,
|
||||
)
|
||||
|
||||
# Import UUID from sqlalchemy.dialects.postgresql
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
# Import relationship from sqlalchemy.orm
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
# Import Base from app.database
|
||||
from app.database import Base
|
||||
|
||||
|
||||
# Define class DefensiveTechnique
|
||||
class DefensiveTechnique(Base):
|
||||
"""
|
||||
MITRE D3FEND defensive technique.
|
||||
"""MITRE D3FEND defensive technique.
|
||||
|
||||
Represents a countermeasure from the D3FEND framework that can be
|
||||
mapped to one or more ATT&CK techniques via DefensiveTechniqueMapping.
|
||||
"""
|
||||
# Assign __tablename__ = "defensive_techniques"
|
||||
__tablename__ = "defensive_techniques"
|
||||
|
||||
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
# Assign d3fend_id = Column(String, unique=True, nullable=False) # e.g. "D3-AL"
|
||||
d3fend_id = Column(String, unique=True, nullable=False) # e.g. "D3-AL"
|
||||
# Assign name = Column(String, nullable=False)
|
||||
name = Column(String, nullable=False)
|
||||
# Assign description = Column(Text, nullable=True)
|
||||
description = Column(Text, nullable=True)
|
||||
# Assign tactic = Column(String, nullable=True) # Detect, ...
|
||||
tactic = Column(String, nullable=True) # Detect, Isolate, Deceive, Evict, etc.
|
||||
# Assign d3fend_url = Column(String, nullable=True)
|
||||
d3fend_url = Column(String, nullable=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
# Assign created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
# Relationships
|
||||
attack_mappings = relationship(
|
||||
# Literal argument value
|
||||
"DefensiveTechniqueMapping",
|
||||
# Keyword argument: back_populates
|
||||
back_populates="defensive_technique",
|
||||
# Keyword argument: cascade
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
# Assign __table_args__ = (
|
||||
__table_args__ = (
|
||||
Index('ix_defensive_techniques_tactic', 'tactic'),
|
||||
)
|
||||
|
||||
|
||||
# Define class DefensiveTechniqueMapping
|
||||
class DefensiveTechniqueMapping(Base):
|
||||
"""
|
||||
Association between a MITRE ATT&CK technique and a D3FEND
|
||||
defensive technique.
|
||||
"""
|
||||
"""Association between a MITRE ATT&CK technique and a D3FEND defensive technique."""
|
||||
# Assign __tablename__ = "defensive_technique_mappings"
|
||||
__tablename__ = "defensive_technique_mappings"
|
||||
|
||||
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
# Assign attack_technique_id = Column(
|
||||
attack_technique_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("techniques.id", ondelete="CASCADE"),
|
||||
# Keyword argument: nullable
|
||||
nullable=False,
|
||||
)
|
||||
# Assign defensive_technique_id = Column(
|
||||
defensive_technique_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("defensive_techniques.id", ondelete="CASCADE"),
|
||||
# Keyword argument: nullable
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Relationships
|
||||
attack_technique = relationship("Technique")
|
||||
# Assign defensive_technique = relationship("DefensiveTechnique", back_populates="attack_mappings")
|
||||
defensive_technique = relationship("DefensiveTechnique", back_populates="attack_mappings")
|
||||
|
||||
# Assign __table_args__ = (
|
||||
__table_args__ = (
|
||||
Index('ix_dtm_attack_technique', 'attack_technique_id'),
|
||||
Index('ix_dtm_defensive_technique', 'defensive_technique_id'),
|
||||
UniqueConstraint(
|
||||
# Literal argument value
|
||||
'attack_technique_id', 'defensive_technique_id',
|
||||
# Keyword argument: name
|
||||
name='uq_attack_defensive_technique',
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1,40 +1,61 @@
|
||||
"""DetectionRule model — detection rules from multiple sources."""
|
||||
|
||||
# Import uuid
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Column, String, Text, Boolean, DateTime, Index
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
# Import Boolean, Column, DateTime, Index, String, Text,... from sqlalchemy
|
||||
from sqlalchemy import Boolean, Column, DateTime, Index, String, Text, func
|
||||
|
||||
# Import JSONB, UUID from sqlalchemy.dialects.postgresql
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
|
||||
# Import Base from app.database
|
||||
from app.database import Base
|
||||
|
||||
|
||||
# Define class DetectionRule
|
||||
class DetectionRule(Base):
|
||||
"""
|
||||
Detection rule from an external source (Sigma, Elastic, Splunk, custom).
|
||||
"""Detection rule from an external source (Sigma, Elastic, Splunk, custom).
|
||||
|
||||
Each rule is mapped to one MITRE ATT&CK technique via
|
||||
``mitre_technique_id`` and stores the complete rule content in
|
||||
``rule_content``.
|
||||
"""
|
||||
# Assign __tablename__ = "detection_rules"
|
||||
__tablename__ = "detection_rules"
|
||||
|
||||
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
# Assign mitre_technique_id = Column(String, nullable=False) # e.g. "T1059.001"
|
||||
mitre_technique_id = Column(String, nullable=False) # e.g. "T1059.001"
|
||||
# Assign title = Column(String, nullable=False)
|
||||
title = Column(String, nullable=False)
|
||||
# Assign description = Column(Text, nullable=True)
|
||||
description = Column(Text, nullable=True)
|
||||
# Assign source = Column(String, nullable=False) # sigma / ela...
|
||||
source = Column(String, nullable=False) # sigma / elastic / splunk / custom
|
||||
# Assign source_id = Column(String, nullable=True) # ID in the sour...
|
||||
source_id = Column(String, nullable=True) # ID in the source repo (for dedup)
|
||||
# Assign source_url = Column(String, nullable=True)
|
||||
source_url = Column(String, nullable=True)
|
||||
# Assign rule_content = Column(Text, nullable=False) # YAML / KQL / SPL ...
|
||||
rule_content = Column(Text, nullable=False) # YAML / KQL / SPL content
|
||||
# Assign rule_format = Column(String, nullable=False) # sigma_yaml / kql...
|
||||
rule_format = Column(String, nullable=False) # sigma_yaml / kql / spl / custom
|
||||
# Assign severity = Column(String, nullable=True) # informational...
|
||||
severity = Column(String, nullable=True) # informational / low / medium / high / critical
|
||||
# Assign platforms = Column(JSONB, nullable=True, default=[])
|
||||
platforms = Column(JSONB, nullable=True, default=[])
|
||||
# Assign log_sources = Column(JSONB, nullable=True) # e.g. {"product":...
|
||||
log_sources = Column(JSONB, nullable=True) # e.g. {"product": "windows", "service": "sysmon"}
|
||||
# Assign false_positive_rate = Column(String, nullable=True) # low / medium / high
|
||||
false_positive_rate = Column(String, nullable=True) # low / medium / high
|
||||
# Assign is_active = Column(Boolean, default=True)
|
||||
is_active = Column(Boolean, default=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
# Assign created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
# Assign __table_args__ = (
|
||||
__table_args__ = (
|
||||
Index('ix_detection_rules_mitre_technique_id', 'mitre_technique_id'),
|
||||
Index('ix_detection_rules_source', 'source'),
|
||||
|
||||
+13
-28
@@ -1,30 +1,15 @@
|
||||
import enum
|
||||
"""ORM-level re-exports of the canonical domain enums.
|
||||
|
||||
The single source of truth lives in ``app.domain.enums``. This module
|
||||
re-exports every enum so that existing model and router code keeps
|
||||
working with ``from app.models.enums import ...``.
|
||||
"""
|
||||
|
||||
class TechniqueStatus(str, enum.Enum):
|
||||
not_evaluated = "not_evaluated"
|
||||
in_progress = "in_progress"
|
||||
validated = "validated"
|
||||
partial = "partial"
|
||||
not_covered = "not_covered"
|
||||
review_required = "review_required"
|
||||
|
||||
|
||||
class TestState(str, enum.Enum):
|
||||
draft = "draft"
|
||||
red_executing = "red_executing" # Red Team documenting attack
|
||||
blue_evaluating = "blue_evaluating" # Blue Team evaluating detection
|
||||
in_review = "in_review"
|
||||
validated = "validated"
|
||||
rejected = "rejected"
|
||||
|
||||
|
||||
class TeamSide(str, enum.Enum):
|
||||
red = "red"
|
||||
blue = "blue"
|
||||
|
||||
|
||||
class TestResult(str, enum.Enum):
|
||||
detected = "detected"
|
||||
not_detected = "not_detected"
|
||||
partially_detected = "partially_detected"
|
||||
# Import # noqa: F401 from app.domain.enums
|
||||
from app.domain.enums import ( # noqa: F401
|
||||
DataClassification,
|
||||
TeamSide,
|
||||
TechniqueStatus,
|
||||
TestResult,
|
||||
TestState,
|
||||
)
|
||||
|
||||
@@ -1,36 +1,59 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
"""SQLAlchemy model for the evidence table."""
|
||||
|
||||
from sqlalchemy import Column, String, Text, DateTime, ForeignKey, Enum
|
||||
# Import uuid
|
||||
import uuid
|
||||
|
||||
# Import Column, DateTime, Enum, ForeignKey, String, Tex... from sqlalchemy
|
||||
from sqlalchemy import Column, DateTime, Enum, ForeignKey, String, Text, func
|
||||
|
||||
# Import UUID from sqlalchemy.dialects.postgresql
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
# Import relationship from sqlalchemy.orm
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
# Import Base from app.database
|
||||
from app.database import Base
|
||||
|
||||
# Import TeamSide from app.models.enums
|
||||
from app.models.enums import TeamSide
|
||||
|
||||
|
||||
# Define class Evidence
|
||||
class Evidence(Base):
|
||||
"""
|
||||
Evidence model for storing file metadata associated with tests.
|
||||
|
||||
"""Evidence model for storing file metadata associated with tests.
|
||||
|
||||
Files are stored in MinIO, and this model tracks the file location,
|
||||
integrity hash, and upload metadata.
|
||||
|
||||
|
||||
The ``team`` field distinguishes whether this evidence was uploaded by
|
||||
Red Team (attack evidence) or Blue Team (detection evidence).
|
||||
"""
|
||||
# Assign __tablename__ = "evidences"
|
||||
__tablename__ = "evidences"
|
||||
|
||||
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
# Assign test_id = Column(UUID(as_uuid=True), ForeignKey("tests.id"), nullable=False)
|
||||
test_id = Column(UUID(as_uuid=True), ForeignKey("tests.id"), nullable=False)
|
||||
# Assign file_name = Column(String, nullable=False)
|
||||
file_name = Column(String, nullable=False)
|
||||
# Assign file_path = Column(String, nullable=False) # Path in MinIO
|
||||
file_path = Column(String, nullable=False) # Path in MinIO
|
||||
# Assign sha256_hash = Column(String, nullable=False)
|
||||
sha256_hash = Column(String, nullable=False)
|
||||
# Assign uploaded_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
|
||||
uploaded_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
|
||||
uploaded_at = Column(DateTime, default=datetime.utcnow)
|
||||
# Assign uploaded_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
uploaded_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
# Assign team = Column(Enum(TeamSide, name="teamside"), nullable=False, default=Tea...
|
||||
team = Column(Enum(TeamSide, name="teamside"), nullable=False, default=TeamSide.red)
|
||||
# Assign notes = Column(Text, nullable=True)
|
||||
notes = Column(Text, nullable=True)
|
||||
# Assign data_classification = Column(String(20), nullable=False, server_default="internal")
|
||||
data_classification = Column(String(20), nullable=False, server_default="internal")
|
||||
|
||||
# Relationships
|
||||
test = relationship("Test", back_populates="evidences")
|
||||
# Assign uploader = relationship("User", foreign_keys=[uploaded_by])
|
||||
uploader = relationship("User", foreign_keys=[uploaded_by])
|
||||
|
||||
@@ -1,28 +1,44 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
"""SQLAlchemy model for the intel_items table."""
|
||||
|
||||
from sqlalchemy import Column, String, Boolean, DateTime, ForeignKey
|
||||
# Import uuid
|
||||
import uuid
|
||||
|
||||
# Import Boolean, Column, DateTime, ForeignKey, String, ... from sqlalchemy
|
||||
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, String, func
|
||||
|
||||
# Import UUID from sqlalchemy.dialects.postgresql
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
# Import relationship from sqlalchemy.orm
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
# Import Base from app.database
|
||||
from app.database import Base
|
||||
|
||||
|
||||
# Define class IntelItem
|
||||
class IntelItem(Base):
|
||||
"""
|
||||
Intelligence item model for tracking threat intelligence related to techniques.
|
||||
|
||||
"""Intelligence item model for tracking threat intelligence related to techniques.
|
||||
|
||||
Stores URLs and metadata from automated intel scans that may indicate
|
||||
new attack variations or detection bypasses for specific techniques.
|
||||
"""
|
||||
# Assign __tablename__ = "intel_items"
|
||||
__tablename__ = "intel_items"
|
||||
|
||||
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
# Assign technique_id = Column(UUID(as_uuid=True), ForeignKey("techniques.id"), nullable=True)
|
||||
technique_id = Column(UUID(as_uuid=True), ForeignKey("techniques.id"), nullable=True)
|
||||
# Assign url = Column(String, nullable=False)
|
||||
url = Column(String, nullable=False)
|
||||
# Assign title = Column(String, nullable=True)
|
||||
title = Column(String, nullable=True)
|
||||
# Assign source = Column(String, nullable=True)
|
||||
source = Column(String, nullable=True)
|
||||
detected_at = Column(DateTime, default=datetime.utcnow)
|
||||
# Assign detected_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
detected_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
# Assign reviewed = Column(Boolean, default=False)
|
||||
reviewed = Column(Boolean, default=False)
|
||||
|
||||
# Relationships
|
||||
|
||||
@@ -0,0 +1,101 @@
|
||||
"""Jira integration models — link Aegis entities to Jira issues."""
|
||||
|
||||
# Import enum
|
||||
import enum
|
||||
|
||||
# Import uuid
|
||||
import uuid
|
||||
|
||||
# Import Column, DateTime, ForeignKey, Index, String, func from sqlalchemy
|
||||
from sqlalchemy import Column, DateTime, ForeignKey, Index, String, func
|
||||
|
||||
# Import Enum as SQLEnum from sqlalchemy
|
||||
from sqlalchemy import Enum as SQLEnum
|
||||
|
||||
# Import JSONB, UUID from sqlalchemy.dialects.postgresql
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
|
||||
# Import relationship from sqlalchemy.orm
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
# Import Base from app.database
|
||||
from app.database import Base
|
||||
|
||||
|
||||
# Define class JiraLinkEntityType
|
||||
class JiraLinkEntityType(str, enum.Enum):
|
||||
"""Aegis entity types that can be linked to a Jira issue."""
|
||||
|
||||
# Assign test = "test"
|
||||
test = "test"
|
||||
# Assign technique = "technique"
|
||||
technique = "technique"
|
||||
# Assign campaign = "campaign"
|
||||
campaign = "campaign"
|
||||
# Assign evidence = "evidence"
|
||||
evidence = "evidence"
|
||||
|
||||
|
||||
# Define class JiraSyncDirection
|
||||
class JiraSyncDirection(str, enum.Enum):
|
||||
"""Direction of synchronisation between Aegis and Jira."""
|
||||
|
||||
# Assign aegis_to_jira = "aegis_to_jira"
|
||||
aegis_to_jira = "aegis_to_jira"
|
||||
# Assign jira_to_aegis = "jira_to_aegis"
|
||||
jira_to_aegis = "jira_to_aegis"
|
||||
# Assign bidirectional = "bidirectional"
|
||||
bidirectional = "bidirectional"
|
||||
|
||||
|
||||
# Define class JiraLink
|
||||
class JiraLink(Base):
|
||||
"""Associates an Aegis entity with a Jira issue for bidirectional sync."""
|
||||
|
||||
# Assign __tablename__ = "jira_links"
|
||||
__tablename__ = "jira_links"
|
||||
|
||||
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
# Assign entity_type = Column(SQLEnum(JiraLinkEntityType), nullable=False)
|
||||
entity_type = Column(SQLEnum(JiraLinkEntityType), nullable=False)
|
||||
# Assign entity_id = Column(UUID(as_uuid=True), nullable=False)
|
||||
entity_id = Column(UUID(as_uuid=True), nullable=False)
|
||||
# Assign jira_issue_key = Column(String(50), nullable=False)
|
||||
jira_issue_key = Column(String(50), nullable=False)
|
||||
# Assign jira_issue_id = Column(String(50))
|
||||
jira_issue_id = Column(String(50))
|
||||
# Assign jira_project_key = Column(String(20))
|
||||
jira_project_key = Column(String(20))
|
||||
# Assign jira_status = Column(String(100))
|
||||
jira_status = Column(String(100))
|
||||
# Assign jira_priority = Column(String(50))
|
||||
jira_priority = Column(String(50))
|
||||
# Assign jira_assignee = Column(String(255))
|
||||
jira_assignee = Column(String(255))
|
||||
# Assign jira_story_points = Column(String(10))
|
||||
jira_story_points = Column(String(10))
|
||||
# Assign sync_direction = Column(
|
||||
sync_direction = Column(
|
||||
SQLEnum(JiraSyncDirection), default=JiraSyncDirection.bidirectional
|
||||
)
|
||||
# Assign last_synced_at = Column(DateTime)
|
||||
last_synced_at = Column(DateTime)
|
||||
# Assign sync_metadata = Column(JSONB, default={})
|
||||
sync_metadata = Column(JSONB, default={})
|
||||
# Assign created_by = Column(UUID(as_uuid=True), ForeignKey("users.id"))
|
||||
created_by = Column(UUID(as_uuid=True), ForeignKey("users.id"))
|
||||
# Assign created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
# Assign updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate...
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
|
||||
|
||||
# Assign creator = relationship("User", foreign_keys=[created_by])
|
||||
creator = relationship("User", foreign_keys=[created_by])
|
||||
|
||||
# Assign __table_args__ = (
|
||||
__table_args__ = (
|
||||
Index("ix_jira_links_entity_id", "entity_id"),
|
||||
Index("ix_jira_links_issue_key", "jira_issue_key"),
|
||||
Index("ix_jira_links_entity_type_entity_id", "entity_type", "entity_id"),
|
||||
)
|
||||
@@ -1,37 +1,54 @@
|
||||
"""Notification model — in-app notifications for user actions."""
|
||||
|
||||
# Import uuid
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Column, String, Text, Boolean, DateTime, ForeignKey, Index
|
||||
# Import Boolean, Column, DateTime, ForeignKey, Index, S... from sqlalchemy
|
||||
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Index, String, Text, func
|
||||
|
||||
# Import UUID from sqlalchemy.dialects.postgresql
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
# Import relationship from sqlalchemy.orm
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
# Import Base from app.database
|
||||
from app.database import Base
|
||||
|
||||
|
||||
# Define class Notification
|
||||
class Notification(Base):
|
||||
"""
|
||||
In-app notification for alerting users when they need to act.
|
||||
"""In-app notification for alerting users when they need to act.
|
||||
|
||||
Types include: test_assigned, validation_needed, test_rejected,
|
||||
test_validated, test_state_changed, etc.
|
||||
"""
|
||||
# Assign __tablename__ = "notifications"
|
||||
__tablename__ = "notifications"
|
||||
|
||||
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
# Assign user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=False)
|
||||
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=False)
|
||||
# Assign type = Column(String, nullable=False)
|
||||
type = Column(String, nullable=False)
|
||||
# Assign title = Column(String, nullable=False)
|
||||
title = Column(String, nullable=False)
|
||||
# Assign message = Column(Text, nullable=True)
|
||||
message = Column(Text, nullable=True)
|
||||
# Assign entity_type = Column(String, nullable=True)
|
||||
entity_type = Column(String, nullable=True)
|
||||
# Assign entity_id = Column(UUID(as_uuid=True), nullable=True)
|
||||
entity_id = Column(UUID(as_uuid=True), nullable=True)
|
||||
# Assign read = Column(Boolean, default=False)
|
||||
read = Column(Boolean, default=False)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
# Assign created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
# Relationships
|
||||
user = relationship("User")
|
||||
|
||||
# Assign __table_args__ = (
|
||||
__table_args__ = (
|
||||
Index("ix_notifications_user_id", "user_id"),
|
||||
Index("ix_notifications_read", "read"),
|
||||
|
||||
@@ -0,0 +1,59 @@
|
||||
"""OSINT enrichment items — CVEs, blogs, PoCs, and advisories linked to techniques."""
|
||||
|
||||
# Import uuid
|
||||
import uuid
|
||||
|
||||
# Import Boolean, Column, DateTime, ForeignKey, String, ... from sqlalchemy
|
||||
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, String, Text, func
|
||||
|
||||
# Import JSONB, UUID from sqlalchemy.dialects.postgresql
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
|
||||
# Import relationship from sqlalchemy.orm
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
# Import Base from app.database
|
||||
from app.database import Base
|
||||
|
||||
|
||||
# Define class OsintItem
|
||||
class OsintItem(Base):
|
||||
"""Represents an OSINT data point (CVE, blog, PoC, advisory) associated with a MITRE ATT&CK technique.
|
||||
|
||||
Used by the enrichment pipeline to surface relevant threat intelligence
|
||||
for each technique, flagging those that need review.
|
||||
"""
|
||||
|
||||
# Assign __tablename__ = "osint_items"
|
||||
__tablename__ = "osint_items"
|
||||
|
||||
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
# Assign technique_id = Column(
|
||||
technique_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("techniques.id"),
|
||||
# Keyword argument: nullable
|
||||
nullable=False,
|
||||
# Keyword argument: index
|
||||
index=True,
|
||||
)
|
||||
# Assign source_type = Column(String(50), nullable=False) # "cve", "blog", "poc", "advisory"
|
||||
source_type = Column(String(50), nullable=False) # "cve", "blog", "poc", "advisory"
|
||||
# Assign source_url = Column(Text, nullable=False)
|
||||
source_url = Column(Text, nullable=False)
|
||||
# Assign title = Column(String(500), nullable=False)
|
||||
title = Column(String(500), nullable=False)
|
||||
# Assign description = Column(Text, nullable=True)
|
||||
description = Column(Text, nullable=True)
|
||||
# Assign severity = Column(String(20), nullable=True) # CRITICAL, HIGH, MEDIUM, LOW, U...
|
||||
severity = Column(String(20), nullable=True) # CRITICAL, HIGH, MEDIUM, LOW, UNKNOWN
|
||||
# Assign discovered_at = Column(DateTime(timezone=True), server_default=func.now(), nullable...
|
||||
discovered_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||
# Assign reviewed = Column(Boolean, default=False)
|
||||
reviewed = Column(Boolean, default=False)
|
||||
# Assign metadata_ = Column("metadata", JSONB, default={})
|
||||
metadata_ = Column("metadata", JSONB, default={})
|
||||
|
||||
# ── Relationships ─────────────────────────────────────────────────
|
||||
technique = relationship("Technique", backref="osint_items")
|
||||
@@ -0,0 +1,43 @@
|
||||
"""ScoringConfig — single-row table for persisted scoring weights."""
|
||||
|
||||
# Import uuid
|
||||
import uuid
|
||||
|
||||
# Import Column, DateTime, Float, ForeignKey, func from sqlalchemy
|
||||
from sqlalchemy import Column, DateTime, Float, ForeignKey, func
|
||||
|
||||
# Import UUID from sqlalchemy.dialects.postgresql
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
# Import Base from app.database
|
||||
from app.database import Base
|
||||
|
||||
|
||||
# Define class ScoringConfig
|
||||
class ScoringConfig(Base):
|
||||
"""Single-row table persisting the active scoring weight configuration."""
|
||||
|
||||
# Assign __tablename__ = "scoring_config"
|
||||
__tablename__ = "scoring_config"
|
||||
|
||||
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
# Assign weight_tests = Column(Float, nullable=False, default=40.0)
|
||||
weight_tests = Column(Float, nullable=False, default=40.0)
|
||||
# Assign weight_detection_rules = Column(Float, nullable=False, default=25.0)
|
||||
weight_detection_rules = Column(Float, nullable=False, default=25.0)
|
||||
# Assign weight_d3fend = Column(Float, nullable=False, default=15.0)
|
||||
weight_d3fend = Column(Float, nullable=False, default=15.0)
|
||||
# Assign weight_recency = Column(Float, nullable=False, default=10.0)
|
||||
weight_recency = Column(Float, nullable=False, default=10.0)
|
||||
# Assign weight_severity = Column(Float, nullable=False, default=10.0)
|
||||
weight_severity = Column(Float, nullable=False, default=10.0)
|
||||
# Assign updated_by = Column(
|
||||
updated_by = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("users.id", ondelete="SET NULL"),
|
||||
# Keyword argument: nullable
|
||||
nullable=True,
|
||||
)
|
||||
# Assign updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate...
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
|
||||
@@ -1,38 +1,63 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
"""SQLAlchemy model for the techniques table."""
|
||||
|
||||
from sqlalchemy import Column, String, Text, Boolean, DateTime, Enum
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
# Import uuid
|
||||
import uuid
|
||||
|
||||
# Import Boolean, Column, DateTime, Enum, String, Text from sqlalchemy
|
||||
from sqlalchemy import Boolean, Column, DateTime, Enum, String, Text
|
||||
|
||||
# Import JSONB, UUID from sqlalchemy.dialects.postgresql
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
|
||||
# Import relationship from sqlalchemy.orm
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
# Import Base from app.database
|
||||
from app.database import Base
|
||||
|
||||
# Import TechniqueStatus from app.models.enums
|
||||
from app.models.enums import TechniqueStatus
|
||||
|
||||
|
||||
# Define class Technique
|
||||
class Technique(Base):
|
||||
"""
|
||||
MITRE ATT&CK Technique model.
|
||||
|
||||
"""MITRE ATT&CK Technique model.
|
||||
|
||||
Represents an attack technique from the MITRE ATT&CK framework,
|
||||
including its coverage status and associated tests.
|
||||
"""
|
||||
# Assign __tablename__ = "techniques"
|
||||
__tablename__ = "techniques"
|
||||
|
||||
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
# Assign mitre_id = Column(String, unique=True, nullable=False) # e.g., "T1059.001"
|
||||
mitre_id = Column(String, unique=True, nullable=False) # e.g., "T1059.001"
|
||||
# Assign name = Column(String, nullable=False)
|
||||
name = Column(String, nullable=False)
|
||||
# Assign description = Column(Text, nullable=True)
|
||||
description = Column(Text, nullable=True)
|
||||
# Assign tactic = Column(String, nullable=True)
|
||||
tactic = Column(String, nullable=True)
|
||||
# Assign platforms = Column(JSONB, nullable=True, default=[])
|
||||
platforms = Column(JSONB, nullable=True, default=[])
|
||||
# Assign mitre_version = Column(String, nullable=True)
|
||||
mitre_version = Column(String, nullable=True)
|
||||
# Assign mitre_last_modified = Column(DateTime, nullable=True)
|
||||
mitre_last_modified = Column(DateTime, nullable=True)
|
||||
# Assign is_subtechnique = Column(Boolean, default=False)
|
||||
is_subtechnique = Column(Boolean, default=False)
|
||||
# Assign parent_mitre_id = Column(String, nullable=True)
|
||||
parent_mitre_id = Column(String, nullable=True)
|
||||
# Assign status_global = Column(
|
||||
status_global = Column(
|
||||
Enum(TechniqueStatus, name="techniquestatus"),
|
||||
# Keyword argument: default
|
||||
default=TechniqueStatus.not_evaluated
|
||||
)
|
||||
# Assign review_required = Column(Boolean, default=False)
|
||||
review_required = Column(Boolean, default=False)
|
||||
# Assign last_review_date = Column(DateTime, nullable=True)
|
||||
last_review_date = Column(DateTime, nullable=True)
|
||||
|
||||
# Relationships
|
||||
|
||||
@@ -1,69 +1,144 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
"""SQLAlchemy model for the tests table."""
|
||||
|
||||
from sqlalchemy import Column, String, Text, Boolean, Integer, DateTime, ForeignKey, Enum
|
||||
# Import uuid
|
||||
import uuid
|
||||
|
||||
# Import from sqlalchemy
|
||||
from sqlalchemy import (
|
||||
Boolean,
|
||||
Column,
|
||||
DateTime,
|
||||
Enum,
|
||||
ForeignKey,
|
||||
Index,
|
||||
Integer,
|
||||
String,
|
||||
Text,
|
||||
func,
|
||||
)
|
||||
|
||||
# Import UUID from sqlalchemy.dialects.postgresql
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
# Import relationship from sqlalchemy.orm
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
# Import Base from app.database
|
||||
from app.database import Base
|
||||
from app.models.enums import TestState, TestResult
|
||||
|
||||
# Import TestResult, TestState from app.models.enums
|
||||
from app.models.enums import TestResult, TestState
|
||||
|
||||
|
||||
# Define class Test
|
||||
class Test(Base):
|
||||
"""
|
||||
Test model representing a security test for a MITRE ATT&CK technique.
|
||||
"""Test model representing a security test for a MITRE ATT&CK technique.
|
||||
|
||||
Each test documents an attempt to validate coverage of a specific technique,
|
||||
including the procedure, tools used, and outcome. V2 introduces dual
|
||||
validation: Red Lead and Blue Lead must each approve independently.
|
||||
"""
|
||||
# Assign __tablename__ = "tests"
|
||||
__tablename__ = "tests"
|
||||
|
||||
# ── Core fields ─────────────────────────────────────────────────
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
# Assign technique_id = Column(UUID(as_uuid=True), ForeignKey("techniques.id"), nullable=Fa...
|
||||
technique_id = Column(UUID(as_uuid=True), ForeignKey("techniques.id"), nullable=False)
|
||||
# Assign name = Column(String, nullable=False)
|
||||
name = Column(String, nullable=False)
|
||||
# Assign description = Column(Text, nullable=True)
|
||||
description = Column(Text, nullable=True)
|
||||
# Assign platform = Column(String, nullable=True)
|
||||
platform = Column(String, nullable=True)
|
||||
# Assign procedure_text = Column(Text, nullable=True)
|
||||
procedure_text = Column(Text, nullable=True)
|
||||
# Assign tool_used = Column(String, nullable=True)
|
||||
tool_used = Column(String, nullable=True)
|
||||
# Assign execution_date = Column(DateTime, nullable=True)
|
||||
execution_date = Column(DateTime, nullable=True)
|
||||
# Assign created_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
|
||||
created_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
|
||||
# Assign result = Column(Enum(TestResult, name="testresult"), nullable=True)
|
||||
result = Column(Enum(TestResult, name="testresult"), nullable=True)
|
||||
# Assign state = Column(Enum(TestState, name="teststate"), default=TestState.draft)
|
||||
state = Column(Enum(TestState, name="teststate"), default=TestState.draft)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
# Assign created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
# ── Red Team fields ─────────────────────────────────────────────
|
||||
red_summary = Column(Text, nullable=True)
|
||||
# Assign attack_success = Column(Boolean, nullable=True)
|
||||
attack_success = Column(Boolean, nullable=True)
|
||||
# Assign red_validated_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
|
||||
red_validated_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
|
||||
# Assign red_validated_at = Column(DateTime, nullable=True)
|
||||
red_validated_at = Column(DateTime, nullable=True)
|
||||
# Assign red_validation_status = Column(String, nullable=True) # pending / approved / rejected
|
||||
red_validation_status = Column(String, nullable=True) # pending / approved / rejected
|
||||
# Assign red_validation_notes = Column(Text, nullable=True)
|
||||
red_validation_notes = Column(Text, nullable=True)
|
||||
|
||||
# ── Blue Team fields ────────────────────────────────────────────
|
||||
blue_summary = Column(Text, nullable=True)
|
||||
# Assign detection_result = Column(Enum(TestResult, name="testresult"), nullable=True)
|
||||
detection_result = Column(Enum(TestResult, name="testresult"), nullable=True)
|
||||
# Assign blue_validated_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
|
||||
blue_validated_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
|
||||
# Assign blue_validated_at = Column(DateTime, nullable=True)
|
||||
blue_validated_at = Column(DateTime, nullable=True)
|
||||
# Assign blue_validation_status = Column(String, nullable=True) # pending / approved / rejected
|
||||
blue_validation_status = Column(String, nullable=True) # pending / approved / rejected
|
||||
# Assign blue_validation_notes = Column(Text, nullable=True)
|
||||
blue_validation_notes = Column(Text, nullable=True)
|
||||
|
||||
# ── Phase timing fields (for automatic Tempo worklogs) ──────────
|
||||
red_started_at = Column(DateTime, nullable=True)
|
||||
# Assign blue_started_at = Column(DateTime, nullable=True)
|
||||
blue_started_at = Column(DateTime, nullable=True)
|
||||
# Assign paused_at = Column(DateTime, nullable=True)
|
||||
paused_at = Column(DateTime, nullable=True)
|
||||
# Assign red_paused_seconds = Column(Integer, default=0)
|
||||
red_paused_seconds = Column(Integer, default=0)
|
||||
# Assign blue_paused_seconds = Column(Integer, default=0)
|
||||
blue_paused_seconds = Column(Integer, default=0)
|
||||
|
||||
# ── Remediation fields ───────────────────────────────────────────
|
||||
remediation_steps = Column(Text, nullable=True)
|
||||
# Assign remediation_status = Column(String, nullable=True) # pending / in_progress / completed ...
|
||||
remediation_status = Column(String, nullable=True) # pending / in_progress / completed / not_applicable
|
||||
# Assign remediation_assignee = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
|
||||
remediation_assignee = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
|
||||
|
||||
# ── Re-test fields ────────────────────────────────────────────
|
||||
retest_of = Column(UUID(as_uuid=True), ForeignKey("tests.id"), nullable=True)
|
||||
# Assign retest_count = Column(Integer, default=0)
|
||||
retest_count = Column(Integer, default=0)
|
||||
# Assign data_classification = Column(String(20), nullable=False, server_default="internal")
|
||||
data_classification = Column(String(20), nullable=False, server_default="internal")
|
||||
|
||||
# ── Relationships ───────────────────────────────────────────────
|
||||
technique = relationship("Technique", back_populates="tests")
|
||||
# Assign evidences = relationship("Evidence", back_populates="test")
|
||||
evidences = relationship("Evidence", back_populates="test")
|
||||
# Assign creator = relationship("User", foreign_keys=[created_by])
|
||||
creator = relationship("User", foreign_keys=[created_by])
|
||||
# Assign red_validator = relationship("User", foreign_keys=[red_validated_by])
|
||||
red_validator = relationship("User", foreign_keys=[red_validated_by])
|
||||
# Assign blue_validator = relationship("User", foreign_keys=[blue_validated_by])
|
||||
blue_validator = relationship("User", foreign_keys=[blue_validated_by])
|
||||
# Assign remediation_user = relationship("User", foreign_keys=[remediation_assignee])
|
||||
remediation_user = relationship("User", foreign_keys=[remediation_assignee])
|
||||
# Assign original_test = relationship("Test", remote_side="Test.id", foreign_keys=[retest_of])
|
||||
original_test = relationship("Test", remote_side="Test.id", foreign_keys=[retest_of])
|
||||
# Assign retests = relationship("Test", foreign_keys=[retest_of], back_populates="orig...
|
||||
retests = relationship("Test", foreign_keys=[retest_of], back_populates="original_test")
|
||||
|
||||
# Assign __table_args__ = (
|
||||
__table_args__ = (
|
||||
Index("ix_tests_technique_id", "technique_id"),
|
||||
Index("ix_tests_state", "state"),
|
||||
Index("ix_tests_created_at", "created_at"),
|
||||
Index("ix_tests_technique_state", "technique_id", "state"),
|
||||
Index("ix_tests_state_created_at", "state", "created_at"),
|
||||
)
|
||||
|
||||
@@ -4,52 +4,81 @@ When the Blue Team evaluates a test, they mark each associated detection
|
||||
rule as triggered / not triggered / not applicable, along with notes.
|
||||
"""
|
||||
|
||||
# Import uuid
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Column, String, Text, Boolean, DateTime, ForeignKey, Index
|
||||
# Import from sqlalchemy
|
||||
from sqlalchemy import (
|
||||
Boolean,
|
||||
Column,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Index,
|
||||
Text,
|
||||
UniqueConstraint,
|
||||
)
|
||||
|
||||
# Import UUID from sqlalchemy.dialects.postgresql
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
# Import relationship from sqlalchemy.orm
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
# Import Base from app.database
|
||||
from app.database import Base
|
||||
|
||||
|
||||
# Define class TestDetectionResult
|
||||
class TestDetectionResult(Base):
|
||||
"""
|
||||
Per-test, per-rule evaluation result.
|
||||
"""Per-test, per-rule evaluation result.
|
||||
|
||||
- ``triggered`` = True: rule detected the attack
|
||||
- ``triggered`` = False: rule did NOT detect the attack
|
||||
- ``triggered`` = None: not yet evaluated
|
||||
"""
|
||||
# Assign __tablename__ = "test_detection_results"
|
||||
__tablename__ = "test_detection_results"
|
||||
|
||||
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
# Assign test_id = Column(
|
||||
test_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("tests.id", ondelete="CASCADE"),
|
||||
# Keyword argument: nullable
|
||||
nullable=False,
|
||||
)
|
||||
# Assign detection_rule_id = Column(
|
||||
detection_rule_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("detection_rules.id", ondelete="CASCADE"),
|
||||
# Keyword argument: nullable
|
||||
nullable=False,
|
||||
)
|
||||
# Assign triggered = Column(Boolean, nullable=True) # None = not evaluated
|
||||
triggered = Column(Boolean, nullable=True) # None = not evaluated
|
||||
# Assign notes = Column(Text, nullable=True)
|
||||
notes = Column(Text, nullable=True)
|
||||
# Assign evaluated_by = Column(
|
||||
evaluated_by = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("users.id", ondelete="SET NULL"),
|
||||
# Keyword argument: nullable
|
||||
nullable=True,
|
||||
)
|
||||
# Assign evaluated_at = Column(DateTime, nullable=True)
|
||||
evaluated_at = Column(DateTime, nullable=True)
|
||||
|
||||
# Relationships
|
||||
test = relationship("Test")
|
||||
# Assign detection_rule = relationship("DetectionRule")
|
||||
detection_rule = relationship("DetectionRule")
|
||||
# Assign evaluator = relationship("User", foreign_keys=[evaluated_by])
|
||||
evaluator = relationship("User", foreign_keys=[evaluated_by])
|
||||
|
||||
# Assign __table_args__ = (
|
||||
__table_args__ = (
|
||||
Index('ix_tdr_test', 'test_id'),
|
||||
Index('ix_tdr_rule', 'detection_rule_id'),
|
||||
UniqueConstraint('test_id', 'detection_rule_id', name='uq_tdr_test_rule'),
|
||||
)
|
||||
|
||||
@@ -1,17 +1,21 @@
|
||||
"""TestTemplate model — predefined test catalog entries."""
|
||||
|
||||
# Import uuid
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Column, String, Text, Boolean, DateTime, Index
|
||||
# Import Boolean, Column, DateTime, Index, String, Text,... from sqlalchemy
|
||||
from sqlalchemy import Boolean, Column, DateTime, Index, String, Text, func
|
||||
|
||||
# Import UUID from sqlalchemy.dialects.postgresql
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
# Import Base from app.database
|
||||
from app.database import Base
|
||||
|
||||
|
||||
# Define class TestTemplate
|
||||
class TestTemplate(Base):
|
||||
"""
|
||||
Predefined test template mapped to a MITRE ATT&CK technique.
|
||||
"""Predefined test template mapped to a MITRE ATT&CK technique.
|
||||
|
||||
Templates come from several sources:
|
||||
- **atomic_red_team**: Atomic Red Team by Red Canary
|
||||
@@ -20,24 +24,41 @@ class TestTemplate(Base):
|
||||
|
||||
Users can instantiate a real Test from a template.
|
||||
"""
|
||||
# Assign __tablename__ = "test_templates"
|
||||
__tablename__ = "test_templates"
|
||||
|
||||
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
# Assign mitre_technique_id = Column(String, nullable=False) # e.g. "T1059.001"
|
||||
mitre_technique_id = Column(String, nullable=False) # e.g. "T1059.001"
|
||||
# Assign name = Column(String, nullable=False)
|
||||
name = Column(String, nullable=False)
|
||||
# Assign description = Column(Text, nullable=True)
|
||||
description = Column(Text, nullable=True)
|
||||
# Assign source = Column(String, nullable=False) # atomic_red_te...
|
||||
source = Column(String, nullable=False) # atomic_red_team / mitre / custom
|
||||
# Assign source_url = Column(String, nullable=True)
|
||||
source_url = Column(String, nullable=True)
|
||||
# Assign attack_procedure = Column(Text, nullable=True) # Suggested attack procedure
|
||||
attack_procedure = Column(Text, nullable=True) # Suggested attack procedure
|
||||
# Assign expected_detection = Column(Text, nullable=True) # What blue team should detect
|
||||
expected_detection = Column(Text, nullable=True) # What blue team should detect
|
||||
# Assign platform = Column(String, nullable=True) # windows / linux...
|
||||
platform = Column(String, nullable=True) # windows / linux / macos
|
||||
# Assign tool_suggested = Column(String, nullable=True)
|
||||
tool_suggested = Column(String, nullable=True)
|
||||
# Assign severity = Column(String, nullable=True) # low / medium / ...
|
||||
severity = Column(String, nullable=True) # low / medium / high / critical
|
||||
# Assign atomic_test_id = Column(String, nullable=True) # ID in Atomic Red Team...
|
||||
atomic_test_id = Column(String, nullable=True) # ID in Atomic Red Team repo
|
||||
# Assign suggested_remediation = Column(Text, nullable=True)
|
||||
suggested_remediation = Column(Text, nullable=True)
|
||||
# Assign is_active = Column(Boolean, default=True)
|
||||
is_active = Column(Boolean, default=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
# Assign created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
# Assign __table_args__ = (
|
||||
__table_args__ = (
|
||||
Index('ix_test_templates_mitre_technique_id', 'mitre_technique_id'),
|
||||
Index('ix_test_templates_source', 'source'),
|
||||
|
||||
@@ -4,47 +4,64 @@ Enables the Blue Team to see which detection rules should fire
|
||||
for a given test template / attack procedure.
|
||||
"""
|
||||
|
||||
# Import uuid
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Column, Boolean, ForeignKey, Index, UniqueConstraint
|
||||
# Import Boolean, Column, ForeignKey, Index, UniqueConst... from sqlalchemy
|
||||
from sqlalchemy import Boolean, Column, ForeignKey, Index, UniqueConstraint
|
||||
|
||||
# Import UUID from sqlalchemy.dialects.postgresql
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
# Import relationship from sqlalchemy.orm
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
# Import Base from app.database
|
||||
from app.database import Base
|
||||
|
||||
|
||||
# Define class TestTemplateDetectionRule
|
||||
class TestTemplateDetectionRule(Base):
|
||||
"""
|
||||
Association between a test template and a detection rule.
|
||||
"""Association between a test template and a detection rule.
|
||||
|
||||
Auto-generated by matching mitre_technique_id, or manually curated.
|
||||
``is_primary`` marks rules with severity >= high as primary detections.
|
||||
"""
|
||||
# Assign __tablename__ = "test_template_detection_rules"
|
||||
__tablename__ = "test_template_detection_rules"
|
||||
|
||||
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
# Assign test_template_id = Column(
|
||||
test_template_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("test_templates.id", ondelete="CASCADE"),
|
||||
# Keyword argument: nullable
|
||||
nullable=True,
|
||||
)
|
||||
# Assign detection_rule_id = Column(
|
||||
detection_rule_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("detection_rules.id", ondelete="CASCADE"),
|
||||
# Keyword argument: nullable
|
||||
nullable=False,
|
||||
)
|
||||
# Assign is_primary = Column(Boolean, default=False)
|
||||
is_primary = Column(Boolean, default=False)
|
||||
|
||||
# Relationships
|
||||
test_template = relationship("TestTemplate")
|
||||
# Assign detection_rule = relationship("DetectionRule")
|
||||
detection_rule = relationship("DetectionRule")
|
||||
|
||||
# Assign __table_args__ = (
|
||||
__table_args__ = (
|
||||
Index('ix_ttdr_template', 'test_template_id'),
|
||||
Index('ix_ttdr_rule', 'detection_rule_id'),
|
||||
UniqueConstraint(
|
||||
# Literal argument value
|
||||
'test_template_id', 'detection_rule_id',
|
||||
# Keyword argument: name
|
||||
name='uq_template_detection_rule',
|
||||
),
|
||||
)
|
||||
|
||||
@@ -4,89 +4,135 @@ Stores profiles of APT groups and their associated MITRE ATT&CK
|
||||
techniques, imported from MITRE CTI (STIX 2.0).
|
||||
"""
|
||||
|
||||
# Import uuid
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
# Import from sqlalchemy
|
||||
from sqlalchemy import (
|
||||
Column, String, Text, Boolean, DateTime,
|
||||
ForeignKey, Index, UniqueConstraint,
|
||||
Boolean,
|
||||
Column,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Index,
|
||||
String,
|
||||
Text,
|
||||
UniqueConstraint,
|
||||
func,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
|
||||
# Import JSONB, UUID from sqlalchemy.dialects.postgresql
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
|
||||
# Import relationship from sqlalchemy.orm
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
# Import Base from app.database
|
||||
from app.database import Base
|
||||
|
||||
|
||||
# Define class ThreatActor
|
||||
class ThreatActor(Base):
|
||||
"""
|
||||
Threat actor / APT group profile.
|
||||
"""Threat actor / APT group profile.
|
||||
|
||||
Imported from MITRE CTI ``intrusion-set`` STIX objects.
|
||||
"""
|
||||
# Assign __tablename__ = "threat_actors"
|
||||
__tablename__ = "threat_actors"
|
||||
|
||||
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
# Assign mitre_id = Column(String, unique=True, nullable=True) # e.g. "G00...
|
||||
mitre_id = Column(String, unique=True, nullable=True) # e.g. "G0016" (APT29)
|
||||
# Assign name = Column(String, nullable=False)
|
||||
name = Column(String, nullable=False)
|
||||
# Assign aliases = Column(JSONB, nullable=True, default=[]) # ["Cozy ...
|
||||
aliases = Column(JSONB, nullable=True, default=[]) # ["Cozy Bear", "The Dukes", ...]
|
||||
# Assign description = Column(Text, nullable=True)
|
||||
description = Column(Text, nullable=True)
|
||||
# Assign country = Column(String, nullable=True)
|
||||
country = Column(String, nullable=True)
|
||||
# Assign target_sectors = Column(JSONB, nullable=True, default=[]) # ["government",...
|
||||
target_sectors = Column(JSONB, nullable=True, default=[]) # ["government", "defense", ...]
|
||||
# Assign target_regions = Column(JSONB, nullable=True, default=[]) # ["north-americ...
|
||||
target_regions = Column(JSONB, nullable=True, default=[]) # ["north-america", "europe", ...]
|
||||
# Assign motivation = Column(String, nullable=True) # espionage ...
|
||||
motivation = Column(String, nullable=True) # espionage / financial / destruction / ...
|
||||
# Assign sophistication = Column(String, nullable=True) # low / medium /...
|
||||
sophistication = Column(String, nullable=True) # low / medium / high / advanced
|
||||
# Assign first_seen = Column(String, nullable=True)
|
||||
first_seen = Column(String, nullable=True)
|
||||
# Assign last_seen = Column(String, nullable=True)
|
||||
last_seen = Column(String, nullable=True)
|
||||
# Assign references = Column(JSONB, nullable=True, default=[]) # [{"url": "...
|
||||
references = Column(JSONB, nullable=True, default=[]) # [{"url": "...", "description": "..."}]
|
||||
# Assign mitre_url = Column(String, nullable=True)
|
||||
mitre_url = Column(String, nullable=True)
|
||||
# Assign is_active = Column(Boolean, default=True)
|
||||
is_active = Column(Boolean, default=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
# Assign created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
# Relationships
|
||||
techniques = relationship(
|
||||
# Literal argument value
|
||||
"ThreatActorTechnique",
|
||||
# Keyword argument: back_populates
|
||||
back_populates="threat_actor",
|
||||
# Keyword argument: cascade
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
# Assign __table_args__ = (
|
||||
__table_args__ = (
|
||||
Index('ix_threat_actors_country', 'country'),
|
||||
Index('ix_threat_actors_motivation', 'motivation'),
|
||||
)
|
||||
|
||||
|
||||
# Define class ThreatActorTechnique
|
||||
class ThreatActorTechnique(Base):
|
||||
"""
|
||||
Association between a threat actor and a MITRE ATT&CK technique.
|
||||
"""Association between a threat actor and a MITRE ATT&CK technique.
|
||||
|
||||
Stores additional context about how the actor uses the technique
|
||||
(from the STIX ``relationship`` ``uses`` objects).
|
||||
"""
|
||||
# Assign __tablename__ = "threat_actor_techniques"
|
||||
__tablename__ = "threat_actor_techniques"
|
||||
|
||||
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
# Assign threat_actor_id = Column(
|
||||
threat_actor_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("threat_actors.id", ondelete="CASCADE"),
|
||||
# Keyword argument: nullable
|
||||
nullable=False,
|
||||
)
|
||||
# Assign technique_id = Column(
|
||||
technique_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("techniques.id", ondelete="CASCADE"),
|
||||
# Keyword argument: nullable
|
||||
nullable=False,
|
||||
)
|
||||
# Assign usage_description = Column(Text, nullable=True)
|
||||
usage_description = Column(Text, nullable=True)
|
||||
# Assign first_seen_using = Column(String, nullable=True)
|
||||
first_seen_using = Column(String, nullable=True)
|
||||
|
||||
# Relationships
|
||||
threat_actor = relationship("ThreatActor", back_populates="techniques")
|
||||
# Assign technique = relationship("Technique")
|
||||
technique = relationship("Technique")
|
||||
|
||||
# Assign __table_args__ = (
|
||||
__table_args__ = (
|
||||
Index('ix_threat_actor_techniques_actor', 'threat_actor_id'),
|
||||
Index('ix_threat_actor_techniques_technique', 'technique_id'),
|
||||
UniqueConstraint(
|
||||
# Literal argument value
|
||||
'threat_actor_id', 'technique_id',
|
||||
# Keyword argument: name
|
||||
name='uq_actor_technique',
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1,16 +1,22 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
"""SQLAlchemy model for the users table."""
|
||||
|
||||
from sqlalchemy import Column, String, Boolean, DateTime
|
||||
# Import uuid
|
||||
import uuid
|
||||
|
||||
# Import Boolean, Column, DateTime, String, func from sqlalchemy
|
||||
from sqlalchemy import Boolean, Column, DateTime, String, func
|
||||
|
||||
# Import UUID from sqlalchemy.dialects.postgresql
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
# Import Base from app.database
|
||||
from app.database import Base
|
||||
|
||||
|
||||
# Define class User
|
||||
class User(Base):
|
||||
"""
|
||||
User model for authentication and authorization.
|
||||
|
||||
"""User model for authentication and authorization.
|
||||
|
||||
Possible roles:
|
||||
- admin: Full system access
|
||||
- red_tech: Red team technician - can create and edit tests
|
||||
@@ -19,13 +25,24 @@ class User(Base):
|
||||
- blue_lead: Blue team lead - can validate tests
|
||||
- viewer: Read-only access (default)
|
||||
"""
|
||||
# Assign __tablename__ = "users"
|
||||
__tablename__ = "users"
|
||||
|
||||
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
# Assign username = Column(String, unique=True, nullable=False)
|
||||
username = Column(String, unique=True, nullable=False)
|
||||
# Assign email = Column(String, nullable=True)
|
||||
email = Column(String, nullable=True)
|
||||
# Assign hashed_password = Column(String, nullable=False)
|
||||
hashed_password = Column(String, nullable=False)
|
||||
# Assign role = Column(String, nullable=False, default="viewer")
|
||||
role = Column(String, nullable=False, default="viewer")
|
||||
# Assign is_active = Column(Boolean, default=True)
|
||||
is_active = Column(Boolean, default=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
# Assign must_change_password = Column(Boolean, default=True)
|
||||
must_change_password = Column(Boolean, default=True)
|
||||
# Assign created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
# Assign last_login = Column(DateTime, nullable=True)
|
||||
last_login = Column(DateTime, nullable=True)
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
"""Worklog model — immutable internal time-tracking records."""
|
||||
|
||||
# Import uuid
|
||||
import uuid
|
||||
|
||||
# Import Column, DateTime, ForeignKey, Index, Integer, S... from sqlalchemy
|
||||
from sqlalchemy import Column, DateTime, ForeignKey, Index, Integer, String, Text, func
|
||||
|
||||
# Import JSONB, UUID from sqlalchemy.dialects.postgresql
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
|
||||
# Import relationship from sqlalchemy.orm
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
# Import Base from app.database
|
||||
from app.database import Base
|
||||
|
||||
|
||||
# Define class Worklog
|
||||
class Worklog(Base):
|
||||
"""Internal worklog entry with integrity hash for audit compliance.
|
||||
|
||||
Each worklog is tied to an Aegis entity (test, campaign, etc.) and
|
||||
optionally synced to Tempo. The ``integrity_hash`` is a SHA-256 of
|
||||
the immutable fields so tampering can be detected.
|
||||
"""
|
||||
|
||||
# Assign __tablename__ = "worklogs"
|
||||
__tablename__ = "worklogs"
|
||||
|
||||
# Assign id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
# Assign entity_type = Column(String(50), nullable=False)
|
||||
entity_type = Column(String(50), nullable=False)
|
||||
# Assign entity_id = Column(UUID(as_uuid=True), nullable=False)
|
||||
entity_id = Column(UUID(as_uuid=True), nullable=False)
|
||||
# Assign user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=False)
|
||||
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=False)
|
||||
# Assign activity_type = Column(String(100), nullable=False)
|
||||
activity_type = Column(String(100), nullable=False)
|
||||
# Assign started_at = Column(DateTime, nullable=False)
|
||||
started_at = Column(DateTime, nullable=False)
|
||||
# Assign ended_at = Column(DateTime)
|
||||
ended_at = Column(DateTime)
|
||||
# Assign duration_seconds = Column(Integer, nullable=False)
|
||||
duration_seconds = Column(Integer, nullable=False)
|
||||
# Assign description = Column(Text)
|
||||
description = Column(Text)
|
||||
# Assign tempo_synced = Column(DateTime)
|
||||
tempo_synced = Column(DateTime)
|
||||
# Assign tempo_worklog_id = Column(String(100))
|
||||
tempo_worklog_id = Column(String(100))
|
||||
# Assign integrity_hash = Column(String(64))
|
||||
integrity_hash = Column(String(64))
|
||||
# Assign created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
# Assign extra_metadata = Column("metadata", JSONB, default={})
|
||||
extra_metadata = Column("metadata", JSONB, default={})
|
||||
|
||||
# Assign user = relationship("User", foreign_keys=[user_id])
|
||||
user = relationship("User", foreign_keys=[user_id])
|
||||
|
||||
# Assign __table_args__ = (
|
||||
__table_args__ = (
|
||||
Index("ix_worklogs_entity_id", "entity_id"),
|
||||
Index("ix_worklogs_user_id", "user_id"),
|
||||
Index("ix_worklogs_entity_type_entity_id", "entity_type", "entity_id"),
|
||||
)
|
||||
@@ -0,0 +1 @@
|
||||
"""FastAPI router modules — one router per feature domain."""
|
||||
|
||||
@@ -0,0 +1,81 @@
|
||||
"""Advanced metrics endpoints — coverage by tactic, never-tested, avg validation time."""
|
||||
|
||||
# Import APIRouter, Depends from fastapi
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import get_db from app.database
|
||||
from app.database import get_db
|
||||
|
||||
# Import get_current_user from app.dependencies.auth
|
||||
from app.dependencies.auth import get_current_user
|
||||
|
||||
# Import User from app.models.user
|
||||
from app.models.user import User
|
||||
|
||||
# Import advanced_metrics_service from app.services
|
||||
from app.services import advanced_metrics_service
|
||||
|
||||
# Assign router = APIRouter(prefix="/metrics/advanced", tags=["advanced-metrics"])
|
||||
router = APIRouter(prefix="/metrics/advanced", tags=["advanced-metrics"])
|
||||
|
||||
|
||||
# Apply the @router.get decorator
|
||||
@router.get("/coverage-by-tactic")
|
||||
# Define function coverage_by_tactic
|
||||
def coverage_by_tactic(
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: user
|
||||
user: User = Depends(get_current_user),
|
||||
) -> list:
|
||||
"""Coverage percentage broken down by MITRE ATT&CK tactic."""
|
||||
# Return advanced_metrics_service.get_coverage_by_tactic(db)
|
||||
return advanced_metrics_service.get_coverage_by_tactic(db)
|
||||
|
||||
|
||||
# Apply the @router.get decorator
|
||||
@router.get("/never-tested")
|
||||
# Define function never_tested_techniques
|
||||
def never_tested_techniques(
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: user
|
||||
user: User = Depends(get_current_user),
|
||||
) -> list:
|
||||
"""Techniques that have never had a test created."""
|
||||
# Return advanced_metrics_service.get_never_tested_techniques(db)
|
||||
return advanced_metrics_service.get_never_tested_techniques(db)
|
||||
|
||||
|
||||
# Apply the @router.get decorator
|
||||
@router.get("/avg-validation-time")
|
||||
# Define function avg_validation_time
|
||||
def avg_validation_time(
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: user
|
||||
user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""Average time from test creation to validation, computed from audit logs.
|
||||
|
||||
Returns overall average and per-phase averages where data is available.
|
||||
"""
|
||||
# Return advanced_metrics_service.get_avg_validation_time(db)
|
||||
return advanced_metrics_service.get_avg_validation_time(db)
|
||||
|
||||
|
||||
# Apply the @router.get decorator
|
||||
@router.get("/detection-rate-trend")
|
||||
# Define function detection_rate_trend
|
||||
def detection_rate_trend(
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: user
|
||||
user: User = Depends(get_current_user),
|
||||
) -> list:
|
||||
"""Monthly detection rate trend for the last 12 months."""
|
||||
# Return advanced_metrics_service.get_detection_rate_trend(db)
|
||||
return advanced_metrics_service.get_detection_rate_trend(db)
|
||||
@@ -0,0 +1,88 @@
|
||||
"""Analytics endpoints — flat JSON optimized for PowerBI / BI tools.
|
||||
|
||||
Returns complete datasets without pagination so BI tools can ingest
|
||||
directly from URL. All endpoints require authentication.
|
||||
"""
|
||||
|
||||
# Import APIRouter, Depends, Query from fastapi
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import get_db from app.database
|
||||
from app.database import get_db
|
||||
|
||||
# Import get_current_user, require_role from app.dependencies.auth
|
||||
from app.dependencies.auth import get_current_user, require_role
|
||||
|
||||
# Import User from app.models.user
|
||||
from app.models.user import User
|
||||
|
||||
# Import analytics_service from app.services
|
||||
from app.services import analytics_service
|
||||
|
||||
# Assign router = APIRouter(prefix="/analytics", tags=["analytics"])
|
||||
router = APIRouter(prefix="/analytics", tags=["analytics"])
|
||||
|
||||
|
||||
# Apply the @router.get decorator
|
||||
@router.get("/coverage")
|
||||
# Define function analytics_coverage
|
||||
def analytics_coverage(
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: user
|
||||
user: User = Depends(get_current_user),
|
||||
) -> list:
|
||||
"""Coverage per technique — flat format for BI dashboards."""
|
||||
# Return analytics_service.get_coverage_analytics(db)
|
||||
return analytics_service.get_coverage_analytics(db)
|
||||
|
||||
|
||||
# Apply the @router.get decorator
|
||||
@router.get("/tests")
|
||||
# Define function analytics_tests
|
||||
def analytics_tests(
|
||||
# Entry: date_from
|
||||
date_from: str = Query(None, description="ISO date filter (>=)"),
|
||||
# Entry: date_to
|
||||
date_to: str = Query(None, description="ISO date filter (<=)"),
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: user
|
||||
user: User = Depends(get_current_user),
|
||||
) -> list:
|
||||
"""All tests with timestamps — flat format for BI dashboards."""
|
||||
# Return analytics_service.get_tests_analytics(
|
||||
return analytics_service.get_tests_analytics(
|
||||
db, date_from=date_from, date_to=date_to
|
||||
)
|
||||
|
||||
|
||||
# Apply the @router.get decorator
|
||||
@router.get("/trends")
|
||||
# Define function analytics_trends
|
||||
def analytics_trends(
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: user
|
||||
user: User = Depends(get_current_user),
|
||||
) -> list:
|
||||
"""Historical coverage snapshots for trend visualization."""
|
||||
# Return analytics_service.get_trends_analytics(db)
|
||||
return analytics_service.get_trends_analytics(db)
|
||||
|
||||
|
||||
# Apply the @router.get decorator
|
||||
@router.get("/operators")
|
||||
# Define function analytics_operators
|
||||
def analytics_operators(
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: user
|
||||
user: User = Depends(require_role("admin")),
|
||||
) -> list:
|
||||
"""Per-operator metrics — for workload management dashboards."""
|
||||
# Return analytics_service.get_operators_analytics(db)
|
||||
return analytics_service.get_operators_analytics(db)
|
||||
@@ -1,118 +1,127 @@
|
||||
"""Audit log viewer router (admin only)."""
|
||||
|
||||
# Import datetime from datetime
|
||||
from datetime import datetime
|
||||
|
||||
# Import Optional from typing
|
||||
from typing import Optional
|
||||
|
||||
# Import APIRouter, Depends, Query from fastapi
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import get_db from app.database
|
||||
from app.database import get_db
|
||||
|
||||
# Import require_role from app.dependencies.auth
|
||||
from app.dependencies.auth import require_role
|
||||
from app.models.audit import AuditLog
|
||||
|
||||
# Import User from app.models.user
|
||||
from app.models.user import User
|
||||
|
||||
# Import AuditLogOut, AuditLogPage from app.schemas.audit
|
||||
from app.schemas.audit import AuditLogOut, AuditLogPage
|
||||
|
||||
# Import from app.services.audit_query_service
|
||||
from app.services.audit_query_service import (
|
||||
list_distinct_actions,
|
||||
list_distinct_entity_types,
|
||||
list_logs,
|
||||
)
|
||||
|
||||
# Assign router = APIRouter(prefix="/audit-logs", tags=["audit"])
|
||||
router = APIRouter(prefix="/audit-logs", tags=["audit"])
|
||||
|
||||
|
||||
# Apply the @router.get decorator
|
||||
@router.get("", response_model=AuditLogPage)
|
||||
# Define function list_audit_logs
|
||||
def list_audit_logs(
|
||||
# Entry: user_id
|
||||
user_id: Optional[str] = Query(None, description="Filter by user ID"),
|
||||
# Entry: action
|
||||
action: Optional[str] = Query(None, description="Filter by action type"),
|
||||
# Entry: entity_type
|
||||
entity_type: Optional[str] = Query(None, description="Filter by entity type"),
|
||||
# Entry: start_date
|
||||
start_date: Optional[datetime] = Query(None, description="Filter by start date"),
|
||||
# Entry: end_date
|
||||
end_date: Optional[datetime] = Query(None, description="Filter by end date"),
|
||||
# Entry: offset
|
||||
offset: int = Query(0, ge=0, description="Number of records to skip"),
|
||||
# Entry: limit
|
||||
limit: int = Query(50, ge=1, le=100, description="Max records to return"),
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
) -> AuditLogPage:
|
||||
"""Return paginated audit logs with optional filters.
|
||||
|
||||
|
||||
**Requires admin role.**
|
||||
"""
|
||||
query = db.query(AuditLog).options(joinedload(AuditLog.user))
|
||||
|
||||
# Apply filters
|
||||
if user_id:
|
||||
query = query.filter(AuditLog.user_id == user_id)
|
||||
if action:
|
||||
query = query.filter(AuditLog.action == action)
|
||||
if entity_type:
|
||||
query = query.filter(AuditLog.entity_type == entity_type)
|
||||
if start_date:
|
||||
query = query.filter(AuditLog.timestamp >= start_date)
|
||||
if end_date:
|
||||
query = query.filter(AuditLog.timestamp <= end_date)
|
||||
|
||||
# Get total count
|
||||
total = query.count()
|
||||
|
||||
# Get paginated results
|
||||
logs = (
|
||||
query
|
||||
.order_by(AuditLog.timestamp.desc())
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Convert to response format with username
|
||||
items = []
|
||||
for log in logs:
|
||||
item = AuditLogOut(
|
||||
id=log.id,
|
||||
user_id=log.user_id,
|
||||
username=log.user.username if log.user else None,
|
||||
action=log.action,
|
||||
entity_type=log.entity_type,
|
||||
entity_id=log.entity_id,
|
||||
timestamp=log.timestamp,
|
||||
details=log.details,
|
||||
)
|
||||
items.append(item)
|
||||
|
||||
return AuditLogPage(
|
||||
items=items,
|
||||
total=total,
|
||||
# Assign result = list_logs(
|
||||
result = list_logs(
|
||||
db,
|
||||
# Keyword argument: user_id
|
||||
user_id=user_id,
|
||||
# Keyword argument: action
|
||||
action=action,
|
||||
# Keyword argument: entity_type
|
||||
entity_type=entity_type,
|
||||
# Keyword argument: start_date
|
||||
start_date=start_date,
|
||||
# Keyword argument: end_date
|
||||
end_date=end_date,
|
||||
# Keyword argument: offset
|
||||
offset=offset,
|
||||
# Keyword argument: limit
|
||||
limit=limit,
|
||||
)
|
||||
# Return AuditLogPage(
|
||||
return AuditLogPage(
|
||||
# Keyword argument: items
|
||||
items=[AuditLogOut(**item) for item in result["items"]],
|
||||
# Keyword argument: total
|
||||
total=result["total"],
|
||||
# Keyword argument: offset
|
||||
offset=result["offset"],
|
||||
# Keyword argument: limit
|
||||
limit=result["limit"],
|
||||
)
|
||||
|
||||
|
||||
# Apply the @router.get decorator
|
||||
@router.get("/actions", response_model=list[str])
|
||||
# Define function list_actions
|
||||
def list_actions(
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
) -> list[str]:
|
||||
"""Return a list of distinct action types in the audit log.
|
||||
|
||||
|
||||
**Requires admin role.**
|
||||
"""
|
||||
actions = (
|
||||
db.query(AuditLog.action)
|
||||
.distinct()
|
||||
.order_by(AuditLog.action)
|
||||
.all()
|
||||
)
|
||||
return [a[0] for a in actions]
|
||||
# Return list_distinct_actions(db)
|
||||
return list_distinct_actions(db)
|
||||
|
||||
|
||||
# Apply the @router.get decorator
|
||||
@router.get("/entity-types", response_model=list[str])
|
||||
# Define function list_entity_types
|
||||
def list_entity_types(
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
) -> list[str]:
|
||||
"""Return a list of distinct entity types in the audit log.
|
||||
|
||||
|
||||
**Requires admin role.**
|
||||
"""
|
||||
types = (
|
||||
db.query(AuditLog.entity_type)
|
||||
.filter(AuditLog.entity_type.isnot(None))
|
||||
.distinct()
|
||||
.order_by(AuditLog.entity_type)
|
||||
.all()
|
||||
)
|
||||
return [t[0] for t in types]
|
||||
# Return list_distinct_entity_types(db)
|
||||
return list_distinct_entity_types(db)
|
||||
|
||||
+266
-18
@@ -1,47 +1,295 @@
|
||||
"""Authentication router: login and current-user endpoints."""
|
||||
"""Authentication router: login, logout and current-user endpoints.
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
The JWT access token is delivered as an **HttpOnly** cookie
|
||||
(``aegis_token``) so it is inaccessible to client-side JavaScript,
|
||||
mitigating XSS token-theft attacks. The JSON response also includes
|
||||
the token in the body for backwards compatibility and for clients that
|
||||
cannot use cookies (e.g. Swagger UI).
|
||||
"""
|
||||
|
||||
# Import os
|
||||
import os
|
||||
|
||||
# Import APIRouter, Cookie, Depends, Request, Response from fastapi
|
||||
from fastapi import APIRouter, Cookie, Depends, Request, Response
|
||||
|
||||
# Import OAuth2PasswordRequestForm from fastapi.security
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
|
||||
# Import jwt (PyJWT)
|
||||
import jwt
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.auth import verify_password, create_access_token
|
||||
# Import blacklist_token, create_access_token, verify_pa... from app.auth
|
||||
from app.auth import blacklist_token, create_access_token, verify_password
|
||||
|
||||
# Import settings from app.config
|
||||
from app.config import settings
|
||||
|
||||
# Import get_db from app.database
|
||||
from app.database import get_db
|
||||
|
||||
# Import get_current_user from app.dependencies.auth
|
||||
from app.dependencies.auth import get_current_user
|
||||
|
||||
# Import BusinessRuleViolation, PermissionViolation from app.domain.errors
|
||||
from app.domain.errors import BusinessRuleViolation, PermissionViolation
|
||||
|
||||
# Import UnitOfWork from app.domain.unit_of_work
|
||||
from app.domain.unit_of_work import UnitOfWork
|
||||
|
||||
# Import limiter from app.limiter
|
||||
from app.limiter import limiter
|
||||
|
||||
# Import resolve_client_ip from app.middleware.request_context
|
||||
from app.middleware.request_context import resolve_client_ip
|
||||
|
||||
# Import User from app.models.user
|
||||
from app.models.user import User
|
||||
|
||||
# Import TokenResponse, UserOut from app.schemas.auth
|
||||
from app.schemas.auth import TokenResponse, UserOut
|
||||
|
||||
# Import PasswordChange from app.schemas.user
|
||||
from app.schemas.user import PasswordChange
|
||||
|
||||
# Import log_action from app.services.audit_service
|
||||
from app.services.audit_service import log_action
|
||||
|
||||
# Import from app.services.auth_service
|
||||
from app.services.auth_service import (
|
||||
_DUMMY_HASH,
|
||||
)
|
||||
|
||||
# Import from app.services.auth_service
|
||||
from app.services.auth_service import (
|
||||
change_password as auth_change_password,
|
||||
)
|
||||
|
||||
# Assign router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /auth/login
|
||||
# ---------------------------------------------------------------------------
|
||||
# Assign _IS_HTTPS = os.environ.get("AEGIS_ENV", "").lower() == "production"
|
||||
_IS_HTTPS = os.environ.get("AEGIS_ENV", "").lower() == "production"
|
||||
# Assign _COOKIE_NAME = "aegis_token"
|
||||
_COOKIE_NAME = "aegis_token"
|
||||
|
||||
|
||||
# Apply the @router.post decorator
|
||||
@router.post("/login", response_model=TokenResponse)
|
||||
# Apply the @limiter.limit decorator
|
||||
@limiter.limit("5/minute")
|
||||
# Define function login
|
||||
def login(
|
||||
# Entry: request
|
||||
request: Request,
|
||||
# Entry: response
|
||||
response: Response,
|
||||
# Entry: form_data
|
||||
form_data: OAuth2PasswordRequestForm = Depends(),
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Authenticate a user and return a JWT access token."""
|
||||
) -> TokenResponse:
|
||||
"""Authenticate a user and return a JWT access token.
|
||||
|
||||
Rate-limited to **5 attempts per minute per IP**. Failed and successful
|
||||
logins are recorded in the audit log (SEC-009).
|
||||
"""
|
||||
# Assign user = db.query(User).filter(User.username == form_data.username).first()
|
||||
user = db.query(User).filter(User.username == form_data.username).first()
|
||||
# Assign target_hash = user.hashed_password if user else _DUMMY_HASH
|
||||
target_hash = user.hashed_password if user else _DUMMY_HASH
|
||||
# Assign password_valid = verify_password(form_data.password, target_hash)
|
||||
password_valid = verify_password(form_data.password, target_hash)
|
||||
# Assign ip = resolve_client_ip(request)
|
||||
ip = resolve_client_ip(request)
|
||||
|
||||
if user is None or not verify_password(form_data.password, user.hashed_password):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Incorrect username or password",
|
||||
)
|
||||
# Check: user is None or not password_valid
|
||||
if user is None or not password_valid:
|
||||
# Open context manager
|
||||
with UnitOfWork(db) as uow:
|
||||
# Call log_action()
|
||||
log_action(
|
||||
db,
|
||||
user.id if user else None,
|
||||
# Literal argument value
|
||||
"LOGIN_FAILED",
|
||||
# Literal argument value
|
||||
"auth",
|
||||
# Literal argument value
|
||||
None,
|
||||
# Keyword argument: details
|
||||
details={
|
||||
# Literal argument value
|
||||
"username": form_data.username,
|
||||
# Literal argument value
|
||||
"ip": ip,
|
||||
# Literal argument value
|
||||
"reason": "invalid_credentials",
|
||||
},
|
||||
# Keyword argument: ip_address
|
||||
ip_address=ip,
|
||||
)
|
||||
# Call uow.commit()
|
||||
uow.commit()
|
||||
# Raise BusinessRuleViolation
|
||||
raise BusinessRuleViolation("Incorrect username or password")
|
||||
|
||||
# Check: not user.is_active
|
||||
if not user.is_active:
|
||||
# Raise PermissionViolation
|
||||
raise PermissionViolation("Account is disabled. Contact an administrator.")
|
||||
|
||||
# Assign access_token = create_access_token(data={"sub": user.username})
|
||||
access_token = create_access_token(data={"sub": user.username})
|
||||
|
||||
# Open context manager
|
||||
with UnitOfWork(db) as uow:
|
||||
# Call log_action()
|
||||
log_action(
|
||||
db,
|
||||
user.id,
|
||||
# Literal argument value
|
||||
"LOGIN_SUCCESS",
|
||||
# Literal argument value
|
||||
"auth",
|
||||
str(user.id),
|
||||
# Keyword argument: details
|
||||
details={"username": user.username, "ip": ip},
|
||||
# Keyword argument: ip_address
|
||||
ip_address=ip,
|
||||
)
|
||||
# Call uow.commit()
|
||||
uow.commit()
|
||||
|
||||
# Call response.set_cookie()
|
||||
response.set_cookie(
|
||||
# Keyword argument: key
|
||||
key=_COOKIE_NAME,
|
||||
# Keyword argument: value
|
||||
value=access_token,
|
||||
# Keyword argument: httponly
|
||||
httponly=True,
|
||||
# Keyword argument: secure
|
||||
secure=_IS_HTTPS,
|
||||
# Keyword argument: samesite
|
||||
samesite="strict",
|
||||
# Keyword argument: max_age
|
||||
max_age=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
||||
# Keyword argument: path
|
||||
path="/",
|
||||
)
|
||||
|
||||
# Return TokenResponse(access_token=access_token)
|
||||
return TokenResponse(access_token=access_token)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /auth/me
|
||||
# ---------------------------------------------------------------------------
|
||||
# Apply the @router.post decorator
|
||||
@router.post("/logout")
|
||||
# Define function logout
|
||||
def logout(
|
||||
# Entry: request
|
||||
request: Request,
|
||||
# Entry: response
|
||||
response: Response,
|
||||
# Entry: aegis_token
|
||||
aegis_token: str | None = Cookie(None),
|
||||
) -> dict:
|
||||
"""Clear the authentication cookie and revoke the current token."""
|
||||
# Assign bearer = (
|
||||
bearer = (
|
||||
request.headers.get("Authorization")
|
||||
or request.headers.get("authorization")
|
||||
or ""
|
||||
)
|
||||
# Assign bearer = bearer.removeprefix("Bearer ").removeprefix("bearer ").strip()
|
||||
bearer = bearer.removeprefix("Bearer ").removeprefix("bearer ").strip()
|
||||
|
||||
# Assign seen = set()
|
||||
seen: set[str] = set()
|
||||
# Iterate over (aegis_token, bearer)
|
||||
for raw in (aegis_token, bearer):
|
||||
# Check: not raw or raw in seen
|
||||
if not raw or raw in seen:
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
# Call seen.add()
|
||||
seen.add(raw)
|
||||
# Attempt the following; catch errors below
|
||||
try:
|
||||
# Assign payload = jwt.decode(
|
||||
payload = jwt.decode(
|
||||
raw,
|
||||
settings.SECRET_KEY,
|
||||
# Keyword argument: algorithms
|
||||
algorithms=[settings.ALGORITHM],
|
||||
)
|
||||
# Assign jti = payload.get("jti")
|
||||
jti = payload.get("jti")
|
||||
# Assign exp = payload.get("exp", 0)
|
||||
exp = payload.get("exp", 0)
|
||||
# Check: jti
|
||||
if jti:
|
||||
# Call blacklist_token()
|
||||
blacklist_token(jti, float(exp))
|
||||
# Handle any JWT validation error during logout (token may be expired or malformed)
|
||||
except jwt.exceptions.InvalidTokenError:
|
||||
# Intentional no-op placeholder
|
||||
pass
|
||||
|
||||
# Call response.delete_cookie()
|
||||
response.delete_cookie(
|
||||
# Keyword argument: key
|
||||
key=_COOKIE_NAME,
|
||||
# Keyword argument: httponly
|
||||
httponly=True,
|
||||
# Keyword argument: secure
|
||||
secure=_IS_HTTPS,
|
||||
# Keyword argument: samesite
|
||||
samesite="strict",
|
||||
# Keyword argument: path
|
||||
path="/",
|
||||
)
|
||||
# Return {"detail": "Logged out"}
|
||||
return {"detail": "Logged out"}
|
||||
|
||||
|
||||
# Apply the @router.get decorator
|
||||
@router.get("/me", response_model=UserOut)
|
||||
def read_current_user(current_user: User = Depends(get_current_user)):
|
||||
# Define function read_current_user
|
||||
def read_current_user(current_user: User = Depends(get_current_user)) -> UserOut:
|
||||
"""Return the profile of the currently authenticated user."""
|
||||
# Return current_user
|
||||
return current_user
|
||||
|
||||
|
||||
# Apply the @router.post decorator
|
||||
@router.post("/change-password")
|
||||
# Define function change_password
|
||||
def change_password(
|
||||
# Entry: body
|
||||
body: PasswordChange,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""Change the current user's password."""
|
||||
# Call auth_change_password()
|
||||
auth_change_password(
|
||||
db,
|
||||
current_user,
|
||||
# Keyword argument: current_password
|
||||
current_password=body.current_password,
|
||||
# Keyword argument: new_password
|
||||
new_password=body.new_password,
|
||||
)
|
||||
# Open context manager
|
||||
with UnitOfWork(db) as uow:
|
||||
# Call uow.commit()
|
||||
uow.commit()
|
||||
|
||||
# Return {"detail": "Password changed successfully"}
|
||||
return {"detail": "Password changed successfully"}
|
||||
|
||||
+549
-407
File diff suppressed because it is too large
Load Diff
+137
-296
@@ -1,292 +1,157 @@
|
||||
"""Compliance endpoints — framework status, reports, and gap analysis.
|
||||
|
||||
Thin HTTP adapter that delegates all data logic to compliance_service.
|
||||
Provides compliance posture assessment by mapping MITRE ATT&CK technique
|
||||
coverage to compliance framework controls.
|
||||
"""
|
||||
|
||||
import csv
|
||||
import io
|
||||
from typing import Optional
|
||||
# Import APIRouter, Depends from fastapi
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
# Import StreamingResponse from fastapi.responses
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import get_db from app.database
|
||||
from app.database import get_db
|
||||
|
||||
# Import get_current_user, require_role from app.dependencies.auth
|
||||
from app.dependencies.auth import get_current_user, require_role
|
||||
|
||||
# Import User from app.models.user
|
||||
from app.models.user import User
|
||||
from app.models.compliance import (
|
||||
ComplianceFramework,
|
||||
ComplianceControl,
|
||||
ComplianceControlMapping,
|
||||
)
|
||||
from app.models.technique import Technique
|
||||
from app.models.test_template import TestTemplate
|
||||
from app.models.threat_actor import ThreatActorTechnique
|
||||
from app.services.scoring_service import calculate_technique_score
|
||||
|
||||
# Import from app.services.compliance_import_service
|
||||
from app.services.compliance_import_service import (
|
||||
import_nist_800_53_mappings,
|
||||
import_cis_controls_v8_mappings,
|
||||
import_nist_800_53_mappings,
|
||||
)
|
||||
|
||||
# Import from app.services.compliance_service
|
||||
from app.services.compliance_service import (
|
||||
build_framework_report_csv,
|
||||
get_framework_gaps,
|
||||
get_framework_status,
|
||||
list_frameworks,
|
||||
)
|
||||
|
||||
# Assign router = APIRouter(prefix="/compliance", tags=["compliance"])
|
||||
router = APIRouter(prefix="/compliance", tags=["compliance"])
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _classify_control(technique_scores: list[float]) -> str:
|
||||
"""Classify a control status based on its technique scores."""
|
||||
if not technique_scores:
|
||||
return "not_evaluated"
|
||||
|
||||
all_above_70 = all(s >= 70 for s in technique_scores)
|
||||
any_above_30 = any(s >= 30 for s in technique_scores)
|
||||
all_below_30 = all(s < 30 for s in technique_scores)
|
||||
all_zero = all(s == 0 for s in technique_scores)
|
||||
|
||||
if all_zero:
|
||||
return "not_evaluated"
|
||||
if all_above_70:
|
||||
return "covered"
|
||||
if all_below_30:
|
||||
return "not_covered"
|
||||
if any_above_30:
|
||||
return "partially_covered"
|
||||
return "not_covered"
|
||||
|
||||
|
||||
def _get_control_status(control: ComplianceControl, db: Session) -> dict:
|
||||
"""Compute the status and score for a single control."""
|
||||
mappings = (
|
||||
db.query(ComplianceControlMapping)
|
||||
.filter(ComplianceControlMapping.compliance_control_id == control.id)
|
||||
.all()
|
||||
)
|
||||
|
||||
if not mappings:
|
||||
return {
|
||||
"control_id": control.control_id,
|
||||
"title": control.title,
|
||||
"category": control.category,
|
||||
"status": "not_evaluated",
|
||||
"score": 0,
|
||||
"techniques_count": 0,
|
||||
"techniques_covered": 0,
|
||||
"techniques": [],
|
||||
}
|
||||
|
||||
technique_ids = [m.technique_id for m in mappings]
|
||||
techniques = (
|
||||
db.query(Technique)
|
||||
.filter(Technique.id.in_(technique_ids))
|
||||
.all()
|
||||
)
|
||||
|
||||
tech_details = []
|
||||
scores = []
|
||||
covered_count = 0
|
||||
|
||||
for tech in techniques:
|
||||
result = calculate_technique_score(tech, db)
|
||||
score = result["total_score"]
|
||||
scores.append(score)
|
||||
if score >= 50:
|
||||
covered_count += 1
|
||||
|
||||
tech_details.append({
|
||||
"mitre_id": tech.mitre_id,
|
||||
"name": tech.name,
|
||||
"score": score,
|
||||
"status": tech.status_global.value if tech.status_global else "not_evaluated",
|
||||
})
|
||||
|
||||
# Sort techniques by score ascending (worst first for priority)
|
||||
tech_details.sort(key=lambda t: t["score"])
|
||||
|
||||
avg_score = round(sum(scores) / len(scores), 1) if scores else 0
|
||||
status = _classify_control(scores)
|
||||
|
||||
return {
|
||||
"control_id": control.control_id,
|
||||
"title": control.title,
|
||||
"category": control.category,
|
||||
"status": status,
|
||||
"score": avg_score,
|
||||
"techniques_count": len(techniques),
|
||||
"techniques_covered": covered_count,
|
||||
"techniques": tech_details,
|
||||
}
|
||||
|
||||
|
||||
# ── GET /compliance/frameworks ────────────────────────────────────────
|
||||
|
||||
|
||||
@router.get("/frameworks")
|
||||
def list_frameworks(
|
||||
# Define function list_frameworks_endpoint
|
||||
def list_frameworks_endpoint(
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""List all available compliance frameworks."""
|
||||
frameworks = (
|
||||
db.query(ComplianceFramework)
|
||||
.filter(ComplianceFramework.is_active == True)
|
||||
.all()
|
||||
)
|
||||
) -> list:
|
||||
"""List all available compliance frameworks.
|
||||
|
||||
result = []
|
||||
for fw in frameworks:
|
||||
control_count = (
|
||||
db.query(ComplianceControl)
|
||||
.filter(ComplianceControl.framework_id == fw.id)
|
||||
.count()
|
||||
)
|
||||
result.append({
|
||||
"id": str(fw.id),
|
||||
"name": fw.name,
|
||||
"version": fw.version,
|
||||
"description": fw.description,
|
||||
"url": fw.url,
|
||||
"is_active": fw.is_active,
|
||||
"controls_count": control_count,
|
||||
})
|
||||
Args:
|
||||
db (Session): SQLAlchemy database session.
|
||||
current_user (User): Authenticated user making the request.
|
||||
|
||||
return result
|
||||
Returns:
|
||||
list: List of framework summary dicts containing id, name, and control counts.
|
||||
"""
|
||||
# Return list_frameworks(db)
|
||||
return list_frameworks(db)
|
||||
|
||||
|
||||
# ── GET /compliance/frameworks/{id}/status ────────────────────────────
|
||||
|
||||
|
||||
@router.get("/frameworks/{framework_id}/status")
|
||||
# Define function framework_status
|
||||
def framework_status(
|
||||
# Entry: framework_id
|
||||
framework_id: str,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Get compliance status for each control in a framework."""
|
||||
framework = (
|
||||
db.query(ComplianceFramework)
|
||||
.filter(ComplianceFramework.id == framework_id)
|
||||
.first()
|
||||
)
|
||||
if not framework:
|
||||
raise HTTPException(status_code=404, detail="Framework not found")
|
||||
) -> dict:
|
||||
"""Get compliance status for each control in a framework.
|
||||
|
||||
controls = (
|
||||
db.query(ComplianceControl)
|
||||
.filter(ComplianceControl.framework_id == framework.id)
|
||||
.order_by(ComplianceControl.control_id)
|
||||
.all()
|
||||
)
|
||||
Args:
|
||||
framework_id (str): Identifier of the compliance framework (e.g. ``nist-800-53``).
|
||||
db (Session): SQLAlchemy database session.
|
||||
current_user (User): Authenticated user making the request.
|
||||
|
||||
control_statuses = []
|
||||
summary = {
|
||||
"total_controls": len(controls),
|
||||
"covered": 0,
|
||||
"partially_covered": 0,
|
||||
"not_covered": 0,
|
||||
"not_evaluated": 0,
|
||||
}
|
||||
|
||||
for control in controls:
|
||||
status_data = _get_control_status(control, db)
|
||||
control_statuses.append(status_data)
|
||||
|
||||
status = status_data["status"]
|
||||
if status in summary:
|
||||
summary[status] += 1
|
||||
|
||||
# Compliance percentage: (covered + partially_covered*0.5) / total * 100
|
||||
total = summary["total_controls"]
|
||||
if total > 0:
|
||||
compliance_pct = round(
|
||||
(summary["covered"] + summary["partially_covered"] * 0.5) / total * 100,
|
||||
1,
|
||||
)
|
||||
else:
|
||||
compliance_pct = 0
|
||||
|
||||
summary["compliance_percentage"] = compliance_pct
|
||||
|
||||
return {
|
||||
"framework": {"id": str(framework.id), "name": framework.name},
|
||||
"summary": summary,
|
||||
"controls": control_statuses,
|
||||
}
|
||||
Returns:
|
||||
dict: Mapping of control IDs to their coverage status and linked techniques.
|
||||
"""
|
||||
# Return get_framework_status(db, framework_id)
|
||||
return get_framework_status(db, framework_id)
|
||||
|
||||
|
||||
# ── GET /compliance/frameworks/{id}/report ────────────────────────────
|
||||
|
||||
|
||||
@router.get("/frameworks/{framework_id}/report")
|
||||
# Define function framework_report
|
||||
def framework_report(
|
||||
# Entry: framework_id
|
||||
framework_id: str,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Get the full compliance report (same as status but marked as report)."""
|
||||
return framework_status(framework_id, db=db, current_user=current_user)
|
||||
) -> dict:
|
||||
"""Get the full compliance report (same as status but marked as report).
|
||||
|
||||
Args:
|
||||
framework_id (str): Identifier of the compliance framework.
|
||||
db (Session): SQLAlchemy database session.
|
||||
current_user (User): Authenticated user making the request.
|
||||
|
||||
Returns:
|
||||
dict: Full compliance report with per-control coverage details.
|
||||
"""
|
||||
# Return get_framework_status(db, framework_id)
|
||||
return get_framework_status(db, framework_id)
|
||||
|
||||
|
||||
# ── GET /compliance/frameworks/{id}/report/csv ────────────────────────
|
||||
|
||||
|
||||
@router.get("/frameworks/{framework_id}/report/csv")
|
||||
# Define function framework_report_csv
|
||||
def framework_report_csv(
|
||||
# Entry: framework_id
|
||||
framework_id: str,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Export compliance report as CSV."""
|
||||
framework = (
|
||||
db.query(ComplianceFramework)
|
||||
.filter(ComplianceFramework.id == framework_id)
|
||||
.first()
|
||||
)
|
||||
if not framework:
|
||||
raise HTTPException(status_code=404, detail="Framework not found")
|
||||
) -> StreamingResponse:
|
||||
"""Export compliance report as CSV.
|
||||
|
||||
controls = (
|
||||
db.query(ComplianceControl)
|
||||
.filter(ComplianceControl.framework_id == framework.id)
|
||||
.order_by(ComplianceControl.control_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
output = io.StringIO()
|
||||
writer = csv.writer(output)
|
||||
writer.writerow([
|
||||
"control_id",
|
||||
"title",
|
||||
"category",
|
||||
"status",
|
||||
"score",
|
||||
"techniques_total",
|
||||
"techniques_covered",
|
||||
"technique_ids",
|
||||
])
|
||||
|
||||
for control in controls:
|
||||
status_data = _get_control_status(control, db)
|
||||
technique_ids = ",".join(t["mitre_id"] for t in status_data["techniques"])
|
||||
writer.writerow([
|
||||
status_data["control_id"],
|
||||
status_data["title"],
|
||||
status_data["category"] or "",
|
||||
status_data["status"],
|
||||
status_data["score"],
|
||||
status_data["techniques_count"],
|
||||
status_data["techniques_covered"],
|
||||
technique_ids,
|
||||
])
|
||||
|
||||
output.seek(0)
|
||||
filename = f"compliance_{framework.name.replace(' ', '_')}.csv"
|
||||
Args:
|
||||
framework_id (str): Identifier of the compliance framework to export.
|
||||
db (Session): SQLAlchemy database session.
|
||||
current_user (User): Authenticated user making the request.
|
||||
|
||||
Returns:
|
||||
StreamingResponse: CSV file attachment with compliance coverage data.
|
||||
"""
|
||||
# csv_bytes, filename = build_framework_report_csv(db, framework_id)
|
||||
csv_bytes, filename = build_framework_report_csv(db, framework_id)
|
||||
# Return StreamingResponse(
|
||||
return StreamingResponse(
|
||||
io.BytesIO(output.getvalue().encode("utf-8")),
|
||||
iter([csv_bytes]),
|
||||
# Keyword argument: media_type
|
||||
media_type="text/csv",
|
||||
# Keyword argument: headers
|
||||
headers={
|
||||
# Literal argument value
|
||||
"Content-Disposition": f"attachment; filename={filename}",
|
||||
},
|
||||
)
|
||||
@@ -296,98 +161,74 @@ def framework_report_csv(
|
||||
|
||||
|
||||
@router.get("/frameworks/{framework_id}/gaps")
|
||||
# Define function framework_gaps
|
||||
def framework_gaps(
|
||||
# Entry: framework_id
|
||||
framework_id: str,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Get controls with techniques that are not adequately covered."""
|
||||
framework = (
|
||||
db.query(ComplianceFramework)
|
||||
.filter(ComplianceFramework.id == framework_id)
|
||||
.first()
|
||||
)
|
||||
if not framework:
|
||||
raise HTTPException(status_code=404, detail="Framework not found")
|
||||
) -> dict:
|
||||
"""Get controls with techniques that are not adequately covered.
|
||||
|
||||
controls = (
|
||||
db.query(ComplianceControl)
|
||||
.filter(ComplianceControl.framework_id == framework.id)
|
||||
.order_by(ComplianceControl.control_id)
|
||||
.all()
|
||||
)
|
||||
Args:
|
||||
framework_id (str): Identifier of the compliance framework to analyse.
|
||||
db (Session): SQLAlchemy database session.
|
||||
current_user (User): Authenticated user making the request.
|
||||
|
||||
gaps = []
|
||||
for control in controls:
|
||||
status_data = _get_control_status(control, db)
|
||||
|
||||
if status_data["status"] in ("not_covered", "partially_covered"):
|
||||
# Find uncovered techniques
|
||||
uncovered_techniques = []
|
||||
for tech_info in status_data["techniques"]:
|
||||
if tech_info["score"] < 70:
|
||||
# Count available templates
|
||||
template_count = (
|
||||
db.query(TestTemplate)
|
||||
.filter(TestTemplate.mitre_technique_id == tech_info["mitre_id"])
|
||||
.count()
|
||||
)
|
||||
|
||||
# Count threat actors using this technique
|
||||
technique = (
|
||||
db.query(Technique)
|
||||
.filter(Technique.mitre_id == tech_info["mitre_id"])
|
||||
.first()
|
||||
)
|
||||
actor_count = 0
|
||||
if technique:
|
||||
actor_count = (
|
||||
db.query(ThreatActorTechnique)
|
||||
.filter(ThreatActorTechnique.technique_id == technique.id)
|
||||
.count()
|
||||
)
|
||||
|
||||
uncovered_techniques.append({
|
||||
**tech_info,
|
||||
"templates_available": template_count,
|
||||
"threat_actors_using": actor_count,
|
||||
})
|
||||
|
||||
if uncovered_techniques:
|
||||
gaps.append({
|
||||
"control_id": status_data["control_id"],
|
||||
"title": status_data["title"],
|
||||
"category": status_data["category"],
|
||||
"status": status_data["status"],
|
||||
"score": status_data["score"],
|
||||
"uncovered_techniques": uncovered_techniques,
|
||||
})
|
||||
|
||||
return {
|
||||
"framework": {"id": str(framework.id), "name": framework.name},
|
||||
"total_gaps": len(gaps),
|
||||
"gaps": gaps,
|
||||
}
|
||||
Returns:
|
||||
dict: Controls flagged as gaps, with linked technique IDs and coverage ratios.
|
||||
"""
|
||||
# Return get_framework_gaps(db, framework_id)
|
||||
return get_framework_gaps(db, framework_id)
|
||||
|
||||
|
||||
# ── POST /compliance/import/nist-800-53 ──────────────────────────────
|
||||
# ── POST /compliance/import/... ────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post("/import/nist-800-53")
|
||||
# Define function import_nist
|
||||
def import_nist(
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Import NIST 800-53 Rev 5 mappings (admin only)."""
|
||||
) -> dict:
|
||||
"""Import NIST 800-53 Rev 5 mappings (admin only).
|
||||
|
||||
Args:
|
||||
db (Session): SQLAlchemy database session.
|
||||
current_user (User): Authenticated admin user.
|
||||
|
||||
Returns:
|
||||
dict: Import result with counts of created and updated control mappings.
|
||||
"""
|
||||
# Assign result = import_nist_800_53_mappings(db)
|
||||
result = import_nist_800_53_mappings(db)
|
||||
# Return result
|
||||
return result
|
||||
|
||||
|
||||
# Apply the @router.post decorator
|
||||
@router.post("/import/cis-controls-v8")
|
||||
# Define function import_cis
|
||||
def import_cis(
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Import CIS Controls v8 mappings (admin only)."""
|
||||
) -> dict:
|
||||
"""Import CIS Controls v8 mappings (admin only).
|
||||
|
||||
Args:
|
||||
db (Session): SQLAlchemy database session.
|
||||
current_user (User): Authenticated admin user.
|
||||
|
||||
Returns:
|
||||
dict: Import result with counts of created and updated control mappings.
|
||||
"""
|
||||
# Assign result = import_cis_controls_v8_mappings(db)
|
||||
result = import_cis_controls_v8_mappings(db)
|
||||
# Return result
|
||||
return result
|
||||
|
||||
@@ -1,24 +1,47 @@
|
||||
"""D3FEND endpoints — defensive technique listings, mappings, and import trigger."""
|
||||
|
||||
# Import logging
|
||||
import logging
|
||||
|
||||
# Import Optional from typing
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
# Import APIRouter, Depends, Query from fastapi
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import get_db from app.database
|
||||
from app.database import get_db
|
||||
|
||||
# Import get_current_user, require_role from app.dependencies.auth
|
||||
from app.dependencies.auth import get_current_user, require_role
|
||||
|
||||
# Import User from app.models.user
|
||||
from app.models.user import User
|
||||
from app.models.technique import Technique
|
||||
from app.models.defensive_technique import DefensiveTechnique, DefensiveTechniqueMapping
|
||||
|
||||
# Import from app.services.d3fend_import_service
|
||||
from app.services.d3fend_import_service import (
|
||||
import_d3fend_techniques,
|
||||
import_d3fend_mappings,
|
||||
get_defenses_for_technique,
|
||||
import_d3fend_techniques,
|
||||
)
|
||||
|
||||
# Import from app.services.d3fend_query_service
|
||||
from app.services.d3fend_query_service import (
|
||||
get_defenses_for_attack_technique,
|
||||
list_d3fend_tactics,
|
||||
)
|
||||
|
||||
# Import from app.services.d3fend_query_service
|
||||
from app.services.d3fend_query_service import (
|
||||
list_defensive_techniques as list_defensive_techniques_svc,
|
||||
)
|
||||
|
||||
# Assign logger = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Assign router = APIRouter(prefix="/d3fend", tags=["d3fend"])
|
||||
router = APIRouter(prefix="/d3fend", tags=["d3fend"])
|
||||
|
||||
|
||||
@@ -27,46 +50,26 @@ router = APIRouter(prefix="/d3fend", tags=["d3fend"])
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("")
|
||||
# Define function list_defensive_techniques
|
||||
def list_defensive_techniques(
|
||||
# Entry: tactic
|
||||
tactic: Optional[str] = Query(None),
|
||||
# Entry: search
|
||||
search: Optional[str] = Query(None),
|
||||
# Entry: offset
|
||||
offset: int = Query(0, ge=0),
|
||||
# Entry: limit
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> list:
|
||||
"""List all D3FEND defensive techniques with optional filters."""
|
||||
query = db.query(DefensiveTechnique)
|
||||
|
||||
if tactic:
|
||||
query = query.filter(DefensiveTechnique.tactic == tactic)
|
||||
|
||||
if search:
|
||||
pattern = f"%{search}%"
|
||||
query = query.filter(
|
||||
DefensiveTechnique.name.ilike(pattern)
|
||||
| DefensiveTechnique.d3fend_id.ilike(pattern)
|
||||
)
|
||||
|
||||
total = query.count()
|
||||
items = query.order_by(DefensiveTechnique.d3fend_id).offset(offset).limit(limit).all()
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"offset": offset,
|
||||
"limit": limit,
|
||||
"items": [
|
||||
{
|
||||
"id": str(dt.id),
|
||||
"d3fend_id": dt.d3fend_id,
|
||||
"name": dt.name,
|
||||
"description": dt.description,
|
||||
"tactic": dt.tactic,
|
||||
"d3fend_url": dt.d3fend_url,
|
||||
}
|
||||
for dt in items
|
||||
],
|
||||
}
|
||||
# Return list_defensive_techniques_svc(
|
||||
return list_defensive_techniques_svc(
|
||||
db, tactic=tactic, search=search, offset=offset, limit=limit
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -74,21 +77,16 @@ def list_defensive_techniques(
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/tactics")
|
||||
def list_d3fend_tactics(
|
||||
# Define function list_d3fend_tactics_endpoint
|
||||
def list_d3fend_tactics_endpoint(
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> list:
|
||||
"""Return a list of all D3FEND tactics with counts."""
|
||||
from sqlalchemy import func
|
||||
|
||||
rows = (
|
||||
db.query(DefensiveTechnique.tactic, func.count(DefensiveTechnique.id))
|
||||
.group_by(DefensiveTechnique.tactic)
|
||||
.order_by(DefensiveTechnique.tactic)
|
||||
.all()
|
||||
)
|
||||
|
||||
return [{"tactic": tactic or "Unknown", "count": count} for tactic, count in rows]
|
||||
# Return list_d3fend_tactics(db)
|
||||
return list_d3fend_tactics(db)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -96,24 +94,18 @@ def list_d3fend_tactics(
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/for-technique/{mitre_id}")
|
||||
def get_defenses_for_attack_technique(
|
||||
# Define function get_defenses_for_attack_technique_endpoint
|
||||
def get_defenses_for_attack_technique_endpoint(
|
||||
# Entry: mitre_id
|
||||
mitre_id: str,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> list:
|
||||
"""Get all D3FEND defensive techniques mapped to a given ATT&CK technique."""
|
||||
technique = db.query(Technique).filter(Technique.mitre_id == mitre_id).first()
|
||||
if not technique:
|
||||
raise HTTPException(status_code=404, detail=f"Technique {mitre_id} not found")
|
||||
|
||||
defenses = get_defenses_for_technique(db, technique.id)
|
||||
|
||||
return {
|
||||
"mitre_id": mitre_id,
|
||||
"technique_name": technique.name,
|
||||
"defenses": defenses,
|
||||
"total": len(defenses),
|
||||
}
|
||||
# Return get_defenses_for_attack_technique(db, mitre_id)
|
||||
return get_defenses_for_attack_technique(db, mitre_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -121,15 +113,23 @@ def get_defenses_for_attack_technique(
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/import")
|
||||
# Define function trigger_d3fend_import
|
||||
def trigger_d3fend_import(
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
) -> dict:
|
||||
"""Import D3FEND techniques and ATT&CK mappings. Admin only."""
|
||||
# Assign tech_result = import_d3fend_techniques(db)
|
||||
tech_result = import_d3fend_techniques(db)
|
||||
# Assign mapping_result = import_d3fend_mappings(db)
|
||||
mapping_result = import_d3fend_mappings(db)
|
||||
|
||||
# Return {
|
||||
return {
|
||||
# Literal argument value
|
||||
"techniques": tech_result,
|
||||
# Literal argument value
|
||||
"mappings": mapping_result,
|
||||
}
|
||||
|
||||
+122
-214
@@ -5,289 +5,197 @@ Provides a centralized panel for managing all external data sources
|
||||
including sync triggers, enable/disable toggles, and statistics.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
# Import Optional from typing
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
# Import APIRouter, Depends from fastapi
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
# Import BaseModel from pydantic
|
||||
from pydantic import BaseModel
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import get_db from app.database
|
||||
from app.database import get_db
|
||||
|
||||
# Import require_role from app.dependencies.auth
|
||||
from app.dependencies.auth import require_role
|
||||
|
||||
# Import UnitOfWork from app.domain.unit_of_work
|
||||
from app.domain.unit_of_work import UnitOfWork
|
||||
|
||||
# Import User from app.models.user
|
||||
from app.models.user import User
|
||||
from app.models.data_source import DataSource
|
||||
|
||||
# Import log_action from app.services.audit_service
|
||||
from app.services.audit_service import log_action
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# Import from app.services.data_source_service
|
||||
from app.services.data_source_service import (
|
||||
get_source_stats,
|
||||
list_sources,
|
||||
sync_all_sources,
|
||||
sync_source,
|
||||
update_source,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pydantic schemas for request validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class DataSourceUpdate(BaseModel):
|
||||
"""Payload for updating a data source — only allowed fields."""
|
||||
# Assign is_enabled = None
|
||||
is_enabled: Optional[bool] = None
|
||||
# Assign sync_frequency = None
|
||||
sync_frequency: Optional[str] = None
|
||||
# Assign config = None
|
||||
config: Optional[dict] = None
|
||||
|
||||
|
||||
# Assign router = APIRouter(prefix="/data-sources", tags=["data-sources"])
|
||||
router = APIRouter(prefix="/data-sources", tags=["data-sources"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sync dispatcher — maps source name → import function
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _get_sync_handler(source_name: str):
|
||||
"""Lazily import and return the sync function for *source_name*.
|
||||
|
||||
We import lazily to avoid circular imports and to only load the
|
||||
modules that are actually needed.
|
||||
"""
|
||||
handlers = {
|
||||
"atomic_red_team": ("app.services.atomic_import_service", "import_atomic_red_team"),
|
||||
"sigma": ("app.services.sigma_import_service", "sync"),
|
||||
"lolbas": ("app.services.lolbas_import_service", "sync"),
|
||||
"gtfobins": ("app.services.lolbas_import_service", "sync_gtfobins"),
|
||||
"caldera": ("app.services.caldera_import_service", "sync"),
|
||||
"elastic_rules": ("app.services.elastic_import_service", "sync"),
|
||||
"mitre_cti": ("app.services.threat_actor_import_service", "sync"),
|
||||
"d3fend": ("app.services.d3fend_import_service", "sync"),
|
||||
}
|
||||
|
||||
if source_name not in handlers:
|
||||
return None
|
||||
|
||||
module_path, func_name = handlers[source_name]
|
||||
import importlib
|
||||
mod = importlib.import_module(module_path)
|
||||
return getattr(mod, func_name)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get("")
|
||||
# Define function list_data_sources
|
||||
def list_data_sources(
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
) -> list:
|
||||
"""List all registered data sources.
|
||||
|
||||
**Requires** the ``admin`` role.
|
||||
"""
|
||||
sources = db.query(DataSource).order_by(DataSource.name).all()
|
||||
return [
|
||||
{
|
||||
"id": str(s.id),
|
||||
"name": s.name,
|
||||
"display_name": s.display_name,
|
||||
"type": s.type,
|
||||
"url": s.url,
|
||||
"description": s.description,
|
||||
"is_enabled": s.is_enabled,
|
||||
"last_sync_at": s.last_sync_at.isoformat() if s.last_sync_at else None,
|
||||
"last_sync_status": s.last_sync_status,
|
||||
"last_sync_stats": s.last_sync_stats,
|
||||
"sync_frequency": s.sync_frequency,
|
||||
"config": s.config,
|
||||
"created_at": s.created_at.isoformat() if s.created_at else None,
|
||||
}
|
||||
for s in sources
|
||||
]
|
||||
# Return list_sources(db)
|
||||
return list_sources(db)
|
||||
|
||||
|
||||
# Apply the @router.patch decorator
|
||||
@router.patch("/{source_id}")
|
||||
# Define function update_data_source
|
||||
def update_data_source(
|
||||
# Entry: source_id
|
||||
source_id: str,
|
||||
body: dict,
|
||||
# Entry: body
|
||||
body: DataSourceUpdate,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
) -> dict:
|
||||
"""Update a data source (enable/disable, change config).
|
||||
|
||||
**Requires** the ``admin`` role.
|
||||
|
||||
Body fields (all optional):
|
||||
- ``is_enabled`` (bool)
|
||||
- ``sync_frequency`` (str)
|
||||
- ``config`` (dict)
|
||||
"""
|
||||
ds = db.query(DataSource).filter(DataSource.id == source_id).first()
|
||||
if not ds:
|
||||
raise HTTPException(status_code=404, detail="Data source not found")
|
||||
# Assign update_data = body.model_dump(exclude_unset=True)
|
||||
update_data = body.model_dump(exclude_unset=True)
|
||||
# Open context manager
|
||||
with UnitOfWork(db) as uow:
|
||||
# Call update_source()
|
||||
update_source(db, source_id, **update_data)
|
||||
# Call log_action()
|
||||
log_action(
|
||||
db,
|
||||
# Keyword argument: user_id
|
||||
user_id=current_user.id,
|
||||
# Keyword argument: action
|
||||
action="update_data_source",
|
||||
# Keyword argument: entity_type
|
||||
entity_type="data_source",
|
||||
# Keyword argument: entity_id
|
||||
entity_id=source_id,
|
||||
# Keyword argument: details
|
||||
details={"updates": update_data},
|
||||
)
|
||||
# Call uow.commit()
|
||||
uow.commit()
|
||||
|
||||
if "is_enabled" in body:
|
||||
ds.is_enabled = bool(body["is_enabled"])
|
||||
if "sync_frequency" in body:
|
||||
ds.sync_frequency = body["sync_frequency"]
|
||||
if "config" in body:
|
||||
ds.config = body["config"]
|
||||
|
||||
db.commit()
|
||||
|
||||
log_action(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
action="update_data_source",
|
||||
entity_type="data_source",
|
||||
entity_id=str(ds.id),
|
||||
details={"updates": body},
|
||||
)
|
||||
|
||||
return {"message": "Data source updated", "id": str(ds.id)}
|
||||
# Return {"message": "Data source updated", "id": source_id}
|
||||
return {"message": "Data source updated", "id": source_id}
|
||||
|
||||
|
||||
# Apply the @router.post decorator
|
||||
@router.post("/{source_id}/sync")
|
||||
# Define function sync_data_source
|
||||
def sync_data_source(
|
||||
# Entry: source_id
|
||||
source_id: str,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
) -> dict:
|
||||
"""Trigger sync/import for a specific data source.
|
||||
|
||||
**Requires** the ``admin`` role.
|
||||
"""
|
||||
ds = db.query(DataSource).filter(DataSource.id == source_id).first()
|
||||
if not ds:
|
||||
raise HTTPException(status_code=404, detail="Data source not found")
|
||||
|
||||
handler = _get_sync_handler(ds.name)
|
||||
if handler is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"No sync handler available for '{ds.name}'",
|
||||
)
|
||||
|
||||
# Mark as in_progress
|
||||
ds.last_sync_status = "in_progress"
|
||||
db.commit()
|
||||
|
||||
try:
|
||||
summary = handler(db)
|
||||
except Exception as exc:
|
||||
logger.error("Sync failed for %s: %s", ds.name, exc)
|
||||
ds.last_sync_status = "error"
|
||||
ds.last_sync_at = datetime.utcnow()
|
||||
ds.last_sync_stats = {"error": str(exc)}
|
||||
db.commit()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Sync failed: {str(exc)}",
|
||||
)
|
||||
|
||||
# Update DS record (the handler may already have done this,
|
||||
# but we ensure it here as well)
|
||||
ds.last_sync_at = datetime.utcnow()
|
||||
ds.last_sync_status = "success"
|
||||
ds.last_sync_stats = summary
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
"message": f"Sync complete for {ds.display_name}",
|
||||
"source": ds.name,
|
||||
"stats": summary,
|
||||
}
|
||||
# Return sync_source(db, source_id)
|
||||
return sync_source(db, source_id)
|
||||
|
||||
|
||||
# Apply the @router.post decorator
|
||||
@router.post("/sync-all")
|
||||
# Define function sync_all_data_sources
|
||||
def sync_all_data_sources(
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
) -> dict:
|
||||
"""Trigger sync for all enabled data sources (sequentially).
|
||||
|
||||
**Requires** the ``admin`` role.
|
||||
"""
|
||||
enabled_sources = (
|
||||
db.query(DataSource)
|
||||
.filter(DataSource.is_enabled == True)
|
||||
.order_by(DataSource.name)
|
||||
.all()
|
||||
)
|
||||
# Assign results = sync_all_sources(db)
|
||||
results = sync_all_sources(db)
|
||||
|
||||
results = []
|
||||
for ds in enabled_sources:
|
||||
handler = _get_sync_handler(ds.name)
|
||||
if handler is None:
|
||||
results.append({
|
||||
"source": ds.name,
|
||||
"status": "skipped",
|
||||
"detail": "No sync handler available",
|
||||
})
|
||||
continue
|
||||
|
||||
ds.last_sync_status = "in_progress"
|
||||
db.commit()
|
||||
|
||||
try:
|
||||
summary = handler(db)
|
||||
ds.last_sync_at = datetime.utcnow()
|
||||
ds.last_sync_status = "success"
|
||||
ds.last_sync_stats = summary
|
||||
db.commit()
|
||||
results.append({
|
||||
"source": ds.name,
|
||||
"status": "success",
|
||||
"stats": summary,
|
||||
})
|
||||
except Exception as exc:
|
||||
logger.error("Sync failed for %s: %s", ds.name, exc)
|
||||
ds.last_sync_status = "error"
|
||||
ds.last_sync_at = datetime.utcnow()
|
||||
ds.last_sync_stats = {"error": str(exc)}
|
||||
db.commit()
|
||||
results.append({
|
||||
"source": ds.name,
|
||||
"status": "error",
|
||||
"detail": str(exc),
|
||||
})
|
||||
|
||||
log_action(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
action="sync_all_data_sources",
|
||||
entity_type="data_source",
|
||||
entity_id=None,
|
||||
details={"results": results},
|
||||
)
|
||||
# Open context manager
|
||||
with UnitOfWork(db) as uow:
|
||||
# Call log_action()
|
||||
log_action(
|
||||
db,
|
||||
# Keyword argument: user_id
|
||||
user_id=current_user.id,
|
||||
# Keyword argument: action
|
||||
action="sync_all_data_sources",
|
||||
# Keyword argument: entity_type
|
||||
entity_type="data_source",
|
||||
# Keyword argument: entity_id
|
||||
entity_id=None,
|
||||
# Keyword argument: details
|
||||
details={"results": results},
|
||||
)
|
||||
# Call uow.commit()
|
||||
uow.commit()
|
||||
|
||||
# Return {"message": "Sync all complete", "results": results}
|
||||
return {"message": "Sync all complete", "results": results}
|
||||
|
||||
|
||||
# Apply the @router.get decorator
|
||||
@router.get("/{source_id}/stats")
|
||||
# Define function get_data_source_stats
|
||||
def get_data_source_stats(
|
||||
# Entry: source_id
|
||||
source_id: str,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
) -> dict:
|
||||
"""Get detailed statistics for a specific data source.
|
||||
|
||||
**Requires** the ``admin`` role.
|
||||
"""
|
||||
ds = db.query(DataSource).filter(DataSource.id == source_id).first()
|
||||
if not ds:
|
||||
raise HTTPException(status_code=404, detail="Data source not found")
|
||||
|
||||
# Count items from this source
|
||||
from app.models.test_template import TestTemplate
|
||||
from app.models.detection_rule import DetectionRule
|
||||
|
||||
template_count = 0
|
||||
rule_count = 0
|
||||
|
||||
if ds.type == "attack_procedure":
|
||||
template_count = (
|
||||
db.query(TestTemplate)
|
||||
.filter(TestTemplate.source == ds.name)
|
||||
.count()
|
||||
)
|
||||
elif ds.type == "detection_rule":
|
||||
rule_count = (
|
||||
db.query(DetectionRule)
|
||||
.filter(DetectionRule.source == ds.name)
|
||||
.count()
|
||||
)
|
||||
|
||||
return {
|
||||
"id": str(ds.id),
|
||||
"name": ds.name,
|
||||
"display_name": ds.display_name,
|
||||
"type": ds.type,
|
||||
"is_enabled": ds.is_enabled,
|
||||
"last_sync_at": ds.last_sync_at.isoformat() if ds.last_sync_at else None,
|
||||
"last_sync_status": ds.last_sync_status,
|
||||
"last_sync_stats": ds.last_sync_stats,
|
||||
"total_templates": template_count,
|
||||
"total_rules": rule_count,
|
||||
}
|
||||
# Return get_source_stats(db, source_id)
|
||||
return get_source_stats(db, source_id)
|
||||
|
||||
@@ -1,370 +1,191 @@
|
||||
"""Detection rules endpoints — listing, filtering, and template association.
|
||||
|
||||
Thin HTTP adapter: delegates all query and business logic to detection_rule_service.
|
||||
|
||||
Provides endpoints for browsing detection rules, querying rules by technique,
|
||||
and managing the template ↔ detection rule associations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
# Import uuid
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy import func
|
||||
# Import Optional from typing
|
||||
from typing import Optional
|
||||
|
||||
# Import APIRouter, Depends, Query from fastapi
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
|
||||
# Import BaseModel from pydantic
|
||||
from pydantic import BaseModel
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import get_db from app.database
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import get_current_user, require_role, require_any_role
|
||||
|
||||
# Import get_current_user, require_any_role, require_role from app.dependencies.auth
|
||||
from app.dependencies.auth import get_current_user, require_any_role, require_role
|
||||
|
||||
# Import User from app.models.user
|
||||
from app.models.user import User
|
||||
from app.models.detection_rule import DetectionRule
|
||||
from app.models.test_template import TestTemplate
|
||||
from app.models.test_template_detection_rule import TestTemplateDetectionRule
|
||||
from app.models.test_detection_result import TestDetectionResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# Import from app.services.detection_rule_service
|
||||
from app.services.detection_rule_service import (
|
||||
auto_associate_rules,
|
||||
evaluate_rule,
|
||||
get_rules_for_template,
|
||||
get_rules_for_test,
|
||||
list_rules,
|
||||
)
|
||||
|
||||
# ── Pydantic schemas for request validation ────────────────────────────
|
||||
|
||||
|
||||
class DetectionRuleEvaluate(BaseModel):
|
||||
"""Payload for evaluating a detection rule against a test."""
|
||||
# test_id: uuid.UUID
|
||||
test_id: uuid.UUID
|
||||
# detection_rule_id: uuid.UUID
|
||||
detection_rule_id: uuid.UUID
|
||||
# Assign triggered = None
|
||||
triggered: Optional[bool] = None
|
||||
# Assign notes = None
|
||||
notes: Optional[str] = None
|
||||
|
||||
|
||||
# Assign router = APIRouter(prefix="/detection-rules", tags=["detection-rules"])
|
||||
router = APIRouter(prefix="/detection-rules", tags=["detection-rules"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /detection-rules — List with filters
|
||||
# ---------------------------------------------------------------------------
|
||||
# ── GET /detection-rules — List with filters ───────────────────────────
|
||||
|
||||
|
||||
@router.get("")
|
||||
# Define function list_detection_rules
|
||||
def list_detection_rules(
|
||||
# Entry: technique
|
||||
technique: Optional[str] = Query(None, description="Filter by MITRE technique ID"),
|
||||
# Entry: source
|
||||
source: Optional[str] = Query(None, description="Filter by source (sigma, elastic, splunk, custom)"),
|
||||
# Entry: severity
|
||||
severity: Optional[str] = Query(None),
|
||||
# Entry: search
|
||||
search: Optional[str] = Query(None),
|
||||
# Entry: offset
|
||||
offset: int = Query(0, ge=0),
|
||||
# Entry: limit
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> list:
|
||||
"""List detection rules with optional filters and pagination."""
|
||||
query = db.query(DetectionRule).filter(DetectionRule.is_active == True) # noqa: E712
|
||||
|
||||
if technique:
|
||||
query = query.filter(DetectionRule.mitre_technique_id == technique)
|
||||
|
||||
if source:
|
||||
query = query.filter(DetectionRule.source == source)
|
||||
|
||||
if severity:
|
||||
query = query.filter(DetectionRule.severity == severity)
|
||||
|
||||
if search:
|
||||
pattern = f"%{search}%"
|
||||
query = query.filter(
|
||||
DetectionRule.title.ilike(pattern)
|
||||
| DetectionRule.description.ilike(pattern)
|
||||
)
|
||||
|
||||
total = query.count()
|
||||
items = query.order_by(DetectionRule.mitre_technique_id, DetectionRule.title).offset(offset).limit(limit).all()
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"offset": offset,
|
||||
"limit": limit,
|
||||
"items": [
|
||||
{
|
||||
"id": str(r.id),
|
||||
"mitre_technique_id": r.mitre_technique_id,
|
||||
"title": r.title,
|
||||
"description": r.description,
|
||||
"source": r.source,
|
||||
"source_url": r.source_url,
|
||||
"rule_format": r.rule_format,
|
||||
"severity": r.severity,
|
||||
"platforms": r.platforms or [],
|
||||
"log_sources": r.log_sources,
|
||||
"is_active": r.is_active,
|
||||
}
|
||||
for r in items
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /test-templates/{id}/detection-rules — Rules for a template
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/for-template/{template_id}")
|
||||
def get_detection_rules_for_template(
|
||||
template_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Get detection rules associated with a test template."""
|
||||
template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first()
|
||||
if not template:
|
||||
raise HTTPException(status_code=404, detail="Test template not found")
|
||||
|
||||
associations = (
|
||||
db.query(TestTemplateDetectionRule)
|
||||
.filter(TestTemplateDetectionRule.test_template_id == template_id)
|
||||
.all()
|
||||
# Return list_rules(
|
||||
return list_rules(
|
||||
db,
|
||||
# Keyword argument: technique
|
||||
technique=technique,
|
||||
# Keyword argument: source
|
||||
source=source,
|
||||
# Keyword argument: severity
|
||||
severity=severity,
|
||||
# Keyword argument: search
|
||||
search=search,
|
||||
# Keyword argument: offset
|
||||
offset=offset,
|
||||
# Keyword argument: limit
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
rules = []
|
||||
for assoc in associations:
|
||||
r = assoc.detection_rule
|
||||
rules.append({
|
||||
"id": str(r.id),
|
||||
"mitre_technique_id": r.mitre_technique_id,
|
||||
"title": r.title,
|
||||
"description": r.description,
|
||||
"source": r.source,
|
||||
"source_url": r.source_url,
|
||||
"rule_content": r.rule_content,
|
||||
"rule_format": r.rule_format,
|
||||
"severity": r.severity,
|
||||
"platforms": r.platforms or [],
|
||||
"log_sources": r.log_sources,
|
||||
"is_primary": assoc.is_primary,
|
||||
})
|
||||
|
||||
return {
|
||||
"template_id": str(template.id),
|
||||
"template_name": template.name,
|
||||
"mitre_technique_id": template.mitre_technique_id,
|
||||
"rules": rules,
|
||||
"total": len(rules),
|
||||
}
|
||||
# ── GET /detection-rules/for-template/{template_id} ────────────────────
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /detection-rules/auto-associate — Auto-link templates ↔ rules
|
||||
# ---------------------------------------------------------------------------
|
||||
@router.get("/for-template/{template_id}")
|
||||
# Define function get_detection_rules_for_template
|
||||
def get_detection_rules_for_template(
|
||||
# Entry: template_id
|
||||
template_id: str,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> list:
|
||||
"""Get detection rules associated with a test template."""
|
||||
# Return get_rules_for_template(db, template_id)
|
||||
return get_rules_for_template(db, template_id)
|
||||
|
||||
|
||||
# ── POST /detection-rules/auto-associate ────────────────────────────────
|
||||
|
||||
|
||||
@router.post("/auto-associate")
|
||||
# Define function auto_associate_detection_rules
|
||||
def auto_associate_detection_rules(
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
) -> dict:
|
||||
"""Auto-associate test templates with detection rules by MITRE technique ID.
|
||||
|
||||
For each active template, find all active detection rules for the same
|
||||
technique and create associations. Rules with severity >= high are marked
|
||||
as primary.
|
||||
"""
|
||||
templates = db.query(TestTemplate).filter(TestTemplate.is_active == True).all() # noqa: E712
|
||||
rules = db.query(DetectionRule).filter(DetectionRule.is_active == True).all() # noqa: E712
|
||||
|
||||
# Index rules by technique
|
||||
rules_by_technique: dict[str, list] = {}
|
||||
for rule in rules:
|
||||
tid = rule.mitre_technique_id
|
||||
if tid not in rules_by_technique:
|
||||
rules_by_technique[tid] = []
|
||||
rules_by_technique[tid].append(rule)
|
||||
|
||||
created = 0
|
||||
skipped = 0
|
||||
high_severities = {"high", "critical"}
|
||||
|
||||
for template in templates:
|
||||
matching_rules = rules_by_technique.get(template.mitre_technique_id, [])
|
||||
for rule in matching_rules:
|
||||
# Check if association already exists
|
||||
existing = (
|
||||
db.query(TestTemplateDetectionRule)
|
||||
.filter(
|
||||
TestTemplateDetectionRule.test_template_id == template.id,
|
||||
TestTemplateDetectionRule.detection_rule_id == rule.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
is_primary = (rule.severity or "").lower() in high_severities
|
||||
|
||||
assoc = TestTemplateDetectionRule(
|
||||
test_template_id=template.id,
|
||||
detection_rule_id=rule.id,
|
||||
is_primary=is_primary,
|
||||
)
|
||||
db.add(assoc)
|
||||
created += 1
|
||||
|
||||
db.commit()
|
||||
|
||||
total = db.query(TestTemplateDetectionRule).count()
|
||||
return {
|
||||
"created": created,
|
||||
"skipped": skipped,
|
||||
"total_associations": total,
|
||||
}
|
||||
# Return auto_associate_rules(db)
|
||||
return auto_associate_rules(db)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /detection-rules/for-test/{test_id} — Rules + results for a test
|
||||
# ---------------------------------------------------------------------------
|
||||
# ── GET /detection-rules/for-test/{test_id} ──────────────────────────────
|
||||
|
||||
|
||||
@router.get("/for-test/{test_id}")
|
||||
# Define function get_detection_rules_for_test
|
||||
def get_detection_rules_for_test(
|
||||
# Entry: test_id
|
||||
test_id: str,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> list:
|
||||
"""Get detection rules relevant to a test, along with their evaluation results.
|
||||
|
||||
Finds rules by matching the test's technique_id to detection rules,
|
||||
and returns any existing evaluation results.
|
||||
"""
|
||||
from app.models.test import Test
|
||||
from app.models.technique import Technique
|
||||
|
||||
test = db.query(Test).filter(Test.id == test_id).first()
|
||||
if not test:
|
||||
raise HTTPException(status_code=404, detail="Test not found")
|
||||
|
||||
technique = db.query(Technique).filter(Technique.id == test.technique_id).first()
|
||||
if not technique:
|
||||
raise HTTPException(status_code=404, detail="Technique not found")
|
||||
|
||||
# Get detection rules for this technique
|
||||
rules = (
|
||||
db.query(DetectionRule)
|
||||
.filter(
|
||||
DetectionRule.mitre_technique_id == technique.mitre_id,
|
||||
DetectionRule.is_active == True, # noqa: E712
|
||||
)
|
||||
.order_by(DetectionRule.severity.desc(), DetectionRule.title)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Get existing results for this test
|
||||
existing_results = (
|
||||
db.query(TestDetectionResult)
|
||||
.filter(TestDetectionResult.test_id == test_id)
|
||||
.all()
|
||||
)
|
||||
results_map = {str(r.detection_rule_id): r for r in existing_results}
|
||||
|
||||
items = []
|
||||
triggered_count = 0
|
||||
evaluated_count = 0
|
||||
|
||||
for rule in rules:
|
||||
result = results_map.get(str(rule.id))
|
||||
triggered = result.triggered if result else None
|
||||
notes = result.notes if result else None
|
||||
evaluated_at = result.evaluated_at.isoformat() if result and result.evaluated_at else None
|
||||
|
||||
if triggered is not None:
|
||||
evaluated_count += 1
|
||||
if triggered:
|
||||
triggered_count += 1
|
||||
|
||||
items.append({
|
||||
"id": str(rule.id),
|
||||
"mitre_technique_id": rule.mitre_technique_id,
|
||||
"title": rule.title,
|
||||
"description": rule.description,
|
||||
"source": rule.source,
|
||||
"source_url": rule.source_url,
|
||||
"rule_content": rule.rule_content,
|
||||
"rule_format": rule.rule_format,
|
||||
"severity": rule.severity,
|
||||
"platforms": rule.platforms or [],
|
||||
"log_sources": rule.log_sources,
|
||||
"triggered": triggered,
|
||||
"notes": notes,
|
||||
"evaluated_at": evaluated_at,
|
||||
"result_id": str(result.id) if result else None,
|
||||
})
|
||||
|
||||
return {
|
||||
"test_id": str(test.id),
|
||||
"mitre_technique_id": technique.mitre_id,
|
||||
"rules": items,
|
||||
"total": len(items),
|
||||
"evaluated": evaluated_count,
|
||||
"triggered": triggered_count,
|
||||
"detection_rate": round(triggered_count / evaluated_count * 100, 1) if evaluated_count > 0 else 0,
|
||||
}
|
||||
# Return get_rules_for_test(db, test_id)
|
||||
return get_rules_for_test(db, test_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /detection-rules/evaluate — Save detection result for a rule
|
||||
# ---------------------------------------------------------------------------
|
||||
# ── POST /detection-rules/evaluate ──────────────────────────────────────
|
||||
|
||||
|
||||
@router.post("/evaluate")
|
||||
# Define function evaluate_detection_rule
|
||||
def evaluate_detection_rule(
|
||||
payload: dict,
|
||||
# Entry: payload
|
||||
payload: DetectionRuleEvaluate,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(require_any_role("blue_tech", "blue_lead")),
|
||||
):
|
||||
"""Save or update the evaluation result for a detection rule on a test.
|
||||
|
||||
Body:
|
||||
{
|
||||
"test_id": "...",
|
||||
"detection_rule_id": "...",
|
||||
"triggered": true | false | null,
|
||||
"notes": "optional notes"
|
||||
}
|
||||
"""
|
||||
test_id = payload.get("test_id")
|
||||
detection_rule_id = payload.get("detection_rule_id")
|
||||
triggered = payload.get("triggered")
|
||||
notes = payload.get("notes")
|
||||
|
||||
if not test_id or not detection_rule_id:
|
||||
raise HTTPException(status_code=400, detail="test_id and detection_rule_id are required")
|
||||
|
||||
# Check test exists
|
||||
from app.models.test import Test
|
||||
test = db.query(Test).filter(Test.id == test_id).first()
|
||||
if not test:
|
||||
raise HTTPException(status_code=404, detail="Test not found")
|
||||
|
||||
# Check rule exists
|
||||
rule = db.query(DetectionRule).filter(DetectionRule.id == detection_rule_id).first()
|
||||
if not rule:
|
||||
raise HTTPException(status_code=404, detail="Detection rule not found")
|
||||
|
||||
# Upsert result
|
||||
existing = (
|
||||
db.query(TestDetectionResult)
|
||||
.filter(
|
||||
TestDetectionResult.test_id == test_id,
|
||||
TestDetectionResult.detection_rule_id == detection_rule_id,
|
||||
)
|
||||
.first()
|
||||
) -> dict:
|
||||
"""Save or update the evaluation result for a detection rule on a test."""
|
||||
# Return evaluate_rule(
|
||||
return evaluate_rule(
|
||||
db,
|
||||
# Keyword argument: test_id
|
||||
test_id=payload.test_id,
|
||||
# Keyword argument: detection_rule_id
|
||||
detection_rule_id=payload.detection_rule_id,
|
||||
# Keyword argument: triggered
|
||||
triggered=payload.triggered,
|
||||
# Keyword argument: notes
|
||||
notes=payload.notes,
|
||||
# Keyword argument: evaluator_id
|
||||
evaluator_id=current_user.id,
|
||||
)
|
||||
|
||||
if existing:
|
||||
existing.triggered = triggered
|
||||
existing.notes = notes
|
||||
existing.evaluated_by = current_user.id
|
||||
existing.evaluated_at = datetime.utcnow()
|
||||
db.commit()
|
||||
db.refresh(existing)
|
||||
return {
|
||||
"id": str(existing.id),
|
||||
"triggered": existing.triggered,
|
||||
"notes": existing.notes,
|
||||
"evaluated_at": existing.evaluated_at.isoformat() if existing.evaluated_at else None,
|
||||
}
|
||||
else:
|
||||
result = TestDetectionResult(
|
||||
test_id=test_id,
|
||||
detection_rule_id=detection_rule_id,
|
||||
triggered=triggered,
|
||||
notes=notes,
|
||||
evaluated_by=current_user.id,
|
||||
evaluated_at=datetime.utcnow(),
|
||||
)
|
||||
db.add(result)
|
||||
db.commit()
|
||||
db.refresh(result)
|
||||
return {
|
||||
"id": str(result.id),
|
||||
"triggered": result.triggered,
|
||||
"notes": result.notes,
|
||||
"evaluated_at": result.evaluated_at.isoformat() if result.evaluated_at else None,
|
||||
}
|
||||
|
||||
+196
-187
@@ -19,204 +19,210 @@ Access Control
|
||||
``validated``, or ``rejected``.
|
||||
"""
|
||||
|
||||
# Import hashlib
|
||||
import hashlib
|
||||
|
||||
# Import os
|
||||
import os
|
||||
|
||||
# Import uuid
|
||||
import uuid as _uuid
|
||||
|
||||
# Import Optional from typing
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, File, Form, HTTPException, Query, UploadFile, status
|
||||
# Import APIRouter, Depends, File, Form, Query, Request,... from fastapi
|
||||
from fastapi import APIRouter, Depends, File, Form, Query, Request, UploadFile, status
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import get_db from app.database
|
||||
from app.database import get_db
|
||||
|
||||
# Import get_current_user from app.dependencies.auth
|
||||
from app.dependencies.auth import get_current_user
|
||||
from app.models.enums import TeamSide, TestState
|
||||
|
||||
# Import UnitOfWork from app.domain.unit_of_work
|
||||
from app.domain.unit_of_work import UnitOfWork
|
||||
|
||||
# Import limiter from app.limiter
|
||||
from app.limiter import limiter
|
||||
|
||||
# Import TeamSide from app.models.enums
|
||||
from app.models.enums import TeamSide
|
||||
|
||||
# Import Evidence from app.models.evidence
|
||||
from app.models.evidence import Evidence
|
||||
from app.models.test import Test
|
||||
|
||||
# Import User from app.models.user
|
||||
from app.models.user import User
|
||||
|
||||
# Import EvidenceOut from app.schemas.evidence
|
||||
from app.schemas.evidence import EvidenceOut
|
||||
|
||||
# Import log_action from app.services.audit_service
|
||||
from app.services.audit_service import log_action
|
||||
|
||||
# Import from app.services.evidence_service
|
||||
from app.services.evidence_service import (
|
||||
MAX_UPLOAD_SIZE,
|
||||
get_evidence_or_raise,
|
||||
get_test_or_raise,
|
||||
list_evidence_for_test,
|
||||
validate_delete_permission,
|
||||
validate_file,
|
||||
validate_upload_permission,
|
||||
)
|
||||
|
||||
# Import get_presigned_url, upload_file from app.storage
|
||||
from app.storage import get_presigned_url, upload_file
|
||||
|
||||
# Assign router = APIRouter(tags=["evidence"])
|
||||
router = APIRouter(tags=["evidence"])
|
||||
|
||||
# States where red evidence can be uploaded / deleted
|
||||
_RED_EDITABLE_STATES = (TestState.draft, TestState.red_executing)
|
||||
# States where blue evidence can be uploaded / deleted
|
||||
_BLUE_EDITABLE_STATES = (TestState.blue_evaluating,)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# Helpers (router-specific: infrastructure / HTTP concerns)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _evidence_to_out(evidence: Evidence) -> EvidenceOut:
|
||||
"""Convert an ORM ``Evidence`` to the API schema, injecting a presigned URL."""
|
||||
# Return EvidenceOut(
|
||||
return EvidenceOut(
|
||||
# Keyword argument: id
|
||||
id=evidence.id,
|
||||
# Keyword argument: test_id
|
||||
test_id=evidence.test_id,
|
||||
# Keyword argument: file_name
|
||||
file_name=evidence.file_name,
|
||||
# Keyword argument: sha256_hash
|
||||
sha256_hash=evidence.sha256_hash,
|
||||
# Keyword argument: uploaded_by
|
||||
uploaded_by=evidence.uploaded_by,
|
||||
# Keyword argument: uploaded_at
|
||||
uploaded_at=evidence.uploaded_at,
|
||||
# Keyword argument: team
|
||||
team=evidence.team,
|
||||
# Keyword argument: notes
|
||||
notes=evidence.notes,
|
||||
# Keyword argument: download_url
|
||||
download_url=get_presigned_url(evidence.file_path),
|
||||
)
|
||||
|
||||
|
||||
def _validate_upload_permission(
|
||||
test: Test,
|
||||
team: TeamSide,
|
||||
user: User,
|
||||
) -> None:
|
||||
"""Raise 403 if the user/team combination is not allowed in the current state."""
|
||||
# Admins bypass all checks
|
||||
if user.role == "admin":
|
||||
return
|
||||
|
||||
if team == TeamSide.red:
|
||||
# Only red_tech can upload red evidence
|
||||
if user.role != "red_tech":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Only red_tech or admin can upload red evidence",
|
||||
)
|
||||
if test.state not in _RED_EDITABLE_STATES:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Cannot upload red evidence in '{test.state.value}' state "
|
||||
f"(allowed in: draft, red_executing)",
|
||||
)
|
||||
elif team == TeamSide.blue:
|
||||
# Only blue_tech can upload blue evidence
|
||||
if user.role != "blue_tech":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Only blue_tech or admin can upload blue evidence",
|
||||
)
|
||||
if test.state not in _BLUE_EDITABLE_STATES:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Cannot upload blue evidence in '{test.state.value}' state "
|
||||
f"(allowed in: blue_evaluating)",
|
||||
)
|
||||
|
||||
|
||||
def _validate_delete_permission(
|
||||
test: Test,
|
||||
evidence: Evidence,
|
||||
user: User,
|
||||
) -> None:
|
||||
"""Raise 403 if the user cannot delete this evidence in the current state."""
|
||||
# No deletions in review / validated / rejected
|
||||
if test.state in (TestState.in_review, TestState.validated, TestState.rejected):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Cannot delete evidence when test is in '{test.state.value}' state",
|
||||
)
|
||||
|
||||
# Admin can delete in editable states
|
||||
if user.role == "admin":
|
||||
return
|
||||
|
||||
ev_team = evidence.team
|
||||
|
||||
if ev_team == TeamSide.red:
|
||||
if test.state not in _RED_EDITABLE_STATES:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Cannot delete red evidence outside draft/red_executing",
|
||||
)
|
||||
if user.role != "red_tech" and evidence.uploaded_by != user.id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Not enough permissions to delete this evidence",
|
||||
)
|
||||
elif ev_team == TeamSide.blue:
|
||||
if test.state not in _BLUE_EDITABLE_STATES:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Cannot delete blue evidence outside blue_evaluating",
|
||||
)
|
||||
if user.role != "blue_tech" and evidence.uploaded_by != user.id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Not enough permissions to delete this evidence",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /tests/{test_id}/evidence — upload with team
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.post(
|
||||
# Literal argument value
|
||||
"/tests/{test_id}/evidence",
|
||||
# Keyword argument: response_model
|
||||
response_model=EvidenceOut,
|
||||
# Keyword argument: status_code
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
)
|
||||
# Apply the @limiter.limit decorator
|
||||
@limiter.limit("10/minute")
|
||||
# Define async function upload_evidence
|
||||
async def upload_evidence(
|
||||
# Entry: request
|
||||
request: Request,
|
||||
# Entry: test_id
|
||||
test_id: _uuid.UUID,
|
||||
# Entry: file
|
||||
file: UploadFile = File(...),
|
||||
# Entry: team
|
||||
team: TeamSide = Form(TeamSide.red),
|
||||
# Entry: notes
|
||||
notes: Optional[str] = Form(None),
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> EvidenceOut:
|
||||
"""Upload a file as evidence for the given test.
|
||||
|
||||
The ``team`` field (sent as form data) determines whether this is
|
||||
Red Team (attack) or Blue Team (detection) evidence.
|
||||
"""
|
||||
test = db.query(Test).filter(Test.id == test_id).first()
|
||||
if test is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Test not found",
|
||||
)
|
||||
# Assign test = get_test_or_raise(db, test_id)
|
||||
test = get_test_or_raise(db, test_id)
|
||||
# Call validate_upload_permission()
|
||||
validate_upload_permission(test, team, current_user.role)
|
||||
|
||||
# Validate permissions
|
||||
_validate_upload_permission(test, team, current_user)
|
||||
# Assign file_name = file.filename or "unnamed"
|
||||
file_name = file.filename or "unnamed"
|
||||
# Assign content = await file.read(MAX_UPLOAD_SIZE + 1)
|
||||
content = await file.read(MAX_UPLOAD_SIZE + 1)
|
||||
# Call validate_file()
|
||||
validate_file(file_name, len(content))
|
||||
|
||||
# 1. Read content + hash
|
||||
content = await file.read()
|
||||
# Hash
|
||||
sha256 = hashlib.sha256(content).hexdigest()
|
||||
|
||||
# 2. Object key
|
||||
file_name = file.filename or "unnamed"
|
||||
key = f"{test_id}/{_uuid.uuid4()}_{file_name}"
|
||||
# 4. Object key (sanitise filename to prevent path traversal in storage)
|
||||
safe_name = os.path.basename(file_name)
|
||||
# Assign key = f"{test_id}/{_uuid.uuid4()}_{safe_name}"
|
||||
key = f"{test_id}/{_uuid.uuid4()}_{safe_name}"
|
||||
|
||||
# 3. Upload to MinIO
|
||||
# 5. Upload to MinIO
|
||||
upload_file(content, key)
|
||||
|
||||
# 4. Persist metadata
|
||||
evidence = Evidence(
|
||||
test_id=test_id,
|
||||
file_name=file_name,
|
||||
file_path=key,
|
||||
sha256_hash=sha256,
|
||||
uploaded_by=current_user.id,
|
||||
team=team,
|
||||
notes=notes,
|
||||
)
|
||||
db.add(evidence)
|
||||
db.commit()
|
||||
# 6. Persist metadata and audit
|
||||
with UnitOfWork(db) as uow:
|
||||
# Assign evidence = Evidence(
|
||||
evidence = Evidence(
|
||||
# Keyword argument: test_id
|
||||
test_id=test_id,
|
||||
# Keyword argument: file_name
|
||||
file_name=safe_name,
|
||||
# Keyword argument: file_path
|
||||
file_path=key,
|
||||
# Keyword argument: sha256_hash
|
||||
sha256_hash=sha256,
|
||||
# Keyword argument: uploaded_by
|
||||
uploaded_by=current_user.id,
|
||||
# Keyword argument: team
|
||||
team=team,
|
||||
# Keyword argument: notes
|
||||
notes=notes,
|
||||
)
|
||||
# Stage new record(s) for database insertion
|
||||
db.add(evidence)
|
||||
# Flush changes to DB without committing the transaction
|
||||
db.flush() # Get evidence.id for audit
|
||||
# Call log_action()
|
||||
log_action(
|
||||
db,
|
||||
# Keyword argument: user_id
|
||||
user_id=current_user.id,
|
||||
# Keyword argument: action
|
||||
action="upload_evidence",
|
||||
# Keyword argument: entity_type
|
||||
entity_type="evidence",
|
||||
# Keyword argument: entity_id
|
||||
entity_id=evidence.id,
|
||||
# Keyword argument: details
|
||||
details={
|
||||
# Literal argument value
|
||||
"file_name": safe_name,
|
||||
# Literal argument value
|
||||
"sha256": sha256,
|
||||
# Literal argument value
|
||||
"test_id": str(test_id),
|
||||
# Literal argument value
|
||||
"team": team.value,
|
||||
},
|
||||
)
|
||||
# Call uow.commit()
|
||||
uow.commit()
|
||||
# Reload ORM object attributes from the database
|
||||
db.refresh(evidence)
|
||||
|
||||
# 5. Audit
|
||||
log_action(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
action="upload_evidence",
|
||||
entity_type="evidence",
|
||||
entity_id=evidence.id,
|
||||
details={
|
||||
"file_name": file_name,
|
||||
"sha256": sha256,
|
||||
"test_id": str(test_id),
|
||||
"team": team.value,
|
||||
},
|
||||
)
|
||||
|
||||
# Return _evidence_to_out(evidence)
|
||||
return _evidence_to_out(evidence)
|
||||
|
||||
|
||||
@@ -226,26 +232,23 @@ async def upload_evidence(
|
||||
|
||||
|
||||
@router.get("/tests/{test_id}/evidence", response_model=list[EvidenceOut])
|
||||
# Define function list_evidence
|
||||
def list_evidence(
|
||||
# Entry: test_id
|
||||
test_id: _uuid.UUID,
|
||||
# Entry: team
|
||||
team: Optional[str] = Query(None, description="Filter by team: red or blue"),
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> list[EvidenceOut]:
|
||||
"""List all evidences for a test, optionally filtered by team."""
|
||||
test = db.query(Test).filter(Test.id == test_id).first()
|
||||
if test is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Test not found",
|
||||
)
|
||||
|
||||
query = db.query(Evidence).filter(Evidence.test_id == test_id)
|
||||
|
||||
if team:
|
||||
query = query.filter(Evidence.team == team)
|
||||
|
||||
evidences = query.order_by(Evidence.uploaded_at.desc()).all()
|
||||
# Call get_test_or_raise()
|
||||
get_test_or_raise(db, test_id)
|
||||
# Assign evidences = list_evidence_for_test(db, test_id, team=team)
|
||||
evidences = list_evidence_for_test(db, test_id, team=team)
|
||||
# Return [_evidence_to_out(e) for e in evidences]
|
||||
return [_evidence_to_out(e) for e in evidences]
|
||||
|
||||
|
||||
@@ -255,19 +258,19 @@ def list_evidence(
|
||||
|
||||
|
||||
@router.get("/evidence/{evidence_id}", response_model=EvidenceOut)
|
||||
# Define function get_evidence
|
||||
def get_evidence(
|
||||
# Entry: evidence_id
|
||||
evidence_id: _uuid.UUID,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> EvidenceOut:
|
||||
"""Return evidence metadata together with a presigned download URL."""
|
||||
evidence = db.query(Evidence).filter(Evidence.id == evidence_id).first()
|
||||
if evidence is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Evidence not found",
|
||||
)
|
||||
|
||||
# Assign evidence = get_evidence_or_raise(db, evidence_id)
|
||||
evidence = get_evidence_or_raise(db, evidence_id)
|
||||
# Return _evidence_to_out(evidence)
|
||||
return _evidence_to_out(evidence)
|
||||
|
||||
|
||||
@@ -277,11 +280,15 @@ def get_evidence(
|
||||
|
||||
|
||||
@router.delete("/evidence/{evidence_id}", status_code=status.HTTP_200_OK)
|
||||
# Define function delete_evidence
|
||||
def delete_evidence(
|
||||
# Entry: evidence_id
|
||||
evidence_id: _uuid.UUID,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> dict:
|
||||
"""Delete an evidence record.
|
||||
|
||||
Only allowed in editable states:
|
||||
@@ -289,38 +296,40 @@ def delete_evidence(
|
||||
- Blue evidence: ``blue_evaluating``
|
||||
- No deletions in ``in_review``, ``validated``, ``rejected``
|
||||
"""
|
||||
evidence = db.query(Evidence).filter(Evidence.id == evidence_id).first()
|
||||
if evidence is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Evidence not found",
|
||||
# Assign evidence = get_evidence_or_raise(db, evidence_id)
|
||||
evidence = get_evidence_or_raise(db, evidence_id)
|
||||
# Assign test = get_test_or_raise(db, evidence.test_id)
|
||||
test = get_test_or_raise(db, evidence.test_id)
|
||||
# Call validate_delete_permission()
|
||||
validate_delete_permission(test, evidence, current_user.role, current_user.id)
|
||||
|
||||
# Open context manager
|
||||
with UnitOfWork(db) as uow:
|
||||
# Call log_action()
|
||||
log_action(
|
||||
db,
|
||||
# Keyword argument: user_id
|
||||
user_id=current_user.id,
|
||||
# Keyword argument: action
|
||||
action="delete_evidence",
|
||||
# Keyword argument: entity_type
|
||||
entity_type="evidence",
|
||||
# Keyword argument: entity_id
|
||||
entity_id=evidence.id,
|
||||
# Keyword argument: details
|
||||
details={
|
||||
# Literal argument value
|
||||
"file_name": evidence.file_name,
|
||||
# Literal argument value
|
||||
"test_id": str(evidence.test_id),
|
||||
# Literal argument value
|
||||
"team": evidence.team.value if evidence.team else None,
|
||||
},
|
||||
)
|
||||
# Mark record for deletion on next commit
|
||||
db.delete(evidence)
|
||||
# Call uow.commit()
|
||||
uow.commit()
|
||||
|
||||
test = db.query(Test).filter(Test.id == evidence.test_id).first()
|
||||
if test is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Parent test not found",
|
||||
)
|
||||
|
||||
# Permission checks
|
||||
_validate_delete_permission(test, evidence, current_user)
|
||||
|
||||
# Audit before deletion
|
||||
log_action(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
action="delete_evidence",
|
||||
entity_type="evidence",
|
||||
entity_id=evidence.id,
|
||||
details={
|
||||
"file_name": evidence.file_name,
|
||||
"test_id": str(evidence.test_id),
|
||||
"team": evidence.team.value if evidence.team else None,
|
||||
},
|
||||
)
|
||||
|
||||
db.delete(evidence)
|
||||
db.commit()
|
||||
|
||||
# Return {"detail": "Evidence deleted"}
|
||||
return {"detail": "Evidence deleted"}
|
||||
|
||||
+99
-452
@@ -1,526 +1,173 @@
|
||||
"""Heatmap endpoints — ATT&CK Navigator-compatible layer generation.
|
||||
|
||||
Provides multiple layer types (coverage, threat actor, detection rules,
|
||||
campaign) and an export endpoint that produces a JSON file importable
|
||||
by the official MITRE ATT&CK Navigator.
|
||||
Thin router that delegates entirely to :mod:`app.services.heatmap_service`.
|
||||
No business logic lives here — only request validation and response
|
||||
formatting.
|
||||
"""
|
||||
|
||||
from typing import Optional, List
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import io
|
||||
import io
|
||||
|
||||
# Import json
|
||||
import json
|
||||
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import get_current_user
|
||||
from app.models.user import User
|
||||
from app.models.technique import Technique
|
||||
from app.models.test import Test
|
||||
from app.models.threat_actor import ThreatActor, ThreatActorTechnique
|
||||
from app.models.detection_rule import DetectionRule
|
||||
from app.models.campaign import Campaign, CampaignTest
|
||||
from app.models.defensive_technique import DefensiveTechniqueMapping
|
||||
from app.models.enums import TechniqueStatus, TestState
|
||||
# Import Optional from typing
|
||||
from typing import Optional
|
||||
|
||||
# Import APIRouter, Depends, Query from fastapi
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
|
||||
# Import StreamingResponse from fastapi.responses
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import get_db from app.database
|
||||
from app.database import get_db
|
||||
|
||||
# Import get_current_user from app.dependencies.auth
|
||||
from app.dependencies.auth import get_current_user
|
||||
|
||||
# Import User from app.models.user
|
||||
from app.models.user import User
|
||||
|
||||
# Import heatmap_service from app.services
|
||||
from app.services import heatmap_service
|
||||
|
||||
# Assign router = APIRouter(prefix="/heatmap", tags=["heatmap"])
|
||||
router = APIRouter(prefix="/heatmap", tags=["heatmap"])
|
||||
|
||||
# ── Constants ─────────────────────────────────────────────────────────
|
||||
|
||||
ATTACK_VERSION = "15"
|
||||
NAVIGATOR_VERSION = "5.0"
|
||||
LAYER_VERSION = "4.5"
|
||||
DOMAIN = "enterprise-attack"
|
||||
|
||||
# Score mapping for technique status_global
|
||||
STATUS_SCORE_MAP = {
|
||||
TechniqueStatus.validated: 100,
|
||||
TechniqueStatus.partial: 60,
|
||||
TechniqueStatus.in_progress: 30,
|
||||
TechniqueStatus.not_covered: 10,
|
||||
TechniqueStatus.not_evaluated: 0,
|
||||
TechniqueStatus.review_required: 10,
|
||||
}
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _score_to_color(score: int) -> str:
|
||||
"""Map a 0-100 score to a red → yellow → green color hex."""
|
||||
if score <= 0:
|
||||
return "#d3d3d3" # gray for not evaluated
|
||||
if score <= 25:
|
||||
return "#ff6666" # red
|
||||
if score <= 50:
|
||||
return "#ff9933" # orange
|
||||
if score <= 75:
|
||||
return "#ffff66" # yellow
|
||||
return "#66ff66" # green
|
||||
|
||||
|
||||
def _build_layer_skeleton(
|
||||
name: str,
|
||||
description: str,
|
||||
gradient_colors: List[str] | None = None,
|
||||
) -> dict:
|
||||
"""Return a base layer dict compatible with ATT&CK Navigator."""
|
||||
return {
|
||||
"name": name,
|
||||
"versions": {
|
||||
"attack": ATTACK_VERSION,
|
||||
"navigator": NAVIGATOR_VERSION,
|
||||
"layer": LAYER_VERSION,
|
||||
},
|
||||
"domain": DOMAIN,
|
||||
"description": description,
|
||||
"filters": {"platforms": ["windows", "linux", "macos"]},
|
||||
"gradient": {
|
||||
"colors": gradient_colors or ["#ff6666", "#ffff66", "#66ff66"],
|
||||
"minValue": 0,
|
||||
"maxValue": 100,
|
||||
},
|
||||
"techniques": [],
|
||||
}
|
||||
|
||||
|
||||
def _apply_filters(
|
||||
query,
|
||||
model,
|
||||
platforms: Optional[List[str]] = None,
|
||||
tactics: Optional[List[str]] = None,
|
||||
):
|
||||
"""Apply common platform and tactic filters to a technique query."""
|
||||
if platforms:
|
||||
from sqlalchemy import or_, cast, String
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
# Filter techniques that have any of the specified platforms
|
||||
platform_filters = []
|
||||
for platform in platforms:
|
||||
platform_filters.append(
|
||||
model.platforms.op("@>")(json.dumps([platform]))
|
||||
)
|
||||
if platform_filters:
|
||||
query = query.filter(or_(*platform_filters))
|
||||
if tactics:
|
||||
from sqlalchemy import or_
|
||||
tactic_filters = []
|
||||
for tactic in tactics:
|
||||
tactic_filters.append(model.tactic.ilike(f"%{tactic}%"))
|
||||
query = query.filter(or_(*tactic_filters))
|
||||
return query
|
||||
|
||||
|
||||
def _format_tactic(tactic_str: str | None) -> str:
|
||||
"""Normalize tactic string to ATT&CK Navigator format (kebab-case)."""
|
||||
if not tactic_str:
|
||||
return ""
|
||||
# Take first tactic if comma-separated
|
||||
first = tactic_str.split(",")[0].strip().lower()
|
||||
return first
|
||||
|
||||
|
||||
def _get_technique_metadata(technique, db: Session) -> list:
|
||||
"""Build metadata array for a technique."""
|
||||
# Count validated tests
|
||||
test_count = (
|
||||
db.query(func.count(Test.id))
|
||||
.filter(Test.technique_id == technique.id, Test.state == TestState.validated)
|
||||
.scalar()
|
||||
) or 0
|
||||
|
||||
# Count detection rules
|
||||
rule_count = (
|
||||
db.query(func.count(DetectionRule.id))
|
||||
.filter(DetectionRule.mitre_technique_id == technique.mitre_id)
|
||||
.scalar()
|
||||
) or 0
|
||||
|
||||
metadata = [
|
||||
{"name": "tests_count", "value": str(test_count)},
|
||||
{"name": "detection_rules", "value": str(rule_count)},
|
||||
]
|
||||
|
||||
if technique.last_review_date:
|
||||
metadata.append(
|
||||
{"name": "last_validated", "value": technique.last_review_date.strftime("%Y-%m-%d")}
|
||||
)
|
||||
|
||||
return metadata
|
||||
|
||||
|
||||
# ── GET /heatmap/coverage ─────────────────────────────────────────────
|
||||
|
||||
|
||||
# Apply the @router.get decorator
|
||||
@router.get("/coverage")
|
||||
# Define function heatmap_coverage
|
||||
def heatmap_coverage(
|
||||
# Entry: platforms
|
||||
platforms: Optional[str] = Query(None, description="Comma-separated platforms"),
|
||||
# Entry: tactics
|
||||
tactics: Optional[str] = Query(None, description="Comma-separated tactics"),
|
||||
# Entry: min_score
|
||||
min_score: int = Query(0, ge=0, le=100),
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> dict:
|
||||
"""Coverage layer — score based on status_global of each technique."""
|
||||
layer = _build_layer_skeleton("Aegis Coverage", "Coverage layer generated by Aegis")
|
||||
|
||||
query = db.query(Technique)
|
||||
|
||||
platform_list = [p.strip() for p in platforms.split(",")] if platforms else None
|
||||
tactic_list = [t.strip() for t in tactics.split(",")] if tactics else None
|
||||
query = _apply_filters(query, Technique, platform_list, tactic_list)
|
||||
|
||||
techniques = query.all()
|
||||
|
||||
for tech in techniques:
|
||||
score = STATUS_SCORE_MAP.get(tech.status_global, 0)
|
||||
if score < min_score:
|
||||
continue
|
||||
|
||||
comment_parts = [f"Status: {tech.status_global.value}"]
|
||||
metadata = _get_technique_metadata(tech, db)
|
||||
|
||||
# Enrich comment with test/rule info
|
||||
tests_info = next((m for m in metadata if m["name"] == "tests_count"), None)
|
||||
rules_info = next((m for m in metadata if m["name"] == "detection_rules"), None)
|
||||
if tests_info:
|
||||
comment_parts.append(f"{tests_info['value']} tests validated")
|
||||
if rules_info:
|
||||
comment_parts.append(f"{rules_info['value']} detection rules")
|
||||
|
||||
layer["techniques"].append({
|
||||
"techniqueID": tech.mitre_id,
|
||||
"tactic": _format_tactic(tech.tactic),
|
||||
"color": _score_to_color(score),
|
||||
"score": score,
|
||||
"comment": " - ".join(comment_parts),
|
||||
"enabled": True,
|
||||
"metadata": metadata,
|
||||
})
|
||||
|
||||
return layer
|
||||
|
||||
|
||||
# ── GET /heatmap/threat-actor/{actor_id} ──────────────────────────────
|
||||
# Return heatmap_service.build_coverage_layer(
|
||||
return heatmap_service.build_coverage_layer(
|
||||
db, platforms=platforms, tactics=tactics, min_score=min_score,
|
||||
)
|
||||
|
||||
|
||||
# Apply the @router.get decorator
|
||||
@router.get("/threat-actor/{actor_id}")
|
||||
# Define function heatmap_threat_actor
|
||||
def heatmap_threat_actor(
|
||||
# Entry: actor_id
|
||||
actor_id: str,
|
||||
# Entry: platforms
|
||||
platforms: Optional[str] = Query(None),
|
||||
# Entry: tactics
|
||||
tactics: Optional[str] = Query(None),
|
||||
# Entry: min_score
|
||||
min_score: int = Query(0, ge=0, le=100),
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> dict:
|
||||
"""Threat actor layer — techniques used by an actor with coverage color."""
|
||||
actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first()
|
||||
if not actor:
|
||||
raise HTTPException(status_code=404, detail="Threat actor not found")
|
||||
|
||||
layer = _build_layer_skeleton(
|
||||
f"Threat Actor: {actor.name}",
|
||||
f"Techniques used by {actor.name} with coverage overlay",
|
||||
gradient_colors=["#808080", "#ff6666", "#66ff66"],
|
||||
# Return heatmap_service.build_threat_actor_layer(
|
||||
return heatmap_service.build_threat_actor_layer(
|
||||
db, actor_id, platforms=platforms, tactics=tactics, min_score=min_score,
|
||||
)
|
||||
|
||||
# Get actor's technique IDs
|
||||
actor_technique_rows = (
|
||||
db.query(ThreatActorTechnique)
|
||||
.filter(ThreatActorTechnique.threat_actor_id == actor.id)
|
||||
.all()
|
||||
)
|
||||
actor_technique_ids = {row.technique_id for row in actor_technique_rows}
|
||||
|
||||
if not actor_technique_ids:
|
||||
return layer
|
||||
|
||||
query = db.query(Technique)
|
||||
platform_list = [p.strip() for p in platforms.split(",")] if platforms else None
|
||||
tactic_list = [t.strip() for t in tactics.split(",")] if tactics else None
|
||||
query = _apply_filters(query, Technique, platform_list, tactic_list)
|
||||
techniques = query.all()
|
||||
|
||||
for tech in techniques:
|
||||
is_actor_technique = tech.id in actor_technique_ids
|
||||
score = STATUS_SCORE_MAP.get(tech.status_global, 0) if is_actor_technique else 0
|
||||
|
||||
if is_actor_technique and score < min_score:
|
||||
continue
|
||||
|
||||
if is_actor_technique:
|
||||
metadata = _get_technique_metadata(tech, db)
|
||||
layer["techniques"].append({
|
||||
"techniqueID": tech.mitre_id,
|
||||
"tactic": _format_tactic(tech.tactic),
|
||||
"color": _score_to_color(score),
|
||||
"score": score,
|
||||
"comment": f"Used by {actor.name} - Coverage: {tech.status_global.value}",
|
||||
"enabled": True,
|
||||
"metadata": metadata,
|
||||
})
|
||||
else:
|
||||
layer["techniques"].append({
|
||||
"techniqueID": tech.mitre_id,
|
||||
"tactic": _format_tactic(tech.tactic),
|
||||
"color": "",
|
||||
"score": 0,
|
||||
"comment": "",
|
||||
"enabled": False,
|
||||
"metadata": [],
|
||||
})
|
||||
|
||||
return layer
|
||||
|
||||
|
||||
# ── GET /heatmap/detection-rules ──────────────────────────────────────
|
||||
|
||||
|
||||
# Apply the @router.get decorator
|
||||
@router.get("/detection-rules")
|
||||
# Define function heatmap_detection_rules
|
||||
def heatmap_detection_rules(
|
||||
# Entry: platforms
|
||||
platforms: Optional[str] = Query(None),
|
||||
# Entry: tactics
|
||||
tactics: Optional[str] = Query(None),
|
||||
# Entry: min_score
|
||||
min_score: int = Query(0, ge=0, le=100),
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> dict:
|
||||
"""Detection rules layer — score based on ratio of rules available vs total."""
|
||||
layer = _build_layer_skeleton(
|
||||
"Detection Rules Coverage",
|
||||
"Coverage of detection rules per technique",
|
||||
# Return heatmap_service.build_detection_rules_layer(
|
||||
return heatmap_service.build_detection_rules_layer(
|
||||
db, platforms=platforms, tactics=tactics, min_score=min_score,
|
||||
)
|
||||
|
||||
query = db.query(Technique)
|
||||
platform_list = [p.strip() for p in platforms.split(",")] if platforms else None
|
||||
tactic_list = [t.strip() for t in tactics.split(",")] if tactics else None
|
||||
query = _apply_filters(query, Technique, platform_list, tactic_list)
|
||||
techniques = query.all()
|
||||
|
||||
# Get rule counts per technique_mitre_id in one query
|
||||
rule_counts = dict(
|
||||
db.query(
|
||||
DetectionRule.mitre_technique_id,
|
||||
func.count(DetectionRule.id),
|
||||
)
|
||||
.filter(DetectionRule.is_active == True)
|
||||
.group_by(DetectionRule.mitre_technique_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Find the max rule count for normalization
|
||||
max_rules = max(rule_counts.values()) if rule_counts else 1
|
||||
|
||||
from app.models.test_detection_result import TestDetectionResult
|
||||
|
||||
# Get evaluated rule counts per technique
|
||||
evaluated_counts_raw = (
|
||||
db.query(
|
||||
DetectionRule.mitre_technique_id,
|
||||
func.count(TestDetectionResult.id),
|
||||
)
|
||||
.join(TestDetectionResult, TestDetectionResult.detection_rule_id == DetectionRule.id)
|
||||
.filter(TestDetectionResult.triggered.isnot(None))
|
||||
.group_by(DetectionRule.mitre_technique_id)
|
||||
.all()
|
||||
)
|
||||
evaluated_counts = dict(evaluated_counts_raw)
|
||||
|
||||
for tech in techniques:
|
||||
total_rules = rule_counts.get(tech.mitre_id, 0)
|
||||
evaluated_rules = evaluated_counts.get(tech.mitre_id, 0)
|
||||
|
||||
if total_rules > 0:
|
||||
# Score based on rule availability (normalized) and evaluation ratio
|
||||
availability_score = min((total_rules / max_rules) * 50, 50)
|
||||
evaluation_score = (evaluated_rules / total_rules) * 50 if total_rules > 0 else 0
|
||||
score = int(min(availability_score + evaluation_score, 100))
|
||||
else:
|
||||
score = 0
|
||||
|
||||
if score < min_score:
|
||||
continue
|
||||
|
||||
layer["techniques"].append({
|
||||
"techniqueID": tech.mitre_id,
|
||||
"tactic": _format_tactic(tech.tactic),
|
||||
"color": _score_to_color(score),
|
||||
"score": score,
|
||||
"comment": f"{total_rules} rules available, {evaluated_rules} evaluated",
|
||||
"enabled": True,
|
||||
"metadata": [
|
||||
{"name": "total_rules", "value": str(total_rules)},
|
||||
{"name": "evaluated_rules", "value": str(evaluated_rules)},
|
||||
],
|
||||
})
|
||||
|
||||
return layer
|
||||
|
||||
|
||||
# ── GET /heatmap/campaign/{campaign_id} ───────────────────────────────
|
||||
|
||||
|
||||
# Apply the @router.get decorator
|
||||
@router.get("/campaign/{campaign_id}")
|
||||
# Define function heatmap_campaign
|
||||
def heatmap_campaign(
|
||||
# Entry: campaign_id
|
||||
campaign_id: str,
|
||||
# Entry: platforms
|
||||
platforms: Optional[str] = Query(None),
|
||||
# Entry: tactics
|
||||
tactics: Optional[str] = Query(None),
|
||||
# Entry: min_score
|
||||
min_score: int = Query(0, ge=0, le=100),
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> dict:
|
||||
"""Campaign layer — only techniques in the campaign, colored by test state."""
|
||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||
if not campaign:
|
||||
raise HTTPException(status_code=404, detail="Campaign not found")
|
||||
|
||||
layer = _build_layer_skeleton(
|
||||
f"Campaign: {campaign.name}",
|
||||
f"Progress of campaign '{campaign.name}'",
|
||||
# Return heatmap_service.build_campaign_layer(
|
||||
return heatmap_service.build_campaign_layer(
|
||||
db, campaign_id, platforms=platforms, tactics=tactics, min_score=min_score,
|
||||
)
|
||||
|
||||
# Get campaign tests with their associated techniques
|
||||
campaign_tests = (
|
||||
db.query(CampaignTest)
|
||||
.filter(CampaignTest.campaign_id == campaign.id)
|
||||
.all()
|
||||
)
|
||||
|
||||
if not campaign_tests:
|
||||
return layer
|
||||
|
||||
# Map test_id -> test for all tests in campaign
|
||||
test_ids = [ct.test_id for ct in campaign_tests]
|
||||
tests = db.query(Test).filter(Test.id.in_(test_ids)).all()
|
||||
test_map = {t.id: t for t in tests}
|
||||
|
||||
# Map technique_id -> technique
|
||||
technique_ids = {t.technique_id for t in tests if t.technique_id}
|
||||
techniques = db.query(Technique).filter(Technique.id.in_(technique_ids)).all()
|
||||
tech_map = {t.id: t for t in techniques}
|
||||
|
||||
# Score mapping for test states
|
||||
test_state_score = {
|
||||
TestState.validated: 100,
|
||||
TestState.in_review: 70,
|
||||
TestState.blue_evaluating: 50,
|
||||
TestState.red_executing: 30,
|
||||
TestState.draft: 10,
|
||||
TestState.rejected: 5,
|
||||
}
|
||||
|
||||
# Group by technique (a technique may have multiple tests in a campaign)
|
||||
tech_scores: dict = {}
|
||||
for ct in campaign_tests:
|
||||
test = test_map.get(ct.test_id)
|
||||
if not test:
|
||||
continue
|
||||
tech = tech_map.get(test.technique_id)
|
||||
if not tech:
|
||||
continue
|
||||
|
||||
state_score = test_state_score.get(test.state, 0)
|
||||
if tech.mitre_id not in tech_scores:
|
||||
tech_scores[tech.mitre_id] = {
|
||||
"technique": tech,
|
||||
"max_score": state_score,
|
||||
"tests": [],
|
||||
}
|
||||
else:
|
||||
tech_scores[tech.mitre_id]["max_score"] = max(
|
||||
tech_scores[tech.mitre_id]["max_score"], state_score
|
||||
)
|
||||
tech_scores[tech.mitre_id]["tests"].append(test)
|
||||
|
||||
platform_list = [p.strip() for p in platforms.split(",")] if platforms else None
|
||||
tactic_list = [t.strip() for t in tactics.split(",")] if tactics else None
|
||||
|
||||
for mitre_id, info in tech_scores.items():
|
||||
tech = info["technique"]
|
||||
score = info["max_score"]
|
||||
|
||||
# Apply filters
|
||||
if platform_list:
|
||||
tech_platforms = tech.platforms or []
|
||||
if not any(p in tech_platforms for p in platform_list):
|
||||
continue
|
||||
if tactic_list:
|
||||
tech_tactics = (tech.tactic or "").lower().split(",")
|
||||
tech_tactics = [t.strip() for t in tech_tactics]
|
||||
if not any(t in tech_tactics for t in tactic_list):
|
||||
continue
|
||||
if score < min_score:
|
||||
continue
|
||||
|
||||
test_states = [t.state.value for t in info["tests"]]
|
||||
layer["techniques"].append({
|
||||
"techniqueID": mitre_id,
|
||||
"tactic": _format_tactic(tech.tactic),
|
||||
"color": _score_to_color(score),
|
||||
"score": score,
|
||||
"comment": f"Campaign tests: {', '.join(test_states)}",
|
||||
"enabled": True,
|
||||
"metadata": [
|
||||
{"name": "campaign_tests", "value": str(len(info["tests"]))},
|
||||
{"name": "best_state", "value": max(test_states) if test_states else "none"},
|
||||
],
|
||||
})
|
||||
|
||||
return layer
|
||||
|
||||
|
||||
# ── GET /heatmap/export-navigator ─────────────────────────────────────
|
||||
|
||||
|
||||
# Apply the @router.get decorator
|
||||
@router.get("/export-navigator")
|
||||
# Define function export_navigator
|
||||
def export_navigator(
|
||||
# Entry: layer
|
||||
layer: str = Query(..., description="Layer type: coverage, threat-actor, detection-rules, campaign"),
|
||||
# Entry: layer_id
|
||||
layer_id: Optional[str] = Query(None, description="Actor ID or Campaign ID (if applicable)"),
|
||||
# Entry: platforms
|
||||
platforms: Optional[str] = Query(None),
|
||||
# Entry: tactics
|
||||
tactics: Optional[str] = Query(None),
|
||||
# Entry: min_score
|
||||
min_score: int = Query(0, ge=0, le=100),
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> StreamingResponse:
|
||||
"""Export a heatmap layer as a downloadable JSON file for ATT&CK Navigator."""
|
||||
# Delegate to the appropriate layer endpoint
|
||||
if layer == "coverage":
|
||||
data = heatmap_coverage(
|
||||
platforms=platforms, tactics=tactics, min_score=min_score,
|
||||
db=db, current_user=current_user,
|
||||
)
|
||||
elif layer == "threat-actor":
|
||||
if not layer_id:
|
||||
raise HTTPException(status_code=400, detail="layer_id required for threat-actor layer")
|
||||
data = heatmap_threat_actor(
|
||||
actor_id=layer_id, platforms=platforms, tactics=tactics,
|
||||
min_score=min_score, db=db, current_user=current_user,
|
||||
)
|
||||
elif layer == "detection-rules":
|
||||
data = heatmap_detection_rules(
|
||||
platforms=platforms, tactics=tactics, min_score=min_score,
|
||||
db=db, current_user=current_user,
|
||||
)
|
||||
elif layer == "campaign":
|
||||
if not layer_id:
|
||||
raise HTTPException(status_code=400, detail="layer_id required for campaign layer")
|
||||
data = heatmap_campaign(
|
||||
campaign_id=layer_id, platforms=platforms, tactics=tactics,
|
||||
min_score=min_score, db=db, current_user=current_user,
|
||||
)
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail=f"Unknown layer type: {layer}")
|
||||
# Assign data = heatmap_service.build_navigator_export(
|
||||
data = heatmap_service.build_navigator_export(
|
||||
db, layer, layer_id=layer_id,
|
||||
# Keyword argument: platforms
|
||||
platforms=platforms, tactics=tactics, min_score=min_score,
|
||||
)
|
||||
|
||||
# Convert to JSON and return as downloadable file
|
||||
# Assign json_content = json.dumps(data, indent=2, default=str)
|
||||
json_content = json.dumps(data, indent=2, default=str)
|
||||
# Assign buffer = io.BytesIO(json_content.encode("utf-8"))
|
||||
buffer = io.BytesIO(json_content.encode("utf-8"))
|
||||
filename = f"aegis_{layer}_layer.json"
|
||||
|
||||
# Return StreamingResponse(
|
||||
return StreamingResponse(
|
||||
buffer,
|
||||
# Keyword argument: media_type
|
||||
media_type="application/json",
|
||||
headers={
|
||||
"Content-Disposition": f"attachment; filename={filename}",
|
||||
},
|
||||
# Keyword argument: headers
|
||||
headers={"Content-Disposition": f"attachment; filename=aegis_{layer}_layer.json"},
|
||||
)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user